Files
gnoma/internal/engine/engine_test.go
vikingowl cb2d63d06f feat: Ollama/gemma4 compat — /init flow, stream filter, safety fixes
provider/openai:
- Fix doubled tool call args (argsComplete flag): Ollama sends complete
  args in the first streaming chunk then repeats them as delta, causing
  doubled JSON and 400 errors in elfs
- Handle fs: prefix (gemma4 uses fs:grep instead of fs.grep)
- Add Reasoning field support for Ollama thinking output

cmd/gnoma:
- Early TTY detection so logger is created with correct destination
  before any component gets a reference to it (fixes slog WARN bleed
  into TUI textarea)

permission:
- Exempt spawn_elfs and agent tools from safety scanner: elf prompt
  text may legitimately mention .env/.ssh/credentials patterns and
  should not be blocked

tui/app:
- /init retry chain: no-tool-calls → spawn_elfs nudge → write nudge
  (ask for plain text output) → TUI fallback write from streamBuf
- looksLikeAgentsMD + extractMarkdownDoc: validate and clean fallback
  content before writing (reject refusals, strip narrative preambles)
- Collapse thinking output to 3 lines; ctrl+o to expand (live stream
  and committed messages)
- Stream-level filter for model pseudo-tool-call blocks: suppresses
  <<tool_code>>...</tool_code>> and <<function_call>>...<tool_call|>
  from entering streamBuf across chunk boundaries
- sanitizeAssistantText regex covers both block formats
- Reset streamFilterClose at every turn start
2026-04-05 19:24:51 +02:00

580 lines
17 KiB
Go

