chore: sync from fn-registry agent
This commit is contained in:
@@ -0,0 +1,201 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
|
||||
"fn-registry/registry"
|
||||
)
|
||||
|
||||
// findRegistryRoot walks up from cwd looking for registry.db. Tests skip if
|
||||
// the registry isn't reachable (e.g. running outside the workspace).
|
||||
func findRegistryRoot(t *testing.T) string {
|
||||
t.Helper()
|
||||
dir, err := os.Getwd()
|
||||
if err != nil {
|
||||
t.Fatalf("getwd: %v", err)
|
||||
}
|
||||
for {
|
||||
// Treat as the root only if both registry.db AND the registry/ pkg
|
||||
// dir are present — avoids picking up stray *.db inside apps/.
|
||||
_, dbErr := os.Stat(filepath.Join(dir, "registry.db"))
|
||||
_, pkgErr := os.Stat(filepath.Join(dir, "registry", "models.go"))
|
||||
if dbErr == nil && pkgErr == nil {
|
||||
return dir
|
||||
}
|
||||
parent := filepath.Dir(dir)
|
||||
if parent == dir {
|
||||
break
|
||||
}
|
||||
dir = parent
|
||||
}
|
||||
t.Skip("registry.db not found in any parent — skipping integration test")
|
||||
return ""
|
||||
}
|
||||
|
||||
// pipePair returns a pair of (clientWriter, serverReader, serverWriter, clientReader).
|
||||
type pipePair struct {
|
||||
clientToServer *io.PipeWriter
|
||||
serverFromClient *io.PipeReader
|
||||
serverToClient *io.PipeWriter
|
||||
clientFromServer *io.PipeReader
|
||||
}
|
||||
|
||||
func newPipePair() *pipePair {
|
||||
c2sR, c2sW := io.Pipe()
|
||||
s2cR, s2cW := io.Pipe()
|
||||
return &pipePair{
|
||||
clientToServer: c2sW,
|
||||
serverFromClient: c2sR,
|
||||
serverToClient: s2cW,
|
||||
clientFromServer: s2cR,
|
||||
}
|
||||
}
|
||||
|
||||
// startServer wires the MCP server to the in-memory pipes and runs it in a
|
||||
// goroutine. Returns a wait-fn so tests can shut it down deterministically.
|
||||
func startServer(t *testing.T, root string, p *pipePair) func() {
|
||||
t.Helper()
|
||||
dbPath := filepath.Join(root, "registry.db")
|
||||
t.Logf("opening registry at %s", dbPath)
|
||||
db, err := registry.Open(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("open registry: %v", err)
|
||||
}
|
||||
var count int
|
||||
if err := db.Conn().QueryRow("SELECT COUNT(*) FROM functions").Scan(&count); err != nil {
|
||||
t.Fatalf("count functions: %v", err)
|
||||
}
|
||||
t.Logf("functions in db: %d", count)
|
||||
srv := server.NewMCPServer("registry_mcp_test", "0.0.0", server.WithToolCapabilities(true))
|
||||
registerTools(srv, db, root, config{})
|
||||
|
||||
stdio := server.NewStdioServer(srv)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = stdio.Listen(ctx, p.serverFromClient, p.serverToClient)
|
||||
}()
|
||||
|
||||
return func() {
|
||||
cancel()
|
||||
_ = p.clientToServer.Close()
|
||||
_ = p.serverToClient.Close()
|
||||
wg.Wait()
|
||||
db.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func sendJSON(t *testing.T, w io.Writer, msg map[string]any) {
|
||||
t.Helper()
|
||||
b, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal: %v", err)
|
||||
}
|
||||
if _, err := w.Write(append(b, '\n')); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func recvJSONUntilID(t *testing.T, r *bufio.Reader, id int) map[string]any {
|
||||
t.Helper()
|
||||
for {
|
||||
line, err := r.ReadBytes('\n')
|
||||
if err != nil {
|
||||
t.Fatalf("read: %v", err)
|
||||
}
|
||||
var msg map[string]any
|
||||
if err := json.Unmarshal(line, &msg); err != nil {
|
||||
t.Fatalf("unmarshal %s: %v", line, err)
|
||||
}
|
||||
if got, ok := msg["id"].(float64); ok && int(got) == id {
|
||||
return msg
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_StdioListSearchShow(t *testing.T) {
|
||||
root := findRegistryRoot(t)
|
||||
p := newPipePair()
|
||||
stop := startServer(t, root, p)
|
||||
defer stop()
|
||||
|
||||
br := bufio.NewReader(p.clientFromServer)
|
||||
|
||||
sendJSON(t, p.clientToServer, map[string]any{
|
||||
"jsonrpc": "2.0", "id": 1, "method": "initialize",
|
||||
"params": map[string]any{
|
||||
"protocolVersion": "2025-06-18",
|
||||
"capabilities": map[string]any{},
|
||||
"clientInfo": map[string]any{"name": "test", "version": "0"},
|
||||
},
|
||||
})
|
||||
_ = recvJSONUntilID(t, br, 1)
|
||||
|
||||
sendJSON(t, p.clientToServer, map[string]any{
|
||||
"jsonrpc": "2.0", "method": "notifications/initialized",
|
||||
})
|
||||
|
||||
sendJSON(t, p.clientToServer, map[string]any{
|
||||
"jsonrpc": "2.0", "id": 2, "method": "tools/list",
|
||||
})
|
||||
listMsg := recvJSONUntilID(t, br, 2)
|
||||
tools := listMsg["result"].(map[string]any)["tools"].([]any)
|
||||
if len(tools) < 6 {
|
||||
t.Errorf("expected >=6 read-only tools, got %d", len(tools))
|
||||
}
|
||||
names := map[string]bool{}
|
||||
for _, tt := range tools {
|
||||
names[tt.(map[string]any)["name"].(string)] = true
|
||||
}
|
||||
for _, want := range []string{"fn_search", "fn_show", "fn_code", "fn_list_domains", "fn_uses", "fn_doctor"} {
|
||||
if !names[want] {
|
||||
t.Errorf("tools list missing %q", want)
|
||||
}
|
||||
}
|
||||
|
||||
sendJSON(t, p.clientToServer, map[string]any{
|
||||
"jsonrpc": "2.0", "id": 3, "method": "tools/call",
|
||||
"params": map[string]any{
|
||||
"name": "fn_search",
|
||||
"arguments": map[string]any{"query": "filter slice", "lang": "go", "purity": "pure", "limit": 5},
|
||||
},
|
||||
})
|
||||
searchMsg := recvJSONUntilID(t, br, 3)
|
||||
text := searchMsg["result"].(map[string]any)["content"].([]any)[0].(map[string]any)["text"].(string)
|
||||
if !strings.Contains(text, "filter_slice_go_core") {
|
||||
t.Errorf("expected filter_slice_go_core in search results:\n%s", text)
|
||||
}
|
||||
|
||||
sendJSON(t, p.clientToServer, map[string]any{
|
||||
"jsonrpc": "2.0", "id": 4, "method": "tools/call",
|
||||
"params": map[string]any{
|
||||
"name": "fn_show",
|
||||
"arguments": map[string]any{"id": "filter_slice_go_core"},
|
||||
},
|
||||
})
|
||||
showMsg := recvJSONUntilID(t, br, 4)
|
||||
text = showMsg["result"].(map[string]any)["content"].([]any)[0].(map[string]any)["text"].(string)
|
||||
if !strings.Contains(text, "filter_slice_go_core") || !strings.Contains(text, "```go") {
|
||||
t.Errorf("show result missing markdown:\n%s", text[:min(400, len(text))])
|
||||
}
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
Reference in New Issue
Block a user