feat: complete M1 — core engine with Mistral provider

Mistral provider adapter with streaming, tool calls (single-chunk
pattern), stop reason inference, model listing, capabilities, and
JSON output support.

Tool system: bash (7 security checks, shell alias harvesting for
bash/zsh/fish), file ops (read, write, edit, glob, grep, ls).
Alias harvesting collects 300+ aliases from user's shell config.

Engine agentic loop: stream → tool execution → re-query → until
done. Tool gating on model capabilities. Max turns safety limit.

CLI pipe mode: echo "prompt" | gnoma streams response to stdout.
Flags: --provider, --model, --system, --api-key, --max-turns,
--verbose, --version.

Provider interface expanded: Models(), DefaultModel(), Capabilities
(ToolUse, JSONOutput, Vision, Thinking, ContextWindow, MaxOutput),
ResponseFormat with JSON schema support.

Live verified: text streaming + tool calling with devstral-small.
117 tests across 8 packages, 10MB binary.
This commit is contained in:
2026-04-03 12:01:55 +02:00
parent 85c643fdca
commit f0633d8ac6
30 changed files with 4658 additions and 24 deletions

View File

@@ -1,7 +1,197 @@
package main
import "fmt"
import (
"context"
"flag"
"fmt"
"io"
"log/slog"
"os"
"os/signal"
"strings"
"somegit.dev/Owlibou/gnoma/internal/engine"
"somegit.dev/Owlibou/gnoma/internal/provider"
"somegit.dev/Owlibou/gnoma/internal/provider/mistral"
"somegit.dev/Owlibou/gnoma/internal/stream"
"somegit.dev/Owlibou/gnoma/internal/tool"
"somegit.dev/Owlibou/gnoma/internal/tool/bash"
"somegit.dev/Owlibou/gnoma/internal/tool/fs"
)
func main() {
fmt.Println("gnoma — provider-agnostic agentic coding assistant")
var (
providerName = flag.String("provider", "mistral", "LLM provider")
model = flag.String("model", "", "model name (empty = provider default)")
system = flag.String("system", defaultSystem, "system prompt")
apiKey = flag.String("api-key", "", "API key (or set MISTRAL_API_KEY env)")
maxTurns = flag.Int("max-turns", 50, "max tool-calling rounds per turn")
verbose = flag.Bool("verbose", false, "enable debug logging")
version = flag.Bool("version", false, "print version and exit")
)
flag.Parse()
if *version {
fmt.Println("gnoma v0.1.0-dev")
os.Exit(0)
}
// Logger
logLevel := slog.LevelWarn
if *verbose {
logLevel = slog.LevelDebug
}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: logLevel}))
// Resolve API key
key := *apiKey
if key == "" {
key = resolveAPIKey(*providerName)
}
if key == "" {
fmt.Fprintf(os.Stderr, "error: no API key for provider %q\nSet %s environment variable or use --api-key\n",
*providerName, envKeyFor(*providerName))
os.Exit(1)
}
// Create provider
prov, err := createProvider(*providerName, key, *model)
if err != nil {
fmt.Fprintf(os.Stderr, "error: %v\n", err)
os.Exit(1)
}
// Create tool registry
reg := buildToolRegistry()
// Harvest shell aliases
aliases, err := bash.HarvestAliases(context.Background())
if err != nil {
logger.Debug("alias harvest failed (non-fatal)", "error", err)
} else {
logger.Debug("harvested aliases", "count", aliases.Len())
}
// Re-register bash tool with aliases
reg.Register(bash.New(bash.WithAliases(aliases)))
// Create engine
eng, err := engine.New(engine.Config{
Provider: prov,
Tools: reg,
System: *system,
Model: *model,
MaxTurns: *maxTurns,
Logger: logger,
})
if err != nil {
fmt.Fprintf(os.Stderr, "error: %v\n", err)
os.Exit(1)
}
// Read input
input, err := readInput(flag.Args())
if err != nil {
fmt.Fprintf(os.Stderr, "error: %v\n", err)
os.Exit(1)
}
if input == "" {
fmt.Fprintln(os.Stderr, "error: no input provided")
fmt.Fprintln(os.Stderr, "usage: echo 'prompt' | gnoma")
fmt.Fprintln(os.Stderr, " or: gnoma 'prompt'")
os.Exit(1)
}
// Context with signal handling
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
defer cancel()
// Callback: stream text deltas to stdout
cb := func(evt stream.Event) {
if evt.Type == stream.EventTextDelta && evt.Text != "" {
fmt.Print(evt.Text)
}
}
// Submit and run
_, err = eng.Submit(ctx, input, cb)
fmt.Println() // final newline
if err != nil {
if ctx.Err() != nil {
fmt.Fprintln(os.Stderr, "\ninterrupted")
os.Exit(130)
}
fmt.Fprintf(os.Stderr, "error: %v\n", err)
os.Exit(1)
}
}
func readInput(args []string) (string, error) {
// Positional args
if len(args) > 0 {
return strings.Join(args, " "), nil
}
// Stdin (pipe mode)
stat, _ := os.Stdin.Stat()
if stat.Mode()&os.ModeCharDevice == 0 {
data, err := io.ReadAll(os.Stdin)
if err != nil {
return "", fmt.Errorf("reading stdin: %w", err)
}
return strings.TrimSpace(string(data)), nil
}
return "", nil
}
func resolveAPIKey(providerName string) string {
envVar := envKeyFor(providerName)
return os.Getenv(envVar)
}
func envKeyFor(providerName string) string {
switch providerName {
case "mistral":
return "MISTRAL_API_KEY"
case "anthropic":
return "ANTHROPIC_API_KEY"
case "openai":
return "OPENAI_API_KEY"
case "google":
return "GEMINI_API_KEY"
default:
return strings.ToUpper(providerName) + "_API_KEY"
}
}
func createProvider(name, apiKey, model string) (provider.Provider, error) {
cfg := provider.ProviderConfig{
APIKey: apiKey,
Model: model,
}
switch name {
case "mistral":
return mistral.New(cfg)
default:
return nil, fmt.Errorf("unknown provider %q (M1 supports: mistral)", name)
}
}
func buildToolRegistry() *tool.Registry {
reg := tool.NewRegistry()
reg.Register(bash.New())
reg.Register(fs.NewReadTool())
reg.Register(fs.NewWriteTool())
reg.Register(fs.NewEditTool())
reg.Register(fs.NewGlobTool())
reg.Register(fs.NewGrepTool())
reg.Register(fs.NewLSTool())
return reg
}
const defaultSystem = `You are gnoma, a provider-agnostic agentic coding assistant.
You help users with software engineering tasks by reading files, writing code, and executing commands.
Be concise and direct. Use tools when needed to accomplish the task.`

View File

@@ -116,9 +116,13 @@ depends_on: [vision]
- [ ] Model picker overlay
- [ ] In-app config editor (`/config` command)
- [ ] Incognito toggle (`/incognito` command)
- [ ] Interactive shell pane: `/shell` command or keybinding opens PTY-connected shell
- For commands needing user input (sudo, ssh, git push with auth, passwd prompts)
- Bash tool detects potentially interactive commands and suggests take-over
- PTY-based execution for flagged commands
- [ ] Session management (channel-based)
**Exit criteria:** Launch TUI, chat interactively, 6 permission modes work, config editable in-app, incognito toggleable.
**Exit criteria:** Launch TUI, chat interactively, 6 permission modes work, config editable in-app, incognito toggleable, `/shell` opens interactive terminal for password prompts.
## M6: Context Intelligence
@@ -219,7 +223,7 @@ depends_on: [vision]
**Exit criteria:** gnoma suggests a persistent task after 3+ repetitions. `/task release v1.2.0` executes a saved workflow.
## M12: Thinking & Structured Output
## M12: Thinking, Structured Output & Notebook
**Deliverables:**
@@ -227,6 +231,7 @@ depends_on: [vision]
- [ ] Thinking block streaming and TUI display
- [ ] Structured output with JSON schema validation
- [ ] Retry logic for schema validation failures
- [ ] NotebookEdit tool: read/write/edit Jupyter notebook cells (.ipynb)
## M13: Auth

View File

@@ -0,0 +1,7 @@
package engine
import "somegit.dev/Owlibou/gnoma/internal/stream"
// Callback receives streaming events for real-time UI updates.
// Called synchronously on the engine goroutine for each event.
type Callback func(stream.Event)

123
internal/engine/engine.go Normal file
View File

@@ -0,0 +1,123 @@
package engine
import (
"context"
"fmt"
"log/slog"
"somegit.dev/Owlibou/gnoma/internal/message"
"somegit.dev/Owlibou/gnoma/internal/provider"
"somegit.dev/Owlibou/gnoma/internal/tool"
)
// Config holds engine configuration.
type Config struct {
Provider provider.Provider
Tools *tool.Registry
System string // system prompt
Model string // override model (empty = provider default)
MaxTurns int // safety limit on tool loops (0 = unlimited)
Logger *slog.Logger
}
func (c Config) validate() error {
if c.Provider == nil {
return fmt.Errorf("engine: provider required")
}
if c.Tools == nil {
return fmt.Errorf("engine: tool registry required")
}
return nil
}
// Turn is the result of a complete agentic turn (may span multiple API calls).
type Turn struct {
Messages []message.Message // all messages produced (assistant + tool results)
Usage message.Usage // cumulative for all API calls in this turn
Rounds int // number of API round-trips
}
// Engine orchestrates the conversation.
type Engine struct {
cfg Config
history []message.Message
usage message.Usage
logger *slog.Logger
// Cached model capabilities, resolved lazily
modelCaps *provider.Capabilities
modelCapsFor string // model ID the cached caps are for
}
// New creates an engine.
func New(cfg Config) (*Engine, error) {
if err := cfg.validate(); err != nil {
return nil, err
}
logger := cfg.Logger
if logger == nil {
logger = slog.Default()
}
return &Engine{
cfg: cfg,
logger: logger,
}, nil
}
// resolveCapabilities returns the capabilities for the active model.
// Caches the result — re-resolves if the model changes.
func (e *Engine) resolveCapabilities(ctx context.Context) *provider.Capabilities {
model := e.cfg.Model
if model == "" {
model = e.cfg.Provider.DefaultModel()
}
// Return cached if same model
if e.modelCaps != nil && e.modelCapsFor == model {
return e.modelCaps
}
// Query provider for model list
models, err := e.cfg.Provider.Models(ctx)
if err != nil {
e.logger.Debug("failed to fetch model capabilities", "error", err)
return nil
}
for _, m := range models {
if m.ID == model {
e.modelCaps = &m.Capabilities
e.modelCapsFor = model
return e.modelCaps
}
}
e.logger.Debug("model not found in provider model list", "model", model)
return nil
}
// History returns the full conversation.
func (e *Engine) History() []message.Message {
return e.history
}
// Usage returns cumulative token usage.
func (e *Engine) Usage() message.Usage {
return e.usage
}
// SetProvider swaps the active provider (for dynamic switching).
func (e *Engine) SetProvider(p provider.Provider) {
e.cfg.Provider = p
}
// SetModel changes the model within the current provider.
func (e *Engine) SetModel(model string) {
e.cfg.Model = model
}
// Reset clears conversation history and usage.
func (e *Engine) Reset() {
e.history = nil
e.usage = message.Usage{}
}

View File

