Skip to content

Commit d54e9a9

Browse files
fix(capture): fix mypy type errors in output_tensors field dict
Annotate fields_dict as Dict[str, Any] and extract param_shapes with proper type to satisfy mypy strict inference. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent f0d7452 commit d54e9a9

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

torchlens/capture/output_tensors.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def _build_shared_fields_dict(
290290
parent_layer_labels = get_attr_values_from_tensor_list(arg_tensors, "tl_tensor_label_raw")
291291
parent_layer_entries = [self[label] for label in parent_layer_labels]
292292

293-
fields_dict = {}
293+
fields_dict: Dict[str, Any] = {}
294294

295295
# General info
296296
fields_dict["layer_type"] = layer_type
@@ -329,12 +329,13 @@ def _build_shared_fields_dict(
329329
_build_module_context_fields(self, fields_dict, arg_tensors, parent_layer_entries)
330330

331331
# Function config — lightweight hyperparameter extraction, always on.
332+
param_shapes: Optional[List[Tuple]] = fields_dict.get("parent_param_shapes")
332333
fields_dict["func_config"] = extract_salient_args(
333334
layer_type,
334335
func_name,
335336
args,
336337
kwargs,
337-
fields_dict.get("parent_param_shapes", []),
338+
param_shapes,
338339
)
339340

340341
return fields_dict, parent_layer_entries, arg_tensors, parent_param_passes

0 commit comments

Comments
 (0)