parent
dda56b30e5
commit
745d7516b3
@ -1,2 +1,55 @@
|
||||
# Face_recognition
|
||||
# MobileFaceNet
|
||||
|
||||
## Introduction
|
||||
* This repository is the pytorch implement of the paper: [MobileFaceNets: Efficient CNNs for Accurate Real-Time Face Verification on Mobile Devices](https://arxiv.org/pdf/1804.07573.pdf) and I almost follow the implement details of the paper.
|
||||
* I train the model on CASIA-WebFace dataset, and evaluate on LFW dataset.
|
||||
|
||||
## Requirements
|
||||
|
||||
* Python 3.5
|
||||
* pytorch 0.4+
|
||||
* GPU memory
|
||||
|
||||
## Usage
|
||||
|
||||
### Part 1: Preprocessing
|
||||
|
||||
* All images of dataset are preprocessed following the [SphereFace](https://github.com/wy1iu/sphereface) and you can download the aligned images at [Align-CASIA-WebFace@BaiduDrive](https://pan.baidu.com/s/1k3Cel2wSHQxHO9NkNi3rkg) and [Align-LFW@BaiduDrive](https://pan.baidu.com/s/1r6BQxzlFza8FM8Z8C_OCBg).
|
||||
|
||||
### Part 2: Train
|
||||
|
||||
1. Change the **CAISIA_DATA_DIR** and **LFW_DATA_DAR** in `config.py` to your data path.
|
||||
|
||||
2. Train the mobilefacenet model.
|
||||
|
||||
**Note:** The default settings set the batch size of 512, use 2 gpus and train the model on 70 epochs. You can change the settings in `config.py`
|
||||
```
|
||||
python3 train.py
|
||||
```
|
||||
|
||||
### Part 3: Test
|
||||
|
||||
1. Test the model on LFW.
|
||||
|
||||
**Note:** I have tested `lfw_eval.py` on the caffe model at [SphereFace](https://github.com/wy1iu/sphereface), it gets the same result.
|
||||
|
||||
```
|
||||
python3 lfw_eval.py --resume --feature_save_dir
|
||||
```
|
||||
* `--resume:` path of saved model
|
||||
* `--feature_save_dir:` path to save the extracted features (must be .mat file)
|
||||
|
||||
## Results
|
||||
|
||||
* You can just run the `lfw_eval.py` to get the result, the accuracy on LFW like this:
|
||||
|
||||
| Fold | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | AVE(ours) | Paper(112x96) |
|
||||
| ------ |------|------|------|------|------|------|------|------|------|------| ------ | ------ |
|
||||
| ACC | 99.00 | 99.00 | 99.00 | 98.67 | 99.33 | 99.67 | 99.17 | 99.50 | 100.00 | 99.67| **99.30** | 99.18 |
|
||||
|
||||
|
||||
## Reference resources
|
||||
|
||||
* [arcface-pytorch](https://github.com/ronghuaiyang/arcface-pytorch)
|
||||
* [SphereFace](https://github.com/wy1iu/sphereface)
|
||||
* [Insightface](https://github.com/deepinsight/insightface)
|
||||
|
||||
@ -0,0 +1,243 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
import math
|
||||
from torch.nn import Parameter
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
def __init__(self, inp, oup, stride, expansion):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.connect = stride == 1 and inp == oup
|
||||
self.conv = nn.Sequential(
|
||||
#pw
|
||||
nn.Conv2d(inp, inp * expansion, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(inp * expansion),
|
||||
nn.ReLU(inplace=True),
|
||||
|
||||
#dw
|
||||
nn.Conv2d(inp * expansion, inp * expansion, 3, stride, 1, groups=inp * expansion, bias=False),
|
||||
nn.BatchNorm2d(inp * expansion),
|
||||
nn.ReLU(inplace=True),
|
||||
|
||||
#pw-linear
|
||||
nn.Conv2d(inp * expansion, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.connect:
|
||||
return x + self.conv(x)
|
||||
else:
|
||||
return self.conv(x)
|
||||
|
||||
# class ConvBlock(nn.Module): # prelu 버전
|
||||
# def __init__(self, inp, oup, k, s, p, dw=False, linear=False):
|
||||
# super(ConvBlock, self).__init__()
|
||||
# self.linear = linear
|
||||
# if dw:
|
||||
# self.conv = nn.Conv2d(inp, oup, k, s, p, groups=inp, bias=False)
|
||||
# else:
|
||||
# self.conv = nn.Conv2d(inp, oup, k, s, p, bias=False)
|
||||
# self.bn = nn.BatchNorm2d(oup)
|
||||
# if not linear:
|
||||
# self.prelu = nn.PReLU(oup)
|
||||
# def forward(self, x):
|
||||
# x = self.conv(x)
|
||||
# x = self.bn(x)
|
||||
# if self.linear:
|
||||
# return x
|
||||
# else:
|
||||
# return self.prelu(x)
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(self, inp, oup, k, s, p, dw=False, linear=False):
|
||||
super(ConvBlock, self).__init__()
|
||||
self.linear = linear
|
||||
if dw:
|
||||
self.conv = nn.Conv2d(inp, oup, k, s, p, groups=inp, bias=False)
|
||||
else:
|
||||
self.conv = nn.Conv2d(inp, oup, k, s, p, bias=False)
|
||||
self.bn = nn.BatchNorm2d(oup)
|
||||
if not linear:
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
if self.linear:
|
||||
return x
|
||||
else:
|
||||
return self.relu(x)
|
||||
|
||||
Mobilefacenet_bottleneck_setting = [
|
||||
# t, c , n ,s
|
||||
[2, 64, 5, 2],
|
||||
[4, 128, 1, 2],
|
||||
[2, 128, 6, 1],
|
||||
[4, 128, 1, 2],
|
||||
[2, 128, 2, 1]
|
||||
]
|
||||
|
||||
Mobilenetv2_bottleneck_setting = [
|
||||
# t, c, n, s
|
||||
[1, 16, 1, 1],
|
||||
[6, 24, 2, 2],
|
||||
[6, 32, 3, 2],
|
||||
[6, 64, 4, 2],
|
||||
[6, 96, 3, 1],
|
||||
[6, 160, 3, 2],
|
||||
[6, 320, 1, 1],
|
||||
]
|
||||
|
||||
class MobileFacenet(nn.Module):
|
||||
def __init__(self, bottleneck_setting=Mobilefacenet_bottleneck_setting):
|
||||
super(MobileFacenet, self).__init__()
|
||||
|
||||
self.conv1 = ConvBlock(3, 64, 3, 2, 1)
|
||||
|
||||
self.dw_conv1 = ConvBlock(64, 64, 3, 1, 1, dw=True)
|
||||
|
||||
self.inplanes = 64
|
||||
block = Bottleneck
|
||||
self.blocks = self._make_layer(block, bottleneck_setting)
|
||||
|
||||
self.conv2 = ConvBlock(128, 512, 1, 1, 0)
|
||||
|
||||
self.linear7 = nn.Sequential(
|
||||
nn.Conv2d(512, 512, kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
self.gap = nn.AvgPool2d(kernel_size=7)
|
||||
|
||||
self.linear1 = ConvBlock(512, 128, 1, 1, 0, linear=True)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
def _make_layer(self, block, setting):
|
||||
layers = []
|
||||
for t, c, n, s in setting:
|
||||
for i in range(n):
|
||||
if i == 0:
|
||||
layers.append(block(self.inplanes, c, s, t))
|
||||
else:
|
||||
layers.append(block(self.inplanes, c, 1, t))
|
||||
self.inplanes = c
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = x[:, :, 8:120, 8:120]
|
||||
x = self.conv1(x)
|
||||
x = self.dw_conv1(x)
|
||||
x = self.blocks(x)
|
||||
x = self.conv2(x)
|
||||
x = self.linear7(x)
|
||||
x = self.gap(x)
|
||||
x = self.linear1(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
||||
# class MobileFacenet(nn.Module):
|
||||
# def __init__(self, bottleneck_setting=Mobilefacenet_bottleneck_setting):
|
||||
# super(MobileFacenet, self).__init__()
|
||||
|
||||
# self.conv1 = ConvBlock(3, 64, 3, 2, 1)
|
||||
|
||||
# self.dw_conv1 = ConvBlock(64, 64, 3, 1, 1, dw=True)
|
||||
|
||||
# self.inplanes = 64
|
||||
# block = Bottleneck
|
||||
# self.blocks = self._make_layer(block, bottleneck_setting)
|
||||
|
||||
# self.conv2 = ConvBlock(128, 512, 1, 1, 0)
|
||||
|
||||
# self.linear7 = ConvBlock(512, 512, (7, 6), 1, 0, dw=True, linear=True)
|
||||
|
||||
# self.linear1 = ConvBlock(512, 128, 1, 1, 0, linear=True)
|
||||
|
||||
# for m in self.modules():
|
||||
# if isinstance(m, nn.Conv2d):
|
||||
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
# m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
# elif isinstance(m, nn.BatchNorm2d):
|
||||
# m.weight.data.fill_(1)
|
||||
# m.bias.data.zero_()
|
||||
|
||||
# def _make_layer(self, block, setting):
|
||||
# layers = []
|
||||
# for t, c, n, s in setting:
|
||||
# for i in range(n):
|
||||
# if i == 0:
|
||||
# layers.append(block(self.inplanes, c, s, t))
|
||||
# else:
|
||||
# layers.append(block(self.inplanes, c, 1, t))
|
||||
# self.inplanes = c
|
||||
|
||||
# return nn.Sequential(*layers)
|
||||
|
||||
# def forward(self, x):
|
||||
# x = self.conv1(x)
|
||||
# x = self.dw_conv1(x)
|
||||
# x = self.blocks(x)
|
||||
# x = self.conv2(x)
|
||||
# x = self.linear7(x)
|
||||
# x = self.linear1(x)
|
||||
# x = x.view(x.size(0), -1)
|
||||
|
||||
# return x
|
||||
|
||||
|
||||
class ArcMarginProduct(nn.Module):
|
||||
def __init__(self, in_features=128, out_features=200, s=32.0, m=0.50, easy_margin=False):
|
||||
super(ArcMarginProduct, self).__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.s = s
|
||||
self.m = m
|
||||
self.weight = Parameter(torch.Tensor(out_features, in_features))
|
||||
nn.init.xavier_uniform_(self.weight)
|
||||
# init.kaiming_uniform_()
|
||||
# self.weight.data.normal_(std=0.001)
|
||||
|
||||
self.easy_margin = easy_margin
|
||||
self.cos_m = math.cos(m)
|
||||
self.sin_m = math.sin(m)
|
||||
# make the function cos(theta+m) monotonic decreasing while theta in [0°,180°]
|
||||
self.th = math.cos(math.pi - m)
|
||||
self.mm = math.sin(math.pi - m) * m
|
||||
|
||||
def forward(self, x, label):
|
||||
cosine = F.linear(F.normalize(x), F.normalize(self.weight))
|
||||
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
|
||||
phi = cosine * self.cos_m - sine * self.sin_m
|
||||
if self.easy_margin:
|
||||
phi = torch.where(cosine > 0, phi, cosine)
|
||||
else:
|
||||
phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm)
|
||||
|
||||
one_hot = torch.zeros(cosine.size(), device='cuda')
|
||||
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
|
||||
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
|
||||
output *= self.s
|
||||
return output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# input = Variable(torch.FloatTensor(2, 3, 112, 96))
|
||||
input = Variable(torch.FloatTensor(2, 3, 128, 128)) # 해상도 128x128 수정 진행.
|
||||
net = MobileFacenet()
|
||||
print(net)
|
||||
x = net(input)
|
||||
print(x.shape)
|
||||
@ -0,0 +1,88 @@
|
||||
import torch
|
||||
import os
|
||||
from core import model_song # 학습할 때 썼던 model 파일을 불러와야 합니다.
|
||||
|
||||
# ----------------------------
|
||||
# 1. 설정 (경로 및 입력 사이즈)
|
||||
# ----------------------------
|
||||
# 사용자님이 알려주신 ckpt 경로
|
||||
ckpt_path = '/home/cuuva/face_exp/MobileFaceNet_Pytorch/result(MODEL_SONG)/MODEL_2_20251203_173900/best_model/best_001.ckpt'
|
||||
onnx_path = 'model_2_test.onnx' # 저장될 파일 이름
|
||||
|
||||
# [중요] 학습할 때 사용한 이미지 해상도와 일치해야 합니다.
|
||||
# 아까 코드에서 128x128로 수정하신 것을 확인했으므로 128로 설정합니다.
|
||||
input_size = (1, 3, 128, 128)
|
||||
|
||||
def convert():
|
||||
print(f"Loading checkpoint from: {ckpt_path}")
|
||||
|
||||
# ----------------------------
|
||||
# 2. 모델 구조 정의
|
||||
# ----------------------------
|
||||
# 학습 코드와 동일한 모델 클래스를 인스턴스화 합니다.
|
||||
net = model_song.MobileFacenet()
|
||||
|
||||
# ----------------------------
|
||||
# 3. 가중치(Weights) 로드
|
||||
# ----------------------------
|
||||
checkpoint = torch.load(ckpt_path, map_location='cpu', weights_only=False) # GPU가 없어도 돌 수 있게 cpu로 로드
|
||||
|
||||
# 저장된 ckpt 구조에 따라 state_dict를 가져옵니다.
|
||||
if 'net_state_dict' in checkpoint:
|
||||
state_dict = checkpoint['net_state_dict']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
# [핵심] DataParallel로 학습했다면 키(key) 앞에 'module.'이 붙어있습니다.
|
||||
# 이를 제거해줘야 단일 모델에 로드할 수 있습니다.
|
||||
new_state_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
name = k.replace("module.", "") # 'module.conv1.weight' -> 'conv1.weight'
|
||||
new_state_dict[name] = v
|
||||
|
||||
# 가중치 덮어씌우기
|
||||
net.load_state_dict(new_state_dict)
|
||||
|
||||
# ----------------------------
|
||||
# 4. 평가 모드 전환 (필수!)
|
||||
# ----------------------------
|
||||
# Dropout이나 Batch Norm이 학습 모드가 아닌 추론 모드로 동작하게 합니다.
|
||||
net.eval()
|
||||
# ----------------------------
|
||||
# 4. ONNX 폴더 경로 생성
|
||||
# ----------------------------
|
||||
# ckpt_path 상위 폴더 이름 추출
|
||||
experiment_folder_name = os.path.basename(os.path.dirname(os.path.dirname(ckpt_path)))
|
||||
# 모델 최상위 경로
|
||||
model_root = '/home/cuuva/face_exp/MobileFaceNet_Pytorch/result(MODEL_SONG)'
|
||||
# 최종 ONNX 경로
|
||||
onnx_dir = os.path.join(model_root, 'ONNX', experiment_folder_name)
|
||||
os.makedirs(onnx_dir, exist_ok=True)
|
||||
|
||||
# ckpt 이름 기반으로 onnx 파일 이름 생성
|
||||
# onnx_name = os.path.splitext(os.path.basename(ckpt_path))[0] + '.onnx'
|
||||
onnx_name = 'model_2_test.onnx'
|
||||
onnx_path = os.path.join(onnx_dir, onnx_name)
|
||||
|
||||
# ----------------------------
|
||||
# 5. ONNX Export
|
||||
# ----------------------------
|
||||
print("Exporting to ONNX...")
|
||||
|
||||
# 모델 추적(Trace)을 위한 더미 입력 데이터 생성
|
||||
dummy_input = torch.randn(*input_size)
|
||||
|
||||
torch.onnx.export(
|
||||
net, # 실행할 모델
|
||||
dummy_input, # 더미 입력값
|
||||
onnx_path, # 저장할 경로
|
||||
verbose=True, # 변환 과정 로그 출력
|
||||
input_names=['input'], # 입력 노드 이름 (나중에 추론할 때 씀)
|
||||
output_names=['output'], # 출력 노드 이름
|
||||
external_data=False
|
||||
)
|
||||
|
||||
print(f"Success! Model saved to: {os.path.abspath(onnx_path)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
convert()
|
||||
@ -0,0 +1,201 @@
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from torch.optim import lr_scheduler
|
||||
from torch.nn import DataParallel, CrossEntropyLoss
|
||||
from dataloader.MyHF_loader import CASIA_HF, LFW_Pairs
|
||||
from core import model_song
|
||||
from core.utils import init_log
|
||||
import os, time, numpy as np, scipy.io
|
||||
from datetime import datetime
|
||||
from config import BATCH_SIZE, SAVE_FREQ, RESUME, SAVE_DIR, TEST_FREQ, TOTAL_EPOCH, MODEL_PRE, GPU
|
||||
from sklearn.metrics.pairwise import cosine_similarity # [추가] 정확도 계산용
|
||||
|
||||
# ----------------------------
|
||||
# [추가] 간단한 LFW 정확도 계산 함수
|
||||
# ----------------------------
|
||||
def calculate_accuracy(featureLs, featureRs, flags, thresholds=np.arange(0, 1, 0.01)):
|
||||
# 1. 특징 벡터 정규화 (Normalize)
|
||||
featureLs = featureLs / np.linalg.norm(featureLs, axis=1, keepdims=True)
|
||||
featureRs = featureRs / np.linalg.norm(featureRs, axis=1, keepdims=True)
|
||||
|
||||
# 2. 코사인 유사도 계산 (Dot Product)
|
||||
scores = np.sum(featureLs * featureRs, axis=1)
|
||||
|
||||
# 3. 최적의 임계값(Threshold) 찾기 및 정확도 계산
|
||||
best_acc = 0
|
||||
for t in thresholds:
|
||||
# 유사도가 t보다 크면 '같은 사람(1)', 작으면 '다른 사람(0)'
|
||||
preds = (scores > t).astype(int)
|
||||
acc = np.mean(preds == flags)
|
||||
if acc > best_acc:
|
||||
best_acc = acc
|
||||
return best_acc
|
||||
|
||||
# ----------------------------
|
||||
# GPU 및 초기 설정 (기존 동일)
|
||||
# ----------------------------
|
||||
gpu_list = ''
|
||||
multi_gpus = False
|
||||
if isinstance(GPU, int):
|
||||
gpu_list = str(GPU)
|
||||
else:
|
||||
multi_gpus = True
|
||||
gpu_list = ','.join(map(str, GPU))
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
|
||||
|
||||
start_epoch = 1
|
||||
save_dir = os.path.join('./result(MODEL_SONG)', 'MODEL_2_' + datetime.now().strftime('%Y%m%d_%H%M%S'))
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
logging = init_log(save_dir)
|
||||
_print = logging.info
|
||||
|
||||
# ----------------------------
|
||||
# Dataloader (기존 동일)
|
||||
# ----------------------------
|
||||
trainset = CASIA_HF()
|
||||
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
|
||||
shuffle=True, num_workers=8, drop_last=False)
|
||||
|
||||
testset = LFW_Pairs()
|
||||
testloader = torch.utils.data.DataLoader(testset, batch_size=32,
|
||||
shuffle=False, num_workers=8, drop_last=False)
|
||||
|
||||
# ----------------------------
|
||||
# Model & Optimizer (기존 동일)
|
||||
# ----------------------------
|
||||
net = model_song.MobileFacenet()
|
||||
ArcMargin = model_song.ArcMarginProduct(128, trainset.dataset.features['label'].num_classes)
|
||||
|
||||
if RESUME:
|
||||
ckpt = torch.load(RESUME)
|
||||
net.load_state_dict(ckpt['net_state_dict'])
|
||||
start_epoch = ckpt['epoch'] + 1
|
||||
|
||||
net = net.cuda()
|
||||
ArcMargin = ArcMargin.cuda()
|
||||
if multi_gpus:
|
||||
net = DataParallel(net)
|
||||
ArcMargin = DataParallel(ArcMargin)
|
||||
|
||||
criterion = CrossEntropyLoss()
|
||||
|
||||
ignored_params = list(map(id, ArcMargin.weight))
|
||||
# prelu_params = [p for m in net.modules() if isinstance(m, torch.nn.PReLU) for p in m.parameters()]
|
||||
base_params = filter(lambda p: id(p) not in ignored_params, net.parameters())
|
||||
|
||||
# 기존 아키텍처에서 prelu 삭제했었으니까 아래 optim에서도 삭제 처리
|
||||
optimizer_ft = optim.SGD([
|
||||
{'params': base_params, 'weight_decay': 4e-5},
|
||||
{'params': ArcMargin.weight, 'weight_decay': 4e-4}
|
||||
], lr=0.1, momentum=0.9, nesterov=True)
|
||||
|
||||
# optimizer_ft = optim.SGD([
|
||||
# {'params': base_params, 'weight_decay': 4e-5},
|
||||
# {'params': net.linear1.parameters(), 'weight_decay': 4e-4},
|
||||
# {'params': ArcMargin.weight, 'weight_decay': 4e-4},
|
||||
# {'params': prelu_params, 'weight_decay': 0.0}
|
||||
# ], lr=0.1, momentum=0.9, nesterov=True)
|
||||
|
||||
# 여기도 Config에서 Epoch 숫자 수정할때마다 milestone도 같이 수정해줘야함.
|
||||
exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer_ft, milestones=[240, 310, 400], gamma=0.1)
|
||||
|
||||
# ----------------------------
|
||||
# [추가] Best Accuracy 기록 변수
|
||||
# ----------------------------
|
||||
best_lfw_acc = 0.0
|
||||
|
||||
# ----------------------------
|
||||
# Training Loop
|
||||
# ----------------------------
|
||||
for epoch in range(start_epoch, TOTAL_EPOCH + 1):
|
||||
net.train()
|
||||
train_total_loss, total = 0, 0
|
||||
since = time.time()
|
||||
_print(f"Train Epoch: {epoch}/{TOTAL_EPOCH} ...")
|
||||
|
||||
for data in trainloader:
|
||||
img, label = data[0].cuda(), data[1].cuda()
|
||||
optimizer_ft.zero_grad()
|
||||
raw_logits = net(img)
|
||||
output = ArcMargin(raw_logits, label)
|
||||
loss = criterion(output, label)
|
||||
loss.backward()
|
||||
optimizer_ft.step()
|
||||
train_total_loss += loss.item() * img.size(0)
|
||||
total += img.size(0)
|
||||
|
||||
train_total_loss /= total
|
||||
time_elapsed = time.time() - since
|
||||
_print(f" total_loss: {train_total_loss:.4f} time: {time_elapsed//60:.0f}m {time_elapsed%60:.0f}s")
|
||||
|
||||
exp_lr_scheduler.step()
|
||||
|
||||
# ----------------------------
|
||||
# Test & Best Model Save
|
||||
# ----------------------------
|
||||
if epoch % TEST_FREQ == 0:
|
||||
net.eval()
|
||||
featureLs, featureRs = None, None
|
||||
flags = [] # [추가] 정답(Label)을 저장할 리스트
|
||||
|
||||
_print(" Testing LFW...")
|
||||
with torch.no_grad(): # [추가] 테스트 땐 기울기 계산 끔 (메모리 절약)
|
||||
for data in testloader:
|
||||
# data 구조: [images_list, label(flag)]라고 가정
|
||||
# LFW_Pairs의 경우 data[1]이 보통 정답(1:같은사람, 0:다른사람)
|
||||
|
||||
# 이미지 GPU 이동
|
||||
imgs = [d.cuda() for d in data[0]]
|
||||
|
||||
# 정답 라벨 수집 (numpy로 변환)
|
||||
flags.append(data[1].numpy())
|
||||
|
||||
# 특징 추출
|
||||
res = [net(d).data.cpu().numpy() for d in imgs]
|
||||
|
||||
featureL = np.concatenate((res[0], res[1]), 1)
|
||||
featureR = np.concatenate((res[2], res[3]), 1)
|
||||
|
||||
featureLs = featureL if featureLs is None else np.concatenate((featureLs, featureL), 0)
|
||||
featureRs = featureR if featureRs is None else np.concatenate((featureRs, featureR), 0)
|
||||
|
||||
# [추가] 정답 리스트 합치기
|
||||
flags = np.concatenate(flags, 0)
|
||||
|
||||
# [추가] 정확도 계산
|
||||
# 만약 scipy.io.savemat은 필요하면 유지, 아니면 삭제해도 됨
|
||||
# result = {'fl': featureLs, 'fr': featureRs}
|
||||
# scipy.io.savemat('./result/tmp_result.mat', result)
|
||||
|
||||
# 직접 정확도 계산 (함수 호출)
|
||||
current_acc = calculate_accuracy(featureLs, featureRs, flags)
|
||||
_print(f" LFW Acc: {current_acc*100:.2f}% (Best: {best_lfw_acc*100:.2f}%)")
|
||||
|
||||
# [핵심] Best Model 저장 (Loss가 아닌 Acc 기준)
|
||||
if current_acc > best_lfw_acc:
|
||||
best_lfw_acc = current_acc
|
||||
state_dict = net.module.state_dict() if multi_gpus else net.state_dict()
|
||||
|
||||
best_dir = os.path.join(save_dir, 'best_model')
|
||||
os.makedirs(best_dir, exist_ok=True)
|
||||
|
||||
best_path = os.path.join(best_dir, f'best_{epoch:03d}.ckpt')
|
||||
torch.save(
|
||||
{
|
||||
'epoch': epoch,
|
||||
'net_state_dict': state_dict,
|
||||
'acc': best_lfw_acc
|
||||
},
|
||||
best_path
|
||||
)
|
||||
_print(f" ==> Best Model Saved! (Acc: {best_lfw_acc*100:.2f}%, Epoch: {epoch}))")
|
||||
|
||||
# ----------------------------
|
||||
# Regular Save (백업용)
|
||||
# ----------------------------
|
||||
if epoch % SAVE_FREQ == 0:
|
||||
state_dict = net.module.state_dict() if multi_gpus else net.state_dict()
|
||||
torch.save({'epoch': epoch, 'net_state_dict': state_dict},
|
||||
os.path.join(save_dir, f'{epoch:03d}.ckpt'))
|
||||
|
||||
_print("finishing training")
|
||||
Loading…
Reference in new issue