@@ -0,0 +1,475 @@
package engine
import (
"context"
"encoding/json"
"errors"
"fmt"
"testing"
"somegit.dev/Owlibou/gnoma/internal/message"
"somegit.dev/Owlibou/gnoma/internal/provider"
"somegit.dev/Owlibou/gnoma/internal/stream"
"somegit.dev/Owlibou/gnoma/internal/tool"
)
// --- Mock Provider ---
// mockProvider returns pre-configured streams for each call.
type mockProvider struct {
name string
calls int
streams []stream.Stream // one per call, consumed in order
}
func (m *mockProvider) Name() string { return m.name }
func (m *mockProvider) DefaultModel() string { return "mock-model" }
func (m *mockProvider) Models(_ context.Context) ([]provider.ModelInfo, error) {
return []provider.ModelInfo{{
ID: "mock-model", Name: "mock-model", Provider: m.name,
Capabilities: provider.Capabilities{ToolUse: true},
}}, nil
}
func (m *mockProvider) Stream(_ context.Context, _ provider.Request) (stream.Stream, error) {
if m.calls >= len(m.streams) {
return nil, fmt.Errorf("mock: no more streams (called %d times)", m.calls+1)
}
s := m.streams[m.calls]
m.calls++
return s, nil
}
// eventStream is a mock stream backed by a slice of events.
type eventStream struct {
events []stream.Event
idx int
stopReason message.StopReason
model string
}
func newEventStream(stopReason message.StopReason, model string, events ...stream.Event) *eventStream {
// Append a final event with stop reason
events = append(events, stream.Event{
Type: stream.EventTextDelta,
StopReason: stopReason,
Model: model,
})
return &eventStream{events: events, stopReason: stopReason, model: model}
}
func (s *eventStream) Next() bool {
if s.idx >= len(s.events) {
return false
}
s.idx++
return true
}
func (s *eventStream) Current() stream.Event {
return s.events[s.idx-1]
}
func (s *eventStream) Err() error { return nil }
func (s *eventStream) Close() error { return nil }
// --- Mock Tool ---
type mockTool struct {
name string
readOnly bool
execFn func(ctx context.Context, args json.RawMessage) (tool.Result, error)
}
func (m *mockTool) Name() string { return m.name }
func (m *mockTool) Description() string { return "mock tool" }
func (m *mockTool) Parameters() json.RawMessage { return json.RawMessage(`{"type":"object"}`) }
func (m *mockTool) IsReadOnly() bool { return m.readOnly }
func (m *mockTool) IsDestructive() bool { return false }
func (m *mockTool) Execute(ctx context.Context, args json.RawMessage) (tool.Result, error) {
if m.execFn != nil {
return m.execFn(ctx, args)
}
return tool.Result{Output: "mock output"}, nil
}
// --- Tests ---
func TestNew_ValidConfig(t *testing.T) {
e, err := New(Config{
Provider: &mockProvider{name: "test"},
Tools: tool.NewRegistry(),
})
if err != nil {
t.Fatalf("New: %v", err)
}
if e == nil {
t.Fatal("engine should not be nil")
}
}
func TestNew_MissingProvider(t *testing.T) {
_, err := New(Config{Tools: tool.NewRegistry()})
if err == nil {
t.Fatal("expected error for missing provider")
}
}
func TestNew_MissingTools(t *testing.T) {
_, err := New(Config{Provider: &mockProvider{name: "test"}})
if err == nil {
t.Fatal("expected error for missing tool registry")
}
}
func TestSubmit_SimpleTextResponse(t *testing.T) {
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
newEventStream(message.StopEndTurn, "test-model",
stream.Event{Type: stream.EventTextDelta, Text: "Hello "},
stream.Event{Type: stream.EventTextDelta, Text: "world!"},
stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 10, OutputTokens: 5}},
),
},
}
e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry()})
var events []stream.Event
turn, err := e.Submit(context.Background(), "hi", func(evt stream.Event) {
events = append(events, evt)
})
if err != nil {
t.Fatalf("Submit: %v", err)
}
// Check turn
if turn.Rounds != 1 {
t.Errorf("Rounds = %d, want 1", turn.Rounds)
}
if len(turn.Messages) != 1 {
t.Fatalf("len(Messages) = %d, want 1", len(turn.Messages))
}
if turn.Messages[0].TextContent() != "Hello world!" {
t.Errorf("TextContent = %q", turn.Messages[0].TextContent())
}
if turn.Usage.InputTokens != 10 {
t.Errorf("Usage.InputTokens = %d", turn.Usage.InputTokens)
}
// Check history
history := e.History()
if len(history) != 2 {
t.Fatalf("len(History) = %d, want 2 (user + assistant)", len(history))
}
if history[0].Role != message.RoleUser {
t.Errorf("History[0].Role = %q", history[0].Role)
}
if history[1].Role != message.RoleAssistant {
t.Errorf("History[1].Role = %q", history[1].Role)
}
// Check events were forwarded
if len(events) == 0 {
t.Error("callback should have received events")
}
}
func TestSubmit_ToolCallLoop(t *testing.T) {
reg := tool.NewRegistry()
reg.Register(&mockTool{
name: "bash",
execFn: func(_ context.Context, args json.RawMessage) (tool.Result, error) {
return tool.Result{Output: "file1.go\nfile2.go"}, nil
},
})
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
// Round 1: model calls a tool
newEventStream(message.StopToolUse, "model-1",
stream.Event{Type: stream.EventTextDelta, Text: "Let me list files."},
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "bash"},
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_1", Args: json.RawMessage(`{"command":"ls"}`)},
),
// Round 2: model responds with final answer
newEventStream(message.StopEndTurn, "model-1",
stream.Event{Type: stream.EventTextDelta, Text: "Found file1.go and file2.go."},
),
},
}
e, _ := New(Config{Provider: mp, Tools: reg})
turn, err := e.Submit(context.Background(), "list files", nil)
if err != nil {
t.Fatalf("Submit: %v", err)
}
if turn.Rounds != 2 {
t.Errorf("Rounds = %d, want 2", turn.Rounds)
}
// Messages: assistant (tool call), tool results, assistant (final)
if len(turn.Messages) != 3 {
t.Fatalf("len(Messages) = %d, want 3", len(turn.Messages))
}
// First message has tool call
if !turn.Messages[0].HasToolCalls() {
t.Error("Messages[0] should have tool calls")
}
// Second message is tool results
if turn.Messages[1].Role != message.RoleUser {
t.Errorf("Messages[1].Role = %q, want user (tool results)", turn.Messages[1].Role)
}
// Third message is final text
if turn.Messages[2].TextContent() != "Found file1.go and file2.go." {
t.Errorf("Messages[2].TextContent = %q", turn.Messages[2].TextContent())
}
// History: user + assistant(tool call) + tool results + assistant(final)
if len(e.History()) != 4 {
t.Errorf("len(History) = %d, want 4", len(e.History()))
}
// Provider called twice
if mp.calls != 2 {
t.Errorf("provider called %d times, want 2", mp.calls)
}
}
func TestSubmit_UnknownTool(t *testing.T) {
reg := tool.NewRegistry()
// Don't register any tools
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
// Model calls a tool that doesn't exist
newEventStream(message.StopToolUse, "",
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "nonexistent"},
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_1", Args: json.RawMessage(`{}`)},
),
// Model responds after seeing error
newEventStream(message.StopEndTurn, "",
stream.Event{Type: stream.EventTextDelta, Text: "Sorry, that tool doesn't exist."},
),
},
}
e, _ := New(Config{Provider: mp, Tools: reg})
turn, err := e.Submit(context.Background(), "do something", nil)
if err != nil {
t.Fatalf("Submit: %v", err)
}
// Should still complete — unknown tool returns error result, model sees it
if turn.Rounds != 2 {
t.Errorf("Rounds = %d, want 2", turn.Rounds)
}
}
func TestSubmit_ToolExecutionError(t *testing.T) {
reg := tool.NewRegistry()
reg.Register(&mockTool{
name: "failing",
execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) {
return tool.Result{}, errors.New("disk full")
},
})
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
newEventStream(message.StopToolUse, "",
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "failing"},
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_1", Args: json.RawMessage(`{}`)},
),
newEventStream(message.StopEndTurn, "",
stream.Event{Type: stream.EventTextDelta, Text: "The tool failed."},
),
},
}
e, _ := New(Config{Provider: mp, Tools: reg})
turn, err := e.Submit(context.Background(), "do it", nil)
if err != nil {
t.Fatalf("Submit: %v", err)
}
// Tool error is returned as error result, not a fatal error
if turn.Rounds != 2 {
t.Errorf("Rounds = %d, want 2", turn.Rounds)
}
}
func TestSubmit_MaxTurnsLimit(t *testing.T) {
reg := tool.NewRegistry()
reg.Register(&mockTool{name: "bash"})
// Provider always returns tool calls — would loop forever
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
newEventStream(message.StopToolUse, "",
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "bash"},
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_1", Args: json.RawMessage(`{}`)},
),
newEventStream(message.StopToolUse, "",
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_2", ToolCallName: "bash"},
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_2", Args: json.RawMessage(`{}`)},
),
newEventStream(message.StopToolUse, "",
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_3", ToolCallName: "bash"},
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_3", Args: json.RawMessage(`{}`)},
),
},
}
e, _ := New(Config{Provider: mp, Tools: reg, MaxTurns: 2})
_, err := e.Submit(context.Background(), "loop forever", nil)
if err == nil {
t.Fatal("expected error from max turns limit")
}
if mp.calls != 2 {
t.Errorf("provider called %d times, want 2 (limited)", mp.calls)
}
}
func TestSubmit_MultipleToolCalls(t *testing.T) {
reg := tool.NewRegistry()
reg.Register(&mockTool{
name: "bash",
execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) {
return tool.Result{Output: "bash output"}, nil
},
})
reg.Register(&mockTool{
name: "fs.read",
readOnly: true,
execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) {
return tool.Result{Output: "file content"}, nil
},
})
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
// Model calls two tools at once
newEventStream(message.StopToolUse, "",
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "bash"},
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_1", Args: json.RawMessage(`{"command":"ls"}`)},
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_2", ToolCallName: "fs.read"},
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_2", Args: json.RawMessage(`{"path":"go.mod"}`)},
),
newEventStream(message.StopEndTurn, "",
stream.Event{Type: stream.EventTextDelta, Text: "Done."},
),
},
}
e, _ := New(Config{Provider: mp, Tools: reg})
turn, err := e.Submit(context.Background(), "run both", nil)
if err != nil {
t.Fatalf("Submit: %v", err)
}
if turn.Rounds != 2 {
t.Errorf("Rounds = %d, want 2", turn.Rounds)
}
// Tool results message should have 2 results
toolMsg := turn.Messages[1] // assistant, tool_results, assistant
if len(toolMsg.Content) != 2 {
t.Errorf("tool results has %d content blocks, want 2", len(toolMsg.Content))
}
}
func TestSubmit_NilCallback(t *testing.T) {
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
newEventStream(message.StopEndTurn, "",
stream.Event{Type: stream.EventTextDelta, Text: "ok"},
),
},
}
e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry()})
// nil callback should not panic
turn, err := e.Submit(context.Background(), "test", nil)
if err != nil {
t.Fatalf("Submit: %v", err)
}
if turn.Rounds != 1 {
t.Errorf("Rounds = %d", turn.Rounds)
}
}
func TestEngine_Reset(t *testing.T) {
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
newEventStream(message.StopEndTurn, "",
stream.Event{Type: stream.EventTextDelta, Text: "first"},
stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 100}},
),
},
}
e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry()})
e.Submit(context.Background(), "hello", nil)
if len(e.History()) == 0 {
t.Fatal("history should not be empty before reset")
}
if e.Usage().InputTokens == 0 {
t.Fatal("usage should not be zero before reset")
}
e.Reset()
if len(e.History()) != 0 {
t.Errorf("history should be empty after reset, got %d", len(e.History()))
}
if e.Usage().InputTokens != 0 {
t.Errorf("usage should be zero after reset, got %d", e.Usage().InputTokens)
}
}
func TestSubmit_CumulativeUsage(t *testing.T) {
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
newEventStream(message.StopEndTurn, "",
stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 100, OutputTokens: 50}},
stream.Event{Type: stream.EventTextDelta, Text: "first"},
),
newEventStream(message.StopEndTurn, "",
stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 200, OutputTokens: 80}},
stream.Event{Type: stream.EventTextDelta, Text: "second"},
),
},
}
e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry()})
e.Submit(context.Background(), "one", nil)
e.Submit(context.Background(), "two", nil)
if e.Usage().InputTokens != 300 {
t.Errorf("cumulative InputTokens = %d, want 300", e.Usage().InputTokens)
}
if e.Usage().OutputTokens != 130 {
t.Errorf("cumulative OutputTokens = %d, want 130", e.Usage().OutputTokens)
}
}

204
internal/engine/loop.go Normal file
View File

@@ -0,0 +1,204 @@
package engine
import (
"context"
"encoding/json"
"fmt"
"somegit.dev/Owlibou/gnoma/internal/message"
"somegit.dev/Owlibou/gnoma/internal/provider"
"somegit.dev/Owlibou/gnoma/internal/stream"
)
// Submit sends a user message and runs the agentic loop to completion.
// The callback receives real-time streaming events.
func (e *Engine) Submit(ctx context.Context, input string, cb Callback) (*Turn, error) {
userMsg := message.NewUserText(input)
e.history = append(e.history, userMsg)
return e.runLoop(ctx, cb)
}
// SubmitMessages is like Submit but accepts pre-built messages.
func (e *Engine) SubmitMessages(ctx context.Context, msgs []message.Message, cb Callback) (*Turn, error) {
e.history = append(e.history, msgs...)
return e.runLoop(ctx, cb)
}
func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
turn := &Turn{}
for {
turn.Rounds++
if e.cfg.MaxTurns > 0 && turn.Rounds > e.cfg.MaxTurns {
return turn, fmt.Errorf("safety limit: %d rounds exceeded", e.cfg.MaxTurns)
}
// Build provider request (gates tools on model capabilities)
req := e.buildRequest(ctx)
e.logger.Debug("streaming request",
"provider", e.cfg.Provider.Name(),
"model", req.Model,
"messages", len(req.Messages),
"tools", len(req.Tools),
"round", turn.Rounds,
)
// Stream from provider
s, err := e.cfg.Provider.Stream(ctx, req)
if err != nil {
return nil, fmt.Errorf("provider stream: %w", err)
}
// Consume stream, forwarding events to callback
acc := stream.NewAccumulator()
var stopReason message.StopReason
var model string
for s.Next() {
evt := s.Current()
acc.Apply(evt)
// Capture stop reason and model from events
if evt.StopReason != "" {
stopReason = evt.StopReason
}
if evt.Model != "" {
model = evt.Model
}
if cb != nil {
cb(evt)
}
}
if err := s.Err(); err != nil {
s.Close()
return nil, fmt.Errorf("stream error: %w", err)
}
s.Close()
// Build response
resp := acc.Response(stopReason, model)
turn.Usage.Add(resp.Usage)
turn.Messages = append(turn.Messages, resp.Message)
e.history = append(e.history, resp.Message)
e.usage.Add(resp.Usage)
e.logger.Debug("turn response",
"stop_reason", resp.StopReason,
"tool_calls", len(resp.Message.ToolCalls()),
"round", turn.Rounds,
)
// Decide next action
switch resp.StopReason {
case message.StopEndTurn, message.StopMaxTokens, message.StopSequence:
return turn, nil
case message.StopToolUse:
results, err := e.executeTools(ctx, resp.Message.ToolCalls(), cb)
if err != nil {
return nil, fmt.Errorf("tool execution: %w", err)
}
toolMsg := message.NewToolResults(results...)
turn.Messages = append(turn.Messages, toolMsg)
e.history = append(e.history, toolMsg)
// Continue loop — re-query provider with tool results
default:
// Unknown stop reason or empty — treat as end of turn
return turn, nil
}
}
}
func (e *Engine) buildRequest(ctx context.Context) provider.Request {
req := provider.Request{
Model: e.cfg.Model,
SystemPrompt: e.cfg.System,
Messages: e.history,
}
// Only include tools if the model supports them
caps := e.resolveCapabilities(ctx)
if caps == nil || caps.ToolUse {
// nil caps = unknown model, include tools optimistically
for _, t := range e.cfg.Tools.All() {
req.Tools = append(req.Tools, provider.ToolDefinition{
Name: t.Name(),
Description: t.Description(),
Parameters: t.Parameters(),
})
}
} else {
e.logger.Debug("tools omitted — model does not support tool use",
"model", req.Model,
)
}
return req
}
func (e *Engine) executeTools(ctx context.Context, calls []message.ToolCall, cb Callback) ([]message.ToolResult, error) {
results := make([]message.ToolResult, 0, len(calls))
for _, call := range calls {
t, ok := e.cfg.Tools.Get(call.Name)
if !ok {
e.logger.Warn("unknown tool", "name", call.Name)
results = append(results, message.ToolResult{
ToolCallID: call.ID,
Content: fmt.Sprintf("unknown tool: %s", call.Name),
IsError: true,
})
continue
}
e.logger.Debug("executing tool", "name", call.Name, "id", call.ID)
result, err := t.Execute(ctx, call.Arguments)
if err != nil {
e.logger.Error("tool execution failed", "name", call.Name, "error", err)
results = append(results, message.ToolResult{
ToolCallID: call.ID,
Content: err.Error(),
IsError: true,
})
continue
}
// Emit tool result as a text delta event so the UI can show it
if cb != nil {
cb(stream.Event{
Type: stream.EventTextDelta,
Text: fmt.Sprintf("\n[tool:%s] %s\n", call.Name, truncate(result.Output, 500)),
})
}
results = append(results, message.ToolResult{
ToolCallID: call.ID,
Content: result.Output,
})
}
return results, nil
}
func truncate(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "..."
}
// toolDefFromTool converts a tool.Tool to provider.ToolDefinition.
// Unused currently but kept for reference when building tool definitions dynamically.
func toolDefFromJSON(name, description string, params json.RawMessage) provider.ToolDefinition {
return provider.ToolDefinition{
Name: name,
Description: description,
Parameters: params,
}
}

View File

@@ -0,0 +1,124 @@
package mistral
import (
"context"
"fmt"
"somegit.dev/Owlibou/gnoma/internal/provider"
"somegit.dev/Owlibou/gnoma/internal/stream"
mistralgo "somegit.dev/vikingowl/mistral-go-sdk"
"somegit.dev/vikingowl/mistral-go-sdk/model"
)
const defaultModel = "mistral-large-latest"
// Provider implements provider.Provider for the Mistral API.
type Provider struct {
client *mistralgo.Client
name string
model string
}
// New creates a Mistral provider from config.
func New(cfg provider.ProviderConfig) (provider.Provider, error) {
if cfg.APIKey == "" {
return nil, fmt.Errorf("mistral: api key required")
}
opts := []mistralgo.Option{}
if cfg.BaseURL != "" {
opts = append(opts, mistralgo.WithBaseURL(cfg.BaseURL))
}
client := mistralgo.NewClient(cfg.APIKey, opts...)
m := cfg.Model
if m == "" {
m = defaultModel
}
return &Provider{
client: client,
name: "mistral",
model: m,
}, nil
}
// Stream initiates a streaming chat completion request.
func (p *Provider) Stream(ctx context.Context, req provider.Request) (stream.Stream, error) {
m := req.Model
if m == "" {
m = p.model
}
cr := translateRequest(req)
cr.Model = m
raw, err := p.client.ChatCompleteStream(ctx, cr)
if err != nil {
return nil, p.wrapError(err)
}
return newMistralStream(raw), nil
}
// Name returns "mistral".
func (p *Provider) Name() string {
return p.name
}
// DefaultModel returns the configured default model.
func (p *Provider) DefaultModel() string {
return p.model
}
// Models lists available models from the Mistral API with capability metadata.
func (p *Provider) Models(ctx context.Context) ([]provider.ModelInfo, error) {
resp, err := p.client.ListModels(ctx, &model.ListParams{})
if err != nil {
return nil, p.wrapError(err)
}
var models []provider.ModelInfo
for _, m := range resp.Data {
models = append(models, provider.ModelInfo{
ID: m.ID,
Name: m.ID,
Provider: p.name,
Capabilities: inferCapabilities(m),
})
}
return models, nil
}
// inferCapabilities maps Mistral model metadata to gnoma capabilities.
func inferCapabilities(m model.ModelCard) provider.Capabilities {
caps := provider.Capabilities{
ToolUse: m.Capabilities.FunctionCalling,
Vision: m.Capabilities.Vision,
JSONOutput: m.Capabilities.CompletionChat, // all chat models support JSON output via ResponseFormat
ContextWindow: m.MaxContextLength,
MaxOutput: 8192, // reasonable default
}
return caps
}
func (p *Provider) wrapError(err error) error {
if apiErr, ok := err.(*mistralgo.APIError); ok {
kind, retryable := provider.ClassifyHTTPStatus(apiErr.StatusCode)
return &provider.ProviderError{
Kind: kind,
Provider: p.name,
StatusCode: apiErr.StatusCode,
Message: apiErr.Message,
Retryable: retryable,
Err: err,
}
}
return &provider.ProviderError{
Kind: provider.ErrTransient,
Provider: p.name,
Message: err.Error(),
Err: err,
}
}

