Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 124 additions & 76 deletions src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
limitations under the License.
"""

import contextlib
import math
from typing import Any, Dict, Optional, Tuple

Expand All @@ -21,7 +22,6 @@
import jax
from jax.ad_checkpoint import checkpoint_name
import jax.numpy as jnp
from jax.sharding import PartitionSpec

from .... import common_types
from ....configuration_utils import register_to_config
Expand Down Expand Up @@ -62,8 +62,12 @@ def __init__(
precision: jax.lax.Precision | None = None,
attention: str = "dot_product",
dropout: float = 0.0,
mask_padding_tokens: bool = True,
enable_jax_named_scopes: bool = False,
apply_input_projection: bool = False,
apply_output_projection: bool = False,
use_base2_exp: bool = False,
use_experimental_scheduler: bool = False,
):
"""Sets up the model.

Expand All @@ -90,7 +94,7 @@ def __init__(
apply_output_projection: Whether to apply an output projection before
outputting the result.
"""

self.enable_jax_named_scopes = enable_jax_named_scopes
self.apply_input_projection = apply_input_projection
self.apply_output_projection = apply_output_projection

Expand Down Expand Up @@ -124,7 +128,12 @@ def __init__(
precision=precision,
attention_kernel=attention,
dropout=dropout,
is_self_attention=True,
mask_padding_tokens=mask_padding_tokens,
residual_checkpoint_name="self_attn",
enable_jax_named_scopes=enable_jax_named_scopes,
use_base2_exp=use_base2_exp,
use_experimental_scheduler=use_experimental_scheduler,
)

# 3. Cross-attention
Expand All @@ -143,7 +152,12 @@ def __init__(
precision=precision,
attention_kernel=attention,
dropout=dropout,
is_self_attention=False,
mask_padding_tokens=mask_padding_tokens,
residual_checkpoint_name="cross_attn",
enable_jax_named_scopes=enable_jax_named_scopes,
use_base2_exp=use_base2_exp,
use_experimental_scheduler=use_experimental_scheduler,
)
assert cross_attn_norm is True, "cross_attn_norm must be True"
self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True)
Expand All @@ -158,6 +172,7 @@ def __init__(
weights_dtype=weights_dtype,
precision=precision,
dropout=dropout,
enable_jax_named_scopes=enable_jax_named_scopes,
)

self.norm3 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False)
Expand All @@ -180,6 +195,10 @@ def __init__(
jax.random.normal(key, (1, 6, dim)) / dim**0.5,
)

def conditional_named_scope(self, name: str):
"""Return a JAX named scope if enabled, otherwise a null context."""
return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext()

def __call__(
self,
*,
Expand All @@ -191,65 +210,74 @@ def __call__(
deterministic: bool = True,
rngs: nnx.Rngs | None = None,
) -> Tuple[jax.Array, jax.Array]:
if self.apply_input_projection:
control_hidden_states = self.proj_in(control_hidden_states)
control_hidden_states = control_hidden_states + hidden_states

shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
)

control_hidden_states = jax.lax.with_sharding_constraint(
control_hidden_states,
PartitionSpec("data", "fsdp", "tensor"),
)
control_hidden_states = checkpoint_name(control_hidden_states, "control_hidden_states")
encoder_hidden_states = jax.lax.with_sharding_constraint(
encoder_hidden_states,
PartitionSpec("data", "fsdp", None),
)

# 1. Self-attention
with jax.named_scope("attn1"):
norm_hidden_states = (self.norm1(control_hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(
control_hidden_states.dtype
with self.conditional_named_scope("vace_transformer_block"):
with self.conditional_named_scope("input_projection"):
if self.apply_input_projection:
control_hidden_states = self.proj_in(control_hidden_states)
control_hidden_states = control_hidden_states + hidden_states

shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
)
attn_output = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_hidden_states,
rotary_emb=rotary_emb,
deterministic=deterministic,
rngs=rngs,
)
control_hidden_states = (control_hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(
control_hidden_states.dtype
)

# 2. Cross-attention
with jax.named_scope("attn2"):
norm_hidden_states = self.norm2(control_hidden_states.astype(jnp.float32)).astype(control_hidden_states.dtype)
attn_output = self.attn2(
hidden_states=norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
deterministic=deterministic,
rngs=rngs,
)
control_hidden_states = control_hidden_states + attn_output

# 3. Feed-forward
with jax.named_scope("ffn"):
norm_hidden_states = (self.norm3(control_hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype(
control_hidden_states.dtype
)
ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs)
control_hidden_states = (
control_hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa
).astype(control_hidden_states.dtype)
conditioning_states = None
if self.apply_output_projection:
conditioning_states = self.proj_out(control_hidden_states)
axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_heads"))
control_hidden_states = jax.lax.with_sharding_constraint(control_hidden_states, axis_names)
control_hidden_states = checkpoint_name(control_hidden_states, "control_hidden_states")
axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_kv"))
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, axis_names)

# 1. Self-attention
with self.conditional_named_scope("self_attn"):
with self.conditional_named_scope("self_attn_norm"):
norm_hidden_states = (self.norm1(control_hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(
control_hidden_states.dtype
)
with self.conditional_named_scope("self_attn_attn"):
attn_output = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_hidden_states,
rotary_emb=rotary_emb,
deterministic=deterministic,
rngs=rngs,
)
with self.conditional_named_scope("self_attn_residual"):
control_hidden_states = (control_hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(
control_hidden_states.dtype
)

return conditioning_states, control_hidden_states
# 2. Cross-attention
with self.conditional_named_scope("cross_attn"):
with self.conditional_named_scope("cross_attn_norm"):
norm_hidden_states = self.norm2(control_hidden_states.astype(jnp.float32)).astype(control_hidden_states.dtype)
with self.conditional_named_scope("cross_attn_attn"):
attn_output = self.attn2(
hidden_states=norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
deterministic=deterministic,
rngs=rngs,
)
with self.conditional_named_scope("cross_attn_residual"):
control_hidden_states = control_hidden_states + attn_output

# 3. Feed-forward
with self.conditional_named_scope("mlp"):
with self.conditional_named_scope("mlp_norm"):
norm_hidden_states = (
self.norm3(control_hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa
).astype(control_hidden_states.dtype)
with self.conditional_named_scope("mlp_ffn"):
ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs)
with self.conditional_named_scope("mlp_residual"):
control_hidden_states = (
control_hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa
).astype(control_hidden_states.dtype)

with self.conditional_named_scope("output_projection"):
conditioning_states = None
if self.apply_output_projection:
conditioning_states = self.proj_out(control_hidden_states)

return conditioning_states, control_hidden_states


class WanVACEModel(WanModel):
Expand Down Expand Up @@ -289,7 +317,11 @@ def __init__(
remat_policy: str = "None",
names_which_can_be_saved: list[str] = [],
names_which_can_be_offloaded: list[str] = [],
mask_padding_tokens: bool = True,
scan_layers: bool = True,
enable_jax_named_scopes: bool = False,
use_base2_exp: bool = False,
use_experimental_scheduler: bool = False,
):
"""Initializes the VACE model.

Expand All @@ -302,6 +334,7 @@ def __init__(
out_channels = out_channels or in_channels
self.num_layers = num_layers
self.scan_layers = scan_layers
self.enable_jax_named_scopes = enable_jax_named_scopes

# 1. Patch & position embedding
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
Expand Down Expand Up @@ -329,6 +362,7 @@ def __init__(
text_embed_dim=text_dim,
image_embed_dim=image_dim,
pos_embed_seq_len=pos_embed_seq_len,
flash_min_seq_length=flash_min_seq_length,
)

self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy)
Expand Down Expand Up @@ -358,6 +392,10 @@ def __init__(
precision=precision,
attention=attention,
dropout=dropout,
mask_padding_tokens=mask_padding_tokens,
enable_jax_named_scopes=enable_jax_named_scopes,
use_base2_exp=use_base2_exp,
use_experimental_scheduler=use_experimental_scheduler,
)
blocks.append(block)
self.blocks = blocks
Expand All @@ -384,8 +422,12 @@ def __init__(
precision=precision,
attention=attention,
dropout=dropout,
mask_padding_tokens=mask_padding_tokens,
enable_jax_named_scopes=enable_jax_named_scopes,
apply_input_projection=vace_block_id == 0,
apply_output_projection=True,
use_base2_exp=use_base2_exp,
use_experimental_scheduler=use_experimental_scheduler,
)
vace_blocks.append(vace_block)
self.vace_blocks = vace_blocks
Expand Down Expand Up @@ -421,6 +463,10 @@ def __init__(
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, "embed")),
)

