209 lines
7.8 KiB
Python
209 lines
7.8 KiB
Python
"""
|
|
Configuration management for MCP Browser.
|
|
|
|
Handles loading and validation of MCP server configurations,
|
|
supporting hierarchical config loading and runtime overrides.
|
|
"""
|
|
|
|
import os
|
|
import yaml
|
|
from typing import Dict, Any, Optional, List
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
|
|
from .default_configs import ConfigManager
|
|
|
|
|
|
@dataclass
|
|
class OAuthClientCredentialsConfig:
|
|
"""Configuration for OAuth client credentials flow."""
|
|
token_url: Optional[str] = None
|
|
client_id: Optional[str] = None
|
|
client_secret: Optional[str] = None
|
|
scope: Optional[str] = None
|
|
audience: Optional[str] = None
|
|
extra_params: Dict[str, str] = field(default_factory=dict)
|
|
access_token: Optional[str] = None
|
|
|
|
|
|
@dataclass
|
|
class TransportConfig:
|
|
"""Transport configuration for connecting to an MCP server."""
|
|
type: str = "stdio"
|
|
url: Optional[str] = None
|
|
headers: Dict[str, str] = field(default_factory=dict)
|
|
timeout: float = 30.0
|
|
sse_timeout: float = 300.0
|
|
oauth: Optional[OAuthClientCredentialsConfig] = None
|
|
|
|
|
|
@dataclass
|
|
class MCPServerConfig:
|
|
"""Configuration for a single MCP server."""
|
|
command: Optional[List[str]] = None
|
|
args: List[str] = field(default_factory=list)
|
|
env: Dict[str, str] = field(default_factory=dict)
|
|
name: Optional[str] = None
|
|
description: Optional[str] = None
|
|
enabled: bool = True
|
|
transport: TransportConfig = field(default_factory=TransportConfig)
|
|
|
|
|
|
@dataclass
|
|
class MCPBrowserConfig:
|
|
"""Main configuration for MCP Browser."""
|
|
servers: Dict[str, MCPServerConfig] = field(default_factory=dict)
|
|
default_server: Optional[str] = None
|
|
sparse_mode: bool = True
|
|
debug: bool = False
|
|
buffer_size: int = 65536
|
|
timeout: float = 30.0
|
|
enable_builtin_servers: bool = True
|
|
|
|
|
|
class ConfigLoader:
|
|
"""Loads and manages MCP Browser configuration."""
|
|
|
|
DEFAULT_CONFIG = {
|
|
"servers": {
|
|
"default": {
|
|
"command": ["npx", "-y", "@modelcontextprotocol/server-memory"],
|
|
"name": "memory",
|
|
"description": "Default in-memory MCP server"
|
|
}
|
|
},
|
|
"default_server": "default",
|
|
"sparse_mode": True,
|
|
"debug": False,
|
|
"enable_builtin_servers": True
|
|
}
|
|
|
|
def __init__(self, config_path: Optional[Path] = None):
|
|
self.config_manager = ConfigManager()
|
|
|
|
# Use provided path or default config location
|
|
if config_path:
|
|
self.config_path = config_path
|
|
else:
|
|
# Ensure default config exists
|
|
self.config_manager.ensure_config_directory()
|
|
self.config_path = self.config_manager.get_config_path()
|
|
|
|
self._config: Optional[MCPBrowserConfig] = None
|
|
|
|
|
|
def load(self) -> MCPBrowserConfig:
|
|
"""Load configuration from file or use defaults."""
|
|
if self._config:
|
|
return self._config
|
|
|
|
config_data = self.DEFAULT_CONFIG.copy()
|
|
|
|
if self.config_path and self.config_path.exists():
|
|
with open(self.config_path) as f:
|
|
file_config = yaml.safe_load(f)
|
|
if file_config:
|
|
self._merge_configs(config_data, file_config)
|
|
|
|
# Convert to dataclass instances
|
|
servers = {}
|
|
for name, server_config in config_data.get("servers", {}).items():
|
|
transport_data = server_config.get("transport", {}) or {}
|
|
|
|
# Merge transport metadata from legacy locations
|
|
raw_transport_type = (
|
|
transport_data.get("type")
|
|
or server_config.get("transportType")
|
|
or server_config.get("type")
|
|
)
|
|
normalized_type = "stdio"
|
|
if isinstance(raw_transport_type, str):
|
|
lowered = raw_transport_type.replace("_", "-").lower()
|
|
if lowered in {"streamable-http", "streamablehttp"}:
|
|
normalized_type = "streamable-http"
|
|
elif lowered in {"stdio", "standard"}:
|
|
normalized_type = "stdio"
|
|
else:
|
|
normalized_type = lowered
|
|
|
|
oauth_data = transport_data.get("oauth") or {}
|
|
oauth_config = None
|
|
if oauth_data:
|
|
oauth_config = OAuthClientCredentialsConfig(
|
|
token_url=oauth_data.get("token_url") or oauth_data.get("tokenUrl"),
|
|
client_id=oauth_data.get("client_id") or oauth_data.get("clientId"),
|
|
client_secret=oauth_data.get("client_secret") or oauth_data.get("clientSecret"),
|
|
scope=oauth_data.get("scope"),
|
|
audience=oauth_data.get("audience") or oauth_data.get("resource"),
|
|
extra_params=oauth_data.get("extra_params") or oauth_data.get("extraParams", {}),
|
|
access_token=oauth_data.get("access_token") or oauth_data.get("accessToken"),
|
|
)
|
|
|
|
transport_url = transport_data.get("url") or server_config.get("url")
|
|
headers = transport_data.get("headers") or server_config.get("headers") or {}
|
|
timeout = (
|
|
transport_data.get("timeout")
|
|
or server_config.get("timeout")
|
|
or 30.0
|
|
)
|
|
sse_timeout = (
|
|
transport_data.get("sse_timeout")
|
|
or transport_data.get("sseTimeout")
|
|
or server_config.get("sse_timeout")
|
|
or server_config.get("sseTimeout")
|
|
or 300.0
|
|
)
|
|
|
|
# Legacy args: ["url", "http://..."] or ["--url=http://..."]
|
|
args = server_config.get("args", [])
|
|
if not transport_url and isinstance(args, list):
|
|
for idx, arg in enumerate(args):
|
|
if not isinstance(arg, str):
|
|
continue
|
|
stripped = arg.strip()
|
|
if stripped in {"url", "--url", "--transport-url"}:
|
|
if idx + 1 < len(args):
|
|
transport_url = args[idx + 1]
|
|
break
|
|
elif stripped.startswith(("url=", "--url=", "--transport-url=")):
|
|
transport_url = stripped.split("=", 1)[1]
|
|
break
|
|
|
|
transport_config = TransportConfig(
|
|
type=normalized_type,
|
|
url=transport_url,
|
|
headers=headers,
|
|
timeout=timeout,
|
|
sse_timeout=sse_timeout,
|
|
oauth=oauth_config
|
|
)
|
|
servers[name] = MCPServerConfig(
|
|
command=server_config.get("command"),
|
|
args=server_config.get("args", []),
|
|
env=server_config.get("env", {}),
|
|
name=server_config.get("name", name),
|
|
description=server_config.get("description"),
|
|
enabled=server_config.get("enabled", True),
|
|
transport=transport_config
|
|
)
|
|
|
|
self._config = MCPBrowserConfig(
|
|
servers=servers,
|
|
default_server=config_data.get("default_server"),
|
|
sparse_mode=config_data.get("sparse_mode", True),
|
|
debug=config_data.get("debug", False),
|
|
buffer_size=config_data.get("buffer_size", 65536),
|
|
timeout=config_data.get("timeout", 30.0),
|
|
enable_builtin_servers=config_data.get("enable_builtin_servers", True)
|
|
)
|
|
|
|
return self._config
|
|
|
|
def _merge_configs(self, base: Dict[str, Any], override: Dict[str, Any]):
|
|
"""Merge override config into base config."""
|
|
for key, value in override.items():
|
|
if key in base and isinstance(base[key], dict) and isinstance(value, dict):
|
|
self._merge_configs(base[key], value)
|
|
else:
|
|
base[key] = value
|