Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 92 additions & 5 deletions src/fetch/src/mcp_server_fetch/server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import ipaddress
import socket
from typing import Annotated, Tuple
from urllib.parse import urlparse, urlunparse
from urllib.parse import urljoin, urlparse, urlunparse

import markdownify
import readabilipy.simple_json
Expand All @@ -22,6 +24,91 @@

DEFAULT_USER_AGENT_AUTONOMOUS = "ModelContextProtocol/1.0 (Autonomous; +https://github.com/modelcontextprotocol/servers)"
DEFAULT_USER_AGENT_MANUAL = "ModelContextProtocol/1.0 (User-Specified; +https://github.com/modelcontextprotocol/servers)"
MAX_REDIRECTS = 10


def _is_blocked_ip_address(address: str) -> bool:
ip = ipaddress.ip_address(address)
if isinstance(ip, ipaddress.IPv6Address) and ip.ipv4_mapped:
ip = ip.ipv4_mapped
return not ip.is_global


def resolve_hostname_addresses(hostname: str, port: int | None) -> set[str]:
return {
result[4][0]
for result in socket.getaddrinfo(hostname, port, type=socket.SOCK_STREAM)
}


def validate_public_url(url: str) -> None:
parsed = urlparse(url)
if parsed.scheme not in {"http", "https"}:
raise McpError(ErrorData(
code=INVALID_PARAMS,
message="URL must use the http or https scheme",
))

hostname = parsed.hostname
if not hostname:
raise McpError(ErrorData(code=INVALID_PARAMS, message="URL must include a hostname"))

normalized_hostname = hostname.rstrip(".").lower()
if normalized_hostname == "localhost" or normalized_hostname.endswith(".localhost"):
raise McpError(ErrorData(
code=INVALID_PARAMS,
message="Fetching localhost URLs is not allowed",
))

try:
if _is_blocked_ip_address(normalized_hostname):
raise McpError(ErrorData(
code=INVALID_PARAMS,
message="Fetching private or non-public IP addresses is not allowed",
))
return
except ValueError:
pass

try:
addresses = resolve_hostname_addresses(normalized_hostname, parsed.port)
except OSError as exc:
raise McpError(ErrorData(
code=INTERNAL_ERROR,
message=f"Failed to resolve hostname {hostname}: {exc}",
))

if not addresses:
raise McpError(ErrorData(
code=INTERNAL_ERROR,
message=f"Failed to resolve hostname {hostname}",
))

for address in addresses:
if _is_blocked_ip_address(address):
raise McpError(ErrorData(
code=INVALID_PARAMS,
message="Fetching private or non-public IP addresses is not allowed",
))


async def get_url_safely(client, url: str, **kwargs):
current_url = url
for _ in range(MAX_REDIRECTS + 1):
validate_public_url(current_url)
response = await client.get(current_url, follow_redirects=False, **kwargs)
if response.status_code not in {301, 302, 303, 307, 308}:
return response

location = response.headers.get("location")
if not location:
return response
current_url = urljoin(str(getattr(response, "url", current_url)), location)

raise McpError(ErrorData(
code=INTERNAL_ERROR,
message=f"Too many redirects while fetching {url}",
))


def extract_content_from_html(html: str) -> str:
Expand Down Expand Up @@ -74,9 +161,9 @@ async def check_may_autonomously_fetch_url(url: str, user_agent: str, proxy_url:

async with AsyncClient(proxy=proxy_url) as client:
try:
response = await client.get(
response = await get_url_safely(
client,
robot_txt_url,
follow_redirects=True,
headers={"User-Agent": user_agent},
)
except HTTPError:
Expand Down Expand Up @@ -118,9 +205,9 @@ async def fetch_url(

async with AsyncClient(proxy=proxy_url) as client:
try:
response = await client.get(
response = await get_url_safely(
client,
url,
follow_redirects=True,
headers={"User-Agent": user_agent},
timeout=30,
)
Expand Down
79 changes: 79 additions & 0 deletions src/fetch/tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,19 @@
check_may_autonomously_fetch_url,
fetch_url,
DEFAULT_USER_AGENT_AUTONOMOUS,
validate_public_url,
)


@pytest.fixture(autouse=True)
def mock_public_dns(monkeypatch):
"""Keep URL safety checks deterministic in unit tests."""
monkeypatch.setattr(
"mcp_server_fetch.server.resolve_hostname_addresses",
lambda hostname, port: {"93.184.216.34"},
)


class TestGetRobotsTxtUrl:
"""Tests for get_robots_txt_url function."""

Expand Down Expand Up @@ -47,6 +57,37 @@ def test_http_url(self):
assert result == "http://example.com/robots.txt"


class TestValidatePublicUrl:
"""Tests for SSRF URL validation."""

@pytest.mark.parametrize(
"url",
[
"file:///etc/passwd",
"http://localhost/admin",
"http://127.0.0.1/admin",
"http://[::1]/admin",
"http://169.254.169.254/latest/meta-data/",
"http://10.0.0.4/admin",
],
)
def test_blocks_non_public_urls(self, url):
with pytest.raises(McpError):
validate_public_url(url)

def test_blocks_private_dns_resolution(self, monkeypatch):
monkeypatch.setattr(
"mcp_server_fetch.server.resolve_hostname_addresses",
lambda hostname, port: {"10.0.0.4"},
)

with pytest.raises(McpError):
validate_public_url("https://internal.example.com/page")

def test_allows_public_https_urls(self):
validate_public_url("https://example.com/page")


class TestExtractContentFromHtml:
"""Tests for extract_content_from_html function."""

Expand Down Expand Up @@ -266,6 +307,44 @@ async def test_fetch_json_returns_raw(self):
assert content == json_content
assert "cannot be simplified" in prefix

@pytest.mark.asyncio
async def test_blocks_private_url_before_request(self):
"""Test that private IP URLs are rejected before any HTTP request is sent."""
with patch("httpx.AsyncClient") as mock_client_class:
mock_client = AsyncMock()
mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client)
mock_client_class.return_value.__aexit__ = AsyncMock(return_value=None)

with pytest.raises(McpError):
await fetch_url(
"http://127.0.0.1/admin",
DEFAULT_USER_AGENT_AUTONOMOUS,
)

mock_client.get.assert_not_called()

@pytest.mark.asyncio
async def test_blocks_redirect_to_private_url(self):
"""Test that redirects are validated before following them."""
redirect_response = MagicMock()
redirect_response.status_code = 302
redirect_response.headers = {"location": "http://127.0.0.1/admin"}
redirect_response.url = "https://example.com/start"

with patch("httpx.AsyncClient") as mock_client_class:
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=redirect_response)
mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client)
mock_client_class.return_value.__aexit__ = AsyncMock(return_value=None)

with pytest.raises(McpError):
await fetch_url(
"https://example.com/start",
DEFAULT_USER_AGENT_AUTONOMOUS,
)

assert mock_client.get.call_count == 1

@pytest.mark.asyncio
async def test_fetch_404_raises_error(self):
"""Test that 404 response raises McpError."""
Expand Down
Loading