feat: unified OAuth MCP connector (read+write on /readonly/mcp)

Scope-gated tools on one Streamable HTTP URL; /oauth/mcp alias.
Reconnect Claude once with ctxd.read ctxd.write.
This commit is contained in:
2026-06-25 12:44:01 +00:00
parent e3567f649f
commit 12b60ee8c7
3 changed files with 301 additions and 74 deletions
+143 -66
View File
@@ -3,6 +3,7 @@ ctxd server — dual-protocol daemon serving context over MCP + HTTP.
"""
import asyncio
import base64
import contextvars
import hashlib
import html
import json
@@ -44,6 +45,14 @@ WRITE_MCP_TOOLS = {
"sync_to_project",
}
# Per-request OAuth scopes for the unified public MCP connector (/readonly/mcp).
_oauth_mcp_ctx: contextvars.ContextVar[dict] = contextvars.ContextVar(
"oauth_mcp_ctx",
default={"has_read": False, "has_write": False, "user_id": "oauth"},
)
PUBLIC_OAUTH_MCP_PATHS = frozenset({"/readonly/mcp", "/oauth/mcp"})
def _b64url_sha256(value: str) -> str:
digest = hashlib.sha256(value.encode()).digest()
@@ -243,6 +252,20 @@ class OAuthStore:
return False, None
return True, record.get("approved_by_user_id")
def resolve_public_oauth(self, token: str) -> tuple[bool, bool, bool, str | None]:
"""Valid OAuth access token for public MCP. Returns (ok, has_read, has_write, user_id)."""
if not token:
return False, False, False, None
record = self.state.get("access_tokens", {}).get(token)
if not record or record.get("expires_at", 0) <= _now():
return False, False, False, None
scope = record.get("scope", "")
has_read = "ctxd.read" in scope
has_write = "ctxd.write" in scope
if not has_read and not has_write:
return False, False, False, None
return True, has_read, has_write, record.get("approved_by_user_id")
class WebSessionStore:
"""Opaque bearer sessions for CTXD Web UI per-user login."""
@@ -331,9 +354,12 @@ def _public_user_row(user: dict) -> dict:
# ── MCP Server ─────────────────────────────────────────────────────────────────
def make_mcp_server(cfg: CtxConfig, readonly: bool = False):
def make_mcp_server(cfg: CtxConfig, readonly: bool = False, oauth_scoped: bool = False):
"""Create an MCP server instance wired to our database."""
app = Server("ctxd-readonly" if readonly else "context-dossier")
if oauth_scoped:
app = Server("ctxd-oauth")
else:
app = Server("ctxd-readonly" if readonly else "context-dossier")
def _conn():
"""Short-lived connection per request (WAL allows concurrent reads)."""
@@ -503,17 +529,50 @@ def make_mcp_server(cfg: CtxConfig, readonly: bool = False):
},
),
]
if oauth_scoped:
ctx = _oauth_mcp_ctx.get()
allowed: set[str] = set()
if ctx.get("has_read"):
allowed |= READONLY_MCP_TOOLS
if ctx.get("has_write"):
allowed |= WRITE_MCP_TOOLS
if not allowed:
allowed = {"get_client_guide"}
return [tool for tool in tools if tool.name in allowed]
if readonly:
return [tool for tool in tools if tool.name in READONLY_MCP_TOOLS]
return tools
@app.call_tool()
async def call_tool(name: str, arguments: dict):
if readonly and name not in READONLY_MCP_TOOLS:
if oauth_scoped:
ctx = _oauth_mcp_ctx.get()
if name in WRITE_MCP_TOOLS and not ctx.get("has_write"):
return [types.TextContent(
type="text",
text=json.dumps({
"error": "forbidden",
"tool": name,
"message": "ctxd.write scope required",
}, indent=2),
)]
if name in READONLY_MCP_TOOLS and not ctx.get("has_read"):
return [types.TextContent(
type="text",
text=json.dumps({
"error": "forbidden",
"tool": name,
"message": "ctxd.read scope required",
}, indent=2),
)]
actor = ctx.get("user_id") or "oauth"
elif readonly and name not in READONLY_MCP_TOOLS:
return [types.TextContent(
type="text",
text=json.dumps({"error": "forbidden", "tool": name, "message": "read-only endpoint"}, indent=2),
)]
else:
actor = "hermes-gateway"
conn = _conn()
try:
@@ -536,7 +595,7 @@ def make_mcp_server(cfg: CtxConfig, readonly: bool = False):
"updated_at": ctx.get("updated_at"),
"content": ctx["content"],
}
_db.audit_log(conn, "hermes-gateway", "read",
_db.audit_log(conn, actor, "read",
f"MCP read context for {pid}",
agent_id="hermes", project_id=pid,
entity_type="project", entity_id=pid)
@@ -554,7 +613,7 @@ def make_mcp_server(cfg: CtxConfig, readonly: bool = False):
query = arguments["query"]
limit = arguments.get("limit", 10)
results = _db.search(conn, query, limit=limit)
_db.audit_log(conn, "hermes-gateway", "search",
_db.audit_log(conn, actor, "search",
f"MCP search: {query[:80]}",
agent_id="hermes")
conn.commit()
@@ -584,7 +643,7 @@ def make_mcp_server(cfg: CtxConfig, readonly: bool = False):
pid = arguments["project_id"]
tags = [str(t).upper().replace(" ", "-") for t in arguments.get("tags", [])]
_db.project_set_tags(conn, pid, tags)
_db.audit_log(conn, "hermes-gateway", "set_tags",
_db.audit_log(conn, actor, "set_tags",
f"Set tags for {pid}: {', '.join(tags)}",
agent_id="hermes", project_id=pid,
entity_type="project", entity_id=pid)
@@ -625,7 +684,7 @@ def make_mcp_server(cfg: CtxConfig, readonly: bool = False):
result = _db.file_read(conn, pid, file_path)
if result is None:
return [types.TextContent(type="text", text=f"File '{file_path}' not found in project '{pid}'")]
_db.audit_log(conn, "hermes-gateway", "read",
_db.audit_log(conn, actor, "read",
f"MCP read file {file_path} for {pid}",
agent_id="hermes", project_id=pid,
entity_type="file", entity_id=file_path)
@@ -649,7 +708,7 @@ def make_mcp_server(cfg: CtxConfig, readonly: bool = False):
file_path = arguments["file_path"]
content = arguments["content"]
base_version = arguments["base_version"]
result = _db.file_update(conn, pid, file_path, content, "hermes-gateway", base_version)
result = _db.file_update(conn, pid, file_path, content, actor, base_version)
if result.get("ok"):
conn.commit()
return [types.TextContent(
@@ -1261,7 +1320,7 @@ class HTTPServer:
self._conn.commit()
issuer = (self.cfg.oauth_issuer or "").rstrip("/") or "https://ctxd.cubecraftcreations.com"
out = {k: v for k, v in client.items()}
out["connector_url"] = f"{issuer}/write/sse"
out["connector_url"] = f"{issuer}/readonly/mcp"
out["authorization_server"] = issuer
return (201, {"Content-Type": "application/json"}, json.dumps(out))
@@ -1300,9 +1359,11 @@ class CombinedApp:
self.oauth_store = OAuthStore(cfg)
self.http_handler = HTTPServer(cfg, self.session_store, self.oauth_store)
self.mcp_app = make_mcp_server(cfg)
self.oauth_mcp_app = make_mcp_server(cfg, oauth_scoped=True)
self.readonly_mcp_app = make_mcp_server(cfg, readonly=True)
self.write_mcp_app = make_write_mcp_server(cfg)
self._mcp_init_opts = self.mcp_app.create_initialization_options()
self._oauth_mcp_init_opts = self.oauth_mcp_app.create_initialization_options()
self._readonly_mcp_init_opts = self.readonly_mcp_app.create_initialization_options()
self._write_mcp_init_opts = self.write_mcp_app.create_initialization_options()
@@ -1569,59 +1630,36 @@ class CombinedApp:
await self._serve_oauth(scope, receive, send, oauth_body)
return
# Public read-only MCP endpoint for Claude Desktop/web connectors.
# GET accepts OAuth bearer tokens. The legacy external_readonly_key query
# fallback remains temporarily for migration away from ?key= URLs.
def _readonly_token_valid(token: str) -> bool:
if self.cfg.api_key and token == self.cfg.api_key:
return True
if self.cfg.external_readonly_key and token == self.cfg.external_readonly_key:
return True
if self.cfg.oauth_enabled and self.oauth_store.validate_access_token(token):
return True
return False
# Public read-only MCP (Streamable HTTP) — OAuth bearer or legacy readonly key.
if path == "/readonly/mcp":
# Public OAuth MCP (Streamable HTTP) — single connector for read + write.
if path in PUBLIC_OAUTH_MCP_PATHS:
token = _request_token()
if not _readonly_token_valid(token):
ok, ctx = self._public_mcp_auth_context(token)
if not ok:
await _auth_error()
return
await self._serve_streamable_mcp(
scope, receive, send,
self.readonly_mcp_app, self._readonly_mcp_init_opts,
surface_name="readonly",
)
await self._serve_oauth_mcp_streamable(scope, receive, send, ctx)
return
# Public write MCP — SSE transport (OAuth ctxd.write scope)
# Legacy write-only SSE (same tools as unified connector; prefer /readonly/mcp).
# Note: Using SSE instead of Streamable HTTP because MCP SDK 1.28's
# Streamable HTTP transport has a race condition where EventSourceResponse
# is killed by the task group before sending headers.
if method == "GET" and path == "/write/sse":
token = _request_token()
valid, _ = (
self.oauth_store.validate_write_token(token)
if self.cfg.oauth_enabled
else (False, None)
)
if not valid:
ok, ctx = self._public_mcp_auth_context(token)
if not ok:
await _auth_error()
return
await self._serve_write_mcp_sse(scope, receive, send)
await self._serve_write_mcp_sse(scope, receive, send, ctx)
return
if method == "POST" and path in ("/write/messages", "/write/messages/"):
token = _request_token()
valid, _ = (
self.oauth_store.validate_write_token(token)
if self.cfg.oauth_enabled
else (False, None)
)
if not valid:
ok, ctx = self._public_mcp_auth_context(token)
if not ok:
await _auth_error()
return
await self._serve_write_mcp_sse(scope, receive, send)
await self._serve_write_mcp_sse(scope, receive, send, ctx)
return
# Internal full MCP (Streamable HTTP) — shared API key only.
@@ -1720,32 +1758,70 @@ class CombinedApp:
"body": body_str.encode(),
})
async def _serve_write_mcp_sse(self, scope, receive, send):
"""Write MCP over SSE transport for OAuth ctxd.write clients."""
def _public_mcp_auth_context(self, token: str) -> tuple[bool, dict]:
"""Resolve auth for the unified public MCP connector."""
if self.cfg.api_key and token == self.cfg.api_key:
return True, {"has_read": True, "has_write": False, "user_id": "api-key"}
if self.cfg.external_readonly_key and token == self.cfg.external_readonly_key:
return True, {"has_read": True, "has_write": False, "user_id": "readonly-key"}
if self.cfg.oauth_enabled:
ok, has_read, has_write, uid = self.oauth_store.resolve_public_oauth(token)
if ok:
return True, {
"has_read": has_read,
"has_write": has_write,
"user_id": uid or "oauth",
}
return False, {}
async def _serve_oauth_mcp_streamable(self, scope, receive, send, ctx: dict):
reset = _oauth_mcp_ctx.set(ctx)
try:
await self._serve_streamable_mcp(
scope, receive, send,
self.oauth_mcp_app, self._oauth_mcp_init_opts,
surface_name="oauth",
)
finally:
_oauth_mcp_ctx.reset(reset)
async def _serve_write_mcp_sse(self, scope, receive, send, ctx: dict | None = None):
"""Write MCP over SSE transport (legacy; same app as /readonly/mcp)."""
from mcp.server.sse import SseServerTransport
if not hasattr(self, "_write_sse_transport") or self._write_sse_transport is None:
self._write_sse_transport = SseServerTransport("/write/messages")
sse = self._write_sse_transport
if ctx is not None:
reset = _oauth_mcp_ctx.set(ctx)
else:
reset = None
method = scope.get("method", "GET")
path = scope.get("path", "/")
try:
if not hasattr(self, "_write_sse_transport") or self._write_sse_transport is None:
self._write_sse_transport = SseServerTransport("/write/messages")
sse = self._write_sse_transport
if method == "GET" and path == "/write/sse":
async with sse.connect_sse(scope, receive, send) as streams:
await self.write_mcp_app.run(streams[0], streams[1], self._write_mcp_init_opts)
return
method = scope.get("method", "GET")
path = scope.get("path", "/")
if method == "POST" and path in ("/write/messages", "/write/messages/"):
await sse.handle_post_message(scope, receive, send)
return
if method == "GET" and path == "/write/sse":
async with sse.connect_sse(scope, receive, send) as streams:
await self.oauth_mcp_app.run(
streams[0], streams[1], self._oauth_mcp_init_opts,
)
return
await send({
"type": "http.response.start",
"status": 404,
"headers": [(b"content-type", b"text/plain")],
})
await send({"type": "http.response.body", "body": b"Not found"})
if method == "POST" and path in ("/write/messages", "/write/messages/"):
await sse.handle_post_message(scope, receive, send)
return
await send({
"type": "http.response.start",
"status": 404,
"headers": [(b"content-type", b"text/plain")],
})
await send({"type": "http.response.body", "body": b"Not found"})
finally:
if reset is not None:
_oauth_mcp_ctx.reset(reset)
async def _serve_streamable_mcp(self, scope, receive, send, mcp_app, init_opts, surface_name="mcp"):
"""Streamable HTTP transport for MCP with persistent session handling.
@@ -1798,8 +1874,9 @@ class CombinedApp:
nonlocal header_injected
if message.get("type") == "http.response.start" and not header_injected:
existing = list(message.get("headers", []))
existing.append((b"mcp-session-id", new_session_id.encode()))
message["headers"] = existing
if not any(k.lower() == b"mcp-session-id" for k, _ in existing):
existing.append((b"mcp-session-id", new_session_id.encode()))
message["headers"] = existing
header_injected = True
await original_send(message)