You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

87 lines
3.5 KiB

import torch
import os
from core import model_bak # 학습할 때 썼던 model 파일을 불러와야 합니다.
# ----------------------------
# 1. 설정 (경로 및 입력 사이즈)
# ----------------------------
# 사용자님이 알려주신 ckpt 경로
# ckpt_path = '/home/cuuva/face_exp/MobileFaceNet_Pytorch/model/CASIA_B512_v2_20251124_175829/best_model/best_104.ckpt'
ckpt_path = '/home/cuuva/face_exp/MobileFaceNet_Pytorch/model/MODEL_BAK20251127_171730/best_model/best_001.ckpt'
onnx_path = 'best_104.onnx' # 저장될 파일 이름
# [중요] 학습할 때 사용한 이미지 해상도와 일치해야 합니다.
# 아까 코드에서 128x128로 수정하신 것을 확인했으므로 128로 설정합니다.
input_size = (1, 3, 128, 128)
def convert():
print(f"Loading checkpoint from: {ckpt_path}")
# ----------------------------
# 2. 모델 구조 정의
# ----------------------------
# 학습 코드와 동일한 모델 클래스를 인스턴스화 합니다.
net = model_bak.MobileFacenet()
# ----------------------------
# 3. 가중치(Weights) 로드
# ----------------------------
checkpoint = torch.load(ckpt_path, map_location='cpu', weights_only=False) # GPU가 없어도 돌 수 있게 cpu로 로드
# 저장된 ckpt 구조에 따라 state_dict를 가져옵니다.
if 'net_state_dict' in checkpoint:
state_dict = checkpoint['net_state_dict']
else:
state_dict = checkpoint
# [핵심] DataParallel로 학습했다면 키(key) 앞에 'module.'이 붙어있습니다.
# 이를 제거해줘야 단일 모델에 로드할 수 있습니다.
new_state_dict = {}
for k, v in state_dict.items():
name = k.replace("module.", "") # 'module.conv1.weight' -> 'conv1.weight'
new_state_dict[name] = v
# 가중치 덮어씌우기
net.load_state_dict(new_state_dict)
# ----------------------------
# 4. 평가 모드 전환 (필수!)
# ----------------------------
# Dropout이나 Batch Norm이 학습 모드가 아닌 추론 모드로 동작하게 합니다.
net.eval()
# ----------------------------
# 5. ONNX Export
# ----------------------------
print("Exporting to ONNX...")
# 모델 추적(Trace)을 위한 더미 입력 데이터 생성
dummy_input = torch.randn(*input_size)
torch.onnx.export(
net, # 실행할 모델
dummy_input, # 더미 입력값
onnx_path, # 저장할 경로
verbose=True, # 변환 과정 로그 출력
input_names=['input'], # 입력 노드 이름 (나중에 추론할 때 씀)
output_names=['output'], # 출력 노드 이름
external_data=False
#opset_version=11 # ONNX 버전 (보통 11이나 12가 호환성이 좋음)
# batch size를 가변적으로 쓰고 싶다면 아래 dynamic_axes 사용 (고정하려면 주석 처리)
#dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)
# torch.onnx.export(
# net,
# dummy_input,
# onnx_path,
# verbose=True,
# input_names=['input'],
# output_names=['output'],
# do_constant_folding=True, # 고정 상수 연산 미리 계산
# use_external_data_format=False # external .data 파일 없이 export
# )
print(f"Success! Model saved to: {os.path.abspath(onnx_path)}")
if __name__ == "__main__":
convert()