You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

71 lines
3.1 KiB

This file contains invisible Unicode characters!

This file contains invisible Unicode characters that may be processed differently from what appears below. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to reveal hidden characters.

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import torch
import argparse
import os
from model import get_face_model
def export_to_onnx(pt_path, output_path):
print(f"🔄 Loading model from: {pt_path}")
# 1. 모델 구조 불러오기 및 가중치 로드
# 보드용 모델이므로 학습용 ArcFace 헤더는 버리고, Backbone만 가져옵니다.
model = get_face_model()
# CPU로 로드 (변환은 굳이 GPU 불필요)
checkpoint = torch.load(pt_path, map_location='cpu')
# state_dict 로드 (혹시 모를 키 불일치 방지를 위해 strict=False는 선택사항이나, 여기선 구조가 같으므로 True 권장)
try:
model.load_state_dict(checkpoint)
except RuntimeError as e:
print(f"⚠️ Key mismatch detected. Trying to load with strict=False...")
model.load_state_dict(checkpoint, strict=False)
# 2. Eval 모드 전환 (매우 중요)
# 이걸 안 하면 BatchNorm, Dropout 등이pip 학습 모드로 동작하여 결과가 이상해집니다.
model.eval()
# 3. Dummy Input 생성 (Static Shape: 128x128)
# 보드 사양에 맞춰 배치 사이즈는 1로 고정합니다. [1, 3, 128, 128]
dummy_input = torch.randn(1, 3, 128, 128)
print(f"Target ONNX Path: {output_path}")
# 4. ONNX Export
# external_data=False는 PyTorch export에서 기본적으로 2GB 미만 모델에 대해 적용되어 단일 파일로 나옵니다.
# dynamic_axes 옵션을 뺌으로써 Static Shape을 강제합니다.
torch.onnx.export(
model, # 실행될 모델
dummy_input, # 모델 입력값 (차원 체크용)
output_path, # 저장될 경로
# export_params=True, # 모델 파일 안에 웨이트 저장 (external_data=False 효과)
# opset_version=11, # 임베디드 보드에서 가장 호환성 좋은 버전 (11 추천)
# do_constant_folding=True, # 상수 폴딩 최적화
input_names=['input'], # 입력 노드 이름
output_names=['output'], # 출력 노드 이름
external_data=False
# dynamic_axes={...} <-- 이 옵션을 사용하지 않음으로써 Static Shape으로 고정됨!
)
print(f"✅ Conversion Completed! Model saved at: {output_path}")
print(f" Input Shape: {dummy_input.shape} (Static)")
print(f" Please check if '{output_path}' is a single file.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Convert PyTorch model to ONNX')
# 입력받을 .pt 파일 경로
parser.add_argument('--input', type=str, required=True, help='Input .pt file path')
# 출력할 .onnx 파일 경로 (옵션)
parser.add_argument('--output', type=str, default=None, help='Output .onnx file path')
args = parser.parse_args()
# Output 경로가 없으면 Input 경로에서 확장자만 바꿔서 자동 지정
if args.output is None:
args.output = args.input.replace('.pt', '.onnx')
if not os.path.exists(args.input):
print(f"❌ Error: Input file not found: {args.input}")
else:
export_to_onnx(args.input, args.output)