-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbench_utils.py
More file actions
107 lines (81 loc) · 3.93 KB
/
bench_utils.py
File metadata and controls
107 lines (81 loc) · 3.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import time
import torch
from torch.utils.data import DataLoader
import random
import nltk
nltk.download('punkt')
def load_and_sample_sentences(num_pairs=1024, min_sentences=2, max_sentences=15, base_seed=100):
with open('data/sample_data.txt', 'r') as f:
text = f.read()
sentences = nltk.sent_tokenize(text)
# Generate random pairs
pairs = []
for i in range(num_pairs):
# Set unique seed for each iteration if base_seed provided
random.seed(base_seed + i)
n1 = random.randint(min_sentences, max_sentences)
text1 = ' '.join(random.sample(sentences, n1))
n2 = random.randint(min_sentences, max_sentences)
text2 = ' '.join(random.sample(sentences, n2))
pairs.append([text1, text2])
return pairs
def inference(model, sentence_pairs):
scores = model.predict(sentence_pairs, convert_to_tensor=True, batch_size=64)
return scores
def test_model(model):
sentence_pairs = load_and_sample_sentences(num_pairs=64, base_seed=42)
# print(sentence_pairs[:10])
with torch.inference_mode():
scores = inference(model, sentence_pairs)
print(scores.tolist()[:10])
def benchmark(model, print_scores=False, num_runs=10, trace=None, seed=100, on_sorted_inputs=False):
sentence_pairs_warmup = load_and_sample_sentences(num_pairs=512, base_seed=seed)
sentence_pairs = load_and_sample_sentences(num_pairs=1024, base_seed=2*seed)
print(max([len(model.tokenizer.encode(x)) for x in sentence_pairs]))
print(max([len(model.tokenizer.encode(x)) for x in sentence_pairs_warmup]))
# print([len(model.tokenizer.encode(x)) for x in sentence_pairs])
print(f"Total pairs: {len(sentence_pairs)}")
with torch.inference_mode():
# Warmup
print("Warming up...")
for i in range(10):
sentence_pairs_warmup = load_and_sample_sentences(num_pairs=2048, base_seed=seed + i)
_ = inference(model, sentence_pairs_warmup)
# Multiple benchmark runs
print("Benchmarking...")
times = []
for i in range(num_runs):
sentence_pairs = load_and_sample_sentences(num_pairs=1024, base_seed=2*seed + i)
if on_sorted_inputs:
# Sort by max length of each pair
lengths = [(len(model.tokenizer.encode(p[0])) + len(model.tokenizer.encode(p[1])), i)
for i, p in enumerate(sentence_pairs)]
sorted_indices = [i for _, i in sorted(lengths, reverse=True)]
sentence_pairs_sorted = [sentence_pairs[i] for i in sorted_indices]
else:
sentence_pairs_sorted = sentence_pairs
sorted_indices = None
torch.cuda.synchronize()
start_time = time.perf_counter()
scores = inference(model, sentence_pairs_sorted)
if on_sorted_inputs:
# Restore original order
original_scores = torch.empty_like(scores)
for new_idx, orig_idx in enumerate(sorted_indices):
original_scores[orig_idx] = scores[new_idx]
scores = original_scores
torch.cuda.synchronize()
end_time = time.perf_counter()
times.append(end_time - start_time)
if trace is not None:
with torch.profiler.profile() as prof:
for i in range(1, 2):
scores = inference(model, sentence_pairs)
prof.step()
prof.export_chrome_trace(f"trace/{trace}.json")
if print_scores:
print(scores.tolist()[:10])
mean_time = sum(times[-5:]) / len(times[-5:])
std_time = (sum((t - mean_time) ** 2 for t in times[-5:]) / len(times[-5:])) ** 0.5
print(f"Inference times: {[f'{t:.4f}' for t in times]}")
print(f"Mean time: {mean_time:.4f} ± {std_time:.4f} seconds")