View File

@@ -0,0 +1,248 @@
package mistral
import (
"encoding/json"
"somegit.dev/Owlibou/gnoma/internal/message"
"somegit.dev/Owlibou/gnoma/internal/stream"
mistralgo "somegit.dev/vikingowl/mistral-go-sdk"
"somegit.dev/vikingowl/mistral-go-sdk/chat"
)
// mistralStream adapts mistral's Stream[CompletionChunk] to gnoma's stream.Stream.
type mistralStream struct {
raw *mistralgo.Stream[chat.CompletionChunk]
cur stream.Event
err error
model string
// Track active tool calls for delta assembly
activeToolCalls map[int]*toolCallState // keyed by ToolCall.Index
// Deferred finish reason (when finish arrives on the same chunk as content)
pendingFinish *chat.FinishReason
pendingUsage *message.Usage // usage from a chunk that also had other data
emittedStop bool // true after we've emitted the synthetic stop event
hadToolCalls bool // true if any tool calls were emitted
}
type toolCallState struct {
id string
name string
args string // accumulated argument fragments
}
func newMistralStream(raw *mistralgo.Stream[chat.CompletionChunk]) *mistralStream {
return &mistralStream{
raw: raw,
activeToolCalls: make(map[int]*toolCallState),
}
}
func (s *mistralStream) Next() bool {
for s.raw.Next() {
chunk := s.raw.Current()
// Capture model from first chunk
if s.model == "" && chunk.Model != "" {
s.model = chunk.Model
}
// Store usage if present (may be on same chunk as tool calls or finish)
if chunk.Usage != nil {
s.pendingUsage = translateUsage(chunk.Usage)
}
if len(chunk.Choices) == 0 {
// Chunk with only usage and no choices — emit usage
if s.pendingUsage != nil {
s.cur = stream.Event{Type: stream.EventUsage, Usage: s.pendingUsage}
s.pendingUsage = nil
return true
}
continue
}
choice := chunk.Choices[0]
delta := choice.Delta
// Process text content first (even on chunks with finish reason)
text := delta.Content.String()
if text != "" {
s.cur = stream.Event{
Type: stream.EventTextDelta,
Text: text,
}
// If this chunk also has a finish reason, store it for next iteration
if choice.FinishReason != nil {
s.pendingFinish = choice.FinishReason
}
return true
}
// Tool call deltas
if len(delta.ToolCalls) > 0 {
// Store finish reason if present on same chunk as tool calls
if choice.FinishReason != nil {
s.pendingFinish = choice.FinishReason
}
for _, tc := range delta.ToolCalls {
existing, ok := s.activeToolCalls[tc.Index]
if !ok {
// New tool call
s.activeToolCalls[tc.Index] = &toolCallState{
id: tc.ID,
name: tc.Function.Name,
args: tc.Function.Arguments,
}
s.hadToolCalls = true
// If arguments are already complete (Mistral sends full args in one chunk),
// emit ToolCallDone directly instead of Start
if tc.Function.Arguments != "" && s.pendingFinish != nil {
s.cur = stream.Event{
Type: stream.EventToolCallDone,
ToolCallID: tc.ID,
ToolCallName: tc.Function.Name,
Args: json.RawMessage(tc.Function.Arguments),
}
// Remove from active — it's already done
delete(s.activeToolCalls, tc.Index)
return true
}
// Otherwise emit Start, accumulate deltas later
s.cur = stream.Event{
Type: stream.EventToolCallStart,
ToolCallID: tc.ID,
ToolCallName: tc.Function.Name,
}
return true
}
// Existing tool call — accumulate arguments, emit Delta
existing.args += tc.Function.Arguments
if tc.Function.Arguments != "" {
s.cur = stream.Event{
Type: stream.EventToolCallDelta,
ToolCallID: existing.id,
ArgDelta: tc.Function.Arguments,
}
return true
}
}
continue
}
// Check finish reason (from this chunk or pending from previous)
fr := choice.FinishReason
if fr == nil {
fr = s.pendingFinish
s.pendingFinish = nil
}
if fr != nil {
// Flush any pending tool calls as Done events
if *fr == chat.FinishReasonToolCalls {
for idx, tc := range s.activeToolCalls {
s.cur = stream.Event{
Type: stream.EventToolCallDone,
ToolCallID: tc.id,
Args: json.RawMessage(tc.args),
}
delete(s.activeToolCalls, idx)
s.pendingFinish = fr // re-store to flush remaining on next call
return true
}
}
// Final event with stop reason
s.cur = stream.Event{
Type: stream.EventTextDelta,
StopReason: translateFinishReason(fr),
Model: s.model,
}
return true
}
}
// Drain any pending finish reason that was stored with the last content chunk
if s.pendingFinish != nil {
fr := s.pendingFinish
s.pendingFinish = nil
// Flush pending tool calls
if *fr == chat.FinishReasonToolCalls {
for idx, tc := range s.activeToolCalls {
s.cur = stream.Event{
Type: stream.EventToolCallDone,
ToolCallID: tc.id,
Args: json.RawMessage(tc.args),
}
delete(s.activeToolCalls, idx)
s.pendingFinish = fr
return true
}
}
s.cur = stream.Event{
Type: stream.EventTextDelta,
StopReason: translateFinishReason(fr),
Model: s.model,
}
return true
}
// Emit any pending usage before the stop event
if s.pendingUsage != nil {
s.cur = stream.Event{Type: stream.EventUsage, Usage: s.pendingUsage}
s.pendingUsage = nil
return true
}
// Stream ended — emit inferred stop reason.
if !s.emittedStop {
s.emittedStop = true
// If we have pending tool calls, they ended with the stream
if len(s.activeToolCalls) > 0 {
for idx, tc := range s.activeToolCalls {
s.cur = stream.Event{
Type: stream.EventToolCallDone,
ToolCallID: tc.id,
Args: json.RawMessage(tc.args),
}
delete(s.activeToolCalls, idx)
return true
}
}
// Infer stop reason: if tool calls were emitted, it's ToolUse; otherwise EndTurn
stopReason := message.StopEndTurn
if s.hadToolCalls {
stopReason = message.StopToolUse
}
s.cur = stream.Event{
Type: stream.EventTextDelta,
StopReason: stopReason,
Model: s.model,
}
return true
}
s.err = s.raw.Err()
return false
}
func (s *mistralStream) Current() stream.Event {
return s.cur
}
func (s *mistralStream) Err() error {
return s.err
}
func (s *mistralStream) Close() error {
return s.raw.Close()
}

View File

@@ -0,0 +1,177 @@
package mistral
import (
"encoding/json"
"somegit.dev/Owlibou/gnoma/internal/message"
"somegit.dev/Owlibou/gnoma/internal/provider"
"somegit.dev/vikingowl/mistral-go-sdk/chat"
)
// --- gnoma → Mistral ---
func translateMessages(msgs []message.Message) []chat.Message {
out := make([]chat.Message, 0, len(msgs))
for _, m := range msgs {
out = append(out, translateMessage(m))
}
return out
}
func translateMessage(m message.Message) chat.Message {
switch m.Role {
case message.RoleSystem:
return &chat.SystemMessage{Content: chat.TextContent(m.TextContent())}
case message.RoleUser:
// Check if this is a tool results message
if len(m.Content) > 0 && m.Content[0].Type == message.ContentToolResult {
// Tool results must be sent as individual ToolMessages
// Return only the first; caller handles multi-result expansion
tr := m.Content[0].ToolResult
return &chat.ToolMessage{
ToolCallID: tr.ToolCallID,
Content: chat.TextContent(tr.Content),
}
}
return &chat.UserMessage{Content: chat.TextContent(m.TextContent())}
case message.RoleAssistant:
am := chat.AssistantMessage{
Content: chat.TextContent(m.TextContent()),
}
for _, tc := range m.ToolCalls() {
am.ToolCalls = append(am.ToolCalls, chat.ToolCall{
ID: tc.ID,
Type: "function",
Function: chat.FunctionCall{
Name: tc.Name,
Arguments: string(tc.Arguments),
},
})
}
return &am
default:
return &chat.UserMessage{Content: chat.TextContent(m.TextContent())}
}
}
// expandToolResults handles the case where a gnoma Message contains
// multiple ToolResults. Mistral expects one ToolMessage per result.
func expandToolResults(msgs []message.Message) []chat.Message {
out := make([]chat.Message, 0, len(msgs))
for _, m := range msgs {
if m.Role == message.RoleUser && len(m.Content) > 0 && m.Content[0].Type == message.ContentToolResult {
for _, c := range m.Content {
if c.Type == message.ContentToolResult && c.ToolResult != nil {
out = append(out, &chat.ToolMessage{
ToolCallID: c.ToolResult.ToolCallID,
Content: chat.TextContent(c.ToolResult.Content),
})
}
}
continue
}
out = append(out, translateMessage(m))
}
return out
}
func translateTools(defs []provider.ToolDefinition) []chat.Tool {
if len(defs) == 0 {
return nil
}
tools := make([]chat.Tool, len(defs))
for i, d := range defs {
var params map[string]any
if d.Parameters != nil {
_ = json.Unmarshal(d.Parameters, &params)
}
tools[i] = chat.Tool{
Type: "function",
Function: chat.Function{
Name: d.Name,
Description: d.Description,
Parameters: params,
},
}
}
return tools
}
func translateRequest(req provider.Request) *chat.CompletionRequest {
cr := &chat.CompletionRequest{
Model: req.Model,
Messages: expandToolResults(req.Messages),
Tools: translateTools(req.Tools),
Stop: req.StopSequences,
}
if req.MaxTokens > 0 {
mt := int(req.MaxTokens)
cr.MaxTokens = &mt
}
if req.Temperature != nil {
cr.Temperature = req.Temperature
}
if req.TopP != nil {
cr.TopP = req.TopP
}
if req.ResponseFormat != nil {
cr.ResponseFormat = translateResponseFormat(req.ResponseFormat)
}
return cr
}
func translateResponseFormat(rf *provider.ResponseFormat) *chat.ResponseFormat {
if rf == nil {
return nil
}
out := &chat.ResponseFormat{
Type: chat.ResponseFormatType(rf.Type),
}
if rf.JSONSchema != nil {
var schema map[string]any
if rf.JSONSchema.Schema != nil {
_ = json.Unmarshal(rf.JSONSchema.Schema, &schema)
}
out.JsonSchema = &chat.JsonSchema{
Name: rf.JSONSchema.Name,
Schema: schema,
Strict: rf.JSONSchema.Strict,
}
if rf.JSONSchema.Description != "" {
desc := rf.JSONSchema.Description
out.JsonSchema.Description = &desc
}
}
return out
}
// --- Mistral → gnoma ---
func translateFinishReason(fr *chat.FinishReason) message.StopReason {
if fr == nil {
return ""
}
switch *fr {
case chat.FinishReasonStop:
return message.StopEndTurn
case chat.FinishReasonToolCalls:
return message.StopToolUse
case chat.FinishReasonLength, chat.FinishReasonModelLength:
return message.StopMaxTokens
default:
return message.StopEndTurn
}
}
func translateUsage(u *chat.UsageInfo) *message.Usage {
if u == nil {
return nil
}
return &message.Usage{
InputTokens: int64(u.PromptTokens),
OutputTokens: int64(u.CompletionTokens),
}
}

View File

@@ -0,0 +1,256 @@
package mistral
import (
"encoding/json"
"testing"
"somegit.dev/Owlibou/gnoma/internal/message"
"somegit.dev/Owlibou/gnoma/internal/provider"
"somegit.dev/vikingowl/mistral-go-sdk/chat"
)
func TestTranslateMessage_User(t *testing.T) {
m := message.NewUserText("hello world")
result := translateMessage(m)
um, ok := result.(*chat.UserMessage)
if !ok {
t.Fatalf("expected *UserMessage, got %T", result)
}
if um.Content.String() != "hello world" {
t.Errorf("Content = %q, want %q", um.Content.String(), "hello world")
}
}
func TestTranslateMessage_System(t *testing.T) {
m := message.NewSystemText("you are a helper")
result := translateMessage(m)
sm, ok := result.(*chat.SystemMessage)
if !ok {
t.Fatalf("expected *SystemMessage, got %T", result)
}
if sm.Content.String() != "you are a helper" {
t.Errorf("Content = %q", sm.Content.String())
}
}
func TestTranslateMessage_AssistantText(t *testing.T) {
m := message.NewAssistantText("here's the answer")
result := translateMessage(m)
am, ok := result.(*chat.AssistantMessage)
if !ok {
t.Fatalf("expected *AssistantMessage, got %T", result)
}
if am.Content.String() != "here's the answer" {
t.Errorf("Content = %q", am.Content.String())
}
if len(am.ToolCalls) != 0 {
t.Errorf("ToolCalls should be empty, got %d", len(am.ToolCalls))
}
}
func TestTranslateMessage_AssistantWithToolCalls(t *testing.T) {
m := message.NewAssistantContent(
message.NewTextContent("running command"),
message.NewToolCallContent(message.ToolCall{
ID: "tc_1",
Name: "bash",
Arguments: json.RawMessage(`{"command":"ls"}`),
}),
)
result := translateMessage(m)
am, ok := result.(*chat.AssistantMessage)
if !ok {
t.Fatalf("expected *AssistantMessage, got %T", result)
}
if len(am.ToolCalls) != 1 {
t.Fatalf("len(ToolCalls) = %d, want 1", len(am.ToolCalls))
}
if am.ToolCalls[0].ID != "tc_1" {
t.Errorf("ToolCalls[0].ID = %q", am.ToolCalls[0].ID)
}
if am.ToolCalls[0].Function.Name != "bash" {
t.Errorf("ToolCalls[0].Function.Name = %q", am.ToolCalls[0].Function.Name)
}
if am.ToolCalls[0].Function.Arguments != `{"command":"ls"}` {
t.Errorf("ToolCalls[0].Function.Arguments = %q", am.ToolCalls[0].Function.Arguments)
}
}
func TestExpandToolResults(t *testing.T) {
msgs := []message.Message{
message.NewUserText("run two commands"),
message.NewAssistantContent(
message.NewToolCallContent(message.ToolCall{ID: "tc_1", Name: "bash"}),
message.NewToolCallContent(message.ToolCall{ID: "tc_2", Name: "bash"}),
),
message.NewToolResults(
message.ToolResult{ToolCallID: "tc_1", Content: "output1"},
message.ToolResult{ToolCallID: "tc_2", Content: "output2"},
),
}
expanded := expandToolResults(msgs)
// UserMessage, AssistantMessage, ToolMessage, ToolMessage
if len(expanded) != 4 {
t.Fatalf("len(expanded) = %d, want 4", len(expanded))
}
// First: UserMessage
if _, ok := expanded[0].(*chat.UserMessage); !ok {
t.Errorf("expanded[0] = %T, want *UserMessage", expanded[0])
}
// Second: AssistantMessage
if _, ok := expanded[1].(*chat.AssistantMessage); !ok {
t.Errorf("expanded[1] = %T, want *AssistantMessage", expanded[1])
}
// Third and fourth: ToolMessages
tm1, ok := expanded[2].(*chat.ToolMessage)
if !ok {
t.Fatalf("expanded[2] = %T, want *ToolMessage", expanded[2])
}
if tm1.ToolCallID != "tc_1" {
t.Errorf("expanded[2].ToolCallID = %q, want tc_1", tm1.ToolCallID)
}
tm2, ok := expanded[3].(*chat.ToolMessage)
if !ok {
t.Fatalf("expanded[3] = %T, want *ToolMessage", expanded[3])
}
if tm2.ToolCallID != "tc_2" {
t.Errorf("expanded[3].ToolCallID = %q, want tc_2", tm2.ToolCallID)
}
}
func TestTranslateTools(t *testing.T) {
defs := []provider.ToolDefinition{
{
Name: "bash",
Description: "Run a bash command",
Parameters: json.RawMessage(`{"type":"object","properties":{"command":{"type":"string"}},"required":["command"]}`),
},
{
Name: "fs.read",
Description: "Read a file",
Parameters: json.RawMessage(`{"type":"object","properties":{"path":{"type":"string"}}}`),
},
}
tools := translateTools(defs)
if len(tools) != 2 {
t.Fatalf("len(tools) = %d, want 2", len(tools))
}
if tools[0].Type != "function" {
t.Errorf("tools[0].Type = %q, want function", tools[0].Type)
}
if tools[0].Function.Name != "bash" {
t.Errorf("tools[0].Function.Name = %q", tools[0].Function.Name)
}
if tools[0].Function.Description != "Run a bash command" {
t.Errorf("tools[0].Function.Description = %q", tools[0].Function.Description)
}
if tools[0].Function.Parameters == nil {
t.Error("tools[0].Function.Parameters should not be nil")
}
// Verify the parameters were correctly unmarshaled
if _, ok := tools[0].Function.Parameters["type"]; !ok {
t.Error("tools[0].Function.Parameters missing 'type' key")
}
}
func TestTranslateTools_Empty(t *testing.T) {
tools := translateTools(nil)
if tools != nil {
t.Errorf("translateTools(nil) should return nil, got %v", tools)
}
}
func TestTranslateFinishReason(t *testing.T) {
tests := []struct {
name string
reason *chat.FinishReason
want message.StopReason
}{
{"nil", nil, ""},
{"stop", ptr(chat.FinishReasonStop), message.StopEndTurn},
{"tool_calls", ptr(chat.FinishReasonToolCalls), message.StopToolUse},
{"length", ptr(chat.FinishReasonLength), message.StopMaxTokens},
{"model_length", ptr(chat.FinishReasonModelLength), message.StopMaxTokens},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := translateFinishReason(tt.reason)
if got != tt.want {
t.Errorf("translateFinishReason() = %q, want %q", got, tt.want)
}
})
}
}
func TestTranslateUsage(t *testing.T) {
u := &chat.UsageInfo{
PromptTokens: 100,
CompletionTokens: 50,
TotalTokens: 150,
}
result := translateUsage(u)
if result.InputTokens != 100 {
t.Errorf("InputTokens = %d, want 100", result.InputTokens)
}
if result.OutputTokens != 50 {
t.Errorf("OutputTokens = %d, want 50", result.OutputTokens)
}
}
func TestTranslateUsage_Nil(t *testing.T) {
result := translateUsage(nil)
if result != nil {
t.Error("translateUsage(nil) should return nil")
}
}
func TestTranslateRequest(t *testing.T) {
temp := 0.7
req := provider.Request{
Model: "mistral-large-latest",
SystemPrompt: "you are helpful",
Messages: []message.Message{
message.NewSystemText("you are helpful"),
message.NewUserText("hello"),
},
Tools: []provider.ToolDefinition{
{Name: "bash", Description: "Run command", Parameters: json.RawMessage(`{"type":"object"}`)},
},
MaxTokens: 4096,
Temperature: &temp,
}
cr := translateRequest(req)
if cr.Model != "mistral-large-latest" {
t.Errorf("Model = %q", cr.Model)
}
if len(cr.Messages) != 2 {
t.Errorf("len(Messages) = %d, want 2", len(cr.Messages))
}
if len(cr.Tools) != 1 {
t.Errorf("len(Tools) = %d, want 1", len(cr.Tools))
}
if cr.MaxTokens == nil || *cr.MaxTokens != 4096 {
t.Errorf("MaxTokens = %v", cr.MaxTokens)
}
if cr.Temperature == nil || *cr.Temperature != 0.7 {
t.Errorf("Temperature = %v", cr.Temperature)
}
}
func ptr[T any](v T) *T {
return &v
}

