Skip to content

Commit 47d0159

Browse files
committed
Added Datasets
1 parent e3292c5 commit 47d0159

File tree

8 files changed

+122
-661
lines changed

8 files changed

+122
-661
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@ __pycache__
22
results
33

44
scratch.ipynb
5-
slurm*.out
5+
*/slurm*.out

dataset.py

Lines changed: 113 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def descendant_dataset(p, num, seed=0, device='cpu'):
123123
np.random.seed(seed)
124124

125125
N_sample = num
126-
x = np.random.choice(range(1,p), N_sample*2).reshape(N_sample, 2)
126+
x = np.random.choice(range(2,p), N_sample*2).reshape(N_sample, 2)
127127

128128
# Check if b is a descendant of a
129129
# In a complete binary tree where two children of x is 2x and 2x+1
@@ -133,7 +133,7 @@ def is_desc(a, b):
133133
return True
134134
b //= 2 # Move up to the parent node
135135
return b == a
136-
target = np.array([(p+1) if is_desc(x[i,0], x[i,1]) else p for i in range(N_sample)])
136+
target = np.array([1 if is_desc(x[i,0]-1, x[i,1]-1) else 0 for i in range(N_sample)])
137137

138138
data_id = torch.from_numpy(x).to(device)
139139
labels = torch.from_numpy(target).to(device)
@@ -145,4 +145,115 @@ def is_desc(a, b):
145145
dataset['label'] = labels
146146
dataset['vocab_size'] = vocab_size
147147

148+
return dataset
149+
150+
def descendant_dataset_2(p, num, seed=0, device='cpu'):
151+
152+
torch.manual_seed(seed)
153+
np.random.seed(seed)
154+
155+
N_sample = num*4
156+
x = np.random.choice(range(1,(p-1)//2), num*2).reshape(num, 2)
157+
158+
data = np.zeros((N_sample, 4), dtype=np.int32)
159+
data[:num,0] = x[:,0]
160+
data[:num,1] = 2*x[:,0]
161+
data[:num,2] = x[:,1]
162+
data[:num,3] = 2*x[:,1]
163+
164+
data[num:(2*num),0] = x[:,0]
165+
data[num:(2*num),1] = 2*x[:,0] + 1
166+
data[num:(2*num),2] = x[:,1]
167+
data[num:(2*num),3] = 2*x[:,1] + 1
168+
169+
data[2*num:(3*num),0] = 2*x[:,0] + 1
170+
data[2*num:(3*num),1] = x[:,0]
171+
data[2*num:(3*num),2] = 2*x[:,1] + 1
172+
data[2*num:(3*num),3] = x[:,1]
173+
174+
data[3*num:(4*num),0] = 2*x[:,0] + 1
175+
data[3*num:(4*num),1] = x[:,0]
176+
data[3*num:(4*num),2] = 2*x[:,1] + 1
177+
data[3*num:(4*num),3] = x[:,1]
178+
179+
np.random.shuffle(data)
180+
181+
data_id = torch.from_numpy(data[:, :3]).to(device)
182+
labels = torch.from_numpy(data[:, 3]).to(device)
183+
184+
vocab_size = p+1
185+
186+
dataset = {}
187+
dataset['data_id'] = data_id
188+
dataset['label'] = labels
189+
dataset['vocab_size'] = vocab_size
190+
191+
return dataset
192+
193+
194+
def greater_than_dataset(p, num, seed=0, device='cpu'):
195+
196+
torch.manual_seed(seed)
197+
np.random.seed(seed)
198+
199+
N_sample = num
200+
x = np.random.choice(range(p), N_sample*2).reshape(N_sample, 2)
201+
202+
target = np.array([p+1 if x[i,0] > x[i,1] else p for i in range(N_sample)])
203+
204+
data_id = torch.from_numpy(x).to(device)
205+
labels = torch.from_numpy(target).to(device)
206+
207+
vocab_size = p+2
208+
209+
dataset = {}
210+
dataset['data_id'] = data_id
211+
dataset['label'] = labels
212+
dataset['vocab_size'] = vocab_size
213+
214+
return dataset
215+
216+
217+
def xor_dataset(p, num, seed=0, device='cpu'):
218+
219+
torch.manual_seed(seed)
220+
np.random.seed(seed)
221+
222+
N_sample = num
223+
x = np.random.choice(range(p), N_sample*2).reshape(N_sample, 2)
224+
225+
target = np.array([x[i,0]^x[i,1] for i in range(N_sample)])
226+
227+
data_id = torch.from_numpy(x).to(device)
228+
labels = torch.from_numpy(target).to(device)
229+
230+
vocab_size = p+2
231+
232+
dataset = {}
233+
dataset['data_id'] = data_id
234+
dataset['label'] = labels
235+
dataset['vocab_size'] = vocab_size
236+
237+
return dataset
238+
239+
def multi_step_dataset(p, num, seed=0, device='cpu'):
240+
241+
torch.manual_seed(seed)
242+
np.random.seed(seed)
243+
244+
N_sample = num
245+
x = np.random.choice(range(p), N_sample*3).reshape(N_sample, 3)
246+
247+
target = np.array([(x[i,0]*x[i,1]+x[i,2])%p for i in range(N_sample)])
248+
249+
data_id = torch.from_numpy(x).to(device)
250+
labels = torch.from_numpy(target).to(device)
251+
252+
vocab_size = p
253+
254+
dataset = {}
255+
dataset['data_id'] = data_id
256+
dataset['label'] = labels
257+
dataset['vocab_size'] = vocab_size
258+
148259
return dataset

model.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def __init__(self, vocab_size, d_model, nhead, num_layers, seq_len = 16, use_dis
155155
if use_dist_layer:
156156
self.dist = DistLayer(d_model, vocab_size, n=1., eps=1e-4, bias=False)
157157
self.fc = nn.Linear(d_model, vocab_size)
158+
self.vocab_size = vocab_size
158159

159160
def forward(self, x):
160161
embedded = self.embedding(x) + self.positional_encoding
@@ -179,16 +180,21 @@ def train(self, param_dict: dict):
179180
learning_rate = param_dict['learning_rate']
180181
dataloader = param_dict['dataloader']
181182
device = param_dict['device']
182-
criterion = nn.CrossEntropyLoss()
183+
183184

184185
optimizer = optim.AdamW(self.parameters(), lr=learning_rate)
185186
for epoch in tqdm(range(num_epochs)):
186187
total_loss = 0
187188
for batch_inputs, batch_targets in dataloader:
188189
batch_inputs = batch_inputs.to(device)
189-
batch_targets = batch_targets.to(device)
190+
batch_targets = batch_targets.type(torch.LongTensor).to(device)
190191
optimizer.zero_grad()
191192
logits = self.forward(batch_inputs)
193+
194+
# class_counts = torch.bincount(batch_targets.squeeze(), minlength=self.vocab_size).double() + 1e-8
195+
# class_weights = 1 / class_counts.cuda()
196+
197+
criterion = nn.CrossEntropyLoss()#weight=class_weights)
192198

193199
loss = criterion(logits, batch_targets.squeeze())
194200
loss.backward()

0 commit comments

Comments
 (0)