|
|
import torch
|
|
|
import argparse
|
|
|
import os
|
|
|
from model import get_face_model
|
|
|
|
|
|
def export_to_onnx(pt_path, output_path):
|
|
|
print(f"🔄 Loading model from: {pt_path}")
|
|
|
|
|
|
# 1. 모델 구조 불러오기 및 가중치 로드
|
|
|
# 보드용 모델이므로 학습용 ArcFace 헤더는 버리고, Backbone만 가져옵니다.
|
|
|
model = get_face_model()
|
|
|
|
|
|
# CPU로 로드 (변환은 굳이 GPU 불필요)
|
|
|
checkpoint = torch.load(pt_path, map_location='cpu')
|
|
|
|
|
|
# state_dict 로드 (혹시 모를 키 불일치 방지를 위해 strict=False는 선택사항이나, 여기선 구조가 같으므로 True 권장)
|
|
|
try:
|
|
|
model.load_state_dict(checkpoint)
|
|
|
except RuntimeError as e:
|
|
|
print(f"⚠️ Key mismatch detected. Trying to load with strict=False...")
|
|
|
model.load_state_dict(checkpoint, strict=False)
|
|
|
|
|
|
# 2. Eval 모드 전환 (매우 중요)
|
|
|
# 이걸 안 하면 BatchNorm, Dropout 등이pip 학습 모드로 동작하여 결과가 이상해집니다.
|
|
|
model.eval()
|
|
|
|
|
|
# 3. Dummy Input 생성 (Static Shape: 128x128)
|
|
|
# 보드 사양에 맞춰 배치 사이즈는 1로 고정합니다. [1, 3, 128, 128]
|
|
|
dummy_input = torch.randn(1, 3, 128, 128)
|
|
|
|
|
|
print(f"Target ONNX Path: {output_path}")
|
|
|
|
|
|
# 4. ONNX Export
|
|
|
# external_data=False는 PyTorch export에서 기본적으로 2GB 미만 모델에 대해 적용되어 단일 파일로 나옵니다.
|
|
|
# dynamic_axes 옵션을 뺌으로써 Static Shape을 강제합니다.
|
|
|
torch.onnx.export(
|
|
|
model, # 실행될 모델
|
|
|
dummy_input, # 모델 입력값 (차원 체크용)
|
|
|
output_path, # 저장될 경로
|
|
|
# export_params=True, # 모델 파일 안에 웨이트 저장 (external_data=False 효과)
|
|
|
# opset_version=11, # 임베디드 보드에서 가장 호환성 좋은 버전 (11 추천)
|
|
|
# do_constant_folding=True, # 상수 폴딩 최적화
|
|
|
input_names=['input'], # 입력 노드 이름
|
|
|
output_names=['output'], # 출력 노드 이름
|
|
|
external_data=False
|
|
|
# dynamic_axes={...} <-- 이 옵션을 사용하지 않음으로써 Static Shape으로 고정됨!
|
|
|
)
|
|
|
|
|
|
print(f"✅ Conversion Completed! Model saved at: {output_path}")
|
|
|
print(f"ℹ️ Input Shape: {dummy_input.shape} (Static)")
|
|
|
print(f"ℹ️ Please check if '{output_path}' is a single file.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
parser = argparse.ArgumentParser(description='Convert PyTorch model to ONNX')
|
|
|
|
|
|
# 입력받을 .pt 파일 경로
|
|
|
parser.add_argument('--input', type=str, required=True, help='Input .pt file path')
|
|
|
|
|
|
# 출력할 .onnx 파일 경로 (옵션)
|
|
|
parser.add_argument('--output', type=str, default=None, help='Output .onnx file path')
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
# Output 경로가 없으면 Input 경로에서 확장자만 바꿔서 자동 지정
|
|
|
if args.output is None:
|
|
|
args.output = args.input.replace('.pt', '.onnx')
|
|
|
|
|
|
if not os.path.exists(args.input):
|
|
|
print(f"❌ Error: Input file not found: {args.input}")
|
|
|
else:
|
|
|
export_to_onnx(args.input, args.output) |