|
96 | 96 | get_quant_config, |
97 | 97 | gguf_quant_weights_iterator, |
98 | 98 | initialize_dummy_weights, |
| 99 | + maybe_add_mtp_safetensors, |
99 | 100 | multi_thread_pt_weights_iterator, |
100 | 101 | multi_thread_safetensors_weights_iterator, |
101 | 102 | np_cache_weights_iterator, |
@@ -321,13 +322,17 @@ class Source: |
321 | 322 | fall_back_to_pt: bool = True |
322 | 323 | """Whether .pt weights can be used.""" |
323 | 324 |
|
| 325 | + model_config: Optional["ModelConfig"] = None |
| 326 | + """The model configuration (for checking architecture, etc).""" |
| 327 | + |
324 | 328 | @classmethod |
325 | 329 | def init_new(cls, model_config: ModelConfig, model): |
326 | 330 | return cls( |
327 | 331 | model_config.model_path, |
328 | 332 | model_config.revision, |
329 | 333 | prefix="", |
330 | 334 | fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True), |
| 335 | + model_config=model_config, |
331 | 336 | ) |
332 | 337 |
|
333 | 338 | def __init__(self, load_config: LoadConfig): |
@@ -471,6 +476,15 @@ def _get_weights_iterator( |
471 | 476 | hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( |
472 | 477 | source.model_or_path, source.revision, source.fall_back_to_pt |
473 | 478 | ) |
| 479 | + |
| 480 | + if use_safetensors and source.model_config is not None: |
| 481 | + hf_weights_files = maybe_add_mtp_safetensors( |
| 482 | + hf_weights_files, |
| 483 | + hf_folder, |
| 484 | + "model.safetensors.index.json", |
| 485 | + source.model_config.hf_config, |
| 486 | + ) |
| 487 | + |
474 | 488 | if self.load_config.load_format == LoadFormat.NPCACHE: |
475 | 489 | # Currently np_cache only support *.bin checkpoints |
476 | 490 | assert use_safetensors is False |
|
0 commit comments