181 lines
5.9 KiB
Python
181 lines
5.9 KiB
Python
"""
|
|
Message filtering and transformation for sparse mode.
|
|
|
|
Intercepts and modifies JSON-RPC messages to implement
|
|
sparse mode and virtual tool injection.
|
|
"""
|
|
|
|
import json
|
|
from typing import Dict, Any, Optional, List, Callable
|
|
from .registry import ToolRegistry
|
|
|
|
|
|
class MessageFilter:
|
|
"""Filter and transform JSON-RPC messages for sparse mode."""
|
|
|
|
def __init__(self, registry: ToolRegistry, sparse_mode: bool = True):
|
|
self.registry = registry
|
|
self.sparse_mode = sparse_mode
|
|
self._handled_ids: set = set()
|
|
|
|
def filter_outgoing(self, message: dict) -> Optional[dict]:
|
|
"""
|
|
Filter messages going from client to server.
|
|
|
|
Args:
|
|
message: Outgoing JSON-RPC message
|
|
|
|
Returns:
|
|
Modified message or None to block
|
|
"""
|
|
# For now, pass through all outgoing messages
|
|
return message
|
|
|
|
def filter_incoming(self, message: dict) -> Optional[dict]:
|
|
"""
|
|
Filter messages coming from server to client.
|
|
|
|
Args:
|
|
message: Incoming JSON-RPC message
|
|
|
|
Returns:
|
|
Modified message or None to block
|
|
"""
|
|
# Check if this is a duplicate error for a handled request
|
|
if (message.get("id") in self._handled_ids and
|
|
message.get("error", {}).get("code") == -32603):
|
|
# Block duplicate error
|
|
self._handled_ids.discard(message.get("id"))
|
|
return None
|
|
|
|
# Intercept tools/list responses for sparse mode
|
|
if (self.sparse_mode and
|
|
message.get("id") and
|
|
message.get("result", {}).get("tools")):
|
|
return self._filter_tools_response(message)
|
|
|
|
return message
|
|
|
|
def _filter_tools_response(self, message: dict) -> dict:
|
|
"""Apply sparse mode filtering to tools/list response."""
|
|
tools = message["result"]["tools"]
|
|
|
|
# Update registry with full tool list
|
|
self.registry.update_tools(tools)
|
|
|
|
# Replace with sparse tools
|
|
message = message.copy()
|
|
message["result"] = message["result"].copy()
|
|
message["result"]["tools"] = self.registry.get_sparse_tools()
|
|
|
|
return message
|
|
|
|
def mark_handled(self, request_id: Union[str, int]):
|
|
"""Mark a request ID as handled locally."""
|
|
self._handled_ids.add(request_id)
|
|
|
|
def is_virtual_tool(self, tool_name: str) -> bool:
|
|
"""Check if a tool is virtual (handled locally)."""
|
|
return tool_name in ["mcp_discover", "mcp_call", "onboarding"]
|
|
|
|
|
|
class VirtualToolHandler:
|
|
"""Handles virtual tool calls that don't exist on the MCP server."""
|
|
|
|
def __init__(self, registry: ToolRegistry, server_callback: Callable):
|
|
self.registry = registry
|
|
self.server_callback = server_callback
|
|
|
|
async def handle_tool_call(self, message: dict) -> Optional[dict]:
|
|
"""
|
|
Handle virtual tool calls.
|
|
|
|
Args:
|
|
message: Tool call request
|
|
|
|
Returns:
|
|
Response message or None if not handled
|
|
"""
|
|
if message.get("method") != "tools/call":
|
|
return None
|
|
|
|
tool_name = message.get("params", {}).get("name")
|
|
|
|
if tool_name == "mcp_discover":
|
|
return await self._handle_discover(message)
|
|
elif tool_name == "mcp_call":
|
|
return await self._handle_call(message)
|
|
elif tool_name == "onboarding":
|
|
# Onboarding is handled specially in the proxy
|
|
return None
|
|
|
|
return None
|
|
|
|
async def _handle_discover(self, message: dict) -> dict:
|
|
"""Handle mcp_discover tool call."""
|
|
params = message.get("params", {}).get("arguments", {})
|
|
jsonpath = params.get("jsonpath", "$.tools[*]")
|
|
|
|
try:
|
|
result = self.registry.discover(jsonpath)
|
|
|
|
return {
|
|
"jsonrpc": "2.0",
|
|
"id": message.get("id"),
|
|
"result": {
|
|
"content": [{
|
|
"type": "text",
|
|
"text": json.dumps(result, indent=2) if result else "No matches found"
|
|
}]
|
|
}
|
|
}
|
|
except Exception as e:
|
|
return {
|
|
"jsonrpc": "2.0",
|
|
"id": message.get("id"),
|
|
"error": {
|
|
"code": -32603,
|
|
"message": f"Discovery error: {str(e)}"
|
|
}
|
|
}
|
|
|
|
async def _handle_call(self, message: dict) -> dict:
|
|
"""Handle mcp_call tool - forward transformed request."""
|
|
params = message.get("params", {}).get("arguments", {})
|
|
|
|
# Extract method and params from the tool arguments
|
|
method = params.get("method")
|
|
call_params = params.get("params", {})
|
|
|
|
if not method:
|
|
return {
|
|
"jsonrpc": "2.0",
|
|
"id": message.get("id"),
|
|
"error": {
|
|
"code": -32602,
|
|
"message": "Missing 'method' parameter"
|
|
}
|
|
}
|
|
|
|
# Create the actual JSON-RPC call
|
|
forwarded_request = {
|
|
"jsonrpc": "2.0",
|
|
"id": message.get("id"), # Use same ID for response mapping
|
|
"method": method,
|
|
"params": call_params
|
|
}
|
|
|
|
# Forward to server and get response
|
|
try:
|
|
# The server_callback should handle sending and receiving
|
|
response = await self.server_callback(forwarded_request)
|
|
return response
|
|
except Exception as e:
|
|
return {
|
|
"jsonrpc": "2.0",
|
|
"id": message.get("id"),
|
|
"error": {
|
|
"code": -32603,
|
|
"message": str(e)
|
|
}
|
|
} |