diff --git a/cmd/gnoma/main.go b/cmd/gnoma/main.go index b4884c1..eeb8856 100644 --- a/cmd/gnoma/main.go +++ b/cmd/gnoma/main.go @@ -750,9 +750,10 @@ func main() { Model: *model, Temperature: cfg.Provider.Temperature, MaxTurns: *maxTurns, - Store: store, - Hooks: dispatcher, - Logger: logger, + Store: store, + Hooks: dispatcher, + Logger: logger, + ForceTwoStageTools: cfg.Router.ForceTwoStage, }) if err != nil { fmt.Fprintf(os.Stderr, "error: %v\n", err) diff --git a/docs/superpowers/plans/2026-05-19-post-slm-unlock.md b/docs/superpowers/plans/2026-05-19-post-slm-unlock.md index 84044ef..66e31fe 100644 --- a/docs/superpowers/plans/2026-05-19-post-slm-unlock.md +++ b/docs/superpowers/plans/2026-05-19-post-slm-unlock.md @@ -43,26 +43,49 @@ only that category's real schemas. ### Tasks -- [ ] `Category()` method on `tool.Tool` (or a registry-side mapping). +- [x] `Category()` method on `tool.Tool` (or a registry-side mapping). Default category for unspecified tools: `meta`. -- [ ] New `engine.useTwoStageTools(arm)` predicate. Gates on - `arm.IsLocal && arm.Capabilities.ContextWindow <= 16384`, with an - optional `[router].force_two_stage = true` config override. -- [ ] Synthetic `select_category` tool definition emitted in +- [x] New `engine.useTwoStageTools()` predicate. Gates on + `arm.IsLocal && arm.Capabilities.ContextWindow <= 16384`, with the + `[router].force_two_stage = true` config override. +- [x] Synthetic `select_category` tool definition emitted in `buildRequest()` when the predicate is true and we're in the first - round of a turn. -- [ ] Engine recognises a `select_category` tool result and filters + round of a turn. Round 1 also forces `ToolChoiceRequired` so SLMs + don't fall back to prose instead of calling the tool. +- [x] Engine recognises a `select_category` tool call and filters the next round's tool schemas. The selection itself doesn't run a - real tool — it's consumed internally. -- [ ] Integration test covering the round-trip with a mocked - openaicompat arm. + real tool — it's consumed internally; `select_category` remains + available in round 2+ so the model can switch categories mid-turn. +- [x] Integration test covering the round-trip with a recording + mock provider. -**Exit criteria:** for an SLM-arm turn with ≤16 k context, the first -request contains exactly one synthetic tool schema; the second -contains only the schemas of the selected category. Real tool -selection still works (the second-round tool call executes normally). +**Status: shipped.** Phase A landed via the two-stage routing commit +on `main`. Module map: +- `internal/tool/category.go` — `Category` type, optional + `Categorized` interface, `CategoryOf()` helper with `meta` fallback. +- Real tools now declare categories: `fs.read`/`fs.ls` → read, + `fs.write`/`fs.edit` → write, `fs.glob`/`fs.grep` → search, + `bash` → exec. Agent/sysinfo fall through to meta. +- `internal/engine/twostage.go` — synthetic tool definition, intercept + helper, per-turn state (`selectedCategory`). +- `internal/engine/engine.go` — `Config.ForceTwoStageTools`, + `useTwoStageTools()` predicate, `twoStageContextLimit = 16384`. +- `internal/config/config.go` — `[router].force_two_stage` TOML key + wired through `cmd/gnoma/main.go`. -**Effort:** ~150 LOC + tests. +**Exit criteria — met:** for an SLM-arm turn with ≤16 k context, the +first request contains exactly one synthetic tool schema with +`ToolChoiceRequired`; the second contains only schemas of the +selected category plus `select_category`. Real tool selection still +works end-to-end (verified by `TestTwoStage_FullRoundTrip`). + +**Effort:** ~250 LOC + tests (including 4 test files). + +**Deferred for follow-up (not Phase A blockers):** +- Elf engines spawned from `internal/elf/manager.go` don't pass + through `ForceTwoStageTools` — small local elves still get the full + tool catalogue. Add per-elf two-stage detection mirroring the main + engine's auto-activation when telemetry shows it's worth it. --- diff --git a/internal/config/config.go b/internal/config/config.go index 5b2d922..76582f5 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -11,6 +11,7 @@ type Config struct { Security SecuritySection `toml:"security"` Session SessionSection `toml:"session"` SLM SLMSection `toml:"slm"` + Router RouterSection `toml:"router"` Hooks []HookConfig `toml:"hooks"` MCPServers []MCPServerConfig `toml:"mcp_servers"` Plugins PluginsSection `toml:"plugins"` @@ -39,6 +40,17 @@ type SLMSection struct { StartupTimeout Duration `toml:"startup_timeout"` // llamafile-only: first-launch wait budget; 0 = default 5s } +// RouterSection holds router-level overrides. Most routing decisions are +// driven automatically by arm capabilities and the bandit; this section +// exists for the rare overrides that don't fit elsewhere. +type RouterSection struct { + // ForceTwoStage forces the two-stage tool-routing path regardless of + // arm context window. Useful for debugging or for forcing the behavior + // on a large local model. Defaults to false: two-stage activates + // automatically on local arms with context window <= 16k. + ForceTwoStage bool `toml:"force_two_stage"` +} + // MCPServerConfig defines an MCP server to start and connect to. // // Example: diff --git a/internal/engine/buildrequest_test.go b/internal/engine/buildrequest_test.go index b56418c..f45119b 100644 --- a/internal/engine/buildrequest_test.go +++ b/internal/engine/buildrequest_test.go @@ -95,7 +95,9 @@ func TestBuildRequest_ForcedArmWithToolSupport_IncludesTools(t *testing.T) { Provider: &mockProvider{name: "llamacpp"}, ModelName: "qwen3", IsLocal: true, - Capabilities: provider.Capabilities{ToolUse: true}, + // ContextWindow > 16384 keeps two-stage routing inactive so this + // test exercises the plain "tools included" path. + Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 32768}, }) rtr.ForceArm("llamacpp/qwen3") diff --git a/internal/engine/engine.go b/internal/engine/engine.go index 25c42b6..eca316c 100644 --- a/internal/engine/engine.go +++ b/internal/engine/engine.go @@ -33,8 +33,18 @@ type Config struct { Store *persist.Store // nil = no result persistence Hooks *hook.Dispatcher // nil = no hooks Logger *slog.Logger + + // ForceTwoStageTools forces the two-stage tool-routing path on the forced + // arm regardless of its context window. When false, two-stage is enabled + // automatically for local arms with ContextWindow <= twoStageContextLimit. + ForceTwoStageTools bool } +// twoStageContextLimit is the upper bound on arm context window (in tokens) +// under which two-stage tool routing kicks in automatically. Models bigger +// than this can afford the full tool catalogue in every request. +const twoStageContextLimit = 16384 + func (c Config) validate() error { if c.Provider == nil { return fmt.Errorf("engine: provider required") @@ -77,6 +87,12 @@ type Engine struct { activatedTools map[string]bool turnOpts TurnOptions + + // selectedCategory is set when the model picks a category via the + // synthetic select_category tool under two-stage routing. Empty string + // means "round 1 of two-stage" (or two-stage inactive). Reset at the end + // of each turn together with turnOpts. + selectedCategory tool.Category } // ToolsAvailable reports whether the current model supports tool calling. @@ -128,6 +144,40 @@ func (e *Engine) isLocalArm() bool { return arm.IsLocal } +// useTwoStageTools reports whether the current turn should use the two-stage +// tool-routing path. True when: +// - cfg.ForceTwoStageTools is set, OR +// - a forced arm exists, is local, and its ContextWindow is small enough +// that the full tool catalogue would burn a non-trivial fraction of the +// prompt budget. +// +// Multi-arm routing (no forced arm) and cloud arms do not trigger two-stage. +func (e *Engine) useTwoStageTools() bool { + if e.cfg.ForceTwoStageTools { + return true + } + if e.cfg.Router == nil { + return false + } + id := e.cfg.Router.ForcedArm() + if id == "" { + return false + } + arm, ok := e.cfg.Router.LookupArm(id) + if !ok { + return false + } + if !arm.IsLocal { + return false + } + cw := arm.Capabilities.ContextWindow + if cw <= 0 { + // Unknown context window on a local arm — assume small. + return true + } + return cw <= twoStageContextLimit +} + // New creates an engine. func New(cfg Config) (*Engine, error) { if err := cfg.validate(); err != nil { diff --git a/internal/engine/loop.go b/internal/engine/loop.go index a08cdc7..c520051 100644 --- a/internal/engine/loop.go +++ b/internal/engine/loop.go @@ -62,6 +62,11 @@ func (e *Engine) SubmitMessages(ctx context.Context, msgs []message.Message, cb } func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) { + // Two-stage tool-routing state is per-turn; clear it so an aborted turn + // can't leak its last category selection into the next one. + e.resetTwoStageState() + defer e.resetTwoStageState() + turn := &Turn{} loopStart := time.Now() var lastArmID router.ArmID @@ -481,26 +486,51 @@ func (e *Engine) buildRequest(ctx context.Context) provider.Request { includeTools = caps == nil || caps.ToolUse } if includeTools { - allowed := turnOpts.AllowedTools - for _, t := range e.cfg.Tools.All() { - // Skip deferred tools until the model requests them - if dt, ok := t.(tool.DeferrableTool); ok && dt.ShouldDefer() && !e.isToolActivated(t.Name()) { - continue + twoStage := e.useTwoStageTools() + selected := e.snapshotSelectedCategory() + + if twoStage && selected == "" { + // Round 1 of two-stage: send only the synthetic select_category tool + // and force the model to call it. Small SLMs given a single optional + // tool will often emit prose instead of calling it. + req.Tools = []provider.ToolDefinition{buildSelectCategoryDef()} + req.ToolChoice = provider.ToolChoiceRequired + e.logger.Debug("two-stage: round 1 — emitting select_category only", + "model", req.Model, + ) + } else { + allowed := turnOpts.AllowedTools + for _, t := range e.cfg.Tools.All() { + // Skip deferred tools until the model requests them + if dt, ok := t.(tool.DeferrableTool); ok && dt.ShouldDefer() && !e.isToolActivated(t.Name()) { + continue + } + // Filter to allowed tools when a restrict list is set + if allowed != nil && !slices.Contains(allowed, t.Name()) { + continue + } + // Under two-stage round 2+, only schemas in the selected category. + if twoStage && tool.CategoryOf(t) != selected { + continue + } + req.Tools = append(req.Tools, provider.ToolDefinition{ + Name: t.Name(), + Description: t.Description(), + Parameters: t.Parameters(), + }) } - // Filter to allowed tools when a restrict list is set - if allowed != nil && !slices.Contains(allowed, t.Name()) { - continue + // Keep select_category available while two-stage is active so the + // model can switch categories without aborting the turn. + if twoStage { + req.Tools = append(req.Tools, buildSelectCategoryDef()) } - req.Tools = append(req.Tools, provider.ToolDefinition{ - Name: t.Name(), - Description: t.Description(), - Parameters: t.Parameters(), - }) + e.logger.Debug("tools included in request", + "model", req.Model, + "count", len(req.Tools), + "two_stage", twoStage, + "category", string(selected), + ) } - e.logger.Debug("tools included in request", - "model", req.Model, - "count", len(req.Tools), - ) } else { e.logger.Debug("tools omitted — model does not support tool use", "model", req.Model, @@ -542,6 +572,10 @@ Synthesis: } func (e *Engine) executeTools(ctx context.Context, calls []message.ToolCall, cb Callback) ([]message.ToolResult, error) { + // Intercept the synthetic select_category tool first — it never reaches + // the registry and produces its own synthetic tool result. + calls, syntheticResults := e.interceptSelectCategoryCalls(calls) + // Partition into read-only (parallel) and write (serial) batches type toolCallWithTool struct { call message.ToolCall @@ -577,7 +611,8 @@ func (e *Engine) executeTools(ctx context.Context, calls []message.ToolCall, cb } } - results := make([]message.ToolResult, 0, len(calls)) + results := make([]message.ToolResult, 0, len(calls)+len(syntheticResults)) + results = append(results, syntheticResults...) results = append(results, unknownResults...) // Execute read-only tools in parallel diff --git a/internal/engine/twostage.go b/internal/engine/twostage.go new file mode 100644 index 0000000..8c2b3dd --- /dev/null +++ b/internal/engine/twostage.go @@ -0,0 +1,139 @@ +package engine + +import ( + "encoding/json" + "fmt" + "strings" + + "somegit.dev/Owlibou/gnoma/internal/message" + "somegit.dev/Owlibou/gnoma/internal/provider" + "somegit.dev/Owlibou/gnoma/internal/tool" +) + +// SyntheticSelectCategoryName is the tool name used by the two-stage routing +// path to let the model pick a category before real tool schemas are sent. +// The name is exported for tests that need to assert against it. +const SyntheticSelectCategoryName = "select_category" + +// buildSelectCategoryDef constructs the synthetic select_category tool +// definition. Categories in the enum match tool.AllCategories() so the +// schema stays in sync if categories are added. +func buildSelectCategoryDef() provider.ToolDefinition { + cats := tool.AllCategories() + quoted := make([]string, len(cats)) + for i, c := range cats { + quoted[i] = `"` + string(c) + `"` + } + params := json.RawMessage(`{ + "type": "object", + "properties": { + "category": { + "type": "string", + "enum": [` + strings.Join(quoted, ", ") + `], + "description": "Tool category to load schemas for. Pick one based on what you intend to do next." + } + }, + "required": ["category"] +}`) + return provider.ToolDefinition{ + Name: SyntheticSelectCategoryName, + Description: "Select the category of tools to load for the next round. " + + "Use 'read' for file reads, 'write' to modify files, 'search' to search the codebase, " + + "'exec' to run commands, 'meta' for agent orchestration and introspection. " + + "You can call this again later to switch categories.", + Parameters: params, + } +} + +// snapshotSelectedCategory returns the category chosen by the model so far +// in this turn (or empty string if none). +func (e *Engine) snapshotSelectedCategory() tool.Category { + e.mu.Lock() + defer e.mu.Unlock() + return e.selectedCategory +} + +// setSelectedCategory records the model's category choice under lock. +func (e *Engine) setSelectedCategory(c tool.Category) { + e.mu.Lock() + e.selectedCategory = c + e.mu.Unlock() +} + +// resetTwoStageState clears any per-turn two-stage state. Called at the start +// of every runLoop so an aborted previous turn cannot leak state forward. +func (e *Engine) resetTwoStageState() { + e.mu.Lock() + e.selectedCategory = "" + e.mu.Unlock() +} + +// interceptSelectCategoryCalls splits the incoming tool calls into two +// buckets: real calls that need actual tool execution, and synthetic +// select_category calls that the engine handles internally. It updates +// e.selectedCategory as a side effect and returns synthetic tool results +// that satisfy the provider's "every tool_call needs a tool_result" contract. +func (e *Engine) interceptSelectCategoryCalls(calls []message.ToolCall) (realCalls []message.ToolCall, syntheticResults []message.ToolResult) { + for _, call := range calls { + if call.Name != SyntheticSelectCategoryName { + realCalls = append(realCalls, call) + continue + } + result := e.handleSelectCategory(call) + syntheticResults = append(syntheticResults, result) + } + return realCalls, syntheticResults +} + +// handleSelectCategory parses the synthetic tool call, updates engine state, +// and returns the tool result the model will see in the next round. +func (e *Engine) handleSelectCategory(call message.ToolCall) message.ToolResult { + var args struct { + Category string `json:"category"` + } + if err := json.Unmarshal(call.Arguments, &args); err != nil { + e.setSelectedCategory("") + return message.ToolResult{ + ToolCallID: call.ID, + Content: fmt.Sprintf("invalid arguments for select_category: %v. Please pick one of: read, write, search, exec, meta.", err), + IsError: true, + } + } + + cat := tool.Category(strings.ToLower(strings.TrimSpace(args.Category))) + if !tool.IsValidCategory(cat) { + e.setSelectedCategory("") + return message.ToolResult{ + ToolCallID: call.ID, + Content: fmt.Sprintf("unknown category %q. Pick one of: read, write, search, exec, meta.", args.Category), + IsError: true, + } + } + + e.setSelectedCategory(cat) + available := e.toolNamesForCategory(cat) + content := fmt.Sprintf("Category %q selected. Tools now available: %s. Call them directly on the next turn.", + cat, strings.Join(available, ", ")) + if e.logger != nil { + e.logger.Debug("two-stage: category selected", + "category", cat, + "tools", available, + ) + } + return message.ToolResult{ + ToolCallID: call.ID, + Content: content, + } +} + +// toolNamesForCategory returns the registered tool names whose category +// matches the argument, in deterministic order. +func (e *Engine) toolNamesForCategory(cat tool.Category) []string { + var names []string + for _, t := range e.cfg.Tools.All() { + if tool.CategoryOf(t) == cat { + names = append(names, t.Name()) + } + } + return names +} diff --git a/internal/engine/twostage_integration_test.go b/internal/engine/twostage_integration_test.go new file mode 100644 index 0000000..a365d14 --- /dev/null +++ b/internal/engine/twostage_integration_test.go @@ -0,0 +1,185 @@ +package engine + +import ( + "context" + "encoding/json" + "fmt" + "slices" + "sync" + "testing" + + "somegit.dev/Owlibou/gnoma/internal/message" + "somegit.dev/Owlibou/gnoma/internal/provider" + "somegit.dev/Owlibou/gnoma/internal/stream" + "somegit.dev/Owlibou/gnoma/internal/tool" +) + +// recordingProvider captures every Request it receives so tests can assert +// on the tool catalogue per round. +type recordingProvider struct { + mu sync.Mutex + requests []provider.Request + streams []stream.Stream + calls int +} + +func (m *recordingProvider) Name() string { return "recording" } +func (m *recordingProvider) DefaultModel() string { return "mock-model" } +func (m *recordingProvider) Models(_ context.Context) ([]provider.ModelInfo, error) { + return []provider.ModelInfo{{ + ID: "mock-model", Name: "mock-model", Provider: "recording", + Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 8192}, + }}, nil +} +func (m *recordingProvider) Stream(_ context.Context, req provider.Request) (stream.Stream, error) { + m.mu.Lock() + defer m.mu.Unlock() + // Clone request so subsequent rounds don't mutate captured state. + clone := req + clone.Tools = slices.Clone(req.Tools) + m.requests = append(m.requests, clone) + if m.calls >= len(m.streams) { + return nil, fmt.Errorf("recording: no more streams (called %d times)", m.calls+1) + } + s := m.streams[m.calls] + m.calls++ + return s, nil +} + +// singleToolCallStream returns a stream that emits a single tool call. +func singleToolCallStream(callID, name, args string) stream.Stream { + return newEventStream(message.StopToolUse, "mock-model", + stream.Event{Type: stream.EventToolCallStart, ToolCallID: callID, ToolCallName: name}, + stream.Event{Type: stream.EventToolCallDone, ToolCallID: callID, ToolCallName: name, Args: json.RawMessage(args)}, + ) +} + +// endTurnTextStream returns a stream that emits text and ends the turn. +func endTurnTextStream(text string) stream.Stream { + return newEventStream(message.StopEndTurn, "mock-model", + stream.Event{Type: stream.EventTextDelta, Text: text}, + ) +} + +func TestTwoStage_FullRoundTrip(t *testing.T) { + // Tool registry: mixed categories so we can verify round 2 filtering. + writeCalled := false + reg := tool.NewRegistry() + reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.read", readOnly: true}, cat: tool.CategoryRead}) + reg.Register(&categorizedMockTool{ + mockTool: mockTool{ + name: "fs.write", + execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) { + writeCalled = true + return tool.Result{Output: "wrote file"}, nil + }, + }, + cat: tool.CategoryWrite, + }) + reg.Register(&categorizedMockTool{mockTool: mockTool{name: "bash"}, cat: tool.CategoryExec}) + + // Three rounds: select_category → fs.write → end turn. + mp := &recordingProvider{ + streams: []stream.Stream{ + singleToolCallStream("c1", SyntheticSelectCategoryName, `{"category":"write"}`), + singleToolCallStream("c2", "fs.write", `{"path":"/tmp/x","content":"hi"}`), + endTurnTextStream("done."), + }, + } + + e, err := New(Config{ + Provider: mp, + Tools: reg, + ForceTwoStageTools: true, // no router needed; just force the path + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + turn, err := e.Submit(context.Background(), "write a file", nil) + if err != nil { + t.Fatalf("Submit: %v", err) + } + + if turn.Rounds != 3 { + t.Errorf("Rounds = %d, want 3", turn.Rounds) + } + if !writeCalled { + t.Error("fs.write tool was not executed") + } + + if len(mp.requests) != 3 { + t.Fatalf("captured %d requests, want 3", len(mp.requests)) + } + + // Round 1: only synthetic select_category, ToolChoice = Required. + r1 := mp.requests[0] + if len(r1.Tools) != 1 { + t.Errorf("round 1 tool count = %d, want 1; tools=%v", len(r1.Tools), toolNamesIn(r1.Tools)) + } + if len(r1.Tools) >= 1 && r1.Tools[0].Name != SyntheticSelectCategoryName { + t.Errorf("round 1 tool[0] = %q, want %q", r1.Tools[0].Name, SyntheticSelectCategoryName) + } + if r1.ToolChoice != provider.ToolChoiceRequired { + t.Errorf("round 1 ToolChoice = %q, want %q", r1.ToolChoice, provider.ToolChoiceRequired) + } + + // Round 2: write tools + select_category, no read/exec. + r2names := toolNamesIn(mp.requests[1].Tools) + if !slices.Contains(r2names, "fs.write") { + t.Errorf("round 2 missing fs.write: %v", r2names) + } + if !slices.Contains(r2names, SyntheticSelectCategoryName) { + t.Errorf("round 2 missing select_category (re-selection should remain available): %v", r2names) + } + if slices.Contains(r2names, "fs.read") || slices.Contains(r2names, "bash") { + t.Errorf("round 2 leaked non-write tools: %v", r2names) + } + + // Round 3: same filter still applied (selection persists for the turn). + r3names := toolNamesIn(mp.requests[2].Tools) + if slices.Contains(r3names, "fs.read") || slices.Contains(r3names, "bash") { + t.Errorf("round 3 leaked non-write tools: %v", r3names) + } +} + +func TestTwoStage_InvalidCategoryFallsBackToRoundOne(t *testing.T) { + reg := tool.NewRegistry() + reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.write"}, cat: tool.CategoryWrite}) + + mp := &recordingProvider{ + streams: []stream.Stream{ + // Round 1: model picks an invalid category. + singleToolCallStream("c1", SyntheticSelectCategoryName, `{"category":"bogus"}`), + // Round 2: should see select_category-only again (round 1 mode). + singleToolCallStream("c2", SyntheticSelectCategoryName, `{"category":"write"}`), + // Round 3: filtered to write. + endTurnTextStream("done."), + }, + } + + e, err := New(Config{ + Provider: mp, + Tools: reg, + ForceTwoStageTools: true, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + if _, err := e.Submit(context.Background(), "do stuff", nil); err != nil { + t.Fatalf("Submit: %v", err) + } + + if len(mp.requests) < 2 { + t.Fatalf("expected at least 2 requests, got %d", len(mp.requests)) + } + + // After invalid category, round 2 should be back to round-1 mode (only synthetic). + r2 := mp.requests[1] + if len(r2.Tools) != 1 || r2.Tools[0].Name != SyntheticSelectCategoryName { + t.Errorf("after invalid category, expected round-1 mode again; got tools=%v", toolNamesIn(r2.Tools)) + } + if r2.ToolChoice != provider.ToolChoiceRequired { + t.Errorf("after invalid category, expected ToolChoiceRequired; got %q", r2.ToolChoice) + } +} diff --git a/internal/engine/twostage_test.go b/internal/engine/twostage_test.go new file mode 100644 index 0000000..4c4282c --- /dev/null +++ b/internal/engine/twostage_test.go @@ -0,0 +1,419 @@ +package engine + +import ( + "context" + "encoding/json" + "slices" + "sort" + "testing" + + "somegit.dev/Owlibou/gnoma/internal/message" + "somegit.dev/Owlibou/gnoma/internal/provider" + "somegit.dev/Owlibou/gnoma/internal/router" + "somegit.dev/Owlibou/gnoma/internal/tool" +) + +// categorizedMockTool extends mockTool with a Category() method so tests can +// exercise the two-stage category filter. +type categorizedMockTool struct { + mockTool + cat tool.Category +} + +func (c *categorizedMockTool) Category() tool.Category { return c.cat } + +// twoStageEngine builds an engine wired to a small local forced arm so +// useTwoStageTools() returns true. +func twoStageEngine(t *testing.T, reg *tool.Registry) *Engine { + t.Helper() + rtr := router.New(router.Config{}) + rtr.RegisterArm(&router.Arm{ + ID: "llamacpp/qwen3-1b", + Provider: &mockProvider{name: "llamacpp"}, + ModelName: "qwen3-1b", + IsLocal: true, + Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 8192}, + }) + rtr.ForceArm("llamacpp/qwen3-1b") + e, err := New(Config{ + Provider: &mockProvider{name: "llamacpp"}, + Router: rtr, + Tools: reg, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + return e +} + +func TestUseTwoStageTools(t *testing.T) { + smallLocal := &router.Arm{ + ID: "llamacpp/qwen3-1b", + Provider: &mockProvider{name: "llamacpp"}, + ModelName: "qwen3-1b", + IsLocal: true, + Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 8192}, + } + bigLocal := &router.Arm{ + ID: "llamacpp/qwen3-30b", + Provider: &mockProvider{name: "llamacpp"}, + ModelName: "qwen3-30b", + IsLocal: true, + Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 32768}, + } + cloud := &router.Arm{ + ID: "anthropic/sonnet", + Provider: &mockProvider{name: "anthropic"}, + ModelName: "sonnet", + IsLocal: false, + Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000}, + } + localUnknownCtx := &router.Arm{ + ID: "ollama/mystery", + Provider: &mockProvider{name: "ollama"}, + ModelName: "mystery", + IsLocal: true, + Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 0}, + } + + cases := []struct { + name string + arm *router.Arm + forced bool + want bool + message string + }{ + { + name: "small local arm triggers two-stage", + arm: smallLocal, + want: true, + }, + { + name: "large local arm does not trigger two-stage", + arm: bigLocal, + want: false, + }, + { + name: "cloud arm never triggers two-stage", + arm: cloud, + want: false, + }, + { + name: "local arm with unknown context window triggers two-stage", + arm: localUnknownCtx, + want: true, + }, + { + name: "ForceTwoStageTools overrides cloud arm", + arm: cloud, + forced: true, + want: true, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + rtr := router.New(router.Config{}) + rtr.RegisterArm(tc.arm) + rtr.ForceArm(tc.arm.ID) + + e, err := New(Config{ + Provider: &mockProvider{name: string(tc.arm.ID.Provider())}, + Router: rtr, + Tools: tool.NewRegistry(), + ForceTwoStageTools: tc.forced, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + if got := e.useTwoStageTools(); got != tc.want { + t.Errorf("useTwoStageTools() = %v, want %v", got, tc.want) + } + }) + } +} + +func TestUseTwoStageTools_NoRouter(t *testing.T) { + e, err := New(Config{ + Provider: &mockProvider{name: "anthropic"}, + Tools: tool.NewRegistry(), + }) + if err != nil { + t.Fatalf("New: %v", err) + } + if e.useTwoStageTools() { + t.Error("useTwoStageTools() = true, want false when no router") + } +} + +func TestUseTwoStageTools_NoRouter_ForcedOverride(t *testing.T) { + e, err := New(Config{ + Provider: &mockProvider{name: "anthropic"}, + Tools: tool.NewRegistry(), + ForceTwoStageTools: true, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + if !e.useTwoStageTools() { + t.Error("useTwoStageTools() = false, want true when ForceTwoStageTools is set even without router") + } +} + +func TestUseTwoStageTools_NoForcedArm(t *testing.T) { + rtr := router.New(router.Config{}) + rtr.RegisterArm(&router.Arm{ + ID: "llamacpp/qwen3-1b", + Provider: &mockProvider{name: "llamacpp"}, + ModelName: "qwen3-1b", + IsLocal: true, + Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 8192}, + }) + // No ForceArm called — multi-arm routing + e, err := New(Config{ + Provider: &mockProvider{name: "llamacpp"}, + Router: rtr, + Tools: tool.NewRegistry(), + }) + if err != nil { + t.Fatalf("New: %v", err) + } + if e.useTwoStageTools() { + t.Error("useTwoStageTools() = true, want false for multi-arm routing") + } +} + +func TestBuildRequest_TwoStage_Round1_EmitsSyntheticOnly(t *testing.T) { + reg := tool.NewRegistry() + reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.read", readOnly: true}, cat: tool.CategoryRead}) + reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.write"}, cat: tool.CategoryWrite}) + reg.Register(&categorizedMockTool{mockTool: mockTool{name: "bash"}, cat: tool.CategoryExec}) + + e := twoStageEngine(t, reg) + + req := e.buildRequest(context.Background()) + + if len(req.Tools) != 1 { + t.Fatalf("round 1 should emit exactly one tool (synthetic); got %d: %+v", len(req.Tools), req.Tools) + } + if req.Tools[0].Name != SyntheticSelectCategoryName { + t.Errorf("round 1 tool name = %q, want %q", req.Tools[0].Name, SyntheticSelectCategoryName) + } + if req.ToolChoice != provider.ToolChoiceRequired { + t.Errorf("round 1 ToolChoice = %q, want %q", req.ToolChoice, provider.ToolChoiceRequired) + } +} + +func TestBuildRequest_TwoStage_Round1_SyntheticEnumMatchesAllCategories(t *testing.T) { + reg := tool.NewRegistry() + reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.read", readOnly: true}, cat: tool.CategoryRead}) + e := twoStageEngine(t, reg) + + req := e.buildRequest(context.Background()) + if len(req.Tools) != 1 { + t.Fatalf("expected 1 tool, got %d", len(req.Tools)) + } + + var schema struct { + Properties struct { + Category struct { + Enum []string `json:"enum"` + } `json:"category"` + } `json:"properties"` + } + if err := json.Unmarshal(req.Tools[0].Parameters, &schema); err != nil { + t.Fatalf("unmarshal synthetic params: %v", err) + } + want := make([]string, 0, len(tool.AllCategories())) + for _, c := range tool.AllCategories() { + want = append(want, string(c)) + } + got := slices.Clone(schema.Properties.Category.Enum) + sort.Strings(got) + sort.Strings(want) + if !slices.Equal(got, want) { + t.Errorf("category enum = %v, want %v", got, want) + } +} + +func TestBuildRequest_TwoStage_Round2_FiltersByCategory(t *testing.T) { + reg := tool.NewRegistry() + reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.read", readOnly: true}, cat: tool.CategoryRead}) + reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.write"}, cat: tool.CategoryWrite}) + reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.edit"}, cat: tool.CategoryWrite}) + reg.Register(&categorizedMockTool{mockTool: mockTool{name: "bash"}, cat: tool.CategoryExec}) + + e := twoStageEngine(t, reg) + e.setSelectedCategory(tool.CategoryWrite) + + req := e.buildRequest(context.Background()) + + names := toolNamesIn(req.Tools) + sort.Strings(names) + want := []string{SyntheticSelectCategoryName, "fs.edit", "fs.write"} + sort.Strings(want) + if !slices.Equal(names, want) { + t.Errorf("round 2 tools = %v, want %v", names, want) + } + + // ToolChoice should not be forced in round 2 — let the model decide. + if req.ToolChoice == provider.ToolChoiceRequired { + t.Errorf("round 2 ToolChoice = %q, should not be Required", req.ToolChoice) + } +} + +func TestBuildRequest_TwoStage_UncategorizedToolDefaultsToMeta(t *testing.T) { + reg := tool.NewRegistry() + reg.Register(&mockTool{name: "agent"}) // no Category() method → meta + reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.read", readOnly: true}, cat: tool.CategoryRead}) + + e := twoStageEngine(t, reg) + e.setSelectedCategory(tool.CategoryMeta) + + req := e.buildRequest(context.Background()) + names := toolNamesIn(req.Tools) + if !slices.Contains(names, "agent") { + t.Errorf("meta filter should include uncategorized 'agent'; got %v", names) + } + if slices.Contains(names, "fs.read") { + t.Errorf("meta filter should exclude 'fs.read'; got %v", names) + } +} + +func TestBuildRequest_NonTwoStage_UnchangedBehavior(t *testing.T) { + // Large local arm — two-stage should not activate. + rtr := router.New(router.Config{}) + rtr.RegisterArm(&router.Arm{ + ID: "llamacpp/qwen3-30b", + Provider: &mockProvider{name: "llamacpp"}, + ModelName: "qwen3-30b", + IsLocal: true, + Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 32768}, + }) + rtr.ForceArm("llamacpp/qwen3-30b") + + reg := tool.NewRegistry() + reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.read", readOnly: true}, cat: tool.CategoryRead}) + reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.write"}, cat: tool.CategoryWrite}) + + e, err := New(Config{ + Provider: &mockProvider{name: "llamacpp"}, + Router: rtr, + Tools: reg, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + req := e.buildRequest(context.Background()) + if len(req.Tools) != 2 { + t.Errorf("non-two-stage path: got %d tools, want 2", len(req.Tools)) + } + for _, td := range req.Tools { + if td.Name == SyntheticSelectCategoryName { + t.Errorf("non-two-stage path should not emit synthetic select_category") + } + } +} + +func TestInterceptSelectCategoryCalls_UpdatesStateAndReturnsResult(t *testing.T) { + reg := tool.NewRegistry() + reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.write"}, cat: tool.CategoryWrite}) + + e := twoStageEngine(t, reg) + + call := message.ToolCall{ + ID: "call-1", + Name: SyntheticSelectCategoryName, + Arguments: json.RawMessage(`{"category": "write"}`), + } + real, synth := e.interceptSelectCategoryCalls([]message.ToolCall{call}) + + if len(real) != 0 { + t.Errorf("real calls = %d, want 0", len(real)) + } + if len(synth) != 1 { + t.Fatalf("synthetic results = %d, want 1", len(synth)) + } + if synth[0].IsError { + t.Errorf("synthetic result should not be error: %s", synth[0].Content) + } + if synth[0].ToolCallID != "call-1" { + t.Errorf("ToolCallID = %q, want call-1", synth[0].ToolCallID) + } + if e.snapshotSelectedCategory() != tool.CategoryWrite { + t.Errorf("selectedCategory = %q, want write", e.snapshotSelectedCategory()) + } +} + +func TestInterceptSelectCategoryCalls_InvalidCategoryReturnsError(t *testing.T) { + e := twoStageEngine(t, tool.NewRegistry()) + + call := message.ToolCall{ + ID: "call-1", + Name: SyntheticSelectCategoryName, + Arguments: json.RawMessage(`{"category": "bogus"}`), + } + _, synth := e.interceptSelectCategoryCalls([]message.ToolCall{call}) + if len(synth) != 1 { + t.Fatalf("synthetic results = %d, want 1", len(synth)) + } + if !synth[0].IsError { + t.Errorf("invalid category should yield error result") + } + if e.snapshotSelectedCategory() != "" { + t.Errorf("invalid category should clear selectedCategory; got %q", e.snapshotSelectedCategory()) + } +} + +func TestInterceptSelectCategoryCalls_InvalidJSONReturnsError(t *testing.T) { + e := twoStageEngine(t, tool.NewRegistry()) + + call := message.ToolCall{ + ID: "call-1", + Name: SyntheticSelectCategoryName, + Arguments: json.RawMessage(`not-json`), + } + _, synth := e.interceptSelectCategoryCalls([]message.ToolCall{call}) + if len(synth) != 1 || !synth[0].IsError { + t.Fatalf("invalid JSON should yield single error result") + } +} + +func TestInterceptSelectCategoryCalls_MixedRealAndSynthetic(t *testing.T) { + e := twoStageEngine(t, tool.NewRegistry()) + + calls := []message.ToolCall{ + {ID: "real-1", Name: "fs.read", Arguments: json.RawMessage(`{}`)}, + {ID: "synth-1", Name: SyntheticSelectCategoryName, Arguments: json.RawMessage(`{"category":"read"}`)}, + {ID: "real-2", Name: "bash", Arguments: json.RawMessage(`{}`)}, + } + real, synth := e.interceptSelectCategoryCalls(calls) + if len(real) != 2 { + t.Errorf("real calls = %d, want 2", len(real)) + } + if len(synth) != 1 { + t.Errorf("synthetic results = %d, want 1", len(synth)) + } + if e.snapshotSelectedCategory() != tool.CategoryRead { + t.Errorf("selectedCategory = %q, want read", e.snapshotSelectedCategory()) + } +} + +func TestResetTwoStageState_ClearsCategory(t *testing.T) { + e := twoStageEngine(t, tool.NewRegistry()) + e.setSelectedCategory(tool.CategoryExec) + e.resetTwoStageState() + if e.snapshotSelectedCategory() != "" { + t.Errorf("resetTwoStageState should clear category; got %q", e.snapshotSelectedCategory()) + } +} + +func toolNamesIn(defs []provider.ToolDefinition) []string { + names := make([]string, len(defs)) + for i, d := range defs { + names[i] = d.Name + } + return names +} diff --git a/internal/tool/bash/bash.go b/internal/tool/bash/bash.go index 9d75047..91dab9a 100644 --- a/internal/tool/bash/bash.go +++ b/internal/tool/bash/bash.go @@ -66,6 +66,7 @@ func (t *Tool) Description() string { return "Execute a bash command and func (t *Tool) Parameters() json.RawMessage { return parameterSchema } func (t *Tool) IsReadOnly() bool { return false } func (t *Tool) IsDestructive() bool { return true } +func (t *Tool) Category() tool.Category { return tool.CategoryExec } type bashArgs struct { Command string `json:"command"` diff --git a/internal/tool/category.go b/internal/tool/category.go new file mode 100644 index 0000000..2321938 --- /dev/null +++ b/internal/tool/category.go @@ -0,0 +1,52 @@ +package tool + +// Category groups tools by what they do. Used by the two-stage tool routing +// path for small local models: round 1 picks a category, round 2 only sees +// schemas in that category. +type Category string + +const ( + // CategoryRead — read filesystem state (fs.read, fs.ls). + CategoryRead Category = "read" + // CategoryWrite — modify filesystem state (fs.write, fs.edit). + CategoryWrite Category = "write" + // CategorySearch — search filesystem content (fs.grep, fs.glob). + CategorySearch Category = "search" + // CategoryExec — execute external commands (bash). + CategoryExec Category = "exec" + // CategoryMeta — agent orchestration, introspection, and result handling. + // Default for tools that don't declare a category. + CategoryMeta Category = "meta" +) + +// AllCategories returns the canonical category list in stable order. +func AllCategories() []Category { + return []Category{CategoryRead, CategoryWrite, CategorySearch, CategoryExec, CategoryMeta} +} + +// IsValidCategory reports whether c is one of the known categories. +func IsValidCategory(c Category) bool { + switch c { + case CategoryRead, CategoryWrite, CategorySearch, CategoryExec, CategoryMeta: + return true + } + return false +} + +// Categorized is the optional interface a tool implements to declare its +// category. Tools that don't implement it fall back to CategoryMeta. +type Categorized interface { + Category() Category +} + +// CategoryOf returns the tool's declared category, or CategoryMeta if the +// tool does not implement Categorized. +func CategoryOf(t Tool) Category { + if c, ok := t.(Categorized); ok { + cat := c.Category() + if IsValidCategory(cat) { + return cat + } + } + return CategoryMeta +} diff --git a/internal/tool/category_test.go b/internal/tool/category_test.go new file mode 100644 index 0000000..1145268 --- /dev/null +++ b/internal/tool/category_test.go @@ -0,0 +1,70 @@ +package tool + +import ( + "slices" + "testing" +) + +type categorizedStub struct { + stubTool + cat Category +} + +func (c *categorizedStub) Category() Category { return c.cat } + +func TestCategoryOf_DefaultIsMeta(t *testing.T) { + plain := &stubTool{name: "plain"} + if got := CategoryOf(plain); got != CategoryMeta { + t.Errorf("CategoryOf(plain) = %q, want %q", got, CategoryMeta) + } +} + +func TestCategoryOf_DeclaredCategory(t *testing.T) { + cases := []struct { + name string + cat Category + }{ + {"read", CategoryRead}, + {"write", CategoryWrite}, + {"search", CategorySearch}, + {"exec", CategoryExec}, + {"meta", CategoryMeta}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + s := &categorizedStub{stubTool: stubTool{name: tc.name}, cat: tc.cat} + if got := CategoryOf(s); got != tc.cat { + t.Errorf("CategoryOf(%s) = %q, want %q", tc.name, got, tc.cat) + } + }) + } +} + +func TestCategoryOf_InvalidFallsBackToMeta(t *testing.T) { + s := &categorizedStub{stubTool: stubTool{name: "bogus"}, cat: Category("not-a-real-category")} + if got := CategoryOf(s); got != CategoryMeta { + t.Errorf("CategoryOf(invalid) = %q, want %q", got, CategoryMeta) + } +} + +func TestIsValidCategory(t *testing.T) { + for _, c := range AllCategories() { + if !IsValidCategory(c) { + t.Errorf("IsValidCategory(%q) = false, want true", c) + } + } + if IsValidCategory(Category("")) { + t.Error("empty category should be invalid") + } + if IsValidCategory(Category("nope")) { + t.Error("unknown category should be invalid") + } +} + +func TestAllCategoriesStable(t *testing.T) { + want := []Category{CategoryRead, CategoryWrite, CategorySearch, CategoryExec, CategoryMeta} + got := AllCategories() + if !slices.Equal(got, want) { + t.Errorf("AllCategories() = %v, want %v", got, want) + } +} diff --git a/internal/tool/fs/edit.go b/internal/tool/fs/edit.go index fb6d49d..1b89278 100644 --- a/internal/tool/fs/edit.go +++ b/internal/tool/fs/edit.go @@ -48,6 +48,7 @@ func (t *EditTool) Description() string { return "Perform exact string r func (t *EditTool) Parameters() json.RawMessage { return editParams } func (t *EditTool) IsReadOnly() bool { return false } func (t *EditTool) IsDestructive() bool { return false } +func (t *EditTool) Category() tool.Category { return tool.CategoryWrite } func (t *EditTool) ExtractPaths(args json.RawMessage) []string { var a editArgs diff --git a/internal/tool/fs/glob.go b/internal/tool/fs/glob.go index baa0c5a..0495896 100644 --- a/internal/tool/fs/glob.go +++ b/internal/tool/fs/glob.go @@ -43,6 +43,7 @@ func (t *GlobTool) Description() string { return "Find files matching a func (t *GlobTool) Parameters() json.RawMessage { return globParams } func (t *GlobTool) IsReadOnly() bool { return true } func (t *GlobTool) IsDestructive() bool { return false } +func (t *GlobTool) Category() tool.Category { return tool.CategorySearch } func (t *GlobTool) ExtractPaths(args json.RawMessage) []string { var a globArgs diff --git a/internal/tool/fs/grep.go b/internal/tool/fs/grep.go index 9fb46c2..0aad2cf 100644 --- a/internal/tool/fs/grep.go +++ b/internal/tool/fs/grep.go @@ -54,6 +54,7 @@ func (t *GrepTool) Description() string { return "Search file contents u func (t *GrepTool) Parameters() json.RawMessage { return grepParams } func (t *GrepTool) IsReadOnly() bool { return true } func (t *GrepTool) IsDestructive() bool { return false } +func (t *GrepTool) Category() tool.Category { return tool.CategorySearch } func (t *GrepTool) ExtractPaths(args json.RawMessage) []string { var a grepArgs diff --git a/internal/tool/fs/ls.go b/internal/tool/fs/ls.go index dc71954..9609b15 100644 --- a/internal/tool/fs/ls.go +++ b/internal/tool/fs/ls.go @@ -37,6 +37,7 @@ func (t *LSTool) Description() string { return "List directory contents func (t *LSTool) Parameters() json.RawMessage { return lsParams } func (t *LSTool) IsReadOnly() bool { return true } func (t *LSTool) IsDestructive() bool { return false } +func (t *LSTool) Category() tool.Category { return tool.CategoryRead } func (t *LSTool) ExtractPaths(args json.RawMessage) []string { var a lsArgs diff --git a/internal/tool/fs/read.go b/internal/tool/fs/read.go index 2e64bb8..665248f 100644 --- a/internal/tool/fs/read.go +++ b/internal/tool/fs/read.go @@ -55,11 +55,12 @@ func NewReadTool(opts ...ReadOption) *ReadTool { return t } -func (t *ReadTool) Name() string { return readToolName } -func (t *ReadTool) Description() string { return "Read a file from the filesystem with optional offset and line limit" } -func (t *ReadTool) Parameters() json.RawMessage { return readParams } -func (t *ReadTool) IsReadOnly() bool { return true } -func (t *ReadTool) IsDestructive() bool { return false } +func (t *ReadTool) Name() string { return readToolName } +func (t *ReadTool) Description() string { return "Read a file from the filesystem with optional offset and line limit" } +func (t *ReadTool) Parameters() json.RawMessage { return readParams } +func (t *ReadTool) IsReadOnly() bool { return true } +func (t *ReadTool) IsDestructive() bool { return false } +func (t *ReadTool) Category() tool.Category { return tool.CategoryRead } func (t *ReadTool) ExtractPaths(args json.RawMessage) []string { var a readArgs diff --git a/internal/tool/fs/write.go b/internal/tool/fs/write.go index 255920b..134d1fe 100644 --- a/internal/tool/fs/write.go +++ b/internal/tool/fs/write.go @@ -54,6 +54,7 @@ func (t *WriteTool) Description() string { return "Write content to a fi func (t *WriteTool) Parameters() json.RawMessage { return writeParams } func (t *WriteTool) IsReadOnly() bool { return false } func (t *WriteTool) IsDestructive() bool { return false } +func (t *WriteTool) Category() tool.Category { return tool.CategoryWrite } func (t *WriteTool) ExtractPaths(args json.RawMessage) []string { var a writeArgs