import torch from torchvision import datasets, transforms from torch.utils.data import DataLoader import os def get_face_dataloaders(data_dir, batch_size=64, num_workers=4): """ Input: data_dir (e.g., ~/face_exp/datasets) Structure assumed: data_dir/CASIA-WebFace/ID/images.jpg """ # 1. Train Transform (Augmentation + Resize to 128x128) train_transform = transforms.Compose([ transforms.Resize((128, 128), interpolation=transforms.InterpolationMode.BICUBIC), transforms.RandomHorizontalFlip(p=0.5), transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) # 2. Path Setup train_dir = os.path.join(data_dir, 'CASIA-WebFace') if not os.path.exists(train_dir): raise FileNotFoundError(f"데이터셋 경로를 찾을 수 없습니다: {train_dir}\n'CASIA-WebFace' 폴더가 해당 위치에 있는지 확인해주세요.") # 3. Dataset & Loader train_dataset = datasets.ImageFolder(root=train_dir, transform=train_transform) train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True ) return train_loader, len(train_dataset.classes)