feat: add streamable-http transport support
This commit is contained in:
parent
568fd53ad2
commit
6199f28e9e
|
|
@ -326,7 +326,8 @@ async def handle_mcp_command(args):
|
|||
browser = MCPBrowser(
|
||||
server_name=args.server,
|
||||
config_path=config_path,
|
||||
enable_builtin_servers=not args.no_builtin
|
||||
enable_builtin_servers=not args.no_builtin,
|
||||
transport_override=getattr(args, "transport_override", None)
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -409,7 +410,8 @@ async def start_daemon_background(args):
|
|||
browser = MCPBrowser(
|
||||
server_name=args.server,
|
||||
config_path=config_path,
|
||||
enable_builtin_servers=not args.no_builtin
|
||||
enable_builtin_servers=not args.no_builtin,
|
||||
transport_override=getattr(args, "transport_override", None)
|
||||
)
|
||||
|
||||
# Run daemon
|
||||
|
|
@ -687,6 +689,28 @@ Environment:
|
|||
parser.add_argument("--version", "-v", action="version",
|
||||
version=f"%(prog)s {__version__}",
|
||||
help="Show program version and exit")
|
||||
parser.add_argument("--transport", choices=["stdio", "streamable-http"],
|
||||
help="Override transport type for the target server")
|
||||
parser.add_argument("--transport-url",
|
||||
help="Endpoint URL for streamable-http transport override")
|
||||
parser.add_argument("--transport-header", action="append", metavar="KEY=VALUE",
|
||||
help="Additional HTTP headers for streamable-http transport (repeatable)")
|
||||
parser.add_argument("--transport-timeout", type=float,
|
||||
help="Override HTTP request timeout in seconds")
|
||||
parser.add_argument("--transport-sse-timeout", type=float,
|
||||
help="Override SSE read timeout in seconds")
|
||||
parser.add_argument("--oauth-token-url",
|
||||
help="OAuth token endpoint for client credentials flow")
|
||||
parser.add_argument("--oauth-client-id", help="OAuth client identifier")
|
||||
parser.add_argument("--oauth-client-secret", help="OAuth client secret")
|
||||
parser.add_argument("--oauth-scope",
|
||||
help="OAuth scopes (space-separated)")
|
||||
parser.add_argument("--oauth-audience",
|
||||
help="OAuth audience/resource parameter")
|
||||
parser.add_argument("--oauth-extra-param", action="append", metavar="KEY=VALUE",
|
||||
help="Additional form parameters for OAuth token requests (repeatable)")
|
||||
parser.add_argument("--oauth-token",
|
||||
help="Static OAuth bearer token (applies Authorization header)")
|
||||
|
||||
# MCP method commands
|
||||
subparsers = parser.add_subparsers(dest="command", help="MCP methods")
|
||||
|
|
@ -725,6 +749,70 @@ Environment:
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
def parse_key_value_pairs(pairs):
|
||||
result = {}
|
||||
if not pairs:
|
||||
return result
|
||||
for item in pairs:
|
||||
if "=" not in item:
|
||||
print(f"Invalid KEY=VALUE format: {item}", file=sys.stderr)
|
||||
sys.exit(2)
|
||||
key, value = item.split("=", 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
if not key:
|
||||
print(f"Header key is empty for entry: {item}", file=sys.stderr)
|
||||
sys.exit(2)
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
transport_headers = parse_key_value_pairs(args.transport_header)
|
||||
oauth_extra = parse_key_value_pairs(args.oauth_extra_param)
|
||||
|
||||
if args.oauth_token:
|
||||
transport_headers.setdefault("Authorization", f"Bearer {args.oauth_token}")
|
||||
|
||||
oauth_override = {}
|
||||
if args.oauth_token_url:
|
||||
oauth_override["token_url"] = args.oauth_token_url
|
||||
if args.oauth_client_id:
|
||||
oauth_override["client_id"] = args.oauth_client_id
|
||||
if args.oauth_client_secret is not None:
|
||||
oauth_override["client_secret"] = args.oauth_client_secret
|
||||
if args.oauth_scope is not None:
|
||||
oauth_override["scope"] = args.oauth_scope
|
||||
if args.oauth_audience is not None:
|
||||
oauth_override["audience"] = args.oauth_audience
|
||||
if oauth_extra:
|
||||
oauth_override["extra_params"] = oauth_extra
|
||||
if args.oauth_token:
|
||||
oauth_override["access_token"] = args.oauth_token
|
||||
|
||||
transport_override = None
|
||||
if any([
|
||||
args.transport,
|
||||
args.transport_url,
|
||||
transport_headers,
|
||||
args.transport_timeout is not None,
|
||||
args.transport_sse_timeout is not None,
|
||||
oauth_override,
|
||||
]):
|
||||
transport_override = {}
|
||||
if args.transport:
|
||||
transport_override["type"] = args.transport
|
||||
if args.transport_url:
|
||||
transport_override["url"] = args.transport_url
|
||||
if transport_headers:
|
||||
transport_override["headers"] = transport_headers
|
||||
if args.transport_timeout is not None:
|
||||
transport_override["timeout"] = args.transport_timeout
|
||||
if args.transport_sse_timeout is not None:
|
||||
transport_override["sse_timeout"] = args.transport_sse_timeout
|
||||
if oauth_override:
|
||||
transport_override["oauth"] = oauth_override
|
||||
|
||||
setattr(args, "transport_override", transport_override)
|
||||
|
||||
# Handle special commands first
|
||||
if args.list_servers:
|
||||
show_available_servers(args.config)
|
||||
|
|
@ -765,7 +853,8 @@ Environment:
|
|||
browser = MCPBrowser(
|
||||
server_name=args.server,
|
||||
config_path=config_path,
|
||||
enable_builtin_servers=not args.no_builtin
|
||||
enable_builtin_servers=not args.no_builtin,
|
||||
transport_override=transport_override
|
||||
)
|
||||
|
||||
# Handle test mode
|
||||
|
|
@ -805,4 +894,4 @@ async def async_main(browser: MCPBrowser):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -14,15 +14,39 @@ from pathlib import Path
|
|||
from .default_configs import ConfigManager
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthClientCredentialsConfig:
|
||||
"""Configuration for OAuth client credentials flow."""
|
||||
token_url: Optional[str] = None
|
||||
client_id: Optional[str] = None
|
||||
client_secret: Optional[str] = None
|
||||
scope: Optional[str] = None
|
||||
audience: Optional[str] = None
|
||||
extra_params: Dict[str, str] = field(default_factory=dict)
|
||||
access_token: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransportConfig:
|
||||
"""Transport configuration for connecting to an MCP server."""
|
||||
type: str = "stdio"
|
||||
url: Optional[str] = None
|
||||
headers: Dict[str, str] = field(default_factory=dict)
|
||||
timeout: float = 30.0
|
||||
sse_timeout: float = 300.0
|
||||
oauth: Optional[OAuthClientCredentialsConfig] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPServerConfig:
|
||||
"""Configuration for a single MCP server."""
|
||||
command: List[str]
|
||||
command: Optional[List[str]] = None
|
||||
args: List[str] = field(default_factory=list)
|
||||
env: Dict[str, str] = field(default_factory=dict)
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
enabled: bool = True
|
||||
transport: TransportConfig = field(default_factory=TransportConfig)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -84,13 +108,35 @@ class ConfigLoader:
|
|||
# Convert to dataclass instances
|
||||
servers = {}
|
||||
for name, server_config in config_data.get("servers", {}).items():
|
||||
transport_data = server_config.get("transport", {}) or {}
|
||||
oauth_data = transport_data.get("oauth") or {}
|
||||
oauth_config = None
|
||||
if oauth_data:
|
||||
oauth_config = OAuthClientCredentialsConfig(
|
||||
token_url=oauth_data.get("token_url") or oauth_data.get("tokenUrl"),
|
||||
client_id=oauth_data.get("client_id") or oauth_data.get("clientId"),
|
||||
client_secret=oauth_data.get("client_secret") or oauth_data.get("clientSecret"),
|
||||
scope=oauth_data.get("scope"),
|
||||
audience=oauth_data.get("audience") or oauth_data.get("resource"),
|
||||
extra_params=oauth_data.get("extra_params") or oauth_data.get("extraParams", {}),
|
||||
access_token=oauth_data.get("access_token") or oauth_data.get("accessToken"),
|
||||
)
|
||||
transport_config = TransportConfig(
|
||||
type=transport_data.get("type", "stdio"),
|
||||
url=transport_data.get("url"),
|
||||
headers=transport_data.get("headers", {}),
|
||||
timeout=transport_data.get("timeout", 30.0),
|
||||
sse_timeout=transport_data.get("sse_timeout") or transport_data.get("sseTimeout", 300.0),
|
||||
oauth=oauth_config
|
||||
)
|
||||
servers[name] = MCPServerConfig(
|
||||
command=server_config["command"],
|
||||
command=server_config.get("command"),
|
||||
args=server_config.get("args", []),
|
||||
env=server_config.get("env", {}),
|
||||
name=server_config.get("name", name),
|
||||
description=server_config.get("description"),
|
||||
enabled=server_config.get("enabled", True)
|
||||
enabled=server_config.get("enabled", True),
|
||||
transport=transport_config
|
||||
)
|
||||
|
||||
self._config = MCPBrowserConfig(
|
||||
|
|
@ -111,4 +157,4 @@ class ConfigLoader:
|
|||
if key in base and isinstance(base[key], dict) and isinstance(value, dict):
|
||||
self._merge_configs(base[key], value)
|
||||
else:
|
||||
base[key] = value
|
||||
base[key] = value
|
||||
|
|
|
|||
|
|
@ -10,7 +10,12 @@ import asyncio
|
|||
from typing import Dict, Any, Optional, Union
|
||||
from pathlib import Path
|
||||
|
||||
from .config import ConfigLoader, MCPBrowserConfig
|
||||
from .config import (
|
||||
ConfigLoader,
|
||||
MCPBrowserConfig,
|
||||
MCPServerConfig,
|
||||
OAuthClientCredentialsConfig,
|
||||
)
|
||||
from .server import MCPServer
|
||||
from .multi_server import MultiServerManager
|
||||
from .registry import ToolRegistry
|
||||
|
|
@ -29,7 +34,7 @@ class MCPBrowser:
|
|||
"""
|
||||
|
||||
def __init__(self, config_path: Optional[Path] = None, server_name: Optional[str] = None,
|
||||
enable_builtin_servers: bool = True):
|
||||
enable_builtin_servers: bool = True, transport_override: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
Initialize MCP Browser.
|
||||
|
||||
|
|
@ -47,6 +52,7 @@ class MCPBrowser:
|
|||
self.virtual_handler: Optional[VirtualToolHandler] = None
|
||||
self._server_name = server_name
|
||||
self._enable_builtin_servers = enable_builtin_servers
|
||||
self._transport_override = transport_override
|
||||
self._initialized = False
|
||||
self._response_buffer: Dict[Union[str, int], asyncio.Future] = {}
|
||||
self._next_id = 1
|
||||
|
|
@ -54,6 +60,42 @@ class MCPBrowser:
|
|||
self._config_watcher = None
|
||||
self._server_configs = {}
|
||||
self._config_mtime = None
|
||||
|
||||
def _apply_transport_override(self, server_config: MCPServerConfig):
|
||||
"""Apply transport overrides from CLI arguments."""
|
||||
override = self._transport_override or {}
|
||||
transport = server_config.transport
|
||||
|
||||
transport.type = override.get("type", transport.type)
|
||||
|
||||
if override.get("url"):
|
||||
transport.url = override["url"]
|
||||
if override.get("headers"):
|
||||
transport.headers.update(override["headers"])
|
||||
if "timeout" in override and override["timeout"] is not None:
|
||||
transport.timeout = override["timeout"]
|
||||
if "sse_timeout" in override and override["sse_timeout"] is not None:
|
||||
transport.sse_timeout = override["sse_timeout"]
|
||||
|
||||
oauth_override = override.get("oauth")
|
||||
if oauth_override:
|
||||
if not transport.oauth:
|
||||
transport.oauth = OAuthClientCredentialsConfig()
|
||||
oauth = transport.oauth
|
||||
if oauth_override.get("token_url"):
|
||||
oauth.token_url = oauth_override["token_url"]
|
||||
if oauth_override.get("client_id"):
|
||||
oauth.client_id = oauth_override["client_id"]
|
||||
if "client_secret" in oauth_override:
|
||||
oauth.client_secret = oauth_override["client_secret"]
|
||||
if "scope" in oauth_override:
|
||||
oauth.scope = oauth_override["scope"]
|
||||
if "audience" in oauth_override:
|
||||
oauth.audience = oauth_override["audience"]
|
||||
if "extra_params" in oauth_override and oauth_override["extra_params"] is not None:
|
||||
oauth.extra_params.update(oauth_override["extra_params"])
|
||||
if "access_token" in oauth_override:
|
||||
oauth.access_token = oauth_override["access_token"]
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry."""
|
||||
|
|
@ -78,6 +120,8 @@ class MCPBrowser:
|
|||
raise ValueError(f"Server '{server_name}' not found in configuration")
|
||||
|
||||
server_config = self.config.servers[server_name]
|
||||
if self._transport_override:
|
||||
self._apply_transport_override(server_config)
|
||||
|
||||
# Create multi-server manager if using built-in servers
|
||||
if self._enable_builtin_servers:
|
||||
|
|
@ -469,4 +513,4 @@ async def create_browser(config_path: Optional[Path] = None,
|
|||
"""Create and initialize an MCP Browser instance."""
|
||||
browser = MCPBrowser(config_path, server_name)
|
||||
await browser.initialize()
|
||||
return browser
|
||||
return browser
|
||||
|
|
|
|||
|
|
@ -9,12 +9,11 @@ import os
|
|||
import json
|
||||
import asyncio
|
||||
import subprocess
|
||||
from typing import Optional, Dict, Any, Callable, List
|
||||
from pathlib import Path
|
||||
|
||||
from typing import Optional, Dict, Any, Callable, List, Union
|
||||
from .buffer import JsonRpcBuffer
|
||||
from .config import MCPServerConfig
|
||||
from .logging_config import get_logger, TRACE
|
||||
from .streamable_http import StreamableHTTPClient, StreamableHTTPError
|
||||
import logging
|
||||
|
||||
|
||||
|
|
@ -32,9 +31,30 @@ class MCPServer:
|
|||
self._pending_requests: Dict[Union[str, int], asyncio.Future] = {}
|
||||
self._last_error_time: Optional[float] = None
|
||||
self._offline_since: Optional[float] = None
|
||||
self._http_client: Optional[StreamableHTTPClient] = None
|
||||
|
||||
async def start(self):
|
||||
"""Start the MCP server process."""
|
||||
if self.config.transport.type == "streamable-http":
|
||||
if self._http_client:
|
||||
return
|
||||
if not self.config.transport.url:
|
||||
raise ValueError("Streamable HTTP transport requires 'url' option")
|
||||
|
||||
self.logger.info(f"Connecting to streamable HTTP endpoint: {self.config.transport.url}")
|
||||
self._http_client = StreamableHTTPClient(
|
||||
self.config.transport.url,
|
||||
headers=self.config.transport.headers,
|
||||
timeout=self.config.transport.timeout,
|
||||
sse_timeout=self.config.transport.sse_timeout,
|
||||
oauth_config=self.config.transport.oauth,
|
||||
logger=self.logger,
|
||||
)
|
||||
await self._http_client.start()
|
||||
self._running = True
|
||||
self._offline_since = None
|
||||
return
|
||||
|
||||
if self.process:
|
||||
return
|
||||
|
||||
|
|
@ -46,6 +66,9 @@ class MCPServer:
|
|||
self.logger.warning(f"Server has been offline for {offline_duration:.0f}s, skipping start")
|
||||
raise RuntimeError(f"Server marked as offline since {offline_duration:.0f}s ago")
|
||||
|
||||
if not self.config.command:
|
||||
raise ValueError("Server command not configured for stdio transport")
|
||||
|
||||
# Prepare environment
|
||||
env = os.environ.copy()
|
||||
env.update({
|
||||
|
|
@ -85,6 +108,13 @@ class MCPServer:
|
|||
|
||||
async def stop(self):
|
||||
"""Stop the MCP server process."""
|
||||
if self.config.transport.type == "streamable-http":
|
||||
self._running = False
|
||||
if self._http_client:
|
||||
await self._http_client.stop()
|
||||
self._http_client = None
|
||||
return
|
||||
|
||||
self._running = False
|
||||
|
||||
if self.process:
|
||||
|
|
@ -116,6 +146,9 @@ class MCPServer:
|
|||
Returns:
|
||||
Response result or raises exception on error
|
||||
"""
|
||||
if self.config.transport.type == "streamable-http":
|
||||
return await self._send_request_http(method, params or {})
|
||||
|
||||
if not self.process:
|
||||
raise RuntimeError("MCP server not started")
|
||||
|
||||
|
|
@ -151,6 +184,41 @@ class MCPServer:
|
|||
self.logger.error(f"Timeout waiting for response to {method} (timeout={timeout}s)")
|
||||
self._mark_offline()
|
||||
raise TimeoutError(f"No response for request {request_id}")
|
||||
|
||||
async def _send_request_http(self, method: str, params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Send request via streamable HTTP transport."""
|
||||
if not self._http_client:
|
||||
raise RuntimeError("Streamable HTTP client not started")
|
||||
|
||||
request_id = self._next_id
|
||||
self._next_id += 1
|
||||
|
||||
request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"method": method,
|
||||
"params": params
|
||||
}
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
future: asyncio.Future = loop.create_future()
|
||||
self._pending_requests[request_id] = future
|
||||
|
||||
timeout = self.config.transport.timeout or 30.0
|
||||
|
||||
try:
|
||||
await self._http_client.send(request, self._handle_message)
|
||||
result = await asyncio.wait_for(future, timeout=timeout)
|
||||
return result
|
||||
except asyncio.TimeoutError:
|
||||
raise TimeoutError(f"No response for request {request_id}")
|
||||
except StreamableHTTPError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
self.logger.error(f"Streamable HTTP request failed: {exc}")
|
||||
raise
|
||||
finally:
|
||||
self._pending_requests.pop(request_id, None)
|
||||
|
||||
def send_raw(self, message: str):
|
||||
"""Send raw message to MCP server (for pass-through)."""
|
||||
|
|
@ -219,4 +287,4 @@ class MCPServer:
|
|||
try:
|
||||
handler(message)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Handler error: {e}")
|
||||
self.logger.error(f"Handler error: {e}")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,289 @@
|
|||
"""
|
||||
Minimal Streamable HTTP transport support for MCP Browser.
|
||||
|
||||
Provides a lightweight client that can communicate with MCP servers using
|
||||
the streamable-http transport defined by the MCP specification.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from .config import OAuthClientCredentialsConfig
|
||||
from .logging_config import get_logger
|
||||
|
||||
MCP_SESSION_HEADER = "mcp-session-id"
|
||||
MCP_PROTOCOL_VERSION_HEADER = "mcp-protocol-version"
|
||||
LAST_EVENT_ID_HEADER = "last-event-id"
|
||||
|
||||
|
||||
class StreamableHTTPError(Exception):
|
||||
"""Base error for streamable HTTP transport failures."""
|
||||
|
||||
|
||||
class StreamableHTTPClient:
|
||||
"""Minimal client for the MCP streamable-http transport."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: float = 30.0,
|
||||
sse_timeout: float = 300.0,
|
||||
oauth_config: Optional[OAuthClientCredentialsConfig] = None,
|
||||
logger=None,
|
||||
) -> None:
|
||||
self.url = url
|
||||
self.headers = headers.copy() if headers else {}
|
||||
self.timeout = timeout
|
||||
self.sse_timeout = sse_timeout
|
||||
self.oauth_config = oauth_config
|
||||
self.logger = logger or get_logger(__name__)
|
||||
|
||||
self._client: Optional[httpx.AsyncClient] = None
|
||||
self._session_id: Optional[str] = None
|
||||
self._protocol_version: Optional[str] = None
|
||||
self._last_event_id: Optional[str] = None
|
||||
|
||||
self._token: Optional[str] = oauth_config.access_token if oauth_config else None
|
||||
self._token_expires_at: Optional[float] = None
|
||||
self._token_lock = asyncio.Lock()
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Initialise underlying HTTP client."""
|
||||
if self._client is not None:
|
||||
return
|
||||
|
||||
self._client = httpx.AsyncClient(timeout=self.timeout)
|
||||
if self.oauth_config and self.oauth_config.token_url:
|
||||
await self._ensure_token()
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Shutdown HTTP client."""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
async def send(
|
||||
self,
|
||||
request: Dict[str, Any],
|
||||
message_callback: Callable[[Dict[str, Any]], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Send a JSON-RPC request and stream responses to callback."""
|
||||
if not self._client:
|
||||
raise StreamableHTTPError("Streamable HTTP client not started")
|
||||
|
||||
await self._ensure_token()
|
||||
|
||||
headers = self._build_headers()
|
||||
try:
|
||||
response = await self._client.post(self.url, json=request, headers=headers)
|
||||
except httpx.HTTPError as exc:
|
||||
raise StreamableHTTPError(f"Streamable HTTP request failed: {exc}") from exc
|
||||
|
||||
self._store_session_headers(response)
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
body = await self._safe_read_text(response)
|
||||
message = body or exc.response.text
|
||||
raise StreamableHTTPError(f"HTTP {exc.response.status_code}: {message}") from exc
|
||||
|
||||
content_type = response.headers.get("content-type", "")
|
||||
|
||||
if "text/event-stream" in content_type:
|
||||
await self._consume_sse(response, message_callback)
|
||||
else:
|
||||
await self._consume_json(response, message_callback)
|
||||
|
||||
def _build_headers(self) -> Dict[str, str]:
|
||||
headers = {
|
||||
"accept": "application/json, text/event-stream",
|
||||
"content-type": "application/json",
|
||||
**self.headers,
|
||||
}
|
||||
|
||||
if self._session_id and MCP_SESSION_HEADER not in headers:
|
||||
headers[MCP_SESSION_HEADER] = self._session_id
|
||||
if self._protocol_version and MCP_PROTOCOL_VERSION_HEADER not in headers:
|
||||
headers[MCP_PROTOCOL_VERSION_HEADER] = self._protocol_version
|
||||
if self._last_event_id and LAST_EVENT_ID_HEADER not in headers:
|
||||
headers[LAST_EVENT_ID_HEADER] = self._last_event_id
|
||||
if self._token and "authorization" not in {k.lower(): v for k, v in headers.items()}:
|
||||
headers["Authorization"] = f"Bearer {self._token}"
|
||||
|
||||
return headers
|
||||
|
||||
async def _ensure_token(self) -> None:
|
||||
if not self.oauth_config:
|
||||
return
|
||||
if self.oauth_config.access_token and not self.oauth_config.token_url:
|
||||
self._token = self.oauth_config.access_token
|
||||
return
|
||||
if not self.oauth_config.token_url or not self.oauth_config.client_id:
|
||||
return
|
||||
|
||||
async with self._token_lock:
|
||||
if self._token and self._token_expires_at:
|
||||
if time.time() < self._token_expires_at - 30:
|
||||
return
|
||||
await self._refresh_token()
|
||||
|
||||
async def _refresh_token(self) -> None:
|
||||
assert self._client is not None
|
||||
data = {
|
||||
"grant_type": "client_credentials",
|
||||
}
|
||||
if self.oauth_config.scope:
|
||||
data["scope"] = self.oauth_config.scope
|
||||
if self.oauth_config.audience:
|
||||
data["audience"] = self.oauth_config.audience
|
||||
if self.oauth_config.extra_params:
|
||||
data.update(self.oauth_config.extra_params)
|
||||
|
||||
auth = None
|
||||
if self.oauth_config.client_secret:
|
||||
auth = (self.oauth_config.client_id or "", self.oauth_config.client_secret)
|
||||
else:
|
||||
data["client_id"] = self.oauth_config.client_id
|
||||
|
||||
response = await self._client.post(
|
||||
self.oauth_config.token_url, # type: ignore[arg-type]
|
||||
data=data,
|
||||
headers={"content-type": "application/x-www-form-urlencoded"},
|
||||
auth=auth,
|
||||
)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPError as exc:
|
||||
body = await self._safe_read_text(response)
|
||||
raise StreamableHTTPError(
|
||||
f"Failed to refresh OAuth token: {body or response.text}"
|
||||
) from exc
|
||||
|
||||
token_data = response.json()
|
||||
access_token = token_data.get("access_token")
|
||||
if not access_token:
|
||||
raise StreamableHTTPError("OAuth token response missing access_token")
|
||||
|
||||
self._token = access_token
|
||||
expires_in = token_data.get("expires_in")
|
||||
if expires_in:
|
||||
try:
|
||||
self._token_expires_at = time.time() + float(expires_in)
|
||||
except (TypeError, ValueError):
|
||||
self._token_expires_at = None
|
||||
else:
|
||||
self._token_expires_at = None
|
||||
|
||||
async def _consume_json(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
message_callback: Callable[[Dict[str, Any]], Awaitable[None]],
|
||||
) -> None:
|
||||
text = await response.aread()
|
||||
if not text:
|
||||
return
|
||||
|
||||
try:
|
||||
payload = json.loads(text)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise StreamableHTTPError(f"Invalid JSON response: {exc}") from exc
|
||||
|
||||
messages = payload if isinstance(payload, list) else [payload]
|
||||
for message in messages:
|
||||
if isinstance(message, dict):
|
||||
self._update_protocol_version(message)
|
||||
await message_callback(message)
|
||||
|
||||
async def _consume_sse(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
message_callback: Callable[[Dict[str, Any]], Awaitable[None]],
|
||||
) -> None:
|
||||
event: Dict[str, Any] = {"data": []}
|
||||
async for line in response.aiter_lines():
|
||||
if line == "":
|
||||
await self._flush_event(event, message_callback)
|
||||
event = {"data": []}
|
||||
continue
|
||||
|
||||
if line.startswith(":"):
|
||||
continue
|
||||
|
||||
field, _, raw_value = line.partition(":")
|
||||
value = raw_value.lstrip(" ")
|
||||
if field == "data":
|
||||
event.setdefault("data", []).append(value)
|
||||
elif field == "event":
|
||||
event["event"] = value
|
||||
elif field == "id":
|
||||
event["id"] = value
|
||||
elif field == "retry":
|
||||
continue
|
||||
|
||||
await self._flush_event(event, message_callback)
|
||||
|
||||
async def _flush_event(
|
||||
self,
|
||||
event: Dict[str, Any],
|
||||
message_callback: Callable[[Dict[str, Any]], Awaitable[None]],
|
||||
) -> None:
|
||||
data_lines = event.get("data") or []
|
||||
if not data_lines:
|
||||
return
|
||||
|
||||
data = "\n".join(data_lines)
|
||||
try:
|
||||
message = json.loads(data)
|
||||
except json.JSONDecodeError:
|
||||
self.logger.warning("Failed to parse SSE event payload")
|
||||
return
|
||||
|
||||
if not isinstance(message, dict):
|
||||
return
|
||||
|
||||
if event.get("id"):
|
||||
self._last_event_id = event["id"]
|
||||
|
||||
self._update_protocol_version(message)
|
||||
await message_callback(message)
|
||||
|
||||
def _update_protocol_version(self, message: Dict[str, Any]) -> None:
|
||||
result = message.get("result")
|
||||
if isinstance(result, dict) and result.get("protocolVersion"):
|
||||
self._protocol_version = str(result["protocolVersion"])
|
||||
|
||||
def _store_session_headers(self, response: httpx.Response) -> None:
|
||||
session_id = response.headers.get(MCP_SESSION_HEADER)
|
||||
if session_id:
|
||||
self._session_id = session_id
|
||||
|
||||
async def _safe_read_text(self, response: httpx.Response) -> Optional[str]:
|
||||
try:
|
||||
raw = await response.aread()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if isinstance(raw, str):
|
||||
return raw
|
||||
|
||||
try:
|
||||
charset = response.charset or "utf-8"
|
||||
except Exception:
|
||||
charset = "utf-8"
|
||||
|
||||
try:
|
||||
return raw.decode(charset)
|
||||
except Exception:
|
||||
try:
|
||||
return raw.decode("utf-8", errors="replace")
|
||||
except Exception:
|
||||
return None
|
||||
|
|
@ -1,2 +1,3 @@
|
|||
pyyaml>=6.0
|
||||
jsonpath-ng>=1.5.3
|
||||
jsonpath-ng>=1.5.3
|
||||
httpx>=0.27
|
||||
|
|
|
|||
3
setup.py
3
setup.py
|
|
@ -317,6 +317,7 @@ setup(
|
|||
"aiofiles>=23.0.0",
|
||||
"jsonpath-ng>=1.6.0",
|
||||
"pyyaml>=6.0",
|
||||
"httpx>=0.27",
|
||||
"typing-extensions>=4.0.0;python_version<'3.11'",
|
||||
],
|
||||
extras_require={
|
||||
|
|
@ -359,4 +360,4 @@ setup(
|
|||
"Programming Language :: Python :: 3.12",
|
||||
],
|
||||
keywords="mcp model-context-protocol ai llm tools json-rpc",
|
||||
)
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue