import torch import os from core import model # 학습할 때 썼던 model 파일을 불러와야 합니다. # ---------------------------- # 1. 설정 (경로 및 입력 사이즈) # ---------------------------- # 사용자님이 알려주신 ckpt 경로 # ckpt_path = '/home/cuuva/face_exp/MobileFaceNet_Pytorch/model/CASIA_B512_v2_20251124_175829/best_model/best_104.ckpt' ckpt_path = '/home/cuuva/face_exp/MobileFaceNet_Pytorch/model/CASIA_B512_v2_20251126_173236/best_model/best_063.ckpt' onnx_path = 'best_104.onnx' # 저장될 파일 이름 # [중요] 학습할 때 사용한 이미지 해상도와 일치해야 합니다. # 아까 코드에서 128x128로 수정하신 것을 확인했으므로 128로 설정합니다. input_size = (1, 3, 128, 128) def convert(): print(f"Loading checkpoint from: {ckpt_path}") # ---------------------------- # 2. 모델 구조 정의 # ---------------------------- # 학습 코드와 동일한 모델 클래스를 인스턴스화 합니다. net = model.MobileFacenet() # ---------------------------- # 3. 가중치(Weights) 로드 # ---------------------------- checkpoint = torch.load(ckpt_path, map_location='cpu', weights_only=False) # GPU가 없어도 돌 수 있게 cpu로 로드 # 저장된 ckpt 구조에 따라 state_dict를 가져옵니다. if 'net_state_dict' in checkpoint: state_dict = checkpoint['net_state_dict'] else: state_dict = checkpoint # [핵심] DataParallel로 학습했다면 키(key) 앞에 'module.'이 붙어있습니다. # 이를 제거해줘야 단일 모델에 로드할 수 있습니다. new_state_dict = {} for k, v in state_dict.items(): name = k.replace("module.", "") # 'module.conv1.weight' -> 'conv1.weight' new_state_dict[name] = v # 가중치 덮어씌우기 net.load_state_dict(new_state_dict) # ---------------------------- # 4. 평가 모드 전환 (필수!) # ---------------------------- # Dropout이나 Batch Norm이 학습 모드가 아닌 추론 모드로 동작하게 합니다. net.eval() # ---------------------------- # 5. ONNX Export # ---------------------------- print("Exporting to ONNX...") # 모델 추적(Trace)을 위한 더미 입력 데이터 생성 dummy_input = torch.randn(*input_size) torch.onnx.export( net, # 실행할 모델 dummy_input, # 더미 입력값 onnx_path, # 저장할 경로 verbose=True, # 변환 과정 로그 출력 input_names=['input'], # 입력 노드 이름 (나중에 추론할 때 씀) output_names=['output'], # 출력 노드 이름 external_data=False #opset_version=11 # ONNX 버전 (보통 11이나 12가 호환성이 좋음) # batch size를 가변적으로 쓰고 싶다면 아래 dynamic_axes 사용 (고정하려면 주석 처리) #dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}} ) # torch.onnx.export( # net, # dummy_input, # onnx_path, # verbose=True, # input_names=['input'], # output_names=['output'], # do_constant_folding=True, # 고정 상수 연산 미리 계산 # use_external_data_format=False # external .data 파일 없이 export # ) print(f"Success! Model saved to: {os.path.abspath(onnx_path)}") if __name__ == "__main__": convert()