You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

51 lines
1.5 KiB

import numpy as np
import scipy.misc
import os
import torch
class CASIA_Face(object):
def __init__(self, root):
self.root = root
img_txt_dir = os.path.join(root, 'CASIA-WebFace-112X96.txt')
image_list = []
label_list = []
with open(img_txt_dir) as f:
img_label_list = f.read().splitlines()
for info in img_label_list:
image_dir, label_name = info.split(' ')
image_list.append(os.path.join(root, 'CASIA-WebFace-112X96', image_dir))
label_list.append(int(label_name))
self.image_list = image_list
self.label_list = label_list
self.class_nums = len(np.unique(self.label_list))
def __getitem__(self, index):
img_path = self.image_list[index]
target = self.label_list[index]
img = scipy.misc.imread(img_path)
if len(img.shape) == 2:
img = np.stack([img] * 3, 2)
flip = np.random.choice(2)*2-1
img = img[:, ::flip, :]
img = (img - 127.5) / 128.0
img = img.transpose(2, 0, 1)
img = torch.from_numpy(img).float()
return img, target
def __len__(self):
return len(self.image_list)
if __name__ == '__main__':
data_dir = '/home/brl/USER/fzc/dataset/CASIA'
dataset = CASIA_Face(root=data_dir)
trainloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True, num_workers=8, drop_last=False)
print(len(dataset))
for data in trainloader:
print(data[0].shape)