feat: add streamable-http transport support

This commit is contained in:
gpt-5-codex 2025-10-10 22:56:32 +02:00 committed by Andre Heinecke
parent 568fd53ad2
commit 6199f28e9e
7 changed files with 555 additions and 17 deletions

View File

@ -326,7 +326,8 @@ async def handle_mcp_command(args):
browser = MCPBrowser(
server_name=args.server,
config_path=config_path,
enable_builtin_servers=not args.no_builtin
enable_builtin_servers=not args.no_builtin,
transport_override=getattr(args, "transport_override", None)
)
try:
@ -409,7 +410,8 @@ async def start_daemon_background(args):
browser = MCPBrowser(
server_name=args.server,
config_path=config_path,
enable_builtin_servers=not args.no_builtin
enable_builtin_servers=not args.no_builtin,
transport_override=getattr(args, "transport_override", None)
)
# Run daemon
@ -687,6 +689,28 @@ Environment:
parser.add_argument("--version", "-v", action="version",
version=f"%(prog)s {__version__}",
help="Show program version and exit")
parser.add_argument("--transport", choices=["stdio", "streamable-http"],
help="Override transport type for the target server")
parser.add_argument("--transport-url",
help="Endpoint URL for streamable-http transport override")
parser.add_argument("--transport-header", action="append", metavar="KEY=VALUE",
help="Additional HTTP headers for streamable-http transport (repeatable)")
parser.add_argument("--transport-timeout", type=float,
help="Override HTTP request timeout in seconds")
parser.add_argument("--transport-sse-timeout", type=float,
help="Override SSE read timeout in seconds")
parser.add_argument("--oauth-token-url",
help="OAuth token endpoint for client credentials flow")
parser.add_argument("--oauth-client-id", help="OAuth client identifier")
parser.add_argument("--oauth-client-secret", help="OAuth client secret")
parser.add_argument("--oauth-scope",
help="OAuth scopes (space-separated)")
parser.add_argument("--oauth-audience",
help="OAuth audience/resource parameter")
parser.add_argument("--oauth-extra-param", action="append", metavar="KEY=VALUE",
help="Additional form parameters for OAuth token requests (repeatable)")
parser.add_argument("--oauth-token",
help="Static OAuth bearer token (applies Authorization header)")
# MCP method commands
subparsers = parser.add_subparsers(dest="command", help="MCP methods")
@ -725,6 +749,70 @@ Environment:
args = parser.parse_args()
def parse_key_value_pairs(pairs):
result = {}
if not pairs:
return result
for item in pairs:
if "=" not in item:
print(f"Invalid KEY=VALUE format: {item}", file=sys.stderr)
sys.exit(2)
key, value = item.split("=", 1)
key = key.strip()
value = value.strip()
if not key:
print(f"Header key is empty for entry: {item}", file=sys.stderr)
sys.exit(2)
result[key] = value
return result
transport_headers = parse_key_value_pairs(args.transport_header)
oauth_extra = parse_key_value_pairs(args.oauth_extra_param)
if args.oauth_token:
transport_headers.setdefault("Authorization", f"Bearer {args.oauth_token}")
oauth_override = {}
if args.oauth_token_url:
oauth_override["token_url"] = args.oauth_token_url
if args.oauth_client_id:
oauth_override["client_id"] = args.oauth_client_id
if args.oauth_client_secret is not None:
oauth_override["client_secret"] = args.oauth_client_secret
if args.oauth_scope is not None:
oauth_override["scope"] = args.oauth_scope
if args.oauth_audience is not None:
oauth_override["audience"] = args.oauth_audience
if oauth_extra:
oauth_override["extra_params"] = oauth_extra
if args.oauth_token:
oauth_override["access_token"] = args.oauth_token
transport_override = None
if any([
args.transport,
args.transport_url,
transport_headers,
args.transport_timeout is not None,
args.transport_sse_timeout is not None,
oauth_override,
]):
transport_override = {}
if args.transport:
transport_override["type"] = args.transport
if args.transport_url:
transport_override["url"] = args.transport_url
if transport_headers:
transport_override["headers"] = transport_headers
if args.transport_timeout is not None:
transport_override["timeout"] = args.transport_timeout
if args.transport_sse_timeout is not None:
transport_override["sse_timeout"] = args.transport_sse_timeout
if oauth_override:
transport_override["oauth"] = oauth_override
setattr(args, "transport_override", transport_override)
# Handle special commands first
if args.list_servers:
show_available_servers(args.config)
@ -765,7 +853,8 @@ Environment:
browser = MCPBrowser(
server_name=args.server,
config_path=config_path,
enable_builtin_servers=not args.no_builtin
enable_builtin_servers=not args.no_builtin,
transport_override=transport_override
)
# Handle test mode
@ -805,4 +894,4 @@ async def async_main(browser: MCPBrowser):
if __name__ == "__main__":
main()
main()

