You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

39 lines
1.4 KiB

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os
def get_face_dataloaders(data_dir, batch_size=64, num_workers=4):
"""
Input: data_dir (e.g., ~/face_exp/datasets)
Structure assumed: data_dir/CASIA-WebFace/ID/images.jpg
"""
# 1. Train Transform (Augmentation + Resize to 128x128)
train_transform = transforms.Compose([
transforms.Resize((128, 128), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# 2. Path Setup
train_dir = os.path.join(data_dir, 'CASIA-WebFace')
if not os.path.exists(train_dir):
raise FileNotFoundError(f"데이터셋 경로를 찾을 수 없습니다: {train_dir}\n'CASIA-WebFace' 폴더가 해당 위치에 있는지 확인해주세요.")
# 3. Dataset & Loader
train_dataset = datasets.ImageFolder(root=train_dir, transform=train_transform)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
drop_last=True
)
return train_loader, len(train_dataset.classes)