Skip to content

Commit 0bdff7a

Browse files
csunnyfangyinc
andauthored
fix: fix vllm tokenizer count error (#2555)
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
1 parent c68332b commit 0bdff7a

File tree

1 file changed

+18
-10
lines changed

1 file changed

+18
-10
lines changed

packages/dbgpt-core/src/dbgpt/model/cluster/worker/default_worker.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from dbgpt.model.adapter.loader import ModelLoader
2020
from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
2121
from dbgpt.model.cluster.worker_base import ModelWorker
22+
from dbgpt.model.proxy.base import TiktokenProxyTokenizer
23+
from dbgpt.util.executor_utils import blocking_func_to_async_no_executor
2224
from dbgpt.util.model_utils import _clear_model_cache, _get_current_cuda_memory
2325
from dbgpt.util.parameter_utils import _get_dict_from_obj
2426
from dbgpt.util.system_utils import get_system_info
@@ -43,6 +45,8 @@ def __init__(self) -> None:
4345
self._support_generate_func = False
4446
self.context_len = 4096
4547
self._device = get_device()
48+
# Use tiktoken to count token if model doesn't support
49+
self._tiktoken = TiktokenProxyTokenizer()
4650

4751
def load_worker(
4852
self, model_name: str, deploy_model_params: BaseDeployModelParameters, **kwargs
@@ -241,18 +245,20 @@ def generate(self, params: Dict) -> ModelOutput:
241245
return output
242246

243247
def count_token(self, prompt: str) -> int:
244-
return _try_to_count_token(prompt, self.tokenizer, self.model)
248+
return _try_to_count_token(prompt, self.tokenizer, self.model, self._tiktoken)
245249

246250
async def async_count_token(self, prompt: str) -> int:
247-
# TODO if we deploy the model by vllm, it can't work, we should run
248-
# transformer _try_to_count_token to async
249251
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
250252

251253
if isinstance(self.model, ProxyModel) and self.model.proxy_llm_client:
252254
return await self.model.proxy_llm_client.count_token(
253255
self.model.proxy_llm_client.default_model, prompt
254256
)
255-
raise NotImplementedError
257+
258+
cnt = await blocking_func_to_async_no_executor(
259+
_try_to_count_token, prompt, self.tokenizer, self.model, self._tiktoken
260+
)
261+
return cnt
256262

257263
def get_model_metadata(self, params: Dict) -> ModelMetadata:
258264
ext_metadata = ModelExtraMedata(
@@ -594,7 +600,9 @@ def _new_metrics_from_model_output(
594600
return metrics
595601

596602

597-
def _try_to_count_token(prompt: str, tokenizer, model) -> int:
603+
def _try_to_count_token(
604+
prompt: str, tokenizer, model, tiktoken: TiktokenProxyTokenizer
605+
) -> int:
598606
"""Try to count token of prompt
599607
600608
Args:
@@ -612,11 +620,11 @@ def _try_to_count_token(prompt: str, tokenizer, model) -> int:
612620

613621
if isinstance(model, ProxyModel):
614622
return model.count_token(prompt)
615-
# Only support huggingface model now
616-
return len(tokenizer(prompt).input_ids[0])
617-
except Exception as e:
618-
logger.warning(f"Count token error, detail: {e}, return -1")
619-
return -1
623+
# Only support huggingface and vllm model now
624+
return len(tokenizer([prompt]).input_ids[0])
625+
except Exception as _e:
626+
logger.warning("Failed to count token, try tiktoken")
627+
return tiktoken.count_token("cl100k_base", [prompt])[0]
620628

621629

622630
def _try_import_torch():

0 commit comments

Comments
 (0)