@@ -68,6 +68,11 @@ def get_options():
6868 type = 'float' ,
6969 default = config .train ['control_ratio' ])
7070
71+ parser .add_option ('-T' , '--teacher-forcing-ratio' ,
72+ dest = 'teacher_forcing_ratio' ,
73+ type = 'float' ,
74+ default = config .train ['teacher_forcing_ratio' ])
75+
7176 parser .add_option ('-t' , '--use-transposition' ,
7277 dest = 'use_transposition' ,
7378 action = 'store_true' ,
@@ -99,6 +104,7 @@ def get_options():
99104stride_size = options .stride_size
100105use_transposition = options .use_transposition
101106control_ratio = options .control_ratio
107+ teacher_forcing_ratio = options .teacher_forcing_ratio
102108reset_optimizer = options .reset_optimizer
103109
104110event_dim = EventSeq .dim ()
@@ -121,6 +127,7 @@ def get_options():
121127print ('Window size:' , window_size )
122128print ('Stride size:' , stride_size )
123129print ('Control ratio:' , control_ratio )
130+ print ('Teacher forcing ratio:' , teacher_forcing_ratio )
124131print ('Random transposition:' , use_transposition )
125132print ('Reset optimizer:' , reset_optimizer )
126133print ('Device:' , device )
@@ -201,7 +208,8 @@ def save_model():
201208 controls = None
202209
203210 init = torch .randn (batch_size , model .init_dim ).to (device )
204- outputs = model .generate (init , window_size , events [:- 1 ], controls , output_type = 'logit' )
211+ outputs = model .generate (init , window_size , events = events [:- 1 ], controls = controls ,
212+ teacher_forcing_ratio = teacher_forcing_ratio , output_type = 'logit' )
205213 assert outputs .shape [:2 ] == events .shape [:2 ]
206214
207215 loss = loss_function (outputs .view (- 1 , event_dim ), events .view (- 1 ))
0 commit comments