Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
6957708
Add Stage-2 RAE DiT model, pipeline, and tooling
plugyawn Mar 8, 2026
d5bf0d9
Add RAE DiT resume-order verifier
plugyawn Mar 8, 2026
f422041
Add RAE DiT training smoke test
plugyawn Mar 8, 2026
626a89d
Sync RAE DiT stack with diffusers quality checks
plugyawn Mar 9, 2026
9c8f052
Add RAE DiT API docs
plugyawn Mar 9, 2026
8314f0e
Rename RAEDiTTransformer2DModel to RAEDiT2DModel
plugyawn Mar 9, 2026
aa47b26
Fix RAE DiT review regressions
plugyawn Mar 9, 2026
a2506ff
Remove RAE DiT validation helper scripts from PR
plugyawn Mar 9, 2026
bc1b237
Add RAE DiT training validation sampling
plugyawn Mar 9, 2026
054ae73
Align RAE DiT with diffusers patterns
plugyawn Mar 9, 2026
c871658
Localize RAE loading and drop unused guidance transformer
plugyawn Mar 11, 2026
fa2d665
Localize RAE nested init loading
plugyawn Mar 30, 2026
34d6351
Fix RAEDiT projector init and dtype handling
plugyawn Mar 30, 2026
afc2db7
Harden RAE DiT training schedule helpers
plugyawn Mar 30, 2026
b9d46ca
Harden RAE DiT conversion and pipeline helpers
plugyawn Mar 30, 2026
9354b55
Merge branch 'main' into rae-dit-training
kashif Mar 30, 2026
b16c31f
Merge branch 'main' into rae-dit-training
kashif Apr 19, 2026
c9770f6
Fix RAEDiT scheduler and conditioning edge cases
plugyawn Apr 20, 2026
a98c667
Add RAEDiT attention fusion support
plugyawn May 1, 2026
b6aad13
Merge branch 'main' into rae-dit-training
plugyawn May 2, 2026
49465e6
Fix RAEDiT pipeline export pos embed preservation
plugyawn May 18, 2026
337737a
Merge branch 'main' into rae-dit-training
plugyawn May 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,8 @@
title: PriorTransformer
- local: api/models/qwenimage_transformer2d
title: QwenImageTransformer2DModel
- local: api/models/rae_dit_transformer2d
title: RAEDiT2DModel
- local: api/models/sana_transformer2d
title: SanaTransformer2DModel
- local: api/models/sana_video_transformer3d
Expand Down Expand Up @@ -606,6 +608,8 @@
title: PRX
- local: api/pipelines/qwenimage
title: QwenImage
- local: api/pipelines/rae_dit
title: RAE DiT
- local: api/pipelines/sana
title: Sana
- local: api/pipelines/sana_sprint
Expand Down
32 changes: 32 additions & 0 deletions docs/source/en/api/models/rae_dit_transformer2d.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
<!-- Copyright 2026 The NYU Vision-X and HuggingFace Teams. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# RAEDiT2DModel