def conditional_named_scope(self, name: str):
"""Return a JAX named scope if enabled, otherwise a null context."""
return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext()

@jax.named_scope("WanVACEModel")
def __call__(
self,
Expand All @@ -436,7 +482,7 @@ def __call__(
rngs: nnx.Rngs = None,
) -> jax.Array:
hidden_states = nn.with_logical_constraint(hidden_states, ("batch", None, None, None, None))
batch_size, num_channels, num_frames, height, width = hidden_states.shape
batch_size, _, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.config.patch_size
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p_h
Expand All @@ -453,13 +499,14 @@ def __call__(

hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1))
control_hidden_states = jnp.transpose(control_hidden_states, (0, 2, 3, 4, 1))
rotary_emb = self.rope(hidden_states)

hidden_states = self.patch_embedding(hidden_states)
hidden_states = jax.lax.collapse(hidden_states, 1, -1)

control_hidden_states = self.vace_patch_embedding(control_hidden_states)
control_hidden_states = jax.lax.collapse(control_hidden_states, 1, -1)
with self.conditional_named_scope("rotary_embedding"):
rotary_emb = self.rope(hidden_states)
with self.conditional_named_scope("patch_embedding"):
hidden_states = self.patch_embedding(hidden_states)
hidden_states = jax.lax.collapse(hidden_states, 1, -1)

