@@ -36,6 +36,7 @@ def semantic_to_waveform(
3636 history_prompt : Optional [str ] = None ,
3737 temp : float = 0.7 ,
3838 silent : bool = False ,
39+ output_full : bool = False ,
3940):
4041 """Generate audio array from semantic input.
4142
@@ -44,31 +45,49 @@ def semantic_to_waveform(
4445 history_prompt: history choice for audio cloning
4546 temp: generation temperature (1.0 more diverse, 0.0 more conservative)
4647 silent: disable progress bar
48+ output_full: return full generation to be used as a history prompt
4749
4850 Returns:
4951 numpy audio array at sample frequency 24khz
5052 """
51- x_coarse_gen = generate_coarse (
53+ coarse_tokens = generate_coarse (
5254 semantic_tokens ,
5355 history_prompt = history_prompt ,
5456 temp = temp ,
5557 silent = silent ,
5658 )
57- x_fine_gen = generate_fine (
58- x_coarse_gen ,
59+ fine_tokens = generate_fine (
60+ coarse_tokens ,
5961 history_prompt = history_prompt ,
6062 temp = 0.5 ,
6163 )
62- audio_arr = codec_decode (x_fine_gen )
64+ audio_arr = codec_decode (fine_tokens )
65+ if output_full :
66+ full_generation = {
67+ "semantic_prompt" : semantic_tokens ,
68+ "coarse_prompt" : coarse_tokens ,
69+ "fine_prompt" : fine_tokens ,
70+ }
71+ return full_generation , audio_arr
6372 return audio_arr
6473
6574
75+ def save_as_prompt (filepath , full_generation ):
76+ assert (filepath .endswith (".npz" ))
77+ assert (isinstance (full_generation , dict ))
78+ assert ("semantic_prompt" in full_generation )
79+ assert ("coarse_prompt" in full_generation )
80+ assert ("fine_prompt" in full_generation )
81+ np .savez (filepath , ** full_generation )
82+
83+
6684def generate_audio (
6785 text : str ,
6886 history_prompt : Optional [str ] = None ,
6987 text_temp : float = 0.7 ,
7088 waveform_temp : float = 0.7 ,
7189 silent : bool = False ,
90+ output_full : bool = False ,
7291):
7392 """Generate audio array from input text.
7493
@@ -78,14 +97,24 @@ def generate_audio(
7897 text_temp: generation temperature (1.0 more diverse, 0.0 more conservative)
7998 waveform_temp: generation temperature (1.0 more diverse, 0.0 more conservative)
8099 silent: disable progress bar
100+ output_full: return full generation to be used as a history prompt
81101
82102 Returns:
83103 numpy audio array at sample frequency 24khz
84104 """
85- x_semantic = text_to_semantic (
105+ semantic_tokens = text_to_semantic (
86106 text , history_prompt = history_prompt , temp = text_temp , silent = silent ,
87107 )
88- audio_arr = semantic_to_waveform (
89- x_semantic , history_prompt = history_prompt , temp = waveform_temp , silent = silent ,
108+ out = semantic_to_waveform (
109+ semantic_tokens ,
110+ history_prompt = history_prompt ,
111+ temp = waveform_temp ,
112+ silent = silent ,
113+ output_full = output_full ,
90114 )
115+ if output_full :
116+ full_generation , audio_arr = out
117+ return full_generation , audio_arr
118+ else :
119+ audio_arr = out
91120 return audio_arr
0 commit comments