Skip to content

Commit dfd2d7b

Browse files
fix(validation): handle bool and complex tensor perturbation properly
- Generate proper complex perturbations using torch.complex() instead of casting away imaginary part - Fix bool tensor crash by reordering .float().abs() (bool doesn't support abs, but float conversion handles it) - Add ContextUnet diffusion model to example_models.py for self-contained stable_diffusion test - Update test_stable_diffusion to use example_models.ContextUnet Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 33e0f22 commit dfd2d7b

File tree

3 files changed

+151
-11
lines changed

3 files changed

+151
-11
lines changed

tests/example_models.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1483,3 +1483,141 @@ def forward(x):
14831483
x = torch.sin(x)
14841484
x = torch.cos(x)
14851485
return x
1486+
1487+
1488+
# =============================================================================
1489+
# Conditional Diffusion UNet
1490+
# =============================================================================
1491+
# Adapted from TeaPearce/Conditional_Diffusion_MNIST:
1492+
# https://github.com/TeaPearce/Conditional_Diffusion_MNIST
1493+
1494+
1495+
class _ResidualConvBlock(nn.Module):
1496+
def __init__(self, in_channels, out_channels, is_res=False):
1497+
super().__init__()
1498+
self.same_channels = in_channels == out_channels
1499+
self.is_res = is_res
1500+
self.conv1 = nn.Sequential(
1501+
nn.Conv2d(in_channels, out_channels, 3, 1, 1),
1502+
nn.BatchNorm2d(out_channels),
1503+
nn.GELU(),
1504+
)
1505+
self.conv2 = nn.Sequential(
1506+
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
1507+
nn.BatchNorm2d(out_channels),
1508+
nn.GELU(),
1509+
)
1510+
1511+
def forward(self, x):
1512+
if self.is_res:
1513+
x1 = self.conv1(x)
1514+
x2 = self.conv2(x1)
1515+
if self.same_channels:
1516+
out = x + x2
1517+
else:
1518+
out = x1 + x2
1519+
return out / 1.414
1520+
else:
1521+
x1 = self.conv1(x)
1522+
x2 = self.conv2(x1)
1523+
return x2
1524+
1525+
1526+
class _UnetDown(nn.Module):
1527+
def __init__(self, in_channels, out_channels):
1528+
super().__init__()
1529+
self.model = nn.Sequential(_ResidualConvBlock(in_channels, out_channels), nn.MaxPool2d(2))
1530+
1531+
def forward(self, x):
1532+
return self.model(x)
1533+
1534+
1535+
class _UnetUp(nn.Module):
1536+
def __init__(self, in_channels, out_channels):
1537+
super().__init__()
1538+
self.model = nn.Sequential(
1539+
nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
1540+
_ResidualConvBlock(out_channels, out_channels),
1541+
_ResidualConvBlock(out_channels, out_channels),
1542+
)
1543+
1544+
def forward(self, x, skip):
1545+
x = torch.cat((x, skip), 1)
1546+
x = self.model(x)
1547+
return x
1548+
1549+
1550+
class _EmbedFC(nn.Module):
1551+
def __init__(self, input_dim, emb_dim):
1552+
super().__init__()
1553+
self.input_dim = input_dim
1554+
self.model = nn.Sequential(
1555+
nn.Linear(input_dim, emb_dim),
1556+
nn.GELU(),
1557+
nn.Linear(emb_dim, emb_dim),
1558+
)
1559+
1560+
def forward(self, x):
1561+
x = x.view(-1, self.input_dim)
1562+
return self.model(x)
1563+
1564+
1565+
class ContextUnet(nn.Module):
1566+
"""Conditional UNet for diffusion models."""
1567+
1568+
def __init__(self, in_channels, n_feat=256, n_classes=10):
1569+
super().__init__()
1570+
self.in_channels = in_channels
1571+
self.n_feat = n_feat
1572+
self.n_classes = n_classes
1573+
1574+
self.init_conv = _ResidualConvBlock(in_channels, n_feat, is_res=True)
1575+
1576+
self.down1 = _UnetDown(n_feat, n_feat)
1577+
self.down2 = _UnetDown(n_feat, 2 * n_feat)
1578+
1579+
self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU())
1580+
1581+
self.timeembed1 = _EmbedFC(1, 2 * n_feat)
1582+
self.timeembed2 = _EmbedFC(1, 1 * n_feat)
1583+
self.contextembed1 = _EmbedFC(n_classes, 2 * n_feat)
1584+
self.contextembed2 = _EmbedFC(n_classes, 1 * n_feat)
1585+
1586+
self.up0 = nn.Sequential(
1587+
nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, 7, 7),
1588+
nn.GroupNorm(8, 2 * n_feat),
1589+
nn.ReLU(),
1590+
)
1591+
1592+
self.up1 = _UnetUp(4 * n_feat, n_feat)
1593+
self.up2 = _UnetUp(2 * n_feat, n_feat)
1594+
self.out = nn.Sequential(
1595+
nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1),
1596+
nn.GroupNorm(8, n_feat),
1597+
nn.ReLU(),
1598+
nn.Conv2d(n_feat, self.in_channels, 3, 1, 1),
1599+
)
1600+
1601+
def forward(self, x, c, t, context_mask):
1602+
x = self.init_conv(x)
1603+
down1 = self.down1(x)
1604+
down2 = self.down2(down1)
1605+
hiddenvec = self.to_vec(down2)
1606+
1607+
c = nn.functional.one_hot(c, num_classes=self.n_classes).type(torch.float)
1608+
1609+
context_mask = context_mask[:, None]
1610+
context_mask = context_mask.repeat(1, self.n_classes)
1611+
context_mask = -1 * (1 - context_mask)
1612+
c = c * context_mask
1613+
1614+
cemb1 = self.contextembed1(c).view(-1, self.n_feat * 2, 1, 1)
1615+
temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)
1616+
cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1)
1617+
temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)
1618+
1619+
up1 = self.up0(hiddenvec)
1620+
up2 = self.up1(cemb1 * up1 + temb1, down2)
1621+
up3 = self.up2(cemb2 * up2 + temb2, down1)
1622+
out = self.out(torch.cat((up3, x), 1))
1623+
return out

