You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
87 lines
3.5 KiB
87 lines
3.5 KiB
import torch
|
|
import os
|
|
from core import model_bak # 학습할 때 썼던 model 파일을 불러와야 합니다.
|
|
|
|
# ----------------------------
|
|
# 1. 설정 (경로 및 입력 사이즈)
|
|
# ----------------------------
|
|
# 사용자님이 알려주신 ckpt 경로
|
|
# ckpt_path = '/home/cuuva/face_exp/MobileFaceNet_Pytorch/model/CASIA_B512_v2_20251124_175829/best_model/best_104.ckpt'
|
|
ckpt_path = '/home/cuuva/face_exp/MobileFaceNet_Pytorch/model/MODEL_BAK20251127_171730/best_model/best_001.ckpt'
|
|
onnx_path = 'best_104.onnx' # 저장될 파일 이름
|
|
|
|
# [중요] 학습할 때 사용한 이미지 해상도와 일치해야 합니다.
|
|
# 아까 코드에서 128x128로 수정하신 것을 확인했으므로 128로 설정합니다.
|
|
input_size = (1, 3, 128, 128)
|
|
|
|
def convert():
|
|
print(f"Loading checkpoint from: {ckpt_path}")
|
|
|
|
# ----------------------------
|
|
# 2. 모델 구조 정의
|
|
# ----------------------------
|
|
# 학습 코드와 동일한 모델 클래스를 인스턴스화 합니다.
|
|
net = model_bak.MobileFacenet()
|
|
|
|
# ----------------------------
|
|
# 3. 가중치(Weights) 로드
|
|
# ----------------------------
|
|
checkpoint = torch.load(ckpt_path, map_location='cpu', weights_only=False) # GPU가 없어도 돌 수 있게 cpu로 로드
|
|
|
|
# 저장된 ckpt 구조에 따라 state_dict를 가져옵니다.
|
|
if 'net_state_dict' in checkpoint:
|
|
state_dict = checkpoint['net_state_dict']
|
|
else:
|
|
state_dict = checkpoint
|
|
|
|
# [핵심] DataParallel로 학습했다면 키(key) 앞에 'module.'이 붙어있습니다.
|
|
# 이를 제거해줘야 단일 모델에 로드할 수 있습니다.
|
|
new_state_dict = {}
|
|
for k, v in state_dict.items():
|
|
name = k.replace("module.", "") # 'module.conv1.weight' -> 'conv1.weight'
|
|
new_state_dict[name] = v
|
|
|
|
# 가중치 덮어씌우기
|
|
net.load_state_dict(new_state_dict)
|
|
|
|
# ----------------------------
|
|
# 4. 평가 모드 전환 (필수!)
|
|
# ----------------------------
|
|
# Dropout이나 Batch Norm이 학습 모드가 아닌 추론 모드로 동작하게 합니다.
|
|
net.eval()
|
|
|
|
# ----------------------------
|
|
# 5. ONNX Export
|
|
# ----------------------------
|
|
print("Exporting to ONNX...")
|
|
|
|
# 모델 추적(Trace)을 위한 더미 입력 데이터 생성
|
|
dummy_input = torch.randn(*input_size)
|
|
|
|
torch.onnx.export(
|
|
net, # 실행할 모델
|
|
dummy_input, # 더미 입력값
|
|
onnx_path, # 저장할 경로
|
|
verbose=True, # 변환 과정 로그 출력
|
|
input_names=['input'], # 입력 노드 이름 (나중에 추론할 때 씀)
|
|
output_names=['output'], # 출력 노드 이름
|
|
external_data=False
|
|
#opset_version=11 # ONNX 버전 (보통 11이나 12가 호환성이 좋음)
|
|
# batch size를 가변적으로 쓰고 싶다면 아래 dynamic_axes 사용 (고정하려면 주석 처리)
|
|
#dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
|
|
)
|
|
# torch.onnx.export(
|
|
# net,
|
|
# dummy_input,
|
|
# onnx_path,
|
|
# verbose=True,
|
|
# input_names=['input'],
|
|
# output_names=['output'],
|
|
# do_constant_folding=True, # 고정 상수 연산 미리 계산
|
|
# use_external_data_format=False # external .data 파일 없이 export
|
|
# )
|
|
|
|
print(f"Success! Model saved to: {os.path.abspath(onnx_path)}")
|
|
|
|
if __name__ == "__main__":
|
|
convert() |