diff --git a/cmd/gnoma/main.go b/cmd/gnoma/main.go index c562ac4..9486e21 100644 --- a/cmd/gnoma/main.go +++ b/cmd/gnoma/main.go @@ -397,7 +397,17 @@ func main() { // Create router and register the provider as a single arm // (M4 foundation: one provider from CLI. Multi-provider routing comes with config.) - rtr := router.New(router.Config{Logger: logger}) + // BanditParams come from [router.bandit] config keys; zero values + // resolve to built-in defaults inside the router package. + rtr := router.New(router.Config{ + Logger: logger, + Bandit: router.BanditParams{ + QualityAlpha: cfg.Router.Bandit.QualityAlpha, + MinObservations: cfg.Router.Bandit.MinObservations, + ObservedWeight: cfg.Router.Bandit.ObservedWeight, + StrengthBonus: cfg.Router.Bandit.StrengthBonus, + }, + }) // Apply the prefer-routing-policy from config (default: auto). // Invalid values are rejected here with an actionable error rather diff --git a/internal/config/config.go b/internal/config/config.go index 3fd3bf8..2c7ed48 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -157,6 +157,40 @@ type RouterSection struct { // and incognito take priority over this knob. See // docs/superpowers/plans/2026-05-23-prefer-routing-policy.md. Prefer string `toml:"prefer"` + + // Bandit exposes the selector's tuning knobs. Defaults preserve + // previous hard-coded behaviour exactly; only set these when you + // need to tune the EMA quality tracker for an unusual workload. + Bandit BanditSection `toml:"bandit"` +} + +// BanditSection holds the scoring knobs for the EMA quality tracker +// and the score blend used by the selector. Each field has a sentinel +// zero value that means "use the built-in default" so an empty TOML +// block is byte-identical to pre-config behaviour. See +// internal/router/feedback.go and internal/router/selector.go for the +// formulas these knobs feed into. +type BanditSection struct { + // QualityAlpha is the EMA smoothing factor for arm-quality + // observations. Larger values weight recent observations more. + // Default: 0.3 (~3-sample memory). 0.0 here means "use default". + QualityAlpha float64 `toml:"quality_alpha"` + + // MinObservations is the minimum number of samples required + // before observed EMA overrides the heuristic fallback. Default: + // 3. 0 here means "use default". + MinObservations int `toml:"min_observations"` + + // ObservedWeight is the weight of the observed EMA in the + // observed/heuristic blend inside scoreArm: the final quality is + // `observed*W + heuristic*(1-W)`. Default: 0.7. 0.0 here means + // "use default". + ObservedWeight float64 `toml:"observed_weight"` + + // StrengthBonus is the quality bonus added when an arm declares + // the current task type in its Strengths list. Default: 0.15. + // 0.0 here means "use default". + StrengthBonus float64 `toml:"strength_bonus"` } // MCPServerConfig defines an MCP server to start and connect to. diff --git a/internal/router/bench_test.go b/internal/router/bench_test.go index 6466f83..52df89b 100644 --- a/internal/router/bench_test.go +++ b/internal/router/bench_test.go @@ -57,12 +57,12 @@ func benchTasks() []Task { func BenchmarkSelectBest(b *testing.B) { arms := benchArms() tasks := benchTasks() - qt := NewQualityTracker() + qt := NewQualityTracker(0, 0) b.ResetTimer() for b.Loop() { for _, task := range tasks { - selectBest(qt, arms, task, PreferAuto) + selectBest(qt, BanditParams{}, arms, task, PreferAuto) } } } @@ -99,13 +99,13 @@ func BenchmarkRouterSelect(b *testing.B) { func BenchmarkScoreArm(b *testing.B) { arms := benchArms() - qt := NewQualityTracker() + qt := NewQualityTracker(0, 0) task := Task{Type: TaskGeneration, Priority: PriorityNormal, EstimatedTokens: 2000, RequiresTools: true, ComplexityScore: 0.5} b.ResetTimer() for b.Loop() { for _, arm := range arms { - scoreArm(qt, arm, task) + scoreArm(qt, BanditParams{}, arm, task) } } } diff --git a/internal/router/feedback.go b/internal/router/feedback.go index 5efddb9..14ee39e 100644 --- a/internal/router/feedback.go +++ b/internal/router/feedback.go @@ -2,9 +2,15 @@ package router import "sync" +// Built-in defaults for the bandit knobs. Surfaced via +// [router.bandit] config keys; see BanditParams in router.go. Kept +// here so the QualityTracker has a sensible fallback when constructed +// without explicit parameters (tests, ad-hoc callers). const ( - qualityAlpha = 0.3 // EMA smoothing factor (~3-sample memory) - minObservations = 3 // min samples before observed score overrides heuristic + defaultQualityAlpha = 0.3 // EMA smoothing factor (~3-sample memory) + defaultMinObservations = 3 // min samples before observed score overrides heuristic + defaultObservedWeight = 0.7 // weight of observed score in observed/heuristic blend + defaultStrengthBonus = 0.15 ) // EMAScore tracks an exponential moving average quality score. @@ -19,13 +25,27 @@ type QualityTracker struct { mu sync.RWMutex scores map[ArmID]map[TaskType]*EMAScore classifierCount map[ClassifierSource]int + + // Configurable knobs — set via NewQualityTracker. Pass 0 for any + // argument to keep the built-in default. + alpha float64 + minObservations int } -// NewQualityTracker returns an empty QualityTracker. -func NewQualityTracker() *QualityTracker { +// NewQualityTracker returns an empty QualityTracker. Pass 0 for any +// argument to keep the built-in default (alpha=0.3, minObs=3). +func NewQualityTracker(alpha float64, minObs int) *QualityTracker { + if alpha == 0 { + alpha = defaultQualityAlpha + } + if minObs == 0 { + minObs = defaultMinObservations + } return &QualityTracker{ scores: make(map[ArmID]map[TaskType]*EMAScore), classifierCount: make(map[ClassifierSource]int), + alpha: alpha, + minObservations: minObs, } } @@ -71,7 +91,7 @@ func (qt *QualityTracker) Record(armID ArmID, taskType TaskType, success bool) { if s.Count == 0 { s.Value = observation } else { - s.Value = qualityAlpha*observation + (1-qualityAlpha)*s.Value + s.Value = qt.alpha*observation + (1-qt.alpha)*s.Value } s.Count++ } @@ -86,7 +106,7 @@ func (qt *QualityTracker) Quality(armID ArmID, taskType TaskType) (score float64 return 0, false } s, ok := m[taskType] - if !ok || s.Count < minObservations { + if !ok || s.Count < qt.minObservations { return 0, false } return s.Value, true diff --git a/internal/router/feedback_test.go b/internal/router/feedback_test.go index a0e6e4d..83accd4 100644 --- a/internal/router/feedback_test.go +++ b/internal/router/feedback_test.go @@ -8,7 +8,7 @@ import ( ) func TestQualityTracker_NoDataReturnsHeuristic(t *testing.T) { - qt := router.NewQualityTracker() + qt := router.NewQualityTracker(0, 0) _, hasData := qt.Quality("arm:model", router.TaskGeneration) if hasData { t.Error("expected no data for unobserved arm") @@ -16,7 +16,7 @@ func TestQualityTracker_NoDataReturnsHeuristic(t *testing.T) { } func TestQualityTracker_RecordUpdatesEMA(t *testing.T) { - qt := router.NewQualityTracker() + qt := router.NewQualityTracker(0, 0) for i := 0; i < 3; i++ { qt.Record("arm:model", router.TaskGeneration, true) } @@ -30,7 +30,7 @@ func TestQualityTracker_RecordUpdatesEMA(t *testing.T) { } func TestQualityTracker_AllFailuresLowScore(t *testing.T) { - qt := router.NewQualityTracker() + qt := router.NewQualityTracker(0, 0) for i := 0; i < 5; i++ { qt.Record("arm:model", router.TaskDebug, false) } @@ -41,7 +41,7 @@ func TestQualityTracker_AllFailuresLowScore(t *testing.T) { } func TestQualityTracker_ConcurrentSafe(t *testing.T) { - qt := router.NewQualityTracker() + qt := router.NewQualityTracker(0, 0) done := make(chan struct{}) for i := 0; i < 10; i++ { go func(success bool) { @@ -113,3 +113,45 @@ func TestQualityTracker_InsufficientDataFallsBackToHeuristic(t *testing.T) { } decision.Rollback() } + +func TestQualityTracker_CustomAlphaShortensMemory(t *testing.T) { + // alpha=0.9 weights the latest sample heavily; after a single + // failure the score should drop further than with the default 0.3. + fast := router.NewQualityTracker(0.9, 0) + slow := router.NewQualityTracker(0.0, 0) // 0 → default 0.3 + + for _, qt := range []*router.QualityTracker{fast, slow} { + // Build up history at the high end with 5 successes. + for i := 0; i < 5; i++ { + qt.Record("arm:m", router.TaskGeneration, true) + } + // One failure. + qt.Record("arm:m", router.TaskGeneration, false) + } + + fastScore, _ := fast.Quality("arm:m", router.TaskGeneration) + slowScore, _ := slow.Quality("arm:m", router.TaskGeneration) + + if !(fastScore < slowScore) { + t.Errorf("expected fast alpha (0.9) to drop quality faster than default (0.3): fast=%f slow=%f", fastScore, slowScore) + } +} + +func TestQualityTracker_CustomMinObservationsGatesScore(t *testing.T) { + // minObs=10 means Quality should return hasData=false until 10 + // observations are recorded, even though the default would say + // "yes" after 3. + qt := router.NewQualityTracker(0, 10) + for i := 0; i < 5; i++ { + qt.Record("arm:m", router.TaskGeneration, true) + } + if _, hasData := qt.Quality("arm:m", router.TaskGeneration); hasData { + t.Error("expected hasData=false at 5 observations with minObs=10") + } + for i := 0; i < 5; i++ { + qt.Record("arm:m", router.TaskGeneration, true) + } + if _, hasData := qt.Quality("arm:m", router.TaskGeneration); !hasData { + t.Error("expected hasData=true after 10 observations with minObs=10") + } +} diff --git a/internal/router/quality_json_test.go b/internal/router/quality_json_test.go index 06829ab..24fd523 100644 --- a/internal/router/quality_json_test.go +++ b/internal/router/quality_json_test.go @@ -8,7 +8,7 @@ import ( ) func TestQualityTracker_SnapshotRestore_RoundTrip(t *testing.T) { - qt := router.NewQualityTracker() + qt := router.NewQualityTracker(0, 0) // Record some outcomes qt.Record("anthropic/claude-3-5-sonnet", router.TaskGeneration, true) qt.Record("anthropic/claude-3-5-sonnet", router.TaskGeneration, true) @@ -33,7 +33,7 @@ func TestQualityTracker_SnapshotRestore_RoundTrip(t *testing.T) { } // Restore into a fresh tracker - qt2 := router.NewQualityTracker() + qt2 := router.NewQualityTracker(0, 0) qt2.Restore(restored) // After restore, Quality() should return data (Count >= minObservations=3) @@ -47,7 +47,7 @@ func TestQualityTracker_SnapshotRestore_RoundTrip(t *testing.T) { } func TestQualityTracker_Snapshot_Empty(t *testing.T) { - qt := router.NewQualityTracker() + qt := router.NewQualityTracker(0, 0) snap := qt.Snapshot() if snap.Scores == nil { t.Error("scores map should be initialized (not nil)") @@ -58,7 +58,7 @@ func TestQualityTracker_Snapshot_Empty(t *testing.T) { } func TestQualityTracker_ClassifierCounts_RecordAndSnapshot(t *testing.T) { - qt := router.NewQualityTracker() + qt := router.NewQualityTracker(0, 0) qt.RecordClassifier(router.ClassifierHeuristic) qt.RecordClassifier(router.ClassifierSLM) qt.RecordClassifier(router.ClassifierSLM) @@ -92,7 +92,7 @@ func TestQualityTracker_ClassifierCounts_RecordAndSnapshot(t *testing.T) { if err := json.Unmarshal(data, &restored); err != nil { t.Fatal(err) } - qt2 := router.NewQualityTracker() + qt2 := router.NewQualityTracker(0, 0) qt2.Restore(restored) if qt2.ClassifierCounts()[router.ClassifierSLM] != 2 { t.Errorf("restored slm count = %d, want 2", qt2.ClassifierCounts()[router.ClassifierSLM]) @@ -107,7 +107,7 @@ func TestQualityTracker_Restore_BackCompat_NoClassifierCounts(t *testing.T) { if err := json.Unmarshal(legacy, &snap); err != nil { t.Fatal(err) } - qt := router.NewQualityTracker() + qt := router.NewQualityTracker(0, 0) qt.Restore(snap) if qt.ClassifierCounts() == nil { t.Error("ClassifierCounts() must return a non-nil map after restoring old snapshot") @@ -122,7 +122,7 @@ func TestQualityTracker_Restore_BackCompat_NoClassifierCounts(t *testing.T) { } func TestQualityTracker_Restore_Replaces(t *testing.T) { - qt := router.NewQualityTracker() + qt := router.NewQualityTracker(0, 0) qt.Record("arm-a", router.TaskDebug, true) qt.Record("arm-a", router.TaskDebug, true) qt.Record("arm-a", router.TaskDebug, true) diff --git a/internal/router/router.go b/internal/router/router.go index d140a6c..4f21304 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -27,6 +27,7 @@ type Router struct { preferPolicy PreferPolicy quality *QualityTracker + bandit BanditParams } // PreferPolicy biases the scoring step toward local or cloud arms. @@ -77,6 +78,41 @@ func (p PreferPolicy) String() string { type Config struct { Logger *slog.Logger + // Bandit tunes the selector's scoring knobs. Pass a zero value to + // keep all pre-config behaviour byte-identical; set individual + // fields to override the corresponding default. + Bandit BanditParams +} + +// BanditParams controls the EMA quality tracker and score blend used +// by the selector. Each field has a "use default" sentinel (0 for +// floats and ints) so a zero-valued BanditParams is byte-identical to +// the pre-config hardcoded constants. Defaults are defined in +// resolveBanditParams below. +type BanditParams struct { + QualityAlpha float64 + MinObservations int + ObservedWeight float64 + StrengthBonus float64 +} + +// resolveBanditParams fills in the built-in defaults for any field +// left at its zero value. Centralised so the same defaults apply +// across NewQualityTracker, scoreArm, and any future caller. +func resolveBanditParams(p BanditParams) BanditParams { + if p.QualityAlpha == 0 { + p.QualityAlpha = defaultQualityAlpha + } + if p.MinObservations == 0 { + p.MinObservations = defaultMinObservations + } + if p.ObservedWeight == 0 { + p.ObservedWeight = defaultObservedWeight + } + if p.StrengthBonus == 0 { + p.StrengthBonus = defaultStrengthBonus + } + return p } func New(cfg Config) *Router { @@ -84,10 +120,12 @@ func New(cfg Config) *Router { if logger == nil { logger = slog.Default() } + params := resolveBanditParams(cfg.Bandit) return &Router{ arms: make(map[ArmID]*Arm), logger: logger, - quality: NewQualityTracker(), + quality: NewQualityTracker(params.QualityAlpha, params.MinObservations), + bandit: params, } } @@ -172,7 +210,7 @@ func (r *Router) Select(task Task) RoutingDecision { } // Select best - best := selectBest(r.quality, feasible, task, r.preferPolicy) + best := selectBest(r.quality, r.bandit, feasible, task, r.preferPolicy) if best == nil { return RoutingDecision{Error: fmt.Errorf("selection failed")} } diff --git a/internal/router/router_test.go b/internal/router/router_test.go index f7d8e28..5964735 100644 --- a/internal/router/router_test.go +++ b/internal/router/router_test.go @@ -262,7 +262,7 @@ func TestSelectBest_PrefersToolSupport(t *testing.T) { } task := Task{Type: TaskGeneration, RequiresTools: true, Priority: PriorityNormal} - best := selectBest(nil, []*Arm{withoutTools, withTools}, task, PreferAuto) + best := selectBest(nil, BanditParams{}, []*Arm{withoutTools, withTools}, task, PreferAuto) if best.ID != "a/with-tools" { t.Errorf("should prefer arm with tool support, got %s", best.ID) @@ -282,7 +282,7 @@ func TestSelectBest_PrefersThinkingForPlanning(t *testing.T) { } task := Task{Type: TaskPlanning, RequiresTools: true, Priority: PriorityNormal, EstimatedTokens: 5000} - best := selectBest(nil, []*Arm{noThinking, thinking}, task, PreferAuto) + best := selectBest(nil, BanditParams{}, []*Arm{noThinking, thinking}, task, PreferAuto) if best.ID != "a/thinking" { t.Errorf("should prefer thinking model for planning, got %s", best.ID) @@ -625,7 +625,7 @@ func TestSelectBest_SmallArmWinsTrivialTask(t *testing.T) { Capabilities: provider.Capabilities{ToolUse: false}, } task := Task{Type: TaskExplain, ComplexityScore: 0.05, RequiresTools: false} - got := selectBest(nil, []*Arm{cliArm, smallArm}, task, PreferAuto) + got := selectBest(nil, BanditParams{}, []*Arm{cliArm, smallArm}, task, PreferAuto) if got != smallArm { t.Errorf("selectBest = %v, want smallArm", got) } @@ -647,7 +647,7 @@ func TestSelectBest_CLIAgentWinsComplexTask(t *testing.T) { Capabilities: provider.Capabilities{ToolUse: false}, } task := Task{Type: TaskRefactor, ComplexityScore: 0.7, RequiresTools: true} - got := selectBest(nil, []*Arm{cliArm, smallArm}, task, PreferAuto) + got := selectBest(nil, BanditParams{}, []*Arm{cliArm, smallArm}, task, PreferAuto) if got != cliArm { t.Errorf("selectBest = %v, want cliArm", got) } @@ -672,21 +672,21 @@ func TestSelectBest_TierPreference(t *testing.T) { task := Task{Type: TaskGeneration, Priority: PriorityNormal, EstimatedTokens: 1000} t.Run("CLI beats local and API", func(t *testing.T) { - best := selectBest(nil, []*Arm{apiArm, localArm, cliArm}, task, PreferAuto) + best := selectBest(nil, BanditParams{}, []*Arm{apiArm, localArm, cliArm}, task, PreferAuto) if best.ID != "subprocess/claude" { t.Errorf("want subprocess/claude (tier 0), got %s", best.ID) } }) t.Run("local beats API when no CLI", func(t *testing.T) { - best := selectBest(nil, []*Arm{apiArm, localArm}, task, PreferAuto) + best := selectBest(nil, BanditParams{}, []*Arm{apiArm, localArm}, task, PreferAuto) if best.ID != "ollama/llama3" { t.Errorf("want ollama/llama3 (tier 1), got %s", best.ID) } }) t.Run("API selected when only option", func(t *testing.T) { - best := selectBest(nil, []*Arm{apiArm}, task, PreferAuto) + best := selectBest(nil, BanditParams{}, []*Arm{apiArm}, task, PreferAuto) if best == nil || best.ID != "mistral/mistral-large" { t.Errorf("want mistral/mistral-large (tier 2), got %v", best) } diff --git a/internal/router/selector.go b/internal/router/selector.go index 177226b..14085bf 100644 --- a/internal/router/selector.go +++ b/internal/router/selector.go @@ -98,7 +98,7 @@ func armBaseTier(arm *Arm, task Task) int { // // Step 2 (fallback): walk tiers low→high. Within a tier, highest-scoring // arm wins. -func selectBest(qt *QualityTracker, arms []*Arm, task Task, prefer PreferPolicy) *Arm { +func selectBest(qt *QualityTracker, params BanditParams, arms []*Arm, task Task, prefer PreferPolicy) *Arm { if len(arms) == 0 { return nil } @@ -110,7 +110,7 @@ func selectBest(qt *QualityTracker, arms []*Arm, task Task, prefer PreferPolicy) } } if len(promoted) > 0 { - return bestScored(qt, promoted, task, prefer) + return bestScored(qt, params, promoted, task, prefer) } // Walk tiers low→high. armTier returns up to 5 when prefer is set @@ -124,18 +124,18 @@ func selectBest(qt *QualityTracker, arms []*Arm, task Task, prefer PreferPolicy) } } if len(inTier) > 0 { - return bestScored(qt, inTier, task, prefer) + return bestScored(qt, params, inTier, task, prefer) } } return nil } // bestScored returns the highest-scoring arm within a set. -func bestScored(qt *QualityTracker, arms []*Arm, task Task, prefer PreferPolicy) *Arm { +func bestScored(qt *QualityTracker, params BanditParams, arms []*Arm, task Task, prefer PreferPolicy) *Arm { var best *Arm bestScore := math.Inf(-1) for _, arm := range arms { - score := scoreArm(qt, arm, task) * policyMultiplier(arm, prefer) + score := scoreArm(qt, params, arm, task) * policyMultiplier(arm, prefer) if score > bestScore { bestScore = score best = arm @@ -172,13 +172,12 @@ func policyMultiplier(arm *Arm, p PreferPolicy) float64 { } } -// strengthScoreBonus is added to quality when an arm's Strengths list -// matches the incoming task type. Tunable in one place. -const strengthScoreBonus = 0.15 - // scoreArm computes a quality/cost score for an arm. // When the quality tracker has sufficient observations, blends observed EMA -// (70%) with heuristic (30%). Falls back to pure heuristic otherwise. +// (default 70%) with heuristic (default 30%). Falls back to pure heuristic +// otherwise. The blend ratio and strength bonus are tunable via +// BanditParams (config: [router.bandit]); a zero-valued params falls back +// to the built-in defaults. // // Strengths add a fixed bonus to quality when matching task.Type. CostWeight // dampens the cost penalty linearly: @@ -189,16 +188,17 @@ const strengthScoreBonus = 0.15 // the original effectiveCost == cost. With CostWeight=0 cost is fully // ignored (effectiveCost = 1.0). Local arms with sub-1 raw costs are not // amplified by fractional weights (the linear formula stays monotone). -func scoreArm(qt *QualityTracker, arm *Arm, task Task) float64 { +func scoreArm(qt *QualityTracker, params BanditParams, arm *Arm, task Task) float64 { + params = resolveBanditParams(params) hq := heuristicQuality(arm, task) quality := hq if qt != nil { if observed, hasData := qt.Quality(arm.ID, task.Type); hasData { - quality = 0.7*observed + 0.3*hq + quality = params.ObservedWeight*observed + (1-params.ObservedWeight)*hq } } if arm.HasStrength(task.Type) { - quality += strengthScoreBonus + quality += params.StrengthBonus } value := task.ValueScore() rawCost := effectiveCost(arm, task) diff --git a/internal/router/strengths_test.go b/internal/router/strengths_test.go index 13eb395..64a1fba 100644 --- a/internal/router/strengths_test.go +++ b/internal/router/strengths_test.go @@ -65,17 +65,17 @@ func TestScoreArm_CostWeightAffectsArmComparison(t *testing.T) { // CostWeight=1.0: cost dominates, cheap arm wins. cheap.CostWeight, expensive.CostWeight = 1.0, 1.0 - if scoreArm(nil, cheap, task) <= scoreArm(nil, expensive, task) { + if scoreArm(nil, BanditParams{}, cheap, task) <= scoreArm(nil, BanditParams{}, expensive, task) { t.Errorf("CostWeight=1.0: cheap arm should beat expensive arm; cheap=%v expensive=%v", - scoreArm(nil, cheap, task), scoreArm(nil, expensive, task)) + scoreArm(nil, BanditParams{}, cheap, task), scoreArm(nil, BanditParams{}, expensive, task)) } // CostWeight=0.0: cost ignored, quality alone decides → expensive (better // context window) wins. cheap.CostWeight, expensive.CostWeight = 0.001, 0.001 - if scoreArm(nil, expensive, task) <= scoreArm(nil, cheap, task) { + if scoreArm(nil, BanditParams{}, expensive, task) <= scoreArm(nil, BanditParams{}, cheap, task) { t.Errorf("CostWeight~0: higher-quality expensive arm should beat cheap arm; expensive=%v cheap=%v", - scoreArm(nil, expensive, task), scoreArm(nil, cheap, task)) + scoreArm(nil, BanditParams{}, expensive, task), scoreArm(nil, BanditParams{}, cheap, task)) } } @@ -140,8 +140,8 @@ func TestScoreArm_StrengthBonus(t *testing.T) { } task := Task{Type: TaskSecurityReview, EstimatedTokens: 5000, RequiresTools: true, Priority: PriorityNormal} - a := scoreArm(nil, withoutStrength, task) - b := scoreArm(nil, withStrength, task) + a := scoreArm(nil, BanditParams{}, withoutStrength, task) + b := scoreArm(nil, BanditParams{}, withStrength, task) if !(b > a) { t.Errorf("strength-tagged arm score (%v) should exceed plain arm score (%v)", b, a) } @@ -160,8 +160,8 @@ func TestScoreArm_StrengthBonusDoesNotApplyToOtherTasks(t *testing.T) { } task := Task{Type: TaskDebug, EstimatedTokens: 5000, RequiresTools: true, Priority: PriorityNormal} - a := scoreArm(nil, plain, task) - b := scoreArm(nil, tagged, task) + a := scoreArm(nil, BanditParams{}, plain, task) + b := scoreArm(nil, BanditParams{}, tagged, task) if math.Abs(a-b) > 1e-9 { t.Errorf("non-matching task should ignore Strengths: plain=%v tagged=%v", a, b) } @@ -184,7 +184,7 @@ func TestSelectBest_StrengthPromotedArmBeatsCLIAgent(t *testing.T) { } task := Task{Type: TaskSecurityReview, EstimatedTokens: 5000, RequiresTools: true, Priority: PriorityNormal} - got := selectBest(nil, []*Arm{cliAgent, opus}, task, PreferAuto) + got := selectBest(nil, BanditParams{}, []*Arm{cliAgent, opus}, task, PreferAuto) if got == nil { t.Fatal("selectBest returned nil") } @@ -208,7 +208,7 @@ func TestSelectBest_EmptyStrengthsPreservesTierOrder(t *testing.T) { } task := Task{Type: TaskSecurityReview, EstimatedTokens: 5000, RequiresTools: true, Priority: PriorityNormal} - got := selectBest(nil, []*Arm{cliAgent, opus}, task, PreferAuto) + got := selectBest(nil, BanditParams{}, []*Arm{cliAgent, opus}, task, PreferAuto) if got.ID != cliAgent.ID { t.Errorf("without Strengths, CLI-agent tier-1 should win; got %s", got.ID) } @@ -327,7 +327,7 @@ func TestSelectBest_MultiplePromotedArmsBestQualityWins(t *testing.T) { Strengths: []TaskType{TaskSecurityReview}, } - qt := NewQualityTracker() + qt := NewQualityTracker(0, 0) // armB has consistently succeeded — minObservations=3 is enough to flip // the score blend. for i := 0; i < 5; i++ { @@ -339,7 +339,7 @@ func TestSelectBest_MultiplePromotedArmsBestQualityWins(t *testing.T) { } task := Task{Type: TaskSecurityReview, EstimatedTokens: 5000, RequiresTools: true, Priority: PriorityNormal} - got := selectBest(qt, []*Arm{armA, armB}, task, PreferAuto) + got := selectBest(qt, BanditParams{}, []*Arm{armA, armB}, task, PreferAuto) if got == nil { t.Fatal("selectBest returned nil") }