You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

88 lines
3.5 KiB

import torch
import os
from core import model_song # 학습할 때 썼던 model 파일을 불러와야 합니다.
# ----------------------------
# 1. 설정 (경로 및 입력 사이즈)
# ----------------------------
# 사용자님이 알려주신 ckpt 경로
ckpt_path = '/home/cuuva/face_exp/MobileFaceNet_Pytorch/result(MODEL_SONG)/MODEL_2_20251203_173900/best_model/best_001.ckpt'
onnx_path = 'model_2_test.onnx' # 저장될 파일 이름
# [중요] 학습할 때 사용한 이미지 해상도와 일치해야 합니다.
# 아까 코드에서 128x128로 수정하신 것을 확인했으므로 128로 설정합니다.
input_size = (1, 3, 128, 128)
def convert():
print(f"Loading checkpoint from: {ckpt_path}")
# ----------------------------
# 2. 모델 구조 정의
# ----------------------------
# 학습 코드와 동일한 모델 클래스를 인스턴스화 합니다.
net = model_song.MobileFacenet()
# ----------------------------
# 3. 가중치(Weights) 로드
# ----------------------------
checkpoint = torch.load(ckpt_path, map_location='cpu', weights_only=False) # GPU가 없어도 돌 수 있게 cpu로 로드
# 저장된 ckpt 구조에 따라 state_dict를 가져옵니다.
if 'net_state_dict' in checkpoint:
state_dict = checkpoint['net_state_dict']
else:
state_dict = checkpoint
# [핵심] DataParallel로 학습했다면 키(key) 앞에 'module.'이 붙어있습니다.
# 이를 제거해줘야 단일 모델에 로드할 수 있습니다.
new_state_dict = {}
for k, v in state_dict.items():
name = k.replace("module.", "") # 'module.conv1.weight' -> 'conv1.weight'
new_state_dict[name] = v
# 가중치 덮어씌우기
net.load_state_dict(new_state_dict)
# ----------------------------
# 4. 평가 모드 전환 (필수!)
# ----------------------------
# Dropout이나 Batch Norm이 학습 모드가 아닌 추론 모드로 동작하게 합니다.
net.eval()
# ----------------------------
# 4. ONNX 폴더 경로 생성
# ----------------------------
# ckpt_path 상위 폴더 이름 추출
experiment_folder_name = os.path.basename(os.path.dirname(os.path.dirname(ckpt_path)))
# 모델 최상위 경로
model_root = '/home/cuuva/face_exp/MobileFaceNet_Pytorch/result(MODEL_SONG)'
# 최종 ONNX 경로
onnx_dir = os.path.join(model_root, 'ONNX', experiment_folder_name)
os.makedirs(onnx_dir, exist_ok=True)
# ckpt 이름 기반으로 onnx 파일 이름 생성
# onnx_name = os.path.splitext(os.path.basename(ckpt_path))[0] + '.onnx'
onnx_name = 'model_2_test.onnx'
onnx_path = os.path.join(onnx_dir, onnx_name)
# ----------------------------
# 5. ONNX Export
# ----------------------------
print("Exporting to ONNX...")
# 모델 추적(Trace)을 위한 더미 입력 데이터 생성
dummy_input = torch.randn(*input_size)
torch.onnx.export(
net, # 실행할 모델
dummy_input, # 더미 입력값
onnx_path, # 저장할 경로
verbose=True, # 변환 과정 로그 출력
input_names=['input'], # 입력 노드 이름 (나중에 추론할 때 씀)
output_names=['output'], # 출력 노드 이름
external_data=False
)
print(f"Success! Model saved to: {os.path.abspath(onnx_path)}")
if __name__ == "__main__":
convert()