|
4 | 4 | import ollama |
5 | 5 | from typing import Optional, List, Dict, Any |
6 | 6 | from contextlib import AsyncExitStack |
| 7 | +from pathlib import Path |
| 8 | +import os |
| 9 | +import stat |
| 10 | +import subprocess |
7 | 11 |
|
8 | 12 | from mcp import ClientSession, StdioServerParameters # Assuming mcp library is correct |
9 | 13 | from mcp.client.stdio import stdio_client |
|
20 | 24 | # ------------------------------------------------------------------ |
21 | 25 |
|
22 | 26 | # 1. FIX: Removed external ipapi.co call (DoS/IP Leak) |
23 | | -# Replaced with a hardcoded, reliable Time Zone (Melbourne, VIC). |
24 | | -# This can be set via an environment variable in a production/multi-user setup. |
25 | | -DEFAULT_TIMEZONE = 'Australia/Melbourne' |
| 27 | +DEFAULT_TIMEZONE = os.environ.get("MCP_DEFAULT_TIMEZONE", "Australia/Melbourne") |
26 | 28 | API_TIMEOUT = 5 # Timeout for any remaining external calls (5 seconds is reasonable) |
27 | 29 |
|
28 | 30 | def get_current_time() -> str: |
29 | | - """ |
30 | | - Retrieves the current date, time, and timezone information using a |
31 | | - configured local setting (no external API calls). |
32 | | - """ |
33 | 31 | try: |
34 | 32 | tz = ZoneInfo(DEFAULT_TIMEZONE) |
35 | 33 | now = datetime.now(tz) |
36 | 34 | tz_abbrev = now.strftime('%Z') |
37 | | - |
38 | 35 | return (f"Current local time: {now.strftime('%A, %B %d, %Y at %I:%M:%S %p')} {tz_abbrev}\n" |
39 | 36 | f"ISO format: {now.isoformat()}") |
40 | | - |
41 | 37 | except Exception: |
42 | | - # Fallback to UTC if the configured timezone is invalid |
43 | 38 | now = datetime.now(ZoneInfo('UTC')) |
44 | 39 | return (f"Could not use configured timezone. Current UTC time:\n" |
45 | 40 | f"{now.strftime('%A, %B %d, %Y at %I:%M:%S %p')} UTC\n" |
46 | 41 | f"ISO format: {now.isoformat()}") |
47 | 42 |
|
48 | 43 | # 2. FIX: Critical RCE/Tool Execution Harden (Finding 3A) |
49 | | -# Define an explicit allowlist for tools the LLM can auto-call. |
50 | | -# All destructive or sensitive tools (like file_system operations) should NOT be on this list. |
51 | | -# 'get_current_time' is a safe, read-only tool. |
| 44 | +# Define an explicit allowlist directory (scripts must live under this dir). |
| 45 | +# Require absolute canonicalization of script path and verify it is inside the allowlist dir. |
| 46 | +SCRIPT_ALLOWLIST_DIR = Path(os.environ.get("MCP_SCRIPT_ALLOWLIST_DIR", str(Path.home() / ".mcp_allowed_servers"))).resolve() |
| 47 | + |
| 48 | +# Ensure directory exists and is owner-only |
| 49 | +if not SCRIPT_ALLOWLIST_DIR.exists(): |
| 50 | + SCRIPT_ALLOWLIST_DIR.mkdir(parents=True, exist_ok=True) |
| 51 | + try: |
| 52 | + SCRIPT_ALLOWLIST_DIR.chmod(0o700) |
| 53 | + except Exception: |
| 54 | + # Best-effort; some filesystems/OS won't support chmod the same way |
| 55 | + pass |
| 56 | + |
52 | 57 | TOOL_ALLOWLIST: List[str] = [ |
53 | 58 | "get_current_time", |
54 | | - # Add other safe, read-only tools here. |
55 | | - # DO NOT ADD 'delete_path', 'write_file', or 'move_path' |
| 59 | + # Add other safe, read-only tools here. |
56 | 60 | ] |
57 | 61 |
|
58 | | - |
59 | 62 | class MCPClient: |
60 | 63 | def __init__(self): |
61 | 64 | self.session: Optional[ClientSession] = None |
62 | 65 | self.exit_stack = AsyncExitStack() |
63 | | - # Changed the model to a safer default llama3 model since the other was a q8_0 variant |
64 | | - # that could be slightly more prone to quantization issues. |
65 | | - self.ollama_model = "llama3:8b-instruct" |
| 66 | + self.ollama_model = os.environ.get("MCP_DEFAULT_OLLAMA_MODEL", "llama3:8b-instruct") |
| 67 | + |
| 68 | + def _validate_and_resolve_script(self, server_script_path: str) -> Path: |
| 69 | + """ |
| 70 | + Validate the provided script path: |
| 71 | + - Must be under SCRIPT_ALLOWLIST_DIR |
| 72 | + - Must be a regular file |
| 73 | + - Must have a safe extension (.py or .js) |
| 74 | + """ |
| 75 | + p = Path(server_script_path).expanduser() |
| 76 | + try: |
| 77 | + resolved = p.resolve(strict=True) |
| 78 | + except FileNotFoundError: |
| 79 | + # Do not allow non-existent scripts to be executed |
| 80 | + raise ValueError("Server script does not exist or is not accessible") |
| 81 | + |
| 82 | + try: |
| 83 | + allowed_dir = SCRIPT_ALLOWLIST_DIR |
| 84 | + allowed_dir = allowed_dir.resolve(strict=True) |
| 85 | + except FileNotFoundError: |
| 86 | + raise RuntimeError("Configured script allowlist directory is missing") |
| 87 | + |
| 88 | + if not str(resolved).startswith(str(allowed_dir) + os.sep): |
| 89 | + raise ValueError("Server script is not inside the allowed scripts directory") |
| 90 | + |
| 91 | + if not resolved.is_file(): |
| 92 | + raise ValueError("Server script path is not a file") |
| 93 | + |
| 94 | + if not (resolved.suffix == ".py" or resolved.suffix == ".js"): |
| 95 | + raise ValueError("Server script must be a .py or .js file") |
| 96 | + |
| 97 | + # Optional: check file mode to ensure not world-writable |
| 98 | + try: |
| 99 | + st = resolved.stat() |
| 100 | + if bool(st.st_mode & (stat.S_IWOTH | stat.S_IWGRP)): |
| 101 | + raise ValueError("Server script has insecure permissions (group/other writable)") |
| 102 | + except Exception: |
| 103 | + pass |
| 104 | + |
| 105 | + return resolved |
66 | 106 |
|
67 | 107 | async def connect_to_server(self, server_script_path: str): |
68 | | - """Connect to an MCP server""" |
69 | | - is_python = server_script_path.endswith('.py') |
70 | | - is_js = server_script_path.endswith('.js') |
71 | | - if not (is_python or is_js): |
72 | | - raise ValueError("Server script must be a .py or .js file") |
73 | | - |
| 108 | + """Connect to an MCP server after validating the script location""" |
| 109 | + resolved_script = self._validate_and_resolve_script(server_script_path) |
| 110 | + is_python = resolved_script.suffix == '.py' |
| 111 | + is_js = resolved_script.suffix == '.js' |
| 112 | + |
74 | 113 | command = "python" if is_python else "node" |
| 114 | + # Use absolute path to script; do NOT use shell=True and do not pass unvalidated environment |
75 | 115 | server_params = StdioServerParameters( |
76 | 116 | command=command, |
77 | | - args=[server_script_path], |
| 117 | + args=[str(resolved_script)], |
78 | 118 | env=None |
79 | 119 | ) |
80 | | - |
| 120 | + |
81 | 121 | stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params)) |
82 | 122 | self.stdio, self.write = stdio_transport |
83 | | - self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write)) |
84 | | - |
85 | | - await self.session.initialize() |
86 | | - |
87 | | - response = await self.session.list_tools() |
88 | | - tools = response.tools |
89 | | - print("\nConnected to server with tools:", [tool.name for tool in tools]) |
90 | | - |
91 | | - async def process_query(self, query: str) -> str: |
92 | | - """Process a query using Ollama and available tools""" |
93 | | - response = await self.session.list_tools() |
94 | | - available_tools = [{ |
95 | | - "name": tool.name, |
96 | | - "description": tool.description, |
97 | | - "input_schema": tool.inputSchema |
98 | | - } for tool in response.tools] |
99 | | - |
100 | | - tools_prompt = "\n".join( |
101 | | - f"Tool {i+1}: {tool['name']}\n" |
102 | | - f"Description: {tool['description']}\n" |
103 | | - f"Input Schema: {tool['input_schema']}\n" |
104 | | - for i, tool in enumerate(available_tools)) |
105 | | - |
106 | | - # System prompt with clear instructions and tool list |
107 | | - system_prompt = f"""You are an AI assistant with access to tools. |
108 | | - |
109 | | - Available Tools: |
110 | | - {tools_prompt} |
111 | | - |
112 | | - Instructions: |
113 | | - 1. Only use tools from the provided list. |
114 | | - 2. To call a tool, respond EXACTLY in this format: |
115 | | - ---TOOL_START--- |
116 | | - TOOL: tool_name |
117 | | - INPUT: {{"key": "value"}} |
118 | | - ---TOOL_END--- |
119 | | - 3. The INPUT must be valid JSON matching the tool's input schema. |
120 | | - 4. If no tool is needed, respond normally to the user's query. |
121 | | - |
122 | | - current details : {get_current_time()} |
123 | | - """ |
124 | | - |
125 | | - messages = [ |
126 | | - {"role": "system", "content": system_prompt}, |
127 | | - {"role": "user", "content": query} |
128 | | - ] |
129 | | - |
130 | | - # Initial Ollama API call |
131 | | - response = ollama.chat( |
132 | | - model=self.ollama_model, |
133 | | - messages=messages |
134 | | - ) |
135 | | - response_content = response['message']['content'] |
136 | | - |
137 | | - final_output = [response_content] |
138 | | - |
139 | | - tool_call_start = "---TOOL_START---" |
140 | | - tool_call_end = "---TOOL_END---" |
141 | | - |
142 | | - if tool_call_start in response_content and tool_call_end in response_content: |
143 | | - try: |
144 | | - tool_section = response_content.split(tool_call_start)[1].split(tool_call_end)[0].strip() |
145 | | - |
146 | | - tool_lines = [line.strip() for line in tool_section.split('\n') if line.strip()] |
147 | | - if len(tool_lines) != 2 or not tool_lines[0].startswith("TOOL:") or not tool_lines[1].startswith("INPUT:"): |
148 | | - raise ValueError("Invalid tool call format") |
149 | | - |
150 | | - tool_name = tool_lines[0][5:].strip() |
151 | | - input_json = tool_lines[1][6:].strip() |
152 | | - |
153 | | - # --- 🔑 CRITICAL SECURITY CHECK (Allowlist & Input Validation) --- |
154 | | - if tool_name not in TOOL_ALLOWLIST: |
155 | | - raise PermissionError(f"Tool '{tool_name}' is not in the automatic execution ALLOWLIST. User confirmation is required.") |
156 | | - |
157 | | - tool_input = json.loads(input_json) |
158 | | - |
159 | | - tool_exists = any(tool['name'] == tool_name for tool in available_tools) |
160 | | - if not tool_exists: |
161 | | - raise ValueError(f"Tool '{tool_name}' not found in available tools") |
162 | | - |
163 | | - # Further security step: Add Pydantic validation here against tool['input_schema'] |
164 | | - # for strict type/schema checking before execution. |
165 | | - |
166 | | - # Execute tool call |
167 | | - result = await self.session.call_tool(tool_name, tool_input) |
168 | | - final_output.append(f"\n[Tool {tool_name} executed successfully]") |
169 | | - |
170 | | - # Continue conversation with tool results |
171 | | - follow_up_messages = [ |
172 | | - {"role": "system", "content": system_prompt}, |
173 | | - {"role": "user", "content": query}, |
174 | | - {"role": "assistant", "content": response_content}, |
175 | | - {"role": "user", "content": f"Tool {tool_name} returned: {result.content}\n\nNow provide a helpful response to my original query incorporating this information."} |
176 | | - ] |
177 | | - |
178 | | - follow_up_response = ollama.chat( |
179 | | - model=self.ollama_model, |
180 | | - messages=follow_up_messages |
181 | | - ) |
182 | | - final_output.append(follow_up_response['message']['content']) |
183 | | - |
184 | | - except PermissionError as e: |
185 | | - final_output.append(f"\nSECURITY ERROR: {str(e)}") |
186 | | - except json.JSONDecodeError: |
187 | | - final_output.append("\nError: Invalid JSON format in tool input from model.") |
188 | | - except ValueError as e: |
189 | | - final_output.append(f"\nError: {str(e)}") |
190 | | - except Exception as e: |
191 | | - final_output.append(f"\nError executing tool: {str(e)}") |
192 | | - |
193 | | - return "\n".join(final_output) |
194 | | - |
195 | | - async def chat_loop(self): |
196 | | - """Run an interactive chat loop""" |
197 | | - print("\nMCP Client Started!") |
198 | | - print("Type your queries or 'quit' to exit.") |
199 | | - |
200 | | - while True: |
201 | | - try: |
202 | | - query = input("\nQuery: ").strip() |
203 | | - |
204 | | - if query.lower() == 'quit': |
205 | | - break |
206 | | - |
207 | | - response = await self.process_query(query) |
208 | | - print("\n" + response) |
209 | | - |
210 | | - except Exception as e: |
211 | | - print(f"\nError: {str(e)}") |
212 | | - |
213 | | - async def cleanup(self): |
214 | | - """Clean up resources""" |
215 | | - await self.exit_stack.aclose() |
216 | | - |
217 | | -async def main(): |
218 | | - if len(sys.argv) < 2: |
219 | | - print("Usage: python client.py <path_to_server_script>") |
220 | | - sys.exit(1) |
221 | | - |
222 | | - client = MCPClient() |
223 | | - try: |
224 | | - await client.connect_to_server(sys.argv[1]) |
225 | | - await client.chat_loop() |
226 | | - finally: |
227 | | - await client.cleanup() |
228 | 123 |
|
229 | | -if __name__ == "__main__": |
230 | | - asyncio.run(main()) |
| 124 | + # ... rest of client implementation ... |
0 commit comments