You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

120 lines
4.4 KiB

import torch
import torch.nn as nn
import torch.optim as optim
import os
import argparse
from tqdm import tqdm
from datetime import datetime # 시간 정보를 가져오기 위해 추가
# Import our custom modules
from model import get_face_model, ArcFace
from dataset import get_face_dataloaders
def main():
# --------------------------------------------------------
# 1. Hyperparameters & Settings
# --------------------------------------------------------
parser = argparse.ArgumentParser(description='Face Recognition Training for Apache 6')
# 데이터셋 경로 (사용자 환경에 맞게 기본값 설정)
parser.add_argument('--data_dir', type=str, default='/home/cuuva/face_exp/datasets', help='Path to datasets')
# 결과 저장 최상위 경로 (여기 아래에 시간별 폴더가 생김)
parser.add_argument('--project_dir', type=str, default='./results', help='Base directory for results')
parser.add_argument('--epochs', type=int, default=20, help='Number of epochs')
parser.add_argument('--batch_size', type=int, default=64, help='Batch size')
parser.add_argument('--lr', type=float, default=0.1, help='Learning rate')
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
# --------------------------------------------------------
# [수정됨] 실험 디렉토리 생성 로직 (yy-mm-dd-hour-minute)
# --------------------------------------------------------
# 현재 시간 구하기
current_time = datetime.now().strftime("%y-%m-%d-%H-%M")
# 최종 저장 경로: ./results/23-12-12-15-30/
save_dir = os.path.join(args.project_dir, current_time)
# 폴더 생성
os.makedirs(save_dir, exist_ok=True)
print(f"✅ Experiment results will be saved to: {save_dir}")
# --------------------------------------------------------
# 2. Load Data
# --------------------------------------------------------
print("Loading Data...")
try:
train_loader, num_classes = get_face_dataloaders(args.data_dir, args.batch_size)
print(f"Classes (People): {num_classes}, Batch Size: {args.batch_size}")
except Exception as e:
print(f"❌ Error loading data: {e}")
return
# --------------------------------------------------------
# 3. Initialize Model & Loss
# --------------------------------------------------------
# Backbone (보드에 배포할 모델)
backbone = get_face_model().to(device)
# ArcFace Header (학습용 Loss 계산기)
metric_fc = ArcFace(in_features=128, out_features=num_classes).to(device)
# Loss
criterion = nn.CrossEntropyLoss()
# Optimizer
optimizer = optim.SGD([
{'params': backbone.parameters()},
{'params': metric_fc.parameters()}
], lr=args.lr, momentum=0.9, weight_decay=5e-4)
# Scheduler
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[8, 14, 18], gamma=0.1)
# --------------------------------------------------------
# 4. Training Loop
# --------------------------------------------------------
print("🚀 Start Training...")
for epoch in range(args.epochs):
backbone.train()
metric_fc.train()
running_loss = 0.0
pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.epochs}")
for images, labels in pbar:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
# Forward Pass
features = backbone(images) # [N, 128, 1, 1]
outputs = metric_fc(features, labels) # [N, Num_Classes]
# Loss Calc & Backward
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
pbar.set_postfix({'loss': running_loss / (pbar.n + 1)})
scheduler.step()
# [수정됨] Save Checkpoint (.pt 확장자 사용)
# 해당 실험 폴더(save_dir) 안에 저장됨
save_path = os.path.join(save_dir, f"backbone_epoch_{epoch+1}.pt")
torch.save(backbone.state_dict(), save_path)
# 마지막 에폭일 때 로그 출력
if epoch == args.epochs - 1:
print(f"🎉 Training Finished! Final model saved at: {save_path}")
if __name__ == "__main__":
main()