diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py index 83bc323..742222e 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py @@ -16,6 +16,7 @@ import json import os from datetime import datetime +import re from text_2_sql_core.payloads.interaction_payloads import ( QuestionPayload, @@ -108,49 +109,69 @@ def extract_disambiguation_request( self, messages: list ) -> DismabiguationRequestPayload: """Extract the disambiguation request from the answer.""" - disambiguation_request = messages[-1].content - - # TODO: Properly extract the disambiguation request return DismabiguationRequestPayload( disambiguation_request=disambiguation_request, ) + def parse_message_content(self, content): + """Parse different message content formats into a dictionary.""" + if isinstance(content, (list, dict)): + # If it's already a list or dict, convert to JSON string + return json.dumps(content) + + # Try to extract JSON from markdown-style code blocks + json_match = re.search(r"```json\s*(.*?)\s*```", content, re.DOTALL) + if json_match: + try: + return json.loads(json_match.group(1)) + except json.JSONDecodeError: + pass + + # Try parsing as regular JSON + try: + return json.loads(content) + except json.JSONDecodeError: + pass + + # If all parsing attempts fail, return the content as-is + return content + def extract_sources(self, messages: list) -> AnswerWithSourcesPayload: """Extract the sources from the answer.""" - answer = messages[-1].content - - sql_query_results = messages[-2].content + sql_query_results = self.parse_message_content(messages[-2].content) try: - sql_query_results = json.loads(sql_query_results) + if isinstance(sql_query_results, str): + sql_query_results = json.loads(sql_query_results) logging.info("SQL Query Results: %s", sql_query_results) - payload = AnswerWithSourcesPayload(answer=answer) - for question, sql_query_result_list in sql_query_results["results"].items(): - logging.info( - "SQL Query Result for question '%s': %s", - question, - sql_query_result_list, - ) - - for sql_query_result in sql_query_result_list: - logging.info("SQL Query Result: %s", sql_query_result) - # Instantiate Source and append to the payload's sources list - source = AnswerWithSourcesPayload.Body.Source( - sql_query=sql_query_result["sql_query"], - sql_rows=sql_query_result["sql_rows"], + if isinstance(sql_query_results, dict) and "results" in sql_query_results: + for question, sql_query_result_list in sql_query_results[ + "results" + ].items(): + logging.info( + "SQL Query Result for question '%s': %s", + question, + sql_query_result_list, ) - payload.body.sources.append(source) + + for sql_query_result in sql_query_result_list: + logging.info("SQL Query Result: %s", sql_query_result) + source = AnswerWithSourcesPayload.Body.Source( + sql_query=sql_query_result["sql_query"], + sql_rows=sql_query_result["sql_rows"], + ) + payload.body.sources.append(source) return payload - except json.JSONDecodeError: - logging.error("Could not load message: %s", sql_query_results) - raise ValueError("Could not load message") + except Exception as e: + logging.error("Error processing results: %s", str(e)) + return AnswerWithSourcesPayload(answer=answer) async def process_question( self, diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py index 2be891c..e05cbb0 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py @@ -38,6 +38,37 @@ async def on_messages( assert response is not None return response + def parse_inner_message(self, message): + """Parse inner message content into a structured format.""" + try: + if isinstance(message, (dict, list)): + return message + + if not isinstance(message, str): + message = str(message) + + # Try to parse as JSON first + try: + return json.loads(message) + except JSONDecodeError: + pass + + # Try to extract JSON from markdown code blocks + import re + + json_match = re.search(r"```json\s*(.*?)\s*```", message, re.DOTALL) + if json_match: + try: + return json.loads(json_match.group(1)) + except JSONDecodeError: + pass + + # If we can't parse it, return it as-is + return message + except Exception as e: + logging.warning(f"Error parsing message: {e}") + return message + async def on_messages_stream( self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken ) -> AsyncGenerator[AgentMessage | Response, None]: @@ -74,46 +105,42 @@ async def consume_inner_messages_from_agentic_flow( if isinstance(inner_message, TaskResult) is False: try: - inner_message = json.loads(inner_message.content) - logging.info(f"Inner Loaded: {inner_message}") + parsed_message = self.parse_inner_message(inner_message.content) + logging.info(f"Inner Loaded: {parsed_message}") # Search for specific message types and add them to the final output object - if ( - "type" in inner_message - and inner_message["type"] == "query_execution_with_limit" - ): - database_results[identifier].append( - { - "sql_query": inner_message["sql_query"].replace( - "\n", " " - ), - "sql_rows": inner_message["sql_rows"], - } - ) - - if ("contains_pre_run_results" in inner_message) and ( - inner_message["contains_pre_run_results"] is True - ): - for pre_run_sql_query, pre_run_result in inner_message[ - "cached_questions_and_schemas" - ].items(): + if isinstance(parsed_message, dict): + if ( + "type" in parsed_message + and parsed_message["type"] + == "query_execution_with_limit" + ): database_results[identifier].append( { - "sql_query": pre_run_sql_query.replace( - "\n", " " - ), - "sql_rows": pre_run_result["sql_rows"], + "sql_query": parsed_message[ + "sql_query" + ].replace("\n", " "), + "sql_rows": parsed_message["sql_rows"], } ) - except (JSONDecodeError, TypeError) as e: - logging.error("Could not load message: %s", inner_message) - logging.warning(f"Error processing message: {e}") + if ("contains_pre_run_results" in parsed_message) and ( + parsed_message["contains_pre_run_results"] is True + ): + for pre_run_sql_query, pre_run_result in parsed_message[ + "cached_questions_and_schemas" + ].items(): + database_results[identifier].append( + { + "sql_query": pre_run_sql_query.replace( + "\n", " " + ), + "sql_rows": pre_run_result["sql_rows"], + } + ) except Exception as e: - logging.error("Could not load message: %s", inner_message) - logging.error(f"Error processing message: {e}") - raise e + logging.warning(f"Error processing message: {e}") yield inner_message