You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
95 lines
3.3 KiB
95 lines
3.3 KiB
|
7 months ago
|
import torch
|
||
|
|
from torch.utils.data import Dataset
|
||
|
|
from PIL import Image
|
||
|
|
import numpy as np
|
||
|
|
from datasets import load_dataset
|
||
|
|
|
||
|
|
# ----------------------------
|
||
|
|
# Train Dataset: CASIA Web Face
|
||
|
|
# ----------------------------
|
||
|
|
class CASIA_HF(Dataset):
|
||
|
|
def __init__(self):
|
||
|
|
self.dataset = load_dataset("SaffalPoosh/casia_web_face", split="train") # Hugging Face train split
|
||
|
|
|
||
|
|
def __len__(self):
|
||
|
|
return len(self.dataset)
|
||
|
|
|
||
|
|
def __getitem__(self, idx):
|
||
|
|
item = self.dataset[idx]
|
||
|
|
img = np.array(item['image']) # Hugging Face image 열
|
||
|
|
img = Image.fromarray(img).convert("RGB").resize((128,128))
|
||
|
|
img = np.array(img)
|
||
|
|
img = (img - 127.5) / 128.0
|
||
|
|
img = img.transpose(2,0,1)
|
||
|
|
img = torch.from_numpy(img).float()
|
||
|
|
label = torch.tensor(int(item['label'])) # label 열 확인 필요
|
||
|
|
return img, label
|
||
|
|
|
||
|
|
|
||
|
|
# ----------------------------
|
||
|
|
# Test Dataset: LFW Pairs
|
||
|
|
# ----------------------------
|
||
|
|
# class LFW_Pairs(Dataset):
|
||
|
|
# def __init__(self):
|
||
|
|
# self.dataset = load_dataset("logasja/lfw", "pairs", split="test")
|
||
|
|
|
||
|
|
# def __len__(self):
|
||
|
|
# return len(self.dataset)
|
||
|
|
|
||
|
|
# def __getitem__(self, idx):
|
||
|
|
# item = self.dataset[idx]
|
||
|
|
# imgl = np.array(item['image1'])
|
||
|
|
# imgr = np.array(item['image2'])
|
||
|
|
|
||
|
|
# imgl = Image.fromarray(imgl).convert("RGB").resize((128,128))
|
||
|
|
# imgr = Image.fromarray(imgr).convert("RGB").resize((128,128))
|
||
|
|
|
||
|
|
# imglist = [imgl, imgl[:, ::-1, :], imgr, imgr[:, ::-1, :]] # original + flip
|
||
|
|
# for i in range(len(imglist)):
|
||
|
|
# imglist[i] = (imglist[i] - 127.5) / 128.0
|
||
|
|
# imglist[i] = imglist[i].transpose(2,0,1)
|
||
|
|
# imgs = [torch.from_numpy(i).float() for i in imglist]
|
||
|
|
|
||
|
|
# label = torch.tensor(item['label'])
|
||
|
|
# return imgs, label
|
||
|
|
class LFW_Pairs(Dataset):
|
||
|
|
def __init__(self):
|
||
|
|
self.dataset = load_dataset("logasja/lfw", "pairs", split="test")
|
||
|
|
|
||
|
|
def __len__(self):
|
||
|
|
return len(self.dataset)
|
||
|
|
|
||
|
|
def __getitem__(self, idx):
|
||
|
|
item = self.dataset[idx]
|
||
|
|
# print(idx,item) # 지울거
|
||
|
|
# print(type(item)) # 지울거
|
||
|
|
|
||
|
|
# imgl = np.array(item['img_0'])
|
||
|
|
# imgr = np.array(item['img_1'])
|
||
|
|
|
||
|
|
# PIL 이미지 가져오기
|
||
|
|
imgl = item['img_0']
|
||
|
|
imgr = item['img_1']
|
||
|
|
|
||
|
|
imgl = imgl.resize((128,128)).convert("RGB")
|
||
|
|
imgr = imgr.resize((128,128)).convert("RGB")
|
||
|
|
# print('imgl shape:', imgl.shape, 'type:', type(imgl))
|
||
|
|
# print('imgr shape:', imgr.shape, 'type:', type(imgr))
|
||
|
|
|
||
|
|
# numpy 배열로 변환
|
||
|
|
imgl = np.array(imgl)
|
||
|
|
imgr = np.array(imgr)
|
||
|
|
# print('numpy 배열로 변환 후, imgl shape:', imgl.shape, 'type:', type(imgl))
|
||
|
|
# print('numpy 배열로 변환 후, imgr shape:', imgr.shape, 'type:', type(imgr))
|
||
|
|
|
||
|
|
|
||
|
|
# imglist = [imgl, imgl[:, ::-1, :], imgr, imgr[:, ::-1, :]] # original + flip
|
||
|
|
# 이미지 리스트 생성 (original + flip)
|
||
|
|
imglist = [imgl, imgl[:, ::-1, :], imgr, imgr[:, ::-1, :]]
|
||
|
|
for i in range(len(imglist)):
|
||
|
|
imglist[i] = (imglist[i] - 127.5) / 128.0
|
||
|
|
imglist[i] = imglist[i].transpose(2, 0, 1)
|
||
|
|
imgs = [torch.from_numpy(i).float() for i in imglist]
|
||
|
|
|
||
|
|
label = torch.tensor(item['pair'])
|
||
|
|
return imgs, label
|