feat(engine): two-stage tool routing for small local arms

Plan A from docs/superpowers/plans/2026-05-19-post-slm-unlock.md.

Small local SLMs (<=16k context) waste ~1500 tokens per turn on the
full tool catalogue. Two-stage routing replaces round-1 tools with a
single synthetic select_category schema; round-2+ sends only the
selected category's real tool schemas plus select_category for
re-selection.

- internal/tool/category.go: Category type, optional Categorized
  interface, CategoryOf() with meta fallback. fs.read/fs.ls -> read,
  fs.write/fs.edit -> write, fs.glob/fs.grep -> search, bash -> exec.
- internal/engine/twostage.go: synthetic select_category tool,
  intercept helper, per-turn selectedCategory state under e.mu.
- Engine round 1 forces ToolChoiceRequired so SLMs don't fall back to
  prose. State resets at the top and end of every runLoop.
- Activates automatically on a forced local arm with ContextWindow
  <=16384, or via [router].force_two_stage TOML key.
- Integration test drives a 3-round trip and asserts: round 1 emits
  exactly one schema (synthetic) with ToolChoiceRequired, round 2
  contains only write-category schemas + select_category, real
  fs.write executes. Invalid-category fallback round-trips back to
  round-1 mode.
This commit is contained in:
2026-05-19 20:53:21 +02:00
parent 21da29e73e
commit 43ea2e562d
18 changed files with 1037 additions and 42 deletions
+4 -3
View File
@@ -750,9 +750,10 @@ func main() {
Model: *model,
Temperature: cfg.Provider.Temperature,
MaxTurns: *maxTurns,
Store: store,
Hooks: dispatcher,
Logger: logger,
Store: store,
Hooks: dispatcher,
Logger: logger,
ForceTwoStageTools: cfg.Router.ForceTwoStage,
})
if err != nil {
fmt.Fprintf(os.Stderr, "error: %v\n", err)
@@ -43,26 +43,49 @@ only that category's real schemas.
### Tasks
- [ ] `Category()` method on `tool.Tool` (or a registry-side mapping).
- [x] `Category()` method on `tool.Tool` (or a registry-side mapping).
Default category for unspecified tools: `meta`.
- [ ] New `engine.useTwoStageTools(arm)` predicate. Gates on
`arm.IsLocal && arm.Capabilities.ContextWindow <= 16384`, with an
optional `[router].force_two_stage = true` config override.
- [ ] Synthetic `select_category` tool definition emitted in
- [x] New `engine.useTwoStageTools()` predicate. Gates on
`arm.IsLocal && arm.Capabilities.ContextWindow <= 16384`, with the
`[router].force_two_stage = true` config override.
- [x] Synthetic `select_category` tool definition emitted in
`buildRequest()` when the predicate is true and we're in the first
round of a turn.
- [ ] Engine recognises a `select_category` tool result and filters
round of a turn. Round 1 also forces `ToolChoiceRequired` so SLMs
don't fall back to prose instead of calling the tool.
- [x] Engine recognises a `select_category` tool call and filters
the next round's tool schemas. The selection itself doesn't run a
real tool — it's consumed internally.
- [ ] Integration test covering the round-trip with a mocked
openaicompat arm.
real tool — it's consumed internally; `select_category` remains
available in round 2+ so the model can switch categories mid-turn.
- [x] Integration test covering the round-trip with a recording
mock provider.
**Exit criteria:** for an SLM-arm turn with ≤16 k context, the first
request contains exactly one synthetic tool schema; the second
contains only the schemas of the selected category. Real tool
selection still works (the second-round tool call executes normally).
**Status: shipped.** Phase A landed via the two-stage routing commit
on `main`. Module map:
- `internal/tool/category.go``Category` type, optional
`Categorized` interface, `CategoryOf()` helper with `meta` fallback.
- Real tools now declare categories: `fs.read`/`fs.ls` → read,
`fs.write`/`fs.edit` → write, `fs.glob`/`fs.grep` → search,
`bash` → exec. Agent/sysinfo fall through to meta.
- `internal/engine/twostage.go` — synthetic tool definition, intercept
helper, per-turn state (`selectedCategory`).
- `internal/engine/engine.go``Config.ForceTwoStageTools`,
`useTwoStageTools()` predicate, `twoStageContextLimit = 16384`.
- `internal/config/config.go``[router].force_two_stage` TOML key
wired through `cmd/gnoma/main.go`.
**Effort:** ~150 LOC + tests.
**Exit criteria — met:** for an SLM-arm turn with ≤16 k context, the
first request contains exactly one synthetic tool schema with
`ToolChoiceRequired`; the second contains only schemas of the
selected category plus `select_category`. Real tool selection still
works end-to-end (verified by `TestTwoStage_FullRoundTrip`).
**Effort:** ~250 LOC + tests (including 4 test files).
**Deferred for follow-up (not Phase A blockers):**
- Elf engines spawned from `internal/elf/manager.go` don't pass
through `ForceTwoStageTools` — small local elves still get the full
tool catalogue. Add per-elf two-stage detection mirroring the main
engine's auto-activation when telemetry shows it's worth it.
---
+12
View File
@@ -11,6 +11,7 @@ type Config struct {
Security SecuritySection `toml:"security"`
Session SessionSection `toml:"session"`
SLM SLMSection `toml:"slm"`
Router RouterSection `toml:"router"`
Hooks []HookConfig `toml:"hooks"`
MCPServers []MCPServerConfig `toml:"mcp_servers"`
Plugins PluginsSection `toml:"plugins"`
@@ -39,6 +40,17 @@ type SLMSection struct {
StartupTimeout Duration `toml:"startup_timeout"` // llamafile-only: first-launch wait budget; 0 = default 5s
}
// RouterSection holds router-level overrides. Most routing decisions are
// driven automatically by arm capabilities and the bandit; this section
// exists for the rare overrides that don't fit elsewhere.
type RouterSection struct {
// ForceTwoStage forces the two-stage tool-routing path regardless of
// arm context window. Useful for debugging or for forcing the behavior
// on a large local model. Defaults to false: two-stage activates
// automatically on local arms with context window <= 16k.
ForceTwoStage bool `toml:"force_two_stage"`
}
// MCPServerConfig defines an MCP server to start and connect to.
//
// Example:
+3 -1
View File
@@ -95,7 +95,9 @@ func TestBuildRequest_ForcedArmWithToolSupport_IncludesTools(t *testing.T) {
Provider: &mockProvider{name: "llamacpp"},
ModelName: "qwen3",
IsLocal: true,
Capabilities: provider.Capabilities{ToolUse: true},
// ContextWindow > 16384 keeps two-stage routing inactive so this
// test exercises the plain "tools included" path.
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 32768},
})
rtr.ForceArm("llamacpp/qwen3")
+50
View File
@@ -33,8 +33,18 @@ type Config struct {
Store *persist.Store // nil = no result persistence
Hooks *hook.Dispatcher // nil = no hooks
Logger *slog.Logger
// ForceTwoStageTools forces the two-stage tool-routing path on the forced
// arm regardless of its context window. When false, two-stage is enabled
// automatically for local arms with ContextWindow <= twoStageContextLimit.
ForceTwoStageTools bool
}
// twoStageContextLimit is the upper bound on arm context window (in tokens)
// under which two-stage tool routing kicks in automatically. Models bigger
// than this can afford the full tool catalogue in every request.
const twoStageContextLimit = 16384
func (c Config) validate() error {
if c.Provider == nil {
return fmt.Errorf("engine: provider required")
@@ -77,6 +87,12 @@ type Engine struct {
activatedTools map[string]bool
turnOpts TurnOptions
// selectedCategory is set when the model picks a category via the
// synthetic select_category tool under two-stage routing. Empty string
// means "round 1 of two-stage" (or two-stage inactive). Reset at the end
// of each turn together with turnOpts.
selectedCategory tool.Category
}
// ToolsAvailable reports whether the current model supports tool calling.
@@ -128,6 +144,40 @@ func (e *Engine) isLocalArm() bool {
return arm.IsLocal
}
// useTwoStageTools reports whether the current turn should use the two-stage
// tool-routing path. True when:
// - cfg.ForceTwoStageTools is set, OR
// - a forced arm exists, is local, and its ContextWindow is small enough
// that the full tool catalogue would burn a non-trivial fraction of the
// prompt budget.
//
// Multi-arm routing (no forced arm) and cloud arms do not trigger two-stage.
func (e *Engine) useTwoStageTools() bool {
if e.cfg.ForceTwoStageTools {
return true
}
if e.cfg.Router == nil {
return false
}
id := e.cfg.Router.ForcedArm()
if id == "" {
return false
}
arm, ok := e.cfg.Router.LookupArm(id)
if !ok {
return false
}
if !arm.IsLocal {
return false
}
cw := arm.Capabilities.ContextWindow
if cw <= 0 {
// Unknown context window on a local arm — assume small.
return true
}
return cw <= twoStageContextLimit
}
// New creates an engine.
func New(cfg Config) (*Engine, error) {
if err := cfg.validate(); err != nil {
+53 -18
View File
@@ -62,6 +62,11 @@ func (e *Engine) SubmitMessages(ctx context.Context, msgs []message.Message, cb
}
func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
// Two-stage tool-routing state is per-turn; clear it so an aborted turn
// can't leak its last category selection into the next one.
e.resetTwoStageState()
defer e.resetTwoStageState()
turn := &Turn{}
loopStart := time.Now()
var lastArmID router.ArmID
@@ -481,26 +486,51 @@ func (e *Engine) buildRequest(ctx context.Context) provider.Request {
includeTools = caps == nil || caps.ToolUse
}
if includeTools {
allowed := turnOpts.AllowedTools
for _, t := range e.cfg.Tools.All() {
// Skip deferred tools until the model requests them
if dt, ok := t.(tool.DeferrableTool); ok && dt.ShouldDefer() && !e.isToolActivated(t.Name()) {
continue
twoStage := e.useTwoStageTools()
selected := e.snapshotSelectedCategory()
if twoStage && selected == "" {
// Round 1 of two-stage: send only the synthetic select_category tool
// and force the model to call it. Small SLMs given a single optional
// tool will often emit prose instead of calling it.
req.Tools = []provider.ToolDefinition{buildSelectCategoryDef()}
req.ToolChoice = provider.ToolChoiceRequired
e.logger.Debug("two-stage: round 1 — emitting select_category only",
"model", req.Model,
)
} else {
allowed := turnOpts.AllowedTools
for _, t := range e.cfg.Tools.All() {
// Skip deferred tools until the model requests them
if dt, ok := t.(tool.DeferrableTool); ok && dt.ShouldDefer() && !e.isToolActivated(t.Name()) {
continue
}
// Filter to allowed tools when a restrict list is set
if allowed != nil && !slices.Contains(allowed, t.Name()) {
continue
}
// Under two-stage round 2+, only schemas in the selected category.
if twoStage && tool.CategoryOf(t) != selected {
continue
}
req.Tools = append(req.Tools, provider.ToolDefinition{
Name: t.Name(),
Description: t.Description(),
Parameters: t.Parameters(),
})
}
// Filter to allowed tools when a restrict list is set
if allowed != nil && !slices.Contains(allowed, t.Name()) {
continue
// Keep select_category available while two-stage is active so the
// model can switch categories without aborting the turn.
if twoStage {
req.Tools = append(req.Tools, buildSelectCategoryDef())
}
req.Tools = append(req.Tools, provider.ToolDefinition{
Name: t.Name(),
Description: t.Description(),
Parameters: t.Parameters(),
})
e.logger.Debug("tools included in request",
"model", req.Model,
"count", len(req.Tools),
"two_stage", twoStage,
"category", string(selected),
)
}
e.logger.Debug("tools included in request",
"model", req.Model,
"count", len(req.Tools),
)
} else {
e.logger.Debug("tools omitted — model does not support tool use",
"model", req.Model,
@@ -542,6 +572,10 @@ Synthesis:
}
func (e *Engine) executeTools(ctx context.Context, calls []message.ToolCall, cb Callback) ([]message.ToolResult, error) {
// Intercept the synthetic select_category tool first — it never reaches
// the registry and produces its own synthetic tool result.
calls, syntheticResults := e.interceptSelectCategoryCalls(calls)
// Partition into read-only (parallel) and write (serial) batches
type toolCallWithTool struct {
call message.ToolCall
@@ -577,7 +611,8 @@ func (e *Engine) executeTools(ctx context.Context, calls []message.ToolCall, cb
}
}
results := make([]message.ToolResult, 0, len(calls))
results := make([]message.ToolResult, 0, len(calls)+len(syntheticResults))
results = append(results, syntheticResults...)
results = append(results, unknownResults...)
// Execute read-only tools in parallel
+139
View File
@@ -0,0 +1,139 @@
package engine
import (
"encoding/json"
"fmt"
"strings"
"somegit.dev/Owlibou/gnoma/internal/message"
"somegit.dev/Owlibou/gnoma/internal/provider"
"somegit.dev/Owlibou/gnoma/internal/tool"
)
// SyntheticSelectCategoryName is the tool name used by the two-stage routing
// path to let the model pick a category before real tool schemas are sent.
// The name is exported for tests that need to assert against it.
const SyntheticSelectCategoryName = "select_category"
// buildSelectCategoryDef constructs the synthetic select_category tool
// definition. Categories in the enum match tool.AllCategories() so the
// schema stays in sync if categories are added.
func buildSelectCategoryDef() provider.ToolDefinition {
cats := tool.AllCategories()
quoted := make([]string, len(cats))
for i, c := range cats {
quoted[i] = `"` + string(c) + `"`
}
params := json.RawMessage(`{
"type": "object",
"properties": {
"category": {
"type": "string",
"enum": [` + strings.Join(quoted, ", ") + `],
"description": "Tool category to load schemas for. Pick one based on what you intend to do next."
}
},
"required": ["category"]
}`)
return provider.ToolDefinition{
Name: SyntheticSelectCategoryName,
Description: "Select the category of tools to load for the next round. " +
"Use 'read' for file reads, 'write' to modify files, 'search' to search the codebase, " +
"'exec' to run commands, 'meta' for agent orchestration and introspection. " +
"You can call this again later to switch categories.",
Parameters: params,
}
}
// snapshotSelectedCategory returns the category chosen by the model so far
// in this turn (or empty string if none).
func (e *Engine) snapshotSelectedCategory() tool.Category {
e.mu.Lock()
defer e.mu.Unlock()
return e.selectedCategory
}
// setSelectedCategory records the model's category choice under lock.
func (e *Engine) setSelectedCategory(c tool.Category) {
e.mu.Lock()
e.selectedCategory = c
e.mu.Unlock()
}
// resetTwoStageState clears any per-turn two-stage state. Called at the start
// of every runLoop so an aborted previous turn cannot leak state forward.
func (e *Engine) resetTwoStageState() {
e.mu.Lock()
e.selectedCategory = ""
e.mu.Unlock()
}
// interceptSelectCategoryCalls splits the incoming tool calls into two
// buckets: real calls that need actual tool execution, and synthetic
// select_category calls that the engine handles internally. It updates
// e.selectedCategory as a side effect and returns synthetic tool results
// that satisfy the provider's "every tool_call needs a tool_result" contract.
func (e *Engine) interceptSelectCategoryCalls(calls []message.ToolCall) (realCalls []message.ToolCall, syntheticResults []message.ToolResult) {
for _, call := range calls {
if call.Name != SyntheticSelectCategoryName {
realCalls = append(realCalls, call)
continue
}
result := e.handleSelectCategory(call)
syntheticResults = append(syntheticResults, result)
}
return realCalls, syntheticResults
}
// handleSelectCategory parses the synthetic tool call, updates engine state,
// and returns the tool result the model will see in the next round.
func (e *Engine) handleSelectCategory(call message.ToolCall) message.ToolResult {
var args struct {
Category string `json:"category"`
}
if err := json.Unmarshal(call.Arguments, &args); err != nil {
e.setSelectedCategory("")
return message.ToolResult{
ToolCallID: call.ID,
Content: fmt.Sprintf("invalid arguments for select_category: %v. Please pick one of: read, write, search, exec, meta.", err),
IsError: true,
}
}
cat := tool.Category(strings.ToLower(strings.TrimSpace(args.Category)))
if !tool.IsValidCategory(cat) {
e.setSelectedCategory("")
return message.ToolResult{
ToolCallID: call.ID,
Content: fmt.Sprintf("unknown category %q. Pick one of: read, write, search, exec, meta.", args.Category),
IsError: true,
}
}
e.setSelectedCategory(cat)
available := e.toolNamesForCategory(cat)
content := fmt.Sprintf("Category %q selected. Tools now available: %s. Call them directly on the next turn.",
cat, strings.Join(available, ", "))
if e.logger != nil {
e.logger.Debug("two-stage: category selected",
"category", cat,
"tools", available,
)
}
return message.ToolResult{
ToolCallID: call.ID,
Content: content,
}
}
// toolNamesForCategory returns the registered tool names whose category
// matches the argument, in deterministic order.
func (e *Engine) toolNamesForCategory(cat tool.Category) []string {
var names []string
for _, t := range e.cfg.Tools.All() {
if tool.CategoryOf(t) == cat {
names = append(names, t.Name())
}
}
return names
}
@@ -0,0 +1,185 @@
package engine
import (
"context"
"encoding/json"
"fmt"
"slices"
"sync"
"testing"
"somegit.dev/Owlibou/gnoma/internal/message"
"somegit.dev/Owlibou/gnoma/internal/provider"
"somegit.dev/Owlibou/gnoma/internal/stream"
"somegit.dev/Owlibou/gnoma/internal/tool"
)
// recordingProvider captures every Request it receives so tests can assert
// on the tool catalogue per round.
type recordingProvider struct {
mu sync.Mutex
requests []provider.Request
streams []stream.Stream
calls int
}
func (m *recordingProvider) Name() string { return "recording" }
func (m *recordingProvider) DefaultModel() string { return "mock-model" }
func (m *recordingProvider) Models(_ context.Context) ([]provider.ModelInfo, error) {
return []provider.ModelInfo{{
ID: "mock-model", Name: "mock-model", Provider: "recording",
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 8192},
}}, nil
}
func (m *recordingProvider) Stream(_ context.Context, req provider.Request) (stream.Stream, error) {
m.mu.Lock()
defer m.mu.Unlock()
// Clone request so subsequent rounds don't mutate captured state.
clone := req
clone.Tools = slices.Clone(req.Tools)
m.requests = append(m.requests, clone)
if m.calls >= len(m.streams) {
return nil, fmt.Errorf("recording: no more streams (called %d times)", m.calls+1)
}
s := m.streams[m.calls]
m.calls++
return s, nil
}
// singleToolCallStream returns a stream that emits a single tool call.
func singleToolCallStream(callID, name, args string) stream.Stream {
return newEventStream(message.StopToolUse, "mock-model",
stream.Event{Type: stream.EventToolCallStart, ToolCallID: callID, ToolCallName: name},
stream.Event{Type: stream.EventToolCallDone, ToolCallID: callID, ToolCallName: name, Args: json.RawMessage(args)},
)
}
// endTurnTextStream returns a stream that emits text and ends the turn.
func endTurnTextStream(text string) stream.Stream {
return newEventStream(message.StopEndTurn, "mock-model",
stream.Event{Type: stream.EventTextDelta, Text: text},
)
}
func TestTwoStage_FullRoundTrip(t *testing.T) {
// Tool registry: mixed categories so we can verify round 2 filtering.
writeCalled := false
reg := tool.NewRegistry()
reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.read", readOnly: true}, cat: tool.CategoryRead})
reg.Register(&categorizedMockTool{
mockTool: mockTool{
name: "fs.write",
execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) {
writeCalled = true
return tool.Result{Output: "wrote file"}, nil
},
},
cat: tool.CategoryWrite,
})
reg.Register(&categorizedMockTool{mockTool: mockTool{name: "bash"}, cat: tool.CategoryExec})
// Three rounds: select_category → fs.write → end turn.
mp := &recordingProvider{
streams: []stream.Stream{
singleToolCallStream("c1", SyntheticSelectCategoryName, `{"category":"write"}`),
singleToolCallStream("c2", "fs.write", `{"path":"/tmp/x","content":"hi"}`),
endTurnTextStream("done."),
},
}
e, err := New(Config{
Provider: mp,
Tools: reg,
ForceTwoStageTools: true, // no router needed; just force the path
})
if err != nil {
t.Fatalf("New: %v", err)
}
turn, err := e.Submit(context.Background(), "write a file", nil)
if err != nil {
t.Fatalf("Submit: %v", err)
}
if turn.Rounds != 3 {
t.Errorf("Rounds = %d, want 3", turn.Rounds)
}
if !writeCalled {
t.Error("fs.write tool was not executed")
}
if len(mp.requests) != 3 {
t.Fatalf("captured %d requests, want 3", len(mp.requests))
}
// Round 1: only synthetic select_category, ToolChoice = Required.
r1 := mp.requests[0]
if len(r1.Tools) != 1 {
t.Errorf("round 1 tool count = %d, want 1; tools=%v", len(r1.Tools), toolNamesIn(r1.Tools))
}
if len(r1.Tools) >= 1 && r1.Tools[0].Name != SyntheticSelectCategoryName {
t.Errorf("round 1 tool[0] = %q, want %q", r1.Tools[0].Name, SyntheticSelectCategoryName)
}
if r1.ToolChoice != provider.ToolChoiceRequired {
t.Errorf("round 1 ToolChoice = %q, want %q", r1.ToolChoice, provider.ToolChoiceRequired)
}
// Round 2: write tools + select_category, no read/exec.
r2names := toolNamesIn(mp.requests[1].Tools)
if !slices.Contains(r2names, "fs.write") {
t.Errorf("round 2 missing fs.write: %v", r2names)
}
if !slices.Contains(r2names, SyntheticSelectCategoryName) {
t.Errorf("round 2 missing select_category (re-selection should remain available): %v", r2names)
}
if slices.Contains(r2names, "fs.read") || slices.Contains(r2names, "bash") {
t.Errorf("round 2 leaked non-write tools: %v", r2names)
}
// Round 3: same filter still applied (selection persists for the turn).
r3names := toolNamesIn(mp.requests[2].Tools)
if slices.Contains(r3names, "fs.read") || slices.Contains(r3names, "bash") {
t.Errorf("round 3 leaked non-write tools: %v", r3names)
}
}
func TestTwoStage_InvalidCategoryFallsBackToRoundOne(t *testing.T) {
reg := tool.NewRegistry()
reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.write"}, cat: tool.CategoryWrite})
mp := &recordingProvider{
streams: []stream.Stream{
// Round 1: model picks an invalid category.
singleToolCallStream("c1", SyntheticSelectCategoryName, `{"category":"bogus"}`),
// Round 2: should see select_category-only again (round 1 mode).
singleToolCallStream("c2", SyntheticSelectCategoryName, `{"category":"write"}`),
// Round 3: filtered to write.
endTurnTextStream("done."),
},
}
e, err := New(Config{
Provider: mp,
Tools: reg,
ForceTwoStageTools: true,
})
if err != nil {
t.Fatalf("New: %v", err)
}
if _, err := e.Submit(context.Background(), "do stuff", nil); err != nil {
t.Fatalf("Submit: %v", err)
}
if len(mp.requests) < 2 {
t.Fatalf("expected at least 2 requests, got %d", len(mp.requests))
}
// After invalid category, round 2 should be back to round-1 mode (only synthetic).
r2 := mp.requests[1]
if len(r2.Tools) != 1 || r2.Tools[0].Name != SyntheticSelectCategoryName {
t.Errorf("after invalid category, expected round-1 mode again; got tools=%v", toolNamesIn(r2.Tools))
}
if r2.ToolChoice != provider.ToolChoiceRequired {
t.Errorf("after invalid category, expected ToolChoiceRequired; got %q", r2.ToolChoice)
}
}
+419
View File
@@ -0,0 +1,419 @@
package engine
import (
"context"
"encoding/json"
"slices"
"sort"
"testing"
"somegit.dev/Owlibou/gnoma/internal/message"
"somegit.dev/Owlibou/gnoma/internal/provider"
"somegit.dev/Owlibou/gnoma/internal/router"
"somegit.dev/Owlibou/gnoma/internal/tool"
)
// categorizedMockTool extends mockTool with a Category() method so tests can
// exercise the two-stage category filter.
type categorizedMockTool struct {
mockTool
cat tool.Category
}
func (c *categorizedMockTool) Category() tool.Category { return c.cat }
// twoStageEngine builds an engine wired to a small local forced arm so
// useTwoStageTools() returns true.
func twoStageEngine(t *testing.T, reg *tool.Registry) *Engine {
t.Helper()
rtr := router.New(router.Config{})
rtr.RegisterArm(&router.Arm{
ID: "llamacpp/qwen3-1b",
Provider: &mockProvider{name: "llamacpp"},
ModelName: "qwen3-1b",
IsLocal: true,
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 8192},
})
rtr.ForceArm("llamacpp/qwen3-1b")
e, err := New(Config{
Provider: &mockProvider{name: "llamacpp"},
Router: rtr,
Tools: reg,
})
if err != nil {
t.Fatalf("New: %v", err)
}
return e
}
func TestUseTwoStageTools(t *testing.T) {
smallLocal := &router.Arm{
ID: "llamacpp/qwen3-1b",
Provider: &mockProvider{name: "llamacpp"},
ModelName: "qwen3-1b",
IsLocal: true,
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 8192},
}
bigLocal := &router.Arm{
ID: "llamacpp/qwen3-30b",
Provider: &mockProvider{name: "llamacpp"},
ModelName: "qwen3-30b",
IsLocal: true,
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 32768},
}
cloud := &router.Arm{
ID: "anthropic/sonnet",
Provider: &mockProvider{name: "anthropic"},
ModelName: "sonnet",
IsLocal: false,
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000},
}
localUnknownCtx := &router.Arm{
ID: "ollama/mystery",
Provider: &mockProvider{name: "ollama"},
ModelName: "mystery",
IsLocal: true,
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 0},
}
cases := []struct {
name string
arm *router.Arm
forced bool
want bool
message string
}{
{
name: "small local arm triggers two-stage",
arm: smallLocal,
want: true,
},
{
name: "large local arm does not trigger two-stage",
arm: bigLocal,
want: false,
},
{
name: "cloud arm never triggers two-stage",
arm: cloud,
want: false,
},
{
name: "local arm with unknown context window triggers two-stage",
arm: localUnknownCtx,
want: true,
},
{
name: "ForceTwoStageTools overrides cloud arm",
arm: cloud,
forced: true,
want: true,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
rtr := router.New(router.Config{})
rtr.RegisterArm(tc.arm)
rtr.ForceArm(tc.arm.ID)
e, err := New(Config{
Provider: &mockProvider{name: string(tc.arm.ID.Provider())},
Router: rtr,
Tools: tool.NewRegistry(),
ForceTwoStageTools: tc.forced,
})
if err != nil {
t.Fatalf("New: %v", err)
}
if got := e.useTwoStageTools(); got != tc.want {
t.Errorf("useTwoStageTools() = %v, want %v", got, tc.want)
}
})
}
}
func TestUseTwoStageTools_NoRouter(t *testing.T) {
e, err := New(Config{
Provider: &mockProvider{name: "anthropic"},
Tools: tool.NewRegistry(),
})
if err != nil {
t.Fatalf("New: %v", err)
}
if e.useTwoStageTools() {
t.Error("useTwoStageTools() = true, want false when no router")
}
}
func TestUseTwoStageTools_NoRouter_ForcedOverride(t *testing.T) {
e, err := New(Config{
Provider: &mockProvider{name: "anthropic"},
Tools: tool.NewRegistry(),
ForceTwoStageTools: true,
})
if err != nil {
t.Fatalf("New: %v", err)
}
if !e.useTwoStageTools() {
t.Error("useTwoStageTools() = false, want true when ForceTwoStageTools is set even without router")
}
}
func TestUseTwoStageTools_NoForcedArm(t *testing.T) {
rtr := router.New(router.Config{})
rtr.RegisterArm(&router.Arm{
ID: "llamacpp/qwen3-1b",
Provider: &mockProvider{name: "llamacpp"},
ModelName: "qwen3-1b",
IsLocal: true,
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 8192},
})
// No ForceArm called — multi-arm routing
e, err := New(Config{
Provider: &mockProvider{name: "llamacpp"},
Router: rtr,
Tools: tool.NewRegistry(),
})
if err != nil {
t.Fatalf("New: %v", err)
}
if e.useTwoStageTools() {
t.Error("useTwoStageTools() = true, want false for multi-arm routing")
}
}
func TestBuildRequest_TwoStage_Round1_EmitsSyntheticOnly(t *testing.T) {
reg := tool.NewRegistry()
reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.read", readOnly: true}, cat: tool.CategoryRead})
reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.write"}, cat: tool.CategoryWrite})
reg.Register(&categorizedMockTool{mockTool: mockTool{name: "bash"}, cat: tool.CategoryExec})
e := twoStageEngine(t, reg)
req := e.buildRequest(context.Background())
if len(req.Tools) != 1 {
t.Fatalf("round 1 should emit exactly one tool (synthetic); got %d: %+v", len(req.Tools), req.Tools)
}
if req.Tools[0].Name != SyntheticSelectCategoryName {
t.Errorf("round 1 tool name = %q, want %q", req.Tools[0].Name, SyntheticSelectCategoryName)
}
if req.ToolChoice != provider.ToolChoiceRequired {
t.Errorf("round 1 ToolChoice = %q, want %q", req.ToolChoice, provider.ToolChoiceRequired)
}
}
func TestBuildRequest_TwoStage_Round1_SyntheticEnumMatchesAllCategories(t *testing.T) {
reg := tool.NewRegistry()
reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.read", readOnly: true}, cat: tool.CategoryRead})
e := twoStageEngine(t, reg)
req := e.buildRequest(context.Background())
if len(req.Tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(req.Tools))
}
var schema struct {
Properties struct {
Category struct {
Enum []string `json:"enum"`
} `json:"category"`
} `json:"properties"`
}
if err := json.Unmarshal(req.Tools[0].Parameters, &schema); err != nil {
t.Fatalf("unmarshal synthetic params: %v", err)
}
want := make([]string, 0, len(tool.AllCategories()))
for _, c := range tool.AllCategories() {
want = append(want, string(c))
}
got := slices.Clone(schema.Properties.Category.Enum)
sort.Strings(got)
sort.Strings(want)
if !slices.Equal(got, want) {
t.Errorf("category enum = %v, want %v", got, want)
}
}
func TestBuildRequest_TwoStage_Round2_FiltersByCategory(t *testing.T) {
reg := tool.NewRegistry()
reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.read", readOnly: true}, cat: tool.CategoryRead})
reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.write"}, cat: tool.CategoryWrite})
reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.edit"}, cat: tool.CategoryWrite})
reg.Register(&categorizedMockTool{mockTool: mockTool{name: "bash"}, cat: tool.CategoryExec})
e := twoStageEngine(t, reg)
e.setSelectedCategory(tool.CategoryWrite)
req := e.buildRequest(context.Background())
names := toolNamesIn(req.Tools)
sort.Strings(names)
want := []string{SyntheticSelectCategoryName, "fs.edit", "fs.write"}
sort.Strings(want)
if !slices.Equal(names, want) {
t.Errorf("round 2 tools = %v, want %v", names, want)
}
// ToolChoice should not be forced in round 2 — let the model decide.
if req.ToolChoice == provider.ToolChoiceRequired {
t.Errorf("round 2 ToolChoice = %q, should not be Required", req.ToolChoice)
}
}
func TestBuildRequest_TwoStage_UncategorizedToolDefaultsToMeta(t *testing.T) {
reg := tool.NewRegistry()
reg.Register(&mockTool{name: "agent"}) // no Category() method → meta
reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.read", readOnly: true}, cat: tool.CategoryRead})
e := twoStageEngine(t, reg)
e.setSelectedCategory(tool.CategoryMeta)
req := e.buildRequest(context.Background())
names := toolNamesIn(req.Tools)
if !slices.Contains(names, "agent") {
t.Errorf("meta filter should include uncategorized 'agent'; got %v", names)
}
if slices.Contains(names, "fs.read") {
t.Errorf("meta filter should exclude 'fs.read'; got %v", names)
}
}
func TestBuildRequest_NonTwoStage_UnchangedBehavior(t *testing.T) {
// Large local arm — two-stage should not activate.
rtr := router.New(router.Config{})
rtr.RegisterArm(&router.Arm{
ID: "llamacpp/qwen3-30b",
Provider: &mockProvider{name: "llamacpp"},
ModelName: "qwen3-30b",
IsLocal: true,
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 32768},
})
rtr.ForceArm("llamacpp/qwen3-30b")
reg := tool.NewRegistry()
reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.read", readOnly: true}, cat: tool.CategoryRead})
reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.write"}, cat: tool.CategoryWrite})
e, err := New(Config{
Provider: &mockProvider{name: "llamacpp"},
Router: rtr,
Tools: reg,
})
if err != nil {
t.Fatalf("New: %v", err)
}
req := e.buildRequest(context.Background())
if len(req.Tools) != 2 {
t.Errorf("non-two-stage path: got %d tools, want 2", len(req.Tools))
}
for _, td := range req.Tools {
if td.Name == SyntheticSelectCategoryName {
t.Errorf("non-two-stage path should not emit synthetic select_category")
}
}
}
func TestInterceptSelectCategoryCalls_UpdatesStateAndReturnsResult(t *testing.T) {
reg := tool.NewRegistry()
reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.write"}, cat: tool.CategoryWrite})
e := twoStageEngine(t, reg)
call := message.ToolCall{
ID: "call-1",
Name: SyntheticSelectCategoryName,
Arguments: json.RawMessage(`{"category": "write"}`),
}
real, synth := e.interceptSelectCategoryCalls([]message.ToolCall{call})
if len(real) != 0 {
t.Errorf("real calls = %d, want 0", len(real))
}
if len(synth) != 1 {
t.Fatalf("synthetic results = %d, want 1", len(synth))
}
if synth[0].IsError {
t.Errorf("synthetic result should not be error: %s", synth[0].Content)
}
if synth[0].ToolCallID != "call-1" {
t.Errorf("ToolCallID = %q, want call-1", synth[0].ToolCallID)
}
if e.snapshotSelectedCategory() != tool.CategoryWrite {
t.Errorf("selectedCategory = %q, want write", e.snapshotSelectedCategory())
}
}
func TestInterceptSelectCategoryCalls_InvalidCategoryReturnsError(t *testing.T) {
e := twoStageEngine(t, tool.NewRegistry())
call := message.ToolCall{
ID: "call-1",
Name: SyntheticSelectCategoryName,
Arguments: json.RawMessage(`{"category": "bogus"}`),
}
_, synth := e.interceptSelectCategoryCalls([]message.ToolCall{call})
if len(synth) != 1 {
t.Fatalf("synthetic results = %d, want 1", len(synth))
}
if !synth[0].IsError {
t.Errorf("invalid category should yield error result")
}
if e.snapshotSelectedCategory() != "" {
t.Errorf("invalid category should clear selectedCategory; got %q", e.snapshotSelectedCategory())
}
}
func TestInterceptSelectCategoryCalls_InvalidJSONReturnsError(t *testing.T) {
e := twoStageEngine(t, tool.NewRegistry())
call := message.ToolCall{
ID: "call-1",
Name: SyntheticSelectCategoryName,
Arguments: json.RawMessage(`not-json`),
}
_, synth := e.interceptSelectCategoryCalls([]message.ToolCall{call})
if len(synth) != 1 || !synth[0].IsError {
t.Fatalf("invalid JSON should yield single error result")
}
}
func TestInterceptSelectCategoryCalls_MixedRealAndSynthetic(t *testing.T) {
e := twoStageEngine(t, tool.NewRegistry())
calls := []message.ToolCall{
{ID: "real-1", Name: "fs.read", Arguments: json.RawMessage(`{}`)},
{ID: "synth-1", Name: SyntheticSelectCategoryName, Arguments: json.RawMessage(`{"category":"read"}`)},
{ID: "real-2", Name: "bash", Arguments: json.RawMessage(`{}`)},
}
real, synth := e.interceptSelectCategoryCalls(calls)
if len(real) != 2 {
t.Errorf("real calls = %d, want 2", len(real))
}
if len(synth) != 1 {
t.Errorf("synthetic results = %d, want 1", len(synth))
}
if e.snapshotSelectedCategory() != tool.CategoryRead {
t.Errorf("selectedCategory = %q, want read", e.snapshotSelectedCategory())
}
}
func TestResetTwoStageState_ClearsCategory(t *testing.T) {
e := twoStageEngine(t, tool.NewRegistry())
e.setSelectedCategory(tool.CategoryExec)
e.resetTwoStageState()
if e.snapshotSelectedCategory() != "" {
t.Errorf("resetTwoStageState should clear category; got %q", e.snapshotSelectedCategory())
}
}
func toolNamesIn(defs []provider.ToolDefinition) []string {
names := make([]string, len(defs))
for i, d := range defs {
names[i] = d.Name
}
return names
}
+1
View File
@@ -66,6 +66,7 @@ func (t *Tool) Description() string { return "Execute a bash command and
func (t *Tool) Parameters() json.RawMessage { return parameterSchema }
func (t *Tool) IsReadOnly() bool { return false }
func (t *Tool) IsDestructive() bool { return true }
func (t *Tool) Category() tool.Category { return tool.CategoryExec }
type bashArgs struct {
Command string `json:"command"`
+52
View File
@@ -0,0 +1,52 @@
package tool
// Category groups tools by what they do. Used by the two-stage tool routing
// path for small local models: round 1 picks a category, round 2 only sees
// schemas in that category.
type Category string
const (
// CategoryRead — read filesystem state (fs.read, fs.ls).
CategoryRead Category = "read"
// CategoryWrite — modify filesystem state (fs.write, fs.edit).
CategoryWrite Category = "write"
// CategorySearch — search filesystem content (fs.grep, fs.glob).
CategorySearch Category = "search"
// CategoryExec — execute external commands (bash).
CategoryExec Category = "exec"
// CategoryMeta — agent orchestration, introspection, and result handling.
// Default for tools that don't declare a category.
CategoryMeta Category = "meta"
)
// AllCategories returns the canonical category list in stable order.
func AllCategories() []Category {
return []Category{CategoryRead, CategoryWrite, CategorySearch, CategoryExec, CategoryMeta}
}
// IsValidCategory reports whether c is one of the known categories.
func IsValidCategory(c Category) bool {
switch c {
case CategoryRead, CategoryWrite, CategorySearch, CategoryExec, CategoryMeta:
return true
}
return false
}
// Categorized is the optional interface a tool implements to declare its
// category. Tools that don't implement it fall back to CategoryMeta.
type Categorized interface {
Category() Category
}
// CategoryOf returns the tool's declared category, or CategoryMeta if the
// tool does not implement Categorized.
func CategoryOf(t Tool) Category {
if c, ok := t.(Categorized); ok {
cat := c.Category()
if IsValidCategory(cat) {
return cat
}
}
return CategoryMeta
}
+70
View File
@@ -0,0 +1,70 @@
package tool
import (
"slices"
"testing"
)
type categorizedStub struct {
stubTool
cat Category
}
func (c *categorizedStub) Category() Category { return c.cat }
func TestCategoryOf_DefaultIsMeta(t *testing.T) {
plain := &stubTool{name: "plain"}
if got := CategoryOf(plain); got != CategoryMeta {
t.Errorf("CategoryOf(plain) = %q, want %q", got, CategoryMeta)
}
}
func TestCategoryOf_DeclaredCategory(t *testing.T) {
cases := []struct {
name string
cat Category
}{
{"read", CategoryRead},
{"write", CategoryWrite},
{"search", CategorySearch},
{"exec", CategoryExec},
{"meta", CategoryMeta},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
s := &categorizedStub{stubTool: stubTool{name: tc.name}, cat: tc.cat}
if got := CategoryOf(s); got != tc.cat {
t.Errorf("CategoryOf(%s) = %q, want %q", tc.name, got, tc.cat)
}
})
}
}
func TestCategoryOf_InvalidFallsBackToMeta(t *testing.T) {
s := &categorizedStub{stubTool: stubTool{name: "bogus"}, cat: Category("not-a-real-category")}
if got := CategoryOf(s); got != CategoryMeta {
t.Errorf("CategoryOf(invalid) = %q, want %q", got, CategoryMeta)
}
}
func TestIsValidCategory(t *testing.T) {
for _, c := range AllCategories() {
if !IsValidCategory(c) {
t.Errorf("IsValidCategory(%q) = false, want true", c)
}
}
if IsValidCategory(Category("")) {
t.Error("empty category should be invalid")
}
if IsValidCategory(Category("nope")) {
t.Error("unknown category should be invalid")
}
}
func TestAllCategoriesStable(t *testing.T) {
want := []Category{CategoryRead, CategoryWrite, CategorySearch, CategoryExec, CategoryMeta}
got := AllCategories()
if !slices.Equal(got, want) {
t.Errorf("AllCategories() = %v, want %v", got, want)
}
}
+1
View File
@@ -48,6 +48,7 @@ func (t *EditTool) Description() string { return "Perform exact string r
func (t *EditTool) Parameters() json.RawMessage { return editParams }
func (t *EditTool) IsReadOnly() bool { return false }
func (t *EditTool) IsDestructive() bool { return false }
func (t *EditTool) Category() tool.Category { return tool.CategoryWrite }
func (t *EditTool) ExtractPaths(args json.RawMessage) []string {
var a editArgs
+1
View File
@@ -43,6 +43,7 @@ func (t *GlobTool) Description() string { return "Find files matching a
func (t *GlobTool) Parameters() json.RawMessage { return globParams }
func (t *GlobTool) IsReadOnly() bool { return true }
func (t *GlobTool) IsDestructive() bool { return false }
func (t *GlobTool) Category() tool.Category { return tool.CategorySearch }
func (t *GlobTool) ExtractPaths(args json.RawMessage) []string {
var a globArgs
+1
View File
@@ -54,6 +54,7 @@ func (t *GrepTool) Description() string { return "Search file contents u
func (t *GrepTool) Parameters() json.RawMessage { return grepParams }
func (t *GrepTool) IsReadOnly() bool { return true }
func (t *GrepTool) IsDestructive() bool { return false }
func (t *GrepTool) Category() tool.Category { return tool.CategorySearch }
func (t *GrepTool) ExtractPaths(args json.RawMessage) []string {
var a grepArgs
+1
View File
@@ -37,6 +37,7 @@ func (t *LSTool) Description() string { return "List directory contents
func (t *LSTool) Parameters() json.RawMessage { return lsParams }
func (t *LSTool) IsReadOnly() bool { return true }
func (t *LSTool) IsDestructive() bool { return false }
func (t *LSTool) Category() tool.Category { return tool.CategoryRead }
func (t *LSTool) ExtractPaths(args json.RawMessage) []string {
var a lsArgs
+6 -5
View File
@@ -55,11 +55,12 @@ func NewReadTool(opts ...ReadOption) *ReadTool {
return t
}
func (t *ReadTool) Name() string { return readToolName }
func (t *ReadTool) Description() string { return "Read a file from the filesystem with optional offset and line limit" }
func (t *ReadTool) Parameters() json.RawMessage { return readParams }
func (t *ReadTool) IsReadOnly() bool { return true }
func (t *ReadTool) IsDestructive() bool { return false }
func (t *ReadTool) Name() string { return readToolName }
func (t *ReadTool) Description() string { return "Read a file from the filesystem with optional offset and line limit" }
func (t *ReadTool) Parameters() json.RawMessage { return readParams }
func (t *ReadTool) IsReadOnly() bool { return true }
func (t *ReadTool) IsDestructive() bool { return false }
func (t *ReadTool) Category() tool.Category { return tool.CategoryRead }
func (t *ReadTool) ExtractPaths(args json.RawMessage) []string {
var a readArgs
+1
View File
@@ -54,6 +54,7 @@ func (t *WriteTool) Description() string { return "Write content to a fi
func (t *WriteTool) Parameters() json.RawMessage { return writeParams }
func (t *WriteTool) IsReadOnly() bool { return false }
func (t *WriteTool) IsDestructive() bool { return false }
func (t *WriteTool) Category() tool.Category { return tool.CategoryWrite }
func (t *WriteTool) ExtractPaths(args json.RawMessage) []string {
var a writeArgs