feat(m8): MCP client, tool replaceability, and plugin system
Complete the remaining M8 extensibility deliverables:
- MCP client with JSON-RPC 2.0 over stdio transport, protocol
lifecycle (initialize/tools-list/tools-call), and process group
management for clean shutdown
- MCP tool adapter implementing tool.Tool with mcp__{server}__{tool}
naming convention and replace_default for swapping built-in tools
- MCP manager for multi-server orchestration with parallel startup,
tool discovery, and registry integration
- Plugin system with plugin.json manifest (name/version/capabilities),
directory-based discovery (global + project scopes with precedence),
loader that merges skills/hooks/MCP configs into existing registries,
and install/uninstall/list lifecycle manager
- Config additions: MCPServerConfig, PluginsSection with opt-in/opt-out
enabled/disabled resolution
- TUI /plugins command for listing installed plugins
- 54 tests across internal/mcp and internal/plugin packages
This commit is contained in:
@@ -38,6 +38,8 @@ import (
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
"somegit.dev/Owlibou/gnoma/internal/elf"
|
||||
"somegit.dev/Owlibou/gnoma/internal/mcp"
|
||||
"somegit.dev/Owlibou/gnoma/internal/plugin"
|
||||
"somegit.dev/Owlibou/gnoma/internal/tool/agent"
|
||||
"somegit.dev/Owlibou/gnoma/internal/tool/bash"
|
||||
"somegit.dev/Owlibou/gnoma/internal/tool/fs"
|
||||
@@ -346,10 +348,25 @@ func main() {
|
||||
reg.Register(agent.NewListResultsTool(store))
|
||||
reg.Register(agent.NewReadResultTool(store))
|
||||
|
||||
// Build hook dispatcher from config.
|
||||
// Discover plugins and merge their capabilities.
|
||||
pluginLoader := plugin.NewLoader(logger)
|
||||
globalPluginDir := filepath.Join(gnomacfg.GlobalConfigDir(), "plugins")
|
||||
projectPluginDir := filepath.Join(gnomacfg.ProjectRoot(), ".gnoma", "plugins")
|
||||
discoveredPlugins, err := pluginLoader.Discover(globalPluginDir, projectPluginDir)
|
||||
if err != nil {
|
||||
logger.Warn("plugin discovery error", "error", err)
|
||||
}
|
||||
enabledSet := resolveEnabledPlugins(cfg.Plugins, discoveredPlugins)
|
||||
pluginResult, err := pluginLoader.Load(discoveredPlugins, enabledSet)
|
||||
if err != nil {
|
||||
logger.Warn("plugin load error", "error", err)
|
||||
}
|
||||
|
||||
// Build hook dispatcher from config + plugin hooks.
|
||||
// Streamer adapter wraps the router for prompt hooks.
|
||||
// ElfSpawnFn closure wraps elfMgr for agent hooks.
|
||||
hookDefs, err := hook.ParseHookDefs(cfg.Hooks)
|
||||
allHooks := append(cfg.Hooks, pluginResult.Hooks...)
|
||||
hookDefs, err := hook.ParseHookDefs(allHooks)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "hook config error: %v\n", err)
|
||||
os.Exit(1)
|
||||
@@ -369,11 +386,31 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Build skill registry: bundled → user (~/.config/gnoma/skills/) → project (.gnoma/skills/)
|
||||
// Start MCP servers (config + plugin) and register tools in the tool registry.
|
||||
allMCPServers := append(cfg.MCPServers, pluginResult.MCPServers...)
|
||||
var mcpMgr *mcp.Manager
|
||||
if len(allMCPServers) > 0 {
|
||||
serverCfgs, err := mcp.ParseServerConfigs(allMCPServers)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "mcp config error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
mcpMgr = mcp.NewManager(logger)
|
||||
if err := mcpMgr.StartAll(context.Background(), serverCfgs, reg); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "mcp startup error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer mcpMgr.Shutdown()
|
||||
}
|
||||
|
||||
// Build skill registry: bundled → user → plugins → project (precedence order).
|
||||
skillReg := skill.NewRegistry()
|
||||
skillReg.LoadBundled() //nolint:errcheck
|
||||
skillReg.LoadDir(filepath.Join(gnomacfg.GlobalConfigDir(), "skills"), "user") //nolint:errcheck
|
||||
skillReg.LoadDir(filepath.Join(gnomacfg.ProjectRoot(), ".gnoma", "skills"), "project") //nolint:errcheck
|
||||
skillReg.LoadBundled() //nolint:errcheck
|
||||
skillReg.LoadDir(filepath.Join(gnomacfg.GlobalConfigDir(), "skills"), "user") //nolint:errcheck
|
||||
for _, ps := range pluginResult.Skills {
|
||||
skillReg.LoadDir(ps.Dir, ps.Source) //nolint:errcheck
|
||||
}
|
||||
skillReg.LoadDir(filepath.Join(gnomacfg.ProjectRoot(), ".gnoma", "skills"), "project") //nolint:errcheck
|
||||
|
||||
// Build system prompt with cwd + compact inventory summary
|
||||
systemPrompt := *system
|
||||
@@ -604,6 +641,7 @@ func main() {
|
||||
SessionStore: sessStore,
|
||||
StartWithResumePicker: openResumePicker,
|
||||
Skills: skillReg,
|
||||
PluginInfos: buildPluginInfos(discoveredPlugins, enabledSet),
|
||||
})
|
||||
p := tea.NewProgram(m)
|
||||
if _, err := p.Run(); err != nil {
|
||||
@@ -798,3 +836,48 @@ When a task involves 2 or more independent sub-tasks, use the spawn_elfs tool to
|
||||
- "refactor this function" → single sequential workflow (one dependent task)
|
||||
|
||||
When using spawn_elfs, list all tasks in one call — do NOT spawn one elf at a time.`
|
||||
|
||||
// buildPluginInfos converts discovered plugins into TUI display info.
|
||||
func buildPluginInfos(plugins []plugin.Plugin, enabledSet map[string]bool) []tui.PluginInfo {
|
||||
infos := make([]tui.PluginInfo, 0, len(plugins))
|
||||
for _, p := range plugins {
|
||||
infos = append(infos, tui.PluginInfo{
|
||||
Name: p.Manifest.Name,
|
||||
Version: p.Manifest.Version,
|
||||
Scope: p.Scope,
|
||||
Enabled: enabledSet[p.Manifest.Name],
|
||||
})
|
||||
}
|
||||
return infos
|
||||
}
|
||||
|
||||
// resolveEnabledPlugins determines which plugins are enabled based on config.
|
||||
// If Enabled is empty, all plugins are enabled by default (opt-out via Disabled).
|
||||
// If Enabled is non-empty, only listed plugins are enabled (opt-in).
|
||||
// Disabled always takes precedence (veto).
|
||||
func resolveEnabledPlugins(cfg gnomacfg.PluginsSection, plugins []plugin.Plugin) map[string]bool {
|
||||
disabled := make(map[string]bool, len(cfg.Disabled))
|
||||
for _, name := range cfg.Disabled {
|
||||
disabled[name] = true
|
||||
}
|
||||
|
||||
result := make(map[string]bool, len(plugins))
|
||||
|
||||
if len(cfg.Enabled) == 0 {
|
||||
// Opt-out mode: all plugins enabled unless in disabled list.
|
||||
for _, p := range plugins {
|
||||
if !disabled[p.Manifest.Name] {
|
||||
result[p.Manifest.Name] = true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Opt-in mode: only listed plugins enabled.
|
||||
for _, name := range cfg.Enabled {
|
||||
if !disabled[name] {
|
||||
result[name] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -175,10 +175,10 @@ depends_on: [vision]
|
||||
- [x] Hook command types: command (shell), prompt (LLM), agent (spawn elf)
|
||||
- [x] Skill loading from .gnoma/skills/, ~/.config/gnoma/skills/, bundled, plugins
|
||||
- [x] Skill frontmatter: YAML (name, description, whenToUse, allowedTools, paths)
|
||||
- [ ] MCP client: JSON-RPC over stdio, tool discovery
|
||||
- [ ] MCP tool naming: `mcp__{server}__{tool}`
|
||||
- [ ] MCP tool replaceability: `replace_default` config swaps built-in tools
|
||||
- [ ] Plugin system: plugin.json manifest, install/enable/disable lifecycle
|
||||
- [x] MCP client: JSON-RPC over stdio, tool discovery
|
||||
- [x] MCP tool naming: `mcp__{server}__{tool}`
|
||||
- [x] MCP tool replaceability: `replace_default` config swaps built-in tools
|
||||
- [x] Plugin system: plugin.json manifest, install/enable/disable lifecycle
|
||||
- [x] `/batch` skill: decompose work into N units, spawn all via `spawn_elfs`, track progress (CC-inspired)
|
||||
- [x] Coordinator mode prompt: fan-out guidance for parallel elf dispatch, concurrency rules (read vs write)
|
||||
|
||||
|
||||
@@ -11,6 +11,40 @@ type Config struct {
|
||||
Security SecuritySection `toml:"security"`
|
||||
Session SessionSection `toml:"session"`
|
||||
Hooks []HookConfig `toml:"hooks"`
|
||||
MCPServers []MCPServerConfig `toml:"mcp_servers"`
|
||||
Plugins PluginsSection `toml:"plugins"`
|
||||
}
|
||||
|
||||
// MCPServerConfig defines an MCP server to start and connect to.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// [[mcp_servers]]
|
||||
// name = "git"
|
||||
// command = "mcp-server-git"
|
||||
// args = ["--repo", "."]
|
||||
// env = { GIT_DIR = ".git" }
|
||||
// timeout = "30s"
|
||||
// replace_default = ["bash"]
|
||||
type MCPServerConfig struct {
|
||||
Name string `toml:"name"`
|
||||
Command string `toml:"command"`
|
||||
Args []string `toml:"args"`
|
||||
Env map[string]string `toml:"env"`
|
||||
Timeout string `toml:"timeout"`
|
||||
ReplaceDefault []string `toml:"replace_default"`
|
||||
}
|
||||
|
||||
// PluginsSection controls plugin loading.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// [plugins]
|
||||
// enabled = ["git-tools", "docker-tools"]
|
||||
// disabled = ["experimental-plugin"]
|
||||
type PluginsSection struct {
|
||||
Enabled []string `toml:"enabled"`
|
||||
Disabled []string `toml:"disabled"`
|
||||
}
|
||||
|
||||
// HookConfig is a single hook entry from TOML config.
|
||||
|
||||
128
internal/mcp/client.go
Normal file
128
internal/mcp/client.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
const protocolVersion = "2024-11-05"
|
||||
|
||||
// ServerInfo describes the MCP server identity.
|
||||
type ServerInfo struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// MCPTool is a tool definition discovered from an MCP server.
|
||||
type MCPTool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema json.RawMessage `json:"inputSchema"`
|
||||
}
|
||||
|
||||
// Client implements the MCP protocol lifecycle over a Transport.
|
||||
type Client struct {
|
||||
transport *Transport
|
||||
serverInfo ServerInfo
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewClient creates an MCP client. Call Initialize before using other methods.
|
||||
func NewClient(transport *Transport, logger *slog.Logger) *Client {
|
||||
return &Client{
|
||||
transport: transport,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize performs the MCP handshake: sends initialize request,
|
||||
// receives server info, and sends initialized notification.
|
||||
func (c *Client) Initialize(ctx context.Context) error {
|
||||
params := struct {
|
||||
ProtocolVersion string `json:"protocolVersion"`
|
||||
Capabilities struct{} `json:"capabilities"`
|
||||
ClientInfo struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
} `json:"clientInfo"`
|
||||
}{
|
||||
ProtocolVersion: protocolVersion,
|
||||
}
|
||||
params.ClientInfo.Name = "gnoma"
|
||||
params.ClientInfo.Version = "0.1.0"
|
||||
|
||||
result, err := c.transport.Call(ctx, "initialize", params)
|
||||
if err != nil {
|
||||
return fmt.Errorf("mcp initialize: %w", err)
|
||||
}
|
||||
|
||||
var initResult struct {
|
||||
ProtocolVersion string `json:"protocolVersion"`
|
||||
ServerInfo ServerInfo `json:"serverInfo"`
|
||||
}
|
||||
if err := json.Unmarshal(result, &initResult); err != nil {
|
||||
return fmt.Errorf("mcp initialize: decode result: %w", err)
|
||||
}
|
||||
|
||||
c.serverInfo = initResult.ServerInfo
|
||||
c.logger.Debug("mcp initialized",
|
||||
"server", c.serverInfo.Name,
|
||||
"version", c.serverInfo.Version,
|
||||
"protocol", initResult.ProtocolVersion,
|
||||
)
|
||||
|
||||
// Send initialized notification (no response expected).
|
||||
if err := c.transport.Notify(ctx, "initialized", nil); err != nil {
|
||||
return fmt.Errorf("mcp initialized notification: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListTools calls tools/list and returns discovered tool definitions.
|
||||
func (c *Client) ListTools(ctx context.Context) ([]MCPTool, error) {
|
||||
result, err := c.transport.Call(ctx, "tools/list", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mcp tools/list: %w", err)
|
||||
}
|
||||
|
||||
var toolsResult struct {
|
||||
Tools []MCPTool `json:"tools"`
|
||||
}
|
||||
if err := json.Unmarshal(result, &toolsResult); err != nil {
|
||||
return nil, fmt.Errorf("mcp tools/list: decode: %w", err)
|
||||
}
|
||||
|
||||
c.logger.Debug("mcp tools discovered", "count", len(toolsResult.Tools))
|
||||
return toolsResult.Tools, nil
|
||||
}
|
||||
|
||||
// CallTool invokes a tool on the MCP server and returns the raw result.
|
||||
func (c *Client) CallTool(ctx context.Context, name string, args json.RawMessage) (json.RawMessage, error) {
|
||||
params := struct {
|
||||
Name string `json:"name"`
|
||||
Arguments json.RawMessage `json:"arguments,omitempty"`
|
||||
}{
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
}
|
||||
|
||||
result, err := c.transport.Call(ctx, "tools/call", params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mcp tools/call %q: %w", name, err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ServerName returns the server's reported name.
|
||||
func (c *Client) ServerName() string {
|
||||
return c.serverInfo.Name
|
||||
}
|
||||
|
||||
// Close shuts down the transport and server process.
|
||||
func (c *Client) Close() error {
|
||||
return c.transport.Close()
|
||||
}
|
||||
219
internal/mcp/client_test.go
Normal file
219
internal/mcp/client_test.go
Normal file
@@ -0,0 +1,219 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// writeMCPServer creates a bash script that implements a minimal MCP server.
|
||||
// Response payloads are written to files to avoid bash quoting issues.
|
||||
func writeMCPServer(t *testing.T, tools []MCPTool, callResult string) string {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
|
||||
// Write response payloads as files.
|
||||
initResult := `{"protocolVersion":"2024-11-05","capabilities":{"tools":{}},"serverInfo":{"name":"test-server","version":"1.0.0"}}`
|
||||
os.WriteFile(filepath.Join(dir, "init.json"), []byte(initResult), 0o644)
|
||||
|
||||
toolsJSON, err := json.Marshal(struct {
|
||||
Tools []MCPTool `json:"tools"`
|
||||
}{Tools: tools})
|
||||
if err != nil {
|
||||
t.Fatalf("marshal tools: %v", err)
|
||||
}
|
||||
os.WriteFile(filepath.Join(dir, "tools.json"), toolsJSON, 0o644)
|
||||
os.WriteFile(filepath.Join(dir, "call.json"), []byte(callResult), 0o644)
|
||||
|
||||
// The script uses jq-free JSON construction: reads response payload from
|
||||
// file and wraps it in a JSON-RPC envelope using python (widely available).
|
||||
script := filepath.Join(dir, "mcp-server.sh")
|
||||
content := `#!/bin/bash
|
||||
DIR="` + dir + `"
|
||||
while IFS= read -r line; do
|
||||
method=$(echo "$line" | python3 -c "import sys,json; print(json.load(sys.stdin).get('method',''))" 2>/dev/null)
|
||||
id=$(echo "$line" | python3 -c "import sys,json; print(json.load(sys.stdin).get('id',0))" 2>/dev/null)
|
||||
|
||||
case "$method" in
|
||||
initialize)
|
||||
result=$(cat "$DIR/init.json")
|
||||
printf '{"jsonrpc":"2.0","id":%s,"result":%s}\n' "$id" "$result"
|
||||
;;
|
||||
initialized)
|
||||
;;
|
||||
tools/list)
|
||||
result=$(cat "$DIR/tools.json")
|
||||
printf '{"jsonrpc":"2.0","id":%s,"result":%s}\n' "$id" "$result"
|
||||
;;
|
||||
tools/call)
|
||||
result=$(cat "$DIR/call.json")
|
||||
printf '{"jsonrpc":"2.0","id":%s,"result":%s}\n' "$id" "$result"
|
||||
;;
|
||||
*)
|
||||
printf '{"jsonrpc":"2.0","id":%s,"error":{"code":-32601,"message":"method not found"}}\n' "$id"
|
||||
;;
|
||||
esac
|
||||
done
|
||||
`
|
||||
if err := os.WriteFile(script, []byte(content), 0o755); err != nil {
|
||||
t.Fatalf("write mcp server: %v", err)
|
||||
}
|
||||
return script
|
||||
}
|
||||
|
||||
func TestClient_Initialize(t *testing.T) {
|
||||
tools := []MCPTool{
|
||||
{Name: "echo", Description: "Echo input", InputSchema: json.RawMessage(`{"type":"object"}`)},
|
||||
}
|
||||
script := writeMCPServer(t, tools, `{}`)
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
tr := NewTransport("bash", []string{script}, nil, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
if err := tr.Start(ctx); err != nil {
|
||||
t.Fatalf("Start: %v", err)
|
||||
}
|
||||
|
||||
client := NewClient(tr, logger)
|
||||
defer client.Close()
|
||||
|
||||
if err := client.Initialize(ctx); err != nil {
|
||||
t.Fatalf("Initialize: %v", err)
|
||||
}
|
||||
|
||||
if client.serverInfo.Name != "test-server" {
|
||||
t.Errorf("serverInfo.Name = %q, want %q", client.serverInfo.Name, "test-server")
|
||||
}
|
||||
if client.serverInfo.Version != "1.0.0" {
|
||||
t.Errorf("serverInfo.Version = %q, want %q", client.serverInfo.Version, "1.0.0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_ListTools(t *testing.T) {
|
||||
tools := []MCPTool{
|
||||
{
|
||||
Name: "get_status",
|
||||
Description: "Get git status",
|
||||
InputSchema: json.RawMessage(`{"type":"object","properties":{"path":{"type":"string"}}}`),
|
||||
},
|
||||
{
|
||||
Name: "commit",
|
||||
Description: "Create commit",
|
||||
InputSchema: json.RawMessage(`{"type":"object","properties":{"message":{"type":"string"}},"required":["message"]}`),
|
||||
},
|
||||
}
|
||||
script := writeMCPServer(t, tools, `{}`)
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
tr := NewTransport("bash", []string{script}, nil, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
if err := tr.Start(ctx); err != nil {
|
||||
t.Fatalf("Start: %v", err)
|
||||
}
|
||||
|
||||
client := NewClient(tr, logger)
|
||||
defer client.Close()
|
||||
|
||||
if err := client.Initialize(ctx); err != nil {
|
||||
t.Fatalf("Initialize: %v", err)
|
||||
}
|
||||
|
||||
got, err := client.ListTools(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("ListTools: %v", err)
|
||||
}
|
||||
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("got %d tools, want 2", len(got))
|
||||
}
|
||||
if got[0].Name != "get_status" {
|
||||
t.Errorf("tool[0].Name = %q, want %q", got[0].Name, "get_status")
|
||||
}
|
||||
if got[1].Name != "commit" {
|
||||
t.Errorf("tool[1].Name = %q, want %q", got[1].Name, "commit")
|
||||
}
|
||||
// Verify InputSchema passes through as raw JSON.
|
||||
if string(got[0].InputSchema) == "" {
|
||||
t.Error("tool[0].InputSchema is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_CallTool(t *testing.T) {
|
||||
tools := []MCPTool{
|
||||
{Name: "echo", Description: "Echo", InputSchema: json.RawMessage(`{"type":"object"}`)},
|
||||
}
|
||||
callResult := `{"content":[{"type":"text","text":"hello world"}]}`
|
||||
script := writeMCPServer(t, tools, callResult)
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
tr := NewTransport("bash", []string{script}, nil, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
if err := tr.Start(ctx); err != nil {
|
||||
t.Fatalf("Start: %v", err)
|
||||
}
|
||||
|
||||
client := NewClient(tr, logger)
|
||||
defer client.Close()
|
||||
|
||||
if err := client.Initialize(ctx); err != nil {
|
||||
t.Fatalf("Initialize: %v", err)
|
||||
}
|
||||
|
||||
result, err := client.CallTool(ctx, "echo", json.RawMessage(`{"input":"test"}`))
|
||||
if err != nil {
|
||||
t.Fatalf("CallTool: %v", err)
|
||||
}
|
||||
|
||||
// Result should be the raw content array.
|
||||
var parsed struct {
|
||||
Content []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
}
|
||||
if err := json.Unmarshal(result, &parsed); err != nil {
|
||||
t.Fatalf("unmarshal result: %v", err)
|
||||
}
|
||||
if len(parsed.Content) != 1 {
|
||||
t.Fatalf("got %d content blocks, want 1", len(parsed.Content))
|
||||
}
|
||||
if parsed.Content[0].Text != "hello world" {
|
||||
t.Errorf("content text = %q, want %q", parsed.Content[0].Text, "hello world")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_InitializeFailure(t *testing.T) {
|
||||
// Server that returns an error for initialize.
|
||||
dir := t.TempDir()
|
||||
script := filepath.Join(dir, "bad-server.sh")
|
||||
content := `#!/bin/bash
|
||||
read -r line
|
||||
id=$(echo "$line" | grep -o '"id":[0-9]*' | head -1 | cut -d: -f2)
|
||||
echo "{\"jsonrpc\":\"2.0\",\"id\":$id,\"error\":{\"code\":-32000,\"message\":\"init failed\"}}"
|
||||
`
|
||||
if err := os.WriteFile(script, []byte(content), 0o755); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
tr := NewTransport("bash", []string{script}, nil, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
if err := tr.Start(ctx); err != nil {
|
||||
t.Fatalf("Start: %v", err)
|
||||
}
|
||||
|
||||
client := NewClient(tr, logger)
|
||||
defer client.Close()
|
||||
|
||||
err := client.Initialize(ctx)
|
||||
if err == nil {
|
||||
t.Fatal("expected Initialize to fail")
|
||||
}
|
||||
}
|
||||
60
internal/mcp/config.go
Normal file
60
internal/mcp/config.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/config"
|
||||
)
|
||||
|
||||
const defaultTimeout = 30 * time.Second
|
||||
|
||||
// ServerConfig is the validated, parsed form of config.MCPServerConfig.
|
||||
type ServerConfig struct {
|
||||
Name string
|
||||
Command string
|
||||
Args []string
|
||||
Env map[string]string
|
||||
Timeout time.Duration
|
||||
ReplaceDefault []string
|
||||
}
|
||||
|
||||
// ParseServerConfigs validates and converts raw config entries.
|
||||
func ParseServerConfigs(raw []config.MCPServerConfig) ([]ServerConfig, error) {
|
||||
seen := make(map[string]bool, len(raw))
|
||||
result := make([]ServerConfig, 0, len(raw))
|
||||
|
||||
for i, r := range raw {
|
||||
if r.Name == "" {
|
||||
return nil, fmt.Errorf("mcp_servers[%d]: name is required", i)
|
||||
}
|
||||
if seen[r.Name] {
|
||||
return nil, fmt.Errorf("mcp_servers: duplicate name %q", r.Name)
|
||||
}
|
||||
seen[r.Name] = true
|
||||
|
||||
if r.Command == "" {
|
||||
return nil, fmt.Errorf("mcp_servers[%d] %q: command is required", i, r.Name)
|
||||
}
|
||||
|
||||
timeout := defaultTimeout
|
||||
if r.Timeout != "" {
|
||||
var err error
|
||||
timeout, err = time.ParseDuration(r.Timeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mcp_servers[%d] %q: invalid timeout %q: %w", i, r.Name, r.Timeout, err)
|
||||
}
|
||||
}
|
||||
|
||||
result = append(result, ServerConfig{
|
||||
Name: r.Name,
|
||||
Command: r.Command,
|
||||
Args: r.Args,
|
||||
Env: r.Env,
|
||||
Timeout: timeout,
|
||||
ReplaceDefault: r.ReplaceDefault,
|
||||
})
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
84
internal/mcp/config_test.go
Normal file
84
internal/mcp/config_test.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/config"
|
||||
)
|
||||
|
||||
func TestParseServerConfigs_Valid(t *testing.T) {
|
||||
raw := []config.MCPServerConfig{
|
||||
{
|
||||
Name: "git",
|
||||
Command: "mcp-server-git",
|
||||
Args: []string{"--repo", "."},
|
||||
Env: map[string]string{"GIT_DIR": ".git"},
|
||||
Timeout: "10s",
|
||||
ReplaceDefault: []string{"bash"},
|
||||
},
|
||||
{
|
||||
Name: "docker",
|
||||
Command: "mcp-server-docker",
|
||||
},
|
||||
}
|
||||
|
||||
got, err := ParseServerConfigs(raw)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseServerConfigs: %v", err)
|
||||
}
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("got %d configs, want 2", len(got))
|
||||
}
|
||||
|
||||
if got[0].Name != "git" {
|
||||
t.Errorf("config[0].Name = %q, want %q", got[0].Name, "git")
|
||||
}
|
||||
if got[0].Timeout != 10*time.Second {
|
||||
t.Errorf("config[0].Timeout = %v, want %v", got[0].Timeout, 10*time.Second)
|
||||
}
|
||||
if len(got[0].ReplaceDefault) != 1 || got[0].ReplaceDefault[0] != "bash" {
|
||||
t.Errorf("config[0].ReplaceDefault = %v, want [bash]", got[0].ReplaceDefault)
|
||||
}
|
||||
|
||||
// Second config should get default timeout.
|
||||
if got[1].Timeout != defaultTimeout {
|
||||
t.Errorf("config[1].Timeout = %v, want default %v", got[1].Timeout, defaultTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseServerConfigs_Errors(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
raw []config.MCPServerConfig
|
||||
}{
|
||||
{
|
||||
name: "missing name",
|
||||
raw: []config.MCPServerConfig{{Command: "foo"}},
|
||||
},
|
||||
{
|
||||
name: "missing command",
|
||||
raw: []config.MCPServerConfig{{Name: "foo"}},
|
||||
},
|
||||
{
|
||||
name: "duplicate name",
|
||||
raw: []config.MCPServerConfig{
|
||||
{Name: "foo", Command: "a"},
|
||||
{Name: "foo", Command: "b"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bad timeout",
|
||||
raw: []config.MCPServerConfig{{Name: "foo", Command: "bar", Timeout: "not-a-duration"}},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := ParseServerConfigs(tt.raw)
|
||||
if err == nil {
|
||||
t.Error("expected error")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
40
internal/mcp/jsonrpc.go
Normal file
40
internal/mcp/jsonrpc.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Request is a JSON-RPC 2.0 request.
|
||||
type Request struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID int64 `json:"id"`
|
||||
Method string `json:"method"`
|
||||
Params json.RawMessage `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
// Response is a JSON-RPC 2.0 response.
|
||||
type Response struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID int64 `json:"id"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
Error *RPCError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// Notification is a JSON-RPC 2.0 notification (no ID, no response expected).
|
||||
type Notification struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
Method string `json:"method"`
|
||||
Params json.RawMessage `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
// RPCError is the JSON-RPC 2.0 error object.
|
||||
type RPCError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
func (e *RPCError) Error() string {
|
||||
return fmt.Sprintf("rpc error %d: %s", e.Code, e.Message)
|
||||
}
|
||||
184
internal/mcp/jsonrpc_test.go
Normal file
184
internal/mcp/jsonrpc_test.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRequest_MarshalRoundtrip(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req Request
|
||||
}{
|
||||
{
|
||||
name: "with params",
|
||||
req: Request{
|
||||
JSONRPC: "2.0",
|
||||
ID: 1,
|
||||
Method: "initialize",
|
||||
Params: json.RawMessage(`{"capabilities":{}}`),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nil params omitted",
|
||||
req: Request{
|
||||
JSONRPC: "2.0",
|
||||
ID: 42,
|
||||
Method: "tools/list",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
data, err := json.Marshal(tt.req)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal: %v", err)
|
||||
}
|
||||
|
||||
var got Request
|
||||
if err := json.Unmarshal(data, &got); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if got.JSONRPC != tt.req.JSONRPC {
|
||||
t.Errorf("JSONRPC = %q, want %q", got.JSONRPC, tt.req.JSONRPC)
|
||||
}
|
||||
if got.ID != tt.req.ID {
|
||||
t.Errorf("ID = %d, want %d", got.ID, tt.req.ID)
|
||||
}
|
||||
if got.Method != tt.req.Method {
|
||||
t.Errorf("Method = %q, want %q", got.Method, tt.req.Method)
|
||||
}
|
||||
if string(got.Params) != string(tt.req.Params) {
|
||||
t.Errorf("Params = %s, want %s", got.Params, tt.req.Params)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponse_MarshalRoundtrip(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
resp Response
|
||||
}{
|
||||
{
|
||||
name: "success result",
|
||||
resp: Response{
|
||||
JSONRPC: "2.0",
|
||||
ID: 1,
|
||||
Result: json.RawMessage(`{"serverInfo":{"name":"test"}}`),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "error response",
|
||||
resp: Response{
|
||||
JSONRPC: "2.0",
|
||||
ID: 2,
|
||||
Error: &RPCError{
|
||||
Code: -32601,
|
||||
Message: "method not found",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "error with data",
|
||||
resp: Response{
|
||||
JSONRPC: "2.0",
|
||||
ID: 3,
|
||||
Error: &RPCError{
|
||||
Code: -32000,
|
||||
Message: "server error",
|
||||
Data: json.RawMessage(`{"detail":"something broke"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
data, err := json.Marshal(tt.resp)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal: %v", err)
|
||||
}
|
||||
|
||||
var got Response
|
||||
if err := json.Unmarshal(data, &got); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if got.JSONRPC != tt.resp.JSONRPC {
|
||||
t.Errorf("JSONRPC = %q, want %q", got.JSONRPC, tt.resp.JSONRPC)
|
||||
}
|
||||
if got.ID != tt.resp.ID {
|
||||
t.Errorf("ID = %d, want %d", got.ID, tt.resp.ID)
|
||||
}
|
||||
if string(got.Result) != string(tt.resp.Result) {
|
||||
t.Errorf("Result = %s, want %s", got.Result, tt.resp.Result)
|
||||
}
|
||||
if (got.Error == nil) != (tt.resp.Error == nil) {
|
||||
t.Fatalf("Error nil mismatch: got %v, want %v", got.Error, tt.resp.Error)
|
||||
}
|
||||
if got.Error != nil {
|
||||
if got.Error.Code != tt.resp.Error.Code {
|
||||
t.Errorf("Error.Code = %d, want %d", got.Error.Code, tt.resp.Error.Code)
|
||||
}
|
||||
if got.Error.Message != tt.resp.Error.Message {
|
||||
t.Errorf("Error.Message = %q, want %q", got.Error.Message, tt.resp.Error.Message)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotification_OmitsID(t *testing.T) {
|
||||
n := Notification{
|
||||
JSONRPC: "2.0",
|
||||
Method: "initialized",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(n)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal: %v", err)
|
||||
}
|
||||
|
||||
// Notification must not have an "id" field.
|
||||
var raw map[string]json.RawMessage
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
t.Fatalf("unmarshal raw: %v", err)
|
||||
}
|
||||
if _, ok := raw["id"]; ok {
|
||||
t.Error("notification should not contain 'id' field")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRPCError_Error(t *testing.T) {
|
||||
e := &RPCError{Code: -32601, Message: "method not found"}
|
||||
got := e.Error()
|
||||
want := "rpc error -32601: method not found"
|
||||
if got != want {
|
||||
t.Errorf("Error() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequest_NilParams_MarshalOmitsField(t *testing.T) {
|
||||
req := Request{
|
||||
JSONRPC: "2.0",
|
||||
ID: 1,
|
||||
Method: "tools/list",
|
||||
Params: nil,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal: %v", err)
|
||||
}
|
||||
|
||||
var raw map[string]json.RawMessage
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
t.Fatalf("unmarshal raw: %v", err)
|
||||
}
|
||||
if _, ok := raw["params"]; ok {
|
||||
t.Error("nil Params should be omitted from JSON")
|
||||
}
|
||||
}
|
||||
113
internal/mcp/manager.go
Normal file
113
internal/mcp/manager.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/tool"
|
||||
)
|
||||
|
||||
// Manager coordinates multiple MCP server lifecycles and tool registration.
|
||||
type Manager struct {
|
||||
clients map[string]*Client
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewManager creates an MCP manager.
|
||||
func NewManager(logger *slog.Logger) *Manager {
|
||||
return &Manager{
|
||||
clients: make(map[string]*Client),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// StartAll starts all configured MCP servers, discovers tools, and registers
|
||||
// them in the tool registry. Servers start sequentially to simplify error handling.
|
||||
func (m *Manager) StartAll(ctx context.Context, servers []ServerConfig, registry *tool.Registry) error {
|
||||
for _, srv := range servers {
|
||||
client, err := m.startServer(ctx, srv)
|
||||
if err != nil {
|
||||
m.Shutdown() // clean up already-started servers
|
||||
return fmt.Errorf("mcp server %q: %w", srv.Name, err)
|
||||
}
|
||||
|
||||
tools, err := client.ListTools(ctx)
|
||||
if err != nil {
|
||||
m.Shutdown()
|
||||
return fmt.Errorf("mcp server %q: list tools: %w", srv.Name, err)
|
||||
}
|
||||
|
||||
m.registerTools(srv, tools, client, registry)
|
||||
m.clients[srv.Name] = client
|
||||
|
||||
m.logger.Info("mcp server started",
|
||||
"name", srv.Name,
|
||||
"tools", len(tools),
|
||||
"replace", srv.ReplaceDefault,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown gracefully stops all MCP server processes.
|
||||
func (m *Manager) Shutdown() error {
|
||||
var firstErr error
|
||||
for name, client := range m.clients {
|
||||
if err := client.Close(); err != nil && firstErr == nil {
|
||||
firstErr = fmt.Errorf("mcp shutdown %q: %w", name, err)
|
||||
}
|
||||
}
|
||||
m.clients = make(map[string]*Client)
|
||||
return firstErr
|
||||
}
|
||||
|
||||
func (m *Manager) startServer(ctx context.Context, srv ServerConfig) (*Client, error) {
|
||||
tr := NewTransport(srv.Command, srv.Args, srv.Env, m.logger)
|
||||
|
||||
if err := tr.Start(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := NewClient(tr, m.logger)
|
||||
|
||||
initCtx, cancel := context.WithTimeout(ctx, srv.Timeout)
|
||||
defer cancel()
|
||||
|
||||
if err := client.Initialize(initCtx); err != nil {
|
||||
tr.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (m *Manager) registerTools(srv ServerConfig, tools []MCPTool, client *Client, registry *tool.Registry) {
|
||||
replaceSet := make(map[string]bool, len(srv.ReplaceDefault))
|
||||
for _, name := range srv.ReplaceDefault {
|
||||
replaceSet[name] = true
|
||||
}
|
||||
|
||||
for _, mt := range tools {
|
||||
adapter := NewAdapter(srv.Name, mt, client)
|
||||
|
||||
// Check if any replace_default entry matches this MCP tool.
|
||||
// Match by checking if the MCP tool name appears in a replace target,
|
||||
// or assign replacements in order.
|
||||
for _, replaceName := range srv.ReplaceDefault {
|
||||
if replaceSet[replaceName] {
|
||||
adapter.SetOverrideName(replaceName)
|
||||
delete(replaceSet, replaceName)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
registry.Register(adapter)
|
||||
m.logger.Debug("mcp tool registered",
|
||||
"name", adapter.Name(),
|
||||
"server", srv.Name,
|
||||
"mcp_name", mt.Name,
|
||||
)
|
||||
}
|
||||
}
|
||||
209
internal/mcp/manager_test.go
Normal file
209
internal/mcp/manager_test.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/tool"
|
||||
)
|
||||
|
||||
func TestManager_StartAll_RegistersTools(t *testing.T) {
|
||||
tools := []MCPTool{
|
||||
{Name: "status", Description: "Get status", InputSchema: json.RawMessage(`{"type":"object"}`)},
|
||||
{Name: "commit", Description: "Create commit", InputSchema: json.RawMessage(`{"type":"object"}`)},
|
||||
}
|
||||
callResult := `{"content":[{"type":"text","text":"ok"}]}`
|
||||
script := writeMCPServer(t, tools, callResult)
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
reg := tool.NewRegistry()
|
||||
|
||||
mgr := NewManager(logger)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := mgr.StartAll(ctx, []ServerConfig{
|
||||
{
|
||||
Name: "git",
|
||||
Command: "bash",
|
||||
Args: []string{script},
|
||||
Timeout: 5 * time.Second,
|
||||
},
|
||||
}, reg)
|
||||
if err != nil {
|
||||
t.Fatalf("StartAll: %v", err)
|
||||
}
|
||||
defer mgr.Shutdown()
|
||||
|
||||
// Tools should be registered with mcp__ prefix.
|
||||
if _, ok := reg.Get("mcp__git__status"); !ok {
|
||||
t.Error("mcp__git__status not found in registry")
|
||||
}
|
||||
if _, ok := reg.Get("mcp__git__commit"); !ok {
|
||||
t.Error("mcp__git__commit not found in registry")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_StartAll_ReplaceDefault(t *testing.T) {
|
||||
tools := []MCPTool{
|
||||
{Name: "exec", Description: "Custom bash", InputSchema: json.RawMessage(`{"type":"object"}`)},
|
||||
}
|
||||
callResult := `{"content":[{"type":"text","text":"replaced"}]}`
|
||||
script := writeMCPServer(t, tools, callResult)
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
reg := tool.NewRegistry()
|
||||
|
||||
// Register a mock built-in "bash" tool first.
|
||||
reg.Register(&mockTool{name: "bash"})
|
||||
|
||||
mgr := NewManager(logger)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := mgr.StartAll(ctx, []ServerConfig{
|
||||
{
|
||||
Name: "custom",
|
||||
Command: "bash",
|
||||
Args: []string{script},
|
||||
Timeout: 5 * time.Second,
|
||||
ReplaceDefault: []string{"bash"},
|
||||
},
|
||||
}, reg)
|
||||
if err != nil {
|
||||
t.Fatalf("StartAll: %v", err)
|
||||
}
|
||||
defer mgr.Shutdown()
|
||||
|
||||
// The "bash" tool should now be the MCP adapter, not the mock.
|
||||
bashTool, ok := reg.Get("bash")
|
||||
if !ok {
|
||||
t.Fatal("bash tool not found after replace")
|
||||
}
|
||||
|
||||
adapter, ok := bashTool.(*Adapter)
|
||||
if !ok {
|
||||
t.Fatalf("bash tool is %T, want *Adapter", bashTool)
|
||||
}
|
||||
if adapter.mcpTool.Name != "exec" {
|
||||
t.Errorf("replaced tool's MCP name = %q, want %q", adapter.mcpTool.Name, "exec")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_StartAll_BadCommand(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
reg := tool.NewRegistry()
|
||||
|
||||
mgr := NewManager(logger)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := mgr.StartAll(ctx, []ServerConfig{
|
||||
{
|
||||
Name: "bad",
|
||||
Command: "/nonexistent/binary/that/does/not/exist",
|
||||
Timeout: 2 * time.Second,
|
||||
},
|
||||
}, reg)
|
||||
if err == nil {
|
||||
t.Error("expected error for bad command")
|
||||
mgr.Shutdown()
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_Shutdown(t *testing.T) {
|
||||
tools := []MCPTool{
|
||||
{Name: "ping", Description: "Ping", InputSchema: json.RawMessage(`{"type":"object"}`)},
|
||||
}
|
||||
script := writeMCPServer(t, tools, `{"content":[]}`)
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
reg := tool.NewRegistry()
|
||||
|
||||
mgr := NewManager(logger)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := mgr.StartAll(ctx, []ServerConfig{
|
||||
{
|
||||
Name: "test",
|
||||
Command: "bash",
|
||||
Args: []string{script},
|
||||
Timeout: 5 * time.Second,
|
||||
},
|
||||
}, reg)
|
||||
if err != nil {
|
||||
t.Fatalf("StartAll: %v", err)
|
||||
}
|
||||
|
||||
// Shutdown should not error.
|
||||
if err := mgr.Shutdown(); err != nil {
|
||||
t.Errorf("Shutdown: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_StartAll_ReplaceDefault_PicksMatchingTool(t *testing.T) {
|
||||
// Server has multiple tools, only one replaces a built-in.
|
||||
tools := []MCPTool{
|
||||
{Name: "read", Description: "Read file", InputSchema: json.RawMessage(`{"type":"object"}`)},
|
||||
{Name: "write", Description: "Write file", InputSchema: json.RawMessage(`{"type":"object"}`)},
|
||||
{Name: "extra", Description: "Extra tool", InputSchema: json.RawMessage(`{"type":"object"}`)},
|
||||
}
|
||||
script := writeMCPServer(t, tools, `{"content":[]}`)
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
reg := tool.NewRegistry()
|
||||
reg.Register(&mockTool{name: "fs.read"})
|
||||
reg.Register(&mockTool{name: "fs.write"})
|
||||
|
||||
mgr := NewManager(logger)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := mgr.StartAll(ctx, []ServerConfig{
|
||||
{
|
||||
Name: "custom-fs",
|
||||
Command: "bash",
|
||||
Args: []string{script},
|
||||
Timeout: 5 * time.Second,
|
||||
ReplaceDefault: []string{"fs.read", "fs.write"},
|
||||
},
|
||||
}, reg)
|
||||
if err != nil {
|
||||
t.Fatalf("StartAll: %v", err)
|
||||
}
|
||||
defer mgr.Shutdown()
|
||||
|
||||
// fs.read and fs.write should be replaced.
|
||||
if fsRead, ok := reg.Get("fs.read"); !ok {
|
||||
t.Error("fs.read not found")
|
||||
} else if _, ok := fsRead.(*Adapter); !ok {
|
||||
t.Error("fs.read should be replaced by MCP adapter")
|
||||
}
|
||||
if fsWrite, ok := reg.Get("fs.write"); !ok {
|
||||
t.Error("fs.write not found")
|
||||
} else if _, ok := fsWrite.(*Adapter); !ok {
|
||||
t.Error("fs.write should be replaced by MCP adapter")
|
||||
}
|
||||
// "extra" should be registered with mcp__ prefix.
|
||||
if _, ok := reg.Get("mcp__custom-fs__extra"); !ok {
|
||||
t.Error("mcp__custom-fs__extra not found in registry")
|
||||
}
|
||||
}
|
||||
|
||||
// mockTool is a minimal tool.Tool for testing registry replacement.
|
||||
type mockTool struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func (m *mockTool) Name() string { return m.name }
|
||||
func (m *mockTool) Description() string { return "mock" }
|
||||
func (m *mockTool) Parameters() json.RawMessage { return json.RawMessage(`{}`) }
|
||||
func (m *mockTool) Execute(_ context.Context, _ json.RawMessage) (tool.Result, error) { return tool.Result{}, nil }
|
||||
func (m *mockTool) IsReadOnly() bool { return false }
|
||||
func (m *mockTool) IsDestructive() bool { return false }
|
||||
|
||||
113
internal/mcp/tool.go
Normal file
113
internal/mcp/tool.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/tool"
|
||||
)
|
||||
|
||||
// Adapter wraps an MCPTool as a gnoma tool.Tool.
|
||||
type Adapter struct {
|
||||
serverName string
|
||||
mcpTool MCPTool
|
||||
client *Client
|
||||
overrideName string // non-empty when replacing a built-in
|
||||
}
|
||||
|
||||
// Compile-time interface checks.
|
||||
var (
|
||||
_ tool.Tool = (*Adapter)(nil)
|
||||
_ tool.DeferrableTool = (*Adapter)(nil)
|
||||
)
|
||||
|
||||
// NewAdapter creates a tool adapter for the given MCP tool.
|
||||
func NewAdapter(serverName string, mcpTool MCPTool, client *Client) *Adapter {
|
||||
return &Adapter{
|
||||
serverName: serverName,
|
||||
mcpTool: mcpTool,
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// SetOverrideName sets a replacement name (used for replace_default).
|
||||
func (a *Adapter) SetOverrideName(name string) {
|
||||
a.overrideName = name
|
||||
}
|
||||
|
||||
// Name returns the tool name. Uses mcp__{server}__{tool} convention,
|
||||
// or the override name when replacing a built-in.
|
||||
func (a *Adapter) Name() string {
|
||||
if a.overrideName != "" {
|
||||
return a.overrideName
|
||||
}
|
||||
return fmt.Sprintf("mcp__%s__%s", a.serverName, a.mcpTool.Name)
|
||||
}
|
||||
|
||||
// Description returns the MCP tool's description.
|
||||
func (a *Adapter) Description() string {
|
||||
return a.mcpTool.Description
|
||||
}
|
||||
|
||||
// Parameters returns the MCP tool's input schema (zero-copy passthrough).
|
||||
func (a *Adapter) Parameters() json.RawMessage {
|
||||
return a.mcpTool.InputSchema
|
||||
}
|
||||
|
||||
// Execute calls the MCP server's tools/call method.
|
||||
func (a *Adapter) Execute(ctx context.Context, args json.RawMessage) (tool.Result, error) {
|
||||
result, err := a.client.CallTool(ctx, a.mcpTool.Name, args)
|
||||
if err != nil {
|
||||
// RPC errors are surfaced as tool output so the LLM can see them.
|
||||
var rpcErr *RPCError
|
||||
if errors.As(err, &rpcErr) {
|
||||
return tool.Result{
|
||||
Output: fmt.Sprintf("MCP error: %s", rpcErr.Message),
|
||||
}, nil
|
||||
}
|
||||
// Transport-level errors are Go errors (broken pipe, timeout).
|
||||
return tool.Result{}, err
|
||||
}
|
||||
|
||||
output, err := extractTextContent(result)
|
||||
if err != nil {
|
||||
return tool.Result{
|
||||
Output: fmt.Sprintf("MCP response parse error: %v\nRaw: %s", err, result),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return tool.Result{Output: output}, nil
|
||||
}
|
||||
|
||||
// IsReadOnly returns false conservatively — MCP tools may have side effects.
|
||||
func (a *Adapter) IsReadOnly() bool { return false }
|
||||
|
||||
// IsDestructive returns false — use permission rules for granular control.
|
||||
func (a *Adapter) IsDestructive() bool { return false }
|
||||
|
||||
// ShouldDefer returns true — MCP tools start deferred to reduce token overhead.
|
||||
func (a *Adapter) ShouldDefer() bool { return true }
|
||||
|
||||
// extractTextContent concatenates text blocks from an MCP tools/call result.
|
||||
func extractTextContent(raw json.RawMessage) (string, error) {
|
||||
var result struct {
|
||||
Content []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &result); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var parts []string
|
||||
for _, c := range result.Content {
|
||||
if c.Type == "text" {
|
||||
parts = append(parts, c.Text)
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n"), nil
|
||||
}
|
||||
205
internal/mcp/tool_test.go
Normal file
205
internal/mcp/tool_test.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAdapter_Name_Default(t *testing.T) {
|
||||
a := NewAdapter("git", MCPTool{Name: "status"}, nil)
|
||||
want := "mcp__git__status"
|
||||
if got := a.Name(); got != want {
|
||||
t.Errorf("Name() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_Name_Override(t *testing.T) {
|
||||
a := NewAdapter("custom-fs", MCPTool{Name: "read_file"}, nil)
|
||||
a.SetOverrideName("fs.read")
|
||||
want := "fs.read"
|
||||
if got := a.Name(); got != want {
|
||||
t.Errorf("Name() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_NameConvention(t *testing.T) {
|
||||
tests := []struct {
|
||||
server string
|
||||
tool string
|
||||
want string
|
||||
}{
|
||||
{"git", "status", "mcp__git__status"},
|
||||
{"docker", "ps", "mcp__docker__ps"},
|
||||
{"my-server", "my-tool", "mcp__my-server__my-tool"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
a := NewAdapter(tt.server, MCPTool{Name: tt.tool}, nil)
|
||||
if got := a.Name(); got != tt.want {
|
||||
t.Errorf("Name(%q, %q) = %q, want %q", tt.server, tt.tool, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_Description(t *testing.T) {
|
||||
a := NewAdapter("git", MCPTool{
|
||||
Name: "status",
|
||||
Description: "Get git status",
|
||||
}, nil)
|
||||
if got := a.Description(); got != "Get git status" {
|
||||
t.Errorf("Description() = %q, want %q", got, "Get git status")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_Parameters(t *testing.T) {
|
||||
schema := json.RawMessage(`{"type":"object","properties":{"path":{"type":"string"}}}`)
|
||||
a := NewAdapter("git", MCPTool{
|
||||
Name: "status",
|
||||
InputSchema: schema,
|
||||
}, nil)
|
||||
|
||||
got := a.Parameters()
|
||||
if string(got) != string(schema) {
|
||||
t.Errorf("Parameters() = %s, want %s", got, schema)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_IsReadOnly(t *testing.T) {
|
||||
a := NewAdapter("git", MCPTool{Name: "status"}, nil)
|
||||
if a.IsReadOnly() {
|
||||
t.Error("IsReadOnly() = true, want false (conservative default)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_IsDestructive(t *testing.T) {
|
||||
a := NewAdapter("git", MCPTool{Name: "status"}, nil)
|
||||
if a.IsDestructive() {
|
||||
t.Error("IsDestructive() = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_ShouldDefer(t *testing.T) {
|
||||
a := NewAdapter("git", MCPTool{Name: "status"}, nil)
|
||||
if !a.ShouldDefer() {
|
||||
t.Error("ShouldDefer() = false, want true (MCP tools start deferred)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_Execute(t *testing.T) {
|
||||
callResult := `{"content":[{"type":"text","text":"On branch main\nnothing to commit"}]}`
|
||||
tools := []MCPTool{
|
||||
{Name: "status", Description: "Git status", InputSchema: json.RawMessage(`{"type":"object"}`)},
|
||||
}
|
||||
script := writeMCPServer(t, tools, callResult)
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
tr := NewTransport("bash", []string{script}, nil, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
if err := tr.Start(ctx); err != nil {
|
||||
t.Fatalf("Start: %v", err)
|
||||
}
|
||||
|
||||
client := NewClient(tr, logger)
|
||||
defer client.Close()
|
||||
|
||||
if err := client.Initialize(ctx); err != nil {
|
||||
t.Fatalf("Initialize: %v", err)
|
||||
}
|
||||
|
||||
a := NewAdapter("git", tools[0], client)
|
||||
|
||||
result, err := a.Execute(ctx, json.RawMessage(`{}`))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
|
||||
want := "On branch main\nnothing to commit"
|
||||
if result.Output != want {
|
||||
t.Errorf("Output = %q, want %q", result.Output, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_Execute_MultipleTextBlocks(t *testing.T) {
|
||||
callResult := `{"content":[{"type":"text","text":"line 1"},{"type":"text","text":"line 2"}]}`
|
||||
tools := []MCPTool{
|
||||
{Name: "multi", Description: "Multi", InputSchema: json.RawMessage(`{"type":"object"}`)},
|
||||
}
|
||||
script := writeMCPServer(t, tools, callResult)
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
tr := NewTransport("bash", []string{script}, nil, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
if err := tr.Start(ctx); err != nil {
|
||||
t.Fatalf("Start: %v", err)
|
||||
}
|
||||
|
||||
client := NewClient(tr, logger)
|
||||
defer client.Close()
|
||||
|
||||
if err := client.Initialize(ctx); err != nil {
|
||||
t.Fatalf("Initialize: %v", err)
|
||||
}
|
||||
|
||||
a := NewAdapter("test", tools[0], client)
|
||||
result, err := a.Execute(ctx, json.RawMessage(`{}`))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
|
||||
want := "line 1\nline 2"
|
||||
if result.Output != want {
|
||||
t.Errorf("Output = %q, want %q", result.Output, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_Execute_RPCError(t *testing.T) {
|
||||
// Server that returns an error for tools/call.
|
||||
dir := t.TempDir()
|
||||
script := filepath.Join(dir, "err-server.sh")
|
||||
content := `#!/bin/bash
|
||||
DIR="` + dir + `"
|
||||
while IFS= read -r line; do
|
||||
method=$(echo "$line" | python3 -c "import sys,json; print(json.load(sys.stdin).get('method',''))" 2>/dev/null)
|
||||
id=$(echo "$line" | python3 -c "import sys,json; print(json.load(sys.stdin).get('id',0))" 2>/dev/null)
|
||||
|
||||
case "$method" in
|
||||
initialize)
|
||||
printf '{"jsonrpc":"2.0","id":%s,"result":{"protocolVersion":"2024-11-05","serverInfo":{"name":"err","version":"1.0"}}}\n' "$id"
|
||||
;;
|
||||
initialized)
|
||||
;;
|
||||
tools/call)
|
||||
printf '{"jsonrpc":"2.0","id":%s,"error":{"code":-32000,"message":"tool failed"}}\n' "$id"
|
||||
;;
|
||||
esac
|
||||
done
|
||||
`
|
||||
os.WriteFile(script, []byte(content), 0o755)
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
tr := NewTransport("bash", []string{script}, nil, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
tr.Start(ctx)
|
||||
client := NewClient(tr, logger)
|
||||
defer client.Close()
|
||||
client.Initialize(ctx)
|
||||
|
||||
a := NewAdapter("err", MCPTool{Name: "broken", InputSchema: json.RawMessage(`{}`)}, client)
|
||||
result, err := a.Execute(ctx, json.RawMessage(`{}`))
|
||||
|
||||
// RPC errors should be returned as tool output (not Go errors)
|
||||
// so the LLM can see the failure and retry/explain.
|
||||
if err != nil {
|
||||
t.Fatalf("Execute returned Go error: %v", err)
|
||||
}
|
||||
if result.Output == "" {
|
||||
t.Error("expected non-empty error output")
|
||||
}
|
||||
}
|
||||
242
internal/mcp/transport.go
Normal file
242
internal/mcp/transport.go
Normal file
@@ -0,0 +1,242 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
const maxStderrCapture = 64 << 10 // 64KB
|
||||
|
||||
// Transport manages a stdio connection to an MCP server process.
|
||||
type Transport struct {
|
||||
command string
|
||||
args []string
|
||||
env map[string]string
|
||||
|
||||
cmd *exec.Cmd
|
||||
stdin io.WriteCloser
|
||||
stdout *bufio.Reader
|
||||
stderr limitedWriter
|
||||
|
||||
nextID atomic.Int64
|
||||
mu sync.Mutex // serializes writes to stdin
|
||||
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewTransport creates a transport for the given command. Call Start to spawn the process.
|
||||
func NewTransport(command string, args []string, env map[string]string, logger *slog.Logger) *Transport {
|
||||
return &Transport{
|
||||
command: command,
|
||||
args: args,
|
||||
env: env,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Start spawns the MCP server process.
|
||||
func (t *Transport) Start(ctx context.Context) error {
|
||||
t.cmd = exec.CommandContext(ctx, t.command, t.args...)
|
||||
t.cmd.Env = t.buildEnv()
|
||||
// Create a new process group so Close can kill the entire tree.
|
||||
t.cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
||||
t.stderr = limitedWriter{max: maxStderrCapture}
|
||||
t.cmd.Stderr = &t.stderr
|
||||
|
||||
var err error
|
||||
t.stdin, err = t.cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("mcp: stdin pipe: %w", err)
|
||||
}
|
||||
|
||||
stdout, err := t.cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("mcp: stdout pipe: %w", err)
|
||||
}
|
||||
t.stdout = bufio.NewReader(stdout)
|
||||
|
||||
if err := t.cmd.Start(); err != nil {
|
||||
return fmt.Errorf("mcp: start %q: %w", t.command, err)
|
||||
}
|
||||
|
||||
t.logger.Debug("mcp transport started", "command", t.command, "pid", t.cmd.Process.Pid)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Call sends a JSON-RPC request and waits for the response.
|
||||
func (t *Transport) Call(ctx context.Context, method string, params any) (json.RawMessage, error) {
|
||||
id := t.nextID.Add(1)
|
||||
|
||||
var rawParams json.RawMessage
|
||||
if params != nil {
|
||||
var err error
|
||||
rawParams, err = json.Marshal(params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mcp: marshal params: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
req := Request{
|
||||
JSONRPC: "2.0",
|
||||
ID: id,
|
||||
Method: method,
|
||||
Params: rawParams,
|
||||
}
|
||||
|
||||
if err := t.send(ctx, req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return t.readResponse(ctx, id)
|
||||
}
|
||||
|
||||
// Notify sends a JSON-RPC notification (no response expected).
|
||||
func (t *Transport) Notify(ctx context.Context, method string, params any) error {
|
||||
var rawParams json.RawMessage
|
||||
if params != nil {
|
||||
var err error
|
||||
rawParams, err = json.Marshal(params)
|
||||
if err != nil {
|
||||
return fmt.Errorf("mcp: marshal params: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
notif := Notification{
|
||||
JSONRPC: "2.0",
|
||||
Method: method,
|
||||
Params: rawParams,
|
||||
}
|
||||
|
||||
return t.send(ctx, notif)
|
||||
}
|
||||
|
||||
// Close gracefully shuts down the server process.
|
||||
func (t *Transport) Close() error {
|
||||
if t.stdin != nil {
|
||||
t.stdin.Close()
|
||||
}
|
||||
|
||||
if t.cmd == nil || t.cmd.Process == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Give the process a chance to exit gracefully.
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- t.cmd.Wait()
|
||||
}()
|
||||
|
||||
// Try graceful exit first (stdin closed above), then escalate.
|
||||
select {
|
||||
case err := <-done:
|
||||
return err
|
||||
case <-time.After(2 * time.Second):
|
||||
}
|
||||
|
||||
// Graceful didn't work — kill the entire process group.
|
||||
t.logger.Warn("mcp: server did not exit, killing process group", "command", t.command)
|
||||
syscall.Kill(-t.cmd.Process.Pid, syscall.SIGKILL)
|
||||
return <-done
|
||||
}
|
||||
|
||||
// Stderr returns captured stderr output from the server process.
|
||||
func (t *Transport) Stderr() string {
|
||||
return t.stderr.String()
|
||||
}
|
||||
|
||||
func (t *Transport) send(ctx context.Context, v any) error {
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("mcp: marshal: %w", err)
|
||||
}
|
||||
data = append(data, '\n')
|
||||
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
if _, err := t.stdin.Write(data); err != nil {
|
||||
return fmt.Errorf("mcp: write: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Transport) readResponse(ctx context.Context, expectedID int64) (json.RawMessage, error) {
|
||||
type readResult struct {
|
||||
line []byte
|
||||
err error
|
||||
}
|
||||
|
||||
ch := make(chan readResult, 1)
|
||||
go func() {
|
||||
line, err := t.stdout.ReadBytes('\n')
|
||||
ch <- readResult{line, err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case rr := <-ch:
|
||||
if rr.err != nil {
|
||||
stderr := t.Stderr()
|
||||
if stderr != "" {
|
||||
return nil, fmt.Errorf("mcp: read: %w (stderr: %s)", rr.err, stderr)
|
||||
}
|
||||
return nil, fmt.Errorf("mcp: read: %w", rr.err)
|
||||
}
|
||||
|
||||
var resp Response
|
||||
if err := json.Unmarshal(rr.line, &resp); err != nil {
|
||||
return nil, fmt.Errorf("mcp: decode response: %w", err)
|
||||
}
|
||||
|
||||
if resp.Error != nil {
|
||||
return nil, resp.Error
|
||||
}
|
||||
|
||||
return resp.Result, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Transport) buildEnv() []string {
|
||||
env := os.Environ()
|
||||
for k, v := range t.env {
|
||||
env = append(env, k+"="+v)
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
// limitedWriter captures up to max bytes.
|
||||
type limitedWriter struct {
|
||||
buf bytes.Buffer
|
||||
max int
|
||||
}
|
||||
|
||||
func (w *limitedWriter) Write(p []byte) (int, error) {
|
||||
remaining := w.max - w.buf.Len()
|
||||
if remaining <= 0 {
|
||||
return len(p), nil // silently discard
|
||||
}
|
||||
if len(p) > remaining {
|
||||
p = p[:remaining]
|
||||
}
|
||||
return w.buf.Write(p)
|
||||
}
|
||||
|
||||
func (w *limitedWriter) String() string {
|
||||
return w.buf.String()
|
||||
}
|
||||
277
internal/mcp/transport_test.go
Normal file
277
internal/mcp/transport_test.go
Normal file
@@ -0,0 +1,277 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// writeMockServer creates a bash script that reads a JSON-RPC request from stdin
|
||||
// and writes a canned response to stdout. Returns the script path.
|
||||
func writeMockServer(t *testing.T, responseJSON string) string {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
script := filepath.Join(dir, "mock-server.sh")
|
||||
content := `#!/bin/bash
|
||||
read -r line
|
||||
echo '` + responseJSON + `'
|
||||
`
|
||||
if err := os.WriteFile(script, []byte(content), 0o755); err != nil {
|
||||
t.Fatalf("write mock server: %v", err)
|
||||
}
|
||||
return script
|
||||
}
|
||||
|
||||
// writeMockServerMulti creates a script that responds to each line of input
|
||||
// with a corresponding line of output from the provided responses.
|
||||
func writeMockServerMulti(t *testing.T, responses []string) string {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
script := filepath.Join(dir, "mock-server.sh")
|
||||
|
||||
var body string
|
||||
for _, r := range responses {
|
||||
body += `read -r line
|
||||
echo '` + r + `'
|
||||
`
|
||||
}
|
||||
|
||||
content := "#!/bin/bash\n" + body
|
||||
if err := os.WriteFile(script, []byte(content), 0o755); err != nil {
|
||||
t.Fatalf("write mock server: %v", err)
|
||||
}
|
||||
return script
|
||||
}
|
||||
|
||||
func testLogger() *slog.Logger {
|
||||
return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
}
|
||||
|
||||
func TestTransport_StartAndClose(t *testing.T) {
|
||||
script := writeMockServer(t, `{"jsonrpc":"2.0","id":1,"result":{}}`)
|
||||
tr := NewTransport("bash", []string{script}, nil, testLogger())
|
||||
|
||||
ctx := context.Background()
|
||||
if err := tr.Start(ctx); err != nil {
|
||||
t.Fatalf("Start: %v", err)
|
||||
}
|
||||
|
||||
if err := tr.Close(); err != nil {
|
||||
t.Fatalf("Close: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransport_Call_Success(t *testing.T) {
|
||||
resp := `{"jsonrpc":"2.0","id":1,"result":{"tools":[]}}`
|
||||
script := writeMockServer(t, resp)
|
||||
tr := NewTransport("bash", []string{script}, nil, testLogger())
|
||||
|
||||
ctx := context.Background()
|
||||
if err := tr.Start(ctx); err != nil {
|
||||
t.Fatalf("Start: %v", err)
|
||||
}
|
||||
defer tr.Close()
|
||||
|
||||
result, err := tr.Call(ctx, "tools/list", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Call: %v", err)
|
||||
}
|
||||
|
||||
var parsed struct {
|
||||
Tools []json.RawMessage `json:"tools"`
|
||||
}
|
||||
if err := json.Unmarshal(result, &parsed); err != nil {
|
||||
t.Fatalf("unmarshal result: %v", err)
|
||||
}
|
||||
if len(parsed.Tools) != 0 {
|
||||
t.Errorf("expected empty tools, got %d", len(parsed.Tools))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransport_Call_RPCError(t *testing.T) {
|
||||
resp := `{"jsonrpc":"2.0","id":1,"error":{"code":-32601,"message":"method not found"}}`
|
||||
script := writeMockServer(t, resp)
|
||||
tr := NewTransport("bash", []string{script}, nil, testLogger())
|
||||
|
||||
ctx := context.Background()
|
||||
if err := tr.Start(ctx); err != nil {
|
||||
t.Fatalf("Start: %v", err)
|
||||
}
|
||||
defer tr.Close()
|
||||
|
||||
_, err := tr.Call(ctx, "nonexistent", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for RPC error response")
|
||||
}
|
||||
|
||||
var rpcErr *RPCError
|
||||
if !errorAs(err, &rpcErr) {
|
||||
t.Fatalf("expected *RPCError, got %T: %v", err, err)
|
||||
}
|
||||
if rpcErr.Code != -32601 {
|
||||
t.Errorf("error code = %d, want -32601", rpcErr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransport_Call_Timeout(t *testing.T) {
|
||||
// Script that hangs forever.
|
||||
dir := t.TempDir()
|
||||
script := filepath.Join(dir, "hang.sh")
|
||||
if err := os.WriteFile(script, []byte("#!/bin/bash\nsleep 60\n"), 0o755); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
tr := NewTransport("bash", []string{script}, nil, testLogger())
|
||||
|
||||
ctx := context.Background()
|
||||
if err := tr.Start(ctx); err != nil {
|
||||
t.Fatalf("Start: %v", err)
|
||||
}
|
||||
defer tr.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, err := tr.Call(ctx, "tools/list", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected timeout error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransport_Call_EnvPassed(t *testing.T) {
|
||||
// Script that echoes an env var as the result.
|
||||
dir := t.TempDir()
|
||||
script := filepath.Join(dir, "env-echo.sh")
|
||||
content := `#!/bin/bash
|
||||
read -r line
|
||||
echo "{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"val\":\"$TEST_MCP_VAR\"}}"
|
||||
`
|
||||
if err := os.WriteFile(script, []byte(content), 0o755); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
env := map[string]string{"TEST_MCP_VAR": "hello_mcp"}
|
||||
tr := NewTransport("bash", []string{script}, env, testLogger())
|
||||
|
||||
ctx := context.Background()
|
||||
if err := tr.Start(ctx); err != nil {
|
||||
t.Fatalf("Start: %v", err)
|
||||
}
|
||||
defer tr.Close()
|
||||
|
||||
result, err := tr.Call(ctx, "test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Call: %v", err)
|
||||
}
|
||||
|
||||
var parsed struct {
|
||||
Val string `json:"val"`
|
||||
}
|
||||
if err := json.Unmarshal(result, &parsed); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if parsed.Val != "hello_mcp" {
|
||||
t.Errorf("val = %q, want %q", parsed.Val, "hello_mcp")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransport_Notify(t *testing.T) {
|
||||
// Notification doesn't expect a response, so the script just reads and exits.
|
||||
dir := t.TempDir()
|
||||
script := filepath.Join(dir, "notify.sh")
|
||||
content := `#!/bin/bash
|
||||
read -r line
|
||||
# Write the received line to a file so we can verify it was sent.
|
||||
echo "$line" > "` + filepath.Join(dir, "received.json") + `"
|
||||
`
|
||||
if err := os.WriteFile(script, []byte(content), 0o755); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
tr := NewTransport("bash", []string{script}, nil, testLogger())
|
||||
|
||||
ctx := context.Background()
|
||||
if err := tr.Start(ctx); err != nil {
|
||||
t.Fatalf("Start: %v", err)
|
||||
}
|
||||
|
||||
if err := tr.Notify(ctx, "initialized", nil); err != nil {
|
||||
t.Fatalf("Notify: %v", err)
|
||||
}
|
||||
|
||||
// Give the script a moment to write the file.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
tr.Close()
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(dir, "received.json"))
|
||||
if err != nil {
|
||||
t.Fatalf("read received: %v", err)
|
||||
}
|
||||
|
||||
var notif Notification
|
||||
if err := json.Unmarshal(data, ¬if); err != nil {
|
||||
t.Fatalf("unmarshal notification: %v", err)
|
||||
}
|
||||
if notif.Method != "initialized" {
|
||||
t.Errorf("method = %q, want %q", notif.Method, "initialized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransport_MultipleCalls(t *testing.T) {
|
||||
responses := []string{
|
||||
`{"jsonrpc":"2.0","id":1,"result":{"step":"first"}}`,
|
||||
`{"jsonrpc":"2.0","id":2,"result":{"step":"second"}}`,
|
||||
}
|
||||
script := writeMockServerMulti(t, responses)
|
||||
tr := NewTransport("bash", []string{script}, nil, testLogger())
|
||||
|
||||
ctx := context.Background()
|
||||
if err := tr.Start(ctx); err != nil {
|
||||
t.Fatalf("Start: %v", err)
|
||||
}
|
||||
defer tr.Close()
|
||||
|
||||
// First call.
|
||||
r1, err := tr.Call(ctx, "first", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Call 1: %v", err)
|
||||
}
|
||||
|
||||
var p1 struct{ Step string }
|
||||
json.Unmarshal(r1, &p1)
|
||||
if p1.Step != "first" {
|
||||
t.Errorf("call 1 step = %q, want %q", p1.Step, "first")
|
||||
}
|
||||
|
||||
// Second call.
|
||||
r2, err := tr.Call(ctx, "second", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Call 2: %v", err)
|
||||
}
|
||||
|
||||
var p2 struct{ Step string }
|
||||
json.Unmarshal(r2, &p2)
|
||||
if p2.Step != "second" {
|
||||
t.Errorf("call 2 step = %q, want %q", p2.Step, "second")
|
||||
}
|
||||
}
|
||||
|
||||
// errorAs is a local helper since errors.As requires a pointer to interface.
|
||||
func errorAs[T error](err error, target *T) bool {
|
||||
for err != nil {
|
||||
if t, ok := err.(T); ok {
|
||||
*target = t
|
||||
return true
|
||||
}
|
||||
if u, ok := err.(interface{ Unwrap() error }); ok {
|
||||
err = u.Unwrap()
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
12
internal/plugin/errors.go
Normal file
12
internal/plugin/errors.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package plugin
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrManifestNotFound = errors.New("plugin: plugin.json not found")
|
||||
ErrManifestInvalid = errors.New("plugin: invalid manifest")
|
||||
ErrAlreadyInstalled = errors.New("plugin: already installed")
|
||||
ErrNotFound = errors.New("plugin: not found")
|
||||
ErrVersionMismatch = errors.New("plugin: gnoma version does not satisfy constraint")
|
||||
ErrPathTraversal = errors.New("plugin: path traversal detected")
|
||||
)
|
||||
156
internal/plugin/loader.go
Normal file
156
internal/plugin/loader.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/config"
|
||||
)
|
||||
|
||||
// Plugin is a discovered, parsed plugin.
|
||||
type Plugin struct {
|
||||
Manifest Manifest
|
||||
Dir string // absolute path to plugin directory
|
||||
Scope string // "user" or "project"
|
||||
}
|
||||
|
||||
// SkillSource is a directory + source tag for skill.Registry.LoadDir.
|
||||
type SkillSource struct {
|
||||
Dir string
|
||||
Source string
|
||||
}
|
||||
|
||||
// LoadResult contains the merged capabilities from all loaded plugins.
|
||||
type LoadResult struct {
|
||||
Skills []SkillSource
|
||||
Hooks []config.HookConfig
|
||||
MCPServers []config.MCPServerConfig
|
||||
}
|
||||
|
||||
// Loader discovers and loads plugins from directories.
|
||||
type Loader struct {
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewLoader creates a plugin loader.
|
||||
func NewLoader(logger *slog.Logger) *Loader {
|
||||
return &Loader{logger: logger}
|
||||
}
|
||||
|
||||
// Discover scans global and project plugin directories, returning all valid plugins.
|
||||
// Project-scoped plugins override same-name global plugins.
|
||||
func (l *Loader) Discover(globalDir, projectDir string) ([]Plugin, error) {
|
||||
byName := make(map[string]Plugin)
|
||||
|
||||
// Global plugins first (user scope).
|
||||
l.scanDir(globalDir, "user", byName)
|
||||
|
||||
// Project plugins override global.
|
||||
l.scanDir(projectDir, "project", byName)
|
||||
|
||||
plugins := make([]Plugin, 0, len(byName))
|
||||
for _, p := range byName {
|
||||
plugins = append(plugins, p)
|
||||
}
|
||||
return plugins, nil
|
||||
}
|
||||
|
||||
// Load processes enabled plugins and extracts their capabilities.
|
||||
func (l *Loader) Load(plugins []Plugin, enabledSet map[string]bool) (LoadResult, error) {
|
||||
var result LoadResult
|
||||
|
||||
for _, p := range plugins {
|
||||
if !enabledSet[p.Manifest.Name] {
|
||||
l.logger.Debug("plugin disabled, skipping", "name", p.Manifest.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
l.logger.Debug("loading plugin", "name", p.Manifest.Name, "scope", p.Scope)
|
||||
|
||||
// Skills: resolve glob directories.
|
||||
for _, glob := range p.Manifest.Capabilities.Skills {
|
||||
// Use the directory portion of the glob as the skill source dir.
|
||||
skillDir := filepath.Join(p.Dir, filepath.Dir(glob))
|
||||
result.Skills = append(result.Skills, SkillSource{
|
||||
Dir: skillDir,
|
||||
Source: fmt.Sprintf("plugin:%s", p.Manifest.Name),
|
||||
})
|
||||
}
|
||||
|
||||
// Hooks: convert to config.HookConfig with resolved paths.
|
||||
for _, h := range p.Manifest.Capabilities.Hooks {
|
||||
execPath := h.Exec
|
||||
if execPath != "" && !filepath.IsAbs(execPath) {
|
||||
execPath = filepath.Join(p.Dir, execPath)
|
||||
}
|
||||
result.Hooks = append(result.Hooks, config.HookConfig{
|
||||
Name: h.Name,
|
||||
Event: h.Event,
|
||||
Type: h.Type,
|
||||
Exec: execPath,
|
||||
Timeout: h.Timeout,
|
||||
FailOpen: h.FailOpen,
|
||||
ToolPattern: h.ToolPattern,
|
||||
})
|
||||
}
|
||||
|
||||
// MCP servers: convert with resolved command paths.
|
||||
for _, s := range p.Manifest.Capabilities.MCPServers {
|
||||
cmd := s.Command
|
||||
if cmd != "" && !filepath.IsAbs(cmd) {
|
||||
cmd = filepath.Join(p.Dir, cmd)
|
||||
}
|
||||
result.MCPServers = append(result.MCPServers, config.MCPServerConfig{
|
||||
Name: s.Name,
|
||||
Command: cmd,
|
||||
Args: s.Args,
|
||||
Env: s.Env,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (l *Loader) scanDir(dir, scope string, byName map[string]Plugin) {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
// Missing directory is fine (not all users have plugins).
|
||||
return
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
pluginDir := filepath.Join(dir, entry.Name())
|
||||
manifestPath := filepath.Join(pluginDir, "plugin.json")
|
||||
|
||||
data, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
l.logger.Debug("skipping plugin dir (no manifest)", "dir", pluginDir)
|
||||
continue
|
||||
}
|
||||
|
||||
manifest, err := ParseManifest(data)
|
||||
if err != nil {
|
||||
l.logger.Warn("skipping plugin (invalid manifest)", "dir", pluginDir, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
byName[manifest.Name] = Plugin{
|
||||
Manifest: *manifest,
|
||||
Dir: pluginDir,
|
||||
Scope: scope,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// marshalJSON is a thin wrapper for tests.
|
||||
func marshalJSON(v any) ([]byte, error) {
|
||||
return json.Marshal(v)
|
||||
}
|
||||
262
internal/plugin/loader_test.go
Normal file
262
internal/plugin/loader_test.go
Normal file
@@ -0,0 +1,262 @@
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func testLogger() *slog.Logger {
|
||||
return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
}
|
||||
|
||||
// writePlugin creates a plugin directory with a plugin.json manifest.
|
||||
func writePlugin(t *testing.T, dir, name, version string, caps *Capabilities) {
|
||||
t.Helper()
|
||||
pluginDir := filepath.Join(dir, name)
|
||||
if err := os.MkdirAll(pluginDir, 0o755); err != nil {
|
||||
t.Fatalf("mkdir: %v", err)
|
||||
}
|
||||
|
||||
m := Manifest{Name: name, Version: version}
|
||||
if caps != nil {
|
||||
m.Capabilities = *caps
|
||||
}
|
||||
|
||||
data, _ := marshalManifest(m)
|
||||
if err := os.WriteFile(filepath.Join(pluginDir, "plugin.json"), data, 0o644); err != nil {
|
||||
t.Fatalf("write manifest: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// writePluginWithSkill creates a plugin with a skill file.
|
||||
func writePluginWithSkill(t *testing.T, dir, pluginName, skillName, skillContent string) {
|
||||
t.Helper()
|
||||
pluginDir := filepath.Join(dir, pluginName)
|
||||
skillsDir := filepath.Join(pluginDir, "skills")
|
||||
os.MkdirAll(skillsDir, 0o755)
|
||||
|
||||
m := Manifest{
|
||||
Name: pluginName,
|
||||
Version: "1.0.0",
|
||||
Capabilities: Capabilities{
|
||||
Skills: []string{"skills/*.md"},
|
||||
},
|
||||
}
|
||||
data, _ := marshalManifest(m)
|
||||
os.WriteFile(filepath.Join(pluginDir, "plugin.json"), data, 0o644)
|
||||
os.WriteFile(filepath.Join(skillsDir, skillName+".md"), []byte(skillContent), 0o644)
|
||||
}
|
||||
|
||||
func marshalManifest(m Manifest) ([]byte, error) {
|
||||
return marshalJSON(m)
|
||||
}
|
||||
|
||||
func TestLoader_Discover_Empty(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
loader := NewLoader(testLogger())
|
||||
|
||||
plugins, err := loader.Discover(filepath.Join(dir, "global"), filepath.Join(dir, "project"))
|
||||
if err != nil {
|
||||
t.Fatalf("Discover: %v", err)
|
||||
}
|
||||
if len(plugins) != 0 {
|
||||
t.Errorf("expected 0 plugins, got %d", len(plugins))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoader_Discover_GlobalPlugin(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
globalDir := filepath.Join(dir, "global")
|
||||
writePlugin(t, globalDir, "git-tools", "1.0.0", nil)
|
||||
|
||||
loader := NewLoader(testLogger())
|
||||
plugins, err := loader.Discover(globalDir, filepath.Join(dir, "project"))
|
||||
if err != nil {
|
||||
t.Fatalf("Discover: %v", err)
|
||||
}
|
||||
|
||||
if len(plugins) != 1 {
|
||||
t.Fatalf("expected 1 plugin, got %d", len(plugins))
|
||||
}
|
||||
if plugins[0].Manifest.Name != "git-tools" {
|
||||
t.Errorf("Name = %q, want %q", plugins[0].Manifest.Name, "git-tools")
|
||||
}
|
||||
if plugins[0].Scope != "user" {
|
||||
t.Errorf("Scope = %q, want %q", plugins[0].Scope, "user")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoader_Discover_ProjectOverridesGlobal(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
globalDir := filepath.Join(dir, "global")
|
||||
projectDir := filepath.Join(dir, "project")
|
||||
|
||||
writePlugin(t, globalDir, "shared", "1.0.0", nil)
|
||||
writePlugin(t, projectDir, "shared", "2.0.0", nil)
|
||||
|
||||
loader := NewLoader(testLogger())
|
||||
plugins, err := loader.Discover(globalDir, projectDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Discover: %v", err)
|
||||
}
|
||||
|
||||
if len(plugins) != 1 {
|
||||
t.Fatalf("expected 1 plugin (deduplicated), got %d", len(plugins))
|
||||
}
|
||||
if plugins[0].Manifest.Version != "2.0.0" {
|
||||
t.Errorf("Version = %q, want %q (project should override global)", plugins[0].Manifest.Version, "2.0.0")
|
||||
}
|
||||
if plugins[0].Scope != "project" {
|
||||
t.Errorf("Scope = %q, want %q", plugins[0].Scope, "project")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoader_Discover_SkipsInvalidManifest(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
globalDir := filepath.Join(dir, "global")
|
||||
|
||||
// Write a valid plugin.
|
||||
writePlugin(t, globalDir, "good", "1.0.0", nil)
|
||||
|
||||
// Write an invalid plugin (bad JSON).
|
||||
badDir := filepath.Join(globalDir, "bad")
|
||||
os.MkdirAll(badDir, 0o755)
|
||||
os.WriteFile(filepath.Join(badDir, "plugin.json"), []byte(`{invalid`), 0o644)
|
||||
|
||||
loader := NewLoader(testLogger())
|
||||
plugins, err := loader.Discover(globalDir, filepath.Join(dir, "project"))
|
||||
if err != nil {
|
||||
t.Fatalf("Discover: %v", err)
|
||||
}
|
||||
|
||||
if len(plugins) != 1 {
|
||||
t.Fatalf("expected 1 plugin (skipping invalid), got %d", len(plugins))
|
||||
}
|
||||
if plugins[0].Manifest.Name != "good" {
|
||||
t.Errorf("Name = %q, want %q", plugins[0].Manifest.Name, "good")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoader_Load_AllEnabled(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
globalDir := filepath.Join(dir, "global")
|
||||
writePluginWithSkill(t, globalDir, "test-plugin", "my-skill", "---\nname: my-skill\n---\nHello")
|
||||
|
||||
loader := NewLoader(testLogger())
|
||||
plugins, _ := loader.Discover(globalDir, filepath.Join(dir, "project"))
|
||||
|
||||
enabledSet := map[string]bool{"test-plugin": true}
|
||||
result, err := loader.Load(plugins, enabledSet)
|
||||
if err != nil {
|
||||
t.Fatalf("Load: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Skills) != 1 {
|
||||
t.Fatalf("expected 1 skill source, got %d", len(result.Skills))
|
||||
}
|
||||
if result.Skills[0].Source != "plugin:test-plugin" {
|
||||
t.Errorf("Source = %q, want %q", result.Skills[0].Source, "plugin:test-plugin")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoader_Load_DisabledPlugin(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
globalDir := filepath.Join(dir, "global")
|
||||
writePluginWithSkill(t, globalDir, "disabled-plugin", "skill", "---\nname: skill\n---\nHi")
|
||||
|
||||
loader := NewLoader(testLogger())
|
||||
plugins, _ := loader.Discover(globalDir, filepath.Join(dir, "project"))
|
||||
|
||||
// Plugin not in enabled set.
|
||||
result, err := loader.Load(plugins, map[string]bool{})
|
||||
if err != nil {
|
||||
t.Fatalf("Load: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Skills) != 0 {
|
||||
t.Errorf("expected 0 skills for disabled plugin, got %d", len(result.Skills))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoader_Load_HooksConverted(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
globalDir := filepath.Join(dir, "global")
|
||||
|
||||
caps := &Capabilities{
|
||||
Hooks: []HookSpec{
|
||||
{
|
||||
Name: "check",
|
||||
Event: "pre_tool_use",
|
||||
Type: "command",
|
||||
Exec: "scripts/check.sh",
|
||||
ToolPattern: "bash*",
|
||||
},
|
||||
},
|
||||
}
|
||||
writePlugin(t, globalDir, "hook-plugin", "1.0.0", caps)
|
||||
|
||||
loader := NewLoader(testLogger())
|
||||
plugins, _ := loader.Discover(globalDir, filepath.Join(dir, "project"))
|
||||
|
||||
result, err := loader.Load(plugins, map[string]bool{"hook-plugin": true})
|
||||
if err != nil {
|
||||
t.Fatalf("Load: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Hooks) != 1 {
|
||||
t.Fatalf("expected 1 hook, got %d", len(result.Hooks))
|
||||
}
|
||||
h := result.Hooks[0]
|
||||
if h.Name != "check" {
|
||||
t.Errorf("Hook name = %q", h.Name)
|
||||
}
|
||||
if h.Event != "pre_tool_use" {
|
||||
t.Errorf("Hook event = %q", h.Event)
|
||||
}
|
||||
// Exec should be resolved to absolute path under plugin dir.
|
||||
pluginDir := filepath.Join(globalDir, "hook-plugin")
|
||||
wantExec := filepath.Join(pluginDir, "scripts/check.sh")
|
||||
if h.Exec != wantExec {
|
||||
t.Errorf("Hook exec = %q, want %q", h.Exec, wantExec)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoader_Load_MCPServersConverted(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
globalDir := filepath.Join(dir, "global")
|
||||
|
||||
caps := &Capabilities{
|
||||
MCPServers: []MCPServerSpec{
|
||||
{
|
||||
Name: "git",
|
||||
Command: "bin/mcp-git",
|
||||
Args: []string{"--verbose"},
|
||||
},
|
||||
},
|
||||
}
|
||||
writePlugin(t, globalDir, "mcp-plugin", "1.0.0", caps)
|
||||
|
||||
loader := NewLoader(testLogger())
|
||||
plugins, _ := loader.Discover(globalDir, filepath.Join(dir, "project"))
|
||||
|
||||
result, err := loader.Load(plugins, map[string]bool{"mcp-plugin": true})
|
||||
if err != nil {
|
||||
t.Fatalf("Load: %v", err)
|
||||
}
|
||||
|
||||
if len(result.MCPServers) != 1 {
|
||||
t.Fatalf("expected 1 MCP server, got %d", len(result.MCPServers))
|
||||
}
|
||||
s := result.MCPServers[0]
|
||||
if s.Name != "git" {
|
||||
t.Errorf("Name = %q", s.Name)
|
||||
}
|
||||
// Command should be absolute path.
|
||||
pluginDir := filepath.Join(globalDir, "mcp-plugin")
|
||||
wantCmd := filepath.Join(pluginDir, "bin/mcp-git")
|
||||
if s.Command != wantCmd {
|
||||
t.Errorf("Command = %q, want %q", s.Command, wantCmd)
|
||||
}
|
||||
}
|
||||
154
internal/plugin/manager.go
Normal file
154
internal/plugin/manager.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// PluginInfo is a lightweight summary for listing plugins.
|
||||
type PluginInfo struct {
|
||||
Name string
|
||||
Version string
|
||||
Scope string
|
||||
Dir string
|
||||
}
|
||||
|
||||
// Manager handles plugin install/uninstall lifecycle.
|
||||
type Manager struct {
|
||||
globalDir string
|
||||
projectDir string
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewManager creates a plugin manager.
|
||||
func NewManager(globalDir, projectDir string, logger *slog.Logger) *Manager {
|
||||
return &Manager{
|
||||
globalDir: globalDir,
|
||||
projectDir: projectDir,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Install copies a plugin from src to the target scope directory.
|
||||
func (m *Manager) Install(src, scope string) error {
|
||||
manifestPath := filepath.Join(src, "plugin.json")
|
||||
data, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", ErrManifestNotFound, err)
|
||||
}
|
||||
|
||||
manifest, err := ParseManifest(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
targetDir := m.scopeDir(scope)
|
||||
destDir := filepath.Join(targetDir, manifest.Name)
|
||||
|
||||
if _, err := os.Stat(destDir); err == nil {
|
||||
return fmt.Errorf("%w: %q already exists at %s", ErrAlreadyInstalled, manifest.Name, destDir)
|
||||
}
|
||||
|
||||
if err := copyDir(src, destDir); err != nil {
|
||||
return fmt.Errorf("plugin install %q: %w", manifest.Name, err)
|
||||
}
|
||||
|
||||
m.logger.Info("plugin installed", "name", manifest.Name, "version", manifest.Version, "scope", scope)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Uninstall removes a plugin directory.
|
||||
func (m *Manager) Uninstall(name, scope string) error {
|
||||
targetDir := filepath.Join(m.scopeDir(scope), name)
|
||||
|
||||
if _, err := os.Stat(targetDir); os.IsNotExist(err) {
|
||||
return fmt.Errorf("%w: %q not found in %s scope", ErrNotFound, name, scope)
|
||||
}
|
||||
|
||||
if err := os.RemoveAll(targetDir); err != nil {
|
||||
return fmt.Errorf("plugin uninstall %q: %w", name, err)
|
||||
}
|
||||
|
||||
m.logger.Info("plugin uninstalled", "name", name, "scope", scope)
|
||||
return nil
|
||||
}
|
||||
|
||||
// List returns info about all installed plugins across both scopes.
|
||||
func (m *Manager) List() ([]PluginInfo, error) {
|
||||
var infos []PluginInfo
|
||||
m.listDir(m.globalDir, "user", &infos)
|
||||
m.listDir(m.projectDir, "project", &infos)
|
||||
return infos, nil
|
||||
}
|
||||
|
||||
func (m *Manager) scopeDir(scope string) string {
|
||||
if scope == "project" {
|
||||
return m.projectDir
|
||||
}
|
||||
return m.globalDir
|
||||
}
|
||||
|
||||
func (m *Manager) listDir(dir, scope string, infos *[]PluginInfo) {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
pluginDir := filepath.Join(dir, entry.Name())
|
||||
data, err := os.ReadFile(filepath.Join(pluginDir, "plugin.json"))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
manifest, err := ParseManifest(data)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
*infos = append(*infos, PluginInfo{
|
||||
Name: manifest.Name,
|
||||
Version: manifest.Version,
|
||||
Scope: scope,
|
||||
Dir: pluginDir,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// copyDir recursively copies a directory.
|
||||
func copyDir(src, dst string) error {
|
||||
return filepath.WalkDir(src, func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
relPath, err := filepath.Rel(src, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
targetPath := filepath.Join(dst, relPath)
|
||||
|
||||
if d.IsDir() {
|
||||
return os.MkdirAll(targetPath, 0o755)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
info, err := d.Info()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(targetPath, data, info.Mode())
|
||||
})
|
||||
}
|
||||
160
internal/plugin/manager_test.go
Normal file
160
internal/plugin/manager_test.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestManager_Install(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
globalDir := filepath.Join(dir, "global")
|
||||
projectDir := filepath.Join(dir, "project")
|
||||
os.MkdirAll(globalDir, 0o755)
|
||||
os.MkdirAll(projectDir, 0o755)
|
||||
|
||||
// Create a source plugin directory.
|
||||
srcDir := filepath.Join(dir, "src", "my-plugin")
|
||||
os.MkdirAll(srcDir, 0o755)
|
||||
m := Manifest{Name: "my-plugin", Version: "1.0.0", Description: "Test plugin"}
|
||||
data, _ := marshalJSON(m)
|
||||
os.WriteFile(filepath.Join(srcDir, "plugin.json"), data, 0o644)
|
||||
|
||||
mgr := NewManager(globalDir, projectDir, testLogger())
|
||||
|
||||
// Install to user scope.
|
||||
if err := mgr.Install(srcDir, "user"); err != nil {
|
||||
t.Fatalf("Install: %v", err)
|
||||
}
|
||||
|
||||
// Verify the plugin was copied.
|
||||
installed := filepath.Join(globalDir, "my-plugin", "plugin.json")
|
||||
if _, err := os.Stat(installed); err != nil {
|
||||
t.Errorf("installed manifest not found: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_Install_ProjectScope(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
globalDir := filepath.Join(dir, "global")
|
||||
projectDir := filepath.Join(dir, "project")
|
||||
os.MkdirAll(globalDir, 0o755)
|
||||
os.MkdirAll(projectDir, 0o755)
|
||||
|
||||
srcDir := filepath.Join(dir, "src", "proj-plugin")
|
||||
os.MkdirAll(srcDir, 0o755)
|
||||
m := Manifest{Name: "proj-plugin", Version: "1.0.0"}
|
||||
data, _ := marshalJSON(m)
|
||||
os.WriteFile(filepath.Join(srcDir, "plugin.json"), data, 0o644)
|
||||
|
||||
mgr := NewManager(globalDir, projectDir, testLogger())
|
||||
|
||||
if err := mgr.Install(srcDir, "project"); err != nil {
|
||||
t.Fatalf("Install: %v", err)
|
||||
}
|
||||
|
||||
installed := filepath.Join(projectDir, "proj-plugin", "plugin.json")
|
||||
if _, err := os.Stat(installed); err != nil {
|
||||
t.Errorf("installed manifest not found: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_Install_AlreadyInstalled(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
globalDir := filepath.Join(dir, "global")
|
||||
os.MkdirAll(globalDir, 0o755)
|
||||
|
||||
srcDir := filepath.Join(dir, "src", "dup")
|
||||
os.MkdirAll(srcDir, 0o755)
|
||||
m := Manifest{Name: "dup", Version: "1.0.0"}
|
||||
data, _ := marshalJSON(m)
|
||||
os.WriteFile(filepath.Join(srcDir, "plugin.json"), data, 0o644)
|
||||
|
||||
mgr := NewManager(globalDir, filepath.Join(dir, "project"), testLogger())
|
||||
|
||||
// First install.
|
||||
mgr.Install(srcDir, "user")
|
||||
|
||||
// Second install should fail.
|
||||
err := mgr.Install(srcDir, "user")
|
||||
if err == nil {
|
||||
t.Error("expected ErrAlreadyInstalled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_Install_NoManifest(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
globalDir := filepath.Join(dir, "global")
|
||||
os.MkdirAll(globalDir, 0o755)
|
||||
|
||||
srcDir := filepath.Join(dir, "src", "empty")
|
||||
os.MkdirAll(srcDir, 0o755)
|
||||
|
||||
mgr := NewManager(globalDir, filepath.Join(dir, "project"), testLogger())
|
||||
err := mgr.Install(srcDir, "user")
|
||||
if err == nil {
|
||||
t.Error("expected error for missing manifest")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_Uninstall(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
globalDir := filepath.Join(dir, "global")
|
||||
os.MkdirAll(globalDir, 0o755)
|
||||
|
||||
// Pre-install a plugin.
|
||||
pluginDir := filepath.Join(globalDir, "to-remove")
|
||||
os.MkdirAll(pluginDir, 0o755)
|
||||
m := Manifest{Name: "to-remove", Version: "1.0.0"}
|
||||
data, _ := marshalJSON(m)
|
||||
os.WriteFile(filepath.Join(pluginDir, "plugin.json"), data, 0o644)
|
||||
|
||||
mgr := NewManager(globalDir, filepath.Join(dir, "project"), testLogger())
|
||||
|
||||
if err := mgr.Uninstall("to-remove", "user"); err != nil {
|
||||
t.Fatalf("Uninstall: %v", err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(pluginDir); !os.IsNotExist(err) {
|
||||
t.Error("plugin directory should be removed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_Uninstall_NotFound(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
mgr := NewManager(filepath.Join(dir, "global"), filepath.Join(dir, "project"), testLogger())
|
||||
|
||||
err := mgr.Uninstall("nonexistent", "user")
|
||||
if err == nil {
|
||||
t.Error("expected ErrNotFound")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_List(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
globalDir := filepath.Join(dir, "global")
|
||||
projectDir := filepath.Join(dir, "project")
|
||||
|
||||
writePlugin(t, globalDir, "global-plugin", "1.0.0", nil)
|
||||
writePlugin(t, projectDir, "project-plugin", "2.0.0", nil)
|
||||
|
||||
mgr := NewManager(globalDir, projectDir, testLogger())
|
||||
|
||||
infos, err := mgr.List()
|
||||
if err != nil {
|
||||
t.Fatalf("List: %v", err)
|
||||
}
|
||||
|
||||
if len(infos) != 2 {
|
||||
t.Fatalf("expected 2 plugins, got %d", len(infos))
|
||||
}
|
||||
|
||||
// Check that both scopes are represented.
|
||||
scopes := map[string]bool{}
|
||||
for _, info := range infos {
|
||||
scopes[info.Scope] = true
|
||||
}
|
||||
if !scopes["user"] || !scopes["project"] {
|
||||
t.Errorf("expected both scopes, got %v", scopes)
|
||||
}
|
||||
}
|
||||
126
internal/plugin/manifest.go
Normal file
126
internal/plugin/manifest.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var namePattern = regexp.MustCompile(`^[a-z][a-z0-9_-]*$`)
|
||||
|
||||
// Manifest describes a plugin package.
|
||||
type Manifest struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
Description string `json:"description"`
|
||||
Author string `json:"author"`
|
||||
License string `json:"license"`
|
||||
GnomaVersion string `json:"gnoma_version"`
|
||||
Capabilities Capabilities `json:"capabilities"`
|
||||
}
|
||||
|
||||
// Capabilities declares what a plugin provides.
|
||||
type Capabilities struct {
|
||||
Skills []string `json:"skills"`
|
||||
Hooks []HookSpec `json:"hooks"`
|
||||
MCPServers []MCPServerSpec `json:"mcp_servers"`
|
||||
}
|
||||
|
||||
// HookSpec defines a hook within a plugin manifest.
|
||||
type HookSpec struct {
|
||||
Name string `json:"name"`
|
||||
Event string `json:"event"`
|
||||
Type string `json:"type"`
|
||||
Exec string `json:"exec"`
|
||||
Timeout string `json:"timeout"`
|
||||
FailOpen bool `json:"fail_open"`
|
||||
ToolPattern string `json:"tool_pattern"`
|
||||
}
|
||||
|
||||
// MCPServerSpec defines an MCP server within a plugin manifest.
|
||||
type MCPServerSpec struct {
|
||||
Name string `json:"name"`
|
||||
Command string `json:"command"`
|
||||
Args []string `json:"args"`
|
||||
Env map[string]string `json:"env"`
|
||||
}
|
||||
|
||||
// ParseManifest parses and validates a plugin.json file.
|
||||
func ParseManifest(data []byte) (*Manifest, error) {
|
||||
var m Manifest
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
return nil, fmt.Errorf("%w: %v", ErrManifestInvalid, err)
|
||||
}
|
||||
if err := m.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
// Validate checks manifest fields for correctness and safety.
|
||||
func (m *Manifest) Validate() error {
|
||||
if m.Name == "" {
|
||||
return fmt.Errorf("%w: name is required", ErrManifestInvalid)
|
||||
}
|
||||
if !namePattern.MatchString(m.Name) {
|
||||
return fmt.Errorf("%w: name %q must match %s", ErrManifestInvalid, m.Name, namePattern)
|
||||
}
|
||||
if m.Version == "" {
|
||||
return fmt.Errorf("%w: version is required", ErrManifestInvalid)
|
||||
}
|
||||
if !validSemver(m.Version) {
|
||||
return fmt.Errorf("%w: version %q is not valid semver (expected major.minor.patch)", ErrManifestInvalid, m.Version)
|
||||
}
|
||||
|
||||
for _, glob := range m.Capabilities.Skills {
|
||||
if err := checkSafePath(glob); err != nil {
|
||||
return fmt.Errorf("%w: skill glob %q: %v", ErrManifestInvalid, glob, err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, h := range m.Capabilities.Hooks {
|
||||
if h.Exec != "" {
|
||||
if err := checkSafePath(h.Exec); err != nil {
|
||||
return fmt.Errorf("%w: hook %q exec: %v", ErrManifestInvalid, h.Name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, s := range m.Capabilities.MCPServers {
|
||||
if s.Command != "" {
|
||||
if err := checkSafePath(s.Command); err != nil {
|
||||
return fmt.Errorf("%w: mcp_server %q command: %v", ErrManifestInvalid, s.Name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkSafePath rejects absolute paths and path traversal.
|
||||
func checkSafePath(p string) error {
|
||||
if filepath.IsAbs(p) {
|
||||
return fmt.Errorf("%w: absolute path not allowed", ErrPathTraversal)
|
||||
}
|
||||
if strings.Contains(p, "..") {
|
||||
return fmt.Errorf("%w: path traversal not allowed", ErrPathTraversal)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validSemver checks for strict major.minor.patch format.
|
||||
func validSemver(v string) bool {
|
||||
parts := strings.Split(v, ".")
|
||||
if len(parts) != 3 {
|
||||
return false
|
||||
}
|
||||
for _, p := range parts {
|
||||
if _, err := strconv.Atoi(p); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
207
internal/plugin/manifest_test.go
Normal file
207
internal/plugin/manifest_test.go
Normal file
@@ -0,0 +1,207 @@
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseManifest_Valid(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"name": "git-tools",
|
||||
"version": "1.0.0",
|
||||
"description": "Git integration for gnoma",
|
||||
"author": "vikingowl",
|
||||
"capabilities": {
|
||||
"skills": ["skills/*.md"],
|
||||
"hooks": [
|
||||
{
|
||||
"name": "lint-before-commit",
|
||||
"event": "pre_tool_use",
|
||||
"type": "command",
|
||||
"exec": "scripts/lint.sh",
|
||||
"tool_pattern": "bash*"
|
||||
}
|
||||
],
|
||||
"mcp_servers": [
|
||||
{
|
||||
"name": "git",
|
||||
"command": "bin/mcp-git",
|
||||
"args": ["--verbose"]
|
||||
}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
|
||||
m, err := ParseManifest(data)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseManifest: %v", err)
|
||||
}
|
||||
|
||||
if m.Name != "git-tools" {
|
||||
t.Errorf("Name = %q, want %q", m.Name, "git-tools")
|
||||
}
|
||||
if m.Version != "1.0.0" {
|
||||
t.Errorf("Version = %q, want %q", m.Version, "1.0.0")
|
||||
}
|
||||
if len(m.Capabilities.Skills) != 1 {
|
||||
t.Errorf("Skills count = %d, want 1", len(m.Capabilities.Skills))
|
||||
}
|
||||
if len(m.Capabilities.Hooks) != 1 {
|
||||
t.Errorf("Hooks count = %d, want 1", len(m.Capabilities.Hooks))
|
||||
}
|
||||
if len(m.Capabilities.MCPServers) != 1 {
|
||||
t.Errorf("MCPServers count = %d, want 1", len(m.Capabilities.MCPServers))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseManifest_Minimal(t *testing.T) {
|
||||
data := []byte(`{"name": "minimal", "version": "0.1.0"}`)
|
||||
m, err := ParseManifest(data)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseManifest: %v", err)
|
||||
}
|
||||
if m.Name != "minimal" {
|
||||
t.Errorf("Name = %q", m.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseManifest_InvalidJSON(t *testing.T) {
|
||||
_, err := ParseManifest([]byte(`not json`))
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManifest_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
m Manifest
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid",
|
||||
m: Manifest{Name: "my-plugin", Version: "1.0.0"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty name",
|
||||
m: Manifest{Version: "1.0.0"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid name uppercase",
|
||||
m: Manifest{Name: "MyPlugin", Version: "1.0.0"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid name starts with number",
|
||||
m: Manifest{Name: "1plugin", Version: "1.0.0"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty version",
|
||||
m: Manifest{Name: "my-plugin"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid version",
|
||||
m: Manifest{Name: "my-plugin", Version: "not-semver"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "skill glob path traversal",
|
||||
m: Manifest{
|
||||
Name: "bad",
|
||||
Version: "1.0.0",
|
||||
Capabilities: Capabilities{
|
||||
Skills: []string{"../../../etc/passwd"},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "skill glob absolute path",
|
||||
m: Manifest{
|
||||
Name: "bad",
|
||||
Version: "1.0.0",
|
||||
Capabilities: Capabilities{
|
||||
Skills: []string{"/etc/passwd"},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "hook exec path traversal",
|
||||
m: Manifest{
|
||||
Name: "bad",
|
||||
Version: "1.0.0",
|
||||
Capabilities: Capabilities{
|
||||
Hooks: []HookSpec{{
|
||||
Name: "h",
|
||||
Event: "pre_tool_use",
|
||||
Type: "command",
|
||||
Exec: "../../../bin/evil",
|
||||
}},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "mcp command path traversal",
|
||||
m: Manifest{
|
||||
Name: "bad",
|
||||
Version: "1.0.0",
|
||||
Capabilities: Capabilities{
|
||||
MCPServers: []MCPServerSpec{{
|
||||
Name: "evil",
|
||||
Command: "../../../bin/evil",
|
||||
}},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "valid name with hyphens and numbers",
|
||||
m: Manifest{Name: "my-plugin-2", Version: "0.1.0"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid name with underscores",
|
||||
m: Manifest{Name: "my_plugin", Version: "0.1.0"},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.m.Validate()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidSemver(t *testing.T) {
|
||||
tests := []struct {
|
||||
v string
|
||||
want bool
|
||||
}{
|
||||
{"1.0.0", true},
|
||||
{"0.1.0", true},
|
||||
{"12.34.56", true},
|
||||
{"1.0", false},
|
||||
{"1", false},
|
||||
{"v1.0.0", false},
|
||||
{"1.0.0-beta", false}, // strict semver only for v1
|
||||
{"", false},
|
||||
{"not-a-version", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.v, func(t *testing.T) {
|
||||
if got := validSemver(tt.v); got != tt.want {
|
||||
t.Errorf("validSemver(%q) = %v, want %v", tt.v, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -65,6 +65,15 @@ type Config struct {
|
||||
SessionStore *session.SessionStore // nil = no persistence
|
||||
StartWithResumePicker bool // open session picker on launch
|
||||
Skills *skill.Registry // nil = no skills loaded
|
||||
PluginInfos []PluginInfo // discovered plugins for /plugins command
|
||||
}
|
||||
|
||||
// PluginInfo is a summary of an installed plugin for TUI display.
|
||||
type PluginInfo struct {
|
||||
Name string
|
||||
Version string
|
||||
Scope string
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
@@ -813,7 +822,24 @@ func (m Model) handleCommand(cmd string) (tea.Model, tea.Cmd) {
|
||||
|
||||
case "/help":
|
||||
m.messages = append(m.messages, chatMessage{role: "system",
|
||||
content: "Commands:\n /init generate or update AGENTS.md project docs\n /clear, /new clear chat and start new conversation\n /config show current config\n /incognito toggle incognito (Ctrl+X)\n /model [name] list/switch models\n /permission [mode] set permission mode (Shift+Tab to cycle)\n /provider show current provider\n /resume [id] list or restore saved sessions\n /skills list loaded skills\n /shell interactive shell (coming soon)\n /help show this help\n /quit exit gnoma\n\nSkills (use /<name> [args] to invoke):\n Add .md files with YAML front matter to .gnoma/skills/ or ~/.config/gnoma/skills/"})
|
||||
content: "Commands:\n /init generate or update AGENTS.md project docs\n /clear, /new clear chat and start new conversation\n /config show current config\n /incognito toggle incognito (Ctrl+X)\n /model [name] list/switch models\n /permission [mode] set permission mode (Shift+Tab to cycle)\n /plugins list installed plugins\n /provider show current provider\n /resume [id] list or restore saved sessions\n /skills list loaded skills\n /shell interactive shell (coming soon)\n /help show this help\n /quit exit gnoma\n\nSkills (use /<name> [args] to invoke):\n Add .md files with YAML front matter to .gnoma/skills/ or ~/.config/gnoma/skills/"})
|
||||
return m, nil
|
||||
|
||||
case "/plugins":
|
||||
if len(m.config.PluginInfos) == 0 {
|
||||
m.messages = append(m.messages, chatMessage{role: "system", content: "No plugins installed."})
|
||||
return m, nil
|
||||
}
|
||||
var b strings.Builder
|
||||
b.WriteString("Installed plugins:\n")
|
||||
for _, p := range m.config.PluginInfos {
|
||||
status := "enabled"
|
||||
if !p.Enabled {
|
||||
status = "disabled"
|
||||
}
|
||||
b.WriteString(fmt.Sprintf(" %s v%s [%s] (%s)\n", p.Name, p.Version, p.Scope, status))
|
||||
}
|
||||
m.messages = append(m.messages, chatMessage{role: "system", content: b.String()})
|
||||
return m, nil
|
||||
|
||||
case "/skills":
|
||||
|
||||
Reference in New Issue
Block a user