11from __future__ import annotations
22
33import logging
4- from typing import TYPE_CHECKING , Any , Type
54
6- from dbgpt .component import ComponentType , SystemApp
5+ from dbgpt .component import SystemApp
76from dbgpt ._private .config import Config
87from dbgpt .configs .model_config import MODEL_DISK_CACHE_DIR
98from dbgpt .util .executor_utils import DefaultExecutorFactory
10- from dbgpt .rag .embedding_engine .embedding_factory import EmbeddingFactory
119from dbgpt .app .base import WebServerParameters
1210
13- if TYPE_CHECKING :
14- from langchain .embeddings .base import Embeddings
1511
1612logger = logging .getLogger (__name__ )
1713
@@ -24,7 +20,10 @@ def initialize_components(
2420 embedding_model_name : str ,
2521 embedding_model_path : str ,
2622):
23+ # Lazy import to avoid high time cost
2724 from dbgpt .model .cluster .controller .controller import controller
25+ from dbgpt .app .initialization .embedding_component import _initialize_embedding_model
26+ from dbgpt .app .initialization .serve_initialization import register_serve_apps
2827
2928 # Register global default executor factory first
3029 system_app .register (DefaultExecutorFactory )
@@ -44,97 +43,8 @@ def initialize_components(
4443 )
4544 _initialize_model_cache (system_app )
4645 _initialize_awel (system_app )
47-
48-
49- def _initialize_embedding_model (
50- param : WebServerParameters ,
51- system_app : SystemApp ,
52- embedding_model_name : str ,
53- embedding_model_path : str ,
54- ):
55- if param .remote_embedding :
56- logger .info ("Register remote RemoteEmbeddingFactory" )
57- system_app .register (RemoteEmbeddingFactory , model_name = embedding_model_name )
58- else :
59- logger .info (f"Register local LocalEmbeddingFactory" )
60- system_app .register (
61- LocalEmbeddingFactory ,
62- default_model_name = embedding_model_name ,
63- default_model_path = embedding_model_path ,
64- )
65-
66-
67- class RemoteEmbeddingFactory (EmbeddingFactory ):
68- def __init__ (self , system_app , model_name : str = None , ** kwargs : Any ) -> None :
69- super ().__init__ (system_app = system_app )
70- self ._default_model_name = model_name
71- self .kwargs = kwargs
72- self .system_app = system_app
73-
74- def init_app (self , system_app ):
75- self .system_app = system_app
76-
77- def create (
78- self , model_name : str = None , embedding_cls : Type = None
79- ) -> "Embeddings" :
80- from dbgpt .model .cluster import WorkerManagerFactory
81- from dbgpt .model .cluster .embedding .remote_embedding import RemoteEmbeddings
82-
83- if embedding_cls :
84- raise NotImplementedError
85- worker_manager = self .system_app .get_component (
86- ComponentType .WORKER_MANAGER_FACTORY , WorkerManagerFactory
87- ).create ()
88- # Ignore model_name args
89- return RemoteEmbeddings (self ._default_model_name , worker_manager )
90-
91-
92- class LocalEmbeddingFactory (EmbeddingFactory ):
93- def __init__ (
94- self ,
95- system_app ,
96- default_model_name : str = None ,
97- default_model_path : str = None ,
98- ** kwargs : Any ,
99- ) -> None :
100- super ().__init__ (system_app = system_app )
101- self ._default_model_name = default_model_name
102- self ._default_model_path = default_model_path
103- self ._kwargs = kwargs
104- self ._model = self ._load_model ()
105-
106- def init_app (self , system_app ):
107- pass
108-
109- def create (
110- self , model_name : str = None , embedding_cls : Type = None
111- ) -> "Embeddings" :
112- if embedding_cls :
113- raise NotImplementedError
114- return self ._model
115-
116- def _load_model (self ) -> "Embeddings" :
117- from dbgpt .model .cluster .embedding .loader import EmbeddingLoader
118- from dbgpt .model .cluster .worker .embedding_worker import _parse_embedding_params
119- from dbgpt .model .parameter import (
120- EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG ,
121- BaseEmbeddingModelParameters ,
122- EmbeddingModelParameters ,
123- )
124-
125- param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG .get (
126- self ._default_model_name , EmbeddingModelParameters
127- )
128- model_params : BaseEmbeddingModelParameters = _parse_embedding_params (
129- model_name = self ._default_model_name ,
130- model_path = self ._default_model_path ,
131- param_cls = param_cls ,
132- ** self ._kwargs ,
133- )
134- logger .info (model_params )
135- loader = EmbeddingLoader ()
136- # Ignore model_name args
137- return loader .load (self ._default_model_name , model_params )
46+ # Register serve apps
47+ register_serve_apps (system_app )
13848
13949
14050def _initialize_model_cache (system_app : SystemApp ):
0 commit comments