Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
"source": [
"import dotenv\n",
"import logging\n",
"from autogen_text_2_sql import AutoGenText2Sql, AgentRequestBody"
"from autogen_text_2_sql import AutoGenText2Sql, QuestionPayload"
]
},
{
Expand Down Expand Up @@ -100,16 +100,9 @@
"metadata": {},
"outputs": [],
"source": [
"async for message in agentic_text_2_sql.process_question(AgentRequestBody(question=\"What total number of orders in June 2008?\")):\n",
"async for message in agentic_text_2_sql.process_question(QuestionPayload(question=\"What total number of orders in June 2008?\")):\n",
" logging.info(\"Received %s Message from Text2SQL System\", message)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
4 changes: 2 additions & 2 deletions text_2_sql/autogen/src/autogen_text_2_sql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from autogen_text_2_sql.autogen_text_2_sql import AutoGenText2Sql
from text_2_sql_core.payloads.agent_request_response_pair import AgentRequestBody
from text_2_sql_core.payloads.interaction_payloads import QuestionPayload

__all__ = ["AutoGenText2Sql", "AgentRequestBody"]
__all__ = ["AutoGenText2Sql", "QuestionPayload"]
91 changes: 38 additions & 53 deletions text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,13 @@
import os
from datetime import datetime

from text_2_sql_core.payloads.agent_request_response_pair import (
AgentRequestResponsePair,
AgentRequestBody,
AnswerWithSources,
Source,
DismabiguationRequests,
)
from text_2_sql_core.payloads.chat_history import ChatHistoryItem
from text_2_sql_core.payloads.processing_update import (
ProcessingUpdateBody,
ProcessingUpdate,
from text_2_sql_core.payloads.interaction_payloads import (
QuestionPayload,
AnswerWithSourcesPayload,
DismabiguationRequestPayload,
ProcessingUpdatePayload,
InteractionPayload,
PayloadType,
)
from autogen_agentchat.base import TaskResult
from typing import AsyncGenerator
Expand Down Expand Up @@ -108,17 +104,19 @@ def agentic_flow(self):
)
return flow

def extract_disambiguation_request(self, messages: list) -> DismabiguationRequests:
def extract_disambiguation_request(
self, messages: list
) -> DismabiguationRequestPayload:
"""Extract the disambiguation request from the answer."""

disambiguation_request = messages[-1].content

# TODO: Properly extract the disambiguation request
return DismabiguationRequests(
return DismabiguationRequestPayload(
disambiguation_request=disambiguation_request,
)

def extract_sources(self, messages: list) -> AnswerWithSources:
def extract_sources(self, messages: list) -> AnswerWithSourcesPayload:
"""Extract the sources from the answer."""

answer = messages[-1].content
Expand All @@ -130,7 +128,7 @@ def extract_sources(self, messages: list) -> AnswerWithSources:

logging.info("SQL Query Results: %s", sql_query_results)

sources = []
payload = AnswerWithSourcesPayload(answer=answer)

for question, sql_query_result_list in sql_query_results["results"].items():
logging.info(
Expand All @@ -141,27 +139,24 @@ def extract_sources(self, messages: list) -> AnswerWithSources:

for sql_query_result in sql_query_result_list:
logging.info("SQL Query Result: %s", sql_query_result)
sources.append(
Source(
sql_query=sql_query_result["sql_query"],
sql_rows=sql_query_result["sql_rows"],
)
# Instantiate Source and append to the payload's sources list
source = AnswerWithSourcesPayload.Body.Source(
sql_query=sql_query_result["sql_query"],
sql_rows=sql_query_result["sql_rows"],
)
payload.body.sources.append(source)

return payload

except json.JSONDecodeError:
logging.error("Could not load message: %s", sql_query_results)
raise ValueError("Could not load message")

return AnswerWithSources(
answer=answer,
sources=sources,
)

async def process_question(
self,
request: AgentRequestBody,
chat_history: list[ChatHistoryItem] = None,
) -> AsyncGenerator[AgentRequestResponsePair | ProcessingUpdate, None]:
question_payload: QuestionPayload,
chat_history: list[InteractionPayload] = None,
) -> AsyncGenerator[InteractionPayload, None]:
"""Process the complete question through the unified system.

Args:
Expand All @@ -174,65 +169,55 @@ async def process_question(
-------
dict: The response from the system.
"""
logging.info("Processing question: %s", request.question)
logging.info("Processing question: %s", question_payload.body.question)
logging.info("Chat history: %s", chat_history)

agent_input = {
"question": request.question,
"question": question_payload.body.question,
"chat_history": {},
"injected_parameters": request.injected_parameters,
"injected_parameters": question_payload.body.injected_parameters,
}

if chat_history is not None:
# Update input
for idx, chat in enumerate(chat_history):
# For now only consider the user query
chat_history_key = f"chat_{idx}"
agent_input[
chat_history_key
] = chat.request_response_pair.request.question
if chat.root.payload_type == PayloadType.QUESTION:
# For now only consider the user query
chat_history_key = f"chat_{idx}"
agent_input[chat_history_key] = chat.root.body.question

async for message in self.agentic_flow.run_stream(task=json.dumps(agent_input)):
logging.debug("Message: %s", message)

payload = None

if isinstance(message, TextMessage):
processing_update = None
if message.source == "query_rewrite_agent":
processing_update = ProcessingUpdateBody(
payload = ProcessingUpdatePayload(
message="Rewriting the query...",
)
elif message.source == "parallel_query_solving_agent":
processing_update = ProcessingUpdateBody(
payload = ProcessingUpdatePayload(
message="Solving the query...",
)
elif message.source == "answer_agent":
processing_update = ProcessingUpdateBody(
payload = ProcessingUpdatePayload(
message="Generating the answer...",
)

if processing_update is not None:
payload = ProcessingUpdate(
processing_update=processing_update,
)

elif isinstance(message, TaskResult):
# Now we need to return the final answer or the disambiguation request
logging.info("TaskResult: %s", message)

response = None
if message.messages[-1].source == "answer_agent":
# If the message is from the answer_agent, we need to return the final answer
response = self.extract_sources(message.messages)
payload = self.extract_sources(message.messages)
elif message.messages[-1].source == "parallel_query_solving_agent":
# Load into disambiguation request
response = self.extract_disambiguation_request(message.messages)
else:
logging.error("Unexpected TaskResult: %s", message)
raise ValueError("Unexpected TaskResult")

payload = AgentRequestResponsePair(request=request, response=response)
payload = self.extract_disambiguation_request(message.messages)
else:
logging.error("Unexpected TaskResult: %s", message)
raise ValueError("Unexpected TaskResult")

if payload is not None:
logging.debug("Payload: %s", payload)
Expand Down

This file was deleted.

This file was deleted.

Loading