@@ -123,7 +123,7 @@ def descendant_dataset(p, num, seed=0, device='cpu'):
123123 np .random .seed (seed )
124124
125125 N_sample = num
126- x = np .random .choice (range (1 ,p ), N_sample * 2 ).reshape (N_sample , 2 )
126+ x = np .random .choice (range (2 ,p ), N_sample * 2 ).reshape (N_sample , 2 )
127127
128128 # Check if b is a descendant of a
129129 # In a complete binary tree where two children of x is 2x and 2x+1
@@ -133,7 +133,7 @@ def is_desc(a, b):
133133 return True
134134 b //= 2 # Move up to the parent node
135135 return b == a
136- target = np .array ([( p + 1 ) if is_desc (x [i ,0 ], x [i ,1 ]) else p for i in range (N_sample )])
136+ target = np .array ([1 if is_desc (x [i ,0 ]- 1 , x [i ,1 ]- 1 ) else 0 for i in range (N_sample )])
137137
138138 data_id = torch .from_numpy (x ).to (device )
139139 labels = torch .from_numpy (target ).to (device )
@@ -145,4 +145,115 @@ def is_desc(a, b):
145145 dataset ['label' ] = labels
146146 dataset ['vocab_size' ] = vocab_size
147147
148+ return dataset
149+
150+ def descendant_dataset_2 (p , num , seed = 0 , device = 'cpu' ):
151+
152+ torch .manual_seed (seed )
153+ np .random .seed (seed )
154+
155+ N_sample = num * 4
156+ x = np .random .choice (range (1 ,(p - 1 )// 2 ), num * 2 ).reshape (num , 2 )
157+
158+ data = np .zeros ((N_sample , 4 ), dtype = np .int32 )
159+ data [:num ,0 ] = x [:,0 ]
160+ data [:num ,1 ] = 2 * x [:,0 ]
161+ data [:num ,2 ] = x [:,1 ]
162+ data [:num ,3 ] = 2 * x [:,1 ]
163+
164+ data [num :(2 * num ),0 ] = x [:,0 ]
165+ data [num :(2 * num ),1 ] = 2 * x [:,0 ] + 1
166+ data [num :(2 * num ),2 ] = x [:,1 ]
167+ data [num :(2 * num ),3 ] = 2 * x [:,1 ] + 1
168+
169+ data [2 * num :(3 * num ),0 ] = 2 * x [:,0 ] + 1
170+ data [2 * num :(3 * num ),1 ] = x [:,0 ]
171+ data [2 * num :(3 * num ),2 ] = 2 * x [:,1 ] + 1
172+ data [2 * num :(3 * num ),3 ] = x [:,1 ]
173+
174+ data [3 * num :(4 * num ),0 ] = 2 * x [:,0 ] + 1
175+ data [3 * num :(4 * num ),1 ] = x [:,0 ]
176+ data [3 * num :(4 * num ),2 ] = 2 * x [:,1 ] + 1
177+ data [3 * num :(4 * num ),3 ] = x [:,1 ]
178+
179+ np .random .shuffle (data )
180+
181+ data_id = torch .from_numpy (data [:, :3 ]).to (device )
182+ labels = torch .from_numpy (data [:, 3 ]).to (device )
183+
184+ vocab_size = p + 1
185+
186+ dataset = {}
187+ dataset ['data_id' ] = data_id
188+ dataset ['label' ] = labels
189+ dataset ['vocab_size' ] = vocab_size
190+
191+ return dataset
192+
193+
194+ def greater_than_dataset (p , num , seed = 0 , device = 'cpu' ):
195+
196+ torch .manual_seed (seed )
197+ np .random .seed (seed )
198+
199+ N_sample = num
200+ x = np .random .choice (range (p ), N_sample * 2 ).reshape (N_sample , 2 )
201+
202+ target = np .array ([p + 1 if x [i ,0 ] > x [i ,1 ] else p for i in range (N_sample )])
203+
204+ data_id = torch .from_numpy (x ).to (device )
205+ labels = torch .from_numpy (target ).to (device )
206+
207+ vocab_size = p + 2
208+
209+ dataset = {}
210+ dataset ['data_id' ] = data_id
211+ dataset ['label' ] = labels
212+ dataset ['vocab_size' ] = vocab_size
213+
214+ return dataset
215+
216+
217+ def xor_dataset (p , num , seed = 0 , device = 'cpu' ):
218+
219+ torch .manual_seed (seed )
220+ np .random .seed (seed )
221+
222+ N_sample = num
223+ x = np .random .choice (range (p ), N_sample * 2 ).reshape (N_sample , 2 )
224+
225+ target = np .array ([x [i ,0 ]^ x [i ,1 ] for i in range (N_sample )])
226+
227+ data_id = torch .from_numpy (x ).to (device )
228+ labels = torch .from_numpy (target ).to (device )
229+
230+ vocab_size = p + 2
231+
232+ dataset = {}
233+ dataset ['data_id' ] = data_id
234+ dataset ['label' ] = labels
235+ dataset ['vocab_size' ] = vocab_size
236+
237+ return dataset
238+
239+ def multi_step_dataset (p , num , seed = 0 , device = 'cpu' ):
240+
241+ torch .manual_seed (seed )
242+ np .random .seed (seed )
243+
244+ N_sample = num
245+ x = np .random .choice (range (p ), N_sample * 3 ).reshape (N_sample , 3 )
246+
247+ target = np .array ([(x [i ,0 ]* x [i ,1 ]+ x [i ,2 ])% p for i in range (N_sample )])
248+
249+ data_id = torch .from_numpy (x ).to (device )
250+ labels = torch .from_numpy (target ).to (device )
251+
252+ vocab_size = p
253+
254+ dataset = {}
255+ dataset ['data_id' ] = data_id
256+ dataset ['label' ] = labels
257+ dataset ['vocab_size' ] = vocab_size
258+
148259 return dataset
0 commit comments