Skip to content

Commit 511a43b

Browse files
authored
refactor: Refactor storage and new serve template (#947)
1 parent 22d95b4 commit 511a43b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+1891
-229
lines changed

.env.template

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@ WEB_SERVER_PORT=7860
2121
#*******************************************************************#
2222
#** LLM MODELS **#
2323
#*******************************************************************#
24-
# LLM_MODEL, see /pilot/configs/model_config.LLM_MODEL_CONFIG
24+
# LLM_MODEL, see dbgpt/configs/model_config.LLM_MODEL_CONFIG
2525
LLM_MODEL=vicuna-13b-v1.5
2626
## LLM model path, by default, DB-GPT will read the model path from LLM_MODEL_CONFIG based on the LLM_MODEL.
2727
## Of course you can specify your model path according to LLM_MODEL_PATH
2828
## In DB-GPT, the priority from high to low to read model path:
2929
## 1. environment variable with key: {LLM_MODEL}_MODEL_PATH (Avoid multi-model conflicts)
3030
## 2. environment variable with key: MODEL_PATH
3131
## 3. environment variable with key: LLM_MODEL_PATH
32-
## 4. the config in /pilot/configs/model_config.LLM_MODEL_CONFIG
32+
## 4. the config in dbgpt/configs/model_config.LLM_MODEL_CONFIG
3333
# LLM_MODEL_PATH=/app/models/vicuna-13b-v1.5
3434
# LLM_PROMPT_TEMPLATE=vicuna_v1.1
3535
MODEL_SERVER=http://127.0.0.1:8000
@@ -47,7 +47,7 @@ QUANTIZE_8bit=True
4747
# PROXYLLM_BACKEND=
4848

4949
### You can configure parameters for a specific model with {model name}_{config key}=xxx
50-
### See /pilot/model/parameter.py
50+
### See dbgpt/model/parameter.py
5151
## prompt template for current model
5252
# llama_cpp_prompt_template=vicuna_v1.1
5353
## llama-2-70b must be 8
@@ -86,7 +86,7 @@ KNOWLEDGE_SEARCH_REWRITE=False
8686
# EMBEDDING_TOKENIZER=all-MiniLM-L6-v2
8787
# EMBEDDING_TOKEN_LIMIT=8191
8888

89-
## Openai embedding model, See /pilot/model/parameter.py
89+
## Openai embedding model, See dbgpt/model/parameter.py
9090
# EMBEDDING_MODEL=proxy_openai
9191
# proxy_openai_proxy_server_url=https://api.openai.com/v1
9292
# proxy_openai_proxy_api_key={your-openai-sk}

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ __pycache__/
88

99
message/
1010

11-
.env
11+
.env*
1212
.vscode
1313
.idea
1414
.chroma

dbgpt/agent/db/my_plugin_db.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,6 @@
77

88
class MyPluginEntity(Model):
99
__tablename__ = "my_plugin"
10-
__table_args__ = {
11-
"mysql_charset": "utf8mb4",
12-
"mysql_collate": "utf8mb4_unicode_ci",
13-
}
1410
id = Column(Integer, primary_key=True, comment="autoincrement id")
1511
tenant = Column(String(255), nullable=True, comment="user's tenant")
1612
user_code = Column(String(255), nullable=False, comment="user code")
@@ -32,7 +28,7 @@ class MyPluginEntity(Model):
3228
UniqueConstraint("user_code", "name", name="uk_name")
3329

3430

35-
class MyPluginDao(BaseDao[MyPluginEntity]):
31+
class MyPluginDao(BaseDao):
3632
def add(self, engity: MyPluginEntity):
3733
session = self.get_raw_session()
3834
my_plugin = MyPluginEntity(
@@ -53,7 +49,7 @@ def add(self, engity: MyPluginEntity):
5349
session.close()
5450
return id
5551

56-
def update(self, entity: MyPluginEntity):
52+
def raw_update(self, entity: MyPluginEntity):
5753
session = self.get_raw_session()
5854
updated = session.merge(entity)
5955
session.commit()
@@ -128,7 +124,7 @@ def count(self, query: MyPluginEntity):
128124
session.close()
129125
return count
130126

131-
def delete(self, plugin_id: int):
127+
def raw_delete(self, plugin_id: int):
132128
session = self.get_raw_session()
133129
if plugin_id is None:
134130
raise Exception("plugin_id is None")

dbgpt/agent/db/plugin_hub_db.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,6 @@
1111

1212
class PluginHubEntity(Model):
1313
__tablename__ = "plugin_hub"
14-
__table_args__ = {
15-
"mysql_charset": "utf8mb4",
16-
"mysql_collate": "utf8mb4_unicode_ci",
17-
}
1814
id = Column(
1915
Integer, primary_key=True, autoincrement=True, comment="autoincrement id"
2016
)
@@ -36,7 +32,7 @@ class PluginHubEntity(Model):
3632
Index("idx_q_type", "type")
3733

3834

39-
class PluginHubDao(BaseDao[PluginHubEntity]):
35+
class PluginHubDao(BaseDao):
4036
def add(self, engity: PluginHubEntity):
4137
session = self.get_raw_session()
4238
timezone = pytz.timezone("Asia/Shanghai")
@@ -56,7 +52,7 @@ def add(self, engity: PluginHubEntity):
5652
session.close()
5753
return id
5854

59-
def update(self, entity: PluginHubEntity):
55+
def raw_update(self, entity: PluginHubEntity):
6056
session = self.get_raw_session()
6157
try:
6258
updated = session.merge(entity)
@@ -131,7 +127,7 @@ def count(self, query: PluginHubEntity):
131127
session.close()
132128
return count
133129

134-
def delete(self, plugin_id: int):
130+
def raw_delete(self, plugin_id: int):
135131
session = self.get_raw_session()
136132
if plugin_id is None:
137133
raise Exception("plugin_id is None")

dbgpt/agent/hub/agent_hub.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def refresh_hub_from_git(
159159
plugin_hub_info.name = git_plugin._name
160160
plugin_hub_info.version = git_plugin._version
161161
plugin_hub_info.description = git_plugin._description
162-
self.hub_dao.update(plugin_hub_info)
162+
self.hub_dao.raw_update(plugin_hub_info)
163163
except Exception as e:
164164
raise ValueError(f"Update Agent Hub Db Info Faild!{str(e)}")
165165

@@ -194,7 +194,7 @@ async def upload_my_plugin(self, doc_file: UploadFile, user: Any = Default_User)
194194
my_plugin_entiy.user_name = user
195195
my_plugin_entiy.tenant = ""
196196
my_plugin_entiy.file_name = doc_file.filename
197-
self.my_plugin_dao.update(my_plugin_entiy)
197+
self.my_plugin_dao.raw_update(my_plugin_entiy)
198198

199199
def reload_my_plugins(self):
200200
logger.info(f"load_plugins start!")

dbgpt/app/_cli.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,24 @@ def migrate(alembic_ini_path: str, script_location: str, message: str):
108108

109109
@migration.command()
110110
@add_migration_options
111-
def upgrade(alembic_ini_path: str, script_location: str):
111+
@click.option(
112+
"--sql-output",
113+
type=str,
114+
default=None,
115+
help="Generate SQL script for migration instead of applying it. ex: --sql-output=upgrade.sql",
116+
)
117+
def upgrade(alembic_ini_path: str, script_location: str, sql_output: str):
112118
"""Upgrade database to target version"""
113-
from dbgpt.util._db_migration_utils import upgrade_database
119+
from dbgpt.util._db_migration_utils import (
120+
upgrade_database,
121+
generate_sql_for_upgrade,
122+
)
114123

115124
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
116-
upgrade_database(alembic_cfg, db_manager.engine)
125+
if sql_output:
126+
generate_sql_for_upgrade(alembic_cfg, db_manager.engine, output_file=sql_output)
127+
else:
128+
upgrade_database(alembic_cfg, db_manager.engine)
117129

118130

119131
@migration.command()
@@ -199,6 +211,7 @@ def clean(
199211
def list(alembic_ini_path: str, script_location: str):
200212
"""List all versions in the migration history, marking the current one"""
201213
from alembic.script import ScriptDirectory
214+
202215
from alembic.runtime.migration import MigrationContext
203216

204217
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
@@ -259,8 +272,8 @@ def _get_migration_config(
259272
from dbgpt.storage.metadata.db_manager import db as db_manager
260273
from dbgpt.util._db_migration_utils import create_alembic_config
261274

262-
# Must import dbgpt_server for initialize db metadata
263-
from dbgpt.app.dbgpt_server import initialize_app as _
275+
# Import all models to make sure they are registered with SQLAlchemy.
276+
from dbgpt.app.initialization.db_model_initialization import _MODELS
264277
from dbgpt.app.base import _initialize_db
265278

266279
# initialize db

dbgpt/app/base.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from dbgpt.component import SystemApp
1111
from dbgpt.util.parameter_utils import BaseParameters
1212

13-
from dbgpt.util._db_migration_utils import _ddl_init_and_upgrade
1413

1514
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
1615
sys.path.append(ROOT_PATH)
@@ -92,10 +91,27 @@ def _initialize_db_storage(param: "WebServerParameters"):
9291
9392
Now just support sqlite and mysql. If db type is sqlite, the db path is `pilot/meta_data/{db_name}.db`.
9493
"""
95-
default_meta_data_path = _initialize_db(
96-
try_to_create_db=not param.disable_alembic_upgrade
97-
)
98-
_ddl_init_and_upgrade(default_meta_data_path, param.disable_alembic_upgrade)
94+
_initialize_db(try_to_create_db=not param.disable_alembic_upgrade)
95+
96+
97+
def _migration_db_storage(param: "WebServerParameters"):
98+
"""Migration the db storage."""
99+
# Import all models to make sure they are registered with SQLAlchemy.
100+
from dbgpt.app.initialization.db_model_initialization import _MODELS
101+
102+
from dbgpt.configs.model_config import PILOT_PATH
103+
104+
default_meta_data_path = os.path.join(PILOT_PATH, "meta_data")
105+
if not param.disable_alembic_upgrade:
106+
from dbgpt.util._db_migration_utils import _ddl_init_and_upgrade
107+
from dbgpt.storage.metadata.db_manager import db
108+
109+
# try to create all tables
110+
try:
111+
db.create_all()
112+
except Exception as e:
113+
logger.warning(f"Create all tables stored in this metadata error: {str(e)}")
114+
_ddl_init_and_upgrade(default_meta_data_path, param.disable_alembic_upgrade)
99115

100116

101117
def _initialize_db(try_to_create_db: Optional[bool] = False) -> str:
@@ -112,7 +128,13 @@ def _initialize_db(try_to_create_db: Optional[bool] = False) -> str:
112128
default_meta_data_path = os.path.join(PILOT_PATH, "meta_data")
113129
os.makedirs(default_meta_data_path, exist_ok=True)
114130
if CFG.LOCAL_DB_TYPE == "mysql":
115-
db_url = f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}/{db_name}"
131+
db_url = (
132+
f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:"
133+
f"{urlquote(CFG.LOCAL_DB_PASSWORD)}@"
134+
f"{CFG.LOCAL_DB_HOST}:"
135+
f"{str(CFG.LOCAL_DB_PORT)}/"
136+
f"{db_name}?charset=utf8mb4&collation=utf8mb4_unicode_ci"
137+
)
116138
# Try to create database, if failed, will raise exception
117139
_create_mysql_database(db_name, db_url, try_to_create_db)
118140
else:
@@ -125,7 +147,7 @@ def _initialize_db(try_to_create_db: Optional[bool] = False) -> str:
125147
"pool_recycle": 3600,
126148
"pool_pre_ping": True,
127149
}
128-
initialize_db(db_url, db_name, engine_args, try_to_create_db=try_to_create_db)
150+
initialize_db(db_url, db_name, engine_args)
129151
return default_meta_data_path
130152

131153

@@ -161,7 +183,11 @@ def _create_mysql_database(db_name: str, db_url: str, try_to_create_db: bool = F
161183
no_db_name_url = db_url.rsplit("/", 1)[0]
162184
engine_no_db = create_engine(no_db_name_url)
163185
with engine_no_db.connect() as conn:
164-
conn.execute(DDL(f"CREATE DATABASE {db_name}"))
186+
conn.execute(
187+
DDL(
188+
f"CREATE DATABASE {db_name} CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci"
189+
)
190+
)
165191
logger.info(f"Database {db_name} successfully created")
166192
except SQLAlchemyError as e:
167193
logger.error(f"Failed to create database {db_name}: {e}")

dbgpt/app/component_configs.py

Lines changed: 6 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
11
from __future__ import annotations
22

33
import logging
4-
from typing import TYPE_CHECKING, Any, Type
54

6-
from dbgpt.component import ComponentType, SystemApp
5+
from dbgpt.component import SystemApp
76
from dbgpt._private.config import Config
87
from dbgpt.configs.model_config import MODEL_DISK_CACHE_DIR
98
from dbgpt.util.executor_utils import DefaultExecutorFactory
10-
from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory
119
from dbgpt.app.base import WebServerParameters
1210

13-
if TYPE_CHECKING:
14-
from langchain.embeddings.base import Embeddings
1511

1612
logger = logging.getLogger(__name__)
1713

@@ -24,7 +20,10 @@ def initialize_components(
2420
embedding_model_name: str,
2521
embedding_model_path: str,
2622
):
23+
# Lazy import to avoid high time cost
2724
from dbgpt.model.cluster.controller.controller import controller
25+
from dbgpt.app.initialization.embedding_component import _initialize_embedding_model
26+
from dbgpt.app.initialization.serve_initialization import register_serve_apps
2827

2928
# Register global default executor factory first
3029
system_app.register(DefaultExecutorFactory)
@@ -44,97 +43,8 @@ def initialize_components(
4443
)
4544
_initialize_model_cache(system_app)
4645
_initialize_awel(system_app)
47-
48-
49-
def _initialize_embedding_model(
50-
param: WebServerParameters,
51-
system_app: SystemApp,
52-
embedding_model_name: str,
53-
embedding_model_path: str,
54-
):
55-
if param.remote_embedding:
56-
logger.info("Register remote RemoteEmbeddingFactory")
57-
system_app.register(RemoteEmbeddingFactory, model_name=embedding_model_name)
58-
else:
59-
logger.info(f"Register local LocalEmbeddingFactory")
60-
system_app.register(
61-
LocalEmbeddingFactory,
62-
default_model_name=embedding_model_name,
63-
default_model_path=embedding_model_path,
64-
)
65-
66-
67-
class RemoteEmbeddingFactory(EmbeddingFactory):
68-
def __init__(self, system_app, model_name: str = None, **kwargs: Any) -> None:
69-
super().__init__(system_app=system_app)
70-
self._default_model_name = model_name
71-
self.kwargs = kwargs
72-
self.system_app = system_app
73-
74-
def init_app(self, system_app):
75-
self.system_app = system_app
76-
77-
def create(
78-
self, model_name: str = None, embedding_cls: Type = None
79-
) -> "Embeddings":
80-
from dbgpt.model.cluster import WorkerManagerFactory
81-
from dbgpt.model.cluster.embedding.remote_embedding import RemoteEmbeddings
82-
83-
if embedding_cls:
84-
raise NotImplementedError
85-
worker_manager = self.system_app.get_component(
86-
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
87-
).create()
88-
# Ignore model_name args
89-
return RemoteEmbeddings(self._default_model_name, worker_manager)
90-
91-
92-
class LocalEmbeddingFactory(EmbeddingFactory):
93-
def __init__(
94-
self,
95-
system_app,
96-
default_model_name: str = None,
97-
default_model_path: str = None,
98-
**kwargs: Any,
99-
) -> None:
100-
super().__init__(system_app=system_app)
101-
self._default_model_name = default_model_name
102-
self._default_model_path = default_model_path
103-
self._kwargs = kwargs
104-
self._model = self._load_model()
105-
106-
def init_app(self, system_app):
107-
pass
108-
109-
def create(
110-
self, model_name: str = None, embedding_cls: Type = None
111-
) -> "Embeddings":
112-
if embedding_cls:
113-
raise NotImplementedError
114-
return self._model
115-
116-
def _load_model(self) -> "Embeddings":
117-
from dbgpt.model.cluster.embedding.loader import EmbeddingLoader
118-
from dbgpt.model.cluster.worker.embedding_worker import _parse_embedding_params
119-
from dbgpt.model.parameter import (
120-
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
121-
BaseEmbeddingModelParameters,
122-
EmbeddingModelParameters,
123-
)
124-
125-
param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
126-
self._default_model_name, EmbeddingModelParameters
127-
)
128-
model_params: BaseEmbeddingModelParameters = _parse_embedding_params(
129-
model_name=self._default_model_name,
130-
model_path=self._default_model_path,
131-
param_cls=param_cls,
132-
**self._kwargs,
133-
)
134-
logger.info(model_params)
135-
loader = EmbeddingLoader()
136-
# Ignore model_name args
137-
return loader.load(self._default_model_name, model_params)
46+
# Register serve apps
47+
register_serve_apps(system_app)
13848

13949

14050
def _initialize_model_cache(system_app: SystemApp):

0 commit comments

Comments
 (0)