Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion src/mcp/server/auth/handlers/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any
from uuid import uuid4

from pydantic import BaseModel, ValidationError
from pydantic import AnyUrl, BaseModel, ValidationError
from starlette.requests import Request
from starlette.responses import Response

Expand All @@ -18,12 +18,22 @@
# provider from what we use in the HTTP handler
RegistrationRequest = OAuthClientMetadata

_LOOPBACK_HOSTS = {"localhost", "127.0.0.1", "[::1]"}


class RegistrationErrorResponse(BaseModel):
error: RegistrationErrorCode
error_description: str | None


def _validate_redirect_uri(url: AnyUrl) -> None:
if url.scheme != "https" and not (url.scheme == "http" and url.host in _LOOPBACK_HOSTS):
raise ValueError("redirect_uris must use HTTPS unless they are HTTP loopback URLs")

if url.fragment is not None:
raise ValueError("redirect_uris must not include a fragment")


@dataclass
class RegistrationHandler:
provider: OAuthAuthorizationServerProvider[Any, Any, Any]
Expand All @@ -45,6 +55,19 @@ async def handle(self, request: Request) -> Response:
status_code=400,
)

if client_metadata.redirect_uris is not None:
try:
for redirect_uri in client_metadata.redirect_uris:
_validate_redirect_uri(redirect_uri)
except ValueError as error:
return PydanticJSONResponse(
content=RegistrationErrorResponse(
error="invalid_redirect_uri",
error_description=str(error),
),
status_code=400,
)

client_id = str(uuid4())

# If auth method is None, default to client_secret_post
Expand Down
58 changes: 58 additions & 0 deletions tests/server/mcpserver/auth/test_auth_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,64 @@ async def test_client_registration(self, test_client: httpx.AsyncClient, mock_oa
# client_info["client_id"]
# ) is not None

@pytest.mark.anyio
async def test_client_registration_allows_loopback_redirect_uri(self, test_client: httpx.AsyncClient):
"""Test client registration with an HTTP loopback redirect URI."""
client_metadata = {
"redirect_uris": ["http://localhost:3030/callback"],
"client_name": "Loopback Client",
}

response = await test_client.post(
"/register",
json=client_metadata,
)
assert response.status_code == 201, response.content
assert response.json()["redirect_uris"] == ["http://localhost:3030/callback"]

@pytest.mark.anyio
async def test_client_registration_allows_null_redirect_uris(self, test_client: httpx.AsyncClient):
client_metadata = {
"redirect_uris": None,
"client_name": "No Redirect Client",
}

response = await test_client.post(
"/register",
json=client_metadata,
)
assert response.status_code == 201, response.content
assert "redirect_uris" not in response.json()

@pytest.mark.anyio
@pytest.mark.parametrize(
"redirect_uri",
[
"http://client.example.com/callback",
"javascript:alert(1)",
"data:text/html,<script>alert(1)</script>",
"file:///tmp/callback",
"https://client.example.com/callback#fragment",
"https://client.example.com/callback#",
],
)
async def test_client_registration_rejects_unsafe_redirect_uris(
self, test_client: httpx.AsyncClient, redirect_uri: str
):
"""Test client registration rejects unsafe redirect URI schemes and fragments."""
client_metadata = {
"redirect_uris": [redirect_uri],
"client_name": "Test Client",
}

response = await test_client.post(
"/register",
json=client_metadata,
)
assert response.status_code == 400
error_data = response.json()
assert error_data["error"] == "invalid_redirect_uri"

@pytest.mark.anyio
async def test_client_registration_missing_required_fields(self, test_client: httpx.AsyncClient):
"""Test client registration with missing required fields."""
Expand Down
Loading