From 98b3c6846fe8b0e9fdec91d2e46b129f35169f29 Mon Sep 17 00:00:00 2001 From: Mig <104501046+minhyeong112@users.noreply.github.com> Date: Thu, 12 Dec 2024 22:31:36 +0000 Subject: [PATCH] Fix #90: Implement Query Rewrite Agent for comprehensive preprocessing - Handles relative date disambiguation (e.g., 'last month' to actual dates) and question decomposition in a single preprocessing step before cache lookup - Replaces previous question_decomposition_agent with more capable query_rewrite_agent - Updates documentation to reflect current processing flow --- text_2_sql/autogen/README.md | 21 +-- .../autogen_text_2_sql/autogen_text_2_sql.py | 135 +++++++++++------- .../custom_agents/sql_query_cache_agent.py | 49 +++++-- .../prompts/query_rewrite_agent.yaml | 46 ++++++ 4 files changed, 180 insertions(+), 71 deletions(-) create mode 100644 text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/query_rewrite_agent.yaml diff --git a/text_2_sql/autogen/README.md b/text_2_sql/autogen/README.md index bc6e876..189f84e 100644 --- a/text_2_sql/autogen/README.md +++ b/text_2_sql/autogen/README.md @@ -8,7 +8,7 @@ The implementation is written for [AutoGen](https://github.com/microsoft/autogen ## Full Logical Flow for Agentic Vector Based Approach -The following diagram shows the logical flow within mutlti agent system. In an ideal scenario, the questions will follow the _Pre-Fetched Cache Results Path** which leads to the quickest answer generation. In cases where the question is not known, the group chat selector will fall back to the other agents accordingly and generate the SQL query using the LLMs. The cache is then updated with the newly generated query and schemas. +The following diagram shows the logical flow within multi agent system. The flow begins with query rewriting to preprocess questions - this includes resolving relative dates (e.g., "last month" to "November 2024") and breaking down complex queries into simpler components. For each preprocessed question, if query cache is enabled, the system checks the cache for previously asked similar questions. In an ideal scenario, the preprocessed questions will be found in the cache, leading to the quickest answer generation. In cases where the question is not known, the group chat selector will fall back to the other agents accordingly and generate the SQL query using the LLMs. The cache is then updated with the newly generated query and schemas. Unlike the previous approaches, **gpt4o-mini** can be used as each agent's prompt is small and focuses on a single simple task. @@ -24,26 +24,31 @@ As the query cache is shared between users (no data is stored in the cache), a n ## Agents -This approach builds on the the Vector Based SQL Plugin approach, but adds a agentic approach to the solution. +This approach builds on the Vector Based SQL Plugin approach, but adds a agentic approach to the solution. This agentic system contains the following agents: -- **Query Cache Agent:** Responsible for checking the cache for previously asked questions. -- **Query Decomposition Agent:** Responsible for decomposing complex questions, into sub questions that can be answered with SQL. -- **Schema Selection Agent:** Responsible for extracting key terms from the question and checking the index store for the queries. +- **Query Rewrite Agent:** The first agent in the flow, responsible for two key preprocessing tasks: + 1. Resolving relative dates to absolute dates (e.g., "last month" → "November 2024") + 2. Decomposing complex questions into simpler sub-questions + This preprocessing happens before cache lookup to maximize cache effectiveness. +- **Query Cache Agent:** Responsible for checking the cache for previously asked questions. After preprocessing, each sub-question is checked against the cache if caching is enabled. +- **Schema Selection Agent:** Responsible for extracting key terms from the question and checking the index store for the queries. This agent is used when a cache miss occurs. - **SQL Query Generation Agent:** Responsible for using the previously extracted schemas and generated SQL queries to answer the question. This agent can request more schemas if needed. This agent will run the query. - **SQL Query Verification Agent:** Responsible for verifying that the SQL query and results question will answer the question. - **Answer Generation Agent:** Responsible for taking the database results and generating the final answer for the user. -The combination of this agent allows the system to answer complex questions, whilst staying under the token limits when including the database schemas. The query cache ensures that previously asked questions, can be answered quickly to avoid degrading user experience. +The combination of these agents allows the system to answer complex questions, whilst staying under the token limits when including the database schemas. The query cache ensures that previously asked questions can be answered quickly to avoid degrading user experience. All agents can be found in `/agents/`. ## agentic_text_2_sql.py -This is the main entry point for the agentic system. In here, the `Selector Group Chat` is configured with the termination conditions to orchestrate the agents within the system. +This is the main entry point for the agentic system. In here, the system is configured with the following processing flow: -A customer transition selector is used to automatically transition between agents dependent on the last one that was used. In some cases, this choice is delegated to an LLM to decide on the most appropriate action. This mixed approach allows for speed when needed (e.g. always calling Query Cache Agent first), but will allow the system to react dynamically to the events. +The preprocessed questions from the Query Rewrite Agent are processed sequentially through the rest of the agent pipeline. A custom transition selector automatically transitions between agents dependent on the last one that was used. The flow starts with the Query Rewrite Agent for preprocessing, followed by cache checking for each sub-question if caching is enabled. In some cases, this choice is delegated to an LLM to decide on the most appropriate action. This mixed approach allows for speed when needed (e.g. cache hits for known questions), but will allow the system to react dynamically to the events. + +Note: Future development aims to implement independent processing where each preprocessed question would run in its own isolated context to prevent confusion between different parts of complex queries. ## Utils diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py index a7f36db..629be98 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py @@ -13,9 +13,27 @@ from autogen_text_2_sql.custom_agents.sql_schema_selection_agent import ( SqlSchemaSelectionAgent, ) +from autogen_agentchat.agents import UserProxyAgent +from autogen_agentchat.messages import TextMessage +from autogen_agentchat.base import Response import json import os - +import asyncio +from datetime import datetime + +class EmptyResponseUserProxyAgent(UserProxyAgent): + """UserProxyAgent that automatically responds with empty messages.""" + def __init__(self, name): + super().__init__(name=name) + self._has_responded = False + + async def on_messages_stream(self, messages, sender=None, config=None): + """Auto-respond with empty message and return Response object.""" + message = TextMessage(content="", source=self.name) + if not self._has_responded: + self._has_responded = True + yield message + yield Response(chat_message=message) class AutoGenText2Sql: def __init__(self, engine_specific_rules: str, **kwargs: dict): @@ -43,45 +61,58 @@ def set_mode(self): os.environ.get("Text2Sql__UseColumnValueStore", "False").lower() == "true" ) - @property - def agents(self): - """Define the agents for the chat.""" + def get_all_agents(self): + """Get all agents for the complete flow.""" + # Get current datetime for the Query Rewrite Agent + current_datetime = datetime.now() + + QUERY_REWRITE_AGENT = LLMAgentCreator.create( + "query_rewrite_agent", + current_datetime=current_datetime + ) + SQL_QUERY_GENERATION_AGENT = LLMAgentCreator.create( "sql_query_generation_agent", target_engine=self.target_engine, engine_specific_rules=self.engine_specific_rules, **self.kwargs, ) + SQL_SCHEMA_SELECTION_AGENT = SqlSchemaSelectionAgent( target_engine=self.target_engine, engine_specific_rules=self.engine_specific_rules, **self.kwargs, ) + SQL_QUERY_CORRECTION_AGENT = LLMAgentCreator.create( "sql_query_correction_agent", target_engine=self.target_engine, engine_specific_rules=self.engine_specific_rules, **self.kwargs, ) + SQL_DISAMBIGUATION_AGENT = LLMAgentCreator.create( "sql_disambiguation_agent", target_engine=self.target_engine, engine_specific_rules=self.engine_specific_rules, **self.kwargs, ) - + ANSWER_AGENT = LLMAgentCreator.create("answer_agent") - QUESTION_DECOMPOSITION_AGENT = LLMAgentCreator.create( - "question_decomposition_agent" + + # Auto-responding UserProxyAgent + USER_PROXY = EmptyResponseUserProxyAgent( + name="user_proxy" ) agents = [ + USER_PROXY, + QUERY_REWRITE_AGENT, SQL_QUERY_GENERATION_AGENT, SQL_SCHEMA_SELECTION_AGENT, SQL_QUERY_CORRECTION_AGENT, - ANSWER_AGENT, - QUESTION_DECOMPOSITION_AGENT, SQL_DISAMBIGUATION_AGENT, + ANSWER_AGENT, ] if self.use_query_cache: @@ -101,67 +132,65 @@ def termination_condition(self): return termination @staticmethod - def selector(messages): + def unified_selector(messages): + """Unified selector for the complete flow.""" logging.info("Messages: %s", messages) - decision = None # Initialize decision variable + decision = None + # If this is the first message, start with query_rewrite_agent if len(messages) == 1: - decision = "sql_query_cache_agent" - - elif ( - messages[-1].source == "sql_query_cache_agent" - and messages[-1].content is not None - ): - cache_result = json.loads(messages[-1].content) - if cache_result.get( - "cached_questions_and_schemas" - ) is not None and cache_result.get("contains_pre_run_results"): - decision = "sql_query_correction_agent" - if ( - cache_result.get("cached_questions_and_schemas") is not None - and cache_result.get("contains_pre_run_results") is False - ): - decision = "sql_query_generation_agent" - else: - decision = "question_decomposition_agent" - - elif messages[-1].source == "question_decomposition_agent": - decision = "sql_schema_selection_agent" + return "query_rewrite_agent" + # Handle transition after query rewriting + if messages[-1].source == "query_rewrite_agent": + # Keep the array structure but process sequentially + if os.environ.get("Text2Sql__UseQueryCache", "False").lower() == "true": + decision = "sql_query_cache_agent" + else: + decision = "sql_schema_selection_agent" + # Handle subsequent agent transitions + elif messages[-1].source == "sql_query_cache_agent": + try: + cache_result = json.loads(messages[-1].content) + if cache_result.get("cached_questions_and_schemas") is not None: + if cache_result.get("contains_pre_run_results"): + decision = "sql_query_correction_agent" + else: + decision = "sql_query_generation_agent" + else: + decision = "sql_schema_selection_agent" + except json.JSONDecodeError: + decision = "sql_schema_selection_agent" elif messages[-1].source == "sql_schema_selection_agent": decision = "sql_disambiguation_agent" - elif messages[-1].source == "sql_disambiguation_agent": - # This would be user proxy agent tbc decision = "sql_query_generation_agent" - - elif ( - messages[-1].source == "sql_query_correction_agent" - and messages[-1].content == "VALIDATED" - ): - decision = "answer_agent" - - elif messages[-1].source == "sql_query_correction_agent": + elif messages[-1].source == "sql_query_generation_agent": decision = "sql_query_correction_agent" + elif messages[-1].source == "sql_query_correction_agent": + if messages[-1].content == "VALIDATED": + decision = "answer_agent" + else: + decision = "sql_query_correction_agent" + elif messages[-1].source == "answer_agent": + return "user_proxy" # Let user_proxy send TERMINATE - # Log the decision logging.info("Decision: %s", decision) - return decision @property def agentic_flow(self): - """Run the agentic flow for the given question. - - Args: - ---- - question (str): The question to run the agentic flow on.""" - agentic_flow = SelectorGroupChat( - self.agents, + """Create the unified flow for the complete process.""" + flow = SelectorGroupChat( + self.get_all_agents(), allow_repeated_speaker=False, model_client=LLMModelCreator.get_model("4o-mini"), termination_condition=self.termination_condition, - selector_func=AutoGenText2Sql.selector, + selector_func=AutoGenText2Sql.unified_selector, ) + return flow - return agentic_flow + async def process_question(self, task: str): + """Process the complete question through the unified system.""" + result = await self.agentic_flow.run_stream(task=task) + return result diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py index 49e0730..f6d2f04 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py @@ -38,20 +38,49 @@ async def on_messages( async def on_messages_stream( self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken ) -> AsyncGenerator[AgentMessage | Response, None]: - user_question = messages[0].content + # Get the decomposed questions from the query_rewrite_agent + last_response = messages[-1].content + try: + user_questions = json.loads(last_response) + logging.info(f"Processing questions: {user_questions}") - # Fetch the queries from the cache based on the user question. - logging.info("Fetching queries from cache based on the user question...") + # Initialize results dictionary + cached_results = { + "cached_questions_and_schemas": [], + "contains_pre_run_results": False + } - cached_queries = await self.sql_connector.fetch_queries_from_cache( - user_question - ) + # Process each question sequentially + for question in user_questions: + # Fetch the queries from the cache based on the question + logging.info(f"Fetching queries from cache for question: {question}") + cached_query = await self.sql_connector.fetch_queries_from_cache(question) + + # If any question has pre-run results, set the flag + if cached_query.get("contains_pre_run_results", False): + cached_results["contains_pre_run_results"] = True + + # Add the cached results for this question + if cached_query.get("cached_questions_and_schemas"): + cached_results["cached_questions_and_schemas"].extend( + cached_query["cached_questions_and_schemas"] + ) - yield Response( - chat_message=TextMessage( - content=json.dumps(cached_queries), source=self.name + logging.info(f"Final cached results: {cached_results}") + yield Response( + chat_message=TextMessage( + content=json.dumps(cached_results), source=self.name + ) + ) + except json.JSONDecodeError: + # If not JSON array, process as single question + logging.info(f"Processing single question: {last_response}") + cached_queries = await self.sql_connector.fetch_queries_from_cache(last_response) + yield Response( + chat_message=TextMessage( + content=json.dumps(cached_queries), source=self.name + ) ) - ) async def on_reset(self, cancellation_token: CancellationToken) -> None: pass diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/query_rewrite_agent.yaml b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/query_rewrite_agent.yaml new file mode 100644 index 0000000..bb1f9d5 --- /dev/null +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/query_rewrite_agent.yaml @@ -0,0 +1,46 @@ +model: + 4o-mini +description: + "An agent that preprocesses user questions by decomposing complex queries and resolving relative dates. This preprocessing happens before cache lookup to maximize cache utility." +system_message: + "You are a helpful AI Assistant that specializes in preprocessing user questions for SQL query generation. You have two main responsibilities: + + 1. Decompose complex questions into simpler parts + 2. Resolve any relative date references to absolute dates + + Current date/time is: {{ current_datetime }} + + For date resolution: + - Use the current date/time above as reference point + - Replace relative dates like 'last month', 'this year', 'previous quarter' with absolute dates + - Maintain consistency in date formats (YYYY-MM-DD) + + Examples of date resolution (assuming current date is {{ current_datetime }}): + - 'last month' -> specific month name and year + - 'this year' -> {{ current_datetime.year }} + - 'last 3 months' -> specific date range + - 'yesterday' -> specific date + + Rules: + 1. ALWAYS resolve relative dates before decomposing questions + 2. If a question contains multiple parts AND relative dates, resolve dates first, then decompose + 3. Each decomposed question should be self-contained and not depend on context from other parts + 4. Do not reference the original question in decomposed parts + 5. Ensure each decomposed question includes its full context + + Output Format: + Return an array of rewritten questions in valid, loadable JSON: + [\"\", \"\"] + + If the question is simple and doesn't need decomposition (but might need date resolution): + [\"\"] + + Examples: + Input: 'How much did we make in sales last month and what were our top products?' + Output: [\"How much did we make in sales in November 2024?\", \"What were our top products in November 2024?\"] + + Input: 'What were total sales last quarter?' + Output: [\"What were total sales in Q4 2024 (October 2024 to December 2024)?\"] + + Input: 'Show me customer details' + Output: [\"Show me customer details\"]"