|
| 1 | +import json |
| 2 | +import pathlib |
| 3 | +from huggingface_hub import snapshot_download |
| 4 | +import os |
| 5 | +from .omograph_model import OmographModel |
| 6 | +from .accent_model import AccentModel |
| 7 | +import re |
| 8 | + |
| 9 | +class RUAccent: |
| 10 | + def __init__(self): |
| 11 | + self.omograph_model = OmographModel() |
| 12 | + self.accent_model = AccentModel() |
| 13 | + self.workdir = str(pathlib.Path(__file__).resolve().parent) |
| 14 | + |
| 15 | + def load(self, omograph_model_size='medium', dict_load_startup=False, repo="TeraTTS/accentuator"): |
| 16 | + if not os.path.exists(self.workdir + '/dictionary') or not os.path.exists(self.workdir + '/nn'): |
| 17 | + snapshot_download(repo_id=repo, ignore_patterns=["*.md", '*.gitattributes'], local_dir=self.workdir) |
| 18 | + self.omographs = json.load(open(self.workdir + '/dictionary/omographs.json')) |
| 19 | + self.yo_words = json.load(open(self.workdir + '/dictionary/yo_words.json')) |
| 20 | + self.dict_load_startup = dict_load_startup |
| 21 | + if dict_load_startup: |
| 22 | + self.accents = json.load(open(self.workdir + '/dictionary/accents.json')) |
| 23 | + if omograph_model_size not in ['small', 'medium']: |
| 24 | + raise NotImplementedError |
| 25 | + self.omograph_model.load(self.workdir + f'/nn/nn_omograph/{omograph_model_size}/') |
| 26 | + self.accent_model.load(self.workdir + '/nn/nn_accent/') |
| 27 | + |
| 28 | + def split_by_words(self, text): |
| 29 | + text = text.lower() |
| 30 | + spec_chars = '!"#$%&\'()*,-./:;<=>?@[\\]^_`{|}~\r\n\xa0«»\t—…' |
| 31 | + text = re.sub('[' + spec_chars + ']', ' ', text) |
| 32 | + text = re.sub(' +', ' ', text) |
| 33 | + output = text.split() |
| 34 | + return output |
| 35 | + |
| 36 | + def extract_initial_letters(self, text): |
| 37 | + words = self.split_by_words(text) |
| 38 | + initial_letters = [] |
| 39 | + for word in words: |
| 40 | + if len(word) > 2: |
| 41 | + initial_letters.append(word[0]) |
| 42 | + |
| 43 | + return initial_letters |
| 44 | + |
| 45 | + def load_dict(self, text): |
| 46 | + chars = self.extract_initial_letters(text) |
| 47 | + out_dict = {} |
| 48 | + for char in chars: |
| 49 | + out_dict.update(json.load(open(f'{self.workdir}/dictionary/letter_accent/{char}.json'))) |
| 50 | + return out_dict |
| 51 | + |
| 52 | + def process_punc(self, original_text, processed_text): |
| 53 | + original_text = self.split_by_words(original_text) |
| 54 | + processed_text = self.split_by_words(processed_text) |
| 55 | + for i, word_to_process in enumerate(original_text): |
| 56 | + spec_chars = 'абвгдеёжзийклмнопрстухфцчшщъыьэюя' |
| 57 | + word_to_append = re.sub('[' + spec_chars + ']', ' ', word_to_process) |
| 58 | + processed_text[i] = processed_text[i] + word_to_append.strip() |
| 59 | + return ' '.join(processed_text) |
| 60 | + |
| 61 | + def count_vowels(self, text): |
| 62 | + vowels = 'аеёиоуыэюяАЕЁИОУЫЭЮЯ' |
| 63 | + return sum(1 for char in text if char in vowels) |
| 64 | + |
| 65 | + def process_omographs(self, text): |
| 66 | + splitted_text = self.split_by_words(text) |
| 67 | + founded_omographs = [] |
| 68 | + for i, word in enumerate(splitted_text): |
| 69 | + variants = self.omographs.get(word) |
| 70 | + if variants: |
| 71 | + founded_omographs.append({'word': word, 'variants': variants, 'position': i}) |
| 72 | + for omograph in founded_omographs: |
| 73 | + splitted_text[omograph['position']] = f"<w>{splitted_text[omograph['position']]}</w>" |
| 74 | + cls = self.omograph_model.classify(' '.join(splitted_text), omograph['variants']) |
| 75 | + splitted_text[omograph['position']] = cls |
| 76 | + return ' '.join(splitted_text) |
| 77 | + |
| 78 | + def process_yo(self, text): |
| 79 | + splitted_text = self.split_by_words(text) |
| 80 | + for i, word in enumerate(splitted_text): |
| 81 | + splitted_text[i] = self.yo_words.get(word, word) |
| 82 | + return ' '.join(splitted_text) |
| 83 | + |
| 84 | + def process_accent(self, text): |
| 85 | + if not self.dict_load_startup: |
| 86 | + self.accents = self.load_dict(text) |
| 87 | + splitted_text = self.split_by_words(text) |
| 88 | + for i, word in enumerate(splitted_text): |
| 89 | + stressed_word = self.accents.get(word, word) |
| 90 | + if '+' not in stressed_word and self.count_vowels(word) > 1: |
| 91 | + splitted_text[i] = self.accent_model.put_accent(word) |
| 92 | + else: |
| 93 | + splitted_text[i] = stressed_word |
| 94 | + return ' '.join(splitted_text) |
| 95 | + |
| 96 | + def process_all(self, text): |
| 97 | + processed_text = self.process_yo(text) |
| 98 | + processed_text = self.process_omographs(processed_text) |
| 99 | + processed_text = self.process_accent(processed_text) |
| 100 | + return processed_text |
0 commit comments