290 lines
9.6 KiB
Python
290 lines
9.6 KiB
Python
"""
|
|
Minimal Streamable HTTP transport support for MCP Browser.
|
|
|
|
Provides a lightweight client that can communicate with MCP servers using
|
|
the streamable-http transport defined by the MCP specification.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import time
|
|
from typing import Any, Awaitable, Callable, Dict, Optional
|
|
|
|
import httpx
|
|
|
|
from .config import OAuthClientCredentialsConfig
|
|
from .logging_config import get_logger
|
|
|
|
MCP_SESSION_HEADER = "mcp-session-id"
|
|
MCP_PROTOCOL_VERSION_HEADER = "mcp-protocol-version"
|
|
LAST_EVENT_ID_HEADER = "last-event-id"
|
|
|
|
|
|
class StreamableHTTPError(Exception):
|
|
"""Base error for streamable HTTP transport failures."""
|
|
|
|
|
|
class StreamableHTTPClient:
|
|
"""Minimal client for the MCP streamable-http transport."""
|
|
|
|
def __init__(
|
|
self,
|
|
url: str,
|
|
*,
|
|
headers: Optional[Dict[str, str]] = None,
|
|
timeout: float = 30.0,
|
|
sse_timeout: float = 300.0,
|
|
oauth_config: Optional[OAuthClientCredentialsConfig] = None,
|
|
logger=None,
|
|
) -> None:
|
|
self.url = url
|
|
self.headers = headers.copy() if headers else {}
|
|
self.timeout = timeout
|
|
self.sse_timeout = sse_timeout
|
|
self.oauth_config = oauth_config
|
|
self.logger = logger or get_logger(__name__)
|
|
|
|
self._client: Optional[httpx.AsyncClient] = None
|
|
self._session_id: Optional[str] = None
|
|
self._protocol_version: Optional[str] = None
|
|
self._last_event_id: Optional[str] = None
|
|
|
|
self._token: Optional[str] = oauth_config.access_token if oauth_config else None
|
|
self._token_expires_at: Optional[float] = None
|
|
self._token_lock = asyncio.Lock()
|
|
|
|
async def start(self) -> None:
|
|
"""Initialise underlying HTTP client."""
|
|
if self._client is not None:
|
|
return
|
|
|
|
self._client = httpx.AsyncClient(timeout=self.timeout)
|
|
if self.oauth_config and self.oauth_config.token_url:
|
|
await self._ensure_token()
|
|
|
|
async def stop(self) -> None:
|
|
"""Shutdown HTTP client."""
|
|
if self._client:
|
|
await self._client.aclose()
|
|
self._client = None
|
|
|
|
async def send(
|
|
self,
|
|
request: Dict[str, Any],
|
|
message_callback: Callable[[Dict[str, Any]], Awaitable[None]],
|
|
) -> None:
|
|
"""Send a JSON-RPC request and stream responses to callback."""
|
|
if not self._client:
|
|
raise StreamableHTTPError("Streamable HTTP client not started")
|
|
|
|
await self._ensure_token()
|
|
|
|
headers = self._build_headers()
|
|
try:
|
|
response = await self._client.post(self.url, json=request, headers=headers)
|
|
except httpx.HTTPError as exc:
|
|
raise StreamableHTTPError(f"Streamable HTTP request failed: {exc}") from exc
|
|
|
|
self._store_session_headers(response)
|
|
|
|
try:
|
|
response.raise_for_status()
|
|
except httpx.HTTPStatusError as exc:
|
|
body = await self._safe_read_text(response)
|
|
message = body or exc.response.text
|
|
raise StreamableHTTPError(f"HTTP {exc.response.status_code}: {message}") from exc
|
|
|
|
content_type = response.headers.get("content-type", "")
|
|
|
|
if "text/event-stream" in content_type:
|
|
await self._consume_sse(response, message_callback)
|
|
else:
|
|
await self._consume_json(response, message_callback)
|
|
|
|
def _build_headers(self) -> Dict[str, str]:
|
|
headers = {
|
|
"accept": "application/json, text/event-stream",
|
|
"content-type": "application/json",
|
|
**self.headers,
|
|
}
|
|
|
|
if self._session_id and MCP_SESSION_HEADER not in headers:
|
|
headers[MCP_SESSION_HEADER] = self._session_id
|
|
if self._protocol_version and MCP_PROTOCOL_VERSION_HEADER not in headers:
|
|
headers[MCP_PROTOCOL_VERSION_HEADER] = self._protocol_version
|
|
if self._last_event_id and LAST_EVENT_ID_HEADER not in headers:
|
|
headers[LAST_EVENT_ID_HEADER] = self._last_event_id
|
|
if self._token and "authorization" not in {k.lower(): v for k, v in headers.items()}:
|
|
headers["Authorization"] = f"Bearer {self._token}"
|
|
|
|
return headers
|
|
|
|
async def _ensure_token(self) -> None:
|
|
if not self.oauth_config:
|
|
return
|
|
if self.oauth_config.access_token and not self.oauth_config.token_url:
|
|
self._token = self.oauth_config.access_token
|
|
return
|
|
if not self.oauth_config.token_url or not self.oauth_config.client_id:
|
|
return
|
|
|
|
async with self._token_lock:
|
|
if self._token and self._token_expires_at:
|
|
if time.time() < self._token_expires_at - 30:
|
|
return
|
|
await self._refresh_token()
|
|
|
|
async def _refresh_token(self) -> None:
|
|
assert self._client is not None
|
|
data = {
|
|
"grant_type": "client_credentials",
|
|
}
|
|
if self.oauth_config.scope:
|
|
data["scope"] = self.oauth_config.scope
|
|
if self.oauth_config.audience:
|
|
data["audience"] = self.oauth_config.audience
|
|
if self.oauth_config.extra_params:
|
|
data.update(self.oauth_config.extra_params)
|
|
|
|
auth = None
|
|
if self.oauth_config.client_secret:
|
|
auth = (self.oauth_config.client_id or "", self.oauth_config.client_secret)
|
|
else:
|
|
data["client_id"] = self.oauth_config.client_id
|
|
|
|
response = await self._client.post(
|
|
self.oauth_config.token_url, # type: ignore[arg-type]
|
|
data=data,
|
|
headers={"content-type": "application/x-www-form-urlencoded"},
|
|
auth=auth,
|
|
)
|
|
try:
|
|
response.raise_for_status()
|
|
except httpx.HTTPError as exc:
|
|
body = await self._safe_read_text(response)
|
|
raise StreamableHTTPError(
|
|
f"Failed to refresh OAuth token: {body or response.text}"
|
|
) from exc
|
|
|
|
token_data = response.json()
|
|
access_token = token_data.get("access_token")
|
|
if not access_token:
|
|
raise StreamableHTTPError("OAuth token response missing access_token")
|
|
|
|
self._token = access_token
|
|
expires_in = token_data.get("expires_in")
|
|
if expires_in:
|
|
try:
|
|
self._token_expires_at = time.time() + float(expires_in)
|
|
except (TypeError, ValueError):
|
|
self._token_expires_at = None
|
|
else:
|
|
self._token_expires_at = None
|
|
|
|
async def _consume_json(
|
|
self,
|
|
response: httpx.Response,
|
|
message_callback: Callable[[Dict[str, Any]], Awaitable[None]],
|
|
) -> None:
|
|
text = await response.aread()
|
|
if not text:
|
|
return
|
|
|
|
try:
|
|
payload = json.loads(text)
|
|
except json.JSONDecodeError as exc:
|
|
raise StreamableHTTPError(f"Invalid JSON response: {exc}") from exc
|
|
|
|
messages = payload if isinstance(payload, list) else [payload]
|
|
for message in messages:
|
|
if isinstance(message, dict):
|
|
self._update_protocol_version(message)
|
|
await message_callback(message)
|
|
|
|
async def _consume_sse(
|
|
self,
|
|
response: httpx.Response,
|
|
message_callback: Callable[[Dict[str, Any]], Awaitable[None]],
|
|
) -> None:
|
|
event: Dict[str, Any] = {"data": []}
|
|
async for line in response.aiter_lines():
|
|
if line == "":
|
|
await self._flush_event(event, message_callback)
|
|
event = {"data": []}
|
|
continue
|
|
|
|
if line.startswith(":"):
|
|
continue
|
|
|
|
field, _, raw_value = line.partition(":")
|
|
value = raw_value.lstrip(" ")
|
|
if field == "data":
|
|
event.setdefault("data", []).append(value)
|
|
elif field == "event":
|
|
event["event"] = value
|
|
elif field == "id":
|
|
event["id"] = value
|
|
elif field == "retry":
|
|
continue
|
|
|
|
await self._flush_event(event, message_callback)
|
|
|
|
async def _flush_event(
|
|
self,
|
|
event: Dict[str, Any],
|
|
message_callback: Callable[[Dict[str, Any]], Awaitable[None]],
|
|
) -> None:
|
|
data_lines = event.get("data") or []
|
|
if not data_lines:
|
|
return
|
|
|
|
data = "\n".join(data_lines)
|
|
try:
|
|
message = json.loads(data)
|
|
except json.JSONDecodeError:
|
|
self.logger.warning("Failed to parse SSE event payload")
|
|
return
|
|
|
|
if not isinstance(message, dict):
|
|
return
|
|
|
|
if event.get("id"):
|
|
self._last_event_id = event["id"]
|
|
|
|
self._update_protocol_version(message)
|
|
await message_callback(message)
|
|
|
|
def _update_protocol_version(self, message: Dict[str, Any]) -> None:
|
|
result = message.get("result")
|
|
if isinstance(result, dict) and result.get("protocolVersion"):
|
|
self._protocol_version = str(result["protocolVersion"])
|
|
|
|
def _store_session_headers(self, response: httpx.Response) -> None:
|
|
session_id = response.headers.get(MCP_SESSION_HEADER)
|
|
if session_id:
|
|
self._session_id = session_id
|
|
|
|
async def _safe_read_text(self, response: httpx.Response) -> Optional[str]:
|
|
try:
|
|
raw = await response.aread()
|
|
except Exception:
|
|
return None
|
|
|
|
if isinstance(raw, str):
|
|
return raw
|
|
|
|
try:
|
|
charset = response.charset or "utf-8"
|
|
except Exception:
|
|
charset = "utf-8"
|
|
|
|
try:
|
|
return raw.decode(charset)
|
|
except Exception:
|
|
try:
|
|
return raw.decode("utf-8", errors="replace")
|
|
except Exception:
|
|
return None
|