1212import logging
1313import math
1414import os
15+ import re
1516import socket
1617import threading
1718import time
@@ -304,6 +305,8 @@ class DefaultModelLoader(BaseModelLoader):
304305 # default number of thread when enable multithread weight loading
305306 DEFAULT_NUM_THREADS = 8
306307
308+ _MTP_PATTERN = re .compile (r"model\.mtp\.layers\.(\d+)\." )
309+
307310 @dataclasses .dataclass
308311 class Source :
309312 """A source for weights."""
@@ -351,11 +354,11 @@ def __init__(self, load_config: LoadConfig):
351354
352355 def _maybe_download_from_modelscope (
353356 self , model : str , revision : Optional [str ]
354- ) -> Optional [ str ] :
357+ ) -> str :
355358 """Download model from ModelScope hub if SGLANG_USE_MODELSCOPE is True.
356359
357- Returns the path to the downloaded model, or None if the model is not
358- downloaded from ModelScope."""
360+ Returns the path to the downloaded model, or the original model path if
361+ not downloaded from ModelScope."""
359362 if get_bool_env_var ("SGLANG_USE_MODELSCOPE" ):
360363 # download model from ModelScope hub,
361364 # lazy import so that modelscope is not required for normal use.
@@ -373,17 +376,16 @@ def _maybe_download_from_modelscope(
373376 else :
374377 model_path = model
375378 return model_path
376- return None
379+ return model
377380
378381 def _prepare_weights (
379382 self , model_name_or_path : str , revision : Optional [str ], fall_back_to_pt : bool
380383 ) -> Tuple [str , List [str ], bool ]:
381384 """Prepare weights for the model.
382385
383386 If the model is not local, it will be downloaded."""
384- model_name_or_path = (
385- self ._maybe_download_from_modelscope (model_name_or_path , revision )
386- or model_name_or_path
387+ model_name_or_path = self ._maybe_download_from_modelscope (
388+ model_name_or_path , revision
387389 )
388390
389391 is_local = os .path .isdir (model_name_or_path )
@@ -474,6 +476,7 @@ def _get_weights_iterator(
474476 ) -> Generator [Tuple [str , torch .Tensor ], None , None ]:
475477 """Get an iterator for the model weights based on the load format."""
476478 extra_config = self .load_config .model_loader_extra_config
479+ use_multithread = extra_config .get ("enable_multithread_load" , False )
477480 hf_folder , hf_weights_files , use_safetensors = self ._prepare_weights (
478481 source .model_or_path , source .revision , source .fall_back_to_pt
479482 )
@@ -504,7 +507,7 @@ def _get_weights_iterator(
504507 weights_iterator = fastsafetensors_weights_iterator (
505508 hf_weights_files ,
506509 )
507- elif extra_config . get ( "enable_multithread_load" ) :
510+ elif use_multithread :
508511 weights_iterator = multi_thread_safetensors_weights_iterator (
509512 hf_weights_files ,
510513 max_workers = extra_config .get (
@@ -518,7 +521,7 @@ def _get_weights_iterator(
518521 )
519522
520523 else :
521- if extra_config . get ( "enable_multithread_load" ) :
524+ if use_multithread :
522525 weights_iterator = multi_thread_pt_weights_iterator (
523526 hf_weights_files ,
524527 max_workers = extra_config .get (
@@ -529,28 +532,34 @@ def _get_weights_iterator(
529532 weights_iterator = pt_weights_iterator (hf_weights_files )
530533
531534 if self .load_config .draft_model_idx is not None :
532- import re
533-
534- pattern = r"model.mtp.layers.(\d+)."
535- filtered_weights = []
536- for name , tensor in weights_iterator :
537- group = re .match (pattern , name )
538- if group is not None :
539- idx = int (group .group (1 ))
540- if idx != self .load_config .draft_model_idx :
541- continue
542- new_name = name .replace (group .group (), "model.mtp.layers.0." )
543- else :
544- new_name = name
545- filtered_weights .append ((source .prefix + new_name , tensor ))
546- return tuple (filtered_weights )
535+ return self ._filter_mtp_weights (
536+ weights_iterator , source .prefix , self .load_config .draft_model_idx
537+ )
547538
548539 if self .counter_before_loading_weights == 0.0 :
549- logger .info ("Beginning to load weights" )
550540 self .counter_before_loading_weights = time .perf_counter ()
551541 # Apply the prefix.
552542 return ((source .prefix + name , tensor ) for (name , tensor ) in weights_iterator )
553543
544+ @classmethod
545+ def _filter_mtp_weights (
546+ cls , weights_iterator , prefix : str , draft_model_idx : int
547+ ) -> Tuple [Tuple [str , torch .Tensor ], ...]:
548+ """Filter MTP (Multi-Token Prediction) weights to keep only the
549+ specified draft model layer and remap it to layer 0."""
550+ filtered_weights = []
551+ for name , tensor in weights_iterator :
552+ match = cls ._MTP_PATTERN .match (name )
553+ if match is not None :
554+ idx = int (match .group (1 ))
555+ if idx != draft_model_idx :
556+ continue
557+ new_name = name .replace (match .group (), "model.mtp.layers.0." )
558+ else :
559+ new_name = name
560+ filtered_weights .append ((prefix + new_name , tensor ))
561+ return tuple (filtered_weights )
562+
554563 def _get_all_weights (
555564 self ,
556565 model_config : ModelConfig ,
@@ -670,10 +679,6 @@ def load_model(
670679 )
671680
672681 self .counter_after_loading_weights = time .perf_counter ()
673- logger .info (
674- "Loading weights took %.2f seconds" ,
675- self .counter_after_loading_weights - self .counter_before_loading_weights ,
676- )
677682 return model .eval ()
678683
679684 @staticmethod
0 commit comments