View File

@@ -20,6 +20,7 @@ type Request struct {
TopK *int64
StopSequences []string
Thinking *ThinkingConfig
ResponseFormat *ResponseFormat
}
// ToolDefinition is the provider-agnostic tool schema.
@@ -34,6 +35,50 @@ type ThinkingConfig struct {
BudgetTokens int64
}
// ResponseFormat controls the output format.
type ResponseFormat struct {
Type ResponseFormatType
JSONSchema *JSONSchema // only used when Type == ResponseJSON
}
type ResponseFormatType string
const (
ResponseText ResponseFormatType = "text"
ResponseJSON ResponseFormatType = "json_object"
)
// JSONSchema defines a schema for structured JSON output.
type JSONSchema struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Schema json.RawMessage `json:"schema"`
Strict bool `json:"strict,omitempty"`
}
// Capabilities describes what a model can do.
type Capabilities struct {
ToolUse bool `json:"tool_use"`
JSONOutput bool `json:"json_output"`
Thinking bool `json:"thinking"`
Vision bool `json:"vision"`
ContextWindow int `json:"context_window"`
MaxOutput int `json:"max_output"`
}
// ModelInfo describes a model available from a provider.
type ModelInfo struct {
ID string `json:"id"`
Name string `json:"name"`
Provider string `json:"provider"`
Capabilities Capabilities `json:"capabilities"`
}
// SupportsTools returns true if the model supports tool/function calling.
func (m ModelInfo) SupportsTools() bool {
return m.Capabilities.ToolUse
}
// Provider is the core abstraction over all LLM backends.
type Provider interface {
// Stream initiates a streaming request and returns an event stream.
@@ -41,4 +86,10 @@ type Provider interface {
// Name returns the provider identifier (e.g., "mistral", "anthropic").
Name() string
// Models returns available models with their capabilities.
Models(ctx context.Context) ([]ModelInfo, error)
// DefaultModel returns the default model ID for this provider.
DefaultModel() string
}

View File

@@ -19,8 +19,10 @@ func (m *mockProvider) Stream(_ context.Context, _ Request) (stream.Stream, erro
return nil, nil
}
func (m *mockProvider) Name() string {
return m.name
func (m *mockProvider) Name() string { return m.name }
func (m *mockProvider) DefaultModel() string { return "mock-model" }
func (m *mockProvider) Models(_ context.Context) ([]ModelInfo, error) {
return []ModelInfo{{ID: "mock-model", Name: "mock-model", Provider: m.name}}, nil
}
func TestRegistry_RegisterAndCreate(t *testing.T) {

View File

@@ -68,7 +68,13 @@ func (a *Accumulator) Apply(e Event) {
}
case EventToolCallDone:
if tc, ok := a.toolCalls[e.ToolCallID]; ok {
tc, ok := a.toolCalls[e.ToolCallID]
if !ok {
// Done without prior Start (e.g., Mistral sends complete tool calls in one chunk)
tc = &toolCallAccum{id: e.ToolCallID, name: e.ToolCallName}
a.toolCalls[e.ToolCallID] = tc
a.toolCallOrder = append(a.toolCallOrder, e.ToolCallID)
}
if e.Args != nil {
// Done event carries authoritative complete args
tc.args = e.Args
@@ -76,7 +82,6 @@ func (a *Accumulator) Apply(e Event) {
// Fall back to accumulated deltas
tc.args = []byte(tc.argsBuf.String())
}
}
case EventUsage:
if e.Usage != nil {

View File

@@ -0,0 +1,231 @@
package bash
import (
"context"
"fmt"
"os"
"os/exec"
"strings"
"sync"
"time"
)
const aliasHarvestTimeout = 5 * time.Second
// AliasMap holds harvested shell aliases.
type AliasMap struct {
mu sync.RWMutex
aliases map[string]string // alias name → expansion
}
func NewAliasMap() *AliasMap {
return &AliasMap{aliases: make(map[string]string)}
}
// Get returns the expansion for an alias, or empty string if not found.
func (m *AliasMap) Get(name string) (string, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
exp, ok := m.aliases[name]
return exp, ok
}
// Len returns the number of harvested aliases.
func (m *AliasMap) Len() int {
m.mu.RLock()
defer m.mu.RUnlock()
return len(m.aliases)
}
// All returns a copy of all aliases.
func (m *AliasMap) All() map[string]string {
m.mu.RLock()
defer m.mu.RUnlock()
cp := make(map[string]string, len(m.aliases))
for k, v := range m.aliases {
cp[k] = v
}
return cp
}
// ExpandCommand expands the first word of a command if it's a known alias.
// Only the first word is expanded (matching bash alias behavior).
// Returns the original command unchanged if no alias matches.
func (m *AliasMap) ExpandCommand(cmd string) string {
trimmed := strings.TrimSpace(cmd)
if trimmed == "" {
return cmd
}
// Extract first word
firstWord := trimmed
rest := ""
if idx := strings.IndexAny(trimmed, " \t"); idx != -1 {
firstWord = trimmed[:idx]
rest = trimmed[idx:]
}
m.mu.RLock()
expansion, ok := m.aliases[firstWord]
m.mu.RUnlock()
if !ok {
return cmd
}
return expansion + rest
}
// HarvestAliases spawns the user's shell once to collect alias definitions.
// Supports bash, zsh, and fish. Falls back gracefully for unknown shells.
// Safe: only reads alias text definitions, never sources them in execution context.
func HarvestAliases(ctx context.Context) (*AliasMap, error) {
shell := os.Getenv("SHELL")
if shell == "" {
shell = "/bin/bash"
}
ctx, cancel := context.WithTimeout(ctx, aliasHarvestTimeout)
defer cancel()
// Build the alias dump command based on shell type
shellBase := shellBaseName(shell)
aliasCmd := aliasCommandFor(shellBase)
// -i: interactive (loads rc files), -c: run command then exit
cmd := exec.CommandContext(ctx, shell, "-ic", aliasCmd)
// Prevent the interactive shell from reading actual stdin
cmd.Stdin = nil
// Suppress stderr (shell startup warnings like zsh's "can't change option: zle")
cmd.Stderr = nil
// Use Output() but don't fail on non-zero exit — zsh often exits with
// errors from zle/prompt setup while still producing valid alias output
output, err := cmd.Output()
if len(output) == 0 && err != nil {
return NewAliasMap(), fmt.Errorf("alias harvest (%s): %w", shellBase, err)
}
// If we got output, parse it regardless of exit code
if shellBase == "fish" {
return ParseFishAliases(string(output))
}
return ParseAliases(string(output))
}
// shellBaseName extracts the shell name from a path (e.g., "/bin/zsh" → "zsh").
func shellBaseName(shell string) string {
parts := strings.Split(shell, "/")
return parts[len(parts)-1]
}
// aliasCommandFor returns the alias dump command for a given shell.
func aliasCommandFor(shell string) string {
switch shell {
case "fish":
// fish uses `alias` without -p, outputs: alias name 'expansion'
return "alias 2>/dev/null; true"
case "zsh":
// zsh: `alias -p` produces nothing; `alias` outputs name=value (no quotes)
return "alias 2>/dev/null; true"
case "bash", "sh", "dash", "ash":
// POSIX shells use `alias -p`
return "alias -p 2>/dev/null; true"
default:
// Best effort for unknown shells
return "alias 2>/dev/null; true"
}
}
// ParseFishAliases parses fish shell alias output.
// Fish format: alias name 'expansion' or alias name "expansion"
func ParseFishAliases(output string) (*AliasMap, error) {
m := NewAliasMap()
for _, line := range strings.Split(output, "\n") {
line = strings.TrimSpace(line)
if line == "" || !strings.HasPrefix(line, "alias ") {
continue
}
// Remove "alias " prefix
rest := strings.TrimPrefix(line, "alias ")
// Split: name 'expansion' or name "expansion" or name expansion
spaceIdx := strings.IndexByte(rest, ' ')
if spaceIdx == -1 {
continue
}
name := rest[:spaceIdx]
expansion := strings.TrimSpace(rest[spaceIdx+1:])
expansion = stripQuotes(expansion)
if name == "" || expansion == "" {
continue
}
if v := ValidateCommand(expansion); v != nil {
continue
}
m.mu.Lock()
m.aliases[name] = expansion
m.mu.Unlock()
}
return m, nil
}
// ParseAliases parses the output of `alias -p` into an AliasMap.
// Each line is: alias name='expansion' (bash) or name=expansion (zsh)
func ParseAliases(output string) (*AliasMap, error) {
m := NewAliasMap()
for _, line := range strings.Split(output, "\n") {
line = strings.TrimSpace(line)
if line == "" {
continue
}
// Strip "alias " prefix if present (bash format)
line = strings.TrimPrefix(line, "alias ")
// Split on first '='
eqIdx := strings.Index(line, "=")
if eqIdx == -1 {
continue
}
name := line[:eqIdx]
expansion := line[eqIdx+1:]
// Strip surrounding quotes from expansion
expansion = stripQuotes(expansion)
if name == "" || expansion == "" {
continue
}
// Security: validate the expansion doesn't contain dangerous patterns
if v := ValidateCommand(expansion); v != nil {
// Skip aliases with dangerous expansions
continue
}
m.mu.Lock()
m.aliases[name] = expansion
m.mu.Unlock()
}
return m, nil
}
// stripQuotes removes matching surrounding single or double quotes.
func stripQuotes(s string) string {
if len(s) < 2 {
return s
}
if (s[0] == '\'' && s[len(s)-1] == '\'') || (s[0] == '"' && s[len(s)-1] == '"') {
return s[1 : len(s)-1]
}
return s
}

View File

@@ -0,0 +1,288 @@
package bash
import (
"context"
"testing"
)
func TestParseAliases_BashFormat(t *testing.T) {
output := `alias gs='git status'
alias ll='ls -la --color=auto'
alias gco='git checkout'
alias ..='cd ..'
`
m, err := ParseAliases(output)
if err != nil {
t.Fatalf("ParseAliases: %v", err)
}
if m.Len() != 4 {
t.Errorf("Len() = %d, want 4", m.Len())
}
tests := []struct {
name, want string
}{
{"gs", "git status"},
{"ll", "ls -la --color=auto"},
{"gco", "git checkout"},
{"..", "cd .."},
}
for _, tt := range tests {
got, ok := m.Get(tt.name)
if !ok {
t.Errorf("alias %q not found", tt.name)
continue
}
if got != tt.want {
t.Errorf("alias %q = %q, want %q", tt.name, got, tt.want)
}
}
}
func TestParseAliases_ZshFormat(t *testing.T) {
// zsh alias -p may omit 'alias ' prefix
output := `gs='git status'
ll='ls -la'
`
m, err := ParseAliases(output)
if err != nil {
t.Fatalf("ParseAliases: %v", err)
}
got, ok := m.Get("gs")
if !ok || got != "git status" {
t.Errorf("gs = %q, %v", got, ok)
}
}
func TestParseAliases_DoubleQuotes(t *testing.T) {
output := `alias gs="git status"
`
m, _ := ParseAliases(output)
got, ok := m.Get("gs")
if !ok || got != "git status" {
t.Errorf("gs = %q, %v", got, ok)
}
}
func TestParseAliases_SkipsDangerousExpansions(t *testing.T) {
output := `alias safe='ls -la'
alias danger='echo $(whoami)'
alias backtick='echo ` + "`" + `date` + "`" + `'
alias ifshack='IFS=: read a b'
`
m, _ := ParseAliases(output)
if _, ok := m.Get("safe"); !ok {
t.Error("safe alias should be kept")
}
if _, ok := m.Get("danger"); ok {
t.Error("danger alias ($()) should be filtered")
}
if _, ok := m.Get("backtick"); ok {
t.Error("backtick alias should be filtered")
}
if _, ok := m.Get("ifshack"); ok {
t.Error("IFS alias should be filtered")
}
}
func TestParseAliases_EmptyAndMalformed(t *testing.T) {
output := `
alias gs='git status'
not a valid line
alias =empty_name
alias noequals
`
m, _ := ParseAliases(output)
if m.Len() != 1 {
t.Errorf("Len() = %d, want 1 (only gs)", m.Len())
}
}
func TestAliasMap_ExpandCommand(t *testing.T) {
m := NewAliasMap()
m.mu.Lock()
m.aliases["ll"] = "ls -la --color=auto"
m.aliases["gs"] = "git status"
m.aliases[".."] = "cd .."
m.mu.Unlock()
tests := []struct {
input string
want string
}{
// Alias with args
{"ll /tmp", "ls -la --color=auto /tmp"},
// Alias without args
{"gs", "git status"},
// Alias with trailing whitespace (trimmed)
{"gs ", "git status"},
// No alias match — return unchanged
{"echo hello", "echo hello"},
// Dotdot alias
{"..", "cd .."},
// Empty command
{"", ""},
// Only whitespace
{" ", " "},
}
for _, tt := range tests {
got := m.ExpandCommand(tt.input)
if got != tt.want {
t.Errorf("ExpandCommand(%q) = %q, want %q", tt.input, got, tt.want)
}
}
}
func TestAliasMap_ExpandCommand_NoAliases(t *testing.T) {
m := NewAliasMap()
got := m.ExpandCommand("echo hello")
if got != "echo hello" {
t.Errorf("ExpandCommand = %q, want unchanged", got)
}
}
func TestAliasMap_All(t *testing.T) {
m := NewAliasMap()
m.mu.Lock()
m.aliases["a"] = "b"
m.aliases["c"] = "d"
m.mu.Unlock()
all := m.All()
if len(all) != 2 {
t.Errorf("len(All()) = %d, want 2", len(all))
}
// Verify it's a copy
all["x"] = "y"
if m.Len() != 2 {
t.Error("All() should return a copy, not a reference")
}
}
func TestStripQuotes(t *testing.T) {
tests := []struct {
input, want string
}{
{"'hello'", "hello"},
{`"hello"`, "hello"},
{"hello", "hello"},
{"'h'", "h"},
{"''", ""},
{`""`, ""},
{"'mismatched\"", "'mismatched\""},
{"x", "x"},
{"", ""},
}
for _, tt := range tests {
got := stripQuotes(tt.input)
if got != tt.want {
t.Errorf("stripQuotes(%q) = %q, want %q", tt.input, got, tt.want)
}
}
}
func TestParseFishAliases(t *testing.T) {
output := `alias gs 'git status'
alias ll 'ls -la'
alias gco "git checkout"
`
m, err := ParseFishAliases(output)
if err != nil {
t.Fatalf("ParseFishAliases: %v", err)
}
if m.Len() != 3 {
t.Errorf("Len() = %d, want 3", m.Len())
}
got, ok := m.Get("gs")
if !ok || got != "git status" {
t.Errorf("gs = %q, %v", got, ok)
}
got, ok = m.Get("gco")
if !ok || got != "git checkout" {
t.Errorf("gco = %q, %v", got, ok)
}
}
func TestShellBaseName(t *testing.T) {
tests := []struct {
input, want string
}{
{"/bin/bash", "bash"},
{"/usr/bin/zsh", "zsh"},
{"/usr/local/bin/fish", "fish"},
{"bash", "bash"},
{"/bin/sh", "sh"},
}
for _, tt := range tests {
got := shellBaseName(tt.input)
if got != tt.want {
t.Errorf("shellBaseName(%q) = %q, want %q", tt.input, got, tt.want)
}
}
}
func TestAliasCommandFor(t *testing.T) {
tests := []struct {
shell string
want string
}{
{"bash", "alias -p 2>/dev/null; true"},
{"zsh", "alias 2>/dev/null; true"},
{"fish", "alias 2>/dev/null; true"},
{"sh", "alias -p 2>/dev/null; true"},
{"unknown", "alias 2>/dev/null; true"},
}
for _, tt := range tests {
got := aliasCommandFor(tt.shell)
if got != tt.want {
t.Errorf("aliasCommandFor(%q) = %q, want %q", tt.shell, got, tt.want)
}
}
}
func TestHarvestAliases_Integration(t *testing.T) {
// This actually runs the user's shell — skip in CI
if testing.Short() {
t.Skip("skipping alias harvest in short mode")
}
m, err := HarvestAliases(context.Background())
if err != nil {
// Non-fatal: harvesting may fail in some environments
t.Logf("HarvestAliases: %v (non-fatal)", err)
}
t.Logf("Harvested %d aliases", m.Len())
for name, exp := range m.All() {
t.Logf(" %s → %s", name, exp)
}
}
func TestBashTool_WithAliases(t *testing.T) {
aliases := NewAliasMap()
aliases.mu.Lock()
aliases.aliases["ll"] = "ls -la"
aliases.mu.Unlock()
b := New(WithAliases(aliases))
// "ll /tmp" should expand to "ls -la /tmp" and execute
result, err := b.Execute(context.Background(), []byte(`{"command":"ll /tmp"}`))
if err != nil {
t.Fatalf("Execute: %v", err)
}
// Should produce output (ls -la /tmp lists files)
if result.Output == "" {
t.Error("expected output from expanded alias")
}
if result.Metadata["blocked"] == true {
t.Error("expanded alias should not be blocked")
}
}

140
internal/tool/bash/bash.go Normal file
View File

@@ -0,0 +1,140 @@
package bash
import (
"context"
"encoding/json"
"fmt"
"os/exec"
"strings"
"time"
"somegit.dev/Owlibou/gnoma/internal/tool"
)
const (
defaultTimeout = 30 * time.Second
toolName = "bash"
)
var parameterSchema = json.RawMessage(`{
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "The bash command to execute"
},
"timeout": {
"type": "integer",
"description": "Timeout in seconds (default 30)"
}
},
"required": ["command"]
}`)
// Tool executes bash commands.
type Tool struct {
timeout time.Duration
workingDir string
aliases *AliasMap
}
type Option func(*Tool)
func WithTimeout(d time.Duration) Option {
return func(t *Tool) { t.timeout = d }
}
func WithWorkingDir(dir string) Option {
return func(t *Tool) { t.workingDir = dir }
}
func WithAliases(aliases *AliasMap) Option {
return func(t *Tool) { t.aliases = aliases }
}
// New creates a bash tool.
func New(opts ...Option) *Tool {
t := &Tool{timeout: defaultTimeout}
for _, opt := range opts {
opt(t)
}
return t
}
func (t *Tool) Name() string { return toolName }
func (t *Tool) Description() string { return "Execute a bash command and return its output" }
func (t *Tool) Parameters() json.RawMessage { return parameterSchema }
func (t *Tool) IsReadOnly() bool { return false }
func (t *Tool) IsDestructive() bool { return true }
type bashArgs struct {
Command string `json:"command"`
Timeout int `json:"timeout,omitempty"`
}
func (t *Tool) Execute(ctx context.Context, args json.RawMessage) (tool.Result, error) {
var a bashArgs
if err := json.Unmarshal(args, &a); err != nil {
return tool.Result{}, fmt.Errorf("bash: invalid args: %w", err)
}
if a.Command == "" {
return tool.Result{}, fmt.Errorf("bash: empty command")
}
// Expand aliases (first word only, matching bash behavior)
command := a.Command
if t.aliases != nil {
command = t.aliases.ExpandCommand(command)
}
// Security validation runs on the expanded command
if violation := ValidateCommand(command); violation != nil {
return tool.Result{
Output: fmt.Sprintf("Command blocked: %s", violation.Message),
Metadata: map[string]any{"blocked": true, "check": int(violation.Check)},
}, nil
}
timeout := t.timeout
if a.Timeout > 0 {
timeout = time.Duration(a.Timeout) * time.Second
}
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
cmd := exec.CommandContext(ctx, "bash", "-c", command)
if t.workingDir != "" {
cmd.Dir = t.workingDir
}
output, err := cmd.CombinedOutput()
exitCode := 0
if err != nil {
// Check timeout first — context deadline may also produce an ExitError
if ctx.Err() == context.DeadlineExceeded {
return tool.Result{
Output: fmt.Sprintf("Command timed out after %s\n%s", timeout, strings.TrimRight(string(output), "\n")),
Metadata: map[string]any{"exit_code": -1, "timeout": true},
}, nil
}
if exitErr, ok := err.(*exec.ExitError); ok {
exitCode = exitErr.ExitCode()
} else {
return tool.Result{}, fmt.Errorf("bash: exec failed: %w", err)
}
}
result := tool.Result{
Output: strings.TrimRight(string(output), "\n"),
Metadata: map[string]any{"exit_code": exitCode},
}
if exitCode != 0 {
result.Output = fmt.Sprintf("Exit code %d\n%s", exitCode, result.Output)
}
return result, nil
}

View File

@@ -0,0 +1,135 @@
package bash
import (
"context"
"encoding/json"
"strings"
"testing"
"time"
)
func TestBashTool_Interface(t *testing.T) {
b := New()
if b.Name() != "bash" {
t.Errorf("Name() = %q", b.Name())
}
if b.IsReadOnly() {
t.Error("bash should not be read-only")
}
if !b.IsDestructive() {
t.Error("bash should be destructive")
}
if b.Parameters() == nil {
t.Error("Parameters() should not be nil")
}
}
func TestBashTool_Echo(t *testing.T) {
b := New()
result, err := b.Execute(context.Background(), json.RawMessage(`{"command":"echo hello world"}`))
if err != nil {
t.Fatalf("Execute: %v", err)
}
if result.Output != "hello world" {
t.Errorf("Output = %q, want %q", result.Output, "hello world")
}
if result.Metadata["exit_code"] != 0 {
t.Errorf("exit_code = %v, want 0", result.Metadata["exit_code"])
}
}
func TestBashTool_ExitCode(t *testing.T) {
b := New()
result, err := b.Execute(context.Background(), json.RawMessage(`{"command":"exit 42"}`))
if err != nil {
t.Fatalf("Execute: %v", err)
}
if result.Metadata["exit_code"] != 42 {
t.Errorf("exit_code = %v, want 42", result.Metadata["exit_code"])
}
if !strings.HasPrefix(result.Output, "Exit code 42") {
t.Errorf("Output = %q, should start with exit code", result.Output)
}
}
func TestBashTool_Timeout(t *testing.T) {
b := New(WithTimeout(100 * time.Millisecond))
result, err := b.Execute(context.Background(), json.RawMessage(`{"command":"sleep 10"}`))
if err != nil {
t.Fatalf("Execute: %v", err)
}
if result.Metadata["timeout"] != true {
t.Error("should have timed out")
}
if !strings.Contains(result.Output, "timed out") {
t.Errorf("Output = %q, should mention timeout", result.Output)
}
}
func TestBashTool_CustomTimeout(t *testing.T) {
b := New(WithTimeout(30 * time.Second))
// Args override the default timeout
result, err := b.Execute(context.Background(), json.RawMessage(`{"command":"sleep 10","timeout":1}`))
if err != nil {
t.Fatalf("Execute: %v", err)
}
if result.Metadata["timeout"] != true {
t.Error("should have timed out with custom 1s timeout")
}
}
func TestBashTool_InvalidArgs(t *testing.T) {
b := New()
_, err := b.Execute(context.Background(), json.RawMessage(`not json`))
if err == nil {
t.Error("expected error for invalid JSON")
}
}
func TestBashTool_EmptyCommand(t *testing.T) {
b := New()
_, err := b.Execute(context.Background(), json.RawMessage(`{"command":""}`))
if err == nil {
t.Error("expected error for empty command")
}
}
func TestBashTool_SecurityBlock(t *testing.T) {
b := New()
// Command substitution should be blocked
result, err := b.Execute(context.Background(), json.RawMessage(`{"command":"echo $(whoami)"}`))
if err != nil {
t.Fatalf("Execute: %v", err)
}
if result.Metadata["blocked"] != true {
t.Error("command with $() should be blocked")
}
if !strings.Contains(result.Output, "blocked") {
t.Errorf("Output = %q, should mention blocked", result.Output)
}
}
func TestBashTool_WorkingDir(t *testing.T) {
b := New(WithWorkingDir(t.TempDir()))
result, err := b.Execute(context.Background(), json.RawMessage(`{"command":"pwd"}`))
if err != nil {
t.Fatalf("Execute: %v", err)
}
if result.Output == "" {
t.Error("pwd should produce output")
}
}
func TestBashTool_ContextCancellation(t *testing.T) {
b := New()
ctx, cancel := context.WithCancel(context.Background())
cancel() // cancel immediately
_, err := b.Execute(ctx, json.RawMessage(`{"command":"echo hello"}`))
// Should either return an error or a timeout result
if err == nil {
// That's ok too — context cancellation is best-effort for fast commands
return
}
}

View File

@@ -0,0 +1,206 @@
package bash
import (
"fmt"
"strings"
"unicode"
)
// SecurityCheck identifies a specific validation check.
type SecurityCheck int
const (
CheckIncomplete SecurityCheck = iota + 1 // fragments, trailing operators
CheckMetacharacters // ; | & $ ` < >
CheckCmdSubstitution // $(), ``, ${}
CheckRedirection // < > >> etc.
CheckDangerousVars // IFS, PATH manipulation
CheckNewlineInjection // embedded newlines
CheckControlChars // ASCII 00-1F (except \n \t)
)
// SecurityViolation describes a failed security check.
type SecurityViolation struct {
Check SecurityCheck
Message string
}
func (v SecurityViolation) Error() string {
return fmt.Sprintf("bash security check %d: %s", v.Check, v.Message)
}
// ValidateCommand runs the 7 critical security checks against a command string.
// Returns nil if all checks pass, or the first violation found.
func ValidateCommand(cmd string) *SecurityViolation {
if strings.TrimSpace(cmd) == "" {
return &SecurityViolation{Check: CheckIncomplete, Message: "empty command"}
}
// Check incomplete on raw command (before trimming) to catch tab-starts
if v := checkIncomplete(cmd); v != nil {
return v
}
cmd = strings.TrimSpace(cmd)
if v := checkControlChars(cmd); v != nil {
return v
}
if v := checkNewlineInjection(cmd); v != nil {
return v
}
if v := checkCmdSubstitution(cmd); v != nil {
return v
}
if v := checkDangerousVars(cmd); v != nil {
return v
}
// Metacharacters and redirection are warnings, not blocks in M1.
// The LLM legitimately uses pipes and redirects.
// Full compound command parsing (mvdan.cc/sh) comes in M5.
return nil
}
// checkIncomplete detects command fragments that shouldn't be executed.
func checkIncomplete(cmd string) *SecurityViolation {
// Starts with tab (likely a fragment from indented code)
if cmd[0] == '\t' {
return &SecurityViolation{Check: CheckIncomplete, Message: "command starts with tab (likely a code fragment)"}
}
// Starts with a flag (no command name)
if cmd[0] == '-' {
return &SecurityViolation{Check: CheckIncomplete, Message: "command starts with flag (no command name)"}
}
// Ends with a dangling operator
trimmed := strings.TrimRight(cmd, " \t")
if len(trimmed) > 0 {
last := trimmed[len(trimmed)-1]
if last == '|' || last == '&' || last == ';' {
return &SecurityViolation{Check: CheckIncomplete, Message: "command ends with dangling operator"}
}
}
return nil
}
// checkControlChars blocks ASCII control characters (0x00-0x1F) except \n and \t.
func checkControlChars(cmd string) *SecurityViolation {
for i, r := range cmd {
if r < 0x20 && r != '\n' && r != '\t' && r != '\r' {
return &SecurityViolation{
Check: CheckControlChars,
Message: fmt.Sprintf("control character U+%04X at position %d", r, i),
}
}
}
return nil
}
// checkNewlineInjection blocks commands with embedded newlines.
// Newlines in quoted strings are legitimate but rare in single commands.
// We allow them inside single/double quotes only.
func checkNewlineInjection(cmd string) *SecurityViolation {
inSingle := false
inDouble := false
escaped := false
for _, r := range cmd {
if escaped {
escaped = false
continue
}
if r == '\\' && !inSingle {
escaped = true
continue
}
if r == '\'' && !inDouble {
inSingle = !inSingle
continue
}
if r == '"' && !inSingle {
inDouble = !inDouble
continue
}
if r == '\n' && !inSingle && !inDouble {
return &SecurityViolation{
Check: CheckNewlineInjection,
Message: "unquoted newline (potential command injection)",
}
}
}
return nil
}
// checkCmdSubstitution blocks $(), ``, and ${} command/variable substitution.
// These allow arbitrary code execution within a command.
func checkCmdSubstitution(cmd string) *SecurityViolation {
inSingle := false
escaped := false
for i, r := range cmd {
if escaped {
escaped = false
continue
}
if r == '\\' && !inSingle {
escaped = true
continue
}
if r == '\'' {
inSingle = !inSingle
continue
}
// Skip checks inside single quotes (literal)
if inSingle {
continue
}
if r == '`' {
return &SecurityViolation{
Check: CheckCmdSubstitution,
Message: "backtick command substitution",
}
}
if r == '$' && i+1 < len(cmd) {
next := rune(cmd[i+1])
if next == '(' {
return &SecurityViolation{
Check: CheckCmdSubstitution,
Message: "$() command substitution",
}
}
if next == '{' {
return &SecurityViolation{
Check: CheckCmdSubstitution,
Message: "${} variable expansion",
}
}
}
}
return nil
}
// checkDangerousVars blocks attempts to manipulate IFS or PATH.
func checkDangerousVars(cmd string) *SecurityViolation {
upper := strings.ToUpper(cmd)
dangerousPatterns := []struct {
pattern string
msg string
}{
{"IFS=", "IFS variable manipulation"},
{"PATH=", "PATH variable manipulation"},
}
for _, p := range dangerousPatterns {
idx := strings.Index(upper, p.pattern)
if idx == -1 {
continue
}
// Only flag if it's at the start or preceded by whitespace/semicolon
if idx == 0 || !unicode.IsLetter(rune(cmd[idx-1])) {
return &SecurityViolation{Check: CheckDangerousVars, Message: p.msg}
}
}
return nil
}

View File

@@ -0,0 +1,182 @@
package bash
import "testing"
func TestValidateCommand_Valid(t *testing.T) {
valid := []string{
"echo hello",
"ls -la",
"cat /etc/hostname",
"go test ./...",
"git status",
"echo 'hello world'",
`echo "hello world"`,
"grep -r 'pattern' .",
"find . -name '*.go'",
}
for _, cmd := range valid {
if v := ValidateCommand(cmd); v != nil {
t.Errorf("ValidateCommand(%q) = %v, want nil", cmd, v)
}
}
}
func TestValidateCommand_Empty(t *testing.T) {
v := ValidateCommand("")
if v == nil {
t.Fatal("expected violation for empty command")
}
if v.Check != CheckIncomplete {
t.Errorf("Check = %d, want %d (incomplete)", v.Check, CheckIncomplete)
}
}
func TestCheckIncomplete(t *testing.T) {
tests := []struct {
cmd string
want SecurityCheck
}{
{"\techo hello", CheckIncomplete}, // tab start
{"-flag value", CheckIncomplete}, // flag start
{"echo hello |", CheckIncomplete}, // trailing pipe
{"echo hello &", CheckIncomplete}, // trailing ampersand
{"echo hello ;", CheckIncomplete}, // trailing semicolon
}
for _, tt := range tests {
v := ValidateCommand(tt.cmd)
if v == nil {
t.Errorf("ValidateCommand(%q) = nil, want check %d", tt.cmd, tt.want)
continue
}
if v.Check != tt.want {
t.Errorf("ValidateCommand(%q).Check = %d, want %d", tt.cmd, v.Check, tt.want)
}
}
}
func TestCheckControlChars(t *testing.T) {
tests := []struct {
name string
cmd string
}{
{"null byte", "echo hello\x00world"},
{"bell", "echo \x07"},
{"backspace", "echo \x08"},
{"escape", "echo \x1b[31m"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := ValidateCommand(tt.cmd)
if v == nil {
t.Error("expected violation")
return
}
if v.Check != CheckControlChars {
t.Errorf("Check = %d, want %d (control chars)", v.Check, CheckControlChars)
}
})
}
}
func TestCheckControlChars_AllowedChars(t *testing.T) {
// Tabs and newlines inside quotes are allowed
valid := []string{
"echo 'hello\tworld'",
}
for _, cmd := range valid {
if v := checkControlChars(cmd); v != nil {
t.Errorf("checkControlChars(%q) = %v, want nil", cmd, v)
}
}
}
func TestCheckNewlineInjection(t *testing.T) {
// Unquoted newline
v := checkNewlineInjection("echo hello\nrm -rf /")
if v == nil {
t.Fatal("expected violation for unquoted newline")
}
if v.Check != CheckNewlineInjection {
t.Errorf("Check = %d, want %d", v.Check, CheckNewlineInjection)
}
}
func TestCheckNewlineInjection_QuotedOK(t *testing.T) {
// Newlines inside quotes are fine
allowed := []string{
"echo 'hello\nworld'",
`echo "hello` + "\n" + `world"`,
}
for _, cmd := range allowed {
if v := checkNewlineInjection(cmd); v != nil {
t.Errorf("checkNewlineInjection(%q) = %v, want nil", cmd, v)
}
}
}
func TestCheckCmdSubstitution(t *testing.T) {
tests := []struct {
name string
cmd string
}{
{"backtick", "echo `whoami`"},
{"dollar paren", "echo $(whoami)"},
{"dollar brace", "echo ${HOME}"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := ValidateCommand(tt.cmd)
if v == nil {
t.Error("expected violation")
return
}
if v.Check != CheckCmdSubstitution {
t.Errorf("Check = %d, want %d", v.Check, CheckCmdSubstitution)
}
})
}
}
func TestCheckCmdSubstitution_SingleQuoteOK(t *testing.T) {
// Inside single quotes, everything is literal
safe := "echo '$(whoami) and `uname` and ${HOME}'"
if v := checkCmdSubstitution(safe); v != nil {
t.Errorf("checkCmdSubstitution(%q) = %v, want nil (single-quoted)", safe, v)
}
}
func TestCheckDangerousVars(t *testing.T) {
tests := []struct {
name string
cmd string
}{
{"IFS at start", "IFS=: read a b"},
{"PATH manipulation", "PATH=/tmp:$PATH command"},
{"ifs with space prefix", " IFS=x echo hi"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := ValidateCommand(tt.cmd)
if v == nil {
t.Error("expected violation")
return
}
if v.Check != CheckDangerousVars {
t.Errorf("Check = %d, want %d", v.Check, CheckDangerousVars)
}
})
}
}
func TestCheckDangerousVars_SafeSubstrings(t *testing.T) {
// "SWIFT=..." should not trigger PATH check, "TARIFFS=..." should not trigger IFS
safe := []string{
"echo SWIFT=enabled",
"TARIFFS=high echo test",
}
for _, cmd := range safe {
if v := checkDangerousVars(cmd); v != nil {
t.Errorf("checkDangerousVars(%q) = %v, want nil", cmd, v)
}
}
}

109
internal/tool/fs/edit.go Normal file
View File

@@ -0,0 +1,109 @@
package fs
import (
"context"
"encoding/json"
"fmt"
"os"
"strings"
"somegit.dev/Owlibou/gnoma/internal/tool"
)
const editToolName = "fs.edit"
var editParams = json.RawMessage(`{
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Absolute path to the file to edit"
},
"old_string": {
"type": "string",
"description": "The exact text to find and replace"
},
"new_string": {
"type": "string",
"description": "The replacement text"
},
"replace_all": {
"type": "boolean",
"description": "Replace all occurrences (default false)"
}
},
"required": ["path", "old_string", "new_string"]
}`)
type EditTool struct{}
func NewEditTool() *EditTool { return &EditTool{} }
func (t *EditTool) Name() string { return editToolName }
func (t *EditTool) Description() string { return "Perform exact string replacement in a file" }
func (t *EditTool) Parameters() json.RawMessage { return editParams }
func (t *EditTool) IsReadOnly() bool { return false }
func (t *EditTool) IsDestructive() bool { return false }
type editArgs struct {
Path string `json:"path"`
OldString string `json:"old_string"`
NewString string `json:"new_string"`
ReplaceAll bool `json:"replace_all,omitempty"`
}
func (t *EditTool) Execute(_ context.Context, args json.RawMessage) (tool.Result, error) {
var a editArgs
if err := json.Unmarshal(args, &a); err != nil {
return tool.Result{}, fmt.Errorf("fs.edit: invalid args: %w", err)
}
if a.Path == "" {
return tool.Result{}, fmt.Errorf("fs.edit: path required")
}
if a.OldString == a.NewString {
return tool.Result{}, fmt.Errorf("fs.edit: old_string and new_string must differ")
}
data, err := os.ReadFile(a.Path)
if err != nil {
return tool.Result{Output: fmt.Sprintf("Error: %v", err)}, nil
}
content := string(data)
count := strings.Count(content, a.OldString)
if count == 0 {
return tool.Result{
Output: "Error: old_string not found in file",
Metadata: map[string]any{"matches": 0},
}, nil
}
if !a.ReplaceAll && count > 1 {
return tool.Result{
Output: fmt.Sprintf("Error: old_string has %d matches (must be unique, or use replace_all)", count),
Metadata: map[string]any{"matches": count},
}, nil
}
var newContent string
if a.ReplaceAll {
newContent = strings.ReplaceAll(content, a.OldString, a.NewString)
} else {
newContent = strings.Replace(content, a.OldString, a.NewString, 1)
}
if err := os.WriteFile(a.Path, []byte(newContent), 0o644); err != nil {
return tool.Result{Output: fmt.Sprintf("Error writing file: %v", err)}, nil
}
replacements := 1
if a.ReplaceAll {
replacements = count
}
return tool.Result{
Output: fmt.Sprintf("Replaced %d occurrence(s) in %s", replacements, a.Path),
Metadata: map[string]any{"replacements": replacements, "path": a.Path},
}, nil
}

545
internal/tool/fs/fs_test.go Normal file
View File

@@ -0,0 +1,545 @@
package fs
import (
"context"
"encoding/json"
"os"
"path/filepath"
"strings"
"testing"
)
// --- Read ---
func TestReadTool_Interface(t *testing.T) {
r := NewReadTool()
if r.Name() != "fs.read" {
t.Errorf("Name() = %q", r.Name())
}
if !r.IsReadOnly() {
t.Error("should be read-only")
}
if r.IsDestructive() {
t.Error("should not be destructive")
}
}
func TestReadTool_SimpleFile(t *testing.T) {
path := writeTestFile(t, "hello\nworld\n")
r := NewReadTool()
result, err := r.Execute(context.Background(), mustJSON(t, readArgs{Path: path}))
if err != nil {
t.Fatalf("Execute: %v", err)
}
if !strings.Contains(result.Output, "1\thello") {
t.Errorf("Output should contain line-numbered content, got %q", result.Output)
}
if !strings.Contains(result.Output, "2\tworld") {
t.Errorf("Output missing line 2, got %q", result.Output)
}
}
func TestReadTool_WithOffset(t *testing.T) {
path := writeTestFile(t, "line1\nline2\nline3\nline4\nline5\n")
r := NewReadTool()
result, err := r.Execute(context.Background(), mustJSON(t, readArgs{Path: path, Offset: 2}))
if err != nil {
t.Fatalf("Execute: %v", err)
}
if !strings.Contains(result.Output, "3\tline3") {
t.Errorf("Output should start at line 3, got %q", result.Output)
}
if strings.Contains(result.Output, "1\tline1") {
t.Error("Output should not contain line 1")
}
}
func TestReadTool_WithLimit(t *testing.T) {
path := writeTestFile(t, "a\nb\nc\nd\ne\n")
r := NewReadTool()
result, err := r.Execute(context.Background(), mustJSON(t, readArgs{Path: path, Limit: 2}))
if err != nil {
t.Fatalf("Execute: %v", err)
}
lines := strings.Split(result.Output, "\n")
if len(lines) != 2 {
t.Errorf("expected 2 lines, got %d: %q", len(lines), result.Output)
}
if result.Metadata["truncated"] != true {
t.Error("should be truncated")
}
}
func TestReadTool_OffsetPastEnd(t *testing.T) {
path := writeTestFile(t, "one\ntwo\n")
r := NewReadTool()
result, err := r.Execute(context.Background(), mustJSON(t, readArgs{Path: path, Offset: 100}))
if err != nil {
t.Fatalf("Execute: %v", err)
}
if !strings.Contains(result.Output, "past end") {
t.Errorf("Output = %q, should mention past end", result.Output)
}
}
func TestReadTool_FileNotFound(t *testing.T) {
r := NewReadTool()
result, err := r.Execute(context.Background(), mustJSON(t, readArgs{Path: "/nonexistent/file.txt"}))
if err != nil {
t.Fatalf("Execute: %v", err)
}
if !strings.Contains(result.Output, "Error") {
t.Errorf("Output = %q, should contain error", result.Output)
}
}
func TestReadTool_EmptyPath(t *testing.T) {
r := NewReadTool()
_, err := r.Execute(context.Background(), mustJSON(t, readArgs{}))
if err == nil {
t.Error("expected error for empty path")
}
}
// --- Write ---
func TestWriteTool_Interface(t *testing.T) {
w := NewWriteTool()
if w.Name() != "fs.write" {
t.Errorf("Name() = %q", w.Name())
}
if w.IsReadOnly() {
t.Error("should not be read-only")
}
}
func TestWriteTool_CreateFile(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.txt")
w := NewWriteTool()
result, err := w.Execute(context.Background(), mustJSON(t, writeArgs{Path: path, Content: "hello world"}))
if err != nil {
t.Fatalf("Execute: %v", err)
}
if !strings.Contains(result.Output, "11 bytes") {
t.Errorf("Output = %q", result.Output)
}
data, _ := os.ReadFile(path)
if string(data) != "hello world" {
t.Errorf("file content = %q", string(data))
}
}
func TestWriteTool_CreatesParentDirs(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "a", "b", "c", "test.txt")
w := NewWriteTool()
_, err := w.Execute(context.Background(), mustJSON(t, writeArgs{Path: path, Content: "nested"}))
if err != nil {
t.Fatalf("Execute: %v", err)
}
data, _ := os.ReadFile(path)
if string(data) != "nested" {
t.Errorf("file content = %q", string(data))
}
}
func TestWriteTool_OverwriteExisting(t *testing.T) {
path := writeTestFile(t, "old content")
w := NewWriteTool()
_, err := w.Execute(context.Background(), mustJSON(t, writeArgs{Path: path, Content: "new content"}))
if err != nil {
t.Fatalf("Execute: %v", err)
}
data, _ := os.ReadFile(path)
if string(data) != "new content" {
t.Errorf("file content = %q", string(data))
}
}
// --- Edit ---
func TestEditTool_Interface(t *testing.T) {
e := NewEditTool()
if e.Name() != "fs.edit" {
t.Errorf("Name() = %q", e.Name())
}
}
func TestEditTool_SingleReplace(t *testing.T) {
path := writeTestFile(t, "hello world")
e := NewEditTool()
result, err := e.Execute(context.Background(), mustJSON(t, editArgs{
Path: path, OldString: "world", NewString: "gnoma",
}))
if err != nil {
t.Fatalf("Execute: %v", err)
}
if !strings.Contains(result.Output, "1 occurrence") {
t.Errorf("Output = %q", result.Output)
}
data, _ := os.ReadFile(path)
if string(data) != "hello gnoma" {
t.Errorf("file content = %q", string(data))
}
}
func TestEditTool_ReplaceAll(t *testing.T) {
path := writeTestFile(t, "foo bar foo baz foo")
e := NewEditTool()
result, err := e.Execute(context.Background(), mustJSON(t, editArgs{
Path: path, OldString: "foo", NewString: "qux", ReplaceAll: true,
}))
if err != nil {
t.Fatalf("Execute: %v", err)
}
if !strings.Contains(result.Output, "3 occurrence") {
t.Errorf("Output = %q", result.Output)
}
data, _ := os.ReadFile(path)
if string(data) != "qux bar qux baz qux" {
t.Errorf("file content = %q", string(data))
}
}
func TestEditTool_NonUniqueWithoutReplaceAll(t *testing.T) {
path := writeTestFile(t, "foo foo foo")
e := NewEditTool()
result, err := e.Execute(context.Background(), mustJSON(t, editArgs{
Path: path, OldString: "foo", NewString: "bar",
}))
if err != nil {
t.Fatalf("Execute: %v", err)
}
if !strings.Contains(result.Output, "3 matches") {
t.Errorf("Output = %q, should mention multiple matches", result.Output)
}
// File should be unchanged
data, _ := os.ReadFile(path)
if string(data) != "foo foo foo" {
t.Errorf("file should be unchanged, got %q", string(data))
}
}
func TestEditTool_NotFound(t *testing.T) {
path := writeTestFile(t, "hello world")
e := NewEditTool()
result, err := e.Execute(context.Background(), mustJSON(t, editArgs{
Path: path, OldString: "missing", NewString: "replaced",
}))
if err != nil {
t.Fatalf("Execute: %v", err)
}
if !strings.Contains(result.Output, "not found") {
t.Errorf("Output = %q, should mention not found", result.Output)
}
}
func TestEditTool_SameStrings(t *testing.T) {
e := NewEditTool()
_, err := e.Execute(context.Background(), mustJSON(t, editArgs{
Path: "/tmp/x", OldString: "same", NewString: "same",
}))
if err == nil {
t.Error("expected error when old_string == new_string")
}
}
// --- Glob ---
func TestGlobTool_Interface(t *testing.T) {
g := NewGlobTool()
if g.Name() != "fs.glob" {
t.Errorf("Name() = %q", g.Name())
}
if !g.IsReadOnly() {
t.Error("should be read-only")
}
}
func TestGlobTool_MatchFiles(t *testing.T) {
dir := t.TempDir()
os.WriteFile(filepath.Join(dir, "main.go"), []byte("package main"), 0o644)
os.WriteFile(filepath.Join(dir, "test.go"), []byte("package main"), 0o644)
os.WriteFile(filepath.Join(dir, "readme.md"), []byte("# readme"), 0o644)
g := NewGlobTool()
result, err := g.Execute(context.Background(), mustJSON(t, globArgs{Pattern: "*.go", Path: dir}))
if err != nil {
t.Fatalf("Execute: %v", err)
}
if result.Metadata["count"] != 2 {
t.Errorf("count = %v, want 2", result.Metadata["count"])
}
if !strings.Contains(result.Output, "main.go") {
t.Errorf("Output missing main.go: %q", result.Output)
}
if strings.Contains(result.Output, "readme.md") {
t.Error("Output should not contain readme.md")
}
}
func TestGlobTool_NoMatches(t *testing.T) {
dir := t.TempDir()
g := NewGlobTool()
result, err := g.Execute(context.Background(), mustJSON(t, globArgs{Pattern: "*.xyz", Path: dir}))
if err != nil {
t.Fatalf("Execute: %v", err)
}
if !strings.Contains(result.Output, "no matches") {
t.Errorf("Output = %q", result.Output)
}
}
// --- Grep ---
func TestGrepTool_Interface(t *testing.T) {
g := NewGrepTool()
if g.Name() != "fs.grep" {
t.Errorf("Name() = %q", g.Name())
}
if !g.IsReadOnly() {
t.Error("should be read-only")
}
}
func TestGrepTool_SingleFile(t *testing.T) {
path := writeTestFile(t, "hello world\nfoo bar\nhello again\n")
g := NewGrepTool()
result, err := g.Execute(context.Background(), mustJSON(t, grepArgs{Pattern: "hello", Path: path}))
if err != nil {
t.Fatalf("Execute: %v", err)
}
if result.Metadata["count"] != 2 {
t.Errorf("count = %v, want 2", result.Metadata["count"])
}
if !strings.Contains(result.Output, "1:hello world") {
t.Errorf("Output = %q", result.Output)
}
}
func TestGrepTool_Directory(t *testing.T) {
dir := t.TempDir()
os.WriteFile(filepath.Join(dir, "a.go"), []byte("func main() {}\nfunc helper() {}"), 0o644)
os.WriteFile(filepath.Join(dir, "b.go"), []byte("func test() {}"), 0o644)
os.WriteFile(filepath.Join(dir, "c.txt"), []byte("func ignored() {}"), 0o644)
g := NewGrepTool()
// Search all files for "func"
result, err := g.Execute(context.Background(), mustJSON(t, grepArgs{Pattern: "func", Path: dir}))
if err != nil {
t.Fatalf("Execute: %v", err)
}
if result.Metadata["count"].(int) < 3 {
t.Errorf("count = %v, want >= 3", result.Metadata["count"])
}
// With glob filter
result, err = g.Execute(context.Background(), mustJSON(t, grepArgs{Pattern: "func", Path: dir, Glob: "*.go"}))
if err != nil {
t.Fatalf("Execute: %v", err)
}
if strings.Contains(result.Output, "c.txt") {
t.Error("should not match .txt files with *.go glob")
}
}
func TestGrepTool_Regex(t *testing.T) {
path := writeTestFile(t, "error: something failed\nwarning: be careful\nerror: another one\n")
g := NewGrepTool()
result, err := g.Execute(context.Background(), mustJSON(t, grepArgs{Pattern: `^error:`, Path: path}))
if err != nil {
t.Fatalf("Execute: %v", err)
}
if result.Metadata["count"] != 2 {
t.Errorf("count = %v, want 2", result.Metadata["count"])
}
}
func TestGrepTool_InvalidRegex(t *testing.T) {
g := NewGrepTool()
result, err := g.Execute(context.Background(), mustJSON(t, grepArgs{Pattern: "[invalid", Path: "."}))
if err != nil {
t.Fatalf("Execute: %v", err)
}
if !strings.Contains(result.Output, "Invalid regex") {
t.Errorf("Output = %q, should mention invalid regex", result.Output)
}
}
func TestGrepTool_NoMatches(t *testing.T) {
path := writeTestFile(t, "hello world\n")
g := NewGrepTool()
result, err := g.Execute(context.Background(), mustJSON(t, grepArgs{Pattern: "zzzzz", Path: path}))
if err != nil {
t.Fatalf("Execute: %v", err)
}
if !strings.Contains(result.Output, "no matches") {
t.Errorf("Output = %q", result.Output)
}
}
func TestGrepTool_MaxResults(t *testing.T) {
var lines strings.Builder
for i := 0; i < 100; i++ {
lines.WriteString("match line\n")
}
path := writeTestFile(t, lines.String())
g := NewGrepTool()
result, err := g.Execute(context.Background(), mustJSON(t, grepArgs{Pattern: "match", Path: path, MaxResults: 5}))
if err != nil {
t.Fatalf("Execute: %v", err)
}
if result.Metadata["count"] != 5 {
t.Errorf("count = %v, want 5", result.Metadata["count"])
}
if result.Metadata["truncated"] != true {
t.Error("should be truncated")
}
}
// --- LS ---
func TestLSTool_Interface(t *testing.T) {
l := NewLSTool()
if l.Name() != "fs.ls" {
t.Errorf("Name() = %q", l.Name())
}
if !l.IsReadOnly() {
t.Error("should be read-only")
}
}
func TestLSTool_ListDirectory(t *testing.T) {
dir := t.TempDir()
os.WriteFile(filepath.Join(dir, "hello.go"), []byte("package main"), 0o644)
os.WriteFile(filepath.Join(dir, "readme.md"), []byte("# readme"), 0o644)
os.MkdirAll(filepath.Join(dir, "subdir"), 0o755)
l := NewLSTool()
result, err := l.Execute(context.Background(), mustJSON(t, lsArgs{Path: dir}))
if err != nil {
t.Fatalf("Execute: %v", err)
}
if !strings.Contains(result.Output, "hello.go") {
t.Errorf("Output missing hello.go: %q", result.Output)
}
if !strings.Contains(result.Output, "readme.md") {
t.Errorf("Output missing readme.md: %q", result.Output)
}
if !strings.Contains(result.Output, "subdir") {
t.Errorf("Output missing subdir: %q", result.Output)
}
if result.Metadata["files"] != 2 {
t.Errorf("files = %v, want 2", result.Metadata["files"])
}
if result.Metadata["dirs"] != 1 {
t.Errorf("dirs = %v, want 1", result.Metadata["dirs"])
}
}
func TestLSTool_EmptyDirectory(t *testing.T) {
dir := t.TempDir()
l := NewLSTool()
result, err := l.Execute(context.Background(), mustJSON(t, lsArgs{Path: dir}))
if err != nil {
t.Fatalf("Execute: %v", err)
}
if !strings.Contains(result.Output, "empty directory") {
t.Errorf("Output = %q, should mention empty", result.Output)
}
}
func TestLSTool_DirectoryNotFound(t *testing.T) {
l := NewLSTool()
result, err := l.Execute(context.Background(), mustJSON(t, lsArgs{Path: "/nonexistent/dir"}))
if err != nil {
t.Fatalf("Execute: %v", err)
}
if !strings.Contains(result.Output, "Error") {
t.Errorf("Output = %q, should contain error", result.Output)
}
}
func TestLSTool_ShowsSizes(t *testing.T) {
dir := t.TempDir()
os.WriteFile(filepath.Join(dir, "small.txt"), []byte("hi"), 0o644)
l := NewLSTool()
result, err := l.Execute(context.Background(), mustJSON(t, lsArgs{Path: dir}))
if err != nil {
t.Fatalf("Execute: %v", err)
}
// Should show "2B" for a 2-byte file
if !strings.Contains(result.Output, "2B") {
t.Errorf("Output = %q, should show file size", result.Output)
}
}
func TestFormatSize(t *testing.T) {
tests := []struct {
bytes int64
want string
}{
{0, "0B"},
{42, "42B"},
{1024, "1.0K"},
{1536, "1.5K"},
{1048576, "1.0M"},
{1073741824, "1.0G"},
}
for _, tt := range tests {
got := formatSize(tt.bytes)
if got != tt.want {
t.Errorf("formatSize(%d) = %q, want %q", tt.bytes, got, tt.want)
}
}
}
// --- Helpers ---
func writeTestFile(t *testing.T, content string) string {
t.Helper()
dir := t.TempDir()
path := filepath.Join(dir, "test.txt")
if err := os.WriteFile(path, []byte(content), 0o644); err != nil {
t.Fatalf("writeTestFile: %v", err)
}
return path
}
func mustJSON(t *testing.T, v any) json.RawMessage {
t.Helper()
data, err := json.Marshal(v)
if err != nil {
t.Fatalf("json.Marshal: %v", err)
}
return data
}

117
internal/tool/fs/glob.go Normal file
View File

@@ -0,0 +1,117 @@
package fs
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"somegit.dev/Owlibou/gnoma/internal/tool"
)
const globToolName = "fs.glob"
var globParams = json.RawMessage(`{
"type": "object",
"properties": {
"pattern": {
"type": "string",
"description": "Glob pattern to match files (e.g. **/*.go, src/**/*.ts)"
},
"path": {
"type": "string",
"description": "Directory to search in (defaults to current directory)"
}
},
"required": ["pattern"]
}`)
type GlobTool struct{}
func NewGlobTool() *GlobTool { return &GlobTool{} }
func (t *GlobTool) Name() string { return globToolName }
func (t *GlobTool) Description() string { return "Find files matching a glob pattern, sorted by modification time" }
func (t *GlobTool) Parameters() json.RawMessage { return globParams }
func (t *GlobTool) IsReadOnly() bool { return true }
func (t *GlobTool) IsDestructive() bool { return false }
type globArgs struct {
Pattern string `json:"pattern"`
Path string `json:"path,omitempty"`
}
func (t *GlobTool) Execute(_ context.Context, args json.RawMessage) (tool.Result, error) {
var a globArgs
if err := json.Unmarshal(args, &a); err != nil {
return tool.Result{}, fmt.Errorf("fs.glob: invalid args: %w", err)
}
if a.Pattern == "" {
return tool.Result{}, fmt.Errorf("fs.glob: pattern required")
}
root := a.Path
if root == "" {
var err error
root, err = os.Getwd()
if err != nil {
return tool.Result{}, fmt.Errorf("fs.glob: %w", err)
}
}
var matches []string
err := filepath.WalkDir(root, func(path string, d os.DirEntry, err error) error {
if err != nil {
return nil // skip inaccessible entries
}
if d.IsDir() {
// Skip hidden directories
if d.Name() != "." && strings.HasPrefix(d.Name(), ".") {
return filepath.SkipDir
}
return nil
}
rel, err := filepath.Rel(root, path)
if err != nil {
return nil
}
matched, err := filepath.Match(a.Pattern, rel)
if err != nil {
// Try matching just the filename for simple patterns
matched, _ = filepath.Match(a.Pattern, d.Name())
}
if matched {
matches = append(matches, rel)
}
return nil
})
if err != nil {
return tool.Result{Output: fmt.Sprintf("Error walking directory: %v", err)}, nil
}
// Sort by modification time (most recent first)
sort.Slice(matches, func(i, j int) bool {
iInfo, _ := os.Stat(filepath.Join(root, matches[i]))
jInfo, _ := os.Stat(filepath.Join(root, matches[j]))
if iInfo == nil || jInfo == nil {
return matches[i] < matches[j]
}
return iInfo.ModTime().After(jInfo.ModTime())
})
output := strings.Join(matches, "\n")
if output == "" {
output = "(no matches)"
}
return tool.Result{
Output: output,
Metadata: map[string]any{"count": len(matches), "pattern": a.Pattern},
}, nil
}

