Skip to content

Commit 11ea006

Browse files
fix(postprocess): fix mypy type errors in _build_module_param_info
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 99f4102 commit 11ea006

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

torchlens/postprocess/finalization.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ class ModuleParamInfo(NamedTuple):
324324

325325

326326
def _build_module_param_info(
327-
self: "ModelLog", address: str, mbd: dict, _buffer_layers_by_module: dict = None
327+
self: "ModelLog", address: str, mbd: dict, _buffer_layers_by_module: Optional[dict] = None
328328
) -> ModuleParamInfo:
329329
"""Gather parameter counts, sizes, and buffer layers for a single module."""
330330
from ..data_classes.param_log import ParamAccessor
@@ -406,9 +406,9 @@ def _build_module_logs(self: "ModelLog") -> None:
406406
_buffer_layers_by_module = defaultdict(list)
407407
for bl in self.buffer_layers:
408408
if bl in self.layer_dict_all_keys:
409-
entry = self.layer_dict_all_keys[bl]
410-
if hasattr(entry, "buffer_address") and entry.buffer_address is not None:
411-
module_addr = entry.buffer_address.rsplit(".", 1)[0]
409+
bl_entry = self.layer_dict_all_keys[bl]
410+
if hasattr(bl_entry, "buffer_address") and bl_entry.buffer_address is not None:
411+
module_addr = bl_entry.buffer_address.rsplit(".", 1)[0]
412412
_buffer_layers_by_module[module_addr].append(bl)
413413

414414
# --- Build ModuleLogs for each submodule ---

0 commit comments

Comments
 (0)