diff --git a/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb b/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb index 4c4d516..4e49cdc 100644 --- a/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb +++ b/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb @@ -9,6 +9,20 @@ "Licensed under the MIT License." ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# This is only needed for this notebook to work\n", + "import sys\n", + "from pathlib import Path\n", + "\n", + "# Add the parent directory of `src` to the path\n", + "sys.path.append(str(Path.cwd() / \"src\"))" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -28,20 +42,6 @@ "`uv add --editable text_2_sql_core`" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# This is only needed for this notebook to work\n", - "import sys\n", - "from pathlib import Path\n", - "\n", - "# Add the parent directory of `src` to the path\n", - "sys.path.append(str(Path.cwd() / \"src\"))" - ] - }, { "cell_type": "code", "execution_count": null, @@ -50,8 +50,7 @@ "source": [ "import dotenv\n", "import logging\n", - "from autogen_agentchat.ui import Console\n", - "from autogen_text_2_sql.autogen_text_2_sql import AutoGenText2Sql" + "from autogen_text_2_sql import AutoGenText2Sql" ] }, { @@ -101,8 +100,8 @@ "metadata": {}, "outputs": [], "source": [ - "result = await agentic_text_2_sql.process_question(task=\"What total number of orders in June 2008?\")\n", - "await Console(result)\n" + "async for message in agentic_text_2_sql.process_question(question=\"What total number of orders in June 2008?\"):\n", + " logging.info(\"Received %s Message from Text2SQL System\", message)" ] }, { diff --git a/text_2_sql/autogen/pyproject.toml b/text_2_sql/autogen/pyproject.toml index 721b9f9..c72e5d6 100644 --- a/text_2_sql/autogen/pyproject.toml +++ b/text_2_sql/autogen/pyproject.toml @@ -5,6 +5,7 @@ description = "AutoGen Based Implementation" readme = "README.md" requires-python = ">=3.12" dependencies = [ + "aiostream>=0.6.4", "autogen-agentchat==0.4.0.dev11", "autogen-core==0.4.0.dev11", "autogen-ext[azure,openai]==0.4.0.dev11", diff --git a/text_2_sql/autogen/src/__init__.py b/text_2_sql/autogen/src/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py b/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py index e69de29..defc348 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py @@ -0,0 +1,3 @@ +from autogen_text_2_sql.autogen_text_2_sql import AutoGenText2Sql + +__all__ = ["AutoGenText2Sql"] diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py index f6d4e98..b021b01 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py @@ -1,31 +1,32 @@ -""" -Copyright (c) Microsoft Corporation. -Licensed under the MIT License. -""" +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from autogen_agentchat.conditions import ( TextMentionTermination, MaxMessageTermination, + SourceMatchTermination, ) from autogen_agentchat.teams import SelectorGroupChat from autogen_text_2_sql.creators.llm_model_creator import LLMModelCreator from autogen_text_2_sql.creators.llm_agent_creator import LLMAgentCreator import logging -from autogen_text_2_sql.custom_agents.sql_query_cache_agent import ( - SqlQueryCacheAgent, -) -from autogen_text_2_sql.custom_agents.sql_schema_selection_agent import ( - SqlSchemaSelectionAgent, -) -from autogen_text_2_sql.custom_agents.answer_and_sources_agent import ( - AnswerAndSourcesAgent, +from autogen_text_2_sql.custom_agents.parallel_query_solving_agent import ( + ParallelQuerySolvingAgent, ) from autogen_agentchat.agents import UserProxyAgent from autogen_agentchat.messages import TextMessage -from autogen_agentchat.base import Response import json import os from datetime import datetime +from text_2_sql_core.payloads import ( + AnswerWithSources, + UserInformationRequest, + ProcessingUpdate, + ChatHistoryItem, +) +from autogen_agentchat.base import Response, TaskResult +from typing import AsyncGenerator + class EmptyResponseUserProxyAgent(UserProxyAgent): """UserProxyAgent that automatically responds with empty messages.""" @@ -45,23 +46,9 @@ async def on_messages_stream(self, messages, sender=None, config=None): class AutoGenText2Sql: def __init__(self, engine_specific_rules: str, **kwargs: dict): - self.pre_run_query_cache = False self.target_engine = os.environ["Text2Sql__DatabaseEngine"].upper() self.engine_specific_rules = engine_specific_rules self.kwargs = kwargs - self.set_mode() - - def set_mode(self): - """Set the mode of the plugin based on the environment variables.""" - self.pre_run_query_cache = ( - os.environ.get("Text2Sql__PreRunQueryCache", "True").lower() == "true" - ) - self.use_column_value_store = ( - os.environ.get("Text2Sql__UseColumnValueStore", "True").lower() == "true" - ) - self.use_query_cache = ( - os.environ.get("Text2Sql__UseQueryCache", "True").lower() == "true" - ) def get_all_agents(self): """Get all agents for the complete flow.""" @@ -72,46 +59,11 @@ def get_all_agents(self): "query_rewrite_agent", current_datetime=current_datetime ) - self.sql_query_generation_agent = LLMAgentCreator.create( - "sql_query_generation_agent", - target_engine=self.target_engine, - engine_specific_rules=self.engine_specific_rules, - **self.kwargs, - ) - - # If relationship_paths not provided, use a generic template - if "relationship_paths" not in self.kwargs: - self.kwargs[ - "relationship_paths" - ] = """ - Common relationship paths to consider: - - Transaction → Related Dimensions (for basic analysis) - - Geographic → Location hierarchies (for geographic analysis) - - Temporal → Date hierarchies (for time-based analysis) - - Entity → Attributes (for entity-specific analysis) - """ - - self.sql_schema_selection_agent = SqlSchemaSelectionAgent( - target_engine=self.target_engine, - engine_specific_rules=self.engine_specific_rules, - **self.kwargs, - ) - - self.sql_query_correction_agent = LLMAgentCreator.create( - "sql_query_correction_agent", - target_engine=self.target_engine, - engine_specific_rules=self.engine_specific_rules, - **self.kwargs, - ) - - self.sql_disambiguation_agent = LLMAgentCreator.create( - "sql_disambiguation_agent", - target_engine=self.target_engine, - engine_specific_rules=self.engine_specific_rules, - **self.kwargs, + self.parallel_query_solving_agent = ParallelQuerySolvingAgent( + engine_specific_rules=self.engine_specific_rules, **self.kwargs ) - self.answer_and_sources_agent = AnswerAndSourcesAgent() + self.answer_agent = LLMAgentCreator.create("answer_agent") # Auto-responding UserProxyAgent self.user_proxy = EmptyResponseUserProxyAgent(name="user_proxy") @@ -119,17 +71,10 @@ def get_all_agents(self): agents = [ self.user_proxy, self.query_rewrite_agent, - self.sql_query_generation_agent, - self.sql_schema_selection_agent, - self.sql_query_correction_agent, - self.sql_disambiguation_agent, - self.answer_and_sources_agent, + self.parallel_query_solving_agent, + self.answer_agent, ] - if self.use_query_cache: - self.query_cache_agent = SqlQueryCacheAgent() - agents.append(self.query_cache_agent) - return agents @property @@ -137,63 +82,32 @@ def termination_condition(self): """Define the termination condition for the chat.""" termination = ( TextMentionTermination("TERMINATE") - | (TextMentionTermination("answer") & TextMentionTermination("sources")) - | MaxMessageTermination(20) + | SourceMatchTermination("answer_agent") + | TextMentionTermination("requires_user_information_request") + | MaxMessageTermination(5) ) return termination def unified_selector(self, messages): """Unified selector for the complete flow.""" logging.info("Messages: %s", messages) - current_agent = messages[-1].source if messages else "start" + current_agent = messages[-1].source if messages else "user" decision = None # If this is the first message start with query_rewrite_agent - if len(messages) == 1: + if current_agent == "user": decision = "query_rewrite_agent" # Handle transition after query rewriting elif current_agent == "query_rewrite_agent": - decision = ( - "sql_query_cache_agent" - if self.use_query_cache - else "sql_schema_selection_agent" - ) - # Handle subsequent agent transitions - elif current_agent == "sql_query_cache_agent": - # Always go through schema selection after cache check - decision = "sql_schema_selection_agent" - elif current_agent == "sql_schema_selection_agent": - decision = "sql_disambiguation_agent" - elif current_agent == "sql_disambiguation_agent": - decision = "sql_query_generation_agent" - elif current_agent == "sql_query_generation_agent": - decision = "sql_query_correction_agent" - elif current_agent == "sql_query_correction_agent": - try: - correction_result = json.loads(messages[-1].content) - if isinstance(correction_result, dict): - if "answer" in correction_result and "sources" in correction_result: - decision = "answer_and_sources_agent" - elif "corrected_query" in correction_result: - if correction_result.get("executing", False): - decision = "sql_query_correction_agent" - else: - decision = "sql_query_generation_agent" - elif "error" in correction_result: - decision = "sql_query_generation_agent" - elif isinstance(correction_result, list) and len(correction_result) > 0: - if "requested_fix" in correction_result[0]: - decision = "sql_query_generation_agent" - - if decision is None: - decision = "sql_query_generation_agent" - except json.JSONDecodeError: - decision = "sql_query_generation_agent" - elif current_agent == "answer_and_sources_agent": - decision = "user_proxy" # Let user_proxy send TERMINATE + decision = "parallel_query_solving_agent" + # Handle transition after parallel query solving + elif current_agent == "parallel_query_solving_agent": + decision = "answer_agent" if decision: logging.info(f"Agent transition: {current_agent} -> {decision}") + else: + logging.info(f"No agent transition defined from {current_agent}") return decision @@ -209,36 +123,115 @@ def agentic_flow(self): ) return flow + def extract_sources(self, messages: list) -> AnswerWithSources: + """Extract the sources from the answer.""" + + answer = messages[-1].content + + sql_query_results = messages[-2].content + + try: + sql_query_results = json.loads(sql_query_results) + + logging.info("SQL Query Results: %s", sql_query_results) + + sources = [] + + for question, sql_query_result_list in sql_query_results["results"].items(): + logging.info( + "SQL Query Result for question '%s': %s", + question, + sql_query_result_list, + ) + + for sql_query_result in sql_query_result_list: + logging.info("SQL Query Result: %s", sql_query_result) + sources.append( + { + "sql_query": sql_query_result["sql_query"], + "sql_rows": sql_query_result["sql_rows"], + } + ) + + except json.JSONDecodeError: + logging.error("Could not load message: %s", sql_query_results) + raise ValueError("Could not load message") + + return AnswerWithSources( + answer=answer, + sources=sources, + ) + async def process_question( self, - task: str, - chat_history: list[str] = None, - parameters: dict = None, - ): + question: str, + chat_history: list[ChatHistoryItem] = None, + injected_parameters: dict = None, + ) -> AsyncGenerator[AnswerWithSources | UserInformationRequest, None]: """Process the complete question through the unified system. Args: ---- task (str): The user question to process. chat_history (list[str], optional): The chat history. Defaults to None. - parameters (dict, optional): Parameters to pass to agents. Defaults to None. + injected_parameters (dict, optional): Parameters to pass to agents. Defaults to None. Returns: ------- dict: The response from the system. """ - logging.info("Processing question: %s", task) + logging.info("Processing question: %s", question) logging.info("Chat history: %s", chat_history) agent_input = { - "user_question": task, + "question": question, "chat_history": {}, - "parameters": parameters, + "injected_parameters": injected_parameters, } if chat_history is not None: # Update input for idx, chat in enumerate(chat_history): - agent_input[f"chat_{idx}"] = chat - - return self.agentic_flow.run_stream(task=json.dumps(agent_input)) + # For now only consider the user query + agent_input[f"chat_{idx}"] = chat.user_query + + async for message in self.agentic_flow.run_stream(task=json.dumps(agent_input)): + logging.debug("Message: %s", message) + + payload = None + + if isinstance(message, TextMessage): + if message.source == "query_rewrite_agent": + # If the message is from the query_rewrite_agent, we need to update the chat history + payload = ProcessingUpdate( + message="Rewriting the query...", + ) + elif message.source == "parallel_query_solving_agent": + # If the message is from the parallel_query_solving_agent, we need to update the chat history + payload = ProcessingUpdate( + message="Solving the query...", + ) + elif message.source == "answer_agent": + # If the message is from the answer_agent, we need to update the chat history + payload = ProcessingUpdate( + message="Generating the answer...", + ) + + elif isinstance(message, TaskResult): + # Now we need to return the final answer or the disambiguation request + logging.info("TaskResult: %s", message) + + if message.messages[-1].source == "answer_agent": + # If the message is from the answer_agent, we need to return the final answer + payload = self.extract_sources(message.messages) + elif message.messages[-1].source == "parallel_query_solving_agent": + payload = UserInformationRequest( + **json.loads(message.messages[-1].content), + ) + else: + logging.error("Unexpected TaskResult: %s", message) + raise ValueError("Unexpected TaskResult") + + if payload is not None: + logging.debug("Payload: %s", payload) + yield payload diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/answer_and_sources_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/answer_and_sources_agent.py deleted file mode 100644 index 3c1afcc..0000000 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/answer_and_sources_agent.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from typing import AsyncGenerator, List, Sequence - -from autogen_agentchat.agents import BaseChatAgent -from autogen_agentchat.base import Response -from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage -from autogen_core import CancellationToken -import json -from json import JSONDecodeError -import logging -import pandas as pd - - -class AnswerAndSourcesAgent(BaseChatAgent): - def __init__(self): - super().__init__( - "answer_and_sources_agent", - "An agent that formats the final answer and sources.", - ) - - @property - def produced_message_types(self) -> List[type[ChatMessage]]: - return [TextMessage] - - async def on_messages( - self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken - ) -> Response: - # Calls the on_messages_stream. - response: Response | None = None - async for message in self.on_messages_stream(messages, cancellation_token): - if isinstance(message, Response): - response = message - assert response is not None - return response - - async def on_messages_stream( - self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken - ) -> AsyncGenerator[AgentMessage | Response, None]: - last_response = messages[-1].content - - # Load the json of the last message to populate the final output object - final_output_object = json.loads(last_response) - final_output_object["sources"] = [] - - for message in messages: - # Load the message content if it is a json object and was a query execution - try: - message = json.loads(message.content) - logging.info(f"Loaded: {message}") - - # Search for specific message types and add them to the final output object - if ( - "type" in message - and message["type"] == "query_execution_with_limit" - ): - dataframe = pd.DataFrame(message["sql_rows"]) - final_output_object["sources"].append( - { - "sql_query": message["sql_query"].replace("\n", " "), - "sql_rows": message["sql_rows"], - "markdown_table": dataframe.to_markdown(index=False), - } - ) - - except JSONDecodeError: - logging.info(f"Could not load message: {message}") - continue - - except Exception as e: - logging.error(f"Error processing message: {e}") - raise e - - yield Response( - chat_message=TextMessage( - content=json.dumps(final_output_object), source=self.name - ) - ) - - async def on_reset(self, cancellation_token: CancellationToken) -> None: - pass diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py new file mode 100644 index 0000000..5a5c707 --- /dev/null +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py @@ -0,0 +1,153 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from typing import AsyncGenerator, List, Sequence + +from autogen_agentchat.agents import BaseChatAgent +from autogen_agentchat.base import Response, TaskResult +from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage +from autogen_core import CancellationToken +import json +import logging +from autogen_text_2_sql.inner_autogen_text_2_sql import InnerAutoGenText2Sql +from aiostream import stream +from json import JSONDecodeError + + +class ParallelQuerySolvingAgent(BaseChatAgent): + def __init__(self, engine_specific_rules: str, **kwargs: dict): + super().__init__( + "parallel_query_solving_agent", + "An agent that solves each query in parallel.", + ) + + self.engine_specific_rules = engine_specific_rules + self.kwargs = kwargs + + @property + def produced_message_types(self) -> List[type[ChatMessage]]: + return [TextMessage] + + async def on_messages( + self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken + ) -> Response: + # Calls the on_messages_stream. + response: Response | None = None + async for message in self.on_messages_stream(messages, cancellation_token): + if isinstance(message, Response): + response = message + assert response is not None + return response + + async def on_messages_stream( + self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken + ) -> AsyncGenerator[AgentMessage | Response, None]: + last_response = messages[-1].content + parameter_input = messages[0].content + try: + user_parameters = json.loads(parameter_input)["parameters"] + except json.JSONDecodeError: + logging.error("Error decoding the user parameters.") + user_parameters = {} + + # Load the json of the last message to populate the final output object + query_rewrites = json.loads(last_response) + + logging.info(f"Query Rewrites: {query_rewrites}") + + async def consume_inner_messages_from_agentic_flow( + agentic_flow, identifier, database_results + ): + """ + Consume the inner messages and append them to the specified list. + + Args: + ---- + agentic_flow: The async generator to consume messages from. + messages_list: The list to which messages should be added. + """ + async for inner_message in agentic_flow: + # Add message to results dictionary, tagged by the function name + if identifier not in database_results: + database_results[identifier] = [] + + logging.info(f"Checking Inner Message: {inner_message}") + + if isinstance(inner_message, TaskResult) is False: + try: + inner_message = json.loads(inner_message.content) + logging.info(f"Loaded: {inner_message}") + + # Search for specific message types and add them to the final output object + if ( + "type" in inner_message + and inner_message["type"] == "query_execution_with_limit" + ): + database_results[identifier].append( + { + "sql_query": inner_message["sql_query"].replace( + "\n", " " + ), + "sql_rows": inner_message["sql_rows"], + } + ) + + except (JSONDecodeError, TypeError) as e: + logging.error("Could not load message: %s", inner_message) + logging.warning(f"Error processing message: {e}") + + except Exception as e: + logging.error("Could not load message: %s", inner_message) + logging.error(f"Error processing message: {e}") + raise e + + yield inner_message + + inner_solving_generators = [] + database_results = {} + + # Start processing sub-queries + for query_rewrite in query_rewrites["sub_queries"]: + logging.info(f"Processing sub-query: {query_rewrite}") + # Create an instance of the InnerAutoGenText2Sql class + inner_autogen_text_2_sql = InnerAutoGenText2Sql( + self.engine_specific_rules, **self.kwargs + ) + + # Launch tasks for each sub-query + inner_solving_generators.append( + consume_inner_messages_from_agentic_flow( + inner_autogen_text_2_sql.process_question( + question=query_rewrite, parameters=user_parameters + ), + query_rewrite, + database_results, + ) + ) + + logging.info( + "Created %i Inner Solving Generators", len(inner_solving_generators) + ) + logging.info("Starting Inner Solving Generators") + combined_message_streams = stream.merge(*inner_solving_generators) + + async with combined_message_streams.stream() as streamer: + async for inner_message in streamer: + if isinstance(inner_message, TextMessage): + logging.debug(f"Inner Solving Message: {inner_message}") + yield inner_message + + # Log final results for debugging or auditing + logging.info(f"Database Results: {database_results}") + + # Final response + yield Response( + chat_message=TextMessage( + content=json.dumps( + {"contains_results": True, "results": database_results} + ), + source=self.name, + ), + ) + + async def on_reset(self, cancellation_token: CancellationToken) -> None: + pass diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py index 0a96932..83fb13f 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py @@ -43,9 +43,9 @@ async def on_messages_stream( last_response = messages[-1].content try: user_questions = json.loads(last_response) - user_parameters = json.loads(parameter_input)["parameters"] + injected_parameters = json.loads(parameter_input)["injected_parameters"] logging.info(f"Processing questions: {user_questions}") - logging.info(f"Input Parameters: {user_parameters}") + logging.info(f"Input Parameters: {injected_parameters}") # Initialize results dictionary cached_results = { @@ -58,7 +58,7 @@ async def on_messages_stream( # Fetch the queries from the cache based on the question logging.info(f"Fetching queries from cache for question: {question}") cached_query = await self.sql_connector.fetch_queries_from_cache( - question, parameters=user_parameters + question, injected_parameters=injected_parameters ) # If any question has pre-run results, set the flag diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py new file mode 100644 index 0000000..9c70a1b --- /dev/null +++ b/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py @@ -0,0 +1,213 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from autogen_agentchat.conditions import ( + TextMentionTermination, + MaxMessageTermination, +) +from autogen_agentchat.teams import SelectorGroupChat +from autogen_text_2_sql.creators.llm_model_creator import LLMModelCreator +from autogen_text_2_sql.creators.llm_agent_creator import LLMAgentCreator +import logging +from autogen_text_2_sql.custom_agents.sql_query_cache_agent import ( + SqlQueryCacheAgent, +) +from autogen_text_2_sql.custom_agents.sql_schema_selection_agent import ( + SqlSchemaSelectionAgent, +) +from autogen_agentchat.agents import UserProxyAgent +from autogen_agentchat.messages import TextMessage +from autogen_agentchat.base import Response +import json +import os + + +class EmptyResponseUserProxyAgent(UserProxyAgent): + """UserProxyAgent that automatically responds with empty messages.""" + + def __init__(self, name): + super().__init__(name=name) + self._has_responded = False + + async def on_messages_stream(self, messages, sender=None, config=None): + """Auto-respond with empty message and return Response object.""" + message = TextMessage(content="", source=self.name) + if not self._has_responded: + self._has_responded = True + yield message + yield Response(chat_message=message) + + +class InnerAutoGenText2Sql: + def __init__(self, engine_specific_rules: str, **kwargs: dict): + self.pre_run_query_cache = False + self.target_engine = os.environ["Text2Sql__DatabaseEngine"].upper() + self.engine_specific_rules = engine_specific_rules + self.kwargs = kwargs + self.set_mode() + + def set_mode(self): + """Set the mode of the plugin based on the environment variables.""" + self.pre_run_query_cache = ( + os.environ.get("Text2Sql__PreRunQueryCache", "True").lower() == "true" + ) + self.use_column_value_store = ( + os.environ.get("Text2Sql__UseColumnValueStore", "True").lower() == "true" + ) + self.use_query_cache = ( + os.environ.get("Text2Sql__UseQueryCache", "True").lower() == "true" + ) + + def get_all_agents(self): + """Get all agents for the complete flow.""" + # Get current datetime for the Query Rewrite Agent + self.sql_query_generation_agent = LLMAgentCreator.create( + "sql_query_generation_agent", + target_engine=self.target_engine, + engine_specific_rules=self.engine_specific_rules, + **self.kwargs, + ) + + # If relationship_paths not provided, use a generic template + if "relationship_paths" not in self.kwargs: + self.kwargs[ + "relationship_paths" + ] = """ + Common relationship paths to consider: + - Transaction → Related Dimensions (for basic analysis) + - Geographic → Location hierarchies (for geographic analysis) + - Temporal → Date hierarchies (for time-based analysis) + - Entity → Attributes (for entity-specific analysis) + """ + + self.sql_schema_selection_agent = SqlSchemaSelectionAgent( + target_engine=self.target_engine, + engine_specific_rules=self.engine_specific_rules, + **self.kwargs, + ) + + self.sql_query_correction_agent = LLMAgentCreator.create( + "sql_query_correction_agent", + target_engine=self.target_engine, + engine_specific_rules=self.engine_specific_rules, + **self.kwargs, + ) + + self.sql_disambiguation_agent = LLMAgentCreator.create( + "sql_disambiguation_agent", + target_engine=self.target_engine, + engine_specific_rules=self.engine_specific_rules, + **self.kwargs, + ) + + # Auto-responding UserProxyAgent + self.user_proxy = EmptyResponseUserProxyAgent(name="user_proxy") + + agents = [ + self.user_proxy, + self.sql_query_generation_agent, + self.sql_schema_selection_agent, + self.sql_query_correction_agent, + self.sql_disambiguation_agent, + ] + + if self.use_query_cache: + self.query_cache_agent = SqlQueryCacheAgent() + agents.append(self.query_cache_agent) + + return agents + + @property + def termination_condition(self): + """Define the termination condition for the chat.""" + termination = TextMentionTermination("TERMINATE") | MaxMessageTermination(10) + return termination + + def unified_selector(self, messages): + """Unified selector for the complete flow.""" + """Unified selector for the complete flow.""" + logging.info("Messages: %s", messages) + current_agent = messages[-1].source if messages else "user" + decision = None + + if current_agent == "user": + decision = ( + "sql_query_cache_agent" + if self.use_query_cache + else "sql_schema_selection_agent" + ) + # Handle subsequent agent transitions + elif current_agent == "sql_query_cache_agent": + # Always go through schema selection after cache check + decision = "sql_schema_selection_agent" + elif current_agent == "sql_schema_selection_agent": + decision = "sql_disambiguation_agent" + elif current_agent == "sql_disambiguation_agent": + decision = "sql_query_generation_agent" + elif current_agent == "sql_query_generation_agent": + decision = "sql_query_correction_agent" + elif current_agent == "sql_query_correction_agent": + try: + correction_result = json.loads(messages[-1].content) + if isinstance(correction_result, dict): + if "answer" in correction_result and "sources" in correction_result: + decision = "user_proxy" + elif "corrected_query" in correction_result: + if correction_result.get("executing", False): + decision = "sql_query_correction_agent" + else: + decision = "sql_query_generation_agent" + elif "error" in correction_result: + decision = "sql_query_generation_agent" + elif isinstance(correction_result, list) and len(correction_result) > 0: + if "requested_fix" in correction_result[0]: + decision = "sql_query_generation_agent" + + if decision is None: + decision = "sql_query_generation_agent" + except json.JSONDecodeError: + decision = "sql_query_generation_agent" + + if decision: + logging.info(f"Agent transition: {current_agent} -> {decision}") + else: + logging.info(f"No agent transition defined from {current_agent}") + + return decision + + @property + def agentic_flow(self): + """Create the unified flow for the complete process.""" + flow = SelectorGroupChat( + self.get_all_agents(), + allow_repeated_speaker=False, + model_client=LLMModelCreator.get_model("4o-mini"), + termination_condition=self.termination_condition, + selector_func=self.unified_selector, + ) + return flow + + def process_question( + self, + question: str, + injected_parameters: dict = None, + ): + """Process the complete question through the unified system. + + Args: + ---- + task (str): The user question to process. + injected_parameters (dict, optional): Parameters to pass to agents. Defaults to None. + + Returns: + ------- + dict: The response from the system. + """ + logging.info("Processing question: %s", question) + + agent_input = { + "question": question, + "chat_history": {}, + "injected_parameters": injected_parameters, + } + + return self.agentic_flow.run_stream(task=json.dumps(agent_input)) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py index 2309128..0ee6ead 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py @@ -143,7 +143,7 @@ async def query_validation( sqlglot.transpile( sql_query, read=self.database_engine.value.lower(), - error_level=sqlglot.ErrorLevel.ERROR, + error_level=sqlglot.ErrorLevel.RAISE, ) except sqlglot.errors.ParseError as e: logging.error("SQL Query is invalid: %s", e.errors) @@ -153,7 +153,7 @@ async def query_validation( return True async def fetch_queries_from_cache( - self, question: str, parameters: dict = None + self, question: str, injected_parameters: dict = None ) -> str: """Fetch the queries from the cache based on the question. @@ -166,21 +166,21 @@ async def fetch_queries_from_cache( str: The formatted string of the queries fetched from the cache. This is injected into the prompt. """ - if parameters is None: - parameters = {} + if injected_parameters is None: + injected_parameters = {} - # Populate the parameters - if "date" not in parameters: - parameters["date"] = self.get_current_date() + # Populate the injected_parameters + if "date" not in injected_parameters: + injected_parameters["date"] = self.get_current_date() - if "time" not in parameters: - parameters["time"] = self.get_current_time() + if "time" not in injected_parameters: + injected_parameters["time"] = self.get_current_time() - if "datetime" not in parameters: - parameters["datetime"] = self.get_current_datetime() + if "datetime" not in injected_parameters: + injected_parameters["datetime"] = self.get_current_datetime() - if "unix_timestamp" not in parameters: - parameters["unix_timestamp"] = self.get_current_unix_timestamp() + if "unix_timestamp" not in injected_parameters: + injected_parameters["unix_timestamp"] = self.get_current_unix_timestamp() cached_schemas = await self.ai_search_connector.run_ai_search_query( question, @@ -204,7 +204,7 @@ async def fetch_queries_from_cache( sql_queries = schema["SqlQueryDecomposition"] for sql_query in sql_queries: sql_query["SqlQuery"] = Template(sql_query["SqlQuery"]).render( - **parameters + **injected_parameters ) logging.info("Cached schemas: %s", cached_schemas) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/__init__.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/__init__.py new file mode 100644 index 0000000..e3d590c --- /dev/null +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/__init__.py @@ -0,0 +1,12 @@ +from text_2_sql_core.payloads.answer_with_sources import AnswerWithSources, Source +from text_2_sql_core.payloads.user_information_request import UserInformationRequest +from text_2_sql_core.payloads.processing_update import ProcessingUpdate +from text_2_sql_core.payloads.chat_history import ChatHistoryItem + +__all__ = [ + "AnswerWithSources", + "Source", + "UserInformationRequest", + "ProcessingUpdate", + "ChatHistoryItem", +] diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/answer_with_sources.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/answer_with_sources.py new file mode 100644 index 0000000..650b0d4 --- /dev/null +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/answer_with_sources.py @@ -0,0 +1,11 @@ +from pydantic import BaseModel, Field + + +class Source(BaseModel): + sql_query: str + sql_rows: list[dict] + + +class AnswerWithSources(BaseModel): + answer: str + sources: list[Source] = Field(default_factory=list) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/chat_history.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/chat_history.py new file mode 100644 index 0000000..06b27cb --- /dev/null +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/chat_history.py @@ -0,0 +1,9 @@ +from pydantic import BaseModel +from text_2_sql_core.payloads.answer_with_sources import AnswerWithSources + + +class ChatHistoryItem(BaseModel): + """Chat history item with user message and agent response.""" + + user_query: str + agent_response: AnswerWithSources diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/processing_update.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/processing_update.py new file mode 100644 index 0000000..3950508 --- /dev/null +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/processing_update.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel, Field + + +class ProcessingUpdate(BaseModel): + title: str | None = Field(default="Processing...") + message: str | None = Field(default="Processing...") diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/user_information_request.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/user_information_request.py new file mode 100644 index 0000000..1aac4c0 --- /dev/null +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/user_information_request.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel, RootModel, Field +from enum import StrEnum +from typing import Literal + + +class RequestType(StrEnum): + DISAMBIGUATION = "disambiguation" + CLARIFICATION = "clarification" + + +class ClarificationRequest(BaseModel): + request_type: Literal[RequestType.CLARIFICATION] + question: str + other_user_choices: list[str] + + +class DismabiguationRequest(BaseModel): + request_type: Literal[RequestType.DISAMBIGUATION] + question: str + matching_columns: list[str] + matching_filter_values: list[str] + other_user_choices: list[str] + + +class UserInformationRequest(RootModel): + root: DismabiguationRequest = Field(..., discriminator="request_type") diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/answer_agent.yaml b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/answer_agent.yaml new file mode 100644 index 0000000..8a4a797 --- /dev/null +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/answer_agent.yaml @@ -0,0 +1,12 @@ +model: "4o-mini" +description: "An agent that generates a response to a user's question." +system_message: | + + You are a helpful AI Assistant specializing in answering a user's question about {{ use_case }}. + + + Use the information obtained to generate a response to the user's question. The question has been broken down into a series of SQL queries and you need to generate a response based on the results of these queries. + + Do not use any external resources to generate the response. The response should be based solely on the information provided in the SQL queries and their results. + + You can use Markdown and Markdown tables to format the response. diff --git a/uv.lock b/uv.lock index e2acbb2..01d5125 100644 --- a/uv.lock +++ b/uv.lock @@ -164,6 +164,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/76/ac/a7305707cb852b7e16ff80eaf5692309bde30e2b1100a1fcacdc8f731d97/aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17", size = 7617 }, ] +[[package]] +name = "aiostream" +version = "0.6.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a5/fe/aa14603fcd5cc4333f81791bbdf58cd4cb8677c8e21e3cc691d27c00173f/aiostream-0.6.4.tar.gz", hash = "sha256:f99bc6b1b9cea3e70885dc235a233523597555fe4a585ed21d65264b3f1ff3d2", size = 67983 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/71/5c/639dc59441df1d5cec49a09c36eb89af651476760f950b7d018bdf0ec4a7/aiostream-0.6.4-py3-none-any.whl", hash = "sha256:bd8c6a8b90a52c0325a3b19406f0f2a131448e596c06398886f5be1c73b4cea9", size = 53665 }, +] + [[package]] name = "annotated-types" version = "0.7.0" @@ -347,6 +359,7 @@ name = "autogen-text-2-sql" version = "0.1.0" source = { virtual = "text_2_sql/autogen" } dependencies = [ + { name = "aiostream" }, { name = "autogen-agentchat" }, { name = "autogen-core" }, { name = "autogen-ext", extra = ["azure", "openai"] }, @@ -369,6 +382,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "aiostream", specifier = ">=0.6.4" }, { name = "autogen-agentchat", specifier = "==0.4.0.dev11" }, { name = "autogen-core", specifier = "==0.4.0.dev11" }, { name = "autogen-ext", extras = ["azure", "openai"], specifier = "==0.4.0.dev11" }, @@ -3301,6 +3315,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695", size = 24521 }, ] +[[package]] +name = "tabulate" +version = "0.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ec/fe/802052aecb21e3797b8f7902564ab6ea0d60ff8ca23952079064155d1ae1/tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c", size = 81090 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/40/44/4a5f08c96eb108af5cb50b41f76142f0afa346dfa99d5296fe7202a11854/tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f", size = 35252 }, +] + [[package]] name = "tenacity" version = "9.0.0" @@ -3337,11 +3360,13 @@ dependencies = [ { name = "networkx" }, { name = "numpy" }, { name = "openai" }, + { name = "pandas" }, { name = "pydantic" }, { name = "python-dotenv" }, { name = "pyyaml" }, { name = "rich" }, { name = "sqlglot", extra = ["rs"] }, + { name = "tabulate" }, { name = "tenacity" }, { name = "typer" }, ] @@ -3377,6 +3402,7 @@ requires-dist = [ { name = "networkx", specifier = ">=3.4.2" }, { name = "numpy", specifier = "<2.0.0" }, { name = "openai", specifier = ">=1.55.3" }, + { name = "pandas", specifier = ">=2.2.3" }, { name = "pyarrow", marker = "extra == 'databricks'", specifier = ">=14.0.2,<17" }, { name = "pydantic", specifier = ">=2.10.2" }, { name = "python-dotenv", specifier = ">=1.0.1" }, @@ -3384,6 +3410,7 @@ requires-dist = [ { name = "rich", specifier = ">=13.9.4" }, { name = "snowflake-connector-python", marker = "extra == 'snowflake'", specifier = ">=3.12.3" }, { name = "sqlglot", extras = ["rs"], specifier = ">=25.32.1" }, + { name = "tabulate", specifier = ">=0.9.0" }, { name = "tenacity", specifier = ">=9.0.0" }, { name = "typer", specifier = ">=0.14.0" }, ]