Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions pyobvector/client/ob_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,19 @@ def _is_seekdb(self) -> bool:
logger.warning(f"Failed to query version: {e}")
return is_seekdb

def _flush_seekdb_index(self) -> None:
"""Flush async HNSW index builds in embedded seekdb after insert.

No-op when not using embedded seekdb or when the server does not expose
a ``refresh_index`` method.
"""
Comment thread
whhe marked this conversation as resolved.
server = self.engine.get_execution_options().get("seekdb_server")
if server is not None and hasattr(server, "refresh_index"):
try:
server.refresh_index()
except Exception as e:
logger.warning("seekdb index refresh failed after insert: %s", e)

def _insert_partition_hint_for_query_sql(self, sql: str, partition_hint: str):
from_index = sql.find("FROM")
assert from_index != -1
Expand Down Expand Up @@ -282,6 +295,7 @@ def insert(
.with_hint(f"PARTITION({partition_name})")
.values(data)
)
self._flush_seekdb_index()

def upsert(
self,
Expand Down
12 changes: 6 additions & 6 deletions pyobvector/client/ob_vec_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,11 +408,11 @@ def ann_search(
)

if partition_names is None:
return conn.execute(text(stmt_str))
return conn.execute(text(stmt_str)).freeze()()
stmt_str = self._insert_partition_hint_for_query_sql(
stmt_str, f"PARTITION({', '.join(partition_names)})"
)
return conn.execute(text(stmt_str))
return conn.execute(text(stmt_str)).freeze()()

def post_ann_search(
self,
Expand Down Expand Up @@ -487,7 +487,7 @@ def post_ann_search(
)
)
)
return conn.execute(stmt)
return conn.execute(stmt).freeze()()
stmt_str = str(
stmt.compile(
dialect=self.engine.dialect,
Expand All @@ -499,7 +499,7 @@ def post_ann_search(
)
if str_list is not None:
str_list.append(stmt_str)
return conn.execute(text(stmt_str))
return conn.execute(text(stmt_str)).freeze()()

def precise_search(
self,
Expand Down Expand Up @@ -537,7 +537,7 @@ def precise_search(
stmt = stmt.where(*where_clause)
with self.engine.connect() as conn:
with conn.begin():
return conn.execute(stmt)
return conn.execute(stmt).freeze()()
else:
stmt = (
select(table)
Expand All @@ -548,4 +548,4 @@ def precise_search(
stmt = stmt.where(*where_clause)
with self.engine.connect() as conn:
with conn.begin():
return conn.execute(stmt)
return conn.execute(stmt).freeze()()
45 changes: 43 additions & 2 deletions pyobvector/client/seekdb_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,45 @@
Requires optional dependency: pip install pyobvector[pyseekdb]
"""

import logging
import re
from collections.abc import Mapping, Sequence
from typing import Any

from sqlalchemy import create_engine
from sqlalchemy.pool import NullPool

logger = logging.getLogger(__name__)

_QUERY_SQL_PREFIXES = ("SELECT", "SHOW", "DESCRIBE", "DESC")


def _is_query_sql(sql: str) -> bool:
"""Return True if sql is a row-returning statement (SELECT/SHOW/DESCRIBE/DESC)."""
return sql.strip().upper().startswith(_QUERY_SQL_PREFIXES)


def _description_from_select(sql: str) -> list[tuple]:
"""Extract DBAPI description tuples from SELECT column list using sqlglot.

Falls back to [] when sqlglot cannot parse the statement (e.g. SHOW/DESCRIBE).
"""
try:
import sqlglot
import sqlglot.expressions as exp

parsed = sqlglot.parse_one(sql)
if parsed is None or not hasattr(parsed, "selects"):
return []
cols = []
for sel in parsed.selects:
name = sel.alias or (sel.name if isinstance(sel, exp.Column) else sel.sql())
cols.append(name)
return [(c, None, None, None, None, None, None) for c in cols]
except Exception as e:
logger.debug("_description_from_select could not parse SQL: %s", e)
return []


def _pyformat_to_format(sql: str, params: Any) -> tuple[str, list[Any]]:
"""Convert SQLAlchemy pyformat (%(name)s) + dict params to %s + list for pyseekdb."""
Expand Down Expand Up @@ -50,7 +82,14 @@ def __init__(self, client: Any) -> None:
def execute(self, operation: str, parameters: Sequence[Any] | None = None) -> None:
result = _execute_via_pyseekdb(self._client, operation, parameters or ())
if not result:
self._description = None
# For SELECT/SHOW/DESCRIBE: populate description so SQLAlchemy treats this
# as a row-returning result with 0 rows rather than a non-returning statement.
# _NoResultMetaData (from None) would cause ResourceClosedError on fetchall().
self._description = (
_description_from_select(operation)
if _is_query_sql(operation)
else None
)
self._rows = []
self.rowcount = 0
return
Expand Down Expand Up @@ -135,12 +174,14 @@ def create_engine_from_client(pyseekdb_client: Any, **kwargs: Any):
def creator() -> _SeekdbConnection:
return _SeekdbConnection(server)

return create_engine(
engine = create_engine(
"mysql+oceanbase://root:@127.0.0.1:2881/" + database,
creator=creator,
poolclass=NullPool,
**kwargs,
)
engine.update_execution_options(seekdb_server=server)
return engine


def create_embedded_engine(path: str, database: str = "test", **kwargs: Any):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_fts_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def test_fts_analyzer_sql_with_properties(self):
idx = self._build_index(props)
sql = compile_create_fts_index(CreateFtsIndex(idx), _MockCompiler())
self.assertIn("WITH PARSER analyzer", sql)
self.assertIn(f"PARSER_PROPERTIES = ({props})", sql)
self.assertIn(f"PARSER_PROPERTIES=({props})", sql)

def test_fts_analyzer_requires_parser_properties(self):
param = FtsIndexParam(
Expand Down Expand Up @@ -587,7 +587,7 @@ def test_fts_parser_properties_not_tied_to_analyzer_type(self):
)
sql = compile_create_fts_index(CreateFtsIndex(idx), _MockCompiler())
self.assertIn("WITH PARSER ngram", sql)
self.assertIn("PARSER_PROPERTIES = (token_size = 2)", sql)
self.assertIn("PARSER_PROPERTIES=(token_size = 2)", sql)

def test_make_analyzer_properties_default(self):
result = make_analyzer_properties()
Expand Down
Loading