forked from sgl-project/sglang
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathlayer.py
More file actions
1403 lines (1256 loc) · 52.8 KB
/
layer.py
File metadata and controls
1403 lines (1256 loc) · 52.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
import logging
from enum import Enum
from typing import List, Optional, Tuple
import torch
from sglang.srt.batch_overlap.single_batch_overlap import DownGemmOverlapArgs
from sglang.srt.batch_overlap.two_batch_overlap import MaybeTboDeepEPDispatcher
from sglang.srt.compilation.piecewise_context_manager import (
get_forward_context,
is_in_piecewise_cuda_graph,
)
from sglang.srt.distributed import (
get_moe_expert_parallel_rank,
get_moe_expert_parallel_world_size,
get_moe_tensor_parallel_rank,
get_moe_tensor_parallel_world_size,
get_tp_group,
tensor_model_parallel_all_reduce,
)
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.layers.dp_attention import is_allocation_symmetric
from sglang.srt.layers.moe import (
MoeRunnerConfig,
get_deepep_mode,
get_moe_a2a_backend,
get_moe_runner_backend,
)
from sglang.srt.layers.moe.kt_ep_wrapper import (
KTEPWrapperMethod,
create_kt_config_from_server_args,
)
from sglang.srt.layers.moe.token_dispatcher import CombineInput, DispatchOutput
from sglang.srt.layers.moe.token_dispatcher.base import BaseDispatcher
from sglang.srt.layers.moe.token_dispatcher.flashinfer import FlashinferDispatcher
from sglang.srt.layers.moe.token_dispatcher.standard import (
StandardDispatcher,
StandardDispatchOutput,
)
from sglang.srt.layers.moe.topk import (
BypassedTopKOutput,
StandardTopKOutput,
TopKConfig,
TopKOutput,
TopKOutputChecker,
)
from sglang.srt.layers.moe.utils import RoutingMethodType
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
QuantizationConfig,
)
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe import (
CompressedTensorsMxInt4MoEMethod,
)
from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod
from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import (
cpu_has_amx_support,
get_bool_env_var,
is_cpu,
is_flashinfer_available,
is_hip,
next_power_of_2,
round_up,
)
from sglang.srt.utils.custom_op import register_custom_op
if is_flashinfer_available():
from flashinfer import fp4_quantize
# Try to import FP4 TRTLLM function if flashinfer is available
trtllm_fp4_block_scale_moe = None
if get_moe_runner_backend().is_flashinfer_trtllm():
try:
from flashinfer.fused_moe import trtllm_fp4_block_scale_moe
except ImportError:
trtllm_fp4_block_scale_moe = None
_is_hip = is_hip()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
logger = logging.getLogger(__name__)
def create_moe_dispatcher(moe_runner_config: MoeRunnerConfig) -> BaseDispatcher:
a2a_backend = get_moe_a2a_backend()
if a2a_backend.is_none():
return StandardDispatcher(moe_runner_config)
elif a2a_backend.is_deepep() or a2a_backend.is_mooncake() or a2a_backend.is_mori():
return MaybeTboDeepEPDispatcher(
group=(
get_tp_group().device_group
if not a2a_backend.is_mori()
else get_tp_group()
),
router_topk=moe_runner_config.top_k,
permute_fusion=True,
num_experts=moe_runner_config.num_experts,
num_local_experts=moe_runner_config.num_local_experts,
hidden_size=moe_runner_config.hidden_size,
params_dtype=moe_runner_config.params_dtype,
deepep_mode=get_deepep_mode(),
async_finish=True,
return_recv_hook=True,
)
elif a2a_backend.is_ascend_fuseep():
from sglang.srt.layers.moe.token_dispatcher import NpuFuseEPDispatcher
return NpuFuseEPDispatcher(
group=get_tp_group().device_group,
router_topk=moe_runner_config.top_k,
permute_fusion=True,
num_experts=moe_runner_config.num_experts,
num_local_experts=moe_runner_config.num_local_experts,
hidden_size=moe_runner_config.hidden_size,
params_dtype=moe_runner_config.params_dtype,
)
elif a2a_backend.is_flashinfer():
return FlashinferDispatcher(
group=get_tp_group().device_group,
router_topk=moe_runner_config.top_k,
num_experts=moe_runner_config.num_experts,
num_local_experts=moe_runner_config.num_local_experts,
hidden_size=moe_runner_config.hidden_size,
)
else:
raise NotImplementedError(f"Unsupported a2a backend: {a2a_backend}")
class FusedMoeWeightScaleSupported(Enum):
TENSOR = "tensor"
CHANNEL = "channel"
GROUP = "group"
BLOCK = "block"
class FusedMoE(torch.nn.Module):
"""FusedMoE layer for MoE models.
This layer contains both MergedColumnParallel weights (gate_up_proj /
w13) and RowParallelLinear weights (down_proj/ w2).
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
copy that naming convention here and handle any remapping in the
load_weights function in each model implementation.
Args:
num_experts: Number of experts in the model
top_k: Number of experts selected for each token
hidden_size: Input hidden state size of the transformer
intermediate_size: Intermediate size of the experts
params_dtype: Data type for the parameters.
reduce_results: Whether to apply all_reduce on the output of the layer
quant_config: Quantization configuration.
inplace: suggestion to compute inplace (modify input activation).
"""
def __init__(
self,
num_experts: int,
hidden_size: int,
intermediate_size: int,
layer_id: int,
top_k: Optional[int] = None,
num_fused_shared_experts: int = 0,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_presharded_weights: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
gemm1_alpha: Optional[float] = None,
gemm1_clamp_limit: Optional[float] = None,
use_weight_loader_fused: bool = False,
with_bias=False,
routing_method_type: Optional[RoutingMethodType] = None,
is_gated: bool = True,
):
super().__init__()
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.layer_id = layer_id
self.top_k = top_k
self.hidden_size = hidden_size
self.num_experts = num_experts
self.num_fused_shared_experts = num_fused_shared_experts
self.enable_flashinfer_cutlass_moe = (
get_moe_runner_backend().is_flashinfer_cutlass()
)
self.moe_ep_size = get_moe_expert_parallel_world_size()
self.moe_ep_rank = get_moe_expert_parallel_rank()
self.moe_tp_size = get_moe_tensor_parallel_world_size()
self.moe_tp_rank = get_moe_tensor_parallel_rank()
assert (num_experts - num_fused_shared_experts) % self.moe_ep_size == 0
self.num_local_experts = (
num_experts - num_fused_shared_experts
) // self.moe_ep_size + num_fused_shared_experts
self.expert_mask_gpu = None
assert intermediate_size % self.moe_tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.moe_tp_size
self.reduce_results = reduce_results
self.use_presharded_weights = use_presharded_weights
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernels()
self.use_flashinfer_trtllm_moe = get_moe_runner_backend().is_flashinfer_trtllm()
# flashinfer_trtllm kernel requires intermediate_size to be a multiple of 128
# Pad the intermediate_size_per_partition if necessary
if (
self.use_flashinfer_trtllm_moe
and self.intermediate_size_per_partition % 128 != 0
):
self.intermediate_size_per_partition = round_up(
self.intermediate_size_per_partition, 128
)
self.quant_config = quant_config
self.use_flashinfer_mxfp4_moe = get_moe_runner_backend().is_flashinfer_mxfp4()
# TODO maybe we should remove this `if`, since `Mxfp4MoEMethod` does another round-up logic
if (
self.quant_config is not None
and self.quant_config.get_name() == "mxfp4"
and self.use_flashinfer_mxfp4_moe
):
hidden_size = round_up(hidden_size, 256)
self.hidden_size = hidden_size
self.moe_runner_config = MoeRunnerConfig(
num_experts=num_experts,
num_local_experts=self.num_local_experts,
hidden_size=hidden_size,
intermediate_size_per_partition=self.intermediate_size_per_partition,
layer_id=layer_id,
top_k=top_k,
num_fused_shared_experts=num_fused_shared_experts,
params_dtype=params_dtype,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
inplace=inplace,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
gemm1_alpha=gemm1_alpha,
gemm1_clamp_limit=gemm1_clamp_limit,
is_gated=is_gated,
routing_method_type=routing_method_type,
)
self.quant_method: Optional[FusedMoEMethodBase] = None
server_args = get_global_server_args()
kt_config = create_kt_config_from_server_args(server_args, layer_id)
if kt_config is not None:
if quant_config is not None:
gpu_method = quant_config.get_quant_method(self, prefix)
else:
gpu_method = UnquantizedFusedMoEMethod(self.use_triton_kernels)
self.quant_method = KTEPWrapperMethod(gpu_method, kt_config)
else:
if quant_config is not None:
self.quant_method = quant_config.get_quant_method(self, prefix)
if self.quant_method is None:
self.quant_method = UnquantizedFusedMoEMethod(
self.use_triton_kernels, self.use_flashinfer_trtllm_moe
)
self.quant_method.create_weights(
layer=self,
num_experts=self.num_local_experts,
hidden_size=hidden_size,
intermediate_size_per_partition=self.intermediate_size_per_partition,
params_dtype=params_dtype,
weight_loader=(
self.weight_loader
if not use_weight_loader_fused
else self.weight_loader_fused
),
with_bias=with_bias,
)
self.quant_method.create_moe_runner(self, self.moe_runner_config)
self.dispatcher = create_moe_dispatcher(self.moe_runner_config)
self.should_fuse_routed_scaling_factor_in_topk = isinstance(
self.quant_method, ModelOptNvFp4FusedMoEMethod
) or (
isinstance(self.quant_method, Fp8MoEMethod)
and get_moe_runner_backend().is_cutlass()
)
self.routing_method_type = routing_method_type
# overlap args
self.down_gemm_overlap_args: Optional[DownGemmOverlapArgs] = None
self.meta_overlap_args: Optional[dict] = None
if self.quant_method is not None and hasattr(self.quant_method, "runner"):
self.runner = self.quant_method.runner
def _load_per_tensor_weight_scale(
self,
shard_id: str,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
expert_id: int,
):
param_data = param.data
# for per tensor weight quantization
if shard_id in ("w1", "w3"):
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx = 0 if shard_id == "w1" else 1
if self.moe_runner_config.is_gated:
param_data[expert_id][idx] = loaded_weight
else:
param_data[expert_id] = loaded_weight
# If we are in the row parallel case (down_proj)
elif shard_id == "w2":
param_data[expert_id] = loaded_weight
def _load_model_weight_or_group_weight_scale(
self,
shard_dim: int,
expert_data: torch.Tensor,
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int,
is_bias: bool = False,
):
# Load grouped weight scales for group quantization
# or model weights
if shard_id == "w2":
self._load_w2(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
is_bias=is_bias,
)
elif shard_id in ("w1", "w3", "w13"):
self._load_w13(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
is_bias=is_bias,
)
def _load_per_channel_weight_scale(
self,
expert_data: torch.Tensor,
shard_dim: int,
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int,
):
# for per channel weight quantization
if shard_id == "w2":
expert_data.copy_(loaded_weight)
elif shard_id in ("w1", "w3"):
self._load_w13(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
def _load_w13(
self,
expert_data: torch.Tensor,
shard_dim: int,
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int,
is_bias: bool = False,
):
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
assert shard_id in {"w1", "w3", "w13"}
if is_bias:
# if this weight is a bias, the last dimension must be the sharded dimension
shard_dim = -1
if shard_id in {"w1", "w3"} and self.moe_runner_config.is_gated:
# non-fused version
shard_size = expert_data.shape[shard_dim] // 2
elif shard_id in {"w13"} or (
shard_id in {"w1", "w3"} and not self.moe_runner_config.is_gated
):
# fused version
shard_size = expert_data.shape[shard_dim]
else:
raise NotImplementedError
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
# w3, up_proj: Load into second logical weight of w13.
# trtllm cutlass kernel assumes differently
switch_w13 = getattr(self.quant_method, "load_up_proj_weight_first", False)
if (switch_w13 and shard_id == "w1") or (not switch_w13 and shard_id == "w3"):
start = shard_size
else:
start = 0
# Use narrow_padded_param_and_loaded_weight for:
# 1. CPU (always)
# 2. GPU with flashinfer_trtllm padding (when intermediate_size is padded to 128)
# This handles the case where the loaded weights are smaller than the padded expert_data
use_padded_loading = _is_cpu or self.use_flashinfer_trtllm_moe
if use_padded_loading:
expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
expert_data,
loaded_weight,
start,
shard_size * tp_rank,
shard_dim,
shard_size,
not self.use_presharded_weights,
)
else:
if not self.use_presharded_weights:
if not is_bias and self.use_triton_kernels:
# do not transpose for bias
loaded_weight = loaded_weight.transpose(-2, -1)
loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size
)
expert_data = expert_data.narrow(shard_dim, start, shard_size)
expert_data.copy_(loaded_weight)
def _load_w2(
self,
expert_data: torch.Tensor,
shard_dim: int,
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int,
is_bias: bool = False,
):
"""Load w2 weights for down projection.
Args:
expert_data: The expert data tensor to load into
shard_dim: The dimension to shard along
shard_id: The shard ID (must be "w2")
loaded_weight: The weight tensor to load from
tp_rank: The tensor parallel rank
"""
if not isinstance(expert_data, torch.Tensor) or not isinstance(
loaded_weight, torch.Tensor
):
raise ValueError("expert_data and loaded_weight must be torch.Tensor")
if (
self.quant_config is not None
and "modelopt" in self.quant_config.get_name()
and (expert_data.dim() != 2 or loaded_weight.dim() != 2)
):
raise ValueError(
f"Expected 2D tensors, got expert_data shape {expert_data.shape} and loaded_weight shape {loaded_weight.shape}"
)
if shard_id != "w2":
raise ValueError(f"shard_id must be 'w2', got {shard_id}")
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
if is_bias:
# this expert_data is a bias, not weight,
# for w2_weight_bias in TP, it does not need to be sharded
shard_size = expert_data.shape[-1]
else:
# this parameter is a weight matrix
# for w2 in TP, it shards the input_features, i.e., shard_dim=2
shard_size = expert_data.shape[shard_dim]
# Use narrow_padded_param_and_loaded_weight for:
# 1. CPU (always)
# 2. GPU with flashinfer_trtllm padding (when intermediate_size is padded to 128)
# This handles the case where the loaded weights are smaller than the padded expert_data
use_padded_loading = _is_cpu or self.use_flashinfer_trtllm_moe
if use_padded_loading:
expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
expert_data,
loaded_weight,
0, # param_data_start
shard_size * tp_rank,
shard_dim,
shard_size,
not self.use_presharded_weights,
)
else:
if not is_bias and not self.use_presharded_weights:
if self.use_triton_kernels:
loaded_weight = loaded_weight.transpose(-2, -1)
loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size
)
# w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight)
def _load_single_value(
self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int
):
param_data = param.data
# Input scales can be loaded directly and should be equal.
param_data[expert_id] = loaded_weight
def _load_g_idx(
self,
shard_id: str,
expert_data: torch.Tensor,
shard_dim: int,
loaded_weight: torch.Tensor,
tp_rank: int,
):
if shard_id == "w2":
self._load_w2(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
else:
assert shard_id in ("w1", "w3")
expert_data.copy_(loaded_weight)
def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
num_global_routed_experts = self.num_experts - self.num_fused_shared_experts
num_local_routed_experts = (
self.num_local_experts - self.num_fused_shared_experts
)
start_idx = self.moe_ep_rank * num_local_routed_experts
end_idx = (self.moe_ep_rank + 1) * num_local_routed_experts
if start_idx <= expert_id < end_idx:
return expert_id - start_idx
elif (
self.num_fused_shared_experts > 0 and expert_id >= num_global_routed_experts
):
return expert_id - num_global_routed_experts + num_local_routed_experts
else:
return -1
def weight_loader(
self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: Optional[int],
) -> None:
# if expert_id is None, then
# all the experts are loaded at the same time
if (
not expert_id
and self.quant_config is not None
and self.quant_config.get_name() == "mxfp4"
and self.quant_config.is_static_cfg()
):
if "bias" in weight_name:
dim1 = loaded_weight.shape[1]
param.data[:, :dim1].copy_(loaded_weight)
else:
dim1 = loaded_weight.shape[1]
dim2 = loaded_weight.shape[2]
param.data[:, :dim1, :dim2].copy_(loaded_weight)
return
global_expert_location_metadata = get_global_expert_location_metadata()
if global_expert_location_metadata is None:
if not getattr(param, "_sglang_require_global_experts", False):
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
if expert_id == -1:
return
self._weight_loader_impl(
param=param,
loaded_weight=loaded_weight,
weight_name=weight_name,
shard_id=shard_id,
expert_id=expert_id,
)
return
if expert_id >= self.num_experts - self.num_fused_shared_experts:
# This is a shared expert.
physical_expert_ids = [expert_id]
else:
require_global_experts = getattr(
param, "_sglang_require_global_experts", False
)
physical_expert_ids = (
global_expert_location_metadata.logical_to_all_physical(
self.layer_id, expert_id, require_global_experts
)
)
for physical_expert_id in physical_expert_ids:
self._weight_loader_physical(
param=param,
loaded_weight=loaded_weight,
weight_name=weight_name,
shard_id=shard_id,
expert_id=physical_expert_id,
)
def _weight_loader_physical(
self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int,
) -> None:
# WARN: This makes the `expert_id` mean "local" and "global" in different cases
if not getattr(param, "_sglang_require_global_experts", False):
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
if expert_id < 0 or expert_id >= self.num_local_experts:
return
if isinstance(
self.quant_method,
KTEPWrapperMethod,
):
if self.quant_method.num_gpu_experts != -1:
if expert_id >= self.quant_method.num_gpu_experts:
return
self._weight_loader_impl(
param=param,
loaded_weight=loaded_weight,
weight_name=weight_name,
shard_id=shard_id,
expert_id=expert_id,
)
def _weight_loader_impl(
self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int,
) -> None:
tp_rank = self.moe_tp_rank
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
method = self.quant_method
if method.__class__.__name__ == "KTEPWrapperMethod":
method = method.gpu_method
loaded_weight = (
loaded_weight.t().contiguous()
if (
method.__class__.__name__
in [
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod",
"CompressedTensorsWNA16TritonMoEMethod",
]
)
else loaded_weight
)
if shard_id not in ("w1", "w2", "w3"):
raise ValueError(f"shard_id must be ['w1','w2','w3'] but got {shard_id}.")
# Flashinfer assumes w31 format for w13_weight. Same for the scales.
if self.use_flashinfer_trtllm_moe and (
isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
or isinstance(self.quant_method, Fp8MoEMethod)
or isinstance(self.quant_method, UnquantizedFusedMoEMethod)
or isinstance(self.quant_method, CompressedTensorsMxInt4MoEMethod)
):
shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
# Fetch the dim to shard the parameter/loaded weight
# based on the shard id. This will be whatever
# dimension intermediate_size is used.
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
expert_data = param.data[expert_id]
# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size is
is_transposed = getattr(param, "is_transposed", False)
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
if self.use_triton_kernels:
is_transposed = True
if is_transposed:
shard_dim = int(not shard_dim)
# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD)
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
loaded_weight = loaded_weight * 2.0
# this is needed for compressed-tensors only
loaded_weight = loaded_weight.to(param.data.device)
if (
(
"compressed" in self.quant_method.__class__.__name__.lower()
or "w4afp8" in self.quant_config.get_name()
)
and (param.data[expert_id] != 1).any()
and ((param.data[expert_id] - loaded_weight).abs() > 1e-5).any()
):
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param.data[expert_id]} "
f"vs. {loaded_weight}"
)
self._load_single_value(
param=param, loaded_weight=loaded_weight, expert_id=expert_id
)
return
# Case g_idx
if "g_idx" in weight_name:
self._load_g_idx(
shard_dim=0,
shard_id=shard_id,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
return
if "ModelOpt" in self.quant_method.__class__.__name__:
# Determine per-tensor weight scale patterns based on variant
is_fp4_variant = isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
# FP4 uses "weight_scale_2" for per-tensor, FP8 uses "weight_scale" for per-tensor
per_tensor_conditions = (
"weight_scale_2" in weight_name
if is_fp4_variant
else "weight_scale" in weight_name
) or "input_scale" in weight_name
if per_tensor_conditions:
self._load_per_tensor_weight_scale(
shard_id=shard_id,
param=param,
loaded_weight=loaded_weight,
expert_id=expert_id,
)
elif "weight" in weight_name:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
return
# Case weight scales and zero_points
if "scale" in weight_name or "zero" in weight_name or "offset" in weight_name:
# load the weight scales and zp based on the quantization scheme
# supported weight scales/zp can be found in
# FusedMoeWeightScaleSupported
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
# specific to each case
quant_method = getattr(param, "quant_method", None)
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD)
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
loaded_weight = loaded_weight * 0.5
self._load_per_channel_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
elif quant_method in [
FusedMoeWeightScaleSupported.GROUP.value,
FusedMoeWeightScaleSupported.BLOCK.value,
]:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD)
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
loaded_weight = loaded_weight * 2.0
self._load_per_tensor_weight_scale(
shard_id=shard_id,
param=param,
loaded_weight=loaded_weight,
expert_id=expert_id,
)
else:
raise ValueError(
f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}"
)
return
# Case weight_shape
if "weight_shape" in weight_name:
# only required by compressed-tensors
self._load_single_value(
param=param, loaded_weight=loaded_weight, expert_id=expert_id
)
return
# Case model weights
if "weight" in weight_name:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
return
if (
"bias" in weight_name
and self.quant_config.quant_description["quant_method"] == "modelslim"
):
self._load_per_channel_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
def weight_loader_fused(
self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
) -> None:
tp_rank = self.moe_tp_rank
if (
self.quant_config is not None
and self.quant_config.get_name() == "mxfp4"
and self.quant_config.is_static_cfg()
):
if "bias" in weight_name:
dim1 = loaded_weight.shape[1]
param.data[:, :dim1].copy_(loaded_weight)
elif "scale" in weight_name:
param.data.copy_(loaded_weight)
else:
dim1 = loaded_weight.shape[1]
dim2 = loaded_weight.shape[2]
param.data[:, :dim1, :dim2].copy_(loaded_weight)
return
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO: check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
loaded_weight = (
loaded_weight.t().contiguous()
if (
self.quant_method.__class__.__name__
in [
"CompressedTensorsWNA16MoEMethod",
"CompressedTensorsWNA16TritonMoEMethod",
]
)
else loaded_weight
)
if shard_id not in ("w13", "w2"):
raise ValueError(f"shard_id must be ['w13','w2'] but got {shard_id}.")
# Fetch the dim to shard the parameter/loaded weight
# based on the shard id. This will be whatever
# dimension intermediate_size is used.
SHARD_ID_TO_SHARDED_DIM = {"w13": 1, "w2": 2}
SHARD_ID_TO_SHARDED_DIM_TRANSPOSE = {"w13": 2, "w2": 1}
expert_data = param.data
is_bias = expert_data.dim() == 2
# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size is
is_transposed = getattr(param, "is_transposed", False)
if self.use_triton_kernels:
is_transposed = True
shard_dim = (
SHARD_ID_TO_SHARDED_DIM[shard_id]
if not is_transposed
else SHARD_ID_TO_SHARDED_DIM_TRANSPOSE[shard_id]
)
# Case model weights
if "weight" in weight_name:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
is_bias=is_bias,
)
return
else:
logging.warning(
f"Unsupported weight_name {weight_name} for FusedMoE weight_loader_fused. Nothing is loaded."
)
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
if is_in_piecewise_cuda_graph():
if not TopKOutputChecker.format_is_standard(topk_output):
# Make sure there is torch lib op registration for the whole moe layer
return self.forward_impl(hidden_states, topk_output)
else:
return moe_forward_piecewise_cuda_graph_impl(
hidden_states,
topk_output.topk_weights,
topk_output.topk_ids,
topk_output.router_logits,
self.layer_id,
)
else:
return self.forward_impl(hidden_states, topk_output)
def forward_impl(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
origin_hidden_states_dim = hidden_states.shape[-1]
assert self.quant_method is not None
dispatch_output = self.dispatcher.dispatch(
hidden_states=hidden_states, topk_output=topk_output
)
if _use_aiter and self.dispatcher.local_expert_mapping is not None:
self.expert_mask_gpu = (
(
(self.dispatcher.local_expert_mapping >= 0)
& (self.dispatcher.local_expert_mapping < self.num_local_experts)
)
.to(torch.int32)
.to(device="cuda")
)
combine_input = self.run_moe_core(
dispatch_output=dispatch_output,
)
with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
final_hidden_states = self.dispatcher.combine(combine_input=combine_input)
# TODO: should we add some conditions here?
final_hidden_states = final_hidden_states[
..., :origin_hidden_states_dim
].contiguous()
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states