control_hidden_states = self.vace_patch_embedding(control_hidden_states)
control_hidden_states = jax.lax.collapse(control_hidden_states, 1, -1)
control_hidden_states_padding = jnp.zeros((
batch_size,
control_hidden_states.shape[1],
Expand All @@ -469,16 +516,17 @@ def __call__(
control_hidden_states = jnp.concatenate([control_hidden_states, control_hidden_states_padding], axis=2)

# Condition embedder is a FC layer.
(
temb,
timestep_proj,
encoder_hidden_states,
encoder_hidden_states_image,
_,
) = self.condition_embedder( # We will need to mask out the text embedding.
timestep, encoder_hidden_states, encoder_hidden_states_image
)
timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1)
with self.conditional_named_scope("condition_embedder"):
(
temb,
timestep_proj,
encoder_hidden_states,
encoder_hidden_states_image,
_,
) = self.condition_embedder( # We will need to mask out the text embedding.
timestep, encoder_hidden_states, encoder_hidden_states_image
)
timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1)

if encoder_hidden_states_image is not None:
raise NotImplementedError("img2vid is not yet implemented.")
Expand Down
1 change: 0 additions & 1 deletion src/maxdiffusion/pipelines/wan/wan_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,6 @@ def _create_common_components(cls, config, vae_only=False, i2v=False):
vae_devices_array = flat_devices.reshape(total_devices // vae_spatial, vae_spatial)

vae_mesh = Mesh(vae_devices_array, ("redundant", "vae_spatial"))
vae_mesh.vae_spatial_axis_name = "vae_spatial"
max_logging.log(
f"Created VAE specific mesh with axes ('redundant', 'vae_spatial') to support spatial sharding of {vae_spatial}."
)
Expand Down
Loading
Loading