184
internal/tool/fs/grep.go Normal file
View File

@@ -0,0 +1,184 @@
package fs
import (
"bufio"
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
"somegit.dev/Owlibou/gnoma/internal/tool"
)
const (
grepToolName = "fs.grep"
defaultMaxResults = 250
)
var grepParams = json.RawMessage(`{
"type": "object",
"properties": {
"pattern": {
"type": "string",
"description": "Regular expression pattern to search for"
},
"path": {
"type": "string",
"description": "File or directory to search in (defaults to current directory)"
},
"glob": {
"type": "string",
"description": "File glob filter (e.g. *.go, *.ts)"
},
"max_results": {
"type": "integer",
"description": "Maximum number of matching lines to return (default 250)"
}
},
"required": ["pattern"]
}`)
type GrepTool struct{}
func NewGrepTool() *GrepTool { return &GrepTool{} }
func (t *GrepTool) Name() string { return grepToolName }
func (t *GrepTool) Description() string { return "Search file contents using a regular expression" }
func (t *GrepTool) Parameters() json.RawMessage { return grepParams }
func (t *GrepTool) IsReadOnly() bool { return true }
func (t *GrepTool) IsDestructive() bool { return false }
type grepArgs struct {
Pattern string `json:"pattern"`
Path string `json:"path,omitempty"`
Glob string `json:"glob,omitempty"`
MaxResults int `json:"max_results,omitempty"`
}
type grepMatch struct {
File string
Line int
Text string
}
func (t *GrepTool) Execute(_ context.Context, args json.RawMessage) (tool.Result, error) {
var a grepArgs
if err := json.Unmarshal(args, &a); err != nil {
return tool.Result{}, fmt.Errorf("fs.grep: invalid args: %w", err)
}
if a.Pattern == "" {
return tool.Result{}, fmt.Errorf("fs.grep: pattern required")
}
re, err := regexp.Compile(a.Pattern)
if err != nil {
return tool.Result{Output: fmt.Sprintf("Invalid regex: %v", err)}, nil
}
maxResults := a.MaxResults
if maxResults <= 0 {
maxResults = defaultMaxResults
}
root := a.Path
if root == "" {
root, err = os.Getwd()
if err != nil {
return tool.Result{}, fmt.Errorf("fs.grep: %w", err)
}
}
info, err := os.Stat(root)
if err != nil {
return tool.Result{Output: fmt.Sprintf("Error: %v", err)}, nil
}
var matches []grepMatch
if !info.IsDir() {
matches = grepFile(root, "", re, maxResults)
} else {
filepath.WalkDir(root, func(path string, d os.DirEntry, err error) error {
if err != nil || d.IsDir() {
if d != nil && d.IsDir() && d.Name() != "." && strings.HasPrefix(d.Name(), ".") {
return filepath.SkipDir
}
return nil
}
// Apply glob filter
if a.Glob != "" {
matched, _ := filepath.Match(a.Glob, d.Name())
if !matched {
return nil
}
}
rel, _ := filepath.Rel(root, path)
fileMatches := grepFile(path, rel, re, maxResults-len(matches))
matches = append(matches, fileMatches...)
if len(matches) >= maxResults {
return filepath.SkipAll
}
return nil
})
}
if len(matches) == 0 {
return tool.Result{
Output: "(no matches)",
Metadata: map[string]any{"count": 0},
}, nil
}
var b strings.Builder
for _, m := range matches {
if m.File != "" {
fmt.Fprintf(&b, "%s:%d:%s\n", m.File, m.Line, m.Text)
} else {
fmt.Fprintf(&b, "%d:%s\n", m.Line, m.Text)
}
}
truncated := len(matches) >= maxResults
return tool.Result{
Output: strings.TrimRight(b.String(), "\n"),
Metadata: map[string]any{
"count": len(matches),
"truncated": truncated,
},
}, nil
}
func grepFile(path, displayPath string, re *regexp.Regexp, limit int) []grepMatch {
f, err := os.Open(path)
if err != nil {
return nil
}
defer f.Close()
var matches []grepMatch
scanner := bufio.NewScanner(f)
lineNum := 0
for scanner.Scan() {
lineNum++
line := scanner.Text()
if re.MatchString(line) {
matches = append(matches, grepMatch{
File: displayPath,
Line: lineNum,
Text: line,
})
if len(matches) >= limit {
break
}
}
}
return matches
}

