""" Minimal Streamable HTTP transport support for MCP Browser. Provides a lightweight client that can communicate with MCP servers using the streamable-http transport defined by the MCP specification. """ from __future__ import annotations import asyncio import json import time from typing import Any, Awaitable, Callable, Dict, Optional import httpx from .config import OAuthClientCredentialsConfig from .logging_config import get_logger MCP_SESSION_HEADER = "mcp-session-id" MCP_PROTOCOL_VERSION_HEADER = "mcp-protocol-version" LAST_EVENT_ID_HEADER = "last-event-id" class StreamableHTTPError(Exception): """Base error for streamable HTTP transport failures.""" class StreamableHTTPClient: """Minimal client for the MCP streamable-http transport.""" def __init__( self, url: str, *, headers: Optional[Dict[str, str]] = None, timeout: float = 30.0, sse_timeout: float = 300.0, oauth_config: Optional[OAuthClientCredentialsConfig] = None, logger=None, ) -> None: self.url = url self.headers = headers.copy() if headers else {} self.timeout = timeout self.sse_timeout = sse_timeout self.oauth_config = oauth_config self.logger = logger or get_logger(__name__) self._client: Optional[httpx.AsyncClient] = None self._session_id: Optional[str] = None self._protocol_version: Optional[str] = None self._last_event_id: Optional[str] = None self._token: Optional[str] = oauth_config.access_token if oauth_config else None self._token_expires_at: Optional[float] = None self._token_lock = asyncio.Lock() async def start(self) -> None: """Initialise underlying HTTP client.""" if self._client is not None: return self._client = httpx.AsyncClient(timeout=self.timeout) if self.oauth_config and self.oauth_config.token_url: await self._ensure_token() async def stop(self) -> None: """Shutdown HTTP client.""" if self._client: await self._client.aclose() self._client = None async def send( self, request: Dict[str, Any], message_callback: Callable[[Dict[str, Any]], Awaitable[None]], ) -> None: """Send a JSON-RPC request and stream responses to callback.""" if not self._client: raise StreamableHTTPError("Streamable HTTP client not started") await self._ensure_token() headers = self._build_headers() try: response = await self._client.post(self.url, json=request, headers=headers) except httpx.HTTPError as exc: raise StreamableHTTPError(f"Streamable HTTP request failed: {exc}") from exc self._store_session_headers(response) try: response.raise_for_status() except httpx.HTTPStatusError as exc: body = await self._safe_read_text(response) message = body or exc.response.text raise StreamableHTTPError(f"HTTP {exc.response.status_code}: {message}") from exc content_type = response.headers.get("content-type", "") if "text/event-stream" in content_type: await self._consume_sse(response, message_callback) else: await self._consume_json(response, message_callback) def _build_headers(self) -> Dict[str, str]: headers = { "accept": "application/json, text/event-stream", "content-type": "application/json", **self.headers, } if self._session_id and MCP_SESSION_HEADER not in headers: headers[MCP_SESSION_HEADER] = self._session_id if self._protocol_version and MCP_PROTOCOL_VERSION_HEADER not in headers: headers[MCP_PROTOCOL_VERSION_HEADER] = self._protocol_version if self._last_event_id and LAST_EVENT_ID_HEADER not in headers: headers[LAST_EVENT_ID_HEADER] = self._last_event_id if self._token and "authorization" not in {k.lower(): v for k, v in headers.items()}: headers["Authorization"] = f"Bearer {self._token}" return headers async def _ensure_token(self) -> None: if not self.oauth_config: return if self.oauth_config.access_token and not self.oauth_config.token_url: self._token = self.oauth_config.access_token return if not self.oauth_config.token_url or not self.oauth_config.client_id: return async with self._token_lock: if self._token and self._token_expires_at: if time.time() < self._token_expires_at - 30: return await self._refresh_token() async def _refresh_token(self) -> None: assert self._client is not None data = { "grant_type": "client_credentials", } if self.oauth_config.scope: data["scope"] = self.oauth_config.scope if self.oauth_config.audience: data["audience"] = self.oauth_config.audience if self.oauth_config.extra_params: data.update(self.oauth_config.extra_params) auth = None if self.oauth_config.client_secret: auth = (self.oauth_config.client_id or "", self.oauth_config.client_secret) else: data["client_id"] = self.oauth_config.client_id response = await self._client.post( self.oauth_config.token_url, # type: ignore[arg-type] data=data, headers={"content-type": "application/x-www-form-urlencoded"}, auth=auth, ) try: response.raise_for_status() except httpx.HTTPError as exc: body = await self._safe_read_text(response) raise StreamableHTTPError( f"Failed to refresh OAuth token: {body or response.text}" ) from exc token_data = response.json() access_token = token_data.get("access_token") if not access_token: raise StreamableHTTPError("OAuth token response missing access_token") self._token = access_token expires_in = token_data.get("expires_in") if expires_in: try: self._token_expires_at = time.time() + float(expires_in) except (TypeError, ValueError): self._token_expires_at = None else: self._token_expires_at = None async def _consume_json( self, response: httpx.Response, message_callback: Callable[[Dict[str, Any]], Awaitable[None]], ) -> None: text = await response.aread() if not text: return try: payload = json.loads(text) except json.JSONDecodeError as exc: raise StreamableHTTPError(f"Invalid JSON response: {exc}") from exc messages = payload if isinstance(payload, list) else [payload] for message in messages: if isinstance(message, dict): self._update_protocol_version(message) await message_callback(message) async def _consume_sse( self, response: httpx.Response, message_callback: Callable[[Dict[str, Any]], Awaitable[None]], ) -> None: event: Dict[str, Any] = {"data": []} async for line in response.aiter_lines(): if line == "": await self._flush_event(event, message_callback) event = {"data": []} continue if line.startswith(":"): continue field, _, raw_value = line.partition(":") value = raw_value.lstrip(" ") if field == "data": event.setdefault("data", []).append(value) elif field == "event": event["event"] = value elif field == "id": event["id"] = value elif field == "retry": continue await self._flush_event(event, message_callback) async def _flush_event( self, event: Dict[str, Any], message_callback: Callable[[Dict[str, Any]], Awaitable[None]], ) -> None: data_lines = event.get("data") or [] if not data_lines: return data = "\n".join(data_lines) try: message = json.loads(data) except json.JSONDecodeError: self.logger.warning("Failed to parse SSE event payload") return if not isinstance(message, dict): return if event.get("id"): self._last_event_id = event["id"] self._update_protocol_version(message) await message_callback(message) def _update_protocol_version(self, message: Dict[str, Any]) -> None: result = message.get("result") if isinstance(result, dict) and result.get("protocolVersion"): self._protocol_version = str(result["protocolVersion"]) def _store_session_headers(self, response: httpx.Response) -> None: session_id = response.headers.get(MCP_SESSION_HEADER) if session_id: self._session_id = session_id async def _safe_read_text(self, response: httpx.Response) -> Optional[str]: try: raw = await response.aread() except Exception: return None if isinstance(raw, str): return raw try: charset = response.charset or "utf-8" except Exception: charset = "utf-8" try: return raw.decode(charset) except Exception: try: return raw.decode("utf-8", errors="replace") except Exception: return None