#!/usr/bin/env python3
"""
한국 LP 합성 데이터 생성기 (PGNet 학습용 Step1 pretrain).

자산 출처: qjadud1994/Korean-license-plate-Generator (MIT)
출력 라벨 포맷: PaddleOCR PGNet (PGDataSet 호환)
  <상대경로>\t[{"transcription": "12가3456", "points": [[x1,y1],...,[x4,y4]]}, ...]

지원 plate 종류:
  Type 1: 신형 승용 가로 1줄  (520x110, 흰)
  Type 2: 구형 승용 가로 1줄  (355x155, 흰)
  Type 3: 영업용 두 줄        (336x170, 노랑) — 지역명 + 숫자2 / 한글 + 숫자4
  Type 4: 친환경/전기 두 줄   (336x170, 파랑·녹색)

주의 — REGION_MAP은 추정 매핑입니다. 합성 결과 PNG를 시각 확인 후 정정하세요.
"""

import argparse
import json
import random
import sys
from pathlib import Path

import cv2
import numpy as np


# 한영 자판 매핑 (자음·모음 → 영문 두 글자 코드 → 한글 글자)
# qjadud1994 자산의 char1/char1_g/char1_y 폴터 파일명 규칙과 일치 (40자).
# 추가: '하'(gk), '호'(gh), '배'(qo) — dict 누락 글자 보충.
HANGUL_CHAR_MAP = {
    'ah': '모', 'aj': '머', 'ak': '마', 'an': '무',
    'dh': '오', 'dj': '어', 'dk': '아', 'dn': '우',
    'eh': '도', 'ej': '더', 'ek': '다', 'en': '두',
    'fh': '로', 'fj': '러', 'fk': '라', 'fn': '루',
    'gh': '호', 'gj': '허', 'gk': '하',
    'qh': '보', 'qj': '버', 'qk': '바', 'qn': '부', 'qo': '배',
    'rh': '고', 'rj': '거', 'rk': '가', 'rn': '구',
    'sh': '노', 'sj': '너', 'sk': '나', 'sn': '누',
    'th': '소', 'tj': '서', 'tk': '사', 'tn': '수',
    'wh': '조', 'wj': '저', 'wk': '자', 'wn': '주',
}

# 자산 region_y / region_g 의 알파벳 코드 → 한국 광역지자체 (16종, 세종 제외 추정).
# 정확한 매핑은 PNG 시각 확인 필요. 학습 노이즈 방지 위해 합성 후 검증할 것.
REGION_MAP = {
    'A': '서울', 'B': '경기', 'C': '인천', 'D': '강원',
    'E': '충남', 'F': '대전', 'G': '충북', 'H': '부산',
    'I': '울산', 'J': '대구', 'K': '경북', 'L': '경남',
    'M': '전남', 'N': '광주', 'O': '전북', 'P': '제주',
}


def random_bright(img: np.ndarray, scale_range=(0.55, 1.45)) -> np.ndarray:
    """HSV V 채널 곱하기로 plate 전체 밝기 무작위 변경. 색상 다양성용."""
    hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype(np.float32)
    hsv[:, :, 2] = np.clip(hsv[:, :, 2] * random.uniform(*scale_range), 0, 255)
    return cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2BGR)


class StratifiedSampler:
    """완전 균등 샘플링: 모든 키를 셔플해서 한 사이클 동안 정확히 1번씩 등장."""
    def __init__(self, keys):
        self.keys = list(keys)
        self._pool = []

    def next(self):
        if not self._pool:
            self._pool = self.keys.copy()
            random.shuffle(self._pool)
        return self._pool.pop()