123
internal/tool/fs/ls.go Normal file
View File

@@ -0,0 +1,123 @@
package fs
import (
"context"
"encoding/json"
"fmt"
"io/fs"
"os"
"path/filepath"
"strings"
"somegit.dev/Owlibou/gnoma/internal/tool"
)
const lsToolName = "fs.ls"
var lsParams = json.RawMessage(`{
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Directory path to list (defaults to current directory)"
}
}
}`)
type LSTool struct{}
func NewLSTool() *LSTool { return &LSTool{} }
func (t *LSTool) Name() string { return lsToolName }
func (t *LSTool) Description() string { return "List directory contents with file types and sizes" }
func (t *LSTool) Parameters() json.RawMessage { return lsParams }
func (t *LSTool) IsReadOnly() bool { return true }
func (t *LSTool) IsDestructive() bool { return false }
type lsArgs struct {
Path string `json:"path,omitempty"`
}
func (t *LSTool) Execute(_ context.Context, args json.RawMessage) (tool.Result, error) {
var a lsArgs
if err := json.Unmarshal(args, &a); err != nil {
return tool.Result{}, fmt.Errorf("fs.ls: invalid args: %w", err)
}
dir := a.Path
if dir == "" {
var err error
dir, err = os.Getwd()
if err != nil {
return tool.Result{}, fmt.Errorf("fs.ls: %w", err)
}
}
entries, err := os.ReadDir(dir)
if err != nil {
return tool.Result{Output: fmt.Sprintf("Error: %v", err)}, nil
}
var b strings.Builder
dirCount, fileCount := 0, 0
for _, entry := range entries {
info, err := entry.Info()
if err != nil {
continue
}
prefix := " "
if entry.IsDir() {
prefix = "d"
dirCount++
} else {
fileCount++
}
size := formatSize(info.Size())
if entry.IsDir() {
size = "-"
}
// Check for symlink
if entry.Type()&fs.ModeSymlink != 0 {
prefix = "l"
target, err := os.Readlink(filepath.Join(dir, entry.Name()))
if err == nil {
fmt.Fprintf(&b, "%s %8s %s -> %s\n", prefix, size, entry.Name(), target)
continue
}
}
fmt.Fprintf(&b, "%s %8s %s\n", prefix, size, entry.Name())
}
output := strings.TrimRight(b.String(), "\n")
if output == "" {
output = "(empty directory)"
}
return tool.Result{
Output: output,
Metadata: map[string]any{
"directory": dir,
"files": fileCount,
"dirs": dirCount,
"total": fileCount + dirCount,
},
}, nil
}
func formatSize(bytes int64) string {
switch {
case bytes >= 1<<30:
return fmt.Sprintf("%.1fG", float64(bytes)/(1<<30))
case bytes >= 1<<20:
return fmt.Sprintf("%.1fM", float64(bytes)/(1<<20))
case bytes >= 1<<10:
return fmt.Sprintf("%.1fK", float64(bytes)/(1<<10))
default:
return fmt.Sprintf("%dB", bytes)
}
}

