-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathquery_sanitize.py
More file actions
135 lines (117 loc) · 5.31 KB
/
query_sanitize.py
File metadata and controls
135 lines (117 loc) · 5.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from llama_index.core.bridge.pydantic import BaseModel, Field
from typing import List, Tuple, Optional, Dict, Any
import json
import re
from sqlalchemy import MetaData
class TableInfo(BaseModel):
"""Information regarding a structured table."""
table_name: str = Field(
..., description="table name (must be underscores and NO spaces)"
)
table_summary: str = Field(
..., description="short, concise summary/caption of the table"
)
class QueryClarification(BaseModel):
"""Model for storing query clarification details."""
needs_clarification: bool
questions: List[str] = []
ambiguous_tables: List[str] = []
ambiguous_columns: List[str] = []
missing_context: List[str] = []
Reasoning: List[str] = []
class QuerySanitizer:
"""Handles query sanitization and clarification generation."""
def __init__(self, bed_llm):
self.llm = bed_llm
self.sql_injection_patterns = [
r'--', # Comments
r';.*?$', # Multiple statements
r'/\*.*?\*/', # Multi-line comments
r'union.*?select', # UNION based
r'exec.*?\(', # Code execution
r'xp_.*?\(', # Extended stored procedures
]
def check_sql_injection(self, query: str) -> bool:
"""Check for potential SQL injection patterns."""
query = query.lower()
for pattern in self.sql_injection_patterns:
if re.search(pattern, query, re.IGNORECASE):
return True
return False
def generate_clarifying_questions(
self,
query: str,
table_infos: List[TableInfo],
metadata_obj: MetaData
) -> QueryClarification:
"""Generate clarifying questions based on the query and schema."""
analysis_prompt = f"""
Analyze this natural language query: "{query}"
Available tables and their descriptions:
{self._format_table_info(table_infos, metadata_obj)}
First, determine if this query is clear and can be answered with the available schema.
A query needs clarification ONLY if:
1. It references tables or columns that could map to multiple options
2. It's missing critical information needed to write SQL
3. It contains ambiguous terms or relationships
Do NOT ask for clarification if:
1. The meaning is clear from context
2. Standard time periods are mentioned (e.g., years, months)
3. The relationships between tables are obvious from schema
4. The query is in some way or form asking for data analysis of the database
Format response as JSON dict with these fields:
- "needs_clarification": 'False' if query is clear enough to process, 'True' if ambiguous,
- "confidence": 0-1 score of how clear the query intent is,
- "questions": [],
- "ambiguous_tables": [],
- "ambiguous_columns": [],
- "missing_context": []
- "Reasoning": []
"""
response = self.llm.complete(analysis_prompt)
try:
response_text = response.text if hasattr(response, 'text') else str(response)
clarification_dict = json.loads(response_text)
print("Query Clarification Index: \n", clarification_dict)
# Only return clarification questions if confidence is low
if clarification_dict.get("confidence", 0) > 0.8:
return QueryClarification(
needs_clarification=False,
questions=[],
ambiguous_tables=[],
ambiguous_columns=[],
missing_context=[],
Reasoning= []
)
return QueryClarification(**clarification_dict)
except Exception as e:
print(f"Error parsing LLM response: {str(e)}")
return QueryClarification(
needs_clarification=False,
questions=[],
ambiguous_tables=[],
ambiguous_columns=[],
missing_context=[],
Reasoning=[]
)
def _format_table_info(self, table_infos: List[TableInfo], metadata_obj: MetaData) -> str:
"""Format table information for LLM prompt."""
formatted_info = []
# print(metadata_obj.tables)
for table_info in table_infos:
# print(table_info)
table = metadata_obj.tables.get('csv_data.' + table_info.table_name)
print(table)
# csv_data.3214b43e68e_pizza_guy_data_test_set_sheet1
# csv_data.3214b43e68e_pizza_guy_data_test_set_sheet1
# csv_data.3214b43e68e_pizza_guy_data_test_set_sheet1
if table is not None: # Explicitly check if table exists
columns = [f"- {col.name} ({str(col.type)})" for col in table.columns]
column_str = "\n".join(columns)
formatted_info.append(f"""
Table: {table_info.table_name}
Description: {table_info.table_summary}
Columns:
{column_str}""")
# print("\n".join(formatted_info))
return "\n".join(formatted_info)