-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathextract_hidden_states.py
More file actions
68 lines (51 loc) · 2.72 KB
/
extract_hidden_states.py
File metadata and controls
68 lines (51 loc) · 2.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from utils.trainer import load_model, extract_last_states
from argparse import ArgumentParser
import os
from utils.data import get_dataloaders, has_token_types
from utils.tokenizers import STR2TOKENIZER
from transformers.models.bert.modeling_bert import BertPooler
import torch
from utils.path import get_default_path
from utils.modules import parse_hf_name_to_head
import torch._dynamo
torch._dynamo.config.suppress_errors = True
def save(states, labels, output_path, prefix):
torch.save(states, os.path.join(output_path, f"{prefix}_states.pt"))
torch.save(labels, os.path.join(output_path, f"{prefix}_labels.pt"))
def extract_and_save(model, datasets, split_name, output_path):
os.makedirs(output_path, exist_ok=True)
states, labels = extract_last_states(model, datasets[split_name], split_name)
save(states, labels, output_path, split_name)
def extract_states(model_name, dataset, task, device, calibration_train_split_size):
output_path = get_default_path(model_name, dataset, task)
os.makedirs(output_path, exist_ok=True)
model = load_model(model_name, output_hidden_states=True)
model.eval()
model.to(device)
tokenizer = STR2TOKENIZER[dataset][task](model_name, max_length=512, truncation=True)
datasets, _ = get_dataloaders(
dataset, task, tokenizer, model.device, batch_size=16, calibration_train_split_size=calibration_train_split_size, requires_token_type_ids=has_token_types(model_name)
)
split_names = [x for x in datasets.keys() if x != "train"]
for split_name in split_names:
extract_and_save(
model, datasets, split_name, os.path.join(output_path, f"{str(int(calibration_train_split_size * 100))}/")
)
# pooler: BertPooler = model.bert.pooler
# cls: torch.nn.Linear = model.classifier # for roberta this is different
# state_dict = pooler.state_dict()
# state_dict.update({f"classifier.{k}": v for k, v in cls.state_dict().items()})
head = parse_hf_name_to_head(model_name)
state_dict = head.convert_model_head_to_states(model)
torch.save(state_dict, os.path.join(output_path, f"classifier.pt"))
def main():
arg_parser = ArgumentParser()
arg_parser.add_argument("--model_name", type=str, default="yoshitomo-matsubara/bert-large-uncased-qnli")
arg_parser.add_argument("--dataset", type=str, default="sentiment")
arg_parser.add_argument("--task", type=str, default="twitter")
arg_parser.add_argument("--device", type=str, default="cuda")
arg_parser.add_argument("--calibration_train_split_size", type=float, default=0.2)
args = arg_parser.parse_args()
extract_states(args.model_name, args.dataset, args.task, args.device, args.calibration_train_split_size)
if __name__ == '__main__':
main()