Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
07c667c
refactor SchedulerSequence
grimoire Aug 26, 2025
af01586
block sparse attn
grimoire Aug 26, 2025
e6e440d
Merge branch 'refactor-seqs' into support-SDAR
grimoire Aug 27, 2025
4301864
Merge branch 'block-sparse-attn' into support-SDAR
grimoire Aug 27, 2025
e328c5d
support SDAR
grimoire Sep 1, 2025
63efa34
Merge branch 'main' into support-SDAR
grimoire Sep 1, 2025
48a0137
fix max_new_tokens;update profiler
grimoire Sep 1, 2025
6e8f4c5
add args
grimoire Sep 1, 2025
42f4582
fix multiround stop words
grimoire Sep 1, 2025
9a68f1a
fix sampling step
grimoire Sep 2, 2025
0fa2e7e
optimize position_ids
grimoire Sep 2, 2025
85255d2
fix long context
grimoire Sep 2, 2025
b65afc5
fix vlm
grimoire Sep 2, 2025
da2f403
fix stopping
grimoire Sep 2, 2025
e6b5bdd
move args into logitsprocessor
grimoire Sep 2, 2025
2b0e607
rename
grimoire Sep 3, 2025
f7c7cd8
Merge branch 'main' into support-SDAR
grimoire Sep 3, 2025
a660a43
fix pd
grimoire Sep 3, 2025
b23d962
rename
grimoire Sep 3, 2025
34e41aa
strategy + abstruct factory
grimoire Sep 5, 2025
de49bb5
update seqs
grimoire Sep 5, 2025
3890cfe
add moe support
grimoire Sep 8, 2025
c1e4cde
bind block length
grimoire Sep 8, 2025
d9d688c
solve conflict
grimoire Sep 11, 2025
26f4c2d
fix num loops
grimoire Sep 12, 2025
11674bf
enum unmasking type
grimoire Sep 15, 2025
8fce74a
typo fixing
grimoire Sep 15, 2025
94c3013
warning
grimoire Sep 15, 2025
c74b535
fix metric
grimoire Sep 16, 2025
bbd1489
limit batch size
grimoire Sep 16, 2025
11d3c2e
merge main
grimoire Sep 17, 2025
cc67ff6
merge main
grimoire Sep 18, 2025
e8771be
rename field;comment unmasking strategy
grimoire Sep 18, 2025
59c7c62
suppression warning
grimoire Sep 18, 2025
c0165df
solve conflict
grimoire Sep 18, 2025
1e47c31
colored vis
grimoire Sep 18, 2025
ee71d91
fix dummy
grimoire Sep 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
support SDAR
  • Loading branch information
