Files
kanboard/mcp_wrapper.py

359 lines
13 KiB
Python

"""Helpers to build and manage MCPTools instances from JSON configuration."""
from __future__ import annotations
import json
import shlex
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from agno.tools.mcp import MCPTools
class MCPConfigError(ValueError):
"""Raised when the provided MCP configuration is invalid."""
@dataclass
class MCPServerDefinition:
"""Normalized representation of an MCP server entry."""
name: str
command: Optional[str] = None
args: List[str] = field(default_factory=list)
transport: str = "stdio"
url: Optional[str] = None
env: Dict[str, str] = field(default_factory=dict)
include_tools: Optional[List[str]] = None
exclude_tools: Optional[List[str]] = None
timeout_seconds: Optional[int] = None
refresh_connection: Optional[bool] = None
enabled: bool = True
@classmethod
def from_dict(cls, name: str, data: Dict[str, Any]) -> "MCPServerDefinition":
if not isinstance(data, dict):
raise MCPConfigError(
f"Configuration for MCP server '{name}' must be a JSON object."
)
command = data.get("command")
args = data.get("args", [])
transport = data.get("transport", "stdio")
if args is None:
args = []
if not isinstance(args, list):
raise MCPConfigError(
f"The 'args' field for MCP server '{name}' must be an array of strings."
)
args = [str(arg) for arg in args]
env = data.get("env") or {}
if not isinstance(env, dict):
raise MCPConfigError(
f"The 'env' field for MCP server '{name}' must be an object of key/value pairs."
)
env = {str(key): str(value) for key, value in env.items()}
include_tools = data.get("include_tools")
if include_tools is not None and not isinstance(include_tools, list):
raise MCPConfigError(
f"The 'include_tools' field for MCP server '{name}' must be an array."
)
exclude_tools = data.get("exclude_tools")
if exclude_tools is not None and not isinstance(exclude_tools, list):
raise MCPConfigError(
f"The 'exclude_tools' field for MCP server '{name}' must be an array."
)
timeout_seconds = data.get("timeout_seconds")
if timeout_seconds is not None:
timeout_seconds = int(timeout_seconds)
refresh_connection = data.get("refresh_connection")
if refresh_connection is not None:
refresh_connection = bool(refresh_connection)
enabled = bool(data.get("enabled", True))
return cls(
name=name,
command=command,
args=args,
transport=str(transport or "stdio"),
url=data.get("url"),
env=env,
include_tools=[str(item) for item in include_tools] if include_tools else None,
exclude_tools=[str(item) for item in exclude_tools] if exclude_tools else None,
timeout_seconds=timeout_seconds,
refresh_connection=refresh_connection,
enabled=enabled,
)
def to_kwargs(self) -> Dict[str, Any]:
kwargs: Dict[str, Any] = {"transport": self.transport or "stdio"}
if self.command:
command_parts = [self.command, *self.args]
# MCPTools accepts the entire command as a string; shlex avoids shell injection.
kwargs["command"] = shlex.join([str(part) for part in command_parts if part])
elif (self.transport or "stdio") == "stdio":
raise MCPConfigError(
f"MCP server '{self.name}' must define a 'command' when using stdio transport."
)
if self.url:
kwargs["url"] = self.url
if self.env:
kwargs["env"] = self.env
if self.include_tools is not None:
kwargs["include_tools"] = self.include_tools
if self.exclude_tools is not None:
kwargs["exclude_tools"] = self.exclude_tools
if self.timeout_seconds is not None:
kwargs["timeout_seconds"] = self.timeout_seconds
if self.refresh_connection is not None:
kwargs["refresh_connection"] = self.refresh_connection
return kwargs
class MCPServerManager:
"""Utility to load MCP server definitions and instantiate MCPTools."""
def __init__(
self,
config_source: Union[str, bytes, Dict[str, Any], Path],
*,
logger=None,
) -> None:
self.logger = logger
self._raw_config = self._load_config(config_source)
self._servers = self._build_servers(self._raw_config)
@staticmethod
def _load_config(config_source: Union[str, bytes, Dict[str, Any], Path]) -> Dict[str, Any]:
if isinstance(config_source, Path):
config_source = config_source.read_text(encoding="utf-8")
if isinstance(config_source, (str, bytes)):
config_text = config_source.decode() if isinstance(config_source, bytes) else config_source
try:
loaded = json.loads(config_text)
except json.JSONDecodeError as exc:
raise MCPConfigError(f"Invalid MCP configuration JSON: {exc}") from exc
return MCPServerManager._validate_root(loaded)
if isinstance(config_source, dict):
return MCPServerManager._validate_root(config_source)
raise MCPConfigError("Unsupported MCP configuration source type.")
@staticmethod
def _validate_root(config: Dict[str, Any]) -> Dict[str, Any]:
if not isinstance(config, dict):
raise MCPConfigError("MCP configuration root must be a JSON object.")
if "mcpServers" not in config:
raise MCPConfigError("MCP configuration must include a 'mcpServers' object.")
if not isinstance(config["mcpServers"], dict):
raise MCPConfigError("The 'mcpServers' field must be a JSON object.")
return config
@staticmethod
def _build_servers(config: Dict[str, Any]) -> List[MCPServerDefinition]:
servers = []
for name, data in config.get("mcpServers", {}).items():
definition = MCPServerDefinition.from_dict(str(name), data)
servers.append(definition)
return servers
def build_tools(
self,
only_servers: Optional[Iterable[str]] = None,
*,
include_disabled: bool = False,
) -> Tuple[List[MCPTools], List[str]]:
target = {str(name) for name in only_servers} if only_servers else None
tools: List[MCPTools] = []
active_servers: List[str] = []
for definition in self._servers:
if target and definition.name not in target:
continue
if not definition.enabled and not include_disabled:
if self.logger:
self.logger.info(
"⏭️ Skipping MCP server because it is disabled",
add_fields={
"mcp_server": definition.name,
"agent_call": {
"action": "load_mcp_server",
"mcp_server": definition.name,
},
"agent_response": {
"status": "skipped",
"reason": "disabled",
},
},
)
continue
try:
tool = MCPTools(**definition.to_kwargs())
setattr(tool, "mcp_server_id", definition.name)
setattr(tool, "mcp_server_definition", definition)
tools.append(tool)
active_servers.append(definition.name)
if self.logger:
self.logger.info(
"⚙️ Configured MCP server",
add_fields={
"mcp_server": definition.name,
"transport": definition.transport,
"has_url": bool(definition.url),
"agent_call": {
"action": "configure_mcp_server",
"mcp_server": definition.name,
},
"agent_response": {
"status": "configured",
"transport": definition.transport,
"has_url": bool(definition.url),
},
},
)
except Exception as exc:
if self.logger:
self.logger.error(
"🚨 Failed to configure MCP server",
add_fields={
"mcp_server": definition.name,
"error": str(exc),
"agent_call": {
"action": "configure_mcp_server",
"mcp_server": definition.name,
},
"agent_response": {
"status": "error",
"error": str(exc),
},
},
)
raise
return tools, active_servers
@property
def server_names(self) -> List[str]:
return [definition.name for definition in self._servers]
def load_mcp_tools(
config_source: Union[str, bytes, Dict[str, Any], Path],
*,
logger=None,
only_servers: Optional[Iterable[str]] = None,
include_disabled: bool = False,
) -> Tuple[List[MCPTools], List[str]]:
"""Convenience helper to build MCPTools instances from a configuration source."""
manager = MCPServerManager(config_source, logger=logger)
return manager.build_tools(only_servers=only_servers, include_disabled=include_disabled)
async def initialize_mcp_tools(
tools: Iterable[MCPTools],
*,
logger=None,
) -> Dict[str, List[str]]:
"""Connect to each MCP server and return the available tool names."""
server_tool_map: Dict[str, List[str]] = {}
for toolkit in tools:
server_id = getattr(toolkit, "mcp_server_id", toolkit.name)
try:
await toolkit.connect()
tool_names = sorted(toolkit.functions.keys())
server_tool_map[server_id] = tool_names
if logger:
logger.info(
"🔌 MCP server conectado",
add_fields={
"mcp_server": server_id,
"available_tools": tool_names,
"tool_count": len(tool_names),
"agent_call": {
"action": "connect_mcp_server",
"mcp_server": server_id,
},
"agent_response": {
"status": "connected",
"available_tools": tool_names,
"tool_count": len(tool_names),
},
},
)
except Exception as exc:
if logger:
logger.error(
"🚨 Error conectando con MCP",
add_fields={
"mcp_server": server_id,
"error": str(exc),
"agent_call": {
"action": "connect_mcp_server",
"mcp_server": server_id,
},
"agent_response": {
"status": "error",
"error": str(exc),
},
},
)
raise
return server_tool_map
async def close_mcp_tools(tools: Iterable[MCPTools], *, logger=None) -> None:
"""Close MCP connections gracefully."""
for toolkit in tools:
server_id = getattr(toolkit, "mcp_server_id", toolkit.name)
try:
await toolkit.close()
if logger:
logger.debug(
"🔻 MCP server desconectado",
add_fields={
"mcp_server": server_id,
"agent_call": {
"action": "close_mcp_server",
"mcp_server": server_id,
},
"agent_response": {
"status": "disconnected",
},
},
)
except Exception as exc:
if logger:
logger.warning(
"⚠️ Error cerrando conexión MCP",
add_fields={
"mcp_server": server_id,
"error": str(exc),
"agent_call": {
"action": "close_mcp_server",
"mcp_server": server_id,
},
"agent_response": {
"status": "error",
"error": str(exc),
},
},
)