Clean state

This commit is contained in:
Johnny Fernandes
2026-04-30 01:25:39 +01:00
commit bb3dfb92d5
266 changed files with 37043 additions and 0 deletions
+74
View File
@@ -0,0 +1,74 @@
import random
from pathlib import Path
from PIL import Image
import torchvision.transforms as T
from torch.utils.data import Dataset
class GeneratorDataset(Dataset):
"""Unlabeled image dataset for generative model training.
Loads images from source subdirectories and returns tensors only —
no labels, since generation is unsupervised.
"""
def __init__(self, data_dir, sources=None, subsample=1.0, transform=None, seed=42):
self.transform = transform
self.samples = []
# Accept either a single root or a list of roots (used by 1D to mix
# raw + aligned crops in one dataset).
roots = [data_dir] if isinstance(data_dir, (str, Path)) else list(data_dir)
if sources is None:
sources = ["wiki"]
for root in roots:
root = Path(root)
if not root.exists():
raise FileNotFoundError(f"Dataset root not found: {root}")
for source in sources:
source_dir = root / source
if not source_dir.exists():
raise FileNotFoundError(f"Missing source directory: {source_dir}")
for subdir in sorted(source_dir.iterdir()):
if subdir.is_dir():
for img_path in sorted(subdir.glob("*.jpg")):
self.samples.append(img_path)
if subsample < 1.0:
rng = random.Random(seed)
n = max(1, int(len(self.samples) * subsample))
self.samples = rng.sample(self.samples, n)
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
img = Image.open(self.samples[idx]).convert("RGB")
if self.transform:
img = self.transform(img)
return img
def get_transform(image_size: int, augment: bool = False) -> T.Compose:
"""Build transform for generator training. Output is in [-1, 1].
augment=True adds horizontal flip + mild rotation + mild color jitter.
Use augment=False for validation / FID real-image sets.
"""
ops = [
T.Resize(image_size),
T.CenterCrop(image_size),
]
if augment:
ops += [
T.RandomHorizontalFlip(p=0.5),
T.RandomRotation(degrees=5, interpolation=T.InterpolationMode.BILINEAR),
T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.05),
]
ops += [
T.ToTensor(),
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), # -> [-1, 1]
]
return T.Compose(ops)