[Enhance] Support internal metrics for gdn A_log and norm & vl model in general#1615
Conversation
|
@claude add PR description |
b3018e9 to
bdf5e01
Compare
bdf5e01 to
86f31ce
Compare
86f31ce to
ed0ed7c
Compare
|
@claude review |
xtuner/v1/utils/internal_metrics.py
Outdated
| def __init__(self, internal_metrics_cfg: InternalMetricsConfig, model: XTunerBaseModel): | ||
| self.internal_metrics_cfg = internal_metrics_cfg | ||
| self.model = model | ||
| self.model = model.language_model if hasattr(model, "language_model") else model |
There was a problem hiding this comment.
Claude: This change silently replaces self.model for all VL models, which means every other metric method (calculate_module_weight_rms, register_attn_output_hook, pop_metrics' named_modules() loop, MoE checks, etc.) now operates on language_model instead of the full composite model.
This is a behavioral change for all existing VL model users, not just GDN metrics. Consider whether this is intentional. If the goal is only to iterate over language-model modules for GDN stats, it would be safer to scope the unwrapping narrowly:
| self.model = model.language_model if hasattr(model, "language_model") else model | |
| self.model = model | |
| self._language_model = model.language_model if hasattr(model, "language_model") else model |
Then use self._language_model only in the GDN-specific iteration in pop_metrics.
|
Claude: ## Review Summary This PR adds internal metrics monitoring for GDN (Gated DeltaNet) IssuesCritical
Warning
Nit
VerdictREQUEST_CHANGES |
551d0b6 to
1817556
Compare
1817556 to
8c66dd4
Compare
Summary
This PR adds internal metrics monitoring support for Gated DeltaNet (GDN) architecture, specifically tracking min/max statistics for A_log parameters and FusedRMSNormGated modules. It also adds general support for vision-language models by properly extracting the language model component.
Implementation Details
New metrics fields:
Key changes:
- DTensor (converts to local tensor before computation)
- Distributed reduction across all ranks using dist.all_reduce
- Both nn.Module and torch.Tensor inputs
multi-modal architectures.
- FusedRMSNormGated modules (tracks min/max of gating parameters)
- A_log attributes (tracks min/max of log-space state transition parameters)