import torch import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image import os import numpy as np # -------------------------------------------------------- # 1. LFW Dataset Loader # -------------------------------------------------------- class LFWDataset(Dataset): def __init__(self, lfw_dir, pairs_path, transform=None): self.lfw_dir = lfw_dir self.pairs_path = pairs_path self.transform = transform self.validation_images = self.get_lfw_paths(lfw_dir) def get_lfw_paths(self, lfw_dir): # pairs.txt 파싱하여 이미지 경로 쌍과 정답(issame) 리스트 생성 pairs = [] with open(self.pairs_path, 'r') as f: lines = f.readlines()[1:] # 첫 줄(헤더) 건너뜀 for line in lines: p = line.strip().split('\t') if len(p) == 3: # 같은 사람 (name, img1_num, img2_num) name = p[0] img1 = os.path.join(lfw_dir, name, f"{name}_{int(p[1]):04d}.jpg") img2 = os.path.join(lfw_dir, name, f"{name}_{int(p[2]):04d}.jpg") issame = True pairs.append((img1, img2, issame)) elif len(p) == 4: # 다른 사람 (name1, img1_num, name2, img2_num) name1 = p[0] img1 = os.path.join(lfw_dir, name1, f"{name1}_{int(p[1]):04d}.jpg") name2 = p[2] img2 = os.path.join(lfw_dir, name2, f"{name2}_{int(p[3]):04d}.jpg") issame = False pairs.append((img1, img2, issame)) return pairs def __len__(self): return len(self.validation_images) def __getitem__(self, index): img1_path, img2_path, issame = self.validation_images[index] try: img1 = Image.open(img1_path).convert('RGB') img2 = Image.open(img2_path).convert('RGB') except Exception as e: # 혹시 파일이 없을 경우를 대비한 더미 (실제론 파일 확인 필요) print(f"File Load Error: {e}") img1 = Image.new('RGB', (128, 128)) img2 = Image.new('RGB', (128, 128)) if self.transform: img1 = self.transform(img1) img2 = self.transform(img2) return img1, img2, issame # -------------------------------------------------------- # 2. Evaluation Function # -------------------------------------------------------- def validate_lfw(model, lfw_loader, device): model.eval() similarities = [] actual_issame = [] print("🔍 Validating on LFW...") with torch.no_grad(): for img1, img2, issame in lfw_loader: img1, img2 = img1.to(device), img2.to(device) # Feature Extraction feat1 = model(img1) # [B, 128, 1, 1] feat2 = model(img2) # [B, 128, 1, 1] # Flatten feat1 = feat1.view(feat1.size(0), -1) feat2 = feat2.view(feat2.size(0), -1) # Cosine Similarity Calculation # 128차원 벡터의 코사인 유사도 (-1 ~ 1) cos_sim = F.cosine_similarity(feat1, feat2) similarities.extend(cos_sim.cpu().numpy()) actual_issame.extend(issame.numpy()) similarities = np.array(similarities) actual_issame = np.array(actual_issame) # ---------------------------------------------------- # Best Threshold Search (단순화된 버전) # ---------------------------------------------------- best_acc = 0.0 best_th = 0.0 # -1.0 부터 1.0 까지 0.01 단위로 Threshold를 이동하며 정확도 측정 thresholds = np.arange(-1.0, 1.0, 0.01) for th in thresholds: # th보다 크면 True(동일인), 작으면 False(타인) 예측 predict_issame = np.greater(similarities, th) # 정답과 비교 true_positives = np.sum(np.logical_and(predict_issame, actual_issame)) true_negatives = np.sum(np.logical_and(np.logical_not(predict_issame), np.logical_not(actual_issame))) acc = (true_positives + true_negatives) / len(actual_issame) if acc > best_acc: best_acc = acc best_th = th return best_acc, best_th