11import json
22import pathlib
3- from huggingface_hub import snapshot_download
3+ from huggingface_hub import HfFileSystem , hf_hub_download
44import os
55from os .path import join as join_path
66from .omograph_model import OmographModel
77from .accent_model import AccentModel
8+ from .yo_omograph_model import YomographModel
89from .text_split import split_by_sentences
910import re
1011
1112
1213class RUAccent :
1314 def __init__ (self , workdir = None ):
1415 self .omograph_model = OmographModel ()
16+ self .yo_omograph_model = YomographModel ()
1517 self .accent_model = AccentModel ()
18+ self .fs = HfFileSystem ()
19+ self .omograph_models_paths = {'big' : '/nn/nn_omograph/big' , 'medium' : '/nn/nn_omograph/medium' , 'small' : '/nn/nn_omograph/small' }
20+ self .accentuator_paths = ['/nn/nn_accent' , '/dictionary' ]
21+ self .yo_omograph_path = ['/nn/nn_yo_omograph' ]
1622 if not workdir :
1723 self .workdir = str (pathlib .Path (__file__ ).resolve ().parent )
1824 else :
1925 self .workdir = workdir
2026
27+
2128 def load (
2229 self ,
2330 omograph_model_size = "big" ,
2431 use_dictionary = False ,
2532 custom_dict = {},
2633 custom_homographs = {},
34+ load_yo_homographs_model = False ,
2735 repo = "TeraTTS/accentuator" ,
2836 ):
2937
38+ self .load_yo_homographs_model = load_yo_homographs_model
3039 self .custom_dict = custom_dict
3140 self .accents = {}
3241 if not os .path .exists (
3342 join_path (self .workdir , "dictionary" )
34- ) or not os .path .exists (join_path (self .workdir , "nn" )):
35- snapshot_download (
36- repo_id = repo ,
37- ignore_patterns = ["*.md" , "*.gitattributes" ],
38- local_dir = self .workdir ,
39- local_dir_use_symlinks = False ,
40- )
43+ ):
44+ for path in self .accentuator_paths :
45+ files = self .fs .ls (repo + path )
46+ for file in files :
47+ hf_hub_download (repo_id = repo , local_dir_use_symlinks = False , local_dir = self .workdir , filename = file ['name' ].replace (repo + '/' , '' ))
48+
49+ if not os .path .exists (join_path (self .workdir , "nn" )):
50+ os .mkdir (join_path (self .workdir , "nn" ))
51+
52+ if not os .path .exists (join_path (self .workdir , "nn" , "nn_omograph" , omograph_model_size )):
53+ model_path = self .omograph_models_paths .get (omograph_model_size , None )
54+ if model_path :
55+ files = self .fs .ls (repo + model_path )
56+ for file in files :
57+ hf_hub_download (repo_id = repo , local_dir_use_symlinks = False , local_dir = self .workdir , filename = file ['name' ].replace (repo + '/' , '' ))
58+ else :
59+ raise FileNotFoundError
60+
4161 self .omographs = json .load (
4262 open (join_path (self .workdir , "dictionary/omographs.json" ), encoding = 'utf-8' )
4363 )
44- #self.yo_omographs = json.load(
45- # open(join_path(self.workdir, "dictionary/yo_omographs.json"), encoding='utf-8')
46- #)
47- #self.omographs.update(self.yo_omographs)
4864 self .omographs .update (custom_homographs )
65+
66+ if load_yo_homographs_model :
67+ if not os .path .exists (join_path (self .workdir , "nn" , "nn_yo_omograph" )):
68+ for path in self .yo_omograph_path :
69+ files = self .fs .ls (repo + path )
70+ for file in files :
71+ hf_hub_download (repo_id = repo , local_dir_use_symlinks = False , local_dir = self .workdir , filename = file ['name' ].replace (repo + '/' , '' ))
72+
73+ self .yo_omographs = json .load (
74+ open (join_path (self .workdir , "dictionary/yo_omographs.json" ), encoding = 'utf-8' )
75+ )
76+ self .yo_omograph_model .load (join_path (self .workdir , "nn/nn_yo_omograph/" ))
77+
4978 self .yo_words = json .load (
5079 open (join_path (self .workdir , "dictionary/yo_words.json" ), encoding = 'utf-8' )
5180 )
@@ -57,16 +86,13 @@ def load(
5786
5887 self .accents .update (self .custom_dict )
5988
60- if omograph_model_size not in ["small" , "big" ]:
61- raise NotImplementedError
62-
6389 self .omograph_model .load (
6490 join_path (self .workdir , f"nn/nn_omograph/{ omograph_model_size } /" )
65-
6691 )
6792 self .accent_model .load (join_path (self .workdir , "nn/nn_accent/" ))
6893
6994
95+
7096 def split_by_words (self , string ):
7197 result = re .findall (r"\w*(?:\+\w+)*|[^\w\s]+" , string .lower ())
7298 return [res for res in result if res ]
@@ -115,6 +141,26 @@ def _process_omographs(self, text):
115141 splitted_text [omograph ["position" ]] = cls
116142 return splitted_text
117143
144+ def _process_yo_omographs (self , text ):
145+ splitted_text = text
146+
147+ founded_omographs = []
148+ for i , word in enumerate (splitted_text ):
149+ variants = self .yo_omographs .get (word )
150+ if variants :
151+ founded_omographs .append (
152+ {"word" : word , "variants" : variants , "position" : i }
153+ )
154+ for omograph in founded_omographs :
155+ splitted_text [
156+ omograph ["position" ]
157+ ] = f"<w>{ splitted_text [omograph ['position' ]]} </w>"
158+ cls = self .yo_omograph_model .classify (
159+ " " .join (splitted_text ), omograph ["variants" ]
160+ )
161+ splitted_text [omograph ["position" ]] = cls
162+ return splitted_text
163+
118164 def _process_accent (self , text ):
119165 splitted_text = text
120166
@@ -126,23 +172,27 @@ def _process_accent(self, text):
126172 splitted_text [i ] = stressed_word
127173 return splitted_text
128174
129- def process_yo (self , text ):
175+ def process_yo (self , text , process_yo_omographs = False ):
130176 sentences = split_by_sentences (text )
131177 outputs = []
132178 for sentence in sentences :
133179 text = self .split_by_words (sentence )
134180 processed_text = self ._process_yo (text )
181+ if process_yo_omographs :
182+ processed_text = self ._process_yo_omographs (processed_text )
135183 processed_text = " " .join (processed_text )
136184 processed_text = self .delete_spaces_before_punc (processed_text )
137185 outputs .append (processed_text )
138186 return " " .join (outputs )
139187
140- def process_all (self , text ):
188+ def process_all (self , text , process_yo_omographs = False ):
141189 sentences = split_by_sentences (text )
142190 outputs = []
143191 for sentence in sentences :
144192 text = self .split_by_words (sentence )
145193 processed_text = self ._process_yo (text )
194+ if process_yo_omographs :
195+ processed_text = self ._process_yo_omographs (processed_text )
146196 processed_text = self ._process_omographs (processed_text )
147197 processed_text = self ._process_accent (processed_text )
148198 processed_text = " " .join (processed_text )
0 commit comments