Skip to content

Commit 7c7cb72

Browse files
authored
Updated to deal with Gradio 4.xx versions
1 parent 3a105ce commit 7c7cb72

File tree

1 file changed

+14
-15
lines changed

1 file changed

+14
-15
lines changed

finetune.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -778,15 +778,15 @@ def get_available_voices(minimum_size_kb=1200):
778778

779779
def find_best_models(directory):
780780
"""Find files named 'best_model.pth' in the given directory."""
781-
return [file for file in Path(directory).rglob("best_model.pth")]
781+
return [str(file) for file in Path(directory).rglob("best_model.pth")]
782782

783783
def find_models(directory, extension):
784784
"""Find files with a specific extension in the given directory."""
785-
return [file for file in Path(directory).rglob(f"*.{extension}")]
785+
return [str(file) for file in Path(directory).rglob(f"*.{extension}")]
786786

787787
def find_jsons(directory, filename):
788788
"""Find files with a specific filename in the given directory."""
789-
return list(Path(directory).rglob(filename))
789+
return [str(file) for file in Path(directory).rglob(filename)]
790790

791791
# Your main directory
792792
main_directory = Path(this_dir) / "finetune" / "tmp-trn"
@@ -809,22 +809,21 @@ def find_latest_best_model(folder_path):
809809

810810
def compact_model(xtts_checkpoint_copy):
811811
this_dir = Path(__file__).parent.resolve()
812-
best_model_path_str = xtts_checkpoint_copy
812+
print("THIS DIR:", this_dir)
813+
best_model_path_str = str(xtts_checkpoint_copy) # Convert to string
814+
print("best_model_path_str", best_model_path_str)
813815

814816
# Check if the best model file exists
815-
if best_model_path_str is None:
817+
if not best_model_path_str:
816818
print("[FINETUNE] No trained model was found.")
817819
return "No trained model was found."
818820

819821
print(f"[FINETUNE] Best model path: {best_model_path_str}")
820822

821-
# Convert model_path_str to Path
822-
best_model_path = Path(best_model_path_str)
823-
824823
# Attempt to load the model
825824
try:
826-
checkpoint = torch.load(best_model_path, map_location=torch.device("cpu"))
827-
print(f"[FINETUNE] Checkpoint loaded: {best_model_path}")
825+
checkpoint = torch.load(best_model_path_str, map_location=torch.device("cpu"))
826+
print(f"[FINETUNE] Checkpoint loaded: {best_model_path_str}")
828827
except Exception as e:
829828
print("[FINETUNE] Error loading checkpoint:", e)
830829
raise
@@ -842,15 +841,15 @@ def compact_model(xtts_checkpoint_copy):
842841
del checkpoint["model"][key]
843842

844843
# Save the modified checkpoint in the target directory
845-
torch.save(checkpoint, target_dir / "model.pth")
844+
torch.save(checkpoint, str(target_dir / "model.pth")) # Convert to string
846845

847846
# Specify the files you want to copy
848-
files_to_copy = ["vocab.json", "config.json", "speakers_xtts.pth", "mel_stats.pth", "dvae.pth",]
847+
files_to_copy = ["vocab.json", "config.json", "speakers_xtts.pth", "mel_stats.pth", "dvae.pth"]
849848

850849
for file_name in files_to_copy:
851850
src_path = this_dir / base_path / base_model_path / file_name
852851
dest_path = target_dir / file_name
853-
shutil.copy(str(src_path), str(dest_path))
852+
shutil.copy(str(src_path), str(dest_path)) # Convert to string
854853

855854
source_wavs_dir = this_dir / "finetune" / "tmp-trn" / "wavs"
856855
target_wavs_dir = target_dir / "wavs"
@@ -861,7 +860,7 @@ def compact_model(xtts_checkpoint_copy):
861860
# Check if it's a file and larger than 1000 KB
862861
if file_path.is_file() and file_path.stat().st_size > 1000 * 1024:
863862
# Copy the file to the target directory
864-
shutil.copy(str(file_path), str(target_wavs_dir / file_path.name))
863+
shutil.copy(str(file_path), str(target_wavs_dir / file_path.name)) # Convert to string
865864

866865
print("[FINETUNE] Model copied to '/models/trainedmodel/'")
867866
return "Model copied to '/models/trainedmodel/'"
@@ -1764,7 +1763,7 @@ def train_model(language, train_csv, eval_csv, learning_rates, num_epochs, batch
17641763
)
17651764
# Create refresh button
17661765
refresh_button = create_refresh_button(
1767-
[xtts_checkpoint,],
1766+
[xtts_checkpoint_copy,],
17681767
[
17691768
lambda: {"choices": find_best_models(main_directory), "value": ""},
17701769
],

0 commit comments

Comments
 (0)