class LPGenerator:
    def __init__(self, asset_dir: Path):
        self.asset = Path(asset_dir)
        if not self.asset.is_dir():
            raise FileNotFoundError(f"asset dir not found: {self.asset}. Run data_gen/setup_assets.sh first.")

        self.plate_w = cv2.imread(str(self.asset / "plate.jpg"))
        self.plate_y = cv2.imread(str(self.asset / "plate_y.jpg"))
        self.plate_g = cv2.imread(str(self.asset / "plate_g.jpg"))

        self.num_w = self._load("num")
        self.num_y = self._load("num_y")
        self.num_g = self._load("num_g")
        self.char_w = self._load("char1")
        self.char_y = self._load("char1_y")
        self.char_g = self._load("char1_g")
        self.region_y_imgs = self._load("region_y")
        self.region_g_imgs = self._load("region_g")

        # 한글/지역명은 클래스 수가 적고 plate당 1개만 등장 → 균등 샘플링 필수
        self.hangul_sampler = StratifiedSampler(HANGUL_CHAR_MAP)
        self.region_y_sampler = StratifiedSampler(self.region_y_imgs)
        self.region_g_sampler = StratifiedSampler(self.region_g_imgs)

    def _load(self, sub: str) -> dict:
        out = {}
        for fp in sorted((self.asset / sub).iterdir()):
            if fp.suffix.lower() in {'.jpg', '.png'}:
                out[fp.stem] = cv2.imread(str(fp))
        return out

    @staticmethod
    def _resize_dict(d: dict, w: int, h: int) -> dict:
        return {k: cv2.resize(v, (w, h)) for k, v in d.items()}

    def gen_type1(self):
        """신형 승용 가로 1줄 (520x110). 글자 실제 위치로 tight polygon 생성."""
        plate = cv2.resize(self.plate_w, (520, 110))
        num = self._resize_dict(self.num_w, 56, 83)
        char = self._resize_dict(self.char_w, 60, 83)

        d = [random.choice('0123456789') for _ in range(2)]
        ch = self.hangul_sampler.next()
        e = [random.choice('0123456789') for _ in range(4)]

        row, col = 13, 35
        x0 = col
        for x in d:
            plate[row:row+83, col:col+56] = num[x]; col += 56
        plate[row:row+83, col:col+60] = char[ch]; col += 60
        for x in e:
            plate[row:row+83, col:col+56] = num[x]; col += 56
        x1 = col

        text = ''.join(d) + HANGUL_CHAR_MAP[ch] + ''.join(e)
        poly = [[x0, row], [x1, row], [x1, row + 83], [x0, row + 83]]
        return plate, [{"transcription": text, "points": poly}]

    def gen_type2(self):
        """구형 승용 가로 1줄 (355x155). 글자 실제 위치로 tight polygon 생성."""
        plate = cv2.resize(self.plate_w, (355, 155))
        num = self._resize_dict(self.num_w, 45, 83)
        char = self._resize_dict(self.char_w, 49, 70)

        d = [random.choice('0123456789') for _ in range(2)]
        ch = self.hangul_sampler.next()
        e = [random.choice('0123456789') for _ in range(4)]

        row, col = 46, 10
        x0 = col
        plate[row:row+83, col:col+45] = num[d[0]]; col += 45
        plate[row:row+83, col:col+45] = num[d[1]]; col += 45
        plate[row+12:row+82, col+2:col+51] = char[ch]; col += 51
        plate[row:row+83, col+2:col+47] = num[e[0]]; col += 47
        for x in e[1:]:
            plate[row:row+83, col:col+45] = num[x]; col += 45
        x1 = col

        text = ''.join(d) + HANGUL_CHAR_MAP[ch] + ''.join(e)
        poly = [[x0, row], [x1, row], [x1, row + 83], [x0, row + 83]]
        return plate, [{"transcription": text, "points": poly}]

    def _gen_two_line(self, plate_bg, num_src, char_src, region_src, region_sampler):
        """두 줄 LP (336x170). 위·아래 줄 각각 tight polygon 생성."""
        plate = cv2.resize(plate_bg, (336, 170))
        num1 = self._resize_dict(num_src, 44, 60)
        num2 = self._resize_dict(num_src, 64, 90)
        region = self._resize_dict(region_src, 88, 60)
        char = self._resize_dict(char_src, 64, 62)

        rkey = region_sampler.next()
        d = [random.choice('0123456789') for _ in range(2)]
        ch = self.hangul_sampler.next()
        e = [random.choice('0123456789') for _ in range(4)]

        # 위 줄: region + 숫자2
        row, col = 8, 76
        tx0 = col
        plate[row:row+60, col:col+88] = region[rkey]; col += 88 + 8
        for x in d:
            plate[row:row+60, col:col+44] = num1[x]; col += 44
        tx1 = col
        top_poly = [[tx0, row], [tx1, row], [tx1, row + 60], [tx0, row + 60]]

        # 아래 줄: 한글 + 숫자4
        row, col = 72, 8
        bx0 = col
        plate[row:row+62, col:col+64] = char[ch]; col += 64
        for x in e:
            plate[row:row+90, col:col+64] = num2[x]; col += 64
        bx1 = col
        bot_poly = [[bx0, row], [bx1, row], [bx1, row + 90], [bx0, row + 90]]

        top = REGION_MAP.get(rkey, '?') + ''.join(d)
        bot = HANGUL_CHAR_MAP[ch] + ''.join(e)
        return plate, [
            {"transcription": top, "points": top_poly},
            {"transcription": bot, "points": bot_poly},
        ]

    def gen_type3(self):
        return self._gen_two_line(self.plate_y, self.num_y, self.char_y, self.region_y_imgs, self.region_y_sampler)

    def gen_type4(self):
        return self._gen_two_line(self.plate_g, self.num_g, self.char_g, self.region_g_imgs, self.region_g_sampler)


