Skip to content

Commit 3b338b9

Browse files
authored
Add files via upload
0 parents  commit 3b338b9

File tree

7 files changed

+321
-0
lines changed

7 files changed

+321
-0
lines changed

README.md

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# RUAccent
2+
3+
RUAccent - это библиотека для автоматической расстановки ударений на русском языке.
4+
5+
## Установка
6+
**Требуется установленный GIT**
7+
```
8+
pip install git+https://github.com/Den4ikAI/ruaccent.git
9+
```
10+
## Методы
11+
12+
RUAccent предоставляет следующие методы:
13+
14+
- `load(omograph_model_size='medium', dict_load_startup=False)`: Загрузка моделей и словарей. На данные момент доступны две модели: medium (рекомендуется к использованию) и small. Переменная dict_load_startup отвечает за загрузку всего словаря (требуется больше ОЗУ), либо во время работы для необходимых слов (экономит ОЗУ, но требует быстрые ЖД и работает медленее)
15+
16+
- `process_all(text)`: Обрабатывает текст всем сразу (ёфикация, расстановка ударений и расстановка ударений в словах-омографах)
17+
18+
- `process_omographs(text)`: Расстановка ударений только в омографах.
19+
20+
- `process_yo(text)`: Ёфикация текста.
21+
22+
## Пример использования
23+
```python
24+
from ruaccent import RUAccent
25+
26+
accentizer = RUAccent()
27+
accentizer.load(omograph_model_size='medium', dict_load_startup=False)
28+
29+
text = 'на двери висит замок'
30+
print(text_processor.process_all(text))
31+
32+
text = 'ежик нашел в лесу ягоды'
33+
print(text_processor.process_yo(text))
34+
```
35+
36+
37+
Файлы моделей и словарей располагаются по [ссылке](https://huggingface.co/TeraTTS/accentuator). Датасеты будут скоро опубликованы. Мы будем признательны, если вы будете расширять словари и загружать их в репозиторий. Это поможет улучшать данный проект.

ruaccent/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .ruaccent import RUAccent

ruaccent/accent_model.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import torch
2+
from .char_tokenizer import CharTokenizer
3+
from transformers import AlbertForTokenClassification
4+
5+
class AccentModel:
6+
def __init__(self) -> None:
7+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8+
def load(self, path):
9+
self.model = AlbertForTokenClassification.from_pretrained(path).to(self.device)
10+
self.tokenizer = CharTokenizer.from_pretrained(path)
11+
12+
def render_stress(self, word, index):
13+
word = list(word)
14+
word[index-1] = '+' + word[index-1]
15+
return ''.join(word)
16+
17+
def put_accent(self, word):
18+
inputs = self.tokenizer(word, return_tensors="pt").to(self.device)
19+
with torch.no_grad():
20+
logits = self.model(**inputs).logits
21+
predictions = torch.argmax(logits, dim=2)
22+
predicted_token_class = [self.model.config.id2label[t.item()] for t in predictions[0]]
23+
return self.render_stress(word, predicted_token_class.index('STRESS'))

ruaccent/char_tokenizer.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import os
2+
from typing import Optional, Tuple, List
3+
from collections import OrderedDict
4+
5+
from transformers import PreTrainedTokenizer
6+
7+
8+
def load_vocab(vocab_file):
9+
vocab = OrderedDict()
10+
with open(vocab_file, "r", encoding="utf-8") as reader:
11+
tokens = reader.readlines()
12+
for index, token in enumerate(tokens):
13+
token = token.rstrip("\n")
14+
vocab[token] = index
15+
return vocab
16+
17+
18+
class CharTokenizer(PreTrainedTokenizer):
19+
vocab_files_names = {"vocab_file": "vocab.txt"}
20+
21+
def __init__(
22+
self,
23+
vocab_file=None,
24+
pad_token="[pad]",
25+
unk_token="[unk]",
26+
bos_token="[bos]",
27+
eos_token="[eos]",
28+
do_lower_case=False,
29+
*args,
30+
**kwargs
31+
):
32+
super().__init__(
33+
pad_token=pad_token,
34+
unk_token=unk_token,
35+
bos_token=bos_token,
36+
eos_token=eos_token,
37+
do_lower_case=do_lower_case,
38+
**kwargs
39+
)
40+
self.do_lower_case = do_lower_case
41+
42+
if not vocab_file or not os.path.isfile(vocab_file):
43+
self.vocab = OrderedDict()
44+
self.ids_to_tokens = OrderedDict()
45+
else:
46+
self.vocab = load_vocab(vocab_file)
47+
self.ids_to_tokens = OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
48+
49+
@property
50+
def vocab_size(self):
51+
return len(self.vocab)
52+
53+
def get_vocab(self):
54+
return self.vocab
55+
56+
def _convert_token_to_id(self, token):
57+
if self.do_lower_case:
58+
token = token.lower()
59+
return self.vocab.get(token, self.vocab[self.unk_token])
60+
61+
def _convert_id_to_token(self, index):
62+
return self.ids_to_tokens[index]
63+
64+
def _tokenize(self, text):
65+
if self.do_lower_case:
66+
text = text.lower()
67+
return list(text)
68+
69+
def convert_tokens_to_string(self, tokens):
70+
return "".join(tokens)
71+
72+
def build_inputs_with_special_tokens(
73+
self,
74+
token_ids_0: List[int],
75+
token_ids_1: Optional[List[int]] = None
76+
) -> List[int]:
77+
bos = [self.bos_token_id]
78+
eos = [self.eos_token_id]
79+
return bos + token_ids_0 + eos
80+
81+
def get_special_tokens_mask(
82+
self,
83+
token_ids_0: List[int],
84+
token_ids_1: Optional[List[int]] = None
85+
) -> List[int]:
86+
return [1] + ([0] * len(token_ids_0)) + [1]
87+
88+
def create_token_type_ids_from_sequences(
89+
self,
90+
token_ids_0: List[int],
91+
token_ids_1: Optional[List[int]] = None
92+
) -> List[int]:
93+
return (len(token_ids_0) + 2) * [0]
94+
95+
def save_vocabulary(
96+
self,
97+
save_directory: str,
98+
filename_prefix: Optional[str] = None
99+
) -> Tuple[str]:
100+
assert os.path.isdir(save_directory)
101+
vocab_file = os.path.join(
102+
save_directory,
103+
(filename_prefix + "-" if filename_prefix else "") +
104+
self.vocab_files_names["vocab_file"]
105+
)
106+
index = 0
107+
with open(vocab_file, "w", encoding="utf-8") as writer:
108+
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
109+
assert index == token_index
110+
writer.write(token + "\n")
111+
index += 1
112+
return (vocab_file,)

ruaccent/omograph_model.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
2+
import torch
3+
4+
class OmographModel:
5+
def __init__(self) -> None:
6+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7+
8+
def load(self, path):
9+
self.nli_model = AutoModelForSequenceClassification.from_pretrained(path, torch_dtype=torch.bfloat16).to(self.device)
10+
self.tokenizer = AutoTokenizer.from_pretrained(path)
11+
12+
def classify(self, text, hypotheses):
13+
encodings = self.tokenizer.batch_encode_plus([(text, hyp) for hyp in hypotheses], return_tensors='pt', padding=True)
14+
input_ids = encodings['input_ids'].to(self.device)
15+
with torch.no_grad():
16+
logits = self.nli_model(input_ids)[0]
17+
entail_contradiction_logits = logits[:,[0,2]]
18+
probs = entail_contradiction_logits.softmax(dim=1)
19+
prob_label_is_true = [float(p[1]) for p in probs]
20+
21+
return hypotheses[prob_label_is_true.index(max(prob_label_is_true))]

ruaccent/ruaccent.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import json
2+
import pathlib
3+
from huggingface_hub import snapshot_download
4+
import os
5+
from .omograph_model import OmographModel
6+
from .accent_model import AccentModel
7+
import re
8+
9+
class RUAccent:
10+
def __init__(self):
11+
self.omograph_model = OmographModel()
12+
self.accent_model = AccentModel()
13+
self.workdir = str(pathlib.Path(__file__).resolve().parent)
14+
15+
def load(self, omograph_model_size='medium', dict_load_startup=False, repo="TeraTTS/accentuator"):
16+
if not os.path.exists(self.workdir + '/dictionary') or not os.path.exists(self.workdir + '/nn'):
17+
snapshot_download(repo_id=repo, ignore_patterns=["*.md", '*.gitattributes'], local_dir=self.workdir)
18+
self.omographs = json.load(open(self.workdir + '/dictionary/omographs.json'))
19+
self.yo_words = json.load(open(self.workdir + '/dictionary/yo_words.json'))
20+
self.dict_load_startup = dict_load_startup
21+
if dict_load_startup:
22+
self.accents = json.load(open(self.workdir + '/dictionary/accents.json'))
23+
if omograph_model_size not in ['small', 'medium']:
24+
raise NotImplementedError
25+
self.omograph_model.load(self.workdir + f'/nn/nn_omograph/{omograph_model_size}/')
26+
self.accent_model.load(self.workdir + '/nn/nn_accent/')
27+
28+
def split_by_words(self, text):
29+
text = text.lower()
30+
spec_chars = '!"#$%&\'()*,-./:;<=>?@[\\]^_`{|}~\r\n\xa0«»\t—…'
31+
text = re.sub('[' + spec_chars + ']', ' ', text)
32+
text = re.sub(' +', ' ', text)
33+
output = text.split()
34+
return output
35+
36+
def extract_initial_letters(self, text):
37+
words = self.split_by_words(text)
38+
initial_letters = []
39+
for word in words:
40+
if len(word) > 2:
41+
initial_letters.append(word[0])
42+
43+
return initial_letters
44+
45+
def load_dict(self, text):
46+
chars = self.extract_initial_letters(text)
47+
out_dict = {}
48+
for char in chars:
49+
out_dict.update(json.load(open(f'{self.workdir}/dictionary/letter_accent/{char}.json')))
50+
return out_dict
51+
52+
def process_punc(self, original_text, processed_text):
53+
original_text = self.split_by_words(original_text)
54+
processed_text = self.split_by_words(processed_text)
55+
for i, word_to_process in enumerate(original_text):
56+
spec_chars = 'абвгдеёжзийклмнопрстухфцчшщъыьэюя'
57+
word_to_append = re.sub('[' + spec_chars + ']', ' ', word_to_process)
58+
processed_text[i] = processed_text[i] + word_to_append.strip()
59+
return ' '.join(processed_text)
60+
61+
def count_vowels(self, text):
62+
vowels = 'аеёиоуыэюяАЕЁИОУЫЭЮЯ'
63+
return sum(1 for char in text if char in vowels)
64+
65+
def process_omographs(self, text):
66+
splitted_text = self.split_by_words(text)
67+
founded_omographs = []
68+
for i, word in enumerate(splitted_text):
69+
variants = self.omographs.get(word)
70+
if variants:
71+
founded_omographs.append({'word': word, 'variants': variants, 'position': i})
72+
for omograph in founded_omographs:
73+
splitted_text[omograph['position']] = f"<w>{splitted_text[omograph['position']]}</w>"
74+
cls = self.omograph_model.classify(' '.join(splitted_text), omograph['variants'])
75+
splitted_text[omograph['position']] = cls
76+
return ' '.join(splitted_text)
77+
78+
def process_yo(self, text):
79+
splitted_text = self.split_by_words(text)
80+
for i, word in enumerate(splitted_text):
81+
splitted_text[i] = self.yo_words.get(word, word)
82+
return ' '.join(splitted_text)
83+
84+
def process_accent(self, text):
85+
if not self.dict_load_startup:
86+
self.accents = self.load_dict(text)
87+
splitted_text = self.split_by_words(text)
88+
for i, word in enumerate(splitted_text):
89+
stressed_word = self.accents.get(word, word)
90+
if '+' not in stressed_word and self.count_vowels(word) > 1:
91+
splitted_text[i] = self.accent_model.put_accent(word)
92+
else:
93+
splitted_text[i] = stressed_word
94+
return ' '.join(splitted_text)
95+
96+
def process_all(self, text):
97+
processed_text = self.process_yo(text)
98+
processed_text = self.process_omographs(processed_text)
99+
processed_text = self.process_accent(processed_text)
100+
return processed_text

setup.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from setuptools import setup, find_packages
2+
3+
setup(
4+
name='ruaccent',
5+
version='1.0.0',
6+
author='Denis Petrov',
7+
author_email='arduino4b@gmail.com',
8+
description='A Russian text accentuation tool',
9+
license='MIT',
10+
url='https://github.com/Den4ikAI/ruaccent',
11+
packages=find_packages(),
12+
install_requires=[
13+
'huggingface_hub',
14+
'torch==1.13.1',
15+
'transformers',
16+
'sentencepiece'
17+
],
18+
classifiers=[
19+
'Development Status :: 5 - Production/Stable',
20+
'Intended Audience :: Developers',
21+
'License :: OSI Approved :: MIT License',
22+
'Programming Language :: Python :: 3',
23+
'Operating System :: Microsoft :: Windows',
24+
'Operating System :: Unix',
25+
'Operating System :: MacOS',
26+
],
27+
)

0 commit comments

Comments
 (0)