Skip to content

Commit 8952443

Browse files
committed
uppload everything
1 parent f2e462a commit 8952443

File tree

14 files changed

+2215
-1
lines changed

14 files changed

+2215
-1
lines changed

README.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,17 @@
1212
Please download the dataset through this [link](https://drive.google.com/file/d/19jRx9WjLviCGpqa_ShXVSnGG0rfr2GeV/view?usp=sharing
1313
).
1414

15-
### Code will be available early October.
15+
### Training
16+
After downloading the dataset and extracting the I3D features using this [**repo**](https://github.com/Tushar-N/pytorch-resnet3d), simply run the following command:
17+
```shell
18+
python main_transformer.py
19+
```
1620

21+
### Inference
22+
For inference, after setting the path of the best checkpoint, then run the following command:
23+
```shell
24+
python inference.py
25+
```
1726

1827
### Citation
1928

comm.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import pickle
2+
import torch
3+
import torch.distributed as dist
4+
5+
6+
class Comm(object):
7+
def __init__(self, local_rank=0):
8+
self.local_rank = 0
9+
10+
@property
11+
def world_size(self):
12+
if not dist.is_available():
13+
return 1
14+
if not dist.is_initialized():
15+
return 1
16+
return dist.get_world_size()
17+
18+
@property
19+
def rank(self):
20+
if not dist.is_available():
21+
return 0
22+
if not dist.is_initialized():
23+
return 0
24+
return dist.get_rank()
25+
26+
@property
27+
def local_rank(self):
28+
if not dist.is_available():
29+
return 0
30+
if not dist.is_initialized():
31+
return 0
32+
return self._local_rank
33+
34+
@local_rank.setter
35+
def local_rank(self, value):
36+
if not dist.is_available():
37+
self._local_rank = 0
38+
if not dist.is_initialized():
39+
self._local_rank = 0
40+
self._local_rank = value
41+
42+
@property
43+
def head(self):
44+
return 'Rank[{}/{}]'.format(self.rank, self.world_size)
45+
46+
def is_main_process(self):
47+
return self.rank == 0
48+
49+
def synchronize(self):
50+
"""
51+
Helper function to synchronize (barrier) among all processes when
52+
using distributed training
53+
"""
54+
if self.world_size == 1:
55+
return
56+
dist.barrier()
57+
58+
59+
comm = Comm()
60+
61+
62+
def all_gather(data):
63+
"""
64+
Run all_gather on arbitrary picklable data (not necessarily tensors)
65+
Args:
66+
data: any picklable object
67+
Returns:
68+
list[data]: list of data gathered from each rank
69+
"""
70+
world_size = comm.world_size
71+
if world_size == 1:
72+
return [data]
73+
74+
# serialized to a Tensor
75+
buffer = pickle.dumps(data)
76+
storage = torch.ByteStorage.from_buffer(buffer)
77+
tensor = torch.ByteTensor(storage).to("cuda")
78+
79+
# obtain Tensor size of each rank
80+
local_size = torch.LongTensor([tensor.numel()]).to("cuda")
81+
size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)]
82+
dist.all_gather(size_list, local_size)
83+
size_list = [int(size.item()) for size in size_list]
84+
max_size = max(size_list)
85+
86+
# receiving Tensor from all ranks
87+
# we pad the tensor because torch all_gather does not support
88+
# gathering tensors of different shapes
89+
tensor_list = []
90+
for _ in size_list:
91+
tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
92+
if local_size != max_size:
93+
padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
94+
tensor = torch.cat((tensor, padding), dim=0)
95+
dist.all_gather(tensor_list, tensor)
96+
97+
data_list = []
98+
for size, tensor in zip(size_list, tensor_list):
99+
buffer = tensor.cpu().numpy().tobytes()[:size]
100+
data_list.append(pickle.loads(buffer))
101+
102+
return data_list
103+
104+
105+
def reduce_dict(input_dict, average=True):
106+
"""
107+
Args:
108+
input_dict (dict): all the values will be reduced
109+
average (bool): whether to do average or sum
110+
Reduce the values in the dictionary from all processes so that process with rank
111+
0 has the averaged results. Returns a dict with the same fields as
112+
input_dict, after reduction.
113+
"""
114+
world_size = comm.world_size
115+
if world_size < 2:
116+
return input_dict
117+
with torch.no_grad():
118+
names = []
119+
values = []
120+
# sort the keys so that they are consistent across processes
121+
for k in sorted(input_dict.keys()):
122+
names.append(k)
123+
values.append(input_dict[k])
124+
values = torch.stack(values, dim=0)
125+
dist.reduce(values, dst=0)
126+
if dist.get_rank() == 0 and average:
127+
# only main process gets accumulated, so only divide by
128+
# world_size in this case
129+
values /= world_size
130+
reduced_dict = {k: v for k, v in zip(names, values)}
131+
return reduced_dict
132+