tests/test_real_world_models.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -331,14 +331,9 @@ def test_gpt2():
331331

332332

333333
def test_stable_diffusion():
334-
try:
335-
import UNet
336-
except ModuleNotFoundError:
337-
pytest.skip("UNet not available")
338-
339-
model = UNet(3, 16, 10)
334+
model = example_models.ContextUnet(3, 16, 10)
340335
model_inputs = (
341-
torch.rand(6, 3, 224, 224),
336+
torch.rand(1, 3, 28, 28),
342337
torch.tensor([1]),
343338
torch.tensor([1.0]),
344339
torch.tensor([3.0]),

torchlens/validation.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -671,10 +671,17 @@ def _perturb_layer_activations(
671671
mean_output += torch.rand(mean_output.shape, device=mean_output.device) * 100
672672
mean_output *= torch.rand(mean_output.shape, device=mean_output.device)
673673
mean_output.requires_grad = False
674-
perturbed_activations = torch.randn_like(
675-
parent_activations.float(), device=device
676-
) * mean_output.to(device)
677-
perturbed_activations = perturbed_activations.type(parent_activations.dtype)
674+
scale = mean_output.to(device)
675+
if parent_activations.is_complex():
676+
perturbed_activations = torch.complex(
677+
torch.randn(parent_activations.shape, device=device) * scale,
678+
torch.randn(parent_activations.shape, device=device) * scale,
679+
).type(parent_activations.dtype)
680+
else:
681+
perturbed_activations = (
682+
torch.randn_like(parent_activations.float(), device=device) * scale
683+
)
684+
perturbed_activations = perturbed_activations.type(parent_activations.dtype)
678685

679686
return perturbed_activations
680687

0 commit comments

Comments
 (0)