diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py b/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py index b8bfeb92b..f3548602b 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py @@ -13,6 +13,7 @@ limitations under the License. """ +import contextlib import math from typing import Any, Dict, Optional, Tuple @@ -21,7 +22,6 @@ import jax from jax.ad_checkpoint import checkpoint_name import jax.numpy as jnp -from jax.sharding import PartitionSpec from .... import common_types from ....configuration_utils import register_to_config @@ -62,8 +62,12 @@ def __init__( precision: jax.lax.Precision | None = None, attention: str = "dot_product", dropout: float = 0.0, + mask_padding_tokens: bool = True, + enable_jax_named_scopes: bool = False, apply_input_projection: bool = False, apply_output_projection: bool = False, + use_base2_exp: bool = False, + use_experimental_scheduler: bool = False, ): """Sets up the model. @@ -90,7 +94,7 @@ def __init__( apply_output_projection: Whether to apply an output projection before outputting the result. """ - + self.enable_jax_named_scopes = enable_jax_named_scopes self.apply_input_projection = apply_input_projection self.apply_output_projection = apply_output_projection @@ -124,7 +128,12 @@ def __init__( precision=precision, attention_kernel=attention, dropout=dropout, + is_self_attention=True, + mask_padding_tokens=mask_padding_tokens, residual_checkpoint_name="self_attn", + enable_jax_named_scopes=enable_jax_named_scopes, + use_base2_exp=use_base2_exp, + use_experimental_scheduler=use_experimental_scheduler, ) # 3. Cross-attention @@ -143,7 +152,12 @@ def __init__( precision=precision, attention_kernel=attention, dropout=dropout, + is_self_attention=False, + mask_padding_tokens=mask_padding_tokens, residual_checkpoint_name="cross_attn", + enable_jax_named_scopes=enable_jax_named_scopes, + use_base2_exp=use_base2_exp, + use_experimental_scheduler=use_experimental_scheduler, ) assert cross_attn_norm is True, "cross_attn_norm must be True" self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True) @@ -158,6 +172,7 @@ def __init__( weights_dtype=weights_dtype, precision=precision, dropout=dropout, + enable_jax_named_scopes=enable_jax_named_scopes, ) self.norm3 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False) @@ -180,6 +195,10 @@ def __init__( jax.random.normal(key, (1, 6, dim)) / dim**0.5, ) + def conditional_named_scope(self, name: str): + """Return a JAX named scope if enabled, otherwise a null context.""" + return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext() + def __call__( self, *, @@ -191,65 +210,74 @@ def __call__( deterministic: bool = True, rngs: nnx.Rngs | None = None, ) -> Tuple[jax.Array, jax.Array]: - if self.apply_input_projection: - control_hidden_states = self.proj_in(control_hidden_states) - control_hidden_states = control_hidden_states + hidden_states - - shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( - (self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1 - ) - - control_hidden_states = jax.lax.with_sharding_constraint( - control_hidden_states, - PartitionSpec("data", "fsdp", "tensor"), - ) - control_hidden_states = checkpoint_name(control_hidden_states, "control_hidden_states") - encoder_hidden_states = jax.lax.with_sharding_constraint( - encoder_hidden_states, - PartitionSpec("data", "fsdp", None), - ) - - # 1. Self-attention - with jax.named_scope("attn1"): - norm_hidden_states = (self.norm1(control_hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype( - control_hidden_states.dtype + with self.conditional_named_scope("vace_transformer_block"): + with self.conditional_named_scope("input_projection"): + if self.apply_input_projection: + control_hidden_states = self.proj_in(control_hidden_states) + control_hidden_states = control_hidden_states + hidden_states + + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( + (self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1 ) - attn_output = self.attn1( - hidden_states=norm_hidden_states, - encoder_hidden_states=norm_hidden_states, - rotary_emb=rotary_emb, - deterministic=deterministic, - rngs=rngs, - ) - control_hidden_states = (control_hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype( - control_hidden_states.dtype - ) - - # 2. Cross-attention - with jax.named_scope("attn2"): - norm_hidden_states = self.norm2(control_hidden_states.astype(jnp.float32)).astype(control_hidden_states.dtype) - attn_output = self.attn2( - hidden_states=norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - deterministic=deterministic, - rngs=rngs, - ) - control_hidden_states = control_hidden_states + attn_output - # 3. Feed-forward - with jax.named_scope("ffn"): - norm_hidden_states = (self.norm3(control_hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype( - control_hidden_states.dtype - ) - ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs) - control_hidden_states = ( - control_hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa - ).astype(control_hidden_states.dtype) - conditioning_states = None - if self.apply_output_projection: - conditioning_states = self.proj_out(control_hidden_states) + axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_heads")) + control_hidden_states = jax.lax.with_sharding_constraint(control_hidden_states, axis_names) + control_hidden_states = checkpoint_name(control_hidden_states, "control_hidden_states") + axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_kv")) + encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, axis_names) + + # 1. Self-attention + with self.conditional_named_scope("self_attn"): + with self.conditional_named_scope("self_attn_norm"): + norm_hidden_states = (self.norm1(control_hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype( + control_hidden_states.dtype + ) + with self.conditional_named_scope("self_attn_attn"): + attn_output = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + rotary_emb=rotary_emb, + deterministic=deterministic, + rngs=rngs, + ) + with self.conditional_named_scope("self_attn_residual"): + control_hidden_states = (control_hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype( + control_hidden_states.dtype + ) - return conditioning_states, control_hidden_states + # 2. Cross-attention + with self.conditional_named_scope("cross_attn"): + with self.conditional_named_scope("cross_attn_norm"): + norm_hidden_states = self.norm2(control_hidden_states.astype(jnp.float32)).astype(control_hidden_states.dtype) + with self.conditional_named_scope("cross_attn_attn"): + attn_output = self.attn2( + hidden_states=norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + deterministic=deterministic, + rngs=rngs, + ) + with self.conditional_named_scope("cross_attn_residual"): + control_hidden_states = control_hidden_states + attn_output + + # 3. Feed-forward + with self.conditional_named_scope("mlp"): + with self.conditional_named_scope("mlp_norm"): + norm_hidden_states = ( + self.norm3(control_hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa + ).astype(control_hidden_states.dtype) + with self.conditional_named_scope("mlp_ffn"): + ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs) + with self.conditional_named_scope("mlp_residual"): + control_hidden_states = ( + control_hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa + ).astype(control_hidden_states.dtype) + + with self.conditional_named_scope("output_projection"): + conditioning_states = None + if self.apply_output_projection: + conditioning_states = self.proj_out(control_hidden_states) + + return conditioning_states, control_hidden_states class WanVACEModel(WanModel): @@ -289,7 +317,11 @@ def __init__( remat_policy: str = "None", names_which_can_be_saved: list[str] = [], names_which_can_be_offloaded: list[str] = [], + mask_padding_tokens: bool = True, scan_layers: bool = True, + enable_jax_named_scopes: bool = False, + use_base2_exp: bool = False, + use_experimental_scheduler: bool = False, ): """Initializes the VACE model. @@ -302,6 +334,7 @@ def __init__( out_channels = out_channels or in_channels self.num_layers = num_layers self.scan_layers = scan_layers + self.enable_jax_named_scopes = enable_jax_named_scopes # 1. Patch & position embedding self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) @@ -329,6 +362,7 @@ def __init__( text_embed_dim=text_dim, image_embed_dim=image_dim, pos_embed_seq_len=pos_embed_seq_len, + flash_min_seq_length=flash_min_seq_length, ) self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy) @@ -358,6 +392,10 @@ def __init__( precision=precision, attention=attention, dropout=dropout, + mask_padding_tokens=mask_padding_tokens, + enable_jax_named_scopes=enable_jax_named_scopes, + use_base2_exp=use_base2_exp, + use_experimental_scheduler=use_experimental_scheduler, ) blocks.append(block) self.blocks = blocks @@ -384,8 +422,12 @@ def __init__( precision=precision, attention=attention, dropout=dropout, + mask_padding_tokens=mask_padding_tokens, + enable_jax_named_scopes=enable_jax_named_scopes, apply_input_projection=vace_block_id == 0, apply_output_projection=True, + use_base2_exp=use_base2_exp, + use_experimental_scheduler=use_experimental_scheduler, ) vace_blocks.append(vace_block) self.vace_blocks = vace_blocks @@ -421,6 +463,10 @@ def __init__( kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, "embed")), ) + def conditional_named_scope(self, name: str): + """Return a JAX named scope if enabled, otherwise a null context.""" + return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext() + @jax.named_scope("WanVACEModel") def __call__( self, @@ -436,7 +482,7 @@ def __call__( rngs: nnx.Rngs = None, ) -> jax.Array: hidden_states = nn.with_logical_constraint(hidden_states, ("batch", None, None, None, None)) - batch_size, num_channels, num_frames, height, width = hidden_states.shape + batch_size, _, num_frames, height, width = hidden_states.shape p_t, p_h, p_w = self.config.patch_size post_patch_num_frames = num_frames // p_t post_patch_height = height // p_h @@ -453,13 +499,14 @@ def __call__( hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1)) control_hidden_states = jnp.transpose(control_hidden_states, (0, 2, 3, 4, 1)) - rotary_emb = self.rope(hidden_states) - - hidden_states = self.patch_embedding(hidden_states) - hidden_states = jax.lax.collapse(hidden_states, 1, -1) - - control_hidden_states = self.vace_patch_embedding(control_hidden_states) - control_hidden_states = jax.lax.collapse(control_hidden_states, 1, -1) + with self.conditional_named_scope("rotary_embedding"): + rotary_emb = self.rope(hidden_states) + with self.conditional_named_scope("patch_embedding"): + hidden_states = self.patch_embedding(hidden_states) + hidden_states = jax.lax.collapse(hidden_states, 1, -1) + + control_hidden_states = self.vace_patch_embedding(control_hidden_states) + control_hidden_states = jax.lax.collapse(control_hidden_states, 1, -1) control_hidden_states_padding = jnp.zeros(( batch_size, control_hidden_states.shape[1], @@ -469,16 +516,17 @@ def __call__( control_hidden_states = jnp.concatenate([control_hidden_states, control_hidden_states_padding], axis=2) # Condition embedder is a FC layer. - ( - temb, - timestep_proj, - encoder_hidden_states, - encoder_hidden_states_image, - _, - ) = self.condition_embedder( # We will need to mask out the text embedding. - timestep, encoder_hidden_states, encoder_hidden_states_image - ) - timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1) + with self.conditional_named_scope("condition_embedder"): + ( + temb, + timestep_proj, + encoder_hidden_states, + encoder_hidden_states_image, + _, + ) = self.condition_embedder( # We will need to mask out the text embedding. + timestep, encoder_hidden_states, encoder_hidden_states_image + ) + timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1) if encoder_hidden_states_image is not None: raise NotImplementedError("img2vid is not yet implemented.") diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 031fe2fe0..50a82607b 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -630,7 +630,6 @@ def _create_common_components(cls, config, vae_only=False, i2v=False): vae_devices_array = flat_devices.reshape(total_devices // vae_spatial, vae_spatial) vae_mesh = Mesh(vae_devices_array, ("redundant", "vae_spatial")) - vae_mesh.vae_spatial_axis_name = "vae_spatial" max_logging.log( f"Created VAE specific mesh with axes ('redundant', 'vae_spatial') to support spatial sharding of {vae_spatial}." ) diff --git a/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py b/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py index aeaf7b617..ac721189c 100644 --- a/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py +++ b/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py @@ -24,7 +24,6 @@ from flax.linen import partitioning as nn_partitioning from ...pyconfig import HyperParameters from ... import max_logging -from ... import max_utils from ...image_processor import PipelineImageInput from ...max_utils import get_flash_block_sizes, get_precision, device_put_replicated from ...models.wan.wan_utils import load_wan_transformer @@ -81,6 +80,11 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): wan_config["names_which_can_be_offloaded"] = config.names_which_can_be_offloaded wan_config["flash_min_seq_length"] = config.flash_min_seq_length wan_config["dropout"] = config.dropout + wan_config["mask_padding_tokens"] = config.mask_padding_tokens + wan_config["enable_jax_named_scopes"] = config.enable_jax_named_scopes + wan_config["use_base2_exp"] = config.use_base2_exp + wan_config["use_experimental_scheduler"] = config.use_experimental_scheduler + wan_config["scan_layers"] = False # 2. eval_shape - will not use flops or create weights on device @@ -342,51 +346,33 @@ def _load_and_init( restored_checkpoint=None, vae_only=False, load_transformer=True, - load_common_components=True, ): - devices_array = max_utils.create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - rng = jax.random.key(config.seed) - rngs = nnx.Rngs(rng) + common_components = cls._create_common_components(config, vae_only) transformer = None - tokenizer = None - scheduler = None - scheduler_state = None - text_encoder = None - wan_vae = None - vae_cache = None if not vae_only: if load_transformer: - with mesh: - transformer = cls.load_transformer( - devices_array=devices_array, - mesh=mesh, - rngs=rngs, - config=config, - restored_checkpoint=restored_checkpoint, - subfolder="transformer", - ) - if load_common_components: - text_encoder = cls.load_text_encoder(config=config) - tokenizer = cls.load_tokenizer(config=config) - - scheduler, scheduler_state = cls.load_scheduler(config=config) - - if load_common_components: - with mesh: - wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) + transformer = cls.load_transformer( + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + rngs=common_components["rngs"], + config=config, + restored_checkpoint=restored_checkpoint, + subfolder="transformer", + ) pipeline = cls( - tokenizer=tokenizer, - text_encoder=text_encoder, + tokenizer=common_components["tokenizer"], + text_encoder=common_components["text_encoder"], transformer=transformer, - vae=wan_vae, - vae_cache=vae_cache, - scheduler=scheduler, - scheduler_state=scheduler_state, - devices_array=devices_array, - mesh=mesh, + vae=common_components["vae"], + vae_cache=common_components["vae_cache"], + scheduler=common_components["scheduler"], + scheduler_state=common_components["scheduler_state"], + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + vae_mesh=common_components["vae_mesh"], + vae_logical_axis_rules=common_components["vae_logical_axis_rules"], config=config, ) @@ -398,9 +384,8 @@ def from_pretrained( config: HyperParameters, vae_only=False, load_transformer=True, - load_common_components=True, ): - pipeline = cls._load_and_init(config, None, vae_only, load_transformer, load_common_components) + pipeline = cls._load_and_init(config, None, vae_only, load_transformer) pipeline.transformer = cls.quantize_transformer(config, pipeline.transformer, pipeline, pipeline.mesh) return pipeline @@ -411,14 +396,12 @@ def from_checkpoint( restored_checkpoint=None, vae_only=False, load_transformer=True, - load_common_components=True, ): pipeline = cls._load_and_init( config, restored_checkpoint, vae_only, load_transformer, - load_common_components, ) pipeline.transformer = cls.quantize_transformer(config, pipeline.transformer, pipeline, pipeline.mesh) return pipeline @@ -512,9 +495,9 @@ def __call__( guidance_scale: float = 5.0, num_videos_per_prompt: Optional[int] = 1, max_sequence_length: int = 512, - latents: jax.Array = None, - prompt_embeds: jax.Array = None, - negative_prompt_embeds: jax.Array = None, + latents: jax.Array | None = None, + prompt_embeds: jax.Array | None = None, + negative_prompt_embeds: jax.Array | None = None, vae_only: bool = False, ): """Runs the VACE model for the given inputs.