From d2d406994c08287a3e57c5f2764f8408a2300065 Mon Sep 17 00:00:00 2001 From: Eric Hofesmann Date: Thu, 24 Oct 2019 20:27:13 -0400 Subject: [PATCH 1/2] Add i3d model and weights links and train/test scripts --- .gitignore | 1 + README.md | 2 + models/i3d/config_test.yaml | 27 +++ models/i3d/config_train.yaml | 37 +++ models/i3d/i3d.py | 447 +++++++++++++++++++++++++++++++++++ weights/download_weights.sh | 6 + 6 files changed, 520 insertions(+) create mode 100644 models/i3d/config_test.yaml create mode 100644 models/i3d/config_train.yaml create mode 100644 models/i3d/i3d.py diff --git a/.gitignore b/.gitignore index c70c4d2..0e4172d 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ runs/* models/HGC3D *.json pbs/* +weights/* diff --git a/README.md b/README.md index 3231468..74bc01e 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,7 @@ Check out our [wiki!](https://github.com/MichiganCOG/ViP/wiki) ### Recognition | Model Architecture | Dataset | ViP Accuracy (%) | |:--------------------:|:------------------:|:---------------------:| +| I3D | HMDB51 (Split 1) | 72.75 | | C3D | HMDB51 (Split 1) | 50.14 ± 0.777 | | C3D | UCF101 (Split 1) | 80.40 ± 0.399 | @@ -43,6 +44,7 @@ Check out our [wiki!](https://github.com/MichiganCOG/ViP/wiki) | Model | Task(s) | |:------------------------------------------------:|:--------------------:| |[C3D](https://github.com/jfzhang95/pytorch-video-recognition/blob/master/network/C3D_model.py) | Activity Recognition | +|[I3D](https://github.com/piergiaj/pytorch-i3d) | Activity Recognition | |[SSD300](https://github.com/amdegroot/ssd.pytorch) | Object Detection | ## Requirements diff --git a/models/i3d/config_test.yaml b/models/i3d/config_test.yaml new file mode 100644 index 0000000..72ae825 --- /dev/null +++ b/models/i3d/config_test.yaml @@ -0,0 +1,27 @@ +# Preprocessing +clip_length: 64 # Number of frames within a clip +clip_offset: 0 # Frame offset between beginning of video and clip (1st clip only) +clip_stride: 0 # Frame offset between successive frames +crop_shape: [224,224] # (Height, Width) of frame +crop_type: Center # Type of cropping operation (Random, Central and None) +final_shape: [224,224] # (Height, Width) of input to be given to CNN +num_clips: -1 # Number clips to be generated from a video (<0: uniform sampling, 0: Divide entire video into clips, >0: Defines number of clips) +random_offset: 0 # Boolean switch to generate a clip length sized clip from a video +resize_shape: [230,250] # (Height, Width) to resize original data +subtract_mean: [123,117,104] # Subtract mean (R,G,B) from all frames during preprocessing + +# Experiment Setup +acc_metric: Accuracy # Accuracy metric +batch_size: 1 # Numbers of videos in a mini-batch +dataset: HMDB51 # Name of dataset +exp: I3D # Experiment name +json_path: /z/dat/HMDB51/ # Path to the json file for the given dataset +labels: 51 # Number of total classes in the dataset +load_type: train_val # Environment selection, to include only training/training and validation/testing dataset +model: I3D # Name of model to be loaded +num_workers: 5 # Number of CPU worker used to load data +preprocess: default # String argument to select preprocessing type +pretrained: 'weights/i3d_rgb_imagenet_then_HMDB51_30epochs.pkl' # Load pretrained network +save_dir: './results' # Path to results directory +seed: 999 # Seed for reproducibility +loss_type: M_XENTROPY # Loss function diff --git a/models/i3d/config_train.yaml b/models/i3d/config_train.yaml new file mode 100644 index 0000000..c15abe9 --- /dev/null +++ b/models/i3d/config_train.yaml @@ -0,0 +1,37 @@ +# Preprocessing +clip_length: 64 # Number of frames within a clip +clip_offset: 0 # Frame offset between beginning of video and clip (1st clip only) +clip_stride: 1 # Frame offset between successive frames +crop_shape: [224,224] # (Height, Width) of frame +crop_type: Center # Type of cropping operation (Random, Central and None) +final_shape: [224,224] # (Height, Width) of input to be given to CNN +num_clips: -1 # Number clips to be generated from a video (<0: uniform sampling, 0: Divide entire video into clips, >0: Defines number of clips) +random_offset: 0 # Boolean switch to generate a clip length sized clip from a video +resize_shape: [230,250] # (Height, Width) to resize original data +subtract_mean: [123,117,104] # Subtract mean (R,G,B) from all frames during preprocessing + +# Experiment Setup +acc_metric: Accuracy # Accuracy metric +batch_size: 5 # Numbers of videos in a mini-batch +pseudo_batch_loop: 10 # Pseudo-batch size multiplier to mimic large minibatches +dataset: HMDB51 # Name of dataset +epoch: 30 # Total number of epochs +exp: I3D # Experiment name +gamma: 0.1 # Multiplier with which to change learning rate +json_path: /z/dat/HMDB51/ # Path to the json file for the given dataset +labels: 51 # Number of total classes in the dataset +load_type: train # Environment selection, to include only training/training and validation/testing dataset +loss_type: M_XENTROPY # Loss function +lr: 0.01 # Learning rate +milestones: [10, 20] # Epoch values to change learning rate +model: I3D # Name of model to be loaded +momentum: 0.9 # Momentum value in optimizer +num_workers: 5 # Number of CPU worker used to load data +opt: sgd # Name of optimizer +preprocess: default # String argument to select preprocessing type +pretrained: 1 # Load pretrained network +rerun: 1 # Number of trials to repeat an experiment +save_dir: './results' # Path to results directory +seed: 999 # Seed for reproducibility +weight_decay: 0.0005 # Weight decay +grad_max_norm: 100 diff --git a/models/i3d/i3d.py b/models/i3d/i3d.py new file mode 100644 index 0000000..4e901b3 --- /dev/null +++ b/models/i3d/i3d.py @@ -0,0 +1,447 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +import datasets.preprocessing_transforms as pt + +import numpy as np + +import os +import sys +from collections import OrderedDict + + +""" +Code from the implementation of i3d by AJ Piergiovanni: https://github.com/piergiaj/pytorch-i3d +""" + +class MaxPool3dSamePadding(nn.MaxPool3d): + + def compute_pad(self, dim, s): + if s % self.stride[dim] == 0: + return max(self.kernel_size[dim] - self.stride[dim], 0) + else: + return max(self.kernel_size[dim] - (s % self.stride[dim]), 0) + + def forward(self, x): + # compute 'same' padding + (batch, channel, t, h, w) = x.size() + #print t,h,w + out_t = np.ceil(float(t) / float(self.stride[0])) + out_h = np.ceil(float(h) / float(self.stride[1])) + out_w = np.ceil(float(w) / float(self.stride[2])) + #print out_t, out_h, out_w + pad_t = self.compute_pad(0, t) + pad_h = self.compute_pad(1, h) + pad_w = self.compute_pad(2, w) + #print pad_t, pad_h, pad_w + + pad_t_f = pad_t // 2 + pad_t_b = pad_t - pad_t_f + pad_h_f = pad_h // 2 + pad_h_b = pad_h - pad_h_f + pad_w_f = pad_w // 2 + pad_w_b = pad_w - pad_w_f + + pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b) + x = F.pad(x, pad) + return super(MaxPool3dSamePadding, self).forward(x) + + +class Unit3D(nn.Module): + + def __init__(self, in_channels, + output_channels, + kernel_shape=(1, 1, 1), + stride=(1, 1, 1), + padding=0, + activation_fn=F.relu, + use_batch_norm=True, + use_bias=False, + name='unit_3d', + dilation=1): + + """Initializes Unit3D module.""" + super(Unit3D, self).__init__() + + self._output_channels = output_channels + self._kernel_shape = kernel_shape + self._stride = stride + self._use_batch_norm = use_batch_norm + self._activation_fn = activation_fn + self._use_bias = use_bias + self.name = name + self.padding = padding + + self.conv3d = nn.Conv3d(in_channels=in_channels, + out_channels=self._output_channels, + kernel_size=self._kernel_shape, + stride=self._stride, + padding=0, # we always want padding to be 0 here. We will dynamically pad based on input size in forward function + bias=self._use_bias, + dilation=dilation) + + if self._use_batch_norm: + self.bn = nn.BatchNorm3d(self._output_channels, eps=0.001, momentum=0.01) + + def compute_pad(self, dim, s): + if s % self._stride[dim] == 0: + return max(self._kernel_shape[dim] - self._stride[dim], 0) + else: + return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0) + + + def forward(self, x): + # compute 'same' padding + (batch, channel, t, h, w) = x.size() + #print t,h,w + out_t = np.ceil(float(t) / float(self._stride[0])) + out_h = np.ceil(float(h) / float(self._stride[1])) + out_w = np.ceil(float(w) / float(self._stride[2])) + #print out_t, out_h, out_w + pad_t = self.compute_pad(0, t) + pad_h = self.compute_pad(1, h) + pad_w = self.compute_pad(2, w) + #print pad_t, pad_h, pad_w + + pad_t_f = pad_t // 2 + pad_t_b = pad_t - pad_t_f + pad_h_f = pad_h // 2 + pad_h_b = pad_h - pad_h_f + pad_w_f = pad_w // 2 + pad_w_b = pad_w - pad_w_f + + pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b) + x = F.pad(x, pad) + + x = self.conv3d(x) + if self._use_batch_norm: + x = self.bn(x) + if self._activation_fn is not None: + x = self._activation_fn(x) + return x + + + +class InceptionModule(nn.Module): + def __init__(self, in_channels, out_channels, name): + super(InceptionModule, self).__init__() + + self.b0 = Unit3D(in_channels=in_channels, output_channels=out_channels[0], kernel_shape=[1, 1, 1], padding=0, + name=name+'/Branch_0/Conv3d_0a_1x1') + self.b1a = Unit3D(in_channels=in_channels, output_channels=out_channels[1], kernel_shape=[1, 1, 1], padding=0, + name=name+'/Branch_1/Conv3d_0a_1x1') + self.b1b = Unit3D(in_channels=out_channels[1], output_channels=out_channels[2], kernel_shape=[3, 3, 3], + name=name+'/Branch_1/Conv3d_0b_3x3') + self.b2a = Unit3D(in_channels=in_channels, output_channels=out_channels[3], kernel_shape=[1, 1, 1], padding=0, + name=name+'/Branch_2/Conv3d_0a_1x1') + self.b2b = Unit3D(in_channels=out_channels[3], output_channels=out_channels[4], kernel_shape=[3, 3, 3], + name=name+'/Branch_2/Conv3d_0b_3x3') + self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3], + stride=(1, 1, 1), padding=0) + self.b3b = Unit3D(in_channels=in_channels, output_channels=out_channels[5], kernel_shape=[1, 1, 1], padding=0, + name=name+'/Branch_3/Conv3d_0b_1x1') + self.name = name + + def forward(self, x): + b0 = self.b0(x) + b1 = self.b1b(self.b1a(x)) + b2 = self.b2b(self.b2a(x)) + b3 = self.b3b(self.b3a(x)) + return torch.cat([b0,b1,b2,b3], dim=1) + + +class I3D(nn.Module): + """Inception-v1 I3D architecture. + The model is introduced in: + Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset + Joao Carreira, Andrew Zisserman + https://arxiv.org/pdf/1705.07750v1.pdf. + See also the Inception architecture, introduced in: + Going deeper with convolutions + Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, + Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich. + http://arxiv.org/pdf/1409.4842v1.pdf. + """ + + # Endpoints of the model in order. During construction, all the endpoints up + # to a designated `final_endpoint` are returned in a dictionary as the + # second return value. + VALID_ENDPOINTS = ( + 'Conv3d_1a_7x7', + 'MaxPool3d_2a_3x3', + 'Conv3d_2b_1x1', + 'Conv3d_2c_3x3', + 'MaxPool3d_3a_3x3', + 'Mixed_3b', + 'Mixed_3c', + 'MaxPool3d_4a_3x3', + 'Mixed_4b', + 'Mixed_4c', + 'Mixed_4d', + 'Mixed_4e', + 'Mixed_4f', + 'MaxPool3d_5a_2x2', + 'Mixed_5b', + 'Mixed_5c', + 'Logits', + 'Predictions', + ) + + def __init__(self, spatial_squeeze=True, + final_endpoint='Logits', name='inception_i3d', in_channels=3, dropout_keep_prob=0.5, **kwargs): + """Initializes I3D model instance. + Args: + num_classes: The number of outputs in the logit layer (default 400, which + matches the Kinetics dataset). + spatial_squeeze: Whether to squeeze the spatial dimensions for the logits + before returning (default True). + final_endpoint: The model contains many possible endpoints. + `final_endpoint` specifies the last endpoint for the model to be built + up to. In addition to the output at `final_endpoint`, all the outputs + at endpoints up to `final_endpoint` will also be returned, in a + dictionary. `final_endpoint` must be one of + InceptionI3d.VALID_ENDPOINTS (default 'Logits'). + name: A string (optional). The name of this module. + Raises: + ValueError: if `final_endpoint` is not recognized. + """ + + if final_endpoint not in self.VALID_ENDPOINTS: + raise ValueError('Unknown final endpoint %s' % final_endpoint) + + super(I3D, self).__init__() + self._num_classes = kwargs['labels'] + self._spatial_squeeze = spatial_squeeze + self._final_endpoint = final_endpoint + self.logits = None + + self.train_transforms = PreprocessTrain(**kwargs) + self.test_transforms = PreprocessEval(**kwargs) + + + if self._final_endpoint not in self.VALID_ENDPOINTS: + raise ValueError('Unknown final endpoint %s' % self._final_endpoint) + + self.end_points = {} + end_point = 'Conv3d_1a_7x7' + self.end_points[end_point] = Unit3D(in_channels=in_channels, output_channels=64, kernel_shape=[7, 7, 7], + stride=(2, 2, 2), padding=(3,3,3), name=name+end_point) + if self._final_endpoint == end_point: return + + end_point = 'MaxPool3d_2a_3x3' + self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2), + padding=0) + if self._final_endpoint == end_point: return + + end_point = 'Conv3d_2b_1x1' + self.end_points[end_point] = Unit3D(in_channels=64, output_channels=64, kernel_shape=[1, 1, 1], padding=0, + name=name+end_point) + if self._final_endpoint == end_point: return + + end_point = 'Conv3d_2c_3x3' + self.end_points[end_point] = Unit3D(in_channels=64, output_channels=192, kernel_shape=[3, 3, 3], padding=1, + name=name+end_point) + if self._final_endpoint == end_point: return + + end_point = 'MaxPool3d_3a_3x3' + self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2), + padding=0) + if self._final_endpoint == end_point: return + + end_point = 'Mixed_3b' + self.end_points[end_point] = InceptionModule(192, [64,96,128,16,32,32], name+end_point) + if self._final_endpoint == end_point: return + + end_point = 'Mixed_3c' + self.end_points[end_point] = InceptionModule(256, [128,128,192,32,96,64], name+end_point) + if self._final_endpoint == end_point: return + + end_point = 'MaxPool3d_4a_3x3' + self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[3, 3, 3], stride=(2, 2, 2), + padding=0) + if self._final_endpoint == end_point: return + + end_point = 'Mixed_4b' + self.end_points[end_point] = InceptionModule(128+192+96+64, [192,96,208,16,48,64], name+end_point) + if self._final_endpoint == end_point: return + + end_point = 'Mixed_4c' + self.end_points[end_point] = InceptionModule(192+208+48+64, [160,112,224,24,64,64], name+end_point) + if self._final_endpoint == end_point: return + + end_point = 'Mixed_4d' + self.end_points[end_point] = InceptionModule(160+224+64+64, [128,128,256,24,64,64], name+end_point) + if self._final_endpoint == end_point: return + + end_point = 'Mixed_4e' + self.end_points[end_point] = InceptionModule(128+256+64+64, [112,144,288,32,64,64], name+end_point) + if self._final_endpoint == end_point: return + + end_point = 'Mixed_4f' + self.end_points[end_point] = InceptionModule(112+288+64+64, [256,160,320,32,128,128], name+end_point) + if self._final_endpoint == end_point: return + + end_point = 'MaxPool3d_5a_2x2' + self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[2, 2, 2], stride=(2, 2, 2), + padding=0) + if self._final_endpoint == end_point: return + + end_point = 'Mixed_5b' + self.end_points[end_point] = InceptionModule(256+320+128+128, [256,160,320,32,128,128], name+end_point) + if self._final_endpoint == end_point: return + + end_point = 'Mixed_5c' + self.end_points[end_point] = InceptionModule(256+320+128+128, [384,192,384,48,128,128], name+end_point) + if self._final_endpoint == end_point: return + + end_point = 'Logits' + self.avg_pool = nn.AvgPool3d(kernel_size=[2, 7, 7], + stride=(1, 1, 1)) + self.dropout = nn.Dropout(dropout_keep_prob) + self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes, + kernel_shape=[1, 1, 1], + padding=0, + activation_fn=None, + use_batch_norm=False, + use_bias=True, + name='logits') + + + + + + self.build() + + if 'pretrained' in kwargs.keys() and kwargs['pretrained']: + if 'i3d_pretrained' in kwargs.keys(): + self._load_checkpoint(kwargs['i3d_pretrained']) + else: + self._load_pretrained_weights() + + def _load_pretrained_weights(self): + p_dict = torch.load('weights/i3d_rgb_imagenet.pt') + s_dict = self.state_dict() + for name in p_dict: + if name in s_dict.keys(): + if p_dict[name].shape == s_dict[name].shape: + s_dict[name] = p_dict[name] + + self.load_state_dict(s_dict) + + def _load_checkpoint(self, saved_weights): + p_dict = torch.load(saved_weights)['state_dict'] + s_dict = self.state_dict() + for name in p_dict: + if name in s_dict.keys(): + if p_dict[name].shape == s_dict[name].shape: + s_dict[name] = p_dict[name] + + self.load_state_dict(s_dict) + + + + def replace_logits(self, num_classes): + self._num_classes = num_classes + self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes, + kernel_shape=[1, 1, 1], + padding=0, + activation_fn=None, + use_batch_norm=False, + use_bias=True, + name='logits') + + + def build(self): + for k in self.end_points.keys(): + self.add_module(k, self.end_points[k]) + + def forward(self, x): + for end_point in self.VALID_ENDPOINTS: + if end_point in self.end_points: + x = self._modules[end_point](x) # use _modules to work with dataparallel + + x = self.logits(self.dropout(self.avg_pool(x))) + + if self._spatial_squeeze: + logits = x.squeeze(3).squeeze(3) + # logits is batch X classes X time, which is what we want to work with + + logits = torch.mean(logits, dim=2) + return logits + + + def extract_features(self, x): + for end_point in self.VALID_ENDPOINTS: + if end_point in self.end_points: + x = self._modules[end_point](x) + return self.avg_pool(x) + +class PreprocessTrain(object): + """ + Container for all transforms used to preprocess clips for training in this dataset. + """ + def __init__(self, **kwargs): + """ + Initialize preprocessing class for training set + Args: + preprocess (String): Keyword to select different preprocessing types + crop_type (String): Select random or central crop + + Return: + None + """ + + self.transforms = [] + self.transforms1 = [] + self.preprocess = kwargs['preprocess'] + crop_type = kwargs['crop_type'] + + + self.transforms.append(pt.ResizeClip(**kwargs)) + + if crop_type == 'Random': + self.transforms.append(pt.RandomCropClip(**kwargs)) + + else: + self.transforms.append(pt.CenterCropClip(**kwargs)) + + self.transforms.append(pt.SubtractRGBMean(**kwargs)) + self.transforms.append(pt.RandomFlipClip(direction='h', p=0.5, **kwargs)) + self.transforms.append(pt.ToTensorClip(**kwargs)) + + def __call__(self, input_data): + for transform in self.transforms: + input_data = transform(input_data) + + return input_data + + +class PreprocessEval(object): + """ + Container for all transforms used to preprocess clips for training in this dataset. + """ + def __init__(self, **kwargs): + """ + Initialize preprocessing class for training set + Args: + preprocess (String): Keyword to select different preprocessing types + crop_type (String): Select random or central crop + + Return: + None + """ + + self.transforms = [] + + self.transforms.append(pt.ResizeClip(**kwargs)) + self.transforms.append(pt.CenterCropClip(**kwargs)) + self.transforms.append(pt.SubtractRGBMean(**kwargs)) + self.transforms.append(pt.ToTensorClip(**kwargs)) + + + def __call__(self, input_data): + for transform in self.transforms: + input_data = transform(input_data) + + return input_data diff --git a/weights/download_weights.sh b/weights/download_weights.sh index 3cbfec6..aef001d 100755 --- a/weights/download_weights.sh +++ b/weights/download_weights.sh @@ -13,3 +13,9 @@ wget -O ./weights/c3d-pretrained.pth https://umich.box.com/shared/static/znmyt8u #C3D Mean wget -O ./weights/sport1m_train16_128_mean.npy https://umich.box.com/shared/static/ppbnldsa5rty615osdjh2yi8fqcx0a3b.npy + +#I3D pretrained on ImageNet and then Kinetics by original authors +wget -O ./weights/i3d_rgb_imagenet.pt https://umich.box.com/shared/static/5m6dwwepzdcw3kjhx7s0peb59lbcde0s.pt + +#I3D pretrained on ImageNet, Kinetics, then on HMDB51 in ViP +wget -O ./weights/i3d_rgb_imagenet_then_HMDB51_30epochs.pkl https://umich.box.com/shared/static/x8x83sw4htidxsxgtus9nt00f383mmm7.pkl From ba7c110fb399474505ce7160d54f707c8ceddced Mon Sep 17 00:00:00 2001 From: Eric Hofesmann Date: Fri, 25 Oct 2019 16:14:31 -0400 Subject: [PATCH 2/2] Ignore .pt files --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 0e4172d..5dbdb9f 100644 --- a/.gitignore +++ b/.gitignore @@ -13,4 +13,4 @@ runs/* models/HGC3D *.json pbs/* -weights/* +*.pt