Skip to content

Commit d35fc40

Browse files
authored
Merge pull request #4 from or4k2l/copilot/fix-batch-unpacking-error
Fix 7 critical training pipeline bugs blocking Colab execution
2 parents 00de221 + 6779ca7 commit d35fc40

File tree

6 files changed

+541
-38
lines changed

6 files changed

+541
-38
lines changed

scripts/train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,9 @@ def main():
101101
augment=True
102102
)
103103

104-
# Load datasets
105-
train_ds = data_loader.get_train_loader()
106-
eval_ds = data_loader.get_test_loader()
104+
# Load datasets (TF datasets, not iterators, so they can be reused)
105+
train_ds = data_loader.create_dataset('train', repeat=False)
106+
eval_ds = data_loader.create_dataset('test', repeat=False)
107107

108108
logger.info(f"Training dataset loaded")
109109
logger.info(f"Evaluation dataset loaded")

src/robust_vision/data/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""Data loading and augmentation utilities."""
2+
3+
from .loaders import ScalableDataLoader
4+
from .noise import NoiseLibrary
5+
6+
__all__ = [
7+
"ScalableDataLoader",
8+
"NoiseLibrary",
9+
]

src/robust_vision/data/loaders.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
"""Scalable data loaders for vision datasets."""
2+
3+
from typing import Optional, Tuple
4+
import tensorflow as tf
5+
import tensorflow_datasets as tfds
6+
import numpy as np
7+
8+
9+
class ScalableDataLoader:
10+
"""
11+
Scalable data loader using TensorFlow Datasets.
12+
13+
Features:
14+
- Handles multiple datasets (CIFAR-10, CIFAR-100, ImageNet, etc.)
15+
- Automatic batching and prefetching
16+
- Data augmentation
17+
- Caching for performance
18+
"""
19+
20+
def __init__(
21+
self,
22+
dataset_name: str = "cifar10",
23+
batch_size: int = 32,
24+
image_size: Tuple[int, int] = (32, 32),
25+
cache: bool = True,
26+
prefetch: bool = True,
27+
augment: bool = False
28+
):
29+
"""
30+
Initialize data loader.
31+
32+
Args:
33+
dataset_name: Name of TFDS dataset
34+
batch_size: Batch size
35+
image_size: Target image size (height, width)
36+
cache: Whether to cache dataset
37+
prefetch: Whether to prefetch batches
38+
augment: Whether to apply data augmentation
39+
"""
40+
self.dataset_name = dataset_name
41+
self.batch_size = batch_size
42+
self.image_size = image_size
43+
self.cache = cache
44+
self.prefetch = prefetch
45+
self.augment = augment
46+
47+
def preprocess(self, image, label, training: bool = False):
48+
"""
49+
Preprocess a single example.
50+
51+
Args:
52+
image: Input image tensor
53+
label: Label tensor
54+
training: Whether in training mode
55+
56+
Returns:
57+
Preprocessed (image, label) tuple
58+
"""
59+
# Resize if needed
60+
if image.shape[:2] != self.image_size:
61+
image = tf.image.resize(image, self.image_size)
62+
63+
# Normalize to [0, 1]
64+
image = tf.cast(image, tf.float32) / 255.0
65+
66+
# Data augmentation for training
67+
if training and self.augment:
68+
# Random horizontal flip
69+
image = tf.image.random_flip_left_right(image)
70+
71+
# Random brightness
72+
image = tf.image.random_brightness(image, 0.1)
73+
74+
# Random contrast
75+
image = tf.image.random_contrast(image, 0.9, 1.1)
76+
77+
return image, label
78+
79+
def load_dataset(self, split: str):
80+
"""
81+
Load dataset split.
82+
83+
Args:
84+
split: Dataset split ('train', 'test', etc.)
85+
86+
Returns:
87+
TensorFlow dataset
88+
"""
89+
# Use as_supervised=True to get (image, label) tuples directly
90+
ds = tfds.load(
91+
self.dataset_name,
92+
split=split,
93+
as_supervised=True,
94+
shuffle_files=(split == 'train')
95+
)
96+
97+
return ds
98+
99+
def create_dataset(self, split: str, repeat: bool = False):
100+
"""
101+
Create preprocessed and batched dataset.
102+
103+
Args:
104+
split: Dataset split ('train' or 'test')
105+
repeat: Whether to repeat the dataset infinitely
106+
107+
Returns:
108+
Batched TensorFlow dataset
109+
"""
110+
ds = self.load_dataset(split)
111+
112+
# Shuffle for training
113+
if split == 'train':
114+
ds = ds.shuffle(10000)
115+
116+
# Preprocess - with as_supervised=True, dataset yields (image, label) tuples
117+
training = (split == 'train')
118+
ds = ds.map(
119+
lambda image, label: self.preprocess(image, label, training=training),
120+
num_parallel_calls=tf.data.AUTOTUNE
121+
)
122+
123+
# Cache before batching for better performance
124+
if self.cache:
125+
ds = ds.cache()
126+
127+
# Repeat if requested
128+
if repeat:
129+
ds = ds.repeat()
130+
131+
# Batch
132+
ds = ds.batch(self.batch_size)
133+
134+
# Prefetch
135+
if self.prefetch:
136+
ds = ds.prefetch(tf.data.AUTOTUNE)
137+
138+
return ds
139+
140+
def to_numpy_iterator(self, dataset):
141+
"""
142+
Convert TensorFlow dataset to numpy iterator.
143+
144+
Args:
145+
dataset: TensorFlow dataset
146+
147+
Returns:
148+
Iterator yielding dict with 'image' and 'label' keys
149+
"""
150+
for images, labels in dataset:
151+
yield {
152+
'image': images.numpy(),
153+
'label': labels.numpy()
154+
}
155+
156+
def get_train_loader(self):
157+
"""Get a fresh training data iterator."""
158+
train_ds = self.create_dataset('train', repeat=False)
159+
return self.to_numpy_iterator(train_ds)
160+
161+
def get_test_loader(self):
162+
"""Get a fresh validation/test data iterator."""
163+
test_ds = self.create_dataset('test', repeat=False)
164+
return self.to_numpy_iterator(test_ds)

0 commit comments

Comments
 (0)