import torch import argparse import os from model import get_face_model def export_to_onnx(pt_path, output_path): print(f"๐Ÿ”„ Loading model from: {pt_path}") # 1. ๋ชจ๋ธ ๊ตฌ์กฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ ๋ฐ ๊ฐ€์ค‘์น˜ ๋กœ๋“œ # ๋ณด๋“œ์šฉ ๋ชจ๋ธ์ด๋ฏ€๋กœ ํ•™์Šต์šฉ ArcFace ํ—ค๋”๋Š” ๋ฒ„๋ฆฌ๊ณ , Backbone๋งŒ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค. model = get_face_model() # CPU๋กœ ๋กœ๋“œ (๋ณ€ํ™˜์€ ๊ตณ์ด GPU ๋ถˆํ•„์š”) checkpoint = torch.load(pt_path, map_location='cpu') # state_dict ๋กœ๋“œ (ํ˜น์‹œ ๋ชจ๋ฅผ ํ‚ค ๋ถˆ์ผ์น˜ ๋ฐฉ์ง€๋ฅผ ์œ„ํ•ด strict=False๋Š” ์„ ํƒ์‚ฌํ•ญ์ด๋‚˜, ์—ฌ๊ธฐ์„  ๊ตฌ์กฐ๊ฐ€ ๊ฐ™์œผ๋ฏ€๋กœ True ๊ถŒ์žฅ) try: model.load_state_dict(checkpoint) except RuntimeError as e: print(f"โš ๏ธ Key mismatch detected. Trying to load with strict=False...") model.load_state_dict(checkpoint, strict=False) # 2. Eval ๋ชจ๋“œ ์ „ํ™˜ (๋งค์šฐ ์ค‘์š”) # ์ด๊ฑธ ์•ˆ ํ•˜๋ฉด BatchNorm, Dropout ๋“ฑ์ดpip ํ•™์Šต ๋ชจ๋“œ๋กœ ๋™์ž‘ํ•˜์—ฌ ๊ฒฐ๊ณผ๊ฐ€ ์ด์ƒํ•ด์ง‘๋‹ˆ๋‹ค. model.eval() # 3. Dummy Input ์ƒ์„ฑ (Static Shape: 128x128) # ๋ณด๋“œ ์‚ฌ์–‘์— ๋งž์ถฐ ๋ฐฐ์น˜ ์‚ฌ์ด์ฆˆ๋Š” 1๋กœ ๊ณ ์ •ํ•ฉ๋‹ˆ๋‹ค. [1, 3, 128, 128] dummy_input = torch.randn(1, 3, 128, 128) print(f"Target ONNX Path: {output_path}") # 4. ONNX Export # external_data=False๋Š” PyTorch export์—์„œ ๊ธฐ๋ณธ์ ์œผ๋กœ 2GB ๋ฏธ๋งŒ ๋ชจ๋ธ์— ๋Œ€ํ•ด ์ ์šฉ๋˜์–ด ๋‹จ์ผ ํŒŒ์ผ๋กœ ๋‚˜์˜ต๋‹ˆ๋‹ค. # dynamic_axes ์˜ต์…˜์„ ๋บŒ์œผ๋กœ์จ Static Shape์„ ๊ฐ•์ œํ•ฉ๋‹ˆ๋‹ค. torch.onnx.export( model, # ์‹คํ–‰๋  ๋ชจ๋ธ dummy_input, # ๋ชจ๋ธ ์ž…๋ ฅ๊ฐ’ (์ฐจ์› ์ฒดํฌ์šฉ) output_path, # ์ €์žฅ๋  ๊ฒฝ๋กœ # export_params=True, # ๋ชจ๋ธ ํŒŒ์ผ ์•ˆ์— ์›จ์ดํŠธ ์ €์žฅ (external_data=False ํšจ๊ณผ) # opset_version=11, # ์ž„๋ฒ ๋””๋“œ ๋ณด๋“œ์—์„œ ๊ฐ€์žฅ ํ˜ธํ™˜์„ฑ ์ข‹์€ ๋ฒ„์ „ (11 ์ถ”์ฒœ) # do_constant_folding=True, # ์ƒ์ˆ˜ ํด๋”ฉ ์ตœ์ ํ™” input_names=['input'], # ์ž…๋ ฅ ๋…ธ๋“œ ์ด๋ฆ„ output_names=['output'], # ์ถœ๋ ฅ ๋…ธ๋“œ ์ด๋ฆ„ external_data=False # dynamic_axes={...} <-- ์ด ์˜ต์…˜์„ ์‚ฌ์šฉํ•˜์ง€ ์•Š์Œ์œผ๋กœ์จ Static Shape์œผ๋กœ ๊ณ ์ •๋จ! ) print(f"โœ… Conversion Completed! Model saved at: {output_path}") print(f"โ„น๏ธ Input Shape: {dummy_input.shape} (Static)") print(f"โ„น๏ธ Please check if '{output_path}' is a single file.") if __name__ == "__main__": parser = argparse.ArgumentParser(description='Convert PyTorch model to ONNX') # ์ž…๋ ฅ๋ฐ›์„ .pt ํŒŒ์ผ ๊ฒฝ๋กœ parser.add_argument('--input', type=str, required=True, help='Input .pt file path') # ์ถœ๋ ฅํ•  .onnx ํŒŒ์ผ ๊ฒฝ๋กœ (์˜ต์…˜) parser.add_argument('--output', type=str, default=None, help='Output .onnx file path') args = parser.parse_args() # Output ๊ฒฝ๋กœ๊ฐ€ ์—†์œผ๋ฉด Input ๊ฒฝ๋กœ์—์„œ ํ™•์žฅ์ž๋งŒ ๋ฐ”๊ฟ”์„œ ์ž๋™ ์ง€์ • if args.output is None: args.output = args.input.replace('.pt', '.onnx') if not os.path.exists(args.input): print(f"โŒ Error: Input file not found: {args.input}") else: export_to_onnx(args.input, args.output)