Skip to content

Commit 7d39f48

Browse files
committed
allow using unconditional as prompts
1 parent c372430 commit 7d39f48

File tree

2 files changed

+57
-19
lines changed

2 files changed

+57
-19
lines changed

bark/api.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def semantic_to_waveform(
3636
history_prompt: Optional[str] = None,
3737
temp: float = 0.7,
3838
silent: bool = False,
39+
output_full: bool = False,
3940
):
4041
"""Generate audio array from semantic input.
4142
@@ -44,31 +45,49 @@ def semantic_to_waveform(
4445
history_prompt: history choice for audio cloning
4546
temp: generation temperature (1.0 more diverse, 0.0 more conservative)
4647
silent: disable progress bar
48+
output_full: return full generation to be used as a history prompt
4749
4850
Returns:
4951
numpy audio array at sample frequency 24khz
5052
"""
51-
x_coarse_gen = generate_coarse(
53+
coarse_tokens = generate_coarse(
5254
semantic_tokens,
5355
history_prompt=history_prompt,
5456
temp=temp,
5557
silent=silent,
5658
)
57-
x_fine_gen = generate_fine(
58-
x_coarse_gen,
59+
fine_tokens = generate_fine(
60+
coarse_tokens,
5961
history_prompt=history_prompt,
6062
temp=0.5,
6163
)
62-
audio_arr = codec_decode(x_fine_gen)
64+
audio_arr = codec_decode(fine_tokens)
65+
if output_full:
66+
full_generation = {
67+
"semantic_prompt": semantic_tokens,
68+
"coarse_prompt": coarse_tokens,
69+
"fine_prompt": fine_tokens,
70+
}
71+
return full_generation, audio_arr
6372
return audio_arr
6473

6574

75+
def save_as_prompt(filepath, full_generation):
76+
assert(filepath.endswith(".npz"))
77+
assert(isinstance(full_generation, dict))
78+
assert("semantic_prompt" in full_generation)
79+
assert("coarse_prompt" in full_generation)
80+
assert("fine_prompt" in full_generation)
81+
np.savez(filepath, **full_generation)
82+
83+
6684
def generate_audio(
6785
text: str,
6886
history_prompt: Optional[str] = None,
6987
text_temp: float = 0.7,
7088
waveform_temp: float = 0.7,
7189
silent: bool = False,
90+
output_full: bool = False,
7291
):
7392
"""Generate audio array from input text.
7493
@@ -78,14 +97,24 @@ def generate_audio(
7897
text_temp: generation temperature (1.0 more diverse, 0.0 more conservative)
7998
waveform_temp: generation temperature (1.0 more diverse, 0.0 more conservative)
8099
silent: disable progress bar
100+
output_full: return full generation to be used as a history prompt
81101
82102
Returns:
83103
numpy audio array at sample frequency 24khz
84104
"""
85-
x_semantic = text_to_semantic(
105+
semantic_tokens = text_to_semantic(
86106
text, history_prompt=history_prompt, temp=text_temp, silent=silent,
87107
)
88-
audio_arr = semantic_to_waveform(
89-
x_semantic, history_prompt=history_prompt, temp=waveform_temp, silent=silent,
108+
out = semantic_to_waveform(
109+
semantic_tokens,
110+
history_prompt=history_prompt,
111+
temp=waveform_temp,
112+
silent=silent,
113+
output_full=output_full,
90114
)
115+
if output_full:
116+
full_generation, audio_arr = out
117+
return full_generation, audio_arr
118+
else:
119+
audio_arr = out
91120
return audio_arr

bark/generation.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -365,10 +365,13 @@ def generate_text_semantic(
365365
text = _normalize_whitespace(text)
366366
assert len(text.strip()) > 0
367367
if history_prompt is not None:
368-
assert (history_prompt in ALLOWED_PROMPTS)
369-
semantic_history = np.load(
370-
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
371-
)["semantic_prompt"]
368+
if history_prompt.endswith(".npz"):
369+
semantic_history = np.load(history_prompt)["semantic_prompt"]
370+
else:
371+
assert (history_prompt in ALLOWED_PROMPTS)
372+
semantic_history = np.load(
373+
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
374+
)["semantic_prompt"]
372375
assert (
373376
isinstance(semantic_history, np.ndarray)
374377
and len(semantic_history.shape) == 1
@@ -509,10 +512,13 @@ def generate_coarse(
509512
semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS
510513
max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio))
511514
if history_prompt is not None:
512-
assert (history_prompt in ALLOWED_PROMPTS)
513-
x_history = np.load(
514-
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
515-
)
515+
if history_prompt.endswith(".npz"):
516+
x_history = np.load(history_prompt)
517+
else:
518+
assert (history_prompt in ALLOWED_PROMPTS)
519+
x_history = np.load(
520+
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
521+
)
516522
x_semantic_history = x_history["semantic_prompt"]
517523
x_coarse_history = x_history["coarse_prompt"]
518524
assert (
@@ -652,10 +658,13 @@ def generate_fine(
652658
and x_coarse_gen.max() <= CODEBOOK_SIZE - 1
653659
)
654660
if history_prompt is not None:
655-
assert (history_prompt in ALLOWED_PROMPTS)
656-
x_fine_history = np.load(
657-
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
658-
)["fine_prompt"]
661+
if history_prompt.endswith(".npz"):
662+
x_fine_history = np.load(history_prompt)["fine_prompt"]
663+
else:
664+
assert (history_prompt in ALLOWED_PROMPTS)
665+
x_fine_history = np.load(
666+
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
667+
)["fine_prompt"]
659668
assert (
660669
isinstance(x_fine_history, np.ndarray)
661670
and len(x_fine_history.shape) == 2

0 commit comments

Comments
 (0)