Files
registry_mcp/tool_create_function.go
T
2026-05-09 13:29:32 +02:00

265 lines
9.2 KiB
Go

package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"github.com/mark3labs/mcp-go/mcp"
)
type createFunctionArgs struct {
Name string `json:"name"`
Lang string `json:"lang"`
Domain string `json:"domain"`
Kind string `json:"kind,omitempty"`
Purity string `json:"purity"`
Signature string `json:"signature,omitempty"`
Description string `json:"description"`
Tags []string `json:"tags,omitempty"`
UsesFunctions []string `json:"uses_functions,omitempty"`
UsesTypes []string `json:"uses_types,omitempty"`
Returns []string `json:"returns,omitempty"`
ErrorType string `json:"error_type,omitempty"`
Code string `json:"code"`
MarkdownBody string `json:"markdown_body,omitempty"`
Example string `json:"example,omitempty"`
Params []param `json:"params,omitempty"`
Output string `json:"output,omitempty"`
Overwrite bool `json:"overwrite,omitempty"`
SkipIndex bool `json:"skip_index,omitempty"`
}
type param struct {
Name string `json:"name"`
Desc string `json:"desc"`
}
func createFunctionTool() mcp.Tool {
return mcp.NewTool("fn_create_function",
mcp.WithDescription("Create a new registry function: writes the source file (.go/.py/.sh/.ts) and the .md (frontmatter + docs) at the canonical path for {lang,domain,name}, then runs `fn index`. Returns the new ID. Off by default — server must be launched with --enable-write. Use this to iterate as a fn-constructor: pass spec + code, get registry entry."),
mcp.WithString("name", mcp.Required(), mcp.Description("snake_case function name.")),
mcp.WithString("lang", mcp.Required(),
mcp.Description("Language: go, py, bash, ts."),
mcp.Enum("go", "py", "bash", "ts"),
),
mcp.WithString("domain", mcp.Required(), mcp.Description("Registry domain (core, infra, finance, datascience, cybersecurity, shell, tui, pipelines, browser, ...).")),
mcp.WithString("kind", mcp.Description("function (default), pipeline or component.")),
mcp.WithString("purity", mcp.Required(),
mcp.Description("pure or impure. Pipelines must be impure."),
mcp.Enum("pure", "impure"),
),
mcp.WithString("signature", mcp.Description("Function signature (e.g. 'func FilterSlice[T any](xs []T, pred func(T) bool) []T').")),
mcp.WithString("description", mcp.Required(), mcp.Description("One-line description for the registry index.")),
mcp.WithArray("tags", mcp.Description("Tags."), mcp.Items(map[string]any{"type": "string"})),
mcp.WithArray("uses_functions", mcp.Description("Registry IDs this function calls."), mcp.Items(map[string]any{"type": "string"})),
mcp.WithArray("uses_types", mcp.Description("Registry types this function uses."), mcp.Items(map[string]any{"type": "string"})),
mcp.WithArray("returns", mcp.Description("Registry type IDs returned (NOT native types)."), mcp.Items(map[string]any{"type": "string"})),
mcp.WithString("error_type", mcp.Description("Error type ID (impure only). Usually 'error_go_core'.")),
mcp.WithString("code", mcp.Required(), mcp.Description("Source code of the function (full file contents).")),
mcp.WithString("markdown_body", mcp.Description("Documentation body appended after frontmatter.")),
mcp.WithString("example", mcp.Description("Inline example shown in markdown.")),
mcp.WithArray("params",
mcp.Description("Param semantics: [{name, desc}]."),
mcp.Items(map[string]any{
"type": "object",
"properties": map[string]any{
"name": map[string]any{"type": "string"},
"desc": map[string]any{"type": "string"},
},
"required": []string{"name", "desc"},
}),
),
mcp.WithString("output", mcp.Description("Output semantics for params_schema.")),
mcp.WithBoolean("overwrite", mcp.Description("If true, overwrite existing files. Default false (errors if files exist).")),
mcp.WithBoolean("skip_index", mcp.Description("If true, skip the `fn index` invocation (caller will run it).")),
)
}
func (d *deps) handleCreateFunction(ctx context.Context, _ mcp.CallToolRequest, a createFunctionArgs) (*mcp.CallToolResult, error) {
if err := validateCreateFunctionArgs(&a); err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
srcDir, srcFile, mdFile := canonicalPaths(d.root, a)
if err := os.MkdirAll(srcDir, 0o755); err != nil {
return mcp.NewToolResultError("mkdir: " + err.Error()), nil
}
srcPath := filepath.Join(srcDir, srcFile)
mdPath := filepath.Join(srcDir, mdFile)
if !a.Overwrite {
if _, err := os.Stat(srcPath); err == nil {
return mcp.NewToolResultError("file exists (set overwrite=true): " + srcPath), nil
}
if _, err := os.Stat(mdPath); err == nil {
return mcp.NewToolResultError("file exists (set overwrite=true): " + mdPath), nil
}
}
if err := os.WriteFile(srcPath, []byte(a.Code), 0o644); err != nil {
return mcp.NewToolResultError("write code: " + err.Error()), nil
}
md := buildFrontmatter(&a, srcPath, d.root)
if err := os.WriteFile(mdPath, []byte(md), 0o644); err != nil {
return mcp.NewToolResultError("write md: " + err.Error()), nil
}
out := map[string]any{
"id": fmt.Sprintf("%s_%s_%s", a.Name, a.Lang, a.Domain),
"source_file": relTo(d.root, srcPath),
"markdown_file": relTo(d.root, mdPath),
"overwrote": a.Overwrite,
}
if !a.SkipIndex {
bin := d.fnBin()
cmd := exec.CommandContext(ctx, bin, "index")
cmd.Dir = d.root
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
out["index_error"] = err.Error()
out["index_stderr"] = stderr.String()
} else {
out["indexed"] = true
}
}
b, _ := json.MarshalIndent(out, "", " ")
return mcp.NewToolResultText(string(b)), nil
}
func validateCreateFunctionArgs(a *createFunctionArgs) error {
if a.Name == "" || a.Lang == "" || a.Domain == "" || a.Description == "" || a.Code == "" {
return fmt.Errorf("name, lang, domain, description, code are required")
}
if !isSnakeCase(a.Name) {
return fmt.Errorf("name must be snake_case (lowercase + digits + underscores), got %q", a.Name)
}
switch a.Lang {
case "go", "py", "bash", "ts":
default:
return fmt.Errorf("unsupported lang: %s", a.Lang)
}
switch a.Purity {
case "pure":
if a.ErrorType != "" {
return fmt.Errorf("pure functions must not have error_type")
}
case "impure":
if a.ErrorType == "" {
return fmt.Errorf("impure functions must declare error_type (e.g. error_go_core)")
}
default:
return fmt.Errorf("purity must be pure or impure")
}
if a.Kind == "pipeline" && a.Purity != "impure" {
return fmt.Errorf("pipeline must be impure")
}
return nil
}
func isSnakeCase(s string) bool {
if s == "" {
return false
}
for _, r := range s {
if !((r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '_') {
return false
}
}
return true
}
// canonicalPaths returns (sourceDir, sourceFile, mdFile) for {lang, domain, name}.
func canonicalPaths(root string, a createFunctionArgs) (string, string, string) {
switch a.Lang {
case "go":
return filepath.Join(root, "functions", a.Domain), a.Name + ".go", a.Name + ".md"
case "py":
return filepath.Join(root, "python", "functions", a.Domain), a.Name + ".py", a.Name + ".md"
case "bash":
return filepath.Join(root, "bash", "functions", a.Domain), a.Name + ".sh", a.Name + ".md"
case "ts":
return filepath.Join(root, "frontend", "functions", a.Domain), a.Name + ".ts", a.Name + ".md"
}
return root, a.Name, a.Name + ".md"
}
func relTo(root, p string) string {
r, err := filepath.Rel(root, p)
if err != nil {
return p
}
return r
}
func buildFrontmatter(a *createFunctionArgs, srcPath, root string) string {
var b strings.Builder
b.WriteString("---\n")
fmt.Fprintf(&b, "name: %s\n", a.Name)
if a.Kind != "" {
fmt.Fprintf(&b, "kind: %s\n", a.Kind)
} else {
b.WriteString("kind: function\n")
}
fmt.Fprintf(&b, "lang: %s\n", a.Lang)
fmt.Fprintf(&b, "domain: %s\n", a.Domain)
b.WriteString("version: 1.0.0\n")
fmt.Fprintf(&b, "purity: %s\n", a.Purity)
if a.Signature != "" {
fmt.Fprintf(&b, "signature: %q\n", a.Signature)
}
fmt.Fprintf(&b, "description: %q\n", a.Description)
writeYAMLList(&b, "tags", a.Tags)
writeYAMLList(&b, "uses_functions", a.UsesFunctions)
writeYAMLList(&b, "uses_types", a.UsesTypes)
writeYAMLList(&b, "returns", a.Returns)
if a.Purity == "pure" {
b.WriteString("returns_optional: false\n")
b.WriteString("error_type: \"\"\n")
} else {
fmt.Fprintf(&b, "error_type: %s\n", a.ErrorType)
}
if len(a.Params) > 0 {
b.WriteString("params:\n")
for _, p := range a.Params {
fmt.Fprintf(&b, " - name: %s\n desc: %q\n", p.Name, p.Desc)
}
}
if a.Output != "" {
fmt.Fprintf(&b, "output: %q\n", a.Output)
}
fmt.Fprintf(&b, "file_path: %s\n", relTo(root, srcPath))
b.WriteString("---\n\n")
if a.MarkdownBody != "" {
b.WriteString(a.MarkdownBody)
b.WriteString("\n")
} else {
fmt.Fprintf(&b, "%s\n", a.Description)
}
if a.Example != "" {
fmt.Fprintf(&b, "\n## example\n\n```%s\n%s\n```\n", langFence(a.Lang), a.Example)
}
return b.String()
}
func writeYAMLList(b *strings.Builder, key string, items []string) {
if len(items) == 0 {
fmt.Fprintf(b, "%s: []\n", key)
return
}
fmt.Fprintf(b, "%s:\n", key)
for _, it := range items {
fmt.Fprintf(b, " - %s\n", it)
}
}