Skip to content

Commit 0188d23

Browse files
authored
feat(model): Support Qwen2.5 VL models (#2743)
1 parent e04b0f9 commit 0188d23

File tree

5 files changed

+103
-36
lines changed

5 files changed

+103
-36
lines changed

packages/dbgpt-core/pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,9 @@ hf_kimi = [
157157
hf_qwen3 = [
158158
"transformers>=4.51.0",
159159
]
160+
hf_qwen2vl = [
161+
"transformers>=4.34.0",
162+
]
160163

161164
[build-system]
162165
requires = ["hatchling"]

packages/dbgpt-core/src/dbgpt/model/adapter/hf_adapter.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,7 @@ def do_match(self, lower_model_name_or_path: Optional[str] = None):
553553
and "1.5" in lower_model_name_or_path
554554
and "moe" not in lower_model_name_or_path
555555
and "qwen2" not in lower_model_name_or_path
556+
and "vl" not in lower_model_name_or_path
556557
)
557558

558559

@@ -565,10 +566,12 @@ def do_match(self, lower_model_name_or_path: Optional[str] = None):
565566
(
566567
"qwen2" in lower_model_name_or_path
567568
and "instruct" in lower_model_name_or_path
569+
and "vl" not in lower_model_name_or_path
568570
)
569571
or (
570572
"qwen2.5" in lower_model_name_or_path
571573
and "instruct" in lower_model_name_or_path
574+
and "vl" not in lower_model_name_or_path
572575
)
573576
)
574577

@@ -608,6 +611,7 @@ def do_match(self, lower_model_name_or_path: Optional[str] = None):
608611
return lower_model_name_or_path and (
609612
"qwen3" in lower_model_name_or_path
610613
and "base" not in lower_model_name_or_path
614+
and "vl" not in lower_model_name_or_path
611615
)
612616

613617
def check_transformer_version(self, current_version: str) -> None:
@@ -664,6 +668,60 @@ def get_str_prompt(
664668
return str_prompt
665669

666670

671+
class Qwen2VLAdapter(NewHFChatModelAdapter):
672+
def check_transformer_version(self, current_version: str) -> None:
673+
if not current_version >= "4.37.0":
674+
raise ValueError(
675+
"Qwen2.5VL model require transformers.__version__>=4.37.0, please "
676+
"upgrade your transformers package."
677+
)
678+
679+
def do_match(self, lower_model_name_or_path: Optional[str] = None):
680+
return (
681+
lower_model_name_or_path
682+
and "qwen2" in lower_model_name_or_path
683+
and "vl" in lower_model_name_or_path
684+
and "instruct" in lower_model_name_or_path
685+
)
686+
687+
def load(self, model_path: str, from_pretrained_kwargs: dict):
688+
try:
689+
from transformers import (
690+
Qwen2_5_VLForConditionalGeneration,
691+
)
692+
except ImportError as exc:
693+
raise ValueError(
694+
"Could not import qwen2.5 vl model, please upgrade your "
695+
"transformers package to 4.37.0 or later."
696+
) from exc
697+
698+
logger.info(
699+
f"Load model from {model_path}, from_pretrained_kwargs: "
700+
f"{from_pretrained_kwargs}"
701+
)
702+
703+
revision = from_pretrained_kwargs.get("revision", "main")
704+
trust_remote_code = from_pretrained_kwargs.get(
705+
"trust_remote_code", self.trust_remote_code
706+
)
707+
low_cpu_mem_usage = from_pretrained_kwargs.get("low_cpu_mem_usage", False)
708+
if "trust_remote_code" not in from_pretrained_kwargs:
709+
from_pretrained_kwargs["trust_remote_code"] = trust_remote_code
710+
if "low_cpu_mem_usage" not in from_pretrained_kwargs:
711+
from_pretrained_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
712+
713+
tokenizer = self.load_tokenizer(
714+
model_path,
715+
revision,
716+
use_fast=self.use_fast_tokenizer(),
717+
trust_remote_code=trust_remote_code,
718+
)
719+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
720+
model_path, **from_pretrained_kwargs
721+
)
722+
return model, tokenizer
723+
724+
667725
class QwenOmniAdapter(NewHFChatModelAdapter):
668726
def do_match(self, lower_model_name_or_path: Optional[str] = None):
669727
return lower_model_name_or_path and (
@@ -1075,6 +1133,7 @@ def is_reasoning_model(
10751133
register_model_adapter(GLM40414Adapter)
10761134
register_model_adapter(Codegeex4Adapter)
10771135
register_model_adapter(Qwen2Adapter, supported_models=COMMON_HF_QWEN25_MODELS)
1136+
register_model_adapter(Qwen2VLAdapter)
10781137
register_model_adapter(Internlm2Adapter)
10791138
register_model_adapter(DeepseekV3R1Adapter, supported_models=COMMON_HF_DEEPSEEK__MODELS)
10801139
register_model_adapter(KimiVLAdapter)

tests/intetration_tests/datasource/test_conn_mysql.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
"""
2-
Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_mysql.py
3-
docker run -itd --name mysql-test -p 3307:3306 -e MYSQL_ROOT_PASSWORD=12345678 mysql:5.7
4-
mysql -h 127.0.0.1 -uroot -p -P3307
5-
Enter password:
6-
Welcome to the MySQL monitor. Commands end with ; or \g.
7-
Your MySQL connection id is 2
8-
Server version: 5.7.41 MySQL Community Server (GPL)
9-
10-
Copyright (c) 2000, 2023, Oracle and/or its affiliates.
11-
12-
Oracle is a registered trademark of Oracle Corporation and/or its
13-
affiliates. Other names may be trademarks of their respective
14-
owners.
15-
16-
Type 'help;' or '\h' for help. Type '\c' to clear the current input statement.
17-
18-
> create database test;
2+
Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_mysql.py
3+
docker run -itd --name mysql-test -p 3307:3306 -e MYSQL_ROOT_PASSWORD=12345678 mysql:5.7
4+
mysql -h 127.0.0.1 -uroot -p -P3307
5+
Enter password:
6+
Welcome to the MySQL monitor. Commands end with ; or \g.
7+
Your MySQL connection id is 2
8+
Server version: 5.7.41 MySQL Community Server (GPL)
9+
10+
Copyright (c) 2000, 2023, Oracle and/or its affiliates.
11+
12+
Oracle is a registered trademark of Oracle Corporation and/or its
13+
affiliates. Other names may be trademarks of their respective
14+
owners.
15+
16+
Type 'help;' or '\h' for help. Type '\c' to clear the current input statement.
17+
18+
> create database test;
1919
"""
2020

2121
import pytest

tests/intetration_tests/datasource/test_conn_oracle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,4 +116,4 @@ def test_get_database_lists(db):
116116
pdb_names = [name[0] for name in databases[1:]]
117117
else:
118118
pdb_names = ["ORCL"]
119-
assert any(name in ("ORCLPDB", "ORCL") for name in pdb_names)
119+
assert any(name in ("ORCLPDB", "ORCL") for name in pdb_names)

0 commit comments

Comments
 (0)