You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
39 lines
1.4 KiB
39 lines
1.4 KiB
|
6 months ago
|
import torch
|
||
|
|
from torchvision import datasets, transforms
|
||
|
|
from torch.utils.data import DataLoader
|
||
|
|
import os
|
||
|
|
|
||
|
|
def get_face_dataloaders(data_dir, batch_size=64, num_workers=4):
|
||
|
|
"""
|
||
|
|
Input: data_dir (e.g., ~/face_exp/datasets)
|
||
|
|
Structure assumed: data_dir/CASIA-WebFace/ID/images.jpg
|
||
|
|
"""
|
||
|
|
|
||
|
|
# 1. Train Transform (Augmentation + Resize to 128x128)
|
||
|
|
train_transform = transforms.Compose([
|
||
|
|
transforms.Resize((128, 128), interpolation=transforms.InterpolationMode.BICUBIC),
|
||
|
|
transforms.RandomHorizontalFlip(p=0.5),
|
||
|
|
transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
|
||
|
|
transforms.ToTensor(),
|
||
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||
|
|
])
|
||
|
|
|
||
|
|
# 2. Path Setup
|
||
|
|
train_dir = os.path.join(data_dir, 'CASIA-WebFace')
|
||
|
|
|
||
|
|
if not os.path.exists(train_dir):
|
||
|
|
raise FileNotFoundError(f"데이터셋 경로를 찾을 수 없습니다: {train_dir}\n'CASIA-WebFace' 폴더가 해당 위치에 있는지 확인해주세요.")
|
||
|
|
|
||
|
|
# 3. Dataset & Loader
|
||
|
|
train_dataset = datasets.ImageFolder(root=train_dir, transform=train_transform)
|
||
|
|
|
||
|
|
train_loader = DataLoader(
|
||
|
|
train_dataset,
|
||
|
|
batch_size=batch_size,
|
||
|
|
shuffle=True,
|
||
|
|
num_workers=num_workers,
|
||
|
|
pin_memory=True,
|
||
|
|
drop_last=True
|
||
|
|
)
|
||
|
|
|
||
|
|
return train_loader, len(train_dataset.classes)
|