-
Notifications
You must be signed in to change notification settings - Fork 141
V0 version of embedding ingestion core. #1964
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
53b551b
8df3d91
25a3406
880fa0d
b7e9ba9
9f23be8
844b92e
470c64b
ca07810
2a4a745
da71463
11465f2
1b071cf
fb91528
b8eee68
9eee198
da4f7a0
4f0dcbe
c061fa2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| FROM python:3.10-slim | ||
|
|
||
| WORKDIR /app | ||
|
|
||
| COPY requirements.txt . | ||
| RUN pip install -r requirements.txt | ||
|
|
||
| COPY . . | ||
|
|
||
| CMD ["python", "main.py"] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
|
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,169 @@ | ||
| # Copyright 2026 Google LLC | ||
|
shixiao-coder marked this conversation as resolved.
|
||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Helper utilities for embedding workflows.""" | ||
|
|
||
| import itertools | ||
| import logging | ||
|
shixiao-coder marked this conversation as resolved.
|
||
| import time | ||
| from datetime import datetime | ||
| from google.cloud.spanner_v1.param_types import TIMESTAMP, STRING, Array, Struct, StructField | ||
|
|
||
|
|
||
| _BATCH_SIZE = 500 | ||
|
|
||
| def get_latest_lock_timestamp(database): | ||
|
shixiao-coder marked this conversation as resolved.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be better to move this logic into the ingestion_helper/spanner_client.py so metadata handling logic can be in a single place?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yep, I'm planning to move it there in a following PR. lmk if it would be better to just move it there in the same PR |
||
| """Gets the latest AcquiredTimestamp from IngestionLock table. | ||
|
|
||
| Args: | ||
| database: google.cloud.spanner.Database object. | ||
|
|
||
| Returns: | ||
| The latest AcquiredTimestamp as a datetime object, or None if no entries exist. | ||
| """ | ||
| time_lock_sql = "SELECT MAX(AcquiredTimestamp) FROM IngestionLock" | ||
| try: | ||
| with database.snapshot() as snapshot: | ||
| results = snapshot.execute_sql(time_lock_sql) | ||
| for row in results: | ||
| return row[0] | ||
| except Exception as e: | ||
| logging.error(f"Error fetching latest lock timestamp: {e}") | ||
| raise | ||
| return None | ||
|
|
||
| def get_updated_nodes(database, timestamp, node_types): | ||
| """Gets subject_ids and names from Node table where update_timestamp > timestamp. | ||
| Yields results to avoid loading all into memory. | ||
|
|
||
| Args: | ||
| database: google.cloud.spanner.Database object. | ||
| timestamp: datetime object to filter by. | ||
| node_types: A list of strings representing the node types to filter by. | ||
|
|
||
| Yields: | ||
| Dictionaries containing subject_id and name. | ||
| """ | ||
| timestamp_condition = "update_timestamp > @timestamp" if timestamp else "TRUE" | ||
|
shixiao-coder marked this conversation as resolved.
|
||
|
|
||
| updated_node_sql = f""" | ||
| SELECT subject_id, name, types FROM Node | ||
| WHERE name IS NOT NULL | ||
| AND {timestamp_condition} | ||
| AND EXISTS ( | ||
| SELECT 1 FROM UNNEST(types) AS t WHERE t IN UNNEST(@node_types) | ||
| ) | ||
|
shixiao-coder marked this conversation as resolved.
|
||
| """ | ||
|
|
||
| params = {"node_types": node_types} | ||
| param_types = {"node_types": Array(STRING)} | ||
|
|
||
| if timestamp: | ||
| logging.info(f"Filtering valid nodes updated after {timestamp}") | ||
| params["timestamp"] = timestamp | ||
| param_types["timestamp"] = TIMESTAMP | ||
| else: | ||
| logging.info("No timestamp provided, reading all valid nodes.") | ||
|
|
||
| try: | ||
| with database.snapshot() as snapshot: | ||
| results = snapshot.execute_sql(updated_node_sql, params=params, param_types=param_types) | ||
|
shixiao-coder marked this conversation as resolved.
|
||
| fields = None | ||
| for row in results: | ||
| if fields is None: | ||
| fields = [field.name for field in results.fields] | ||
| yield dict(zip(fields, row)) | ||
| except Exception as e: | ||
| logging.error(f"Error fetching updated nodes: {e}") | ||
| raise | ||
|
|
||
|
|
||
| def filter_and_convert_nodes(nodes_generator): | ||
| """Filters out nodes without a name and converts dictionaries to tuples. | ||
| Reads from a generator and yields results. | ||
|
|
||
| Args: | ||
| nodes_generator: A generator yielding dictionaries containing subject_id, name, and types. | ||
|
|
||
| Yields: | ||
| Tuples (subject_id, embedding_content, types). | ||
| """ | ||
| for node in nodes_generator: | ||
| if node.get("name"): | ||
| yield (node.get("subject_id"), node.get("name"), node.get("types")) | ||
|
|
||
|
|
||
| def generate_embeddings_partitioned(database, nodes_generator): | ||
| """Generates embeddings in batches using standard transactions. | ||
| Processes nodes in chunks of 500 to avoid transaction size limits. | ||
| Accepts a generator to avoid loading all nodes into memory. | ||
|
|
||
| Args: | ||
| database: google.cloud.spanner.Database object. | ||
| nodes_generator: A generator yielding tuples containing (subject_id, embedding_content). | ||
|
|
||
| Returns: | ||
| The number of affected rows. | ||
| """ | ||
| global _BATCH_SIZE | ||
| total_rows_affected = 0 | ||
|
|
||
| logging.info(f"Generating embeddings in batches of {_BATCH_SIZE}.") | ||
|
|
||
| embeddings_sql = """ | ||
| INSERT OR UPDATE INTO NodeEmbeddings (subject_id, embedding_content, embeddings, types) | ||
| SELECT subject_id, content, embeddings.values, types | ||
| FROM ML.PREDICT( | ||
| MODEL text_embeddings, | ||
| (SELECT subject_id, embedding_content AS content, types, "RETRIEVAL_QUERY" AS task_type FROM UNNEST(@nodes)) | ||
| ) | ||
| """ | ||
|
|
||
| struct_type = Struct([ | ||
| StructField("subject_id", STRING), | ||
| StructField("embedding_content", STRING), | ||
| StructField("types", Array(STRING)) | ||
| ]) | ||
|
|
||
| def chunked(iterable, n): | ||
| it = iter(iterable) | ||
| while True: | ||
| chunk = list(itertools.islice(it, n)) | ||
| if not chunk: | ||
| break | ||
| yield chunk | ||
|
|
||
| for batch in chunked(nodes_generator, _BATCH_SIZE): | ||
| params = {"nodes": batch} | ||
| param_types = {"nodes": Array(struct_type)} | ||
|
|
||
| def _execute_dml(transaction): | ||
| return transaction.execute_update(embeddings_sql, params=params, param_types=param_types, timeout=300) | ||
|
|
||
| try: | ||
| row_count = database.run_in_transaction(_execute_dml) | ||
| total_rows_affected += row_count | ||
| logging.info(f"Processed batch of {len(batch)} nodes. Affected {row_count} rows.") | ||
| time.sleep(0.5) | ||
| except Exception as e: | ||
| logging.error(f"Error executing batch transaction: {e}") | ||
| raise | ||
|
|
||
| logging.info(f"Completed batch processing. Total affected rows: {total_rows_affected}") | ||
| return total_rows_affected | ||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| # Copyright 2026 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import os | ||
| import logging | ||
| from google.cloud import spanner | ||
| from embedding_utils import get_latest_lock_timestamp, get_updated_nodes, filter_and_convert_nodes, generate_embeddings_partitioned | ||
|
|
||
| logging.basicConfig(level=logging.INFO) | ||
|
|
||
| def main(): | ||
| # Read configuration from environment variables | ||
| instance_id = os.environ.get("SPANNER_INSTANCE") | ||
| database_id = os.environ.get("SPANNER_DATABASE") | ||
| project_id = os.environ.get("SPANNER_PROJECT") | ||
|
|
||
| if not instance_id or not database_id: | ||
| logging.error("SPANNER_INSTANCE or SPANNER_DATABASE environment variables not set.") | ||
| exit(1) | ||
|
|
||
| logging.info(f"Connecting to Spanner instance: {instance_id}, database: {database_id}, project: {project_id}") | ||
|
|
||
| spanner_client = spanner.Client(project=project_id) | ||
| instance = spanner_client.instance(instance_id) | ||
| database = instance.database(database_id) | ||
|
|
||
| node_types = ["StatisticalVariable", "Topic"] | ||
|
|
||
| try: | ||
| logging.info(f"Job started. Fetching all nodes for types: {node_types}") | ||
| timestamp = get_latest_lock_timestamp(database) | ||
| nodes = get_updated_nodes(database, timestamp, node_types) | ||
|
|
||
| converted_nodes = filter_and_convert_nodes(nodes) | ||
|
|
||
| affected_rows = generate_embeddings_partitioned(database, converted_nodes) | ||
|
|
||
| logging.info(f"Job completed successfully. Total affected rows: {affected_rows}") | ||
| except Exception as e: | ||
| logging.error(f"Job failed with error: {e}") | ||
| exit(1) | ||
|
|
||
| if __name__ == "__main__": | ||
| main() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| functions-framework==3.* | ||
| google-cloud-spanner | ||
| google-auth |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious why we have this docker file? We want to deploy this as a cloud run server instead of a cloud function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The runtime is a bit long with 50-60min and processing embedding data is more like actual job and might be a bit heavy for function.
I'm open to switch to function if needed.