""" Message filtering and transformation for sparse mode. Intercepts and modifies JSON-RPC messages to implement sparse mode and virtual tool injection. """ import json from typing import Dict, Any, Optional, List, Callable from .registry import ToolRegistry class MessageFilter: """Filter and transform JSON-RPC messages for sparse mode.""" def __init__(self, registry: ToolRegistry, sparse_mode: bool = True): self.registry = registry self.sparse_mode = sparse_mode self._handled_ids: set = set() def filter_outgoing(self, message: dict) -> Optional[dict]: """ Filter messages going from client to server. Args: message: Outgoing JSON-RPC message Returns: Modified message or None to block """ # For now, pass through all outgoing messages return message def filter_incoming(self, message: dict) -> Optional[dict]: """ Filter messages coming from server to client. Args: message: Incoming JSON-RPC message Returns: Modified message or None to block """ # Check if this is a duplicate error for a handled request if (message.get("id") in self._handled_ids and message.get("error", {}).get("code") == -32603): # Block duplicate error self._handled_ids.discard(message.get("id")) return None # Intercept tools/list responses for sparse mode if (self.sparse_mode and message.get("id") and message.get("result", {}).get("tools")): return self._filter_tools_response(message) return message def _filter_tools_response(self, message: dict) -> dict: """Apply sparse mode filtering to tools/list response.""" tools = message["result"]["tools"] # Update registry with full tool list self.registry.update_tools(tools) # Replace with sparse tools message = message.copy() message["result"] = message["result"].copy() message["result"]["tools"] = self.registry.get_sparse_tools() return message def mark_handled(self, request_id: Union[str, int]): """Mark a request ID as handled locally.""" self._handled_ids.add(request_id) def is_virtual_tool(self, tool_name: str) -> bool: """Check if a tool is virtual (handled locally).""" return tool_name in ["mcp_discover", "mcp_call", "onboarding"] class VirtualToolHandler: """Handles virtual tool calls that don't exist on the MCP server.""" def __init__(self, registry: ToolRegistry, server_callback: Callable): self.registry = registry self.server_callback = server_callback async def handle_tool_call(self, message: dict) -> Optional[dict]: """ Handle virtual tool calls. Args: message: Tool call request Returns: Response message or None if not handled """ if message.get("method") != "tools/call": return None tool_name = message.get("params", {}).get("name") if tool_name == "mcp_discover": return await self._handle_discover(message) elif tool_name == "mcp_call": return await self._handle_call(message) elif tool_name == "onboarding": # Onboarding is handled specially in the proxy return None return None async def _handle_discover(self, message: dict) -> dict: """Handle mcp_discover tool call.""" params = message.get("params", {}).get("arguments", {}) jsonpath = params.get("jsonpath", "$.tools[*]") try: result = self.registry.discover(jsonpath) return { "jsonrpc": "2.0", "id": message.get("id"), "result": { "content": [{ "type": "text", "text": json.dumps(result, indent=2) if result else "No matches found" }] } } except Exception as e: return { "jsonrpc": "2.0", "id": message.get("id"), "error": { "code": -32603, "message": f"Discovery error: {str(e)}" } } async def _handle_call(self, message: dict) -> dict: """Handle mcp_call tool - forward transformed request.""" params = message.get("params", {}).get("arguments", {}) # Extract method and params from the tool arguments method = params.get("method") call_params = params.get("params", {}) if not method: return { "jsonrpc": "2.0", "id": message.get("id"), "error": { "code": -32602, "message": "Missing 'method' parameter" } } # Create the actual JSON-RPC call forwarded_request = { "jsonrpc": "2.0", "id": message.get("id"), # Use same ID for response mapping "method": method, "params": call_params } # Forward to server and get response try: # The server_callback should handle sending and receiving response = await self.server_callback(forwarded_request) return response except Exception as e: return { "jsonrpc": "2.0", "id": message.get("id"), "error": { "code": -32603, "message": str(e) } }