Mistral provider adapter with streaming, tool calls (single-chunk pattern), stop reason inference, model listing, capabilities, and JSON output support. Tool system: bash (7 security checks, shell alias harvesting for bash/zsh/fish), file ops (read, write, edit, glob, grep, ls). Alias harvesting collects 300+ aliases from user's shell config. Engine agentic loop: stream → tool execution → re-query → until done. Tool gating on model capabilities. Max turns safety limit. CLI pipe mode: echo "prompt" | gnoma streams response to stdout. Flags: --provider, --model, --system, --api-key, --max-turns, --verbose, --version. Provider interface expanded: Models(), DefaultModel(), Capabilities (ToolUse, JSONOutput, Vision, Thinking, ContextWindow, MaxOutput), ResponseFormat with JSON schema support. Live verified: text streaming + tool calling with devstral-small. 117 tests across 8 packages, 10MB binary.
476 lines
14 KiB
Go
476 lines
14 KiB
Go
package engine
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"testing"
|
|
|
|
"somegit.dev/Owlibou/gnoma/internal/message"
|
|
"somegit.dev/Owlibou/gnoma/internal/provider"
|
|
"somegit.dev/Owlibou/gnoma/internal/stream"
|
|
"somegit.dev/Owlibou/gnoma/internal/tool"
|
|
)
|
|
|
|
// --- Mock Provider ---
|
|
|
|
// mockProvider returns pre-configured streams for each call.
|
|
type mockProvider struct {
|
|
name string
|
|
calls int
|
|
streams []stream.Stream // one per call, consumed in order
|
|
}
|
|
|
|
func (m *mockProvider) Name() string { return m.name }
|
|
func (m *mockProvider) DefaultModel() string { return "mock-model" }
|
|
func (m *mockProvider) Models(_ context.Context) ([]provider.ModelInfo, error) {
|
|
return []provider.ModelInfo{{
|
|
ID: "mock-model", Name: "mock-model", Provider: m.name,
|
|
Capabilities: provider.Capabilities{ToolUse: true},
|
|
}}, nil
|
|
}
|
|
func (m *mockProvider) Stream(_ context.Context, _ provider.Request) (stream.Stream, error) {
|
|
if m.calls >= len(m.streams) {
|
|
return nil, fmt.Errorf("mock: no more streams (called %d times)", m.calls+1)
|
|
}
|
|
s := m.streams[m.calls]
|
|
m.calls++
|
|
return s, nil
|
|
}
|
|
|
|
// eventStream is a mock stream backed by a slice of events.
|
|
type eventStream struct {
|
|
events []stream.Event
|
|
idx int
|
|
stopReason message.StopReason
|
|
model string
|
|
}
|
|
|
|
func newEventStream(stopReason message.StopReason, model string, events ...stream.Event) *eventStream {
|
|
// Append a final event with stop reason
|
|
events = append(events, stream.Event{
|
|
Type: stream.EventTextDelta,
|
|
StopReason: stopReason,
|
|
Model: model,
|
|
})
|
|
return &eventStream{events: events, stopReason: stopReason, model: model}
|
|
}
|
|
|
|
func (s *eventStream) Next() bool {
|
|
if s.idx >= len(s.events) {
|
|
return false
|
|
}
|
|
s.idx++
|
|
return true
|
|
}
|
|
|
|
func (s *eventStream) Current() stream.Event {
|
|
return s.events[s.idx-1]
|
|
}
|
|
|
|
func (s *eventStream) Err() error { return nil }
|
|
func (s *eventStream) Close() error { return nil }
|
|
|
|
// --- Mock Tool ---
|
|
|
|
type mockTool struct {
|
|
name string
|
|
readOnly bool
|
|
execFn func(ctx context.Context, args json.RawMessage) (tool.Result, error)
|
|
}
|
|
|
|
func (m *mockTool) Name() string { return m.name }
|
|
func (m *mockTool) Description() string { return "mock tool" }
|
|
func (m *mockTool) Parameters() json.RawMessage { return json.RawMessage(`{"type":"object"}`) }
|
|
func (m *mockTool) IsReadOnly() bool { return m.readOnly }
|
|
func (m *mockTool) IsDestructive() bool { return false }
|
|
func (m *mockTool) Execute(ctx context.Context, args json.RawMessage) (tool.Result, error) {
|
|
if m.execFn != nil {
|
|
return m.execFn(ctx, args)
|
|
}
|
|
return tool.Result{Output: "mock output"}, nil
|
|
}
|
|
|
|
// --- Tests ---
|
|
|
|
func TestNew_ValidConfig(t *testing.T) {
|
|
e, err := New(Config{
|
|
Provider: &mockProvider{name: "test"},
|
|
Tools: tool.NewRegistry(),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("New: %v", err)
|
|
}
|
|
if e == nil {
|
|
t.Fatal("engine should not be nil")
|
|
}
|
|
}
|
|
|
|
func TestNew_MissingProvider(t *testing.T) {
|
|
_, err := New(Config{Tools: tool.NewRegistry()})
|
|
if err == nil {
|
|
t.Fatal("expected error for missing provider")
|
|
}
|
|
}
|
|
|
|
func TestNew_MissingTools(t *testing.T) {
|
|
_, err := New(Config{Provider: &mockProvider{name: "test"}})
|
|
if err == nil {
|
|
t.Fatal("expected error for missing tool registry")
|
|
}
|
|
}
|
|
|
|
func TestSubmit_SimpleTextResponse(t *testing.T) {
|
|
mp := &mockProvider{
|
|
name: "test",
|
|
streams: []stream.Stream{
|
|
newEventStream(message.StopEndTurn, "test-model",
|
|
stream.Event{Type: stream.EventTextDelta, Text: "Hello "},
|
|
stream.Event{Type: stream.EventTextDelta, Text: "world!"},
|
|
stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 10, OutputTokens: 5}},
|
|
),
|
|
},
|
|
}
|
|
|
|
e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry()})
|
|
|
|
var events []stream.Event
|
|
turn, err := e.Submit(context.Background(), "hi", func(evt stream.Event) {
|
|
events = append(events, evt)
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Submit: %v", err)
|
|
}
|
|
|
|
// Check turn
|
|
if turn.Rounds != 1 {
|
|
t.Errorf("Rounds = %d, want 1", turn.Rounds)
|
|
}
|
|
if len(turn.Messages) != 1 {
|
|
t.Fatalf("len(Messages) = %d, want 1", len(turn.Messages))
|
|
}
|
|
if turn.Messages[0].TextContent() != "Hello world!" {
|
|
t.Errorf("TextContent = %q", turn.Messages[0].TextContent())
|
|
}
|
|
if turn.Usage.InputTokens != 10 {
|
|
t.Errorf("Usage.InputTokens = %d", turn.Usage.InputTokens)
|
|
}
|
|
|
|
// Check history
|
|
history := e.History()
|
|
if len(history) != 2 {
|
|
t.Fatalf("len(History) = %d, want 2 (user + assistant)", len(history))
|
|
}
|
|
if history[0].Role != message.RoleUser {
|
|
t.Errorf("History[0].Role = %q", history[0].Role)
|
|
}
|
|
if history[1].Role != message.RoleAssistant {
|
|
t.Errorf("History[1].Role = %q", history[1].Role)
|
|
}
|
|
|
|
// Check events were forwarded
|
|
if len(events) == 0 {
|
|
t.Error("callback should have received events")
|
|
}
|
|
}
|
|
|
|
func TestSubmit_ToolCallLoop(t *testing.T) {
|
|
reg := tool.NewRegistry()
|
|
reg.Register(&mockTool{
|
|
name: "bash",
|
|
execFn: func(_ context.Context, args json.RawMessage) (tool.Result, error) {
|
|
return tool.Result{Output: "file1.go\nfile2.go"}, nil
|
|
},
|
|
})
|
|
|
|
mp := &mockProvider{
|
|
name: "test",
|
|
streams: []stream.Stream{
|
|
// Round 1: model calls a tool
|
|
newEventStream(message.StopToolUse, "model-1",
|
|
stream.Event{Type: stream.EventTextDelta, Text: "Let me list files."},
|
|
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "bash"},
|
|
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_1", Args: json.RawMessage(`{"command":"ls"}`)},
|
|
),
|
|
// Round 2: model responds with final answer
|
|
newEventStream(message.StopEndTurn, "model-1",
|
|
stream.Event{Type: stream.EventTextDelta, Text: "Found file1.go and file2.go."},
|
|
),
|
|
},
|
|
}
|
|
|
|
e, _ := New(Config{Provider: mp, Tools: reg})
|
|
|
|
turn, err := e.Submit(context.Background(), "list files", nil)
|
|
if err != nil {
|
|
t.Fatalf("Submit: %v", err)
|
|
}
|
|
|
|
if turn.Rounds != 2 {
|
|
t.Errorf("Rounds = %d, want 2", turn.Rounds)
|
|
}
|
|
|
|
// Messages: assistant (tool call), tool results, assistant (final)
|
|
if len(turn.Messages) != 3 {
|
|
t.Fatalf("len(Messages) = %d, want 3", len(turn.Messages))
|
|
}
|
|
|
|
// First message has tool call
|
|
if !turn.Messages[0].HasToolCalls() {
|
|
t.Error("Messages[0] should have tool calls")
|
|
}
|
|
|
|
// Second message is tool results
|
|
if turn.Messages[1].Role != message.RoleUser {
|
|
t.Errorf("Messages[1].Role = %q, want user (tool results)", turn.Messages[1].Role)
|
|
}
|
|
|
|
// Third message is final text
|
|
if turn.Messages[2].TextContent() != "Found file1.go and file2.go." {
|
|
t.Errorf("Messages[2].TextContent = %q", turn.Messages[2].TextContent())
|
|
}
|
|
|
|
// History: user + assistant(tool call) + tool results + assistant(final)
|
|
if len(e.History()) != 4 {
|
|
t.Errorf("len(History) = %d, want 4", len(e.History()))
|
|
}
|
|
|
|
// Provider called twice
|
|
if mp.calls != 2 {
|
|
t.Errorf("provider called %d times, want 2", mp.calls)
|
|
}
|
|
}
|
|
|
|
func TestSubmit_UnknownTool(t *testing.T) {
|
|
reg := tool.NewRegistry()
|
|
// Don't register any tools
|
|
|
|
mp := &mockProvider{
|
|
name: "test",
|
|
streams: []stream.Stream{
|
|
// Model calls a tool that doesn't exist
|
|
newEventStream(message.StopToolUse, "",
|
|
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "nonexistent"},
|
|
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_1", Args: json.RawMessage(`{}`)},
|
|
),
|
|
// Model responds after seeing error
|
|
newEventStream(message.StopEndTurn, "",
|
|
stream.Event{Type: stream.EventTextDelta, Text: "Sorry, that tool doesn't exist."},
|
|
),
|
|
},
|
|
}
|
|
|
|
e, _ := New(Config{Provider: mp, Tools: reg})
|
|
|
|
turn, err := e.Submit(context.Background(), "do something", nil)
|
|
if err != nil {
|
|
t.Fatalf("Submit: %v", err)
|
|
}
|
|
|
|
// Should still complete — unknown tool returns error result, model sees it
|
|
if turn.Rounds != 2 {
|
|
t.Errorf("Rounds = %d, want 2", turn.Rounds)
|
|
}
|
|
}
|
|
|
|
func TestSubmit_ToolExecutionError(t *testing.T) {
|
|
reg := tool.NewRegistry()
|
|
reg.Register(&mockTool{
|
|
name: "failing",
|
|
execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) {
|
|
return tool.Result{}, errors.New("disk full")
|
|
},
|
|
})
|
|
|
|
mp := &mockProvider{
|
|
name: "test",
|
|
streams: []stream.Stream{
|
|
newEventStream(message.StopToolUse, "",
|
|
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "failing"},
|
|
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_1", Args: json.RawMessage(`{}`)},
|
|
),
|
|
newEventStream(message.StopEndTurn, "",
|
|
stream.Event{Type: stream.EventTextDelta, Text: "The tool failed."},
|
|
),
|
|
},
|
|
}
|
|
|
|
e, _ := New(Config{Provider: mp, Tools: reg})
|
|
|
|
turn, err := e.Submit(context.Background(), "do it", nil)
|
|
if err != nil {
|
|
t.Fatalf("Submit: %v", err)
|
|
}
|
|
|
|
// Tool error is returned as error result, not a fatal error
|
|
if turn.Rounds != 2 {
|
|
t.Errorf("Rounds = %d, want 2", turn.Rounds)
|
|
}
|
|
}
|
|
|
|
func TestSubmit_MaxTurnsLimit(t *testing.T) {
|
|
reg := tool.NewRegistry()
|
|
reg.Register(&mockTool{name: "bash"})
|
|
|
|
// Provider always returns tool calls — would loop forever
|
|
mp := &mockProvider{
|
|
name: "test",
|
|
streams: []stream.Stream{
|
|
newEventStream(message.StopToolUse, "",
|
|
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "bash"},
|
|
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_1", Args: json.RawMessage(`{}`)},
|
|
),
|
|
newEventStream(message.StopToolUse, "",
|
|
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_2", ToolCallName: "bash"},
|
|
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_2", Args: json.RawMessage(`{}`)},
|
|
),
|
|
newEventStream(message.StopToolUse, "",
|
|
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_3", ToolCallName: "bash"},
|
|
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_3", Args: json.RawMessage(`{}`)},
|
|
),
|
|
},
|
|
}
|
|
|
|
e, _ := New(Config{Provider: mp, Tools: reg, MaxTurns: 2})
|
|
|
|
_, err := e.Submit(context.Background(), "loop forever", nil)
|
|
if err == nil {
|
|
t.Fatal("expected error from max turns limit")
|
|
}
|
|
if mp.calls != 2 {
|
|
t.Errorf("provider called %d times, want 2 (limited)", mp.calls)
|
|
}
|
|
}
|
|
|
|
func TestSubmit_MultipleToolCalls(t *testing.T) {
|
|
reg := tool.NewRegistry()
|
|
reg.Register(&mockTool{
|
|
name: "bash",
|
|
execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) {
|
|
return tool.Result{Output: "bash output"}, nil
|
|
},
|
|
})
|
|
reg.Register(&mockTool{
|
|
name: "fs.read",
|
|
readOnly: true,
|
|
execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) {
|
|
return tool.Result{Output: "file content"}, nil
|
|
},
|
|
})
|
|
|
|
mp := &mockProvider{
|
|
name: "test",
|
|
streams: []stream.Stream{
|
|
// Model calls two tools at once
|
|
newEventStream(message.StopToolUse, "",
|
|
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "bash"},
|
|
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_1", Args: json.RawMessage(`{"command":"ls"}`)},
|
|
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_2", ToolCallName: "fs.read"},
|
|
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_2", Args: json.RawMessage(`{"path":"go.mod"}`)},
|
|
),
|
|
newEventStream(message.StopEndTurn, "",
|
|
stream.Event{Type: stream.EventTextDelta, Text: "Done."},
|
|
),
|
|
},
|
|
}
|
|
|
|
e, _ := New(Config{Provider: mp, Tools: reg})
|
|
|
|
turn, err := e.Submit(context.Background(), "run both", nil)
|
|
if err != nil {
|
|
t.Fatalf("Submit: %v", err)
|
|
}
|
|
|
|
if turn.Rounds != 2 {
|
|
t.Errorf("Rounds = %d, want 2", turn.Rounds)
|
|
}
|
|
|
|
// Tool results message should have 2 results
|
|
toolMsg := turn.Messages[1] // assistant, tool_results, assistant
|
|
if len(toolMsg.Content) != 2 {
|
|
t.Errorf("tool results has %d content blocks, want 2", len(toolMsg.Content))
|
|
}
|
|
}
|
|
|
|
func TestSubmit_NilCallback(t *testing.T) {
|
|
mp := &mockProvider{
|
|
name: "test",
|
|
streams: []stream.Stream{
|
|
newEventStream(message.StopEndTurn, "",
|
|
stream.Event{Type: stream.EventTextDelta, Text: "ok"},
|
|
),
|
|
},
|
|
}
|
|
|
|
e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry()})
|
|
|
|
// nil callback should not panic
|
|
turn, err := e.Submit(context.Background(), "test", nil)
|
|
if err != nil {
|
|
t.Fatalf("Submit: %v", err)
|
|
}
|
|
if turn.Rounds != 1 {
|
|
t.Errorf("Rounds = %d", turn.Rounds)
|
|
}
|
|
}
|
|
|
|
func TestEngine_Reset(t *testing.T) {
|
|
mp := &mockProvider{
|
|
name: "test",
|
|
streams: []stream.Stream{
|
|
newEventStream(message.StopEndTurn, "",
|
|
stream.Event{Type: stream.EventTextDelta, Text: "first"},
|
|
stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 100}},
|
|
),
|
|
},
|
|
}
|
|
|
|
e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry()})
|
|
e.Submit(context.Background(), "hello", nil)
|
|
|
|
if len(e.History()) == 0 {
|
|
t.Fatal("history should not be empty before reset")
|
|
}
|
|
if e.Usage().InputTokens == 0 {
|
|
t.Fatal("usage should not be zero before reset")
|
|
}
|
|
|
|
e.Reset()
|
|
|
|
if len(e.History()) != 0 {
|
|
t.Errorf("history should be empty after reset, got %d", len(e.History()))
|
|
}
|
|
if e.Usage().InputTokens != 0 {
|
|
t.Errorf("usage should be zero after reset, got %d", e.Usage().InputTokens)
|
|
}
|
|
}
|
|
|
|
func TestSubmit_CumulativeUsage(t *testing.T) {
|
|
mp := &mockProvider{
|
|
name: "test",
|
|
streams: []stream.Stream{
|
|
newEventStream(message.StopEndTurn, "",
|
|
stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 100, OutputTokens: 50}},
|
|
stream.Event{Type: stream.EventTextDelta, Text: "first"},
|
|
),
|
|
newEventStream(message.StopEndTurn, "",
|
|
stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 200, OutputTokens: 80}},
|
|
stream.Event{Type: stream.EventTextDelta, Text: "second"},
|
|
),
|
|
},
|
|
}
|
|
|
|
e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry()})
|
|
|
|
e.Submit(context.Background(), "one", nil)
|
|
e.Submit(context.Background(), "two", nil)
|
|
|
|
if e.Usage().InputTokens != 300 {
|
|
t.Errorf("cumulative InputTokens = %d, want 300", e.Usage().InputTokens)
|
|
}
|
|
if e.Usage().OutputTokens != 130 {
|
|
t.Errorf("cumulative OutputTokens = %d, want 130", e.Usage().OutputTokens)
|
|
}
|
|
}
|