feat: complete M1 — core engine with Mistral provider
Mistral provider adapter with streaming, tool calls (single-chunk pattern), stop reason inference, model listing, capabilities, and JSON output support. Tool system: bash (7 security checks, shell alias harvesting for bash/zsh/fish), file ops (read, write, edit, glob, grep, ls). Alias harvesting collects 300+ aliases from user's shell config. Engine agentic loop: stream → tool execution → re-query → until done. Tool gating on model capabilities. Max turns safety limit. CLI pipe mode: echo "prompt" | gnoma streams response to stdout. Flags: --provider, --model, --system, --api-key, --max-turns, --verbose, --version. Provider interface expanded: Models(), DefaultModel(), Capabilities (ToolUse, JSONOutput, Vision, Thinking, ContextWindow, MaxOutput), ResponseFormat with JSON schema support. Live verified: text streaming + tool calling with devstral-small. 117 tests across 8 packages, 10MB binary.
This commit is contained in:
@@ -1,7 +1,197 @@
|
||||
package main
|
||||
|
||||
import "fmt"
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/engine"
|
||||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||||
"somegit.dev/Owlibou/gnoma/internal/provider/mistral"
|
||||
"somegit.dev/Owlibou/gnoma/internal/stream"
|
||||
"somegit.dev/Owlibou/gnoma/internal/tool"
|
||||
"somegit.dev/Owlibou/gnoma/internal/tool/bash"
|
||||
"somegit.dev/Owlibou/gnoma/internal/tool/fs"
|
||||
)
|
||||
|
||||
func main() {
|
||||
fmt.Println("gnoma — provider-agnostic agentic coding assistant")
|
||||
var (
|
||||
providerName = flag.String("provider", "mistral", "LLM provider")
|
||||
model = flag.String("model", "", "model name (empty = provider default)")
|
||||
system = flag.String("system", defaultSystem, "system prompt")
|
||||
apiKey = flag.String("api-key", "", "API key (or set MISTRAL_API_KEY env)")
|
||||
maxTurns = flag.Int("max-turns", 50, "max tool-calling rounds per turn")
|
||||
verbose = flag.Bool("verbose", false, "enable debug logging")
|
||||
version = flag.Bool("version", false, "print version and exit")
|
||||
)
|
||||
flag.Parse()
|
||||
|
||||
if *version {
|
||||
fmt.Println("gnoma v0.1.0-dev")
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
// Logger
|
||||
logLevel := slog.LevelWarn
|
||||
if *verbose {
|
||||
logLevel = slog.LevelDebug
|
||||
}
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: logLevel}))
|
||||
|
||||
// Resolve API key
|
||||
key := *apiKey
|
||||
if key == "" {
|
||||
key = resolveAPIKey(*providerName)
|
||||
}
|
||||
if key == "" {
|
||||
fmt.Fprintf(os.Stderr, "error: no API key for provider %q\nSet %s environment variable or use --api-key\n",
|
||||
*providerName, envKeyFor(*providerName))
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Create provider
|
||||
prov, err := createProvider(*providerName, key, *model)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Create tool registry
|
||||
reg := buildToolRegistry()
|
||||
|
||||
// Harvest shell aliases
|
||||
aliases, err := bash.HarvestAliases(context.Background())
|
||||
if err != nil {
|
||||
logger.Debug("alias harvest failed (non-fatal)", "error", err)
|
||||
} else {
|
||||
logger.Debug("harvested aliases", "count", aliases.Len())
|
||||
}
|
||||
|
||||
// Re-register bash tool with aliases
|
||||
reg.Register(bash.New(bash.WithAliases(aliases)))
|
||||
|
||||
// Create engine
|
||||
eng, err := engine.New(engine.Config{
|
||||
Provider: prov,
|
||||
Tools: reg,
|
||||
System: *system,
|
||||
Model: *model,
|
||||
MaxTurns: *maxTurns,
|
||||
Logger: logger,
|
||||
})
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Read input
|
||||
input, err := readInput(flag.Args())
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if input == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: no input provided")
|
||||
fmt.Fprintln(os.Stderr, "usage: echo 'prompt' | gnoma")
|
||||
fmt.Fprintln(os.Stderr, " or: gnoma 'prompt'")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Context with signal handling
|
||||
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
|
||||
defer cancel()
|
||||
|
||||
// Callback: stream text deltas to stdout
|
||||
cb := func(evt stream.Event) {
|
||||
if evt.Type == stream.EventTextDelta && evt.Text != "" {
|
||||
fmt.Print(evt.Text)
|
||||
}
|
||||
}
|
||||
|
||||
// Submit and run
|
||||
_, err = eng.Submit(ctx, input, cb)
|
||||
fmt.Println() // final newline
|
||||
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
fmt.Fprintln(os.Stderr, "\ninterrupted")
|
||||
os.Exit(130)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func readInput(args []string) (string, error) {
|
||||
// Positional args
|
||||
if len(args) > 0 {
|
||||
return strings.Join(args, " "), nil
|
||||
}
|
||||
|
||||
// Stdin (pipe mode)
|
||||
stat, _ := os.Stdin.Stat()
|
||||
if stat.Mode()&os.ModeCharDevice == 0 {
|
||||
data, err := io.ReadAll(os.Stdin)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("reading stdin: %w", err)
|
||||
}
|
||||
return strings.TrimSpace(string(data)), nil
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func resolveAPIKey(providerName string) string {
|
||||
envVar := envKeyFor(providerName)
|
||||
return os.Getenv(envVar)
|
||||
}
|
||||
|
||||
func envKeyFor(providerName string) string {
|
||||
switch providerName {
|
||||
case "mistral":
|
||||
return "MISTRAL_API_KEY"
|
||||
case "anthropic":
|
||||
return "ANTHROPIC_API_KEY"
|
||||
case "openai":
|
||||
return "OPENAI_API_KEY"
|
||||
case "google":
|
||||
return "GEMINI_API_KEY"
|
||||
default:
|
||||
return strings.ToUpper(providerName) + "_API_KEY"
|
||||
}
|
||||
}
|
||||
|
||||
func createProvider(name, apiKey, model string) (provider.Provider, error) {
|
||||
cfg := provider.ProviderConfig{
|
||||
APIKey: apiKey,
|
||||
Model: model,
|
||||
}
|
||||
|
||||
switch name {
|
||||
case "mistral":
|
||||
return mistral.New(cfg)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown provider %q (M1 supports: mistral)", name)
|
||||
}
|
||||
}
|
||||
|
||||
func buildToolRegistry() *tool.Registry {
|
||||
reg := tool.NewRegistry()
|
||||
reg.Register(bash.New())
|
||||
reg.Register(fs.NewReadTool())
|
||||
reg.Register(fs.NewWriteTool())
|
||||
reg.Register(fs.NewEditTool())
|
||||
reg.Register(fs.NewGlobTool())
|
||||
reg.Register(fs.NewGrepTool())
|
||||
reg.Register(fs.NewLSTool())
|
||||
return reg
|
||||
}
|
||||
|
||||
const defaultSystem = `You are gnoma, a provider-agnostic agentic coding assistant.
|
||||
You help users with software engineering tasks by reading files, writing code, and executing commands.
|
||||
Be concise and direct. Use tools when needed to accomplish the task.`
|
||||
|
||||
@@ -116,9 +116,13 @@ depends_on: [vision]
|
||||
- [ ] Model picker overlay
|
||||
- [ ] In-app config editor (`/config` command)
|
||||
- [ ] Incognito toggle (`/incognito` command)
|
||||
- [ ] Interactive shell pane: `/shell` command or keybinding opens PTY-connected shell
|
||||
- For commands needing user input (sudo, ssh, git push with auth, passwd prompts)
|
||||
- Bash tool detects potentially interactive commands and suggests take-over
|
||||
- PTY-based execution for flagged commands
|
||||
- [ ] Session management (channel-based)
|
||||
|
||||
**Exit criteria:** Launch TUI, chat interactively, 6 permission modes work, config editable in-app, incognito toggleable.
|
||||
**Exit criteria:** Launch TUI, chat interactively, 6 permission modes work, config editable in-app, incognito toggleable, `/shell` opens interactive terminal for password prompts.
|
||||
|
||||
## M6: Context Intelligence
|
||||
|
||||
@@ -219,7 +223,7 @@ depends_on: [vision]
|
||||
|
||||
**Exit criteria:** gnoma suggests a persistent task after 3+ repetitions. `/task release v1.2.0` executes a saved workflow.
|
||||
|
||||
## M12: Thinking & Structured Output
|
||||
## M12: Thinking, Structured Output & Notebook
|
||||
|
||||
**Deliverables:**
|
||||
|
||||
@@ -227,6 +231,7 @@ depends_on: [vision]
|
||||
- [ ] Thinking block streaming and TUI display
|
||||
- [ ] Structured output with JSON schema validation
|
||||
- [ ] Retry logic for schema validation failures
|
||||
- [ ] NotebookEdit tool: read/write/edit Jupyter notebook cells (.ipynb)
|
||||
|
||||
## M13: Auth
|
||||
|
||||
|
||||
7
internal/engine/callback.go
Normal file
7
internal/engine/callback.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package engine
|
||||
|
||||
import "somegit.dev/Owlibou/gnoma/internal/stream"
|
||||
|
||||
// Callback receives streaming events for real-time UI updates.
|
||||
// Called synchronously on the engine goroutine for each event.
|
||||
type Callback func(stream.Event)
|
||||
123
internal/engine/engine.go
Normal file
123
internal/engine/engine.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||||
"somegit.dev/Owlibou/gnoma/internal/tool"
|
||||
)
|
||||
|
||||
// Config holds engine configuration.
|
||||
type Config struct {
|
||||
Provider provider.Provider
|
||||
Tools *tool.Registry
|
||||
System string // system prompt
|
||||
Model string // override model (empty = provider default)
|
||||
MaxTurns int // safety limit on tool loops (0 = unlimited)
|
||||
Logger *slog.Logger
|
||||
}
|
||||
|
||||
func (c Config) validate() error {
|
||||
if c.Provider == nil {
|
||||
return fmt.Errorf("engine: provider required")
|
||||
}
|
||||
if c.Tools == nil {
|
||||
return fmt.Errorf("engine: tool registry required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Turn is the result of a complete agentic turn (may span multiple API calls).
|
||||
type Turn struct {
|
||||
Messages []message.Message // all messages produced (assistant + tool results)
|
||||
Usage message.Usage // cumulative for all API calls in this turn
|
||||
Rounds int // number of API round-trips
|
||||
}
|
||||
|
||||
// Engine orchestrates the conversation.
|
||||
type Engine struct {
|
||||
cfg Config
|
||||
history []message.Message
|
||||
usage message.Usage
|
||||
logger *slog.Logger
|
||||
|
||||
// Cached model capabilities, resolved lazily
|
||||
modelCaps *provider.Capabilities
|
||||
modelCapsFor string // model ID the cached caps are for
|
||||
}
|
||||
|
||||
// New creates an engine.
|
||||
func New(cfg Config) (*Engine, error) {
|
||||
if err := cfg.validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
logger := cfg.Logger
|
||||
if logger == nil {
|
||||
logger = slog.Default()
|
||||
}
|
||||
return &Engine{
|
||||
cfg: cfg,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// resolveCapabilities returns the capabilities for the active model.
|
||||
// Caches the result — re-resolves if the model changes.
|
||||
func (e *Engine) resolveCapabilities(ctx context.Context) *provider.Capabilities {
|
||||
model := e.cfg.Model
|
||||
if model == "" {
|
||||
model = e.cfg.Provider.DefaultModel()
|
||||
}
|
||||
|
||||
// Return cached if same model
|
||||
if e.modelCaps != nil && e.modelCapsFor == model {
|
||||
return e.modelCaps
|
||||
}
|
||||
|
||||
// Query provider for model list
|
||||
models, err := e.cfg.Provider.Models(ctx)
|
||||
if err != nil {
|
||||
e.logger.Debug("failed to fetch model capabilities", "error", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, m := range models {
|
||||
if m.ID == model {
|
||||
e.modelCaps = &m.Capabilities
|
||||
e.modelCapsFor = model
|
||||
return e.modelCaps
|
||||
}
|
||||
}
|
||||
|
||||
e.logger.Debug("model not found in provider model list", "model", model)
|
||||
return nil
|
||||
}
|
||||
|
||||
// History returns the full conversation.
|
||||
func (e *Engine) History() []message.Message {
|
||||
return e.history
|
||||
}
|
||||
|
||||
// Usage returns cumulative token usage.
|
||||
func (e *Engine) Usage() message.Usage {
|
||||
return e.usage
|
||||
}
|
||||
|
||||
// SetProvider swaps the active provider (for dynamic switching).
|
||||
func (e *Engine) SetProvider(p provider.Provider) {
|
||||
e.cfg.Provider = p
|
||||
}
|
||||
|
||||
// SetModel changes the model within the current provider.
|
||||
func (e *Engine) SetModel(model string) {
|
||||
e.cfg.Model = model
|
||||
}
|
||||
|
||||
// Reset clears conversation history and usage.
|
||||
func (e *Engine) Reset() {
|
||||
e.history = nil
|
||||
e.usage = message.Usage{}
|
||||
}
|
||||
475
internal/engine/engine_test.go
Normal file
475
internal/engine/engine_test.go
Normal file
@@ -0,0 +1,475 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||||
"somegit.dev/Owlibou/gnoma/internal/stream"
|
||||
"somegit.dev/Owlibou/gnoma/internal/tool"
|
||||
)
|
||||
|
||||
// --- Mock Provider ---
|
||||
|
||||
// mockProvider returns pre-configured streams for each call.
|
||||
type mockProvider struct {
|
||||
name string
|
||||
calls int
|
||||
streams []stream.Stream // one per call, consumed in order
|
||||
}
|
||||
|
||||
func (m *mockProvider) Name() string { return m.name }
|
||||
func (m *mockProvider) DefaultModel() string { return "mock-model" }
|
||||
func (m *mockProvider) Models(_ context.Context) ([]provider.ModelInfo, error) {
|
||||
return []provider.ModelInfo{{
|
||||
ID: "mock-model", Name: "mock-model", Provider: m.name,
|
||||
Capabilities: provider.Capabilities{ToolUse: true},
|
||||
}}, nil
|
||||
}
|
||||
func (m *mockProvider) Stream(_ context.Context, _ provider.Request) (stream.Stream, error) {
|
||||
if m.calls >= len(m.streams) {
|
||||
return nil, fmt.Errorf("mock: no more streams (called %d times)", m.calls+1)
|
||||
}
|
||||
s := m.streams[m.calls]
|
||||
m.calls++
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// eventStream is a mock stream backed by a slice of events.
|
||||
type eventStream struct {
|
||||
events []stream.Event
|
||||
idx int
|
||||
stopReason message.StopReason
|
||||
model string
|
||||
}
|
||||
|
||||
func newEventStream(stopReason message.StopReason, model string, events ...stream.Event) *eventStream {
|
||||
// Append a final event with stop reason
|
||||
events = append(events, stream.Event{
|
||||
Type: stream.EventTextDelta,
|
||||
StopReason: stopReason,
|
||||
Model: model,
|
||||
})
|
||||
return &eventStream{events: events, stopReason: stopReason, model: model}
|
||||
}
|
||||
|
||||
func (s *eventStream) Next() bool {
|
||||
if s.idx >= len(s.events) {
|
||||
return false
|
||||
}
|
||||
s.idx++
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *eventStream) Current() stream.Event {
|
||||
return s.events[s.idx-1]
|
||||
}
|
||||
|
||||
func (s *eventStream) Err() error { return nil }
|
||||
func (s *eventStream) Close() error { return nil }
|
||||
|
||||
// --- Mock Tool ---
|
||||
|
||||
type mockTool struct {
|
||||
name string
|
||||
readOnly bool
|
||||
execFn func(ctx context.Context, args json.RawMessage) (tool.Result, error)
|
||||
}
|
||||
|
||||
func (m *mockTool) Name() string { return m.name }
|
||||
func (m *mockTool) Description() string { return "mock tool" }
|
||||
func (m *mockTool) Parameters() json.RawMessage { return json.RawMessage(`{"type":"object"}`) }
|
||||
func (m *mockTool) IsReadOnly() bool { return m.readOnly }
|
||||
func (m *mockTool) IsDestructive() bool { return false }
|
||||
func (m *mockTool) Execute(ctx context.Context, args json.RawMessage) (tool.Result, error) {
|
||||
if m.execFn != nil {
|
||||
return m.execFn(ctx, args)
|
||||
}
|
||||
return tool.Result{Output: "mock output"}, nil
|
||||
}
|
||||
|
||||
// --- Tests ---
|
||||
|
||||
func TestNew_ValidConfig(t *testing.T) {
|
||||
e, err := New(Config{
|
||||
Provider: &mockProvider{name: "test"},
|
||||
Tools: tool.NewRegistry(),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
if e == nil {
|
||||
t.Fatal("engine should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew_MissingProvider(t *testing.T) {
|
||||
_, err := New(Config{Tools: tool.NewRegistry()})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing provider")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew_MissingTools(t *testing.T) {
|
||||
_, err := New(Config{Provider: &mockProvider{name: "test"}})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing tool registry")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubmit_SimpleTextResponse(t *testing.T) {
|
||||
mp := &mockProvider{
|
||||
name: "test",
|
||||
streams: []stream.Stream{
|
||||
newEventStream(message.StopEndTurn, "test-model",
|
||||
stream.Event{Type: stream.EventTextDelta, Text: "Hello "},
|
||||
stream.Event{Type: stream.EventTextDelta, Text: "world!"},
|
||||
stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 10, OutputTokens: 5}},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry()})
|
||||
|
||||
var events []stream.Event
|
||||
turn, err := e.Submit(context.Background(), "hi", func(evt stream.Event) {
|
||||
events = append(events, evt)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Submit: %v", err)
|
||||
}
|
||||
|
||||
// Check turn
|
||||
if turn.Rounds != 1 {
|
||||
t.Errorf("Rounds = %d, want 1", turn.Rounds)
|
||||
}
|
||||
if len(turn.Messages) != 1 {
|
||||
t.Fatalf("len(Messages) = %d, want 1", len(turn.Messages))
|
||||
}
|
||||
if turn.Messages[0].TextContent() != "Hello world!" {
|
||||
t.Errorf("TextContent = %q", turn.Messages[0].TextContent())
|
||||
}
|
||||
if turn.Usage.InputTokens != 10 {
|
||||
t.Errorf("Usage.InputTokens = %d", turn.Usage.InputTokens)
|
||||
}
|
||||
|
||||
// Check history
|
||||
history := e.History()
|
||||
if len(history) != 2 {
|
||||
t.Fatalf("len(History) = %d, want 2 (user + assistant)", len(history))
|
||||
}
|
||||
if history[0].Role != message.RoleUser {
|
||||
t.Errorf("History[0].Role = %q", history[0].Role)
|
||||
}
|
||||
if history[1].Role != message.RoleAssistant {
|
||||
t.Errorf("History[1].Role = %q", history[1].Role)
|
||||
}
|
||||
|
||||
// Check events were forwarded
|
||||
if len(events) == 0 {
|
||||
t.Error("callback should have received events")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubmit_ToolCallLoop(t *testing.T) {
|
||||
reg := tool.NewRegistry()
|
||||
reg.Register(&mockTool{
|
||||
name: "bash",
|
||||
execFn: func(_ context.Context, args json.RawMessage) (tool.Result, error) {
|
||||
return tool.Result{Output: "file1.go\nfile2.go"}, nil
|
||||
},
|
||||
})
|
||||
|
||||
mp := &mockProvider{
|
||||
name: "test",
|
||||
streams: []stream.Stream{
|
||||
// Round 1: model calls a tool
|
||||
newEventStream(message.StopToolUse, "model-1",
|
||||
stream.Event{Type: stream.EventTextDelta, Text: "Let me list files."},
|
||||
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "bash"},
|
||||
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_1", Args: json.RawMessage(`{"command":"ls"}`)},
|
||||
),
|
||||
// Round 2: model responds with final answer
|
||||
newEventStream(message.StopEndTurn, "model-1",
|
||||
stream.Event{Type: stream.EventTextDelta, Text: "Found file1.go and file2.go."},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
e, _ := New(Config{Provider: mp, Tools: reg})
|
||||
|
||||
turn, err := e.Submit(context.Background(), "list files", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Submit: %v", err)
|
||||
}
|
||||
|
||||
if turn.Rounds != 2 {
|
||||
t.Errorf("Rounds = %d, want 2", turn.Rounds)
|
||||
}
|
||||
|
||||
// Messages: assistant (tool call), tool results, assistant (final)
|
||||
if len(turn.Messages) != 3 {
|
||||
t.Fatalf("len(Messages) = %d, want 3", len(turn.Messages))
|
||||
}
|
||||
|
||||
// First message has tool call
|
||||
if !turn.Messages[0].HasToolCalls() {
|
||||
t.Error("Messages[0] should have tool calls")
|
||||
}
|
||||
|
||||
// Second message is tool results
|
||||
if turn.Messages[1].Role != message.RoleUser {
|
||||
t.Errorf("Messages[1].Role = %q, want user (tool results)", turn.Messages[1].Role)
|
||||
}
|
||||
|
||||
// Third message is final text
|
||||
if turn.Messages[2].TextContent() != "Found file1.go and file2.go." {
|
||||
t.Errorf("Messages[2].TextContent = %q", turn.Messages[2].TextContent())
|
||||
}
|
||||
|
||||
// History: user + assistant(tool call) + tool results + assistant(final)
|
||||
if len(e.History()) != 4 {
|
||||
t.Errorf("len(History) = %d, want 4", len(e.History()))
|
||||
}
|
||||
|
||||
// Provider called twice
|
||||
if mp.calls != 2 {
|
||||
t.Errorf("provider called %d times, want 2", mp.calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubmit_UnknownTool(t *testing.T) {
|
||||
reg := tool.NewRegistry()
|
||||
// Don't register any tools
|
||||
|
||||
mp := &mockProvider{
|
||||
name: "test",
|
||||
streams: []stream.Stream{
|
||||
// Model calls a tool that doesn't exist
|
||||
newEventStream(message.StopToolUse, "",
|
||||
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "nonexistent"},
|
||||
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_1", Args: json.RawMessage(`{}`)},
|
||||
),
|
||||
// Model responds after seeing error
|
||||
newEventStream(message.StopEndTurn, "",
|
||||
stream.Event{Type: stream.EventTextDelta, Text: "Sorry, that tool doesn't exist."},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
e, _ := New(Config{Provider: mp, Tools: reg})
|
||||
|
||||
turn, err := e.Submit(context.Background(), "do something", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Submit: %v", err)
|
||||
}
|
||||
|
||||
// Should still complete — unknown tool returns error result, model sees it
|
||||
if turn.Rounds != 2 {
|
||||
t.Errorf("Rounds = %d, want 2", turn.Rounds)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubmit_ToolExecutionError(t *testing.T) {
|
||||
reg := tool.NewRegistry()
|
||||
reg.Register(&mockTool{
|
||||
name: "failing",
|
||||
execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) {
|
||||
return tool.Result{}, errors.New("disk full")
|
||||
},
|
||||
})
|
||||
|
||||
mp := &mockProvider{
|
||||
name: "test",
|
||||
streams: []stream.Stream{
|
||||
newEventStream(message.StopToolUse, "",
|
||||
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "failing"},
|
||||
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_1", Args: json.RawMessage(`{}`)},
|
||||
),
|
||||
newEventStream(message.StopEndTurn, "",
|
||||
stream.Event{Type: stream.EventTextDelta, Text: "The tool failed."},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
e, _ := New(Config{Provider: mp, Tools: reg})
|
||||
|
||||
turn, err := e.Submit(context.Background(), "do it", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Submit: %v", err)
|
||||
}
|
||||
|
||||
// Tool error is returned as error result, not a fatal error
|
||||
if turn.Rounds != 2 {
|
||||
t.Errorf("Rounds = %d, want 2", turn.Rounds)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubmit_MaxTurnsLimit(t *testing.T) {
|
||||
reg := tool.NewRegistry()
|
||||
reg.Register(&mockTool{name: "bash"})
|
||||
|
||||
// Provider always returns tool calls — would loop forever
|
||||
mp := &mockProvider{
|
||||
name: "test",
|
||||
streams: []stream.Stream{
|
||||
newEventStream(message.StopToolUse, "",
|
||||
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "bash"},
|
||||
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_1", Args: json.RawMessage(`{}`)},
|
||||
),
|
||||
newEventStream(message.StopToolUse, "",
|
||||
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_2", ToolCallName: "bash"},
|
||||
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_2", Args: json.RawMessage(`{}`)},
|
||||
),
|
||||
newEventStream(message.StopToolUse, "",
|
||||
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_3", ToolCallName: "bash"},
|
||||
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_3", Args: json.RawMessage(`{}`)},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
e, _ := New(Config{Provider: mp, Tools: reg, MaxTurns: 2})
|
||||
|
||||
_, err := e.Submit(context.Background(), "loop forever", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error from max turns limit")
|
||||
}
|
||||
if mp.calls != 2 {
|
||||
t.Errorf("provider called %d times, want 2 (limited)", mp.calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubmit_MultipleToolCalls(t *testing.T) {
|
||||
reg := tool.NewRegistry()
|
||||
reg.Register(&mockTool{
|
||||
name: "bash",
|
||||
execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) {
|
||||
return tool.Result{Output: "bash output"}, nil
|
||||
},
|
||||
})
|
||||
reg.Register(&mockTool{
|
||||
name: "fs.read",
|
||||
readOnly: true,
|
||||
execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) {
|
||||
return tool.Result{Output: "file content"}, nil
|
||||
},
|
||||
})
|
||||
|
||||
mp := &mockProvider{
|
||||
name: "test",
|
||||
streams: []stream.Stream{
|
||||
// Model calls two tools at once
|
||||
newEventStream(message.StopToolUse, "",
|
||||
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "bash"},
|
||||
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_1", Args: json.RawMessage(`{"command":"ls"}`)},
|
||||
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_2", ToolCallName: "fs.read"},
|
||||
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_2", Args: json.RawMessage(`{"path":"go.mod"}`)},
|
||||
),
|
||||
newEventStream(message.StopEndTurn, "",
|
||||
stream.Event{Type: stream.EventTextDelta, Text: "Done."},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
e, _ := New(Config{Provider: mp, Tools: reg})
|
||||
|
||||
turn, err := e.Submit(context.Background(), "run both", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Submit: %v", err)
|
||||
}
|
||||
|
||||
if turn.Rounds != 2 {
|
||||
t.Errorf("Rounds = %d, want 2", turn.Rounds)
|
||||
}
|
||||
|
||||
// Tool results message should have 2 results
|
||||
toolMsg := turn.Messages[1] // assistant, tool_results, assistant
|
||||
if len(toolMsg.Content) != 2 {
|
||||
t.Errorf("tool results has %d content blocks, want 2", len(toolMsg.Content))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubmit_NilCallback(t *testing.T) {
|
||||
mp := &mockProvider{
|
||||
name: "test",
|
||||
streams: []stream.Stream{
|
||||
newEventStream(message.StopEndTurn, "",
|
||||
stream.Event{Type: stream.EventTextDelta, Text: "ok"},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry()})
|
||||
|
||||
// nil callback should not panic
|
||||
turn, err := e.Submit(context.Background(), "test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Submit: %v", err)
|
||||
}
|
||||
if turn.Rounds != 1 {
|
||||
t.Errorf("Rounds = %d", turn.Rounds)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngine_Reset(t *testing.T) {
|
||||
mp := &mockProvider{
|
||||
name: "test",
|
||||
streams: []stream.Stream{
|
||||
newEventStream(message.StopEndTurn, "",
|
||||
stream.Event{Type: stream.EventTextDelta, Text: "first"},
|
||||
stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 100}},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry()})
|
||||
e.Submit(context.Background(), "hello", nil)
|
||||
|
||||
if len(e.History()) == 0 {
|
||||
t.Fatal("history should not be empty before reset")
|
||||
}
|
||||
if e.Usage().InputTokens == 0 {
|
||||
t.Fatal("usage should not be zero before reset")
|
||||
}
|
||||
|
||||
e.Reset()
|
||||
|
||||
if len(e.History()) != 0 {
|
||||
t.Errorf("history should be empty after reset, got %d", len(e.History()))
|
||||
}
|
||||
if e.Usage().InputTokens != 0 {
|
||||
t.Errorf("usage should be zero after reset, got %d", e.Usage().InputTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubmit_CumulativeUsage(t *testing.T) {
|
||||
mp := &mockProvider{
|
||||
name: "test",
|
||||
streams: []stream.Stream{
|
||||
newEventStream(message.StopEndTurn, "",
|
||||
stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 100, OutputTokens: 50}},
|
||||
stream.Event{Type: stream.EventTextDelta, Text: "first"},
|
||||
),
|
||||
newEventStream(message.StopEndTurn, "",
|
||||
stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 200, OutputTokens: 80}},
|
||||
stream.Event{Type: stream.EventTextDelta, Text: "second"},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry()})
|
||||
|
||||
e.Submit(context.Background(), "one", nil)
|
||||
e.Submit(context.Background(), "two", nil)
|
||||
|
||||
if e.Usage().InputTokens != 300 {
|
||||
t.Errorf("cumulative InputTokens = %d, want 300", e.Usage().InputTokens)
|
||||
}
|
||||
if e.Usage().OutputTokens != 130 {
|
||||
t.Errorf("cumulative OutputTokens = %d, want 130", e.Usage().OutputTokens)
|
||||
}
|
||||
}
|
||||
204
internal/engine/loop.go
Normal file
204
internal/engine/loop.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||||
"somegit.dev/Owlibou/gnoma/internal/stream"
|
||||
)
|
||||
|
||||
// Submit sends a user message and runs the agentic loop to completion.
|
||||
// The callback receives real-time streaming events.
|
||||
func (e *Engine) Submit(ctx context.Context, input string, cb Callback) (*Turn, error) {
|
||||
userMsg := message.NewUserText(input)
|
||||
e.history = append(e.history, userMsg)
|
||||
|
||||
return e.runLoop(ctx, cb)
|
||||
}
|
||||
|
||||
// SubmitMessages is like Submit but accepts pre-built messages.
|
||||
func (e *Engine) SubmitMessages(ctx context.Context, msgs []message.Message, cb Callback) (*Turn, error) {
|
||||
e.history = append(e.history, msgs...)
|
||||
|
||||
return e.runLoop(ctx, cb)
|
||||
}
|
||||
|
||||
func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
|
||||
turn := &Turn{}
|
||||
|
||||
for {
|
||||
turn.Rounds++
|
||||
if e.cfg.MaxTurns > 0 && turn.Rounds > e.cfg.MaxTurns {
|
||||
return turn, fmt.Errorf("safety limit: %d rounds exceeded", e.cfg.MaxTurns)
|
||||
}
|
||||
|
||||
// Build provider request (gates tools on model capabilities)
|
||||
req := e.buildRequest(ctx)
|
||||
|
||||
e.logger.Debug("streaming request",
|
||||
"provider", e.cfg.Provider.Name(),
|
||||
"model", req.Model,
|
||||
"messages", len(req.Messages),
|
||||
"tools", len(req.Tools),
|
||||
"round", turn.Rounds,
|
||||
)
|
||||
|
||||
// Stream from provider
|
||||
s, err := e.cfg.Provider.Stream(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("provider stream: %w", err)
|
||||
}
|
||||
|
||||
// Consume stream, forwarding events to callback
|
||||
acc := stream.NewAccumulator()
|
||||
var stopReason message.StopReason
|
||||
var model string
|
||||
|
||||
for s.Next() {
|
||||
evt := s.Current()
|
||||
acc.Apply(evt)
|
||||
|
||||
// Capture stop reason and model from events
|
||||
if evt.StopReason != "" {
|
||||
stopReason = evt.StopReason
|
||||
}
|
||||
if evt.Model != "" {
|
||||
model = evt.Model
|
||||
}
|
||||
|
||||
if cb != nil {
|
||||
cb(evt)
|
||||
}
|
||||
}
|
||||
if err := s.Err(); err != nil {
|
||||
s.Close()
|
||||
return nil, fmt.Errorf("stream error: %w", err)
|
||||
}
|
||||
s.Close()
|
||||
|
||||
// Build response
|
||||
resp := acc.Response(stopReason, model)
|
||||
turn.Usage.Add(resp.Usage)
|
||||
turn.Messages = append(turn.Messages, resp.Message)
|
||||
e.history = append(e.history, resp.Message)
|
||||
e.usage.Add(resp.Usage)
|
||||
|
||||
e.logger.Debug("turn response",
|
||||
"stop_reason", resp.StopReason,
|
||||
"tool_calls", len(resp.Message.ToolCalls()),
|
||||
"round", turn.Rounds,
|
||||
)
|
||||
|
||||
// Decide next action
|
||||
switch resp.StopReason {
|
||||
case message.StopEndTurn, message.StopMaxTokens, message.StopSequence:
|
||||
return turn, nil
|
||||
|
||||
case message.StopToolUse:
|
||||
results, err := e.executeTools(ctx, resp.Message.ToolCalls(), cb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tool execution: %w", err)
|
||||
}
|
||||
toolMsg := message.NewToolResults(results...)
|
||||
turn.Messages = append(turn.Messages, toolMsg)
|
||||
e.history = append(e.history, toolMsg)
|
||||
// Continue loop — re-query provider with tool results
|
||||
|
||||
default:
|
||||
// Unknown stop reason or empty — treat as end of turn
|
||||
return turn, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) buildRequest(ctx context.Context) provider.Request {
|
||||
req := provider.Request{
|
||||
Model: e.cfg.Model,
|
||||
SystemPrompt: e.cfg.System,
|
||||
Messages: e.history,
|
||||
}
|
||||
|
||||
// Only include tools if the model supports them
|
||||
caps := e.resolveCapabilities(ctx)
|
||||
if caps == nil || caps.ToolUse {
|
||||
// nil caps = unknown model, include tools optimistically
|
||||
for _, t := range e.cfg.Tools.All() {
|
||||
req.Tools = append(req.Tools, provider.ToolDefinition{
|
||||
Name: t.Name(),
|
||||
Description: t.Description(),
|
||||
Parameters: t.Parameters(),
|
||||
})
|
||||
}
|
||||
} else {
|
||||
e.logger.Debug("tools omitted — model does not support tool use",
|
||||
"model", req.Model,
|
||||
)
|
||||
}
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func (e *Engine) executeTools(ctx context.Context, calls []message.ToolCall, cb Callback) ([]message.ToolResult, error) {
|
||||
results := make([]message.ToolResult, 0, len(calls))
|
||||
|
||||
for _, call := range calls {
|
||||
t, ok := e.cfg.Tools.Get(call.Name)
|
||||
if !ok {
|
||||
e.logger.Warn("unknown tool", "name", call.Name)
|
||||
results = append(results, message.ToolResult{
|
||||
ToolCallID: call.ID,
|
||||
Content: fmt.Sprintf("unknown tool: %s", call.Name),
|
||||
IsError: true,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
e.logger.Debug("executing tool", "name", call.Name, "id", call.ID)
|
||||
|
||||
result, err := t.Execute(ctx, call.Arguments)
|
||||
if err != nil {
|
||||
e.logger.Error("tool execution failed", "name", call.Name, "error", err)
|
||||
results = append(results, message.ToolResult{
|
||||
ToolCallID: call.ID,
|
||||
Content: err.Error(),
|
||||
IsError: true,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Emit tool result as a text delta event so the UI can show it
|
||||
if cb != nil {
|
||||
cb(stream.Event{
|
||||
Type: stream.EventTextDelta,
|
||||
Text: fmt.Sprintf("\n[tool:%s] %s\n", call.Name, truncate(result.Output, 500)),
|
||||
})
|
||||
}
|
||||
|
||||
results = append(results, message.ToolResult{
|
||||
ToolCallID: call.ID,
|
||||
Content: result.Output,
|
||||
})
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func truncate(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return s[:maxLen] + "..."
|
||||
}
|
||||
|
||||
// toolDefFromTool converts a tool.Tool to provider.ToolDefinition.
|
||||
// Unused currently but kept for reference when building tool definitions dynamically.
|
||||
func toolDefFromJSON(name, description string, params json.RawMessage) provider.ToolDefinition {
|
||||
return provider.ToolDefinition{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Parameters: params,
|
||||
}
|
||||
}
|
||||
124
internal/provider/mistral/provider.go
Normal file
124
internal/provider/mistral/provider.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||||
"somegit.dev/Owlibou/gnoma/internal/stream"
|
||||
mistralgo "somegit.dev/vikingowl/mistral-go-sdk"
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/model"
|
||||
)
|
||||
|
||||
const defaultModel = "mistral-large-latest"
|
||||
|
||||
// Provider implements provider.Provider for the Mistral API.
|
||||
type Provider struct {
|
||||
client *mistralgo.Client
|
||||
name string
|
||||
model string
|
||||
}
|
||||
|
||||
// New creates a Mistral provider from config.
|
||||
func New(cfg provider.ProviderConfig) (provider.Provider, error) {
|
||||
if cfg.APIKey == "" {
|
||||
return nil, fmt.Errorf("mistral: api key required")
|
||||
}
|
||||
|
||||
opts := []mistralgo.Option{}
|
||||
if cfg.BaseURL != "" {
|
||||
opts = append(opts, mistralgo.WithBaseURL(cfg.BaseURL))
|
||||
}
|
||||
|
||||
client := mistralgo.NewClient(cfg.APIKey, opts...)
|
||||
|
||||
m := cfg.Model
|
||||
if m == "" {
|
||||
m = defaultModel
|
||||
}
|
||||
|
||||
return &Provider{
|
||||
client: client,
|
||||
name: "mistral",
|
||||
model: m,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Stream initiates a streaming chat completion request.
|
||||
func (p *Provider) Stream(ctx context.Context, req provider.Request) (stream.Stream, error) {
|
||||
m := req.Model
|
||||
if m == "" {
|
||||
m = p.model
|
||||
}
|
||||
|
||||
cr := translateRequest(req)
|
||||
cr.Model = m
|
||||
|
||||
raw, err := p.client.ChatCompleteStream(ctx, cr)
|
||||
if err != nil {
|
||||
return nil, p.wrapError(err)
|
||||
}
|
||||
|
||||
return newMistralStream(raw), nil
|
||||
}
|
||||
|
||||
// Name returns "mistral".
|
||||
func (p *Provider) Name() string {
|
||||
return p.name
|
||||
}
|
||||
|
||||
// DefaultModel returns the configured default model.
|
||||
func (p *Provider) DefaultModel() string {
|
||||
return p.model
|
||||
}
|
||||
|
||||
// Models lists available models from the Mistral API with capability metadata.
|
||||
func (p *Provider) Models(ctx context.Context) ([]provider.ModelInfo, error) {
|
||||
resp, err := p.client.ListModels(ctx, &model.ListParams{})
|
||||
if err != nil {
|
||||
return nil, p.wrapError(err)
|
||||
}
|
||||
|
||||
var models []provider.ModelInfo
|
||||
for _, m := range resp.Data {
|
||||
models = append(models, provider.ModelInfo{
|
||||
ID: m.ID,
|
||||
Name: m.ID,
|
||||
Provider: p.name,
|
||||
Capabilities: inferCapabilities(m),
|
||||
})
|
||||
}
|
||||
return models, nil
|
||||
}
|
||||
|
||||
// inferCapabilities maps Mistral model metadata to gnoma capabilities.
|
||||
func inferCapabilities(m model.ModelCard) provider.Capabilities {
|
||||
caps := provider.Capabilities{
|
||||
ToolUse: m.Capabilities.FunctionCalling,
|
||||
Vision: m.Capabilities.Vision,
|
||||
JSONOutput: m.Capabilities.CompletionChat, // all chat models support JSON output via ResponseFormat
|
||||
ContextWindow: m.MaxContextLength,
|
||||
MaxOutput: 8192, // reasonable default
|
||||
}
|
||||
return caps
|
||||
}
|
||||
|
||||
func (p *Provider) wrapError(err error) error {
|
||||
if apiErr, ok := err.(*mistralgo.APIError); ok {
|
||||
kind, retryable := provider.ClassifyHTTPStatus(apiErr.StatusCode)
|
||||
return &provider.ProviderError{
|
||||
Kind: kind,
|
||||
Provider: p.name,
|
||||
StatusCode: apiErr.StatusCode,
|
||||
Message: apiErr.Message,
|
||||
Retryable: retryable,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
return &provider.ProviderError{
|
||||
Kind: provider.ErrTransient,
|
||||
Provider: p.name,
|
||||
Message: err.Error(),
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
248
internal/provider/mistral/stream.go
Normal file
248
internal/provider/mistral/stream.go
Normal file
@@ -0,0 +1,248 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
"somegit.dev/Owlibou/gnoma/internal/stream"
|
||||
mistralgo "somegit.dev/vikingowl/mistral-go-sdk"
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/chat"
|
||||
)
|
||||
|
||||
// mistralStream adapts mistral's Stream[CompletionChunk] to gnoma's stream.Stream.
|
||||
type mistralStream struct {
|
||||
raw *mistralgo.Stream[chat.CompletionChunk]
|
||||
cur stream.Event
|
||||
err error
|
||||
model string
|
||||
|
||||
// Track active tool calls for delta assembly
|
||||
activeToolCalls map[int]*toolCallState // keyed by ToolCall.Index
|
||||
|
||||
// Deferred finish reason (when finish arrives on the same chunk as content)
|
||||
pendingFinish *chat.FinishReason
|
||||
pendingUsage *message.Usage // usage from a chunk that also had other data
|
||||
emittedStop bool // true after we've emitted the synthetic stop event
|
||||
hadToolCalls bool // true if any tool calls were emitted
|
||||
}
|
||||
|
||||
type toolCallState struct {
|
||||
id string
|
||||
name string
|
||||
args string // accumulated argument fragments
|
||||
}
|
||||
|
||||
func newMistralStream(raw *mistralgo.Stream[chat.CompletionChunk]) *mistralStream {
|
||||
return &mistralStream{
|
||||
raw: raw,
|
||||
activeToolCalls: make(map[int]*toolCallState),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *mistralStream) Next() bool {
|
||||
for s.raw.Next() {
|
||||
chunk := s.raw.Current()
|
||||
|
||||
// Capture model from first chunk
|
||||
if s.model == "" && chunk.Model != "" {
|
||||
s.model = chunk.Model
|
||||
}
|
||||
|
||||
// Store usage if present (may be on same chunk as tool calls or finish)
|
||||
if chunk.Usage != nil {
|
||||
s.pendingUsage = translateUsage(chunk.Usage)
|
||||
}
|
||||
|
||||
if len(chunk.Choices) == 0 {
|
||||
// Chunk with only usage and no choices — emit usage
|
||||
if s.pendingUsage != nil {
|
||||
s.cur = stream.Event{Type: stream.EventUsage, Usage: s.pendingUsage}
|
||||
s.pendingUsage = nil
|
||||
return true
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
choice := chunk.Choices[0]
|
||||
delta := choice.Delta
|
||||
|
||||
// Process text content first (even on chunks with finish reason)
|
||||
text := delta.Content.String()
|
||||
if text != "" {
|
||||
s.cur = stream.Event{
|
||||
Type: stream.EventTextDelta,
|
||||
Text: text,
|
||||
}
|
||||
// If this chunk also has a finish reason, store it for next iteration
|
||||
if choice.FinishReason != nil {
|
||||
s.pendingFinish = choice.FinishReason
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Tool call deltas
|
||||
if len(delta.ToolCalls) > 0 {
|
||||
// Store finish reason if present on same chunk as tool calls
|
||||
if choice.FinishReason != nil {
|
||||
s.pendingFinish = choice.FinishReason
|
||||
}
|
||||
|
||||
for _, tc := range delta.ToolCalls {
|
||||
existing, ok := s.activeToolCalls[tc.Index]
|
||||
if !ok {
|
||||
// New tool call
|
||||
s.activeToolCalls[tc.Index] = &toolCallState{
|
||||
id: tc.ID,
|
||||
name: tc.Function.Name,
|
||||
args: tc.Function.Arguments,
|
||||
}
|
||||
s.hadToolCalls = true
|
||||
|
||||
// If arguments are already complete (Mistral sends full args in one chunk),
|
||||
// emit ToolCallDone directly instead of Start
|
||||
if tc.Function.Arguments != "" && s.pendingFinish != nil {
|
||||
s.cur = stream.Event{
|
||||
Type: stream.EventToolCallDone,
|
||||
ToolCallID: tc.ID,
|
||||
ToolCallName: tc.Function.Name,
|
||||
Args: json.RawMessage(tc.Function.Arguments),
|
||||
}
|
||||
// Remove from active — it's already done
|
||||
delete(s.activeToolCalls, tc.Index)
|
||||
return true
|
||||
}
|
||||
|
||||
// Otherwise emit Start, accumulate deltas later
|
||||
s.cur = stream.Event{
|
||||
Type: stream.EventToolCallStart,
|
||||
ToolCallID: tc.ID,
|
||||
ToolCallName: tc.Function.Name,
|
||||
}
|
||||
return true
|
||||
}
|
||||
// Existing tool call — accumulate arguments, emit Delta
|
||||
existing.args += tc.Function.Arguments
|
||||
if tc.Function.Arguments != "" {
|
||||
s.cur = stream.Event{
|
||||
Type: stream.EventToolCallDelta,
|
||||
ToolCallID: existing.id,
|
||||
ArgDelta: tc.Function.Arguments,
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Check finish reason (from this chunk or pending from previous)
|
||||
fr := choice.FinishReason
|
||||
if fr == nil {
|
||||
fr = s.pendingFinish
|
||||
s.pendingFinish = nil
|
||||
}
|
||||
|
||||
if fr != nil {
|
||||
// Flush any pending tool calls as Done events
|
||||
if *fr == chat.FinishReasonToolCalls {
|
||||
for idx, tc := range s.activeToolCalls {
|
||||
s.cur = stream.Event{
|
||||
Type: stream.EventToolCallDone,
|
||||
ToolCallID: tc.id,
|
||||
Args: json.RawMessage(tc.args),
|
||||
}
|
||||
delete(s.activeToolCalls, idx)
|
||||
s.pendingFinish = fr // re-store to flush remaining on next call
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Final event with stop reason
|
||||
s.cur = stream.Event{
|
||||
Type: stream.EventTextDelta,
|
||||
StopReason: translateFinishReason(fr),
|
||||
Model: s.model,
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Drain any pending finish reason that was stored with the last content chunk
|
||||
if s.pendingFinish != nil {
|
||||
fr := s.pendingFinish
|
||||
s.pendingFinish = nil
|
||||
|
||||
// Flush pending tool calls
|
||||
if *fr == chat.FinishReasonToolCalls {
|
||||
for idx, tc := range s.activeToolCalls {
|
||||
s.cur = stream.Event{
|
||||
Type: stream.EventToolCallDone,
|
||||
ToolCallID: tc.id,
|
||||
Args: json.RawMessage(tc.args),
|
||||
}
|
||||
delete(s.activeToolCalls, idx)
|
||||
s.pendingFinish = fr
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
s.cur = stream.Event{
|
||||
Type: stream.EventTextDelta,
|
||||
StopReason: translateFinishReason(fr),
|
||||
Model: s.model,
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Emit any pending usage before the stop event
|
||||
if s.pendingUsage != nil {
|
||||
s.cur = stream.Event{Type: stream.EventUsage, Usage: s.pendingUsage}
|
||||
s.pendingUsage = nil
|
||||
return true
|
||||
}
|
||||
|
||||
// Stream ended — emit inferred stop reason.
|
||||
if !s.emittedStop {
|
||||
s.emittedStop = true
|
||||
|
||||
// If we have pending tool calls, they ended with the stream
|
||||
if len(s.activeToolCalls) > 0 {
|
||||
for idx, tc := range s.activeToolCalls {
|
||||
s.cur = stream.Event{
|
||||
Type: stream.EventToolCallDone,
|
||||
ToolCallID: tc.id,
|
||||
Args: json.RawMessage(tc.args),
|
||||
}
|
||||
delete(s.activeToolCalls, idx)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Infer stop reason: if tool calls were emitted, it's ToolUse; otherwise EndTurn
|
||||
stopReason := message.StopEndTurn
|
||||
if s.hadToolCalls {
|
||||
stopReason = message.StopToolUse
|
||||
}
|
||||
|
||||
s.cur = stream.Event{
|
||||
Type: stream.EventTextDelta,
|
||||
StopReason: stopReason,
|
||||
Model: s.model,
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
s.err = s.raw.Err()
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *mistralStream) Current() stream.Event {
|
||||
return s.cur
|
||||
}
|
||||
|
||||
func (s *mistralStream) Err() error {
|
||||
return s.err
|
||||
}
|
||||
|
||||
func (s *mistralStream) Close() error {
|
||||
return s.raw.Close()
|
||||
}
|
||||
177
internal/provider/mistral/translate.go
Normal file
177
internal/provider/mistral/translate.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/chat"
|
||||
)
|
||||
|
||||
// --- gnoma → Mistral ---
|
||||
|
||||
func translateMessages(msgs []message.Message) []chat.Message {
|
||||
out := make([]chat.Message, 0, len(msgs))
|
||||
for _, m := range msgs {
|
||||
out = append(out, translateMessage(m))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func translateMessage(m message.Message) chat.Message {
|
||||
switch m.Role {
|
||||
case message.RoleSystem:
|
||||
return &chat.SystemMessage{Content: chat.TextContent(m.TextContent())}
|
||||
|
||||
case message.RoleUser:
|
||||
// Check if this is a tool results message
|
||||
if len(m.Content) > 0 && m.Content[0].Type == message.ContentToolResult {
|
||||
// Tool results must be sent as individual ToolMessages
|
||||
// Return only the first; caller handles multi-result expansion
|
||||
tr := m.Content[0].ToolResult
|
||||
return &chat.ToolMessage{
|
||||
ToolCallID: tr.ToolCallID,
|
||||
Content: chat.TextContent(tr.Content),
|
||||
}
|
||||
}
|
||||
return &chat.UserMessage{Content: chat.TextContent(m.TextContent())}
|
||||
|
||||
case message.RoleAssistant:
|
||||
am := chat.AssistantMessage{
|
||||
Content: chat.TextContent(m.TextContent()),
|
||||
}
|
||||
for _, tc := range m.ToolCalls() {
|
||||
am.ToolCalls = append(am.ToolCalls, chat.ToolCall{
|
||||
ID: tc.ID,
|
||||
Type: "function",
|
||||
Function: chat.FunctionCall{
|
||||
Name: tc.Name,
|
||||
Arguments: string(tc.Arguments),
|
||||
},
|
||||
})
|
||||
}
|
||||
return &am
|
||||
|
||||
default:
|
||||
return &chat.UserMessage{Content: chat.TextContent(m.TextContent())}
|
||||
}
|
||||
}
|
||||
|
||||
// expandToolResults handles the case where a gnoma Message contains
|
||||
// multiple ToolResults. Mistral expects one ToolMessage per result.
|
||||
func expandToolResults(msgs []message.Message) []chat.Message {
|
||||
out := make([]chat.Message, 0, len(msgs))
|
||||
for _, m := range msgs {
|
||||
if m.Role == message.RoleUser && len(m.Content) > 0 && m.Content[0].Type == message.ContentToolResult {
|
||||
for _, c := range m.Content {
|
||||
if c.Type == message.ContentToolResult && c.ToolResult != nil {
|
||||
out = append(out, &chat.ToolMessage{
|
||||
ToolCallID: c.ToolResult.ToolCallID,
|
||||
Content: chat.TextContent(c.ToolResult.Content),
|
||||
})
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
out = append(out, translateMessage(m))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func translateTools(defs []provider.ToolDefinition) []chat.Tool {
|
||||
if len(defs) == 0 {
|
||||
return nil
|
||||
}
|
||||
tools := make([]chat.Tool, len(defs))
|
||||
for i, d := range defs {
|
||||
var params map[string]any
|
||||
if d.Parameters != nil {
|
||||
_ = json.Unmarshal(d.Parameters, ¶ms)
|
||||
}
|
||||
tools[i] = chat.Tool{
|
||||
Type: "function",
|
||||
Function: chat.Function{
|
||||
Name: d.Name,
|
||||
Description: d.Description,
|
||||
Parameters: params,
|
||||
},
|
||||
}
|
||||
}
|
||||
return tools
|
||||
}
|
||||
|
||||
func translateRequest(req provider.Request) *chat.CompletionRequest {
|
||||
cr := &chat.CompletionRequest{
|
||||
Model: req.Model,
|
||||
Messages: expandToolResults(req.Messages),
|
||||
Tools: translateTools(req.Tools),
|
||||
Stop: req.StopSequences,
|
||||
}
|
||||
if req.MaxTokens > 0 {
|
||||
mt := int(req.MaxTokens)
|
||||
cr.MaxTokens = &mt
|
||||
}
|
||||
if req.Temperature != nil {
|
||||
cr.Temperature = req.Temperature
|
||||
}
|
||||
if req.TopP != nil {
|
||||
cr.TopP = req.TopP
|
||||
}
|
||||
if req.ResponseFormat != nil {
|
||||
cr.ResponseFormat = translateResponseFormat(req.ResponseFormat)
|
||||
}
|
||||
return cr
|
||||
}
|
||||
|
||||
func translateResponseFormat(rf *provider.ResponseFormat) *chat.ResponseFormat {
|
||||
if rf == nil {
|
||||
return nil
|
||||
}
|
||||
out := &chat.ResponseFormat{
|
||||
Type: chat.ResponseFormatType(rf.Type),
|
||||
}
|
||||
if rf.JSONSchema != nil {
|
||||
var schema map[string]any
|
||||
if rf.JSONSchema.Schema != nil {
|
||||
_ = json.Unmarshal(rf.JSONSchema.Schema, &schema)
|
||||
}
|
||||
out.JsonSchema = &chat.JsonSchema{
|
||||
Name: rf.JSONSchema.Name,
|
||||
Schema: schema,
|
||||
Strict: rf.JSONSchema.Strict,
|
||||
}
|
||||
if rf.JSONSchema.Description != "" {
|
||||
desc := rf.JSONSchema.Description
|
||||
out.JsonSchema.Description = &desc
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// --- Mistral → gnoma ---
|
||||
|
||||
func translateFinishReason(fr *chat.FinishReason) message.StopReason {
|
||||
if fr == nil {
|
||||
return ""
|
||||
}
|
||||
switch *fr {
|
||||
case chat.FinishReasonStop:
|
||||
return message.StopEndTurn
|
||||
case chat.FinishReasonToolCalls:
|
||||
return message.StopToolUse
|
||||
case chat.FinishReasonLength, chat.FinishReasonModelLength:
|
||||
return message.StopMaxTokens
|
||||
default:
|
||||
return message.StopEndTurn
|
||||
}
|
||||
}
|
||||
|
||||
func translateUsage(u *chat.UsageInfo) *message.Usage {
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
return &message.Usage{
|
||||
InputTokens: int64(u.PromptTokens),
|
||||
OutputTokens: int64(u.CompletionTokens),
|
||||
}
|
||||
}
|
||||
256
internal/provider/mistral/translate_test.go
Normal file
256
internal/provider/mistral/translate_test.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/chat"
|
||||
)
|
||||
|
||||
func TestTranslateMessage_User(t *testing.T) {
|
||||
m := message.NewUserText("hello world")
|
||||
result := translateMessage(m)
|
||||
|
||||
um, ok := result.(*chat.UserMessage)
|
||||
if !ok {
|
||||
t.Fatalf("expected *UserMessage, got %T", result)
|
||||
}
|
||||
if um.Content.String() != "hello world" {
|
||||
t.Errorf("Content = %q, want %q", um.Content.String(), "hello world")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateMessage_System(t *testing.T) {
|
||||
m := message.NewSystemText("you are a helper")
|
||||
result := translateMessage(m)
|
||||
|
||||
sm, ok := result.(*chat.SystemMessage)
|
||||
if !ok {
|
||||
t.Fatalf("expected *SystemMessage, got %T", result)
|
||||
}
|
||||
if sm.Content.String() != "you are a helper" {
|
||||
t.Errorf("Content = %q", sm.Content.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateMessage_AssistantText(t *testing.T) {
|
||||
m := message.NewAssistantText("here's the answer")
|
||||
result := translateMessage(m)
|
||||
|
||||
am, ok := result.(*chat.AssistantMessage)
|
||||
if !ok {
|
||||
t.Fatalf("expected *AssistantMessage, got %T", result)
|
||||
}
|
||||
if am.Content.String() != "here's the answer" {
|
||||
t.Errorf("Content = %q", am.Content.String())
|
||||
}
|
||||
if len(am.ToolCalls) != 0 {
|
||||
t.Errorf("ToolCalls should be empty, got %d", len(am.ToolCalls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateMessage_AssistantWithToolCalls(t *testing.T) {
|
||||
m := message.NewAssistantContent(
|
||||
message.NewTextContent("running command"),
|
||||
message.NewToolCallContent(message.ToolCall{
|
||||
ID: "tc_1",
|
||||
Name: "bash",
|
||||
Arguments: json.RawMessage(`{"command":"ls"}`),
|
||||
}),
|
||||
)
|
||||
result := translateMessage(m)
|
||||
|
||||
am, ok := result.(*chat.AssistantMessage)
|
||||
if !ok {
|
||||
t.Fatalf("expected *AssistantMessage, got %T", result)
|
||||
}
|
||||
if len(am.ToolCalls) != 1 {
|
||||
t.Fatalf("len(ToolCalls) = %d, want 1", len(am.ToolCalls))
|
||||
}
|
||||
if am.ToolCalls[0].ID != "tc_1" {
|
||||
t.Errorf("ToolCalls[0].ID = %q", am.ToolCalls[0].ID)
|
||||
}
|
||||
if am.ToolCalls[0].Function.Name != "bash" {
|
||||
t.Errorf("ToolCalls[0].Function.Name = %q", am.ToolCalls[0].Function.Name)
|
||||
}
|
||||
if am.ToolCalls[0].Function.Arguments != `{"command":"ls"}` {
|
||||
t.Errorf("ToolCalls[0].Function.Arguments = %q", am.ToolCalls[0].Function.Arguments)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandToolResults(t *testing.T) {
|
||||
msgs := []message.Message{
|
||||
message.NewUserText("run two commands"),
|
||||
message.NewAssistantContent(
|
||||
message.NewToolCallContent(message.ToolCall{ID: "tc_1", Name: "bash"}),
|
||||
message.NewToolCallContent(message.ToolCall{ID: "tc_2", Name: "bash"}),
|
||||
),
|
||||
message.NewToolResults(
|
||||
message.ToolResult{ToolCallID: "tc_1", Content: "output1"},
|
||||
message.ToolResult{ToolCallID: "tc_2", Content: "output2"},
|
||||
),
|
||||
}
|
||||
|
||||
expanded := expandToolResults(msgs)
|
||||
|
||||
// UserMessage, AssistantMessage, ToolMessage, ToolMessage
|
||||
if len(expanded) != 4 {
|
||||
t.Fatalf("len(expanded) = %d, want 4", len(expanded))
|
||||
}
|
||||
|
||||
// First: UserMessage
|
||||
if _, ok := expanded[0].(*chat.UserMessage); !ok {
|
||||
t.Errorf("expanded[0] = %T, want *UserMessage", expanded[0])
|
||||
}
|
||||
|
||||
// Second: AssistantMessage
|
||||
if _, ok := expanded[1].(*chat.AssistantMessage); !ok {
|
||||
t.Errorf("expanded[1] = %T, want *AssistantMessage", expanded[1])
|
||||
}
|
||||
|
||||
// Third and fourth: ToolMessages
|
||||
tm1, ok := expanded[2].(*chat.ToolMessage)
|
||||
if !ok {
|
||||
t.Fatalf("expanded[2] = %T, want *ToolMessage", expanded[2])
|
||||
}
|
||||
if tm1.ToolCallID != "tc_1" {
|
||||
t.Errorf("expanded[2].ToolCallID = %q, want tc_1", tm1.ToolCallID)
|
||||
}
|
||||
|
||||
tm2, ok := expanded[3].(*chat.ToolMessage)
|
||||
if !ok {
|
||||
t.Fatalf("expanded[3] = %T, want *ToolMessage", expanded[3])
|
||||
}
|
||||
if tm2.ToolCallID != "tc_2" {
|
||||
t.Errorf("expanded[3].ToolCallID = %q, want tc_2", tm2.ToolCallID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateTools(t *testing.T) {
|
||||
defs := []provider.ToolDefinition{
|
||||
{
|
||||
Name: "bash",
|
||||
Description: "Run a bash command",
|
||||
Parameters: json.RawMessage(`{"type":"object","properties":{"command":{"type":"string"}},"required":["command"]}`),
|
||||
},
|
||||
{
|
||||
Name: "fs.read",
|
||||
Description: "Read a file",
|
||||
Parameters: json.RawMessage(`{"type":"object","properties":{"path":{"type":"string"}}}`),
|
||||
},
|
||||
}
|
||||
|
||||
tools := translateTools(defs)
|
||||
if len(tools) != 2 {
|
||||
t.Fatalf("len(tools) = %d, want 2", len(tools))
|
||||
}
|
||||
|
||||
if tools[0].Type != "function" {
|
||||
t.Errorf("tools[0].Type = %q, want function", tools[0].Type)
|
||||
}
|
||||
if tools[0].Function.Name != "bash" {
|
||||
t.Errorf("tools[0].Function.Name = %q", tools[0].Function.Name)
|
||||
}
|
||||
if tools[0].Function.Description != "Run a bash command" {
|
||||
t.Errorf("tools[0].Function.Description = %q", tools[0].Function.Description)
|
||||
}
|
||||
if tools[0].Function.Parameters == nil {
|
||||
t.Error("tools[0].Function.Parameters should not be nil")
|
||||
}
|
||||
// Verify the parameters were correctly unmarshaled
|
||||
if _, ok := tools[0].Function.Parameters["type"]; !ok {
|
||||
t.Error("tools[0].Function.Parameters missing 'type' key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateTools_Empty(t *testing.T) {
|
||||
tools := translateTools(nil)
|
||||
if tools != nil {
|
||||
t.Errorf("translateTools(nil) should return nil, got %v", tools)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateFinishReason(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reason *chat.FinishReason
|
||||
want message.StopReason
|
||||
}{
|
||||
{"nil", nil, ""},
|
||||
{"stop", ptr(chat.FinishReasonStop), message.StopEndTurn},
|
||||
{"tool_calls", ptr(chat.FinishReasonToolCalls), message.StopToolUse},
|
||||
{"length", ptr(chat.FinishReasonLength), message.StopMaxTokens},
|
||||
{"model_length", ptr(chat.FinishReasonModelLength), message.StopMaxTokens},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := translateFinishReason(tt.reason)
|
||||
if got != tt.want {
|
||||
t.Errorf("translateFinishReason() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateUsage(t *testing.T) {
|
||||
u := &chat.UsageInfo{
|
||||
PromptTokens: 100,
|
||||
CompletionTokens: 50,
|
||||
TotalTokens: 150,
|
||||
}
|
||||
result := translateUsage(u)
|
||||
if result.InputTokens != 100 {
|
||||
t.Errorf("InputTokens = %d, want 100", result.InputTokens)
|
||||
}
|
||||
if result.OutputTokens != 50 {
|
||||
t.Errorf("OutputTokens = %d, want 50", result.OutputTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateUsage_Nil(t *testing.T) {
|
||||
result := translateUsage(nil)
|
||||
if result != nil {
|
||||
t.Error("translateUsage(nil) should return nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateRequest(t *testing.T) {
|
||||
temp := 0.7
|
||||
req := provider.Request{
|
||||
Model: "mistral-large-latest",
|
||||
SystemPrompt: "you are helpful",
|
||||
Messages: []message.Message{
|
||||
message.NewSystemText("you are helpful"),
|
||||
message.NewUserText("hello"),
|
||||
},
|
||||
Tools: []provider.ToolDefinition{
|
||||
{Name: "bash", Description: "Run command", Parameters: json.RawMessage(`{"type":"object"}`)},
|
||||
},
|
||||
MaxTokens: 4096,
|
||||
Temperature: &temp,
|
||||
}
|
||||
|
||||
cr := translateRequest(req)
|
||||
|
||||
if cr.Model != "mistral-large-latest" {
|
||||
t.Errorf("Model = %q", cr.Model)
|
||||
}
|
||||
if len(cr.Messages) != 2 {
|
||||
t.Errorf("len(Messages) = %d, want 2", len(cr.Messages))
|
||||
}
|
||||
if len(cr.Tools) != 1 {
|
||||
t.Errorf("len(Tools) = %d, want 1", len(cr.Tools))
|
||||
}
|
||||
if cr.MaxTokens == nil || *cr.MaxTokens != 4096 {
|
||||
t.Errorf("MaxTokens = %v", cr.MaxTokens)
|
||||
}
|
||||
if cr.Temperature == nil || *cr.Temperature != 0.7 {
|
||||
t.Errorf("Temperature = %v", cr.Temperature)
|
||||
}
|
||||
}
|
||||
|
||||
func ptr[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
@@ -20,6 +20,7 @@ type Request struct {
|
||||
TopK *int64
|
||||
StopSequences []string
|
||||
Thinking *ThinkingConfig
|
||||
ResponseFormat *ResponseFormat
|
||||
}
|
||||
|
||||
// ToolDefinition is the provider-agnostic tool schema.
|
||||
@@ -34,6 +35,50 @@ type ThinkingConfig struct {
|
||||
BudgetTokens int64
|
||||
}
|
||||
|
||||
// ResponseFormat controls the output format.
|
||||
type ResponseFormat struct {
|
||||
Type ResponseFormatType
|
||||
JSONSchema *JSONSchema // only used when Type == ResponseJSON
|
||||
}
|
||||
|
||||
type ResponseFormatType string
|
||||
|
||||
const (
|
||||
ResponseText ResponseFormatType = "text"
|
||||
ResponseJSON ResponseFormatType = "json_object"
|
||||
)
|
||||
|
||||
// JSONSchema defines a schema for structured JSON output.
|
||||
type JSONSchema struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Schema json.RawMessage `json:"schema"`
|
||||
Strict bool `json:"strict,omitempty"`
|
||||
}
|
||||
|
||||
// Capabilities describes what a model can do.
|
||||
type Capabilities struct {
|
||||
ToolUse bool `json:"tool_use"`
|
||||
JSONOutput bool `json:"json_output"`
|
||||
Thinking bool `json:"thinking"`
|
||||
Vision bool `json:"vision"`
|
||||
ContextWindow int `json:"context_window"`
|
||||
MaxOutput int `json:"max_output"`
|
||||
}
|
||||
|
||||
// ModelInfo describes a model available from a provider.
|
||||
type ModelInfo struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Provider string `json:"provider"`
|
||||
Capabilities Capabilities `json:"capabilities"`
|
||||
}
|
||||
|
||||
// SupportsTools returns true if the model supports tool/function calling.
|
||||
func (m ModelInfo) SupportsTools() bool {
|
||||
return m.Capabilities.ToolUse
|
||||
}
|
||||
|
||||
// Provider is the core abstraction over all LLM backends.
|
||||
type Provider interface {
|
||||
// Stream initiates a streaming request and returns an event stream.
|
||||
@@ -41,4 +86,10 @@ type Provider interface {
|
||||
|
||||
// Name returns the provider identifier (e.g., "mistral", "anthropic").
|
||||
Name() string
|
||||
|
||||
// Models returns available models with their capabilities.
|
||||
Models(ctx context.Context) ([]ModelInfo, error)
|
||||
|
||||
// DefaultModel returns the default model ID for this provider.
|
||||
DefaultModel() string
|
||||
}
|
||||
|
||||
@@ -19,8 +19,10 @@ func (m *mockProvider) Stream(_ context.Context, _ Request) (stream.Stream, erro
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockProvider) Name() string {
|
||||
return m.name
|
||||
func (m *mockProvider) Name() string { return m.name }
|
||||
func (m *mockProvider) DefaultModel() string { return "mock-model" }
|
||||
func (m *mockProvider) Models(_ context.Context) ([]ModelInfo, error) {
|
||||
return []ModelInfo{{ID: "mock-model", Name: "mock-model", Provider: m.name}}, nil
|
||||
}
|
||||
|
||||
func TestRegistry_RegisterAndCreate(t *testing.T) {
|
||||
|
||||
@@ -68,7 +68,13 @@ func (a *Accumulator) Apply(e Event) {
|
||||
}
|
||||
|
||||
case EventToolCallDone:
|
||||
if tc, ok := a.toolCalls[e.ToolCallID]; ok {
|
||||
tc, ok := a.toolCalls[e.ToolCallID]
|
||||
if !ok {
|
||||
// Done without prior Start (e.g., Mistral sends complete tool calls in one chunk)
|
||||
tc = &toolCallAccum{id: e.ToolCallID, name: e.ToolCallName}
|
||||
a.toolCalls[e.ToolCallID] = tc
|
||||
a.toolCallOrder = append(a.toolCallOrder, e.ToolCallID)
|
||||
}
|
||||
if e.Args != nil {
|
||||
// Done event carries authoritative complete args
|
||||
tc.args = e.Args
|
||||
@@ -76,7 +82,6 @@ func (a *Accumulator) Apply(e Event) {
|
||||
// Fall back to accumulated deltas
|
||||
tc.args = []byte(tc.argsBuf.String())
|
||||
}
|
||||
}
|
||||
|
||||
case EventUsage:
|
||||
if e.Usage != nil {
|
||||
|
||||
231
internal/tool/bash/aliases.go
Normal file
231
internal/tool/bash/aliases.go
Normal file
@@ -0,0 +1,231 @@
|
||||
package bash
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const aliasHarvestTimeout = 5 * time.Second
|
||||
|
||||
// AliasMap holds harvested shell aliases.
|
||||
type AliasMap struct {
|
||||
mu sync.RWMutex
|
||||
aliases map[string]string // alias name → expansion
|
||||
}
|
||||
|
||||
func NewAliasMap() *AliasMap {
|
||||
return &AliasMap{aliases: make(map[string]string)}
|
||||
}
|
||||
|
||||
// Get returns the expansion for an alias, or empty string if not found.
|
||||
func (m *AliasMap) Get(name string) (string, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
exp, ok := m.aliases[name]
|
||||
return exp, ok
|
||||
}
|
||||
|
||||
// Len returns the number of harvested aliases.
|
||||
func (m *AliasMap) Len() int {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return len(m.aliases)
|
||||
}
|
||||
|
||||
// All returns a copy of all aliases.
|
||||
func (m *AliasMap) All() map[string]string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
cp := make(map[string]string, len(m.aliases))
|
||||
for k, v := range m.aliases {
|
||||
cp[k] = v
|
||||
}
|
||||
return cp
|
||||
}
|
||||
|
||||
// ExpandCommand expands the first word of a command if it's a known alias.
|
||||
// Only the first word is expanded (matching bash alias behavior).
|
||||
// Returns the original command unchanged if no alias matches.
|
||||
func (m *AliasMap) ExpandCommand(cmd string) string {
|
||||
trimmed := strings.TrimSpace(cmd)
|
||||
if trimmed == "" {
|
||||
return cmd
|
||||
}
|
||||
|
||||
// Extract first word
|
||||
firstWord := trimmed
|
||||
rest := ""
|
||||
if idx := strings.IndexAny(trimmed, " \t"); idx != -1 {
|
||||
firstWord = trimmed[:idx]
|
||||
rest = trimmed[idx:]
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
expansion, ok := m.aliases[firstWord]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return cmd
|
||||
}
|
||||
return expansion + rest
|
||||
}
|
||||
|
||||
// HarvestAliases spawns the user's shell once to collect alias definitions.
|
||||
// Supports bash, zsh, and fish. Falls back gracefully for unknown shells.
|
||||
// Safe: only reads alias text definitions, never sources them in execution context.
|
||||
func HarvestAliases(ctx context.Context) (*AliasMap, error) {
|
||||
shell := os.Getenv("SHELL")
|
||||
if shell == "" {
|
||||
shell = "/bin/bash"
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, aliasHarvestTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Build the alias dump command based on shell type
|
||||
shellBase := shellBaseName(shell)
|
||||
aliasCmd := aliasCommandFor(shellBase)
|
||||
|
||||
// -i: interactive (loads rc files), -c: run command then exit
|
||||
cmd := exec.CommandContext(ctx, shell, "-ic", aliasCmd)
|
||||
// Prevent the interactive shell from reading actual stdin
|
||||
cmd.Stdin = nil
|
||||
// Suppress stderr (shell startup warnings like zsh's "can't change option: zle")
|
||||
cmd.Stderr = nil
|
||||
|
||||
// Use Output() but don't fail on non-zero exit — zsh often exits with
|
||||
// errors from zle/prompt setup while still producing valid alias output
|
||||
output, err := cmd.Output()
|
||||
if len(output) == 0 && err != nil {
|
||||
return NewAliasMap(), fmt.Errorf("alias harvest (%s): %w", shellBase, err)
|
||||
}
|
||||
// If we got output, parse it regardless of exit code
|
||||
|
||||
if shellBase == "fish" {
|
||||
return ParseFishAliases(string(output))
|
||||
}
|
||||
return ParseAliases(string(output))
|
||||
}
|
||||
|
||||
// shellBaseName extracts the shell name from a path (e.g., "/bin/zsh" → "zsh").
|
||||
func shellBaseName(shell string) string {
|
||||
parts := strings.Split(shell, "/")
|
||||
return parts[len(parts)-1]
|
||||
}
|
||||
|
||||
// aliasCommandFor returns the alias dump command for a given shell.
|
||||
func aliasCommandFor(shell string) string {
|
||||
switch shell {
|
||||
case "fish":
|
||||
// fish uses `alias` without -p, outputs: alias name 'expansion'
|
||||
return "alias 2>/dev/null; true"
|
||||
case "zsh":
|
||||
// zsh: `alias -p` produces nothing; `alias` outputs name=value (no quotes)
|
||||
return "alias 2>/dev/null; true"
|
||||
case "bash", "sh", "dash", "ash":
|
||||
// POSIX shells use `alias -p`
|
||||
return "alias -p 2>/dev/null; true"
|
||||
default:
|
||||
// Best effort for unknown shells
|
||||
return "alias 2>/dev/null; true"
|
||||
}
|
||||
}
|
||||
|
||||
// ParseFishAliases parses fish shell alias output.
|
||||
// Fish format: alias name 'expansion' or alias name "expansion"
|
||||
func ParseFishAliases(output string) (*AliasMap, error) {
|
||||
m := NewAliasMap()
|
||||
|
||||
for _, line := range strings.Split(output, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || !strings.HasPrefix(line, "alias ") {
|
||||
continue
|
||||
}
|
||||
// Remove "alias " prefix
|
||||
rest := strings.TrimPrefix(line, "alias ")
|
||||
|
||||
// Split: name 'expansion' or name "expansion" or name expansion
|
||||
spaceIdx := strings.IndexByte(rest, ' ')
|
||||
if spaceIdx == -1 {
|
||||
continue
|
||||
}
|
||||
|
||||
name := rest[:spaceIdx]
|
||||
expansion := strings.TrimSpace(rest[spaceIdx+1:])
|
||||
expansion = stripQuotes(expansion)
|
||||
|
||||
if name == "" || expansion == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if v := ValidateCommand(expansion); v != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
m.aliases[name] = expansion
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// ParseAliases parses the output of `alias -p` into an AliasMap.
|
||||
// Each line is: alias name='expansion' (bash) or name=expansion (zsh)
|
||||
func ParseAliases(output string) (*AliasMap, error) {
|
||||
m := NewAliasMap()
|
||||
|
||||
for _, line := range strings.Split(output, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Strip "alias " prefix if present (bash format)
|
||||
line = strings.TrimPrefix(line, "alias ")
|
||||
|
||||
// Split on first '='
|
||||
eqIdx := strings.Index(line, "=")
|
||||
if eqIdx == -1 {
|
||||
continue
|
||||
}
|
||||
|
||||
name := line[:eqIdx]
|
||||
expansion := line[eqIdx+1:]
|
||||
|
||||
// Strip surrounding quotes from expansion
|
||||
expansion = stripQuotes(expansion)
|
||||
|
||||
if name == "" || expansion == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Security: validate the expansion doesn't contain dangerous patterns
|
||||
if v := ValidateCommand(expansion); v != nil {
|
||||
// Skip aliases with dangerous expansions
|
||||
continue
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
m.aliases[name] = expansion
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// stripQuotes removes matching surrounding single or double quotes.
|
||||
func stripQuotes(s string) string {
|
||||
if len(s) < 2 {
|
||||
return s
|
||||
}
|
||||
if (s[0] == '\'' && s[len(s)-1] == '\'') || (s[0] == '"' && s[len(s)-1] == '"') {
|
||||
return s[1 : len(s)-1]
|
||||
}
|
||||
return s
|
||||
}
|
||||
288
internal/tool/bash/aliases_test.go
Normal file
288
internal/tool/bash/aliases_test.go
Normal file
@@ -0,0 +1,288 @@
|
||||
package bash
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseAliases_BashFormat(t *testing.T) {
|
||||
output := `alias gs='git status'
|
||||
alias ll='ls -la --color=auto'
|
||||
alias gco='git checkout'
|
||||
alias ..='cd ..'
|
||||
`
|
||||
m, err := ParseAliases(output)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseAliases: %v", err)
|
||||
}
|
||||
|
||||
if m.Len() != 4 {
|
||||
t.Errorf("Len() = %d, want 4", m.Len())
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name, want string
|
||||
}{
|
||||
{"gs", "git status"},
|
||||
{"ll", "ls -la --color=auto"},
|
||||
{"gco", "git checkout"},
|
||||
{"..", "cd .."},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got, ok := m.Get(tt.name)
|
||||
if !ok {
|
||||
t.Errorf("alias %q not found", tt.name)
|
||||
continue
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("alias %q = %q, want %q", tt.name, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAliases_ZshFormat(t *testing.T) {
|
||||
// zsh alias -p may omit 'alias ' prefix
|
||||
output := `gs='git status'
|
||||
ll='ls -la'
|
||||
`
|
||||
m, err := ParseAliases(output)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseAliases: %v", err)
|
||||
}
|
||||
|
||||
got, ok := m.Get("gs")
|
||||
if !ok || got != "git status" {
|
||||
t.Errorf("gs = %q, %v", got, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAliases_DoubleQuotes(t *testing.T) {
|
||||
output := `alias gs="git status"
|
||||
`
|
||||
m, _ := ParseAliases(output)
|
||||
|
||||
got, ok := m.Get("gs")
|
||||
if !ok || got != "git status" {
|
||||
t.Errorf("gs = %q, %v", got, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAliases_SkipsDangerousExpansions(t *testing.T) {
|
||||
output := `alias safe='ls -la'
|
||||
alias danger='echo $(whoami)'
|
||||
alias backtick='echo ` + "`" + `date` + "`" + `'
|
||||
alias ifshack='IFS=: read a b'
|
||||
`
|
||||
m, _ := ParseAliases(output)
|
||||
|
||||
if _, ok := m.Get("safe"); !ok {
|
||||
t.Error("safe alias should be kept")
|
||||
}
|
||||
if _, ok := m.Get("danger"); ok {
|
||||
t.Error("danger alias ($()) should be filtered")
|
||||
}
|
||||
if _, ok := m.Get("backtick"); ok {
|
||||
t.Error("backtick alias should be filtered")
|
||||
}
|
||||
if _, ok := m.Get("ifshack"); ok {
|
||||
t.Error("IFS alias should be filtered")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAliases_EmptyAndMalformed(t *testing.T) {
|
||||
output := `
|
||||
alias gs='git status'
|
||||
|
||||
not a valid line
|
||||
alias =empty_name
|
||||
alias noequals
|
||||
`
|
||||
m, _ := ParseAliases(output)
|
||||
|
||||
if m.Len() != 1 {
|
||||
t.Errorf("Len() = %d, want 1 (only gs)", m.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAliasMap_ExpandCommand(t *testing.T) {
|
||||
m := NewAliasMap()
|
||||
m.mu.Lock()
|
||||
m.aliases["ll"] = "ls -la --color=auto"
|
||||
m.aliases["gs"] = "git status"
|
||||
m.aliases[".."] = "cd .."
|
||||
m.mu.Unlock()
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
// Alias with args
|
||||
{"ll /tmp", "ls -la --color=auto /tmp"},
|
||||
// Alias without args
|
||||
{"gs", "git status"},
|
||||
// Alias with trailing whitespace (trimmed)
|
||||
{"gs ", "git status"},
|
||||
// No alias match — return unchanged
|
||||
{"echo hello", "echo hello"},
|
||||
// Dotdot alias
|
||||
{"..", "cd .."},
|
||||
// Empty command
|
||||
{"", ""},
|
||||
// Only whitespace
|
||||
{" ", " "},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := m.ExpandCommand(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("ExpandCommand(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAliasMap_ExpandCommand_NoAliases(t *testing.T) {
|
||||
m := NewAliasMap()
|
||||
got := m.ExpandCommand("echo hello")
|
||||
if got != "echo hello" {
|
||||
t.Errorf("ExpandCommand = %q, want unchanged", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAliasMap_All(t *testing.T) {
|
||||
m := NewAliasMap()
|
||||
m.mu.Lock()
|
||||
m.aliases["a"] = "b"
|
||||
m.aliases["c"] = "d"
|
||||
m.mu.Unlock()
|
||||
|
||||
all := m.All()
|
||||
if len(all) != 2 {
|
||||
t.Errorf("len(All()) = %d, want 2", len(all))
|
||||
}
|
||||
// Verify it's a copy
|
||||
all["x"] = "y"
|
||||
if m.Len() != 2 {
|
||||
t.Error("All() should return a copy, not a reference")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripQuotes(t *testing.T) {
|
||||
tests := []struct {
|
||||
input, want string
|
||||
}{
|
||||
{"'hello'", "hello"},
|
||||
{`"hello"`, "hello"},
|
||||
{"hello", "hello"},
|
||||
{"'h'", "h"},
|
||||
{"''", ""},
|
||||
{`""`, ""},
|
||||
{"'mismatched\"", "'mismatched\""},
|
||||
{"x", "x"},
|
||||
{"", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := stripQuotes(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("stripQuotes(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFishAliases(t *testing.T) {
|
||||
output := `alias gs 'git status'
|
||||
alias ll 'ls -la'
|
||||
alias gco "git checkout"
|
||||
`
|
||||
m, err := ParseFishAliases(output)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseFishAliases: %v", err)
|
||||
}
|
||||
|
||||
if m.Len() != 3 {
|
||||
t.Errorf("Len() = %d, want 3", m.Len())
|
||||
}
|
||||
|
||||
got, ok := m.Get("gs")
|
||||
if !ok || got != "git status" {
|
||||
t.Errorf("gs = %q, %v", got, ok)
|
||||
}
|
||||
got, ok = m.Get("gco")
|
||||
if !ok || got != "git checkout" {
|
||||
t.Errorf("gco = %q, %v", got, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShellBaseName(t *testing.T) {
|
||||
tests := []struct {
|
||||
input, want string
|
||||
}{
|
||||
{"/bin/bash", "bash"},
|
||||
{"/usr/bin/zsh", "zsh"},
|
||||
{"/usr/local/bin/fish", "fish"},
|
||||
{"bash", "bash"},
|
||||
{"/bin/sh", "sh"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := shellBaseName(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("shellBaseName(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAliasCommandFor(t *testing.T) {
|
||||
tests := []struct {
|
||||
shell string
|
||||
want string
|
||||
}{
|
||||
{"bash", "alias -p 2>/dev/null; true"},
|
||||
{"zsh", "alias 2>/dev/null; true"},
|
||||
{"fish", "alias 2>/dev/null; true"},
|
||||
{"sh", "alias -p 2>/dev/null; true"},
|
||||
{"unknown", "alias 2>/dev/null; true"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := aliasCommandFor(tt.shell)
|
||||
if got != tt.want {
|
||||
t.Errorf("aliasCommandFor(%q) = %q, want %q", tt.shell, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHarvestAliases_Integration(t *testing.T) {
|
||||
// This actually runs the user's shell — skip in CI
|
||||
if testing.Short() {
|
||||
t.Skip("skipping alias harvest in short mode")
|
||||
}
|
||||
|
||||
m, err := HarvestAliases(context.Background())
|
||||
if err != nil {
|
||||
// Non-fatal: harvesting may fail in some environments
|
||||
t.Logf("HarvestAliases: %v (non-fatal)", err)
|
||||
}
|
||||
t.Logf("Harvested %d aliases", m.Len())
|
||||
for name, exp := range m.All() {
|
||||
t.Logf(" %s → %s", name, exp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBashTool_WithAliases(t *testing.T) {
|
||||
aliases := NewAliasMap()
|
||||
aliases.mu.Lock()
|
||||
aliases.aliases["ll"] = "ls -la"
|
||||
aliases.mu.Unlock()
|
||||
|
||||
b := New(WithAliases(aliases))
|
||||
|
||||
// "ll /tmp" should expand to "ls -la /tmp" and execute
|
||||
result, err := b.Execute(context.Background(), []byte(`{"command":"ll /tmp"}`))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
// Should produce output (ls -la /tmp lists files)
|
||||
if result.Output == "" {
|
||||
t.Error("expected output from expanded alias")
|
||||
}
|
||||
if result.Metadata["blocked"] == true {
|
||||
t.Error("expanded alias should not be blocked")
|
||||
}
|
||||
}
|
||||
140
internal/tool/bash/bash.go
Normal file
140
internal/tool/bash/bash.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package bash
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/tool"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTimeout = 30 * time.Second
|
||||
toolName = "bash"
|
||||
)
|
||||
|
||||
var parameterSchema = json.RawMessage(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "The bash command to execute"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Timeout in seconds (default 30)"
|
||||
}
|
||||
},
|
||||
"required": ["command"]
|
||||
}`)
|
||||
|
||||
// Tool executes bash commands.
|
||||
type Tool struct {
|
||||
timeout time.Duration
|
||||
workingDir string
|
||||
aliases *AliasMap
|
||||
}
|
||||
|
||||
type Option func(*Tool)
|
||||
|
||||
func WithTimeout(d time.Duration) Option {
|
||||
return func(t *Tool) { t.timeout = d }
|
||||
}
|
||||
|
||||
func WithWorkingDir(dir string) Option {
|
||||
return func(t *Tool) { t.workingDir = dir }
|
||||
}
|
||||
|
||||
func WithAliases(aliases *AliasMap) Option {
|
||||
return func(t *Tool) { t.aliases = aliases }
|
||||
}
|
||||
|
||||
// New creates a bash tool.
|
||||
func New(opts ...Option) *Tool {
|
||||
t := &Tool{timeout: defaultTimeout}
|
||||
for _, opt := range opts {
|
||||
opt(t)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *Tool) Name() string { return toolName }
|
||||
func (t *Tool) Description() string { return "Execute a bash command and return its output" }
|
||||
func (t *Tool) Parameters() json.RawMessage { return parameterSchema }
|
||||
func (t *Tool) IsReadOnly() bool { return false }
|
||||
func (t *Tool) IsDestructive() bool { return true }
|
||||
|
||||
type bashArgs struct {
|
||||
Command string `json:"command"`
|
||||
Timeout int `json:"timeout,omitempty"`
|
||||
}
|
||||
|
||||
func (t *Tool) Execute(ctx context.Context, args json.RawMessage) (tool.Result, error) {
|
||||
var a bashArgs
|
||||
if err := json.Unmarshal(args, &a); err != nil {
|
||||
return tool.Result{}, fmt.Errorf("bash: invalid args: %w", err)
|
||||
}
|
||||
|
||||
if a.Command == "" {
|
||||
return tool.Result{}, fmt.Errorf("bash: empty command")
|
||||
}
|
||||
|
||||
// Expand aliases (first word only, matching bash behavior)
|
||||
command := a.Command
|
||||
if t.aliases != nil {
|
||||
command = t.aliases.ExpandCommand(command)
|
||||
}
|
||||
|
||||
// Security validation runs on the expanded command
|
||||
if violation := ValidateCommand(command); violation != nil {
|
||||
return tool.Result{
|
||||
Output: fmt.Sprintf("Command blocked: %s", violation.Message),
|
||||
Metadata: map[string]any{"blocked": true, "check": int(violation.Check)},
|
||||
}, nil
|
||||
}
|
||||
|
||||
timeout := t.timeout
|
||||
if a.Timeout > 0 {
|
||||
timeout = time.Duration(a.Timeout) * time.Second
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "bash", "-c", command)
|
||||
if t.workingDir != "" {
|
||||
cmd.Dir = t.workingDir
|
||||
}
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
exitCode := 0
|
||||
|
||||
if err != nil {
|
||||
// Check timeout first — context deadline may also produce an ExitError
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
return tool.Result{
|
||||
Output: fmt.Sprintf("Command timed out after %s\n%s", timeout, strings.TrimRight(string(output), "\n")),
|
||||
Metadata: map[string]any{"exit_code": -1, "timeout": true},
|
||||
}, nil
|
||||
}
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
exitCode = exitErr.ExitCode()
|
||||
} else {
|
||||
return tool.Result{}, fmt.Errorf("bash: exec failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
result := tool.Result{
|
||||
Output: strings.TrimRight(string(output), "\n"),
|
||||
Metadata: map[string]any{"exit_code": exitCode},
|
||||
}
|
||||
|
||||
if exitCode != 0 {
|
||||
result.Output = fmt.Sprintf("Exit code %d\n%s", exitCode, result.Output)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
135
internal/tool/bash/bash_test.go
Normal file
135
internal/tool/bash/bash_test.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package bash
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestBashTool_Interface(t *testing.T) {
|
||||
b := New()
|
||||
if b.Name() != "bash" {
|
||||
t.Errorf("Name() = %q", b.Name())
|
||||
}
|
||||
if b.IsReadOnly() {
|
||||
t.Error("bash should not be read-only")
|
||||
}
|
||||
if !b.IsDestructive() {
|
||||
t.Error("bash should be destructive")
|
||||
}
|
||||
if b.Parameters() == nil {
|
||||
t.Error("Parameters() should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBashTool_Echo(t *testing.T) {
|
||||
b := New()
|
||||
result, err := b.Execute(context.Background(), json.RawMessage(`{"command":"echo hello world"}`))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if result.Output != "hello world" {
|
||||
t.Errorf("Output = %q, want %q", result.Output, "hello world")
|
||||
}
|
||||
if result.Metadata["exit_code"] != 0 {
|
||||
t.Errorf("exit_code = %v, want 0", result.Metadata["exit_code"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBashTool_ExitCode(t *testing.T) {
|
||||
b := New()
|
||||
result, err := b.Execute(context.Background(), json.RawMessage(`{"command":"exit 42"}`))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if result.Metadata["exit_code"] != 42 {
|
||||
t.Errorf("exit_code = %v, want 42", result.Metadata["exit_code"])
|
||||
}
|
||||
if !strings.HasPrefix(result.Output, "Exit code 42") {
|
||||
t.Errorf("Output = %q, should start with exit code", result.Output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBashTool_Timeout(t *testing.T) {
|
||||
b := New(WithTimeout(100 * time.Millisecond))
|
||||
result, err := b.Execute(context.Background(), json.RawMessage(`{"command":"sleep 10"}`))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if result.Metadata["timeout"] != true {
|
||||
t.Error("should have timed out")
|
||||
}
|
||||
if !strings.Contains(result.Output, "timed out") {
|
||||
t.Errorf("Output = %q, should mention timeout", result.Output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBashTool_CustomTimeout(t *testing.T) {
|
||||
b := New(WithTimeout(30 * time.Second))
|
||||
// Args override the default timeout
|
||||
result, err := b.Execute(context.Background(), json.RawMessage(`{"command":"sleep 10","timeout":1}`))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if result.Metadata["timeout"] != true {
|
||||
t.Error("should have timed out with custom 1s timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBashTool_InvalidArgs(t *testing.T) {
|
||||
b := New()
|
||||
_, err := b.Execute(context.Background(), json.RawMessage(`not json`))
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBashTool_EmptyCommand(t *testing.T) {
|
||||
b := New()
|
||||
_, err := b.Execute(context.Background(), json.RawMessage(`{"command":""}`))
|
||||
if err == nil {
|
||||
t.Error("expected error for empty command")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBashTool_SecurityBlock(t *testing.T) {
|
||||
b := New()
|
||||
|
||||
// Command substitution should be blocked
|
||||
result, err := b.Execute(context.Background(), json.RawMessage(`{"command":"echo $(whoami)"}`))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if result.Metadata["blocked"] != true {
|
||||
t.Error("command with $() should be blocked")
|
||||
}
|
||||
if !strings.Contains(result.Output, "blocked") {
|
||||
t.Errorf("Output = %q, should mention blocked", result.Output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBashTool_WorkingDir(t *testing.T) {
|
||||
b := New(WithWorkingDir(t.TempDir()))
|
||||
result, err := b.Execute(context.Background(), json.RawMessage(`{"command":"pwd"}`))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if result.Output == "" {
|
||||
t.Error("pwd should produce output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBashTool_ContextCancellation(t *testing.T) {
|
||||
b := New()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // cancel immediately
|
||||
|
||||
_, err := b.Execute(ctx, json.RawMessage(`{"command":"echo hello"}`))
|
||||
// Should either return an error or a timeout result
|
||||
if err == nil {
|
||||
// That's ok too — context cancellation is best-effort for fast commands
|
||||
return
|
||||
}
|
||||
}
|
||||
206
internal/tool/bash/security.go
Normal file
206
internal/tool/bash/security.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package bash
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// SecurityCheck identifies a specific validation check.
|
||||
type SecurityCheck int
|
||||
|
||||
const (
|
||||
CheckIncomplete SecurityCheck = iota + 1 // fragments, trailing operators
|
||||
CheckMetacharacters // ; | & $ ` < >
|
||||
CheckCmdSubstitution // $(), ``, ${}
|
||||
CheckRedirection // < > >> etc.
|
||||
CheckDangerousVars // IFS, PATH manipulation
|
||||
CheckNewlineInjection // embedded newlines
|
||||
CheckControlChars // ASCII 00-1F (except \n \t)
|
||||
)
|
||||
|
||||
// SecurityViolation describes a failed security check.
|
||||
type SecurityViolation struct {
|
||||
Check SecurityCheck
|
||||
Message string
|
||||
}
|
||||
|
||||
func (v SecurityViolation) Error() string {
|
||||
return fmt.Sprintf("bash security check %d: %s", v.Check, v.Message)
|
||||
}
|
||||
|
||||
// ValidateCommand runs the 7 critical security checks against a command string.
|
||||
// Returns nil if all checks pass, or the first violation found.
|
||||
func ValidateCommand(cmd string) *SecurityViolation {
|
||||
if strings.TrimSpace(cmd) == "" {
|
||||
return &SecurityViolation{Check: CheckIncomplete, Message: "empty command"}
|
||||
}
|
||||
|
||||
// Check incomplete on raw command (before trimming) to catch tab-starts
|
||||
if v := checkIncomplete(cmd); v != nil {
|
||||
return v
|
||||
}
|
||||
|
||||
cmd = strings.TrimSpace(cmd)
|
||||
|
||||
if v := checkControlChars(cmd); v != nil {
|
||||
return v
|
||||
}
|
||||
if v := checkNewlineInjection(cmd); v != nil {
|
||||
return v
|
||||
}
|
||||
if v := checkCmdSubstitution(cmd); v != nil {
|
||||
return v
|
||||
}
|
||||
if v := checkDangerousVars(cmd); v != nil {
|
||||
return v
|
||||
}
|
||||
// Metacharacters and redirection are warnings, not blocks in M1.
|
||||
// The LLM legitimately uses pipes and redirects.
|
||||
// Full compound command parsing (mvdan.cc/sh) comes in M5.
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkIncomplete detects command fragments that shouldn't be executed.
|
||||
func checkIncomplete(cmd string) *SecurityViolation {
|
||||
// Starts with tab (likely a fragment from indented code)
|
||||
if cmd[0] == '\t' {
|
||||
return &SecurityViolation{Check: CheckIncomplete, Message: "command starts with tab (likely a code fragment)"}
|
||||
}
|
||||
// Starts with a flag (no command name)
|
||||
if cmd[0] == '-' {
|
||||
return &SecurityViolation{Check: CheckIncomplete, Message: "command starts with flag (no command name)"}
|
||||
}
|
||||
// Ends with a dangling operator
|
||||
trimmed := strings.TrimRight(cmd, " \t")
|
||||
if len(trimmed) > 0 {
|
||||
last := trimmed[len(trimmed)-1]
|
||||
if last == '|' || last == '&' || last == ';' {
|
||||
return &SecurityViolation{Check: CheckIncomplete, Message: "command ends with dangling operator"}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkControlChars blocks ASCII control characters (0x00-0x1F) except \n and \t.
|
||||
func checkControlChars(cmd string) *SecurityViolation {
|
||||
for i, r := range cmd {
|
||||
if r < 0x20 && r != '\n' && r != '\t' && r != '\r' {
|
||||
return &SecurityViolation{
|
||||
Check: CheckControlChars,
|
||||
Message: fmt.Sprintf("control character U+%04X at position %d", r, i),
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkNewlineInjection blocks commands with embedded newlines.
|
||||
// Newlines in quoted strings are legitimate but rare in single commands.
|
||||
// We allow them inside single/double quotes only.
|
||||
func checkNewlineInjection(cmd string) *SecurityViolation {
|
||||
inSingle := false
|
||||
inDouble := false
|
||||
escaped := false
|
||||
|
||||
for _, r := range cmd {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if r == '\\' && !inSingle {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if r == '\'' && !inDouble {
|
||||
inSingle = !inSingle
|
||||
continue
|
||||
}
|
||||
if r == '"' && !inSingle {
|
||||
inDouble = !inDouble
|
||||
continue
|
||||
}
|
||||
if r == '\n' && !inSingle && !inDouble {
|
||||
return &SecurityViolation{
|
||||
Check: CheckNewlineInjection,
|
||||
Message: "unquoted newline (potential command injection)",
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkCmdSubstitution blocks $(), ``, and ${} command/variable substitution.
|
||||
// These allow arbitrary code execution within a command.
|
||||
func checkCmdSubstitution(cmd string) *SecurityViolation {
|
||||
inSingle := false
|
||||
escaped := false
|
||||
|
||||
for i, r := range cmd {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if r == '\\' && !inSingle {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if r == '\'' {
|
||||
inSingle = !inSingle
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip checks inside single quotes (literal)
|
||||
if inSingle {
|
||||
continue
|
||||
}
|
||||
|
||||
if r == '`' {
|
||||
return &SecurityViolation{
|
||||
Check: CheckCmdSubstitution,
|
||||
Message: "backtick command substitution",
|
||||
}
|
||||
}
|
||||
|
||||
if r == '$' && i+1 < len(cmd) {
|
||||
next := rune(cmd[i+1])
|
||||
if next == '(' {
|
||||
return &SecurityViolation{
|
||||
Check: CheckCmdSubstitution,
|
||||
Message: "$() command substitution",
|
||||
}
|
||||
}
|
||||
if next == '{' {
|
||||
return &SecurityViolation{
|
||||
Check: CheckCmdSubstitution,
|
||||
Message: "${} variable expansion",
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkDangerousVars blocks attempts to manipulate IFS or PATH.
|
||||
func checkDangerousVars(cmd string) *SecurityViolation {
|
||||
upper := strings.ToUpper(cmd)
|
||||
dangerousPatterns := []struct {
|
||||
pattern string
|
||||
msg string
|
||||
}{
|
||||
{"IFS=", "IFS variable manipulation"},
|
||||
{"PATH=", "PATH variable manipulation"},
|
||||
}
|
||||
|
||||
for _, p := range dangerousPatterns {
|
||||
idx := strings.Index(upper, p.pattern)
|
||||
if idx == -1 {
|
||||
continue
|
||||
}
|
||||
// Only flag if it's at the start or preceded by whitespace/semicolon
|
||||
if idx == 0 || !unicode.IsLetter(rune(cmd[idx-1])) {
|
||||
return &SecurityViolation{Check: CheckDangerousVars, Message: p.msg}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
182
internal/tool/bash/security_test.go
Normal file
182
internal/tool/bash/security_test.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package bash
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestValidateCommand_Valid(t *testing.T) {
|
||||
valid := []string{
|
||||
"echo hello",
|
||||
"ls -la",
|
||||
"cat /etc/hostname",
|
||||
"go test ./...",
|
||||
"git status",
|
||||
"echo 'hello world'",
|
||||
`echo "hello world"`,
|
||||
"grep -r 'pattern' .",
|
||||
"find . -name '*.go'",
|
||||
}
|
||||
for _, cmd := range valid {
|
||||
if v := ValidateCommand(cmd); v != nil {
|
||||
t.Errorf("ValidateCommand(%q) = %v, want nil", cmd, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateCommand_Empty(t *testing.T) {
|
||||
v := ValidateCommand("")
|
||||
if v == nil {
|
||||
t.Fatal("expected violation for empty command")
|
||||
}
|
||||
if v.Check != CheckIncomplete {
|
||||
t.Errorf("Check = %d, want %d (incomplete)", v.Check, CheckIncomplete)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckIncomplete(t *testing.T) {
|
||||
tests := []struct {
|
||||
cmd string
|
||||
want SecurityCheck
|
||||
}{
|
||||
{"\techo hello", CheckIncomplete}, // tab start
|
||||
{"-flag value", CheckIncomplete}, // flag start
|
||||
{"echo hello |", CheckIncomplete}, // trailing pipe
|
||||
{"echo hello &", CheckIncomplete}, // trailing ampersand
|
||||
{"echo hello ;", CheckIncomplete}, // trailing semicolon
|
||||
}
|
||||
for _, tt := range tests {
|
||||
v := ValidateCommand(tt.cmd)
|
||||
if v == nil {
|
||||
t.Errorf("ValidateCommand(%q) = nil, want check %d", tt.cmd, tt.want)
|
||||
continue
|
||||
}
|
||||
if v.Check != tt.want {
|
||||
t.Errorf("ValidateCommand(%q).Check = %d, want %d", tt.cmd, v.Check, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckControlChars(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cmd string
|
||||
}{
|
||||
{"null byte", "echo hello\x00world"},
|
||||
{"bell", "echo \x07"},
|
||||
{"backspace", "echo \x08"},
|
||||
{"escape", "echo \x1b[31m"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
v := ValidateCommand(tt.cmd)
|
||||
if v == nil {
|
||||
t.Error("expected violation")
|
||||
return
|
||||
}
|
||||
if v.Check != CheckControlChars {
|
||||
t.Errorf("Check = %d, want %d (control chars)", v.Check, CheckControlChars)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckControlChars_AllowedChars(t *testing.T) {
|
||||
// Tabs and newlines inside quotes are allowed
|
||||
valid := []string{
|
||||
"echo 'hello\tworld'",
|
||||
}
|
||||
for _, cmd := range valid {
|
||||
if v := checkControlChars(cmd); v != nil {
|
||||
t.Errorf("checkControlChars(%q) = %v, want nil", cmd, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckNewlineInjection(t *testing.T) {
|
||||
// Unquoted newline
|
||||
v := checkNewlineInjection("echo hello\nrm -rf /")
|
||||
if v == nil {
|
||||
t.Fatal("expected violation for unquoted newline")
|
||||
}
|
||||
if v.Check != CheckNewlineInjection {
|
||||
t.Errorf("Check = %d, want %d", v.Check, CheckNewlineInjection)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckNewlineInjection_QuotedOK(t *testing.T) {
|
||||
// Newlines inside quotes are fine
|
||||
allowed := []string{
|
||||
"echo 'hello\nworld'",
|
||||
`echo "hello` + "\n" + `world"`,
|
||||
}
|
||||
for _, cmd := range allowed {
|
||||
if v := checkNewlineInjection(cmd); v != nil {
|
||||
t.Errorf("checkNewlineInjection(%q) = %v, want nil", cmd, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckCmdSubstitution(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cmd string
|
||||
}{
|
||||
{"backtick", "echo `whoami`"},
|
||||
{"dollar paren", "echo $(whoami)"},
|
||||
{"dollar brace", "echo ${HOME}"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
v := ValidateCommand(tt.cmd)
|
||||
if v == nil {
|
||||
t.Error("expected violation")
|
||||
return
|
||||
}
|
||||
if v.Check != CheckCmdSubstitution {
|
||||
t.Errorf("Check = %d, want %d", v.Check, CheckCmdSubstitution)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckCmdSubstitution_SingleQuoteOK(t *testing.T) {
|
||||
// Inside single quotes, everything is literal
|
||||
safe := "echo '$(whoami) and `uname` and ${HOME}'"
|
||||
if v := checkCmdSubstitution(safe); v != nil {
|
||||
t.Errorf("checkCmdSubstitution(%q) = %v, want nil (single-quoted)", safe, v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckDangerousVars(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cmd string
|
||||
}{
|
||||
{"IFS at start", "IFS=: read a b"},
|
||||
{"PATH manipulation", "PATH=/tmp:$PATH command"},
|
||||
{"ifs with space prefix", " IFS=x echo hi"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
v := ValidateCommand(tt.cmd)
|
||||
if v == nil {
|
||||
t.Error("expected violation")
|
||||
return
|
||||
}
|
||||
if v.Check != CheckDangerousVars {
|
||||
t.Errorf("Check = %d, want %d", v.Check, CheckDangerousVars)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckDangerousVars_SafeSubstrings(t *testing.T) {
|
||||
// "SWIFT=..." should not trigger PATH check, "TARIFFS=..." should not trigger IFS
|
||||
safe := []string{
|
||||
"echo SWIFT=enabled",
|
||||
"TARIFFS=high echo test",
|
||||
}
|
||||
for _, cmd := range safe {
|
||||
if v := checkDangerousVars(cmd); v != nil {
|
||||
t.Errorf("checkDangerousVars(%q) = %v, want nil", cmd, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
109
internal/tool/fs/edit.go
Normal file
109
internal/tool/fs/edit.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package fs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/tool"
|
||||
)
|
||||
|
||||
const editToolName = "fs.edit"
|
||||
|
||||
var editParams = json.RawMessage(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Absolute path to the file to edit"
|
||||
},
|
||||
"old_string": {
|
||||
"type": "string",
|
||||
"description": "The exact text to find and replace"
|
||||
},
|
||||
"new_string": {
|
||||
"type": "string",
|
||||
"description": "The replacement text"
|
||||
},
|
||||
"replace_all": {
|
||||
"type": "boolean",
|
||||
"description": "Replace all occurrences (default false)"
|
||||
}
|
||||
},
|
||||
"required": ["path", "old_string", "new_string"]
|
||||
}`)
|
||||
|
||||
type EditTool struct{}
|
||||
|
||||
func NewEditTool() *EditTool { return &EditTool{} }
|
||||
|
||||
func (t *EditTool) Name() string { return editToolName }
|
||||
func (t *EditTool) Description() string { return "Perform exact string replacement in a file" }
|
||||
func (t *EditTool) Parameters() json.RawMessage { return editParams }
|
||||
func (t *EditTool) IsReadOnly() bool { return false }
|
||||
func (t *EditTool) IsDestructive() bool { return false }
|
||||
|
||||
type editArgs struct {
|
||||
Path string `json:"path"`
|
||||
OldString string `json:"old_string"`
|
||||
NewString string `json:"new_string"`
|
||||
ReplaceAll bool `json:"replace_all,omitempty"`
|
||||
}
|
||||
|
||||
func (t *EditTool) Execute(_ context.Context, args json.RawMessage) (tool.Result, error) {
|
||||
var a editArgs
|
||||
if err := json.Unmarshal(args, &a); err != nil {
|
||||
return tool.Result{}, fmt.Errorf("fs.edit: invalid args: %w", err)
|
||||
}
|
||||
if a.Path == "" {
|
||||
return tool.Result{}, fmt.Errorf("fs.edit: path required")
|
||||
}
|
||||
if a.OldString == a.NewString {
|
||||
return tool.Result{}, fmt.Errorf("fs.edit: old_string and new_string must differ")
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(a.Path)
|
||||
if err != nil {
|
||||
return tool.Result{Output: fmt.Sprintf("Error: %v", err)}, nil
|
||||
}
|
||||
|
||||
content := string(data)
|
||||
|
||||
count := strings.Count(content, a.OldString)
|
||||
if count == 0 {
|
||||
return tool.Result{
|
||||
Output: "Error: old_string not found in file",
|
||||
Metadata: map[string]any{"matches": 0},
|
||||
}, nil
|
||||
}
|
||||
|
||||
if !a.ReplaceAll && count > 1 {
|
||||
return tool.Result{
|
||||
Output: fmt.Sprintf("Error: old_string has %d matches (must be unique, or use replace_all)", count),
|
||||
Metadata: map[string]any{"matches": count},
|
||||
}, nil
|
||||
}
|
||||
|
||||
var newContent string
|
||||
if a.ReplaceAll {
|
||||
newContent = strings.ReplaceAll(content, a.OldString, a.NewString)
|
||||
} else {
|
||||
newContent = strings.Replace(content, a.OldString, a.NewString, 1)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(a.Path, []byte(newContent), 0o644); err != nil {
|
||||
return tool.Result{Output: fmt.Sprintf("Error writing file: %v", err)}, nil
|
||||
}
|
||||
|
||||
replacements := 1
|
||||
if a.ReplaceAll {
|
||||
replacements = count
|
||||
}
|
||||
|
||||
return tool.Result{
|
||||
Output: fmt.Sprintf("Replaced %d occurrence(s) in %s", replacements, a.Path),
|
||||
Metadata: map[string]any{"replacements": replacements, "path": a.Path},
|
||||
}, nil
|
||||
}
|
||||
545
internal/tool/fs/fs_test.go
Normal file
545
internal/tool/fs/fs_test.go
Normal file
@@ -0,0 +1,545 @@
|
||||
package fs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// --- Read ---
|
||||
|
||||
func TestReadTool_Interface(t *testing.T) {
|
||||
r := NewReadTool()
|
||||
if r.Name() != "fs.read" {
|
||||
t.Errorf("Name() = %q", r.Name())
|
||||
}
|
||||
if !r.IsReadOnly() {
|
||||
t.Error("should be read-only")
|
||||
}
|
||||
if r.IsDestructive() {
|
||||
t.Error("should not be destructive")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadTool_SimpleFile(t *testing.T) {
|
||||
path := writeTestFile(t, "hello\nworld\n")
|
||||
r := NewReadTool()
|
||||
|
||||
result, err := r.Execute(context.Background(), mustJSON(t, readArgs{Path: path}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if !strings.Contains(result.Output, "1\thello") {
|
||||
t.Errorf("Output should contain line-numbered content, got %q", result.Output)
|
||||
}
|
||||
if !strings.Contains(result.Output, "2\tworld") {
|
||||
t.Errorf("Output missing line 2, got %q", result.Output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadTool_WithOffset(t *testing.T) {
|
||||
path := writeTestFile(t, "line1\nline2\nline3\nline4\nline5\n")
|
||||
r := NewReadTool()
|
||||
|
||||
result, err := r.Execute(context.Background(), mustJSON(t, readArgs{Path: path, Offset: 2}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if !strings.Contains(result.Output, "3\tline3") {
|
||||
t.Errorf("Output should start at line 3, got %q", result.Output)
|
||||
}
|
||||
if strings.Contains(result.Output, "1\tline1") {
|
||||
t.Error("Output should not contain line 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadTool_WithLimit(t *testing.T) {
|
||||
path := writeTestFile(t, "a\nb\nc\nd\ne\n")
|
||||
r := NewReadTool()
|
||||
|
||||
result, err := r.Execute(context.Background(), mustJSON(t, readArgs{Path: path, Limit: 2}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
lines := strings.Split(result.Output, "\n")
|
||||
if len(lines) != 2 {
|
||||
t.Errorf("expected 2 lines, got %d: %q", len(lines), result.Output)
|
||||
}
|
||||
if result.Metadata["truncated"] != true {
|
||||
t.Error("should be truncated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadTool_OffsetPastEnd(t *testing.T) {
|
||||
path := writeTestFile(t, "one\ntwo\n")
|
||||
r := NewReadTool()
|
||||
|
||||
result, err := r.Execute(context.Background(), mustJSON(t, readArgs{Path: path, Offset: 100}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if !strings.Contains(result.Output, "past end") {
|
||||
t.Errorf("Output = %q, should mention past end", result.Output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadTool_FileNotFound(t *testing.T) {
|
||||
r := NewReadTool()
|
||||
result, err := r.Execute(context.Background(), mustJSON(t, readArgs{Path: "/nonexistent/file.txt"}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if !strings.Contains(result.Output, "Error") {
|
||||
t.Errorf("Output = %q, should contain error", result.Output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadTool_EmptyPath(t *testing.T) {
|
||||
r := NewReadTool()
|
||||
_, err := r.Execute(context.Background(), mustJSON(t, readArgs{}))
|
||||
if err == nil {
|
||||
t.Error("expected error for empty path")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Write ---
|
||||
|
||||
func TestWriteTool_Interface(t *testing.T) {
|
||||
w := NewWriteTool()
|
||||
if w.Name() != "fs.write" {
|
||||
t.Errorf("Name() = %q", w.Name())
|
||||
}
|
||||
if w.IsReadOnly() {
|
||||
t.Error("should not be read-only")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteTool_CreateFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.txt")
|
||||
|
||||
w := NewWriteTool()
|
||||
result, err := w.Execute(context.Background(), mustJSON(t, writeArgs{Path: path, Content: "hello world"}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if !strings.Contains(result.Output, "11 bytes") {
|
||||
t.Errorf("Output = %q", result.Output)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(path)
|
||||
if string(data) != "hello world" {
|
||||
t.Errorf("file content = %q", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteTool_CreatesParentDirs(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "a", "b", "c", "test.txt")
|
||||
|
||||
w := NewWriteTool()
|
||||
_, err := w.Execute(context.Background(), mustJSON(t, writeArgs{Path: path, Content: "nested"}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(path)
|
||||
if string(data) != "nested" {
|
||||
t.Errorf("file content = %q", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteTool_OverwriteExisting(t *testing.T) {
|
||||
path := writeTestFile(t, "old content")
|
||||
w := NewWriteTool()
|
||||
|
||||
_, err := w.Execute(context.Background(), mustJSON(t, writeArgs{Path: path, Content: "new content"}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(path)
|
||||
if string(data) != "new content" {
|
||||
t.Errorf("file content = %q", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
// --- Edit ---
|
||||
|
||||
func TestEditTool_Interface(t *testing.T) {
|
||||
e := NewEditTool()
|
||||
if e.Name() != "fs.edit" {
|
||||
t.Errorf("Name() = %q", e.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestEditTool_SingleReplace(t *testing.T) {
|
||||
path := writeTestFile(t, "hello world")
|
||||
e := NewEditTool()
|
||||
|
||||
result, err := e.Execute(context.Background(), mustJSON(t, editArgs{
|
||||
Path: path, OldString: "world", NewString: "gnoma",
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if !strings.Contains(result.Output, "1 occurrence") {
|
||||
t.Errorf("Output = %q", result.Output)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(path)
|
||||
if string(data) != "hello gnoma" {
|
||||
t.Errorf("file content = %q", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestEditTool_ReplaceAll(t *testing.T) {
|
||||
path := writeTestFile(t, "foo bar foo baz foo")
|
||||
e := NewEditTool()
|
||||
|
||||
result, err := e.Execute(context.Background(), mustJSON(t, editArgs{
|
||||
Path: path, OldString: "foo", NewString: "qux", ReplaceAll: true,
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if !strings.Contains(result.Output, "3 occurrence") {
|
||||
t.Errorf("Output = %q", result.Output)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(path)
|
||||
if string(data) != "qux bar qux baz qux" {
|
||||
t.Errorf("file content = %q", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestEditTool_NonUniqueWithoutReplaceAll(t *testing.T) {
|
||||
path := writeTestFile(t, "foo foo foo")
|
||||
e := NewEditTool()
|
||||
|
||||
result, err := e.Execute(context.Background(), mustJSON(t, editArgs{
|
||||
Path: path, OldString: "foo", NewString: "bar",
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if !strings.Contains(result.Output, "3 matches") {
|
||||
t.Errorf("Output = %q, should mention multiple matches", result.Output)
|
||||
}
|
||||
|
||||
// File should be unchanged
|
||||
data, _ := os.ReadFile(path)
|
||||
if string(data) != "foo foo foo" {
|
||||
t.Errorf("file should be unchanged, got %q", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestEditTool_NotFound(t *testing.T) {
|
||||
path := writeTestFile(t, "hello world")
|
||||
e := NewEditTool()
|
||||
|
||||
result, err := e.Execute(context.Background(), mustJSON(t, editArgs{
|
||||
Path: path, OldString: "missing", NewString: "replaced",
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if !strings.Contains(result.Output, "not found") {
|
||||
t.Errorf("Output = %q, should mention not found", result.Output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEditTool_SameStrings(t *testing.T) {
|
||||
e := NewEditTool()
|
||||
_, err := e.Execute(context.Background(), mustJSON(t, editArgs{
|
||||
Path: "/tmp/x", OldString: "same", NewString: "same",
|
||||
}))
|
||||
if err == nil {
|
||||
t.Error("expected error when old_string == new_string")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Glob ---
|
||||
|
||||
func TestGlobTool_Interface(t *testing.T) {
|
||||
g := NewGlobTool()
|
||||
if g.Name() != "fs.glob" {
|
||||
t.Errorf("Name() = %q", g.Name())
|
||||
}
|
||||
if !g.IsReadOnly() {
|
||||
t.Error("should be read-only")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobTool_MatchFiles(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
os.WriteFile(filepath.Join(dir, "main.go"), []byte("package main"), 0o644)
|
||||
os.WriteFile(filepath.Join(dir, "test.go"), []byte("package main"), 0o644)
|
||||
os.WriteFile(filepath.Join(dir, "readme.md"), []byte("# readme"), 0o644)
|
||||
|
||||
g := NewGlobTool()
|
||||
result, err := g.Execute(context.Background(), mustJSON(t, globArgs{Pattern: "*.go", Path: dir}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
|
||||
if result.Metadata["count"] != 2 {
|
||||
t.Errorf("count = %v, want 2", result.Metadata["count"])
|
||||
}
|
||||
if !strings.Contains(result.Output, "main.go") {
|
||||
t.Errorf("Output missing main.go: %q", result.Output)
|
||||
}
|
||||
if strings.Contains(result.Output, "readme.md") {
|
||||
t.Error("Output should not contain readme.md")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobTool_NoMatches(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
g := NewGlobTool()
|
||||
|
||||
result, err := g.Execute(context.Background(), mustJSON(t, globArgs{Pattern: "*.xyz", Path: dir}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if !strings.Contains(result.Output, "no matches") {
|
||||
t.Errorf("Output = %q", result.Output)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Grep ---
|
||||
|
||||
func TestGrepTool_Interface(t *testing.T) {
|
||||
g := NewGrepTool()
|
||||
if g.Name() != "fs.grep" {
|
||||
t.Errorf("Name() = %q", g.Name())
|
||||
}
|
||||
if !g.IsReadOnly() {
|
||||
t.Error("should be read-only")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGrepTool_SingleFile(t *testing.T) {
|
||||
path := writeTestFile(t, "hello world\nfoo bar\nhello again\n")
|
||||
g := NewGrepTool()
|
||||
|
||||
result, err := g.Execute(context.Background(), mustJSON(t, grepArgs{Pattern: "hello", Path: path}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if result.Metadata["count"] != 2 {
|
||||
t.Errorf("count = %v, want 2", result.Metadata["count"])
|
||||
}
|
||||
if !strings.Contains(result.Output, "1:hello world") {
|
||||
t.Errorf("Output = %q", result.Output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGrepTool_Directory(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
os.WriteFile(filepath.Join(dir, "a.go"), []byte("func main() {}\nfunc helper() {}"), 0o644)
|
||||
os.WriteFile(filepath.Join(dir, "b.go"), []byte("func test() {}"), 0o644)
|
||||
os.WriteFile(filepath.Join(dir, "c.txt"), []byte("func ignored() {}"), 0o644)
|
||||
|
||||
g := NewGrepTool()
|
||||
|
||||
// Search all files for "func"
|
||||
result, err := g.Execute(context.Background(), mustJSON(t, grepArgs{Pattern: "func", Path: dir}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if result.Metadata["count"].(int) < 3 {
|
||||
t.Errorf("count = %v, want >= 3", result.Metadata["count"])
|
||||
}
|
||||
|
||||
// With glob filter
|
||||
result, err = g.Execute(context.Background(), mustJSON(t, grepArgs{Pattern: "func", Path: dir, Glob: "*.go"}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if strings.Contains(result.Output, "c.txt") {
|
||||
t.Error("should not match .txt files with *.go glob")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGrepTool_Regex(t *testing.T) {
|
||||
path := writeTestFile(t, "error: something failed\nwarning: be careful\nerror: another one\n")
|
||||
g := NewGrepTool()
|
||||
|
||||
result, err := g.Execute(context.Background(), mustJSON(t, grepArgs{Pattern: `^error:`, Path: path}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if result.Metadata["count"] != 2 {
|
||||
t.Errorf("count = %v, want 2", result.Metadata["count"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestGrepTool_InvalidRegex(t *testing.T) {
|
||||
g := NewGrepTool()
|
||||
result, err := g.Execute(context.Background(), mustJSON(t, grepArgs{Pattern: "[invalid", Path: "."}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if !strings.Contains(result.Output, "Invalid regex") {
|
||||
t.Errorf("Output = %q, should mention invalid regex", result.Output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGrepTool_NoMatches(t *testing.T) {
|
||||
path := writeTestFile(t, "hello world\n")
|
||||
g := NewGrepTool()
|
||||
|
||||
result, err := g.Execute(context.Background(), mustJSON(t, grepArgs{Pattern: "zzzzz", Path: path}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if !strings.Contains(result.Output, "no matches") {
|
||||
t.Errorf("Output = %q", result.Output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGrepTool_MaxResults(t *testing.T) {
|
||||
var lines strings.Builder
|
||||
for i := 0; i < 100; i++ {
|
||||
lines.WriteString("match line\n")
|
||||
}
|
||||
path := writeTestFile(t, lines.String())
|
||||
g := NewGrepTool()
|
||||
|
||||
result, err := g.Execute(context.Background(), mustJSON(t, grepArgs{Pattern: "match", Path: path, MaxResults: 5}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if result.Metadata["count"] != 5 {
|
||||
t.Errorf("count = %v, want 5", result.Metadata["count"])
|
||||
}
|
||||
if result.Metadata["truncated"] != true {
|
||||
t.Error("should be truncated")
|
||||
}
|
||||
}
|
||||
|
||||
// --- LS ---
|
||||
|
||||
func TestLSTool_Interface(t *testing.T) {
|
||||
l := NewLSTool()
|
||||
if l.Name() != "fs.ls" {
|
||||
t.Errorf("Name() = %q", l.Name())
|
||||
}
|
||||
if !l.IsReadOnly() {
|
||||
t.Error("should be read-only")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLSTool_ListDirectory(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
os.WriteFile(filepath.Join(dir, "hello.go"), []byte("package main"), 0o644)
|
||||
os.WriteFile(filepath.Join(dir, "readme.md"), []byte("# readme"), 0o644)
|
||||
os.MkdirAll(filepath.Join(dir, "subdir"), 0o755)
|
||||
|
||||
l := NewLSTool()
|
||||
result, err := l.Execute(context.Background(), mustJSON(t, lsArgs{Path: dir}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(result.Output, "hello.go") {
|
||||
t.Errorf("Output missing hello.go: %q", result.Output)
|
||||
}
|
||||
if !strings.Contains(result.Output, "readme.md") {
|
||||
t.Errorf("Output missing readme.md: %q", result.Output)
|
||||
}
|
||||
if !strings.Contains(result.Output, "subdir") {
|
||||
t.Errorf("Output missing subdir: %q", result.Output)
|
||||
}
|
||||
|
||||
if result.Metadata["files"] != 2 {
|
||||
t.Errorf("files = %v, want 2", result.Metadata["files"])
|
||||
}
|
||||
if result.Metadata["dirs"] != 1 {
|
||||
t.Errorf("dirs = %v, want 1", result.Metadata["dirs"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestLSTool_EmptyDirectory(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
l := NewLSTool()
|
||||
|
||||
result, err := l.Execute(context.Background(), mustJSON(t, lsArgs{Path: dir}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if !strings.Contains(result.Output, "empty directory") {
|
||||
t.Errorf("Output = %q, should mention empty", result.Output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLSTool_DirectoryNotFound(t *testing.T) {
|
||||
l := NewLSTool()
|
||||
result, err := l.Execute(context.Background(), mustJSON(t, lsArgs{Path: "/nonexistent/dir"}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if !strings.Contains(result.Output, "Error") {
|
||||
t.Errorf("Output = %q, should contain error", result.Output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLSTool_ShowsSizes(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
os.WriteFile(filepath.Join(dir, "small.txt"), []byte("hi"), 0o644)
|
||||
|
||||
l := NewLSTool()
|
||||
result, err := l.Execute(context.Background(), mustJSON(t, lsArgs{Path: dir}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
// Should show "2B" for a 2-byte file
|
||||
if !strings.Contains(result.Output, "2B") {
|
||||
t.Errorf("Output = %q, should show file size", result.Output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatSize(t *testing.T) {
|
||||
tests := []struct {
|
||||
bytes int64
|
||||
want string
|
||||
}{
|
||||
{0, "0B"},
|
||||
{42, "42B"},
|
||||
{1024, "1.0K"},
|
||||
{1536, "1.5K"},
|
||||
{1048576, "1.0M"},
|
||||
{1073741824, "1.0G"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := formatSize(tt.bytes)
|
||||
if got != tt.want {
|
||||
t.Errorf("formatSize(%d) = %q, want %q", tt.bytes, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
func writeTestFile(t *testing.T, content string) string {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.txt")
|
||||
if err := os.WriteFile(path, []byte(content), 0o644); err != nil {
|
||||
t.Fatalf("writeTestFile: %v", err)
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func mustJSON(t *testing.T, v any) json.RawMessage {
|
||||
t.Helper()
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal: %v", err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
117
internal/tool/fs/glob.go
Normal file
117
internal/tool/fs/glob.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package fs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/tool"
|
||||
)
|
||||
|
||||
const globToolName = "fs.glob"
|
||||
|
||||
var globParams = json.RawMessage(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Glob pattern to match files (e.g. **/*.go, src/**/*.ts)"
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory to search in (defaults to current directory)"
|
||||
}
|
||||
},
|
||||
"required": ["pattern"]
|
||||
}`)
|
||||
|
||||
type GlobTool struct{}
|
||||
|
||||
func NewGlobTool() *GlobTool { return &GlobTool{} }
|
||||
|
||||
func (t *GlobTool) Name() string { return globToolName }
|
||||
func (t *GlobTool) Description() string { return "Find files matching a glob pattern, sorted by modification time" }
|
||||
func (t *GlobTool) Parameters() json.RawMessage { return globParams }
|
||||
func (t *GlobTool) IsReadOnly() bool { return true }
|
||||
func (t *GlobTool) IsDestructive() bool { return false }
|
||||
|
||||
type globArgs struct {
|
||||
Pattern string `json:"pattern"`
|
||||
Path string `json:"path,omitempty"`
|
||||
}
|
||||
|
||||
func (t *GlobTool) Execute(_ context.Context, args json.RawMessage) (tool.Result, error) {
|
||||
var a globArgs
|
||||
if err := json.Unmarshal(args, &a); err != nil {
|
||||
return tool.Result{}, fmt.Errorf("fs.glob: invalid args: %w", err)
|
||||
}
|
||||
if a.Pattern == "" {
|
||||
return tool.Result{}, fmt.Errorf("fs.glob: pattern required")
|
||||
}
|
||||
|
||||
root := a.Path
|
||||
if root == "" {
|
||||
var err error
|
||||
root, err = os.Getwd()
|
||||
if err != nil {
|
||||
return tool.Result{}, fmt.Errorf("fs.glob: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var matches []string
|
||||
err := filepath.WalkDir(root, func(path string, d os.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return nil // skip inaccessible entries
|
||||
}
|
||||
if d.IsDir() {
|
||||
// Skip hidden directories
|
||||
if d.Name() != "." && strings.HasPrefix(d.Name(), ".") {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
rel, err := filepath.Rel(root, path)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
matched, err := filepath.Match(a.Pattern, rel)
|
||||
if err != nil {
|
||||
// Try matching just the filename for simple patterns
|
||||
matched, _ = filepath.Match(a.Pattern, d.Name())
|
||||
}
|
||||
|
||||
if matched {
|
||||
matches = append(matches, rel)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return tool.Result{Output: fmt.Sprintf("Error walking directory: %v", err)}, nil
|
||||
}
|
||||
|
||||
// Sort by modification time (most recent first)
|
||||
sort.Slice(matches, func(i, j int) bool {
|
||||
iInfo, _ := os.Stat(filepath.Join(root, matches[i]))
|
||||
jInfo, _ := os.Stat(filepath.Join(root, matches[j]))
|
||||
if iInfo == nil || jInfo == nil {
|
||||
return matches[i] < matches[j]
|
||||
}
|
||||
return iInfo.ModTime().After(jInfo.ModTime())
|
||||
})
|
||||
|
||||
output := strings.Join(matches, "\n")
|
||||
if output == "" {
|
||||
output = "(no matches)"
|
||||
}
|
||||
|
||||
return tool.Result{
|
||||
Output: output,
|
||||
Metadata: map[string]any{"count": len(matches), "pattern": a.Pattern},
|
||||
}, nil
|
||||
}
|
||||
184
internal/tool/fs/grep.go
Normal file
184
internal/tool/fs/grep.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package fs
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/tool"
|
||||
)
|
||||
|
||||
const (
|
||||
grepToolName = "fs.grep"
|
||||
defaultMaxResults = 250
|
||||
)
|
||||
|
||||
var grepParams = json.RawMessage(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Regular expression pattern to search for"
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File or directory to search in (defaults to current directory)"
|
||||
},
|
||||
"glob": {
|
||||
"type": "string",
|
||||
"description": "File glob filter (e.g. *.go, *.ts)"
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of matching lines to return (default 250)"
|
||||
}
|
||||
},
|
||||
"required": ["pattern"]
|
||||
}`)
|
||||
|
||||
type GrepTool struct{}
|
||||
|
||||
func NewGrepTool() *GrepTool { return &GrepTool{} }
|
||||
|
||||
func (t *GrepTool) Name() string { return grepToolName }
|
||||
func (t *GrepTool) Description() string { return "Search file contents using a regular expression" }
|
||||
func (t *GrepTool) Parameters() json.RawMessage { return grepParams }
|
||||
func (t *GrepTool) IsReadOnly() bool { return true }
|
||||
func (t *GrepTool) IsDestructive() bool { return false }
|
||||
|
||||
type grepArgs struct {
|
||||
Pattern string `json:"pattern"`
|
||||
Path string `json:"path,omitempty"`
|
||||
Glob string `json:"glob,omitempty"`
|
||||
MaxResults int `json:"max_results,omitempty"`
|
||||
}
|
||||
|
||||
type grepMatch struct {
|
||||
File string
|
||||
Line int
|
||||
Text string
|
||||
}
|
||||
|
||||
func (t *GrepTool) Execute(_ context.Context, args json.RawMessage) (tool.Result, error) {
|
||||
var a grepArgs
|
||||
if err := json.Unmarshal(args, &a); err != nil {
|
||||
return tool.Result{}, fmt.Errorf("fs.grep: invalid args: %w", err)
|
||||
}
|
||||
if a.Pattern == "" {
|
||||
return tool.Result{}, fmt.Errorf("fs.grep: pattern required")
|
||||
}
|
||||
|
||||
re, err := regexp.Compile(a.Pattern)
|
||||
if err != nil {
|
||||
return tool.Result{Output: fmt.Sprintf("Invalid regex: %v", err)}, nil
|
||||
}
|
||||
|
||||
maxResults := a.MaxResults
|
||||
if maxResults <= 0 {
|
||||
maxResults = defaultMaxResults
|
||||
}
|
||||
|
||||
root := a.Path
|
||||
if root == "" {
|
||||
root, err = os.Getwd()
|
||||
if err != nil {
|
||||
return tool.Result{}, fmt.Errorf("fs.grep: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
info, err := os.Stat(root)
|
||||
if err != nil {
|
||||
return tool.Result{Output: fmt.Sprintf("Error: %v", err)}, nil
|
||||
}
|
||||
|
||||
var matches []grepMatch
|
||||
|
||||
if !info.IsDir() {
|
||||
matches = grepFile(root, "", re, maxResults)
|
||||
} else {
|
||||
filepath.WalkDir(root, func(path string, d os.DirEntry, err error) error {
|
||||
if err != nil || d.IsDir() {
|
||||
if d != nil && d.IsDir() && d.Name() != "." && strings.HasPrefix(d.Name(), ".") {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Apply glob filter
|
||||
if a.Glob != "" {
|
||||
matched, _ := filepath.Match(a.Glob, d.Name())
|
||||
if !matched {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
rel, _ := filepath.Rel(root, path)
|
||||
fileMatches := grepFile(path, rel, re, maxResults-len(matches))
|
||||
matches = append(matches, fileMatches...)
|
||||
|
||||
if len(matches) >= maxResults {
|
||||
return filepath.SkipAll
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if len(matches) == 0 {
|
||||
return tool.Result{
|
||||
Output: "(no matches)",
|
||||
Metadata: map[string]any{"count": 0},
|
||||
}, nil
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
for _, m := range matches {
|
||||
if m.File != "" {
|
||||
fmt.Fprintf(&b, "%s:%d:%s\n", m.File, m.Line, m.Text)
|
||||
} else {
|
||||
fmt.Fprintf(&b, "%d:%s\n", m.Line, m.Text)
|
||||
}
|
||||
}
|
||||
|
||||
truncated := len(matches) >= maxResults
|
||||
|
||||
return tool.Result{
|
||||
Output: strings.TrimRight(b.String(), "\n"),
|
||||
Metadata: map[string]any{
|
||||
"count": len(matches),
|
||||
"truncated": truncated,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func grepFile(path, displayPath string, re *regexp.Regexp, limit int) []grepMatch {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var matches []grepMatch
|
||||
scanner := bufio.NewScanner(f)
|
||||
lineNum := 0
|
||||
|
||||
for scanner.Scan() {
|
||||
lineNum++
|
||||
line := scanner.Text()
|
||||
if re.MatchString(line) {
|
||||
matches = append(matches, grepMatch{
|
||||
File: displayPath,
|
||||
Line: lineNum,
|
||||
Text: line,
|
||||
})
|
||||
if len(matches) >= limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return matches
|
||||
}
|
||||
123
internal/tool/fs/ls.go
Normal file
123
internal/tool/fs/ls.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package fs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/tool"
|
||||
)
|
||||
|
||||
const lsToolName = "fs.ls"
|
||||
|
||||
var lsParams = json.RawMessage(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory path to list (defaults to current directory)"
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
type LSTool struct{}
|
||||
|
||||
func NewLSTool() *LSTool { return &LSTool{} }
|
||||
|
||||
func (t *LSTool) Name() string { return lsToolName }
|
||||
func (t *LSTool) Description() string { return "List directory contents with file types and sizes" }
|
||||
func (t *LSTool) Parameters() json.RawMessage { return lsParams }
|
||||
func (t *LSTool) IsReadOnly() bool { return true }
|
||||
func (t *LSTool) IsDestructive() bool { return false }
|
||||
|
||||
type lsArgs struct {
|
||||
Path string `json:"path,omitempty"`
|
||||
}
|
||||
|
||||
func (t *LSTool) Execute(_ context.Context, args json.RawMessage) (tool.Result, error) {
|
||||
var a lsArgs
|
||||
if err := json.Unmarshal(args, &a); err != nil {
|
||||
return tool.Result{}, fmt.Errorf("fs.ls: invalid args: %w", err)
|
||||
}
|
||||
|
||||
dir := a.Path
|
||||
if dir == "" {
|
||||
var err error
|
||||
dir, err = os.Getwd()
|
||||
if err != nil {
|
||||
return tool.Result{}, fmt.Errorf("fs.ls: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return tool.Result{Output: fmt.Sprintf("Error: %v", err)}, nil
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
dirCount, fileCount := 0, 0
|
||||
|
||||
for _, entry := range entries {
|
||||
info, err := entry.Info()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
prefix := " "
|
||||
if entry.IsDir() {
|
||||
prefix = "d"
|
||||
dirCount++
|
||||
} else {
|
||||
fileCount++
|
||||
}
|
||||
|
||||
size := formatSize(info.Size())
|
||||
if entry.IsDir() {
|
||||
size = "-"
|
||||
}
|
||||
|
||||
// Check for symlink
|
||||
if entry.Type()&fs.ModeSymlink != 0 {
|
||||
prefix = "l"
|
||||
target, err := os.Readlink(filepath.Join(dir, entry.Name()))
|
||||
if err == nil {
|
||||
fmt.Fprintf(&b, "%s %8s %s -> %s\n", prefix, size, entry.Name(), target)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(&b, "%s %8s %s\n", prefix, size, entry.Name())
|
||||
}
|
||||
|
||||
output := strings.TrimRight(b.String(), "\n")
|
||||
if output == "" {
|
||||
output = "(empty directory)"
|
||||
}
|
||||
|
||||
return tool.Result{
|
||||
Output: output,
|
||||
Metadata: map[string]any{
|
||||
"directory": dir,
|
||||
"files": fileCount,
|
||||
"dirs": dirCount,
|
||||
"total": fileCount + dirCount,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func formatSize(bytes int64) string {
|
||||
switch {
|
||||
case bytes >= 1<<30:
|
||||
return fmt.Sprintf("%.1fG", float64(bytes)/(1<<30))
|
||||
case bytes >= 1<<20:
|
||||
return fmt.Sprintf("%.1fM", float64(bytes)/(1<<20))
|
||||
case bytes >= 1<<10:
|
||||
return fmt.Sprintf("%.1fK", float64(bytes)/(1<<10))
|
||||
default:
|
||||
return fmt.Sprintf("%dB", bytes)
|
||||
}
|
||||
}
|
||||
123
internal/tool/fs/read.go
Normal file
123
internal/tool/fs/read.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package fs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/tool"
|
||||
)
|
||||
|
||||
const (
|
||||
readToolName = "fs.read"
|
||||
defaultMaxLines = 2000
|
||||
)
|
||||
|
||||
var readParams = json.RawMessage(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Absolute path to the file to read"
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Line number to start reading from (0-based)"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of lines to read"
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
}`)
|
||||
|
||||
type ReadTool struct {
|
||||
maxLines int
|
||||
}
|
||||
|
||||
type ReadOption func(*ReadTool)
|
||||
|
||||
func WithMaxLines(n int) ReadOption {
|
||||
return func(t *ReadTool) { t.maxLines = n }
|
||||
}
|
||||
|
||||
func NewReadTool(opts ...ReadOption) *ReadTool {
|
||||
t := &ReadTool{maxLines: defaultMaxLines}
|
||||
for _, opt := range opts {
|
||||
opt(t)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *ReadTool) Name() string { return readToolName }
|
||||
func (t *ReadTool) Description() string { return "Read a file from the filesystem with optional offset and line limit" }
|
||||
func (t *ReadTool) Parameters() json.RawMessage { return readParams }
|
||||
func (t *ReadTool) IsReadOnly() bool { return true }
|
||||
func (t *ReadTool) IsDestructive() bool { return false }
|
||||
|
||||
type readArgs struct {
|
||||
Path string `json:"path"`
|
||||
Offset int `json:"offset,omitempty"`
|
||||
Limit int `json:"limit,omitempty"`
|
||||
}
|
||||
|
||||
func (t *ReadTool) Execute(_ context.Context, args json.RawMessage) (tool.Result, error) {
|
||||
var a readArgs
|
||||
if err := json.Unmarshal(args, &a); err != nil {
|
||||
return tool.Result{}, fmt.Errorf("fs.read: invalid args: %w", err)
|
||||
}
|
||||
if a.Path == "" {
|
||||
return tool.Result{}, fmt.Errorf("fs.read: path required")
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(a.Path)
|
||||
if err != nil {
|
||||
return tool.Result{Output: fmt.Sprintf("Error: %v", err)}, nil
|
||||
}
|
||||
|
||||
lines := strings.Split(string(data), "\n")
|
||||
totalLines := len(lines)
|
||||
|
||||
// Apply offset
|
||||
offset := a.Offset
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
if offset >= totalLines {
|
||||
return tool.Result{
|
||||
Output: fmt.Sprintf("(file has %d lines, offset %d is past end)", totalLines, offset),
|
||||
Metadata: map[string]any{"total_lines": totalLines},
|
||||
}, nil
|
||||
}
|
||||
lines = lines[offset:]
|
||||
|
||||
// Apply limit
|
||||
limit := a.Limit
|
||||
if limit <= 0 {
|
||||
limit = t.maxLines
|
||||
}
|
||||
truncated := false
|
||||
if len(lines) > limit {
|
||||
lines = lines[:limit]
|
||||
truncated = true
|
||||
}
|
||||
|
||||
// Format with line numbers (1-based, matching cat -n)
|
||||
var b strings.Builder
|
||||
for i, line := range lines {
|
||||
fmt.Fprintf(&b, "%d\t%s\n", offset+i+1, line)
|
||||
}
|
||||
|
||||
output := strings.TrimRight(b.String(), "\n")
|
||||
|
||||
meta := map[string]any{"total_lines": totalLines}
|
||||
if truncated {
|
||||
meta["truncated"] = true
|
||||
meta["showing"] = fmt.Sprintf("lines %d-%d of %d", offset+1, offset+len(lines), totalLines)
|
||||
}
|
||||
|
||||
return tool.Result{Output: output, Metadata: meta}, nil
|
||||
}
|
||||
68
internal/tool/fs/write.go
Normal file
68
internal/tool/fs/write.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package fs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/tool"
|
||||
)
|
||||
|
||||
const writeToolName = "fs.write"
|
||||
|
||||
var writeParams = json.RawMessage(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Absolute path to the file to write"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Content to write to the file"
|
||||
}
|
||||
},
|
||||
"required": ["path", "content"]
|
||||
}`)
|
||||
|
||||
type WriteTool struct{}
|
||||
|
||||
func NewWriteTool() *WriteTool { return &WriteTool{} }
|
||||
|
||||
func (t *WriteTool) Name() string { return writeToolName }
|
||||
func (t *WriteTool) Description() string { return "Write content to a file, creating parent directories as needed" }
|
||||
func (t *WriteTool) Parameters() json.RawMessage { return writeParams }
|
||||
func (t *WriteTool) IsReadOnly() bool { return false }
|
||||
func (t *WriteTool) IsDestructive() bool { return false }
|
||||
|
||||
type writeArgs struct {
|
||||
Path string `json:"path"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
func (t *WriteTool) Execute(_ context.Context, args json.RawMessage) (tool.Result, error) {
|
||||
var a writeArgs
|
||||
if err := json.Unmarshal(args, &a); err != nil {
|
||||
return tool.Result{}, fmt.Errorf("fs.write: invalid args: %w", err)
|
||||
}
|
||||
if a.Path == "" {
|
||||
return tool.Result{}, fmt.Errorf("fs.write: path required")
|
||||
}
|
||||
|
||||
// Create parent directories
|
||||
dir := filepath.Dir(a.Path)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return tool.Result{Output: fmt.Sprintf("Error creating directory: %v", err)}, nil
|
||||
}
|
||||
|
||||
if err := os.WriteFile(a.Path, []byte(a.Content), 0o644); err != nil {
|
||||
return tool.Result{Output: fmt.Sprintf("Error writing file: %v", err)}, nil
|
||||
}
|
||||
|
||||
return tool.Result{
|
||||
Output: fmt.Sprintf("Wrote %d bytes to %s", len(a.Content), a.Path),
|
||||
Metadata: map[string]any{"bytes_written": len(a.Content), "path": a.Path},
|
||||
}, nil
|
||||
}
|
||||
77
internal/tool/registry.go
Normal file
77
internal/tool/registry.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package tool
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Definition is the provider-agnostic tool schema sent to the LLM.
|
||||
type Definition struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters json.RawMessage `json:"parameters"`
|
||||
}
|
||||
|
||||
// Registry holds all available tools.
|
||||
type Registry struct {
|
||||
mu sync.RWMutex
|
||||
tools map[string]Tool
|
||||
}
|
||||
|
||||
func NewRegistry() *Registry {
|
||||
return &Registry{
|
||||
tools: make(map[string]Tool),
|
||||
}
|
||||
}
|
||||
|
||||
// Register adds a tool. Overwrites if name already exists.
|
||||
func (r *Registry) Register(t Tool) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.tools[t.Name()] = t
|
||||
}
|
||||
|
||||
// Get returns a tool by name.
|
||||
func (r *Registry) Get(name string) (Tool, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
t, ok := r.tools[name]
|
||||
return t, ok
|
||||
}
|
||||
|
||||
// All returns all registered tools.
|
||||
func (r *Registry) All() []Tool {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
all := make([]Tool, 0, len(r.tools))
|
||||
for _, t := range r.tools {
|
||||
all = append(all, t)
|
||||
}
|
||||
return all
|
||||
}
|
||||
|
||||
// Definitions returns tool definitions for all registered tools,
|
||||
// suitable for sending to the LLM.
|
||||
func (r *Registry) Definitions() []Definition {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
defs := make([]Definition, 0, len(r.tools))
|
||||
for _, t := range r.tools {
|
||||
defs = append(defs, Definition{
|
||||
Name: t.Name(),
|
||||
Description: t.Description(),
|
||||
Parameters: t.Parameters(),
|
||||
})
|
||||
}
|
||||
return defs
|
||||
}
|
||||
|
||||
// MustGet returns a tool by name or panics. For use in tests.
|
||||
func (r *Registry) MustGet(name string) Tool {
|
||||
t, ok := r.Get(name)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("tool not found: %q", name))
|
||||
}
|
||||
return t
|
||||
}
|
||||
208
internal/tool/registry_test.go
Normal file
208
internal/tool/registry_test.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package tool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"slices"
|
||||
"sort"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// stubTool is a minimal Tool implementation for testing.
|
||||
type stubTool struct {
|
||||
name string
|
||||
description string
|
||||
params json.RawMessage
|
||||
readOnly bool
|
||||
destructive bool
|
||||
execFn func(ctx context.Context, args json.RawMessage) (Result, error)
|
||||
}
|
||||
|
||||
func (s *stubTool) Name() string { return s.name }
|
||||
func (s *stubTool) Description() string { return s.description }
|
||||
func (s *stubTool) Parameters() json.RawMessage { return s.params }
|
||||
func (s *stubTool) IsReadOnly() bool { return s.readOnly }
|
||||
func (s *stubTool) IsDestructive() bool { return s.destructive }
|
||||
func (s *stubTool) Execute(ctx context.Context, args json.RawMessage) (Result, error) {
|
||||
if s.execFn != nil {
|
||||
return s.execFn(ctx, args)
|
||||
}
|
||||
return Result{Output: "ok"}, nil
|
||||
}
|
||||
|
||||
func TestRegistry_RegisterAndGet(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register(&stubTool{name: "bash", description: "run commands"})
|
||||
|
||||
tool, ok := r.Get("bash")
|
||||
if !ok {
|
||||
t.Fatal("Get(bash) should find tool")
|
||||
}
|
||||
if tool.Name() != "bash" {
|
||||
t.Errorf("Name() = %q", tool.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Get_NotFound(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
_, ok := r.Get("nonexistent")
|
||||
if ok {
|
||||
t.Error("Get(nonexistent) should return false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Register_Overwrite(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register(&stubTool{name: "bash", description: "old"})
|
||||
r.Register(&stubTool{name: "bash", description: "new"})
|
||||
|
||||
tool, _ := r.Get("bash")
|
||||
if tool.Description() != "new" {
|
||||
t.Errorf("Description() = %q, want 'new' (overwritten)", tool.Description())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_All(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register(&stubTool{name: "bash"})
|
||||
r.Register(&stubTool{name: "fs.read"})
|
||||
r.Register(&stubTool{name: "fs.write"})
|
||||
|
||||
all := r.All()
|
||||
if len(all) != 3 {
|
||||
t.Fatalf("len(All()) = %d, want 3", len(all))
|
||||
}
|
||||
|
||||
names := make([]string, len(all))
|
||||
for i, t := range all {
|
||||
names[i] = t.Name()
|
||||
}
|
||||
sort.Strings(names)
|
||||
|
||||
want := []string{"bash", "fs.read", "fs.write"}
|
||||
if !slices.Equal(names, want) {
|
||||
t.Errorf("All() names = %v, want %v", names, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Definitions(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register(&stubTool{
|
||||
name: "bash",
|
||||
description: "Run a command",
|
||||
params: json.RawMessage(`{"type":"object","properties":{"command":{"type":"string"}}}`),
|
||||
})
|
||||
r.Register(&stubTool{
|
||||
name: "fs.read",
|
||||
description: "Read a file",
|
||||
params: json.RawMessage(`{"type":"object","properties":{"path":{"type":"string"}}}`),
|
||||
})
|
||||
|
||||
defs := r.Definitions()
|
||||
if len(defs) != 2 {
|
||||
t.Fatalf("len(Definitions()) = %d, want 2", len(defs))
|
||||
}
|
||||
|
||||
// Find bash definition
|
||||
var bashDef *Definition
|
||||
for i := range defs {
|
||||
if defs[i].Name == "bash" {
|
||||
bashDef = &defs[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
if bashDef == nil {
|
||||
t.Fatal("bash definition not found")
|
||||
}
|
||||
if bashDef.Description != "Run a command" {
|
||||
t.Errorf("bash.Description = %q", bashDef.Description)
|
||||
}
|
||||
if bashDef.Parameters == nil {
|
||||
t.Error("bash.Parameters should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_MustGet_Panics(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Error("MustGet should panic for missing tool")
|
||||
}
|
||||
}()
|
||||
|
||||
r.MustGet("nonexistent")
|
||||
}
|
||||
|
||||
func TestRegistry_MustGet_Success(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register(&stubTool{name: "bash"})
|
||||
|
||||
tool := r.MustGet("bash")
|
||||
if tool.Name() != "bash" {
|
||||
t.Errorf("Name() = %q", tool.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Empty(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
if len(r.All()) != 0 {
|
||||
t.Error("empty registry should return no tools")
|
||||
}
|
||||
if len(r.Definitions()) != 0 {
|
||||
t.Error("empty registry should return no definitions")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStubTool_Execute(t *testing.T) {
|
||||
called := false
|
||||
tool := &stubTool{
|
||||
name: "test",
|
||||
execFn: func(ctx context.Context, args json.RawMessage) (Result, error) {
|
||||
called = true
|
||||
var input struct{ Value string }
|
||||
json.Unmarshal(args, &input)
|
||||
return Result{
|
||||
Output: "processed: " + input.Value,
|
||||
Metadata: map[string]any{"key": "val"},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
result, err := tool.Execute(context.Background(), json.RawMessage(`{"Value":"hello"}`))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if !called {
|
||||
t.Error("execFn should have been called")
|
||||
}
|
||||
if result.Output != "processed: hello" {
|
||||
t.Errorf("Output = %q", result.Output)
|
||||
}
|
||||
if result.Metadata["key"] != "val" {
|
||||
t.Errorf("Metadata = %v", result.Metadata)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolInterface_ReadOnlyDestructive(t *testing.T) {
|
||||
readTool := &stubTool{name: "fs.read", readOnly: true, destructive: false}
|
||||
writeTool := &stubTool{name: "fs.write", readOnly: false, destructive: false}
|
||||
deleteTool := &stubTool{name: "bash.rm", readOnly: false, destructive: true}
|
||||
|
||||
if !readTool.IsReadOnly() {
|
||||
t.Error("fs.read should be read-only")
|
||||
}
|
||||
if readTool.IsDestructive() {
|
||||
t.Error("fs.read should not be destructive")
|
||||
}
|
||||
if writeTool.IsReadOnly() {
|
||||
t.Error("fs.write should not be read-only")
|
||||
}
|
||||
if deleteTool.IsReadOnly() {
|
||||
t.Error("bash.rm should not be read-only")
|
||||
}
|
||||
if !deleteTool.IsDestructive() {
|
||||
t.Error("bash.rm should be destructive")
|
||||
}
|
||||
}
|
||||
9
internal/tool/result.go
Normal file
9
internal/tool/result.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package tool
|
||||
|
||||
// Result is the output of a tool execution.
|
||||
type Result struct {
|
||||
// Output is the text content returned to the LLM.
|
||||
Output string
|
||||
// Metadata carries optional structured data (exit code, file path, match count, etc.).
|
||||
Metadata map[string]any
|
||||
}
|
||||
22
internal/tool/tool.go
Normal file
22
internal/tool/tool.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package tool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// Tool is the interface every tool must implement.
|
||||
type Tool interface {
|
||||
// Name returns the tool's identifier (used in LLM tool schemas).
|
||||
Name() string
|
||||
// Description returns a human-readable description for the LLM.
|
||||
Description() string
|
||||
// Parameters returns the JSON Schema for the tool's input.
|
||||
Parameters() json.RawMessage
|
||||
// Execute runs the tool with the given JSON arguments.
|
||||
Execute(ctx context.Context, args json.RawMessage) (Result, error)
|
||||
// IsReadOnly returns true if the tool only reads (safe for concurrent execution).
|
||||
IsReadOnly() bool
|
||||
// IsDestructive returns true if the tool can cause irreversible changes.
|
||||
IsDestructive() bool
|
||||
}
|
||||
Reference in New Issue
Block a user