Files
gnoma/internal/engine/engine_test.go
vikingowl f0633d8ac6 feat: complete M1 — core engine with Mistral provider
Mistral provider adapter with streaming, tool calls (single-chunk
pattern), stop reason inference, model listing, capabilities, and
JSON output support.

Tool system: bash (7 security checks, shell alias harvesting for
bash/zsh/fish), file ops (read, write, edit, glob, grep, ls).
Alias harvesting collects 300+ aliases from user's shell config.

Engine agentic loop: stream → tool execution → re-query → until
done. Tool gating on model capabilities. Max turns safety limit.

CLI pipe mode: echo "prompt" | gnoma streams response to stdout.
Flags: --provider, --model, --system, --api-key, --max-turns,
--verbose, --version.

Provider interface expanded: Models(), DefaultModel(), Capabilities
(ToolUse, JSONOutput, Vision, Thinking, ContextWindow, MaxOutput),
ResponseFormat with JSON schema support.

Live verified: text streaming + tool calling with devstral-small.
117 tests across 8 packages, 10MB binary.
2026-04-03 12:01:55 +02:00

476 lines
14 KiB
Go

package engine
import (
"context"
"encoding/json"
"errors"
"fmt"
"testing"
"somegit.dev/Owlibou/gnoma/internal/message"
"somegit.dev/Owlibou/gnoma/internal/provider"
"somegit.dev/Owlibou/gnoma/internal/stream"
"somegit.dev/Owlibou/gnoma/internal/tool"
)
// --- Mock Provider ---
// mockProvider returns pre-configured streams for each call.
type mockProvider struct {
name string
calls int
streams []stream.Stream // one per call, consumed in order
}
func (m *mockProvider) Name() string { return m.name }
func (m *mockProvider) DefaultModel() string { return "mock-model" }
func (m *mockProvider) Models(_ context.Context) ([]provider.ModelInfo, error) {
return []provider.ModelInfo{{
ID: "mock-model", Name: "mock-model", Provider: m.name,
Capabilities: provider.Capabilities{ToolUse: true},
}}, nil
}
func (m *mockProvider) Stream(_ context.Context, _ provider.Request) (stream.Stream, error) {
if m.calls >= len(m.streams) {
return nil, fmt.Errorf("mock: no more streams (called %d times)", m.calls+1)
}
s := m.streams[m.calls]
m.calls++
return s, nil
}
// eventStream is a mock stream backed by a slice of events.
type eventStream struct {
events []stream.Event
idx int
stopReason message.StopReason
model string
}
func newEventStream(stopReason message.StopReason, model string, events ...stream.Event) *eventStream {
// Append a final event with stop reason
events = append(events, stream.Event{
Type: stream.EventTextDelta,
StopReason: stopReason,
Model: model,
})
return &eventStream{events: events, stopReason: stopReason, model: model}
}
func (s *eventStream) Next() bool {
if s.idx >= len(s.events) {
return false
}
s.idx++
return true
}
func (s *eventStream) Current() stream.Event {
return s.events[s.idx-1]
}
func (s *eventStream) Err() error { return nil }
func (s *eventStream) Close() error { return nil }
// --- Mock Tool ---
type mockTool struct {
name string
readOnly bool
execFn func(ctx context.Context, args json.RawMessage) (tool.Result, error)
}
func (m *mockTool) Name() string { return m.name }
func (m *mockTool) Description() string { return "mock tool" }
func (m *mockTool) Parameters() json.RawMessage { return json.RawMessage(`{"type":"object"}`) }
func (m *mockTool) IsReadOnly() bool { return m.readOnly }
func (m *mockTool) IsDestructive() bool { return false }
func (m *mockTool) Execute(ctx context.Context, args json.RawMessage) (tool.Result, error) {
if m.execFn != nil {
return m.execFn(ctx, args)
}
return tool.Result{Output: "mock output"}, nil
}
// --- Tests ---
func TestNew_ValidConfig(t *testing.T) {
e, err := New(Config{
Provider: &mockProvider{name: "test"},
Tools: tool.NewRegistry(),
})
if err != nil {
t.Fatalf("New: %v", err)
}
if e == nil {
t.Fatal("engine should not be nil")
}
}
func TestNew_MissingProvider(t *testing.T) {
_, err := New(Config{Tools: tool.NewRegistry()})
if err == nil {
t.Fatal("expected error for missing provider")
}
}
func TestNew_MissingTools(t *testing.T) {
_, err := New(Config{Provider: &mockProvider{name: "test"}})
if err == nil {
t.Fatal("expected error for missing tool registry")
}
}
func TestSubmit_SimpleTextResponse(t *testing.T) {
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
newEventStream(message.StopEndTurn, "test-model",
stream.Event{Type: stream.EventTextDelta, Text: "Hello "},
stream.Event{Type: stream.EventTextDelta, Text: "world!"},
stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 10, OutputTokens: 5}},
),
},
}
e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry()})
var events []stream.Event
turn, err := e.Submit(context.Background(), "hi", func(evt stream.Event) {
events = append(events, evt)
})
if err != nil {
t.Fatalf("Submit: %v", err)
}
// Check turn
if turn.Rounds != 1 {
t.Errorf("Rounds = %d, want 1", turn.Rounds)
}
if len(turn.Messages) != 1 {
t.Fatalf("len(Messages) = %d, want 1", len(turn.Messages))
}
if turn.Messages[0].TextContent() != "Hello world!" {
t.Errorf("TextContent = %q", turn.Messages[0].TextContent())
}
if turn.Usage.InputTokens != 10 {
t.Errorf("Usage.InputTokens = %d", turn.Usage.InputTokens)
}
// Check history
history := e.History()
if len(history) != 2 {
t.Fatalf("len(History) = %d, want 2 (user + assistant)", len(history))
}
if history[0].Role != message.RoleUser {
t.Errorf("History[0].Role = %q", history[0].Role)
}
if history[1].Role != message.RoleAssistant {
t.Errorf("History[1].Role = %q", history[1].Role)
}
// Check events were forwarded
if len(events) == 0 {
t.Error("callback should have received events")
}
}
func TestSubmit_ToolCallLoop(t *testing.T) {
reg := tool.NewRegistry()
reg.Register(&mockTool{
name: "bash",
execFn: func(_ context.Context, args json.RawMessage) (tool.Result, error) {
return tool.Result{Output: "file1.go\nfile2.go"}, nil
},
})
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
// Round 1: model calls a tool
newEventStream(message.StopToolUse, "model-1",
stream.Event{Type: stream.EventTextDelta, Text: "Let me list files."},
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "bash"},
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_1", Args: json.RawMessage(`{"command":"ls"}`)},
),
// Round 2: model responds with final answer
newEventStream(message.StopEndTurn, "model-1",
stream.Event{Type: stream.EventTextDelta, Text: "Found file1.go and file2.go."},
),
},
}
e, _ := New(Config{Provider: mp, Tools: reg})
turn, err := e.Submit(context.Background(), "list files", nil)
if err != nil {
t.Fatalf("Submit: %v", err)
}
if turn.Rounds != 2 {
t.Errorf("Rounds = %d, want 2", turn.Rounds)
}
// Messages: assistant (tool call), tool results, assistant (final)
if len(turn.Messages) != 3 {
t.Fatalf("len(Messages) = %d, want 3", len(turn.Messages))
}
// First message has tool call
if !turn.Messages[0].HasToolCalls() {
t.Error("Messages[0] should have tool calls")
}
// Second message is tool results
if turn.Messages[1].Role != message.RoleUser {
t.Errorf("Messages[1].Role = %q, want user (tool results)", turn.Messages[1].Role)
}
// Third message is final text
if turn.Messages[2].TextContent() != "Found file1.go and file2.go." {
t.Errorf("Messages[2].TextContent = %q", turn.Messages[2].TextContent())
}
// History: user + assistant(tool call) + tool results + assistant(final)
if len(e.History()) != 4 {
t.Errorf("len(History) = %d, want 4", len(e.History()))
}
// Provider called twice
if mp.calls != 2 {
t.Errorf("provider called %d times, want 2", mp.calls)
}
}
func TestSubmit_UnknownTool(t *testing.T) {
reg := tool.NewRegistry()
// Don't register any tools
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
// Model calls a tool that doesn't exist
newEventStream(message.StopToolUse, "",
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "nonexistent"},
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_1", Args: json.RawMessage(`{}`)},
),
// Model responds after seeing error
newEventStream(message.StopEndTurn, "",
stream.Event{Type: stream.EventTextDelta, Text: "Sorry, that tool doesn't exist."},
),
},
}
e, _ := New(Config{Provider: mp, Tools: reg})
turn, err := e.Submit(context.Background(), "do something", nil)
if err != nil {
t.Fatalf("Submit: %v", err)
}
// Should still complete — unknown tool returns error result, model sees it
if turn.Rounds != 2 {
t.Errorf("Rounds = %d, want 2", turn.Rounds)
}
}
func TestSubmit_ToolExecutionError(t *testing.T) {
reg := tool.NewRegistry()
reg.Register(&mockTool{
name: "failing",
execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) {
return tool.Result{}, errors.New("disk full")
},
})
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
newEventStream(message.StopToolUse, "",
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "failing"},
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_1", Args: json.RawMessage(`{}`)},
),
newEventStream(message.StopEndTurn, "",
stream.Event{Type: stream.EventTextDelta, Text: "The tool failed."},
),
},
}
e, _ := New(Config{Provider: mp, Tools: reg})
turn, err := e.Submit(context.Background(), "do it", nil)
if err != nil {
t.Fatalf("Submit: %v", err)
}
// Tool error is returned as error result, not a fatal error
if turn.Rounds != 2 {
t.Errorf("Rounds = %d, want 2", turn.Rounds)
}
}
func TestSubmit_MaxTurnsLimit(t *testing.T) {
reg := tool.NewRegistry()
reg.Register(&mockTool{name: "bash"})
// Provider always returns tool calls — would loop forever
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
newEventStream(message.StopToolUse, "",
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "bash"},
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_1", Args: json.RawMessage(`{}`)},
),
newEventStream(message.StopToolUse, "",
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_2", ToolCallName: "bash"},
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_2", Args: json.RawMessage(`{}`)},
),
newEventStream(message.StopToolUse, "",
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_3", ToolCallName: "bash"},
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_3", Args: json.RawMessage(`{}`)},
),
},
}
e, _ := New(Config{Provider: mp, Tools: reg, MaxTurns: 2})
_, err := e.Submit(context.Background(), "loop forever", nil)
if err == nil {
t.Fatal("expected error from max turns limit")
}
if mp.calls != 2 {
t.Errorf("provider called %d times, want 2 (limited)", mp.calls)
}
}
func TestSubmit_MultipleToolCalls(t *testing.T) {
reg := tool.NewRegistry()
reg.Register(&mockTool{
name: "bash",
execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) {
return tool.Result{Output: "bash output"}, nil
},
})
reg.Register(&mockTool{
name: "fs.read",
readOnly: true,
execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) {
return tool.Result{Output: "file content"}, nil
},
})
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
// Model calls two tools at once
newEventStream(message.StopToolUse, "",
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "bash"},
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_1", Args: json.RawMessage(`{"command":"ls"}`)},
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_2", ToolCallName: "fs.read"},
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_2", Args: json.RawMessage(`{"path":"go.mod"}`)},
),
newEventStream(message.StopEndTurn, "",
stream.Event{Type: stream.EventTextDelta, Text: "Done."},
),
},
}
e, _ := New(Config{Provider: mp, Tools: reg})
turn, err := e.Submit(context.Background(), "run both", nil)
if err != nil {
t.Fatalf("Submit: %v", err)
}
if turn.Rounds != 2 {
t.Errorf("Rounds = %d, want 2", turn.Rounds)
}
// Tool results message should have 2 results
toolMsg := turn.Messages[1] // assistant, tool_results, assistant
if len(toolMsg.Content) != 2 {
t.Errorf("tool results has %d content blocks, want 2", len(toolMsg.Content))
}
}
func TestSubmit_NilCallback(t *testing.T) {
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
newEventStream(message.StopEndTurn, "",
stream.Event{Type: stream.EventTextDelta, Text: "ok"},
),
},
}
e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry()})
// nil callback should not panic
turn, err := e.Submit(context.Background(), "test", nil)
if err != nil {
t.Fatalf("Submit: %v", err)
}
if turn.Rounds != 1 {
t.Errorf("Rounds = %d", turn.Rounds)
}
}
func TestEngine_Reset(t *testing.T) {
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
newEventStream(message.StopEndTurn, "",
stream.Event{Type: stream.EventTextDelta, Text: "first"},
stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 100}},
),
},
}
e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry()})
e.Submit(context.Background(), "hello", nil)
if len(e.History()) == 0 {
t.Fatal("history should not be empty before reset")
}
if e.Usage().InputTokens == 0 {
t.Fatal("usage should not be zero before reset")
}
e.Reset()
if len(e.History()) != 0 {
t.Errorf("history should be empty after reset, got %d", len(e.History()))
}
if e.Usage().InputTokens != 0 {
t.Errorf("usage should be zero after reset, got %d", e.Usage().InputTokens)
}
}
func TestSubmit_CumulativeUsage(t *testing.T) {
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
newEventStream(message.StopEndTurn, "",
stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 100, OutputTokens: 50}},
stream.Event{Type: stream.EventTextDelta, Text: "first"},
),
newEventStream(message.StopEndTurn, "",
stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 200, OutputTokens: 80}},
stream.Event{Type: stream.EventTextDelta, Text: "second"},
),
},
}
e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry()})
e.Submit(context.Background(), "one", nil)
e.Submit(context.Background(), "two", nil)
if e.Usage().InputTokens != 300 {
t.Errorf("cumulative InputTokens = %d, want 300", e.Usage().InputTokens)
}
if e.Usage().OutputTokens != 130 {
t.Errorf("cumulative OutputTokens = %d, want 130", e.Usage().OutputTokens)
}
}