diff --git a/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb b/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb index d4f22d4..23b2a91 100644 --- a/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb +++ b/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb @@ -50,7 +50,7 @@ "source": [ "import dotenv\n", "import logging\n", - "from autogen_text_2_sql import AutoGenText2Sql, AgentRequestBody" + "from autogen_text_2_sql import AutoGenText2Sql, QuestionPayload" ] }, { @@ -100,16 +100,9 @@ "metadata": {}, "outputs": [], "source": [ - "async for message in agentic_text_2_sql.process_question(AgentRequestBody(question=\"What total number of orders in June 2008?\")):\n", + "async for message in agentic_text_2_sql.process_question(QuestionPayload(question=\"What total number of orders in June 2008?\")):\n", " logging.info(\"Received %s Message from Text2SQL System\", message)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py b/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py index 3c3046b..cda320f 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from autogen_text_2_sql.autogen_text_2_sql import AutoGenText2Sql -from text_2_sql_core.payloads.agent_request_response_pair import AgentRequestBody +from text_2_sql_core.payloads.interaction_payloads import QuestionPayload -__all__ = ["AutoGenText2Sql", "AgentRequestBody"] +__all__ = ["AutoGenText2Sql", "QuestionPayload"] diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py index 40a0160..83bc323 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py @@ -17,17 +17,13 @@ import os from datetime import datetime -from text_2_sql_core.payloads.agent_request_response_pair import ( - AgentRequestResponsePair, - AgentRequestBody, - AnswerWithSources, - Source, - DismabiguationRequests, -) -from text_2_sql_core.payloads.chat_history import ChatHistoryItem -from text_2_sql_core.payloads.processing_update import ( - ProcessingUpdateBody, - ProcessingUpdate, +from text_2_sql_core.payloads.interaction_payloads import ( + QuestionPayload, + AnswerWithSourcesPayload, + DismabiguationRequestPayload, + ProcessingUpdatePayload, + InteractionPayload, + PayloadType, ) from autogen_agentchat.base import TaskResult from typing import AsyncGenerator @@ -108,17 +104,19 @@ def agentic_flow(self): ) return flow - def extract_disambiguation_request(self, messages: list) -> DismabiguationRequests: + def extract_disambiguation_request( + self, messages: list + ) -> DismabiguationRequestPayload: """Extract the disambiguation request from the answer.""" disambiguation_request = messages[-1].content # TODO: Properly extract the disambiguation request - return DismabiguationRequests( + return DismabiguationRequestPayload( disambiguation_request=disambiguation_request, ) - def extract_sources(self, messages: list) -> AnswerWithSources: + def extract_sources(self, messages: list) -> AnswerWithSourcesPayload: """Extract the sources from the answer.""" answer = messages[-1].content @@ -130,7 +128,7 @@ def extract_sources(self, messages: list) -> AnswerWithSources: logging.info("SQL Query Results: %s", sql_query_results) - sources = [] + payload = AnswerWithSourcesPayload(answer=answer) for question, sql_query_result_list in sql_query_results["results"].items(): logging.info( @@ -141,27 +139,24 @@ def extract_sources(self, messages: list) -> AnswerWithSources: for sql_query_result in sql_query_result_list: logging.info("SQL Query Result: %s", sql_query_result) - sources.append( - Source( - sql_query=sql_query_result["sql_query"], - sql_rows=sql_query_result["sql_rows"], - ) + # Instantiate Source and append to the payload's sources list + source = AnswerWithSourcesPayload.Body.Source( + sql_query=sql_query_result["sql_query"], + sql_rows=sql_query_result["sql_rows"], ) + payload.body.sources.append(source) + + return payload except json.JSONDecodeError: logging.error("Could not load message: %s", sql_query_results) raise ValueError("Could not load message") - return AnswerWithSources( - answer=answer, - sources=sources, - ) - async def process_question( self, - request: AgentRequestBody, - chat_history: list[ChatHistoryItem] = None, - ) -> AsyncGenerator[AgentRequestResponsePair | ProcessingUpdate, None]: + question_payload: QuestionPayload, + chat_history: list[InteractionPayload] = None, + ) -> AsyncGenerator[InteractionPayload, None]: """Process the complete question through the unified system. Args: @@ -174,23 +169,22 @@ async def process_question( ------- dict: The response from the system. """ - logging.info("Processing question: %s", request.question) + logging.info("Processing question: %s", question_payload.body.question) logging.info("Chat history: %s", chat_history) agent_input = { - "question": request.question, + "question": question_payload.body.question, "chat_history": {}, - "injected_parameters": request.injected_parameters, + "injected_parameters": question_payload.body.injected_parameters, } if chat_history is not None: # Update input for idx, chat in enumerate(chat_history): - # For now only consider the user query - chat_history_key = f"chat_{idx}" - agent_input[ - chat_history_key - ] = chat.request_response_pair.request.question + if chat.root.payload_type == PayloadType.QUESTION: + # For now only consider the user query + chat_history_key = f"chat_{idx}" + agent_input[chat_history_key] = chat.root.body.question async for message in self.agentic_flow.run_stream(task=json.dumps(agent_input)): logging.debug("Message: %s", message) @@ -198,41 +192,32 @@ async def process_question( payload = None if isinstance(message, TextMessage): - processing_update = None if message.source == "query_rewrite_agent": - processing_update = ProcessingUpdateBody( + payload = ProcessingUpdatePayload( message="Rewriting the query...", ) elif message.source == "parallel_query_solving_agent": - processing_update = ProcessingUpdateBody( + payload = ProcessingUpdatePayload( message="Solving the query...", ) elif message.source == "answer_agent": - processing_update = ProcessingUpdateBody( + payload = ProcessingUpdatePayload( message="Generating the answer...", ) - if processing_update is not None: - payload = ProcessingUpdate( - processing_update=processing_update, - ) - elif isinstance(message, TaskResult): # Now we need to return the final answer or the disambiguation request logging.info("TaskResult: %s", message) - response = None if message.messages[-1].source == "answer_agent": # If the message is from the answer_agent, we need to return the final answer - response = self.extract_sources(message.messages) + payload = self.extract_sources(message.messages) elif message.messages[-1].source == "parallel_query_solving_agent": # Load into disambiguation request - response = self.extract_disambiguation_request(message.messages) - else: - logging.error("Unexpected TaskResult: %s", message) - raise ValueError("Unexpected TaskResult") - - payload = AgentRequestResponsePair(request=request, response=response) + payload = self.extract_disambiguation_request(message.messages) + else: + logging.error("Unexpected TaskResult: %s", message) + raise ValueError("Unexpected TaskResult") if payload is not None: logging.debug("Payload: %s", payload) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/agent_request_response_pair.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/agent_request_response_pair.py deleted file mode 100644 index 0126eff..0000000 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/agent_request_response_pair.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from pydantic import BaseModel, RootModel, Field, model_validator -from enum import StrEnum - -from typing import Literal -from datetime import datetime, timezone - - -class AgentRequestResponseHeader(BaseModel): - prompt_tokens: int - completion_tokens: int - timestamp: datetime = Field( - ..., - description="Timestamp in UTC", - default_factory=lambda: datetime.now(timezone.utc), - ) - - -class AgentResponseType(StrEnum): - ANSWER_WITH_SOURCES = "answer_with_sources" - DISAMBIGUATION = "disambiguation" - - -class DismabiguationRequest(BaseModel): - question: str - matching_columns: list[str] - matching_filter_values: list[str] - other_user_choices: list[str] - - -class DismabiguationRequests(BaseModel): - response_type: Literal[AgentResponseType.DISAMBIGUATION] = Field( - default=AgentResponseType.DISAMBIGUATION - ) - requests: list[DismabiguationRequest] - - -class Source(BaseModel): - sql_query: str - sql_rows: list[dict] - - -class AnswerWithSources(BaseModel): - response_type: Literal[AgentResponseType.ANSWER_WITH_SOURCES] = Field( - default=AgentResponseType.ANSWER_WITH_SOURCES - ) - answer: str - sources: list[Source] = Field(default_factory=list) - - -class AgentResponseBody(RootModel): - root: DismabiguationRequests | AnswerWithSources = Field( - ..., discriminator="response_type" - ) - - -class AgentRequestBody(BaseModel): - question: str - injected_parameters: dict = Field(default_factory=dict) - - @model_validator(mode="before") - def add_defaults_to_injected_parameters(cls, values): - if "injected_parameters" not in values: - values["injected_parameters"] = {} - - if "date" not in values["injected_parameters"]: - values["injected_parameters"]["date"] = datetime.now().strftime("%d/%m/%Y") - - if "time" not in values["injected_parameters"]: - values["injected_parameters"]["time"] = datetime.now().strftime("%H:%M:%S") - - if "datetime" not in values["injected_parameters"]: - values["injected_parameters"]["datetime"] = datetime.now().strftime( - "%d/%m/%Y, %H:%M:%S" - ) - - if "unix_timestamp" not in values["injected_parameters"]: - values["injected_parameters"]["unix_timestamp"] = int( - datetime.now().timestamp() - ) - - return values - - -class AgentRequestResponsePair(BaseModel): - header: AgentRequestResponseHeader | None = Field(default=None) - request: AgentRequestBody - response: AgentResponseBody diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/chat_history.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/chat_history.py deleted file mode 100644 index db846c7..0000000 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/chat_history.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from pydantic import BaseModel, Field -from text_2_sql_core.payloads.agent_request_response_pair import ( - AgentRequestResponsePair, -) -from datetime import datetime, timezone - - -class ChatHistoryItem(BaseModel): - """Chat history item with user message and agent response.""" - - timestamp: datetime = Field( - ..., - description="Timestamp in UTC", - default_factory=lambda: datetime.now(timezone.utc), - ) - request_response_pair: AgentRequestResponsePair diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py new file mode 100644 index 0000000..6ef64c2 --- /dev/null +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py @@ -0,0 +1,121 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from pydantic import BaseModel, RootModel, Field, model_validator +from enum import StrEnum + +from typing import Literal +from datetime import datetime, timezone + + +class PayloadBase(BaseModel): + prompt_tokens: int | None = None + completion_tokens: int | None = None + timestamp: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), + description="Timestamp in UTC", + ) + payload_type: str + payload_source: str + + +class PayloadSource(StrEnum): + USER = "user" + AGENT = "agent" + + +class PayloadType(StrEnum): + ANSWER_WITH_SOURCES = "answer_with_sources" + DISAMBIGUATION_REQUEST = "disambiguation_request" + PROCESSING_UPDATE = "processing_update" + QUESTION = "question" + + +class DismabiguationRequestPayload(PayloadBase): + class Body(BaseModel): + class DismabiguationRequest(BaseModel): + question: str + matching_columns: list[str] + matching_filter_values: list[str] + other_user_choices: list[str] + + disambiguation_requests: list[DismabiguationRequest] + + payload_type: Literal[ + PayloadType.DISAMBIGUATION_REQUEST + ] = PayloadType.DISAMBIGUATION_REQUEST + payload_source: Literal[PayloadSource.AGENT] = PayloadSource.AGENT + body: Body | None = Field(default=None) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.body = self.Body(**kwargs) + + +class AnswerWithSourcesPayload(PayloadBase): + class Body(BaseModel): + class Source(BaseModel): + sql_query: str + sql_rows: list[dict] + + answer: str + sources: list[Source] = Field(default_factory=list) + + payload_type: Literal[ + PayloadType.ANSWER_WITH_SOURCES + ] = PayloadType.ANSWER_WITH_SOURCES + payload_source: Literal[PayloadSource.AGENT] = PayloadSource.AGENT + body: Body | None = Field(default=None) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.body = self.Body(**kwargs) + + +class ProcessingUpdatePayload(PayloadBase): + class Body(BaseModel): + title: str | None = "Processing..." + message: str | None = "Processing..." + + payload_type: Literal[PayloadType.PROCESSING_UPDATE] = PayloadType.PROCESSING_UPDATE + payload_source: Literal[PayloadSource.AGENT] = PayloadSource.AGENT + body: Body | None = Field(default=None) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.body = self.Body(**kwargs) + + +class QuestionPayload(PayloadBase): + class Body(BaseModel): + question: str + injected_parameters: dict = Field(default_factory=dict) + + @model_validator(mode="before") + def add_defaults(cls, values): + defaults = { + "date": datetime.now().strftime("%d/%m/%Y"), + "time": datetime.now().strftime("%H:%M:%S"), + "datetime": datetime.now().strftime("%d/%m/%Y, %H:%M:%S"), + "unix_timestamp": int(datetime.now().timestamp()), + } + injected = values.get("injected_parameters", {}) + values["injected_parameters"] = {**defaults, **injected} + return values + + payload_type: Literal[PayloadType.QUESTION] = PayloadType.QUESTION + payload_source: Literal[PayloadSource.USER] = PayloadSource.USER + body: Body | None = Field(default=None) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.body = self.Body(**kwargs) + + +class InteractionPayload(RootModel): + root: QuestionPayload | ProcessingUpdatePayload | DismabiguationRequestPayload | AnswerWithSourcesPayload = Field( + ..., discriminator="payload_type" + ) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/processing_update.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/processing_update.py deleted file mode 100644 index 500b3b2..0000000 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/processing_update.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from pydantic import BaseModel, Field -from datetime import datetime, timezone - - -class ProcessingUpdateHeader(BaseModel): - timestamp: datetime = Field( - ..., - description="Timestamp in UTC", - default_factory=lambda: datetime.now(timezone.utc), - ) - - -class ProcessingUpdateBody(BaseModel): - title: str | None = Field(default="Processing...") - message: str | None = Field(default="Processing...") - - -class ProcessingUpdate(BaseModel): - header: ProcessingUpdateHeader | None = Field( - default_factory=ProcessingUpdateHeader - ) - processing_update: ProcessingUpdateBody