View File

@ -14,15 +14,39 @@ from pathlib import Path
from .default_configs import ConfigManager
@dataclass
class OAuthClientCredentialsConfig:
"""Configuration for OAuth client credentials flow."""
token_url: Optional[str] = None
client_id: Optional[str] = None
client_secret: Optional[str] = None
scope: Optional[str] = None
audience: Optional[str] = None
extra_params: Dict[str, str] = field(default_factory=dict)
access_token: Optional[str] = None
@dataclass
class TransportConfig:
"""Transport configuration for connecting to an MCP server."""
type: str = "stdio"
url: Optional[str] = None
headers: Dict[str, str] = field(default_factory=dict)
timeout: float = 30.0
sse_timeout: float = 300.0
oauth: Optional[OAuthClientCredentialsConfig] = None
@dataclass
class MCPServerConfig:
"""Configuration for a single MCP server."""
command: List[str]
command: Optional[List[str]] = None
args: List[str] = field(default_factory=list)
env: Dict[str, str] = field(default_factory=dict)
name: Optional[str] = None
description: Optional[str] = None
enabled: bool = True
transport: TransportConfig = field(default_factory=TransportConfig)
@dataclass
@ -84,13 +108,35 @@ class ConfigLoader:
# Convert to dataclass instances
servers = {}
for name, server_config in config_data.get("servers", {}).items():
transport_data = server_config.get("transport", {}) or {}
oauth_data = transport_data.get("oauth") or {}
oauth_config = None
if oauth_data:
oauth_config = OAuthClientCredentialsConfig(
token_url=oauth_data.get("token_url") or oauth_data.get("tokenUrl"),
client_id=oauth_data.get("client_id") or oauth_data.get("clientId"),
client_secret=oauth_data.get("client_secret") or oauth_data.get("clientSecret"),
scope=oauth_data.get("scope"),
audience=oauth_data.get("audience") or oauth_data.get("resource"),
extra_params=oauth_data.get("extra_params") or oauth_data.get("extraParams", {}),
access_token=oauth_data.get("access_token") or oauth_data.get("accessToken"),
)
transport_config = TransportConfig(
type=transport_data.get("type", "stdio"),
url=transport_data.get("url"),
headers=transport_data.get("headers", {}),
timeout=transport_data.get("timeout", 30.0),
sse_timeout=transport_data.get("sse_timeout") or transport_data.get("sseTimeout", 300.0),
oauth=oauth_config
)
servers[name] = MCPServerConfig(
command=server_config["command"],
command=server_config.get("command"),
args=server_config.get("args", []),
env=server_config.get("env", {}),
name=server_config.get("name", name),
description=server_config.get("description"),
enabled=server_config.get("enabled", True)
enabled=server_config.get("enabled", True),
transport=transport_config
)
self._config = MCPBrowserConfig(
@ -111,4 +157,4 @@ class ConfigLoader:
if key in base and isinstance(base[key], dict) and isinstance(value, dict):
self._merge_configs(base[key], value)
else:
base[key] = value
base[key] = value

View File

@ -10,7 +10,12 @@ import asyncio
from typing import Dict, Any, Optional, Union
from pathlib import Path
from .config import ConfigLoader, MCPBrowserConfig
from .config import (
ConfigLoader,
MCPBrowserConfig,
MCPServerConfig,
OAuthClientCredentialsConfig,
)
from .server import MCPServer
from .multi_server import MultiServerManager
from .registry import ToolRegistry
@ -29,7 +34,7 @@ class MCPBrowser:
"""
def __init__(self, config_path: Optional[Path] = None, server_name: Optional[str] = None,
enable_builtin_servers: bool = True):
enable_builtin_servers: bool = True, transport_override: Optional[Dict[str, Any]] = None):
"""
Initialize MCP Browser.
@ -47,6 +52,7 @@ class MCPBrowser:
self.virtual_handler: Optional[VirtualToolHandler] = None
self._server_name = server_name
self._enable_builtin_servers = enable_builtin_servers
self._transport_override = transport_override
self._initialized = False
self._response_buffer: Dict[Union[str, int], asyncio.Future] = {}
self._next_id = 1
@ -54,6 +60,42 @@ class MCPBrowser:
self._config_watcher = None
self._server_configs = {}
self._config_mtime = None
def _apply_transport_override(self, server_config: MCPServerConfig):
"""Apply transport overrides from CLI arguments."""
override = self._transport_override or {}
transport = server_config.transport
transport.type = override.get("type", transport.type)
if override.get("url"):
transport.url = override["url"]
if override.get("headers"):
transport.headers.update(override["headers"])
if "timeout" in override and override["timeout"] is not None:
transport.timeout = override["timeout"]
if "sse_timeout" in override and override["sse_timeout"] is not None:
transport.sse_timeout = override["sse_timeout"]
oauth_override = override.get("oauth")
if oauth_override:
if not transport.oauth:
transport.oauth = OAuthClientCredentialsConfig()
oauth = transport.oauth
if oauth_override.get("token_url"):
oauth.token_url = oauth_override["token_url"]
if oauth_override.get("client_id"):
oauth.client_id = oauth_override["client_id"]
if "client_secret" in oauth_override:
oauth.client_secret = oauth_override["client_secret"]
if "scope" in oauth_override:
oauth.scope = oauth_override["scope"]
if "audience" in oauth_override:
oauth.audience = oauth_override["audience"]
if "extra_params" in oauth_override and oauth_override["extra_params"] is not None:
oauth.extra_params.update(oauth_override["extra_params"])
if "access_token" in oauth_override:
oauth.access_token = oauth_override["access_token"]
async def __aenter__(self):
"""Async context manager entry."""
@ -78,6 +120,8 @@ class MCPBrowser:
raise ValueError(f"Server '{server_name}' not found in configuration")
server_config = self.config.servers[server_name]
if self._transport_override:
self._apply_transport_override(server_config)
# Create multi-server manager if using built-in servers
if self._enable_builtin_servers:
@ -469,4 +513,4 @@ async def create_browser(config_path: Optional[Path] = None,
"""Create and initialize an MCP Browser instance."""
browser = MCPBrowser(config_path, server_name)
await browser.initialize()
return browser
return browser

View File

@ -9,12 +9,11 @@ import os
import json
import asyncio
import subprocess
from typing import Optional, Dict, Any, Callable, List
from pathlib import Path
from typing import Optional, Dict, Any, Callable, List, Union
from .buffer import JsonRpcBuffer
from .config import MCPServerConfig
from .logging_config import get_logger, TRACE
from .streamable_http import StreamableHTTPClient, StreamableHTTPError
import logging
@ -32,9 +31,30 @@ class MCPServer:
self._pending_requests: Dict[Union[str, int], asyncio.Future] = {}
self._last_error_time: Optional[float] = None
self._offline_since: Optional[float] = None
self._http_client: Optional[StreamableHTTPClient] = None
async def start(self):
"""Start the MCP server process."""
if self.config.transport.type == "streamable-http":
if self._http_client:
return
if not self.config.transport.url:
raise ValueError("Streamable HTTP transport requires 'url' option")
self.logger.info(f"Connecting to streamable HTTP endpoint: {self.config.transport.url}")
self._http_client = StreamableHTTPClient(
self.config.transport.url,
headers=self.config.transport.headers,
timeout=self.config.transport.timeout,
sse_timeout=self.config.transport.sse_timeout,
oauth_config=self.config.transport.oauth,
logger=self.logger,
)
await self._http_client.start()
self._running = True
self._offline_since = None
return
if self.process:
return
@ -46,6 +66,9 @@ class MCPServer:
self.logger.warning(f"Server has been offline for {offline_duration:.0f}s, skipping start")
raise RuntimeError(f"Server marked as offline since {offline_duration:.0f}s ago")
if not self.config.command:
raise ValueError("Server command not configured for stdio transport")
# Prepare environment
env = os.environ.copy()
env.update({
@ -85,6 +108,13 @@ class MCPServer:
async def stop(self):
"""Stop the MCP server process."""
if self.config.transport.type == "streamable-http":
self._running = False
if self._http_client:
await self._http_client.stop()
self._http_client = None
return
self._running = False
if self.process:
@ -116,6 +146,9 @@ class MCPServer:
Returns:
Response result or raises exception on error
"""
if self.config.transport.type == "streamable-http":
return await self._send_request_http(method, params or {})
if not self.process:
raise RuntimeError("MCP server not started")
@ -151,6 +184,41 @@ class MCPServer:
self.logger.error(f"Timeout waiting for response to {method} (timeout={timeout}s)")
self._mark_offline()
raise TimeoutError(f"No response for request {request_id}")
async def _send_request_http(self, method: str, params: Dict[str, Any]) -> Dict[str, Any]:
"""Send request via streamable HTTP transport."""
if not self._http_client:
raise RuntimeError("Streamable HTTP client not started")
request_id = self._next_id
self._next_id += 1
request = {
"jsonrpc": "2.0",
"id": request_id,
"method": method,
"params": params
}
loop = asyncio.get_running_loop()
future: asyncio.Future = loop.create_future()
self._pending_requests[request_id] = future
timeout = self.config.transport.timeout or 30.0
try:
await self._http_client.send(request, self._handle_message)
result = await asyncio.wait_for(future, timeout=timeout)
return result
except asyncio.TimeoutError:
raise TimeoutError(f"No response for request {request_id}")
except StreamableHTTPError:
raise
except Exception as exc:
self.logger.error(f"Streamable HTTP request failed: {exc}")
raise
finally:
self._pending_requests.pop(request_id, None)
def send_raw(self, message: str):
"""Send raw message to MCP server (for pass-through)."""
@ -219,4 +287,4 @@ class MCPServer:
try:
handler(message)
except Exception as e:
self.logger.error(f"Handler error: {e}")
self.logger.error(f"Handler error: {e}")

View File

@ -0,0 +1,289 @@
"""
Minimal Streamable HTTP transport support for MCP Browser.
Provides a lightweight client that can communicate with MCP servers using
the streamable-http transport defined by the MCP specification.
"""
from __future__ import annotations
import asyncio
import json
import time
from typing import Any, Awaitable, Callable, Dict, Optional
import httpx
from .config import OAuthClientCredentialsConfig
from .logging_config import get_logger
MCP_SESSION_HEADER = "mcp-session-id"
MCP_PROTOCOL_VERSION_HEADER = "mcp-protocol-version"
LAST_EVENT_ID_HEADER = "last-event-id"
class StreamableHTTPError(Exception):
"""Base error for streamable HTTP transport failures."""
class StreamableHTTPClient:
"""Minimal client for the MCP streamable-http transport."""
def __init__(
self,
url: str,
*,
headers: Optional[Dict[str, str]] = None,
timeout: float = 30.0,
sse_timeout: float = 300.0,
oauth_config: Optional[OAuthClientCredentialsConfig] = None,
logger=None,
) -> None:
self.url = url
self.headers = headers.copy() if headers else {}
self.timeout = timeout
self.sse_timeout = sse_timeout
self.oauth_config = oauth_config
self.logger = logger or get_logger(__name__)
self._client: Optional[httpx.AsyncClient] = None
self._session_id: Optional[str] = None
self._protocol_version: Optional[str] = None
self._last_event_id: Optional[str] = None
self._token: Optional[str] = oauth_config.access_token if oauth_config else None
self._token_expires_at: Optional[float] = None
self._token_lock = asyncio.Lock()
async def start(self) -> None:
"""Initialise underlying HTTP client."""
if self._client is not None:
return
self._client = httpx.AsyncClient(timeout=self.timeout)
if self.oauth_config and self.oauth_config.token_url:
await self._ensure_token()
async def stop(self) -> None:
"""Shutdown HTTP client."""
if self._client:
await self._client.aclose()
self._client = None
async def send(
self,
request: Dict[str, Any],
message_callback: Callable[[Dict[str, Any]], Awaitable[None]],
) -> None:
"""Send a JSON-RPC request and stream responses to callback."""
if not self._client:
raise StreamableHTTPError("Streamable HTTP client not started")
await self._ensure_token()
headers = self._build_headers()
try:
response = await self._client.post(self.url, json=request, headers=headers)
except httpx.HTTPError as exc:
raise StreamableHTTPError(f"Streamable HTTP request failed: {exc}") from exc
self._store_session_headers(response)
try:
response.raise_for_status()
except httpx.HTTPStatusError as exc:
body = await self._safe_read_text(response)
message = body or exc.response.text
raise StreamableHTTPError(f"HTTP {exc.response.status_code}: {message}") from exc
content_type = response.headers.get("content-type", "")
if "text/event-stream" in content_type:
await self._consume_sse(response, message_callback)
else:
await self._consume_json(response, message_callback)
def _build_headers(self) -> Dict[str, str]:
headers = {
"accept": "application/json, text/event-stream",
"content-type": "application/json",
**self.headers,
}
if self._session_id and MCP_SESSION_HEADER not in headers:
headers[MCP_SESSION_HEADER] = self._session_id
if self._protocol_version and MCP_PROTOCOL_VERSION_HEADER not in headers:
headers[MCP_PROTOCOL_VERSION_HEADER] = self._protocol_version
if self._last_event_id and LAST_EVENT_ID_HEADER not in headers:
headers[LAST_EVENT_ID_HEADER] = self._last_event_id
if self._token and "authorization" not in {k.lower(): v for k, v in headers.items()}:
headers["Authorization"] = f"Bearer {self._token}"
return headers
async def _ensure_token(self) -> None:
if not self.oauth_config:
return
if self.oauth_config.access_token and not self.oauth_config.token_url:
self._token = self.oauth_config.access_token
return
if not self.oauth_config.token_url or not self.oauth_config.client_id:
return
async with self._token_lock:
if self._token and self._token_expires_at:
if time.time() < self._token_expires_at - 30:
return
await self._refresh_token()
async def _refresh_token(self) -> None:
assert self._client is not None
data = {
"grant_type": "client_credentials",
}
if self.oauth_config.scope:
data["scope"] = self.oauth_config.scope
if self.oauth_config.audience:
data["audience"] = self.oauth_config.audience
if self.oauth_config.extra_params:
data.update(self.oauth_config.extra_params)
auth = None
if self.oauth_config.client_secret:
auth = (self.oauth_config.client_id or "", self.oauth_config.client_secret)
else:
data["client_id"] = self.oauth_config.client_id
response = await self._client.post(
self.oauth_config.token_url, # type: ignore[arg-type]
data=data,
headers={"content-type": "application/x-www-form-urlencoded"},
auth=auth,
)
try:
response.raise_for_status()
except httpx.HTTPError as exc:
body = await self._safe_read_text(response)
raise StreamableHTTPError(
f"Failed to refresh OAuth token: {body or response.text}"
) from exc
token_data = response.json()
access_token = token_data.get("access_token")
if not access_token:
raise StreamableHTTPError("OAuth token response missing access_token")
self._token = access_token
expires_in = token_data.get("expires_in")
if expires_in:
try:
self._token_expires_at = time.time() + float(expires_in)
except (TypeError, ValueError):
self._token_expires_at = None
else:
self._token_expires_at = None
async def _consume_json(
self,
response: httpx.Response,
message_callback: Callable[[Dict[str, Any]], Awaitable[None]],
) -> None:
text = await response.aread()
if not text:
return
try:
payload = json.loads(text)
except json.JSONDecodeError as exc:
raise StreamableHTTPError(f"Invalid JSON response: {exc}") from exc
messages = payload if isinstance(payload, list) else [payload]
for message in messages:
if isinstance(message, dict):
self._update_protocol_version(message)
await message_callback(message)
async def _consume_sse(
self,
response: httpx.Response,
message_callback: Callable[[Dict[str, Any]], Awaitable[None]],
) -> None:
event: Dict[str, Any] = {"data": []}
async for line in response.aiter_lines():
if line == "":
await self._flush_event(event, message_callback)
event = {"data": []}
continue
if line.startswith(":"):
continue
field, _, raw_value = line.partition(":")
value = raw_value.lstrip(" ")
if field == "data":
event.setdefault("data", []).append(value)
elif field == "event":
event["event"] = value
elif field == "id":
event["id"] = value
elif field == "retry":
continue
await self._flush_event(event, message_callback)
async def _flush_event(
self,
event: Dict[str, Any],
message_callback: Callable[[Dict[str, Any]], Awaitable[None]],
) -> None:
data_lines = event.get("data") or []
if not data_lines:
return
data = "\n".join(data_lines)
try:
message = json.loads(data)
except json.JSONDecodeError:
self.logger.warning("Failed to parse SSE event payload")
return
if not isinstance(message, dict):
return
if event.get("id"):
self._last_event_id = event["id"]
self._update_protocol_version(message)
await message_callback(message)
def _update_protocol_version(self, message: Dict[str, Any]) -> None:
result = message.get("result")
if isinstance(result, dict) and result.get("protocolVersion"):
self._protocol_version = str(result["protocolVersion"])
def _store_session_headers(self, response: httpx.Response) -> None:
session_id = response.headers.get(MCP_SESSION_HEADER)
if session_id:
self._session_id = session_id
async def _safe_read_text(self, response: httpx.Response) -> Optional[str]:
try:
raw = await response.aread()
except Exception:
return None
if isinstance(raw, str):
return raw
try:
charset = response.charset or "utf-8"
except Exception:
charset = "utf-8"
try:
return raw.decode(charset)
except Exception:
try:
return raw.decode("utf-8", errors="replace")
except Exception:
return None

View File

@ -1,2 +1,3 @@
pyyaml>=6.0
jsonpath-ng>=1.5.3
jsonpath-ng>=1.5.3
httpx>=0.27

View File

@ -317,6 +317,7 @@ setup(
"aiofiles>=23.0.0",
"jsonpath-ng>=1.6.0",
"pyyaml>=6.0",
"httpx>=0.27",
"typing-extensions>=4.0.0;python_version<'3.11'",
],
extras_require={
@ -359,4 +360,4 @@ setup(
"Programming Language :: Python :: 3.12",
],
keywords="mcp model-context-protocol ai llm tools json-rpc",
)
)