diff --git a/pyobvector/client/ob_client.py b/pyobvector/client/ob_client.py index 1254b04..8a0c010 100644 --- a/pyobvector/client/ob_client.py +++ b/pyobvector/client/ob_client.py @@ -158,6 +158,19 @@ def _is_seekdb(self) -> bool: logger.warning(f"Failed to query version: {e}") return is_seekdb + def _flush_seekdb_index(self) -> None: + """Flush async HNSW index builds in embedded seekdb after insert. + + No-op when not using embedded seekdb or when the server does not expose + a ``refresh_index`` method. + """ + server = self.engine.get_execution_options().get("seekdb_server") + if server is not None and hasattr(server, "refresh_index"): + try: + server.refresh_index() + except Exception as e: + logger.warning("seekdb index refresh failed after insert: %s", e) + def _insert_partition_hint_for_query_sql(self, sql: str, partition_hint: str): from_index = sql.find("FROM") assert from_index != -1 @@ -282,6 +295,7 @@ def insert( .with_hint(f"PARTITION({partition_name})") .values(data) ) + self._flush_seekdb_index() def upsert( self, diff --git a/pyobvector/client/ob_vec_client.py b/pyobvector/client/ob_vec_client.py index be1f192..8deece0 100644 --- a/pyobvector/client/ob_vec_client.py +++ b/pyobvector/client/ob_vec_client.py @@ -408,11 +408,11 @@ def ann_search( ) if partition_names is None: - return conn.execute(text(stmt_str)) + return conn.execute(text(stmt_str)).freeze()() stmt_str = self._insert_partition_hint_for_query_sql( stmt_str, f"PARTITION({', '.join(partition_names)})" ) - return conn.execute(text(stmt_str)) + return conn.execute(text(stmt_str)).freeze()() def post_ann_search( self, @@ -487,7 +487,7 @@ def post_ann_search( ) ) ) - return conn.execute(stmt) + return conn.execute(stmt).freeze()() stmt_str = str( stmt.compile( dialect=self.engine.dialect, @@ -499,7 +499,7 @@ def post_ann_search( ) if str_list is not None: str_list.append(stmt_str) - return conn.execute(text(stmt_str)) + return conn.execute(text(stmt_str)).freeze()() def precise_search( self, @@ -537,7 +537,7 @@ def precise_search( stmt = stmt.where(*where_clause) with self.engine.connect() as conn: with conn.begin(): - return conn.execute(stmt) + return conn.execute(stmt).freeze()() else: stmt = ( select(table) @@ -548,4 +548,4 @@ def precise_search( stmt = stmt.where(*where_clause) with self.engine.connect() as conn: with conn.begin(): - return conn.execute(stmt) + return conn.execute(stmt).freeze()() diff --git a/pyobvector/client/seekdb_engine.py b/pyobvector/client/seekdb_engine.py index d6188a1..aa82a92 100644 --- a/pyobvector/client/seekdb_engine.py +++ b/pyobvector/client/seekdb_engine.py @@ -5,6 +5,7 @@ Requires optional dependency: pip install pyobvector[pyseekdb] """ +import logging import re from collections.abc import Mapping, Sequence from typing import Any @@ -12,6 +13,37 @@ from sqlalchemy import create_engine from sqlalchemy.pool import NullPool +logger = logging.getLogger(__name__) + +_QUERY_SQL_PREFIXES = ("SELECT", "SHOW", "DESCRIBE", "DESC") + + +def _is_query_sql(sql: str) -> bool: + """Return True if sql is a row-returning statement (SELECT/SHOW/DESCRIBE/DESC).""" + return sql.strip().upper().startswith(_QUERY_SQL_PREFIXES) + + +def _description_from_select(sql: str) -> list[tuple]: + """Extract DBAPI description tuples from SELECT column list using sqlglot. + + Falls back to [] when sqlglot cannot parse the statement (e.g. SHOW/DESCRIBE). + """ + try: + import sqlglot + import sqlglot.expressions as exp + + parsed = sqlglot.parse_one(sql) + if parsed is None or not hasattr(parsed, "selects"): + return [] + cols = [] + for sel in parsed.selects: + name = sel.alias or (sel.name if isinstance(sel, exp.Column) else sel.sql()) + cols.append(name) + return [(c, None, None, None, None, None, None) for c in cols] + except Exception as e: + logger.debug("_description_from_select could not parse SQL: %s", e) + return [] + def _pyformat_to_format(sql: str, params: Any) -> tuple[str, list[Any]]: """Convert SQLAlchemy pyformat (%(name)s) + dict params to %s + list for pyseekdb.""" @@ -50,7 +82,14 @@ def __init__(self, client: Any) -> None: def execute(self, operation: str, parameters: Sequence[Any] | None = None) -> None: result = _execute_via_pyseekdb(self._client, operation, parameters or ()) if not result: - self._description = None + # For SELECT/SHOW/DESCRIBE: populate description so SQLAlchemy treats this + # as a row-returning result with 0 rows rather than a non-returning statement. + # _NoResultMetaData (from None) would cause ResourceClosedError on fetchall(). + self._description = ( + _description_from_select(operation) + if _is_query_sql(operation) + else None + ) self._rows = [] self.rowcount = 0 return @@ -135,12 +174,14 @@ def create_engine_from_client(pyseekdb_client: Any, **kwargs: Any): def creator() -> _SeekdbConnection: return _SeekdbConnection(server) - return create_engine( + engine = create_engine( "mysql+oceanbase://root:@127.0.0.1:2881/" + database, creator=creator, poolclass=NullPool, **kwargs, ) + engine.update_execution_options(seekdb_server=server) + return engine def create_embedded_engine(path: str, database: str = "test", **kwargs: Any): diff --git a/tests/test_fts_index.py b/tests/test_fts_index.py index 10bed93..98245f4 100644 --- a/tests/test_fts_index.py +++ b/tests/test_fts_index.py @@ -559,7 +559,7 @@ def test_fts_analyzer_sql_with_properties(self): idx = self._build_index(props) sql = compile_create_fts_index(CreateFtsIndex(idx), _MockCompiler()) self.assertIn("WITH PARSER analyzer", sql) - self.assertIn(f"PARSER_PROPERTIES = ({props})", sql) + self.assertIn(f"PARSER_PROPERTIES=({props})", sql) def test_fts_analyzer_requires_parser_properties(self): param = FtsIndexParam( @@ -587,7 +587,7 @@ def test_fts_parser_properties_not_tied_to_analyzer_type(self): ) sql = compile_create_fts_index(CreateFtsIndex(idx), _MockCompiler()) self.assertIn("WITH PARSER ngram", sql) - self.assertIn("PARSER_PROPERTIES = (token_size = 2)", sql) + self.assertIn("PARSER_PROPERTIES=(token_size = 2)", sql) def test_make_analyzer_properties_default(self): result = make_analyzer_properties()