Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 27 additions & 26 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
ARG CUDA_VERSION=12.8.0
ARG CUDA_VERSION=13.0.0
FROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu22.04

ARG PYTHON_VERSION=3.10
ARG MAMBA_VERSION=24.7.1-0
ARG VLLM_VERSION=0.16.0
ARG VLLM_VERSION=0.21.0
ARG NIXL_REF=v1.1.0
ARG FLASH_MLA_REF=47c35a7
ARG DEEPGEMM_REF=891d57b4db1071624b5c8fa0d1e51cb317fa709f
ARG TARGETPLATFORM
ARG ENABLE_DEEPEP=1
ARG ENABLE_NIXL=1
ARG ENABLE_CACHE=1
ARG ENABLE_SM100=0

ENV PATH=/opt/conda/bin:$PATH \
CONDA_PREFIX=/opt/conda
Expand Down Expand Up @@ -44,13 +47,18 @@ WORKDIR /root

COPY ./requirements.txt /lightllm/requirements.txt
RUN pip install -U pip
RUN pip install -r /lightllm/requirements.txt --no-cache-dir
RUN pip install --no-cache-dir vllm==${VLLM_VERSION}
RUN git clone https://github.com/deepseek-ai/FlashMLA.git /root/FlashMLA && \
RUN pip install --no-cache-dir \
--extra-index-url https://download.pytorch.org/whl/cu130 \
vllm==${VLLM_VERSION}
RUN pip install -r /lightllm/requirements.txt --no-cache-dir \
--extra-index-url https://download.pytorch.org/whl/cu130
RUN export CPATH=/usr/local/cuda/targets/x86_64-linux/include/cccl:/usr/local/cuda/targets/x86_64-linux/include${CPATH:+:${CPATH}} && \
git clone https://github.com/deepseek-ai/FlashMLA.git /root/FlashMLA && \
cd /root/FlashMLA && \
git checkout ${FLASH_MLA_REF} && \
git submodule update --init --recursive && \
FLASH_MLA_DISABLE_SM100=1 pip install --no-cache-dir .
FLASH_MLA_DISABLE_SM100="$(if [ "${ENABLE_SM100}" = "1" ]; then echo 0; else echo 1; fi)" \
pip install --no-cache-dir .

