@@ -29,9 +29,9 @@ def __init__(self):
2929 'small_poetry' : '/nn/nn_omograph/small_poetry' ,
3030 'turbo' : '/nn/nn_omograph/turbo' }
3131
32- self .accentuator_paths = ['/nn/nn_accent' , '/nn/nn_stress_usage_predictor' ,'/nn/nn_yo_homograph_resolver' , '/dictionary' , '/dictionary/rule_engine' , "/koziev/rulemma" , "/koziev/rupostagger" , "/koziev/rupostagger/database" ]
32+ self .accentuator_paths = ['/nn/nn_accent' , '/nn/nn_stress_usage_predictor' ,'/nn/nn_yo_homograph_resolver' , '/dictionary' , '/dictionary/rule_engine' ]
3333 self .letters_accent = {'о' : '+о' , 'О' : '+О' }
34-
34+ self . koziev_paths = [ "/koziev/rulemma" , "/koziev/rupostagger" , "/koziev/rupostagger/database" ]
3535 def load (
3636 self ,
3737 omograph_model_size = "big_poetry" ,
@@ -46,7 +46,7 @@ def load(
4646 self .workdir = workdir
4747 else :
4848 self .workdir = str (pathlib .Path (__file__ ).resolve ().parent )
49-
49+ self . module_path = str ( pathlib . Path ( __file__ ). resolve (). parent )
5050 self .custom_dict = custom_dict
5151 self .accents = {}
5252 if not os .path .exists (
@@ -66,10 +66,16 @@ def load(
6666 if model_path :
6767 files = self .fs .ls (repo + model_path )
6868 for file in files :
69- hf_hub_download (repo_id = repo , local_dir_use_symlinks = False , local_dir = self .workdir , filename = file ['name' ].replace (repo + '/' , '' ))
69+ if file ["type" ] == "file" :
70+ hf_hub_download (repo_id = repo , local_dir_use_symlinks = False , local_dir = self .workdir , filename = file ['name' ].replace (repo + '/' , '' ))
7071 else :
7172 raise FileNotFoundError
72-
73+ if not os .path .exists (join_path (self .module_path , "koziev" )):
74+ for path in self .koziev_paths :
75+ files = self .fs .ls (repo + path )
76+ for file in files :
77+ if file ["type" ] == "file" :
78+ hf_hub_download (repo_id = repo , local_dir_use_symlinks = False , local_dir = self .module_path , filename = file ['name' ].replace (repo + '/' , '' ))
7379 from .rule_accent_engine import RuleEngine
7480 self .rule_accent = RuleEngine ()
7581 self .omographs = json .load (
@@ -91,17 +97,14 @@ def load(
9197 ))
9298 self .accents .update (self .custom_dict )
9399 self .accents .update (self .letters_accent )
94- self .omograph_model .load (
95- join_path (self .workdir , f"nn/nn_omograph/{ omograph_model_size } /" )
96-
97- #"../../../pretrain_ruaccent_turbo/onnx_deberta"
98- )
100+ self .omograph_model .load (join_path (self .workdir , f"nn/nn_omograph/{ omograph_model_size } /" ), device = device )
101+
99102 self .yo_homographs = json .load (
100103 gzip .open (join_path (self .workdir , "dictionary" ,"yo_homographs.json.gz" ))
101104 )
102- self .accent_model .load (join_path (self .workdir , "nn" ,"nn_accent/" ))
103- self .stress_usage_predictor .load (join_path (self .workdir , "nn" ,"nn_stress_usage_predictor/" ))
104- self .yo_homograph_model .load (join_path (self .workdir , "nn" ,"nn_yo_homograph_resolver" ))
105+ self .accent_model .load (join_path (self .workdir , "nn" ,"nn_accent/" ), device = device )
106+ self .stress_usage_predictor .load (join_path (self .workdir , "nn" ,"nn_stress_usage_predictor/" ), device = device )
107+ self .yo_homograph_model .load (join_path (self .workdir , "nn" ,"nn_yo_homograph_resolver" ), device = device )
105108 self .rule_accent .load (join_path (self .workdir , "dictionary" ,"rule_engine" ))
106109
107110 def split_by_words (self , string ):
0 commit comments