The `RAEDiT2DModel` is the Stage-2 latent diffusion transformer introduced in
[Diffusion Transformers with Representation Autoencoders](https://huggingface.co/papers/2510.11690).

Unlike DiT models that operate on VAE latents, this transformer denoises the latent space learned by
[`AutoencoderRAE`](./autoencoder_rae). It is designed to be used with [`FlowMatchEulerDiscreteScheduler`] and
decoded back to RGB with [`AutoencoderRAE`].

## Loading a pretrained transformer

```python
from diffusers import RAEDiT2DModel

transformer = RAEDiT2DModel.from_pretrained("path/to/converted-stage2-transformer")
```

## RAEDiT2DModel

[[autodoc]] RAEDiT2DModel
59 changes: 59 additions & 0 deletions docs/source/en/api/pipelines/rae_dit.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
<!-- Copyright 2026 The NYU Vision-X and HuggingFace Teams. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# RAE DiT

[Diffusion Transformers with Representation Autoencoders](https://huggingface.co/papers/2510.11690) introduces a
two-stage recipe: first train a representation autoencoder (RAE), then train a diffusion transformer on the resulting
latent space.

[`RAEDiTPipeline`] implements the Stage-2 class-conditional generator in Diffusers. It combines:

- [`RAEDiT2DModel`] for latent denoising
- [`FlowMatchEulerDiscreteScheduler`] for the denoising trajectory
- [`AutoencoderRAE`] for decoding latent samples to RGB images

> [!TIP]
> [`RAEDiTPipeline`] expects a Stage-2 checkpoint converted to Diffusers format together with a compatible
> [`AutoencoderRAE`] checkpoint.

## Loading a converted pipeline

```python
import torch
from diffusers import RAEDiTPipeline

pipe = RAEDiTPipeline.from_pretrained(
"path/to/converted-rae-dit-imagenet256",
torch_dtype=torch.bfloat16,
).to("cuda")

image = pipe(class_labels=[207], num_inference_steps=25).images[0]
image.save("golden_retriever.png")
```

If the converted pipeline includes an `id2label` mapping, you can also look up class ids by name:

```python
class_id = pipe.get_label_ids("golden retriever")[0]
image = pipe(class_labels=[class_id], num_inference_steps=25).images[0]
```

## RAEDiTPipeline

[[autodoc]] RAEDiTPipeline
- all
- __call__

## RAEDiTPipelineOutput

[[autodoc]] RAEDiTPipelineOutput
94 changes: 94 additions & 0 deletions examples/research_projects/rae_dit/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Training RAEDiT Stage 2

This folder contains the minimal Stage-2 follow-up for the RAE integration: training `RAEDiT2DModel` on top of a frozen `AutoencoderRAE`.

It is intentionally placed under `examples/research_projects/rae_dit/` rather than the top-level `examples/` trainers because this is still an experimental follow-up to the new RAE support.

## Current scope

This is a minimal full-finetuning scaffold, not a paper-complete training stack. It currently does the following:

- loads a frozen pretrained `AutoencoderRAE`
- encodes RGB images to normalized Stage-1 latents on the fly
- trains only the Stage-2 `RAEDiT2DModel`
- uses `FlowMatchEulerDiscreteScheduler` with the same shifted-sigma schedule shape already used elsewhere in `diffusers`
- consumes ImageFolder class ids as `class_labels`
- can generate validation samples through `RAEDiTPipeline` during training
- saves the trained transformer under `output_dir/transformer`
- saves the scheduler config under `output_dir/scheduler`
- writes `id2label.json` from the ImageFolder class mapping

It intentionally does not yet include:

- a latent-caching path
- autoguidance or the broader upstream transport stack
- exact upstream distributed training/runtime features

## Dataset format

The script expects an `ImageFolder`-compatible dataset:

```text
train_data_dir/
n01440764/
img_0001.jpeg
n01443537/
img_0002.jpeg
```

The folder names define the class labels used during Stage-2 training.

## Quickstart

```bash
accelerate launch examples/research_projects/rae_dit/train_rae_dit.py \
--pretrained_rae_model_name_or_path nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08 \
--train_data_dir /path/to/imagenet_like_folder \
--output_dir /tmp/rae-dit \
--resolution 256 \
--train_batch_size 8 \
--gradient_accumulation_steps 1 \
--gradient_checkpointing \
--learning_rate 1e-4 \
--lr_scheduler cosine \
--lr_warmup_steps 1000 \
--max_train_steps 200000 \
--mixed_precision bf16 \
--report_to wandb \
--allow_tf32
```

To emit validation samples during training, add:

```bash
--validation_steps 1000 \
--validation_class_label 207 \
--num_validation_images 4 \
--validation_num_inference_steps 25 \
--validation_guidance_scale 1.0
```

Validation images are written to `output_dir/validation/step-<global_step>/`.

If you already have a converted or partially trained Stage-2 checkpoint, resume from it with:

```bash
accelerate launch examples/research_projects/rae_dit/train_rae_dit.py \
--pretrained_rae_model_name_or_path nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08 \
--pretrained_transformer_model_name_or_path /path/to/previous \
--train_data_dir /path/to/imagenet_like_folder \
--output_dir /tmp/rae-dit-finetune \
--resolution 256 \
--train_batch_size 8 \
--max_train_steps 50000
```

The preferred input is the stage-2 root that contains sibling `transformer/` and `scheduler/` folders. A local
`.../transformer` path still works when there is a sibling `scheduler/` directory next to it.

## Notes

- The script derives a default flow shift from the latent dimensionality as `sqrt(latent_dim / time_shift_base)`, matching the upstream Stage-2 heuristic at a high level.
- The trainer assumes the selected `AutoencoderRAE` uses `reshape_to_2d=True`, because `RAEDiT2DModel` operates on 2D latent feature maps.
- Validation sampling uses a fresh scheduler cloned from the training config so sampling does not mutate the in-flight training scheduler state.
- This example is meant to land first as a training scaffold that matches the new Stage-2 model and export layout. A later follow-up can add cached latents and other training conveniences.
Loading
Loading