import torch import torch.nn as nn import torch.optim as optim import os import argparse from tqdm import tqdm from datetime import datetime from torchvision import transforms from torch.utils.data import DataLoader # Import Custom Modules from model import get_face_model, ArcFace from dataset import get_face_dataloaders from validation import LFWDataset, validate_lfw def main(): # 1. Hyperparameters parser = argparse.ArgumentParser(description='Face Recognition Training with LFW Validation') parser.add_argument('--data_dir', type=str, default='/home/cuuva/face_exp/datasets', help='Root dataset dir') parser.add_argument('--project_dir', type=str, default='./results', help='Save dir') parser.add_argument('--epochs', type=int, default=20) parser.add_argument('--batch_size', type=int, default=64) parser.add_argument('--lr', type=float, default=0.1) args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Device: {device}") # 디렉토리 설정 current_time = datetime.now().strftime("%y-%m-%d-%H-%M") base_save_dir = os.path.join(args.project_dir, current_time) best_save_dir = os.path.join(base_save_dir, 'best_model') os.makedirs(best_save_dir, exist_ok=True) print(f"✅ Save path: {base_save_dir}") # 2. Train Data Loader print("Loading Train Data...") try: train_loader, num_classes = get_face_dataloaders(args.data_dir, args.batch_size) except Exception as e: print(f"❌ Train Data Error: {e}") return # -------------------------------------------------------- # [수정됨] 3. LFW Validation Loader (경로 로직 개선) # -------------------------------------------------------- print("Loading LFW Data...") # pairs.txt가 있는 루트 경로: ~/face_exp/datasets/LFW lfw_root = os.path.join(args.data_dir, 'LFW') pairs_path = os.path.join(lfw_root, 'pairs.txt') # 실제 이미지가 있는 경로 찾기 # 1순위: ~/face_exp/datasets/LFW/lfw (보여주신 구조) # 2순위: ~/face_exp/datasets/LFW (일반적인 구조) if os.path.exists(os.path.join(lfw_root, 'lfw')): lfw_img_dir = os.path.join(lfw_root, 'lfw') else: lfw_img_dir = lfw_root print(f"ℹ️ Pairs path: {pairs_path}") print(f"ℹ️ Images path: {lfw_img_dir}") if os.path.exists(pairs_path) and os.path.exists(lfw_img_dir): lfw_transform = transforms.Compose([ transforms.Resize((128, 128), interpolation=transforms.InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) # 여기서 lfw_img_dir를 넘겨주는 것이 핵심! lfw_dataset = LFWDataset(lfw_img_dir, pairs_path, transform=lfw_transform) lfw_loader = DataLoader(lfw_dataset, batch_size=64, shuffle=False, num_workers=4, drop_last=False) do_validation = True print(f"✅ LFW Loaded. Pairs: {len(lfw_dataset)}") else: print(f"⚠️ Warning: 'pairs.txt' or Image dir not found. Skipping Validation.") do_validation = False # 4. Model & Loss backbone = get_face_model().to(device) metric_fc = ArcFace(in_features=128, out_features=num_classes).to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD([ {'params': backbone.parameters()}, {'params': metric_fc.parameters()} ], lr=args.lr, momentum=0.9, weight_decay=5e-4) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[8, 14, 18], gamma=0.1) # 5. Training Loop print("🚀 Start Training...") best_acc = 0.0 for epoch in range(args.epochs): # --- TRAIN --- backbone.train() metric_fc.train() running_loss = 0.0 pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.epochs}") for images, labels in pbar: images, labels = images.to(device), labels.to(device) optimizer.zero_grad() features = backbone(images) outputs = metric_fc(features, labels) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() pbar.set_postfix({'loss': f"{running_loss / (pbar.n + 1):.4f}"}) scheduler.step() # --- VALIDATION (LFW) --- if do_validation: acc, th = validate_lfw(backbone, lfw_loader, device) print(f"📊 LFW Acc: {acc*100:.2f}% (Threshold: {th:.2f})") if acc > best_acc: best_acc = acc save_path = os.path.join(best_save_dir, "best_backbone.pt") torch.save(backbone.state_dict(), save_path) print(f"🏆 Best Model Updated! Saved to {save_path}") else: save_path = os.path.join(base_save_dir, f"backbone_epoch_{epoch+1}.pt") torch.save(backbone.state_dict(), save_path) torch.save(backbone.state_dict(), os.path.join(base_save_dir, "last_backbone.pt")) print("\n🎉 Training Finished!") if __name__ == "__main__": main()