Skip to content

Commit 721fa05

Browse files
authored
Update V1.5.2
Эксперимент с Ё-омографами Переделка механизма скачивания моделей Обновление моделей
1 parent 1227bd4 commit 721fa05

File tree

6 files changed

+104
-26
lines changed

6 files changed

+104
-26
lines changed

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@ RUAccent - это библиотека для автоматической ра
1313
```
1414
## Параметры работы
1515

16-
load(omograph_model_size='big', use_dictionary=False, custom_dict={}, custom_homographs={}
16+
load(omograph_model_size='big', use_dictionary=False, custom_dict={}, custom_homographs={}, load_yo_homographs_model=False)
1717

1818

19-
- На данный момент доступны две модели: **big** (рекомендуется к использованию) и **small**.
20-
- Модель **big** имеет 178 миллионов параметров, а **small** 10 миллионов
21-
- Переменная **use_dict** отвечает за загрузку всего словаря (требуется больше ОЗУ), иначе все ударения расставляет нейросеть.
19+
- На данный момент доступны две модели: **big** (рекомендуется к использованию), **medium** и **small**.
20+
- Модель **big** имеет 178 миллионов параметров, **medium** 85 миллионов, а **small** 42 миллиона
21+
- Переменная **use_dictionary** отвечает за загрузку всего словаря (требуется больше ОЗУ), иначе все ударения расставляет нейросеть.
2222
- Переменная **custom_homographs** отвечает за добавление своих омографов. Формат такой: `{'слово-омограф': ['вариант ударения 1', 'вариант ударения 2']}`.
2323
- Функция **custom_dict** отвечает за добавление своих вариантов ударений в словарь. Формат такой: `{'слово': 'сл+ово с удар+ением'}`
24-
24+
- Также вы можете протестировать **beta-функцию** разрешения Ё-омографов, установив `load_yo_homographs_model=True` в `load()`, а также `accentizer.process_all(text, process_yo_omographs=True)` или `accentizer.process_yo(text, process_yo_omographs=True)`.
2525

2626

2727
## Пример использования

ruaccent/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Russian accentizer"""
22

3-
__version__ = "1.5.1"
3+
__version__ = "1.5.2"
44

55

66
from .ruaccent import RUAccent

ruaccent/omograph_model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ def softmax(self, x):
1818
def classify(self, text, hypotheses):
1919
hypotheses_probs = []
2020
text = re.sub(r'\s+(?=(?:[,.?!:;…]))', r'', text)
21-
2221
for h in hypotheses:
2322
inputs = self.tokenizer(text, h, return_tensors="np")
2423
inputs = {k: v.astype(np.int64) for k, v in inputs.items()}

ruaccent/ruaccent.py

Lines changed: 68 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,80 @@
11
import json
22
import pathlib
3-
from huggingface_hub import snapshot_download
3+
from huggingface_hub import HfFileSystem, hf_hub_download
44
import os
55
from os.path import join as join_path
66
from .omograph_model import OmographModel
77
from .accent_model import AccentModel
8+
from .yo_omograph_model import YomographModel
89
from .text_split import split_by_sentences
910
import re
1011

1112

1213
class RUAccent:
1314
def __init__(self, workdir=None):
1415
self.omograph_model = OmographModel()
16+
self.yo_omograph_model = YomographModel()
1517
self.accent_model = AccentModel()
18+
self.fs = HfFileSystem()
19+
self.omograph_models_paths = {'big': '/nn/nn_omograph/big', 'medium': '/nn/nn_omograph/medium', 'small': '/nn/nn_omograph/small'}
20+
self.accentuator_paths = ['/nn/nn_accent', '/dictionary']
21+
self.yo_omograph_path = ['/nn/nn_yo_omograph']
1622
if not workdir:
1723
self.workdir = str(pathlib.Path(__file__).resolve().parent)
1824
else:
1925
self.workdir = workdir
2026

27+
2128
def load(
2229
self,
2330
omograph_model_size="big",
2431
use_dictionary=False,
2532
custom_dict={},
2633
custom_homographs={},
34+
load_yo_homographs_model=False,
2735
repo="TeraTTS/accentuator",
2836
):
2937

38+
self.load_yo_homographs_model = load_yo_homographs_model
3039
self.custom_dict = custom_dict
3140
self.accents = {}
3241
if not os.path.exists(
3342
join_path(self.workdir, "dictionary")
34-
) or not os.path.exists(join_path(self.workdir, "nn")):
35-
snapshot_download(
36-
repo_id=repo,
37-
ignore_patterns=["*.md", "*.gitattributes"],
38-
local_dir=self.workdir,
39-
local_dir_use_symlinks=False,
40-
)
43+
):
44+
for path in self.accentuator_paths:
45+
files = self.fs.ls(repo + path)
46+
for file in files:
47+
hf_hub_download(repo_id=repo, local_dir_use_symlinks=False, local_dir=self.workdir, filename=file['name'].replace(repo+'/', ''))
48+
49+
if not os.path.exists(join_path(self.workdir, "nn")):
50+
os.mkdir(join_path(self.workdir, "nn"))
51+
52+
if not os.path.exists(join_path(self.workdir, "nn", "nn_omograph", omograph_model_size)):
53+
model_path = self.omograph_models_paths.get(omograph_model_size, None)
54+
if model_path:
55+
files = self.fs.ls(repo + model_path)
56+
for file in files:
57+
hf_hub_download(repo_id=repo, local_dir_use_symlinks=False, local_dir=self.workdir, filename=file['name'].replace(repo+'/', ''))
58+
else:
59+
raise FileNotFoundError
60+
4161
self.omographs = json.load(
4262
open(join_path(self.workdir, "dictionary/omographs.json"), encoding='utf-8')
4363
)
44-
#self.yo_omographs = json.load(
45-
# open(join_path(self.workdir, "dictionary/yo_omographs.json"), encoding='utf-8')
46-
#)
47-
#self.omographs.update(self.yo_omographs)
4864
self.omographs.update(custom_homographs)
65+
66+
if load_yo_homographs_model:
67+
if not os.path.exists(join_path(self.workdir, "nn", "nn_yo_omograph")):
68+
for path in self.yo_omograph_path:
69+
files = self.fs.ls(repo + path)
70+
for file in files:
71+
hf_hub_download(repo_id=repo, local_dir_use_symlinks=False, local_dir=self.workdir, filename=file['name'].replace(repo+'/', ''))
72+
73+
self.yo_omographs = json.load(
74+
open(join_path(self.workdir, "dictionary/yo_omographs.json"), encoding='utf-8')
75+
)
76+
self.yo_omograph_model.load(join_path(self.workdir, "nn/nn_yo_omograph/"))
77+
4978
self.yo_words = json.load(
5079
open(join_path(self.workdir, "dictionary/yo_words.json"), encoding='utf-8')
5180
)
@@ -57,16 +86,13 @@ def load(
5786

5887
self.accents.update(self.custom_dict)
5988

60-
if omograph_model_size not in ["small", "big"]:
61-
raise NotImplementedError
62-
6389
self.omograph_model.load(
6490
join_path(self.workdir, f"nn/nn_omograph/{omograph_model_size}/")
65-
6691
)
6792
self.accent_model.load(join_path(self.workdir, "nn/nn_accent/"))
6893

6994

95+
7096
def split_by_words(self, string):
7197
result = re.findall(r"\w*(?:\+\w+)*|[^\w\s]+", string.lower())
7298
return [res for res in result if res]
@@ -115,6 +141,26 @@ def _process_omographs(self, text):
115141
splitted_text[omograph["position"]] = cls
116142
return splitted_text
117143

144+
def _process_yo_omographs(self, text):
145+
splitted_text = text
146+
147+
founded_omographs = []
148+
for i, word in enumerate(splitted_text):
149+
variants = self.yo_omographs.get(word)
150+
if variants:
151+
founded_omographs.append(
152+
{"word": word, "variants": variants, "position": i}
153+
)
154+
for omograph in founded_omographs:
155+
splitted_text[
156+
omograph["position"]
157+
] = f"<w>{splitted_text[omograph['position']]}</w>"
158+
cls = self.yo_omograph_model.classify(
159+
" ".join(splitted_text), omograph["variants"]
160+
)
161+
splitted_text[omograph["position"]] = cls
162+
return splitted_text
163+
118164
def _process_accent(self, text):
119165
splitted_text = text
120166

@@ -126,23 +172,27 @@ def _process_accent(self, text):
126172
splitted_text[i] = stressed_word
127173
return splitted_text
128174

129-
def process_yo(self, text):
175+
def process_yo(self, text, process_yo_omographs=False):
130176
sentences = split_by_sentences(text)
131177
outputs = []
132178
for sentence in sentences:
133179
text = self.split_by_words(sentence)
134180
processed_text = self._process_yo(text)
181+
if process_yo_omographs:
182+
processed_text = self._process_yo_omographs(processed_text)
135183
processed_text = " ".join(processed_text)
136184
processed_text = self.delete_spaces_before_punc(processed_text)
137185
outputs.append(processed_text)
138186
return " ".join(outputs)
139187

140-
def process_all(self, text):
188+
def process_all(self, text, process_yo_omographs=False):
141189
sentences = split_by_sentences(text)
142190
outputs = []
143191
for sentence in sentences:
144192
text = self.split_by_words(sentence)
145193
processed_text = self._process_yo(text)
194+
if process_yo_omographs:
195+
processed_text = self._process_yo_omographs(processed_text)
146196
processed_text = self._process_omographs(processed_text)
147197
processed_text = self._process_accent(processed_text)
148198
processed_text = " ".join(processed_text)

ruaccent/yo_omograph_model.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import numpy as np
2+
from onnxruntime import InferenceSession
3+
from transformers import AutoTokenizer
4+
import re
5+
6+
class YomographModel:
7+
def __init__(self):
8+
pass
9+
10+
def load(self, path):
11+
self.tokenizer = AutoTokenizer.from_pretrained(path)
12+
self.session = InferenceSession(f"{path}/model.onnx", providers=['CPUExecutionProvider'])
13+
14+
def softmax(self, x):
15+
e_x = np.exp(x - np.max(x))
16+
return e_x / e_x.sum()
17+
18+
def classify(self, text, hypotheses):
19+
hypotheses_probs = []
20+
text = re.sub(r'\s+(?=(?:[,.?!:;…]))', r'', text)
21+
for h in hypotheses:
22+
inputs = self.tokenizer(text, h, return_tensors="np")
23+
inputs = {k: v.astype(np.int64) for k, v in inputs.items()}
24+
25+
outputs = self.session.run(None, inputs)[0]
26+
outputs = self.softmax(outputs)
27+
prob_label_is_true = [float(p[1]) for p in outputs][0]
28+
hypotheses_probs.append(prob_label_is_true)
29+
return hypotheses[hypotheses_probs.index(max(hypotheses_probs))]

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name='ruaccent',
5-
version='1.5.1',
5+
version='1.5.2',
66
author='Denis Petrov',
77
author_email='arduino4b@gmail.com',
88
description='A Russian text accentuation tool',

0 commit comments

Comments
 (0)