@@ -67,10 +67,10 @@ def train_single_model(param_dict: dict):
6767 elif data_id == "greater" :
6868 dataset = greater_than_dataset (p = 30 , num = data_size , seed = seed , device = device )
6969 elif data_id == "family_tree" :
70- dataset = family_tree_dataset_2 (p = 127 , num = data_size , seed = seed , device = device )
70+ dataset = family_tree_dataset_2 (p = 255 , num = data_size , seed = seed , device = device )
7171 elif data_id == "equivalence" :
7272 input_token = 1
73- dataset = mod_classification_dataset (p = 100 , num = data_size , seed = seed , device = device )
73+ dataset = mod_equiv_dataset (p = 50 , num = data_size , seed = seed , device = device )
7474 elif data_id == "circle" :
7575 dataset = modular_addition_dataset (p = 31 , num = data_size , seed = seed , device = device )
7676 elif data_id == "permutation" :
@@ -87,15 +87,15 @@ def train_single_model(param_dict: dict):
8787 weight_tied = True
8888 hidden_size = 100
8989 shp = [input_token * embd_dim , hidden_size , embd_dim , vocab_size ]
90- model = MLP_HS (shp = shp , vocab_size = vocab_size , embd_dim = embd_dim , input_token = input_token , weight_tied = weight_tied , seed = seed , n = np . sqrt (embd_dim ), init_scale = 1 ).to (device )
90+ model = MLP_HS (shp = shp , vocab_size = vocab_size , embd_dim = embd_dim , input_token = input_token , weight_tied = weight_tied , seed = seed , n = (embd_dim ), init_scale = 1 ).to (device )
9191 elif model_id == "standard_MLP" :
9292 unembd = True
9393 weight_tied = True
9494 hidden_size = 100
9595 shp = [input_token * embd_dim , hidden_size , embd_dim , vocab_size ]
9696 model = MLP (shp = shp , vocab_size = vocab_size , embd_dim = embd_dim , input_token = input_token , unembd = unembd , weight_tied = weight_tied , seed = seed , init_scale = 1 ).to (device )
9797 elif model_id == "H_transformer" :
98- model = ToyTransformer (vocab_size = vocab_size , d_model = embd_dim , nhead = 2 , num_layers = 2 , n_dist = np . sqrt ( embd_dim ) ,seq_len = input_token , seed = seed , use_dist_layer = True , init_scale = 1 ).to (device )
98+ model = ToyTransformer (vocab_size = vocab_size , d_model = embd_dim , nhead = 2 , num_layers = 2 , n_dist = embd_dim ,seq_len = input_token , seed = seed , use_dist_layer = True , init_scale = 1 ).to (device )
9999 elif model_id == "standard_transformer" :
100100 model = ToyTransformer (vocab_size = vocab_size , d_model = embd_dim , nhead = 2 , num_layers = 2 , seq_len = input_token , seed = seed , use_dist_layer = False , init_scale = 1 ).to (device )
101101 else :
0 commit comments