Skip to content

fix: decoders and pipeline parity gaps of linen to nnx migrations #4288

Open
mesakhcienet wants to merge 1 commit into
AI-Hypercomputer:mainfrom
CIeNET-International:fix/nnx-linen-decoders-pipeline-parity-gaps
Open

fix: decoders and pipeline parity gaps of linen to nnx migrations #4288
mesakhcienet wants to merge 1 commit into
AI-Hypercomputer:mainfrom
CIeNET-International:fix/nnx-linen-decoders-pipeline-parity-gaps

Conversation

@mesakhcienet

@mesakhcienet mesakhcienet commented Jun 29, 2026

Copy link
Copy Markdown
Collaborator

Description

Closes the remaining behavioral gaps between the pure-NNX decoder/pipeline path and the Linen reference.
The NNX path now reproduces Linen for DeepSeek-V4, per-stage remat, and the pipeline's non-trainable / repeat-level-remat handlning.

What's included

DeepSeek-V4 NNX decoder port

  • Registered DEEPSEEK4 in NNXDecoder.get_decoder_layer(was missing → ValueError at construction) and added full decoder-level handling: norm dispatch (RMSNorm), scanned + non-scanned init, _apply_deepseek4_scanned_blocks (prefix first_num_hash_layers unroll + paired HCA/CSA scan), global layer_idx, and decoder_input_tokens threading — matching Linen _apply_deepseek4_scanned_blocks.

Per-stage pipeline remat parity (set_remat_policy_on_layers_per_stage)

  • The flag was a no-op in the NNX pipeline after the Linen→NNX migration. Restored per-stage remat (jax.checkpoint) + params-only host-offload in NNXSequentialPipelineStage / NNXScannedPipelineStage, wired from both stage builders, incl. num_layers_per_pipeline_stage == 1.
  • Fix: decoupled "apply remat" from the policy value. remat_policy='full' resolves to None (== full remat, as Linen nn.remat(policy=None)); the old if policy is not None gate silently dropped remat for the default 'full' policy. Now gated on the flag via an explicit apply_remat argument.

Pipeline Linen→NNX migration parity (pipeline.py)

  • non_trainable collection: the migration asserted the iteration-scan catch-all was RngState-only, crashing any pipelined model with a non-trainable variable (e.g. the DeepSeek-V4 hash-routing table). Non-circular now broadcasts non_trainable as a loop-invariant constant (4-way state split); circular carries it via
    carry_state.
  • circular repeat-level remat: made unconditional to match the Linen reference (whose flag-check was dead code, always rematting). Default flag path unchanged; only flag=False configs regain the dropped rematerialization.

Unit Tests

  • tests/unit/nnx_decoders_test.py (+ DeepSeek-V4 construct/forward/scan parity, pipeline-stage forward + remat transparency, per-stage-remat-applied guards, layer_map registration guard).
  • tests/unit/nnx_pipeline_test.py (new): first CPU coverage for NNXPipeline / NNXCircularPipeline — non-circular + circular forward, non_trainable partitioning, repeat-remat output transparency.

Tests

combination of set_remat_policy_on_layers_per_stage flag.

Llama2 7b

  • Linen Decoder + normal pipeline + remat=false : log / xprofile
  • Linen Decoder + normal pipeline + remat=true : log / xprofile
  • Linen Decoder + circular pipeline + remat=false : log / xprofile
  • Linen Decoder + circular pipeline + remat=true : log / xprofile
  • NNX Decoder + normal pipeline+ remat=false : log / xprofile
  • NNX Decoder + normal pipeline+ remat=true : log / xprofile
  • NNXDecoder + circular pipeline+ remat=false : log / xprofile
  • NNXDecoder + circular pipeline+ remat=true : log / xprofile

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov

codecov Bot commented Jun 29, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 82.50000% with 14 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/layers/nnx_decoders.py 89.06% 3 Missing and 4 partials ⚠️
src/maxtext/layers/decoders.py 0.00% 5 Missing ⚠️
src/maxtext/layers/pipeline.py 81.81% 2 Missing ⚠️

📢 Thoughts on this report? Let us know!

@mesakhcienet mesakhcienet force-pushed the fix/nnx-linen-decoders-pipeline-parity-gaps branch from 783a66a to 7064594 Compare June 29, 2026 04:44
@mesakhcienet mesakhcienet changed the title fix: update nnx decoders deepseek4 and pipeline implementation fix: parity gaps of linen to nnx for decoders and pipeline Jun 29, 2026
Comment on lines -598 to -630
def get_layer_to_pipeline(blocks, cfg):
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
return blocks[1] # return the sparse block
else:
return blocks[0]

cfg = self.config
base_stage = get_layer_to_pipeline(decoder_blocks, cfg)
if cfg.set_remat_policy_on_layers_per_stage:
policy = self.get_remat_policy()
base_stage = self.set_remat_policy([base_stage], policy)[0]
if cfg.num_layers_per_pipeline_stage == 1:
stage_module = base_stage(config=cfg, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode)
elif cfg.scan_layers_per_stage:
stage_module = self.scan_decoder_layers(
cfg,
base_stage,
base_stage_cls,
cfg.num_layers_per_pipeline_stage,
"layers_per_stage",
cfg,
self.mesh,
in_axes_tuple=(nn.broadcast,) * 4,
model_mode=self.model_mode,
)
else:
stage_module = SequentialBlockDecoderLayers(
decoder_layer=base_stage,
num_decoder_layers=cfg.num_layers_per_pipeline_stage,
config=cfg,
mesh=self.mesh,
quant=self.quant,
model_mode=self.model_mode,
self.quant,
self.model_mode,
rngs=rngs,
remat_policy=per_stage_remat,
apply_remat=apply_per_stage_remat,
)
return stage_module

@mesakhcienet mesakhcienet Jun 29, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove get_layer_to_pipeline dead code (unused anymore)

@mesakhcienet mesakhcienet force-pushed the fix/nnx-linen-decoders-pipeline-parity-gaps branch from 262fa8e to bc467c2 Compare June 29, 2026 06:58
@mesakhcienet mesakhcienet marked this pull request as ready for review June 29, 2026 07:20
@mesakhcienet mesakhcienet changed the title fix: parity gaps of linen to nnx for decoders and pipeline fix: decoders and pipeline parity gaps of linen to nnx migrations Jun 29, 2026
@ecnal-cienet ecnal-cienet force-pushed the fix/nnx-linen-decoders-pipeline-parity-gaps branch from c8206f0 to b2663d2 Compare June 29, 2026 22:00
@mesakhcienet mesakhcienet force-pushed the fix/nnx-linen-decoders-pipeline-parity-gaps branch 3 times, most recently from 50c83c4 to a776654 Compare June 30, 2026 04:05
@mesakhcienet mesakhcienet force-pushed the fix/nnx-linen-decoders-pipeline-parity-gaps branch from 71b97ed to cf1a449 Compare June 30, 2026 04:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant