diff --git a/app/src/ctxd/db.py b/app/src/ctxd/db.py index b4f108a..c368aa2 100644 --- a/app/src/ctxd/db.py +++ b/app/src/ctxd/db.py @@ -114,6 +114,22 @@ def _is_pg(conn) -> bool: return not isinstance(conn, sqlite3.Connection) +def is_integrity_error(exc: BaseException) -> bool: + """True if exc is a DB constraint violation (unique/primary-key, FK, check) + on either backend. Used to map duplicate-key errors to HTTP 409. + + SQLite raises sqlite3.IntegrityError; psycopg raises subclasses of + psycopg.errors.IntegrityError (UniqueViolation, ForeignKeyViolation, ...). + """ + if isinstance(exc, sqlite3.IntegrityError): + return True + try: + import psycopg + except ImportError: + return False + return isinstance(exc, psycopg.errors.IntegrityError) + + # ── Helpers ─────────────────────────────────────────────────────────────────── def _row_to_dict(row) -> dict | None: diff --git a/app/src/ctxd/server.py b/app/src/ctxd/server.py index 9486491..23536f3 100644 --- a/app/src/ctxd/server.py +++ b/app/src/ctxd/server.py @@ -796,111 +796,6 @@ def make_mcp_server(cfg: CtxConfig, readonly: bool = False, oauth_scoped: bool = return app -def make_write_mcp_server(cfg: CtxConfig): - """Create an MCP server that exposes only write tools (OAuth ctxd.write scope).""" - app = Server("ctxd-write") - - def _conn(): - return _db.init_db(cfg) - - @app.list_tools() - async def list_tools(): - tools = [ - types.Tool( - name="update_file", - description="Update a single context file in a project with optimistic version checking", - inputSchema={ - "type": "object", - "properties": { - "project_id": {"type": "string", "description": "Project slug"}, - "file_path": {"type": "string", "description": "File name e.g. CONTEXT.md"}, - "content": {"type": "string", "description": "New file content (markdown)"}, - "base_version": {"type": "integer", "description": "Current version of the file (for conflict detection)"}, - }, - "required": ["project_id", "file_path", "content", "base_version"], - }, - ), - types.Tool( - name="set_project_tags", - description="Set metadata tags for a project (replaces all tags)", - inputSchema={ - "type": "object", - "properties": { - "project_id": {"type": "string", "description": "Project slug"}, - "tags": {"type": "array", "items": {"type": "string"}, "description": "Uppercase metadata tags"}, - }, - "required": ["project_id", "tags"], - }, - ), - types.Tool( - name="sync_to_project", - description="Write current shared context as AGENTS.md + symlinks to the project root", - inputSchema={ - "type": "object", - "properties": { - "project_id": {"type": "string", "description": "Project slug"}, - }, - "required": ["project_id"], - }, - ), - types.Tool( - name="get_client_guide", - description="Return the CTXD client guide (LLM-CLIENT.MD) — always read this first. Covers OAuth, MCP tools, read/write endpoints, version-checked updates, and error handling.", - inputSchema={ - "type": "object", - "properties": {}, - }, - ), - ] - return tools - - @app.call_tool() - async def call_tool(name: str, arguments: dict): - conn = _conn() - try: - if name == "get_client_guide": - result = _db.file_read(conn, "ctxd-docs", "LLM-CLIENT.MD") - if result is None: - return [types.TextContent(type="text", text="Client guide not found — ctxd-docs/LLM-CLIENT.MD is missing")] - return [types.TextContent( - type="text", - text=json.dumps(result, indent=2), - )] - - elif name == "update_file": - pid = arguments["project_id"] - file_path = arguments["file_path"] - content = arguments["content"] - base_version = arguments.get("base_version", 0) - result = _db.file_update(conn, pid, file_path, content, base_version, updated_by="oauth-write") - conn.commit() - return [types.TextContent(type="text", text=json.dumps(result, indent=2))] - - elif name == "set_project_tags": - pid = arguments["project_id"] - tags = [str(t).upper().replace(" ", "-") for t in arguments.get("tags", [])] - _db.project_set_tags(conn, pid, tags) - _db.audit_log(conn, "oauth-write", "set_tags", - f"Set tags for {pid}: {', '.join(tags)}", - agent_id="oauth", project_id=pid, - entity_type="project", entity_id=pid) - conn.commit() - return [types.TextContent(type="text", text=json.dumps({"ok": True, "project_id": pid, "tags": tags}, indent=2))] - - elif name == "sync_to_project": - pid = arguments["project_id"] - result = _db.sync_to_project(conn, pid) - conn.commit() - return [types.TextContent(type="text", text=json.dumps(result, indent=2))] - - else: - return [types.TextContent(type="text", text=json.dumps({"error": "unknown tool", "tool": name}, indent=2))] - finally: - conn.close() - - return app - - # ── HTTP Server (stdlib-only, no dependencies) ──────────────────────────────── class HTTPServer: @@ -917,6 +812,17 @@ class HTTPServer: try: return self._route(method, path, body, auth or {}) except Exception as e: + # The shared PG connection is left in an aborted-transaction state by + # any failed statement; without this rollback every subsequent request + # 500s ("current transaction is aborted"). This is the single funnel + # that guarantees the connection is clean no matter which path failed. + try: + self._conn.rollback() + except Exception: + logger.exception("Rollback failed after request error") + if _db.is_integrity_error(e): + return (409, {"Content-Type": "application/json"}, + json.dumps({"error": "conflict", "detail": str(e)})) logger.exception("HTTP error") return (500, {"Content-Type": "text/plain"}, f"Internal error: {e}") @@ -1043,10 +949,9 @@ class HTTPServer: return (200, {"Content-Type": "application/json"}, json.dumps({"ok": True, "user_id": user_id})) except Exception as e: self._conn.rollback() - msg = str(e) - if "UNIQUE" in msg or "idx_users_lower" in msg: + if _db.is_integrity_error(e): return (409, {"Content-Type": "application/json"}, json.dumps({"error": "user already exists"})) - return (400, {"Content-Type": "application/json"}, json.dumps({"error": msg})) + return (400, {"Content-Type": "application/json"}, json.dumps({"error": str(e)})) # POST /users//password — set or reset password (admin only) if method == "POST" and path.startswith("/users/") and path.endswith("/password"): @@ -1136,6 +1041,8 @@ class HTTPServer: return (200, {"Content-Type": "application/json"}, json.dumps({"ok": True, "project_id": pid})) except Exception as e: self._conn.rollback() + if _db.is_integrity_error(e): + return (409, {"Content-Type": "application/json"}, json.dumps({"error": "project already exists"})) return (400, {"Content-Type": "application/json"}, json.dumps({"error": str(e)})) # DELETE /projects/ — delete a project (admin only) @@ -1464,11 +1371,9 @@ class CombinedApp: self.mcp_app = make_mcp_server(cfg) self.oauth_mcp_app = make_mcp_server(cfg, oauth_scoped=True) self.readonly_mcp_app = make_mcp_server(cfg, readonly=True) - self.write_mcp_app = make_write_mcp_server(cfg) self._mcp_init_opts = self.mcp_app.create_initialization_options() self._oauth_mcp_init_opts = self.oauth_mcp_app.create_initialization_options() self._readonly_mcp_init_opts = self.readonly_mcp_app.create_initialization_options() - self._write_mcp_init_opts = self.write_mcp_app.create_initialization_options() async def __call__(self, scope, receive, send): if scope["type"] == "http":