You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

141 lines
5.2 KiB

import torch
import torch.nn as nn
import torch.optim as optim
import os
import argparse
from tqdm import tqdm
from datetime import datetime
from torchvision import transforms
from torch.utils.data import DataLoader
# Import Custom Modules
from model import get_face_model, ArcFace
from dataset import get_face_dataloaders
from validation import LFWDataset, validate_lfw
def main():
# 1. Hyperparameters
parser = argparse.ArgumentParser(description='Face Recognition Training with LFW Validation')
parser.add_argument('--data_dir', type=str, default='/home/cuuva/face_exp/datasets', help='Root dataset dir')
parser.add_argument('--project_dir', type=str, default='./results', help='Save dir')
parser.add_argument('--epochs', type=int, default=20)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--lr', type=float, default=0.1)
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
# 디렉토리 설정
current_time = datetime.now().strftime("%y-%m-%d-%H-%M")
base_save_dir = os.path.join(args.project_dir, current_time)
best_save_dir = os.path.join(base_save_dir, 'best_model')
os.makedirs(best_save_dir, exist_ok=True)
print(f"✅ Save path: {base_save_dir}")
# 2. Train Data Loader
print("Loading Train Data...")
try:
train_loader, num_classes = get_face_dataloaders(args.data_dir, args.batch_size)
except Exception as e:
print(f"❌ Train Data Error: {e}")
return
# --------------------------------------------------------
# [수정됨] 3. LFW Validation Loader (경로 로직 개선)
# --------------------------------------------------------
print("Loading LFW Data...")
# pairs.txt가 있는 루트 경로: ~/face_exp/datasets/LFW
lfw_root = os.path.join(args.data_dir, 'LFW')
pairs_path = os.path.join(lfw_root, 'pairs.txt')
# 실제 이미지가 있는 경로 찾기
# 1순위: ~/face_exp/datasets/LFW/lfw (보여주신 구조)
# 2순위: ~/face_exp/datasets/LFW (일반적인 구조)
if os.path.exists(os.path.join(lfw_root, 'lfw')):
lfw_img_dir = os.path.join(lfw_root, 'lfw')
else:
lfw_img_dir = lfw_root
print(f" Pairs path: {pairs_path}")
print(f" Images path: {lfw_img_dir}")
if os.path.exists(pairs_path) and os.path.exists(lfw_img_dir):
lfw_transform = transforms.Compose([
transforms.Resize((128, 128), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# 여기서 lfw_img_dir를 넘겨주는 것이 핵심!
lfw_dataset = LFWDataset(lfw_img_dir, pairs_path, transform=lfw_transform)
lfw_loader = DataLoader(lfw_dataset, batch_size=64, shuffle=False, num_workers=4, drop_last=False)
do_validation = True
print(f"✅ LFW Loaded. Pairs: {len(lfw_dataset)}")
else:
print(f"⚠️ Warning: 'pairs.txt' or Image dir not found. Skipping Validation.")
do_validation = False
# 4. Model & Loss
backbone = get_face_model().to(device)
metric_fc = ArcFace(in_features=128, out_features=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD([
{'params': backbone.parameters()},
{'params': metric_fc.parameters()}
], lr=args.lr, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[8, 14, 18], gamma=0.1)
# 5. Training Loop
print("🚀 Start Training...")
best_acc = 0.0
for epoch in range(args.epochs):
# --- TRAIN ---
backbone.train()
metric_fc.train()
running_loss = 0.0
pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.epochs}")
for images, labels in pbar:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
features = backbone(images)
outputs = metric_fc(features, labels)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
pbar.set_postfix({'loss': f"{running_loss / (pbar.n + 1):.4f}"})
scheduler.step()
# --- VALIDATION (LFW) ---
if do_validation:
acc, th = validate_lfw(backbone, lfw_loader, device)
print(f"📊 LFW Acc: {acc*100:.2f}% (Threshold: {th:.2f})")
if acc > best_acc:
best_acc = acc
save_path = os.path.join(best_save_dir, "best_backbone.pt")
torch.save(backbone.state_dict(), save_path)
print(f"🏆 Best Model Updated! Saved to {save_path}")
else:
save_path = os.path.join(base_save_dir, f"backbone_epoch_{epoch+1}.pt")
torch.save(backbone.state_dict(), save_path)
torch.save(backbone.state_dict(), os.path.join(base_save_dir, "last_backbone.pt"))
print("\n🎉 Training Finished!")
if __name__ == "__main__":
main()