Skip to content

Commit d737828

Browse files
fix(postprocess): backward-only flood in conditional branch detection + THEN labeling (#88)
Bug #88: _mark_conditional_branches flooded bidirectionally (parents + children), causing non-conditional nodes' children to be falsely marked as in_cond_branch. Fix restricts flooding to parent_layers only. Additionally adds THEN branch detection via AST analysis when save_source_context=True, with IF/THEN edge labels in visualization. Includes 8 new test models, 22 new tests, and fixes missing 'verbose' in MODEL_LOG_FIELD_ORDER. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 3fe6a84 commit d737828

File tree

12 files changed

+633
-10
lines changed

12 files changed

+633
-10
lines changed

tests/example_models.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,116 @@ def forward(x):
496496
return x
497497

498498

499+
class ConditionalAlwaysTrue(nn.Module):
500+
"""Condition always True — only THEN branch executes."""
501+
502+
@staticmethod
503+
def forward(x):
504+
if torch.mean(torch.abs(x)) >= 0: # always true (abs >= 0)
505+
x = torch.sin(x)
506+
x = x + 1
507+
else:
508+
x = torch.cos(x)
509+
return x
510+
511+
512+
class ConditionalAlwaysFalse(nn.Module):
513+
"""Condition always False — only ELSE branch executes."""
514+
515+
@staticmethod
516+
def forward(x):
517+
if torch.mean(torch.abs(x)) < 0: # always false (abs >= 0)
518+
x = torch.sin(x)
519+
else:
520+
x = torch.cos(x)
521+
x = x + 1
522+
return x
523+
524+
525+
class ConditionalNested(nn.Module):
526+
"""Nested if-then (condition inside condition)."""
527+
528+
@staticmethod
529+
def forward(x):
530+
if torch.mean(x) > -1000:
531+
x = x + 1
532+
if torch.sum(x) > -1000:
533+
x = x * 2
534+
else:
535+
x = x * 3
536+
else:
537+
x = x - 1
538+
return x
539+
540+
541+
class ConditionalChainedBools(nn.Module):
542+
"""Two boolean conditions checked before branching."""
543+
544+
@staticmethod
545+
def forward(x):
546+
cond1 = torch.mean(x) > 0
547+
cond2 = torch.sum(x) > 0
548+
if cond1 and cond2:
549+
x = torch.sin(x)
550+
else:
551+
x = torch.cos(x)
552+
return x
553+
554+
555+
class ConditionalNoBranch(nn.Module):
556+
"""Bool computed but never used for branching (no THEN -> should clear IF)."""
557+
558+
@staticmethod
559+
def forward(x):
560+
_ = torch.mean(x) > 0 # computed but not used for control flow
561+
x = torch.sin(x) + 1
562+
return x
563+
564+
565+
class ConditionalMultipleBranches(nn.Module):
566+
"""Two separate if-then blocks in sequence."""
567+
568+
@staticmethod
569+
def forward(x):
570+
if torch.mean(x) > 0:
571+
x = torch.sin(x)
572+
else:
573+
x = torch.cos(x)
574+
if torch.sum(x) > 0:
575+
x = x + 1
576+
else:
577+
x = x - 1
578+
return x
579+
580+
581+
class ConditionalWithModules(nn.Module):
582+
"""Branches using nn.Linear layers."""
583+
584+
def __init__(self):
585+
super().__init__()
586+
self.linear1 = nn.Linear(5, 5, bias=False)
587+
self.linear2 = nn.Linear(5, 5, bias=False)
588+
589+
def forward(self, x):
590+
if torch.mean(x) > 0:
591+
x = self.linear1(x)
592+
else:
593+
x = self.linear2(x)
594+
return x
595+
596+
597+
class ConditionalIdentity(nn.Module):
598+
"""Condition but both branches do same thing (still valid IF/THEN)."""
599+
600+
@staticmethod
601+
def forward(x):
602+
if torch.mean(x) > 0:
603+
x = x + 1
604+
else:
605+
x = x + 1
606+
return x
607+
608+
499609
class RepeatedModule(nn.Module):
500610
def __init__(self):
501611
super().__init__()

0 commit comments

Comments
 (0)