import torch import os from core import model # 학습할 때 썼던 model 파일을 불러와야 합니다. # ---------------------------- # 1. 설정 (경로 및 입력 사이즈) # ---------------------------- # 사용자님이 알려주신 ckpt 경로 ckpt_path = '/home/cuuva/face_exp/MobileFaceNet_Pytorch/model/MODEL_2_20251127_174006/best_model/best_004.ckpt' onnx_path = 'best_104.onnx' # 저장될 파일 이름 # [중요] 학습할 때 사용한 이미지 해상도와 일치해야 합니다. # 아까 코드에서 128x128로 수정하신 것을 확인했으므로 128로 설정합니다. input_size = (1, 3, 128, 128) def convert(): print(f"Loading checkpoint from: {ckpt_path}") # ---------------------------- # 2. 모델 구조 정의 # ---------------------------- # 학습 코드와 동일한 모델 클래스를 인스턴스화 합니다. net = model.MobileFacenet() # ---------------------------- # 3. 가중치(Weights) 로드 # ---------------------------- checkpoint = torch.load(ckpt_path, map_location='cpu', weights_only=False) # GPU가 없어도 돌 수 있게 cpu로 로드 # 저장된 ckpt 구조에 따라 state_dict를 가져옵니다. if 'net_state_dict' in checkpoint: state_dict = checkpoint['net_state_dict'] else: state_dict = checkpoint # [핵심] DataParallel로 학습했다면 키(key) 앞에 'module.'이 붙어있습니다. # 이를 제거해줘야 단일 모델에 로드할 수 있습니다. new_state_dict = {} for k, v in state_dict.items(): name = k.replace("module.", "") # 'module.conv1.weight' -> 'conv1.weight' new_state_dict[name] = v # 가중치 덮어씌우기 net.load_state_dict(new_state_dict) # ---------------------------- # 4. 평가 모드 전환 (필수!) # ---------------------------- # Dropout이나 Batch Norm이 학습 모드가 아닌 추론 모드로 동작하게 합니다. net.eval() # ---------------------------- # 4. ONNX 폴더 경로 생성 # ---------------------------- # ckpt_path 상위 폴더 이름 추출 experiment_folder_name = os.path.basename(os.path.dirname(os.path.dirname(ckpt_path))) # 모델 최상위 경로 model_root = '/home/cuuva/face_exp/MobileFaceNet_Pytorch/model' # 최종 ONNX 경로 onnx_dir = os.path.join(model_root, 'ONNX', experiment_folder_name) os.makedirs(onnx_dir, exist_ok=True) # ckpt 이름 기반으로 onnx 파일 이름 생성 onnx_name = os.path.splitext(os.path.basename(ckpt_path))[0] + '.onnx' onnx_path = os.path.join(onnx_dir, onnx_name) # ---------------------------- # 5. ONNX Export # ---------------------------- print("Exporting to ONNX...") # 모델 추적(Trace)을 위한 더미 입력 데이터 생성 dummy_input = torch.randn(*input_size) torch.onnx.export( net, # 실행할 모델 dummy_input, # 더미 입력값 onnx_path, # 저장할 경로 verbose=True, # 변환 과정 로그 출력 input_names=['input'], # 입력 노드 이름 (나중에 추론할 때 씀) output_names=['output'], # 출력 노드 이름 external_data=False ) print(f"Success! Model saved to: {os.path.abspath(onnx_path)}") if __name__ == "__main__": convert()