fix: OAuth write MCP via SSE; require container recreate after build
- Route public write at GET /write/sse and POST /write/messages (ctxd.write) - Always require write token on /write/messages (was optional) - Remove debug tracing; document SSE write surface in SKILL.md - Add scripts/test_write_mcp.py for local OAuth write smoke test
This commit is contained in:
+111
-18
@@ -1261,7 +1261,7 @@ class HTTPServer:
|
||||
self._conn.commit()
|
||||
issuer = (self.cfg.oauth_issuer or "").rstrip("/") or "https://ctxd.cubecraftcreations.com"
|
||||
out = {k: v for k, v in client.items()}
|
||||
out["connector_url"] = f"{issuer}/readonly/mcp"
|
||||
out["connector_url"] = f"{issuer}/write/sse"
|
||||
out["authorization_server"] = issuer
|
||||
return (201, {"Content-Type": "application/json"}, json.dumps(out))
|
||||
|
||||
@@ -1590,20 +1590,38 @@ class CombinedApp:
|
||||
await self._serve_streamable_mcp(
|
||||
scope, receive, send,
|
||||
self.readonly_mcp_app, self._readonly_mcp_init_opts,
|
||||
surface_name="readonly",
|
||||
)
|
||||
return
|
||||
|
||||
# Public write MCP (Streamable HTTP) — OAuth bearer with ctxd.write scope.
|
||||
if path == "/write/mcp":
|
||||
# Public write MCP — SSE transport (OAuth ctxd.write scope)
|
||||
# Note: Using SSE instead of Streamable HTTP because MCP SDK 1.28's
|
||||
# Streamable HTTP transport has a race condition where EventSourceResponse
|
||||
# is killed by the task group before sending headers.
|
||||
if method == "GET" and path == "/write/sse":
|
||||
token = _request_token()
|
||||
valid, _ = self.oauth_store.validate_write_token(token) if self.cfg.oauth_enabled else (False, None)
|
||||
valid, _ = (
|
||||
self.oauth_store.validate_write_token(token)
|
||||
if self.cfg.oauth_enabled
|
||||
else (False, None)
|
||||
)
|
||||
if not valid:
|
||||
await _auth_error()
|
||||
return
|
||||
await self._serve_streamable_mcp(
|
||||
scope, receive, send,
|
||||
self.write_mcp_app, self._write_mcp_init_opts,
|
||||
await self._serve_write_mcp_sse(scope, receive, send)
|
||||
return
|
||||
|
||||
if method == "POST" and path in ("/write/messages", "/write/messages/"):
|
||||
token = _request_token()
|
||||
valid, _ = (
|
||||
self.oauth_store.validate_write_token(token)
|
||||
if self.cfg.oauth_enabled
|
||||
else (False, None)
|
||||
)
|
||||
if not valid:
|
||||
await _auth_error()
|
||||
return
|
||||
await self._serve_write_mcp_sse(scope, receive, send)
|
||||
return
|
||||
|
||||
# Internal full MCP (Streamable HTTP) — shared API key only.
|
||||
@@ -1615,6 +1633,7 @@ class CombinedApp:
|
||||
await self._serve_streamable_mcp(
|
||||
scope, receive, send,
|
||||
self.mcp_app, self._mcp_init_opts,
|
||||
surface_name="internal",
|
||||
)
|
||||
return
|
||||
token = _request_token()
|
||||
@@ -1701,20 +1720,94 @@ class CombinedApp:
|
||||
"body": body_str.encode(),
|
||||
})
|
||||
|
||||
async def _serve_streamable_mcp(self, scope, receive, send, mcp_app, init_opts):
|
||||
"""Streamable HTTP transport for MCP (single endpoint handles POST/GET/DELETE)."""
|
||||
import anyio
|
||||
async def _serve_write_mcp_sse(self, scope, receive, send):
|
||||
"""Write MCP over SSE transport for OAuth ctxd.write clients."""
|
||||
from mcp.server.sse import SseServerTransport
|
||||
|
||||
if not hasattr(self, "_write_sse_transport") or self._write_sse_transport is None:
|
||||
self._write_sse_transport = SseServerTransport("/write/messages")
|
||||
sse = self._write_sse_transport
|
||||
|
||||
method = scope.get("method", "GET")
|
||||
path = scope.get("path", "/")
|
||||
|
||||
if method == "GET" and path == "/write/sse":
|
||||
async with sse.connect_sse(scope, receive, send) as streams:
|
||||
await self.write_mcp_app.run(streams[0], streams[1], self._write_mcp_init_opts)
|
||||
return
|
||||
|
||||
if method == "POST" and path in ("/write/messages", "/write/messages/"):
|
||||
await sse.handle_post_message(scope, receive, send)
|
||||
return
|
||||
|
||||
await send({
|
||||
"type": "http.response.start",
|
||||
"status": 404,
|
||||
"headers": [(b"content-type", b"text/plain")],
|
||||
})
|
||||
await send({"type": "http.response.body", "body": b"Not found"})
|
||||
|
||||
async def _serve_streamable_mcp(self, scope, receive, send, mcp_app, init_opts, surface_name="mcp"):
|
||||
"""Streamable HTTP transport for MCP with persistent session handling.
|
||||
|
||||
Transports are cached by session ID. The first request creates a
|
||||
transport, starts the MCP server, and handles the request. Subsequent
|
||||
requests with the same session ID reuse the transport.
|
||||
"""
|
||||
from mcp.server.streamable_http import StreamableHTTPServerTransport
|
||||
|
||||
transport = StreamableHTTPServerTransport(mcp_session_id=None)
|
||||
async with transport.connect() as (read_stream, write_stream):
|
||||
async def run_server():
|
||||
await mcp_app.run(read_stream, write_stream, init_opts)
|
||||
# Check for existing session ID from request headers
|
||||
req_headers = {k.decode().lower(): v.decode() for k, v in scope.get("headers", [])}
|
||||
session_id = req_headers.get("mcp-session-id", "")
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(run_server)
|
||||
await transport.handle_request(scope, receive, send)
|
||||
tg.cancel_scope.cancel()
|
||||
# Get or create session cache for this surface
|
||||
cache_attr = f"_streamable_{surface_name}_sessions"
|
||||
if not hasattr(self, cache_attr):
|
||||
setattr(self, cache_attr, {})
|
||||
sessions = getattr(self, cache_attr)
|
||||
|
||||
# Reuse existing transport for this session
|
||||
if session_id and session_id in sessions:
|
||||
transport, server_task = sessions[session_id]
|
||||
await transport.handle_request(scope, receive, send)
|
||||
return
|
||||
|
||||
# New session — generate session ID, create transport, start server
|
||||
import secrets as _secrets
|
||||
import asyncio
|
||||
new_session_id = _secrets.token_hex(16)
|
||||
transport = StreamableHTTPServerTransport(
|
||||
mcp_session_id=new_session_id,
|
||||
is_json_response_enabled=True,
|
||||
)
|
||||
|
||||
# Connect the transport (creates read/write streams) — keep open for session lifetime
|
||||
cm = transport.connect()
|
||||
read_stream, write_stream = await cm.__aenter__()
|
||||
|
||||
async def run_server():
|
||||
await mcp_app.run(read_stream, write_stream, init_opts)
|
||||
|
||||
server_task = asyncio.create_task(run_server())
|
||||
|
||||
# Wrap send to inject Mcp-Session-Id header
|
||||
original_send = send
|
||||
header_injected = False
|
||||
|
||||
async def send_with_session_id(message):
|
||||
nonlocal header_injected
|
||||
if message.get("type") == "http.response.start" and not header_injected:
|
||||
existing = list(message.get("headers", []))
|
||||
existing.append((b"mcp-session-id", new_session_id.encode()))
|
||||
message["headers"] = existing
|
||||
header_injected = True
|
||||
await original_send(message)
|
||||
|
||||
# Handle the request (this blocks until the HTTP response is complete)
|
||||
await transport.handle_request(scope, receive, send_with_session_id)
|
||||
|
||||
# Store transport and server task for reuse on subsequent requests
|
||||
sessions[new_session_id] = (transport, server_task)
|
||||
|
||||
|
||||
# ── Entry point ────────────────────────────────────────────────────────────────
|
||||
|
||||
Reference in New Issue
Block a user