123
internal/tool/fs/read.go Normal file
View File

@@ -0,0 +1,123 @@
package fs
import (
"context"
"encoding/json"
"fmt"
"os"
"strings"
"somegit.dev/Owlibou/gnoma/internal/tool"
)
const (
readToolName = "fs.read"
defaultMaxLines = 2000
)
var readParams = json.RawMessage(`{
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Absolute path to the file to read"
},
"offset": {
"type": "integer",
"description": "Line number to start reading from (0-based)"
},
"limit": {
"type": "integer",
"description": "Maximum number of lines to read"
}
},
"required": ["path"]
}`)
type ReadTool struct {
maxLines int
}
type ReadOption func(*ReadTool)
func WithMaxLines(n int) ReadOption {
return func(t *ReadTool) { t.maxLines = n }
}
func NewReadTool(opts ...ReadOption) *ReadTool {
t := &ReadTool{maxLines: defaultMaxLines}
for _, opt := range opts {
opt(t)
}
return t
}
func (t *ReadTool) Name() string { return readToolName }
func (t *ReadTool) Description() string { return "Read a file from the filesystem with optional offset and line limit" }
func (t *ReadTool) Parameters() json.RawMessage { return readParams }
func (t *ReadTool) IsReadOnly() bool { return true }
func (t *ReadTool) IsDestructive() bool { return false }
type readArgs struct {
Path string `json:"path"`
Offset int `json:"offset,omitempty"`
Limit int `json:"limit,omitempty"`
}
func (t *ReadTool) Execute(_ context.Context, args json.RawMessage) (tool.Result, error) {
var a readArgs
if err := json.Unmarshal(args, &a); err != nil {
return tool.Result{}, fmt.Errorf("fs.read: invalid args: %w", err)
}
if a.Path == "" {
return tool.Result{}, fmt.Errorf("fs.read: path required")
}
data, err := os.ReadFile(a.Path)
if err != nil {
return tool.Result{Output: fmt.Sprintf("Error: %v", err)}, nil
}
lines := strings.Split(string(data), "\n")
totalLines := len(lines)
// Apply offset
offset := a.Offset
if offset < 0 {
offset = 0
}
if offset >= totalLines {
return tool.Result{
Output: fmt.Sprintf("(file has %d lines, offset %d is past end)", totalLines, offset),
Metadata: map[string]any{"total_lines": totalLines},
}, nil
}
lines = lines[offset:]
// Apply limit
limit := a.Limit
if limit <= 0 {
limit = t.maxLines
}
truncated := false
if len(lines) > limit {
lines = lines[:limit]
truncated = true
}
// Format with line numbers (1-based, matching cat -n)
var b strings.Builder
for i, line := range lines {
fmt.Fprintf(&b, "%d\t%s\n", offset+i+1, line)
}
output := strings.TrimRight(b.String(), "\n")
meta := map[string]any{"total_lines": totalLines}
if truncated {
meta["truncated"] = true
meta["showing"] = fmt.Sprintf("lines %d-%d of %d", offset+1, offset+len(lines), totalLines)
}
return tool.Result{Output: output, Metadata: meta}, nil
}

