diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 4f9bb8c8ef..233e927d6e 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -1479,8 +1479,8 @@ def attention_as_linen( max_target_length: int, mesh: Mesh, attention_kernel: str, - inputs_q: Array, - inputs_kv: Array, + inputs_q_shape: Tuple, + inputs_kv_shape: Tuple, dtype: DType = jnp.float32, weight_dtype: DType = jnp.float32, max_prefill_predict_length: int = -1, @@ -1540,8 +1540,8 @@ def attention_as_linen( max_target_length=max_target_length, mesh=mesh, attention_kernel=attention_kernel, - inputs_q=inputs_q, - inputs_kv=inputs_kv, + inputs_q_shape=inputs_q_shape, + inputs_kv_shape=inputs_kv_shape, dtype=dtype, weight_dtype=weight_dtype, max_prefill_predict_length=max_prefill_predict_length, @@ -1603,8 +1603,8 @@ class Attention(nnx.Module): max_target_length: Maximum sequence length. mesh: The device mesh. attention_kernel: The attention kernel to use (e.g., 'dot_product', 'flash'). - inputs_q: Dummy query inputs for initialization, required by NNX. - inputs_kv: Dummy key/value inputs for initialization, required by NNX. + inputs_q_shape: Query inputs shape for initialization, required by NNX. + inputs_kv_shape: Key/value inputs shape for initialization, required by NNX. dtype: The data type for computation. weight_dtype: The data type for weights. max_prefill_predict_length: Maximum length for prefill. @@ -1628,8 +1628,8 @@ def __init__( max_target_length: int, mesh: Mesh, attention_kernel: str, - inputs_q: Array, - inputs_kv: Array, + inputs_q_shape: Tuple, + inputs_kv_shape: Tuple, dtype: DType = jnp.float32, weight_dtype: DType = jnp.float32, max_prefill_predict_length: int = -1, @@ -1687,8 +1687,8 @@ def __init__( max_target_length: Maximum sequence length. mesh: The device mesh. attention_kernel: The attention kernel to use (e.g., 'dot_product', 'flash'). - inputs_q: Dummy query inputs for initialization, required by NNX. - inputs_kv: Dummy key/value inputs for initialization, required by NNX. + inputs_q_shape: Query inputs shape for initialization, required by NNX. + inputs_kv_shape: Key/value inputs shape for initialization, required by NNX. dtype: The data type for computation. weight_dtype: The data type for weights. max_prefill_predict_length: Maximum length for prefill. @@ -1767,7 +1767,7 @@ def __init__( # Module attribute names must match names previously passed to Linen for checkpointing self.KVCache_0 = ( - self.init_kv_caches(inputs_kv=inputs_kv) + self.init_kv_caches(inputs_kv_shape=inputs_kv_shape) if self.model_mode != MODEL_MODE_TRAIN and base_kv_cache else None ) @@ -1818,13 +1818,13 @@ def __init__( ) if self.config.fused_qkv: - self.qkv_proj = self.init_qkv_w(inputs=inputs_q, proj_name="qkv_proj") + self.qkv_proj = self.init_qkv_w(inputs_shape=inputs_q_shape) else: - self.query = self.init_query_w(inputs_q=inputs_q) - self.key = self.init_kv_w(inputs_kv=inputs_kv, proj_name="key") - self.value = self.init_kv_w(inputs_kv=inputs_kv, proj_name="value") + self.query = self.init_query_w(inputs_q_shape=inputs_q_shape) + self.key = self.init_kv_w(inputs_kv_shape=inputs_kv_shape) + self.value = self.init_kv_w(inputs_kv_shape=inputs_kv_shape) - self.out = self.init_out_w(output_dim=inputs_q.shape[-1], out=inputs_q) + self.out = self.init_out_w(output_dim=inputs_q_shape[-1]) is_llama4_decoder_block = self.config.decoder_block == DecoderBlockType.LLAMA4 if self.use_qk_norm and not is_llama4_decoder_block: @@ -1849,7 +1849,7 @@ def __init__( self.key_norm = None - def init_query_w(self, inputs_q: Array) -> nnx.Module: + def init_query_w(self, inputs_q_shape: Tuple) -> nnx.Module: """Query projection initialization.""" # NOTE: T5 does not explicitly rescale the attention logits by @@ -1865,7 +1865,7 @@ def query_init(*args): (None, None, None) if self.config.ici_context_autoregressive_parallelism > 1 else ("embed", "q_heads", "kv") ) return DenseGeneral( - in_features_shape=self.convert_dense_general_inputs_shape(inputs_q.shape), + in_features_shape=self.convert_dense_general_inputs_shape(inputs_q_shape), out_features_shape=(self.num_query_heads, self.head_dim), axis=-1, kernel_init=query_init, @@ -1883,12 +1883,11 @@ def query_projection(self, inputs_q: Array) -> Array: return self.query(inputs_q) - def init_kv_w(self, inputs_kv: Array, proj_name: str) -> nnx.Module: + def init_kv_w(self, inputs_kv_shape: Tuple) -> nnx.Module: """Initializes the key or value projection. Args: - inputs_kv: Dummy key/value inputs for initialization. - proj_name: The name of the projection ("key" or "value"). + inputs_kv_shape: Key/value inputs shape for initialization. Returns: A DenseGeneral module that performs the key or value projection. @@ -1906,7 +1905,7 @@ def init_kv_w(self, inputs_kv: Array, proj_name: str) -> nnx.Module: ) return DenseGeneral( - in_features_shape=self.convert_dense_general_inputs_shape(inputs_kv.shape), + in_features_shape=self.convert_dense_general_inputs_shape(inputs_kv_shape), out_features_shape=(self.num_kv_heads, self.head_dim), axis=-1, kernel_init=self.kernel_init, @@ -1941,9 +1940,9 @@ def kv_projection(self, inputs_kv: Array, proj_name: str) -> nnx.Module: else: raise ValueError(f"proj_name must be 'key' or 'value', but got {proj_name}") - def init_qkv_w(self, inputs: Array, proj_name: str) -> nnx.Module: + def init_qkv_w(self, inputs_shape: Tuple) -> nnx.Module: return DenseGeneral( - in_features_shape=self.convert_dense_general_inputs_shape(inputs.shape), + in_features_shape=self.convert_dense_general_inputs_shape(inputs_shape), out_features_shape=(3, self.num_query_heads, self.head_dim), axis=-1, kernel_init=self.kernel_init, @@ -1964,7 +1963,7 @@ def qkv_projection(self, inputs: Array, proj_name: str): query, key, value = qkv_proj[:, :, 0, ...], qkv_proj[:, :, 1, ...], qkv_proj[:, :, 2, ...] return query, key, value - def init_out_w(self, output_dim: int, out: Array) -> nnx.Module: + def init_out_w(self, output_dim: int) -> nnx.Module: """out projection""" out_kernel_axis = ( (None, None, None) if self.config.ici_context_autoregressive_parallelism > 1 else ("heads", "kv", "embed") @@ -2068,17 +2067,17 @@ def apply_rotary_embedding(self, inputs: Array, inputs_positions: Optional[Array """ return self.rotary_embedding(inputs, inputs_positions) - def init_kv_caches(self, inputs_kv: Array): + def init_kv_caches(self, inputs_kv_shape: Tuple): """Initializes KVCache. Args: - inputs_kv: Dummy key/value inputs for initialization. + inputs_kv_shape: Key/value inputs shape for initialization. Returns: A KVCache module instance. """ - batch_size, _, _ = inputs_kv.shape + batch_size, _, _ = inputs_kv_shape # During initialization, seq_len of inputs_kv is max_target_length, # which is not always correct for some functions in KVCache. # However, KVCache internal cache shapes are based on max_prefill_length @@ -2278,8 +2277,8 @@ def mla_as_linen( max_target_length: int, mesh: Mesh, attention_kernel: str, - inputs_q: Array, - inputs_kv: Array, + inputs_q_shape: Tuple, + inputs_kv_shape: Tuple, dtype: DType = jnp.float32, weight_dtype: DType = jnp.float32, max_prefill_predict_length: int = -1, @@ -2354,8 +2353,8 @@ def mla_as_linen( max_target_length=max_target_length, mesh=mesh, attention_kernel=attention_kernel, - inputs_q=inputs_q, - inputs_kv=inputs_kv, + inputs_q_shape=inputs_q_shape, + inputs_kv_shape=inputs_kv_shape, dtype=dtype, weight_dtype=weight_dtype, max_prefill_predict_length=max_prefill_predict_length, @@ -2421,8 +2420,8 @@ def __init__( max_target_length: int, mesh: Mesh, attention_kernel: str, - inputs_q: Array, - inputs_kv: Array, + inputs_q_shape: Tuple, + inputs_kv_shape: Tuple, dtype: DType = jnp.float32, weight_dtype: DType = jnp.float32, max_prefill_predict_length: int = -1, @@ -2508,8 +2507,8 @@ def __init__( max_target_length=max_target_length, mesh=mesh, attention_kernel=attention_kernel, - inputs_q=inputs_q, - inputs_kv=inputs_kv, + inputs_q_shape=inputs_q_shape, + inputs_kv_shape=inputs_kv_shape, dtype=dtype, weight_dtype=weight_dtype, max_prefill_predict_length=max_prefill_predict_length, @@ -2554,7 +2553,7 @@ def __init__( ) # Module attribute names must match names previously passed to Linen for checkpointing - self.MlaKVCache_0 = self.init_mla_kv_caches(inputs_kv) if model_mode != MODEL_MODE_TRAIN else None + self.MlaKVCache_0 = self.init_mla_kv_caches(inputs_kv_shape) if model_mode != MODEL_MODE_TRAIN else None # Assert required configuration parameters for MLA attention. assert ( @@ -2731,11 +2730,11 @@ def mla_get_key_value(self, low_rank_main, key_rope, model_mode): value = nn.with_logical_constraint(value, self.value_axis_names) return key, value - def init_mla_kv_caches(self, inputs_kv: Array): + def init_mla_kv_caches(self, inputs_kv_shape: Tuple): """Initializes MlaKVCache. Args: - inputs_kv: Dummy key/value inputs for initialization. + inputs_kv_shape: Key/value inputs shape for initialization. Returns: An MlaKVCache module instance. @@ -2744,7 +2743,7 @@ def init_mla_kv_caches(self, inputs_kv: Array): ValueError: If the configuration is invalid. """ - batch_size, _, _ = inputs_kv.shape + batch_size, _, _ = inputs_kv_shape # During initialization, seq_len of inputs_kv is max_target_length, # which is not always correct for some functions in MlaKVCache. # However, MlaKVCache internal cache shapes are based on max_prefill_length diff --git a/MaxText/layers/decoders.py b/MaxText/layers/decoders.py index 92f2828a4a..e203fcd4dd 100644 --- a/MaxText/layers/decoders.py +++ b/MaxText/layers/decoders.py @@ -116,8 +116,8 @@ def __call__( max_target_length=cfg.max_target_length, max_prefill_predict_length=cfg.max_prefill_predict_length, attention_kernel=cfg.attention, - inputs_q=lnx, - inputs_kv=lnx, + inputs_q_shape=lnx.shape, + inputs_kv_shape=lnx.shape, mesh=mesh, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, diff --git a/MaxText/layers/deepseek.py b/MaxText/layers/deepseek.py index f60f772127..f59e68146c 100644 --- a/MaxText/layers/deepseek.py +++ b/MaxText/layers/deepseek.py @@ -82,8 +82,8 @@ def self_attention_with_norm( max_target_length=cfg.max_target_length, max_prefill_predict_length=cfg.max_prefill_predict_length, attention_kernel=cfg.attention, - inputs_q=lnx, - inputs_kv=lnx, + inputs_q_shape=lnx.shape, + inputs_kv_shape=lnx.shape, mesh=mesh, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, diff --git a/MaxText/layers/gemma.py b/MaxText/layers/gemma.py index 7f1ea718c0..4a11ab3b91 100644 --- a/MaxText/layers/gemma.py +++ b/MaxText/layers/gemma.py @@ -74,8 +74,8 @@ def __call__( max_target_length=cfg.max_target_length, max_prefill_predict_length=cfg.max_prefill_predict_length, attention_kernel=cfg.attention, - inputs_q=lnx, - inputs_kv=lnx, + inputs_q_shape=lnx.shape, + inputs_kv_shape=lnx.shape, mesh=mesh, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, diff --git a/MaxText/layers/gemma2.py b/MaxText/layers/gemma2.py index 0a5595680c..764c873a20 100644 --- a/MaxText/layers/gemma2.py +++ b/MaxText/layers/gemma2.py @@ -78,8 +78,8 @@ def __call__( max_target_length=cfg.max_target_length, max_prefill_predict_length=cfg.max_prefill_predict_length, attention_kernel=cfg.attention, - inputs_q=lnx, - inputs_kv=lnx, + inputs_q_shape=lnx.shape, + inputs_kv_shape=lnx.shape, mesh=mesh, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, @@ -176,8 +176,8 @@ def __call__( max_target_length=cfg.max_target_length, max_prefill_predict_length=cfg.max_prefill_predict_length, attention_kernel=cfg.attention, - inputs_q=lnx, - inputs_kv=lnx, + inputs_q_shape=lnx.shape, + inputs_kv_shape=lnx.shape, mesh=mesh, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, diff --git a/MaxText/layers/gemma3.py b/MaxText/layers/gemma3.py index 43c81644ab..dfc53f4f5b 100644 --- a/MaxText/layers/gemma3.py +++ b/MaxText/layers/gemma3.py @@ -103,8 +103,8 @@ def __call__( max_target_length=cfg.max_target_length, max_prefill_predict_length=cfg.max_prefill_predict_length, attention_kernel=cfg.attention, - inputs_q=lnx, - inputs_kv=lnx, + inputs_q_shape=lnx.shape, + inputs_kv_shape=lnx.shape, mesh=mesh, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, diff --git a/MaxText/layers/llama2.py b/MaxText/layers/llama2.py index 283038c2e0..79c01cce01 100644 --- a/MaxText/layers/llama2.py +++ b/MaxText/layers/llama2.py @@ -92,8 +92,8 @@ def __call__( max_target_length=cfg.max_target_length, max_prefill_predict_length=cfg.max_prefill_predict_length, attention_kernel=cfg.attention, - inputs_q=lnx, - inputs_kv=lnx, + inputs_q_shape=lnx.shape, + inputs_kv_shape=lnx.shape, mesh=mesh, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, diff --git a/MaxText/layers/llama4.py b/MaxText/layers/llama4.py index 415c8859d6..256f3eb8ff 100644 --- a/MaxText/layers/llama4.py +++ b/MaxText/layers/llama4.py @@ -408,8 +408,8 @@ def __call__( max_target_length=cfg.max_target_length, max_prefill_predict_length=cfg.max_prefill_predict_length, attention_kernel=cfg.attention, - inputs_q=lnx, - inputs_kv=lnx, + inputs_q_shape=lnx.shape, + inputs_kv_shape=lnx.shape, mesh=mesh, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, @@ -628,8 +628,8 @@ def __call__( head_dim=self.config.hidden_size_for_vit // self.config.num_attention_heads_for_vit, max_target_length=(self.config.image_size_for_vit // self.config.patch_size_for_vit) ** 2 + 1, attention_kernel="dot_product", - inputs_q=hidden_states, - inputs_kv=hidden_states, + inputs_q_shape=hidden_states.shape, + inputs_kv_shape=hidden_states.shape, float32_qk_product=self.config.float32_qk_product, float32_logits=self.config.float32_logits, mesh=self.mesh, diff --git a/MaxText/layers/mistral.py b/MaxText/layers/mistral.py index d85f276c0d..8a8276eccb 100644 --- a/MaxText/layers/mistral.py +++ b/MaxText/layers/mistral.py @@ -85,8 +85,8 @@ def __call__( max_target_length=cfg.max_target_length, max_prefill_predict_length=cfg.max_prefill_predict_length, attention_kernel=cfg.attention, - inputs_q=lnx, - inputs_kv=lnx, + inputs_q_shape=lnx.shape, + inputs_kv_shape=lnx.shape, mesh=mesh, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, diff --git a/MaxText/layers/mixtral.py b/MaxText/layers/mixtral.py index 1d3c580603..d5dc65b23a 100644 --- a/MaxText/layers/mixtral.py +++ b/MaxText/layers/mixtral.py @@ -86,8 +86,8 @@ def __call__( max_target_length=cfg.max_target_length, max_prefill_predict_length=cfg.max_prefill_predict_length, attention_kernel=cfg.attention, - inputs_q=lnx, - inputs_kv=lnx, + inputs_q_shape=lnx.shape, + inputs_kv_shape=lnx.shape, mesh=mesh, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, diff --git a/MaxText/layers/qwen3.py b/MaxText/layers/qwen3.py index 1e2de23faf..9bb6b502c1 100644 --- a/MaxText/layers/qwen3.py +++ b/MaxText/layers/qwen3.py @@ -82,8 +82,8 @@ def __call__( max_target_length=cfg.max_target_length, max_prefill_predict_length=cfg.max_prefill_predict_length, attention_kernel=cfg.attention, - inputs_q=lnx, - inputs_kv=lnx, + inputs_q_shape=lnx.shape, + inputs_kv_shape=lnx.shape, mesh=mesh, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, diff --git a/MaxText/tests/attention_test.py b/MaxText/tests/attention_test.py index ca6f0fed6d..2e51c72b52 100644 --- a/MaxText/tests/attention_test.py +++ b/MaxText/tests/attention_test.py @@ -322,8 +322,8 @@ def setUp(self): head_dim=self.head_dim, max_target_length=self.max_target_length, max_prefill_predict_length=self.max_prefill_predict_length, - inputs_q=dummy_inputs_q, - inputs_kv=dummy_inputs_kv, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, mesh=self.mesh, attention_kernel="dot_product", dtype=self.dtype, @@ -439,8 +439,8 @@ def _test_model_mode_prefill_dtype(self, dtype): head_dim=self.head_dim, max_target_length=self.max_target_length, max_prefill_predict_length=self.cfg.max_prefill_predict_length, - inputs_q=dummy_inputs_q, - inputs_kv=dummy_inputs_kv, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, mesh=self.mesh, attention_kernel="dot_product", dtype=dtype, @@ -486,8 +486,8 @@ def tpu_kernel_attention_helper(self, num_kv_heads): head_dim=self.head_dim, max_target_length=self.max_target_length, max_prefill_predict_length=self.cfg.max_prefill_predict_length, - inputs_q=dummy_inputs_q, - inputs_kv=dummy_inputs_kv, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, mesh=self.mesh, attention_kernel="dot_product", dtype=self.dtype, @@ -515,8 +515,8 @@ def tpu_kernel_attention_helper(self, num_kv_heads): head_dim=self.head_dim, max_target_length=self.max_target_length, max_prefill_predict_length=self.cfg.max_prefill_predict_length, - inputs_q=dummy_inputs_q, - inputs_kv=dummy_inputs_kv, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, mesh=self.mesh, attention_kernel="flash", dtype=self.dtype, @@ -546,8 +546,8 @@ def tpu_kernel_attention_helper(self, num_kv_heads): head_dim=self.cfg_cp.head_dim, max_target_length=self.cfg_cp.max_target_length, max_prefill_predict_length=self.cfg_cp.max_prefill_predict_length, - inputs_q=dummy_inputs_q, - inputs_kv=dummy_inputs_kv, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, mesh=self.mesh_cp, attention_kernel="flash", dtype=self.dtype, @@ -629,8 +629,8 @@ def _dot_product_attention( num_query_heads=config.num_query_heads, num_kv_heads=config.num_kv_heads, head_dim=config.head_dim, - inputs_q=dummy_inputs_q, - inputs_kv=dummy_inputs_kv, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, max_target_length=config.max_target_length, max_prefill_predict_length=config.max_prefill_predict_length, attention_kernel=config.attention, @@ -722,8 +722,8 @@ def _dot_product_attention_reshape_q(self, compute_axis_order): head_dim=config.head_dim, max_target_length=config.max_target_length, max_prefill_predict_length=config.max_prefill_predict_length, - inputs_q=dummy_inputs_q, - inputs_kv=dummy_inputs_kv, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, attention_kernel=config.attention, dtype=config.dtype, compute_axis_order=compute_axis_order, @@ -740,8 +740,8 @@ def _dot_product_attention_reshape_q(self, compute_axis_order): head_dim=config.head_dim, max_target_length=config.max_target_length, max_prefill_predict_length=config.max_prefill_predict_length, - inputs_q=dummy_inputs_q, - inputs_kv=dummy_inputs_kv, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, attention_kernel=config.attention, dtype=config.dtype, compute_axis_order=compute_axis_order, @@ -863,8 +863,8 @@ def test_sliding_window_attention(self): max_target_length=self.max_target_length, max_prefill_predict_length=self.max_prefill_predict_length, mesh=self.mesh, - inputs_q=dummy_inputs_q, - inputs_kv=dummy_inputs_kv, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, attention_kernel="dot_product", dtype=self.dtype, dropout_rate=self.cfg.dropout_rate, @@ -882,8 +882,8 @@ def test_sliding_window_attention(self): max_target_length=self.max_target_length, max_prefill_predict_length=self.max_prefill_predict_length, mesh=self.mesh, - inputs_q=dummy_inputs_q, - inputs_kv=dummy_inputs_kv, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, attention_kernel="dot_product", dtype=self.dtype, dropout_rate=self.cfg.dropout_rate, @@ -932,8 +932,8 @@ def test_sliding_window_attention(self): max_target_length=self.max_target_length, max_prefill_predict_length=self.max_prefill_predict_length, mesh=self.mesh, - inputs_q=dummy_inputs_q, - inputs_kv=dummy_inputs_kv, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, attention_kernel="dot_product", dtype=self.dtype, dropout_rate=self.cfg.dropout_rate, @@ -1003,8 +1003,8 @@ def init_mla(self, rope_type): num_query_heads=num_query_heads, num_kv_heads=num_kv_heads, head_dim=192, - inputs_q=dummy_inputs_q, - inputs_kv=dummy_inputs_kv, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, max_target_length=max_target_length, max_prefill_predict_length=max_prefill_predict_length, mesh=mesh, diff --git a/MaxText/tests/check_llama4_layers.py b/MaxText/tests/check_llama4_layers.py index 020a1f0a7e..d3fc988f80 100644 --- a/MaxText/tests/check_llama4_layers.py +++ b/MaxText/tests/check_llama4_layers.py @@ -648,8 +648,8 @@ def test_vision_attention(self): head_dim=self.cfg.hidden_size_for_vit // self.cfg.num_attention_heads_for_vit, max_target_length=self.seq_len_for_vit, attention_kernel="dot_product", # TODO aireenmei: support flash attention - inputs_q=lnx, - inputs_kv=lnx, + inputs_q_shape=lnx.shape, + inputs_kv_shape=lnx.shape, mesh=self.mesh, dropout_rate=0, name="self_attention_vision",