You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
33 lines
927 B
33 lines
927 B
|
7 months ago
|
import numpy as np
|
||
|
|
import scipy.misc
|
||
|
|
|
||
|
|
import torch
|
||
|
|
class LFW(object):
|
||
|
|
def __init__(self, imgl, imgr):
|
||
|
|
|
||
|
|
self.imgl_list = imgl
|
||
|
|
self.imgr_list = imgr
|
||
|
|
|
||
|
|
def __getitem__(self, index):
|
||
|
|
imgl = scipy.misc.imread(self.imgl_list[index])
|
||
|
|
if len(imgl.shape) == 2:
|
||
|
|
imgl = np.stack([imgl] * 3, 2)
|
||
|
|
imgr = scipy.misc.imread(self.imgr_list[index])
|
||
|
|
if len(imgr.shape) == 2:
|
||
|
|
imgr = np.stack([imgr] * 3, 2)
|
||
|
|
|
||
|
|
# imgl = imgl[:, :, ::-1]
|
||
|
|
# imgr = imgr[:, :, ::-1]
|
||
|
|
imglist = [imgl, imgl[:, ::-1, :], imgr, imgr[:, ::-1, :]]
|
||
|
|
for i in range(len(imglist)):
|
||
|
|
imglist[i] = (imglist[i] - 127.5) / 128.0
|
||
|
|
imglist[i] = imglist[i].transpose(2, 0, 1)
|
||
|
|
imgs = [torch.from_numpy(i).float() for i in imglist]
|
||
|
|
return imgs
|
||
|
|
|
||
|
|
def __len__(self):
|
||
|
|
return len(self.imgl_list)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
pass
|