You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

71 lines
3.1 KiB

import torch
import argparse
import os
from model import get_face_model
def export_to_onnx(pt_path, output_path):
print(f"🔄 Loading model from: {pt_path}")
# 1. 모델 구조 불러오기 및 가중치 로드
# 보드용 모델이므로 학습용 ArcFace 헤더는 버리고, Backbone만 가져옵니다.
model = get_face_model()
# CPU로 로드 (변환은 굳이 GPU 불필요)
checkpoint = torch.load(pt_path, map_location='cpu')
# state_dict 로드 (혹시 모를 키 불일치 방지를 위해 strict=False는 선택사항이나, 여기선 구조가 같으므로 True 권장)
try:
model.load_state_dict(checkpoint)
except RuntimeError as e:
print(f"⚠️ Key mismatch detected. Trying to load with strict=False...")
model.load_state_dict(checkpoint, strict=False)
# 2. Eval 모드 전환 (매우 중요)
# 이걸 안 하면 BatchNorm, Dropout 등이pip 학습 모드로 동작하여 결과가 이상해집니다.
model.eval()
# 3. Dummy Input 생성 (Static Shape: 128x128)
# 보드 사양에 맞춰 배치 사이즈는 1로 고정합니다. [1, 3, 128, 128]
dummy_input = torch.randn(1, 3, 128, 128)
print(f"Target ONNX Path: {output_path}")
# 4. ONNX Export
# external_data=False는 PyTorch export에서 기본적으로 2GB 미만 모델에 대해 적용되어 단일 파일로 나옵니다.
# dynamic_axes 옵션을 뺌으로써 Static Shape을 강제합니다.
torch.onnx.export(
model, # 실행될 모델
dummy_input, # 모델 입력값 (차원 체크용)
output_path, # 저장될 경로
# export_params=True, # 모델 파일 안에 웨이트 저장 (external_data=False 효과)
# opset_version=11, # 임베디드 보드에서 가장 호환성 좋은 버전 (11 추천)
# do_constant_folding=True, # 상수 폴딩 최적화
input_names=['input'], # 입력 노드 이름
output_names=['output'], # 출력 노드 이름
external_data=False
# dynamic_axes={...} <-- 이 옵션을 사용하지 않음으로써 Static Shape으로 고정됨!
)
print(f"✅ Conversion Completed! Model saved at: {output_path}")
print(f" Input Shape: {dummy_input.shape} (Static)")
print(f" Please check if '{output_path}' is a single file.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Convert PyTorch model to ONNX')
# 입력받을 .pt 파일 경로
parser.add_argument('--input', type=str, required=True, help='Input .pt file path')
# 출력할 .onnx 파일 경로 (옵션)
parser.add_argument('--output', type=str, default=None, help='Output .onnx file path')
args = parser.parse_args()
# Output 경로가 없으면 Input 경로에서 확장자만 바꿔서 자동 지정
if args.output is None:
args.output = args.input.replace('.pt', '.onnx')
if not os.path.exists(args.input):
print(f"❌ Error: Input file not found: {args.input}")
else:
export_to_onnx(args.input, args.output)