feat: add streamable-http transport support
This commit is contained in:
parent
568fd53ad2
commit
6199f28e9e
|
|
@ -326,7 +326,8 @@ async def handle_mcp_command(args):
|
||||||
browser = MCPBrowser(
|
browser = MCPBrowser(
|
||||||
server_name=args.server,
|
server_name=args.server,
|
||||||
config_path=config_path,
|
config_path=config_path,
|
||||||
enable_builtin_servers=not args.no_builtin
|
enable_builtin_servers=not args.no_builtin,
|
||||||
|
transport_override=getattr(args, "transport_override", None)
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -409,7 +410,8 @@ async def start_daemon_background(args):
|
||||||
browser = MCPBrowser(
|
browser = MCPBrowser(
|
||||||
server_name=args.server,
|
server_name=args.server,
|
||||||
config_path=config_path,
|
config_path=config_path,
|
||||||
enable_builtin_servers=not args.no_builtin
|
enable_builtin_servers=not args.no_builtin,
|
||||||
|
transport_override=getattr(args, "transport_override", None)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run daemon
|
# Run daemon
|
||||||
|
|
@ -687,6 +689,28 @@ Environment:
|
||||||
parser.add_argument("--version", "-v", action="version",
|
parser.add_argument("--version", "-v", action="version",
|
||||||
version=f"%(prog)s {__version__}",
|
version=f"%(prog)s {__version__}",
|
||||||
help="Show program version and exit")
|
help="Show program version and exit")
|
||||||
|
parser.add_argument("--transport", choices=["stdio", "streamable-http"],
|
||||||
|
help="Override transport type for the target server")
|
||||||
|
parser.add_argument("--transport-url",
|
||||||
|
help="Endpoint URL for streamable-http transport override")
|
||||||
|
parser.add_argument("--transport-header", action="append", metavar="KEY=VALUE",
|
||||||
|
help="Additional HTTP headers for streamable-http transport (repeatable)")
|
||||||
|
parser.add_argument("--transport-timeout", type=float,
|
||||||
|
help="Override HTTP request timeout in seconds")
|
||||||
|
parser.add_argument("--transport-sse-timeout", type=float,
|
||||||
|
help="Override SSE read timeout in seconds")
|
||||||
|
parser.add_argument("--oauth-token-url",
|
||||||
|
help="OAuth token endpoint for client credentials flow")
|
||||||
|
parser.add_argument("--oauth-client-id", help="OAuth client identifier")
|
||||||
|
parser.add_argument("--oauth-client-secret", help="OAuth client secret")
|
||||||
|
parser.add_argument("--oauth-scope",
|
||||||
|
help="OAuth scopes (space-separated)")
|
||||||
|
parser.add_argument("--oauth-audience",
|
||||||
|
help="OAuth audience/resource parameter")
|
||||||
|
parser.add_argument("--oauth-extra-param", action="append", metavar="KEY=VALUE",
|
||||||
|
help="Additional form parameters for OAuth token requests (repeatable)")
|
||||||
|
parser.add_argument("--oauth-token",
|
||||||
|
help="Static OAuth bearer token (applies Authorization header)")
|
||||||
|
|
||||||
# MCP method commands
|
# MCP method commands
|
||||||
subparsers = parser.add_subparsers(dest="command", help="MCP methods")
|
subparsers = parser.add_subparsers(dest="command", help="MCP methods")
|
||||||
|
|
@ -725,6 +749,70 @@ Environment:
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
def parse_key_value_pairs(pairs):
|
||||||
|
result = {}
|
||||||
|
if not pairs:
|
||||||
|
return result
|
||||||
|
for item in pairs:
|
||||||
|
if "=" not in item:
|
||||||
|
print(f"Invalid KEY=VALUE format: {item}", file=sys.stderr)
|
||||||
|
sys.exit(2)
|
||||||
|
key, value = item.split("=", 1)
|
||||||
|
key = key.strip()
|
||||||
|
value = value.strip()
|
||||||
|
if not key:
|
||||||
|
print(f"Header key is empty for entry: {item}", file=sys.stderr)
|
||||||
|
sys.exit(2)
|
||||||
|
result[key] = value
|
||||||
|
return result
|
||||||
|
|
||||||
|
transport_headers = parse_key_value_pairs(args.transport_header)
|
||||||
|
oauth_extra = parse_key_value_pairs(args.oauth_extra_param)
|
||||||
|
|
||||||
|
if args.oauth_token:
|
||||||
|
transport_headers.setdefault("Authorization", f"Bearer {args.oauth_token}")
|
||||||
|
|
||||||
|
oauth_override = {}
|
||||||
|
if args.oauth_token_url:
|
||||||
|
oauth_override["token_url"] = args.oauth_token_url
|
||||||
|
if args.oauth_client_id:
|
||||||
|
oauth_override["client_id"] = args.oauth_client_id
|
||||||
|
if args.oauth_client_secret is not None:
|
||||||
|
oauth_override["client_secret"] = args.oauth_client_secret
|
||||||
|
if args.oauth_scope is not None:
|
||||||
|
oauth_override["scope"] = args.oauth_scope
|
||||||
|
if args.oauth_audience is not None:
|
||||||
|
oauth_override["audience"] = args.oauth_audience
|
||||||
|
if oauth_extra:
|
||||||
|
oauth_override["extra_params"] = oauth_extra
|
||||||
|
if args.oauth_token:
|
||||||
|
oauth_override["access_token"] = args.oauth_token
|
||||||
|
|
||||||
|
transport_override = None
|
||||||
|
if any([
|
||||||
|
args.transport,
|
||||||
|
args.transport_url,
|
||||||
|
transport_headers,
|
||||||
|
args.transport_timeout is not None,
|
||||||
|
args.transport_sse_timeout is not None,
|
||||||
|
oauth_override,
|
||||||
|
]):
|
||||||
|
transport_override = {}
|
||||||
|
if args.transport:
|
||||||
|
transport_override["type"] = args.transport
|
||||||
|
if args.transport_url:
|
||||||
|
transport_override["url"] = args.transport_url
|
||||||
|
if transport_headers:
|
||||||
|
transport_override["headers"] = transport_headers
|
||||||
|
if args.transport_timeout is not None:
|
||||||
|
transport_override["timeout"] = args.transport_timeout
|
||||||
|
if args.transport_sse_timeout is not None:
|
||||||
|
transport_override["sse_timeout"] = args.transport_sse_timeout
|
||||||
|
if oauth_override:
|
||||||
|
transport_override["oauth"] = oauth_override
|
||||||
|
|
||||||
|
setattr(args, "transport_override", transport_override)
|
||||||
|
|
||||||
# Handle special commands first
|
# Handle special commands first
|
||||||
if args.list_servers:
|
if args.list_servers:
|
||||||
show_available_servers(args.config)
|
show_available_servers(args.config)
|
||||||
|
|
@ -765,7 +853,8 @@ Environment:
|
||||||
browser = MCPBrowser(
|
browser = MCPBrowser(
|
||||||
server_name=args.server,
|
server_name=args.server,
|
||||||
config_path=config_path,
|
config_path=config_path,
|
||||||
enable_builtin_servers=not args.no_builtin
|
enable_builtin_servers=not args.no_builtin,
|
||||||
|
transport_override=transport_override
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle test mode
|
# Handle test mode
|
||||||
|
|
|
||||||
|
|
@ -14,15 +14,39 @@ from pathlib import Path
|
||||||
from .default_configs import ConfigManager
|
from .default_configs import ConfigManager
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OAuthClientCredentialsConfig:
|
||||||
|
"""Configuration for OAuth client credentials flow."""
|
||||||
|
token_url: Optional[str] = None
|
||||||
|
client_id: Optional[str] = None
|
||||||
|
client_secret: Optional[str] = None
|
||||||
|
scope: Optional[str] = None
|
||||||
|
audience: Optional[str] = None
|
||||||
|
extra_params: Dict[str, str] = field(default_factory=dict)
|
||||||
|
access_token: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TransportConfig:
|
||||||
|
"""Transport configuration for connecting to an MCP server."""
|
||||||
|
type: str = "stdio"
|
||||||
|
url: Optional[str] = None
|
||||||
|
headers: Dict[str, str] = field(default_factory=dict)
|
||||||
|
timeout: float = 30.0
|
||||||
|
sse_timeout: float = 300.0
|
||||||
|
oauth: Optional[OAuthClientCredentialsConfig] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MCPServerConfig:
|
class MCPServerConfig:
|
||||||
"""Configuration for a single MCP server."""
|
"""Configuration for a single MCP server."""
|
||||||
command: List[str]
|
command: Optional[List[str]] = None
|
||||||
args: List[str] = field(default_factory=list)
|
args: List[str] = field(default_factory=list)
|
||||||
env: Dict[str, str] = field(default_factory=dict)
|
env: Dict[str, str] = field(default_factory=dict)
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
enabled: bool = True
|
enabled: bool = True
|
||||||
|
transport: TransportConfig = field(default_factory=TransportConfig)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -84,13 +108,35 @@ class ConfigLoader:
|
||||||
# Convert to dataclass instances
|
# Convert to dataclass instances
|
||||||
servers = {}
|
servers = {}
|
||||||
for name, server_config in config_data.get("servers", {}).items():
|
for name, server_config in config_data.get("servers", {}).items():
|
||||||
|
transport_data = server_config.get("transport", {}) or {}
|
||||||
|
oauth_data = transport_data.get("oauth") or {}
|
||||||
|
oauth_config = None
|
||||||
|
if oauth_data:
|
||||||
|
oauth_config = OAuthClientCredentialsConfig(
|
||||||
|
token_url=oauth_data.get("token_url") or oauth_data.get("tokenUrl"),
|
||||||
|
client_id=oauth_data.get("client_id") or oauth_data.get("clientId"),
|
||||||
|
client_secret=oauth_data.get("client_secret") or oauth_data.get("clientSecret"),
|
||||||
|
scope=oauth_data.get("scope"),
|
||||||
|
audience=oauth_data.get("audience") or oauth_data.get("resource"),
|
||||||
|
extra_params=oauth_data.get("extra_params") or oauth_data.get("extraParams", {}),
|
||||||
|
access_token=oauth_data.get("access_token") or oauth_data.get("accessToken"),
|
||||||
|
)
|
||||||
|
transport_config = TransportConfig(
|
||||||
|
type=transport_data.get("type", "stdio"),
|
||||||
|
url=transport_data.get("url"),
|
||||||
|
headers=transport_data.get("headers", {}),
|
||||||
|
timeout=transport_data.get("timeout", 30.0),
|
||||||
|
sse_timeout=transport_data.get("sse_timeout") or transport_data.get("sseTimeout", 300.0),
|
||||||
|
oauth=oauth_config
|
||||||
|
)
|
||||||
servers[name] = MCPServerConfig(
|
servers[name] = MCPServerConfig(
|
||||||
command=server_config["command"],
|
command=server_config.get("command"),
|
||||||
args=server_config.get("args", []),
|
args=server_config.get("args", []),
|
||||||
env=server_config.get("env", {}),
|
env=server_config.get("env", {}),
|
||||||
name=server_config.get("name", name),
|
name=server_config.get("name", name),
|
||||||
description=server_config.get("description"),
|
description=server_config.get("description"),
|
||||||
enabled=server_config.get("enabled", True)
|
enabled=server_config.get("enabled", True),
|
||||||
|
transport=transport_config
|
||||||
)
|
)
|
||||||
|
|
||||||
self._config = MCPBrowserConfig(
|
self._config = MCPBrowserConfig(
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,12 @@ import asyncio
|
||||||
from typing import Dict, Any, Optional, Union
|
from typing import Dict, Any, Optional, Union
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from .config import ConfigLoader, MCPBrowserConfig
|
from .config import (
|
||||||
|
ConfigLoader,
|
||||||
|
MCPBrowserConfig,
|
||||||
|
MCPServerConfig,
|
||||||
|
OAuthClientCredentialsConfig,
|
||||||
|
)
|
||||||
from .server import MCPServer
|
from .server import MCPServer
|
||||||
from .multi_server import MultiServerManager
|
from .multi_server import MultiServerManager
|
||||||
from .registry import ToolRegistry
|
from .registry import ToolRegistry
|
||||||
|
|
@ -29,7 +34,7 @@ class MCPBrowser:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config_path: Optional[Path] = None, server_name: Optional[str] = None,
|
def __init__(self, config_path: Optional[Path] = None, server_name: Optional[str] = None,
|
||||||
enable_builtin_servers: bool = True):
|
enable_builtin_servers: bool = True, transport_override: Optional[Dict[str, Any]] = None):
|
||||||
"""
|
"""
|
||||||
Initialize MCP Browser.
|
Initialize MCP Browser.
|
||||||
|
|
||||||
|
|
@ -47,6 +52,7 @@ class MCPBrowser:
|
||||||
self.virtual_handler: Optional[VirtualToolHandler] = None
|
self.virtual_handler: Optional[VirtualToolHandler] = None
|
||||||
self._server_name = server_name
|
self._server_name = server_name
|
||||||
self._enable_builtin_servers = enable_builtin_servers
|
self._enable_builtin_servers = enable_builtin_servers
|
||||||
|
self._transport_override = transport_override
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
self._response_buffer: Dict[Union[str, int], asyncio.Future] = {}
|
self._response_buffer: Dict[Union[str, int], asyncio.Future] = {}
|
||||||
self._next_id = 1
|
self._next_id = 1
|
||||||
|
|
@ -55,6 +61,42 @@ class MCPBrowser:
|
||||||
self._server_configs = {}
|
self._server_configs = {}
|
||||||
self._config_mtime = None
|
self._config_mtime = None
|
||||||
|
|
||||||
|
def _apply_transport_override(self, server_config: MCPServerConfig):
|
||||||
|
"""Apply transport overrides from CLI arguments."""
|
||||||
|
override = self._transport_override or {}
|
||||||
|
transport = server_config.transport
|
||||||
|
|
||||||
|
transport.type = override.get("type", transport.type)
|
||||||
|
|
||||||
|
if override.get("url"):
|
||||||
|
transport.url = override["url"]
|
||||||
|
if override.get("headers"):
|
||||||
|
transport.headers.update(override["headers"])
|
||||||
|
if "timeout" in override and override["timeout"] is not None:
|
||||||
|
transport.timeout = override["timeout"]
|
||||||
|
if "sse_timeout" in override and override["sse_timeout"] is not None:
|
||||||
|
transport.sse_timeout = override["sse_timeout"]
|
||||||
|
|
||||||
|
oauth_override = override.get("oauth")
|
||||||
|
if oauth_override:
|
||||||
|
if not transport.oauth:
|
||||||
|
transport.oauth = OAuthClientCredentialsConfig()
|
||||||
|
oauth = transport.oauth
|
||||||
|
if oauth_override.get("token_url"):
|
||||||
|
oauth.token_url = oauth_override["token_url"]
|
||||||
|
if oauth_override.get("client_id"):
|
||||||
|
oauth.client_id = oauth_override["client_id"]
|
||||||
|
if "client_secret" in oauth_override:
|
||||||
|
oauth.client_secret = oauth_override["client_secret"]
|
||||||
|
if "scope" in oauth_override:
|
||||||
|
oauth.scope = oauth_override["scope"]
|
||||||
|
if "audience" in oauth_override:
|
||||||
|
oauth.audience = oauth_override["audience"]
|
||||||
|
if "extra_params" in oauth_override and oauth_override["extra_params"] is not None:
|
||||||
|
oauth.extra_params.update(oauth_override["extra_params"])
|
||||||
|
if "access_token" in oauth_override:
|
||||||
|
oauth.access_token = oauth_override["access_token"]
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
"""Async context manager entry."""
|
"""Async context manager entry."""
|
||||||
await self.initialize()
|
await self.initialize()
|
||||||
|
|
@ -78,6 +120,8 @@ class MCPBrowser:
|
||||||
raise ValueError(f"Server '{server_name}' not found in configuration")
|
raise ValueError(f"Server '{server_name}' not found in configuration")
|
||||||
|
|
||||||
server_config = self.config.servers[server_name]
|
server_config = self.config.servers[server_name]
|
||||||
|
if self._transport_override:
|
||||||
|
self._apply_transport_override(server_config)
|
||||||
|
|
||||||
# Create multi-server manager if using built-in servers
|
# Create multi-server manager if using built-in servers
|
||||||
if self._enable_builtin_servers:
|
if self._enable_builtin_servers:
|
||||||
|
|
|
||||||
|
|
@ -9,12 +9,11 @@ import os
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
import subprocess
|
import subprocess
|
||||||
from typing import Optional, Dict, Any, Callable, List
|
from typing import Optional, Dict, Any, Callable, List, Union
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from .buffer import JsonRpcBuffer
|
from .buffer import JsonRpcBuffer
|
||||||
from .config import MCPServerConfig
|
from .config import MCPServerConfig
|
||||||
from .logging_config import get_logger, TRACE
|
from .logging_config import get_logger, TRACE
|
||||||
|
from .streamable_http import StreamableHTTPClient, StreamableHTTPError
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -32,9 +31,30 @@ class MCPServer:
|
||||||
self._pending_requests: Dict[Union[str, int], asyncio.Future] = {}
|
self._pending_requests: Dict[Union[str, int], asyncio.Future] = {}
|
||||||
self._last_error_time: Optional[float] = None
|
self._last_error_time: Optional[float] = None
|
||||||
self._offline_since: Optional[float] = None
|
self._offline_since: Optional[float] = None
|
||||||
|
self._http_client: Optional[StreamableHTTPClient] = None
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
"""Start the MCP server process."""
|
"""Start the MCP server process."""
|
||||||
|
if self.config.transport.type == "streamable-http":
|
||||||
|
if self._http_client:
|
||||||
|
return
|
||||||
|
if not self.config.transport.url:
|
||||||
|
raise ValueError("Streamable HTTP transport requires 'url' option")
|
||||||
|
|
||||||
|
self.logger.info(f"Connecting to streamable HTTP endpoint: {self.config.transport.url}")
|
||||||
|
self._http_client = StreamableHTTPClient(
|
||||||
|
self.config.transport.url,
|
||||||
|
headers=self.config.transport.headers,
|
||||||
|
timeout=self.config.transport.timeout,
|
||||||
|
sse_timeout=self.config.transport.sse_timeout,
|
||||||
|
oauth_config=self.config.transport.oauth,
|
||||||
|
logger=self.logger,
|
||||||
|
)
|
||||||
|
await self._http_client.start()
|
||||||
|
self._running = True
|
||||||
|
self._offline_since = None
|
||||||
|
return
|
||||||
|
|
||||||
if self.process:
|
if self.process:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -46,6 +66,9 @@ class MCPServer:
|
||||||
self.logger.warning(f"Server has been offline for {offline_duration:.0f}s, skipping start")
|
self.logger.warning(f"Server has been offline for {offline_duration:.0f}s, skipping start")
|
||||||
raise RuntimeError(f"Server marked as offline since {offline_duration:.0f}s ago")
|
raise RuntimeError(f"Server marked as offline since {offline_duration:.0f}s ago")
|
||||||
|
|
||||||
|
if not self.config.command:
|
||||||
|
raise ValueError("Server command not configured for stdio transport")
|
||||||
|
|
||||||
# Prepare environment
|
# Prepare environment
|
||||||
env = os.environ.copy()
|
env = os.environ.copy()
|
||||||
env.update({
|
env.update({
|
||||||
|
|
@ -85,6 +108,13 @@ class MCPServer:
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
"""Stop the MCP server process."""
|
"""Stop the MCP server process."""
|
||||||
|
if self.config.transport.type == "streamable-http":
|
||||||
|
self._running = False
|
||||||
|
if self._http_client:
|
||||||
|
await self._http_client.stop()
|
||||||
|
self._http_client = None
|
||||||
|
return
|
||||||
|
|
||||||
self._running = False
|
self._running = False
|
||||||
|
|
||||||
if self.process:
|
if self.process:
|
||||||
|
|
@ -116,6 +146,9 @@ class MCPServer:
|
||||||
Returns:
|
Returns:
|
||||||
Response result or raises exception on error
|
Response result or raises exception on error
|
||||||
"""
|
"""
|
||||||
|
if self.config.transport.type == "streamable-http":
|
||||||
|
return await self._send_request_http(method, params or {})
|
||||||
|
|
||||||
if not self.process:
|
if not self.process:
|
||||||
raise RuntimeError("MCP server not started")
|
raise RuntimeError("MCP server not started")
|
||||||
|
|
||||||
|
|
@ -152,6 +185,41 @@ class MCPServer:
|
||||||
self._mark_offline()
|
self._mark_offline()
|
||||||
raise TimeoutError(f"No response for request {request_id}")
|
raise TimeoutError(f"No response for request {request_id}")
|
||||||
|
|
||||||
|
async def _send_request_http(self, method: str, params: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Send request via streamable HTTP transport."""
|
||||||
|
if not self._http_client:
|
||||||
|
raise RuntimeError("Streamable HTTP client not started")
|
||||||
|
|
||||||
|
request_id = self._next_id
|
||||||
|
self._next_id += 1
|
||||||
|
|
||||||
|
request = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": request_id,
|
||||||
|
"method": method,
|
||||||
|
"params": params
|
||||||
|
}
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
future: asyncio.Future = loop.create_future()
|
||||||
|
self._pending_requests[request_id] = future
|
||||||
|
|
||||||
|
timeout = self.config.transport.timeout or 30.0
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._http_client.send(request, self._handle_message)
|
||||||
|
result = await asyncio.wait_for(future, timeout=timeout)
|
||||||
|
return result
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
raise TimeoutError(f"No response for request {request_id}")
|
||||||
|
except StreamableHTTPError:
|
||||||
|
raise
|
||||||
|
except Exception as exc:
|
||||||
|
self.logger.error(f"Streamable HTTP request failed: {exc}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self._pending_requests.pop(request_id, None)
|
||||||
|
|
||||||
def send_raw(self, message: str):
|
def send_raw(self, message: str):
|
||||||
"""Send raw message to MCP server (for pass-through)."""
|
"""Send raw message to MCP server (for pass-through)."""
|
||||||
if not self.process:
|
if not self.process:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,289 @@
|
||||||
|
"""
|
||||||
|
Minimal Streamable HTTP transport support for MCP Browser.
|
||||||
|
|
||||||
|
Provides a lightweight client that can communicate with MCP servers using
|
||||||
|
the streamable-http transport defined by the MCP specification.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import Any, Awaitable, Callable, Dict, Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from .config import OAuthClientCredentialsConfig
|
||||||
|
from .logging_config import get_logger
|
||||||
|
|
||||||
|
MCP_SESSION_HEADER = "mcp-session-id"
|
||||||
|
MCP_PROTOCOL_VERSION_HEADER = "mcp-protocol-version"
|
||||||
|
LAST_EVENT_ID_HEADER = "last-event-id"
|
||||||
|
|
||||||
|
|
||||||
|
class StreamableHTTPError(Exception):
|
||||||
|
"""Base error for streamable HTTP transport failures."""
|
||||||
|
|
||||||
|
|
||||||
|
class StreamableHTTPClient:
|
||||||
|
"""Minimal client for the MCP streamable-http transport."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
*,
|
||||||
|
headers: Optional[Dict[str, str]] = None,
|
||||||
|
timeout: float = 30.0,
|
||||||
|
sse_timeout: float = 300.0,
|
||||||
|
oauth_config: Optional[OAuthClientCredentialsConfig] = None,
|
||||||
|
logger=None,
|
||||||
|
) -> None:
|
||||||
|
self.url = url
|
||||||
|
self.headers = headers.copy() if headers else {}
|
||||||
|
self.timeout = timeout
|
||||||
|
self.sse_timeout = sse_timeout
|
||||||
|
self.oauth_config = oauth_config
|
||||||
|
self.logger = logger or get_logger(__name__)
|
||||||
|
|
||||||
|
self._client: Optional[httpx.AsyncClient] = None
|
||||||
|
self._session_id: Optional[str] = None
|
||||||
|
self._protocol_version: Optional[str] = None
|
||||||
|
self._last_event_id: Optional[str] = None
|
||||||
|
|
||||||
|
self._token: Optional[str] = oauth_config.access_token if oauth_config else None
|
||||||
|
self._token_expires_at: Optional[float] = None
|
||||||
|
self._token_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
"""Initialise underlying HTTP client."""
|
||||||
|
if self._client is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._client = httpx.AsyncClient(timeout=self.timeout)
|
||||||
|
if self.oauth_config and self.oauth_config.token_url:
|
||||||
|
await self._ensure_token()
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""Shutdown HTTP client."""
|
||||||
|
if self._client:
|
||||||
|
await self._client.aclose()
|
||||||
|
self._client = None
|
||||||
|
|
||||||
|
async def send(
|
||||||
|
self,
|
||||||
|
request: Dict[str, Any],
|
||||||
|
message_callback: Callable[[Dict[str, Any]], Awaitable[None]],
|
||||||
|
) -> None:
|
||||||
|
"""Send a JSON-RPC request and stream responses to callback."""
|
||||||
|
if not self._client:
|
||||||
|
raise StreamableHTTPError("Streamable HTTP client not started")
|
||||||
|
|
||||||
|
await self._ensure_token()
|
||||||
|
|
||||||
|
headers = self._build_headers()
|
||||||
|
try:
|
||||||
|
response = await self._client.post(self.url, json=request, headers=headers)
|
||||||
|
except httpx.HTTPError as exc:
|
||||||
|
raise StreamableHTTPError(f"Streamable HTTP request failed: {exc}") from exc
|
||||||
|
|
||||||
|
self._store_session_headers(response)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as exc:
|
||||||
|
body = await self._safe_read_text(response)
|
||||||
|
message = body or exc.response.text
|
||||||
|
raise StreamableHTTPError(f"HTTP {exc.response.status_code}: {message}") from exc
|
||||||
|
|
||||||
|
content_type = response.headers.get("content-type", "")
|
||||||
|
|
||||||
|
if "text/event-stream" in content_type:
|
||||||
|
await self._consume_sse(response, message_callback)
|
||||||
|
else:
|
||||||
|
await self._consume_json(response, message_callback)
|
||||||
|
|
||||||
|
def _build_headers(self) -> Dict[str, str]:
|
||||||
|
headers = {
|
||||||
|
"accept": "application/json, text/event-stream",
|
||||||
|
"content-type": "application/json",
|
||||||
|
**self.headers,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self._session_id and MCP_SESSION_HEADER not in headers:
|
||||||
|
headers[MCP_SESSION_HEADER] = self._session_id
|
||||||
|
if self._protocol_version and MCP_PROTOCOL_VERSION_HEADER not in headers:
|
||||||
|
headers[MCP_PROTOCOL_VERSION_HEADER] = self._protocol_version
|
||||||
|
if self._last_event_id and LAST_EVENT_ID_HEADER not in headers:
|
||||||
|
headers[LAST_EVENT_ID_HEADER] = self._last_event_id
|
||||||
|
if self._token and "authorization" not in {k.lower(): v for k, v in headers.items()}:
|
||||||
|
headers["Authorization"] = f"Bearer {self._token}"
|
||||||
|
|
||||||
|
return headers
|
||||||
|
|
||||||
|
async def _ensure_token(self) -> None:
|
||||||
|
if not self.oauth_config:
|
||||||
|
return
|
||||||
|
if self.oauth_config.access_token and not self.oauth_config.token_url:
|
||||||
|
self._token = self.oauth_config.access_token
|
||||||
|
return
|
||||||
|
if not self.oauth_config.token_url or not self.oauth_config.client_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
async with self._token_lock:
|
||||||
|
if self._token and self._token_expires_at:
|
||||||
|
if time.time() < self._token_expires_at - 30:
|
||||||
|
return
|
||||||
|
await self._refresh_token()
|
||||||
|
|
||||||
|
async def _refresh_token(self) -> None:
|
||||||
|
assert self._client is not None
|
||||||
|
data = {
|
||||||
|
"grant_type": "client_credentials",
|
||||||
|
}
|
||||||
|
if self.oauth_config.scope:
|
||||||
|
data["scope"] = self.oauth_config.scope
|
||||||
|
if self.oauth_config.audience:
|
||||||
|
data["audience"] = self.oauth_config.audience
|
||||||
|
if self.oauth_config.extra_params:
|
||||||
|
data.update(self.oauth_config.extra_params)
|
||||||
|
|
||||||
|
auth = None
|
||||||
|
if self.oauth_config.client_secret:
|
||||||
|
auth = (self.oauth_config.client_id or "", self.oauth_config.client_secret)
|
||||||
|
else:
|
||||||
|
data["client_id"] = self.oauth_config.client_id
|
||||||
|
|
||||||
|
response = await self._client.post(
|
||||||
|
self.oauth_config.token_url, # type: ignore[arg-type]
|
||||||
|
data=data,
|
||||||
|
headers={"content-type": "application/x-www-form-urlencoded"},
|
||||||
|
auth=auth,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPError as exc:
|
||||||
|
body = await self._safe_read_text(response)
|
||||||
|
raise StreamableHTTPError(
|
||||||
|
f"Failed to refresh OAuth token: {body or response.text}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
token_data = response.json()
|
||||||
|
access_token = token_data.get("access_token")
|
||||||
|
if not access_token:
|
||||||
|
raise StreamableHTTPError("OAuth token response missing access_token")
|
||||||
|
|
||||||
|
self._token = access_token
|
||||||
|
expires_in = token_data.get("expires_in")
|
||||||
|
if expires_in:
|
||||||
|
try:
|
||||||
|
self._token_expires_at = time.time() + float(expires_in)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
self._token_expires_at = None
|
||||||
|
else:
|
||||||
|
self._token_expires_at = None
|
||||||
|
|
||||||
|
async def _consume_json(
|
||||||
|
self,
|
||||||
|
response: httpx.Response,
|
||||||
|
message_callback: Callable[[Dict[str, Any]], Awaitable[None]],
|
||||||
|
) -> None:
|
||||||
|
text = await response.aread()
|
||||||
|
if not text:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = json.loads(text)
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
raise StreamableHTTPError(f"Invalid JSON response: {exc}") from exc
|
||||||
|
|
||||||
|
messages = payload if isinstance(payload, list) else [payload]
|
||||||
|
for message in messages:
|
||||||
|
if isinstance(message, dict):
|
||||||
|
self._update_protocol_version(message)
|
||||||
|
await message_callback(message)
|
||||||
|
|
||||||
|
async def _consume_sse(
|
||||||
|
self,
|
||||||
|
response: httpx.Response,
|
||||||
|
message_callback: Callable[[Dict[str, Any]], Awaitable[None]],
|
||||||
|
) -> None:
|
||||||
|
event: Dict[str, Any] = {"data": []}
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if line == "":
|
||||||
|
await self._flush_event(event, message_callback)
|
||||||
|
event = {"data": []}
|
||||||
|
continue
|
||||||
|
|
||||||
|
if line.startswith(":"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
field, _, raw_value = line.partition(":")
|
||||||
|
value = raw_value.lstrip(" ")
|
||||||
|
if field == "data":
|
||||||
|
event.setdefault("data", []).append(value)
|
||||||
|
elif field == "event":
|
||||||
|
event["event"] = value
|
||||||
|
elif field == "id":
|
||||||
|
event["id"] = value
|
||||||
|
elif field == "retry":
|
||||||
|
continue
|
||||||
|
|
||||||
|
await self._flush_event(event, message_callback)
|
||||||
|
|
||||||
|
async def _flush_event(
|
||||||
|
self,
|
||||||
|
event: Dict[str, Any],
|
||||||
|
message_callback: Callable[[Dict[str, Any]], Awaitable[None]],
|
||||||
|
) -> None:
|
||||||
|
data_lines = event.get("data") or []
|
||||||
|
if not data_lines:
|
||||||
|
return
|
||||||
|
|
||||||
|
data = "\n".join(data_lines)
|
||||||
|
try:
|
||||||
|
message = json.loads(data)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
self.logger.warning("Failed to parse SSE event payload")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not isinstance(message, dict):
|
||||||
|
return
|
||||||
|
|
||||||
|
if event.get("id"):
|
||||||
|
self._last_event_id = event["id"]
|
||||||
|
|
||||||
|
self._update_protocol_version(message)
|
||||||
|
await message_callback(message)
|
||||||
|
|
||||||
|
def _update_protocol_version(self, message: Dict[str, Any]) -> None:
|
||||||
|
result = message.get("result")
|
||||||
|
if isinstance(result, dict) and result.get("protocolVersion"):
|
||||||
|
self._protocol_version = str(result["protocolVersion"])
|
||||||
|
|
||||||
|
def _store_session_headers(self, response: httpx.Response) -> None:
|
||||||
|
session_id = response.headers.get(MCP_SESSION_HEADER)
|
||||||
|
if session_id:
|
||||||
|
self._session_id = session_id
|
||||||
|
|
||||||
|
async def _safe_read_text(self, response: httpx.Response) -> Optional[str]:
|
||||||
|
try:
|
||||||
|
raw = await response.aread()
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(raw, str):
|
||||||
|
return raw
|
||||||
|
|
||||||
|
try:
|
||||||
|
charset = response.charset or "utf-8"
|
||||||
|
except Exception:
|
||||||
|
charset = "utf-8"
|
||||||
|
|
||||||
|
try:
|
||||||
|
return raw.decode(charset)
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
return raw.decode("utf-8", errors="replace")
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
@ -1,2 +1,3 @@
|
||||||
pyyaml>=6.0
|
pyyaml>=6.0
|
||||||
jsonpath-ng>=1.5.3
|
jsonpath-ng>=1.5.3
|
||||||
|
httpx>=0.27
|
||||||
|
|
|
||||||
1
setup.py
1
setup.py
|
|
@ -317,6 +317,7 @@ setup(
|
||||||
"aiofiles>=23.0.0",
|
"aiofiles>=23.0.0",
|
||||||
"jsonpath-ng>=1.6.0",
|
"jsonpath-ng>=1.6.0",
|
||||||
"pyyaml>=6.0",
|
"pyyaml>=6.0",
|
||||||
|
"httpx>=0.27",
|
||||||
"typing-extensions>=4.0.0;python_version<'3.11'",
|
"typing-extensions>=4.0.0;python_version<'3.11'",
|
||||||
],
|
],
|
||||||
extras_require={
|
extras_require={
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue