feat(engine): two-stage tool routing for small local arms
Plan A from docs/superpowers/plans/2026-05-19-post-slm-unlock.md. Small local SLMs (<=16k context) waste ~1500 tokens per turn on the full tool catalogue. Two-stage routing replaces round-1 tools with a single synthetic select_category schema; round-2+ sends only the selected category's real tool schemas plus select_category for re-selection. - internal/tool/category.go: Category type, optional Categorized interface, CategoryOf() with meta fallback. fs.read/fs.ls -> read, fs.write/fs.edit -> write, fs.glob/fs.grep -> search, bash -> exec. - internal/engine/twostage.go: synthetic select_category tool, intercept helper, per-turn selectedCategory state under e.mu. - Engine round 1 forces ToolChoiceRequired so SLMs don't fall back to prose. State resets at the top and end of every runLoop. - Activates automatically on a forced local arm with ContextWindow <=16384, or via [router].force_two_stage TOML key. - Integration test drives a 3-round trip and asserts: round 1 emits exactly one schema (synthetic) with ToolChoiceRequired, round 2 contains only write-category schemas + select_category, real fs.write executes. Invalid-category fallback round-trips back to round-1 mode.
This commit is contained in:
+4
-3
@@ -750,9 +750,10 @@ func main() {
|
||||
Model: *model,
|
||||
Temperature: cfg.Provider.Temperature,
|
||||
MaxTurns: *maxTurns,
|
||||
Store: store,
|
||||
Hooks: dispatcher,
|
||||
Logger: logger,
|
||||
Store: store,
|
||||
Hooks: dispatcher,
|
||||
Logger: logger,
|
||||
ForceTwoStageTools: cfg.Router.ForceTwoStage,
|
||||
})
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: %v\n", err)
|
||||
|
||||
@@ -43,26 +43,49 @@ only that category's real schemas.
|
||||
|
||||
### Tasks
|
||||
|
||||
- [ ] `Category()` method on `tool.Tool` (or a registry-side mapping).
|
||||
- [x] `Category()` method on `tool.Tool` (or a registry-side mapping).
|
||||
Default category for unspecified tools: `meta`.
|
||||
- [ ] New `engine.useTwoStageTools(arm)` predicate. Gates on
|
||||
`arm.IsLocal && arm.Capabilities.ContextWindow <= 16384`, with an
|
||||
optional `[router].force_two_stage = true` config override.
|
||||
- [ ] Synthetic `select_category` tool definition emitted in
|
||||
- [x] New `engine.useTwoStageTools()` predicate. Gates on
|
||||
`arm.IsLocal && arm.Capabilities.ContextWindow <= 16384`, with the
|
||||
`[router].force_two_stage = true` config override.
|
||||
- [x] Synthetic `select_category` tool definition emitted in
|
||||
`buildRequest()` when the predicate is true and we're in the first
|
||||
round of a turn.
|
||||
- [ ] Engine recognises a `select_category` tool result and filters
|
||||
round of a turn. Round 1 also forces `ToolChoiceRequired` so SLMs
|
||||
don't fall back to prose instead of calling the tool.
|
||||
- [x] Engine recognises a `select_category` tool call and filters
|
||||
the next round's tool schemas. The selection itself doesn't run a
|
||||
real tool — it's consumed internally.
|
||||
- [ ] Integration test covering the round-trip with a mocked
|
||||
openaicompat arm.
|
||||
real tool — it's consumed internally; `select_category` remains
|
||||
available in round 2+ so the model can switch categories mid-turn.
|
||||
- [x] Integration test covering the round-trip with a recording
|
||||
mock provider.
|
||||
|
||||
**Exit criteria:** for an SLM-arm turn with ≤16 k context, the first
|
||||
request contains exactly one synthetic tool schema; the second
|
||||
contains only the schemas of the selected category. Real tool
|
||||
selection still works (the second-round tool call executes normally).
|
||||
**Status: shipped.** Phase A landed via the two-stage routing commit
|
||||
on `main`. Module map:
|
||||
- `internal/tool/category.go` — `Category` type, optional
|
||||
`Categorized` interface, `CategoryOf()` helper with `meta` fallback.
|
||||
- Real tools now declare categories: `fs.read`/`fs.ls` → read,
|
||||
`fs.write`/`fs.edit` → write, `fs.glob`/`fs.grep` → search,
|
||||
`bash` → exec. Agent/sysinfo fall through to meta.
|
||||
- `internal/engine/twostage.go` — synthetic tool definition, intercept
|
||||
helper, per-turn state (`selectedCategory`).
|
||||
- `internal/engine/engine.go` — `Config.ForceTwoStageTools`,
|
||||
`useTwoStageTools()` predicate, `twoStageContextLimit = 16384`.
|
||||
- `internal/config/config.go` — `[router].force_two_stage` TOML key
|
||||
wired through `cmd/gnoma/main.go`.
|
||||
|
||||
**Effort:** ~150 LOC + tests.
|
||||
**Exit criteria — met:** for an SLM-arm turn with ≤16 k context, the
|
||||
first request contains exactly one synthetic tool schema with
|
||||
`ToolChoiceRequired`; the second contains only schemas of the
|
||||
selected category plus `select_category`. Real tool selection still
|
||||
works end-to-end (verified by `TestTwoStage_FullRoundTrip`).
|
||||
|
||||
**Effort:** ~250 LOC + tests (including 4 test files).
|
||||
|
||||
**Deferred for follow-up (not Phase A blockers):**
|
||||
- Elf engines spawned from `internal/elf/manager.go` don't pass
|
||||
through `ForceTwoStageTools` — small local elves still get the full
|
||||
tool catalogue. Add per-elf two-stage detection mirroring the main
|
||||
engine's auto-activation when telemetry shows it's worth it.
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ type Config struct {
|
||||
Security SecuritySection `toml:"security"`
|
||||
Session SessionSection `toml:"session"`
|
||||
SLM SLMSection `toml:"slm"`
|
||||
Router RouterSection `toml:"router"`
|
||||
Hooks []HookConfig `toml:"hooks"`
|
||||
MCPServers []MCPServerConfig `toml:"mcp_servers"`
|
||||
Plugins PluginsSection `toml:"plugins"`
|
||||
@@ -39,6 +40,17 @@ type SLMSection struct {
|
||||
StartupTimeout Duration `toml:"startup_timeout"` // llamafile-only: first-launch wait budget; 0 = default 5s
|
||||
}
|
||||
|
||||
// RouterSection holds router-level overrides. Most routing decisions are
|
||||
// driven automatically by arm capabilities and the bandit; this section
|
||||
// exists for the rare overrides that don't fit elsewhere.
|
||||
type RouterSection struct {
|
||||
// ForceTwoStage forces the two-stage tool-routing path regardless of
|
||||
// arm context window. Useful for debugging or for forcing the behavior
|
||||
// on a large local model. Defaults to false: two-stage activates
|
||||
// automatically on local arms with context window <= 16k.
|
||||
ForceTwoStage bool `toml:"force_two_stage"`
|
||||
}
|
||||
|
||||
// MCPServerConfig defines an MCP server to start and connect to.
|
||||
//
|
||||
// Example:
|
||||
|
||||
@@ -95,7 +95,9 @@ func TestBuildRequest_ForcedArmWithToolSupport_IncludesTools(t *testing.T) {
|
||||
Provider: &mockProvider{name: "llamacpp"},
|
||||
ModelName: "qwen3",
|
||||
IsLocal: true,
|
||||
Capabilities: provider.Capabilities{ToolUse: true},
|
||||
// ContextWindow > 16384 keeps two-stage routing inactive so this
|
||||
// test exercises the plain "tools included" path.
|
||||
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 32768},
|
||||
})
|
||||
rtr.ForceArm("llamacpp/qwen3")
|
||||
|
||||
|
||||
@@ -33,8 +33,18 @@ type Config struct {
|
||||
Store *persist.Store // nil = no result persistence
|
||||
Hooks *hook.Dispatcher // nil = no hooks
|
||||
Logger *slog.Logger
|
||||
|
||||
// ForceTwoStageTools forces the two-stage tool-routing path on the forced
|
||||
// arm regardless of its context window. When false, two-stage is enabled
|
||||
// automatically for local arms with ContextWindow <= twoStageContextLimit.
|
||||
ForceTwoStageTools bool
|
||||
}
|
||||
|
||||
// twoStageContextLimit is the upper bound on arm context window (in tokens)
|
||||
// under which two-stage tool routing kicks in automatically. Models bigger
|
||||
// than this can afford the full tool catalogue in every request.
|
||||
const twoStageContextLimit = 16384
|
||||
|
||||
func (c Config) validate() error {
|
||||
if c.Provider == nil {
|
||||
return fmt.Errorf("engine: provider required")
|
||||
@@ -77,6 +87,12 @@ type Engine struct {
|
||||
activatedTools map[string]bool
|
||||
|
||||
turnOpts TurnOptions
|
||||
|
||||
// selectedCategory is set when the model picks a category via the
|
||||
// synthetic select_category tool under two-stage routing. Empty string
|
||||
// means "round 1 of two-stage" (or two-stage inactive). Reset at the end
|
||||
// of each turn together with turnOpts.
|
||||
selectedCategory tool.Category
|
||||
}
|
||||
|
||||
// ToolsAvailable reports whether the current model supports tool calling.
|
||||
@@ -128,6 +144,40 @@ func (e *Engine) isLocalArm() bool {
|
||||
return arm.IsLocal
|
||||
}
|
||||
|
||||
// useTwoStageTools reports whether the current turn should use the two-stage
|
||||
// tool-routing path. True when:
|
||||
// - cfg.ForceTwoStageTools is set, OR
|
||||
// - a forced arm exists, is local, and its ContextWindow is small enough
|
||||
// that the full tool catalogue would burn a non-trivial fraction of the
|
||||
// prompt budget.
|
||||
//
|
||||
// Multi-arm routing (no forced arm) and cloud arms do not trigger two-stage.
|
||||
func (e *Engine) useTwoStageTools() bool {
|
||||
if e.cfg.ForceTwoStageTools {
|
||||
return true
|
||||
}
|
||||
if e.cfg.Router == nil {
|
||||
return false
|
||||
}
|
||||
id := e.cfg.Router.ForcedArm()
|
||||
if id == "" {
|
||||
return false
|
||||
}
|
||||
arm, ok := e.cfg.Router.LookupArm(id)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if !arm.IsLocal {
|
||||
return false
|
||||
}
|
||||
cw := arm.Capabilities.ContextWindow
|
||||
if cw <= 0 {
|
||||
// Unknown context window on a local arm — assume small.
|
||||
return true
|
||||
}
|
||||
return cw <= twoStageContextLimit
|
||||
}
|
||||
|
||||
// New creates an engine.
|
||||
func New(cfg Config) (*Engine, error) {
|
||||
if err := cfg.validate(); err != nil {
|
||||
|
||||
+53
-18
@@ -62,6 +62,11 @@ func (e *Engine) SubmitMessages(ctx context.Context, msgs []message.Message, cb
|
||||
}
|
||||
|
||||
func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
|
||||
// Two-stage tool-routing state is per-turn; clear it so an aborted turn
|
||||
// can't leak its last category selection into the next one.
|
||||
e.resetTwoStageState()
|
||||
defer e.resetTwoStageState()
|
||||
|
||||
turn := &Turn{}
|
||||
loopStart := time.Now()
|
||||
var lastArmID router.ArmID
|
||||
@@ -481,26 +486,51 @@ func (e *Engine) buildRequest(ctx context.Context) provider.Request {
|
||||
includeTools = caps == nil || caps.ToolUse
|
||||
}
|
||||
if includeTools {
|
||||
allowed := turnOpts.AllowedTools
|
||||
for _, t := range e.cfg.Tools.All() {
|
||||
// Skip deferred tools until the model requests them
|
||||
if dt, ok := t.(tool.DeferrableTool); ok && dt.ShouldDefer() && !e.isToolActivated(t.Name()) {
|
||||
continue
|
||||
twoStage := e.useTwoStageTools()
|
||||
selected := e.snapshotSelectedCategory()
|
||||
|
||||
if twoStage && selected == "" {
|
||||
// Round 1 of two-stage: send only the synthetic select_category tool
|
||||
// and force the model to call it. Small SLMs given a single optional
|
||||
// tool will often emit prose instead of calling it.
|
||||
req.Tools = []provider.ToolDefinition{buildSelectCategoryDef()}
|
||||
req.ToolChoice = provider.ToolChoiceRequired
|
||||
e.logger.Debug("two-stage: round 1 — emitting select_category only",
|
||||
"model", req.Model,
|
||||
)
|
||||
} else {
|
||||
allowed := turnOpts.AllowedTools
|
||||
for _, t := range e.cfg.Tools.All() {
|
||||
// Skip deferred tools until the model requests them
|
||||
if dt, ok := t.(tool.DeferrableTool); ok && dt.ShouldDefer() && !e.isToolActivated(t.Name()) {
|
||||
continue
|
||||
}
|
||||
// Filter to allowed tools when a restrict list is set
|
||||
if allowed != nil && !slices.Contains(allowed, t.Name()) {
|
||||
continue
|
||||
}
|
||||
// Under two-stage round 2+, only schemas in the selected category.
|
||||
if twoStage && tool.CategoryOf(t) != selected {
|
||||
continue
|
||||
}
|
||||
req.Tools = append(req.Tools, provider.ToolDefinition{
|
||||
Name: t.Name(),
|
||||
Description: t.Description(),
|
||||
Parameters: t.Parameters(),
|
||||
})
|
||||
}
|
||||
// Filter to allowed tools when a restrict list is set
|
||||
if allowed != nil && !slices.Contains(allowed, t.Name()) {
|
||||
continue
|
||||
// Keep select_category available while two-stage is active so the
|
||||
// model can switch categories without aborting the turn.
|
||||
if twoStage {
|
||||
req.Tools = append(req.Tools, buildSelectCategoryDef())
|
||||
}
|
||||
req.Tools = append(req.Tools, provider.ToolDefinition{
|
||||
Name: t.Name(),
|
||||
Description: t.Description(),
|
||||
Parameters: t.Parameters(),
|
||||
})
|
||||
e.logger.Debug("tools included in request",
|
||||
"model", req.Model,
|
||||
"count", len(req.Tools),
|
||||
"two_stage", twoStage,
|
||||
"category", string(selected),
|
||||
)
|
||||
}
|
||||
e.logger.Debug("tools included in request",
|
||||
"model", req.Model,
|
||||
"count", len(req.Tools),
|
||||
)
|
||||
} else {
|
||||
e.logger.Debug("tools omitted — model does not support tool use",
|
||||
"model", req.Model,
|
||||
@@ -542,6 +572,10 @@ Synthesis:
|
||||
}
|
||||
|
||||
func (e *Engine) executeTools(ctx context.Context, calls []message.ToolCall, cb Callback) ([]message.ToolResult, error) {
|
||||
// Intercept the synthetic select_category tool first — it never reaches
|
||||
// the registry and produces its own synthetic tool result.
|
||||
calls, syntheticResults := e.interceptSelectCategoryCalls(calls)
|
||||
|
||||
// Partition into read-only (parallel) and write (serial) batches
|
||||
type toolCallWithTool struct {
|
||||
call message.ToolCall
|
||||
@@ -577,7 +611,8 @@ func (e *Engine) executeTools(ctx context.Context, calls []message.ToolCall, cb
|
||||
}
|
||||
}
|
||||
|
||||
results := make([]message.ToolResult, 0, len(calls))
|
||||
results := make([]message.ToolResult, 0, len(calls)+len(syntheticResults))
|
||||
results = append(results, syntheticResults...)
|
||||
results = append(results, unknownResults...)
|
||||
|
||||
// Execute read-only tools in parallel
|
||||
|
||||
@@ -0,0 +1,139 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||||
"somegit.dev/Owlibou/gnoma/internal/tool"
|
||||
)
|
||||
|
||||
// SyntheticSelectCategoryName is the tool name used by the two-stage routing
|
||||
// path to let the model pick a category before real tool schemas are sent.
|
||||
// The name is exported for tests that need to assert against it.
|
||||
const SyntheticSelectCategoryName = "select_category"
|
||||
|
||||
// buildSelectCategoryDef constructs the synthetic select_category tool
|
||||
// definition. Categories in the enum match tool.AllCategories() so the
|
||||
// schema stays in sync if categories are added.
|
||||
func buildSelectCategoryDef() provider.ToolDefinition {
|
||||
cats := tool.AllCategories()
|
||||
quoted := make([]string, len(cats))
|
||||
for i, c := range cats {
|
||||
quoted[i] = `"` + string(c) + `"`
|
||||
}
|
||||
params := json.RawMessage(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"category": {
|
||||
"type": "string",
|
||||
"enum": [` + strings.Join(quoted, ", ") + `],
|
||||
"description": "Tool category to load schemas for. Pick one based on what you intend to do next."
|
||||
}
|
||||
},
|
||||
"required": ["category"]
|
||||
}`)
|
||||
return provider.ToolDefinition{
|
||||
Name: SyntheticSelectCategoryName,
|
||||
Description: "Select the category of tools to load for the next round. " +
|
||||
"Use 'read' for file reads, 'write' to modify files, 'search' to search the codebase, " +
|
||||
"'exec' to run commands, 'meta' for agent orchestration and introspection. " +
|
||||
"You can call this again later to switch categories.",
|
||||
Parameters: params,
|
||||
}
|
||||
}
|
||||
|
||||
// snapshotSelectedCategory returns the category chosen by the model so far
|
||||
// in this turn (or empty string if none).
|
||||
func (e *Engine) snapshotSelectedCategory() tool.Category {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
return e.selectedCategory
|
||||
}
|
||||
|
||||
// setSelectedCategory records the model's category choice under lock.
|
||||
func (e *Engine) setSelectedCategory(c tool.Category) {
|
||||
e.mu.Lock()
|
||||
e.selectedCategory = c
|
||||
e.mu.Unlock()
|
||||
}
|
||||
|
||||
// resetTwoStageState clears any per-turn two-stage state. Called at the start
|
||||
// of every runLoop so an aborted previous turn cannot leak state forward.
|
||||
func (e *Engine) resetTwoStageState() {
|
||||
e.mu.Lock()
|
||||
e.selectedCategory = ""
|
||||
e.mu.Unlock()
|
||||
}
|
||||
|
||||
// interceptSelectCategoryCalls splits the incoming tool calls into two
|
||||
// buckets: real calls that need actual tool execution, and synthetic
|
||||
// select_category calls that the engine handles internally. It updates
|
||||
// e.selectedCategory as a side effect and returns synthetic tool results
|
||||
// that satisfy the provider's "every tool_call needs a tool_result" contract.
|
||||
func (e *Engine) interceptSelectCategoryCalls(calls []message.ToolCall) (realCalls []message.ToolCall, syntheticResults []message.ToolResult) {
|
||||
for _, call := range calls {
|
||||
if call.Name != SyntheticSelectCategoryName {
|
||||
realCalls = append(realCalls, call)
|
||||
continue
|
||||
}
|
||||
result := e.handleSelectCategory(call)
|
||||
syntheticResults = append(syntheticResults, result)
|
||||
}
|
||||
return realCalls, syntheticResults
|
||||
}
|
||||
|
||||
// handleSelectCategory parses the synthetic tool call, updates engine state,
|
||||
// and returns the tool result the model will see in the next round.
|
||||
func (e *Engine) handleSelectCategory(call message.ToolCall) message.ToolResult {
|
||||
var args struct {
|
||||
Category string `json:"category"`
|
||||
}
|
||||
if err := json.Unmarshal(call.Arguments, &args); err != nil {
|
||||
e.setSelectedCategory("")
|
||||
return message.ToolResult{
|
||||
ToolCallID: call.ID,
|
||||
Content: fmt.Sprintf("invalid arguments for select_category: %v. Please pick one of: read, write, search, exec, meta.", err),
|
||||
IsError: true,
|
||||
}
|
||||
}
|
||||
|
||||
cat := tool.Category(strings.ToLower(strings.TrimSpace(args.Category)))
|
||||
if !tool.IsValidCategory(cat) {
|
||||
e.setSelectedCategory("")
|
||||
return message.ToolResult{
|
||||
ToolCallID: call.ID,
|
||||
Content: fmt.Sprintf("unknown category %q. Pick one of: read, write, search, exec, meta.", args.Category),
|
||||
IsError: true,
|
||||
}
|
||||
}
|
||||
|
||||
e.setSelectedCategory(cat)
|
||||
available := e.toolNamesForCategory(cat)
|
||||
content := fmt.Sprintf("Category %q selected. Tools now available: %s. Call them directly on the next turn.",
|
||||
cat, strings.Join(available, ", "))
|
||||
if e.logger != nil {
|
||||
e.logger.Debug("two-stage: category selected",
|
||||
"category", cat,
|
||||
"tools", available,
|
||||
)
|
||||
}
|
||||
return message.ToolResult{
|
||||
ToolCallID: call.ID,
|
||||
Content: content,
|
||||
}
|
||||
}
|
||||
|
||||
// toolNamesForCategory returns the registered tool names whose category
|
||||
// matches the argument, in deterministic order.
|
||||
func (e *Engine) toolNamesForCategory(cat tool.Category) []string {
|
||||
var names []string
|
||||
for _, t := range e.cfg.Tools.All() {
|
||||
if tool.CategoryOf(t) == cat {
|
||||
names = append(names, t.Name())
|
||||
}
|
||||
}
|
||||
return names
|
||||
}
|
||||
@@ -0,0 +1,185 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"slices"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||||
"somegit.dev/Owlibou/gnoma/internal/stream"
|
||||
"somegit.dev/Owlibou/gnoma/internal/tool"
|
||||
)
|
||||
|
||||
// recordingProvider captures every Request it receives so tests can assert
|
||||
// on the tool catalogue per round.
|
||||
type recordingProvider struct {
|
||||
mu sync.Mutex
|
||||
requests []provider.Request
|
||||
streams []stream.Stream
|
||||
calls int
|
||||
}
|
||||
|
||||
func (m *recordingProvider) Name() string { return "recording" }
|
||||
func (m *recordingProvider) DefaultModel() string { return "mock-model" }
|
||||
func (m *recordingProvider) Models(_ context.Context) ([]provider.ModelInfo, error) {
|
||||
return []provider.ModelInfo{{
|
||||
ID: "mock-model", Name: "mock-model", Provider: "recording",
|
||||
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 8192},
|
||||
}}, nil
|
||||
}
|
||||
func (m *recordingProvider) Stream(_ context.Context, req provider.Request) (stream.Stream, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
// Clone request so subsequent rounds don't mutate captured state.
|
||||
clone := req
|
||||
clone.Tools = slices.Clone(req.Tools)
|
||||
m.requests = append(m.requests, clone)
|
||||
if m.calls >= len(m.streams) {
|
||||
return nil, fmt.Errorf("recording: no more streams (called %d times)", m.calls+1)
|
||||
}
|
||||
s := m.streams[m.calls]
|
||||
m.calls++
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// singleToolCallStream returns a stream that emits a single tool call.
|
||||
func singleToolCallStream(callID, name, args string) stream.Stream {
|
||||
return newEventStream(message.StopToolUse, "mock-model",
|
||||
stream.Event{Type: stream.EventToolCallStart, ToolCallID: callID, ToolCallName: name},
|
||||
stream.Event{Type: stream.EventToolCallDone, ToolCallID: callID, ToolCallName: name, Args: json.RawMessage(args)},
|
||||
)
|
||||
}
|
||||
|
||||
// endTurnTextStream returns a stream that emits text and ends the turn.
|
||||
func endTurnTextStream(text string) stream.Stream {
|
||||
return newEventStream(message.StopEndTurn, "mock-model",
|
||||
stream.Event{Type: stream.EventTextDelta, Text: text},
|
||||
)
|
||||
}
|
||||
|
||||
func TestTwoStage_FullRoundTrip(t *testing.T) {
|
||||
// Tool registry: mixed categories so we can verify round 2 filtering.
|
||||
writeCalled := false
|
||||
reg := tool.NewRegistry()
|
||||
reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.read", readOnly: true}, cat: tool.CategoryRead})
|
||||
reg.Register(&categorizedMockTool{
|
||||
mockTool: mockTool{
|
||||
name: "fs.write",
|
||||
execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) {
|
||||
writeCalled = true
|
||||
return tool.Result{Output: "wrote file"}, nil
|
||||
},
|
||||
},
|
||||
cat: tool.CategoryWrite,
|
||||
})
|
||||
reg.Register(&categorizedMockTool{mockTool: mockTool{name: "bash"}, cat: tool.CategoryExec})
|
||||
|
||||
// Three rounds: select_category → fs.write → end turn.
|
||||
mp := &recordingProvider{
|
||||
streams: []stream.Stream{
|
||||
singleToolCallStream("c1", SyntheticSelectCategoryName, `{"category":"write"}`),
|
||||
singleToolCallStream("c2", "fs.write", `{"path":"/tmp/x","content":"hi"}`),
|
||||
endTurnTextStream("done."),
|
||||
},
|
||||
}
|
||||
|
||||
e, err := New(Config{
|
||||
Provider: mp,
|
||||
Tools: reg,
|
||||
ForceTwoStageTools: true, // no router needed; just force the path
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
|
||||
turn, err := e.Submit(context.Background(), "write a file", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Submit: %v", err)
|
||||
}
|
||||
|
||||
if turn.Rounds != 3 {
|
||||
t.Errorf("Rounds = %d, want 3", turn.Rounds)
|
||||
}
|
||||
if !writeCalled {
|
||||
t.Error("fs.write tool was not executed")
|
||||
}
|
||||
|
||||
if len(mp.requests) != 3 {
|
||||
t.Fatalf("captured %d requests, want 3", len(mp.requests))
|
||||
}
|
||||
|
||||
// Round 1: only synthetic select_category, ToolChoice = Required.
|
||||
r1 := mp.requests[0]
|
||||
if len(r1.Tools) != 1 {
|
||||
t.Errorf("round 1 tool count = %d, want 1; tools=%v", len(r1.Tools), toolNamesIn(r1.Tools))
|
||||
}
|
||||
if len(r1.Tools) >= 1 && r1.Tools[0].Name != SyntheticSelectCategoryName {
|
||||
t.Errorf("round 1 tool[0] = %q, want %q", r1.Tools[0].Name, SyntheticSelectCategoryName)
|
||||
}
|
||||
if r1.ToolChoice != provider.ToolChoiceRequired {
|
||||
t.Errorf("round 1 ToolChoice = %q, want %q", r1.ToolChoice, provider.ToolChoiceRequired)
|
||||
}
|
||||
|
||||
// Round 2: write tools + select_category, no read/exec.
|
||||
r2names := toolNamesIn(mp.requests[1].Tools)
|
||||
if !slices.Contains(r2names, "fs.write") {
|
||||
t.Errorf("round 2 missing fs.write: %v", r2names)
|
||||
}
|
||||
if !slices.Contains(r2names, SyntheticSelectCategoryName) {
|
||||
t.Errorf("round 2 missing select_category (re-selection should remain available): %v", r2names)
|
||||
}
|
||||
if slices.Contains(r2names, "fs.read") || slices.Contains(r2names, "bash") {
|
||||
t.Errorf("round 2 leaked non-write tools: %v", r2names)
|
||||
}
|
||||
|
||||
// Round 3: same filter still applied (selection persists for the turn).
|
||||
r3names := toolNamesIn(mp.requests[2].Tools)
|
||||
if slices.Contains(r3names, "fs.read") || slices.Contains(r3names, "bash") {
|
||||
t.Errorf("round 3 leaked non-write tools: %v", r3names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTwoStage_InvalidCategoryFallsBackToRoundOne(t *testing.T) {
|
||||
reg := tool.NewRegistry()
|
||||
reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.write"}, cat: tool.CategoryWrite})
|
||||
|
||||
mp := &recordingProvider{
|
||||
streams: []stream.Stream{
|
||||
// Round 1: model picks an invalid category.
|
||||
singleToolCallStream("c1", SyntheticSelectCategoryName, `{"category":"bogus"}`),
|
||||
// Round 2: should see select_category-only again (round 1 mode).
|
||||
singleToolCallStream("c2", SyntheticSelectCategoryName, `{"category":"write"}`),
|
||||
// Round 3: filtered to write.
|
||||
endTurnTextStream("done."),
|
||||
},
|
||||
}
|
||||
|
||||
e, err := New(Config{
|
||||
Provider: mp,
|
||||
Tools: reg,
|
||||
ForceTwoStageTools: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
if _, err := e.Submit(context.Background(), "do stuff", nil); err != nil {
|
||||
t.Fatalf("Submit: %v", err)
|
||||
}
|
||||
|
||||
if len(mp.requests) < 2 {
|
||||
t.Fatalf("expected at least 2 requests, got %d", len(mp.requests))
|
||||
}
|
||||
|
||||
// After invalid category, round 2 should be back to round-1 mode (only synthetic).
|
||||
r2 := mp.requests[1]
|
||||
if len(r2.Tools) != 1 || r2.Tools[0].Name != SyntheticSelectCategoryName {
|
||||
t.Errorf("after invalid category, expected round-1 mode again; got tools=%v", toolNamesIn(r2.Tools))
|
||||
}
|
||||
if r2.ToolChoice != provider.ToolChoiceRequired {
|
||||
t.Errorf("after invalid category, expected ToolChoiceRequired; got %q", r2.ToolChoice)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,419 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"slices"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||||
"somegit.dev/Owlibou/gnoma/internal/router"
|
||||
"somegit.dev/Owlibou/gnoma/internal/tool"
|
||||
)
|
||||
|
||||
// categorizedMockTool extends mockTool with a Category() method so tests can
|
||||
// exercise the two-stage category filter.
|
||||
type categorizedMockTool struct {
|
||||
mockTool
|
||||
cat tool.Category
|
||||
}
|
||||
|
||||
func (c *categorizedMockTool) Category() tool.Category { return c.cat }
|
||||
|
||||
// twoStageEngine builds an engine wired to a small local forced arm so
|
||||
// useTwoStageTools() returns true.
|
||||
func twoStageEngine(t *testing.T, reg *tool.Registry) *Engine {
|
||||
t.Helper()
|
||||
rtr := router.New(router.Config{})
|
||||
rtr.RegisterArm(&router.Arm{
|
||||
ID: "llamacpp/qwen3-1b",
|
||||
Provider: &mockProvider{name: "llamacpp"},
|
||||
ModelName: "qwen3-1b",
|
||||
IsLocal: true,
|
||||
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 8192},
|
||||
})
|
||||
rtr.ForceArm("llamacpp/qwen3-1b")
|
||||
e, err := New(Config{
|
||||
Provider: &mockProvider{name: "llamacpp"},
|
||||
Router: rtr,
|
||||
Tools: reg,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
func TestUseTwoStageTools(t *testing.T) {
|
||||
smallLocal := &router.Arm{
|
||||
ID: "llamacpp/qwen3-1b",
|
||||
Provider: &mockProvider{name: "llamacpp"},
|
||||
ModelName: "qwen3-1b",
|
||||
IsLocal: true,
|
||||
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 8192},
|
||||
}
|
||||
bigLocal := &router.Arm{
|
||||
ID: "llamacpp/qwen3-30b",
|
||||
Provider: &mockProvider{name: "llamacpp"},
|
||||
ModelName: "qwen3-30b",
|
||||
IsLocal: true,
|
||||
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 32768},
|
||||
}
|
||||
cloud := &router.Arm{
|
||||
ID: "anthropic/sonnet",
|
||||
Provider: &mockProvider{name: "anthropic"},
|
||||
ModelName: "sonnet",
|
||||
IsLocal: false,
|
||||
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000},
|
||||
}
|
||||
localUnknownCtx := &router.Arm{
|
||||
ID: "ollama/mystery",
|
||||
Provider: &mockProvider{name: "ollama"},
|
||||
ModelName: "mystery",
|
||||
IsLocal: true,
|
||||
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 0},
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
arm *router.Arm
|
||||
forced bool
|
||||
want bool
|
||||
message string
|
||||
}{
|
||||
{
|
||||
name: "small local arm triggers two-stage",
|
||||
arm: smallLocal,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "large local arm does not trigger two-stage",
|
||||
arm: bigLocal,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "cloud arm never triggers two-stage",
|
||||
arm: cloud,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "local arm with unknown context window triggers two-stage",
|
||||
arm: localUnknownCtx,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "ForceTwoStageTools overrides cloud arm",
|
||||
arm: cloud,
|
||||
forced: true,
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
rtr := router.New(router.Config{})
|
||||
rtr.RegisterArm(tc.arm)
|
||||
rtr.ForceArm(tc.arm.ID)
|
||||
|
||||
e, err := New(Config{
|
||||
Provider: &mockProvider{name: string(tc.arm.ID.Provider())},
|
||||
Router: rtr,
|
||||
Tools: tool.NewRegistry(),
|
||||
ForceTwoStageTools: tc.forced,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
|
||||
if got := e.useTwoStageTools(); got != tc.want {
|
||||
t.Errorf("useTwoStageTools() = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseTwoStageTools_NoRouter(t *testing.T) {
|
||||
e, err := New(Config{
|
||||
Provider: &mockProvider{name: "anthropic"},
|
||||
Tools: tool.NewRegistry(),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
if e.useTwoStageTools() {
|
||||
t.Error("useTwoStageTools() = true, want false when no router")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseTwoStageTools_NoRouter_ForcedOverride(t *testing.T) {
|
||||
e, err := New(Config{
|
||||
Provider: &mockProvider{name: "anthropic"},
|
||||
Tools: tool.NewRegistry(),
|
||||
ForceTwoStageTools: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
if !e.useTwoStageTools() {
|
||||
t.Error("useTwoStageTools() = false, want true when ForceTwoStageTools is set even without router")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseTwoStageTools_NoForcedArm(t *testing.T) {
|
||||
rtr := router.New(router.Config{})
|
||||
rtr.RegisterArm(&router.Arm{
|
||||
ID: "llamacpp/qwen3-1b",
|
||||
Provider: &mockProvider{name: "llamacpp"},
|
||||
ModelName: "qwen3-1b",
|
||||
IsLocal: true,
|
||||
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 8192},
|
||||
})
|
||||
// No ForceArm called — multi-arm routing
|
||||
e, err := New(Config{
|
||||
Provider: &mockProvider{name: "llamacpp"},
|
||||
Router: rtr,
|
||||
Tools: tool.NewRegistry(),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
if e.useTwoStageTools() {
|
||||
t.Error("useTwoStageTools() = true, want false for multi-arm routing")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRequest_TwoStage_Round1_EmitsSyntheticOnly(t *testing.T) {
|
||||
reg := tool.NewRegistry()
|
||||
reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.read", readOnly: true}, cat: tool.CategoryRead})
|
||||
reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.write"}, cat: tool.CategoryWrite})
|
||||
reg.Register(&categorizedMockTool{mockTool: mockTool{name: "bash"}, cat: tool.CategoryExec})
|
||||
|
||||
e := twoStageEngine(t, reg)
|
||||
|
||||
req := e.buildRequest(context.Background())
|
||||
|
||||
if len(req.Tools) != 1 {
|
||||
t.Fatalf("round 1 should emit exactly one tool (synthetic); got %d: %+v", len(req.Tools), req.Tools)
|
||||
}
|
||||
if req.Tools[0].Name != SyntheticSelectCategoryName {
|
||||
t.Errorf("round 1 tool name = %q, want %q", req.Tools[0].Name, SyntheticSelectCategoryName)
|
||||
}
|
||||
if req.ToolChoice != provider.ToolChoiceRequired {
|
||||
t.Errorf("round 1 ToolChoice = %q, want %q", req.ToolChoice, provider.ToolChoiceRequired)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRequest_TwoStage_Round1_SyntheticEnumMatchesAllCategories(t *testing.T) {
|
||||
reg := tool.NewRegistry()
|
||||
reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.read", readOnly: true}, cat: tool.CategoryRead})
|
||||
e := twoStageEngine(t, reg)
|
||||
|
||||
req := e.buildRequest(context.Background())
|
||||
if len(req.Tools) != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d", len(req.Tools))
|
||||
}
|
||||
|
||||
var schema struct {
|
||||
Properties struct {
|
||||
Category struct {
|
||||
Enum []string `json:"enum"`
|
||||
} `json:"category"`
|
||||
} `json:"properties"`
|
||||
}
|
||||
if err := json.Unmarshal(req.Tools[0].Parameters, &schema); err != nil {
|
||||
t.Fatalf("unmarshal synthetic params: %v", err)
|
||||
}
|
||||
want := make([]string, 0, len(tool.AllCategories()))
|
||||
for _, c := range tool.AllCategories() {
|
||||
want = append(want, string(c))
|
||||
}
|
||||
got := slices.Clone(schema.Properties.Category.Enum)
|
||||
sort.Strings(got)
|
||||
sort.Strings(want)
|
||||
if !slices.Equal(got, want) {
|
||||
t.Errorf("category enum = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRequest_TwoStage_Round2_FiltersByCategory(t *testing.T) {
|
||||
reg := tool.NewRegistry()
|
||||
reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.read", readOnly: true}, cat: tool.CategoryRead})
|
||||
reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.write"}, cat: tool.CategoryWrite})
|
||||
reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.edit"}, cat: tool.CategoryWrite})
|
||||
reg.Register(&categorizedMockTool{mockTool: mockTool{name: "bash"}, cat: tool.CategoryExec})
|
||||
|
||||
e := twoStageEngine(t, reg)
|
||||
e.setSelectedCategory(tool.CategoryWrite)
|
||||
|
||||
req := e.buildRequest(context.Background())
|
||||
|
||||
names := toolNamesIn(req.Tools)
|
||||
sort.Strings(names)
|
||||
want := []string{SyntheticSelectCategoryName, "fs.edit", "fs.write"}
|
||||
sort.Strings(want)
|
||||
if !slices.Equal(names, want) {
|
||||
t.Errorf("round 2 tools = %v, want %v", names, want)
|
||||
}
|
||||
|
||||
// ToolChoice should not be forced in round 2 — let the model decide.
|
||||
if req.ToolChoice == provider.ToolChoiceRequired {
|
||||
t.Errorf("round 2 ToolChoice = %q, should not be Required", req.ToolChoice)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRequest_TwoStage_UncategorizedToolDefaultsToMeta(t *testing.T) {
|
||||
reg := tool.NewRegistry()
|
||||
reg.Register(&mockTool{name: "agent"}) // no Category() method → meta
|
||||
reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.read", readOnly: true}, cat: tool.CategoryRead})
|
||||
|
||||
e := twoStageEngine(t, reg)
|
||||
e.setSelectedCategory(tool.CategoryMeta)
|
||||
|
||||
req := e.buildRequest(context.Background())
|
||||
names := toolNamesIn(req.Tools)
|
||||
if !slices.Contains(names, "agent") {
|
||||
t.Errorf("meta filter should include uncategorized 'agent'; got %v", names)
|
||||
}
|
||||
if slices.Contains(names, "fs.read") {
|
||||
t.Errorf("meta filter should exclude 'fs.read'; got %v", names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRequest_NonTwoStage_UnchangedBehavior(t *testing.T) {
|
||||
// Large local arm — two-stage should not activate.
|
||||
rtr := router.New(router.Config{})
|
||||
rtr.RegisterArm(&router.Arm{
|
||||
ID: "llamacpp/qwen3-30b",
|
||||
Provider: &mockProvider{name: "llamacpp"},
|
||||
ModelName: "qwen3-30b",
|
||||
IsLocal: true,
|
||||
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 32768},
|
||||
})
|
||||
rtr.ForceArm("llamacpp/qwen3-30b")
|
||||
|
||||
reg := tool.NewRegistry()
|
||||
reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.read", readOnly: true}, cat: tool.CategoryRead})
|
||||
reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.write"}, cat: tool.CategoryWrite})
|
||||
|
||||
e, err := New(Config{
|
||||
Provider: &mockProvider{name: "llamacpp"},
|
||||
Router: rtr,
|
||||
Tools: reg,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
|
||||
req := e.buildRequest(context.Background())
|
||||
if len(req.Tools) != 2 {
|
||||
t.Errorf("non-two-stage path: got %d tools, want 2", len(req.Tools))
|
||||
}
|
||||
for _, td := range req.Tools {
|
||||
if td.Name == SyntheticSelectCategoryName {
|
||||
t.Errorf("non-two-stage path should not emit synthetic select_category")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInterceptSelectCategoryCalls_UpdatesStateAndReturnsResult(t *testing.T) {
|
||||
reg := tool.NewRegistry()
|
||||
reg.Register(&categorizedMockTool{mockTool: mockTool{name: "fs.write"}, cat: tool.CategoryWrite})
|
||||
|
||||
e := twoStageEngine(t, reg)
|
||||
|
||||
call := message.ToolCall{
|
||||
ID: "call-1",
|
||||
Name: SyntheticSelectCategoryName,
|
||||
Arguments: json.RawMessage(`{"category": "write"}`),
|
||||
}
|
||||
real, synth := e.interceptSelectCategoryCalls([]message.ToolCall{call})
|
||||
|
||||
if len(real) != 0 {
|
||||
t.Errorf("real calls = %d, want 0", len(real))
|
||||
}
|
||||
if len(synth) != 1 {
|
||||
t.Fatalf("synthetic results = %d, want 1", len(synth))
|
||||
}
|
||||
if synth[0].IsError {
|
||||
t.Errorf("synthetic result should not be error: %s", synth[0].Content)
|
||||
}
|
||||
if synth[0].ToolCallID != "call-1" {
|
||||
t.Errorf("ToolCallID = %q, want call-1", synth[0].ToolCallID)
|
||||
}
|
||||
if e.snapshotSelectedCategory() != tool.CategoryWrite {
|
||||
t.Errorf("selectedCategory = %q, want write", e.snapshotSelectedCategory())
|
||||
}
|
||||
}
|
||||
|
||||
func TestInterceptSelectCategoryCalls_InvalidCategoryReturnsError(t *testing.T) {
|
||||
e := twoStageEngine(t, tool.NewRegistry())
|
||||
|
||||
call := message.ToolCall{
|
||||
ID: "call-1",
|
||||
Name: SyntheticSelectCategoryName,
|
||||
Arguments: json.RawMessage(`{"category": "bogus"}`),
|
||||
}
|
||||
_, synth := e.interceptSelectCategoryCalls([]message.ToolCall{call})
|
||||
if len(synth) != 1 {
|
||||
t.Fatalf("synthetic results = %d, want 1", len(synth))
|
||||
}
|
||||
if !synth[0].IsError {
|
||||
t.Errorf("invalid category should yield error result")
|
||||
}
|
||||
if e.snapshotSelectedCategory() != "" {
|
||||
t.Errorf("invalid category should clear selectedCategory; got %q", e.snapshotSelectedCategory())
|
||||
}
|
||||
}
|
||||
|
||||
func TestInterceptSelectCategoryCalls_InvalidJSONReturnsError(t *testing.T) {
|
||||
e := twoStageEngine(t, tool.NewRegistry())
|
||||
|
||||
call := message.ToolCall{
|
||||
ID: "call-1",
|
||||
Name: SyntheticSelectCategoryName,
|
||||
Arguments: json.RawMessage(`not-json`),
|
||||
}
|
||||
_, synth := e.interceptSelectCategoryCalls([]message.ToolCall{call})
|
||||
if len(synth) != 1 || !synth[0].IsError {
|
||||
t.Fatalf("invalid JSON should yield single error result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInterceptSelectCategoryCalls_MixedRealAndSynthetic(t *testing.T) {
|
||||
e := twoStageEngine(t, tool.NewRegistry())
|
||||
|
||||
calls := []message.ToolCall{
|
||||
{ID: "real-1", Name: "fs.read", Arguments: json.RawMessage(`{}`)},
|
||||
{ID: "synth-1", Name: SyntheticSelectCategoryName, Arguments: json.RawMessage(`{"category":"read"}`)},
|
||||
{ID: "real-2", Name: "bash", Arguments: json.RawMessage(`{}`)},
|
||||
}
|
||||
real, synth := e.interceptSelectCategoryCalls(calls)
|
||||
if len(real) != 2 {
|
||||
t.Errorf("real calls = %d, want 2", len(real))
|
||||
}
|
||||
if len(synth) != 1 {
|
||||
t.Errorf("synthetic results = %d, want 1", len(synth))
|
||||
}
|
||||
if e.snapshotSelectedCategory() != tool.CategoryRead {
|
||||
t.Errorf("selectedCategory = %q, want read", e.snapshotSelectedCategory())
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetTwoStageState_ClearsCategory(t *testing.T) {
|
||||
e := twoStageEngine(t, tool.NewRegistry())
|
||||
e.setSelectedCategory(tool.CategoryExec)
|
||||
e.resetTwoStageState()
|
||||
if e.snapshotSelectedCategory() != "" {
|
||||
t.Errorf("resetTwoStageState should clear category; got %q", e.snapshotSelectedCategory())
|
||||
}
|
||||
}
|
||||
|
||||
func toolNamesIn(defs []provider.ToolDefinition) []string {
|
||||
names := make([]string, len(defs))
|
||||
for i, d := range defs {
|
||||
names[i] = d.Name
|
||||
}
|
||||
return names
|
||||
}
|
||||
@@ -66,6 +66,7 @@ func (t *Tool) Description() string { return "Execute a bash command and
|
||||
func (t *Tool) Parameters() json.RawMessage { return parameterSchema }
|
||||
func (t *Tool) IsReadOnly() bool { return false }
|
||||
func (t *Tool) IsDestructive() bool { return true }
|
||||
func (t *Tool) Category() tool.Category { return tool.CategoryExec }
|
||||
|
||||
type bashArgs struct {
|
||||
Command string `json:"command"`
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
package tool
|
||||
|
||||
// Category groups tools by what they do. Used by the two-stage tool routing
|
||||
// path for small local models: round 1 picks a category, round 2 only sees
|
||||
// schemas in that category.
|
||||
type Category string
|
||||
|
||||
const (
|
||||
// CategoryRead — read filesystem state (fs.read, fs.ls).
|
||||
CategoryRead Category = "read"
|
||||
// CategoryWrite — modify filesystem state (fs.write, fs.edit).
|
||||
CategoryWrite Category = "write"
|
||||
// CategorySearch — search filesystem content (fs.grep, fs.glob).
|
||||
CategorySearch Category = "search"
|
||||
// CategoryExec — execute external commands (bash).
|
||||
CategoryExec Category = "exec"
|
||||
// CategoryMeta — agent orchestration, introspection, and result handling.
|
||||
// Default for tools that don't declare a category.
|
||||
CategoryMeta Category = "meta"
|
||||
)
|
||||
|
||||
// AllCategories returns the canonical category list in stable order.
|
||||
func AllCategories() []Category {
|
||||
return []Category{CategoryRead, CategoryWrite, CategorySearch, CategoryExec, CategoryMeta}
|
||||
}
|
||||
|
||||
// IsValidCategory reports whether c is one of the known categories.
|
||||
func IsValidCategory(c Category) bool {
|
||||
switch c {
|
||||
case CategoryRead, CategoryWrite, CategorySearch, CategoryExec, CategoryMeta:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Categorized is the optional interface a tool implements to declare its
|
||||
// category. Tools that don't implement it fall back to CategoryMeta.
|
||||
type Categorized interface {
|
||||
Category() Category
|
||||
}
|
||||
|
||||
// CategoryOf returns the tool's declared category, or CategoryMeta if the
|
||||
// tool does not implement Categorized.
|
||||
func CategoryOf(t Tool) Category {
|
||||
if c, ok := t.(Categorized); ok {
|
||||
cat := c.Category()
|
||||
if IsValidCategory(cat) {
|
||||
return cat
|
||||
}
|
||||
}
|
||||
return CategoryMeta
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
package tool
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type categorizedStub struct {
|
||||
stubTool
|
||||
cat Category
|
||||
}
|
||||
|
||||
func (c *categorizedStub) Category() Category { return c.cat }
|
||||
|
||||
func TestCategoryOf_DefaultIsMeta(t *testing.T) {
|
||||
plain := &stubTool{name: "plain"}
|
||||
if got := CategoryOf(plain); got != CategoryMeta {
|
||||
t.Errorf("CategoryOf(plain) = %q, want %q", got, CategoryMeta)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCategoryOf_DeclaredCategory(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
cat Category
|
||||
}{
|
||||
{"read", CategoryRead},
|
||||
{"write", CategoryWrite},
|
||||
{"search", CategorySearch},
|
||||
{"exec", CategoryExec},
|
||||
{"meta", CategoryMeta},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
s := &categorizedStub{stubTool: stubTool{name: tc.name}, cat: tc.cat}
|
||||
if got := CategoryOf(s); got != tc.cat {
|
||||
t.Errorf("CategoryOf(%s) = %q, want %q", tc.name, got, tc.cat)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCategoryOf_InvalidFallsBackToMeta(t *testing.T) {
|
||||
s := &categorizedStub{stubTool: stubTool{name: "bogus"}, cat: Category("not-a-real-category")}
|
||||
if got := CategoryOf(s); got != CategoryMeta {
|
||||
t.Errorf("CategoryOf(invalid) = %q, want %q", got, CategoryMeta)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsValidCategory(t *testing.T) {
|
||||
for _, c := range AllCategories() {
|
||||
if !IsValidCategory(c) {
|
||||
t.Errorf("IsValidCategory(%q) = false, want true", c)
|
||||
}
|
||||
}
|
||||
if IsValidCategory(Category("")) {
|
||||
t.Error("empty category should be invalid")
|
||||
}
|
||||
if IsValidCategory(Category("nope")) {
|
||||
t.Error("unknown category should be invalid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllCategoriesStable(t *testing.T) {
|
||||
want := []Category{CategoryRead, CategoryWrite, CategorySearch, CategoryExec, CategoryMeta}
|
||||
got := AllCategories()
|
||||
if !slices.Equal(got, want) {
|
||||
t.Errorf("AllCategories() = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
@@ -48,6 +48,7 @@ func (t *EditTool) Description() string { return "Perform exact string r
|
||||
func (t *EditTool) Parameters() json.RawMessage { return editParams }
|
||||
func (t *EditTool) IsReadOnly() bool { return false }
|
||||
func (t *EditTool) IsDestructive() bool { return false }
|
||||
func (t *EditTool) Category() tool.Category { return tool.CategoryWrite }
|
||||
|
||||
func (t *EditTool) ExtractPaths(args json.RawMessage) []string {
|
||||
var a editArgs
|
||||
|
||||
@@ -43,6 +43,7 @@ func (t *GlobTool) Description() string { return "Find files matching a
|
||||
func (t *GlobTool) Parameters() json.RawMessage { return globParams }
|
||||
func (t *GlobTool) IsReadOnly() bool { return true }
|
||||
func (t *GlobTool) IsDestructive() bool { return false }
|
||||
func (t *GlobTool) Category() tool.Category { return tool.CategorySearch }
|
||||
|
||||
func (t *GlobTool) ExtractPaths(args json.RawMessage) []string {
|
||||
var a globArgs
|
||||
|
||||
@@ -54,6 +54,7 @@ func (t *GrepTool) Description() string { return "Search file contents u
|
||||
func (t *GrepTool) Parameters() json.RawMessage { return grepParams }
|
||||
func (t *GrepTool) IsReadOnly() bool { return true }
|
||||
func (t *GrepTool) IsDestructive() bool { return false }
|
||||
func (t *GrepTool) Category() tool.Category { return tool.CategorySearch }
|
||||
|
||||
func (t *GrepTool) ExtractPaths(args json.RawMessage) []string {
|
||||
var a grepArgs
|
||||
|
||||
@@ -37,6 +37,7 @@ func (t *LSTool) Description() string { return "List directory contents
|
||||
func (t *LSTool) Parameters() json.RawMessage { return lsParams }
|
||||
func (t *LSTool) IsReadOnly() bool { return true }
|
||||
func (t *LSTool) IsDestructive() bool { return false }
|
||||
func (t *LSTool) Category() tool.Category { return tool.CategoryRead }
|
||||
|
||||
func (t *LSTool) ExtractPaths(args json.RawMessage) []string {
|
||||
var a lsArgs
|
||||
|
||||
@@ -55,11 +55,12 @@ func NewReadTool(opts ...ReadOption) *ReadTool {
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *ReadTool) Name() string { return readToolName }
|
||||
func (t *ReadTool) Description() string { return "Read a file from the filesystem with optional offset and line limit" }
|
||||
func (t *ReadTool) Parameters() json.RawMessage { return readParams }
|
||||
func (t *ReadTool) IsReadOnly() bool { return true }
|
||||
func (t *ReadTool) IsDestructive() bool { return false }
|
||||
func (t *ReadTool) Name() string { return readToolName }
|
||||
func (t *ReadTool) Description() string { return "Read a file from the filesystem with optional offset and line limit" }
|
||||
func (t *ReadTool) Parameters() json.RawMessage { return readParams }
|
||||
func (t *ReadTool) IsReadOnly() bool { return true }
|
||||
func (t *ReadTool) IsDestructive() bool { return false }
|
||||
func (t *ReadTool) Category() tool.Category { return tool.CategoryRead }
|
||||
|
||||
func (t *ReadTool) ExtractPaths(args json.RawMessage) []string {
|
||||
var a readArgs
|
||||
|
||||
@@ -54,6 +54,7 @@ func (t *WriteTool) Description() string { return "Write content to a fi
|
||||
func (t *WriteTool) Parameters() json.RawMessage { return writeParams }
|
||||
func (t *WriteTool) IsReadOnly() bool { return false }
|
||||
func (t *WriteTool) IsDestructive() bool { return false }
|
||||
func (t *WriteTool) Category() tool.Category { return tool.CategoryWrite }
|
||||
|
||||
func (t *WriteTool) ExtractPaths(args json.RawMessage) []string {
|
||||
var a writeArgs
|
||||
|
||||
Reference in New Issue
Block a user