68
internal/tool/fs/write.go Normal file
View File

@@ -0,0 +1,68 @@
package fs
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"somegit.dev/Owlibou/gnoma/internal/tool"
)
const writeToolName = "fs.write"
var writeParams = json.RawMessage(`{
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Absolute path to the file to write"
},
"content": {
"type": "string",
"description": "Content to write to the file"
}
},
"required": ["path", "content"]
}`)
type WriteTool struct{}
func NewWriteTool() *WriteTool { return &WriteTool{} }
func (t *WriteTool) Name() string { return writeToolName }
func (t *WriteTool) Description() string { return "Write content to a file, creating parent directories as needed" }
func (t *WriteTool) Parameters() json.RawMessage { return writeParams }
func (t *WriteTool) IsReadOnly() bool { return false }
func (t *WriteTool) IsDestructive() bool { return false }
type writeArgs struct {
Path string `json:"path"`
Content string `json:"content"`
}
func (t *WriteTool) Execute(_ context.Context, args json.RawMessage) (tool.Result, error) {
var a writeArgs
if err := json.Unmarshal(args, &a); err != nil {
return tool.Result{}, fmt.Errorf("fs.write: invalid args: %w", err)
}
if a.Path == "" {
return tool.Result{}, fmt.Errorf("fs.write: path required")
}
// Create parent directories
dir := filepath.Dir(a.Path)
if err := os.MkdirAll(dir, 0o755); err != nil {
return tool.Result{Output: fmt.Sprintf("Error creating directory: %v", err)}, nil
}
if err := os.WriteFile(a.Path, []byte(a.Content), 0o644); err != nil {
return tool.Result{Output: fmt.Sprintf("Error writing file: %v", err)}, nil
}
return tool.Result{
Output: fmt.Sprintf("Wrote %d bytes to %s", len(a.Content), a.Path),
Metadata: map[string]any{"bytes_written": len(a.Content), "path": a.Path},
}, nil
}

77
internal/tool/registry.go Normal file
View File

@@ -0,0 +1,77 @@
package tool
import (
"encoding/json"
"fmt"
"sync"
)
// Definition is the provider-agnostic tool schema sent to the LLM.
type Definition struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters json.RawMessage `json:"parameters"`
}
// Registry holds all available tools.
type Registry struct {
mu sync.RWMutex
tools map[string]Tool
}
func NewRegistry() *Registry {
return &Registry{
tools: make(map[string]Tool),
}
}
// Register adds a tool. Overwrites if name already exists.
func (r *Registry) Register(t Tool) {
r.mu.Lock()
defer r.mu.Unlock()
r.tools[t.Name()] = t
}
// Get returns a tool by name.
func (r *Registry) Get(name string) (Tool, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
t, ok := r.tools[name]
return t, ok
}
// All returns all registered tools.
func (r *Registry) All() []Tool {
r.mu.RLock()
defer r.mu.RUnlock()
all := make([]Tool, 0, len(r.tools))
for _, t := range r.tools {
all = append(all, t)
}
return all
}
// Definitions returns tool definitions for all registered tools,
// suitable for sending to the LLM.
func (r *Registry) Definitions() []Definition {
r.mu.RLock()
defer r.mu.RUnlock()
defs := make([]Definition, 0, len(r.tools))
for _, t := range r.tools {
defs = append(defs, Definition{
Name: t.Name(),
Description: t.Description(),
Parameters: t.Parameters(),
})
}
return defs
}
// MustGet returns a tool by name or panics. For use in tests.
func (r *Registry) MustGet(name string) Tool {
t, ok := r.Get(name)
if !ok {
panic(fmt.Sprintf("tool not found: %q", name))
}
return t
}

View File

@@ -0,0 +1,208 @@
package tool
import (
"context"
"encoding/json"
"slices"
"sort"
"testing"
)
// stubTool is a minimal Tool implementation for testing.
type stubTool struct {
name string
description string
params json.RawMessage
readOnly bool
destructive bool
execFn func(ctx context.Context, args json.RawMessage) (Result, error)
}
func (s *stubTool) Name() string { return s.name }
func (s *stubTool) Description() string { return s.description }
func (s *stubTool) Parameters() json.RawMessage { return s.params }
func (s *stubTool) IsReadOnly() bool { return s.readOnly }
func (s *stubTool) IsDestructive() bool { return s.destructive }
func (s *stubTool) Execute(ctx context.Context, args json.RawMessage) (Result, error) {
if s.execFn != nil {
return s.execFn(ctx, args)
}
return Result{Output: "ok"}, nil
}
func TestRegistry_RegisterAndGet(t *testing.T) {
r := NewRegistry()
r.Register(&stubTool{name: "bash", description: "run commands"})
tool, ok := r.Get("bash")
if !ok {
t.Fatal("Get(bash) should find tool")
}
if tool.Name() != "bash" {
t.Errorf("Name() = %q", tool.Name())
}
}
func TestRegistry_Get_NotFound(t *testing.T) {
r := NewRegistry()
_, ok := r.Get("nonexistent")
if ok {
t.Error("Get(nonexistent) should return false")
}
}
func TestRegistry_Register_Overwrite(t *testing.T) {
r := NewRegistry()
r.Register(&stubTool{name: "bash", description: "old"})
r.Register(&stubTool{name: "bash", description: "new"})
tool, _ := r.Get("bash")
if tool.Description() != "new" {
t.Errorf("Description() = %q, want 'new' (overwritten)", tool.Description())
}
}
func TestRegistry_All(t *testing.T) {
r := NewRegistry()
r.Register(&stubTool{name: "bash"})
r.Register(&stubTool{name: "fs.read"})
r.Register(&stubTool{name: "fs.write"})
all := r.All()
if len(all) != 3 {
t.Fatalf("len(All()) = %d, want 3", len(all))
}
names := make([]string, len(all))
for i, t := range all {
names[i] = t.Name()
}
sort.Strings(names)
want := []string{"bash", "fs.read", "fs.write"}
if !slices.Equal(names, want) {
t.Errorf("All() names = %v, want %v", names, want)
}
}
func TestRegistry_Definitions(t *testing.T) {
r := NewRegistry()
r.Register(&stubTool{
name: "bash",
description: "Run a command",
params: json.RawMessage(`{"type":"object","properties":{"command":{"type":"string"}}}`),
})
r.Register(&stubTool{
name: "fs.read",
description: "Read a file",
params: json.RawMessage(`{"type":"object","properties":{"path":{"type":"string"}}}`),
})
defs := r.Definitions()
if len(defs) != 2 {
t.Fatalf("len(Definitions()) = %d, want 2", len(defs))
}
// Find bash definition
var bashDef *Definition
for i := range defs {
if defs[i].Name == "bash" {
bashDef = &defs[i]
break
}
}
if bashDef == nil {
t.Fatal("bash definition not found")
}
if bashDef.Description != "Run a command" {
t.Errorf("bash.Description = %q", bashDef.Description)
}
if bashDef.Parameters == nil {
t.Error("bash.Parameters should not be nil")
}
}
func TestRegistry_MustGet_Panics(t *testing.T) {
r := NewRegistry()
defer func() {
if r := recover(); r == nil {
t.Error("MustGet should panic for missing tool")
}
}()
r.MustGet("nonexistent")
}
func TestRegistry_MustGet_Success(t *testing.T) {
r := NewRegistry()
r.Register(&stubTool{name: "bash"})
tool := r.MustGet("bash")
if tool.Name() != "bash" {
t.Errorf("Name() = %q", tool.Name())
}
}
func TestRegistry_Empty(t *testing.T) {
r := NewRegistry()
if len(r.All()) != 0 {
t.Error("empty registry should return no tools")
}
if len(r.Definitions()) != 0 {
t.Error("empty registry should return no definitions")
}
}
func TestStubTool_Execute(t *testing.T) {
called := false
tool := &stubTool{
name: "test",
execFn: func(ctx context.Context, args json.RawMessage) (Result, error) {
called = true
var input struct{ Value string }
json.Unmarshal(args, &input)
return Result{
Output: "processed: " + input.Value,
Metadata: map[string]any{"key": "val"},
}, nil
},
}
result, err := tool.Execute(context.Background(), json.RawMessage(`{"Value":"hello"}`))
if err != nil {
t.Fatalf("Execute: %v", err)
}
if !called {
t.Error("execFn should have been called")
}
if result.Output != "processed: hello" {
t.Errorf("Output = %q", result.Output)
}
if result.Metadata["key"] != "val" {
t.Errorf("Metadata = %v", result.Metadata)
}
}
func TestToolInterface_ReadOnlyDestructive(t *testing.T) {
readTool := &stubTool{name: "fs.read", readOnly: true, destructive: false}
writeTool := &stubTool{name: "fs.write", readOnly: false, destructive: false}
deleteTool := &stubTool{name: "bash.rm", readOnly: false, destructive: true}
if !readTool.IsReadOnly() {
t.Error("fs.read should be read-only")
}
if readTool.IsDestructive() {
t.Error("fs.read should not be destructive")
}
if writeTool.IsReadOnly() {
t.Error("fs.write should not be read-only")
}
if deleteTool.IsReadOnly() {
t.Error("bash.rm should not be read-only")
}
if !deleteTool.IsDestructive() {
t.Error("bash.rm should be destructive")
}
}

9
internal/tool/result.go Normal file
View File

@@ -0,0 +1,9 @@
package tool
// Result is the output of a tool execution.
type Result struct {
// Output is the text content returned to the LLM.
Output string
// Metadata carries optional structured data (exit code, file path, match count, etc.).
Metadata map[string]any
}

22
internal/tool/tool.go Normal file
View File

@@ -0,0 +1,22 @@
package tool
import (
"context"
"encoding/json"
)
// Tool is the interface every tool must implement.
type Tool interface {
// Name returns the tool's identifier (used in LLM tool schemas).
Name() string
// Description returns a human-readable description for the LLM.
Description() string
// Parameters returns the JSON Schema for the tool's input.
Parameters() json.RawMessage
// Execute runs the tool with the given JSON arguments.
Execute(ctx context.Context, args json.RawMessage) (Result, error)
// IsReadOnly returns true if the tool only reads (safe for concurrent execution).
IsReadOnly() bool
// IsDestructive returns true if the tool can cause irreversible changes.
IsDestructive() bool
}