Skip to content

Commit e9ad2d5

Browse files
committed
initial commit
1 parent 2c12023 commit e9ad2d5

File tree

1 file changed

+59
-21
lines changed

1 file changed

+59
-21
lines changed

bark/generation.py

Lines changed: 59 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import torch.nn.functional as F
1515
import tqdm
1616
from transformers import BertTokenizer
17+
from huggingface_hub import hf_hub_download
1718

1819
from .model import GPTConfig, GPT
1920
from .model_fine import FineGPT, FineGPTConfig
@@ -89,31 +90,64 @@ def autocast():
8990
GLOBAL_ENABLE_MPS = os.environ.get("SUNO_ENABLE_MPS", False)
9091
OFFLOAD_CPU = os.environ.get("SUNO_OFFLOAD_CPU", False)
9192

92-
REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/"
93+
# REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/"
94+
95+
# REMOTE_MODEL_PATHS = {
96+
# "text_small": {
97+
# "path": os.path.join(REMOTE_BASE_URL, "text.pt"),
98+
# "checksum": "b3e42bcbab23b688355cd44128c4cdd3",
99+
# },
100+
# "coarse_small": {
101+
# "path": os.path.join(REMOTE_BASE_URL, "coarse.pt"),
102+
# "checksum": "5fe964825e3b0321f9d5f3857b89194d",
103+
# },
104+
# "fine_small": {
105+
# "path": os.path.join(REMOTE_BASE_URL, "fine.pt"),
106+
# "checksum": "5428d1befe05be2ba32195496e58dc90",
107+
# },
108+
# "text": {
109+
# "path": os.path.join(REMOTE_BASE_URL, "text_2.pt"),
110+
# "checksum": "54afa89d65e318d4f5f80e8e8799026a",
111+
# },
112+
# "coarse": {
113+
# "path": os.path.join(REMOTE_BASE_URL, "coarse_2.pt"),
114+
# "checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28",
115+
# },
116+
# "fine": {
117+
# "path": os.path.join(REMOTE_BASE_URL, "fine_2.pt"),
118+
# "checksum": "59d184ed44e3650774a2f0503a48a97b",
119+
# },
120+
# }
93121

94122
REMOTE_MODEL_PATHS = {
95123
"text_small": {
96-
"path": os.path.join(REMOTE_BASE_URL, "text.pt"),
124+
"repo_id": "reach-vb/bark-small",
125+
"file_name": "text.pt",
97126
"checksum": "b3e42bcbab23b688355cd44128c4cdd3",
98127
},
99128
"coarse_small": {
100-
"path": os.path.join(REMOTE_BASE_URL, "coarse.pt"),
129+
"repo_id": "reach-vb/bark-small",
130+
"file_name": "coarse.pt",
101131
"checksum": "5fe964825e3b0321f9d5f3857b89194d",
102132
},
103133
"fine_small": {
104-
"path": os.path.join(REMOTE_BASE_URL, "fine.pt"),
134+
"repo_id": "reach-vb/bark-small",
135+
"file_name": "fine.pt",
105136
"checksum": "5428d1befe05be2ba32195496e58dc90",
106137
},
107138
"text": {
108-
"path": os.path.join(REMOTE_BASE_URL, "text_2.pt"),
139+
"repo_id": "reach-vb/bark",
140+
"file_name": "text_2.pt",
109141
"checksum": "54afa89d65e318d4f5f80e8e8799026a",
110142
},
111143
"coarse": {
112-
"path": os.path.join(REMOTE_BASE_URL, "coarse_2.pt"),
144+
"repo_id": "reach-vb/bark",
145+
"file_name": "coarse_2.pt",
113146
"checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28",
114147
},
115148
"fine": {
116-
"path": os.path.join(REMOTE_BASE_URL, "fine_2.pt"),
149+
"repo_id": "reach-vb/bark-small",
150+
"file_name": "fine_2.pt",
117151
"checksum": "59d184ed44e3650774a2f0503a48a97b",
118152
},
119153
}
@@ -165,20 +199,24 @@ def _parse_s3_filepath(s3_filepath):
165199
return bucket_name, rel_s3_filepath
166200

167201

168-
def _download(from_s3_path, to_local_path):
169-
os.makedirs(CACHE_DIR, exist_ok=True)
170-
response = requests.get(from_s3_path, stream=True)
171-
total_size_in_bytes = int(response.headers.get("content-length", 0))
172-
block_size = 1024
173-
progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
174-
with open(to_local_path, "wb") as file:
175-
for data in response.iter_content(block_size):
176-
progress_bar.update(len(data))
177-
file.write(data)
178-
progress_bar.close()
179-
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
180-
raise ValueError("ERROR, something went wrong")
202+
# def _download(from_s3_path, to_local_path):
203+
# os.makedirs(CACHE_DIR, exist_ok=True)
204+
# response = requests.get(from_s3_path, stream=True)
205+
# total_size_in_bytes = int(response.headers.get("content-length", 0))
206+
# block_size = 1024
207+
# progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
208+
# with open(to_local_path, "wb") as file:
209+
# for data in response.iter_content(block_size):
210+
# progress_bar.update(len(data))
211+
# file.write(data)
212+
# progress_bar.close()
213+
# if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
214+
# raise ValueError("ERROR, something went wrong")
215+
181216

217+
def _download(from_hf_path, file_name, to_local_path):
218+
os.makedirs(CACHE_DIR, exist_ok=True)
219+
hf_hub_download(repo_id=from_hf_path, filename=file_name, cache_dir=to_local_path)
182220

183221
class InferenceContext:
184222
def __init__(self, benchmark=False):
@@ -243,7 +281,7 @@ def _load_model(ckpt_path, device, use_small=False, model_type="text"):
243281
os.remove(ckpt_path)
244282
if not os.path.exists(ckpt_path):
245283
logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
246-
_download(model_info["path"], ckpt_path)
284+
_download(model_info["repo_id"], model_info["file_name"], ckpt_path)
247285
checkpoint = torch.load(ckpt_path, map_location=device)
248286
# this is a hack
249287
model_args = checkpoint["model_args"]

0 commit comments

Comments
 (0)