You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
167 lines
5.2 KiB
167 lines
5.2 KiB
|
7 months ago
|
import os
|
||
|
|
import torch.utils.data
|
||
|
|
from torch import nn
|
||
|
|
from torch.nn import DataParallel
|
||
|
|
from datetime import datetime
|
||
|
|
from config import BATCH_SIZE, SAVE_FREQ, RESUME, SAVE_DIR, TEST_FREQ, TOTAL_EPOCH, MODEL_PRE, GPU
|
||
|
|
from config import CASIA_DATA_DIR, LFW_DATA_DIR
|
||
|
|
from core import model
|
||
|
|
from core.utils import init_log
|
||
|
|
from dataloader.CASIA_Face_loader import CASIA_Face
|
||
|
|
from dataloader.LFW_loader import LFW
|
||
|
|
from torch.optim import lr_scheduler
|
||
|
|
import torch.optim as optim
|
||
|
|
import time
|
||
|
|
from lfw_eval import parseList, evaluation_10_fold
|
||
|
|
import numpy as np
|
||
|
|
import scipy.io
|
||
|
|
|
||
|
|
# gpu init
|
||
|
|
gpu_list = ''
|
||
|
|
multi_gpus = False
|
||
|
|
if isinstance(GPU, int):
|
||
|
|
gpu_list = str(GPU)
|
||
|
|
else:
|
||
|
|
multi_gpus = True
|
||
|
|
for i, gpu_id in enumerate(GPU):
|
||
|
|
gpu_list += str(gpu_id)
|
||
|
|
if i != len(GPU) - 1:
|
||
|
|
gpu_list += ','
|
||
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
|
||
|
|
|
||
|
|
# other init
|
||
|
|
start_epoch = 1
|
||
|
|
save_dir = os.path.join(SAVE_DIR, MODEL_PRE + 'v2_' + datetime.now().strftime('%Y%m%d_%H%M%S'))
|
||
|
|
if os.path.exists(save_dir):
|
||
|
|
raise NameError('model dir exists!')
|
||
|
|
os.makedirs(save_dir)
|
||
|
|
logging = init_log(save_dir)
|
||
|
|
_print = logging.info
|
||
|
|
|
||
|
|
|
||
|
|
# define trainloader and testloader
|
||
|
|
trainset = CASIA_Face(root=CASIA_DATA_DIR)
|
||
|
|
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
|
||
|
|
shuffle=True, num_workers=8, drop_last=False)
|
||
|
|
|
||
|
|
# nl: left_image_path
|
||
|
|
# nr: right_image_path
|
||
|
|
nl, nr, folds, flags = parseList(root=LFW_DATA_DIR)
|
||
|
|
testdataset = LFW(nl, nr)
|
||
|
|
testloader = torch.utils.data.DataLoader(testdataset, batch_size=32,
|
||
|
|
shuffle=False, num_workers=8, drop_last=False)
|
||
|
|
|
||
|
|
# define model
|
||
|
|
net = model.MobileFacenet()
|
||
|
|
ArcMargin = model.ArcMarginProduct(128, trainset.class_nums)
|
||
|
|
|
||
|
|
if RESUME:
|
||
|
|
ckpt = torch.load(RESUME)
|
||
|
|
net.load_state_dict(ckpt['net_state_dict'])
|
||
|
|
start_epoch = ckpt['epoch'] + 1
|
||
|
|
|
||
|
|
|
||
|
|
# define optimizers
|
||
|
|
ignored_params = list(map(id, net.linear1.parameters()))
|
||
|
|
ignored_params += list(map(id, ArcMargin.weight))
|
||
|
|
prelu_params_id = []
|
||
|
|
prelu_params = []
|
||
|
|
for m in net.modules():
|
||
|
|
if isinstance(m, nn.PReLU):
|
||
|
|
ignored_params += list(map(id, m.parameters()))
|
||
|
|
prelu_params += m.parameters()
|
||
|
|
base_params = filter(lambda p: id(p) not in ignored_params, net.parameters())
|
||
|
|
|
||
|
|
optimizer_ft = optim.SGD([
|
||
|
|
{'params': base_params, 'weight_decay': 4e-5},
|
||
|
|
{'params': net.linear1.parameters(), 'weight_decay': 4e-4},
|
||
|
|
{'params': ArcMargin.weight, 'weight_decay': 4e-4},
|
||
|
|
{'params': prelu_params, 'weight_decay': 0.0}
|
||
|
|
], lr=0.1, momentum=0.9, nesterov=True)
|
||
|
|
|
||
|
|
exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer_ft, milestones=[36, 52, 58], gamma=0.1)
|
||
|
|
|
||
|
|
|
||
|
|
net = net.cuda()
|
||
|
|
ArcMargin = ArcMargin.cuda()
|
||
|
|
if multi_gpus:
|
||
|
|
net = DataParallel(net)
|
||
|
|
ArcMargin = DataParallel(ArcMargin)
|
||
|
|
criterion = torch.nn.CrossEntropyLoss()
|
||
|
|
|
||
|
|
|
||
|
|
best_acc = 0.0
|
||
|
|
best_epoch = 0
|
||
|
|
for epoch in range(start_epoch, TOTAL_EPOCH+1):
|
||
|
|
exp_lr_scheduler.step()
|
||
|
|
# train model
|
||
|
|
_print('Train Epoch: {}/{} ...'.format(epoch, TOTAL_EPOCH))
|
||
|
|
net.train()
|
||
|
|
|
||
|
|
train_total_loss = 0.0
|
||
|
|
total = 0
|
||
|
|
since = time.time()
|
||
|
|
for data in trainloader:
|
||
|
|
img, label = data[0].cuda(), data[1].cuda()
|
||
|
|
batch_size = img.size(0)
|
||
|
|
optimizer_ft.zero_grad()
|
||
|
|
|
||
|
|
raw_logits = net(img)
|
||
|
|
|
||
|
|
output = ArcMargin(raw_logits, label)
|
||
|
|
total_loss = criterion(output, label)
|
||
|
|
total_loss.backward()
|
||
|
|
optimizer_ft.step()
|
||
|
|
|
||
|
|
train_total_loss += total_loss.item() * batch_size
|
||
|
|
total += batch_size
|
||
|
|
|
||
|
|
train_total_loss = train_total_loss / total
|
||
|
|
time_elapsed = time.time() - since
|
||
|
|
loss_msg = ' total_loss: {:.4f} time: {:.0f}m {:.0f}s'\
|
||
|
|
.format(train_total_loss, time_elapsed // 60, time_elapsed % 60)
|
||
|
|
_print(loss_msg)
|
||
|
|
|
||
|
|
# test model on lfw
|
||
|
|
if epoch % TEST_FREQ == 0:
|
||
|
|
net.eval()
|
||
|
|
featureLs = None
|
||
|
|
featureRs = None
|
||
|
|
_print('Test Epoch: {} ...'.format(epoch))
|
||
|
|
for data in testloader:
|
||
|
|
for i in range(len(data)):
|
||
|
|
data[i] = data[i].cuda()
|
||
|
|
res = [net(d).data.cpu().numpy() for d in data]
|
||
|
|
featureL = np.concatenate((res[0], res[1]), 1)
|
||
|
|
featureR = np.concatenate((res[2], res[3]), 1)
|
||
|
|
if featureLs is None:
|
||
|
|
featureLs = featureL
|
||
|
|
else:
|
||
|
|
featureLs = np.concatenate((featureLs, featureL), 0)
|
||
|
|
if featureRs is None:
|
||
|
|
featureRs = featureR
|
||
|
|
else:
|
||
|
|
featureRs = np.concatenate((featureRs, featureR), 0)
|
||
|
|
|
||
|
|
result = {'fl': featureLs, 'fr': featureRs, 'fold': folds, 'flag': flags}
|
||
|
|
# save tmp_result
|
||
|
|
scipy.io.savemat('./result/tmp_result.mat', result)
|
||
|
|
accs = evaluation_10_fold('./result/tmp_result.mat')
|
||
|
|
_print(' ave: {:.4f}'.format(np.mean(accs) * 100))
|
||
|
|
|
||
|
|
# save model
|
||
|
|
if epoch % SAVE_FREQ == 0:
|
||
|
|
msg = 'Saving checkpoint: {}'.format(epoch)
|
||
|
|
_print(msg)
|
||
|
|
if multi_gpus:
|
||
|
|
net_state_dict = net.module.state_dict()
|
||
|
|
else:
|
||
|
|
net_state_dict = net.state_dict()
|
||
|
|
if not os.path.exists(save_dir):
|
||
|
|
os.mkdir(save_dir)
|
||
|
|
torch.save({
|
||
|
|
'epoch': epoch,
|
||
|
|
'net_state_dict': net_state_dict},
|
||
|
|
os.path.join(save_dir, '%03d.ckpt' % epoch))
|
||
|
|
print('finishing training')
|