import torch from torch.utils.data import Dataset from PIL import Image import numpy as np from datasets import load_dataset # ---------------------------- # Train Dataset: CASIA Web Face # ---------------------------- class CASIA_HF(Dataset): def __init__(self): self.dataset = load_dataset("SaffalPoosh/casia_web_face", split="train") # Hugging Face train split def __len__(self): return len(self.dataset) def __getitem__(self, idx): item = self.dataset[idx] img = np.array(item['image']) # Hugging Face image 열 img = Image.fromarray(img).convert("RGB").resize((128,128)) img = np.array(img) img = (img - 127.5) / 128.0 img = img.transpose(2,0,1) img = torch.from_numpy(img).float() label = torch.tensor(int(item['label'])) # label 열 확인 필요 return img, label # ---------------------------- # Test Dataset: LFW Pairs # ---------------------------- # class LFW_Pairs(Dataset): # def __init__(self): # self.dataset = load_dataset("logasja/lfw", "pairs", split="test") # def __len__(self): # return len(self.dataset) # def __getitem__(self, idx): # item = self.dataset[idx] # imgl = np.array(item['image1']) # imgr = np.array(item['image2']) # imgl = Image.fromarray(imgl).convert("RGB").resize((128,128)) # imgr = Image.fromarray(imgr).convert("RGB").resize((128,128)) # imglist = [imgl, imgl[:, ::-1, :], imgr, imgr[:, ::-1, :]] # original + flip # for i in range(len(imglist)): # imglist[i] = (imglist[i] - 127.5) / 128.0 # imglist[i] = imglist[i].transpose(2,0,1) # imgs = [torch.from_numpy(i).float() for i in imglist] # label = torch.tensor(item['label']) # return imgs, label class LFW_Pairs(Dataset): def __init__(self): self.dataset = load_dataset("logasja/lfw", "pairs", split="test") def __len__(self): return len(self.dataset) def __getitem__(self, idx): item = self.dataset[idx] # print(idx,item) # 지울거 # print(type(item)) # 지울거 # imgl = np.array(item['img_0']) # imgr = np.array(item['img_1']) # PIL 이미지 가져오기 imgl = item['img_0'] imgr = item['img_1'] imgl = imgl.resize((128,128)).convert("RGB") imgr = imgr.resize((128,128)).convert("RGB") # print('imgl shape:', imgl.shape, 'type:', type(imgl)) # print('imgr shape:', imgr.shape, 'type:', type(imgr)) # numpy 배열로 변환 imgl = np.array(imgl) imgr = np.array(imgr) # print('numpy 배열로 변환 후, imgl shape:', imgl.shape, 'type:', type(imgl)) # print('numpy 배열로 변환 후, imgr shape:', imgr.shape, 'type:', type(imgr)) # imglist = [imgl, imgl[:, ::-1, :], imgr, imgr[:, ::-1, :]] # original + flip # 이미지 리스트 생성 (original + flip) imglist = [imgl, imgl[:, ::-1, :], imgr, imgr[:, ::-1, :]] for i in range(len(imglist)): imglist[i] = (imglist[i] - 127.5) / 128.0 imglist[i] = imglist[i].transpose(2, 0, 1) imgs = [torch.from_numpy(i).float() for i in imglist] label = torch.tensor(item['pair']) return imgs, label