Skip to content

Commit c9d5698

Browse files
feat: hugging face publish (#33)
* feat: hugging face publish * edit docs --------- Co-authored-by: ControlNet <smczx@hotmail.com>
1 parent 398060f commit c9d5698

File tree

12 files changed

+736
-1
lines changed

12 files changed

+736
-1
lines changed

README.md

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
<a href="https://arxiv.org/abs/2211.06627">
2222
<img src="https://img.shields.io/badge/arXiv-2211.06627-b31b1b.svg?style=flat-square">
2323
</a>
24-
<a href="https://huggingface.co/ControlNet/MARLIN">
24+
<a href="https://huggingface.co/collections/ControlNet/marlin-67e79296284080c98d95e3d9">
2525
<img src="https://img.shields.io/badge/huggingface-model-FFD21E?style=flat-square&logo=huggingface">
2626
</a>
2727
</div>
@@ -50,6 +50,7 @@ This repo is the official PyTorch implementation for the paper
5050

5151
The repository contains 2 parts:
5252
- `marlin-pytorch`: The PyPI package for MARLIN used for inference.
53+
- The HuggingFace wrapper for MARLIN used for inference.
5354
- The implementation for the paper including training and evaluation scripts.
5455

5556
```
@@ -70,6 +71,9 @@ The repository contains 2 parts:
7071
├── init.py
7172
├── version.txt
7273
74+
# below is for the huggingface wrapper
75+
├── hf_src
76+
7377
# below is for the paper implementation
7478
├── configs # Configs for experiments settings
7579
├── model # Marlin models
@@ -150,6 +154,29 @@ features = model.extract_features(x) # torch.Size([B, k, 768])
150154
features = model.extract_features(x, keep_seq=False) # torch.Size([B, 768])
151155
```
152156

157+
## Use `transformers` (HuggingFace) for Feature Extraction
158+
159+
Requirements:
160+
- Python
161+
- PyTorch
162+
- transformers
163+
- einops
164+
165+
Currently the huggingface model is only for direct feature extraction without any video pre-processing (e.g. face detection, cropping, strided window, etc).
166+
167+
168+
```python
169+
import torch
170+
from transformers import AutoModel
171+
172+
model = AutoModel.from_pretrained(
173+
"ControlNet/marlin_vit_base_ytf", # or other variants
174+
trust_remote_code=True
175+
)
176+
tensor = torch.rand([1, 3, 16, 224, 224]) # (B, C, T, H, W)
177+
output = model(tensor) # torch.Size([1, 1568, 384])
178+
```
179+
153180
## Paper Implementation
154181

155182
### Requirements

hf_src/marlin_configs/vit_base.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from marlin_huggingface import MarlinConfig
2+
3+
4+
vit_base_config = MarlinConfig(
5+
img_size=224,
6+
patch_size=16,
7+
n_frames=16,
8+
mlp_ratio=4.0,
9+
qkv_bias=True,
10+
qk_scale=None,
11+
drop_rate=0.0,
12+
attn_drop_rate=0.0,
13+
norm_layer="LayerNorm",
14+
init_values=0.0,
15+
tubelet_size=2,
16+
encoder_embed_dim=768,
17+
encoder_depth=12,
18+
encoder_num_heads=12,
19+
decoder_embed_dim=384,
20+
decoder_depth=4,
21+
decoder_num_heads=6,
22+
)

hf_src/marlin_configs/vit_large.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from marlin_huggingface import MarlinConfig
2+
3+
4+
vit_large_config = MarlinConfig(
5+
img_size=224,
6+
patch_size=16,
7+
n_frames=16,
8+
mlp_ratio=4.0,
9+
qkv_bias=True,
10+
qk_scale=None,
11+
drop_rate=0.0,
12+
attn_drop_rate=0.0,
13+
norm_layer="LayerNorm",
14+
init_values=0.0,
15+
tubelet_size=2,
16+
encoder_embed_dim=1024,
17+
encoder_depth=24,
18+
encoder_num_heads=16,
19+
decoder_embed_dim=512,
20+
decoder_depth=12,
21+
decoder_num_heads=8,
22+
)

hf_src/marlin_configs/vit_small.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from marlin_huggingface import MarlinConfig
2+
3+
vit_small_config = MarlinConfig(
4+
img_size=224,
5+
patch_size=16,
6+
n_frames=16,
7+
mlp_ratio=4.,
8+
qkv_bias=True,
9+
qk_scale=None,
10+
drop_rate=0.,
11+
attn_drop_rate=0.,
12+
norm_layer="LayerNorm",
13+
init_values=0.,
14+
tubelet_size=2,
15+
encoder_embed_dim=384,
16+
encoder_depth=12,
17+
encoder_num_heads=6,
18+
decoder_embed_dim=192,
19+
decoder_depth=4,
20+
decoder_num_heads=3,
21+
)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from transformers import AutoModel, AutoConfig
2+
3+
from .config import MarlinConfig
4+
from .marlin import Marlin, MarlinModel
5+
6+
MarlinConfig.register_for_auto_class()
7+
MarlinModel.register_for_auto_class()
8+
AutoConfig.register("marlin", MarlinConfig)
9+
AutoModel.register(MarlinConfig, MarlinModel)
10+
11+
__all__ = ["Marlin", "MarlinModel", "MarlinConfig"]
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from transformers import PretrainedConfig
2+
3+
4+
class MarlinConfig(PretrainedConfig):
5+
model_type = "marlin"
6+
7+
def __init__(self, **kwargs):
8+
self.img_size = kwargs.pop("img_size", None)
9+
self.patch_size = kwargs.pop("patch_size", None)
10+
self.n_frames = kwargs.pop("n_frames", None)
11+
self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", None)
12+
self.encoder_depth = kwargs.pop("encoder_depth", None)
13+
self.encoder_num_heads = kwargs.pop("encoder_num_heads", None)
14+
self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", None)
15+
self.decoder_depth = kwargs.pop("decoder_depth", None)
16+
self.decoder_num_heads = kwargs.pop("decoder_num_heads", None)
17+
self.mlp_ratio = kwargs.pop("mlp_ratio", None)
18+
self.qkv_bias = kwargs.pop("qkv_bias", None)
19+
self.qk_scale = kwargs.pop("qk_scale", None)
20+
self.drop_rate = kwargs.pop("drop_rate", None)
21+
self.attn_drop_rate = kwargs.pop("attn_drop_rate", None)
22+
self.norm_layer = kwargs.pop("norm_layer", None)
23+
self.init_values = kwargs.pop("init_values", None)
24+
self.tubelet_size = kwargs.pop("tubelet_size", None)
25+
self.as_feature_extractor = kwargs.pop("as_feature_extractor", True)
26+
27+
super().__init__(**kwargs)
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import torch
2+
from einops import rearrange
3+
from torch import nn, Tensor
4+
from torch.nn import LayerNorm, Linear, ModuleList
5+
6+
from .modules import Block, no_grad_trunc_normal_
7+
from .positional_embedding import SinCosPositionalEmbedding
8+
9+
10+
class MarlinDecoder(nn.Module):
11+
12+
def __init__(self, img_size=224, patch_size=16, n_frames=16, embed_dim=384, depth=8,
13+
num_heads=6, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
14+
norm_layer="LayerNorm", init_values=1., tubelet_size=2
15+
):
16+
super().__init__()
17+
output_dim = 3 * tubelet_size * patch_size * patch_size
18+
self.patch_size = patch_size
19+
self.tubelet_size = tubelet_size
20+
self.n_patch_h = img_size // patch_size
21+
self.n_patch_w = img_size // patch_size
22+
self.embed_dim = embed_dim
23+
if norm_layer == "LayerNorm":
24+
self.norm_layer = LayerNorm
25+
self.norm = self.norm_layer(embed_dim)
26+
else:
27+
raise NotImplementedError("Only LayerNorm is supported")
28+
29+
# sine-cosine positional embeddings
30+
self.pos_embedding = SinCosPositionalEmbedding(
31+
(self.n_patch_h * self.n_patch_w * (n_frames // tubelet_size), embed_dim), dropout_rate=0.)
32+
self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
33+
34+
self.blocks = ModuleList([
35+
Block(
36+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
37+
drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=self.norm_layer,
38+
init_values=init_values
39+
) for _ in range(depth)])
40+
41+
self.head = Linear(embed_dim, output_dim)
42+
self.apply(self._init_weights)
43+
no_grad_trunc_normal_(self.mask_token, mean=0., std=0.02, a=-0.02, b=0.02)
44+
45+
@staticmethod
46+
def _init_weights(m):
47+
if isinstance(m, nn.Linear):
48+
nn.init.xavier_uniform_(m.weight)
49+
if isinstance(m, nn.Linear) and m.bias is not None:
50+
nn.init.constant_(m.bias, 0)
51+
elif isinstance(m, nn.LayerNorm):
52+
nn.init.constant_(m.bias, 0)
53+
nn.init.constant_(m.weight, 1.0)
54+
55+
def unpatch_to_img(self, x: Tensor) -> Tensor:
56+
# x: (Batch, No. batches, Prod of cube size * C)
57+
x = rearrange(x, "b n (c p) -> b n p c", c=3)
58+
# x: (Batch, No. batches, Prod of cube size, C)
59+
x = rearrange(x, "b (t h w) (p0 p1 p2) c -> b c (t p0) (h p1) (w p2)", p0=self.tubelet_size,
60+
p1=self.patch_size, p2=self.patch_size, h=self.n_patch_h, w=self.n_patch_w)
61+
# x: (B, C, T, H, W)
62+
return x
63+
64+
def forward_features(self, x, return_token_num=0):
65+
for block in self.blocks:
66+
x = block(x)
67+
68+
if return_token_num > 0:
69+
x = x[:, -return_token_num:]
70+
71+
x = self.norm(x)
72+
x = self.head(x)
73+
# x: (B, N_mask, C)
74+
return x
75+
76+
def forward(self, x, mask):
77+
# mask: 0 -> masked, 1 -> visible
78+
b, n, c = x.shape
79+
expand_pos_embed = self.pos_embedding.emb.data.expand(b, -1, -1)
80+
pos_emb_vis = expand_pos_embed[mask].view(b, -1, c)
81+
pos_emb_mask = expand_pos_embed[~mask].view(b, -1, c)
82+
x = torch.cat([x + pos_emb_vis, self.mask_token + pos_emb_mask], dim=1)
83+
84+
mask_num = pos_emb_mask.shape[1]
85+
86+
x = self.forward_features(x, return_token_num=mask_num)
87+
return x
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from torch import nn, Tensor
2+
from torch.nn import ModuleList, LayerNorm
3+
4+
from .modules import PatchEmbedding3d, Block
5+
from .positional_embedding import SinCosPositionalEmbedding
6+
7+
8+
class MarlinEncoder(nn.Module):
9+
10+
def __init__(self, img_size=224, patch_size=16, n_frames=16, embed_dim=768, depth=12,
11+
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
12+
norm_layer="LayerNorm", init_values=0., tubelet_size=2
13+
):
14+
super().__init__()
15+
16+
self.embed_dim = embed_dim
17+
self.patch_embedding = PatchEmbedding3d(
18+
input_size=(3, n_frames, img_size, img_size),
19+
patch_size=(tubelet_size, patch_size, patch_size),
20+
embedding=embed_dim
21+
)
22+
num_patches = (img_size // patch_size) * (img_size // patch_size) * (n_frames // tubelet_size)
23+
24+
# sine-cosine positional embeddings
25+
self.pos_embedding = SinCosPositionalEmbedding((num_patches, embed_dim), dropout_rate=0.)
26+
27+
if norm_layer == "LayerNorm":
28+
self.norm_layer = LayerNorm
29+
self.norm = self.norm_layer(embed_dim)
30+
else:
31+
raise NotImplementedError("Only LayerNorm is supported")
32+
33+
self.blocks = ModuleList([
34+
Block(
35+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
36+
drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=self.norm_layer,
37+
init_values=init_values)
38+
for _ in range(depth)
39+
])
40+
41+
self.apply(self._init_weights)
42+
43+
@staticmethod
44+
def _init_weights(m):
45+
if isinstance(m, nn.Linear):
46+
nn.init.xavier_uniform_(m.weight)
47+
if isinstance(m, nn.Linear) and m.bias is not None:
48+
nn.init.constant_(m.bias, 0)
49+
elif isinstance(m, nn.LayerNorm):
50+
nn.init.constant_(m.bias, 0)
51+
nn.init.constant_(m.weight, 1.0)
52+
53+
def forward_features(self, x):
54+
for block in self.blocks:
55+
x = block(x)
56+
x = self.norm(x)
57+
return x
58+
59+
def forward(self, x: Tensor, mask: Tensor) -> Tensor:
60+
# mask: (B, T, N) with boolean values, 0 -> masked, 1 -> visible
61+
assert len(x.shape) == 5, "x must be 5D"
62+
emb = self.patch_embedding(x)
63+
emb = self.pos_embedding(emb)
64+
b, _, c = emb.shape
65+
emb = emb[mask].view(b, -1, c) # only visible patches are used
66+
emb = self.forward_features(emb)
67+
return emb
68+
69+
def extract_features(self, x: Tensor, seq_mean_pool: bool) -> Tensor:
70+
x = self.patch_embedding(x)
71+
x = self.pos_embedding(x)
72+
for block in self.blocks:
73+
x = block(x)
74+
75+
if seq_mean_pool:
76+
x = x.mean(dim=1)
77+
x = self.norm(x)
78+
return x

0 commit comments

Comments
 (0)