-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathtrain_probe.py
More file actions
157 lines (111 loc) · 5.09 KB
/
train_probe.py
File metadata and controls
157 lines (111 loc) · 5.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""
Trains a probe given a load_dir (looks for a model.pth checkpoint in load_dir).
The script also saves the probe (in load_dir/probe.pth), and records the accuracy (both square/cell accuracy and board accuracy) through training as well as at the end.
You can modify some parameters at the top of the script.
If you didn't change any directories and left as default in train.py and create_data_probing.py, you don't need to modify dir_activations and dir_boards.
You can also set the load_dir and layer via command line when launching the script :
python train_probe.py --load_dir=runs/celestial-frog-68 --layer=6
"""
import os
import json
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from data import ProbingDataset
from models.lm import LM
from models.transformer.transformer import TransformerConfig
from models.mamba.mamba import MambaConfig
from models.mamba.jamba import JambaConfig
from eval import eval_probe_accuracy
# -------------------------------------------------------
layer = None
load_dir = "" # run directory
dir_activations = None # will default to load_dir/data_probing/layer_{layer} if None
dir_boards = None # will default to load_dir/data_probing if None
save_dir = None # will default to load_dir/probe_{layer}.pth if None
batch_size = 256
num_iters = 20000
n_games = 500 # number of games to compute final acc
# probe training parameters
lr = 1e-4
weight_decay = 0.01
adam_b1 = 0.9
adam_b2 = 0.99
print_interval = 1000
device = "cuda" if torch.cuda.is_available() else "cpu"
# -------------------------------------------------------
parser = argparse.ArgumentParser()
parser.add_argument("--load_dir", type=str, default=None, help="something like runs/name_run/ (will look for a model.pth in this dir)")
parser.add_argument("--layer", type=int, default=None)
args = parser.parse_args()
if args.layer is not None:
layer = args.layer
if args.load_dir is not None:
load_dir = args.load_dir
assert load_dir is not None, "Please provide the run path (either as an argument or in the load_dir variable in the file)"
if dir_activations is None:
dir_activations = os.path.join(load_dir, "data_probing", f"layer_{layer}")
if dir_boards is None:
dir_boards = os.path.join(load_dir, "data_probing")
ds = ProbingDataset(dir_activations=dir_activations, dir_boards=dir_boards)
loader = torch.utils.data.DataLoader(ds, batch_size=batch_size, num_workers=0, pin_memory=True)
config_dir = os.path.join(load_dir, 'config.json')
checkpoint_dir = os.path.join(load_dir, 'model.pth')
config_json = json.load(open(config_dir))
architecture = config_json['architecture']
del config_json['architecture']
if architecture == "Transformer":
config = TransformerConfig(**config_json)
elif architecture == "Mamba":
config = MambaConfig(**config_json)
elif architecture == "Jamba":
config = JambaConfig(**config_json)
else:
raise NotImplementedError
model = LM(config, vocab_size=65).to(device)
checkpoint = torch.load(checkpoint_dir, map_location=device)
model.load_state_dict(checkpoint['model'])
print(f"Successfully loaded checkpoint from {load_dir}.")
model.eval()
class Probe(nn.Module):
def __init__(self, d_model: int):
super().__init__()
self.fc = nn.Linear(d_model, 3*8*8, bias=True)
# 3 = number of cell types (empty=0, yours=1, mine=2)
# 8*8 = board size
def forward(self, x):
# x : (B, 512) -> y : (B, 3*8*8)
return self.fc(x)
probe = Probe(config.d_model).to(device)
optim = torch.optim.AdamW(probe.parameters(), lr=lr, weight_decay=weight_decay, betas=(adam_b1, adam_b2))
print("Starting training...")
for iter, data in enumerate(loader):
activations, boards = data
activations, boards = activations.to(device), boards.to(device)
boards = boards.long()
logits = probe(activations)
loss = F.cross_entropy(logits.view(-1, 3), boards.view(-1), ignore_index=-100)
optim.zero_grad()
loss.backward()
optim.step()
# printing
if iter % print_interval == 0:
cell_acc, board_acc = eval_probe_accuracy(model, probe, layer, device, n_games=10)
num_digits = len(str(num_iters))
formatted_iter = f"{iter:0{num_digits}d}"
print(f"Step {formatted_iter}/{num_iters}. train loss = {loss.item():.3f}. mean cell acc = {100*cell_acc:.2f}%. mean board acc = {100*board_acc:.2f}%")
if iter >= num_iters:
break
print("Training done.")
save_dir = os.path.join(load_dir, f"probe_{layer}.pth")
checkpoint = {"probe": probe.state_dict()}
torch.save(checkpoint, save_dir)
print(f"Sucessfully saved trained probe in {save_dir}")
cell_acc, board_acc = eval_probe_accuracy(model, probe, layer, device, n_games=500)
print(f"Mean cell accuracy: {100*cell_acc:.2f}% (vs {66}% for an untrained model)")
print(f"Mean board accuracy: {100*board_acc:.2f}% (vs {0}% for an untrained model)")
# its important to compare the results with the "trained probe on an untrained model" setup :
# 1) untrained probe on trained model : 33%, 0%
# 2) untrained probe on untrained model : 33%, 0%
# 3) trained probe on untrained model : 66%, 0% (most important)