# 학습용 가중치 — 검증 단계에서는 두줄(type3+4)을 도로 분포보다 의도적으로 늘려
# 모델이 윗줄/아랫줄 동시 검출을 충분히 학습하게 함. 추론 시 도로 분포로 평가됨.
TYPE_DEFAULT_WEIGHTS = {'1': 0.50, '2': 0.05, '3': 0.30, '4': 0.15}


def main():
    p = argparse.ArgumentParser(description=__doc__)
    p.add_argument("--asset_dir", default=str(Path(__file__).parent / "Korean-license-plate-Generator"))
    p.add_argument("--out_dir", required=True, help="합성 데이터셋 출력 루트")
    p.add_argument("--num", type=int, default=200, help="총 이미지 개수")
    p.add_argument("--test_ratio", type=float, default=0.05)
    p.add_argument("--types", default="1,2,3,4", help="포함할 type (콤마 구분)")
    p.add_argument("--type_weights", default=None,
                   help="--types 순서대로 콤마 구분 가중치. 미지정 시 한국 도로 분포 기본값 사용.")
    p.add_argument("--no_bright", action="store_true", help="random_bright 끄기 (디버깅용)")
    p.add_argument("--dict", default=None, help="검증용 dict 경로 (선택)")
    p.add_argument("--seed", type=int, default=42)
    args = p.parse_args()

    random.seed(args.seed)
    np.random.seed(args.seed)

    gen = LPGenerator(Path(args.asset_dir))
    type_funcs = {
        '1': gen.gen_type1,
        '2': gen.gen_type2,
        '3': gen.gen_type3,
        '4': gen.gen_type4,
    }
    selected_keys = [t.strip() for t in args.types.split(',') if t.strip() in type_funcs]
    if not selected_keys:
        sys.exit("No valid types selected.")
    chosen = [type_funcs[k] for k in selected_keys]

    if args.type_weights:
        weights = [float(w) for w in args.type_weights.split(',')]
        if len(weights) != len(selected_keys):
            sys.exit(f"--type_weights ({len(weights)}) length mismatch with --types ({len(selected_keys)})")
    else:
        weights = [TYPE_DEFAULT_WEIGHTS[k] for k in selected_keys]
    total = sum(weights)
    weights = [w / total for w in weights]
    print(f"Type sampling weights: {dict(zip(selected_keys, [round(w, 3) for w in weights]))}")

    type_count = {k: 0 for k in selected_keys}
    seen_chars = set()
    out = Path(args.out_dir)
    n_test = max(1, int(args.num * args.test_ratio))
    n_train = args.num - n_test

    for split, count in [("train", n_train), ("test", n_test)]:
        img_dir = out / split / "images"
        img_dir.mkdir(parents=True, exist_ok=True)
        records = []
        for i in range(count):
            idx = random.choices(range(len(chosen)), weights=weights, k=1)[0]
            fn = chosen[idx]
            type_count[selected_keys[idx]] += 1
            plate, label = fn()
            for entry in label:
                seen_chars.update(entry["transcription"])
            if not args.no_bright:
                plate = random_bright(plate)
            fname = f"{i:06d}.jpg"
            cv2.imwrite(str(img_dir / fname), plate)
            records.append((f"images/{fname}", json.dumps(label, ensure_ascii=False)))

        with open(out / split / f"{split}.txt", "w", encoding="utf-8") as f:
            for path, lab in records:
                f.write(f"{path}\t{lab}\n")
        print(f"  {split}: {len(records)} images → {out / split}")

    print(f"Type counts: {type_count}")

    # dict 검증
    if args.dict:
        with open(args.dict, encoding="utf-8") as f:
            dict_chars = {ln.strip() for ln in f if ln.strip()}
        missing = seen_chars - dict_chars - {'?'}
        unused = dict_chars - seen_chars
        print()
        print(f"dict 검증 (in {args.dict}):")
        print(f"  생성에 등장한 글자 {len(seen_chars)}자")
        print(f"  dict 누락 (라벨에 있는데 dict에 없음): {sorted(missing) or 'none'}")
        print(f"  미등장 (dict에 있는데 합성 데이터에 없음): {sorted(unused) or 'none'}")


if __name__ == "__main__":
    main()
