You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
119 lines
4.3 KiB
119 lines
4.3 KiB
|
6 months ago
|
import torch
|
||
|
|
import torch.nn.functional as F
|
||
|
|
from torch.utils.data import Dataset, DataLoader
|
||
|
|
from torchvision import transforms
|
||
|
|
from PIL import Image
|
||
|
|
import os
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
# --------------------------------------------------------
|
||
|
|
# 1. LFW Dataset Loader
|
||
|
|
# --------------------------------------------------------
|
||
|
|
class LFWDataset(Dataset):
|
||
|
|
def __init__(self, lfw_dir, pairs_path, transform=None):
|
||
|
|
self.lfw_dir = lfw_dir
|
||
|
|
self.pairs_path = pairs_path
|
||
|
|
self.transform = transform
|
||
|
|
self.validation_images = self.get_lfw_paths(lfw_dir)
|
||
|
|
|
||
|
|
def get_lfw_paths(self, lfw_dir):
|
||
|
|
# pairs.txt 파싱하여 이미지 경로 쌍과 정답(issame) 리스트 생성
|
||
|
|
pairs = []
|
||
|
|
with open(self.pairs_path, 'r') as f:
|
||
|
|
lines = f.readlines()[1:] # 첫 줄(헤더) 건너뜀
|
||
|
|
|
||
|
|
for line in lines:
|
||
|
|
p = line.strip().split('\t')
|
||
|
|
|
||
|
|
if len(p) == 3: # 같은 사람 (name, img1_num, img2_num)
|
||
|
|
name = p[0]
|
||
|
|
img1 = os.path.join(lfw_dir, name, f"{name}_{int(p[1]):04d}.jpg")
|
||
|
|
img2 = os.path.join(lfw_dir, name, f"{name}_{int(p[2]):04d}.jpg")
|
||
|
|
issame = True
|
||
|
|
pairs.append((img1, img2, issame))
|
||
|
|
|
||
|
|
elif len(p) == 4: # 다른 사람 (name1, img1_num, name2, img2_num)
|
||
|
|
name1 = p[0]
|
||
|
|
img1 = os.path.join(lfw_dir, name1, f"{name1}_{int(p[1]):04d}.jpg")
|
||
|
|
name2 = p[2]
|
||
|
|
img2 = os.path.join(lfw_dir, name2, f"{name2}_{int(p[3]):04d}.jpg")
|
||
|
|
issame = False
|
||
|
|
pairs.append((img1, img2, issame))
|
||
|
|
return pairs
|
||
|
|
|
||
|
|
def __len__(self):
|
||
|
|
return len(self.validation_images)
|
||
|
|
|
||
|
|
def __getitem__(self, index):
|
||
|
|
img1_path, img2_path, issame = self.validation_images[index]
|
||
|
|
|
||
|
|
try:
|
||
|
|
img1 = Image.open(img1_path).convert('RGB')
|
||
|
|
img2 = Image.open(img2_path).convert('RGB')
|
||
|
|
except Exception as e:
|
||
|
|
# 혹시 파일이 없을 경우를 대비한 더미 (실제론 파일 확인 필요)
|
||
|
|
print(f"File Load Error: {e}")
|
||
|
|
img1 = Image.new('RGB', (128, 128))
|
||
|
|
img2 = Image.new('RGB', (128, 128))
|
||
|
|
|
||
|
|
if self.transform:
|
||
|
|
img1 = self.transform(img1)
|
||
|
|
img2 = self.transform(img2)
|
||
|
|
|
||
|
|
return img1, img2, issame
|
||
|
|
|
||
|
|
# --------------------------------------------------------
|
||
|
|
# 2. Evaluation Function
|
||
|
|
# --------------------------------------------------------
|
||
|
|
def validate_lfw(model, lfw_loader, device):
|
||
|
|
model.eval()
|
||
|
|
similarities = []
|
||
|
|
actual_issame = []
|
||
|
|
|
||
|
|
print("🔍 Validating on LFW...")
|
||
|
|
with torch.no_grad():
|
||
|
|
for img1, img2, issame in lfw_loader:
|
||
|
|
img1, img2 = img1.to(device), img2.to(device)
|
||
|
|
|
||
|
|
# Feature Extraction
|
||
|
|
feat1 = model(img1) # [B, 128, 1, 1]
|
||
|
|
feat2 = model(img2) # [B, 128, 1, 1]
|
||
|
|
|
||
|
|
# Flatten
|
||
|
|
feat1 = feat1.view(feat1.size(0), -1)
|
||
|
|
feat2 = feat2.view(feat2.size(0), -1)
|
||
|
|
|
||
|
|
# Cosine Similarity Calculation
|
||
|
|
# 128차원 벡터의 코사인 유사도 (-1 ~ 1)
|
||
|
|
cos_sim = F.cosine_similarity(feat1, feat2)
|
||
|
|
|
||
|
|
similarities.extend(cos_sim.cpu().numpy())
|
||
|
|
actual_issame.extend(issame.numpy())
|
||
|
|
|
||
|
|
similarities = np.array(similarities)
|
||
|
|
actual_issame = np.array(actual_issame)
|
||
|
|
|
||
|
|
# ----------------------------------------------------
|
||
|
|
# Best Threshold Search (단순화된 버전)
|
||
|
|
# ----------------------------------------------------
|
||
|
|
best_acc = 0.0
|
||
|
|
best_th = 0.0
|
||
|
|
|
||
|
|
# -1.0 부터 1.0 까지 0.01 단위로 Threshold를 이동하며 정확도 측정
|
||
|
|
thresholds = np.arange(-1.0, 1.0, 0.01)
|
||
|
|
|
||
|
|
for th in thresholds:
|
||
|
|
# th보다 크면 True(동일인), 작으면 False(타인) 예측
|
||
|
|
predict_issame = np.greater(similarities, th)
|
||
|
|
|
||
|
|
# 정답과 비교
|
||
|
|
true_positives = np.sum(np.logical_and(predict_issame, actual_issame))
|
||
|
|
true_negatives = np.sum(np.logical_and(np.logical_not(predict_issame), np.logical_not(actual_issame)))
|
||
|
|
|
||
|
|
acc = (true_positives + true_negatives) / len(actual_issame)
|
||
|
|
|
||
|
|
if acc > best_acc:
|
||
|
|
best_acc = acc
|
||
|
|
best_th = th
|
||
|
|
|
||
|
|
return best_acc, best_th
|