|
| 1 | +import torch |
| 2 | +from einops import rearrange |
| 3 | +from torch import nn, Tensor |
| 4 | +from torch.nn import LayerNorm, Linear, ModuleList |
| 5 | + |
| 6 | +from .modules import Block, no_grad_trunc_normal_ |
| 7 | +from .positional_embedding import SinCosPositionalEmbedding |
| 8 | + |
| 9 | + |
| 10 | +class MarlinDecoder(nn.Module): |
| 11 | + |
| 12 | + def __init__(self, img_size=224, patch_size=16, n_frames=16, embed_dim=384, depth=8, |
| 13 | + num_heads=6, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., |
| 14 | + norm_layer="LayerNorm", init_values=1., tubelet_size=2 |
| 15 | + ): |
| 16 | + super().__init__() |
| 17 | + output_dim = 3 * tubelet_size * patch_size * patch_size |
| 18 | + self.patch_size = patch_size |
| 19 | + self.tubelet_size = tubelet_size |
| 20 | + self.n_patch_h = img_size // patch_size |
| 21 | + self.n_patch_w = img_size // patch_size |
| 22 | + self.embed_dim = embed_dim |
| 23 | + if norm_layer == "LayerNorm": |
| 24 | + self.norm_layer = LayerNorm |
| 25 | + self.norm = self.norm_layer(embed_dim) |
| 26 | + else: |
| 27 | + raise NotImplementedError("Only LayerNorm is supported") |
| 28 | + |
| 29 | + # sine-cosine positional embeddings |
| 30 | + self.pos_embedding = SinCosPositionalEmbedding( |
| 31 | + (self.n_patch_h * self.n_patch_w * (n_frames // tubelet_size), embed_dim), dropout_rate=0.) |
| 32 | + self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
| 33 | + |
| 34 | + self.blocks = ModuleList([ |
| 35 | + Block( |
| 36 | + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, |
| 37 | + drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=self.norm_layer, |
| 38 | + init_values=init_values |
| 39 | + ) for _ in range(depth)]) |
| 40 | + |
| 41 | + self.head = Linear(embed_dim, output_dim) |
| 42 | + self.apply(self._init_weights) |
| 43 | + no_grad_trunc_normal_(self.mask_token, mean=0., std=0.02, a=-0.02, b=0.02) |
| 44 | + |
| 45 | + @staticmethod |
| 46 | + def _init_weights(m): |
| 47 | + if isinstance(m, nn.Linear): |
| 48 | + nn.init.xavier_uniform_(m.weight) |
| 49 | + if isinstance(m, nn.Linear) and m.bias is not None: |
| 50 | + nn.init.constant_(m.bias, 0) |
| 51 | + elif isinstance(m, nn.LayerNorm): |
| 52 | + nn.init.constant_(m.bias, 0) |
| 53 | + nn.init.constant_(m.weight, 1.0) |
| 54 | + |
| 55 | + def unpatch_to_img(self, x: Tensor) -> Tensor: |
| 56 | + # x: (Batch, No. batches, Prod of cube size * C) |
| 57 | + x = rearrange(x, "b n (c p) -> b n p c", c=3) |
| 58 | + # x: (Batch, No. batches, Prod of cube size, C) |
| 59 | + x = rearrange(x, "b (t h w) (p0 p1 p2) c -> b c (t p0) (h p1) (w p2)", p0=self.tubelet_size, |
| 60 | + p1=self.patch_size, p2=self.patch_size, h=self.n_patch_h, w=self.n_patch_w) |
| 61 | + # x: (B, C, T, H, W) |
| 62 | + return x |
| 63 | + |
| 64 | + def forward_features(self, x, return_token_num=0): |
| 65 | + for block in self.blocks: |
| 66 | + x = block(x) |
| 67 | + |
| 68 | + if return_token_num > 0: |
| 69 | + x = x[:, -return_token_num:] |
| 70 | + |
| 71 | + x = self.norm(x) |
| 72 | + x = self.head(x) |
| 73 | + # x: (B, N_mask, C) |
| 74 | + return x |
| 75 | + |
| 76 | + def forward(self, x, mask): |
| 77 | + # mask: 0 -> masked, 1 -> visible |
| 78 | + b, n, c = x.shape |
| 79 | + expand_pos_embed = self.pos_embedding.emb.data.expand(b, -1, -1) |
| 80 | + pos_emb_vis = expand_pos_embed[mask].view(b, -1, c) |
| 81 | + pos_emb_mask = expand_pos_embed[~mask].view(b, -1, c) |
| 82 | + x = torch.cat([x + pos_emb_vis, self.mask_token + pos_emb_mask], dim=1) |
| 83 | + |
| 84 | + mask_num = pos_emb_mask.shape[1] |
| 85 | + |
| 86 | + x = self.forward_features(x, return_token_num=mask_num) |
| 87 | + return x |
0 commit comments