From e716b8dbb71d9c98f4a0e941e62d918cf138e56b Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Fri, 27 Dec 2024 15:53:37 +0000 Subject: [PATCH 01/15] Update lock file --- uv.lock | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/uv.lock b/uv.lock index e2acbb2..1130f44 100644 --- a/uv.lock +++ b/uv.lock @@ -3301,6 +3301,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695", size = 24521 }, ] +[[package]] +name = "tabulate" +version = "0.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ec/fe/802052aecb21e3797b8f7902564ab6ea0d60ff8ca23952079064155d1ae1/tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c", size = 81090 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/40/44/4a5f08c96eb108af5cb50b41f76142f0afa346dfa99d5296fe7202a11854/tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f", size = 35252 }, +] + [[package]] name = "tenacity" version = "9.0.0" @@ -3337,11 +3346,13 @@ dependencies = [ { name = "networkx" }, { name = "numpy" }, { name = "openai" }, + { name = "pandas" }, { name = "pydantic" }, { name = "python-dotenv" }, { name = "pyyaml" }, { name = "rich" }, { name = "sqlglot", extra = ["rs"] }, + { name = "tabulate" }, { name = "tenacity" }, { name = "typer" }, ] @@ -3377,6 +3388,7 @@ requires-dist = [ { name = "networkx", specifier = ">=3.4.2" }, { name = "numpy", specifier = "<2.0.0" }, { name = "openai", specifier = ">=1.55.3" }, + { name = "pandas", specifier = ">=2.2.3" }, { name = "pyarrow", marker = "extra == 'databricks'", specifier = ">=14.0.2,<17" }, { name = "pydantic", specifier = ">=2.10.2" }, { name = "python-dotenv", specifier = ">=1.0.1" }, @@ -3384,6 +3396,7 @@ requires-dist = [ { name = "rich", specifier = ">=13.9.4" }, { name = "snowflake-connector-python", marker = "extra == 'snowflake'", specifier = ">=3.12.3" }, { name = "sqlglot", extras = ["rs"], specifier = ">=25.32.1" }, + { name = "tabulate", specifier = ">=0.9.0" }, { name = "tenacity", specifier = ">=9.0.0" }, { name = "typer", specifier = ">=0.14.0" }, ] From 1132cd0374d7284bfa305170f5aed3c55a7b6958 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Fri, 27 Dec 2024 16:25:15 +0000 Subject: [PATCH 02/15] Parallel solving --- .../autogen_text_2_sql/autogen_text_2_sql.py | 121 ++--------- .../parallel_query_solving_agent.py | 74 +++++++ .../inner_autogen_text_2_sql.py | 189 ++++++++++++++++++ 3 files changed, 277 insertions(+), 107 deletions(-) create mode 100644 text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py create mode 100644 text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py index f6d4e98..0c9e1ba 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py @@ -1,7 +1,5 @@ -""" -Copyright (c) Microsoft Corporation. -Licensed under the MIT License. -""" +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from autogen_agentchat.conditions import ( TextMentionTermination, MaxMessageTermination, @@ -10,11 +8,8 @@ from autogen_text_2_sql.creators.llm_model_creator import LLMModelCreator from autogen_text_2_sql.creators.llm_agent_creator import LLMAgentCreator import logging -from autogen_text_2_sql.custom_agents.sql_query_cache_agent import ( - SqlQueryCacheAgent, -) -from autogen_text_2_sql.custom_agents.sql_schema_selection_agent import ( - SqlSchemaSelectionAgent, +from autogen_text_2_sql.custom_agents.parallel_query_solving_agent import ( + ParallelQuerySolvingAgent, ) from autogen_text_2_sql.custom_agents.answer_and_sources_agent import ( AnswerAndSourcesAgent, @@ -45,23 +40,9 @@ async def on_messages_stream(self, messages, sender=None, config=None): class AutoGenText2Sql: def __init__(self, engine_specific_rules: str, **kwargs: dict): - self.pre_run_query_cache = False self.target_engine = os.environ["Text2Sql__DatabaseEngine"].upper() self.engine_specific_rules = engine_specific_rules self.kwargs = kwargs - self.set_mode() - - def set_mode(self): - """Set the mode of the plugin based on the environment variables.""" - self.pre_run_query_cache = ( - os.environ.get("Text2Sql__PreRunQueryCache", "True").lower() == "true" - ) - self.use_column_value_store = ( - os.environ.get("Text2Sql__UseColumnValueStore", "True").lower() == "true" - ) - self.use_query_cache = ( - os.environ.get("Text2Sql__UseQueryCache", "True").lower() == "true" - ) def get_all_agents(self): """Get all agents for the complete flow.""" @@ -72,43 +53,8 @@ def get_all_agents(self): "query_rewrite_agent", current_datetime=current_datetime ) - self.sql_query_generation_agent = LLMAgentCreator.create( - "sql_query_generation_agent", - target_engine=self.target_engine, - engine_specific_rules=self.engine_specific_rules, - **self.kwargs, - ) - - # If relationship_paths not provided, use a generic template - if "relationship_paths" not in self.kwargs: - self.kwargs[ - "relationship_paths" - ] = """ - Common relationship paths to consider: - - Transaction → Related Dimensions (for basic analysis) - - Geographic → Location hierarchies (for geographic analysis) - - Temporal → Date hierarchies (for time-based analysis) - - Entity → Attributes (for entity-specific analysis) - """ - - self.sql_schema_selection_agent = SqlSchemaSelectionAgent( - target_engine=self.target_engine, - engine_specific_rules=self.engine_specific_rules, - **self.kwargs, - ) - - self.sql_query_correction_agent = LLMAgentCreator.create( - "sql_query_correction_agent", - target_engine=self.target_engine, - engine_specific_rules=self.engine_specific_rules, - **self.kwargs, - ) - - self.sql_disambiguation_agent = LLMAgentCreator.create( - "sql_disambiguation_agent", - target_engine=self.target_engine, - engine_specific_rules=self.engine_specific_rules, - **self.kwargs, + self.parallel_query_solving_agent = ParallelQuerySolvingAgent( + engine_specific_rules=self.engine_specific_rules, **self.kwargs ) self.answer_and_sources_agent = AnswerAndSourcesAgent() @@ -119,17 +65,10 @@ def get_all_agents(self): agents = [ self.user_proxy, self.query_rewrite_agent, - self.sql_query_generation_agent, - self.sql_schema_selection_agent, - self.sql_query_correction_agent, - self.sql_disambiguation_agent, + self.parallel_query_solving_agent, self.answer_and_sources_agent, ] - if self.use_query_cache: - self.query_cache_agent = SqlQueryCacheAgent() - agents.append(self.query_cache_agent) - return agents @property @@ -149,51 +88,19 @@ def unified_selector(self, messages): decision = None # If this is the first message start with query_rewrite_agent - if len(messages) == 1: + if current_agent == "start": decision = "query_rewrite_agent" # Handle transition after query rewriting elif current_agent == "query_rewrite_agent": - decision = ( - "sql_query_cache_agent" - if self.use_query_cache - else "sql_schema_selection_agent" - ) - # Handle subsequent agent transitions - elif current_agent == "sql_query_cache_agent": - # Always go through schema selection after cache check - decision = "sql_schema_selection_agent" - elif current_agent == "sql_schema_selection_agent": - decision = "sql_disambiguation_agent" - elif current_agent == "sql_disambiguation_agent": - decision = "sql_query_generation_agent" - elif current_agent == "sql_query_generation_agent": - decision = "sql_query_correction_agent" - elif current_agent == "sql_query_correction_agent": - try: - correction_result = json.loads(messages[-1].content) - if isinstance(correction_result, dict): - if "answer" in correction_result and "sources" in correction_result: - decision = "answer_and_sources_agent" - elif "corrected_query" in correction_result: - if correction_result.get("executing", False): - decision = "sql_query_correction_agent" - else: - decision = "sql_query_generation_agent" - elif "error" in correction_result: - decision = "sql_query_generation_agent" - elif isinstance(correction_result, list) and len(correction_result) > 0: - if "requested_fix" in correction_result[0]: - decision = "sql_query_generation_agent" - - if decision is None: - decision = "sql_query_generation_agent" - except json.JSONDecodeError: - decision = "sql_query_generation_agent" - elif current_agent == "answer_and_sources_agent": - decision = "user_proxy" # Let user_proxy send TERMINATE + decision = "parallel_query_solving_agent" + # Handle transition after parallel query solving + elif current_agent == "parallel_query_solving_agent": + decision = "answer_and_sources_agent" if decision: logging.info(f"Agent transition: {current_agent} -> {decision}") + else: + logging.info(f"No agent transition defined from {current_agent}") return decision diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py new file mode 100644 index 0000000..e50c79b --- /dev/null +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py @@ -0,0 +1,74 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from typing import AsyncGenerator, List, Sequence + +from autogen_agentchat.agents import BaseChatAgent +from autogen_agentchat.base import Response +from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage +from autogen_core import CancellationToken +import json +import logging +import asyncio +from autogen_text_2_sql.inner_autogen_text_2_sql import InnerAutoGenText2Sql + + +class ParallelQuerySolvingAgent(BaseChatAgent): + def __init__(self, engine_specific_rules: str, **kwargs: dict): + super().__init__( + "parallel_query_solving_agent", + "An agent that solves each query in parallel.", + ) + + self.engine_specific_rules = engine_specific_rules + self.kwargs = kwargs + + @property + def produced_message_types(self) -> List[type[ChatMessage]]: + return [TextMessage] + + async def on_messages( + self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken + ) -> Response: + # Calls the on_messages_stream. + response: Response | None = None + async for message in self.on_messages_stream(messages, cancellation_token): + if isinstance(message, Response): + response = message + assert response is not None + return response + + async def on_messages_stream( + self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken + ) -> AsyncGenerator[AgentMessage | Response, None]: + last_response = messages[-1].content + + # Load the json of the last message to populate the final output object + query_rewrites = json.loads(last_response) + + logging.info(f"Query Rewrite: {query_rewrites}") + + inner_solving_tasks = [] + + for query_rewrite in query_rewrites: + # Create an instance of the InnerAutoGenText2Sql class + inner_autogen_text_2_sql = InnerAutoGenText2Sql( + self.engine_specific_rules, **self.kwargs + ) + + inner_solving_tasks.append( + inner_autogen_text_2_sql.run_stream(task=query_rewrite) + ) + + # Wait for all the inner solving tasks to complete + inner_solving_results = await asyncio.gather(*inner_solving_tasks) + + logging.info(f"Inner Solving Results: {inner_solving_results}") + + yield Response( + chat_message=TextMessage( + content=json.dumps(inner_solving_results), source=self.name + ) + ) + + async def on_reset(self, cancellation_token: CancellationToken) -> None: + pass diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py new file mode 100644 index 0000000..3b4926e --- /dev/null +++ b/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py @@ -0,0 +1,189 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from autogen_agentchat.conditions import ( + TextMentionTermination, + MaxMessageTermination, +) +from autogen_agentchat.teams import SelectorGroupChat +from autogen_text_2_sql.creators.llm_model_creator import LLMModelCreator +from autogen_text_2_sql.creators.llm_agent_creator import LLMAgentCreator +import logging +from autogen_text_2_sql.custom_agents.sql_query_cache_agent import ( + SqlQueryCacheAgent, +) +from autogen_text_2_sql.custom_agents.sql_schema_selection_agent import ( + SqlSchemaSelectionAgent, +) +from autogen_agentchat.agents import UserProxyAgent +from autogen_agentchat.messages import TextMessage +from autogen_agentchat.base import Response +import json +import os + + +class EmptyResponseUserProxyAgent(UserProxyAgent): + """UserProxyAgent that automatically responds with empty messages.""" + + def __init__(self, name): + super().__init__(name=name) + self._has_responded = False + + async def on_messages_stream(self, messages, sender=None, config=None): + """Auto-respond with empty message and return Response object.""" + message = TextMessage(content="", source=self.name) + if not self._has_responded: + self._has_responded = True + yield message + yield Response(chat_message=message) + + +class InnerAutoGenText2Sql: + def __init__(self, engine_specific_rules: str, **kwargs: dict): + self.pre_run_query_cache = False + self.target_engine = os.environ["Text2Sql__DatabaseEngine"].upper() + self.engine_specific_rules = engine_specific_rules + self.kwargs = kwargs + self.set_mode() + + def set_mode(self): + """Set the mode of the plugin based on the environment variables.""" + self.pre_run_query_cache = ( + os.environ.get("Text2Sql__PreRunQueryCache", "True").lower() == "true" + ) + self.use_column_value_store = ( + os.environ.get("Text2Sql__UseColumnValueStore", "True").lower() == "true" + ) + self.use_query_cache = ( + os.environ.get("Text2Sql__UseQueryCache", "True").lower() == "true" + ) + + def get_all_agents(self): + """Get all agents for the complete flow.""" + # Get current datetime for the Query Rewrite Agent + self.sql_query_generation_agent = LLMAgentCreator.create( + "sql_query_generation_agent", + target_engine=self.target_engine, + engine_specific_rules=self.engine_specific_rules, + **self.kwargs, + ) + + # If relationship_paths not provided, use a generic template + if "relationship_paths" not in self.kwargs: + self.kwargs[ + "relationship_paths" + ] = """ + Common relationship paths to consider: + - Transaction → Related Dimensions (for basic analysis) + - Geographic → Location hierarchies (for geographic analysis) + - Temporal → Date hierarchies (for time-based analysis) + - Entity → Attributes (for entity-specific analysis) + """ + + self.sql_schema_selection_agent = SqlSchemaSelectionAgent( + target_engine=self.target_engine, + engine_specific_rules=self.engine_specific_rules, + **self.kwargs, + ) + + self.sql_query_correction_agent = LLMAgentCreator.create( + "sql_query_correction_agent", + target_engine=self.target_engine, + engine_specific_rules=self.engine_specific_rules, + **self.kwargs, + ) + + self.sql_disambiguation_agent = LLMAgentCreator.create( + "sql_disambiguation_agent", + target_engine=self.target_engine, + engine_specific_rules=self.engine_specific_rules, + **self.kwargs, + ) + + # Auto-responding UserProxyAgent + self.user_proxy = EmptyResponseUserProxyAgent(name="user_proxy") + + agents = [ + self.user_proxy, + self.sql_query_generation_agent, + self.sql_schema_selection_agent, + self.sql_query_correction_agent, + self.sql_disambiguation_agent, + ] + + if self.use_query_cache: + self.query_cache_agent = SqlQueryCacheAgent() + agents.append(self.query_cache_agent) + + return agents + + @property + def termination_condition(self): + """Define the termination condition for the chat.""" + termination = TextMentionTermination("TERMINATE") | MaxMessageTermination(10) + return termination + + def unified_selector(self, messages): + """Unified selector for the complete flow.""" + """Unified selector for the complete flow.""" + logging.info("Messages: %s", messages) + current_agent = messages[-1].source if messages else "start" + decision = None + + if current_agent == "start": + decision = ( + "sql_query_cache_agent" + if self.use_query_cache + else "sql_schema_selection_agent" + ) + # Handle subsequent agent transitions + elif current_agent == "sql_query_cache_agent": + # Always go through schema selection after cache check + decision = "sql_schema_selection_agent" + elif current_agent == "sql_schema_selection_agent": + decision = "sql_disambiguation_agent" + elif current_agent == "sql_disambiguation_agent": + decision = "sql_query_generation_agent" + elif current_agent == "sql_query_generation_agent": + decision = "sql_query_correction_agent" + elif current_agent == "sql_query_correction_agent": + try: + correction_result = json.loads(messages[-1].content) + if isinstance(correction_result, dict): + if "answer" in correction_result and "sources" in correction_result: + decision = "answer_and_sources_agent" + elif "corrected_query" in correction_result: + if correction_result.get("executing", False): + decision = "sql_query_correction_agent" + else: + decision = "sql_query_generation_agent" + elif "error" in correction_result: + decision = "sql_query_generation_agent" + elif isinstance(correction_result, list) and len(correction_result) > 0: + if "requested_fix" in correction_result[0]: + decision = "sql_query_generation_agent" + + if decision is None: + decision = "sql_query_generation_agent" + except json.JSONDecodeError: + decision = "sql_query_generation_agent" + elif current_agent == "answer_and_sources_agent": + decision = "user_proxy" # Let user_proxy send TERMINATE + + if decision: + logging.info(f"Agent transition: {current_agent} -> {decision}") + else: + logging.info(f"No agent transition defined from {current_agent}") + + return decision + + @property + def agentic_flow(self): + """Create the unified flow for the complete process.""" + flow = SelectorGroupChat( + self.get_all_agents(), + allow_repeated_speaker=False, + model_client=LLMModelCreator.get_model("4o-mini"), + termination_condition=self.termination_condition, + selector_func=self.unified_selector, + ) + return flow From 124c491fe9a88f851846796090a14f409427ae82 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Fri, 27 Dec 2024 16:36:54 +0000 Subject: [PATCH 03/15] Separate out sources and answer agent --- text_2_sql/autogen/src/__init__.py | 0 .../autogen_text_2_sql/autogen_text_2_sql.py | 20 ++++++++------ .../parallel_query_solving_agent.py | 13 ++++++++-- ..._and_sources_agent.py => sources_agent.py} | 4 +-- .../inner_autogen_text_2_sql.py | 26 +++++++++++++++++++ 5 files changed, 51 insertions(+), 12 deletions(-) delete mode 100644 text_2_sql/autogen/src/__init__.py rename text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/{answer_and_sources_agent.py => sources_agent.py} (97%) diff --git a/text_2_sql/autogen/src/__init__.py b/text_2_sql/autogen/src/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py index 0c9e1ba..5c17c8c 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py @@ -11,8 +11,8 @@ from autogen_text_2_sql.custom_agents.parallel_query_solving_agent import ( ParallelQuerySolvingAgent, ) -from autogen_text_2_sql.custom_agents.answer_and_sources_agent import ( - AnswerAndSourcesAgent, +from text_2_sql.autogen.src.autogen_text_2_sql.custom_agents.sources_agent import ( + SourcesAgent, ) from autogen_agentchat.agents import UserProxyAgent from autogen_agentchat.messages import TextMessage @@ -57,7 +57,9 @@ def get_all_agents(self): engine_specific_rules=self.engine_specific_rules, **self.kwargs ) - self.answer_and_sources_agent = AnswerAndSourcesAgent() + self.answer_agent = LLMAgentCreator.create("answer_agent") + + self.sources_agent = SourcesAgent() # Auto-responding UserProxyAgent self.user_proxy = EmptyResponseUserProxyAgent(name="user_proxy") @@ -66,7 +68,7 @@ def get_all_agents(self): self.user_proxy, self.query_rewrite_agent, self.parallel_query_solving_agent, - self.answer_and_sources_agent, + self.sources_agent, ] return agents @@ -95,7 +97,9 @@ def unified_selector(self, messages): decision = "parallel_query_solving_agent" # Handle transition after parallel query solving elif current_agent == "parallel_query_solving_agent": - decision = "answer_and_sources_agent" + decision = "answer_agent" + elif current_agent == "answer_agent": + decision = "sources_agent" if decision: logging.info(f"Agent transition: {current_agent} -> {decision}") @@ -118,7 +122,7 @@ def agentic_flow(self): async def process_question( self, - task: str, + question: str, chat_history: list[str] = None, parameters: dict = None, ): @@ -134,11 +138,11 @@ async def process_question( ------- dict: The response from the system. """ - logging.info("Processing question: %s", task) + logging.info("Processing question: %s", question) logging.info("Chat history: %s", chat_history) agent_input = { - "user_question": task, + "question": question, "chat_history": {}, "parameters": parameters, } diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py index e50c79b..75ea16a 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py @@ -41,6 +41,13 @@ async def on_messages_stream( self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken ) -> AsyncGenerator[AgentMessage | Response, None]: last_response = messages[-1].content + parameter_input = messages[0].content + last_response = messages[-1].content + try: + user_parameters = json.loads(parameter_input)["parameters"] + except json.JSONDecodeError: + logging.error("Error decoding the user parameters.") + user_parameters = {} # Load the json of the last message to populate the final output object query_rewrites = json.loads(last_response) @@ -49,14 +56,16 @@ async def on_messages_stream( inner_solving_tasks = [] - for query_rewrite in query_rewrites: + for query_rewrite in query_rewrites["sub_queries"]: # Create an instance of the InnerAutoGenText2Sql class inner_autogen_text_2_sql = InnerAutoGenText2Sql( self.engine_specific_rules, **self.kwargs ) inner_solving_tasks.append( - inner_autogen_text_2_sql.run_stream(task=query_rewrite) + inner_autogen_text_2_sql.process_question( + question=query_rewrite, parameters=user_parameters + ) ) # Wait for all the inner solving tasks to complete diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/answer_and_sources_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sources_agent.py similarity index 97% rename from text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/answer_and_sources_agent.py rename to text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sources_agent.py index 3c1afcc..e1e8e20 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/answer_and_sources_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sources_agent.py @@ -12,10 +12,10 @@ import pandas as pd -class AnswerAndSourcesAgent(BaseChatAgent): +class SourcesAgent(BaseChatAgent): def __init__(self): super().__init__( - "answer_and_sources_agent", + "sources_agent", "An agent that formats the final answer and sources.", ) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py index 3b4926e..2152e64 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py @@ -187,3 +187,29 @@ def agentic_flow(self): selector_func=self.unified_selector, ) return flow + + async def process_question( + self, + question: str, + parameters: dict = None, + ): + """Process the complete question through the unified system. + + Args: + ---- + task (str): The user question to process. + parameters (dict, optional): Parameters to pass to agents. Defaults to None. + + Returns: + ------- + dict: The response from the system. + """ + logging.info("Processing question: %s", question) + + agent_input = { + "question": question, + "chat_history": {}, + "parameters": parameters, + } + + return self.agentic_flow.run_stream(task=json.dumps(agent_input)) From 01b0c8e64c1bf1da6e267d63272c5cbd9a61f94a Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Fri, 27 Dec 2024 16:38:31 +0000 Subject: [PATCH 04/15] Add inner message --- .../parallel_query_solving_agent.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py index 75ea16a..a1a3015 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py @@ -4,7 +4,12 @@ from autogen_agentchat.agents import BaseChatAgent from autogen_agentchat.base import Response -from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage +from autogen_agentchat.messages import ( + AgentEvent, + AgentMessage, + ChatMessage, + TextMessage, +) from autogen_core import CancellationToken import json import logging @@ -40,6 +45,8 @@ async def on_messages( async def on_messages_stream( self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken ) -> AsyncGenerator[AgentMessage | Response, None]: + inner_messages: List[AgentEvent | ChatMessage] = [] + last_response = messages[-1].content parameter_input = messages[0].content last_response = messages[-1].content @@ -76,8 +83,9 @@ async def on_messages_stream( yield Response( chat_message=TextMessage( content=json.dumps(inner_solving_results), source=self.name - ) + ), + inner_messages=inner_messages, ) - async def on_reset(self, cancellation_token: CancellationToken) -> None: - pass + async def on_reset(self, cancellation_token: CancellationToken) -> None: + pass From 9beacdea731d42d8af5cee40c11347acc5275c42 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Fri, 27 Dec 2024 17:59:04 +0000 Subject: [PATCH 05/15] Add aiostream --- text_2_sql/autogen/pyproject.toml | 1 + .../autogen_text_2_sql/autogen_text_2_sql.py | 4 +- .../parallel_query_solving_agent.py | 65 +++++++++++++++---- .../inner_autogen_text_2_sql.py | 2 +- .../text_2_sql_core/prompts/answer_agent.yaml | 6 ++ uv.lock | 14 ++++ 6 files changed, 76 insertions(+), 16 deletions(-) create mode 100644 text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/answer_agent.yaml diff --git a/text_2_sql/autogen/pyproject.toml b/text_2_sql/autogen/pyproject.toml index 721b9f9..c72e5d6 100644 --- a/text_2_sql/autogen/pyproject.toml +++ b/text_2_sql/autogen/pyproject.toml @@ -5,6 +5,7 @@ description = "AutoGen Based Implementation" readme = "README.md" requires-python = ">=3.12" dependencies = [ + "aiostream>=0.6.4", "autogen-agentchat==0.4.0.dev11", "autogen-core==0.4.0.dev11", "autogen-ext[azure,openai]==0.4.0.dev11", diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py index 5c17c8c..3305e49 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py @@ -11,7 +11,7 @@ from autogen_text_2_sql.custom_agents.parallel_query_solving_agent import ( ParallelQuerySolvingAgent, ) -from text_2_sql.autogen.src.autogen_text_2_sql.custom_agents.sources_agent import ( +from autogen_text_2_sql.custom_agents.sources_agent import ( SourcesAgent, ) from autogen_agentchat.agents import UserProxyAgent @@ -120,7 +120,7 @@ def agentic_flow(self): ) return flow - async def process_question( + def process_question( self, question: str, chat_history: list[str] = None, diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py index a1a3015..64b52c8 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py @@ -5,7 +5,6 @@ from autogen_agentchat.agents import BaseChatAgent from autogen_agentchat.base import Response from autogen_agentchat.messages import ( - AgentEvent, AgentMessage, ChatMessage, TextMessage, @@ -13,9 +12,10 @@ from autogen_core import CancellationToken import json import logging -import asyncio from autogen_text_2_sql.inner_autogen_text_2_sql import InnerAutoGenText2Sql +from aiostream import stream + class ParallelQuerySolvingAgent(BaseChatAgent): def __init__(self, engine_specific_rules: str, **kwargs: dict): @@ -45,11 +45,10 @@ async def on_messages( async def on_messages_stream( self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken ) -> AsyncGenerator[AgentMessage | Response, None]: - inner_messages: List[AgentEvent | ChatMessage] = [] + inner_messages: List[AgentMessage | ChatMessage] = [] last_response = messages[-1].content parameter_input = messages[0].content - last_response = messages[-1].content try: user_parameters = json.loads(parameter_input)["parameters"] except json.JSONDecodeError: @@ -61,31 +60,71 @@ async def on_messages_stream( logging.info(f"Query Rewrite: {query_rewrites}") - inner_solving_tasks = [] + inner_solving_generators = [] + + async def consume_inner_messages_from_agentic_flow( + agentic_flow, identifier, complete_inner_messages + ): + """ + Consume the inner messages and append them to the specified list. + + Args: + ---- + agentic_flow: The async generator to consume messages from. + messages_list: The list to which messages should be added. + """ + async for inner_message in agentic_flow: + # Add message to results dictionary, tagged by the function name + if identifier not in complete_inner_messages: + complete_inner_messages[identifier] = [] + complete_inner_messages[identifier].append(inner_message) + + yield {"source": identifier, "message": inner_message} + complete_inner_messages = {} + + # Start processing sub-queries for query_rewrite in query_rewrites["sub_queries"]: # Create an instance of the InnerAutoGenText2Sql class inner_autogen_text_2_sql = InnerAutoGenText2Sql( self.engine_specific_rules, **self.kwargs ) - inner_solving_tasks.append( + # Launch tasks for each sub-query + inner_solving_generators.append( inner_autogen_text_2_sql.process_question( question=query_rewrite, parameters=user_parameters ) ) - # Wait for all the inner solving tasks to complete - inner_solving_results = await asyncio.gather(*inner_solving_tasks) + combined_message_streams = stream.merge(*inner_solving_generators) + + async with combined_message_streams.stream() as streamer: + async for inner_message in streamer: + print(inner_message) + yield inner_message + + # # Process the results as they are yielded + # for completed in asyncio.as_completed(inner_solving_generators): + # async for inner_message in completed: + # # Yield the result as soon as it's available + # yield inner_message + + # # Wait for all tasks to complete + # await asyncio.gather(*inner_solving_generators, return_exceptions=True) + + # # Log final results for debugging or auditing + # logging.info(f"Formatted Results: {complete_inner_messages}") - logging.info(f"Inner Solving Results: {inner_solving_results}") + # TODO: Trim out unnecessary information from the final response + # Final response yield Response( chat_message=TextMessage( - content=json.dumps(inner_solving_results), source=self.name + content=json.dumps(complete_inner_messages), source=self.name ), - inner_messages=inner_messages, + inner_messages=complete_inner_messages, ) - async def on_reset(self, cancellation_token: CancellationToken) -> None: - pass + async def on_reset(self, cancellation_token: CancellationToken) -> None: + pass diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py index 2152e64..a0fb60a 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py @@ -188,7 +188,7 @@ def agentic_flow(self): ) return flow - async def process_question( + def process_question( self, question: str, parameters: dict = None, diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/answer_agent.yaml b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/answer_agent.yaml new file mode 100644 index 0000000..d777c3a --- /dev/null +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/answer_agent.yaml @@ -0,0 +1,6 @@ +model: "4o-mini" +description: "An agent that generates a response to a user's question." +system_message: | + + You are a helpful AI Assistant specializing in answering a user's question. + diff --git a/uv.lock b/uv.lock index 1130f44..01d5125 100644 --- a/uv.lock +++ b/uv.lock @@ -164,6 +164,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/76/ac/a7305707cb852b7e16ff80eaf5692309bde30e2b1100a1fcacdc8f731d97/aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17", size = 7617 }, ] +[[package]] +name = "aiostream" +version = "0.6.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a5/fe/aa14603fcd5cc4333f81791bbdf58cd4cb8677c8e21e3cc691d27c00173f/aiostream-0.6.4.tar.gz", hash = "sha256:f99bc6b1b9cea3e70885dc235a233523597555fe4a585ed21d65264b3f1ff3d2", size = 67983 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/71/5c/639dc59441df1d5cec49a09c36eb89af651476760f950b7d018bdf0ec4a7/aiostream-0.6.4-py3-none-any.whl", hash = "sha256:bd8c6a8b90a52c0325a3b19406f0f2a131448e596c06398886f5be1c73b4cea9", size = 53665 }, +] + [[package]] name = "annotated-types" version = "0.7.0" @@ -347,6 +359,7 @@ name = "autogen-text-2-sql" version = "0.1.0" source = { virtual = "text_2_sql/autogen" } dependencies = [ + { name = "aiostream" }, { name = "autogen-agentchat" }, { name = "autogen-core" }, { name = "autogen-ext", extra = ["azure", "openai"] }, @@ -369,6 +382,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "aiostream", specifier = ">=0.6.4" }, { name = "autogen-agentchat", specifier = "==0.4.0.dev11" }, { name = "autogen-core", specifier = "==0.4.0.dev11" }, { name = "autogen-ext", extras = ["azure", "openai"], specifier = "==0.4.0.dev11" }, From 40e4fc31a4da6e054ff744d785ad202f6544503c Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Fri, 27 Dec 2024 18:04:06 +0000 Subject: [PATCH 06/15] Add aiostream --- .../parallel_query_solving_agent.py | 38 +++++++++---------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py index 64b52c8..9af4df2 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py @@ -45,8 +45,6 @@ async def on_messages( async def on_messages_stream( self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken ) -> AsyncGenerator[AgentMessage | Response, None]: - inner_messages: List[AgentMessage | ChatMessage] = [] - last_response = messages[-1].content parameter_input = messages[0].content try: @@ -58,9 +56,7 @@ async def on_messages_stream( # Load the json of the last message to populate the final output object query_rewrites = json.loads(last_response) - logging.info(f"Query Rewrite: {query_rewrites}") - - inner_solving_generators = [] + logging.info(f"Query Rewrites: {query_rewrites}") async def consume_inner_messages_from_agentic_flow( agentic_flow, identifier, complete_inner_messages @@ -81,10 +77,12 @@ async def consume_inner_messages_from_agentic_flow( yield {"source": identifier, "message": inner_message} + inner_solving_generators = [] complete_inner_messages = {} # Start processing sub-queries for query_rewrite in query_rewrites["sub_queries"]: + logging.info(f"Processing sub-query: {query_rewrite}") # Create an instance of the InnerAutoGenText2Sql class inner_autogen_text_2_sql = InnerAutoGenText2Sql( self.engine_specific_rules, **self.kwargs @@ -92,38 +90,36 @@ async def consume_inner_messages_from_agentic_flow( # Launch tasks for each sub-query inner_solving_generators.append( - inner_autogen_text_2_sql.process_question( - question=query_rewrite, parameters=user_parameters + consume_inner_messages_from_agentic_flow( + inner_autogen_text_2_sql.process_question( + question=query_rewrite, parameters=user_parameters + ), + query_rewrite, + complete_inner_messages, ) ) + logging.info("Starting Inner Solving Generators") combined_message_streams = stream.merge(*inner_solving_generators) async with combined_message_streams.stream() as streamer: async for inner_message in streamer: - print(inner_message) + logging.info(f"Inner Solving Message: {inner_message}") yield inner_message - # # Process the results as they are yielded - # for completed in asyncio.as_completed(inner_solving_generators): - # async for inner_message in completed: - # # Yield the result as soon as it's available - # yield inner_message - - # # Wait for all tasks to complete - # await asyncio.gather(*inner_solving_generators, return_exceptions=True) - - # # Log final results for debugging or auditing - # logging.info(f"Formatted Results: {complete_inner_messages}") + # Log final results for debugging or auditing + logging.info(f"Formatted Results: {complete_inner_messages}") # TODO: Trim out unnecessary information from the final response - # Final response yield Response( chat_message=TextMessage( content=json.dumps(complete_inner_messages), source=self.name ), - inner_messages=complete_inner_messages, + inner_messages=[ + complete_inner_message["message"] + for complete_inner_message in complete_inner_messages + ], ) async def on_reset(self, cancellation_token: CancellationToken) -> None: From ee41cbeeb2136737c74583d40ceb52899ef4dd98 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Fri, 27 Dec 2024 18:16:58 +0000 Subject: [PATCH 07/15] Update agentic flow --- .../autogen/src/autogen_text_2_sql/autogen_text_2_sql.py | 4 ++-- .../custom_agents/parallel_query_solving_agent.py | 5 ++++- .../src/autogen_text_2_sql/inner_autogen_text_2_sql.py | 4 ++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py index 3305e49..3b732fe 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py @@ -86,11 +86,11 @@ def termination_condition(self): def unified_selector(self, messages): """Unified selector for the complete flow.""" logging.info("Messages: %s", messages) - current_agent = messages[-1].source if messages else "start" + current_agent = messages[-1].source if messages else "user" decision = None # If this is the first message start with query_rewrite_agent - if current_agent == "start": + if current_agent == "user": decision = "query_rewrite_agent" # Handle transition after query rewriting elif current_agent == "query_rewrite_agent": diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py index 9af4df2..58afe77 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py @@ -75,7 +75,10 @@ async def consume_inner_messages_from_agentic_flow( complete_inner_messages[identifier] = [] complete_inner_messages[identifier].append(inner_message) - yield {"source": identifier, "message": inner_message} + yield TextMessage( + content=json.dumps(inner_message), + source=f"{self.name}-{identifier}", + ) inner_solving_generators = [] complete_inner_messages = {} diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py index a0fb60a..38e3897 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py @@ -126,10 +126,10 @@ def unified_selector(self, messages): """Unified selector for the complete flow.""" """Unified selector for the complete flow.""" logging.info("Messages: %s", messages) - current_agent = messages[-1].source if messages else "start" + current_agent = messages[-1].source if messages else "user" decision = None - if current_agent == "start": + if current_agent == "user": decision = ( "sql_query_cache_agent" if self.use_query_cache From 977454a42f67c1013b14001bf39d2bbbca27eb5c Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Fri, 27 Dec 2024 18:28:10 +0000 Subject: [PATCH 08/15] Yield original message type --- .../custom_agents/parallel_query_solving_agent.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py index 58afe77..c124c74 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py @@ -75,10 +75,7 @@ async def consume_inner_messages_from_agentic_flow( complete_inner_messages[identifier] = [] complete_inner_messages[identifier].append(inner_message) - yield TextMessage( - content=json.dumps(inner_message), - source=f"{self.name}-{identifier}", - ) + yield inner_message inner_solving_generators = [] complete_inner_messages = {} @@ -102,6 +99,7 @@ async def consume_inner_messages_from_agentic_flow( ) ) + logging.info("Created %i Inner Solving Generators", inner_solving_generators) logging.info("Starting Inner Solving Generators") combined_message_streams = stream.merge(*inner_solving_generators) From 1977699a3c382125406c7f20380d413cee878436 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Mon, 30 Dec 2024 15:47:31 +0000 Subject: [PATCH 09/15] Add enums for structured data mode --- .../src/text_2_sql_core/payloads/__init__.py | 0 .../payloads/answer_with_sources.py | 12 +++++++++ .../payloads/user_information_request.py | 26 +++++++++++++++++++ 3 files changed, 38 insertions(+) create mode 100644 text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/__init__.py create mode 100644 text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/answer_with_sources.py create mode 100644 text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/user_information_request.py diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/__init__.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/answer_with_sources.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/answer_with_sources.py new file mode 100644 index 0000000..6064162 --- /dev/null +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/answer_with_sources.py @@ -0,0 +1,12 @@ +from pydantic import BaseModel, Field + + +class Source(BaseModel): + sql_query: str + sql_rows: list[dict] + markdown_table: str + + +class AnswerWithSources(BaseModel): + answer: str + sources: list[str] = Field(default_factory=list) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/user_information_request.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/user_information_request.py new file mode 100644 index 0000000..a24496c --- /dev/null +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/user_information_request.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel, RootModel, Field +from enum import StrEnum +from typing import Literal + + +class RequestType(StrEnum): + DISAMBIGUATION = "disambiguation" + CLARIFICATION = "clarification" + + +class ClarificationRequest(BaseModel): + type: Literal[RequestType.CLARIFICATION] + question: str + other_user_choices: list[str] + + +class DismabiguationRequest(BaseModel): + type: Literal[RequestType.DISAMBIGUATION] + question: str + matching_columns: list[str] + matching_filter_values: list[str] + other_user_choices: list[str] + + +class UserInformationRequest(RootModel): + root: DismabiguationRequest = Field(..., discriminator="type") From e846e9e00aa1ae881c129f7f465f59615573951a Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Mon, 30 Dec 2024 15:59:02 +0000 Subject: [PATCH 10/15] Update --- text_2_sql/autogen/src/autogen_text_2_sql/__init__.py | 3 +++ .../text_2_sql_core/src/text_2_sql_core/payloads/__init__.py | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py b/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py index e69de29..defc348 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py @@ -0,0 +1,3 @@ +from autogen_text_2_sql.autogen_text_2_sql import AutoGenText2Sql + +__all__ = ["AutoGenText2Sql"] diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/__init__.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/__init__.py index e69de29..57e3e6b 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/__init__.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/__init__.py @@ -0,0 +1,4 @@ +from text_2_sql_core.payloads.answer_with_sources import AnswerWithSources, Source +from text_2_sql_core.payloads.user_information_request import UserInformationRequest + +__all__ = ["AnswerWithSources", "Source", "UserInformationRequest"] From c0c630c9c7e8929df370eb5daf7aee112a44009a Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Mon, 30 Dec 2024 16:14:08 +0000 Subject: [PATCH 11/15] Update query validation --- .../autogen_text_2_sql/autogen_text_2_sql.py | 53 +++++++++++++++++-- .../src/text_2_sql_core/connectors/sql.py | 2 +- .../src/text_2_sql_core/payloads/__init__.py | 3 +- .../payloads/processing_update.py | 6 +++ .../payloads/user_information_request.py | 6 +-- 5 files changed, 61 insertions(+), 9 deletions(-) create mode 100644 text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/processing_update.py diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py index 3b732fe..5b54c49 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py @@ -16,11 +16,18 @@ ) from autogen_agentchat.agents import UserProxyAgent from autogen_agentchat.messages import TextMessage -from autogen_agentchat.base import Response import json import os from datetime import datetime +from text_2_sql_core.payloads import ( + AnswerWithSources, + UserInformationRequest, + ProcessingUpdate, +) +from autogen_agentchat.base import Response, TaskResult +from asyncio import AsyncGenerator + class EmptyResponseUserProxyAgent(UserProxyAgent): """UserProxyAgent that automatically responds with empty messages.""" @@ -120,12 +127,12 @@ def agentic_flow(self): ) return flow - def process_question( + async def process_question( self, question: str, chat_history: list[str] = None, parameters: dict = None, - ): + ) -> AsyncGenerator[AnswerWithSources | UserInformationRequest]: """Process the complete question through the unified system. Args: @@ -152,4 +159,42 @@ def process_question( for idx, chat in enumerate(chat_history): agent_input[f"chat_{idx}"] = chat - return self.agentic_flow.run_stream(task=json.dumps(agent_input)) + async for message in self.agentic_flow.run_stream(task=json.dumps(agent_input)): + logging.info("Message: %s", message) + logging.info("Message type: %s", type(message)) + + payload = None + + if isinstance(message, TextMessage): + if message.source == "query_rewrite_agent": + # If the message is from the query_rewrite_agent, we need to update the chat history + payload = ProcessingUpdate( + title="Rewriting the query...", + ) + elif message.source == "parallel_query_solving_agent": + # If the message is from the parallel_query_solving_agent, we need to update the chat history + payload = ProcessingUpdate( + title="Solving the query...", + ) + elif message.source == "answer_agent": + # If the message is from the answer_agent, we need to update the chat history + payload = ProcessingUpdate( + title="Generating the answer...", + ) + + elif isinstance(message, TaskResult): + # Now we need to return the final answer or the disambiguation request + + if message.task == "answer_agent": + # If the message is from the answer_agent, we need to return the final answer + payload = AnswerWithSources( + **json.loads(message.content), + ) + else: + payload = UserInformationRequest( + **json.loads(message.content), + ) + + if payload is not None: + logging.info("Payload: %s", payload) + yield payload diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py index 2309128..8b2d256 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py @@ -143,7 +143,7 @@ async def query_validation( sqlglot.transpile( sql_query, read=self.database_engine.value.lower(), - error_level=sqlglot.ErrorLevel.ERROR, + error_level=sqlglot.ErrorLevel.RAISE, ) except sqlglot.errors.ParseError as e: logging.error("SQL Query is invalid: %s", e.errors) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/__init__.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/__init__.py index 57e3e6b..b7f9116 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/__init__.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/__init__.py @@ -1,4 +1,5 @@ from text_2_sql_core.payloads.answer_with_sources import AnswerWithSources, Source from text_2_sql_core.payloads.user_information_request import UserInformationRequest +from text_2_sql_core.payloads.processing_update import ProcessingUpdate -__all__ = ["AnswerWithSources", "Source", "UserInformationRequest"] +__all__ = ["AnswerWithSources", "Source", "UserInformationRequest", "ProcessingUpdate"] diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/processing_update.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/processing_update.py new file mode 100644 index 0000000..1e4eb85 --- /dev/null +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/processing_update.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class ProcessingUpdate(BaseModel): + title: str + message: str diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/user_information_request.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/user_information_request.py index a24496c..1aac4c0 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/user_information_request.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/user_information_request.py @@ -9,13 +9,13 @@ class RequestType(StrEnum): class ClarificationRequest(BaseModel): - type: Literal[RequestType.CLARIFICATION] + request_type: Literal[RequestType.CLARIFICATION] question: str other_user_choices: list[str] class DismabiguationRequest(BaseModel): - type: Literal[RequestType.DISAMBIGUATION] + request_type: Literal[RequestType.DISAMBIGUATION] question: str matching_columns: list[str] matching_filter_values: list[str] @@ -23,4 +23,4 @@ class DismabiguationRequest(BaseModel): class UserInformationRequest(RootModel): - root: DismabiguationRequest = Field(..., discriminator="type") + root: DismabiguationRequest = Field(..., discriminator="request_type") From 492f41b120abf54468dca97dbf44ffebfff1f3f2 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Mon, 30 Dec 2024 16:34:44 +0000 Subject: [PATCH 12/15] Processing update --- .../src/autogen_text_2_sql/autogen_text_2_sql.py | 12 ++++++------ .../text_2_sql_core/payloads/processing_update.py | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py index 5b54c49..f5ecbde 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py @@ -26,7 +26,7 @@ ProcessingUpdate, ) from autogen_agentchat.base import Response, TaskResult -from asyncio import AsyncGenerator +from typing import AsyncGenerator class EmptyResponseUserProxyAgent(UserProxyAgent): @@ -132,7 +132,7 @@ async def process_question( question: str, chat_history: list[str] = None, parameters: dict = None, - ) -> AsyncGenerator[AnswerWithSources | UserInformationRequest]: + ) -> AsyncGenerator[AnswerWithSources | UserInformationRequest, None]: """Process the complete question through the unified system. Args: @@ -169,23 +169,23 @@ async def process_question( if message.source == "query_rewrite_agent": # If the message is from the query_rewrite_agent, we need to update the chat history payload = ProcessingUpdate( - title="Rewriting the query...", + message="Rewriting the query...", ) elif message.source == "parallel_query_solving_agent": # If the message is from the parallel_query_solving_agent, we need to update the chat history payload = ProcessingUpdate( - title="Solving the query...", + message="Solving the query...", ) elif message.source == "answer_agent": # If the message is from the answer_agent, we need to update the chat history payload = ProcessingUpdate( - title="Generating the answer...", + message="Generating the answer...", ) elif isinstance(message, TaskResult): # Now we need to return the final answer or the disambiguation request - if message.task == "answer_agent": + if message.source == "answer_agent": # If the message is from the answer_agent, we need to return the final answer payload = AnswerWithSources( **json.loads(message.content), diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/processing_update.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/processing_update.py index 1e4eb85..3950508 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/processing_update.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/processing_update.py @@ -1,6 +1,6 @@ -from pydantic import BaseModel +from pydantic import BaseModel, Field class ProcessingUpdate(BaseModel): - title: str - message: str + title: str | None = Field(default="Processing...") + message: str | None = Field(default="Processing...") From 4520da9a264868882160527841795c9bbf9e392f Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Mon, 30 Dec 2024 16:38:53 +0000 Subject: [PATCH 13/15] Update notebook --- ...on 5 - Agentic Vector Based Text2SQL.ipynb | 35 +++++++++---------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb b/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb index 4c4d516..4e49cdc 100644 --- a/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb +++ b/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb @@ -9,6 +9,20 @@ "Licensed under the MIT License." ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# This is only needed for this notebook to work\n", + "import sys\n", + "from pathlib import Path\n", + "\n", + "# Add the parent directory of `src` to the path\n", + "sys.path.append(str(Path.cwd() / \"src\"))" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -28,20 +42,6 @@ "`uv add --editable text_2_sql_core`" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# This is only needed for this notebook to work\n", - "import sys\n", - "from pathlib import Path\n", - "\n", - "# Add the parent directory of `src` to the path\n", - "sys.path.append(str(Path.cwd() / \"src\"))" - ] - }, { "cell_type": "code", "execution_count": null, @@ -50,8 +50,7 @@ "source": [ "import dotenv\n", "import logging\n", - "from autogen_agentchat.ui import Console\n", - "from autogen_text_2_sql.autogen_text_2_sql import AutoGenText2Sql" + "from autogen_text_2_sql import AutoGenText2Sql" ] }, { @@ -101,8 +100,8 @@ "metadata": {}, "outputs": [], "source": [ - "result = await agentic_text_2_sql.process_question(task=\"What total number of orders in June 2008?\")\n", - "await Console(result)\n" + "async for message in agentic_text_2_sql.process_question(question=\"What total number of orders in June 2008?\"):\n", + " logging.info(\"Received %s Message from Text2SQL System\", message)" ] }, { From 5efa871403d1568e32b6d7b8bd8ffb5d4459243a Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Mon, 30 Dec 2024 17:49:06 +0000 Subject: [PATCH 14/15] Update answer --- .../autogen_text_2_sql/autogen_text_2_sql.py | 73 ++++++++++++----- .../parallel_query_solving_agent.py | 72 +++++++++++------ .../custom_agents/sources_agent.py | 81 ------------------- .../inner_autogen_text_2_sql.py | 4 +- .../payloads/answer_with_sources.py | 3 +- .../text_2_sql_core/prompts/answer_agent.yaml | 6 ++ 6 files changed, 111 insertions(+), 128 deletions(-) delete mode 100644 text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sources_agent.py diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py index f5ecbde..b319715 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py @@ -3,6 +3,7 @@ from autogen_agentchat.conditions import ( TextMentionTermination, MaxMessageTermination, + SourceMatchTermination, ) from autogen_agentchat.teams import SelectorGroupChat from autogen_text_2_sql.creators.llm_model_creator import LLMModelCreator @@ -11,9 +12,6 @@ from autogen_text_2_sql.custom_agents.parallel_query_solving_agent import ( ParallelQuerySolvingAgent, ) -from autogen_text_2_sql.custom_agents.sources_agent import ( - SourcesAgent, -) from autogen_agentchat.agents import UserProxyAgent from autogen_agentchat.messages import TextMessage import json @@ -66,8 +64,6 @@ def get_all_agents(self): self.answer_agent = LLMAgentCreator.create("answer_agent") - self.sources_agent = SourcesAgent() - # Auto-responding UserProxyAgent self.user_proxy = EmptyResponseUserProxyAgent(name="user_proxy") @@ -75,7 +71,7 @@ def get_all_agents(self): self.user_proxy, self.query_rewrite_agent, self.parallel_query_solving_agent, - self.sources_agent, + self.answer_agent, ] return agents @@ -85,8 +81,9 @@ def termination_condition(self): """Define the termination condition for the chat.""" termination = ( TextMentionTermination("TERMINATE") - | (TextMentionTermination("answer") & TextMentionTermination("sources")) - | MaxMessageTermination(20) + | SourceMatchTermination("answer_agent") + | TextMentionTermination("requires_user_information_request") + | MaxMessageTermination(5) ) return termination @@ -105,8 +102,6 @@ def unified_selector(self, messages): # Handle transition after parallel query solving elif current_agent == "parallel_query_solving_agent": decision = "answer_agent" - elif current_agent == "answer_agent": - decision = "sources_agent" if decision: logging.info(f"Agent transition: {current_agent} -> {decision}") @@ -127,6 +122,45 @@ def agentic_flow(self): ) return flow + def extract_sources(self, messages: list) -> AnswerWithSources: + """Extract the sources from the answer.""" + + answer = messages[-1].content + + sql_query_results = messages[-2].content + + try: + sql_query_results = json.loads(sql_query_results) + + logging.info("SQL Query Results: %s", sql_query_results) + + sources = [] + + for question, sql_query_result_list in sql_query_results["results"].items(): + logging.info( + "SQL Query Result for question '%s': %s", + question, + sql_query_result_list, + ) + + for sql_query_result in sql_query_result_list: + logging.info("SQL Query Result: %s", sql_query_result) + sources.append( + { + "sql_query": sql_query_result["sql_query"], + "sql_rows": sql_query_result["sql_rows"], + } + ) + + except json.JSONDecodeError: + logging.error("Could not load message: %s", sql_query_results) + raise ValueError("Could not load message") + + return AnswerWithSources( + answer=answer, + sources=sources, + ) + async def process_question( self, question: str, @@ -160,8 +194,7 @@ async def process_question( agent_input[f"chat_{idx}"] = chat async for message in self.agentic_flow.run_stream(task=json.dumps(agent_input)): - logging.info("Message: %s", message) - logging.info("Message type: %s", type(message)) + logging.debug("Message: %s", message) payload = None @@ -184,17 +217,19 @@ async def process_question( elif isinstance(message, TaskResult): # Now we need to return the final answer or the disambiguation request + logging.info("TaskResult: %s", message) - if message.source == "answer_agent": + if message.messages[-1].source == "answer_agent": # If the message is from the answer_agent, we need to return the final answer - payload = AnswerWithSources( - **json.loads(message.content), - ) - else: + payload = self.extract_sources(message.messages) + elif message.messages[-1].source == "parallel_query_solving_agent": payload = UserInformationRequest( - **json.loads(message.content), + **json.loads(message.messages[-1].content), ) + else: + logging.error("Unexpected TaskResult: %s", message) + raise ValueError("Unexpected TaskResult") if payload is not None: - logging.info("Payload: %s", payload) + logging.debug("Payload: %s", payload) yield payload diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py index c124c74..5a5c707 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py @@ -3,18 +3,14 @@ from typing import AsyncGenerator, List, Sequence from autogen_agentchat.agents import BaseChatAgent -from autogen_agentchat.base import Response -from autogen_agentchat.messages import ( - AgentMessage, - ChatMessage, - TextMessage, -) +from autogen_agentchat.base import Response, TaskResult +from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage from autogen_core import CancellationToken import json import logging from autogen_text_2_sql.inner_autogen_text_2_sql import InnerAutoGenText2Sql - from aiostream import stream +from json import JSONDecodeError class ParallelQuerySolvingAgent(BaseChatAgent): @@ -59,7 +55,7 @@ async def on_messages_stream( logging.info(f"Query Rewrites: {query_rewrites}") async def consume_inner_messages_from_agentic_flow( - agentic_flow, identifier, complete_inner_messages + agentic_flow, identifier, database_results ): """ Consume the inner messages and append them to the specified list. @@ -71,14 +67,43 @@ async def consume_inner_messages_from_agentic_flow( """ async for inner_message in agentic_flow: # Add message to results dictionary, tagged by the function name - if identifier not in complete_inner_messages: - complete_inner_messages[identifier] = [] - complete_inner_messages[identifier].append(inner_message) + if identifier not in database_results: + database_results[identifier] = [] + + logging.info(f"Checking Inner Message: {inner_message}") + + if isinstance(inner_message, TaskResult) is False: + try: + inner_message = json.loads(inner_message.content) + logging.info(f"Loaded: {inner_message}") + + # Search for specific message types and add them to the final output object + if ( + "type" in inner_message + and inner_message["type"] == "query_execution_with_limit" + ): + database_results[identifier].append( + { + "sql_query": inner_message["sql_query"].replace( + "\n", " " + ), + "sql_rows": inner_message["sql_rows"], + } + ) + + except (JSONDecodeError, TypeError) as e: + logging.error("Could not load message: %s", inner_message) + logging.warning(f"Error processing message: {e}") + + except Exception as e: + logging.error("Could not load message: %s", inner_message) + logging.error(f"Error processing message: {e}") + raise e yield inner_message inner_solving_generators = [] - complete_inner_messages = {} + database_results = {} # Start processing sub-queries for query_rewrite in query_rewrites["sub_queries"]: @@ -95,32 +120,33 @@ async def consume_inner_messages_from_agentic_flow( question=query_rewrite, parameters=user_parameters ), query_rewrite, - complete_inner_messages, + database_results, ) ) - logging.info("Created %i Inner Solving Generators", inner_solving_generators) + logging.info( + "Created %i Inner Solving Generators", len(inner_solving_generators) + ) logging.info("Starting Inner Solving Generators") combined_message_streams = stream.merge(*inner_solving_generators) async with combined_message_streams.stream() as streamer: async for inner_message in streamer: - logging.info(f"Inner Solving Message: {inner_message}") - yield inner_message + if isinstance(inner_message, TextMessage): + logging.debug(f"Inner Solving Message: {inner_message}") + yield inner_message # Log final results for debugging or auditing - logging.info(f"Formatted Results: {complete_inner_messages}") + logging.info(f"Database Results: {database_results}") - # TODO: Trim out unnecessary information from the final response # Final response yield Response( chat_message=TextMessage( - content=json.dumps(complete_inner_messages), source=self.name + content=json.dumps( + {"contains_results": True, "results": database_results} + ), + source=self.name, ), - inner_messages=[ - complete_inner_message["message"] - for complete_inner_message in complete_inner_messages - ], ) async def on_reset(self, cancellation_token: CancellationToken) -> None: diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sources_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sources_agent.py deleted file mode 100644 index e1e8e20..0000000 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sources_agent.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from typing import AsyncGenerator, List, Sequence - -from autogen_agentchat.agents import BaseChatAgent -from autogen_agentchat.base import Response -from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage -from autogen_core import CancellationToken -import json -from json import JSONDecodeError -import logging -import pandas as pd - - -class SourcesAgent(BaseChatAgent): - def __init__(self): - super().__init__( - "sources_agent", - "An agent that formats the final answer and sources.", - ) - - @property - def produced_message_types(self) -> List[type[ChatMessage]]: - return [TextMessage] - - async def on_messages( - self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken - ) -> Response: - # Calls the on_messages_stream. - response: Response | None = None - async for message in self.on_messages_stream(messages, cancellation_token): - if isinstance(message, Response): - response = message - assert response is not None - return response - - async def on_messages_stream( - self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken - ) -> AsyncGenerator[AgentMessage | Response, None]: - last_response = messages[-1].content - - # Load the json of the last message to populate the final output object - final_output_object = json.loads(last_response) - final_output_object["sources"] = [] - - for message in messages: - # Load the message content if it is a json object and was a query execution - try: - message = json.loads(message.content) - logging.info(f"Loaded: {message}") - - # Search for specific message types and add them to the final output object - if ( - "type" in message - and message["type"] == "query_execution_with_limit" - ): - dataframe = pd.DataFrame(message["sql_rows"]) - final_output_object["sources"].append( - { - "sql_query": message["sql_query"].replace("\n", " "), - "sql_rows": message["sql_rows"], - "markdown_table": dataframe.to_markdown(index=False), - } - ) - - except JSONDecodeError: - logging.info(f"Could not load message: {message}") - continue - - except Exception as e: - logging.error(f"Error processing message: {e}") - raise e - - yield Response( - chat_message=TextMessage( - content=json.dumps(final_output_object), source=self.name - ) - ) - - async def on_reset(self, cancellation_token: CancellationToken) -> None: - pass diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py index 38e3897..6564b9a 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py @@ -150,7 +150,7 @@ def unified_selector(self, messages): correction_result = json.loads(messages[-1].content) if isinstance(correction_result, dict): if "answer" in correction_result and "sources" in correction_result: - decision = "answer_and_sources_agent" + decision = "user_proxy" elif "corrected_query" in correction_result: if correction_result.get("executing", False): decision = "sql_query_correction_agent" @@ -166,8 +166,6 @@ def unified_selector(self, messages): decision = "sql_query_generation_agent" except json.JSONDecodeError: decision = "sql_query_generation_agent" - elif current_agent == "answer_and_sources_agent": - decision = "user_proxy" # Let user_proxy send TERMINATE if decision: logging.info(f"Agent transition: {current_agent} -> {decision}") diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/answer_with_sources.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/answer_with_sources.py index 6064162..650b0d4 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/answer_with_sources.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/answer_with_sources.py @@ -4,9 +4,8 @@ class Source(BaseModel): sql_query: str sql_rows: list[dict] - markdown_table: str class AnswerWithSources(BaseModel): answer: str - sources: list[str] = Field(default_factory=list) + sources: list[Source] = Field(default_factory=list) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/answer_agent.yaml b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/answer_agent.yaml index d777c3a..22ac01e 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/answer_agent.yaml +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/answer_agent.yaml @@ -4,3 +4,9 @@ system_message: | You are a helpful AI Assistant specializing in answering a user's question. + + Use the information obtained to generate a response to the user's question. The question has been broken down into a series of SQL queries and you need to generate a response based on the results of these queries. + + Do not use any external resources to generate the response. The response should be based solely on the information provided in the SQL queries and their results. + + You can use Markdown and Markdown tables to format the response. From 8593902eedd5529a7d5dfe05bd525f7ffee7c767 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Mon, 30 Dec 2024 17:58:55 +0000 Subject: [PATCH 15/15] Update parameters --- .../autogen_text_2_sql/autogen_text_2_sql.py | 12 +++++---- .../custom_agents/sql_query_cache_agent.py | 6 ++--- .../inner_autogen_text_2_sql.py | 6 ++--- .../src/text_2_sql_core/connectors/sql.py | 26 +++++++++---------- .../src/text_2_sql_core/payloads/__init__.py | 9 ++++++- .../text_2_sql_core/payloads/chat_history.py | 9 +++++++ .../text_2_sql_core/prompts/answer_agent.yaml | 2 +- 7 files changed, 44 insertions(+), 26 deletions(-) create mode 100644 text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/chat_history.py diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py index b319715..b021b01 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py @@ -22,6 +22,7 @@ AnswerWithSources, UserInformationRequest, ProcessingUpdate, + ChatHistoryItem, ) from autogen_agentchat.base import Response, TaskResult from typing import AsyncGenerator @@ -164,8 +165,8 @@ def extract_sources(self, messages: list) -> AnswerWithSources: async def process_question( self, question: str, - chat_history: list[str] = None, - parameters: dict = None, + chat_history: list[ChatHistoryItem] = None, + injected_parameters: dict = None, ) -> AsyncGenerator[AnswerWithSources | UserInformationRequest, None]: """Process the complete question through the unified system. @@ -173,7 +174,7 @@ async def process_question( ---- task (str): The user question to process. chat_history (list[str], optional): The chat history. Defaults to None. - parameters (dict, optional): Parameters to pass to agents. Defaults to None. + injected_parameters (dict, optional): Parameters to pass to agents. Defaults to None. Returns: ------- @@ -185,13 +186,14 @@ async def process_question( agent_input = { "question": question, "chat_history": {}, - "parameters": parameters, + "injected_parameters": injected_parameters, } if chat_history is not None: # Update input for idx, chat in enumerate(chat_history): - agent_input[f"chat_{idx}"] = chat + # For now only consider the user query + agent_input[f"chat_{idx}"] = chat.user_query async for message in self.agentic_flow.run_stream(task=json.dumps(agent_input)): logging.debug("Message: %s", message) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py index 0a96932..83fb13f 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py @@ -43,9 +43,9 @@ async def on_messages_stream( last_response = messages[-1].content try: user_questions = json.loads(last_response) - user_parameters = json.loads(parameter_input)["parameters"] + injected_parameters = json.loads(parameter_input)["injected_parameters"] logging.info(f"Processing questions: {user_questions}") - logging.info(f"Input Parameters: {user_parameters}") + logging.info(f"Input Parameters: {injected_parameters}") # Initialize results dictionary cached_results = { @@ -58,7 +58,7 @@ async def on_messages_stream( # Fetch the queries from the cache based on the question logging.info(f"Fetching queries from cache for question: {question}") cached_query = await self.sql_connector.fetch_queries_from_cache( - question, parameters=user_parameters + question, injected_parameters=injected_parameters ) # If any question has pre-run results, set the flag diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py index 6564b9a..9c70a1b 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py @@ -189,14 +189,14 @@ def agentic_flow(self): def process_question( self, question: str, - parameters: dict = None, + injected_parameters: dict = None, ): """Process the complete question through the unified system. Args: ---- task (str): The user question to process. - parameters (dict, optional): Parameters to pass to agents. Defaults to None. + injected_parameters (dict, optional): Parameters to pass to agents. Defaults to None. Returns: ------- @@ -207,7 +207,7 @@ def process_question( agent_input = { "question": question, "chat_history": {}, - "parameters": parameters, + "injected_parameters": injected_parameters, } return self.agentic_flow.run_stream(task=json.dumps(agent_input)) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py index 8b2d256..0ee6ead 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py @@ -153,7 +153,7 @@ async def query_validation( return True async def fetch_queries_from_cache( - self, question: str, parameters: dict = None + self, question: str, injected_parameters: dict = None ) -> str: """Fetch the queries from the cache based on the question. @@ -166,21 +166,21 @@ async def fetch_queries_from_cache( str: The formatted string of the queries fetched from the cache. This is injected into the prompt. """ - if parameters is None: - parameters = {} + if injected_parameters is None: + injected_parameters = {} - # Populate the parameters - if "date" not in parameters: - parameters["date"] = self.get_current_date() + # Populate the injected_parameters + if "date" not in injected_parameters: + injected_parameters["date"] = self.get_current_date() - if "time" not in parameters: - parameters["time"] = self.get_current_time() + if "time" not in injected_parameters: + injected_parameters["time"] = self.get_current_time() - if "datetime" not in parameters: - parameters["datetime"] = self.get_current_datetime() + if "datetime" not in injected_parameters: + injected_parameters["datetime"] = self.get_current_datetime() - if "unix_timestamp" not in parameters: - parameters["unix_timestamp"] = self.get_current_unix_timestamp() + if "unix_timestamp" not in injected_parameters: + injected_parameters["unix_timestamp"] = self.get_current_unix_timestamp() cached_schemas = await self.ai_search_connector.run_ai_search_query( question, @@ -204,7 +204,7 @@ async def fetch_queries_from_cache( sql_queries = schema["SqlQueryDecomposition"] for sql_query in sql_queries: sql_query["SqlQuery"] = Template(sql_query["SqlQuery"]).render( - **parameters + **injected_parameters ) logging.info("Cached schemas: %s", cached_schemas) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/__init__.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/__init__.py index b7f9116..e3d590c 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/__init__.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/__init__.py @@ -1,5 +1,12 @@ from text_2_sql_core.payloads.answer_with_sources import AnswerWithSources, Source from text_2_sql_core.payloads.user_information_request import UserInformationRequest from text_2_sql_core.payloads.processing_update import ProcessingUpdate +from text_2_sql_core.payloads.chat_history import ChatHistoryItem -__all__ = ["AnswerWithSources", "Source", "UserInformationRequest", "ProcessingUpdate"] +__all__ = [ + "AnswerWithSources", + "Source", + "UserInformationRequest", + "ProcessingUpdate", + "ChatHistoryItem", +] diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/chat_history.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/chat_history.py new file mode 100644 index 0000000..06b27cb --- /dev/null +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/chat_history.py @@ -0,0 +1,9 @@ +from pydantic import BaseModel +from text_2_sql_core.payloads.answer_with_sources import AnswerWithSources + + +class ChatHistoryItem(BaseModel): + """Chat history item with user message and agent response.""" + + user_query: str + agent_response: AnswerWithSources diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/answer_agent.yaml b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/answer_agent.yaml index 22ac01e..8a4a797 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/answer_agent.yaml +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/answer_agent.yaml @@ -2,7 +2,7 @@ model: "4o-mini" description: "An agent that generates a response to a user's question." system_message: | - You are a helpful AI Assistant specializing in answering a user's question. + You are a helpful AI Assistant specializing in answering a user's question about {{ use_case }}. Use the information obtained to generate a response to the user's question. The question has been broken down into a series of SQL queries and you need to generate a response based on the results of these queries.