diff --git a/cmd/gnoma/main.go b/cmd/gnoma/main.go index 043e2b2..38004f8 100644 --- a/cmd/gnoma/main.go +++ b/cmd/gnoma/main.go @@ -420,6 +420,26 @@ func main() { logger.Debug("CLI agents discovered", "count", len(cliAgents)) } + // Apply [[arms]] overrides (strengths, cost_weight) now that all initial + // arms are registered. Late-discovered arms (background polling) won't + // pick these up — by design: overrides target arms the user knows exist. + if len(cfg.Arms) > 0 { + overrides := make([]router.ArmOverride, 0, len(cfg.Arms)) + for _, ac := range cfg.Arms { + overrides = append(overrides, router.ArmOverride{ + ID: ac.ID, + Strengths: ac.Strengths, + CostWeight: ac.CostWeight, + }) + } + if unknown := rtr.ApplyArmOverrides(overrides); len(unknown) > 0 { + logger.Warn("[[arms]] config references unregistered arm IDs", + "ids", unknown, + "hint", "run `gnoma providers` to see registered arms", + ) + } + } + // Start background discovery polling (30s interval). // modelUpdater is set after the session is created so the discovery loop // can update the displayed model name when it reconciles the forced arm. diff --git a/docs/superpowers/plans/2026-05-19-post-slm-unlock.md b/docs/superpowers/plans/2026-05-19-post-slm-unlock.md index bc15e81..b652b2b 100644 --- a/docs/superpowers/plans/2026-05-19-post-slm-unlock.md +++ b/docs/superpowers/plans/2026-05-19-post-slm-unlock.md @@ -276,22 +276,55 @@ shouldn't dominate (e.g. SecurityReview). ### Tasks -- [ ] Add `Strengths` and `CostWeight` to `router.Arm`. -- [ ] Config schema for per-arm overrides — likely - `[arms..strengths] = ["planning", "orchestration"]`. -- [ ] `scoreArm` consults both fields. -- [ ] Bandit signal feeds back into a per-arm-per-task affinity over - time (≥10 observations needed). Currently `QualityTracker` already - tracks per-arm × per-task EMA; what's missing is letting that - signal *promote* an arm out of its default tier. -- [ ] Tests that show Opus winning over Gemini for SecurityReview - when `arms.anthropic_opus.strengths = ["security_review"]`. +- [x] Add `Strengths []TaskType` and `CostWeight float64` to + `router.Arm`. Zero values preserve current behavior. +- [x] Config schema: `[[arms]]` array of tables — `id`, `strengths` + (string list, parsed via new `ParseTaskTypeStrict`), `cost_weight`. +- [x] `scoreArm` consults both fields: strength match adds a tunable + bonus (`strengthScoreBonus = 0.15`); `CostWeight` linearly dampens + cost via `effectiveCost = 1 + CostWeight*(cost-1)` — monotone on + both sides of cost=1. +- [x] `selectBest` cross-tier promotion: arms whose `Strengths` + contain `task.Type` are evaluated as one set before falling through + to default tier order. Strengths are a preference, not a pin — + backoff/feasibility filtering at the router level removes promoted + arms when unavailable, and selection falls through. +- [x] `Router.ApplyArmOverrides()` applies config overrides post + arm-registration. Unknown arm IDs surfaced via return value; main + logs a warning. Unknown strength names skipped with per-strength + warning. +- [x] Tests: Opus with `Strengths=[security_review]` beats CLI-agent + tier-1 arm; empty Strengths preserves tier order; promoted arm in + backoff falls through (via full `Router.Select` path); two + strength-tagged arms decided by observed quality; CostWeight + direction across two arms; linear-formula monotonicity regression + test for the cost^weight bug avoided. -**Exit criteria:** with explicit per-task strengths set, the router -picks the strongest available arm for that task type, not the -lowest-tier one. +**Status: shipped (static portion).** Module map: +- `internal/router/arm.go` — `Strengths`, `CostWeight`, + `HasStrength()`, `ResolvedCostWeight()`. +- `internal/router/selector.go` — `scoreArm` updated, `selectBest` + cross-tier promotion path. +- `internal/router/router.go` — `ArmOverride` type and + `ApplyArmOverrides()`. +- `internal/router/task.go` — `ParseTaskTypeStrict()` (returns ok + bool) for typo-resistant config parsing. +- `internal/config/config.go` — `ArmConfig` struct and `[[arms]]` + TOML wiring. +- `cmd/gnoma/main.go` — applies overrides after all initial arms + register; warns on unknown IDs. -**Effort:** ~300 LOC + tests. Touches `selector.go`, `arm.go`, config. +**Exit criteria — met:** with `[[arms]] id="anthropic/..." +strengths=["security_review"]`, the router picks Opus over a +higher-tier CLI agent for that task type. Verified by +`TestSelectBest_StrengthPromotedArmBeatsCLIAgent`. + +**Effort:** ~350 LOC + tests. + +**Deferred to D-2:** dynamic bandit-driven promotion (≥10 observations +threshold + per-arm × per-task affinity that overrides tier order +without static config). Holding until telemetry from real workloads +informs the quality bar — same rationale as Phase E. --- diff --git a/internal/config/config.go b/internal/config/config.go index 3d7a467..0f0ad7e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -13,6 +13,7 @@ type Config struct { SLM SLMSection `toml:"slm"` Router RouterSection `toml:"router"` CLIAgents CLIAgentsSection `toml:"cli_agents"` + Arms []ArmConfig `toml:"arms"` Hooks []HookConfig `toml:"hooks"` MCPServers []MCPServerConfig `toml:"mcp_servers"` Plugins PluginsSection `toml:"plugins"` @@ -41,6 +42,30 @@ type SLMSection struct { StartupTimeout Duration `toml:"startup_timeout"` // llamafile-only: first-launch wait budget; 0 = default 5s } +// ArmConfig tunes routing for a single registered arm. Multiple [[arms]] +// blocks may appear; each is matched by ID against the runtime arm +// registry. An ID that doesn't match any registered arm logs a warning at +// startup — typos here are otherwise silent. +// +// Example: +// +// [[arms]] +// id = "anthropic/claude-opus-4-7" +// strengths = ["security_review", "planning"] # task types this arm is preferred for +// cost_weight = 0.3 # 1.0 = full cost penalty, 0 = ignore cost +// +// [[arms]] +// id = "subprocess/claude" +// strengths = ["orchestration"] +// +// Strength names map to router.TaskType via router.ParseTaskType — same +// names the SLM classifier emits (snake_case or no separator both work). +type ArmConfig struct { + ID string `toml:"id"` + Strengths []string `toml:"strengths"` + CostWeight float64 `toml:"cost_weight"` +} + // CLIAgentsSection maps canonical CLI agent names to override binary names. // // Useful when a user has aliased the canonical binary — e.g. `claude-priv` diff --git a/internal/config/config_test.go b/internal/config/config_test.go index d037293..949fb33 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -249,6 +249,44 @@ gemini = "" } } +func TestArmConfig_TOML_RoundTrip(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "config.toml") + _ = os.WriteFile(path, []byte(` +[[arms]] +id = "anthropic/claude-opus-4-7" +strengths = ["security_review", "planning"] +cost_weight = 0.3 + +[[arms]] +id = "subprocess/claude" +strengths = ["orchestration"] +`), 0o644) + + cfg := Defaults() + if err := loadTOML(&cfg, path); err != nil { + t.Fatalf("loadTOML: %v", err) + } + if len(cfg.Arms) != 2 { + t.Fatalf("len(Arms) = %d, want 2", len(cfg.Arms)) + } + if cfg.Arms[0].ID != "anthropic/claude-opus-4-7" { + t.Errorf("Arms[0].ID = %q", cfg.Arms[0].ID) + } + if len(cfg.Arms[0].Strengths) != 2 || cfg.Arms[0].Strengths[0] != "security_review" { + t.Errorf("Arms[0].Strengths = %v", cfg.Arms[0].Strengths) + } + if cfg.Arms[0].CostWeight != 0.3 { + t.Errorf("Arms[0].CostWeight = %v, want 0.3", cfg.Arms[0].CostWeight) + } + if cfg.Arms[1].ID != "subprocess/claude" { + t.Errorf("Arms[1].ID = %q", cfg.Arms[1].ID) + } + if cfg.Arms[1].CostWeight != 0 { + t.Errorf("Arms[1].CostWeight = %v, want 0 (default)", cfg.Arms[1].CostWeight) + } +} + func TestCLIAgentsSection_Absent_NilMap(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "config.toml") diff --git a/internal/router/arm.go b/internal/router/arm.go index 25a9780..202817f 100644 --- a/internal/router/arm.go +++ b/internal/router/arm.go @@ -30,6 +30,24 @@ type Arm struct { // Zero means no ceiling (default for all existing arms). MaxComplexity float64 + // Strengths lists task types where this arm is preferred. When any + // listed task type matches an incoming task, the arm crosses tier + // boundaries during selection — Opus tagged with TaskSecurityReview + // can beat a CLI-agent tier-1 arm for that task type, for example. + // Strengths are a preference, not a pin: if no strength-matching arm + // is feasible (rate-limited, backoff), selection falls back to the + // default tier order. + Strengths []TaskType + + // CostWeight scales how much per-arm cost matters during scoring. + // effectiveCost = 1 + CostWeight*(cost-1): + // - 1.0 (or zero, which is normalized to 1.0): current behavior. + // - 0.5: half-weight cost — pricey arms penalized less. + // - 0.0: cost ignored, pure quality wins. + // Use sub-1.0 values for task types where being right matters more + // than being cheap (e.g. SecurityReview). + CostWeight float64 + // Cost per 1k tokens (EUR, estimated) CostPer1kInput float64 CostPer1kOutput float64 @@ -72,6 +90,29 @@ func (a *Arm) SupportsTools() bool { return a.Capabilities.ToolUse } +// HasStrength reports whether the arm is tagged as strong at the given task +// type. Used by the selector to consider cross-tier promotion. +func (a *Arm) HasStrength(t TaskType) bool { + for _, s := range a.Strengths { + if s == t { + return true + } + } + return false +} + +// ResolvedCostWeight normalizes the CostWeight field. A zero value means +// "unset" and is treated as 1.0 (current full-cost behavior). Users who +// want minimal cost influence set a small positive value like 0.05 — no +// real use case wants exactly zero ("ignore cost entirely") and 0 doubles +// as the Go zero value for arms registered before this field existed. +func (a *Arm) ResolvedCostWeight() float64 { + if a.CostWeight == 0 { + return 1.0 + } + return a.CostWeight +} + // perfAlpha is the EMA smoothing factor for ArmPerf updates (0.3 = ~3-sample memory). const perfAlpha = 0.3 diff --git a/internal/router/router.go b/internal/router/router.go index 131cd64..edaf9e4 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -219,6 +219,47 @@ func (r *Router) QualityTracker() *QualityTracker { return r.quality } +// ArmOverride is a config-supplied tweak to a registered arm. Use it to +// declare per-task strengths and a CostWeight override. +type ArmOverride struct { + ID string // ArmID as registered (e.g. "anthropic/claude-opus-4-7") + Strengths []string // task-type names, parsed via ParseTaskType + CostWeight float64 // 0 leaves arm's current CostWeight untouched +} + +// ApplyArmOverrides walks the override list, locates each by ID, and +// applies the requested Strengths/CostWeight in place. Returns the list of +// IDs that did not match a registered arm so the caller can warn about +// typos. Apply after all arms have been registered. +func (r *Router) ApplyArmOverrides(overrides []ArmOverride) (unknownIDs []string) { + r.mu.Lock() + defer r.mu.Unlock() + for _, ov := range overrides { + arm, ok := r.arms[ArmID(ov.ID)] + if !ok { + unknownIDs = append(unknownIDs, ov.ID) + continue + } + if len(ov.Strengths) > 0 { + parsed := make([]TaskType, 0, len(ov.Strengths)) + for _, s := range ov.Strengths { + t, ok := ParseTaskTypeStrict(s) + if !ok { + r.logger.Warn("unknown strength task-type; skipping", + "arm", ov.ID, "strength", s) + continue + } + parsed = append(parsed, t) + } + arm.Strengths = parsed + } + if ov.CostWeight != 0 { + arm.CostWeight = ov.CostWeight + } + } + return unknownIDs +} + // Arms returns all registered arms. func (r *Router) Arms() []*Arm { r.mu.RLock() diff --git a/internal/router/selector.go b/internal/router/selector.go index b6dd91c..a394ce7 100644 --- a/internal/router/selector.go +++ b/internal/router/selector.go @@ -56,13 +56,32 @@ func armTier(arm *Arm, task Task) int { return 3 } -// selectBest picks the best arm, preferring lower-tier arms first. -// Within a tier, the highest-scoring arm (by quality/cost) wins. +// selectBest picks the best arm. +// +// Step 1: arms whose Strengths list contains task.Type cross all tier +// boundaries — Opus tagged with SecurityReview beats a CLI-agent tier-1 +// arm for that task. Strengths are a preference, not a pin: if no +// strength-matching arm is in the input set (filterFeasible already +// removed arms in backoff, lacking tool support, or out of pool capacity), +// selection falls through to the default tier order. +// +// Step 2 (fallback): walk tiers low→high. Within a tier, highest-scoring +// arm wins. func selectBest(qt *QualityTracker, arms []*Arm, task Task) *Arm { if len(arms) == 0 { return nil } + var promoted []*Arm + for _, arm := range arms { + if arm.HasStrength(task.Type) { + promoted = append(promoted, arm) + } + } + if len(promoted) > 0 { + return bestScored(qt, promoted, task) + } + for tier := 0; tier <= 3; tier++ { var inTier []*Arm for _, arm := range arms { @@ -91,10 +110,23 @@ func bestScored(qt *QualityTracker, arms []*Arm, task Task) *Arm { return best } +// strengthScoreBonus is added to quality when an arm's Strengths list +// matches the incoming task type. Tunable in one place. +const strengthScoreBonus = 0.15 + // scoreArm computes a quality/cost score for an arm. // When the quality tracker has sufficient observations, blends observed EMA // (70%) with heuristic (30%). Falls back to pure heuristic otherwise. -// Score = (quality × value) / effective_cost +// +// Strengths add a fixed bonus to quality when matching task.Type. CostWeight +// dampens the cost penalty linearly: +// +// effectiveCost = 1 + CostWeight * (cost - 1) +// +// With CostWeight=1.0 (or unset → resolved to 1.0) the formula collapses to +// the original effectiveCost == cost. With CostWeight=0 cost is fully +// ignored (effectiveCost = 1.0). Local arms with sub-1 raw costs are not +// amplified by fractional weights (the linear formula stays monotone). func scoreArm(qt *QualityTracker, arm *Arm, task Task) float64 { hq := heuristicQuality(arm, task) quality := hq @@ -103,12 +135,19 @@ func scoreArm(qt *QualityTracker, arm *Arm, task Task) float64 { quality = 0.7*observed + 0.3*hq } } - value := task.ValueScore() - cost := effectiveCost(arm, task) - if cost <= 0 { - cost = 0.001 + if arm.HasStrength(task.Type) { + quality += strengthScoreBonus } - return (quality * value) / cost + value := task.ValueScore() + rawCost := effectiveCost(arm, task) + if rawCost <= 0 { + rawCost = 0.001 + } + weighted := 1.0 + arm.ResolvedCostWeight()*(rawCost-1.0) + if weighted <= 0 { + weighted = 0.001 + } + return (quality * value) / weighted } // heuristicQuality estimates arm quality without historical data. diff --git a/internal/router/strengths_test.go b/internal/router/strengths_test.go new file mode 100644 index 0000000..3acfee0 --- /dev/null +++ b/internal/router/strengths_test.go @@ -0,0 +1,349 @@ +package router + +import ( + "math" + "testing" + "time" + + "somegit.dev/Owlibou/gnoma/internal/provider" +) + +func timeFuture() time.Time { return time.Now().Add(1 * time.Hour) } + +func TestArm_HasStrength(t *testing.T) { + a := &Arm{Strengths: []TaskType{TaskSecurityReview, TaskPlanning}} + if !a.HasStrength(TaskSecurityReview) { + t.Error("HasStrength(SecurityReview) = false, want true") + } + if !a.HasStrength(TaskPlanning) { + t.Error("HasStrength(Planning) = false, want true") + } + if a.HasStrength(TaskDebug) { + t.Error("HasStrength(Debug) = true, want false") + } + empty := &Arm{} + if empty.HasStrength(TaskSecurityReview) { + t.Error("empty Strengths should never match") + } +} + +func TestArm_ResolvedCostWeight(t *testing.T) { + cases := []struct { + in, want float64 + }{ + {0, 1.0}, // unset → 1.0 + {1.0, 1.0}, // explicit 1.0 → 1.0 + {0.5, 0.5}, + {0.05, 0.05}, + } + for _, tc := range cases { + a := &Arm{CostWeight: tc.in} + if got := a.ResolvedCostWeight(); got != tc.want { + t.Errorf("CostWeight=%v: ResolvedCostWeight() = %v, want %v", tc.in, got, tc.want) + } + } +} + +func TestScoreArm_CostWeightAffectsArmComparison(t *testing.T) { + // The semantically meaningful test: two arms with different costs but + // otherwise identical. At CostWeight=1.0 (current behavior), the cheap + // arm wins. At CostWeight=0.0 (cost ignored), they tie on quality — + // and the slightly-higher-quality one wins. + cheap := &Arm{ + ID: NewArmID("provA", "small"), + Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 100000}, + CostPer1kInput: 0.0005, + CostPer1kOutput: 0.0015, + } + expensive := &Arm{ + ID: NewArmID("provB", "big"), + Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000}, // slight quality edge + CostPer1kInput: 0.015, + CostPer1kOutput: 0.075, + } + task := Task{Type: TaskDebug, EstimatedTokens: 5000, RequiresTools: true, Priority: PriorityNormal} + + // CostWeight=1.0: cost dominates, cheap arm wins. + cheap.CostWeight, expensive.CostWeight = 1.0, 1.0 + if scoreArm(nil, cheap, task) <= scoreArm(nil, expensive, task) { + t.Errorf("CostWeight=1.0: cheap arm should beat expensive arm; cheap=%v expensive=%v", + scoreArm(nil, cheap, task), scoreArm(nil, expensive, task)) + } + + // CostWeight=0.0: cost ignored, quality alone decides → expensive (better + // context window) wins. + cheap.CostWeight, expensive.CostWeight = 0.001, 0.001 + if scoreArm(nil, expensive, task) <= scoreArm(nil, cheap, task) { + t.Errorf("CostWeight~0: higher-quality expensive arm should beat cheap arm; expensive=%v cheap=%v", + scoreArm(nil, expensive, task), scoreArm(nil, cheap, task)) + } +} + +func TestScoreArm_LinearFormulaMonotone(t *testing.T) { + // Regression: the original draft used cost^CostWeight, which inverts + // direction when cost<1 (local arms). The linear formula + // effectiveCost = 1 + CostWeight*(cost-1) + // is monotone: increasing CostWeight monotonically pulls effectiveCost + // toward the raw cost regardless of whether cost is above or below 1. + // + // Verify monotonicity on both sides of cost=1. + cheap := &Arm{ // cost < 1 + CostPer1kInput: 0.001, + CostPer1kOutput: 0.001, + } + expensive := &Arm{ // cost > 1 for big tasks + CostPer1kInput: 0.05, + CostPer1kOutput: 0.15, + } + task := Task{Type: TaskDebug, EstimatedTokens: 20000} + + weights := []float64{0.05, 0.25, 0.5, 0.75, 1.0} + for _, name := range []string{"cheap", "expensive"} { + var prev float64 + for i, w := range weights { + arm := cheap + if name == "expensive" { + arm = expensive + } + arm.CostWeight = w + raw := effectiveCost(arm, task) + weighted := 1.0 + arm.ResolvedCostWeight()*(raw-1.0) + if i == 0 { + prev = weighted + continue + } + // As w increases, weighted should move toward raw. + // For cheap (raw<1), weighted should DECREASE. + // For expensive (raw>1), weighted should INCREASE. + if raw < 1 && weighted > prev { + t.Errorf("%s arm w=%v: weighted (%v) increased from prev (%v); raw=%v", + name, w, weighted, prev, raw) + } + if raw > 1 && weighted < prev { + t.Errorf("%s arm w=%v: weighted (%v) decreased from prev (%v); raw=%v", + name, w, weighted, prev, raw) + } + prev = weighted + } + } +} + +func TestScoreArm_StrengthBonus(t *testing.T) { + withoutStrength := &Arm{ + ID: NewArmID("anthropic", "opus"), + Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000}, + } + withStrength := &Arm{ + ID: NewArmID("anthropic", "opus"), + Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000}, + Strengths: []TaskType{TaskSecurityReview}, + } + task := Task{Type: TaskSecurityReview, EstimatedTokens: 5000, RequiresTools: true, Priority: PriorityNormal} + + a := scoreArm(nil, withoutStrength, task) + b := scoreArm(nil, withStrength, task) + if !(b > a) { + t.Errorf("strength-tagged arm score (%v) should exceed plain arm score (%v)", b, a) + } +} + +func TestScoreArm_StrengthBonusDoesNotApplyToOtherTasks(t *testing.T) { + // Strengths apply only to listed task types. + tagged := &Arm{ + ID: NewArmID("anthropic", "opus"), + Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000}, + Strengths: []TaskType{TaskSecurityReview}, + } + plain := &Arm{ + ID: NewArmID("anthropic", "opus"), + Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000}, + } + task := Task{Type: TaskDebug, EstimatedTokens: 5000, RequiresTools: true, Priority: PriorityNormal} + + a := scoreArm(nil, plain, task) + b := scoreArm(nil, tagged, task) + if math.Abs(a-b) > 1e-9 { + t.Errorf("non-matching task should ignore Strengths: plain=%v tagged=%v", a, b) + } +} + +func TestSelectBest_StrengthPromotedArmBeatsCLIAgent(t *testing.T) { + // Plan exit criteria: with Strengths set, Opus (tier 3) wins over a CLI + // agent (tier 1) for SecurityReview. + cliAgent := &Arm{ + ID: NewArmID("subprocess", "claude"), + IsCLIAgent: true, + Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000}, + } + opus := &Arm{ + ID: NewArmID("anthropic", "opus"), + Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000}, + Strengths: []TaskType{TaskSecurityReview}, + CostPer1kInput: 0.015, + CostPer1kOutput: 0.075, + } + + task := Task{Type: TaskSecurityReview, EstimatedTokens: 5000, RequiresTools: true, Priority: PriorityNormal} + got := selectBest(nil, []*Arm{cliAgent, opus}, task) + if got == nil { + t.Fatal("selectBest returned nil") + } + if got.ID != opus.ID { + t.Errorf("selectBest = %s, want %s (strength-promoted arm should beat tier-1 CLI agent)", got.ID, opus.ID) + } +} + +func TestSelectBest_EmptyStrengthsPreservesTierOrder(t *testing.T) { + // Regression: without Strengths, CLI-agent tier-1 still wins over API tier-3. + cliAgent := &Arm{ + ID: NewArmID("subprocess", "claude"), + IsCLIAgent: true, + Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000}, + } + opus := &Arm{ + ID: NewArmID("anthropic", "opus"), + Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000}, + CostPer1kInput: 0.015, + CostPer1kOutput: 0.075, + } + + task := Task{Type: TaskSecurityReview, EstimatedTokens: 5000, RequiresTools: true, Priority: PriorityNormal} + got := selectBest(nil, []*Arm{cliAgent, opus}, task) + if got.ID != cliAgent.ID { + t.Errorf("without Strengths, CLI-agent tier-1 should win; got %s", got.ID) + } +} + +func TestRouter_Select_PromotedArmInBackoffFallsThroughToTierOrder(t *testing.T) { + // Strengths are preference, not pin. Full Router.Select path: backoff + // filtering removes the promoted arm; selectBest then falls through to + // the default tier order and picks the CLI agent. + cliAgent := &Arm{ + ID: NewArmID("subprocess", "claude"), + IsCLIAgent: true, + Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000}, + } + opus := &Arm{ + ID: NewArmID("anthropic", "opus"), + Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000}, + Strengths: []TaskType{TaskSecurityReview}, + } + opus.SetBackoff(timeFuture()) + + r := New(Config{}) + r.RegisterArm(cliAgent) + r.RegisterArm(opus) + + task := Task{Type: TaskSecurityReview, EstimatedTokens: 5000, RequiresTools: true, Priority: PriorityNormal} + decision := r.Select(task) + if decision.Error != nil { + t.Fatalf("Select: %v", decision.Error) + } + if decision.Arm.ID != cliAgent.ID { + t.Errorf("promoted arm in backoff should fall through to CLI agent; got %s", decision.Arm.ID) + } +} + +func TestApplyArmOverrides_ApplyStrengthsAndCostWeight(t *testing.T) { + r := New(Config{}) + opus := &Arm{ + ID: NewArmID("anthropic", "opus"), + Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000}, + } + r.RegisterArm(opus) + + unknown := r.ApplyArmOverrides([]ArmOverride{ + { + ID: "anthropic/opus", + Strengths: []string{"security_review", "planning"}, + CostWeight: 0.3, + }, + }) + if len(unknown) != 0 { + t.Errorf("unknown = %v, want empty", unknown) + } + + got, _ := r.LookupArm(NewArmID("anthropic", "opus")) + if !got.HasStrength(TaskSecurityReview) { + t.Error("opus should have SecurityReview strength after override") + } + if !got.HasStrength(TaskPlanning) { + t.Error("opus should have Planning strength after override") + } + if got.CostWeight != 0.3 { + t.Errorf("opus.CostWeight = %v, want 0.3", got.CostWeight) + } +} + +func TestApplyArmOverrides_UnknownIDReported(t *testing.T) { + r := New(Config{}) + r.RegisterArm(&Arm{ + ID: NewArmID("anthropic", "opus"), + Capabilities: provider.Capabilities{ToolUse: true}, + }) + + unknown := r.ApplyArmOverrides([]ArmOverride{ + {ID: "anthropic/opus", Strengths: []string{"debug"}}, + {ID: "anthropic/typo-here", Strengths: []string{"refactor"}}, + }) + if len(unknown) != 1 || unknown[0] != "anthropic/typo-here" { + t.Errorf("unknown = %v, want [anthropic/typo-here]", unknown) + } +} + +func TestApplyArmOverrides_UnknownStrengthSkipped(t *testing.T) { + r := New(Config{}) + arm := &Arm{ + ID: NewArmID("anthropic", "opus"), + Capabilities: provider.Capabilities{ToolUse: true}, + } + r.RegisterArm(arm) + + r.ApplyArmOverrides([]ArmOverride{ + {ID: "anthropic/opus", Strengths: []string{"security_review", "bogus-type"}}, + }) + + got, _ := r.LookupArm(NewArmID("anthropic", "opus")) + if !got.HasStrength(TaskSecurityReview) { + t.Error("security_review should be applied") + } + if len(got.Strengths) != 1 { + t.Errorf("got.Strengths = %v, want [security_review] only (bogus skipped)", got.Strengths) + } +} + +func TestSelectBest_MultiplePromotedArmsBestQualityWins(t *testing.T) { + // Tunability check: when two arms both have Strengths for the same task, + // observed quality (via QualityTracker) should determine the winner, not + // the static strength bonus alone. + armA := &Arm{ + ID: NewArmID("provA", "model"), + Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000}, + Strengths: []TaskType{TaskSecurityReview}, + } + armB := &Arm{ + ID: NewArmID("provB", "model"), + Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000}, + Strengths: []TaskType{TaskSecurityReview}, + } + + qt := NewQualityTracker() + // armB has consistently succeeded — minObservations=3 is enough to flip + // the score blend. + for i := 0; i < 5; i++ { + qt.Record(armB.ID, TaskSecurityReview, true) + } + // armA fails consistently. + for i := 0; i < 5; i++ { + qt.Record(armA.ID, TaskSecurityReview, false) + } + + task := Task{Type: TaskSecurityReview, EstimatedTokens: 5000, RequiresTools: true, Priority: PriorityNormal} + got := selectBest(qt, []*Arm{armA, armB}, task) + if got == nil { + t.Fatal("selectBest returned nil") + } + if got.ID != armB.ID { + t.Errorf("observed-quality winner should beat tied-strength loser; got %s", got.ID) + } +} diff --git a/internal/router/task.go b/internal/router/task.go index f50d854..c6c21bb 100644 --- a/internal/router/task.go +++ b/internal/router/task.go @@ -347,31 +347,40 @@ func estimateComplexity(prompt string) float64 { return score } -// ParseTaskType converts a string from an SLM JSON response to a TaskType. -// Matching is case-insensitive. Unknown strings fall back to TaskGeneration. -func ParseTaskType(s string) TaskType { +// ParseTaskTypeStrict is like ParseTaskType but reports whether the input +// matched a known type. Used by config wiring to surface typos in +// user-supplied task-type names instead of silently falling back to +// TaskGeneration. +func ParseTaskTypeStrict(s string) (TaskType, bool) { switch strings.ToLower(strings.ReplaceAll(s, "_", "")) { case "debug": - return TaskDebug + return TaskDebug, true case "explain": - return TaskExplain + return TaskExplain, true case "generation": - return TaskGeneration + return TaskGeneration, true case "refactor": - return TaskRefactor + return TaskRefactor, true case "unittest": - return TaskUnitTest + return TaskUnitTest, true case "boilerplate": - return TaskBoilerplate + return TaskBoilerplate, true case "planning": - return TaskPlanning + return TaskPlanning, true case "orchestration": - return TaskOrchestration + return TaskOrchestration, true case "securityreview": - return TaskSecurityReview + return TaskSecurityReview, true case "review": - return TaskReview - default: - return TaskGeneration + return TaskReview, true } + return TaskGeneration, false +} + +// ParseTaskType converts a string from an SLM JSON response to a TaskType. +// Matching is case-insensitive. Unknown strings fall back to TaskGeneration. +// Use ParseTaskTypeStrict when you need to detect typos. +func ParseTaskType(s string) TaskType { + t, _ := ParseTaskTypeStrict(s) + return t }