config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import numpy as np
2+
import os
3+
4+
class Config(object):
5+
def __init__(self, args):
6+
self.lr = eval(args.lr)
7+
self.lr_str = args.lr
8+
9+
def __str__(self):
10+
attrs = vars(self)
11+
attr_lst = sorted(attrs.keys())
12+
return '\n'.join("- %s: %s" % (item, attrs[item]) for item in attr_lst if item != 'lr')
13+

config_cvt.py

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
5+
import os.path as op
6+
import yaml
7+
from yacs.config import CfgNode as CN
8+
9+
from comm import comm
10+
11+
12+
_C = CN()
13+
14+
_C.BASE = ['']
15+
_C.NAME = ''
16+
_C.DATA_DIR = ''
17+
_C.DIST_BACKEND = 'nccl'
18+
_C.GPUS = (0,)
19+
# _C.LOG_DIR = ''
20+
_C.MULTIPROCESSING_DISTRIBUTED = True
21+
_C.OUTPUT_DIR = ''
22+
_C.PIN_MEMORY = True
23+
_C.PRINT_FREQ = 20
24+
_C.RANK = 0
25+
_C.VERBOSE = True
26+
_C.WORKERS = 4
27+
_C.MODEL_SUMMARY = False
28+
29+
_C.AMP = CN()
30+
_C.AMP.ENABLED = False
31+
_C.AMP.MEMORY_FORMAT = 'nchw'
32+
33+
# Cudnn related params
34+
_C.CUDNN = CN()
35+
_C.CUDNN.BENCHMARK = True
36+
_C.CUDNN.DETERMINISTIC = False
37+
_C.CUDNN.ENABLED = True
38+
39+
# common params for NETWORK
40+
_C.MODEL = CN()
41+
_C.MODEL.NAME = 'cls_hrnet'
42+
_C.MODEL.INIT_WEIGHTS = True
43+
_C.MODEL.PRETRAINED = ''
44+
_C.MODEL.PRETRAINED_LAYERS = ['*']
45+
_C.MODEL.NUM_CLASSES = 1000
46+
_C.MODEL.SPEC = CN(new_allowed=True)
47+
48+
49+
50+
51+
52+
_C.LOSS = CN(new_allowed=True)
53+
_C.LOSS.LABEL_SMOOTHING = 0.0
54+
_C.LOSS.LOSS = 'softmax'
55+
56+
# DATASET related params
57+
_C.DATASET = CN()
58+
_C.DATASET.ROOT = ''
59+
_C.DATASET.DATASET = 'imagenet'
60+
_C.DATASET.TRAIN_SET = 'train'
61+
_C.DATASET.TEST_SET = 'val'
62+
_C.DATASET.DATA_FORMAT = 'jpg'
63+
_C.DATASET.LABELMAP = ''
64+
_C.DATASET.TRAIN_TSV_LIST = []
65+
_C.DATASET.TEST_TSV_LIST = []
66+
_C.DATASET.SAMPLER = 'default'
67+
68+
_C.DATASET.TARGET_SIZE = -1
69+
70+
# training data augmentation
71+
_C.INPUT = CN()
72+
_C.INPUT.MEAN = [0.485, 0.456, 0.406]
73+
_C.INPUT.STD = [0.229, 0.224, 0.225]
74+
75+
# data augmentation
76+
_C.AUG = CN()
77+
_C.AUG.SCALE = (0.08, 1.0)
78+
_C.AUG.RATIO = (3.0/4.0, 4.0/3.0)
79+
_C.AUG.COLOR_JITTER = [0.4, 0.4, 0.4, 0.1, 0.0]
80+
_C.AUG.GRAY_SCALE = 0.0
81+
_C.AUG.GAUSSIAN_BLUR = 0.0
82+
_C.AUG.DROPBLOCK_LAYERS = [3, 4]
83+
_C.AUG.DROPBLOCK_KEEP_PROB = 1.0
84+
_C.AUG.DROPBLOCK_BLOCK_SIZE = 7
85+
_C.AUG.MIXUP_PROB = 0.0
86+
_C.AUG.MIXUP = 0.0
87+
_C.AUG.MIXCUT = 0.0
88+
_C.AUG.MIXCUT_MINMAX = []
89+
_C.AUG.MIXUP_SWITCH_PROB = 0.5
90+
_C.AUG.MIXUP_MODE = 'batch'
91+
_C.AUG.MIXCUT_AND_MIXUP = False
92+
_C.AUG.INTERPOLATION = 2
93+
_C.AUG.TIMM_AUG = CN(new_allowed=True)
94+
_C.AUG.TIMM_AUG.USE_LOADER = False
95+
_C.AUG.TIMM_AUG.USE_TRANSFORM = False
96+
97+
# train
98+
_C.TRAIN = CN()
99+
100+
_C.TRAIN.AUTO_RESUME = True
101+
_C.TRAIN.CHECKPOINT = ''
102+
_C.TRAIN.LR_SCHEDULER = CN(new_allowed=True)
103+
_C.TRAIN.SCALE_LR = True
104+
_C.TRAIN.LR = 0.001
105+
106+
_C.TRAIN.OPTIMIZER = 'sgd'
107+
_C.TRAIN.OPTIMIZER_ARGS = CN(new_allowed=True)
108+
_C.TRAIN.MOMENTUM = 0.9
109+
_C.TRAIN.WD = 0.0001
110+
_C.TRAIN.WITHOUT_WD_LIST = []
111+
_C.TRAIN.NESTEROV = True
112+
# for adam
113+
_C.TRAIN.GAMMA1 = 0.99
114+
_C.TRAIN.GAMMA2 = 0.0
115+
116+
_C.TRAIN.BEGIN_EPOCH = 0
117+
_C.TRAIN.END_EPOCH = 100
118+
119+
_C.TRAIN.IMAGE_SIZE = [224, 224] # width * height, ex: 192 * 256
120+
_C.TRAIN.BATCH_SIZE_PER_GPU = 32
121+
_C.TRAIN.SHUFFLE = True
122+
123+
_C.TRAIN.EVAL_BEGIN_EPOCH = 0
124+
125+
_C.TRAIN.DETECT_ANOMALY = False
126+
127+
_C.TRAIN.CLIP_GRAD_NORM = 0.0
128+
_C.TRAIN.SAVE_ALL_MODELS = False
129+
130+
# testing
131+
_C.TEST = CN()
132+
133+
# size of images for each device
134+
_C.TEST.BATCH_SIZE_PER_GPU = 32
135+
_C.TEST.CENTER_CROP = True
136+
_C.TEST.IMAGE_SIZE = [224, 224] # width * height, ex: 192 * 256
137+
_C.TEST.INTERPOLATION = 2
138+
_C.TEST.MODEL_FILE = ''
139+
_C.TEST.REAL_LABELS = False
140+
_C.TEST.VALID_LABELS = ''
141+
142+
_C.FINETUNE = CN()
143+
_C.FINETUNE.FINETUNE = False
144+
_C.FINETUNE.USE_TRAIN_AUG = False
145+
_C.FINETUNE.BASE_LR = 0.003
146+
_C.FINETUNE.BATCH_SIZE = 512
147+
_C.FINETUNE.EVAL_EVERY = 3000
148+
_C.FINETUNE.TRAIN_MODE = True
149+
# _C.FINETUNE.MODEL_FILE = ''
150+
_C.FINETUNE.FROZEN_LAYERS = []
151+
_C.FINETUNE.LR_SCHEDULER = CN(new_allowed=True)
152+
_C.FINETUNE.LR_SCHEDULER.DECAY_TYPE = 'step'
153+
154+
# debug
155+
_C.DEBUG = CN()
156+
_C.DEBUG.DEBUG = False
157+
158+
159+
def _update_config_from_file(config, cfg_file):
160+
config.defrost()
161+
with open(cfg_file, 'r') as f:
162+
yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
163+
164+
for cfg in yaml_cfg.setdefault('BASE', ['']):
165+
if cfg:
166+
_update_config_from_file(
167+
config, op.join(op.dirname(cfg_file), cfg)
168+
)
169+
print('=> merge config from {}'.format(cfg_file))
170+
config.merge_from_file(cfg_file)
171+
config.freeze()
172+
173+
174+
def update_config(config, args):
175+
_update_config_from_file(config, args.cfg)
176+
177+
config.defrost()
178+
config.merge_from_list(args.opts)
179+
if config.TRAIN.SCALE_LR:
180+
config.TRAIN.LR *= comm.world_size
181+
file_name, _ = op.splitext(op.basename(args.cfg))
182+
config.NAME = file_name + config.NAME
183+
config.RANK = comm.rank
184+
185+
if 'timm' == config.TRAIN.LR_SCHEDULER.METHOD:
186+
config.TRAIN.LR_SCHEDULER.ARGS.epochs = config.TRAIN.END_EPOCH
187+
188+
if 'timm' == config.TRAIN.OPTIMIZER:
189+
config.TRAIN.OPTIMIZER_ARGS.lr = config.TRAIN.LR
190+
191+
aug = config.AUG
192+
if aug.MIXUP > 0.0 or aug.MIXCUT > 0.0 or aug.MIXCUT_MINMAX:
193+
aug.MIXUP_PROB = 1.0
194+
config.freeze()
195+
196+
197+
def save_config(cfg, path):
198+
if comm.is_main_process():
199+
with open(path, 'w') as f:
200+
f.write(cfg.dump())
201+
202+
203+
if __name__ == '__main__':
204+
import sys
205+
with open(sys.argv[1], 'w') as f:
206+
print(_C, file=f)
207+

0 commit comments

Comments
 (0)