@@ -324,7 +324,7 @@ class ModuleParamInfo(NamedTuple):
324324
325325
326326def _build_module_param_info (
327- self : "ModelLog" , address : str , mbd : dict , _buffer_layers_by_module : dict = None
327+ self : "ModelLog" , address : str , mbd : dict , _buffer_layers_by_module : Optional [ dict ] = None
328328) -> ModuleParamInfo :
329329 """Gather parameter counts, sizes, and buffer layers for a single module."""
330330 from ..data_classes .param_log import ParamAccessor
@@ -406,9 +406,9 @@ def _build_module_logs(self: "ModelLog") -> None:
406406 _buffer_layers_by_module = defaultdict (list )
407407 for bl in self .buffer_layers :
408408 if bl in self .layer_dict_all_keys :
409- entry = self .layer_dict_all_keys [bl ]
410- if hasattr (entry , "buffer_address" ) and entry .buffer_address is not None :
411- module_addr = entry .buffer_address .rsplit ("." , 1 )[0 ]
409+ bl_entry = self .layer_dict_all_keys [bl ]
410+ if hasattr (bl_entry , "buffer_address" ) and bl_entry .buffer_address is not None :
411+ module_addr = bl_entry .buffer_address .rsplit ("." , 1 )[0 ]
412412 _buffer_layers_by_module [module_addr ].append (bl )
413413
414414 # --- Build ModuleLogs for each submodule ---
0 commit comments