You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
120 lines
4.4 KiB
120 lines
4.4 KiB
|
6 months ago
|
import torch
|
||
|
|
import torch.nn as nn
|
||
|
|
import torch.optim as optim
|
||
|
|
import os
|
||
|
|
import argparse
|
||
|
|
from tqdm import tqdm
|
||
|
|
from datetime import datetime # 시간 정보를 가져오기 위해 추가
|
||
|
|
|
||
|
|
# Import our custom modules
|
||
|
|
from model import get_face_model, ArcFace
|
||
|
|
from dataset import get_face_dataloaders
|
||
|
|
|
||
|
|
def main():
|
||
|
|
# --------------------------------------------------------
|
||
|
|
# 1. Hyperparameters & Settings
|
||
|
|
# --------------------------------------------------------
|
||
|
|
parser = argparse.ArgumentParser(description='Face Recognition Training for Apache 6')
|
||
|
|
|
||
|
|
# 데이터셋 경로 (사용자 환경에 맞게 기본값 설정)
|
||
|
|
parser.add_argument('--data_dir', type=str, default='/home/cuuva/face_exp/datasets', help='Path to datasets')
|
||
|
|
|
||
|
|
# 결과 저장 최상위 경로 (여기 아래에 시간별 폴더가 생김)
|
||
|
|
parser.add_argument('--project_dir', type=str, default='./results', help='Base directory for results')
|
||
|
|
|
||
|
|
parser.add_argument('--epochs', type=int, default=20, help='Number of epochs')
|
||
|
|
parser.add_argument('--batch_size', type=int, default=64, help='Batch size')
|
||
|
|
parser.add_argument('--lr', type=float, default=0.1, help='Learning rate')
|
||
|
|
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
|
print(f"Device: {device}")
|
||
|
|
|
||
|
|
# --------------------------------------------------------
|
||
|
|
# [수정됨] 실험 디렉토리 생성 로직 (yy-mm-dd-hour-minute)
|
||
|
|
# --------------------------------------------------------
|
||
|
|
# 현재 시간 구하기
|
||
|
|
current_time = datetime.now().strftime("%y-%m-%d-%H-%M")
|
||
|
|
|
||
|
|
# 최종 저장 경로: ./results/23-12-12-15-30/
|
||
|
|
save_dir = os.path.join(args.project_dir, current_time)
|
||
|
|
|
||
|
|
# 폴더 생성
|
||
|
|
os.makedirs(save_dir, exist_ok=True)
|
||
|
|
print(f"✅ Experiment results will be saved to: {save_dir}")
|
||
|
|
|
||
|
|
# --------------------------------------------------------
|
||
|
|
# 2. Load Data
|
||
|
|
# --------------------------------------------------------
|
||
|
|
print("Loading Data...")
|
||
|
|
try:
|
||
|
|
train_loader, num_classes = get_face_dataloaders(args.data_dir, args.batch_size)
|
||
|
|
print(f"Classes (People): {num_classes}, Batch Size: {args.batch_size}")
|
||
|
|
except Exception as e:
|
||
|
|
print(f"❌ Error loading data: {e}")
|
||
|
|
return
|
||
|
|
|
||
|
|
# --------------------------------------------------------
|
||
|
|
# 3. Initialize Model & Loss
|
||
|
|
# --------------------------------------------------------
|
||
|
|
# Backbone (보드에 배포할 모델)
|
||
|
|
backbone = get_face_model().to(device)
|
||
|
|
|
||
|
|
# ArcFace Header (학습용 Loss 계산기)
|
||
|
|
metric_fc = ArcFace(in_features=128, out_features=num_classes).to(device)
|
||
|
|
|
||
|
|
# Loss
|
||
|
|
criterion = nn.CrossEntropyLoss()
|
||
|
|
|
||
|
|
# Optimizer
|
||
|
|
optimizer = optim.SGD([
|
||
|
|
{'params': backbone.parameters()},
|
||
|
|
{'params': metric_fc.parameters()}
|
||
|
|
], lr=args.lr, momentum=0.9, weight_decay=5e-4)
|
||
|
|
|
||
|
|
# Scheduler
|
||
|
|
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[8, 14, 18], gamma=0.1)
|
||
|
|
|
||
|
|
# --------------------------------------------------------
|
||
|
|
# 4. Training Loop
|
||
|
|
# --------------------------------------------------------
|
||
|
|
print("🚀 Start Training...")
|
||
|
|
|
||
|
|
for epoch in range(args.epochs):
|
||
|
|
backbone.train()
|
||
|
|
metric_fc.train()
|
||
|
|
|
||
|
|
running_loss = 0.0
|
||
|
|
pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.epochs}")
|
||
|
|
|
||
|
|
for images, labels in pbar:
|
||
|
|
images, labels = images.to(device), labels.to(device)
|
||
|
|
|
||
|
|
optimizer.zero_grad()
|
||
|
|
|
||
|
|
# Forward Pass
|
||
|
|
features = backbone(images) # [N, 128, 1, 1]
|
||
|
|
outputs = metric_fc(features, labels) # [N, Num_Classes]
|
||
|
|
|
||
|
|
# Loss Calc & Backward
|
||
|
|
loss = criterion(outputs, labels)
|
||
|
|
loss.backward()
|
||
|
|
optimizer.step()
|
||
|
|
|
||
|
|
running_loss += loss.item()
|
||
|
|
pbar.set_postfix({'loss': running_loss / (pbar.n + 1)})
|
||
|
|
|
||
|
|
scheduler.step()
|
||
|
|
|
||
|
|
# [수정됨] Save Checkpoint (.pt 확장자 사용)
|
||
|
|
# 해당 실험 폴더(save_dir) 안에 저장됨
|
||
|
|
save_path = os.path.join(save_dir, f"backbone_epoch_{epoch+1}.pt")
|
||
|
|
torch.save(backbone.state_dict(), save_path)
|
||
|
|
|
||
|
|
# 마지막 에폭일 때 로그 출력
|
||
|
|
if epoch == args.epochs - 1:
|
||
|
|
print(f"🎉 Training Finished! Final model saved at: {save_path}")
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|