Skip to content
This repository was archived by the owner on Jan 26, 2024. It is now read-only.

Commit 3c3f1b9

Browse files
committed
Add teacher forcing ratio
1 parent ca87729 commit 3c3f1b9

File tree

3 files changed

+14
-4
lines changed

3 files changed

+14
-4
lines changed

config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,6 @@
2020
'window_size': 200,
2121
'stride_size': 10,
2222
'use_transposition': False,
23-
'control_ratio': 1.0
23+
'control_ratio': 1.0,
24+
'teacher_forcing_ratio': 1.0
2425
}

model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def expand_controls(self, controls, steps):
9898
return controls.repeat(steps, 1, 1)
9999

100100
def generate(self, init, steps, events=None, controls=None, greedy=1.0,
101-
temperature=1.0, output_type='index', verbose=False):
101+
temperature=1.0, teacher_forcing_ratio=1.0, output_type='index', verbose=False):
102102
# init [batch_size, init_dim]
103103
# events [steps, batch_size] indeces
104104
# controls [1 or steps, batch_size, control_dim]
@@ -142,7 +142,8 @@ def generate(self, init, steps, events=None, controls=None, greedy=1.0,
142142
assert False
143143

144144
if use_teacher_forcing and step < steps - 1: # avoid last one
145-
event = events[step].unsqueeze(0)
145+
if np.random.random() <= teacher_forcing_ratio:
146+
event = events[step].unsqueeze(0)
146147

147148
return torch.cat(outputs, 0)
148149

train.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ def get_options():
6868
type='float',
6969
default=config.train['control_ratio'])
7070

71+
parser.add_option('-T', '--teacher-forcing-ratio',
72+
dest='teacher_forcing_ratio',
73+
type='float',
74+
default=config.train['teacher_forcing_ratio'])
75+
7176
parser.add_option('-t', '--use-transposition',
7277
dest='use_transposition',
7378
action='store_true',
@@ -99,6 +104,7 @@ def get_options():
99104
stride_size = options.stride_size
100105
use_transposition = options.use_transposition
101106
control_ratio = options.control_ratio
107+
teacher_forcing_ratio = options.teacher_forcing_ratio
102108
reset_optimizer = options.reset_optimizer
103109

104110
event_dim = EventSeq.dim()
@@ -121,6 +127,7 @@ def get_options():
121127
print('Window size:', window_size)
122128
print('Stride size:', stride_size)
123129
print('Control ratio:', control_ratio)
130+
print('Teacher forcing ratio:', teacher_forcing_ratio)
124131
print('Random transposition:', use_transposition)
125132
print('Reset optimizer:', reset_optimizer)
126133
print('Device:', device)
@@ -201,7 +208,8 @@ def save_model():
201208
controls = None
202209

203210
init = torch.randn(batch_size, model.init_dim).to(device)
204-
outputs = model.generate(init, window_size, events[:-1], controls, output_type='logit')
211+
outputs = model.generate(init, window_size, events=events[:-1], controls=controls,
212+
teacher_forcing_ratio=teacher_forcing_ratio, output_type='logit')
205213
assert outputs.shape[:2] == events.shape[:2]
206214

207215
loss = loss_function(outputs.view(-1, event_dim), events.view(-1))

0 commit comments

Comments
 (0)