package engine
import (
"context"
"encoding/json"
"errors"
"fmt"
"testing"
gnomactx "somegit.dev/Owlibou/gnoma/internal/context"
"somegit.dev/Owlibou/gnoma/internal/message"
"somegit.dev/Owlibou/gnoma/internal/provider"
"somegit.dev/Owlibou/gnoma/internal/stream"
"somegit.dev/Owlibou/gnoma/internal/tool"
)
// --- Mock Provider ---
// mockProvider returns pre-configured streams for each call.
type mockProvider struct {
name string
calls int
streams []stream.Stream // one per call, consumed in order
}
func (m *mockProvider) Name() string { return m.name }
func (m *mockProvider) DefaultModel() string { return "mock-model" }
func (m *mockProvider) Models(_ context.Context) ([]provider.ModelInfo, error) {
return []provider.ModelInfo{{
ID: "mock-model", Name: "mock-model", Provider: m.name,
Capabilities: provider.Capabilities{ToolUse: true},
}}, nil
}
func (m *mockProvider) Stream(_ context.Context, _ provider.Request) (stream.Stream, error) {
if m.calls >= len(m.streams) {
return nil, fmt.Errorf("mock: no more streams (called %d times)", m.calls+1)
}
s := m.streams[m.calls]
m.calls++
return s, nil
}
// eventStream is a mock stream backed by a slice of events.
type eventStream struct {
events []stream.Event
idx int
stopReason message.StopReason
model string
}
func newEventStream(stopReason message.StopReason, model string, events ...stream.Event) *eventStream {
// Append a final event with stop reason
events = append(events, stream.Event{
Type: stream.EventTextDelta,
StopReason: stopReason,
Model: model,
})
return &eventStream{events: events, stopReason: stopReason, model: model}
}
func (s *eventStream) Next() bool {
if s.idx >= len(s.events) {
return false
}
s.idx++
return true
}
func (s *eventStream) Current() stream.Event {
return s.events[s.idx-1]
}
func (s *eventStream) Err() error { return nil }
func (s *eventStream) Close() error { return nil }
// --- Mock Tool ---
type mockTool struct {
name string
readOnly bool
execFn func(ctx context.Context, args json.RawMessage) (tool.Result, error)
}
func (m *mockTool) Name() string { return m.name }
func (m *mockTool) Description() string { return "mock tool" }
func (m *mockTool) Parameters() json.RawMessage { return json.RawMessage(`{"type":"object"}`) }
func (m *mockTool) IsReadOnly() bool { return m.readOnly }
func (m *mockTool) IsDestructive() bool { return false }
func (m *mockTool) Execute(ctx context.Context, args json.RawMessage) (tool.Result, error) {
if m.execFn != nil {
return m.execFn(ctx, args)
}
return tool.Result{Output: "mock output"}, nil
}
// --- Tests ---
func TestNew_ValidConfig(t *testing.T) {
e, err := New(Config{
Provider: &mockProvider{name: "test"},
Tools: tool.NewRegistry(),
})
if err != nil {
t.Fatalf("New: %v", err)
}
if e == nil {
t.Fatal("engine should not be nil")
}
}
func TestNew_MissingProvider(t *testing.T) {
_, err := New(Config{Tools: tool.NewRegistry()})
if err == nil {
t.Fatal("expected error for missing provider")
}
}
func TestNew_MissingTools(t *testing.T) {
_, err := New(Config{Provider: &mockProvider{name: "test"}})
if err == nil {
t.Fatal("expected error for missing tool registry")
}
}
func TestSubmit_SimpleTextResponse(t *testing.T) {
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
newEventStream(message.StopEndTurn, "test-model",
stream.Event{Type: stream.EventTextDelta, Text: "Hello "},
stream.Event{Type: stream.EventTextDelta, Text: "world!"},
stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 10, OutputTokens: 5}},
),
},
}
e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry()})
var events []stream.Event
turn, err := e.Submit(context.Background(), "hi", func(evt stream.Event) {
events = append(events, evt)
})
if err != nil {
t.Fatalf("Submit: %v", err)
}
// Check turn
if turn.Rounds != 1 {
t.Errorf("Rounds = %d, want 1", turn.Rounds)
}
if len(turn.Messages) != 1 {
t.Fatalf("len(Messages) = %d, want 1", len(turn.Messages))
}
if turn.Messages[0].TextContent() != "Hello world!" {
t.Errorf("TextContent = %q", turn.Messages[0].TextContent())
}
if turn.Usage.InputTokens != 10 {
t.Errorf("Usage.InputTokens = %d", turn.Usage.InputTokens)
}
// Check history
history := e.History()
if len(history) != 2 {
t.Fatalf("len(History) = %d, want 2 (user + assistant)", len(history))
}
if history[0].Role != message.RoleUser {
t.Errorf("History[0].Role = %q", history[0].Role)
}
if history[1].Role != message.RoleAssistant {
t.Errorf("History[1].Role = %q", history[1].Role)
}
// Check events were forwarded
if len(events) == 0 {
t.Error("callback should have received events")
}
}
func TestSubmit_ToolCallLoop(t *testing.T) {
reg := tool.NewRegistry()
reg.Register(&mockTool{
name: "bash",
execFn: func(_ context.Context, args json.RawMessage) (tool.Result, error) {
return tool.Result{Output: "file1.go\nfile2.go"}, nil
},
})
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
// Round 1: model calls a tool
newEventStream(message.StopToolUse, "model-1",
stream.Event{Type: stream.EventTextDelta, Text: "Let me list files."},
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "bash"},
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_1", Args: json.RawMessage(`{"command":"ls"}`)},
),
// Round 2: model responds with final answer
newEventStream(message.StopEndTurn, "model-1",
stream.Event{Type: stream.EventTextDelta, Text: "Found file1.go and file2.go."},
),
},
}
e, _ := New(Config{Provider: mp, Tools: reg})
turn, err := e.Submit(context.Background(), "list files", nil)
if err != nil {
t.Fatalf("Submit: %v", err)
}
if turn.Rounds != 2 {
t.Errorf("Rounds = %d, want 2", turn.Rounds)
}
// Messages: assistant (tool call), tool results, assistant (final)
if len(turn.Messages) != 3 {
t.Fatalf("len(Messages) = %d, want 3", len(turn.Messages))
}
// First message has tool call
if !turn.Messages[0].HasToolCalls() {
t.Error("Messages[0] should have tool calls")
}
// Second message is tool results
if turn.Messages[1].Role != message.RoleUser {
t.Errorf("Messages[1].Role = %q, want user (tool results)", turn.Messages[1].Role)
}
// Third message is final text
if turn.Messages[2].TextContent() != "Found file1.go and file2.go." {
t.Errorf("Messages[2].TextContent = %q", turn.Messages[2].TextContent())
}
// History: user + assistant(tool call) + tool results + assistant(final)
if len(e.History()) != 4 {
t.Errorf("len(History) = %d, want 4", len(e.History()))
}
// Provider called twice
if mp.calls != 2 {
t.Errorf("provider called %d times, want 2", mp.calls)
}
}
func TestSubmit_UnknownTool(t *testing.T) {
reg := tool.NewRegistry()
// Don't register any tools
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
// Model calls a tool that doesn't exist
newEventStream(message.StopToolUse, "",
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "nonexistent"},
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_1", Args: json.RawMessage(`{}`)},
),
// Model responds after seeing error
newEventStream(message.StopEndTurn, "",
stream.Event{Type: stream.EventTextDelta, Text: "Sorry, that tool doesn't exist."},
),
},
}
e, _ := New(Config{Provider: mp, Tools: reg})
turn, err := e.Submit(context.Background(), "do something", nil)
if err != nil {
t.Fatalf("Submit: %v", err)
}
// Should still complete — unknown tool returns error result, model sees it
if turn.Rounds != 2 {
t.Errorf("Rounds = %d, want 2", turn.Rounds)
}
}
func TestSubmit_ToolExecutionError(t *testing.T) {
reg := tool.NewRegistry()
reg.Register(&mockTool{
name: "failing",
execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) {
return tool.Result{}, errors.New("disk full")
},
})
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
newEventStream(message.StopToolUse, "",
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "failing"},
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_1", Args: json.RawMessage(`{}`)},
),
newEventStream(message.StopEndTurn, "",
stream.Event{Type: stream.EventTextDelta, Text: "The tool failed."},
),
},
}
e, _ := New(Config{Provider: mp, Tools: reg})
turn, err := e.Submit(context.Background(), "do it", nil)
if err != nil {
t.Fatalf("Submit: %v", err)
}
// Tool error is returned as error result, not a fatal error
if turn.Rounds != 2 {
t.Errorf("Rounds = %d, want 2", turn.Rounds)
}
}
func TestSubmit_MaxTurnsLimit(t *testing.T) {
reg := tool.NewRegistry()
reg.Register(&mockTool{name: "bash"})
// Provider always returns tool calls — would loop forever
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
newEventStream(message.StopToolUse, "",
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "bash"},
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_1", Args: json.RawMessage(`{}`)},
),
newEventStream(message.StopToolUse, "",
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_2", ToolCallName: "bash"},
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_2", Args: json.RawMessage(`{}`)},
),
newEventStream(message.StopToolUse, "",
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_3", ToolCallName: "bash"},
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_3", Args: json.RawMessage(`{}`)},
),
},
}
e, _ := New(Config{Provider: mp, Tools: reg, MaxTurns: 2})
_, err := e.Submit(context.Background(), "loop forever", nil)
if err == nil {
t.Fatal("expected error from max turns limit")
}
if mp.calls != 2 {
t.Errorf("provider called %d times, want 2 (limited)", mp.calls)
}
}
func TestSubmit_MultipleToolCalls(t *testing.T) {
reg := tool.NewRegistry()
reg.Register(&mockTool{
name: "bash",
execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) {
return tool.Result{Output: "bash output"}, nil
},
})
reg.Register(&mockTool{
name: "fs.read",
readOnly: true,
execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) {
return tool.Result{Output: "file content"}, nil
},
})
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
// Model calls two tools at once
newEventStream(message.StopToolUse, "",
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "bash"},
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_1", Args: json.RawMessage(`{"command":"ls"}`)},
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_2", ToolCallName: "fs.read"},
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_2", Args: json.RawMessage(`{"path":"go.mod"}`)},
),
newEventStream(message.StopEndTurn, "",
stream.Event{Type: stream.EventTextDelta, Text: "Done."},
),
},
}
e, _ := New(Config{Provider: mp, Tools: reg})
turn, err := e.Submit(context.Background(), "run both", nil)
if err != nil {
t.Fatalf("Submit: %v", err)
}
if turn.Rounds != 2 {
t.Errorf("Rounds = %d, want 2", turn.Rounds)
}
// Tool results message should have 2 results
toolMsg := turn.Messages[1] // assistant, tool_results, assistant
if len(toolMsg.Content) != 2 {
t.Errorf("tool results has %d content blocks, want 2", len(toolMsg.Content))
}
}
func TestSubmit_NilCallback(t *testing.T) {
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
newEventStream(message.StopEndTurn, "",
stream.Event{Type: stream.EventTextDelta, Text: "ok"},
),
},
}
e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry()})
// nil callback should not panic
turn, err := e.Submit(context.Background(), "test", nil)
if err != nil {
t.Fatalf("Submit: %v", err)
}
if turn.Rounds != 1 {
t.Errorf("Rounds = %d", turn.Rounds)
}
}
func TestEngine_Reset(t *testing.T) {
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
newEventStream(message.StopEndTurn, "",
stream.Event{Type: stream.EventTextDelta, Text: "first"},
stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 100}},
),
},
}
e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry()})
e.Submit(context.Background(), "hello", nil)
if len(e.History()) == 0 {
t.Fatal("history should not be empty before reset")
}
if e.Usage().InputTokens == 0 {
t.Fatal("usage should not be zero before reset")
}
e.Reset()
if len(e.History()) != 0 {
t.Errorf("history should be empty after reset, got %d", len(e.History()))
}
if e.Usage().InputTokens != 0 {
t.Errorf("usage should be zero after reset, got %d", e.Usage().InputTokens)
}
}
func TestEngine_Reset_ClearsContextWindow(t *testing.T) {
ctxWindow := gnomactx.NewWindow(gnomactx.WindowConfig{MaxTokens: 200_000})
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
newEventStream(message.StopEndTurn, "",
stream.Event{Type: stream.EventTextDelta, Text: "hi"},
),
},
}
e, _ := New(Config{
Provider: mp,
Tools: tool.NewRegistry(),
Context: ctxWindow,
})
e.Submit(context.Background(), "hello", nil)
if len(ctxWindow.Messages()) == 0 {
t.Fatal("context window should have messages before reset")
}
e.Reset()
if len(ctxWindow.Messages()) != 0 {
t.Errorf("context window should be empty after reset, got %d messages", len(ctxWindow.Messages()))
}
}
func TestSubmit_ContextWindowTracksUserAndToolMessages(t *testing.T) {
reg := tool.NewRegistry()
reg.Register(&mockTool{
name: "bash",
execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) {
return tool.Result{Output: "output"}, nil
},
})
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
newEventStream(message.StopToolUse, "model",
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc1", ToolCallName: "bash"},
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc1", Args: json.RawMessage(`{"command":"ls"}`)},
stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 100, OutputTokens: 20}},
),
newEventStream(message.StopEndTurn, "model",
stream.Event{Type: stream.EventTextDelta, Text: "Done."},
),
},
}
ctxWindow := gnomactx.NewWindow(gnomactx.WindowConfig{MaxTokens: 200_000})
e, _ := New(Config{
Provider: mp,
Tools: reg,
Context: ctxWindow,
})
_, err := e.Submit(context.Background(), "list files", nil)
if err != nil {
t.Fatalf("Submit: %v", err)
}
allMsgs := ctxWindow.AllMessages()
// Expect: user msg, assistant (tool call), tool results, assistant (final)
if len(allMsgs) < 4 {
t.Errorf("context window has %d messages, want at least 4 (user+assistant+tool_results+assistant)", len(allMsgs))
for i, m := range allMsgs {
t.Logf(" [%d] role=%s content=%s", i, m.Role, m.TextContent())
}
}
// First message should be user
if len(allMsgs) > 0 && allMsgs[0].Role != message.RoleUser {
t.Errorf("allMsgs[0].Role = %q, want user", allMsgs[0].Role)
}
}
func TestSubmit_TrackerReflectsInputTokens(t *testing.T) {
// Verify the tracker is set from InputTokens (not accumulated).
// After 3 rounds, tracker should equal last round's InputTokens+OutputTokens,
// not the sum of all rounds.
ctxWindow := gnomactx.NewWindow(gnomactx.WindowConfig{MaxTokens: 200_000})
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
newEventStream(message.StopEndTurn, "",
stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 100, OutputTokens: 50}},
stream.Event{Type: stream.EventTextDelta, Text: "a"},
),
},
}
e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry(), Context: ctxWindow})
e.Submit(context.Background(), "hi", nil)
// Tracker should be InputTokens + OutputTokens = 150, not more
used := ctxWindow.Tracker().Used()
if used != 150 {
t.Errorf("tracker = %d, want 150 (InputTokens+OutputTokens, not cumulative)", used)
}
}
func TestSubmit_CumulativeUsage(t *testing.T) {
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
newEventStream(message.StopEndTurn, "",
stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 100, OutputTokens: 50}},
stream.Event{Type: stream.EventTextDelta, Text: "first"},
),
newEventStream(message.StopEndTurn, "",
stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 200, OutputTokens: 80}},
stream.Event{Type: stream.EventTextDelta, Text: "second"},
),
},
}
e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry()})
e.Submit(context.Background(), "one", nil)
e.Submit(context.Background(), "two", nil)
if e.Usage().InputTokens != 300 {
t.Errorf("cumulative InputTokens = %d, want 300", e.Usage().InputTokens)
}
if e.Usage().OutputTokens != 130 {
t.Errorf("cumulative OutputTokens = %d, want 130", e.Usage().OutputTokens)
}
}