RUN apt-get update && apt-get install -y libnuma-dev && rm -rf /var/lib/apt/lists/*

Expand Down Expand Up @@ -78,27 +86,20 @@ RUN if [ "${ENABLE_NIXL}" = "1" ] || [ "${ENABLE_DEEPEP}" = "1" ]; then \
RUN if [ "${ENABLE_DEEPEP}" = "1" ]; then \
set -e; \
ln -sf /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so; \
NVSHMEM_VERSION=3.3.9; \
CUDA_ARCHS=90; \
wget https://developer.download.nvidia.com/compute/redist/nvshmem/${NVSHMEM_VERSION}/source/nvshmem_src_cuda12-all-all-${NVSHMEM_VERSION}.tar.gz \
&& tar -xf nvshmem_src_cuda12-all-all-${NVSHMEM_VERSION}.tar.gz && mv nvshmem_src nvshmem \
&& cd nvshmem \
&& rm -f /root/nvshmem_src_cuda12-all-all-${NVSHMEM_VERSION}.tar.gz \
&& NVSHMEM_SHMEM_SUPPORT=0 \
NVSHMEM_UCX_SUPPORT=0 \
NVSHMEM_USE_NCCL=0 \
NVSHMEM_MPI_SUPPORT=0 \
NVSHMEM_IBGDA_SUPPORT=1 \
NVSHMEM_PMIX_SUPPORT=0 \
NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \
NVSHMEM_USE_GDRCOPY=1 \
cmake -S . -B build/ -DCMAKE_INSTALL_PREFIX=/root/nvshmem/install -DCMAKE_CUDA_ARCHITECTURES=${CUDA_ARCHS} \
&& cmake --build build --target install -j64; \
DEEPEP_COMMIT=b6ce310bb0b75079682d09bc2ebc063a074fbd58; \
cd /root && git clone https://github.com/deepseek-ai/DeepEP.git && cd DeepEP && git checkout ${DEEPEP_COMMIT} && cd ..; \
cd /root/DeepEP && NVSHMEM_DIR=/root/nvshmem/install python setup.py install; \
python -m pip install --upgrade --no-deps \
"nvidia-nccl-cu13==2.30.4" \
"nvidia-nvshmem-cu13==3.6.5"; \
cd /root && git clone https://github.com/deepseek-ai/DeepEP.git && cd DeepEP && git checkout b306af06afd412c88e51e71802951606e40b7358; \
ln -sf /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nvshmem/lib/libnvshmem_host.so.3 /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nvshmem/lib/libnvshmem_host.so; \
ln -sf /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nccl/lib/libnccl.so.2 /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nccl/lib/libnccl.so; \
pip install --no-build-isolation .; \
fi

RUN cd /root && git clone https://github.com/deepseek-ai/DeepGEMM.git && \
cd DeepGEMM && git checkout ${DEEPGEMM_REF} && \
git submodule update --init --recursive && \
pip install --no-build-isolation .

RUN if [ "${ENABLE_NIXL}" = "1" ]; then \
apt-get update && apt-get install -y cmake automake autotools-dev libtool libz-dev && \
DEBIAN_FRONTEND=noninteractive apt-get -y install --reinstall libibverbs-dev rdma-core ibverbs-utils libibumad-dev; \
Expand Down Expand Up @@ -126,7 +127,7 @@ RUN if [ "${ENABLE_NIXL}" = "1" ]; then \
apt-get update && apt-get install -y pkg-config tmux net-tools && \
cd /usr/local/src; \
pip install --upgrade meson pybind11 patchelf; \
git clone https://github.com/ai-dynamo/nixl.git -b main && \
git clone https://github.com/ai-dynamo/nixl.git -b ${NIXL_REF} && \
cd nixl && \
rm -rf build && \
mkdir build && \
Expand Down
14 changes: 10 additions & 4 deletions docker/scripts/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,23 @@ set -euo pipefail
# --no-nixl Disable NIXL (default: enabled)
# --no-cache Disable cache (default: enabled)
# --lite Disable DEEPEP, NIXL and cache in one shot
# --cuda-version <ver> CUDA version (default: 12.8.0)
# --cuda-version <ver> CUDA version (default: 13.0.0)
# --image-prefix <name> Image prefix (default: lightllm)
# --image-tag <tag> Image tag (default: generated from enabled features)
# --enable-sm100 Enable SM100 support (default: disabled)
# -h / --help Show help

ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)"
cd "${ROOT_DIR}"

IMAGE_PREFIX="${IMAGE_PREFIX:-lightllm}"
CUDA_VERSION="${CUDA_VERSION:-12.8.0}"
CUDA_VERSION="${CUDA_VERSION:-13.0.0}"
IMAGE_TAG="${IMAGE_TAG:-}"

ENABLE_DEEPEP="${ENABLE_DEEPEP:-1}"
ENABLE_NIXL="${ENABLE_NIXL:-1}"
ENABLE_CACHE="${ENABLE_CACHE:-1}"
ENABLE_SM100="${ENABLE_SM100:-0}"

print_help() {
sed -n '1,80p' "$0" | sed 's/^# \{0,1\}//'
Expand All @@ -43,6 +45,7 @@ while [[ $# -gt 0 ]]; do
--no-deepep) ENABLE_DEEPEP=0 ;;
--no-nixl) ENABLE_NIXL=0 ;;
--no-cache) ENABLE_CACHE=0 ;;
--enable-sm100) ENABLE_SM100=1 ;;
--lite)
ENABLE_DEEPEP=0
ENABLE_NIXL=0
Expand Down Expand Up @@ -78,13 +81,16 @@ done
# - Other combos: composed from enabled feature names
if [[ -z "${IMAGE_TAG}" ]]; then
tag_parts=()
if [[ "${ENABLE_SM100}" -eq 1 ]]; then
tag_parts+=("sm100")
fi
if [[ "${ENABLE_NIXL}" -eq 1 ]]; then
tag_parts+=("nixl")
fi
if [[ "${ENABLE_DEEPEP}" -eq 1 ]]; then
tag_parts+=("deepep")
fi
if [[ "${ENABLE_NIXL}" -eq 1 && "${ENABLE_DEEPEP}" -eq 1 && "${ENABLE_CACHE}" -eq 1 ]]; then
if [[ "${ENABLE_SM100}" -eq 0 && "${ENABLE_NIXL}" -eq 1 && "${ENABLE_DEEPEP}" -eq 1 && "${ENABLE_CACHE}" -eq 1 ]]; then
IMAGE_TAG="cuda${CUDA_VERSION}"
else
prefix=""
Expand All @@ -100,6 +106,6 @@ DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile \
--build-arg ENABLE_DEEPEP="${ENABLE_DEEPEP}" \
--build-arg ENABLE_NIXL="${ENABLE_NIXL}" \
--build-arg ENABLE_CACHE="${ENABLE_CACHE}" \
--build-arg ENABLE_SM100="${ENABLE_SM100}" \
--progress=plain \
-t "${IMAGE_PREFIX}:${IMAGE_TAG}" .

24 changes: 23 additions & 1 deletion lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from lightllm.common.basemodel.prefill_cuda_graph import PrefillCudaGraph
from lightllm.common.quantization import Quantcfg
from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token, gather_token_prefill_decode_mixed
from lightllm.utils.config_utils import _derive_max_req_total_len_from_model_config
from lightllm.utils.log_utils import init_logger
from lightllm.utils.dist_utils import get_dp_world_size
from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type, get_added_mtp_kv_layer_num
Expand Down Expand Up @@ -105,8 +106,8 @@ def __init__(self, kvargs):
self._init_quant()

self._init_weights()
self._init_req_manager()
self._init_mem_manager()
self._init_req_manager()
# 因为类似 qwen3.5 的linear 架构的模型,其 req_manager 会存储运行时使用的大量 linear state
# 这可能会占用大量的显存,所以,req_manger 中保存的 mem_manger 是mem manager 初始化后再赋值
self.req_manager.mem_manager = self.mem_manager
Expand Down Expand Up @@ -210,6 +211,26 @@ def _init_kv_move_buffer(self):
if self.run_mode in ["prefill", "decode"]:
self.mem_manager.alloc_kv_move_buffer(self.mem_manager.size)

# 推导出的max_req_total_len如果显存预算支持不了,需要进一步截断到可支持的长度
def _safe_clamp_auto_max_req_total_len(self):
max_total_token_num = self.mem_manager.size
if self.max_seq_length is None or self.max_seq_length <= max_total_token_num:
return

# 只截断推导生成的max_req_total_len
old_max_req_total_len = self.max_seq_length - 8
derived_max_req_total_len = _derive_max_req_total_len_from_model_config(self.weight_dir_)
if derived_max_req_total_len is None or old_max_req_total_len != derived_max_req_total_len:
return

supported_max_req_total_len = max(max_total_token_num - 8, 1)
self.args.max_req_total_len = supported_max_req_total_len
self.max_seq_length = supported_max_req_total_len + 8

if self.graph_max_len_in_batch == old_max_req_total_len:
self.args.graph_max_len_in_batch = min(self.args.graph_max_len_in_batch, supported_max_req_total_len)
self.graph_max_len_in_batch = self.args.graph_max_len_in_batch

def _check_mem_size(self):
self.max_total_token_num = self.mem_manager.size

Expand All @@ -232,6 +253,7 @@ def _check_mem_size(self):
return

def _init_req_manager(self):
self._safe_clamp_auto_max_req_total_len()
create_max_seq_len = 0

if self.batch_max_tokens is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class BufNode:
inner_tensor: torch.Tensor
shape_key: Tuple[int, torch.dtype]
storage_weak_ptr: int
free_use_count_bias: int = 0
shape_to_tensor: Dict[Union[torch.Size, Iterable[int]], torch.Tensor] = field(default_factory=dict)

def __del__(self):
Expand Down Expand Up @@ -99,7 +100,8 @@ def alloc_tensor(
# 回收可能消亡的 tensor
for ptr in self.changed_ptr:
t_buf_node = self.ptr_to_bufnode[ptr]
if self.use_count(ptr) == 1 + len(t_buf_node.shape_to_tensor):
free_use_count = t_buf_node.free_use_count_bias + 1 + len(t_buf_node.shape_to_tensor)
if self.use_count(ptr) <= free_use_count:
self.free_shape_dtype_to_bufs[t_buf_node.shape_key].append(t_buf_node)
self.changed_ptr.clear()

Expand Down Expand Up @@ -131,6 +133,7 @@ def alloc_tensor(
self.ptr_to_bufnode[storage_weak_ptr] = buf_node
if shape not in buf_node.shape_to_tensor:
buf_node.shape_to_tensor[shape] = buf_node.inner_tensor.view(shape)
buf_node.free_use_count_bias = self.use_count(storage_weak_ptr) - (1 + len(buf_node.shape_to_tensor))
mark_tensor = buf_node.shape_to_tensor[shape]
ans = mark_tensor.data # 返回一个新的引用, 否则引用计数会无法判断
ans.storage_weak_ptr = buf_node.storage_weak_ptr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.impl import select_fuse_moe_impl
from lightllm.common.quantization.quantize_method import QuantizationMethod
from lightllm.utils.envs_utils import get_redundancy_expert_ids, get_redundancy_expert_num, get_env_start_args
from lightllm.utils.device_utils import is_sm100_gpu
from lightllm.utils.dist_utils import get_global_world_size, get_global_rank
from lightllm.utils.log_utils import init_logger

Expand Down Expand Up @@ -52,6 +53,7 @@ def __init__(
self.quant_method = quant_method
assert num_fused_shared_experts in [0, 1], "num_fused_shared_experts can only support 0 or 1 now."
self.enable_ep_moe = get_env_start_args().enable_ep_moe
self.quant_method = self._maybe_upgrade_quant_method_for_ep_moe(self.quant_method)
self.n_routed_experts = n_routed_experts
self.num_fused_shared_experts = num_fused_shared_experts
self._init_config(network_config)
Expand All @@ -70,6 +72,28 @@ def __init__(
self.lock = threading.Lock()
self._create_weight()

def _maybe_upgrade_quant_method_for_ep_moe(self, quant_method: QuantizationMethod) -> QuantizationMethod:
if not self.enable_ep_moe:
return quant_method

target_method = "deepgemm-fp8fp4-b32" if is_sm100_gpu() else "deepgemm-fp8w8a8-b128"
if quant_method.method_name == "none":
from lightllm.common.quantization.registry import QUANTMETHODS

logger.info(
f"enable_ep_moe requires DeepGEMM MoE expert weights; "
f"auto-upgrading fused_moe quantization from `none` to `{target_method}`."
)
quant_method = QUANTMETHODS.get(target_method)

if quant_method.method_name != target_method:
raise ValueError(
f"enable_ep_moe currently requires `{target_method}` for fused_moe on this GPU, "
f"but got `{quant_method.method_name}`."
)

return quant_method

def _init_config(self, network_config: Dict[str, Any]):
self.n_group = network_config.get("n_group", 0)
self.use_grouped_topk = self.n_group > 0
Expand Down Expand Up @@ -152,6 +176,9 @@ def experts(
per_expert_scale=self.per_expert_scale,
)

def use_sm100_mega_moe(self) -> bool:
return bool(getattr(self.fuse_moe_impl, "_use_sm100_fp4_moe", lambda: False)())

def low_latency_dispatch(
self,
hidden_states: torch.Tensor,
Expand Down
Loading
Loading