359 lines
13 KiB
Python
359 lines
13 KiB
Python
"""Helpers to build and manage MCPTools instances from JSON configuration."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import shlex
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
|
|
|
from agno.tools.mcp import MCPTools
|
|
|
|
|
|
class MCPConfigError(ValueError):
|
|
"""Raised when the provided MCP configuration is invalid."""
|
|
|
|
|
|
@dataclass
|
|
class MCPServerDefinition:
|
|
"""Normalized representation of an MCP server entry."""
|
|
|
|
name: str
|
|
command: Optional[str] = None
|
|
args: List[str] = field(default_factory=list)
|
|
transport: str = "stdio"
|
|
url: Optional[str] = None
|
|
env: Dict[str, str] = field(default_factory=dict)
|
|
include_tools: Optional[List[str]] = None
|
|
exclude_tools: Optional[List[str]] = None
|
|
timeout_seconds: Optional[int] = None
|
|
refresh_connection: Optional[bool] = None
|
|
enabled: bool = True
|
|
|
|
@classmethod
|
|
def from_dict(cls, name: str, data: Dict[str, Any]) -> "MCPServerDefinition":
|
|
if not isinstance(data, dict):
|
|
raise MCPConfigError(
|
|
f"Configuration for MCP server '{name}' must be a JSON object."
|
|
)
|
|
|
|
command = data.get("command")
|
|
args = data.get("args", [])
|
|
transport = data.get("transport", "stdio")
|
|
|
|
if args is None:
|
|
args = []
|
|
if not isinstance(args, list):
|
|
raise MCPConfigError(
|
|
f"The 'args' field for MCP server '{name}' must be an array of strings."
|
|
)
|
|
args = [str(arg) for arg in args]
|
|
|
|
env = data.get("env") or {}
|
|
if not isinstance(env, dict):
|
|
raise MCPConfigError(
|
|
f"The 'env' field for MCP server '{name}' must be an object of key/value pairs."
|
|
)
|
|
env = {str(key): str(value) for key, value in env.items()}
|
|
|
|
include_tools = data.get("include_tools")
|
|
if include_tools is not None and not isinstance(include_tools, list):
|
|
raise MCPConfigError(
|
|
f"The 'include_tools' field for MCP server '{name}' must be an array."
|
|
)
|
|
|
|
exclude_tools = data.get("exclude_tools")
|
|
if exclude_tools is not None and not isinstance(exclude_tools, list):
|
|
raise MCPConfigError(
|
|
f"The 'exclude_tools' field for MCP server '{name}' must be an array."
|
|
)
|
|
|
|
timeout_seconds = data.get("timeout_seconds")
|
|
if timeout_seconds is not None:
|
|
timeout_seconds = int(timeout_seconds)
|
|
|
|
refresh_connection = data.get("refresh_connection")
|
|
if refresh_connection is not None:
|
|
refresh_connection = bool(refresh_connection)
|
|
|
|
enabled = bool(data.get("enabled", True))
|
|
|
|
return cls(
|
|
name=name,
|
|
command=command,
|
|
args=args,
|
|
transport=str(transport or "stdio"),
|
|
url=data.get("url"),
|
|
env=env,
|
|
include_tools=[str(item) for item in include_tools] if include_tools else None,
|
|
exclude_tools=[str(item) for item in exclude_tools] if exclude_tools else None,
|
|
timeout_seconds=timeout_seconds,
|
|
refresh_connection=refresh_connection,
|
|
enabled=enabled,
|
|
)
|
|
|
|
def to_kwargs(self) -> Dict[str, Any]:
|
|
kwargs: Dict[str, Any] = {"transport": self.transport or "stdio"}
|
|
|
|
if self.command:
|
|
command_parts = [self.command, *self.args]
|
|
# MCPTools accepts the entire command as a string; shlex avoids shell injection.
|
|
kwargs["command"] = shlex.join([str(part) for part in command_parts if part])
|
|
elif (self.transport or "stdio") == "stdio":
|
|
raise MCPConfigError(
|
|
f"MCP server '{self.name}' must define a 'command' when using stdio transport."
|
|
)
|
|
|
|
if self.url:
|
|
kwargs["url"] = self.url
|
|
if self.env:
|
|
kwargs["env"] = self.env
|
|
if self.include_tools is not None:
|
|
kwargs["include_tools"] = self.include_tools
|
|
if self.exclude_tools is not None:
|
|
kwargs["exclude_tools"] = self.exclude_tools
|
|
if self.timeout_seconds is not None:
|
|
kwargs["timeout_seconds"] = self.timeout_seconds
|
|
if self.refresh_connection is not None:
|
|
kwargs["refresh_connection"] = self.refresh_connection
|
|
|
|
return kwargs
|
|
|
|
|
|
class MCPServerManager:
|
|
"""Utility to load MCP server definitions and instantiate MCPTools."""
|
|
|
|
def __init__(
|
|
self,
|
|
config_source: Union[str, bytes, Dict[str, Any], Path],
|
|
*,
|
|
logger=None,
|
|
) -> None:
|
|
self.logger = logger
|
|
self._raw_config = self._load_config(config_source)
|
|
self._servers = self._build_servers(self._raw_config)
|
|
|
|
@staticmethod
|
|
def _load_config(config_source: Union[str, bytes, Dict[str, Any], Path]) -> Dict[str, Any]:
|
|
if isinstance(config_source, Path):
|
|
config_source = config_source.read_text(encoding="utf-8")
|
|
|
|
if isinstance(config_source, (str, bytes)):
|
|
config_text = config_source.decode() if isinstance(config_source, bytes) else config_source
|
|
try:
|
|
loaded = json.loads(config_text)
|
|
except json.JSONDecodeError as exc:
|
|
raise MCPConfigError(f"Invalid MCP configuration JSON: {exc}") from exc
|
|
return MCPServerManager._validate_root(loaded)
|
|
|
|
if isinstance(config_source, dict):
|
|
return MCPServerManager._validate_root(config_source)
|
|
|
|
raise MCPConfigError("Unsupported MCP configuration source type.")
|
|
|
|
@staticmethod
|
|
def _validate_root(config: Dict[str, Any]) -> Dict[str, Any]:
|
|
if not isinstance(config, dict):
|
|
raise MCPConfigError("MCP configuration root must be a JSON object.")
|
|
if "mcpServers" not in config:
|
|
raise MCPConfigError("MCP configuration must include a 'mcpServers' object.")
|
|
if not isinstance(config["mcpServers"], dict):
|
|
raise MCPConfigError("The 'mcpServers' field must be a JSON object.")
|
|
return config
|
|
|
|
@staticmethod
|
|
def _build_servers(config: Dict[str, Any]) -> List[MCPServerDefinition]:
|
|
servers = []
|
|
for name, data in config.get("mcpServers", {}).items():
|
|
definition = MCPServerDefinition.from_dict(str(name), data)
|
|
servers.append(definition)
|
|
return servers
|
|
|
|
def build_tools(
|
|
self,
|
|
only_servers: Optional[Iterable[str]] = None,
|
|
*,
|
|
include_disabled: bool = False,
|
|
) -> Tuple[List[MCPTools], List[str]]:
|
|
target = {str(name) for name in only_servers} if only_servers else None
|
|
tools: List[MCPTools] = []
|
|
active_servers: List[str] = []
|
|
|
|
for definition in self._servers:
|
|
if target and definition.name not in target:
|
|
continue
|
|
if not definition.enabled and not include_disabled:
|
|
if self.logger:
|
|
self.logger.info(
|
|
"⏭️ Skipping MCP server because it is disabled",
|
|
add_fields={
|
|
"mcp_server": definition.name,
|
|
"agent_call": {
|
|
"action": "load_mcp_server",
|
|
"mcp_server": definition.name,
|
|
},
|
|
"agent_response": {
|
|
"status": "skipped",
|
|
"reason": "disabled",
|
|
},
|
|
},
|
|
)
|
|
continue
|
|
try:
|
|
tool = MCPTools(**definition.to_kwargs())
|
|
setattr(tool, "mcp_server_id", definition.name)
|
|
setattr(tool, "mcp_server_definition", definition)
|
|
tools.append(tool)
|
|
active_servers.append(definition.name)
|
|
if self.logger:
|
|
self.logger.info(
|
|
"⚙️ Configured MCP server",
|
|
add_fields={
|
|
"mcp_server": definition.name,
|
|
"transport": definition.transport,
|
|
"has_url": bool(definition.url),
|
|
"agent_call": {
|
|
"action": "configure_mcp_server",
|
|
"mcp_server": definition.name,
|
|
},
|
|
"agent_response": {
|
|
"status": "configured",
|
|
"transport": definition.transport,
|
|
"has_url": bool(definition.url),
|
|
},
|
|
},
|
|
)
|
|
except Exception as exc:
|
|
if self.logger:
|
|
self.logger.error(
|
|
"🚨 Failed to configure MCP server",
|
|
add_fields={
|
|
"mcp_server": definition.name,
|
|
"error": str(exc),
|
|
"agent_call": {
|
|
"action": "configure_mcp_server",
|
|
"mcp_server": definition.name,
|
|
},
|
|
"agent_response": {
|
|
"status": "error",
|
|
"error": str(exc),
|
|
},
|
|
},
|
|
)
|
|
raise
|
|
|
|
return tools, active_servers
|
|
|
|
@property
|
|
def server_names(self) -> List[str]:
|
|
return [definition.name for definition in self._servers]
|
|
|
|
|
|
def load_mcp_tools(
|
|
config_source: Union[str, bytes, Dict[str, Any], Path],
|
|
*,
|
|
logger=None,
|
|
only_servers: Optional[Iterable[str]] = None,
|
|
include_disabled: bool = False,
|
|
) -> Tuple[List[MCPTools], List[str]]:
|
|
"""Convenience helper to build MCPTools instances from a configuration source."""
|
|
|
|
manager = MCPServerManager(config_source, logger=logger)
|
|
return manager.build_tools(only_servers=only_servers, include_disabled=include_disabled)
|
|
|
|
|
|
async def initialize_mcp_tools(
|
|
tools: Iterable[MCPTools],
|
|
*,
|
|
logger=None,
|
|
) -> Dict[str, List[str]]:
|
|
"""Connect to each MCP server and return the available tool names."""
|
|
|
|
server_tool_map: Dict[str, List[str]] = {}
|
|
|
|
for toolkit in tools:
|
|
server_id = getattr(toolkit, "mcp_server_id", toolkit.name)
|
|
try:
|
|
await toolkit.connect()
|
|
tool_names = sorted(toolkit.functions.keys())
|
|
server_tool_map[server_id] = tool_names
|
|
if logger:
|
|
logger.info(
|
|
"🔌 MCP server conectado",
|
|
add_fields={
|
|
"mcp_server": server_id,
|
|
"available_tools": tool_names,
|
|
"tool_count": len(tool_names),
|
|
"agent_call": {
|
|
"action": "connect_mcp_server",
|
|
"mcp_server": server_id,
|
|
},
|
|
"agent_response": {
|
|
"status": "connected",
|
|
"available_tools": tool_names,
|
|
"tool_count": len(tool_names),
|
|
},
|
|
},
|
|
)
|
|
except Exception as exc:
|
|
if logger:
|
|
logger.error(
|
|
"🚨 Error conectando con MCP",
|
|
add_fields={
|
|
"mcp_server": server_id,
|
|
"error": str(exc),
|
|
"agent_call": {
|
|
"action": "connect_mcp_server",
|
|
"mcp_server": server_id,
|
|
},
|
|
"agent_response": {
|
|
"status": "error",
|
|
"error": str(exc),
|
|
},
|
|
},
|
|
)
|
|
raise
|
|
|
|
return server_tool_map
|
|
|
|
|
|
async def close_mcp_tools(tools: Iterable[MCPTools], *, logger=None) -> None:
|
|
"""Close MCP connections gracefully."""
|
|
|
|
for toolkit in tools:
|
|
server_id = getattr(toolkit, "mcp_server_id", toolkit.name)
|
|
try:
|
|
await toolkit.close()
|
|
if logger:
|
|
logger.debug(
|
|
"🔻 MCP server desconectado",
|
|
add_fields={
|
|
"mcp_server": server_id,
|
|
"agent_call": {
|
|
"action": "close_mcp_server",
|
|
"mcp_server": server_id,
|
|
},
|
|
"agent_response": {
|
|
"status": "disconnected",
|
|
},
|
|
},
|
|
)
|
|
except Exception as exc:
|
|
if logger:
|
|
logger.warning(
|
|
"⚠️ Error cerrando conexión MCP",
|
|
add_fields={
|
|
"mcp_server": server_id,
|
|
"error": str(exc),
|
|
"agent_call": {
|
|
"action": "close_mcp_server",
|
|
"mcp_server": server_id,
|
|
},
|
|
"agent_response": {
|
|
"status": "error",
|
|
"error": str(exc),
|
|
},
|
|
},
|
|
)
|