diff --git a/src/fetch/src/mcp_server_fetch/server.py b/src/fetch/src/mcp_server_fetch/server.py index b42c7b1f6b..182dcbbb86 100644 --- a/src/fetch/src/mcp_server_fetch/server.py +++ b/src/fetch/src/mcp_server_fetch/server.py @@ -1,5 +1,7 @@ +import ipaddress +import socket from typing import Annotated, Tuple -from urllib.parse import urlparse, urlunparse +from urllib.parse import urljoin, urlparse, urlunparse import markdownify import readabilipy.simple_json @@ -22,6 +24,91 @@ DEFAULT_USER_AGENT_AUTONOMOUS = "ModelContextProtocol/1.0 (Autonomous; +https://github.com/modelcontextprotocol/servers)" DEFAULT_USER_AGENT_MANUAL = "ModelContextProtocol/1.0 (User-Specified; +https://github.com/modelcontextprotocol/servers)" +MAX_REDIRECTS = 10 + + +def _is_blocked_ip_address(address: str) -> bool: + ip = ipaddress.ip_address(address) + if isinstance(ip, ipaddress.IPv6Address) and ip.ipv4_mapped: + ip = ip.ipv4_mapped + return not ip.is_global + + +def resolve_hostname_addresses(hostname: str, port: int | None) -> set[str]: + return { + result[4][0] + for result in socket.getaddrinfo(hostname, port, type=socket.SOCK_STREAM) + } + + +def validate_public_url(url: str) -> None: + parsed = urlparse(url) + if parsed.scheme not in {"http", "https"}: + raise McpError(ErrorData( + code=INVALID_PARAMS, + message="URL must use the http or https scheme", + )) + + hostname = parsed.hostname + if not hostname: + raise McpError(ErrorData(code=INVALID_PARAMS, message="URL must include a hostname")) + + normalized_hostname = hostname.rstrip(".").lower() + if normalized_hostname == "localhost" or normalized_hostname.endswith(".localhost"): + raise McpError(ErrorData( + code=INVALID_PARAMS, + message="Fetching localhost URLs is not allowed", + )) + + try: + if _is_blocked_ip_address(normalized_hostname): + raise McpError(ErrorData( + code=INVALID_PARAMS, + message="Fetching private or non-public IP addresses is not allowed", + )) + return + except ValueError: + pass + + try: + addresses = resolve_hostname_addresses(normalized_hostname, parsed.port) + except OSError as exc: + raise McpError(ErrorData( + code=INTERNAL_ERROR, + message=f"Failed to resolve hostname {hostname}: {exc}", + )) + + if not addresses: + raise McpError(ErrorData( + code=INTERNAL_ERROR, + message=f"Failed to resolve hostname {hostname}", + )) + + for address in addresses: + if _is_blocked_ip_address(address): + raise McpError(ErrorData( + code=INVALID_PARAMS, + message="Fetching private or non-public IP addresses is not allowed", + )) + + +async def get_url_safely(client, url: str, **kwargs): + current_url = url + for _ in range(MAX_REDIRECTS + 1): + validate_public_url(current_url) + response = await client.get(current_url, follow_redirects=False, **kwargs) + if response.status_code not in {301, 302, 303, 307, 308}: + return response + + location = response.headers.get("location") + if not location: + return response + current_url = urljoin(str(getattr(response, "url", current_url)), location) + + raise McpError(ErrorData( + code=INTERNAL_ERROR, + message=f"Too many redirects while fetching {url}", + )) def extract_content_from_html(html: str) -> str: @@ -74,9 +161,9 @@ async def check_may_autonomously_fetch_url(url: str, user_agent: str, proxy_url: async with AsyncClient(proxy=proxy_url) as client: try: - response = await client.get( + response = await get_url_safely( + client, robot_txt_url, - follow_redirects=True, headers={"User-Agent": user_agent}, ) except HTTPError: @@ -118,9 +205,9 @@ async def fetch_url( async with AsyncClient(proxy=proxy_url) as client: try: - response = await client.get( + response = await get_url_safely( + client, url, - follow_redirects=True, headers={"User-Agent": user_agent}, timeout=30, ) diff --git a/src/fetch/tests/test_server.py b/src/fetch/tests/test_server.py index 96c1cb38c7..2816b38ca7 100644 --- a/src/fetch/tests/test_server.py +++ b/src/fetch/tests/test_server.py @@ -10,9 +10,19 @@ check_may_autonomously_fetch_url, fetch_url, DEFAULT_USER_AGENT_AUTONOMOUS, + validate_public_url, ) +@pytest.fixture(autouse=True) +def mock_public_dns(monkeypatch): + """Keep URL safety checks deterministic in unit tests.""" + monkeypatch.setattr( + "mcp_server_fetch.server.resolve_hostname_addresses", + lambda hostname, port: {"93.184.216.34"}, + ) + + class TestGetRobotsTxtUrl: """Tests for get_robots_txt_url function.""" @@ -47,6 +57,37 @@ def test_http_url(self): assert result == "http://example.com/robots.txt" +class TestValidatePublicUrl: + """Tests for SSRF URL validation.""" + + @pytest.mark.parametrize( + "url", + [ + "file:///etc/passwd", + "http://localhost/admin", + "http://127.0.0.1/admin", + "http://[::1]/admin", + "http://169.254.169.254/latest/meta-data/", + "http://10.0.0.4/admin", + ], + ) + def test_blocks_non_public_urls(self, url): + with pytest.raises(McpError): + validate_public_url(url) + + def test_blocks_private_dns_resolution(self, monkeypatch): + monkeypatch.setattr( + "mcp_server_fetch.server.resolve_hostname_addresses", + lambda hostname, port: {"10.0.0.4"}, + ) + + with pytest.raises(McpError): + validate_public_url("https://internal.example.com/page") + + def test_allows_public_https_urls(self): + validate_public_url("https://example.com/page") + + class TestExtractContentFromHtml: """Tests for extract_content_from_html function.""" @@ -266,6 +307,44 @@ async def test_fetch_json_returns_raw(self): assert content == json_content assert "cannot be simplified" in prefix + @pytest.mark.asyncio + async def test_blocks_private_url_before_request(self): + """Test that private IP URLs are rejected before any HTTP request is sent.""" + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client) + mock_client_class.return_value.__aexit__ = AsyncMock(return_value=None) + + with pytest.raises(McpError): + await fetch_url( + "http://127.0.0.1/admin", + DEFAULT_USER_AGENT_AUTONOMOUS, + ) + + mock_client.get.assert_not_called() + + @pytest.mark.asyncio + async def test_blocks_redirect_to_private_url(self): + """Test that redirects are validated before following them.""" + redirect_response = MagicMock() + redirect_response.status_code = 302 + redirect_response.headers = {"location": "http://127.0.0.1/admin"} + redirect_response.url = "https://example.com/start" + + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_client.get = AsyncMock(return_value=redirect_response) + mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client) + mock_client_class.return_value.__aexit__ = AsyncMock(return_value=None) + + with pytest.raises(McpError): + await fetch_url( + "https://example.com/start", + DEFAULT_USER_AGENT_AUTONOMOUS, + ) + + assert mock_client.get.call_count == 1 + @pytest.mark.asyncio async def test_fetch_404_raises_error(self): """Test that 404 response raises McpError."""