Training pipeline for a PixArt-$\alpha$-style Diffusion Transformer (DiT) in JAX/Equinox, using flow matching on precomputed VAE latents.
This repo is intentionally practical and “systems-first”: it uses Grain for data loading, Orbax for checkpointing (including dataset iterator state), W&B logging, and JAX sharding for multi-device training.
-
Stage 1 (class-conditional): trains a DiT on VAE latents conditioned on ImageNet-style class labels. This stage has been run to completion in this repo (see
checkpoints/300000/); the image above is from Stage 1 sampling. - Stage 2 (text-conditional): trains a DiT conditioned on text embeddings from a frozen T5Gemma encoder. The pipeline is implemented, but training is still in progress / not yet validated to the same level as Stage 1.
-
Flow matching objective: model predicts a vector field
$v(x_t, t, c)$ and samples are generated by solving an ODE with Diffrax. -
Attention improvements: Q/K RMSNorm +
jax.nn.dot_product_attentionin the core attention module. - EMA weights: maintains an exponential moving average of trainable parameters for (typically) better sampling stability.
- Mixed precision pattern: keeps FP32 master params and casts to bfloat16 for the forward/grad path.
train_stage1.py: class-conditional training + sampling + checkpointing.train_stage2.py: text-conditional training + gradient accumulation + optional init from a reparameterized model.models/mmDiT/: DiT implementation (patchify, AdaLN-style conditioning, self-attn + cross-attn).data/data.py: GrainArrayRecordDataSourceloaders + record parsing.configs/: Hydra configs (data/model/train/optim/wandb).
Training uses Grain ArrayRecord files where each record is a pickled dict.
- Stage 1 records contain:
latent: VAE latent tensorlabel: integer class label
- Stage 2 records contain:
latent: VAE latent tensorshort_caption: stringlong_caption: string
See data/data.py for the exact parsing logic.
This project assumes you already have a JAX GPU environment working (CUDA/ROCm as appropriate).
Install Python dependencies:
pip install -r requirements.txtJAX/JAXlib are intentionally not pinned in requirements.txt because the correct wheel depends on your platform (CPU vs CUDA vs ROCm). Install JAX following the official installation instructions for your setup.
Recommended environment variables (already set inside the training scripts):
export XLA_PYTHON_CLIENT_PREALLOCATE=false
export TF_GPU_ALLOCATOR=cuda_malloc_asyncRun Stage 1:
python train_stage1.pyRun Stage 2:
python train_stage2.pyHydra overrides work as usual:
python train_stage2.py train.batch_size=256 train.gradient_accum=8 wandb.enabled=falseThe default entrypoint config is configs/config.yaml, which composes:
configs/train/{stage1,stage2}.yamlconfigs/data/{stage1,stage2}.yamlconfigs/model/{mmdit,ccdit}.yamlconfigs/optim/adamw.yamlconfigs/wandb/{stage1,stage2}.yaml
If your dataset paths differ, update them via config overrides (or by editing the YAMLs).
Checkpoints are stored under checkpoints/ by default and typically include:
state: (model, optimizer state)model_ema: EMA model weightsdataset: Grain iterator checkpoint (so resuming doesn’t reshuffle/restart unexpectedly)
To resume, set train.is_restore=true.
The repo uses the Stable Diffusion 3 VAE (via diffusers) primarily for decoding latents for logging/samples.
To avoid GPU memory issues, the training scripts instantiate the VAE on CPU (.to("cpu")).
Recent work (see development_log.md) has focused on convergence and stability improvements:
- FP32 master params + BF16 compute
- dataset latent mean/shift adjustments (
train.latent_mean) - Q/K RMSNorm attention
- EMA weighting
Some other ideas may be explored next, but aren’t necessarily implemented yet.
If you use this repo or derive work from it, please cite the relevant papers in bib.bib.