Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion src/server/webStandardStreamableHttp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ export class WebStandardStreamableHTTPServerTransport implements Transport {
// The client MUST include an Accept header, listing text/event-stream as a supported content type.
const acceptHeader = req.headers.get('accept');
if (!acceptHeader?.includes('text/event-stream')) {
this.onerror?.(new Error('Not Acceptable: Client must accept text/event-stream'));
return this.createJsonErrorResponse(406, -32000, 'Not Acceptable: Client must accept text/event-stream');
}

Expand All @@ -409,6 +410,7 @@ export class WebStandardStreamableHTTPServerTransport implements Transport {
// Check if there's already an active standalone SSE stream for this session
if (this._streamMapping.get(this._standaloneSseStreamId) !== undefined) {
// Only one GET SSE stream is allowed per session
this.onerror?.(new Error('Conflict: Only one SSE stream is allowed per session'));
return this.createJsonErrorResponse(409, -32000, 'Conflict: Only one SSE stream is allowed per session');
}

Expand Down Expand Up @@ -460,6 +462,7 @@ export class WebStandardStreamableHTTPServerTransport implements Transport {
*/
private async replayEvents(lastEventId: string): Promise<Response> {
if (!this._eventStore) {
this.onerror?.(new Error('Event store not configured'));
return this.createJsonErrorResponse(400, -32000, 'Event store not configured');
}

Expand All @@ -470,11 +473,13 @@ export class WebStandardStreamableHTTPServerTransport implements Transport {
streamId = await this._eventStore.getStreamIdForEventId(lastEventId);

if (!streamId) {
this.onerror?.(new Error('Invalid event ID format'));
return this.createJsonErrorResponse(400, -32000, 'Invalid event ID format');
}

// Check conflict with the SAME streamId we'll use for mapping
if (this._streamMapping.get(streamId) !== undefined) {
this.onerror?.(new Error('Conflict: Stream already has an active connection'));
return this.createJsonErrorResponse(409, -32000, 'Conflict: Stream already has an active connection');
}
}
Expand Down Expand Up @@ -556,7 +561,8 @@ export class WebStandardStreamableHTTPServerTransport implements Transport {
eventData += `data: ${JSON.stringify(message)}\n\n`;
controller.enqueue(encoder.encode(eventData));
return true;
} catch {
} catch (error) {
this.onerror?.(error as Error);
return false;
}
}
Expand All @@ -565,6 +571,7 @@ export class WebStandardStreamableHTTPServerTransport implements Transport {
* Handles unsupported requests (PUT, PATCH, etc.)
*/
private handleUnsupportedRequest(): Response {
this.onerror?.(new Error('Method not allowed.'));
return new Response(
JSON.stringify({
jsonrpc: '2.0',
Expand Down Expand Up @@ -593,6 +600,7 @@ export class WebStandardStreamableHTTPServerTransport implements Transport {
const acceptHeader = req.headers.get('accept');
// The client MUST include an Accept header, listing both application/json and text/event-stream as supported content types.
if (!acceptHeader?.includes('application/json') || !acceptHeader.includes('text/event-stream')) {
this.onerror?.(new Error('Not Acceptable: Client must accept both application/json and text/event-stream'));
return this.createJsonErrorResponse(
406,
-32000,
Expand All @@ -602,6 +610,7 @@ export class WebStandardStreamableHTTPServerTransport implements Transport {

const ct = req.headers.get('content-type');
if (!ct || !ct.includes('application/json')) {
this.onerror?.(new Error('Unsupported Media Type: Content-Type must be application/json'));
return this.createJsonErrorResponse(415, -32000, 'Unsupported Media Type: Content-Type must be application/json');
}

Expand All @@ -618,6 +627,7 @@ export class WebStandardStreamableHTTPServerTransport implements Transport {
try {
rawMessage = await req.json();
} catch {
this.onerror?.(new Error('Parse error: Invalid JSON'));
return this.createJsonErrorResponse(400, -32700, 'Parse error: Invalid JSON');
}
}
Expand All @@ -632,6 +642,7 @@ export class WebStandardStreamableHTTPServerTransport implements Transport {
messages = [JSONRPCMessageSchema.parse(rawMessage)];
}
} catch {
this.onerror?.(new Error('Parse error: Invalid JSON-RPC message'));
return this.createJsonErrorResponse(400, -32700, 'Parse error: Invalid JSON-RPC message');
}

Expand All @@ -642,9 +653,11 @@ export class WebStandardStreamableHTTPServerTransport implements Transport {
// If it's a server with session management and the session ID is already set we should reject the request
// to avoid re-initialization.
if (this._initialized && this.sessionId !== undefined) {
this.onerror?.(new Error('Invalid Request: Server already initialized'));
return this.createJsonErrorResponse(400, -32600, 'Invalid Request: Server already initialized');
}
if (messages.length > 1) {
this.onerror?.(new Error('Invalid Request: Only one initialization request is allowed'));
return this.createJsonErrorResponse(400, -32600, 'Invalid Request: Only one initialization request is allowed');
}
this.sessionId = this.sessionIdGenerator?.();
Expand Down Expand Up @@ -824,18 +837,21 @@ export class WebStandardStreamableHTTPServerTransport implements Transport {
}
if (!this._initialized) {
// If the server has not been initialized yet, reject all requests
this.onerror?.(new Error('Bad Request: Server not initialized'));
return this.createJsonErrorResponse(400, -32000, 'Bad Request: Server not initialized');
}

const sessionId = req.headers.get('mcp-session-id');

if (!sessionId) {
// Non-initialization requests without a session ID should return 400 Bad Request
this.onerror?.(new Error('Bad Request: Mcp-Session-Id header is required'));
return this.createJsonErrorResponse(400, -32000, 'Bad Request: Mcp-Session-Id header is required');
}

if (sessionId !== this.sessionId) {
// Reject requests with invalid session ID with 404 Not Found
this.onerror?.(new Error('Session not found'));
return this.createJsonErrorResponse(404, -32001, 'Session not found');
}

Expand All @@ -859,6 +875,12 @@ export class WebStandardStreamableHTTPServerTransport implements Transport {
const protocolVersion = req.headers.get('mcp-protocol-version');

if (protocolVersion !== null && !SUPPORTED_PROTOCOL_VERSIONS.includes(protocolVersion)) {
this.onerror?.(
new Error(
`Bad Request: Unsupported protocol version: ${protocolVersion}` +
` (supported versions: ${SUPPORTED_PROTOCOL_VERSIONS.join(', ')})`
)
);
return this.createJsonErrorResponse(
400,
-32000,
Expand Down
160 changes: 160 additions & 0 deletions test/server/streamableHttp.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { createServer, type Server, IncomingMessage, ServerResponse } from 'node
import { AddressInfo, createServer as netCreateServer } from 'node:net';
import { randomUUID } from 'node:crypto';
import { EventStore, StreamableHTTPServerTransport, EventId, StreamId } from '../../src/server/streamableHttp.js';
import { WebStandardStreamableHTTPServerTransport } from '../../src/server/webStandardStreamableHttp.js';
import { McpServer } from '../../src/server/mcp.js';
import { CallToolResult, JSONRPCMessage } from '../../src/types.js';
import { AuthInfo } from '../../src/server/auth/types.js';
Expand Down Expand Up @@ -3112,3 +3113,162 @@ async function createTestServerWithDnsProtection(config: {
baseUrl: serverUrl
};
}

describe('WebStandardStreamableHTTPServerTransport - onerror callback', () => {
let transport: WebStandardStreamableHTTPServerTransport;
let mcpServer: McpServer;
let onerrorSpy: ReturnType<typeof vi.fn<(error: Error) => void>>;

/** Shorthand to build a Web Standard Request for direct transport testing. */
function req(method: string, opts?: { body?: unknown; headers?: Record<string, string> }): Request {
const headers: Record<string, string> = { ...opts?.headers };
if (method === 'POST') {
headers['Accept'] ??= 'application/json, text/event-stream';
headers['Content-Type'] ??= 'application/json';
} else if (method === 'GET') {
headers['Accept'] ??= 'text/event-stream';
}
return new Request('http://localhost/mcp', {
method,
headers,
body: opts?.body !== undefined ? (typeof opts.body === 'string' ? opts.body : JSON.stringify(opts.body)) : undefined
});
}

function withSession(sessionId: string, extra?: Record<string, string>): Record<string, string> {
return { 'mcp-session-id': sessionId, 'mcp-protocol-version': '2025-11-25', ...extra };
}

beforeEach(async () => {
onerrorSpy = vi.fn<(error: Error) => void>();
mcpServer = new McpServer({ name: 'test-server', version: '1.0.0' });
transport = new WebStandardStreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID() });
transport.onerror = onerrorSpy;
await mcpServer.connect(transport);
});

afterEach(async () => {
await transport.close();
});

async function initializeServer(): Promise<string> {
onerrorSpy.mockClear();
const response = await transport.handleRequest(req('POST', { body: TEST_MESSAGES.initialize }));
expect(response.status).toBe(200);
return response.headers.get('mcp-session-id') as string;
}

it('should call onerror for invalid JSON in POST', async () => {
await initializeServer();
await transport.handleRequest(req('POST', { body: 'not valid json' }));
expect(onerrorSpy).toHaveBeenCalled();
expect(onerrorSpy.mock.calls[0]![0]!.message).toMatch(/Invalid JSON/);
});

it('should call onerror for invalid JSON-RPC message', async () => {
const sid = await initializeServer();
await transport.handleRequest(req('POST', { body: { not: 'valid' }, headers: withSession(sid) }));
expect(onerrorSpy).toHaveBeenCalled();
expect(onerrorSpy.mock.calls[0]![0]!.message).toMatch(/Invalid JSON-RPC message/);
});

it('should call onerror for missing Accept header on POST', async () => {
await transport.handleRequest(
req('POST', { body: TEST_MESSAGES.initialize, headers: { Accept: 'application/json', 'Content-Type': 'application/json' } })
);
expect(onerrorSpy).toHaveBeenCalled();
expect(onerrorSpy.mock.calls[0]![0]!.message).toMatch(/Not Acceptable/);
});

it('should call onerror for unsupported Content-Type', async () => {
await transport.handleRequest(
req('POST', {
body: TEST_MESSAGES.initialize,
headers: { Accept: 'application/json, text/event-stream', 'Content-Type': 'text/plain' }
})
);
expect(onerrorSpy).toHaveBeenCalled();
expect(onerrorSpy.mock.calls[0]![0]!.message).toMatch(/Unsupported Media Type/);
});

it('should call onerror when server is not initialized', async () => {
await transport.handleRequest(req('POST', { body: TEST_MESSAGES.toolsList }));
expect(onerrorSpy).toHaveBeenCalledTimes(1);
expect(onerrorSpy.mock.calls[0]![0]!.message).toMatch(/Server not initialized/);
});

it('should call onerror for invalid session ID', async () => {
await initializeServer();
await transport.handleRequest(req('POST', { body: TEST_MESSAGES.toolsList, headers: withSession('invalid-session-id') }));
expect(onerrorSpy).toHaveBeenCalled();
expect(onerrorSpy.mock.calls[0]![0]!.message).toMatch(/Session not found/);
});

it('should call onerror for re-initialization attempt', async () => {
await initializeServer();
await transport.handleRequest(req('POST', { body: TEST_MESSAGES.initialize }));
expect(onerrorSpy).toHaveBeenCalled();
expect(onerrorSpy.mock.calls[0]![0]!.message).toMatch(/Server already initialized/);
});

it('should call onerror for missing Accept header on GET', async () => {
const sid = await initializeServer();
await transport.handleRequest(req('GET', { headers: { Accept: 'application/json', ...withSession(sid) } }));
expect(onerrorSpy).toHaveBeenCalled();
expect(onerrorSpy.mock.calls[0]![0]!.message).toMatch(/Not Acceptable/);
});

it('should call onerror for concurrent SSE streams', async () => {
const sid = await initializeServer();
const response1 = await transport.handleRequest(req('GET', { headers: withSession(sid) }));
expect(response1.status).toBe(200);
await transport.handleRequest(req('GET', { headers: withSession(sid) }));
expect(onerrorSpy).toHaveBeenCalled();
expect(onerrorSpy.mock.calls[0]![0]!.message).toMatch(/Only one SSE stream/);
});

it('should call onerror for unsupported protocol version', async () => {
const sid = await initializeServer();
await transport.handleRequest(
req('POST', { body: TEST_MESSAGES.toolsList, headers: withSession(sid, { 'mcp-protocol-version': 'unsupported-version' }) })
);
expect(onerrorSpy).toHaveBeenCalled();
expect(onerrorSpy.mock.calls[0]![0]!.message).toMatch(/Unsupported protocol version/);
});

it('should call onerror for unsupported HTTP methods', async () => {
await transport.handleRequest(req('PUT'));
expect(onerrorSpy).toHaveBeenCalledTimes(1);
expect(onerrorSpy.mock.calls[0]![0]!.message).toMatch(/Method not allowed/);
});

it('should call onerror for invalid event ID in replay', async () => {
const eventStore: EventStore = {
async storeEvent(): Promise<EventId> {
return 'evt-1';
},
async getStreamIdForEventId(): Promise<StreamId | undefined> {
return undefined;
},
async replayEventsAfter(): Promise<StreamId> {
return 'stream-1';
}
};
const storeTransport = new WebStandardStreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID(), eventStore });
const storeSpy = vi.fn<(error: Error) => void>();
storeTransport.onerror = storeSpy;
await new McpServer({ name: 'test', version: '1.0.0' }).connect(storeTransport);

const initResp = await storeTransport.handleRequest(req('POST', { body: TEST_MESSAGES.initialize }));
const sid = initResp.headers.get('mcp-session-id') as string;
storeSpy.mockClear();

const response = await storeTransport.handleRequest(
req('GET', { headers: { ...withSession(sid), 'Last-Event-ID': 'unknown-event-id' } })
);
expect(response.status).toBe(400);
expect(storeSpy).toHaveBeenCalledTimes(1);
expect(storeSpy.mock.calls[0]![0]!.message).toMatch(/Invalid event ID format/);
await storeTransport.close();
});
});
Loading