Skip to content

hachoj/PixelArt-Alpha-Equinox-Grain

Repository files navigation

JaxFM (JAX Flow Matching)

Training pipeline for a PixArt-$\alpha$-style Diffusion Transformer (DiT) in JAX/Equinox, using flow matching on precomputed VAE latents.

This repo is intentionally practical and “systems-first”: it uses Grain for data loading, Orbax for checkpointing (including dataset iterator state), W&B logging, and JAX sharding for multi-device training.

Example sample (Stage 1)

What’s in here (status)

  • Stage 1 (class-conditional): trains a DiT on VAE latents conditioned on ImageNet-style class labels. This stage has been run to completion in this repo (see checkpoints/300000/); the image above is from Stage 1 sampling.
  • Stage 2 (text-conditional): trains a DiT conditioned on text embeddings from a frozen T5Gemma encoder. The pipeline is implemented, but training is still in progress / not yet validated to the same level as Stage 1.
  • Flow matching objective: model predicts a vector field $v(x_t, t, c)$ and samples are generated by solving an ODE with Diffrax.
  • Attention improvements: Q/K RMSNorm + jax.nn.dot_product_attention in the core attention module.
  • EMA weights: maintains an exponential moving average of trainable parameters for (typically) better sampling stability.
  • Mixed precision pattern: keeps FP32 master params and casts to bfloat16 for the forward/grad path.

Repository map

  • train_stage1.py: class-conditional training + sampling + checkpointing.
  • train_stage2.py: text-conditional training + gradient accumulation + optional init from a reparameterized model.
  • models/mmDiT/: DiT implementation (patchify, AdaLN-style conditioning, self-attn + cross-attn).
  • data/data.py: Grain ArrayRecordDataSource loaders + record parsing.
  • configs/: Hydra configs (data/model/train/optim/wandb).

Data format

Training uses Grain ArrayRecord files where each record is a pickled dict.

  • Stage 1 records contain:
    • latent: VAE latent tensor
    • label: integer class label
  • Stage 2 records contain:
    • latent: VAE latent tensor
    • short_caption: string
    • long_caption: string

See data/data.py for the exact parsing logic.

Quickstart

This project assumes you already have a JAX GPU environment working (CUDA/ROCm as appropriate).

Install Python dependencies:

pip install -r requirements.txt

JAX/JAXlib are intentionally not pinned in requirements.txt because the correct wheel depends on your platform (CPU vs CUDA vs ROCm). Install JAX following the official installation instructions for your setup.

Recommended environment variables (already set inside the training scripts):

export XLA_PYTHON_CLIENT_PREALLOCATE=false
export TF_GPU_ALLOCATOR=cuda_malloc_async

Run Stage 1:

python train_stage1.py

Run Stage 2:

python train_stage2.py

Hydra overrides work as usual:

python train_stage2.py train.batch_size=256 train.gradient_accum=8 wandb.enabled=false

Configuration

The default entrypoint config is configs/config.yaml, which composes:

  • configs/train/{stage1,stage2}.yaml
  • configs/data/{stage1,stage2}.yaml
  • configs/model/{mmdit,ccdit}.yaml
  • configs/optim/adamw.yaml
  • configs/wandb/{stage1,stage2}.yaml

If your dataset paths differ, update them via config overrides (or by editing the YAMLs).

Checkpointing (Orbax)

Checkpoints are stored under checkpoints/ by default and typically include:

  • state: (model, optimizer state)
  • model_ema: EMA model weights
  • dataset: Grain iterator checkpoint (so resuming doesn’t reshuffle/restart unexpectedly)

To resume, set train.is_restore=true.

VAE note (PyTorch / Diffusers)

The repo uses the Stable Diffusion 3 VAE (via diffusers) primarily for decoding latents for logging/samples. To avoid GPU memory issues, the training scripts instantiate the VAE on CPU (.to("cpu")).

Notes on current experimentation

Recent work (see development_log.md) has focused on convergence and stability improvements:

  • FP32 master params + BF16 compute
  • dataset latent mean/shift adjustments (train.latent_mean)
  • Q/K RMSNorm attention
  • EMA weighting

Some other ideas may be explored next, but aren’t necessarily implemented yet.

Citations / credit

If you use this repo or derive work from it, please cite the relevant papers in bib.bib.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors