|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.optim as optim
|
|
|
import os
|
|
|
import argparse
|
|
|
from tqdm import tqdm
|
|
|
from datetime import datetime
|
|
|
from torchvision import transforms
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
# Import Custom Modules
|
|
|
from model import get_face_model, ArcFace
|
|
|
from dataset import get_face_dataloaders
|
|
|
from validation import LFWDataset, validate_lfw
|
|
|
|
|
|
def main():
|
|
|
# 1. Hyperparameters
|
|
|
parser = argparse.ArgumentParser(description='Face Recognition Training with LFW Validation')
|
|
|
parser.add_argument('--data_dir', type=str, default='/home/cuuva/face_exp/datasets', help='Root dataset dir')
|
|
|
parser.add_argument('--project_dir', type=str, default='./results', help='Save dir')
|
|
|
parser.add_argument('--epochs', type=int, default=20)
|
|
|
parser.add_argument('--batch_size', type=int, default=64)
|
|
|
parser.add_argument('--lr', type=float, default=0.1)
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
print(f"Device: {device}")
|
|
|
|
|
|
# 디렉토리 설정
|
|
|
current_time = datetime.now().strftime("%y-%m-%d-%H-%M")
|
|
|
base_save_dir = os.path.join(args.project_dir, current_time)
|
|
|
best_save_dir = os.path.join(base_save_dir, 'best_model')
|
|
|
os.makedirs(best_save_dir, exist_ok=True)
|
|
|
|
|
|
print(f"✅ Save path: {base_save_dir}")
|
|
|
|
|
|
# 2. Train Data Loader
|
|
|
print("Loading Train Data...")
|
|
|
try:
|
|
|
train_loader, num_classes = get_face_dataloaders(args.data_dir, args.batch_size)
|
|
|
except Exception as e:
|
|
|
print(f"❌ Train Data Error: {e}")
|
|
|
return
|
|
|
|
|
|
# --------------------------------------------------------
|
|
|
# [수정됨] 3. LFW Validation Loader (경로 로직 개선)
|
|
|
# --------------------------------------------------------
|
|
|
print("Loading LFW Data...")
|
|
|
|
|
|
# pairs.txt가 있는 루트 경로: ~/face_exp/datasets/LFW
|
|
|
lfw_root = os.path.join(args.data_dir, 'LFW')
|
|
|
pairs_path = os.path.join(lfw_root, 'pairs.txt')
|
|
|
|
|
|
# 실제 이미지가 있는 경로 찾기
|
|
|
# 1순위: ~/face_exp/datasets/LFW/lfw (보여주신 구조)
|
|
|
# 2순위: ~/face_exp/datasets/LFW (일반적인 구조)
|
|
|
if os.path.exists(os.path.join(lfw_root, 'lfw')):
|
|
|
lfw_img_dir = os.path.join(lfw_root, 'lfw')
|
|
|
else:
|
|
|
lfw_img_dir = lfw_root
|
|
|
|
|
|
print(f"ℹ️ Pairs path: {pairs_path}")
|
|
|
print(f"ℹ️ Images path: {lfw_img_dir}")
|
|
|
|
|
|
if os.path.exists(pairs_path) and os.path.exists(lfw_img_dir):
|
|
|
lfw_transform = transforms.Compose([
|
|
|
transforms.Resize((128, 128), interpolation=transforms.InterpolationMode.BICUBIC),
|
|
|
transforms.ToTensor(),
|
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
|
|
])
|
|
|
|
|
|
# 여기서 lfw_img_dir를 넘겨주는 것이 핵심!
|
|
|
lfw_dataset = LFWDataset(lfw_img_dir, pairs_path, transform=lfw_transform)
|
|
|
|
|
|
lfw_loader = DataLoader(lfw_dataset, batch_size=64, shuffle=False, num_workers=4, drop_last=False)
|
|
|
do_validation = True
|
|
|
print(f"✅ LFW Loaded. Pairs: {len(lfw_dataset)}")
|
|
|
else:
|
|
|
print(f"⚠️ Warning: 'pairs.txt' or Image dir not found. Skipping Validation.")
|
|
|
do_validation = False
|
|
|
|
|
|
# 4. Model & Loss
|
|
|
backbone = get_face_model().to(device)
|
|
|
metric_fc = ArcFace(in_features=128, out_features=num_classes).to(device)
|
|
|
criterion = nn.CrossEntropyLoss()
|
|
|
|
|
|
optimizer = optim.SGD([
|
|
|
{'params': backbone.parameters()},
|
|
|
{'params': metric_fc.parameters()}
|
|
|
], lr=args.lr, momentum=0.9, weight_decay=5e-4)
|
|
|
|
|
|
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[8, 14, 18], gamma=0.1)
|
|
|
|
|
|
# 5. Training Loop
|
|
|
print("🚀 Start Training...")
|
|
|
|
|
|
best_acc = 0.0
|
|
|
|
|
|
for epoch in range(args.epochs):
|
|
|
# --- TRAIN ---
|
|
|
backbone.train()
|
|
|
metric_fc.train()
|
|
|
|
|
|
running_loss = 0.0
|
|
|
pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.epochs}")
|
|
|
|
|
|
for images, labels in pbar:
|
|
|
images, labels = images.to(device), labels.to(device)
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
features = backbone(images)
|
|
|
outputs = metric_fc(features, labels)
|
|
|
loss = criterion(outputs, labels)
|
|
|
loss.backward()
|
|
|
optimizer.step()
|
|
|
|
|
|
running_loss += loss.item()
|
|
|
pbar.set_postfix({'loss': f"{running_loss / (pbar.n + 1):.4f}"})
|
|
|
|
|
|
scheduler.step()
|
|
|
|
|
|
# --- VALIDATION (LFW) ---
|
|
|
if do_validation:
|
|
|
acc, th = validate_lfw(backbone, lfw_loader, device)
|
|
|
print(f"📊 LFW Acc: {acc*100:.2f}% (Threshold: {th:.2f})")
|
|
|
|
|
|
if acc > best_acc:
|
|
|
best_acc = acc
|
|
|
save_path = os.path.join(best_save_dir, "best_backbone.pt")
|
|
|
torch.save(backbone.state_dict(), save_path)
|
|
|
print(f"🏆 Best Model Updated! Saved to {save_path}")
|
|
|
else:
|
|
|
save_path = os.path.join(base_save_dir, f"backbone_epoch_{epoch+1}.pt")
|
|
|
torch.save(backbone.state_dict(), save_path)
|
|
|
|
|
|
torch.save(backbone.state_dict(), os.path.join(base_save_dir, "last_backbone.pt"))
|
|
|
|
|
|
print("\n🎉 Training Finished!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main() |