1+ from __future__ import print_function
2+ import os
3+
4+ os .environ ["CUDA_VISIBLE_DEVICES" ] = "0"
5+ import torch .nn as nn
6+ import torch .optim as optim
7+ from torch .autograd import Variable
8+ from torch .optim .lr_scheduler import StepLR
9+ from sklearn .metrics import accuracy_score , classification_report
10+ from dataloader import *
11+ import random
12+ from models_weakly import *
13+ import warnings
14+ warnings .filterwarnings ("ignore" )
15+
16+ random .seed (3344 )
17+ import time
18+ import argparse
19+
20+ parser = argparse .ArgumentParser (description = 'AVE' )
21+
22+ # Data specifications
23+ parser .add_argument ('--dir_video' , type = str , default = "data/visual_feature.h5" ,
24+ help = 'dataset directory' )
25+ parser .add_argument ('--dir_video_bg' , type = str , default = "data/video_feature_noisy.h5" ,
26+ help = 'dataset directory' )
27+
28+ parser .add_argument ('--dir_audio' , type = str ,
29+ default = 'data/audio_feature.h5' ,
30+ help = 'dataset directory' )
31+
32+ parser .add_argument ('--dir_audio_bg' , type = str ,
33+ default = 'data/audio_feature_noisy.h5' ,
34+ help = 'dataset directory' )
35+
36+ parser .add_argument ('--dir_labels' , type = str , default = 'data/mil_labels.h5' ,
37+ help = 'dataset directory' )
38+ parser .add_argument ('--dir_labels_bg' , type = str , default = 'data/labels_noisy.h5' ,
39+ help = 'dataset directory' )
40+ parser .add_argument ('--dir_labels_gt' , type = str , default = 'data/labels.h5' ,
41+ help = 'dataset directory' )
42+
43+ parser .add_argument ('--dir_order_train' , type = str , default = 'data/train_order.h5' ,
44+ help = 'dataset directory' )
45+
46+ parser .add_argument ('--dir_order_val' , type = str , default = 'data/val_order.h5' ,
47+ help = 'dataset directory' )
48+ parser .add_argument ('--dir_order_test' , type = str , default = 'data/test_order.h5' ,
49+ help = 'dataset directory' )
50+
51+ parser .add_argument ('--nb_epoch' , type = int , default = 250 ,
52+ help = 'number of epoch' )
53+ parser .add_argument ('--batch_size' , type = int , default = 64 ,
54+ help = 'number of batch size' )
55+ parser .add_argument ('--train' , action = 'store_true' , default = False ,
56+ help = 'train a new model' )
57+
58+ args = parser .parse_args ()
59+
60+ # model
61+ model_name = 'AV_att_weak'
62+ net_model = att_Net (128 , 128 , 512 , 29 )
63+ net_model .cuda ()
64+
65+ net_model .cuda ()
66+ loss_function = nn .MultiLabelSoftMarginLoss ()
67+ optimizer = optim .Adam (net_model .parameters (), lr = 1e-3 )
68+ scheduler = StepLR (optimizer , step_size = 15000 , gamma = 0.1 )
69+
70+
71+ def train (args ):
72+ AVEData = AVE_weak_Dataset (video_dir = args .dir_video , video_dir_bg = args .dir_video_bg , audio_dir = args .dir_audio ,
73+ audio_dir_bg = args .dir_audio_bg , label_dir = args .dir_labels ,label_dir_bg = args .dir_labels_bg ,
74+ label_dir_gt = args .dir_labels_gt ,
75+ order_dir = args .dir_order_train , batch_size = args .batch_size , status = "train" )
76+ nb_batch = AVEData .__len__ () // args .batch_size
77+ print (AVEData .__len__ ())
78+ epoch_l = []
79+ best_val_acc = 0
80+ for epoch in range (args .nb_epoch ):
81+ epoch_loss = 0
82+ n = 0
83+ start = time .time ()
84+ for i in range (nb_batch ):
85+ audio_inputs , video_inputs , labels = AVEData .get_batch (i )
86+ audio_inputs = Variable (audio_inputs .cuda (), requires_grad = False )
87+ video_inputs = Variable (video_inputs .cuda (), requires_grad = False )
88+ labels = Variable (labels .cuda (), requires_grad = False )
89+ net_model .zero_grad ()
90+ scores , _ = net_model (audio_inputs , video_inputs )
91+ loss = loss_function (scores , labels )
92+ epoch_loss += loss .cpu ().data .numpy ()
93+ loss .backward ()
94+ scheduler .step ()
95+ optimizer .step ()
96+ n = n + 1
97+
98+ end = time .time ()
99+ epoch_l .append (epoch_loss )
100+ print ("=== Epoch {%s} Loss: {%.4f} Running time: {%4f}" % (str (epoch ), (epoch_loss ) / n , end - start ))
101+ if epoch % 5 == 0 :
102+ val_acc = val (args )
103+ if val_acc > best_val_acc :
104+ torch .save (net_model , 'saved_models/' + model_name + ".pt" )
105+
106+
107+
108+ def val (args ):
109+ net_model .eval ()
110+ AVEData = AVE_weak_Dataset (video_dir = args .dir_video , video_dir_bg = args .dir_video_bg , audio_dir = args .dir_audio ,
111+ audio_dir_bg = args .dir_audio_bg , label_dir = args .dir_labels , label_dir_bg = args .dir_labels_bg ,
112+ label_dir_gt = args .dir_labels_gt , order_dir = args .dir_order_val , batch_size = 402 , status = "val" )
113+ nb_batch = AVEData .__len__ ()
114+ audio_inputs , video_inputs , labels = AVEData .get_batch (0 )
115+ audio_inputs = Variable (audio_inputs .cuda (), requires_grad = False )
116+ video_inputs = Variable (video_inputs .cuda (), requires_grad = False )
117+ labels = labels .numpy ()
118+ _ , x_labels = net_model (audio_inputs , video_inputs )
119+ #print(x_labels)
120+ x_labels = x_labels .cpu ().data .numpy ()
121+
122+ N = int (nb_batch * 10 )
123+ pre_labels = np .zeros (N )
124+ real_labels = np .zeros (N )
125+ c = 0
126+ for i in range (nb_batch ):
127+ for j in range (x_labels .shape [1 ]): # 10
128+ pre_labels [c ] = np .argmax (x_labels [i , j , :])
129+ real_labels [c ] = np .argmax (labels [i , j , :])
130+ c += 1
131+ target_names = []
132+ for i in range (29 ):
133+ target_names .append ("class" + str (i ))
134+ print (accuracy_score (real_labels , pre_labels ))
135+ return accuracy_score (real_labels , pre_labels )
136+
137+
138+ def test (args ):
139+ model = torch .load ('model/' + model_name + ".pt" )
140+ model .eval ()
141+ AVEData = AVE_weak_Dataset (video_dir = args .dir_video , video_dir_bg = args .dir_video_bg , audio_dir = args .dir_audio ,
142+ audio_dir_bg = args .dir_audio_bg , label_dir = args .dir_labels , label_dir_bg = args .dir_labels_bg ,
143+ label_dir_gt = args .dir_labels_gt ,
144+ order_dir = args .dir_order_test , batch_size = 402 , status = "test" )
145+ nb_batch = AVEData .__len__ ()
146+ print (nb_batch )
147+ audio_inputs , video_inputs , labels = AVEData .get_batch (0 )
148+ audio_inputs = Variable (audio_inputs .cuda (), requires_grad = False )
149+ video_inputs = Variable (video_inputs .cuda (), requires_grad = False )
150+ labels = labels .numpy ()
151+ _ , x_labels = model (audio_inputs , video_inputs )
152+ x_labels = x_labels .cpu ().data .numpy ()
153+
154+ N = int (nb_batch * 10 )
155+ pre_labels = np .zeros (N )
156+ real_labels = np .zeros (N )
157+ c = 0
158+ for i in range (nb_batch ):
159+ for j in range (x_labels .shape [1 ]): # 10
160+ pre_labels [c ] = np .argmax (x_labels [i , j , :])
161+ real_labels [c ] = np .argmax (labels [i , j , :])
162+ # print(pre_labels[c], real_labels[c])
163+ c += 1
164+ target_names = []
165+ for i in range (29 ):
166+ target_names .append ("class" + str (i ))
167+ print (accuracy_score (real_labels , pre_labels ))
168+
169+
170+ if args .train :
171+ train (args )
172+ else :
173+ test (args )
0 commit comments