diff --git a/backend/apps/datasource/crud/datasource.py b/backend/apps/datasource/crud/datasource.py index fc7afb7c..033a9607 100644 --- a/backend/apps/datasource/crud/datasource.py +++ b/backend/apps/datasource/crud/datasource.py @@ -329,7 +329,7 @@ def preview(session: SessionDep, current_user: CurrentUser, id: int, data: Table conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config() sql: str = "" - if ds.type == "mysql" or ds.type == "doris" or ds.type == "starrocks": + if ds.type == "mysql" or ds.type == "doris" or ds.type == "starrocks" or ds.type == "hive": sql = f"""SELECT `{"`, `".join(fields)}` FROM `{data.table.table_name}` {where} LIMIT 100""" diff --git a/backend/apps/db/db.py b/backend/apps/db/db.py index 99e5911b..58d4a71a 100644 --- a/backend/apps/db/db.py +++ b/backend/apps/db/db.py @@ -38,7 +38,6 @@ from sqlalchemy.pool import NullPool from pyhive import hive - try: if os.path.exists(settings.ORACLE_CLIENT_PATH): oracledb.init_oracle_client( @@ -159,9 +158,10 @@ def get_engine(ds: CoreDatasource, timeout: int = 0) -> Engine: poolclass=NullPool) elif equals_ignore_case(ds.type, 'oracle'): engine = create_engine(get_uri(ds), poolclass=NullPool) - elif equals_ignore_case(ds.type, 'mysql'): # mysql + elif equals_ignore_case(ds.type, 'mysql'): # mysql ssl_mode = {"require": True} if conf.ssl else None - engine = create_engine(get_uri(ds), connect_args={"connect_timeout": conf.timeout, "ssl": ssl_mode}, poolclass=NullPool) + engine = create_engine(get_uri(ds), connect_args={"connect_timeout": conf.timeout, "ssl": ssl_mode}, + poolclass=NullPool) elif equals_ignore_case(ds.type, 'sqlite'): engine = create_engine(get_uri(ds), connect_args={"check_same_thread": False}, poolclass=NullPool) else: # ck @@ -271,7 +271,7 @@ def check_connection(trans: Optional[Trans], ds: CoreDatasource | AssistantOutDs if is_raise: raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') return False - + elif equals_ignore_case(ds.type, 'es'): es_conn = get_es_connect(conf) if es_conn.ping(): @@ -314,7 +314,7 @@ def get_version(ds: CoreDatasource | AssistantOutDsSchema): # conf.timeout = 10 db = DB.get_db(ds.type) sql = get_version_sql(ds, conf) - if equals_ignore_case(ds.type, 'sqlite'): + if not sql: return '' try: if db.connect_type == ConnectType.sqlalchemy: @@ -397,30 +397,6 @@ def get_schema(ds: CoreDatasource): res = cursor.fetchall() res_list = [item[0] for item in res] return res_list - elif equals_ignore_case(ds.type, 'hive'): - conn = hive.connect(host=conf.host, port=conf.port, username=conf.username, - database=conf.database, **extra_config_dict) - cursor = conn.cursor() - cursor.execute('SHOW DATABASES') - res = cursor.fetchall() - res_list = [item[0] for item in res] - cursor.close() - conn.close() - return res_list - elif equals_ignore_case(ds.type, 'doris', 'starrocks'): - with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, - port=conf.port, db=conf.database, connect_timeout=10, - read_timeout=10, **extra_config_dict) as conn, conn.cursor() as cursor: - cursor.execute('SHOW DATABASES') - res = cursor.fetchall() - res_list = [item[0] for item in res] - return res_list - elif equals_ignore_case(ds.type, 'ck'): - with get_session(ds) as session: - with session.execute(text('SHOW DATABASES')) as result: - res = result.fetchall() - res_list = [item[0] for item in res] - return res_list def get_tables(ds: CoreDatasource): @@ -456,7 +432,8 @@ def get_tables(ds: CoreDatasource): ssl_args = {'ssl': {'ssl_mode': 'REQUIRE'}} if conf.ssl else {} with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, port=conf.port, db=conf.database, connect_timeout=conf.timeout, - read_timeout=conf.timeout, **extra_config_dict, **ssl_args) as conn, conn.cursor() as cursor: + read_timeout=conf.timeout, **extra_config_dict, + **ssl_args) as conn, conn.cursor() as cursor: cursor.execute(sql, (sql_param,)) res = cursor.fetchall() res_list = [TableSchema(*item) for item in res] @@ -527,7 +504,8 @@ def get_fields(ds: CoreDatasource, table_name: str = None): ssl_args = {'ssl': {'ssl_mode': 'REQUIRE'}} if conf.ssl else {} with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, port=conf.port, db=conf.database, connect_timeout=conf.timeout, - read_timeout=conf.timeout, **extra_config_dict, **ssl_args) as conn, conn.cursor() as cursor: + read_timeout=conf.timeout, **extra_config_dict, + **ssl_args) as conn, conn.cursor() as cursor: cursor.execute(sql, (p1, p2)) res = cursor.fetchall() res_list = [ColumnSchema(*item) for item in res] @@ -684,7 +662,8 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column= ssl_args = {'ssl': {'ssl_mode': 'REQUIRE'}} if conf.ssl else {} with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, port=conf.port, db=conf.database, connect_timeout=conf.timeout, - read_timeout=conf.timeout, **extra_config_dict, **ssl_args) as conn, conn.cursor() as cursor: + read_timeout=conf.timeout, **extra_config_dict, + **ssl_args) as conn, conn.cursor() as cursor: try: cursor.execute(sql) res = cursor.fetchall() diff --git a/backend/apps/db/db_sql.py b/backend/apps/db/db_sql.py index d4c3c74f..fa790c8a 100644 --- a/backend/apps/db/db_sql.py +++ b/backend/apps/db/db_sql.py @@ -29,9 +29,7 @@ def get_version_sql(ds: CoreDatasource, conf: DatasourceConf): return """ SELECT * FROM v$version """ - elif equals_ignore_case(ds.type, "redshift"): - return '' - elif equals_ignore_case(ds.type, "sqlite"): + elif equals_ignore_case(ds.type, "redshift", "sqlite", "hive"): return '' diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 727efde3..e0f345f1 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -54,7 +54,7 @@ dependencies = [ "ldap3>=2.9.1", "sqlglot>=28.6.0", "numpy==2.3.5", - "pyhive[hive]>=0.7.0", + "pyhive[hive_pure_sasl]>=0.7.0", "thrift-sasl" ]