diff --git a/data_gen/generate_synthetic.py b/data_gen/generate_synthetic.py index c8de532..9dd7cc0 100644 --- a/data_gen/generate_synthetic.py +++ b/data_gen/generate_synthetic.py @@ -175,6 +175,11 @@ def make_label_two_line(plate, top, bot): ] +# 한국 도로 LP 분포 추정 (자가용 92% + 영업용 7.5%, 신형 가로 ~98% 등) +# 자산 한계 고려한 합성 가중치 — generate_synthetic 호출 시 --type_weights 로 덮어쓰기 가능. +TYPE_DEFAULT_WEIGHTS = {'1': 0.80, '2': 0.05, '3': 0.10, '4': 0.05} + + def main(): p = argparse.ArgumentParser(description=__doc__) p.add_argument("--asset_dir", default=str(Path(__file__).parent / "Korean-license-plate-Generator")) @@ -182,6 +187,8 @@ def main(): p.add_argument("--num", type=int, default=200, help="총 이미지 개수") p.add_argument("--test_ratio", type=float, default=0.05) p.add_argument("--types", default="1,2,3,4", help="포함할 type (콤마 구분)") + p.add_argument("--type_weights", default=None, + help="--types 순서대로 콤마 구분 가중치. 미지정 시 한국 도로 분포 기본값 사용.") p.add_argument("--dict", default=None, help="검증용 dict 경로 (선택)") p.add_argument("--seed", type=int, default=42) args = p.parse_args() @@ -196,10 +203,22 @@ def main(): '3': ('two', gen.gen_type3), '4': ('two', gen.gen_type4), } - chosen = [type_funcs[t.strip()] for t in args.types.split(',') if t.strip() in type_funcs] - if not chosen: + selected_keys = [t.strip() for t in args.types.split(',') if t.strip() in type_funcs] + if not selected_keys: sys.exit("No valid types selected.") - + chosen = [type_funcs[k] for k in selected_keys] + + if args.type_weights: + weights = [float(w) for w in args.type_weights.split(',')] + if len(weights) != len(selected_keys): + sys.exit(f"--type_weights ({len(weights)}) length mismatch with --types ({len(selected_keys)})") + else: + weights = [TYPE_DEFAULT_WEIGHTS[k] for k in selected_keys] + total = sum(weights) + weights = [w / total for w in weights] + print(f"Type sampling weights: {dict(zip(selected_keys, [round(w, 3) for w in weights]))}") + + type_count = {k: 0 for k in selected_keys} seen_chars = set() out = Path(args.out_dir) n_test = max(1, int(args.num * args.test_ratio)) @@ -210,7 +229,9 @@ def main(): img_dir.mkdir(parents=True, exist_ok=True) records = [] for i in range(count): - kind, fn = random.choice(chosen) + idx = random.choices(range(len(chosen)), weights=weights, k=1)[0] + kind, fn = chosen[idx] + type_count[selected_keys[idx]] += 1 if kind == 'one': plate, text = fn() label = make_label_one_line(plate, text) @@ -228,6 +249,8 @@ def main(): f.write(f"{path}\t{lab}\n") print(f" {split}: {len(records)} images → {out / split}") + print(f"Type counts: {type_count}") + # dict 검증 if args.dict: with open(args.dict, encoding="utf-8") as f: