|
| 1 | +"""Tests for ModuleLog, ModulePassLog, and ModuleAccessor.""" |
| 2 | + |
| 3 | +import pytest |
| 4 | +import torch |
| 5 | +import torch.nn as nn |
| 6 | + |
| 7 | +import example_models |
| 8 | +from torchlens import ModuleLog, ModulePassLog, log_forward_pass |
| 9 | +from torchlens.data_classes import ModuleAccessor, ParamAccessor |
| 10 | + |
| 11 | + |
| 12 | +# --------------------------------------------------------------------------- |
| 13 | +# Helpers |
| 14 | +# --------------------------------------------------------------------------- |
| 15 | + |
| 16 | + |
| 17 | +def _make_simple_model(): |
| 18 | + return nn.Sequential(nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 2)) |
| 19 | + |
| 20 | + |
| 21 | +def _simple_input(): |
| 22 | + return torch.randn(1, 10) |
| 23 | + |
| 24 | + |
| 25 | +def _make_nested_model(): |
| 26 | + """Model with nested submodules for hierarchy tests.""" |
| 27 | + return nn.Sequential( |
| 28 | + nn.Sequential(nn.Linear(10, 8), nn.ReLU()), |
| 29 | + nn.Sequential(nn.Linear(8, 4), nn.Sigmoid()), |
| 30 | + nn.Linear(4, 2), |
| 31 | + ) |
| 32 | + |
| 33 | + |
| 34 | +def _nested_input(): |
| 35 | + return torch.randn(1, 10) |
| 36 | + |
| 37 | + |
| 38 | +# --------------------------------------------------------------------------- |
| 39 | +# TestModuleLogBasic |
| 40 | +# --------------------------------------------------------------------------- |
| 41 | + |
| 42 | + |
| 43 | +class TestModuleLogBasic: |
| 44 | + def test_modules_accessor_exists(self): |
| 45 | + log = log_forward_pass(_make_simple_model(), _simple_input()) |
| 46 | + assert isinstance(log.modules, ModuleAccessor) |
| 47 | + |
| 48 | + def test_root_module_exists(self): |
| 49 | + log = log_forward_pass(_make_simple_model(), _simple_input()) |
| 50 | + root = log.modules["self"] |
| 51 | + assert isinstance(root, ModuleLog) |
| 52 | + assert root.address == "self" |
| 53 | + |
| 54 | + def test_root_module_alias(self): |
| 55 | + log = log_forward_pass(_make_simple_model(), _simple_input()) |
| 56 | + assert log.modules[""] is log.modules["self"] |
| 57 | + |
| 58 | + def test_root_module_property(self): |
| 59 | + log = log_forward_pass(_make_simple_model(), _simple_input()) |
| 60 | + assert log.root_module is log.modules["self"] |
| 61 | + |
| 62 | + def test_module_count(self): |
| 63 | + log = log_forward_pass(_make_simple_model(), _simple_input()) |
| 64 | + # Sequential has 3 children: Linear, ReLU, Linear → 3 submodules + root = 4 |
| 65 | + assert len(log.modules) >= 4 |
| 66 | + |
| 67 | + def test_access_by_address(self): |
| 68 | + log = log_forward_pass(_make_simple_model(), _simple_input()) |
| 69 | + ml = log.modules["0"] |
| 70 | + assert isinstance(ml, ModuleLog) |
| 71 | + assert ml.address == "0" |
| 72 | + |
| 73 | + def test_access_by_index(self): |
| 74 | + log = log_forward_pass(_make_simple_model(), _simple_input()) |
| 75 | + ml = log.modules[0] |
| 76 | + assert isinstance(ml, ModuleLog) |
| 77 | + |
| 78 | + def test_access_by_pass_notation(self): |
| 79 | + log = log_forward_pass(_make_simple_model(), _simple_input()) |
| 80 | + # All modules have 1 pass in a non-recurrent model |
| 81 | + addresses = [ml.address for ml in log.modules if ml.address != "self"] |
| 82 | + if addresses: |
| 83 | + addr = addresses[0] |
| 84 | + mpl = log.modules[f"{addr}:1"] |
| 85 | + assert isinstance(mpl, ModulePassLog) |
| 86 | + |
| 87 | + def test_contains(self): |
| 88 | + log = log_forward_pass(_make_simple_model(), _simple_input()) |
| 89 | + assert "0" in log.modules |
| 90 | + assert "self" in log.modules |
| 91 | + assert "nonexistent" not in log.modules |
| 92 | + |
| 93 | + def test_iter(self): |
| 94 | + log = log_forward_pass(_make_simple_model(), _simple_input()) |
| 95 | + modules_list = list(log.modules) |
| 96 | + assert len(modules_list) == len(log.modules) |
| 97 | + assert all(isinstance(ml, ModuleLog) for ml in modules_list) |
| 98 | + |
| 99 | + def test_getitem_multi_pass_returns_module_log(self, input_2d): |
| 100 | + """log["fc1"] for a multi-pass module should return ModuleLog (instead of error).""" |
| 101 | + model = example_models.RecurrentParamsSimple() |
| 102 | + log = log_forward_pass(model, input_2d) |
| 103 | + result = log["fc1"] |
| 104 | + assert isinstance(result, ModuleLog) |
| 105 | + |
| 106 | + |
| 107 | +# --------------------------------------------------------------------------- |
| 108 | +# TestModuleLogFields |
| 109 | +# --------------------------------------------------------------------------- |
| 110 | + |
| 111 | + |
| 112 | +class TestModuleLogFields: |
| 113 | + def test_identity_fields(self): |
| 114 | + log = log_forward_pass(_make_simple_model(), _simple_input()) |
| 115 | + ml = log.modules["0"] |
| 116 | + assert ml.address == "0" |
| 117 | + assert ml.name == "0" |
| 118 | + assert ml.module_class_name == "Linear" |
| 119 | + |
| 120 | + def test_source_info(self): |
| 121 | + log = log_forward_pass(_make_simple_model(), _simple_input()) |
| 122 | + ml = log.modules["0"] |
| 123 | + assert ml.source_file is not None # nn.Linear has inspectable source |
| 124 | + assert ml.forward_signature is not None |
| 125 | + |
| 126 | + def test_hierarchy_address(self): |
| 127 | + log = log_forward_pass(_make_nested_model(), _nested_input()) |
| 128 | + # "0.0" is Linear inside first Sequential |
| 129 | + ml = log.modules["0.0"] |
| 130 | + assert ml.address_parent == "0" |
| 131 | + assert ml.address_depth == 2 |
| 132 | + |
| 133 | + # "0" is the first Sequential |
| 134 | + parent = log.modules["0"] |
| 135 | + assert parent.address_parent == "self" |
| 136 | + assert "0.0" in parent.address_children |
| 137 | + |
| 138 | + def test_hierarchy_call(self): |
| 139 | + log = log_forward_pass(_make_simple_model(), _simple_input()) |
| 140 | + root = log.modules["self"] |
| 141 | + # Root's call_children should include top-level modules |
| 142 | + assert len(root.call_children) > 0 |
| 143 | + |
| 144 | + def test_nesting_depth(self): |
| 145 | + log = log_forward_pass(_make_nested_model(), _nested_input()) |
| 146 | + root = log.modules["self"] |
| 147 | + assert root.nesting_depth == 0 |
| 148 | + |
| 149 | + # Top-level module should be depth 1 |
| 150 | + top = log.modules["0"] |
| 151 | + assert top.nesting_depth == 1 |
| 152 | + |
| 153 | + # Nested inside "0" should be depth 2 |
| 154 | + nested = log.modules["0.0"] |
| 155 | + assert nested.nesting_depth == 2 |
| 156 | + |
| 157 | + def test_address_depth(self): |
| 158 | + log = log_forward_pass(_make_nested_model(), _nested_input()) |
| 159 | + assert log.modules["self"].address_depth == 0 |
| 160 | + assert log.modules["0"].address_depth == 1 |
| 161 | + assert log.modules["0.0"].address_depth == 2 |
| 162 | + |
| 163 | + def test_layers_populated(self): |
| 164 | + log = log_forward_pass(_make_simple_model(), _simple_input()) |
| 165 | + ml = log.modules["0"] |
| 166 | + assert len(ml.all_layers) > 0 |
| 167 | + assert ml.num_layers == len(ml.all_layers) |
| 168 | + |
| 169 | + def test_params_accessor(self): |
| 170 | + log = log_forward_pass(_make_simple_model(), _simple_input()) |
| 171 | + ml = log.modules["0"] # Linear layer |
| 172 | + assert isinstance(ml.params, ParamAccessor) |
| 173 | + assert len(ml.params) == 2 # weight + bias |
| 174 | + |
| 175 | + def test_training_mode(self): |
| 176 | + model = _make_simple_model() |
| 177 | + model.eval() |
| 178 | + log = log_forward_pass(model, _simple_input()) |
| 179 | + for ml in log.modules: |
| 180 | + if ml.address != "self": |
| 181 | + assert ml.training_mode is False |
| 182 | + |
| 183 | + def test_hooks_detected_false(self): |
| 184 | + log = log_forward_pass(_make_simple_model(), _simple_input()) |
| 185 | + for ml in log.modules: |
| 186 | + assert ml.has_forward_hooks is False |
| 187 | + assert ml.has_backward_hooks is False |
| 188 | + |
| 189 | + def test_repr(self): |
| 190 | + log = log_forward_pass(_make_simple_model(), _simple_input()) |
| 191 | + ml = log.modules["0"] |
| 192 | + r = repr(ml) |
| 193 | + assert "ModuleLog" in r |
| 194 | + assert ml.address in r |
| 195 | + assert ml.module_class_name in r |
| 196 | + |
| 197 | + |
| 198 | +# --------------------------------------------------------------------------- |
| 199 | +# TestModulePassLog |
| 200 | +# --------------------------------------------------------------------------- |
| 201 | + |
| 202 | + |
| 203 | +class TestModulePassLog: |
| 204 | + def test_pass_layers(self): |
| 205 | + log = log_forward_pass(_make_simple_model(), _simple_input()) |
| 206 | + ml = log.modules["0"] |
| 207 | + mpl = ml.passes[1] |
| 208 | + assert isinstance(mpl, ModulePassLog) |
| 209 | + # Pass layers should be a subset of parent all_layers |
| 210 | + assert all(label in ml.all_layers for label in mpl.layers) |
| 211 | + |
| 212 | + def test_input_output_layers(self): |
| 213 | + log = log_forward_pass(_make_simple_model(), _simple_input()) |
| 214 | + ml = log.modules["0"] |
| 215 | + mpl = ml.passes[1] |
| 216 | + assert isinstance(mpl.input_layers, list) |
| 217 | + assert isinstance(mpl.output_layers, list) |
| 218 | + |
| 219 | + def test_call_children(self): |
| 220 | + log = log_forward_pass(_make_nested_model(), _nested_input()) |
| 221 | + # "0" contains "0.0" and "0.1" as submodules |
| 222 | + ml = log.modules["0"] |
| 223 | + mpl = ml.passes[1] |
| 224 | + assert isinstance(mpl.call_children, list) |
| 225 | + |
| 226 | + def test_repr(self): |
| 227 | + log = log_forward_pass(_make_simple_model(), _simple_input()) |
| 228 | + ml = log.modules["0"] |
| 229 | + mpl = ml.passes[1] |
| 230 | + r = repr(mpl) |
| 231 | + assert "ModulePassLog" in r |
| 232 | + assert len(r) > 0 |
| 233 | + |
| 234 | + |
| 235 | +# --------------------------------------------------------------------------- |
| 236 | +# TestMultiPassModules |
| 237 | +# --------------------------------------------------------------------------- |
| 238 | + |
| 239 | + |
| 240 | +class TestMultiPassModules: |
| 241 | + def test_num_passes_gt_1(self, input_2d): |
| 242 | + model = example_models.RecurrentParamsSimple() |
| 243 | + log = log_forward_pass(model, input_2d) |
| 244 | + # fc1 is used 4 times |
| 245 | + ml = log.modules["fc1"] |
| 246 | + assert ml.num_passes >= 2 |
| 247 | + |
| 248 | + def test_per_call_field_raises(self, input_2d): |
| 249 | + model = example_models.RecurrentParamsSimple() |
| 250 | + log = log_forward_pass(model, input_2d) |
| 251 | + ml = log.modules["fc1"] |
| 252 | + assert ml.num_passes > 1 |
| 253 | + with pytest.raises(AttributeError, match="passes"): |
| 254 | + _ = ml.layers |
| 255 | + |
| 256 | + def test_pass_access(self, input_2d): |
| 257 | + model = example_models.RecurrentParamsSimple() |
| 258 | + log = log_forward_pass(model, input_2d) |
| 259 | + ml = log.modules["fc1"] |
| 260 | + assert 1 in ml.passes |
| 261 | + assert 2 in ml.passes |
| 262 | + assert isinstance(ml.passes[1], ModulePassLog) |
| 263 | + assert isinstance(ml.passes[2], ModulePassLog) |
| 264 | + |
| 265 | + def test_pass_notation_accessor(self, input_2d): |
| 266 | + model = example_models.RecurrentParamsSimple() |
| 267 | + log = log_forward_pass(model, input_2d) |
| 268 | + mpl = log.modules["fc1:2"] |
| 269 | + assert isinstance(mpl, ModulePassLog) |
| 270 | + assert mpl.pass_num == 2 |
| 271 | + |
| 272 | + |
| 273 | +# --------------------------------------------------------------------------- |
| 274 | +# TestSinglePassDelegation |
| 275 | +# --------------------------------------------------------------------------- |
| 276 | + |
| 277 | + |
| 278 | +class TestSinglePassDelegation: |
| 279 | + def test_layers_delegates(self): |
| 280 | + log = log_forward_pass(_make_simple_model(), _simple_input()) |
| 281 | + ml = log.modules["0"] |
| 282 | + assert ml.num_passes == 1 |
| 283 | + # Should delegate to passes[1].layers |
| 284 | + assert ml.layers == ml.passes[1].layers |
| 285 | + |
| 286 | + def test_forward_args_delegates(self): |
| 287 | + log = log_forward_pass(_make_simple_model(), _simple_input()) |
| 288 | + ml = log.modules["0"] |
| 289 | + # forward_args should be accessible for single-pass |
| 290 | + _ = ml.forward_args # should not raise |
| 291 | + |
| 292 | + |
| 293 | +# --------------------------------------------------------------------------- |
| 294 | +# TestModuleAccessorSummary |
| 295 | +# --------------------------------------------------------------------------- |
| 296 | + |
| 297 | + |
| 298 | +class TestModuleAccessorSummary: |
| 299 | + def test_to_pandas(self): |
| 300 | + log = log_forward_pass(_make_simple_model(), _simple_input()) |
| 301 | + df = log.modules.to_pandas() |
| 302 | + assert len(df) == len(log.modules) |
| 303 | + assert "address" in df.columns |
| 304 | + assert "module_class_name" in df.columns |
| 305 | + assert "nesting_depth" in df.columns |
| 306 | + assert "num_params" in df.columns |
| 307 | + |
| 308 | + def test_summary(self): |
| 309 | + log = log_forward_pass(_make_simple_model(), _simple_input()) |
| 310 | + s = log.modules.summary() |
| 311 | + assert isinstance(s, str) |
| 312 | + assert len(s) > 0 |
| 313 | + assert "Address" in s |
| 314 | + |
| 315 | + def test_repr(self): |
| 316 | + log = log_forward_pass(_make_simple_model(), _simple_input()) |
| 317 | + r = repr(log.modules) |
| 318 | + assert "ModuleAccessor" in r |
| 319 | + |
| 320 | + |
| 321 | +# --------------------------------------------------------------------------- |
| 322 | +# TestModuleLogIntegration |
| 323 | +# --------------------------------------------------------------------------- |
| 324 | + |
| 325 | + |
| 326 | +class TestModuleLogIntegration: |
| 327 | + def test_root_all_layers_equals_model_layers(self): |
| 328 | + log = log_forward_pass(_make_simple_model(), _simple_input()) |
| 329 | + root = log.root_module |
| 330 | + assert root.all_layers == log.layer_labels |
| 331 | + |
| 332 | + def test_root_params_count(self): |
| 333 | + log = log_forward_pass(_make_simple_model(), _simple_input()) |
| 334 | + root = log.root_module |
| 335 | + assert root.num_params == log.total_params |
| 336 | + |
| 337 | + def test_module_class_name_matches(self): |
| 338 | + log = log_forward_pass(_make_simple_model(), _simple_input()) |
| 339 | + # Module "0" is Linear |
| 340 | + assert log.modules["0"].module_class_name == "Linear" |
| 341 | + # Module "1" is ReLU |
| 342 | + assert log.modules["1"].module_class_name == "ReLU" |
| 343 | + # Module "2" is Linear |
| 344 | + assert log.modules["2"].module_class_name == "Linear" |
| 345 | + |
| 346 | + def test_nested_model_hierarchy(self): |
| 347 | + log = log_forward_pass(_make_nested_model(), _nested_input()) |
| 348 | + # Check that nesting is consistent |
| 349 | + for ml in log.modules: |
| 350 | + if ml.address == "self": |
| 351 | + continue |
| 352 | + # address_parent should be a valid module |
| 353 | + assert ml.address_parent in log.modules |
| 354 | + |
| 355 | + def test_module_log_to_pandas(self): |
| 356 | + log = log_forward_pass(_make_simple_model(), _simple_input()) |
| 357 | + ml = log.modules["0"] |
| 358 | + df = ml.to_pandas() |
| 359 | + assert len(df) == ml.num_layers |
| 360 | + |
| 361 | + def test_module_log_iter(self): |
| 362 | + log = log_forward_pass(_make_simple_model(), _simple_input()) |
| 363 | + ml = log.modules["0"] |
| 364 | + entries = list(ml) |
| 365 | + assert len(entries) == ml.num_layers |
| 366 | + |
| 367 | + def test_module_log_getitem(self): |
| 368 | + log = log_forward_pass(_make_simple_model(), _simple_input()) |
| 369 | + ml = log.modules["0"] |
| 370 | + if ml.num_layers > 0: |
| 371 | + entry = ml[0] |
| 372 | + assert entry.layer_label == ml.all_layers[0] |
| 373 | + |
| 374 | + def test_old_module_dicts_still_exist(self): |
| 375 | + """Old module_* dicts should still exist for vis.py compatibility.""" |
| 376 | + log = log_forward_pass(_make_simple_model(), _simple_input()) |
| 377 | + assert hasattr(log, "module_types") |
| 378 | + assert hasattr(log, "module_addresses") |
| 379 | + assert hasattr(log, "module_layers") |
| 380 | + assert hasattr(log, "module_pass_layers") |
| 381 | + assert hasattr(log, "module_pass_children") |
| 382 | + |
| 383 | + def test_nested_modules_model(self, input_2d): |
| 384 | + """Integration test with the NestedModules example model.""" |
| 385 | + model = example_models.NestedModules() |
| 386 | + log = log_forward_pass(model, input_2d) |
| 387 | + assert len(log.modules) > 1 |
| 388 | + root = log.root_module |
| 389 | + assert root.address == "self" |
| 390 | + # Should have nested hierarchy |
| 391 | + max_depth = max(ml.nesting_depth for ml in log.modules) |
| 392 | + assert max_depth >= 2 # At least 3 levels of nesting |
0 commit comments