Skip to content

Commit 34ba472

Browse files
authored
1.2.0 Update
1. Переход на ORT (onnxruntime) 2. Обновлена модель расстановки ударений в неизвестных словах 3. Расширен набор обучающих данных 1.2M -> 3.3M 4. Добавлена возможность отключить словарь
1 parent f21ee4a commit 34ba472

File tree

4 files changed

+61
-42
lines changed

4 files changed

+61
-42
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ RUAccent - это библиотека для автоматической ра
1111

1212
RUAccent предоставляет следующие методы:
1313

14-
- `load(omograph_model_size='medium', dict_load_startup=False)`: Загрузка моделей и словарей. На данные момент доступны две модели: medium (рекомендуется к использованию) и small. Переменная dict_load_startup отвечает за загрузку всего словаря (требуется больше ОЗУ), либо во время работы для необходимых слов (экономит ОЗУ, но требует быстрые ЖД и работает медленее)
14+
- `load(omograph_model_size='medium', dict_load_startup=False), disable_accent_dict=False`: Загрузка моделей и словарей. На данные момент доступны две модели: medium (рекомендуется к использованию) и small. Переменная dict_load_startup отвечает за загрузку всего словаря (требуется больше ОЗУ), либо во время работы для необходимых слов (экономит ОЗУ, но требует быстрыq ЖД и работает медленее). Переменная disable_accent_dict отключает использование словаря (все ударения расставляет нейросеть). Данная функция экономит ОЗУ, по скорости работы сопоставима со всем словарём в ОЗУ.
1515

1616
- `process_all(text)`: Обрабатывает текст всем сразу (ёфикация, расстановка ударений и расстановка ударений в словах-омографах)
1717

@@ -24,13 +24,13 @@ RUAccent предоставляет следующие методы:
2424
from ruaccent import RUAccent
2525

2626
accentizer = RUAccent()
27-
accentizer.load(omograph_model_size='medium', dict_load_startup=False)
27+
accentizer.load(omograph_model_size='medium', dict_load_startup=False, disable_accent_dict=False)
2828

2929
text = 'на двери висит замок'
30-
print(accentizer.process_all(text)) # на двер+и вис+ит зам+ок
30+
print(text_processor.process_all(text))
3131

3232
text = 'ежик нашел в лесу ягоды'
33-
print(accentizer.process_yo(text)) # ёжик нашел в лесу ягоды
33+
print(text_processor.process_yo(text))
3434
```
3535

3636

ruaccent/accent_model.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,37 @@
1-
import torch
1+
import numpy as np
2+
import json
3+
from onnxruntime import InferenceSession
24
from .char_tokenizer import CharTokenizer
3-
from transformers import AutoModelForTokenClassification
45

56
class AccentModel:
67
def __init__(self) -> None:
7-
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8+
pass
9+
810
def load(self, path):
9-
self.model = AutoModelForTokenClassification.from_pretrained(path).to(self.device)
11+
self.session = InferenceSession(f"{path}/model.onnx", providers=["CPUExecutionProvider"])
12+
13+
with open(f"{path}/config.json", "r") as f:
14+
self.id2label = json.load(f)["id2label"]
1015
self.tokenizer = CharTokenizer.from_pretrained(path)
11-
12-
def render_stress(self, word, token_classes):
13-
if 'STRESS' in token_classes:
14-
index = token_classes.index('STRESS')
15-
word = list(word)
16-
word[index-1] = '+' + word[index-1]
17-
return ''.join(word)
18-
else:
19-
return word
20-
16+
self.tokenizer.model_input_names = ["input_ids", "attention_mask"]
17+
18+
def render_stress(self, text, pred):
19+
text = list(text)
20+
i = 0
21+
for chunk in pred:
22+
if chunk != "NO":
23+
text[i - 1] = "+" + text[i - 1]
24+
i += 1
25+
text = "".join(text)
26+
return text
27+
2128
def put_accent(self, word):
22-
inputs = self.tokenizer(word, return_tensors="pt").to(self.device)
23-
with torch.no_grad():
24-
logits = self.model(**inputs).logits
25-
predictions = torch.argmax(logits, dim=2)
26-
predicted_token_class = [self.model.config.id2label[t.item()] for t in predictions[0]]
27-
return self.render_stress(word, predicted_token_class)
29+
inputs = self.tokenizer(word, return_tensors="np")
30+
inputs = {k: v.astype(np.int64) for k, v in inputs.items()}
31+
outputs = self.session.run(None, inputs)
32+
output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())}
33+
logits = outputs[output_names["logits"]]
34+
labels = np.argmax(logits, axis=-1)[0]
35+
labels = [self.id2label[str(label)] for label in labels]
36+
stressed_word = self.render_stress(word, labels)
37+
return stressed_word

ruaccent/omograph_model.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,29 @@
1-
from transformers import AutoModelForSequenceClassification, AutoTokenizer
2-
import torch
1+
import numpy as np
2+
from onnxruntime import InferenceSession
3+
from transformers import AutoTokenizer
34

45
class OmographModel:
5-
def __init__(self) -> None:
6-
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7-
6+
def __init__(self):
7+
pass
8+
89
def load(self, path):
9-
self.nli_model = AutoModelForSequenceClassification.from_pretrained(path, torch_dtype=torch.bfloat16).to(self.device)
1010
self.tokenizer = AutoTokenizer.from_pretrained(path)
11-
11+
self.session = InferenceSession(f"{path}/model.onnx", providers=['CPUExecutionProvider'])
12+
13+
def softmax(self, x):
14+
e_x = np.exp(x - np.max(x))
15+
return e_x / e_x.sum()
16+
1217
def classify(self, text, hypotheses):
13-
encodings = self.tokenizer.batch_encode_plus([(text, hyp) for hyp in hypotheses], return_tensors='pt', padding=True)
14-
input_ids = encodings['input_ids'].to(self.device)
15-
with torch.no_grad():
16-
logits = self.nli_model(input_ids)[0]
17-
entail_contradiction_logits = logits[:,[0,2]]
18-
probs = entail_contradiction_logits.softmax(dim=1)
19-
prob_label_is_true = [float(p[1]) for p in probs]
18+
hypotheses_probs = []
19+
for h in hypotheses:
20+
inputs = self.tokenizer(text, h, return_tensors="np")
21+
inputs = {k: v.astype(np.int64) for k, v in inputs.items()}
22+
outputs = self.session.run(None, inputs)[0]
23+
entail_contradiction_logits = outputs[:, [0, 2]]
24+
probs = self.softmax(entail_contradiction_logits)
25+
prob_label_is_true = [float(p[1]) for p in probs][0]
26+
hypotheses_probs.append(prob_label_is_true)
27+
return hypotheses[hypotheses_probs.index(max(hypotheses_probs))]
28+
2029

21-
return hypotheses[prob_label_is_true.index(max(prob_label_is_true))]

setup.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name='ruaccent',
5-
version='1.0.0',
5+
version='1.2.0',
66
author='Denis Petrov',
77
author_email='arduino4b@gmail.com',
88
description='A Russian text accentuation tool',
@@ -11,9 +11,10 @@
1111
packages=find_packages(),
1212
install_requires=[
1313
'huggingface_hub',
14-
'torch==1.13.1',
14+
'onnxruntime',
1515
'transformers',
16-
'sentencepiece'
16+
'sentencepiece',
17+
'numpy'
1718
],
1819
classifiers=[
1920
'Development Status :: 5 - Production/Stable',

0 commit comments

Comments
 (0)