Train a model with the baseline configuration:
python scripts/train.py --config configs/baseline.yamlConfiguration files are in YAML format. See configs/ for examples:
baseline.yaml- Standard training with label smoothingmargin_loss.yaml- Training with margin loss for higher confidence
model:
n_classes: 10 # Number of output classes
features: [64, 128, 256] # Feature dimensions per layer
dropout_rate: 0.3 # Dropout rate
use_residual: true # Use residual connections
training:
batch_size: 128
epochs: 30
learning_rate: 0.001
weight_decay: 0.0001
# Loss configuration
loss_type: "label_smoothing" # Options: label_smoothing, margin, focal, combined
label_smoothing: 0.1
margin: 1.0
margin_weight: 0.5
# EMA
ema_enabled: true
ema_decay: 0.99
# Data
dataset_name: "cifar10" # Options: cifar10, cifar100, imagenet2012
image_size: [32, 32]
# Paths
checkpoint_dir: "./checkpoints"
log_dir: "./logs"python scripts/train.py --config configs/baseline.yamlCreate a new config with n_classes: 100 and dataset_name: "cifar100".
model:
n_classes: 1000
features: [64, 128, 256, 512]
training:
dataset_name: "imagenet2012"
image_size: [224, 224]
batch_size: 64 # Reduce for large imagesTo use a custom dataset, modify src/robust_vision/data/loaders.py to add support for your dataset format.
python scripts/train.py \
--config configs/baseline.yaml \
--checkpoint ./checkpoints/checkpoint_15python scripts/train.py \
--config configs/baseline.yaml \
--experiment-name my_experimentpython scripts/train.py \
--config configs/baseline.yaml \
--seed 1234The trainer automatically detects and uses all available GPUs via JAX's pmap.
To restrict to specific GPUs:
CUDA_VISIBLE_DEVICES=0,1 python scripts/train.py --config configs/baseline.yamlBest for general use:
loss_type: "label_smoothing"
label_smoothing: 0.1 # 0 = no smoothing, 0.2 = more smoothingEncourages confident predictions:
loss_type: "margin"
margin: 2.0 # Higher = more separation between classesGood for imbalanced datasets:
loss_type: "focal"
alpha: 0.25
gamma: 2.0Best overall performance:
loss_type: "combined"
label_smoothing: 0.1
margin: 2.0
margin_weight: 1.0Adjust hyperparameters in your config file and retrain.
Run hyperparameter sweep:
python scripts/hyperparameter_sweep.py \
--output ./sweep_results \
--dataset cifar10 \
--epochs 10This will:
- Try different combinations of hyperparameters
- Train models for each configuration
- Save results to
sweep_results/ - Report the best configuration
Logs are saved to the log_dir specified in config:
logs/
├── experiment_name.log # Training log
└── experiment_name_metrics.jsonl # Metrics (JSONL format)
Checkpoints are saved periodically:
checkpoints/
├── checkpoint_5
├── checkpoint_10
├── best_checkpoint_18 # Best model by validation accuracy
└── final_checkpoint_30
from robust_vision.evaluation.visualization import plot_training_history
plot_training_history(
"logs/experiment_metrics.jsonl",
output_path="training_curves.png"
)Always start with configs/baseline.yaml and adjust from there.
Always enable EMA for better generalization:
ema_enabled: true
ema_decay: 0.99Use 0.1 for most tasks:
label_smoothing: 0.1Start with 1e-3 and adjust based on convergence:
- Too high: Loss oscillates or increases
- Too low: Very slow convergence
- Larger batch size: Faster training, more stable gradients
- Smaller batch size: Better generalization, less memory
Rule of thumb: As large as fits in memory.
For overfitting, try:
dropout_rate: 0.4 # Increase dropout
weight_decay: 0.001 # Increase weight decay
label_smoothing: 0.2 # Increase smoothing- Check GPU utilization:
nvidia-smi - Increase batch size if possible
- Enable prefetching (already enabled by default)
- Use smaller model for debugging
- Reduce learning rate
- Check for data issues (NaNs, extreme values)
- Use gradient clipping
- Visualize training curves
- Check if model is learning (loss should decrease)
- Try different learning rate
- Add more regularization if overfitting
- Run robustness evaluation: Evaluation Guide
- Deploy your model: DEPLOYMENT.md
- Try different architectures