-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtext_2_sql_agentic.py
More file actions
300 lines (241 loc) · 12.8 KB
/
text_2_sql_agentic.py
File metadata and controls
300 lines (241 loc) · 12.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers.string import StrOutputParser
import sqlite3
from typing import List, Any, Dict
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, END
from urllib.parse import quote_plus
import pyodbc
import os
from langchain_community.callbacks import get_openai_callback
from langchain_core.messages import ToolMessage
from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks
from langgraph.prebuilt import ToolNode
import pandas as pd
import numpy as np
import os
from openai import OpenAI
from langchain_core.tools import tool
from dotenv import load_dotenv
import os
load_dotenv()
## Using the OpenAI client directly.
client = OpenAI(
api_key=os.environ.get("OPENAI_API_KEY"),
)
class InputState(TypedDict):
question: str
database: str
uri: str
schema: Dict[str, Any]
sql_query: str
result: List[Any]
class OutputState(TypedDict):
sql_query: str
result: List[Any]
visualization_code: str
def get_schema(state) -> dict:
"""Retrieves the schema of SQL Server database."""
conn_str = state["uri"]
conn = pyodbc.connect(conn_str)
cursor = conn.cursor()
cursor.execute("SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='BASE TABLE';")
tables = cursor.fetchall()
schema = {}
for table_name, in tables:
cursor.execute(f"SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME='{table_name}'")
schema[table_name] = {
"columns": [row[0] for row in cursor.fetchall()],
"sample_rows": []
}
cursor.execute(f"SELECT TOP 3 * FROM {table_name}")
sample_rows = cursor.fetchall()
schema[table_name]["sample_rows"] = [list(row) for row in sample_rows]
conn.close()
return {"schema": schema}
def generate_sql_query(state) -> dict:
"""Generates an SQL Server query from a natural language question."""
question = state["question"]
schema = state["schema"]
llm = ChatOpenAI(model="gpt-4o-mini")
prompt = ChatPromptTemplate.from_messages([
("system", """
You are an SQL expert. Generate an SQL server query based on the database schema and question.
Always generate syntactically correct SQL Server queries. Avoid including unncessary columns in the query. Only use the columns that are relevant to the question. If the user asks to exclude something from the results, follow the instructions completely.
Ensure the output query is executable. Do not put ```sql in the beginning of the query. Just provide raw executable query as a string. Use appropriate column names for the results.
"""),
("human", "===Database Schema:\n{schema}\n\n===User Question:\n{question}\n\nSQL Server Query:")
])
generate_query_chain = prompt | llm | StrOutputParser()
response = generate_query_chain.invoke({"question": question, "schema": schema})
print(f"Original SQL query: {response}")
return {"sql_query": response}
## Need to add query checker for SQL query
def sql_server_query_checker(state) -> dict:
"""Basic SQL Server query checker."""
query = state["sql_query"]
question = state["question"]
query_check_system_prefix = """You are a SQL Server expert with a strong attention to detail.
You will be provided a SQL Server query and a question. Double check the SQL Server query for common mistakes and verify if it answers the question properly.
If there are any mistakes, rewrite the query. If there are no mistakes, just reproduce the original query. Always produce the query as raw string which the user can execute directly.
Do not output anything else other then the query itself. No explanations. Just provide the relevant query."""
template = f"""
Check if the SQL query answers the user question properly. Specifically, review if the query satisfies all the information asked by the user and is syntactically correct.
Question: {question}
SQL Server Query: {query}
"""
completion = client.chat.completions.create(
model="o3-mini", ### Usage of advanced reasoning model
messages=[
{"role": "developer", "content": query_check_system_prefix},
{
"role": "user",
"content": template,
},
])
print(f"Corrected SQL query: {completion.choices[0].message.content}")
return {"sql_query": completion.choices[0].message.content}
def run_sql_server_query(state) -> dict:
"""Executes a SQL Server query and returns the results with column names embedded."""
conn_str = state["uri"]
query = state["sql_query"]
try:
conn = pyodbc.connect(conn_str)
cursor = conn.cursor()
cursor.execute(query)
# Fetch column names
columns = tuple(column[0] for column in cursor.description)
# Fetch results and prepend column names as the first row
result = [columns] + cursor.fetchall()
return {"result": result}
except Exception as e:
return {"result": [("Error",)], "error": f"Error executing query: {e}"}
finally:
conn.close()
### Can also directly use the assistants api from Open AI for writing and executing code.
## More information on this here: https://platform.openai.com/docs/assistants/quickstart
def generate_code_using_openai_call(state) -> dict:
question = state["question"]
data = state["result"]
#error = state["error"]
system_prefix = """You are an agent designed to generate Python code for creating visualizations based on the data provided.
Given an input question and data, create a Python code snippet that uses the Plotly library to generate the appropriate visualization. You might get a list of tuples as the data. First flatten the tuple, and then create a DataFrame using the data. Use the DataFrame to create the visualization.
Ensure the code is syntactically correct and includes necessary imports. Ensure that the code won't throw any errors when executed.
Do not include the code inside ```python``` tags. Just give the code snippet that the user can run directly using exec(). Therefore, produce the code as raw string.
Do not include any additional explanations or comments in the code. Do not add any additional logic on the dataframe since the data provided to you already answers the question.
After importing the necessary libraries, you can start your code directly with "df = pd.DataFrame(result)".
Always use the right syntax while using functions like df.pivot_table() or df.groupby().sum(). Do not use invalid syntax that will throw errors.
Use figure size like "plt.figure(figsize=10, 5))" since I have to render this visualization on a web page. Use appropriate colors and labels for the visualization.
If you do not receive any data, return an empty visualization with a message indicating that the data is not available.
"""
template = f"""
Based on the data provided below, write a Python code snippet that would create a appealing visualization based on the data provided. You would be optionally given a chart type in the question. If it is not provided, think of an appropriate chart type based on the data.
Question: {question}
Data: {data}
Visualization Code:
"""
completion = client.chat.completions.create(
model="o3-mini",
messages=[
{"role": "developer", "content": system_prefix},
{
"role": "user",
"content": template,
},
])
return {"visualization_code": completion.choices[0].message.content}
def generate_visualization_code(state) -> dict:
"""Generates Python code for visualizing the SQL query results."""
#print(state)
question = state["question"]
data = state["result"]
llm = ChatOpenAI(model="gpt-4o", temperature=0)
template = """
Based on the data provided below, write a Python code snippet that would create a appealing visualization based on the data provided. You would be optionally given a chart type in the question. If it is not provided, think of an appropriate chart type based on the data.
Question: {question}
Data: {data}
Visualization Code:
"""
system_prefix = """You are an agent designed to generate Python code for creating visualizations based on the data provided.
Given an input question and data, create a Python code snippet that uses the Plotly library to generate the appropriate visualization. You might get a list of tuples as the data. First flatten the tuple, and then create a DataFrame using the data. Use the DataFrame to create the visualization.
Ensure the code is syntactically correct and includes necessary imports. Ensure that the code won't throw any errors when executed.
Do not include the code inside ```python``` tags. Just give the code snippet that the user can run directly using exec(). Therefore, produce the code as raw string.
Do not include any additional explanations or comments in the code. Do not add any additional logic on the dataframe since the data provided to you already answers the question.
After importing the necessary libraries, you can start your code directly with "df = pd.DataFrame(result)".
Always use the right syntax while using functions like df.pivot_table() or df.groupby().sum(). Do not use invalid syntax that will throw errors.
Use a reasonably big figure size like "plt.figure(figsize=10, 5))" since I have to render this visualization on a web page. Use appropriate colors and labels for the visualization.
If you do not receive any data, return an empty visualization with a message indicating that the data is not available.
"""
final_prompt = ChatPromptTemplate.from_messages(
[
("system", system_prefix),
("human", template),
]
)
generate_code_chain = final_prompt | llm | StrOutputParser()
with get_openai_callback() as cb:
response = generate_code_chain.invoke({"question": question, "data": data})
print("GPT -4o callback :::::")
print(cb)
return {"visualization_code": response}
### Ideally, it should be like: Generate Code -> Run it and if it fails -> Go to a fallback method that can fix the code
### If it works then continue with the flow
### Same should be the case with SQL query
def run_python_code(state) -> dict:
visualization_code = state["visualization_code"]
try:
exec(visualization_code)
except Exception as e:
return {"error": f"Your code failed with the following error: {str(e)}"}
## Also need to add code checker before running it.
## Need to implement handling of errors from the tools. This will take the error returned from any of the generative functions and ask the model to fix it.
def handle_tool_error(state) -> dict:
error = state.get("error")
tool_calls = state["messages"][-1].tool_calls
return {
"messages": [
ToolMessage(
content=f"Error: {repr(error)}\n please fix your mistakes.",
tool_call_id=tc["id"],
)
for tc in tool_calls
]
}
def create_tool_node_with_fallback(tools: list) -> RunnableWithFallbacks[Any, dict]:
"""
Create a ToolNode with a fallback to handle errors and surface them to the agent.
"""
return ToolNode(tools).with_fallbacks(
[RunnableLambda(handle_tool_error)], exception_key="error"
)
workflow = StateGraph(input=InputState, output=OutputState)
workflow.add_node("get_schema", get_schema)
workflow.add_node("generate_sql", generate_sql_query)
workflow.add_node("check_sql_query", sql_server_query_checker)
workflow.add_node("run_sql", run_sql_server_query)
#workflow.add_node("generate_visualization", generate_visualization_code)
workflow.add_node("generate_visualization", generate_code_using_openai_call)
workflow.add_edge("get_schema", "generate_sql")
workflow.add_edge("generate_sql", "check_sql_query")
workflow.add_edge("check_sql_query", "run_sql")
workflow.add_edge("run_sql", "generate_visualization")
workflow.add_edge("generate_visualization", END)
workflow.set_entry_point("get_schema")
if __name__ == "__main__":
SERVER_NAME = None
CONN_STR = f"DRIVER=ODBC Driver 17 for SQL Server;SERVER={SERVER_NAME};DATABASE=GENAIDB;Trusted_Connection=yes;"
#print(CONN_STR)
# Create SQLAlchemy Engine for SQL Server
# CONN_STR = f"mssql+pyodbc://{SQL_USERNAME}:{SQL_PASSWORD}@{SQL_SERVER}/{SQL_DATABASE}?driver={SQL_DRIVER}&trusted_connection=yes"
#print(CONN_STR)
## Define initial state and invoke the graph.
state = {
"question": "Show the weekly occurance trend of top 10 defects in October 2024.",
"database": "GENAIDB",
"uri": CONN_STR
}
workflow = workflow.compile()
result = workflow.invoke(state)
print(result["result"])
exec(result["visualization_code"])