You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
51 lines
1.5 KiB
51 lines
1.5 KiB
import numpy as np
|
|
import scipy.misc
|
|
import os
|
|
import torch
|
|
|
|
class CASIA_Face(object):
|
|
def __init__(self, root):
|
|
self.root = root
|
|
|
|
img_txt_dir = os.path.join(root, 'CASIA-WebFace-112X96.txt')
|
|
image_list = []
|
|
label_list = []
|
|
with open(img_txt_dir) as f:
|
|
img_label_list = f.read().splitlines()
|
|
for info in img_label_list:
|
|
image_dir, label_name = info.split(' ')
|
|
image_list.append(os.path.join(root, 'CASIA-WebFace-112X96', image_dir))
|
|
label_list.append(int(label_name))
|
|
|
|
self.image_list = image_list
|
|
self.label_list = label_list
|
|
self.class_nums = len(np.unique(self.label_list))
|
|
|
|
def __getitem__(self, index):
|
|
img_path = self.image_list[index]
|
|
target = self.label_list[index]
|
|
img = scipy.misc.imread(img_path)
|
|
|
|
if len(img.shape) == 2:
|
|
img = np.stack([img] * 3, 2)
|
|
flip = np.random.choice(2)*2-1
|
|
img = img[:, ::flip, :]
|
|
img = (img - 127.5) / 128.0
|
|
img = img.transpose(2, 0, 1)
|
|
img = torch.from_numpy(img).float()
|
|
|
|
return img, target
|
|
|
|
def __len__(self):
|
|
return len(self.image_list)
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
data_dir = '/home/brl/USER/fzc/dataset/CASIA'
|
|
dataset = CASIA_Face(root=data_dir)
|
|
trainloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True, num_workers=8, drop_last=False)
|
|
print(len(dataset))
|
|
for data in trainloader:
|
|
print(data[0].shape)
|