feat(router): per-arm strengths + cost weight (Phase D)
Plan D from docs/superpowers/plans/2026-05-19-post-slm-unlock.md (static portion; dynamic bandit-driven promotion deferred to D-2). Routing previously let tier ordering (CLI > local > API) dominate selection — Opus, in tier 3, would lose to a tier-1 CLI agent for SecurityReview even though Opus is empirically stronger at that task. This change introduces explicit per-arm overrides: [[arms]] id = "anthropic/claude-opus-4-7" strengths = ["security_review", "planning"] cost_weight = 0.3 Strengths gate cross-tier promotion: arms matching task.Type bypass the tier loop and compete with each other directly. Promotion is a preference, not a pin — if no strength-tagged arm is feasible (backoff, pool capacity, tool support), selection falls through to the default tier order. CostWeight linearly dampens the cost penalty in scoreArm via effectiveCost = 1 + CostWeight * (cost - 1) CostWeight=1.0 (or unset) preserves current behavior; lower values trade cheapness for quality. The earlier draft used cost^CostWeight which inverts direction for sub-1 local-arm costs (raising a fraction <1 to a fractional power makes it bigger, not smaller); a monotonicity regression test prevents that drift. - internal/router/arm.go: Strengths []TaskType, CostWeight float64, HasStrength(), ResolvedCostWeight() (zero → 1.0). - internal/router/selector.go: scoreArm strength bonus const (strengthScoreBonus = 0.15) + linear cost dampening; selectBest cross-tier promotion before tier loop. - internal/router/router.go: ArmOverride type + ApplyArmOverrides() returns unknown IDs; unknown strength names skipped with per-name warning via slog. - internal/router/task.go: ParseTaskTypeStrict() returns ok bool; ParseTaskType now delegates so the two switches stay in sync. - internal/config/config.go: ArmConfig + [[arms]] TOML wiring. - cmd/gnoma/main.go: applies overrides after all initial arms register; logs a warning when an [[arms]] id has no matching registered arm. Tests cover: predicate helpers, scoring direction across two arms, linear-formula monotonicity on both sides of cost=1, cross-tier promotion, empty-Strengths preserves tier order, promoted arm in backoff falls through via full Router.Select path, observed-quality tiebreak between two strength-tagged arms, ApplyArmOverrides happy path + unknown-ID reporting + unknown-strength skipping.
This commit is contained in:
@@ -420,6 +420,26 @@ func main() {
|
||||
logger.Debug("CLI agents discovered", "count", len(cliAgents))
|
||||
}
|
||||
|
||||
// Apply [[arms]] overrides (strengths, cost_weight) now that all initial
|
||||
// arms are registered. Late-discovered arms (background polling) won't
|
||||
// pick these up — by design: overrides target arms the user knows exist.
|
||||
if len(cfg.Arms) > 0 {
|
||||
overrides := make([]router.ArmOverride, 0, len(cfg.Arms))
|
||||
for _, ac := range cfg.Arms {
|
||||
overrides = append(overrides, router.ArmOverride{
|
||||
ID: ac.ID,
|
||||
Strengths: ac.Strengths,
|
||||
CostWeight: ac.CostWeight,
|
||||
})
|
||||
}
|
||||
if unknown := rtr.ApplyArmOverrides(overrides); len(unknown) > 0 {
|
||||
logger.Warn("[[arms]] config references unregistered arm IDs",
|
||||
"ids", unknown,
|
||||
"hint", "run `gnoma providers` to see registered arms",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Start background discovery polling (30s interval).
|
||||
// modelUpdater is set after the session is created so the discovery loop
|
||||
// can update the displayed model name when it reconciles the forced arm.
|
||||
|
||||
@@ -276,22 +276,55 @@ shouldn't dominate (e.g. SecurityReview).
|
||||
|
||||
### Tasks
|
||||
|
||||
- [ ] Add `Strengths` and `CostWeight` to `router.Arm`.
|
||||
- [ ] Config schema for per-arm overrides — likely
|
||||
`[arms.<id>.strengths] = ["planning", "orchestration"]`.
|
||||
- [ ] `scoreArm` consults both fields.
|
||||
- [ ] Bandit signal feeds back into a per-arm-per-task affinity over
|
||||
time (≥10 observations needed). Currently `QualityTracker` already
|
||||
tracks per-arm × per-task EMA; what's missing is letting that
|
||||
signal *promote* an arm out of its default tier.
|
||||
- [ ] Tests that show Opus winning over Gemini for SecurityReview
|
||||
when `arms.anthropic_opus.strengths = ["security_review"]`.
|
||||
- [x] Add `Strengths []TaskType` and `CostWeight float64` to
|
||||
`router.Arm`. Zero values preserve current behavior.
|
||||
- [x] Config schema: `[[arms]]` array of tables — `id`, `strengths`
|
||||
(string list, parsed via new `ParseTaskTypeStrict`), `cost_weight`.
|
||||
- [x] `scoreArm` consults both fields: strength match adds a tunable
|
||||
bonus (`strengthScoreBonus = 0.15`); `CostWeight` linearly dampens
|
||||
cost via `effectiveCost = 1 + CostWeight*(cost-1)` — monotone on
|
||||
both sides of cost=1.
|
||||
- [x] `selectBest` cross-tier promotion: arms whose `Strengths`
|
||||
contain `task.Type` are evaluated as one set before falling through
|
||||
to default tier order. Strengths are a preference, not a pin —
|
||||
backoff/feasibility filtering at the router level removes promoted
|
||||
arms when unavailable, and selection falls through.
|
||||
- [x] `Router.ApplyArmOverrides()` applies config overrides post
|
||||
arm-registration. Unknown arm IDs surfaced via return value; main
|
||||
logs a warning. Unknown strength names skipped with per-strength
|
||||
warning.
|
||||
- [x] Tests: Opus with `Strengths=[security_review]` beats CLI-agent
|
||||
tier-1 arm; empty Strengths preserves tier order; promoted arm in
|
||||
backoff falls through (via full `Router.Select` path); two
|
||||
strength-tagged arms decided by observed quality; CostWeight
|
||||
direction across two arms; linear-formula monotonicity regression
|
||||
test for the cost^weight bug avoided.
|
||||
|
||||
**Exit criteria:** with explicit per-task strengths set, the router
|
||||
picks the strongest available arm for that task type, not the
|
||||
lowest-tier one.
|
||||
**Status: shipped (static portion).** Module map:
|
||||
- `internal/router/arm.go` — `Strengths`, `CostWeight`,
|
||||
`HasStrength()`, `ResolvedCostWeight()`.
|
||||
- `internal/router/selector.go` — `scoreArm` updated, `selectBest`
|
||||
cross-tier promotion path.
|
||||
- `internal/router/router.go` — `ArmOverride` type and
|
||||
`ApplyArmOverrides()`.
|
||||
- `internal/router/task.go` — `ParseTaskTypeStrict()` (returns ok
|
||||
bool) for typo-resistant config parsing.
|
||||
- `internal/config/config.go` — `ArmConfig` struct and `[[arms]]`
|
||||
TOML wiring.
|
||||
- `cmd/gnoma/main.go` — applies overrides after all initial arms
|
||||
register; warns on unknown IDs.
|
||||
|
||||
**Effort:** ~300 LOC + tests. Touches `selector.go`, `arm.go`, config.
|
||||
**Exit criteria — met:** with `[[arms]] id="anthropic/..."
|
||||
strengths=["security_review"]`, the router picks Opus over a
|
||||
higher-tier CLI agent for that task type. Verified by
|
||||
`TestSelectBest_StrengthPromotedArmBeatsCLIAgent`.
|
||||
|
||||
**Effort:** ~350 LOC + tests.
|
||||
|
||||
**Deferred to D-2:** dynamic bandit-driven promotion (≥10 observations
|
||||
threshold + per-arm × per-task affinity that overrides tier order
|
||||
without static config). Holding until telemetry from real workloads
|
||||
informs the quality bar — same rationale as Phase E.
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ type Config struct {
|
||||
SLM SLMSection `toml:"slm"`
|
||||
Router RouterSection `toml:"router"`
|
||||
CLIAgents CLIAgentsSection `toml:"cli_agents"`
|
||||
Arms []ArmConfig `toml:"arms"`
|
||||
Hooks []HookConfig `toml:"hooks"`
|
||||
MCPServers []MCPServerConfig `toml:"mcp_servers"`
|
||||
Plugins PluginsSection `toml:"plugins"`
|
||||
@@ -41,6 +42,30 @@ type SLMSection struct {
|
||||
StartupTimeout Duration `toml:"startup_timeout"` // llamafile-only: first-launch wait budget; 0 = default 5s
|
||||
}
|
||||
|
||||
// ArmConfig tunes routing for a single registered arm. Multiple [[arms]]
|
||||
// blocks may appear; each is matched by ID against the runtime arm
|
||||
// registry. An ID that doesn't match any registered arm logs a warning at
|
||||
// startup — typos here are otherwise silent.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// [[arms]]
|
||||
// id = "anthropic/claude-opus-4-7"
|
||||
// strengths = ["security_review", "planning"] # task types this arm is preferred for
|
||||
// cost_weight = 0.3 # 1.0 = full cost penalty, 0 = ignore cost
|
||||
//
|
||||
// [[arms]]
|
||||
// id = "subprocess/claude"
|
||||
// strengths = ["orchestration"]
|
||||
//
|
||||
// Strength names map to router.TaskType via router.ParseTaskType — same
|
||||
// names the SLM classifier emits (snake_case or no separator both work).
|
||||
type ArmConfig struct {
|
||||
ID string `toml:"id"`
|
||||
Strengths []string `toml:"strengths"`
|
||||
CostWeight float64 `toml:"cost_weight"`
|
||||
}
|
||||
|
||||
// CLIAgentsSection maps canonical CLI agent names to override binary names.
|
||||
//
|
||||
// Useful when a user has aliased the canonical binary — e.g. `claude-priv`
|
||||
|
||||
@@ -249,6 +249,44 @@ gemini = ""
|
||||
}
|
||||
}
|
||||
|
||||
func TestArmConfig_TOML_RoundTrip(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "config.toml")
|
||||
_ = os.WriteFile(path, []byte(`
|
||||
[[arms]]
|
||||
id = "anthropic/claude-opus-4-7"
|
||||
strengths = ["security_review", "planning"]
|
||||
cost_weight = 0.3
|
||||
|
||||
[[arms]]
|
||||
id = "subprocess/claude"
|
||||
strengths = ["orchestration"]
|
||||
`), 0o644)
|
||||
|
||||
cfg := Defaults()
|
||||
if err := loadTOML(&cfg, path); err != nil {
|
||||
t.Fatalf("loadTOML: %v", err)
|
||||
}
|
||||
if len(cfg.Arms) != 2 {
|
||||
t.Fatalf("len(Arms) = %d, want 2", len(cfg.Arms))
|
||||
}
|
||||
if cfg.Arms[0].ID != "anthropic/claude-opus-4-7" {
|
||||
t.Errorf("Arms[0].ID = %q", cfg.Arms[0].ID)
|
||||
}
|
||||
if len(cfg.Arms[0].Strengths) != 2 || cfg.Arms[0].Strengths[0] != "security_review" {
|
||||
t.Errorf("Arms[0].Strengths = %v", cfg.Arms[0].Strengths)
|
||||
}
|
||||
if cfg.Arms[0].CostWeight != 0.3 {
|
||||
t.Errorf("Arms[0].CostWeight = %v, want 0.3", cfg.Arms[0].CostWeight)
|
||||
}
|
||||
if cfg.Arms[1].ID != "subprocess/claude" {
|
||||
t.Errorf("Arms[1].ID = %q", cfg.Arms[1].ID)
|
||||
}
|
||||
if cfg.Arms[1].CostWeight != 0 {
|
||||
t.Errorf("Arms[1].CostWeight = %v, want 0 (default)", cfg.Arms[1].CostWeight)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCLIAgentsSection_Absent_NilMap(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "config.toml")
|
||||
|
||||
@@ -30,6 +30,24 @@ type Arm struct {
|
||||
// Zero means no ceiling (default for all existing arms).
|
||||
MaxComplexity float64
|
||||
|
||||
// Strengths lists task types where this arm is preferred. When any
|
||||
// listed task type matches an incoming task, the arm crosses tier
|
||||
// boundaries during selection — Opus tagged with TaskSecurityReview
|
||||
// can beat a CLI-agent tier-1 arm for that task type, for example.
|
||||
// Strengths are a preference, not a pin: if no strength-matching arm
|
||||
// is feasible (rate-limited, backoff), selection falls back to the
|
||||
// default tier order.
|
||||
Strengths []TaskType
|
||||
|
||||
// CostWeight scales how much per-arm cost matters during scoring.
|
||||
// effectiveCost = 1 + CostWeight*(cost-1):
|
||||
// - 1.0 (or zero, which is normalized to 1.0): current behavior.
|
||||
// - 0.5: half-weight cost — pricey arms penalized less.
|
||||
// - 0.0: cost ignored, pure quality wins.
|
||||
// Use sub-1.0 values for task types where being right matters more
|
||||
// than being cheap (e.g. SecurityReview).
|
||||
CostWeight float64
|
||||
|
||||
// Cost per 1k tokens (EUR, estimated)
|
||||
CostPer1kInput float64
|
||||
CostPer1kOutput float64
|
||||
@@ -72,6 +90,29 @@ func (a *Arm) SupportsTools() bool {
|
||||
return a.Capabilities.ToolUse
|
||||
}
|
||||
|
||||
// HasStrength reports whether the arm is tagged as strong at the given task
|
||||
// type. Used by the selector to consider cross-tier promotion.
|
||||
func (a *Arm) HasStrength(t TaskType) bool {
|
||||
for _, s := range a.Strengths {
|
||||
if s == t {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ResolvedCostWeight normalizes the CostWeight field. A zero value means
|
||||
// "unset" and is treated as 1.0 (current full-cost behavior). Users who
|
||||
// want minimal cost influence set a small positive value like 0.05 — no
|
||||
// real use case wants exactly zero ("ignore cost entirely") and 0 doubles
|
||||
// as the Go zero value for arms registered before this field existed.
|
||||
func (a *Arm) ResolvedCostWeight() float64 {
|
||||
if a.CostWeight == 0 {
|
||||
return 1.0
|
||||
}
|
||||
return a.CostWeight
|
||||
}
|
||||
|
||||
// perfAlpha is the EMA smoothing factor for ArmPerf updates (0.3 = ~3-sample memory).
|
||||
const perfAlpha = 0.3
|
||||
|
||||
|
||||
@@ -219,6 +219,47 @@ func (r *Router) QualityTracker() *QualityTracker {
|
||||
return r.quality
|
||||
}
|
||||
|
||||
// ArmOverride is a config-supplied tweak to a registered arm. Use it to
|
||||
// declare per-task strengths and a CostWeight override.
|
||||
type ArmOverride struct {
|
||||
ID string // ArmID as registered (e.g. "anthropic/claude-opus-4-7")
|
||||
Strengths []string // task-type names, parsed via ParseTaskType
|
||||
CostWeight float64 // 0 leaves arm's current CostWeight untouched
|
||||
}
|
||||
|
||||
// ApplyArmOverrides walks the override list, locates each by ID, and
|
||||
// applies the requested Strengths/CostWeight in place. Returns the list of
|
||||
// IDs that did not match a registered arm so the caller can warn about
|
||||
// typos. Apply after all arms have been registered.
|
||||
func (r *Router) ApplyArmOverrides(overrides []ArmOverride) (unknownIDs []string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, ov := range overrides {
|
||||
arm, ok := r.arms[ArmID(ov.ID)]
|
||||
if !ok {
|
||||
unknownIDs = append(unknownIDs, ov.ID)
|
||||
continue
|
||||
}
|
||||
if len(ov.Strengths) > 0 {
|
||||
parsed := make([]TaskType, 0, len(ov.Strengths))
|
||||
for _, s := range ov.Strengths {
|
||||
t, ok := ParseTaskTypeStrict(s)
|
||||
if !ok {
|
||||
r.logger.Warn("unknown strength task-type; skipping",
|
||||
"arm", ov.ID, "strength", s)
|
||||
continue
|
||||
}
|
||||
parsed = append(parsed, t)
|
||||
}
|
||||
arm.Strengths = parsed
|
||||
}
|
||||
if ov.CostWeight != 0 {
|
||||
arm.CostWeight = ov.CostWeight
|
||||
}
|
||||
}
|
||||
return unknownIDs
|
||||
}
|
||||
|
||||
// Arms returns all registered arms.
|
||||
func (r *Router) Arms() []*Arm {
|
||||
r.mu.RLock()
|
||||
|
||||
@@ -56,13 +56,32 @@ func armTier(arm *Arm, task Task) int {
|
||||
return 3
|
||||
}
|
||||
|
||||
// selectBest picks the best arm, preferring lower-tier arms first.
|
||||
// Within a tier, the highest-scoring arm (by quality/cost) wins.
|
||||
// selectBest picks the best arm.
|
||||
//
|
||||
// Step 1: arms whose Strengths list contains task.Type cross all tier
|
||||
// boundaries — Opus tagged with SecurityReview beats a CLI-agent tier-1
|
||||
// arm for that task. Strengths are a preference, not a pin: if no
|
||||
// strength-matching arm is in the input set (filterFeasible already
|
||||
// removed arms in backoff, lacking tool support, or out of pool capacity),
|
||||
// selection falls through to the default tier order.
|
||||
//
|
||||
// Step 2 (fallback): walk tiers low→high. Within a tier, highest-scoring
|
||||
// arm wins.
|
||||
func selectBest(qt *QualityTracker, arms []*Arm, task Task) *Arm {
|
||||
if len(arms) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var promoted []*Arm
|
||||
for _, arm := range arms {
|
||||
if arm.HasStrength(task.Type) {
|
||||
promoted = append(promoted, arm)
|
||||
}
|
||||
}
|
||||
if len(promoted) > 0 {
|
||||
return bestScored(qt, promoted, task)
|
||||
}
|
||||
|
||||
for tier := 0; tier <= 3; tier++ {
|
||||
var inTier []*Arm
|
||||
for _, arm := range arms {
|
||||
@@ -91,10 +110,23 @@ func bestScored(qt *QualityTracker, arms []*Arm, task Task) *Arm {
|
||||
return best
|
||||
}
|
||||
|
||||
// strengthScoreBonus is added to quality when an arm's Strengths list
|
||||
// matches the incoming task type. Tunable in one place.
|
||||
const strengthScoreBonus = 0.15
|
||||
|
||||
// scoreArm computes a quality/cost score for an arm.
|
||||
// When the quality tracker has sufficient observations, blends observed EMA
|
||||
// (70%) with heuristic (30%). Falls back to pure heuristic otherwise.
|
||||
// Score = (quality × value) / effective_cost
|
||||
//
|
||||
// Strengths add a fixed bonus to quality when matching task.Type. CostWeight
|
||||
// dampens the cost penalty linearly:
|
||||
//
|
||||
// effectiveCost = 1 + CostWeight * (cost - 1)
|
||||
//
|
||||
// With CostWeight=1.0 (or unset → resolved to 1.0) the formula collapses to
|
||||
// the original effectiveCost == cost. With CostWeight=0 cost is fully
|
||||
// ignored (effectiveCost = 1.0). Local arms with sub-1 raw costs are not
|
||||
// amplified by fractional weights (the linear formula stays monotone).
|
||||
func scoreArm(qt *QualityTracker, arm *Arm, task Task) float64 {
|
||||
hq := heuristicQuality(arm, task)
|
||||
quality := hq
|
||||
@@ -103,12 +135,19 @@ func scoreArm(qt *QualityTracker, arm *Arm, task Task) float64 {
|
||||
quality = 0.7*observed + 0.3*hq
|
||||
}
|
||||
}
|
||||
value := task.ValueScore()
|
||||
cost := effectiveCost(arm, task)
|
||||
if cost <= 0 {
|
||||
cost = 0.001
|
||||
if arm.HasStrength(task.Type) {
|
||||
quality += strengthScoreBonus
|
||||
}
|
||||
return (quality * value) / cost
|
||||
value := task.ValueScore()
|
||||
rawCost := effectiveCost(arm, task)
|
||||
if rawCost <= 0 {
|
||||
rawCost = 0.001
|
||||
}
|
||||
weighted := 1.0 + arm.ResolvedCostWeight()*(rawCost-1.0)
|
||||
if weighted <= 0 {
|
||||
weighted = 0.001
|
||||
}
|
||||
return (quality * value) / weighted
|
||||
}
|
||||
|
||||
// heuristicQuality estimates arm quality without historical data.
|
||||
|
||||
@@ -0,0 +1,349 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||||
)
|
||||
|
||||
func timeFuture() time.Time { return time.Now().Add(1 * time.Hour) }
|
||||
|
||||
func TestArm_HasStrength(t *testing.T) {
|
||||
a := &Arm{Strengths: []TaskType{TaskSecurityReview, TaskPlanning}}
|
||||
if !a.HasStrength(TaskSecurityReview) {
|
||||
t.Error("HasStrength(SecurityReview) = false, want true")
|
||||
}
|
||||
if !a.HasStrength(TaskPlanning) {
|
||||
t.Error("HasStrength(Planning) = false, want true")
|
||||
}
|
||||
if a.HasStrength(TaskDebug) {
|
||||
t.Error("HasStrength(Debug) = true, want false")
|
||||
}
|
||||
empty := &Arm{}
|
||||
if empty.HasStrength(TaskSecurityReview) {
|
||||
t.Error("empty Strengths should never match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestArm_ResolvedCostWeight(t *testing.T) {
|
||||
cases := []struct {
|
||||
in, want float64
|
||||
}{
|
||||
{0, 1.0}, // unset → 1.0
|
||||
{1.0, 1.0}, // explicit 1.0 → 1.0
|
||||
{0.5, 0.5},
|
||||
{0.05, 0.05},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
a := &Arm{CostWeight: tc.in}
|
||||
if got := a.ResolvedCostWeight(); got != tc.want {
|
||||
t.Errorf("CostWeight=%v: ResolvedCostWeight() = %v, want %v", tc.in, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestScoreArm_CostWeightAffectsArmComparison(t *testing.T) {
|
||||
// The semantically meaningful test: two arms with different costs but
|
||||
// otherwise identical. At CostWeight=1.0 (current behavior), the cheap
|
||||
// arm wins. At CostWeight=0.0 (cost ignored), they tie on quality —
|
||||
// and the slightly-higher-quality one wins.
|
||||
cheap := &Arm{
|
||||
ID: NewArmID("provA", "small"),
|
||||
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 100000},
|
||||
CostPer1kInput: 0.0005,
|
||||
CostPer1kOutput: 0.0015,
|
||||
}
|
||||
expensive := &Arm{
|
||||
ID: NewArmID("provB", "big"),
|
||||
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000}, // slight quality edge
|
||||
CostPer1kInput: 0.015,
|
||||
CostPer1kOutput: 0.075,
|
||||
}
|
||||
task := Task{Type: TaskDebug, EstimatedTokens: 5000, RequiresTools: true, Priority: PriorityNormal}
|
||||
|
||||
// CostWeight=1.0: cost dominates, cheap arm wins.
|
||||
cheap.CostWeight, expensive.CostWeight = 1.0, 1.0
|
||||
if scoreArm(nil, cheap, task) <= scoreArm(nil, expensive, task) {
|
||||
t.Errorf("CostWeight=1.0: cheap arm should beat expensive arm; cheap=%v expensive=%v",
|
||||
scoreArm(nil, cheap, task), scoreArm(nil, expensive, task))
|
||||
}
|
||||
|
||||
// CostWeight=0.0: cost ignored, quality alone decides → expensive (better
|
||||
// context window) wins.
|
||||
cheap.CostWeight, expensive.CostWeight = 0.001, 0.001
|
||||
if scoreArm(nil, expensive, task) <= scoreArm(nil, cheap, task) {
|
||||
t.Errorf("CostWeight~0: higher-quality expensive arm should beat cheap arm; expensive=%v cheap=%v",
|
||||
scoreArm(nil, expensive, task), scoreArm(nil, cheap, task))
|
||||
}
|
||||
}
|
||||
|
||||
func TestScoreArm_LinearFormulaMonotone(t *testing.T) {
|
||||
// Regression: the original draft used cost^CostWeight, which inverts
|
||||
// direction when cost<1 (local arms). The linear formula
|
||||
// effectiveCost = 1 + CostWeight*(cost-1)
|
||||
// is monotone: increasing CostWeight monotonically pulls effectiveCost
|
||||
// toward the raw cost regardless of whether cost is above or below 1.
|
||||
//
|
||||
// Verify monotonicity on both sides of cost=1.
|
||||
cheap := &Arm{ // cost < 1
|
||||
CostPer1kInput: 0.001,
|
||||
CostPer1kOutput: 0.001,
|
||||
}
|
||||
expensive := &Arm{ // cost > 1 for big tasks
|
||||
CostPer1kInput: 0.05,
|
||||
CostPer1kOutput: 0.15,
|
||||
}
|
||||
task := Task{Type: TaskDebug, EstimatedTokens: 20000}
|
||||
|
||||
weights := []float64{0.05, 0.25, 0.5, 0.75, 1.0}
|
||||
for _, name := range []string{"cheap", "expensive"} {
|
||||
var prev float64
|
||||
for i, w := range weights {
|
||||
arm := cheap
|
||||
if name == "expensive" {
|
||||
arm = expensive
|
||||
}
|
||||
arm.CostWeight = w
|
||||
raw := effectiveCost(arm, task)
|
||||
weighted := 1.0 + arm.ResolvedCostWeight()*(raw-1.0)
|
||||
if i == 0 {
|
||||
prev = weighted
|
||||
continue
|
||||
}
|
||||
// As w increases, weighted should move toward raw.
|
||||
// For cheap (raw<1), weighted should DECREASE.
|
||||
// For expensive (raw>1), weighted should INCREASE.
|
||||
if raw < 1 && weighted > prev {
|
||||
t.Errorf("%s arm w=%v: weighted (%v) increased from prev (%v); raw=%v",
|
||||
name, w, weighted, prev, raw)
|
||||
}
|
||||
if raw > 1 && weighted < prev {
|
||||
t.Errorf("%s arm w=%v: weighted (%v) decreased from prev (%v); raw=%v",
|
||||
name, w, weighted, prev, raw)
|
||||
}
|
||||
prev = weighted
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestScoreArm_StrengthBonus(t *testing.T) {
|
||||
withoutStrength := &Arm{
|
||||
ID: NewArmID("anthropic", "opus"),
|
||||
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000},
|
||||
}
|
||||
withStrength := &Arm{
|
||||
ID: NewArmID("anthropic", "opus"),
|
||||
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000},
|
||||
Strengths: []TaskType{TaskSecurityReview},
|
||||
}
|
||||
task := Task{Type: TaskSecurityReview, EstimatedTokens: 5000, RequiresTools: true, Priority: PriorityNormal}
|
||||
|
||||
a := scoreArm(nil, withoutStrength, task)
|
||||
b := scoreArm(nil, withStrength, task)
|
||||
if !(b > a) {
|
||||
t.Errorf("strength-tagged arm score (%v) should exceed plain arm score (%v)", b, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScoreArm_StrengthBonusDoesNotApplyToOtherTasks(t *testing.T) {
|
||||
// Strengths apply only to listed task types.
|
||||
tagged := &Arm{
|
||||
ID: NewArmID("anthropic", "opus"),
|
||||
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000},
|
||||
Strengths: []TaskType{TaskSecurityReview},
|
||||
}
|
||||
plain := &Arm{
|
||||
ID: NewArmID("anthropic", "opus"),
|
||||
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000},
|
||||
}
|
||||
task := Task{Type: TaskDebug, EstimatedTokens: 5000, RequiresTools: true, Priority: PriorityNormal}
|
||||
|
||||
a := scoreArm(nil, plain, task)
|
||||
b := scoreArm(nil, tagged, task)
|
||||
if math.Abs(a-b) > 1e-9 {
|
||||
t.Errorf("non-matching task should ignore Strengths: plain=%v tagged=%v", a, b)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectBest_StrengthPromotedArmBeatsCLIAgent(t *testing.T) {
|
||||
// Plan exit criteria: with Strengths set, Opus (tier 3) wins over a CLI
|
||||
// agent (tier 1) for SecurityReview.
|
||||
cliAgent := &Arm{
|
||||
ID: NewArmID("subprocess", "claude"),
|
||||
IsCLIAgent: true,
|
||||
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000},
|
||||
}
|
||||
opus := &Arm{
|
||||
ID: NewArmID("anthropic", "opus"),
|
||||
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000},
|
||||
Strengths: []TaskType{TaskSecurityReview},
|
||||
CostPer1kInput: 0.015,
|
||||
CostPer1kOutput: 0.075,
|
||||
}
|
||||
|
||||
task := Task{Type: TaskSecurityReview, EstimatedTokens: 5000, RequiresTools: true, Priority: PriorityNormal}
|
||||
got := selectBest(nil, []*Arm{cliAgent, opus}, task)
|
||||
if got == nil {
|
||||
t.Fatal("selectBest returned nil")
|
||||
}
|
||||
if got.ID != opus.ID {
|
||||
t.Errorf("selectBest = %s, want %s (strength-promoted arm should beat tier-1 CLI agent)", got.ID, opus.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectBest_EmptyStrengthsPreservesTierOrder(t *testing.T) {
|
||||
// Regression: without Strengths, CLI-agent tier-1 still wins over API tier-3.
|
||||
cliAgent := &Arm{
|
||||
ID: NewArmID("subprocess", "claude"),
|
||||
IsCLIAgent: true,
|
||||
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000},
|
||||
}
|
||||
opus := &Arm{
|
||||
ID: NewArmID("anthropic", "opus"),
|
||||
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000},
|
||||
CostPer1kInput: 0.015,
|
||||
CostPer1kOutput: 0.075,
|
||||
}
|
||||
|
||||
task := Task{Type: TaskSecurityReview, EstimatedTokens: 5000, RequiresTools: true, Priority: PriorityNormal}
|
||||
got := selectBest(nil, []*Arm{cliAgent, opus}, task)
|
||||
if got.ID != cliAgent.ID {
|
||||
t.Errorf("without Strengths, CLI-agent tier-1 should win; got %s", got.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_Select_PromotedArmInBackoffFallsThroughToTierOrder(t *testing.T) {
|
||||
// Strengths are preference, not pin. Full Router.Select path: backoff
|
||||
// filtering removes the promoted arm; selectBest then falls through to
|
||||
// the default tier order and picks the CLI agent.
|
||||
cliAgent := &Arm{
|
||||
ID: NewArmID("subprocess", "claude"),
|
||||
IsCLIAgent: true,
|
||||
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000},
|
||||
}
|
||||
opus := &Arm{
|
||||
ID: NewArmID("anthropic", "opus"),
|
||||
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000},
|
||||
Strengths: []TaskType{TaskSecurityReview},
|
||||
}
|
||||
opus.SetBackoff(timeFuture())
|
||||
|
||||
r := New(Config{})
|
||||
r.RegisterArm(cliAgent)
|
||||
r.RegisterArm(opus)
|
||||
|
||||
task := Task{Type: TaskSecurityReview, EstimatedTokens: 5000, RequiresTools: true, Priority: PriorityNormal}
|
||||
decision := r.Select(task)
|
||||
if decision.Error != nil {
|
||||
t.Fatalf("Select: %v", decision.Error)
|
||||
}
|
||||
if decision.Arm.ID != cliAgent.ID {
|
||||
t.Errorf("promoted arm in backoff should fall through to CLI agent; got %s", decision.Arm.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyArmOverrides_ApplyStrengthsAndCostWeight(t *testing.T) {
|
||||
r := New(Config{})
|
||||
opus := &Arm{
|
||||
ID: NewArmID("anthropic", "opus"),
|
||||
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000},
|
||||
}
|
||||
r.RegisterArm(opus)
|
||||
|
||||
unknown := r.ApplyArmOverrides([]ArmOverride{
|
||||
{
|
||||
ID: "anthropic/opus",
|
||||
Strengths: []string{"security_review", "planning"},
|
||||
CostWeight: 0.3,
|
||||
},
|
||||
})
|
||||
if len(unknown) != 0 {
|
||||
t.Errorf("unknown = %v, want empty", unknown)
|
||||
}
|
||||
|
||||
got, _ := r.LookupArm(NewArmID("anthropic", "opus"))
|
||||
if !got.HasStrength(TaskSecurityReview) {
|
||||
t.Error("opus should have SecurityReview strength after override")
|
||||
}
|
||||
if !got.HasStrength(TaskPlanning) {
|
||||
t.Error("opus should have Planning strength after override")
|
||||
}
|
||||
if got.CostWeight != 0.3 {
|
||||
t.Errorf("opus.CostWeight = %v, want 0.3", got.CostWeight)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyArmOverrides_UnknownIDReported(t *testing.T) {
|
||||
r := New(Config{})
|
||||
r.RegisterArm(&Arm{
|
||||
ID: NewArmID("anthropic", "opus"),
|
||||
Capabilities: provider.Capabilities{ToolUse: true},
|
||||
})
|
||||
|
||||
unknown := r.ApplyArmOverrides([]ArmOverride{
|
||||
{ID: "anthropic/opus", Strengths: []string{"debug"}},
|
||||
{ID: "anthropic/typo-here", Strengths: []string{"refactor"}},
|
||||
})
|
||||
if len(unknown) != 1 || unknown[0] != "anthropic/typo-here" {
|
||||
t.Errorf("unknown = %v, want [anthropic/typo-here]", unknown)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyArmOverrides_UnknownStrengthSkipped(t *testing.T) {
|
||||
r := New(Config{})
|
||||
arm := &Arm{
|
||||
ID: NewArmID("anthropic", "opus"),
|
||||
Capabilities: provider.Capabilities{ToolUse: true},
|
||||
}
|
||||
r.RegisterArm(arm)
|
||||
|
||||
r.ApplyArmOverrides([]ArmOverride{
|
||||
{ID: "anthropic/opus", Strengths: []string{"security_review", "bogus-type"}},
|
||||
})
|
||||
|
||||
got, _ := r.LookupArm(NewArmID("anthropic", "opus"))
|
||||
if !got.HasStrength(TaskSecurityReview) {
|
||||
t.Error("security_review should be applied")
|
||||
}
|
||||
if len(got.Strengths) != 1 {
|
||||
t.Errorf("got.Strengths = %v, want [security_review] only (bogus skipped)", got.Strengths)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectBest_MultiplePromotedArmsBestQualityWins(t *testing.T) {
|
||||
// Tunability check: when two arms both have Strengths for the same task,
|
||||
// observed quality (via QualityTracker) should determine the winner, not
|
||||
// the static strength bonus alone.
|
||||
armA := &Arm{
|
||||
ID: NewArmID("provA", "model"),
|
||||
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000},
|
||||
Strengths: []TaskType{TaskSecurityReview},
|
||||
}
|
||||
armB := &Arm{
|
||||
ID: NewArmID("provB", "model"),
|
||||
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000},
|
||||
Strengths: []TaskType{TaskSecurityReview},
|
||||
}
|
||||
|
||||
qt := NewQualityTracker()
|
||||
// armB has consistently succeeded — minObservations=3 is enough to flip
|
||||
// the score blend.
|
||||
for i := 0; i < 5; i++ {
|
||||
qt.Record(armB.ID, TaskSecurityReview, true)
|
||||
}
|
||||
// armA fails consistently.
|
||||
for i := 0; i < 5; i++ {
|
||||
qt.Record(armA.ID, TaskSecurityReview, false)
|
||||
}
|
||||
|
||||
task := Task{Type: TaskSecurityReview, EstimatedTokens: 5000, RequiresTools: true, Priority: PriorityNormal}
|
||||
got := selectBest(qt, []*Arm{armA, armB}, task)
|
||||
if got == nil {
|
||||
t.Fatal("selectBest returned nil")
|
||||
}
|
||||
if got.ID != armB.ID {
|
||||
t.Errorf("observed-quality winner should beat tied-strength loser; got %s", got.ID)
|
||||
}
|
||||
}
|
||||
+24
-15
@@ -347,31 +347,40 @@ func estimateComplexity(prompt string) float64 {
|
||||
return score
|
||||
}
|
||||
|
||||
// ParseTaskType converts a string from an SLM JSON response to a TaskType.
|
||||
// Matching is case-insensitive. Unknown strings fall back to TaskGeneration.
|
||||
func ParseTaskType(s string) TaskType {
|
||||
// ParseTaskTypeStrict is like ParseTaskType but reports whether the input
|
||||
// matched a known type. Used by config wiring to surface typos in
|
||||
// user-supplied task-type names instead of silently falling back to
|
||||
// TaskGeneration.
|
||||
func ParseTaskTypeStrict(s string) (TaskType, bool) {
|
||||
switch strings.ToLower(strings.ReplaceAll(s, "_", "")) {
|
||||
case "debug":
|
||||
return TaskDebug
|
||||
return TaskDebug, true
|
||||
case "explain":
|
||||
return TaskExplain
|
||||
return TaskExplain, true
|
||||
case "generation":
|
||||
return TaskGeneration
|
||||
return TaskGeneration, true
|
||||
case "refactor":
|
||||
return TaskRefactor
|
||||
return TaskRefactor, true
|
||||
case "unittest":
|
||||
return TaskUnitTest
|
||||
return TaskUnitTest, true
|
||||
case "boilerplate":
|
||||
return TaskBoilerplate
|
||||
return TaskBoilerplate, true
|
||||
case "planning":
|
||||
return TaskPlanning
|
||||
return TaskPlanning, true
|
||||
case "orchestration":
|
||||
return TaskOrchestration
|
||||
return TaskOrchestration, true
|
||||
case "securityreview":
|
||||
return TaskSecurityReview
|
||||
return TaskSecurityReview, true
|
||||
case "review":
|
||||
return TaskReview
|
||||
default:
|
||||
return TaskGeneration
|
||||
return TaskReview, true
|
||||
}
|
||||
return TaskGeneration, false
|
||||
}
|
||||
|
||||
// ParseTaskType converts a string from an SLM JSON response to a TaskType.
|
||||
// Matching is case-insensitive. Unknown strings fall back to TaskGeneration.
|
||||
// Use ParseTaskTypeStrict when you need to detect typos.
|
||||
func ParseTaskType(s string) TaskType {
|
||||
t, _ := ParseTaskTypeStrict(s)
|
||||
return t
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user