@@ -1483,3 +1483,141 @@ def forward(x):
14831483 x = torch .sin (x )
14841484 x = torch .cos (x )
14851485 return x
1486+
1487+
1488+ # =============================================================================
1489+ # Conditional Diffusion UNet
1490+ # =============================================================================
1491+ # Adapted from TeaPearce/Conditional_Diffusion_MNIST:
1492+ # https://github.com/TeaPearce/Conditional_Diffusion_MNIST
1493+
1494+
1495+ class _ResidualConvBlock (nn .Module ):
1496+ def __init__ (self , in_channels , out_channels , is_res = False ):
1497+ super ().__init__ ()
1498+ self .same_channels = in_channels == out_channels
1499+ self .is_res = is_res
1500+ self .conv1 = nn .Sequential (
1501+ nn .Conv2d (in_channels , out_channels , 3 , 1 , 1 ),
1502+ nn .BatchNorm2d (out_channels ),
1503+ nn .GELU (),
1504+ )
1505+ self .conv2 = nn .Sequential (
1506+ nn .Conv2d (out_channels , out_channels , 3 , 1 , 1 ),
1507+ nn .BatchNorm2d (out_channels ),
1508+ nn .GELU (),
1509+ )
1510+
1511+ def forward (self , x ):
1512+ if self .is_res :
1513+ x1 = self .conv1 (x )
1514+ x2 = self .conv2 (x1 )
1515+ if self .same_channels :
1516+ out = x + x2
1517+ else :
1518+ out = x1 + x2
1519+ return out / 1.414
1520+ else :
1521+ x1 = self .conv1 (x )
1522+ x2 = self .conv2 (x1 )
1523+ return x2
1524+
1525+
1526+ class _UnetDown (nn .Module ):
1527+ def __init__ (self , in_channels , out_channels ):
1528+ super ().__init__ ()
1529+ self .model = nn .Sequential (_ResidualConvBlock (in_channels , out_channels ), nn .MaxPool2d (2 ))
1530+
1531+ def forward (self , x ):
1532+ return self .model (x )
1533+
1534+
1535+ class _UnetUp (nn .Module ):
1536+ def __init__ (self , in_channels , out_channels ):
1537+ super ().__init__ ()
1538+ self .model = nn .Sequential (
1539+ nn .ConvTranspose2d (in_channels , out_channels , 2 , 2 ),
1540+ _ResidualConvBlock (out_channels , out_channels ),
1541+ _ResidualConvBlock (out_channels , out_channels ),
1542+ )
1543+
1544+ def forward (self , x , skip ):
1545+ x = torch .cat ((x , skip ), 1 )
1546+ x = self .model (x )
1547+ return x
1548+
1549+
1550+ class _EmbedFC (nn .Module ):
1551+ def __init__ (self , input_dim , emb_dim ):
1552+ super ().__init__ ()
1553+ self .input_dim = input_dim
1554+ self .model = nn .Sequential (
1555+ nn .Linear (input_dim , emb_dim ),
1556+ nn .GELU (),
1557+ nn .Linear (emb_dim , emb_dim ),
1558+ )
1559+
1560+ def forward (self , x ):
1561+ x = x .view (- 1 , self .input_dim )
1562+ return self .model (x )
1563+
1564+
1565+ class ContextUnet (nn .Module ):
1566+ """Conditional UNet for diffusion models."""
1567+
1568+ def __init__ (self , in_channels , n_feat = 256 , n_classes = 10 ):
1569+ super ().__init__ ()
1570+ self .in_channels = in_channels
1571+ self .n_feat = n_feat
1572+ self .n_classes = n_classes
1573+
1574+ self .init_conv = _ResidualConvBlock (in_channels , n_feat , is_res = True )
1575+
1576+ self .down1 = _UnetDown (n_feat , n_feat )
1577+ self .down2 = _UnetDown (n_feat , 2 * n_feat )
1578+
1579+ self .to_vec = nn .Sequential (nn .AvgPool2d (7 ), nn .GELU ())
1580+
1581+ self .timeembed1 = _EmbedFC (1 , 2 * n_feat )
1582+ self .timeembed2 = _EmbedFC (1 , 1 * n_feat )
1583+ self .contextembed1 = _EmbedFC (n_classes , 2 * n_feat )
1584+ self .contextembed2 = _EmbedFC (n_classes , 1 * n_feat )
1585+
1586+ self .up0 = nn .Sequential (
1587+ nn .ConvTranspose2d (2 * n_feat , 2 * n_feat , 7 , 7 ),
1588+ nn .GroupNorm (8 , 2 * n_feat ),
1589+ nn .ReLU (),
1590+ )
1591+
1592+ self .up1 = _UnetUp (4 * n_feat , n_feat )
1593+ self .up2 = _UnetUp (2 * n_feat , n_feat )
1594+ self .out = nn .Sequential (
1595+ nn .Conv2d (2 * n_feat , n_feat , 3 , 1 , 1 ),
1596+ nn .GroupNorm (8 , n_feat ),
1597+ nn .ReLU (),
1598+ nn .Conv2d (n_feat , self .in_channels , 3 , 1 , 1 ),
1599+ )
1600+
1601+ def forward (self , x , c , t , context_mask ):
1602+ x = self .init_conv (x )
1603+ down1 = self .down1 (x )
1604+ down2 = self .down2 (down1 )
1605+ hiddenvec = self .to_vec (down2 )
1606+
1607+ c = nn .functional .one_hot (c , num_classes = self .n_classes ).type (torch .float )
1608+
1609+ context_mask = context_mask [:, None ]
1610+ context_mask = context_mask .repeat (1 , self .n_classes )
1611+ context_mask = - 1 * (1 - context_mask )
1612+ c = c * context_mask
1613+
1614+ cemb1 = self .contextembed1 (c ).view (- 1 , self .n_feat * 2 , 1 , 1 )
1615+ temb1 = self .timeembed1 (t ).view (- 1 , self .n_feat * 2 , 1 , 1 )
1616+ cemb2 = self .contextembed2 (c ).view (- 1 , self .n_feat , 1 , 1 )
1617+ temb2 = self .timeembed2 (t ).view (- 1 , self .n_feat , 1 , 1 )
1618+
1619+ up1 = self .up0 (hiddenvec )
1620+ up2 = self .up1 (cemb1 * up1 + temb1 , down2 )
1621+ up3 = self .up2 (cemb2 * up2 + temb2 , down1 )
1622+ out = self .out (torch .cat ((up3 , x ), 1 ))
1623+ return out
0 commit comments