diff --git a/docker/Dockerfile b/docker/Dockerfile index 439ecddb34..bba404c965 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,14 +1,17 @@ -ARG CUDA_VERSION=12.8.0 +ARG CUDA_VERSION=13.0.0 FROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu22.04 ARG PYTHON_VERSION=3.10 ARG MAMBA_VERSION=24.7.1-0 -ARG VLLM_VERSION=0.16.0 +ARG VLLM_VERSION=0.21.0 +ARG NIXL_REF=v1.1.0 ARG FLASH_MLA_REF=47c35a7 +ARG DEEPGEMM_REF=891d57b4db1071624b5c8fa0d1e51cb317fa709f ARG TARGETPLATFORM ARG ENABLE_DEEPEP=1 ARG ENABLE_NIXL=1 ARG ENABLE_CACHE=1 +ARG ENABLE_SM100=0 ENV PATH=/opt/conda/bin:$PATH \ CONDA_PREFIX=/opt/conda @@ -44,13 +47,18 @@ WORKDIR /root COPY ./requirements.txt /lightllm/requirements.txt RUN pip install -U pip -RUN pip install -r /lightllm/requirements.txt --no-cache-dir -RUN pip install --no-cache-dir vllm==${VLLM_VERSION} -RUN git clone https://github.com/deepseek-ai/FlashMLA.git /root/FlashMLA && \ +RUN pip install --no-cache-dir \ + --extra-index-url https://download.pytorch.org/whl/cu130 \ + vllm==${VLLM_VERSION} +RUN pip install -r /lightllm/requirements.txt --no-cache-dir \ + --extra-index-url https://download.pytorch.org/whl/cu130 +RUN export CPATH=/usr/local/cuda/targets/x86_64-linux/include/cccl:/usr/local/cuda/targets/x86_64-linux/include${CPATH:+:${CPATH}} && \ + git clone https://github.com/deepseek-ai/FlashMLA.git /root/FlashMLA && \ cd /root/FlashMLA && \ git checkout ${FLASH_MLA_REF} && \ git submodule update --init --recursive && \ - FLASH_MLA_DISABLE_SM100=1 pip install --no-cache-dir . + FLASH_MLA_DISABLE_SM100="$(if [ "${ENABLE_SM100}" = "1" ]; then echo 0; else echo 1; fi)" \ + pip install --no-cache-dir . RUN apt-get update && apt-get install -y libnuma-dev && rm -rf /var/lib/apt/lists/* @@ -78,27 +86,20 @@ RUN if [ "${ENABLE_NIXL}" = "1" ] || [ "${ENABLE_DEEPEP}" = "1" ]; then \ RUN if [ "${ENABLE_DEEPEP}" = "1" ]; then \ set -e; \ ln -sf /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so; \ - NVSHMEM_VERSION=3.3.9; \ - CUDA_ARCHS=90; \ - wget https://developer.download.nvidia.com/compute/redist/nvshmem/${NVSHMEM_VERSION}/source/nvshmem_src_cuda12-all-all-${NVSHMEM_VERSION}.tar.gz \ - && tar -xf nvshmem_src_cuda12-all-all-${NVSHMEM_VERSION}.tar.gz && mv nvshmem_src nvshmem \ - && cd nvshmem \ - && rm -f /root/nvshmem_src_cuda12-all-all-${NVSHMEM_VERSION}.tar.gz \ - && NVSHMEM_SHMEM_SUPPORT=0 \ - NVSHMEM_UCX_SUPPORT=0 \ - NVSHMEM_USE_NCCL=0 \ - NVSHMEM_MPI_SUPPORT=0 \ - NVSHMEM_IBGDA_SUPPORT=1 \ - NVSHMEM_PMIX_SUPPORT=0 \ - NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \ - NVSHMEM_USE_GDRCOPY=1 \ - cmake -S . -B build/ -DCMAKE_INSTALL_PREFIX=/root/nvshmem/install -DCMAKE_CUDA_ARCHITECTURES=${CUDA_ARCHS} \ - && cmake --build build --target install -j64; \ - DEEPEP_COMMIT=b6ce310bb0b75079682d09bc2ebc063a074fbd58; \ - cd /root && git clone https://github.com/deepseek-ai/DeepEP.git && cd DeepEP && git checkout ${DEEPEP_COMMIT} && cd ..; \ - cd /root/DeepEP && NVSHMEM_DIR=/root/nvshmem/install python setup.py install; \ + python -m pip install --upgrade --no-deps \ + "nvidia-nccl-cu13==2.30.4" \ + "nvidia-nvshmem-cu13==3.6.5"; \ + cd /root && git clone https://github.com/deepseek-ai/DeepEP.git && cd DeepEP && git checkout b306af06afd412c88e51e71802951606e40b7358; \ + ln -sf /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nvshmem/lib/libnvshmem_host.so.3 /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nvshmem/lib/libnvshmem_host.so; \ + ln -sf /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nccl/lib/libnccl.so.2 /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nccl/lib/libnccl.so; \ + pip install --no-build-isolation .; \ fi +RUN cd /root && git clone https://github.com/deepseek-ai/DeepGEMM.git && \ + cd DeepGEMM && git checkout ${DEEPGEMM_REF} && \ + git submodule update --init --recursive && \ + pip install --no-build-isolation . + RUN if [ "${ENABLE_NIXL}" = "1" ]; then \ apt-get update && apt-get install -y cmake automake autotools-dev libtool libz-dev && \ DEBIAN_FRONTEND=noninteractive apt-get -y install --reinstall libibverbs-dev rdma-core ibverbs-utils libibumad-dev; \ @@ -126,7 +127,7 @@ RUN if [ "${ENABLE_NIXL}" = "1" ]; then \ apt-get update && apt-get install -y pkg-config tmux net-tools && \ cd /usr/local/src; \ pip install --upgrade meson pybind11 patchelf; \ - git clone https://github.com/ai-dynamo/nixl.git -b main && \ + git clone https://github.com/ai-dynamo/nixl.git -b ${NIXL_REF} && \ cd nixl && \ rm -rf build && \ mkdir build && \ diff --git a/docker/scripts/build.sh b/docker/scripts/build.sh index 355d6c65b3..bc1fd73da3 100644 --- a/docker/scripts/build.sh +++ b/docker/scripts/build.sh @@ -18,21 +18,23 @@ set -euo pipefail # --no-nixl Disable NIXL (default: enabled) # --no-cache Disable cache (default: enabled) # --lite Disable DEEPEP, NIXL and cache in one shot -# --cuda-version CUDA version (default: 12.8.0) +# --cuda-version CUDA version (default: 13.0.0) # --image-prefix Image prefix (default: lightllm) # --image-tag Image tag (default: generated from enabled features) +# --enable-sm100 Enable SM100 support (default: disabled) # -h / --help Show help ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" cd "${ROOT_DIR}" IMAGE_PREFIX="${IMAGE_PREFIX:-lightllm}" -CUDA_VERSION="${CUDA_VERSION:-12.8.0}" +CUDA_VERSION="${CUDA_VERSION:-13.0.0}" IMAGE_TAG="${IMAGE_TAG:-}" ENABLE_DEEPEP="${ENABLE_DEEPEP:-1}" ENABLE_NIXL="${ENABLE_NIXL:-1}" ENABLE_CACHE="${ENABLE_CACHE:-1}" +ENABLE_SM100="${ENABLE_SM100:-0}" print_help() { sed -n '1,80p' "$0" | sed 's/^# \{0,1\}//' @@ -43,6 +45,7 @@ while [[ $# -gt 0 ]]; do --no-deepep) ENABLE_DEEPEP=0 ;; --no-nixl) ENABLE_NIXL=0 ;; --no-cache) ENABLE_CACHE=0 ;; + --enable-sm100) ENABLE_SM100=1 ;; --lite) ENABLE_DEEPEP=0 ENABLE_NIXL=0 @@ -78,13 +81,16 @@ done # - Other combos: composed from enabled feature names if [[ -z "${IMAGE_TAG}" ]]; then tag_parts=() + if [[ "${ENABLE_SM100}" -eq 1 ]]; then + tag_parts+=("sm100") + fi if [[ "${ENABLE_NIXL}" -eq 1 ]]; then tag_parts+=("nixl") fi if [[ "${ENABLE_DEEPEP}" -eq 1 ]]; then tag_parts+=("deepep") fi - if [[ "${ENABLE_NIXL}" -eq 1 && "${ENABLE_DEEPEP}" -eq 1 && "${ENABLE_CACHE}" -eq 1 ]]; then + if [[ "${ENABLE_SM100}" -eq 0 && "${ENABLE_NIXL}" -eq 1 && "${ENABLE_DEEPEP}" -eq 1 && "${ENABLE_CACHE}" -eq 1 ]]; then IMAGE_TAG="cuda${CUDA_VERSION}" else prefix="" @@ -100,6 +106,6 @@ DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile \ --build-arg ENABLE_DEEPEP="${ENABLE_DEEPEP}" \ --build-arg ENABLE_NIXL="${ENABLE_NIXL}" \ --build-arg ENABLE_CACHE="${ENABLE_CACHE}" \ + --build-arg ENABLE_SM100="${ENABLE_SM100}" \ --progress=plain \ -t "${IMAGE_PREFIX}:${IMAGE_TAG}" . - diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index c46617ae98..efb4bde86c 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -23,6 +23,7 @@ from lightllm.common.basemodel.prefill_cuda_graph import PrefillCudaGraph from lightllm.common.quantization import Quantcfg from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token, gather_token_prefill_decode_mixed +from lightllm.utils.config_utils import _derive_max_req_total_len_from_model_config from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_dp_world_size from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type, get_added_mtp_kv_layer_num @@ -105,8 +106,8 @@ def __init__(self, kvargs): self._init_quant() self._init_weights() - self._init_req_manager() self._init_mem_manager() + self._init_req_manager() # 因为类似 qwen3.5 的linear 架构的模型,其 req_manager 会存储运行时使用的大量 linear state # 这可能会占用大量的显存,所以,req_manger 中保存的 mem_manger 是mem manager 初始化后再赋值 self.req_manager.mem_manager = self.mem_manager @@ -210,6 +211,26 @@ def _init_kv_move_buffer(self): if self.run_mode in ["prefill", "decode"]: self.mem_manager.alloc_kv_move_buffer(self.mem_manager.size) + # 推导出的max_req_total_len如果显存预算支持不了,需要进一步截断到可支持的长度 + def _safe_clamp_auto_max_req_total_len(self): + max_total_token_num = self.mem_manager.size + if self.max_seq_length is None or self.max_seq_length <= max_total_token_num: + return + + # 只截断推导生成的max_req_total_len + old_max_req_total_len = self.max_seq_length - 8 + derived_max_req_total_len = _derive_max_req_total_len_from_model_config(self.weight_dir_) + if derived_max_req_total_len is None or old_max_req_total_len != derived_max_req_total_len: + return + + supported_max_req_total_len = max(max_total_token_num - 8, 1) + self.args.max_req_total_len = supported_max_req_total_len + self.max_seq_length = supported_max_req_total_len + 8 + + if self.graph_max_len_in_batch == old_max_req_total_len: + self.args.graph_max_len_in_batch = min(self.args.graph_max_len_in_batch, supported_max_req_total_len) + self.graph_max_len_in_batch = self.args.graph_max_len_in_batch + def _check_mem_size(self): self.max_total_token_num = self.mem_manager.size @@ -232,6 +253,7 @@ def _check_mem_size(self): return def _init_req_manager(self): + self._safe_clamp_auto_max_req_total_len() create_max_seq_len = 0 if self.batch_max_tokens is not None: diff --git a/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py b/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py index 7889e8090e..8bcf99b992 100644 --- a/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py +++ b/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py @@ -33,6 +33,7 @@ class BufNode: inner_tensor: torch.Tensor shape_key: Tuple[int, torch.dtype] storage_weak_ptr: int + free_use_count_bias: int = 0 shape_to_tensor: Dict[Union[torch.Size, Iterable[int]], torch.Tensor] = field(default_factory=dict) def __del__(self): @@ -99,7 +100,8 @@ def alloc_tensor( # 回收可能消亡的 tensor for ptr in self.changed_ptr: t_buf_node = self.ptr_to_bufnode[ptr] - if self.use_count(ptr) == 1 + len(t_buf_node.shape_to_tensor): + free_use_count = t_buf_node.free_use_count_bias + 1 + len(t_buf_node.shape_to_tensor) + if self.use_count(ptr) <= free_use_count: self.free_shape_dtype_to_bufs[t_buf_node.shape_key].append(t_buf_node) self.changed_ptr.clear() @@ -131,6 +133,7 @@ def alloc_tensor( self.ptr_to_bufnode[storage_weak_ptr] = buf_node if shape not in buf_node.shape_to_tensor: buf_node.shape_to_tensor[shape] = buf_node.inner_tensor.view(shape) + buf_node.free_use_count_bias = self.use_count(storage_weak_ptr) - (1 + len(buf_node.shape_to_tensor)) mark_tensor = buf_node.shape_to_tensor[shape] ans = mark_tensor.data # 返回一个新的引用, 否则引用计数会无法判断 ans.storage_weak_ptr = buf_node.storage_weak_ptr diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index fca9b80fcf..375725d124 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -11,6 +11,7 @@ from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.impl import select_fuse_moe_impl from lightllm.common.quantization.quantize_method import QuantizationMethod from lightllm.utils.envs_utils import get_redundancy_expert_ids, get_redundancy_expert_num, get_env_start_args +from lightllm.utils.device_utils import is_sm100_gpu from lightllm.utils.dist_utils import get_global_world_size, get_global_rank from lightllm.utils.log_utils import init_logger @@ -52,6 +53,7 @@ def __init__( self.quant_method = quant_method assert num_fused_shared_experts in [0, 1], "num_fused_shared_experts can only support 0 or 1 now." self.enable_ep_moe = get_env_start_args().enable_ep_moe + self.quant_method = self._maybe_upgrade_quant_method_for_ep_moe(self.quant_method) self.n_routed_experts = n_routed_experts self.num_fused_shared_experts = num_fused_shared_experts self._init_config(network_config) @@ -70,6 +72,28 @@ def __init__( self.lock = threading.Lock() self._create_weight() + def _maybe_upgrade_quant_method_for_ep_moe(self, quant_method: QuantizationMethod) -> QuantizationMethod: + if not self.enable_ep_moe: + return quant_method + + target_method = "deepgemm-fp8fp4-b32" if is_sm100_gpu() else "deepgemm-fp8w8a8-b128" + if quant_method.method_name == "none": + from lightllm.common.quantization.registry import QUANTMETHODS + + logger.info( + f"enable_ep_moe requires DeepGEMM MoE expert weights; " + f"auto-upgrading fused_moe quantization from `none` to `{target_method}`." + ) + quant_method = QUANTMETHODS.get(target_method) + + if quant_method.method_name != target_method: + raise ValueError( + f"enable_ep_moe currently requires `{target_method}` for fused_moe on this GPU, " + f"but got `{quant_method.method_name}`." + ) + + return quant_method + def _init_config(self, network_config: Dict[str, Any]): self.n_group = network_config.get("n_group", 0) self.use_grouped_topk = self.n_group > 0 @@ -152,6 +176,9 @@ def experts( per_expert_scale=self.per_expert_scale, ) + def use_sm100_mega_moe(self) -> bool: + return bool(getattr(self.fuse_moe_impl, "_use_sm100_fp4_moe", lambda: False)()) + def low_latency_dispatch( self, hidden_states: torch.Tensor, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py index c9b8cfa3eb..2adc4343e2 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py @@ -4,11 +4,14 @@ from lightllm.distributed import dist_group_manager from lightllm.common.triton_utils.autotuner import Autotuner from lightllm.common.quantization.quantize_method import WeightPack -from lightllm.utils.envs_utils import get_deepep_num_max_dispatch_tokens_per_rank +from lightllm.utils.envs_utils import ( + get_deepep_num_max_dispatch_tokens_per_rank_prefill, + get_deepep_num_max_dispatch_tokens_per_rank_decode, +) from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe_ep import ( fused_experts_impl, masked_group_gemm, - _deepgemm_grouped_fp8_nt_contiguous, + deepgemm_grouped_fp8_nt_contiguous, ) from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import ( per_token_group_quant_fp8, @@ -17,9 +20,84 @@ from lightllm.common.basemodel.triton_kernel.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul import silu_and_mul_fwd from lightllm.common.basemodel.triton_kernel.redundancy_topk_ids_repair import redundancy_topk_ids_repair +from lightllm.utils.device_utils import is_sm100_gpu class FuseMoeDeepGEMM(FuseMoeTriton): + def _get_ep_num_sms(self) -> int: + return getattr(dist_group_manager, "ep_num_sms", None) or 0 + + def _use_sm100_fp4_moe(self) -> bool: + return is_sm100_gpu() and self.quant_method.method_name == "deepgemm-fp8fp4-b32" + + def _get_mega_moe_weights(self, w13: WeightPack, w2: WeightPack): + cache_key = ( + w13.weight.data_ptr(), + w13.weight_scale.data_ptr(), + w2.weight.data_ptr(), + w2.weight_scale.data_ptr(), + ) + if getattr(self, "_mega_moe_weight_cache_key", None) != cache_key: + import deep_gemm + + self._mega_moe_weight_cache = deep_gemm.transform_weights_for_mega_moe( + (w13.weight, w13.weight_scale), + (w2.weight, w2.weight_scale), + ) + self._mega_moe_weight_cache_key = cache_key + return self._mega_moe_weight_cache + + def _get_mega_moe_stats(self, num_local_experts: int, device: torch.device): + stats = getattr(self, "_mega_moe_stats", None) + if stats is None or stats.numel() != num_local_experts or stats.device != device: + stats = torch.zeros((num_local_experts,), device=device, dtype=torch.int32) + self._mega_moe_stats = stats + return stats + + def _mega_moe( + self, + hidden_states: torch.Tensor, + w13: WeightPack, + w2: WeightPack, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ) -> torch.Tensor: + import deep_gemm + from deep_gemm.utils import per_token_cast_to_fp8 + + buffer = getattr(dist_group_manager, "ep_mega_moe_buffer", None) + if buffer is None: + raise RuntimeError("SM100 Mega MoE requires dist_group_manager.ep_mega_moe_buffer to be initialized") + + num_tokens = hidden_states.shape[0] + if num_tokens > buffer.num_max_tokens_per_rank: + raise RuntimeError( + f"Mega MoE got {num_tokens} tokens, exceeding num_max_tokens_per_rank={buffer.num_max_tokens_per_rank}" + ) + + qinput_tensor = per_token_cast_to_fp8( + hidden_states, + use_ue8m0=True, + gran_k=self.quant_method.block_size, + use_packed_ue8m0=True, + ) + l1_weights, l2_weights = self._get_mega_moe_weights(w13, w2) + cumulative_stats = self._get_mega_moe_stats(w13.weight.shape[0], hidden_states.device) + buffer.x[:num_tokens].copy_(qinput_tensor[0]) + buffer.x_sf[:num_tokens].copy_(qinput_tensor[1]) + buffer.topk_idx[:num_tokens].copy_(topk_ids) + buffer.topk_weights[:num_tokens].copy_(topk_weights) + + output = torch.empty_like(hidden_states) + deep_gemm.fp8_fp4_mega_moe( + output, + l1_weights, + l2_weights, + buffer, + cumulative_local_expert_recv_stats=cumulative_stats, + ) + return output + def _select_experts( self, input_tensor: torch.Tensor, @@ -74,7 +152,11 @@ def _fused_experts( ): w13_weight, w13_scale = w13.weight, w13.weight_scale w2_weight, w2_scale = w2.weight, w2.weight_scale + if self._use_sm100_fp4_moe(): + return self._mega_moe(input_tensor, w13, w2, topk_weights, topk_ids.to(torch.long)) + use_fp8_w8a8 = self.quant_method.method_name != "none" + buffer = dist_group_manager.ep_buffer if is_prefill else dist_group_manager.ep_low_latency_buffer output = fused_experts_impl( hidden_states=input_tensor, w1=w13_weight, @@ -82,7 +164,7 @@ def _fused_experts( topk_weights=topk_weights, topk_idx=topk_ids.to(torch.long), num_experts=self.total_expert_num_contain_redundancy, # number of all experts contain redundancy - buffer=dist_group_manager.ep_buffer, + buffer=buffer, is_prefill=is_prefill, use_fp8_w8a8=use_fp8_w8a8, use_fp8_all2all=use_fp8_w8a8, @@ -118,13 +200,13 @@ def low_latency_dispatch( ) topk_idx = topk_idx.to(torch.long) - num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank() + num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank_decode() use_fp8_w8a8 = self.quant_method.method_name != "none" - recv_x, masked_m, handle, event, hook = dist_group_manager.ep_buffer.low_latency_dispatch( - hidden_states, - topk_idx, - num_max_dispatch_tokens_per_rank, - self.total_expert_num_contain_redundancy, + recv_x, masked_m, handle, event, hook = dist_group_manager.ep_low_latency_buffer.low_latency_dispatch( + topk_idx=topk_idx, + x=hidden_states, + num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank, + num_experts=self.total_expert_num_contain_redundancy, use_fp8=use_fp8_w8a8, async_finish=False, return_recv_hook=True, @@ -156,6 +238,17 @@ def select_experts_and_quant_input( scoring_func=scoring_func, ) w13_weight, w13_scale = w13.weight, w13.weight_scale + if self._use_sm100_fp4_moe(): + from deep_gemm.utils import per_token_cast_to_fp8 + + qinput_tensor = per_token_cast_to_fp8( + hidden_states, + use_ue8m0=True, + gran_k=self.quant_method.block_size, + use_packed_ue8m0=True, + ) + return topk_weights, topk_idx.to(torch.long), qinput_tensor + block_size_k = 0 if w13_weight.ndim == 3: block_size_k = w13_weight.shape[2] // w13_scale.shape[2] @@ -171,38 +264,26 @@ def dispatch( overlap_event: Optional[Any] = None, ): buffer = dist_group_manager.ep_buffer - # get_dispatch_layout - ( - num_tokens_per_rank, - num_tokens_per_rdma_rank, - num_tokens_per_expert, - is_token_in_rank, - previous_event, - ) = buffer.get_dispatch_layout( - topk_idx, - self.total_expert_num_contain_redundancy, - previous_event=overlap_event, - async_finish=True, - allocate_on_comm_stream=True, - ) - recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = buffer.dispatch( + num_max_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank_prefill() + recv_x, recv_topk_idx, recv_topk_weights, handle, event = buffer.dispatch( qinput_tensor, topk_idx=topk_idx, topk_weights=topk_weights, - num_tokens_per_rank=num_tokens_per_rank, - num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, - is_token_in_rank=is_token_in_rank, - num_tokens_per_expert=num_tokens_per_expert, - previous_event=previous_event, - async_finish=True, - allocate_on_comm_stream=True, + num_experts=self.total_expert_num_contain_redundancy, + num_max_tokens_per_rank=num_max_tokens_per_rank, expert_alignment=128, + num_sms=self._get_ep_num_sms(), + previous_event=overlap_event, + async_with_compute_stream=True, + allocate_on_comm_stream=True, + do_cpu_sync=True, + do_handle_copy=False, ) def hook(): event.current_stream_wait() - return recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, hook + return recv_x, recv_topk_idx, recv_topk_weights, handle.num_recv_tokens_per_expert_list, handle, hook def masked_group_gemm( self, @@ -281,7 +362,7 @@ def prefilled_group_gemm( # groupgemm (contiguous layout) gemm_out_a = torch.empty((all_tokens, N), device=device, dtype=hidden_dtype) - _deepgemm_grouped_fp8_nt_contiguous(input_tensor, (w13_weight, w13_scale), gemm_out_a, m_indices) + deepgemm_grouped_fp8_nt_contiguous(input_tensor, (w13_weight, w13_scale), gemm_out_a, m_indices) # silu_and_mul_fwd + qaunt # TODO fused kernel @@ -295,7 +376,7 @@ def prefilled_group_gemm( # groupgemm (contiguous layout) gemm_out_b = torch.empty((all_tokens, K), device=device, dtype=hidden_dtype) - _deepgemm_grouped_fp8_nt_contiguous( + deepgemm_grouped_fp8_nt_contiguous( (qsilu_out, qsilu_out_scale), (w2_weight, w2_scale), gemm_out_b, m_indices ) # gather and local reduce @@ -319,7 +400,7 @@ def low_latency_combine( topk_weights: torch.Tensor, handle: Any, ): - combined_x, event_overlap, hook = dist_group_manager.ep_buffer.low_latency_combine( + combined_x, event_overlap, hook = dist_group_manager.ep_low_latency_buffer.low_latency_combine( gemm_out_b, topk_idx, topk_weights, handle, async_finish=False, return_recv_hook=True ) return combined_x, hook @@ -335,8 +416,9 @@ def combine( gemm_out_b, handle, topk_weights=None, - async_finish=True, + num_sms=self._get_ep_num_sms(), previous_event=overlap_event, + async_with_compute_stream=True, allocate_on_comm_stream=True, ) diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py index 2c6d013bd5..77705b1755 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py @@ -1,10 +1,7 @@ """Fused MoE kernel.""" -import os import torch import triton -import triton.language as tl from typing import Any, Callable, Dict, Optional, Tuple -import torch.distributed as dist from lightllm.utils.log_utils import init_logger from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul import silu_and_mul_fwd from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul_mix_quant_ep import ( @@ -15,9 +12,11 @@ tma_align_input_scale, ) from lightllm.common.basemodel.triton_kernel.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather -from lightllm.utils.envs_utils import get_deepep_num_max_dispatch_tokens_per_rank +from lightllm.utils.envs_utils import ( + get_deepep_num_max_dispatch_tokens_per_rank_prefill, + get_deepep_num_max_dispatch_tokens_per_rank_decode, +) from lightllm.common.triton_utils.autotuner import Autotuner -import numpy as np logger = init_logger(__name__) @@ -66,14 +65,14 @@ def fused_experts_impl( topk_weights: torch.Tensor, # [M, topk] topk_idx: torch.Tensor, # [M, topk] num_experts: int, - buffer: "Buffer", + buffer: Any, is_prefill: bool, use_fp8_w8a8: bool = False, use_fp8_all2all: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, - previous_event: Optional["EventOverlap"] = None, + previous_event: Optional[Any] = None, ): # Check constraints. assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" @@ -99,39 +98,27 @@ def fused_experts_impl( combined_x = None if is_prefill: qinput_tensor, input_scale = per_token_group_quant_fp8(hidden_states, block_size_k, dtype=w1.dtype) - - # get_dispatch_layout - ( - num_tokens_per_rank, - num_tokens_per_rdma_rank, - num_tokens_per_expert, - is_token_in_rank, - previous_event, - ) = buffer.get_dispatch_layout( - topk_idx, num_experts, previous_event=previous_event, async_finish=False, allocate_on_comm_stream=False - ) - + allocate_on_comm_stream = previous_event is not None # normal dispatch # recv_x [recive_num_tokens, hidden] recv_x_scale [recive_num_tokens, hidden // block_size] # recv_topk_idx [recive_num_tokens, topk_num] # recv_topk_weights [recive_num_tokens, topk_num] # num_recv_tokens_per_expert_list list [cur_node_expert_num] padding with expert_alignment=128 - recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = buffer.dispatch( + recv_x, recv_topk_idx, recv_topk_weights, handle, _ = buffer.dispatch( (qinput_tensor, input_scale), topk_idx=topk_idx, topk_weights=topk_weights, - num_tokens_per_rank=num_tokens_per_rank, - num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, - is_token_in_rank=is_token_in_rank, - num_tokens_per_expert=num_tokens_per_expert, - previous_event=previous_event, - async_finish=False, - allocate_on_comm_stream=False, + num_experts=num_experts, + num_max_tokens_per_rank=get_deepep_num_max_dispatch_tokens_per_rank_prefill(), expert_alignment=128, + previous_event=previous_event, + allocate_on_comm_stream=allocate_on_comm_stream, + do_cpu_sync=True, + do_handle_copy=False, ) # scatter - all_tokens = sum(num_recv_tokens_per_expert_list) # calcu padding all nums. + all_tokens = sum(handle.num_recv_tokens_per_expert_list) # calcu padding all nums. # gather_out shape [recive_num_tokens, hidden] gather_out = torch.empty_like(recv_x[0], device=hidden_states.device, dtype=hidden_states.dtype) if all_tokens > 0: @@ -149,7 +136,7 @@ def fused_experts_impl( output_index = torch.empty_like(recv_topk_idx) num_recv_tokens_per_expert = torch.tensor( - num_recv_tokens_per_expert_list, dtype=torch.int32, pin_memory=True, device="cpu" + handle.num_recv_tokens_per_expert_list, dtype=torch.int32, pin_memory=True, device="cpu" ).cuda(non_blocking=True) expert_start_loc = torch.empty_like(num_recv_tokens_per_expert) @@ -169,7 +156,7 @@ def fused_experts_impl( # groupgemm (contiguous layout) gemm_out_a = torch.empty((all_tokens, N), device=hidden_states.device, dtype=hidden_states.dtype) input_tensor[1] = tma_align_input_scale(input_tensor[1]) - _deepgemm_grouped_fp8_nt_contiguous(input_tensor, (w1, w1_scale), gemm_out_a, m_indices) + deepgemm_grouped_fp8_nt_contiguous(input_tensor, (w1, w1_scale), gemm_out_a, m_indices) # silu_and_mul_fwd + qaunt # TODO fused kernel @@ -183,7 +170,7 @@ def fused_experts_impl( # groupgemm (contiguous layout) gemm_out_b = torch.empty((all_tokens, K), device=hidden_states.device, dtype=hidden_states.dtype) - _deepgemm_grouped_fp8_nt_contiguous((qsilu_out, qsilu_out_scale), (w2, w2_scale), gemm_out_b, m_indices) + deepgemm_grouped_fp8_nt_contiguous((qsilu_out, qsilu_out_scale), (w2, w2_scale), gemm_out_b, m_indices) # gather and local reduce ep_gather(gemm_out_b, recv_topk_idx, recv_topk_weights, output_index, gather_out) @@ -202,13 +189,12 @@ def fused_experts_impl( gather_out, handle, topk_weights=None, - async_finish=False, previous_event=previous_event, - allocate_on_comm_stream=False, + allocate_on_comm_stream=allocate_on_comm_stream, ) else: # low latency dispatch - num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank() + num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank_decode() expected_m = triton.cdiv(hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1], num_experts) recv_x, masked_m, handle, event, hook = buffer.low_latency_dispatch( hidden_states, @@ -228,7 +214,7 @@ def fused_experts_impl( return combined_x -def _deepgemm_grouped_fp8_nt_contiguous( +def deepgemm_grouped_fp8_nt_contiguous( input_tuple: Tuple[torch.Tensor, torch.Tensor], w_tuple: Tuple[torch.Tensor, torch.Tensor], out: torch.Tensor, @@ -255,3 +241,22 @@ def _deepgemm_grouped_fp8_nt_masked( if hasattr(deep_gemm, "m_grouped_gemm_fp8_fp8_bf16_nt_masked"): return deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(input_tuple, w_tuple, out, masked_m, expected_m) raise RuntimeError("deep_gemm does not provide grouped_gemm_fp8 NT contiguous GEMM kernel in this version") + + +def deepgemm_grouped_fp8_fp4_nt_contiguous( + input_tuple: Tuple[torch.Tensor, torch.Tensor], + w_tuple: Tuple[torch.Tensor, torch.Tensor], + out: torch.Tensor, + grouped_layout: torch.Tensor, + use_psum_layout: bool = False, +): + if HAS_DEEPGEMM and hasattr(deep_gemm, "m_grouped_fp8_fp4_gemm_nt_contiguous"): + return deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous( + input_tuple, + w_tuple, + out, + grouped_layout, + use_psum_layout=use_psum_layout, + recipe=(1, 1, 32), + ) + raise RuntimeError("deep_gemm does not provide grouped fp8-fp4 NT contiguous GEMM kernel") diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index 0454c86628..12d89b9e25 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -14,7 +14,7 @@ from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args from lightllm.distributed.pynccl import PyNcclCommunicator from lightllm.utils.dist_utils import get_current_device_id -from lightllm.utils.config_utils import get_num_key_value_heads +from lightllm.utils.config_utils import get_num_key_value_heads, get_vocab_size from lightllm.common.kv_trans_kernel.nixl_kv_trans import page_io from lightllm.utils.device_utils import kv_trans_use_p2p from lightllm.utils.shm_utils import create_or_link_shm @@ -61,14 +61,28 @@ def get_att_input_params(self, layer_index: int) -> Tuple[Any, Any]: def get_cell_size(self): return 2 * self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype) + def get_req_manager_reserve_bytes(self): + args = get_env_start_args() + max_request_num = args.running_max_req_size + 8 + max_sequence_length = max(args.batch_max_tokens or 0, (args.max_req_total_len or 0) + 8) + req_state_num = max_request_num + 1 + + reserve_bytes = req_state_num * max_sequence_length * torch._utils._element_size(torch.int32) + reserve_bytes += req_state_num * 4 * torch._utils._element_size(torch.float32) + reserve_bytes += req_state_num * 8 * torch._utils._element_size(torch.int64) + if args.penalty_counter_mode == "gpu_counter": + reserve_bytes += req_state_num * get_vocab_size(args.model_dir) * torch._utils._element_size(torch.int32) + return reserve_bytes + def profile_size(self, mem_fraction): if self.size is not None: return torch.cuda.empty_cache() world_size = dist.get_world_size() - - available_memory = get_available_gpu_memory(world_size) * mem_fraction + available_memory = get_available_gpu_memory(world_size) - get_total_gpu_memory() * (1 - mem_fraction) + req_manager_reserve_gb = self.get_req_manager_reserve_bytes() / (1024 ** 3) + available_memory -= req_manager_reserve_gb cell_size = self.get_cell_size() self.size = int(available_memory * 1024 ** 3 / cell_size) if world_size > 1: @@ -76,7 +90,8 @@ def profile_size(self, mem_fraction): dist.all_reduce(tensor, op=dist.ReduceOp.MIN) self.size = tensor.item() logger.info( - f"{str(available_memory)} GB space is available after load the model weight\n" + f"{str(available_memory)} GB space is available after load the model weight " + f"and reserve {req_manager_reserve_gb} GB for req_manager\n" f"{str(cell_size / 1024 ** 2)} MB is the size of one token kv cache\n" f"{self.size} is the profiled max_total_token_num with the mem_fraction {mem_fraction}\n" ) diff --git a/lightllm/common/kv_cache_mem_manager/qwen3next_mem_manager.py b/lightllm/common/kv_cache_mem_manager/qwen3next_mem_manager.py index caca4bb621..b26b7a7004 100644 --- a/lightllm/common/kv_cache_mem_manager/qwen3next_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/qwen3next_mem_manager.py @@ -1,8 +1,11 @@ import torch +import torch.distributed as dist import triton from lightllm.utils.log_utils import init_logger from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.dist_utils import get_current_device_id +from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory from lightllm.common.linear_att_cache_manager import LinearAttCacheConfig, LinearAttCacheManager from .operator import LinearAttMemOperator from typing import Tuple, Any @@ -32,6 +35,38 @@ def get_att_input_params(self, layer_index: int) -> Tuple[Any, Any]: layer_index = layer_index // self.linear_config.full_attention_interval return super().get_att_input_params(layer_index) + def profile_size(self, mem_fraction): + if self.size is not None: + return + + torch.cuda.empty_cache() + args = get_env_start_args() + reserve_bytes = self.get_req_manager_reserve_bytes() + req_state_num = (args.running_max_req_size + 8 + 1) * (args.mtp_step + 1) + reserve_bytes += ( + req_state_num + * self.linear_config.linear_layer_num + * (self.linear_config.get_conv_state_bytes_per_layer() + self.linear_config.get_ssm_state_bytes_per_layer()) + ) + reserve_gb = reserve_bytes / (1024 ** 3) + + world_size = dist.get_world_size() + available_memory = get_available_gpu_memory(world_size) - get_total_gpu_memory() * (1 - mem_fraction) + available_memory -= reserve_gb + cell_size = self.get_cell_size() + self.size = max(int(available_memory * 1024 ** 3 / cell_size), 1) + if world_size > 1: + tensor = torch.tensor(self.size, dtype=torch.int64, device=f"cuda:{get_current_device_id()}") + dist.all_reduce(tensor, op=dist.ReduceOp.MIN) + self.size = tensor.item() + logger.info( + f"{str(available_memory)} GB space is available after load the model weight " + f"and reserve {reserve_gb} GB for qwen3next req_manager\n" + f"{str(cell_size / 1024 ** 2)} MB is the size of one token kv cache\n" + f"{self.size} is the profiled max_total_token_num with the mem_fraction {mem_fraction}\n" + ) + return + def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): super()._init_buffers(size, dtype, head_num, head_dim, layer_num) # TODO 初始化线性 att 对应的部分 buffer. diff --git a/lightllm/common/quantization/deepgemm.py b/lightllm/common/quantization/deepgemm.py index 137455a821..3b29951f28 100644 --- a/lightllm/common/quantization/deepgemm.py +++ b/lightllm/common/quantization/deepgemm.py @@ -126,6 +126,78 @@ def _create_weight( return mm_param, mm_param_list +@QUANTMETHODS.register(["deepgemm-fp8fp4-b32"], platform="cuda") +class DeepGEMMFP8FP4B32QuantizationMethod(DeepGEMMBaseQuantizationMethod): + def __init__(self): + super().__init__() + self.block_size = 32 + self.weight_suffix = "weight" + self.weight_zero_point_suffix = None + self.weight_scale_suffix = None + self.has_weight_scale = True + self.has_weight_zero_point = False + + @property + def method_name(self): + return "deepgemm-fp8fp4-b32" + + def quantize(self, weight: torch.Tensor, output: WeightPack): + from deep_gemm.utils import per_token_cast_to_fp4 + import deep_gemm + + weight = weight.cuda(output.weight.device) + if weight.dim() == 2: + n, k = weight.shape + packed_weight, weight_scale = per_token_cast_to_fp4(weight, use_ue8m0=True, gran_k=self.block_size) + weight_scale = deep_gemm.transform_sf_into_required_layout(weight_scale, n, k, (1, self.block_size), None) + else: + num_groups, n, k = weight.shape + packed_weight = torch.empty((num_groups, n, k // 2), device=weight.device, dtype=torch.int8) + weight_scale = torch.empty((num_groups, n, k // self.block_size), device=weight.device, dtype=torch.float32) + for i in range(num_groups): + packed_weight[i], weight_scale[i] = per_token_cast_to_fp4( + weight[i], use_ue8m0=True, gran_k=self.block_size + ) + weight_scale = deep_gemm.transform_sf_into_required_layout( + weight_scale, n, k, (1, self.block_size), num_groups + ) + output.weight.copy_(packed_weight) + output.weight_scale.copy_(weight_scale) + return + + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: "WeightPack", + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError("deepgemm-fp8fp4-b32 is only implemented for fused MoE expert weights") + + def _create_weight( + self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> Tuple[WeightPack, List[WeightPack]]: + out_dim = sum(out_dims) if isinstance(out_dims, list) else out_dims + assert in_dim % 2 == 0, "FP4 packed weight requires even input dimension" + assert in_dim % self.block_size == 0, "FP4 scale dimension must be divisible by block_size" + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim // 2), dtype=torch.int8).cuda(device_id) + weight_scale = torch.empty(expert_prefix + (out_dim, in_dim // self.block_size), dtype=torch.int32).cuda( + device_id + ) + mm_param = WeightPack(weight=weight, weight_scale=weight_scale) + mm_param_list = self._split_weight_pack( + mm_param, + weight_out_dims=out_dims, + weight_split_dim=-2, + weight_scale_out_dims=out_dims, + weight_scale_split_dim=-2, + ) + return mm_param, mm_param_list + + def _deepgemm_fp8_nt(a_tuple, b_tuple, out): if HAS_DEEPGEMM: if hasattr(deep_gemm, "gemm_fp8_fp8_bf16_nt"): diff --git a/lightllm/distributed/communication_op.py b/lightllm/distributed/communication_op.py index f01f1c87f7..f15badde25 100644 --- a/lightllm/distributed/communication_op.py +++ b/lightllm/distributed/communication_op.py @@ -27,7 +27,8 @@ from lightllm.utils.device_utils import has_nvlink from lightllm.utils.envs_utils import ( get_env_start_args, - get_deepep_num_max_dispatch_tokens_per_rank, + get_deepep_num_max_dispatch_tokens_per_rank_prefill, + get_deepep_num_max_dispatch_tokens_per_rank_decode, get_redundancy_expert_num, ) from lightllm.utils.dist_utils import ( @@ -36,7 +37,7 @@ create_new_group_for_current_dp, create_dp_special_inter_group, ) -from lightllm.utils.device_utils import get_device_sm_count +from lightllm.utils.device_utils import get_device_sm_count, is_sm100_gpu from lightllm.utils.torch_dtype_utils import get_torch_dtype logger = init_logger(__name__) @@ -106,6 +107,10 @@ def all_gather_into_tensor(self, output_: torch.Tensor, input_: torch.Tensor, as class DistributeGroupManager: def __init__(self): self.groups = [] + self.ep_buffer = None + self.ep_low_latency_buffer = None + self.ep_mega_moe_buffer = None + self.ep_num_sms = None def __len__(self): return len(self.groups) @@ -127,52 +132,92 @@ def get_default_group(self) -> CustomProcessGroup: def get_group(self, group_index: int) -> CustomProcessGroup: return self.groups[group_index] - def new_deepep_group(self, n_routed_experts, hidden_size): + def new_deepep_group( + self, + n_routed_experts, + hidden_size, + num_experts_per_tok: int = 1, + moe_intermediate_size: Optional[int] = None, + ): enable_ep_moe = get_env_start_args().enable_ep_moe - num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank() + prefill_num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank_prefill() + decode_num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank_decode() if not enable_ep_moe: self.ep_buffer = None + self.ep_low_latency_buffer = None + self.ep_mega_moe_buffer = None + self.ep_num_sms = None return assert HAS_DEEPEP, "deep_ep is required for expert parallelism" - self._set_num_sms_for_deep_gemm() global_world_size = get_global_world_size() deepep_group = dist.new_group(list(range(global_world_size))) - low_latency_mode, num_rdma_bytes = True, 0 - if low_latency_mode: - self.ll_num_tokens, self.ll_hidden = num_max_dispatch_tokens_per_rank, hidden_size - self.ll_num_experts = n_routed_experts + get_redundancy_expert_num() * global_world_size - num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( - self.ll_num_tokens, self.ll_hidden, global_world_size, self.ll_num_experts - ) - self.ep_buffer = deep_ep.Buffer( + self.ll_num_tokens = prefill_num_max_dispatch_tokens_per_rank + self.ll_decode_num_tokens = decode_num_max_dispatch_tokens_per_rank + self.ll_hidden = hidden_size + self.ll_num_experts = n_routed_experts + get_redundancy_expert_num() * global_world_size + self.ep_buffer = deep_ep.ElasticBuffer( deepep_group, - int(1e9), - num_rdma_bytes, - low_latency_mode=low_latency_mode, - num_qps_per_rank=(self.ll_num_experts // global_world_size if low_latency_mode else 1), + num_max_tokens_per_rank=self.ll_num_tokens, + hidden=self.ll_hidden, + num_topk=num_experts_per_tok, + use_fp8_dispatch=True, + allow_multiple_reduction=False, ) + self.ep_mega_moe_buffer = None + self.ep_low_latency_buffer = None + if not is_sm100_gpu(): + num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( + self.ll_decode_num_tokens, self.ll_hidden, global_world_size, self.ll_num_experts + ) + self.ep_low_latency_buffer = deep_ep.Buffer( + deepep_group, + int(1e9), + num_rdma_bytes, + low_latency_mode=True, + num_qps_per_rank=(self.ll_num_experts // global_world_size), + ) + else: + if moe_intermediate_size is None: + raise ValueError("SM100 Mega MoE requires moe_intermediate_size or intermediate_size in model config") + + import deep_gemm + + self.ep_mega_moe_buffer = deep_gemm.get_symm_buffer_for_mega_moe( + deepep_group, + self.ll_num_experts, + self.ll_num_tokens, + num_experts_per_tok, + self.ll_hidden, + moe_intermediate_size, + ) + theoretical_sms = self.ep_buffer.get_theoretical_num_sms(self.ll_num_experts, num_experts_per_tok) + self._set_num_sms_for_deep_gemm(theoretical_sms) - def _set_num_sms_for_deep_gemm(self): + def _set_num_sms_for_deep_gemm(self, deepep_sms: int): try: try: from deep_gemm.jit_kernels.utils import set_num_sms except: from deep_gemm import set_num_sms - deepep_sms = int(os.getenv("DEEPEP_SMS", deep_ep.Buffer.num_sms)) device_sms = get_device_sm_count() - deep_ep.Buffer.set_num_sms(deepep_sms) - set_num_sms(device_sms - deepep_sms) + deepep_sms = max(0, min(deepep_sms, max(device_sms - 2, 0))) + self.ep_num_sms = deepep_sms + if self.ep_low_latency_buffer is not None: + deep_ep.Buffer.set_num_sms(deepep_sms - deepep_sms % 2) + set_num_sms(max(device_sms - deepep_sms, 2)) except BaseException as e: logger.warning(f"set num sms for deep_gemm failed: {e}") def clear_deepep_buffer(self): """ - prefill 之后需要clean 一下,ep buffer 才能正常执行 decode。 + Prefill after using ElasticBuffer may leave the legacy low-latency buffer dirty for decode. """ - if hasattr(self, "ep_buffer") and self.ep_buffer is not None: - self.ep_buffer.clean_low_latency_buffer(self.ll_num_tokens, self.ll_hidden, self.ll_num_experts) + if self.ep_low_latency_buffer is not None: + self.ep_low_latency_buffer.clean_low_latency_buffer( + self.ll_decode_num_tokens, self.ll_hidden, self.ll_num_experts + ) def all_reduce( diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index fa2dee444f..4547ad529a 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -295,7 +295,7 @@ def overlap_tpsp_token_forward( infer_state1: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, ): - if not self.is_moe: + if not self.is_moe or layer_weight.experts.use_sm100_mega_moe(): return super().overlap_tpsp_token_forward( input_embdings, input_embdings1, infer_state, infer_state1, layer_weight ) @@ -421,7 +421,7 @@ def overlap_tpsp_context_forward( infer_state1: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, ): - if not self.is_moe: + if not self.is_moe or layer_weight.experts.use_sm100_mega_moe(): return super().overlap_tpsp_context_forward( input_embdings, input_embdings1, infer_state, infer_state1, layer_weight ) @@ -447,9 +447,9 @@ def overlap_tpsp_context_forward( _0_topk_weight, _0_topk_idx, _0_qinput_tensor = layer_weight.experts.select_experts_and_quant_input( _0_input1, _0_router_logits ) - from deep_ep import Buffer + from deep_ep import ElasticBuffer - _0_overlap_event = Buffer.capture() + _0_overlap_event = ElasticBuffer.capture() # 1 attention _1_input1 = self._att_norm(input_embdings1, infer_state1, layer_weight) @@ -486,8 +486,7 @@ def overlap_tpsp_context_forward( _1_topk_weight, _1_topk_idx, _1_qinput_tensor = layer_weight.experts.select_experts_and_quant_input( _1_input1, _1_router_logits ) - - _1_overlap_event = Buffer.capture() + _1_overlap_event = ElasticBuffer.capture() # 0 shared expert if self.n_shared_experts is not None: @@ -518,7 +517,7 @@ def overlap_tpsp_context_forward( infer_state1.hook() infer_state1.hook = None - _0_combine_event = Buffer.capture() + _0_combine_event = ElasticBuffer.capture() # 0 combine execute _0_ffn_out, _0_hook = layer_weight.experts.combine(_0_moe_out, _0_handle, _0_combine_event) infer_state.hook = _0_hook @@ -533,7 +532,7 @@ def overlap_tpsp_context_forward( infer_state.hook() infer_state.hook = None - _1_combine_event = Buffer.capture() + _1_combine_event = ElasticBuffer.capture() if self.n_shared_experts is not None: _0_ffn_out.add_(_0_shared_output) diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index e596eed97c..ea6620b4e4 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -48,7 +48,12 @@ def _init_some_value(self): def _init_custom(self): self._init_to_get_yarn_rotary() - dist_group_manager.new_deepep_group(self.config["n_routed_experts"], self.config["hidden_size"]) + dist_group_manager.new_deepep_group( + self.config["n_routed_experts"], + self.config["hidden_size"], + self.config.get("num_experts_per_tok", 1), + self.config.get("moe_intermediate_size", self.config.get("intermediate_size")), + ) def _verify_params(self): return super()._verify_params() diff --git a/lightllm/models/gemma4/model.py b/lightllm/models/gemma4/model.py index e1df1ec7fd..10b1958b0e 100644 --- a/lightllm/models/gemma4/model.py +++ b/lightllm/models/gemma4/model.py @@ -130,7 +130,12 @@ def _init_att_backend1(self): def _init_custom(self): self._init_to_get_rotary_gemma4() if self.config.get("enable_moe_block", False): - dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) + dist_group_manager.new_deepep_group( + self.config["num_experts"], + self.config["hidden_size"], + self.config.get("num_experts_per_tok", self.config.get("top_k_experts", 1)), + self.config.get("moe_intermediate_size", self.config.get("intermediate_size")), + ) self._init_ple_static_buffer() def _init_ple_static_buffer(self): diff --git a/lightllm/models/glm4_moe_lite/model.py b/lightllm/models/glm4_moe_lite/model.py index a8fe49ac5e..1e31306aea 100644 --- a/lightllm/models/glm4_moe_lite/model.py +++ b/lightllm/models/glm4_moe_lite/model.py @@ -25,7 +25,12 @@ def _init_config(self): def _init_custom(self): self._init_to_get_yarn_rotary() - dist_group_manager.new_deepep_group(self.config["n_routed_experts"], self.config["hidden_size"]) + dist_group_manager.new_deepep_group( + self.config["n_routed_experts"], + self.config["hidden_size"], + self.config.get("num_experts_per_tok", 1), + self.config.get("moe_intermediate_size", self.config.get("intermediate_size")), + ) def _init_to_get_yarn_rotary(self): rope_scaling = self.config.get("rope_scaling") diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 54e4373652..a39d2f9297 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -133,7 +133,7 @@ def overlap_tpsp_token_forward( infer_state1: LlamaInferStateInfo, layer_weight: Qwen3MOETransformerLayerWeight, ): - if not self.is_moe: + if not self.is_moe or layer_weight.experts.use_sm100_mega_moe(): return super().overlap_tpsp_token_forward( input_embdings, input_embdings1, infer_state, infer_state1, layer_weight ) @@ -245,7 +245,7 @@ def overlap_tpsp_context_forward( infer_state1: LlamaInferStateInfo, layer_weight: Qwen3MOETransformerLayerWeight, ): - if not self.is_moe: + if not self.is_moe or layer_weight.experts.use_sm100_mega_moe(): return super().overlap_tpsp_context_forward( input_embdings, input_embdings1, infer_state, infer_state1, layer_weight ) @@ -270,9 +270,9 @@ def overlap_tpsp_context_forward( _0_topk_weight, _0_topk_idx, _0_qinput_tensor = layer_weight.experts.select_experts_and_quant_input( _0_input1, _0_router_logits ) - from deep_ep import Buffer + from deep_ep import ElasticBuffer - _0_overlap_event = Buffer.capture() + _0_overlap_event = ElasticBuffer.capture() # 1 attention _1_input1 = self._att_norm(input_embdings1, infer_state1, layer_weight) @@ -308,8 +308,7 @@ def overlap_tpsp_context_forward( _1_topk_weight, _1_topk_idx, _1_qinput_tensor = layer_weight.experts.select_experts_and_quant_input( _1_input1, _1_router_logits ) - - _1_overlap_event = Buffer.capture() + _1_overlap_event = ElasticBuffer.capture() # 0 moe calu _0_moe_out = layer_weight.experts.prefilled_group_gemm( @@ -332,7 +331,7 @@ def overlap_tpsp_context_forward( infer_state1.hook() infer_state1.hook = None - _0_combine_event = Buffer.capture() + _0_combine_event = ElasticBuffer.capture() # 0 combine execute _0_ffn_out, _0_hook = layer_weight.experts.combine(_0_moe_out, _0_handle, _0_combine_event) infer_state.hook = _0_hook @@ -347,7 +346,7 @@ def overlap_tpsp_context_forward( infer_state.hook() infer_state.hook = None - _1_combine_event = Buffer.capture() + _1_combine_event = ElasticBuffer.capture() input_embdings.add_(_0_ffn_out.view(-1, self.embed_dim_)) diff --git a/lightllm/models/qwen3_moe/model.py b/lightllm/models/qwen3_moe/model.py index b71d7f4878..0d4b45bfe6 100644 --- a/lightllm/models/qwen3_moe/model.py +++ b/lightllm/models/qwen3_moe/model.py @@ -27,4 +27,9 @@ def _init_custom(self): super()._init_custom() # Only initialize DeepEP group for MoE models with num_experts if "num_experts" in self.config and self.config["num_experts"] > 0: - dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) + dist_group_manager.new_deepep_group( + self.config["num_experts"], + self.config["hidden_size"], + self.config.get("num_experts_per_tok", 1), + self.config.get("moe_intermediate_size", self.config.get("intermediate_size")), + ) diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index 4a8ee80a46..c1266e9df9 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -12,7 +12,6 @@ ) from lightllm.models.qwen3next.infer_struct import Qwen3NextInferStateInfo from lightllm.utils.log_utils import init_logger -from lightllm.distributed.communication_op import dist_group_manager from lightllm.utils.envs_utils import get_env_start_args from lightllm.common.kv_cache_mem_manager.qwen3next_mem_manager import Qwen3NextMemManager from lightllm.server.core.objs.start_args_type import StartArgs @@ -56,12 +55,6 @@ def _init_config(self): super()._init_config() self.num_kv_heads = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1) - def _init_custom(self): - super()._init_custom() - # Only initialize DeepEP group for MoE models with num_experts - if "num_experts" in self.config and self.config["num_experts"] > 0: - dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) - def _init_mem_manager(self): assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 start_args: StartArgs = get_env_start_args() @@ -96,6 +89,7 @@ def _init_mem_manager(self): ) def _init_req_manager(self): + self._safe_clamp_auto_max_req_total_len() create_max_seq_len = 0 if self.batch_max_tokens is not None: diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 249839b0a7..654ba0f3e5 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -4,6 +4,7 @@ import uuid import subprocess import signal +import math from lightllm.utils.net_utils import alloc_can_use_network_port, PortLocker from lightllm.utils.start_utils import process_manager, kill_recursive from .metrics.manager import start_metric_manager @@ -291,7 +292,10 @@ def normal_or_p_d_start(args): # linear att cache 参数自动设置 if args.linear_att_cache_size is None: # linear_att_cache_size 只会在 qwen3.5 等混合线性层模型中生效。 - args.linear_att_cache_size = args.running_max_req_size * 2 + default_cache_size = args.running_max_req_size * 2 + dp_size_in_node = max(1, args.dp // args.nnodes) + per_dp_cache_size = max(1, math.ceil(args.running_max_req_size / dp_size_in_node) * 2) + args.linear_att_cache_size = min(default_cache_size, per_dp_cache_size) if args.enable_cpu_cache and is_linear_att_mixed_model(args.model_dir): args.cpu_cache_token_page_size = args.linear_att_hash_page_size * args.linear_att_page_block_num diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py index 43b10ec88b..58bff90560 100644 --- a/lightllm/utils/device_utils.py +++ b/lightllm/utils/device_utils.py @@ -40,6 +40,11 @@ def get_device_sm_count(): return properties["multiprocessor_count"] +@lru_cache(maxsize=None) +def is_sm100_gpu(): + return torch.cuda.get_device_capability()[0] == 10 + + @lru_cache(maxsize=None) def get_device_sm_regs_num(): import triton diff --git a/lightllm/utils/dist_check_utils.py b/lightllm/utils/dist_check_utils.py index e11da07c8c..12b0b81993 100644 --- a/lightllm/utils/dist_check_utils.py +++ b/lightllm/utils/dist_check_utils.py @@ -17,7 +17,7 @@ logger = init_logger(__name__) _CUSTOM_ALLREDUCE_WORLD_SIZES = (2, 4, 6, 8) -_TWO_GPU_CHECK_TIMEOUT_SECONDS = 60.0 +_TWO_GPU_CHECK_TIMEOUT_SECONDS = 600.0 def _start_two_gpu_check_timeout_watchdog(backend_name: str) -> threading.Event: @@ -84,6 +84,8 @@ def _flashinfer_two_gpu_check_worker(process_rank: int, init_tcp_port: int) -> N input_tensor = torch.zeros(2, 64, device=cuda_device, dtype=torch.bfloat16) else: input_tensor = torch.ones(2, 64, device=cuda_device, dtype=torch.bfloat16) + if not flashinfer_all_reduce.should_use(input_tensor): + raise RuntimeError("FlashInferAllReduce unsupported for probe tensor") output_tensor = flashinfer_all_reduce.all_reduce(input_tensor) dist.barrier() expected_reduced = torch.ones(2, 64, device=cuda_device, dtype=torch.bfloat16) diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index 350507e897..2bdd4005fa 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -69,9 +69,22 @@ def enable_env_vars(args): @lru_cache(maxsize=None) -def get_deepep_num_max_dispatch_tokens_per_rank(): +def get_deepep_num_max_dispatch_tokens_per_rank_prefill(): + # 该参数需要大于单卡最大batch size,且是8的倍数。该参数与显存占用直接相关,值越大,显存占用越大。 + # 如果未显式配置,则默认至少覆盖当前进程的 `batch_max_tokens`,避免 DeepEP V2 在 autotune + # warmup 或大 prefill batch 时因为 buffer 上界过小而报错。 + configured = os.getenv("NUM_MAX_DISPATCH_TOKENS_PER_RANK_PREFILL", None) + if configured is not None: + return int(configured) + + batch_max_tokens = get_env_start_args().batch_max_tokens or 256 + return ((int(batch_max_tokens) + 7) // 8) * 8 + + +@lru_cache(maxsize=None) +def get_deepep_num_max_dispatch_tokens_per_rank_decode(): # 该参数需要大于单卡最大batch size,且是8的倍数。该参数与显存占用直接相关,值越大,显存占用越大,如果出现显存不足,可以尝试调小该值 - return int(os.getenv("NUM_MAX_DISPATCH_TOKENS_PER_RANK", 256)) + return int(os.getenv("NUM_MAX_DISPATCH_TOKENS_PER_RANK_DECODE", 256)) def get_lightllm_gunicorn_keep_alive(): diff --git a/requirements.txt b/requirements.txt index d37ae05690..31a88629ed 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,7 +33,7 @@ mpmath==1.3.0 multiprocessing-logging==0.3.4 networkx==3.1 ninja==1.11.1 -numpy==1.25.1 +numpy==2.1.3 packaging==24.2 pip==23.0.1 pluggy==1.2.0 @@ -59,7 +59,7 @@ six==1.16.0 sniffio==1.3.0 sortedcontainers==2.4.0 toolz==0.12.0 -torch==2.9.1 +torch==2.11.0 tqdm==4.65.0 transformers==4.57.1 tokenizers==0.22.1 @@ -71,7 +71,7 @@ zstandard==0.23.0 safetensors==0.4.5 Pillow==10.4.0 tiktoken==0.7.0 -matplotlib==3.8.2 +matplotlib==3.10.0 psutil==5.9.4 prometheus_client==0.20.0 cchardet==2.1.7 @@ -81,19 +81,21 @@ atomics==1.0.3 easydict==1.13 hypercorn==0.18.0 flashinfer-python==0.6.8.post1 -sgl-kernel==0.3.21 +flashinfer-cubin==0.6.8.post1 +sglang-kernel==0.4.2.post1 httpx==0.28.1 librosa==0.11.0 -cuda_bindings==12.9.0 +cuda_bindings==13.2.0 orjson==3.11.2 setproctitle==1.3.6 xxhash==3.6.0 -torchvision==0.24.1 +torchvision==0.26.0 interegular==0.3.3 partial_json_parser==0.2.1.1.post6 websockets==15.0.1 -cupy-cuda12x==13.6.0 -nixl==0.8.0 -xformers==0.0.33.post2 +cupy-cuda13x==14.0.1 +nixl==1.1.0 +xformers==0.0.35 redis==7.3.0 litellm>=1.52.0,<1.85 +flash-attn-4[13]==4.0.0b14 diff --git a/test/benchmark/service/benchmark_client.py b/test/benchmark/service/benchmark_client.py index 09009fc9e1..3f55bcab1e 100644 --- a/test/benchmark/service/benchmark_client.py +++ b/test/benchmark/service/benchmark_client.py @@ -27,6 +27,13 @@ def get_tokenizer( return tokenizer +def normalize_model_name(model_name: str) -> str: + if not model_name: + return model_name + normalized = model_name.rstrip("/\\") + return normalized or model_name + + def get_output_length(input_num: int, output_len: int) -> List[int]: min_len, max_len = 2, output_len * 2 mean = (min_len + max_len) * 0.5 @@ -162,7 +169,7 @@ def main(): return assert args.tokenizer_path is not None - model_name.append(args.tokenizer_path) + model_name.append(normalize_model_name(args.tokenizer_path)) seed_all(args.seed) url = args.url tokenizer = get_tokenizer(args.tokenizer_path) diff --git a/test/benchmark/service/benchmark_multiturn.py b/test/benchmark/service/benchmark_multiturn.py index c1c87b0f5c..01863b37b6 100644 --- a/test/benchmark/service/benchmark_multiturn.py +++ b/test/benchmark/service/benchmark_multiturn.py @@ -39,12 +39,25 @@ import os import random import time +import urllib.parse +import urllib.request from typing import Dict, List, Optional, Tuple, Union import aiohttp import numpy as np from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast +_STREAM_READ_BUFSIZE = 1 << 20 +_STREAM_MAX_LINE_SIZE = 1 << 20 +_DEFAULT_TRANSIENT_RETRIES = 2 +_PROMPT_LEN_OVERLAP_CHARS = 512 +_TRANSIENT_STREAM_ERRORS = ( + aiohttp.ServerDisconnectedError, + aiohttp.ClientPayloadError, + aiohttp.ClientOSError, + asyncio.TimeoutError, +) + def seed_all(seed: int) -> None: if not seed: @@ -58,6 +71,85 @@ def get_tokenizer(tokenizer_name: str) -> Union[PreTrainedTokenizer, PreTrainedT return AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True) +def normalize_model_name(model_name: str) -> str: + if not model_name: + return model_name + normalized = model_name.rstrip("/\\") + return normalized or model_name + + +def get_models_url(completions_url: str) -> str: + parsed = urllib.parse.urlsplit(completions_url) + path = parsed.path.rstrip("/") + for suffix in ("/chat/completions", "/completions"): + if path.endswith(suffix): + path = path[: -len(suffix)] + "/models" + return urllib.parse.urlunsplit(parsed._replace(path=path, query="", fragment="")) + return urllib.parse.urlunsplit(parsed._replace(path="/v1/models", query="", fragment="")) + + +def fetch_served_model_names(completions_url: str, timeout_s: int = 10) -> List[str]: + models_url = get_models_url(completions_url) + request = urllib.request.Request(models_url, headers={"Accept": "application/json"}) + with urllib.request.urlopen(request, timeout=timeout_s) as response: + payload = json.loads(response.read().decode("utf-8")) + return [item["id"] for item in payload.get("data", []) if item.get("id")] + + +def resolve_model_name( + completions_url: str, + requested_model_name: str, + explicit_model_name: bool, +) -> Tuple[str, Optional[str]]: + normalized_name = normalize_model_name(requested_model_name) + if normalized_name != requested_model_name: + note = f"Normalized model name from `{requested_model_name}` to `{normalized_name}`." + else: + note = None + + try: + served_model_names = fetch_served_model_names(completions_url) + except Exception as exc: + if note is not None: + note = f"{note} Failed to query served models: {exc}." + return normalized_name, note + + if requested_model_name in served_model_names: + return requested_model_name, note + if normalized_name in served_model_names: + if normalized_name != requested_model_name: + return normalized_name, ( + f"Normalized model name from `{requested_model_name}` to `{normalized_name}` " "to match `/v1/models`." + ) + return normalized_name, note + + requested_basename = os.path.basename(normalized_name) + basename_matches = [ + served_name + for served_name in served_model_names + if os.path.basename(normalize_model_name(served_name)) == requested_basename + ] + if len(basename_matches) == 1: + matched_name = basename_matches[0] + return matched_name, ( + f"Resolved model name `{requested_model_name}` to served model `{matched_name}` " "via `/v1/models`." + ) + + if not explicit_model_name and len(served_model_names) == 1: + matched_name = served_model_names[0] + return matched_name, ( + f"Using the only served model `{matched_name}` returned by `/v1/models` " + f"instead of `{requested_model_name}`." + ) + + if note is not None: + note = ( + f"{note} Available served models: {', '.join(served_model_names) or '(none)'}. " + f"Using `{normalized_name}`." + ) + return normalized_name, note + + def gen_random_token_ids(tokenizer, n: int, rng: random.Random) -> List[int]: vocab = tokenizer.vocab_size return [rng.randint(0, vocab - 1) for _ in range(n)] @@ -86,6 +178,7 @@ def gen_session_initial_prompt( def append_turn_input( tokenizer, prompt: str, + prompt_token_len: int, generated_text: str, turn_input_increment: int, rng: random.Random, @@ -97,17 +190,34 @@ def append_turn_input( new_text = decode_ids(tokenizer, new_ids) else: new_text = "" - new_prompt = prompt + generated_text + new_text - new_len = len(tokenizer.encode(new_prompt, add_special_tokens=False)) + + appended_text = generated_text + new_text + new_prompt = prompt + appended_text + if not appended_text: + return new_prompt, prompt_token_len + + # Token merges only depend on a small boundary window, so avoid + # re-encoding the entire prompt on every turn. + overlap_text = prompt[-_PROMPT_LEN_OVERLAP_CHARS:] + if overlap_text: + overlap_token_len = len(tokenizer.encode(overlap_text, add_special_tokens=False)) + merged_token_len = len(tokenizer.encode(overlap_text + appended_text, add_special_tokens=False)) + appended_token_len = max(merged_token_len - overlap_token_len, 0) + else: + appended_token_len = len(tokenizer.encode(appended_text, add_special_tokens=False)) + new_len = prompt_token_len + appended_token_len return new_prompt, new_len async def stream_one_turn( session: aiohttp.ClientSession, + tokenizer, url: str, model_name: str, prompt: str, + prompt_token_len: int, max_new_tokens: int, + max_retries: int = _DEFAULT_TRANSIENT_RETRIES, ) -> Optional[Dict]: """Send one streaming completion request, return per-turn stats: { @@ -116,6 +226,8 @@ async def stream_one_turn( "prompt_tokens": int, "completion_tokens": int, "cached_tokens": int, + "cached_tokens_reported": bool, + "usage_estimated": bool, "generated_text": str, } Returns None on failure.""" @@ -130,74 +242,111 @@ async def stream_one_turn( } headers = {"Content-Type": "application/json"} - start_time = time.time() - first_token_time: Optional[float] = None - last_token_time: Optional[float] = None - decode_times: List[float] = [] - generated_text_parts: List[str] = [] - prompt_tokens = 0 - completion_tokens = 0 - cached_tokens = 0 - - try: - async with session.post(url, headers=headers, json=payload) as response: - if response.status != 200: - err = await response.text() - print(f"\n[turn failed] status={response.status} body={err[:200]}") - return None - - async for raw in response.content: - line = raw.strip() - if not line or not line.startswith(b"data:"): - continue - data_str = line[len(b"data:") :].strip() - if data_str == b"[DONE]": - break - try: - chunk = json.loads(data_str) - except Exception: - continue - - # Final usage-only chunk: choices == [] and usage present - usage = chunk.get("usage") - choices = chunk.get("choices") or [] - if usage is not None and not choices: - prompt_tokens = usage.get("prompt_tokens", prompt_tokens) - completion_tokens = usage.get("completion_tokens", completion_tokens) - details = usage.get("prompt_tokens_details") or {} - cached_tokens = details.get("cached_tokens", cached_tokens) - continue - - # Token-bearing chunk - if not choices: - continue - text_piece = choices[0].get("text", "") - if text_piece == "" and choices[0].get("finish_reason") is None: - continue - - now = time.time() - if first_token_time is None: - first_token_time = now - else: - decode_times.append(now - last_token_time) - last_token_time = now - if text_piece: - generated_text_parts.append(text_piece) - except Exception as e: - print(f"\n[turn exception] {e}") - return None - - if first_token_time is None: - return None - - return { - "ttft": first_token_time - start_time, - "decode_times": decode_times, - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "cached_tokens": cached_tokens, - "generated_text": "".join(generated_text_parts), - } + for attempt in range(max_retries + 1): + start_time = time.time() + first_token_time: Optional[float] = None + last_token_time: Optional[float] = None + decode_times: List[float] = [] + generated_text_parts: List[str] = [] + prompt_tokens = 0 + completion_tokens = 0 + cached_tokens = 0 + cached_tokens_reported = False + + try: + async with session.post(url, headers=headers, json=payload) as response: + if response.status != 200: + err = await response.text() + if response.status >= 500 and attempt < max_retries: + await asyncio.sleep(0.2 * (attempt + 1)) + continue + print(f"\n[turn failed] status={response.status} body={err[:200]}") + return None + + async for raw in response.content: + line = raw.strip() + if not line or not line.startswith(b"data:"): + continue + data_str = line[len(b"data:") :].strip() + if data_str == b"[DONE]": + break + try: + chunk = json.loads(data_str) + except Exception: + continue + + # Final usage-only chunk: choices == [] and usage present + usage = chunk.get("usage") + choices = chunk.get("choices") or [] + if usage is not None and not choices: + prompt_tokens = usage.get("prompt_tokens", prompt_tokens) + completion_tokens = usage.get("completion_tokens", completion_tokens) + details = usage.get("prompt_tokens_details") + if isinstance(details, dict) and details.get("cached_tokens") is not None: + cached_tokens = details["cached_tokens"] + cached_tokens_reported = True + continue + + # Token-bearing chunk + if not choices: + continue + text_piece = choices[0].get("text", "") + if text_piece == "" and choices[0].get("finish_reason") is None: + continue + + now = time.time() + if first_token_time is None: + first_token_time = now + else: + decode_times.append(now - last_token_time) + last_token_time = now + if text_piece: + generated_text_parts.append(text_piece) + except _TRANSIENT_STREAM_ERRORS as e: + if first_token_time is None and attempt < max_retries: + await asyncio.sleep(0.2 * (attempt + 1)) + continue + + if first_token_time is not None: + generated_text = "".join(generated_text_parts) + estimated_completion_tokens = len(tokenizer.encode(generated_text, add_special_tokens=False)) + estimated_completion_tokens = max(estimated_completion_tokens, len(generated_text_parts)) + print(f"\n[turn warning] {e}; keeping partial turn with estimated usage " f"(attempt={attempt + 1})") + return { + "ttft": first_token_time - start_time, + "decode_times": decode_times, + "prompt_tokens": prompt_tokens or prompt_token_len, + "completion_tokens": completion_tokens or estimated_completion_tokens, + "cached_tokens": cached_tokens, + "cached_tokens_reported": cached_tokens_reported, + "usage_estimated": completion_tokens == 0 or prompt_tokens == 0, + "generated_text": generated_text, + } + + print(f"\n[turn exception] {e}") + return None + except Exception as e: + print(f"\n[turn exception] {e}") + return None + + if first_token_time is None: + if attempt < max_retries: + await asyncio.sleep(0.2 * (attempt + 1)) + continue + return None + + return { + "ttft": first_token_time - start_time, + "decode_times": decode_times, + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "cached_tokens": cached_tokens, + "cached_tokens_reported": cached_tokens_reported, + "usage_estimated": False, + "generated_text": "".join(generated_text_parts), + } + + return None async def run_session( @@ -219,13 +368,18 @@ async def run_session( """Run a single multi-turn dialogue session. Returns a list of per-turn stat dicts (same schema as stream_one_turn output).""" rng = random.Random(base_seed + session_id) - prompt, prompt_len = gen_session_initial_prompt(tokenizer, start_input_len, base_seed + session_id) + prompt, prompt_len = await asyncio.to_thread( + gen_session_initial_prompt, + tokenizer, + start_input_len, + base_seed + session_id, + ) per_turn: List[Dict] = [] turn_idx = 0 while turn_idx < max_turns and prompt_len < max_input_len: turn_output_len = rng.randint(min_output_len, output_len) - result = await stream_one_turn(session, url, model_name, prompt, turn_output_len) + result = await stream_one_turn(session, tokenizer, url, model_name, prompt, prompt_len, turn_output_len) if result is None: break per_turn.append(result) @@ -237,9 +391,11 @@ async def run_session( end="", ) turn_input_len = rng.randint(min_turn_input_increment, turn_input_increment) - prompt, prompt_len = append_turn_input( + prompt, prompt_len = await asyncio.to_thread( + append_turn_input, tokenizer, prompt, + result["prompt_tokens"] or prompt_len, result["generated_text"], turn_input_len, rng, @@ -267,7 +423,7 @@ async def run_concurrency_level( ) -> Dict: """Run one concurrency level. Returns the aggregated stats dict.""" timeout = aiohttp.ClientTimeout(total=request_timeout_s) - connector = aiohttp.TCPConnector(limit=max(concurrency * 2, 32)) + connector = aiohttp.TCPConnector(limit=max(concurrency * 2, 32), enable_cleanup_closed=True) progress_state = { "concurrency": concurrency, "finished_turns": 0, @@ -275,7 +431,12 @@ async def run_concurrency_level( } wall_start = time.time() - async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session: + async with aiohttp.ClientSession( + connector=connector, + timeout=timeout, + read_bufsize=_STREAM_READ_BUFSIZE, + max_line_size=_STREAM_MAX_LINE_SIZE, + ) as session: tasks = [ asyncio.create_task( run_session( @@ -341,13 +502,14 @@ def summarize( prompt_tokens = sum(t["prompt_tokens"] for t in turns) completion_tokens = sum(t["completion_tokens"] for t in turns) cached_tokens = sum(t["cached_tokens"] for t in turns) + cached_tokens_reported_turns = sum(1 for t in turns if t.get("cached_tokens_reported")) + usage_estimated_turns = sum(1 for t in turns if t.get("usage_estimated")) total_tokens = prompt_tokens + completion_tokens qps = len(turns) / wall_time tpm_total = total_tokens / wall_time * 60.0 tpm_prompt = prompt_tokens / wall_time * 60.0 tpm_completion = completion_tokens / wall_time * 60.0 - cache_hit_ratio = cached_tokens / prompt_tokens if prompt_tokens else 0.0 out["QPS"] = round(qps, 4) out["TPM_total"] = round(tpm_total, 2) @@ -356,7 +518,18 @@ def summarize( out["total_prompt_tokens"] = prompt_tokens out["total_completion_tokens"] = completion_tokens out["total_cached_prompt_tokens"] = cached_tokens - out["cache_hit_ratio"] = round(cache_hit_ratio, 6) + out["cached_tokens_reported_turns"] = cached_tokens_reported_turns + out["usage_estimated_turns"] = usage_estimated_turns + if cached_tokens_reported_turns > 0: + cache_hit_ratio = cached_tokens / prompt_tokens if prompt_tokens else 0.0 + out["cache_hit_ratio"] = round(cache_hit_ratio, 6) + else: + out["cache_hit_ratio"] = None + out["cache_hit_ratio_note"] = ( + "Server did not return usage.prompt_tokens_details.cached_tokens. " + "For vLLM OpenAI-compatible APIs, start the server with " + "--enable-prompt-tokens-details to expose cache-hit stats." + ) out["avg_prompt_tokens_per_turn"] = round(prompt_tokens / len(turns), 2) out["avg_completion_tokens_per_turn"] = round(completion_tokens / len(turns), 2) @@ -389,10 +562,16 @@ def print_summary(summary: Dict) -> None: print(f" TPM (total) : {summary['TPM_total']}") print(f" TPM (prompt) : {summary['TPM_prompt']}") print(f" TPM (completion) : {summary['TPM_completion']}") - print( - f" Cache hit ratio : {summary['cache_hit_ratio'] * 100:.2f}% " - f"({summary['total_cached_prompt_tokens']} / {summary['total_prompt_tokens']})" - ) + if summary["cache_hit_ratio"] is None: + print(" Cache hit ratio : n/a") + print(f" Cache hit note : {summary['cache_hit_ratio_note']}") + else: + print( + f" Cache hit ratio : {summary['cache_hit_ratio'] * 100:.2f}% " + f"({summary['total_cached_prompt_tokens']} / {summary['total_prompt_tokens']})" + ) + if summary.get("usage_estimated_turns"): + print(f" Usage estimated : {summary['usage_estimated_turns']} turns") print(f" Avg prompt tokens : {summary['avg_prompt_tokens_per_turn']}") print(f" Avg output tokens : {summary['avg_completion_tokens_per_turn']}") ttft = summary["TTFT_ms"] @@ -415,7 +594,7 @@ def main() -> None: parser.add_argument( "--url", type=str, - default="http://127.0.0.1:8088/v1/completions", + default="http://127.0.0.1:8000/v1/completions", help="Streaming OpenAI completion endpoint. The benchmark relies on " "the final SSE `usage` chunk to obtain cached_tokens.", ) @@ -482,12 +661,19 @@ def main() -> None: return seed_all(args.seed) - model_name = args.model_name or args.tokenizer_path + requested_model_name = args.model_name or args.tokenizer_path + model_name, model_name_note = resolve_model_name( + args.url, + requested_model_name, + explicit_model_name=args.model_name is not None, + ) tokenizer = get_tokenizer(args.tokenizer_path) concurrency_levels = [int(x) for x in args.concurrency_levels.split(",") if x.strip()] print(f"URL : {args.url}") print(f"Model : {model_name}") + if model_name_note: + print(f"Model note : {model_name_note}") print(f"Concurrency levels : {concurrency_levels}") print(f"start_input_len : {args.start_input_len}") print(f"max_input_len : {args.max_input_len}") @@ -528,6 +714,7 @@ def main() -> None: "config": { "url": args.url, "model_name": model_name, + "requested_model_name": requested_model_name, "tokenizer_path": args.tokenizer_path, "concurrency_levels": concurrency_levels, "start_input_len": args.start_input_len, diff --git a/test/benchmark/service/benchmark_qps.py b/test/benchmark/service/benchmark_qps.py index 8249ae2c49..3249ebcbda 100644 --- a/test/benchmark/service/benchmark_qps.py +++ b/test/benchmark/service/benchmark_qps.py @@ -31,6 +31,13 @@ def get_tokenizer( return tokenizer +def normalize_model_name(model_name: str) -> str: + if not model_name: + return model_name + normalized = model_name.rstrip("/\\") + return normalized or model_name + + def get_random_length(reqs_num: int, length: int, range_ratio: float) -> List[int]: lens = [] lens = np.random.randint( @@ -429,7 +436,7 @@ def main(): return assert args.tokenizer_path is not None - model_name.append(args.tokenizer_path) + model_name.append(normalize_model_name(args.tokenizer_path)) seed_all(args.seed) url = args.url tokenizer = get_tokenizer(args.tokenizer_path)