You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

202 lines
7.7 KiB

import torch
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.nn import DataParallel, CrossEntropyLoss
from dataloader.MyHF_loader import CASIA_HF, LFW_Pairs
from core import model_bak
from core.utils import init_log
import os, time, numpy as np, scipy.io
from datetime import datetime
from config import BATCH_SIZE, SAVE_FREQ, RESUME, SAVE_DIR, TEST_FREQ, TOTAL_EPOCH, MODEL_PRE, GPU
from sklearn.metrics.pairwise import cosine_similarity # [추가] 정확도 계산용
# ----------------------------
# [추가] 간단한 LFW 정확도 계산 함수
# ----------------------------
def calculate_accuracy(featureLs, featureRs, flags, thresholds=np.arange(0, 1, 0.01)):
# 1. 특징 벡터 정규화 (Normalize)
featureLs = featureLs / np.linalg.norm(featureLs, axis=1, keepdims=True)
featureRs = featureRs / np.linalg.norm(featureRs, axis=1, keepdims=True)
# 2. 코사인 유사도 계산 (Dot Product)
scores = np.sum(featureLs * featureRs, axis=1)
# 3. 최적의 임계값(Threshold) 찾기 및 정확도 계산
best_acc = 0
for t in thresholds:
# 유사도가 t보다 크면 '같은 사람(1)', 작으면 '다른 사람(0)'
preds = (scores > t).astype(int)
acc = np.mean(preds == flags)
if acc > best_acc:
best_acc = acc
return best_acc
# ----------------------------
# GPU 및 초기 설정 (기존 동일)
# ----------------------------
gpu_list = ''
multi_gpus = False
if isinstance(GPU, int):
gpu_list = str(GPU)
else:
multi_gpus = True
gpu_list = ','.join(map(str, GPU))
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
start_epoch = 1
save_dir = os.path.join(SAVE_DIR, 'MODEL_BAK_' + datetime.now().strftime('%Y%m%d_%H%M%S'))
os.makedirs(save_dir, exist_ok=True)
logging = init_log(save_dir)
_print = logging.info
# ----------------------------
# Dataloader (기존 동일)
# ----------------------------
trainset = CASIA_HF()
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
shuffle=True, num_workers=8, drop_last=False)
testset = LFW_Pairs()
testloader = torch.utils.data.DataLoader(testset, batch_size=32,
shuffle=False, num_workers=8, drop_last=False)
# ----------------------------
# Model & Optimizer (기존 동일)
# ----------------------------
net = model_bak.MobileFacenet()
ArcMargin = model_bak.ArcMarginProduct(128, trainset.dataset.features['label'].num_classes)
if RESUME:
ckpt = torch.load(RESUME)
net.load_state_dict(ckpt['net_state_dict'])
start_epoch = ckpt['epoch'] + 1
net = net.cuda()
ArcMargin = ArcMargin.cuda()
if multi_gpus:
net = DataParallel(net)
ArcMargin = DataParallel(ArcMargin)
criterion = CrossEntropyLoss()
ignored_params = list(map(id, net.linear1.parameters())) + list(map(id, ArcMargin.weight))
# prelu_params = [p for m in net.modules() if isinstance(m, torch.nn.PReLU) for p in m.parameters()]
base_params = filter(lambda p: id(p) not in ignored_params, net.parameters())
# 기존 아키텍처에서 prelu 삭제했었으니까 아래 optim에서도 삭제 처리
optimizer_ft = optim.SGD([
{'params': base_params, 'weight_decay': 4e-5},
{'params': net.linear1.parameters(), 'weight_decay': 4e-4},
{'params': ArcMargin.weight, 'weight_decay': 4e-4}
], lr=0.1, momentum=0.9, nesterov=True)
# optimizer_ft = optim.SGD([
# {'params': base_params, 'weight_decay': 4e-5},
# {'params': net.linear1.parameters(), 'weight_decay': 4e-4},
# {'params': ArcMargin.weight, 'weight_decay': 4e-4},
# {'params': prelu_params, 'weight_decay': 0.0}
# ], lr=0.1, momentum=0.9, nesterov=True)
# 여기도 Config에서 Epoch 숫자 수정할때마다 milestone도 같이 수정해줘야함.
exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer_ft, milestones=[240, 310, 400], gamma=0.1)
# ----------------------------
# [추가] Best Accuracy 기록 변수
# ----------------------------
best_lfw_acc = 0.0
# ----------------------------
# Training Loop
# ----------------------------
for epoch in range(start_epoch, TOTAL_EPOCH + 1):
net.train()
train_total_loss, total = 0, 0
since = time.time()
_print(f"Train Epoch: {epoch}/{TOTAL_EPOCH} ...")
for data in trainloader:
img, label = data[0].cuda(), data[1].cuda()
optimizer_ft.zero_grad()
raw_logits = net(img)
output = ArcMargin(raw_logits, label)
loss = criterion(output, label)
loss.backward()
optimizer_ft.step()
train_total_loss += loss.item() * img.size(0)
total += img.size(0)
train_total_loss /= total
time_elapsed = time.time() - since
_print(f" total_loss: {train_total_loss:.4f} time: {time_elapsed//60:.0f}m {time_elapsed%60:.0f}s")
exp_lr_scheduler.step()
# ----------------------------
# Test & Best Model Save
# ----------------------------
if epoch % TEST_FREQ == 0:
net.eval()
featureLs, featureRs = None, None
flags = [] # [추가] 정답(Label)을 저장할 리스트
_print(" Testing LFW...")
with torch.no_grad(): # [추가] 테스트 땐 기울기 계산 끔 (메모리 절약)
for data in testloader:
# data 구조: [images_list, label(flag)]라고 가정
# LFW_Pairs의 경우 data[1]이 보통 정답(1:같은사람, 0:다른사람)
# 이미지 GPU 이동
imgs = [d.cuda() for d in data[0]]
# 정답 라벨 수집 (numpy로 변환)
flags.append(data[1].numpy())
# 특징 추출
res = [net(d).data.cpu().numpy() for d in imgs]
featureL = np.concatenate((res[0], res[1]), 1)
featureR = np.concatenate((res[2], res[3]), 1)
featureLs = featureL if featureLs is None else np.concatenate((featureLs, featureL), 0)
featureRs = featureR if featureRs is None else np.concatenate((featureRs, featureR), 0)
# [추가] 정답 리스트 합치기
flags = np.concatenate(flags, 0)
# [추가] 정확도 계산
# 만약 scipy.io.savemat은 필요하면 유지, 아니면 삭제해도 됨
# result = {'fl': featureLs, 'fr': featureRs}
# scipy.io.savemat('./result/tmp_result.mat', result)
# 직접 정확도 계산 (함수 호출)
current_acc = calculate_accuracy(featureLs, featureRs, flags)
_print(f" LFW Acc: {current_acc*100:.2f}% (Best: {best_lfw_acc*100:.2f}%)")
# [핵심] Best Model 저장 (Loss가 아닌 Acc 기준)
if current_acc > best_lfw_acc:
best_lfw_acc = current_acc
state_dict = net.module.state_dict() if multi_gpus else net.state_dict()
best_dir = os.path.join(save_dir, 'best_model')
os.makedirs(best_dir, exist_ok=True)
best_path = os.path.join(best_dir, f'best_{epoch:03d}.ckpt')
torch.save(
{
'epoch': epoch,
'net_state_dict': state_dict,
'acc': best_lfw_acc
},
best_path
)
_print(f" ==> Best Model Saved! (Acc: {best_lfw_acc*100:.2f}%, Epoch: {epoch}))")
# ----------------------------
# Regular Save (백업용)
# ----------------------------
if epoch % SAVE_FREQ == 0:
state_dict = net.module.state_dict() if multi_gpus else net.state_dict()
torch.save({'epoch': epoch, 'net_state_dict': state_dict},
os.path.join(save_dir, f'{epoch:03d}.ckpt'))
_print("finishing training")