mcp-browser/mcp_browser/streamable_http.py

290 lines
9.6 KiB
Python

"""
Minimal Streamable HTTP transport support for MCP Browser.
Provides a lightweight client that can communicate with MCP servers using
the streamable-http transport defined by the MCP specification.
"""
from __future__ import annotations
import asyncio
import json
import time
from typing import Any, Awaitable, Callable, Dict, Optional
import httpx
from .config import OAuthClientCredentialsConfig
from .logging_config import get_logger
MCP_SESSION_HEADER = "mcp-session-id"
MCP_PROTOCOL_VERSION_HEADER = "mcp-protocol-version"
LAST_EVENT_ID_HEADER = "last-event-id"
class StreamableHTTPError(Exception):
"""Base error for streamable HTTP transport failures."""
class StreamableHTTPClient:
"""Minimal client for the MCP streamable-http transport."""
def __init__(
self,
url: str,
*,
headers: Optional[Dict[str, str]] = None,
timeout: float = 30.0,
sse_timeout: float = 300.0,
oauth_config: Optional[OAuthClientCredentialsConfig] = None,
logger=None,
) -> None:
self.url = url
self.headers = headers.copy() if headers else {}
self.timeout = timeout
self.sse_timeout = sse_timeout
self.oauth_config = oauth_config
self.logger = logger or get_logger(__name__)
self._client: Optional[httpx.AsyncClient] = None
self._session_id: Optional[str] = None
self._protocol_version: Optional[str] = None
self._last_event_id: Optional[str] = None
self._token: Optional[str] = oauth_config.access_token if oauth_config else None
self._token_expires_at: Optional[float] = None
self._token_lock = asyncio.Lock()
async def start(self) -> None:
"""Initialise underlying HTTP client."""
if self._client is not None:
return
self._client = httpx.AsyncClient(timeout=self.timeout)
if self.oauth_config and self.oauth_config.token_url:
await self._ensure_token()
async def stop(self) -> None:
"""Shutdown HTTP client."""
if self._client:
await self._client.aclose()
self._client = None
async def send(
self,
request: Dict[str, Any],
message_callback: Callable[[Dict[str, Any]], Awaitable[None]],
) -> None:
"""Send a JSON-RPC request and stream responses to callback."""
if not self._client:
raise StreamableHTTPError("Streamable HTTP client not started")
await self._ensure_token()
headers = self._build_headers()
try:
response = await self._client.post(self.url, json=request, headers=headers)
except httpx.HTTPError as exc:
raise StreamableHTTPError(f"Streamable HTTP request failed: {exc}") from exc
self._store_session_headers(response)
try:
response.raise_for_status()
except httpx.HTTPStatusError as exc:
body = await self._safe_read_text(response)
message = body or exc.response.text
raise StreamableHTTPError(f"HTTP {exc.response.status_code}: {message}") from exc
content_type = response.headers.get("content-type", "")
if "text/event-stream" in content_type:
await self._consume_sse(response, message_callback)
else:
await self._consume_json(response, message_callback)
def _build_headers(self) -> Dict[str, str]:
headers = {
"accept": "application/json, text/event-stream",
"content-type": "application/json",
**self.headers,
}
if self._session_id and MCP_SESSION_HEADER not in headers:
headers[MCP_SESSION_HEADER] = self._session_id
if self._protocol_version and MCP_PROTOCOL_VERSION_HEADER not in headers:
headers[MCP_PROTOCOL_VERSION_HEADER] = self._protocol_version
if self._last_event_id and LAST_EVENT_ID_HEADER not in headers:
headers[LAST_EVENT_ID_HEADER] = self._last_event_id
if self._token and "authorization" not in {k.lower(): v for k, v in headers.items()}:
headers["Authorization"] = f"Bearer {self._token}"
return headers
async def _ensure_token(self) -> None:
if not self.oauth_config:
return
if self.oauth_config.access_token and not self.oauth_config.token_url:
self._token = self.oauth_config.access_token
return
if not self.oauth_config.token_url or not self.oauth_config.client_id:
return
async with self._token_lock:
if self._token and self._token_expires_at:
if time.time() < self._token_expires_at - 30:
return
await self._refresh_token()
async def _refresh_token(self) -> None:
assert self._client is not None
data = {
"grant_type": "client_credentials",
}
if self.oauth_config.scope:
data["scope"] = self.oauth_config.scope
if self.oauth_config.audience:
data["audience"] = self.oauth_config.audience
if self.oauth_config.extra_params:
data.update(self.oauth_config.extra_params)
auth = None
if self.oauth_config.client_secret:
auth = (self.oauth_config.client_id or "", self.oauth_config.client_secret)
else:
data["client_id"] = self.oauth_config.client_id
response = await self._client.post(
self.oauth_config.token_url, # type: ignore[arg-type]
data=data,
headers={"content-type": "application/x-www-form-urlencoded"},
auth=auth,
)
try:
response.raise_for_status()
except httpx.HTTPError as exc:
body = await self._safe_read_text(response)
raise StreamableHTTPError(
f"Failed to refresh OAuth token: {body or response.text}"
) from exc
token_data = response.json()
access_token = token_data.get("access_token")
if not access_token:
raise StreamableHTTPError("OAuth token response missing access_token")
self._token = access_token
expires_in = token_data.get("expires_in")
if expires_in:
try:
self._token_expires_at = time.time() + float(expires_in)
except (TypeError, ValueError):
self._token_expires_at = None
else:
self._token_expires_at = None
async def _consume_json(
self,
response: httpx.Response,
message_callback: Callable[[Dict[str, Any]], Awaitable[None]],
) -> None:
text = await response.aread()
if not text:
return
try:
payload = json.loads(text)
except json.JSONDecodeError as exc:
raise StreamableHTTPError(f"Invalid JSON response: {exc}") from exc
messages = payload if isinstance(payload, list) else [payload]
for message in messages:
if isinstance(message, dict):
self._update_protocol_version(message)
await message_callback(message)
async def _consume_sse(
self,
response: httpx.Response,
message_callback: Callable[[Dict[str, Any]], Awaitable[None]],
) -> None:
event: Dict[str, Any] = {"data": []}
async for line in response.aiter_lines():
if line == "":
await self._flush_event(event, message_callback)
event = {"data": []}
continue
if line.startswith(":"):
continue
field, _, raw_value = line.partition(":")
value = raw_value.lstrip(" ")
if field == "data":
event.setdefault("data", []).append(value)
elif field == "event":
event["event"] = value
elif field == "id":
event["id"] = value
elif field == "retry":
continue
await self._flush_event(event, message_callback)
async def _flush_event(
self,
event: Dict[str, Any],
message_callback: Callable[[Dict[str, Any]], Awaitable[None]],
) -> None:
data_lines = event.get("data") or []
if not data_lines:
return
data = "\n".join(data_lines)
try:
message = json.loads(data)
except json.JSONDecodeError:
self.logger.warning("Failed to parse SSE event payload")
return
if not isinstance(message, dict):
return
if event.get("id"):
self._last_event_id = event["id"]
self._update_protocol_version(message)
await message_callback(message)
def _update_protocol_version(self, message: Dict[str, Any]) -> None:
result = message.get("result")
if isinstance(result, dict) and result.get("protocolVersion"):
self._protocol_version = str(result["protocolVersion"])
def _store_session_headers(self, response: httpx.Response) -> None:
session_id = response.headers.get(MCP_SESSION_HEADER)
if session_id:
self._session_id = session_id
async def _safe_read_text(self, response: httpx.Response) -> Optional[str]:
try:
raw = await response.aread()
except Exception:
return None
if isinstance(raw, str):
return raw
try:
charset = response.charset or "utf-8"
except Exception:
charset = "utf-8"
try:
return raw.decode(charset)
except Exception:
try:
return raw.decode("utf-8", errors="replace")
except Exception:
return None