You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

94 lines
3.3 KiB

6 months ago
import os
import glob
from tqdm.notebook import tqdm # 주피터 노트북용 진행바
# ==========================================
# 1. 절대 경로 설정 (수정된 부분)
# ==========================================
# 사용자 홈 디렉토리(/home/cuuva)를 포함한 전체 경로를 입력합니다.
SRC_ROOT = '/home/cuuva/git/Detection_Experiment/datasets/fashionpedia_yolo/labels_all_bak'
DST_ROOT = '/home/cuuva/git/Detection_Experiment/datasets/fashionpedia_yolo/labels_reduced'
# ==========================================
# 2. 클래스 매핑 규칙 (Old ID -> New ID)
# ==========================================
# 0(shirt), 1(top), 2(sweater) -> 0 (shirt)
# 3(cardigan), 4(jacket) -> 1 (jacket)
# 6(pants) -> 2 (pants)
# 13(glasses) -> 3 (glasses)
class_mapping = {
0: 0, 1: 0, 2: 0,
3: 1, 4: 1,
6: 2,
13: 3
}
print(f"원본 경로: {SRC_ROOT}")
print(f"저장 경로: {DST_ROOT}")
print("-" * 30)
# ==========================================
# 3. 데이터 변환 로직
# ==========================================
def process_yolo_labels(src_root, dst_root, mapping):
subsets = ['train', 'val']
total_files = 0
total_objects_kept = 0
for subset in subsets:
src_dir = os.path.join(src_root, subset)
dst_dir = os.path.join(dst_root, subset)
# 소스 디렉토리 존재 확인
if not os.path.exists(src_dir):
print(f"⚠️ 에러: 소스 폴더를 찾을 수 없습니다 -> {src_dir}")
continue
# 타겟 디렉토리 생성
os.makedirs(dst_dir, exist_ok=True)
# 파일 목록 로드
txt_files = glob.glob(os.path.join(src_dir, '*.txt'))
total_files += len(txt_files)
print(f"🚀 Processing [{subset}]: {len(txt_files)} files found.")
# 변환 시작
for file_path in tqdm(txt_files, desc=f"{subset} Converting"):
file_name = os.path.basename(file_path)
dst_path = os.path.join(dst_dir, file_name)
new_lines = []
with open(file_path, 'r') as f:
lines = f.readlines()
for line in lines:
parts = line.strip().split()
if not parts: continue
old_cls = int(parts[0])
coords = parts[1:] # x, y, w, h
# 매핑 규칙에 있는 클래스만 남김
if old_cls in mapping:
new_cls = mapping[old_cls]
new_line = f"{new_cls} {' '.join(coords)}\n"
new_lines.append(new_line)
total_objects_kept += 1
# 파일 쓰기 (빈 파일이라도 생성하여 구조 유지)
with open(dst_path, 'w') as f:
f.writelines(new_lines)
return total_files, total_objects_kept
# 실행
processed_cnt, kept_cnt = process_yolo_labels(SRC_ROOT, DST_ROOT, class_mapping)
print("="*30)
print("✅ 변환 완료")
print(f"총 처리된 파일: {processed_cnt}")
print(f"남은 객체 수: {kept_cnt}")
print(f"저장된 위치: {DST_ROOT}")