|
14 | 14 | import torch.nn.functional as F |
15 | 15 | import tqdm |
16 | 16 | from transformers import BertTokenizer |
| 17 | +from huggingface_hub import hf_hub_download |
17 | 18 |
|
18 | 19 | from .model import GPTConfig, GPT |
19 | 20 | from .model_fine import FineGPT, FineGPTConfig |
@@ -89,31 +90,64 @@ def autocast(): |
89 | 90 | GLOBAL_ENABLE_MPS = os.environ.get("SUNO_ENABLE_MPS", False) |
90 | 91 | OFFLOAD_CPU = os.environ.get("SUNO_OFFLOAD_CPU", False) |
91 | 92 |
|
92 | | -REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/" |
| 93 | +# REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/" |
| 94 | + |
| 95 | +# REMOTE_MODEL_PATHS = { |
| 96 | +# "text_small": { |
| 97 | +# "path": os.path.join(REMOTE_BASE_URL, "text.pt"), |
| 98 | +# "checksum": "b3e42bcbab23b688355cd44128c4cdd3", |
| 99 | +# }, |
| 100 | +# "coarse_small": { |
| 101 | +# "path": os.path.join(REMOTE_BASE_URL, "coarse.pt"), |
| 102 | +# "checksum": "5fe964825e3b0321f9d5f3857b89194d", |
| 103 | +# }, |
| 104 | +# "fine_small": { |
| 105 | +# "path": os.path.join(REMOTE_BASE_URL, "fine.pt"), |
| 106 | +# "checksum": "5428d1befe05be2ba32195496e58dc90", |
| 107 | +# }, |
| 108 | +# "text": { |
| 109 | +# "path": os.path.join(REMOTE_BASE_URL, "text_2.pt"), |
| 110 | +# "checksum": "54afa89d65e318d4f5f80e8e8799026a", |
| 111 | +# }, |
| 112 | +# "coarse": { |
| 113 | +# "path": os.path.join(REMOTE_BASE_URL, "coarse_2.pt"), |
| 114 | +# "checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28", |
| 115 | +# }, |
| 116 | +# "fine": { |
| 117 | +# "path": os.path.join(REMOTE_BASE_URL, "fine_2.pt"), |
| 118 | +# "checksum": "59d184ed44e3650774a2f0503a48a97b", |
| 119 | +# }, |
| 120 | +# } |
93 | 121 |
|
94 | 122 | REMOTE_MODEL_PATHS = { |
95 | 123 | "text_small": { |
96 | | - "path": os.path.join(REMOTE_BASE_URL, "text.pt"), |
| 124 | + "repo_id": "reach-vb/bark-small", |
| 125 | + "file_name": "text.pt", |
97 | 126 | "checksum": "b3e42bcbab23b688355cd44128c4cdd3", |
98 | 127 | }, |
99 | 128 | "coarse_small": { |
100 | | - "path": os.path.join(REMOTE_BASE_URL, "coarse.pt"), |
| 129 | + "repo_id": "reach-vb/bark-small", |
| 130 | + "file_name": "coarse.pt", |
101 | 131 | "checksum": "5fe964825e3b0321f9d5f3857b89194d", |
102 | 132 | }, |
103 | 133 | "fine_small": { |
104 | | - "path": os.path.join(REMOTE_BASE_URL, "fine.pt"), |
| 134 | + "repo_id": "reach-vb/bark-small", |
| 135 | + "file_name": "fine.pt", |
105 | 136 | "checksum": "5428d1befe05be2ba32195496e58dc90", |
106 | 137 | }, |
107 | 138 | "text": { |
108 | | - "path": os.path.join(REMOTE_BASE_URL, "text_2.pt"), |
| 139 | + "repo_id": "reach-vb/bark", |
| 140 | + "file_name": "text_2.pt", |
109 | 141 | "checksum": "54afa89d65e318d4f5f80e8e8799026a", |
110 | 142 | }, |
111 | 143 | "coarse": { |
112 | | - "path": os.path.join(REMOTE_BASE_URL, "coarse_2.pt"), |
| 144 | + "repo_id": "reach-vb/bark", |
| 145 | + "file_name": "coarse_2.pt", |
113 | 146 | "checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28", |
114 | 147 | }, |
115 | 148 | "fine": { |
116 | | - "path": os.path.join(REMOTE_BASE_URL, "fine_2.pt"), |
| 149 | + "repo_id": "reach-vb/bark-small", |
| 150 | + "file_name": "fine_2.pt", |
117 | 151 | "checksum": "59d184ed44e3650774a2f0503a48a97b", |
118 | 152 | }, |
119 | 153 | } |
@@ -165,20 +199,24 @@ def _parse_s3_filepath(s3_filepath): |
165 | 199 | return bucket_name, rel_s3_filepath |
166 | 200 |
|
167 | 201 |
|
168 | | -def _download(from_s3_path, to_local_path): |
169 | | - os.makedirs(CACHE_DIR, exist_ok=True) |
170 | | - response = requests.get(from_s3_path, stream=True) |
171 | | - total_size_in_bytes = int(response.headers.get("content-length", 0)) |
172 | | - block_size = 1024 |
173 | | - progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) |
174 | | - with open(to_local_path, "wb") as file: |
175 | | - for data in response.iter_content(block_size): |
176 | | - progress_bar.update(len(data)) |
177 | | - file.write(data) |
178 | | - progress_bar.close() |
179 | | - if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: |
180 | | - raise ValueError("ERROR, something went wrong") |
| 202 | +# def _download(from_s3_path, to_local_path): |
| 203 | +# os.makedirs(CACHE_DIR, exist_ok=True) |
| 204 | +# response = requests.get(from_s3_path, stream=True) |
| 205 | +# total_size_in_bytes = int(response.headers.get("content-length", 0)) |
| 206 | +# block_size = 1024 |
| 207 | +# progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) |
| 208 | +# with open(to_local_path, "wb") as file: |
| 209 | +# for data in response.iter_content(block_size): |
| 210 | +# progress_bar.update(len(data)) |
| 211 | +# file.write(data) |
| 212 | +# progress_bar.close() |
| 213 | +# if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: |
| 214 | +# raise ValueError("ERROR, something went wrong") |
| 215 | + |
181 | 216 |
|
| 217 | +def _download(from_hf_path, file_name, to_local_path): |
| 218 | + os.makedirs(CACHE_DIR, exist_ok=True) |
| 219 | + hf_hub_download(repo_id=from_hf_path, filename=file_name, cache_dir=to_local_path) |
182 | 220 |
|
183 | 221 | class InferenceContext: |
184 | 222 | def __init__(self, benchmark=False): |
@@ -243,7 +281,7 @@ def _load_model(ckpt_path, device, use_small=False, model_type="text"): |
243 | 281 | os.remove(ckpt_path) |
244 | 282 | if not os.path.exists(ckpt_path): |
245 | 283 | logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.") |
246 | | - _download(model_info["path"], ckpt_path) |
| 284 | + _download(model_info["repo_id"], model_info["file_name"], ckpt_path) |
247 | 285 | checkpoint = torch.load(ckpt_path, map_location=device) |
248 | 286 | # this is a hack |
249 | 287 | model_args = checkpoint["model_args"] |
|
0 commit comments