grimoire committed Sep 1, 2025
commit e328c5dffd858d97ebdf94fbeef5a74e1b1ca96c
2 changes: 2 additions & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ class PytorchEngineConfig:
It can be used to override the default config of the model,
disable_vision_encoder (bool): Whether to disable loading vision
encoder. Default to False.
block_sparse_size (int): Block size of block diffusion model.
logprobs_mode (str): The mode of logprob, options: ['raw_logits', 'raw_logprobs']
"""
dtype: str = 'auto'
Expand Down Expand Up @@ -367,6 +368,7 @@ class PytorchEngineConfig:
enable_metrics: bool = False
hf_overrides: Optional[Dict[str, Any]] = None
disable_vision_encoder: bool = False
block_sparse_size: int = 1
logprobs_mode: str = None

role: EngineRole = EngineRole.Hybrid
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/pytorch/backends/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def build(
causal: bool = True,
use_flash_mla: bool = False,
learnable_sink: bool = False,
block_sparse_size: int = 1,
**kwargs,
) -> AttentionImpl[T]:
"""build."""
Expand Down
10 changes: 8 additions & 2 deletions lmdeploy/pytorch/backends/cuda/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
sliding_window: int = None,
logit_softcapping: float = None,
causal: bool = True,
block_sparse_size: int = 1,
**kwargs,
):
super().__init__(
Expand Down Expand Up @@ -91,6 +92,7 @@ def __init__(
world_size, rank = get_tp_world_rank()
self.alibi_head_offset = self.num_heads * rank
self.alibi_num_heads = self.num_heads * world_size
self.block_sparse_size = block_sparse_size

def forward(
self,
Expand All @@ -116,7 +118,7 @@ def forward(
kv_flatten_size = attn_metadata.kv_flatten_size
quant_policy = attn_metadata.quant_policy
if attn_metadata.is_decoding:
max_q_seqlen = 1
max_q_seqlen = self.block_sparse_size
else:
max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2))
fill_max_q_seqlen = max_q_seqlen
Expand Down Expand Up @@ -213,6 +215,7 @@ def forward(
logit_softcapping=self.logit_softcapping,
sinks=learnable_sink,
causal=self.causal,
block_sparse_size=self.block_sparse_size,
)

return attn_output
Expand Down Expand Up @@ -528,9 +531,11 @@ def build(
causal: bool = True,
use_flash_mla: bool = False,
learnable_sink: bool = False,
block_sparse_size: int = 1,
**kwargs,
) -> TritonAttentionImpl:
"""build."""
enable_fa3 = use_fa3 and not alibi and not learnable_sink and block_sparse_size == 1
if use_flash_mla is True:
return FlashMLAImpl(num_heads,
head_size,
Expand All @@ -542,7 +547,7 @@ def build(
logical_softcapping=logical_softcapping,
causal=causal,
**kwargs)
elif use_fa3 and not alibi and not learnable_sink:
elif enable_fa3:
return FA3Impl(num_heads,
head_size,
scale=scale,
Expand All @@ -563,4 +568,5 @@ def build(
sliding_window=sliding_window,
logical_softcapping=logical_softcapping,
causal=causal,
block_sparse_size=block_sparse_size,
**kwargs)
27 changes: 20 additions & 7 deletions lmdeploy/pytorch/backends/cuda/graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from lmdeploy.utils import get_logger

from ..graph_runner import GraphRunner
from .attention import TritonAttentionMetadata

logger = get_logger('lmdeploy')

Expand Down Expand Up @@ -173,18 +174,30 @@ def _get_capture_tokens(self, batch_size: int):
assert False, f'Unsupported batch_size={batch_size}'

def get_graph_key(self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: List,
attn_metadata: Any, inputs_embeds: torch.Tensor, **kwargs):
attn_metadata: TritonAttentionMetadata, inputs_embeds: torch.Tensor, **kwargs):
"""Get graph key."""
context = self.ctx_mgr.current_context()
is_decoding = context.is_decoding
num_tokens = input_ids.numel()
batch_size = attn_metadata.q_seqlens.size(0)
meta = self.get_meta()
enable_microbatch = get_step_ctx_manager().current_context().enable_microbatch
if meta.padding_batch_size is None:
new_num_tokens = self._get_capture_tokens(num_tokens)
batch_size = self._get_capture_tokens(batch_size)
else:
new_num_tokens = self._get_capture_tokens(meta.padding_batch_size)
return (new_num_tokens, is_decoding, enable_microbatch)
batch_size = self._get_capture_tokens(meta.padding_batch_size)
return (batch_size, is_decoding, enable_microbatch)

def _get_max_tokens(self, graph_key: tuple):
max_batches = graph_key[0]
is_decoding = graph_key[1]
assert is_decoding
model_paradigm = self.model_config.model_paradigm
if model_paradigm == 'dllm':
step_mgr = get_step_ctx_manager()
build_ctx = step_mgr.build_ctx
block_sparse_size = build_ctx.block_sparse_size
return max_batches * block_sparse_size
return max_batches

def __call__(self, **kwargs):
"""call."""
Expand All @@ -198,10 +211,10 @@ def __call__(self, **kwargs):
return self.model(**kwargs)

graph_key = self.get_graph_key(**kwargs)
max_tokens = graph_key[0]
max_batches = graph_key[0]
is_decoding = graph_key[1]
if graph_key not in self._runner_map:
max_batches = max_tokens if is_decoding else self.max_batches
max_tokens = self._get_max_tokens(graph_key)
runner = CUDASingleGraphRunner(self.model,
max_batches=max_batches,
max_tokens=max_tokens,
Expand Down
4 changes: 4 additions & 0 deletions lmdeploy/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ class ModelConfig:
cogvlm_style: bool = False
custom_module_map: Dict[str, setattr] = None
use_flash_mla: bool = False
model_paradigm: str = 'llm'
dllm_mask_token: int = 0

def get_head_size(self):
"""Get head size."""
Expand Down Expand Up @@ -294,6 +296,7 @@ class MiscConfig:
hf_overrides: Dict[str, Any] = None
disable_vision_encoder: bool = False
logprobs_mode: str = None
block_sparse_size: int = 1

@classmethod
def from_engine_config(cls, engine_config: PytorchEngineConfig):
Expand All @@ -304,5 +307,6 @@ def from_engine_config(cls, engine_config: PytorchEngineConfig):
model_format=engine_config.model_format,
hf_overrides=engine_config.hf_overrides,
disable_vision_encoder=engine_config.disable_vision_encoder,
block_sparse_size=engine_config.block_sparse_size,
logprobs_mode=engine_config.logprobs_mode)
return misc_config
18 changes: 18 additions & 0 deletions lmdeploy/pytorch/configurations/sdar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .default import AutoModelConfigBuilder, DefaultModelConfigBuilder


class SDARModelConfigBuilder(AutoModelConfigBuilder):

@classmethod
def condition(cls, hf_config):
"""config."""
return hf_config.model_type == 'sdar'

@classmethod
def build(cls, hf_config, model_path: str = None, **kwargs):
"""build."""
cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs)
cfg.dllm_mask_token = 151669
cfg.model_paradigm = 'dllm'
return cfg
Loading