You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

95 lines
3.3 KiB

import torch
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
from datasets import load_dataset
# ----------------------------
# Train Dataset: CASIA Web Face
# ----------------------------
class CASIA_HF(Dataset):
def __init__(self):
self.dataset = load_dataset("SaffalPoosh/casia_web_face", split="train") # Hugging Face train split
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
item = self.dataset[idx]
img = np.array(item['image']) # Hugging Face image 열
img = Image.fromarray(img).convert("RGB").resize((128,128))
img = np.array(img)
img = (img - 127.5) / 128.0
img = img.transpose(2,0,1)
img = torch.from_numpy(img).float()
label = torch.tensor(int(item['label'])) # label 열 확인 필요
return img, label
# ----------------------------
# Test Dataset: LFW Pairs
# ----------------------------
# class LFW_Pairs(Dataset):
# def __init__(self):
# self.dataset = load_dataset("logasja/lfw", "pairs", split="test")
# def __len__(self):
# return len(self.dataset)
# def __getitem__(self, idx):
# item = self.dataset[idx]
# imgl = np.array(item['image1'])
# imgr = np.array(item['image2'])
# imgl = Image.fromarray(imgl).convert("RGB").resize((128,128))
# imgr = Image.fromarray(imgr).convert("RGB").resize((128,128))
# imglist = [imgl, imgl[:, ::-1, :], imgr, imgr[:, ::-1, :]] # original + flip
# for i in range(len(imglist)):
# imglist[i] = (imglist[i] - 127.5) / 128.0
# imglist[i] = imglist[i].transpose(2,0,1)
# imgs = [torch.from_numpy(i).float() for i in imglist]
# label = torch.tensor(item['label'])
# return imgs, label
class LFW_Pairs(Dataset):
def __init__(self):
self.dataset = load_dataset("logasja/lfw", "pairs", split="test")
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
item = self.dataset[idx]
# print(idx,item) # 지울거
# print(type(item)) # 지울거
# imgl = np.array(item['img_0'])
# imgr = np.array(item['img_1'])
# PIL 이미지 가져오기
imgl = item['img_0']
imgr = item['img_1']
imgl = imgl.resize((128,128)).convert("RGB")
imgr = imgr.resize((128,128)).convert("RGB")
# print('imgl shape:', imgl.shape, 'type:', type(imgl))
# print('imgr shape:', imgr.shape, 'type:', type(imgr))
# numpy 배열로 변환
imgl = np.array(imgl)
imgr = np.array(imgr)
# print('numpy 배열로 변환 후, imgl shape:', imgl.shape, 'type:', type(imgl))
# print('numpy 배열로 변환 후, imgr shape:', imgr.shape, 'type:', type(imgr))
# imglist = [imgl, imgl[:, ::-1, :], imgr, imgr[:, ::-1, :]] # original + flip
# 이미지 리스트 생성 (original + flip)
imglist = [imgl, imgl[:, ::-1, :], imgr, imgr[:, ::-1, :]]
for i in range(len(imglist)):
imglist[i] = (imglist[i] - 127.5) / 128.0
imglist[i] = imglist[i].transpose(2, 0, 1)
imgs = [torch.from_numpy(i).float() for i in imglist]
label = torch.tensor(item['pair'])
return imgs, label