You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
87 lines
3.4 KiB
87 lines
3.4 KiB
import torch
|
|
import os
|
|
from core import model2 # 학습할 때 썼던 model 파일을 불러와야 합니다.
|
|
|
|
# ----------------------------
|
|
# 1. 설정 (경로 및 입력 사이즈)
|
|
# ----------------------------
|
|
# 사용자님이 알려주신 ckpt 경로
|
|
ckpt_path = '/home/cuuva/face_exp/MobileFaceNet_Pytorch/model/MODEL_2_20251127_174006/best_model/best_004.ckpt'
|
|
onnx_path = 'best_104.onnx' # 저장될 파일 이름
|
|
|
|
# [중요] 학습할 때 사용한 이미지 해상도와 일치해야 합니다.
|
|
# 아까 코드에서 128x128로 수정하신 것을 확인했으므로 128로 설정합니다.
|
|
input_size = (1, 3, 128, 128)
|
|
|
|
def convert():
|
|
print(f"Loading checkpoint from: {ckpt_path}")
|
|
|
|
# ----------------------------
|
|
# 2. 모델 구조 정의
|
|
# ----------------------------
|
|
# 학습 코드와 동일한 모델 클래스를 인스턴스화 합니다.
|
|
net = model2.MobileFacenet()
|
|
|
|
# ----------------------------
|
|
# 3. 가중치(Weights) 로드
|
|
# ----------------------------
|
|
checkpoint = torch.load(ckpt_path, map_location='cpu', weights_only=False) # GPU가 없어도 돌 수 있게 cpu로 로드
|
|
|
|
# 저장된 ckpt 구조에 따라 state_dict를 가져옵니다.
|
|
if 'net_state_dict' in checkpoint:
|
|
state_dict = checkpoint['net_state_dict']
|
|
else:
|
|
state_dict = checkpoint
|
|
|
|
# [핵심] DataParallel로 학습했다면 키(key) 앞에 'module.'이 붙어있습니다.
|
|
# 이를 제거해줘야 단일 모델에 로드할 수 있습니다.
|
|
new_state_dict = {}
|
|
for k, v in state_dict.items():
|
|
name = k.replace("module.", "") # 'module.conv1.weight' -> 'conv1.weight'
|
|
new_state_dict[name] = v
|
|
|
|
# 가중치 덮어씌우기
|
|
net.load_state_dict(new_state_dict)
|
|
|
|
# ----------------------------
|
|
# 4. 평가 모드 전환 (필수!)
|
|
# ----------------------------
|
|
# Dropout이나 Batch Norm이 학습 모드가 아닌 추론 모드로 동작하게 합니다.
|
|
net.eval()
|
|
# ----------------------------
|
|
# 4. ONNX 폴더 경로 생성
|
|
# ----------------------------
|
|
# ckpt_path 상위 폴더 이름 추출
|
|
experiment_folder_name = os.path.basename(os.path.dirname(os.path.dirname(ckpt_path)))
|
|
# 모델 최상위 경로
|
|
model_root = '/home/cuuva/face_exp/MobileFaceNet_Pytorch/model'
|
|
# 최종 ONNX 경로
|
|
onnx_dir = os.path.join(model_root, 'ONNX', experiment_folder_name)
|
|
os.makedirs(onnx_dir, exist_ok=True)
|
|
|
|
# ckpt 이름 기반으로 onnx 파일 이름 생성
|
|
onnx_name = os.path.splitext(os.path.basename(ckpt_path))[0] + '.onnx'
|
|
onnx_path = os.path.join(onnx_dir, onnx_name)
|
|
|
|
# ----------------------------
|
|
# 5. ONNX Export
|
|
# ----------------------------
|
|
print("Exporting to ONNX...")
|
|
|
|
# 모델 추적(Trace)을 위한 더미 입력 데이터 생성
|
|
dummy_input = torch.randn(*input_size)
|
|
|
|
torch.onnx.export(
|
|
net, # 실행할 모델
|
|
dummy_input, # 더미 입력값
|
|
onnx_path, # 저장할 경로
|
|
verbose=True, # 변환 과정 로그 출력
|
|
input_names=['input'], # 입력 노드 이름 (나중에 추론할 때 씀)
|
|
output_names=['output'], # 출력 노드 이름
|
|
external_data=False
|
|
)
|
|
|
|
print(f"Success! Model saved to: {os.path.abspath(onnx_path)}")
|
|
|
|
if __name__ == "__main__":
|
|
convert() |