1919from dbgpt .model .adapter .loader import ModelLoader
2020from dbgpt .model .adapter .model_adapter import get_llm_model_adapter
2121from dbgpt .model .cluster .worker_base import ModelWorker
22+ from dbgpt .model .proxy .base import TiktokenProxyTokenizer
23+ from dbgpt .util .executor_utils import blocking_func_to_async_no_executor
2224from dbgpt .util .model_utils import _clear_model_cache , _get_current_cuda_memory
2325from dbgpt .util .parameter_utils import _get_dict_from_obj
2426from dbgpt .util .system_utils import get_system_info
@@ -43,6 +45,8 @@ def __init__(self) -> None:
4345 self ._support_generate_func = False
4446 self .context_len = 4096
4547 self ._device = get_device ()
48+ # Use tiktoken to count token if model doesn't support
49+ self ._tiktoken = TiktokenProxyTokenizer ()
4650
4751 def load_worker (
4852 self , model_name : str , deploy_model_params : BaseDeployModelParameters , ** kwargs
@@ -241,18 +245,20 @@ def generate(self, params: Dict) -> ModelOutput:
241245 return output
242246
243247 def count_token (self , prompt : str ) -> int :
244- return _try_to_count_token (prompt , self .tokenizer , self .model )
248+ return _try_to_count_token (prompt , self .tokenizer , self .model , self . _tiktoken )
245249
246250 async def async_count_token (self , prompt : str ) -> int :
247- # TODO if we deploy the model by vllm, it can't work, we should run
248- # transformer _try_to_count_token to async
249251 from dbgpt .model .proxy .llms .proxy_model import ProxyModel
250252
251253 if isinstance (self .model , ProxyModel ) and self .model .proxy_llm_client :
252254 return await self .model .proxy_llm_client .count_token (
253255 self .model .proxy_llm_client .default_model , prompt
254256 )
255- raise NotImplementedError
257+
258+ cnt = await blocking_func_to_async_no_executor (
259+ _try_to_count_token , prompt , self .tokenizer , self .model , self ._tiktoken
260+ )
261+ return cnt
256262
257263 def get_model_metadata (self , params : Dict ) -> ModelMetadata :
258264 ext_metadata = ModelExtraMedata (
@@ -594,7 +600,9 @@ def _new_metrics_from_model_output(
594600 return metrics
595601
596602
597- def _try_to_count_token (prompt : str , tokenizer , model ) -> int :
603+ def _try_to_count_token (
604+ prompt : str , tokenizer , model , tiktoken : TiktokenProxyTokenizer
605+ ) -> int :
598606 """Try to count token of prompt
599607
600608 Args:
@@ -612,11 +620,11 @@ def _try_to_count_token(prompt: str, tokenizer, model) -> int:
612620
613621 if isinstance (model , ProxyModel ):
614622 return model .count_token (prompt )
615- # Only support huggingface model now
616- return len (tokenizer (prompt ).input_ids [0 ])
617- except Exception as e :
618- logger .warning (f"Count token error, detail: { e } , return -1 " )
619- return - 1
623+ # Only support huggingface and vllm model now
624+ return len (tokenizer ([ prompt ] ).input_ids [0 ])
625+ except Exception as _e :
626+ logger .warning ("Failed to count token, try tiktoken " )
627+ return tiktoken . count_token ( "cl100k_base" , [ prompt ])[ 0 ]
620628
621629
622630def _try_import_torch ():
0 commit comments