Skip to content

feat: add pytorch_engine_qwen2_5vl_sm120#3750

Merged
lvhan028 merged 6 commits intoInternLM:mainfrom
kolmogorov-quyet:feature/pytorch_engine_qwen2_5vl_sm120
Jul 24, 2025
Merged

feat: add pytorch_engine_qwen2_5vl_sm120#3750
lvhan028 merged 6 commits intoInternLM:mainfrom
kolmogorov-quyet:feature/pytorch_engine_qwen2_5vl_sm120

Conversation

@kolmogorov-quyet
Copy link
Copy Markdown
Contributor

Motivation

  • Silence spurious CancelledError logs from asyncio.run_coroutine_threadsafe when the future is cancelled.
  • Improve FlashAttention kernel meta-tuning for newer GPUs (adds SM 128 path).
  • Keep the repository clean with extra .gitignore entries.
  • Add a minimal smoke test for the new engine to ensure CI coverage.

Modifications

File Key changes
.gitignore Ignore builder/, lmdeploy/lib/, IDE caches, etc.
lmdeploy/pytorch/kernels/cuda/flashattention.py Added _kernel_meta_sm128(), refactored meta-selection logic (~40 LoC).
lmdeploy/serve/async_engine.py Replaced lambda f: f.result() with a safe callback:
lambda f: None if f.cancelled() else f.result()
tests/test.py New smoke test: init engine, run one caption generation, assert non-empty output.

Backward Compatibility

No API breaks; existing engines and interfaces continue to work.

Use Cases

  • Cleaner logs under heavy load—no more CancelledError spam.
  • Better kernel parameters for Ada/Hopper-next GPUs (RTX 50xx, etc.).
  • Quick regression guard via the new unit test.

Checklist

  • Code formatted / linted
  • Unit tests added & pass (pytest -q)
  • Documentation updated if necessary
  • Ready for review


num_warps = 4
if _nv_cap[0] < 8:
if _nv_cap[0] >= 12: # Blackwell (sm_120 etc.)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you please follow the original style to enable Blackwell support when _nv_cap[0] < 13 and put this branch to the end of the if-elif-else block? I think this will help the community better understand the code.

@windreamer
Copy link
Copy Markdown
Collaborator

And you can try to install pre-commit to help you fix lint errors.

Thank you for your efforts!

@kolmogorov-quyet
Copy link
Copy Markdown
Contributor Author

Thank you for your helpful review and suggestions 🙏
I've updated the code to follow the original style for better clarity, as you mentioned.
Also, I’ve set up pre-commit and made sure it passes all hooks, to avoid common PEP8 issues.

Please take a look when you have time — really appreciate your support!

@windreamer windreamer requested review from grimoire and lvhan028 July 23, 2025 17:14
test.py Outdated
from lmdeploy import PytorchEngineConfig, pipeline
from lmdeploy.vl import load_image

backend_config = PytorchEngineConfig(session_len=16384)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lvhan028 any advice about the tests?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @kolmogorov-quyet
We appreciate your contribution. Just a quick note—the lmdeploy/tests directory is intended for unit test cases rather than functional testing.
We've already integrated Qwen2.5-VL model testing into lmdeploy's functional test suite, so you can safely remove this test file

Copy link
Copy Markdown
Collaborator

@grimoire grimoire left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lvhan028 lvhan028 added the enhancement New feature or request label Jul 24, 2025
@lvhan028 lvhan028 merged commit 5bea8f5 into InternLM:main Jul 24, 2025
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants