-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtinymodel_runtime.py
More file actions
133 lines (113 loc) · 4.44 KB
/
tinymodel_runtime.py
File metadata and controls
133 lines (113 loc) · 4.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#!/usr/bin/env python3
"""General-purpose TinyModel runtime utilities.
This module extends usage beyond plain classification by exposing:
- class probabilities
- sentence embeddings from the encoder
- semantic similarity scoring
- nearest-neighbor retrieval over a candidate set
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Sequence
import torch
import torch.nn.functional as F
from transformers import AutoModelForSequenceClassification, AutoTokenizer
@dataclass
class RetrievalHit:
text: str
score: float
index: int
class TinyModelRuntime:
"""Inference helper around TinyModel classification checkpoints."""
def __init__(
self,
model_id_or_path: str,
*,
device: str | None = None,
max_length: int = 128,
) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(model_id_or_path)
self.model = AutoModelForSequenceClassification.from_pretrained(model_id_or_path)
self.model.eval()
self.max_length = max_length
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device)
self.model.to(self.device)
def _encoder_backbone(self):
"""Return the base encoder (BERT, DistilBERT, RoBERTa, etc.)."""
m = self.model
for name in ("bert", "distilbert", "roberta", "electra", "camembert", "xlm_roberta"):
if hasattr(m, name):
return getattr(m, name)
raise AttributeError(
"Could not find a supported encoder backbone on this model; "
"embeddings need BERT/DistilBERT/RoBERTa-style checkpoints."
)
def classify(self, texts: Sequence[str]) -> list[dict[str, float]]:
"""Return per-label probabilities for each input text."""
encoded = self.tokenizer(
list(texts),
truncation=True,
padding=True,
max_length=self.max_length,
return_tensors="pt",
)
encoded = {k: v.to(self.device) for k, v in encoded.items()}
with torch.inference_mode():
logits = self.model(**encoded).logits
probs = F.softmax(logits, dim=-1).cpu()
id2label = self.model.config.id2label
out: list[dict[str, float]] = []
for row in probs:
item = {id2label[i]: float(row[i]) for i in range(row.shape[0])}
out.append(item)
return out
def embed(self, texts: Sequence[str], *, normalize: bool = True) -> torch.Tensor:
"""Generate pooled sentence embeddings from the transformer encoder ([CLS] / first token)."""
encoded = self.tokenizer(
list(texts),
truncation=True,
padding=True,
max_length=self.max_length,
return_tensors="pt",
)
encoded = {k: v.to(self.device) for k, v in encoded.items()}
with torch.inference_mode():
backbone = self._encoder_backbone()
# Only pass ids/mask so DistilBERT and BERT both accept the call.
hidden = backbone(
input_ids=encoded["input_ids"],
attention_mask=encoded["attention_mask"],
return_dict=True,
).last_hidden_state
cls = hidden[:, 0, :]
if normalize:
cls = F.normalize(cls, p=2, dim=1)
return cls.cpu()
def similarity(self, text_a: str, text_b: str) -> float:
"""Cosine similarity between two texts using encoder embeddings."""
embs = self.embed([text_a, text_b], normalize=True)
score = F.cosine_similarity(embs[0].unsqueeze(0), embs[1].unsqueeze(0))
return float(score.item())
def retrieve(
self,
query: str,
candidates: Sequence[str],
*,
top_k: int = 3,
) -> list[RetrievalHit]:
"""Return top-k semantically closest candidates to query."""
if not candidates:
return []
texts = [query, *candidates]
embs = self.embed(texts, normalize=True)
query_emb = embs[0:1]
cand_embs = embs[1:]
scores = (query_emb @ cand_embs.T).squeeze(0)
top_k = max(1, min(top_k, scores.shape[0]))
vals, idxs = torch.topk(scores, k=top_k)
hits: list[RetrievalHit] = []
for score, idx in zip(vals.tolist(), idxs.tolist()):
hits.append(RetrievalHit(text=candidates[idx], score=float(score), index=idx))
return hits