feat(router): surface bandit knobs as [router.bandit] config
Four hardcoded constants in the selector and feedback tracker are now
user-tunable via [router.bandit]:
- quality_alpha (EMA smoothing, default 0.3)
- min_observations (samples before observed overrides heuristic, default 3)
- observed_weight (observed/heuristic blend ratio, default 0.7)
- strength_bonus (quality bonus for Strengths-tagged arms, default 0.15)
Each field treats 0 as 'use default', so an empty TOML block is
byte-identical to pre-config behaviour. BanditParams is plumbed via
router.Config{Bandit: ...} and resolveBanditParams() centralises the
fallback so every call site shares the same defaults.
QualityTracker, scoreArm, bestScored, and selectBest signatures now
take the configured values directly rather than reaching for package-
level constants. Tests updated to pass BanditParams{} (defaults) or
explicit overrides where they validate the new tuning paths.
Tracks item #3 from the 'Bandit selector — design decisions deferred'
TODO entry — ships independently of the EMA vs SLM strategic decision.
This commit is contained in:
+11
-1
@@ -397,7 +397,17 @@ func main() {
|
||||
|
||||
// Create router and register the provider as a single arm
|
||||
// (M4 foundation: one provider from CLI. Multi-provider routing comes with config.)
|
||||
rtr := router.New(router.Config{Logger: logger})
|
||||
// BanditParams come from [router.bandit] config keys; zero values
|
||||
// resolve to built-in defaults inside the router package.
|
||||
rtr := router.New(router.Config{
|
||||
Logger: logger,
|
||||
Bandit: router.BanditParams{
|
||||
QualityAlpha: cfg.Router.Bandit.QualityAlpha,
|
||||
MinObservations: cfg.Router.Bandit.MinObservations,
|
||||
ObservedWeight: cfg.Router.Bandit.ObservedWeight,
|
||||
StrengthBonus: cfg.Router.Bandit.StrengthBonus,
|
||||
},
|
||||
})
|
||||
|
||||
// Apply the prefer-routing-policy from config (default: auto).
|
||||
// Invalid values are rejected here with an actionable error rather
|
||||
|
||||
@@ -157,6 +157,40 @@ type RouterSection struct {
|
||||
// and incognito take priority over this knob. See
|
||||
// docs/superpowers/plans/2026-05-23-prefer-routing-policy.md.
|
||||
Prefer string `toml:"prefer"`
|
||||
|
||||
// Bandit exposes the selector's tuning knobs. Defaults preserve
|
||||
// previous hard-coded behaviour exactly; only set these when you
|
||||
// need to tune the EMA quality tracker for an unusual workload.
|
||||
Bandit BanditSection `toml:"bandit"`
|
||||
}
|
||||
|
||||
// BanditSection holds the scoring knobs for the EMA quality tracker
|
||||
// and the score blend used by the selector. Each field has a sentinel
|
||||
// zero value that means "use the built-in default" so an empty TOML
|
||||
// block is byte-identical to pre-config behaviour. See
|
||||
// internal/router/feedback.go and internal/router/selector.go for the
|
||||
// formulas these knobs feed into.
|
||||
type BanditSection struct {
|
||||
// QualityAlpha is the EMA smoothing factor for arm-quality
|
||||
// observations. Larger values weight recent observations more.
|
||||
// Default: 0.3 (~3-sample memory). 0.0 here means "use default".
|
||||
QualityAlpha float64 `toml:"quality_alpha"`
|
||||
|
||||
// MinObservations is the minimum number of samples required
|
||||
// before observed EMA overrides the heuristic fallback. Default:
|
||||
// 3. 0 here means "use default".
|
||||
MinObservations int `toml:"min_observations"`
|
||||
|
||||
// ObservedWeight is the weight of the observed EMA in the
|
||||
// observed/heuristic blend inside scoreArm: the final quality is
|
||||
// `observed*W + heuristic*(1-W)`. Default: 0.7. 0.0 here means
|
||||
// "use default".
|
||||
ObservedWeight float64 `toml:"observed_weight"`
|
||||
|
||||
// StrengthBonus is the quality bonus added when an arm declares
|
||||
// the current task type in its Strengths list. Default: 0.15.
|
||||
// 0.0 here means "use default".
|
||||
StrengthBonus float64 `toml:"strength_bonus"`
|
||||
}
|
||||
|
||||
// MCPServerConfig defines an MCP server to start and connect to.
|
||||
|
||||
@@ -57,12 +57,12 @@ func benchTasks() []Task {
|
||||
func BenchmarkSelectBest(b *testing.B) {
|
||||
arms := benchArms()
|
||||
tasks := benchTasks()
|
||||
qt := NewQualityTracker()
|
||||
qt := NewQualityTracker(0, 0)
|
||||
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
for _, task := range tasks {
|
||||
selectBest(qt, arms, task, PreferAuto)
|
||||
selectBest(qt, BanditParams{}, arms, task, PreferAuto)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -99,13 +99,13 @@ func BenchmarkRouterSelect(b *testing.B) {
|
||||
|
||||
func BenchmarkScoreArm(b *testing.B) {
|
||||
arms := benchArms()
|
||||
qt := NewQualityTracker()
|
||||
qt := NewQualityTracker(0, 0)
|
||||
task := Task{Type: TaskGeneration, Priority: PriorityNormal, EstimatedTokens: 2000, RequiresTools: true, ComplexityScore: 0.5}
|
||||
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
for _, arm := range arms {
|
||||
scoreArm(qt, arm, task)
|
||||
scoreArm(qt, BanditParams{}, arm, task)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,9 +2,15 @@ package router
|
||||
|
||||
import "sync"
|
||||
|
||||
// Built-in defaults for the bandit knobs. Surfaced via
|
||||
// [router.bandit] config keys; see BanditParams in router.go. Kept
|
||||
// here so the QualityTracker has a sensible fallback when constructed
|
||||
// without explicit parameters (tests, ad-hoc callers).
|
||||
const (
|
||||
qualityAlpha = 0.3 // EMA smoothing factor (~3-sample memory)
|
||||
minObservations = 3 // min samples before observed score overrides heuristic
|
||||
defaultQualityAlpha = 0.3 // EMA smoothing factor (~3-sample memory)
|
||||
defaultMinObservations = 3 // min samples before observed score overrides heuristic
|
||||
defaultObservedWeight = 0.7 // weight of observed score in observed/heuristic blend
|
||||
defaultStrengthBonus = 0.15
|
||||
)
|
||||
|
||||
// EMAScore tracks an exponential moving average quality score.
|
||||
@@ -19,13 +25,27 @@ type QualityTracker struct {
|
||||
mu sync.RWMutex
|
||||
scores map[ArmID]map[TaskType]*EMAScore
|
||||
classifierCount map[ClassifierSource]int
|
||||
|
||||
// Configurable knobs — set via NewQualityTracker. Pass 0 for any
|
||||
// argument to keep the built-in default.
|
||||
alpha float64
|
||||
minObservations int
|
||||
}
|
||||
|
||||
// NewQualityTracker returns an empty QualityTracker.
|
||||
func NewQualityTracker() *QualityTracker {
|
||||
// NewQualityTracker returns an empty QualityTracker. Pass 0 for any
|
||||
// argument to keep the built-in default (alpha=0.3, minObs=3).
|
||||
func NewQualityTracker(alpha float64, minObs int) *QualityTracker {
|
||||
if alpha == 0 {
|
||||
alpha = defaultQualityAlpha
|
||||
}
|
||||
if minObs == 0 {
|
||||
minObs = defaultMinObservations
|
||||
}
|
||||
return &QualityTracker{
|
||||
scores: make(map[ArmID]map[TaskType]*EMAScore),
|
||||
classifierCount: make(map[ClassifierSource]int),
|
||||
alpha: alpha,
|
||||
minObservations: minObs,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,7 +91,7 @@ func (qt *QualityTracker) Record(armID ArmID, taskType TaskType, success bool) {
|
||||
if s.Count == 0 {
|
||||
s.Value = observation
|
||||
} else {
|
||||
s.Value = qualityAlpha*observation + (1-qualityAlpha)*s.Value
|
||||
s.Value = qt.alpha*observation + (1-qt.alpha)*s.Value
|
||||
}
|
||||
s.Count++
|
||||
}
|
||||
@@ -86,7 +106,7 @@ func (qt *QualityTracker) Quality(armID ArmID, taskType TaskType) (score float64
|
||||
return 0, false
|
||||
}
|
||||
s, ok := m[taskType]
|
||||
if !ok || s.Count < minObservations {
|
||||
if !ok || s.Count < qt.minObservations {
|
||||
return 0, false
|
||||
}
|
||||
return s.Value, true
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
func TestQualityTracker_NoDataReturnsHeuristic(t *testing.T) {
|
||||
qt := router.NewQualityTracker()
|
||||
qt := router.NewQualityTracker(0, 0)
|
||||
_, hasData := qt.Quality("arm:model", router.TaskGeneration)
|
||||
if hasData {
|
||||
t.Error("expected no data for unobserved arm")
|
||||
@@ -16,7 +16,7 @@ func TestQualityTracker_NoDataReturnsHeuristic(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestQualityTracker_RecordUpdatesEMA(t *testing.T) {
|
||||
qt := router.NewQualityTracker()
|
||||
qt := router.NewQualityTracker(0, 0)
|
||||
for i := 0; i < 3; i++ {
|
||||
qt.Record("arm:model", router.TaskGeneration, true)
|
||||
}
|
||||
@@ -30,7 +30,7 @@ func TestQualityTracker_RecordUpdatesEMA(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestQualityTracker_AllFailuresLowScore(t *testing.T) {
|
||||
qt := router.NewQualityTracker()
|
||||
qt := router.NewQualityTracker(0, 0)
|
||||
for i := 0; i < 5; i++ {
|
||||
qt.Record("arm:model", router.TaskDebug, false)
|
||||
}
|
||||
@@ -41,7 +41,7 @@ func TestQualityTracker_AllFailuresLowScore(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestQualityTracker_ConcurrentSafe(t *testing.T) {
|
||||
qt := router.NewQualityTracker()
|
||||
qt := router.NewQualityTracker(0, 0)
|
||||
done := make(chan struct{})
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(success bool) {
|
||||
@@ -113,3 +113,45 @@ func TestQualityTracker_InsufficientDataFallsBackToHeuristic(t *testing.T) {
|
||||
}
|
||||
decision.Rollback()
|
||||
}
|
||||
|
||||
func TestQualityTracker_CustomAlphaShortensMemory(t *testing.T) {
|
||||
// alpha=0.9 weights the latest sample heavily; after a single
|
||||
// failure the score should drop further than with the default 0.3.
|
||||
fast := router.NewQualityTracker(0.9, 0)
|
||||
slow := router.NewQualityTracker(0.0, 0) // 0 → default 0.3
|
||||
|
||||
for _, qt := range []*router.QualityTracker{fast, slow} {
|
||||
// Build up history at the high end with 5 successes.
|
||||
for i := 0; i < 5; i++ {
|
||||
qt.Record("arm:m", router.TaskGeneration, true)
|
||||
}
|
||||
// One failure.
|
||||
qt.Record("arm:m", router.TaskGeneration, false)
|
||||
}
|
||||
|
||||
fastScore, _ := fast.Quality("arm:m", router.TaskGeneration)
|
||||
slowScore, _ := slow.Quality("arm:m", router.TaskGeneration)
|
||||
|
||||
if !(fastScore < slowScore) {
|
||||
t.Errorf("expected fast alpha (0.9) to drop quality faster than default (0.3): fast=%f slow=%f", fastScore, slowScore)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQualityTracker_CustomMinObservationsGatesScore(t *testing.T) {
|
||||
// minObs=10 means Quality should return hasData=false until 10
|
||||
// observations are recorded, even though the default would say
|
||||
// "yes" after 3.
|
||||
qt := router.NewQualityTracker(0, 10)
|
||||
for i := 0; i < 5; i++ {
|
||||
qt.Record("arm:m", router.TaskGeneration, true)
|
||||
}
|
||||
if _, hasData := qt.Quality("arm:m", router.TaskGeneration); hasData {
|
||||
t.Error("expected hasData=false at 5 observations with minObs=10")
|
||||
}
|
||||
for i := 0; i < 5; i++ {
|
||||
qt.Record("arm:m", router.TaskGeneration, true)
|
||||
}
|
||||
if _, hasData := qt.Quality("arm:m", router.TaskGeneration); !hasData {
|
||||
t.Error("expected hasData=true after 10 observations with minObs=10")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
func TestQualityTracker_SnapshotRestore_RoundTrip(t *testing.T) {
|
||||
qt := router.NewQualityTracker()
|
||||
qt := router.NewQualityTracker(0, 0)
|
||||
// Record some outcomes
|
||||
qt.Record("anthropic/claude-3-5-sonnet", router.TaskGeneration, true)
|
||||
qt.Record("anthropic/claude-3-5-sonnet", router.TaskGeneration, true)
|
||||
@@ -33,7 +33,7 @@ func TestQualityTracker_SnapshotRestore_RoundTrip(t *testing.T) {
|
||||
}
|
||||
|
||||
// Restore into a fresh tracker
|
||||
qt2 := router.NewQualityTracker()
|
||||
qt2 := router.NewQualityTracker(0, 0)
|
||||
qt2.Restore(restored)
|
||||
|
||||
// After restore, Quality() should return data (Count >= minObservations=3)
|
||||
@@ -47,7 +47,7 @@ func TestQualityTracker_SnapshotRestore_RoundTrip(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestQualityTracker_Snapshot_Empty(t *testing.T) {
|
||||
qt := router.NewQualityTracker()
|
||||
qt := router.NewQualityTracker(0, 0)
|
||||
snap := qt.Snapshot()
|
||||
if snap.Scores == nil {
|
||||
t.Error("scores map should be initialized (not nil)")
|
||||
@@ -58,7 +58,7 @@ func TestQualityTracker_Snapshot_Empty(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestQualityTracker_ClassifierCounts_RecordAndSnapshot(t *testing.T) {
|
||||
qt := router.NewQualityTracker()
|
||||
qt := router.NewQualityTracker(0, 0)
|
||||
qt.RecordClassifier(router.ClassifierHeuristic)
|
||||
qt.RecordClassifier(router.ClassifierSLM)
|
||||
qt.RecordClassifier(router.ClassifierSLM)
|
||||
@@ -92,7 +92,7 @@ func TestQualityTracker_ClassifierCounts_RecordAndSnapshot(t *testing.T) {
|
||||
if err := json.Unmarshal(data, &restored); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
qt2 := router.NewQualityTracker()
|
||||
qt2 := router.NewQualityTracker(0, 0)
|
||||
qt2.Restore(restored)
|
||||
if qt2.ClassifierCounts()[router.ClassifierSLM] != 2 {
|
||||
t.Errorf("restored slm count = %d, want 2", qt2.ClassifierCounts()[router.ClassifierSLM])
|
||||
@@ -107,7 +107,7 @@ func TestQualityTracker_Restore_BackCompat_NoClassifierCounts(t *testing.T) {
|
||||
if err := json.Unmarshal(legacy, &snap); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
qt := router.NewQualityTracker()
|
||||
qt := router.NewQualityTracker(0, 0)
|
||||
qt.Restore(snap)
|
||||
if qt.ClassifierCounts() == nil {
|
||||
t.Error("ClassifierCounts() must return a non-nil map after restoring old snapshot")
|
||||
@@ -122,7 +122,7 @@ func TestQualityTracker_Restore_BackCompat_NoClassifierCounts(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestQualityTracker_Restore_Replaces(t *testing.T) {
|
||||
qt := router.NewQualityTracker()
|
||||
qt := router.NewQualityTracker(0, 0)
|
||||
qt.Record("arm-a", router.TaskDebug, true)
|
||||
qt.Record("arm-a", router.TaskDebug, true)
|
||||
qt.Record("arm-a", router.TaskDebug, true)
|
||||
|
||||
@@ -27,6 +27,7 @@ type Router struct {
|
||||
preferPolicy PreferPolicy
|
||||
|
||||
quality *QualityTracker
|
||||
bandit BanditParams
|
||||
}
|
||||
|
||||
// PreferPolicy biases the scoring step toward local or cloud arms.
|
||||
@@ -77,6 +78,41 @@ func (p PreferPolicy) String() string {
|
||||
|
||||
type Config struct {
|
||||
Logger *slog.Logger
|
||||
// Bandit tunes the selector's scoring knobs. Pass a zero value to
|
||||
// keep all pre-config behaviour byte-identical; set individual
|
||||
// fields to override the corresponding default.
|
||||
Bandit BanditParams
|
||||
}
|
||||
|
||||
// BanditParams controls the EMA quality tracker and score blend used
|
||||
// by the selector. Each field has a "use default" sentinel (0 for
|
||||
// floats and ints) so a zero-valued BanditParams is byte-identical to
|
||||
// the pre-config hardcoded constants. Defaults are defined in
|
||||
// resolveBanditParams below.
|
||||
type BanditParams struct {
|
||||
QualityAlpha float64
|
||||
MinObservations int
|
||||
ObservedWeight float64
|
||||
StrengthBonus float64
|
||||
}
|
||||
|
||||
// resolveBanditParams fills in the built-in defaults for any field
|
||||
// left at its zero value. Centralised so the same defaults apply
|
||||
// across NewQualityTracker, scoreArm, and any future caller.
|
||||
func resolveBanditParams(p BanditParams) BanditParams {
|
||||
if p.QualityAlpha == 0 {
|
||||
p.QualityAlpha = defaultQualityAlpha
|
||||
}
|
||||
if p.MinObservations == 0 {
|
||||
p.MinObservations = defaultMinObservations
|
||||
}
|
||||
if p.ObservedWeight == 0 {
|
||||
p.ObservedWeight = defaultObservedWeight
|
||||
}
|
||||
if p.StrengthBonus == 0 {
|
||||
p.StrengthBonus = defaultStrengthBonus
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func New(cfg Config) *Router {
|
||||
@@ -84,10 +120,12 @@ func New(cfg Config) *Router {
|
||||
if logger == nil {
|
||||
logger = slog.Default()
|
||||
}
|
||||
params := resolveBanditParams(cfg.Bandit)
|
||||
return &Router{
|
||||
arms: make(map[ArmID]*Arm),
|
||||
logger: logger,
|
||||
quality: NewQualityTracker(),
|
||||
quality: NewQualityTracker(params.QualityAlpha, params.MinObservations),
|
||||
bandit: params,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -172,7 +210,7 @@ func (r *Router) Select(task Task) RoutingDecision {
|
||||
}
|
||||
|
||||
// Select best
|
||||
best := selectBest(r.quality, feasible, task, r.preferPolicy)
|
||||
best := selectBest(r.quality, r.bandit, feasible, task, r.preferPolicy)
|
||||
if best == nil {
|
||||
return RoutingDecision{Error: fmt.Errorf("selection failed")}
|
||||
}
|
||||
|
||||
@@ -262,7 +262,7 @@ func TestSelectBest_PrefersToolSupport(t *testing.T) {
|
||||
}
|
||||
|
||||
task := Task{Type: TaskGeneration, RequiresTools: true, Priority: PriorityNormal}
|
||||
best := selectBest(nil, []*Arm{withoutTools, withTools}, task, PreferAuto)
|
||||
best := selectBest(nil, BanditParams{}, []*Arm{withoutTools, withTools}, task, PreferAuto)
|
||||
|
||||
if best.ID != "a/with-tools" {
|
||||
t.Errorf("should prefer arm with tool support, got %s", best.ID)
|
||||
@@ -282,7 +282,7 @@ func TestSelectBest_PrefersThinkingForPlanning(t *testing.T) {
|
||||
}
|
||||
|
||||
task := Task{Type: TaskPlanning, RequiresTools: true, Priority: PriorityNormal, EstimatedTokens: 5000}
|
||||
best := selectBest(nil, []*Arm{noThinking, thinking}, task, PreferAuto)
|
||||
best := selectBest(nil, BanditParams{}, []*Arm{noThinking, thinking}, task, PreferAuto)
|
||||
|
||||
if best.ID != "a/thinking" {
|
||||
t.Errorf("should prefer thinking model for planning, got %s", best.ID)
|
||||
@@ -625,7 +625,7 @@ func TestSelectBest_SmallArmWinsTrivialTask(t *testing.T) {
|
||||
Capabilities: provider.Capabilities{ToolUse: false},
|
||||
}
|
||||
task := Task{Type: TaskExplain, ComplexityScore: 0.05, RequiresTools: false}
|
||||
got := selectBest(nil, []*Arm{cliArm, smallArm}, task, PreferAuto)
|
||||
got := selectBest(nil, BanditParams{}, []*Arm{cliArm, smallArm}, task, PreferAuto)
|
||||
if got != smallArm {
|
||||
t.Errorf("selectBest = %v, want smallArm", got)
|
||||
}
|
||||
@@ -647,7 +647,7 @@ func TestSelectBest_CLIAgentWinsComplexTask(t *testing.T) {
|
||||
Capabilities: provider.Capabilities{ToolUse: false},
|
||||
}
|
||||
task := Task{Type: TaskRefactor, ComplexityScore: 0.7, RequiresTools: true}
|
||||
got := selectBest(nil, []*Arm{cliArm, smallArm}, task, PreferAuto)
|
||||
got := selectBest(nil, BanditParams{}, []*Arm{cliArm, smallArm}, task, PreferAuto)
|
||||
if got != cliArm {
|
||||
t.Errorf("selectBest = %v, want cliArm", got)
|
||||
}
|
||||
@@ -672,21 +672,21 @@ func TestSelectBest_TierPreference(t *testing.T) {
|
||||
task := Task{Type: TaskGeneration, Priority: PriorityNormal, EstimatedTokens: 1000}
|
||||
|
||||
t.Run("CLI beats local and API", func(t *testing.T) {
|
||||
best := selectBest(nil, []*Arm{apiArm, localArm, cliArm}, task, PreferAuto)
|
||||
best := selectBest(nil, BanditParams{}, []*Arm{apiArm, localArm, cliArm}, task, PreferAuto)
|
||||
if best.ID != "subprocess/claude" {
|
||||
t.Errorf("want subprocess/claude (tier 0), got %s", best.ID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("local beats API when no CLI", func(t *testing.T) {
|
||||
best := selectBest(nil, []*Arm{apiArm, localArm}, task, PreferAuto)
|
||||
best := selectBest(nil, BanditParams{}, []*Arm{apiArm, localArm}, task, PreferAuto)
|
||||
if best.ID != "ollama/llama3" {
|
||||
t.Errorf("want ollama/llama3 (tier 1), got %s", best.ID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("API selected when only option", func(t *testing.T) {
|
||||
best := selectBest(nil, []*Arm{apiArm}, task, PreferAuto)
|
||||
best := selectBest(nil, BanditParams{}, []*Arm{apiArm}, task, PreferAuto)
|
||||
if best == nil || best.ID != "mistral/mistral-large" {
|
||||
t.Errorf("want mistral/mistral-large (tier 2), got %v", best)
|
||||
}
|
||||
|
||||
+13
-13
@@ -98,7 +98,7 @@ func armBaseTier(arm *Arm, task Task) int {
|
||||
//
|
||||
// Step 2 (fallback): walk tiers low→high. Within a tier, highest-scoring
|
||||
// arm wins.
|
||||
func selectBest(qt *QualityTracker, arms []*Arm, task Task, prefer PreferPolicy) *Arm {
|
||||
func selectBest(qt *QualityTracker, params BanditParams, arms []*Arm, task Task, prefer PreferPolicy) *Arm {
|
||||
if len(arms) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -110,7 +110,7 @@ func selectBest(qt *QualityTracker, arms []*Arm, task Task, prefer PreferPolicy)
|
||||
}
|
||||
}
|
||||
if len(promoted) > 0 {
|
||||
return bestScored(qt, promoted, task, prefer)
|
||||
return bestScored(qt, params, promoted, task, prefer)
|
||||
}
|
||||
|
||||
// Walk tiers low→high. armTier returns up to 5 when prefer is set
|
||||
@@ -124,18 +124,18 @@ func selectBest(qt *QualityTracker, arms []*Arm, task Task, prefer PreferPolicy)
|
||||
}
|
||||
}
|
||||
if len(inTier) > 0 {
|
||||
return bestScored(qt, inTier, task, prefer)
|
||||
return bestScored(qt, params, inTier, task, prefer)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// bestScored returns the highest-scoring arm within a set.
|
||||
func bestScored(qt *QualityTracker, arms []*Arm, task Task, prefer PreferPolicy) *Arm {
|
||||
func bestScored(qt *QualityTracker, params BanditParams, arms []*Arm, task Task, prefer PreferPolicy) *Arm {
|
||||
var best *Arm
|
||||
bestScore := math.Inf(-1)
|
||||
for _, arm := range arms {
|
||||
score := scoreArm(qt, arm, task) * policyMultiplier(arm, prefer)
|
||||
score := scoreArm(qt, params, arm, task) * policyMultiplier(arm, prefer)
|
||||
if score > bestScore {
|
||||
bestScore = score
|
||||
best = arm
|
||||
@@ -172,13 +172,12 @@ func policyMultiplier(arm *Arm, p PreferPolicy) float64 {
|
||||
}
|
||||
}
|
||||
|
||||
// strengthScoreBonus is added to quality when an arm's Strengths list
|
||||
// matches the incoming task type. Tunable in one place.
|
||||
const strengthScoreBonus = 0.15
|
||||
|
||||
// scoreArm computes a quality/cost score for an arm.
|
||||
// When the quality tracker has sufficient observations, blends observed EMA
|
||||
// (70%) with heuristic (30%). Falls back to pure heuristic otherwise.
|
||||
// (default 70%) with heuristic (default 30%). Falls back to pure heuristic
|
||||
// otherwise. The blend ratio and strength bonus are tunable via
|
||||
// BanditParams (config: [router.bandit]); a zero-valued params falls back
|
||||
// to the built-in defaults.
|
||||
//
|
||||
// Strengths add a fixed bonus to quality when matching task.Type. CostWeight
|
||||
// dampens the cost penalty linearly:
|
||||
@@ -189,16 +188,17 @@ const strengthScoreBonus = 0.15
|
||||
// the original effectiveCost == cost. With CostWeight=0 cost is fully
|
||||
// ignored (effectiveCost = 1.0). Local arms with sub-1 raw costs are not
|
||||
// amplified by fractional weights (the linear formula stays monotone).
|
||||
func scoreArm(qt *QualityTracker, arm *Arm, task Task) float64 {
|
||||
func scoreArm(qt *QualityTracker, params BanditParams, arm *Arm, task Task) float64 {
|
||||
params = resolveBanditParams(params)
|
||||
hq := heuristicQuality(arm, task)
|
||||
quality := hq
|
||||
if qt != nil {
|
||||
if observed, hasData := qt.Quality(arm.ID, task.Type); hasData {
|
||||
quality = 0.7*observed + 0.3*hq
|
||||
quality = params.ObservedWeight*observed + (1-params.ObservedWeight)*hq
|
||||
}
|
||||
}
|
||||
if arm.HasStrength(task.Type) {
|
||||
quality += strengthScoreBonus
|
||||
quality += params.StrengthBonus
|
||||
}
|
||||
value := task.ValueScore()
|
||||
rawCost := effectiveCost(arm, task)
|
||||
|
||||
@@ -65,17 +65,17 @@ func TestScoreArm_CostWeightAffectsArmComparison(t *testing.T) {
|
||||
|
||||
// CostWeight=1.0: cost dominates, cheap arm wins.
|
||||
cheap.CostWeight, expensive.CostWeight = 1.0, 1.0
|
||||
if scoreArm(nil, cheap, task) <= scoreArm(nil, expensive, task) {
|
||||
if scoreArm(nil, BanditParams{}, cheap, task) <= scoreArm(nil, BanditParams{}, expensive, task) {
|
||||
t.Errorf("CostWeight=1.0: cheap arm should beat expensive arm; cheap=%v expensive=%v",
|
||||
scoreArm(nil, cheap, task), scoreArm(nil, expensive, task))
|
||||
scoreArm(nil, BanditParams{}, cheap, task), scoreArm(nil, BanditParams{}, expensive, task))
|
||||
}
|
||||
|
||||
// CostWeight=0.0: cost ignored, quality alone decides → expensive (better
|
||||
// context window) wins.
|
||||
cheap.CostWeight, expensive.CostWeight = 0.001, 0.001
|
||||
if scoreArm(nil, expensive, task) <= scoreArm(nil, cheap, task) {
|
||||
if scoreArm(nil, BanditParams{}, expensive, task) <= scoreArm(nil, BanditParams{}, cheap, task) {
|
||||
t.Errorf("CostWeight~0: higher-quality expensive arm should beat cheap arm; expensive=%v cheap=%v",
|
||||
scoreArm(nil, expensive, task), scoreArm(nil, cheap, task))
|
||||
scoreArm(nil, BanditParams{}, expensive, task), scoreArm(nil, BanditParams{}, cheap, task))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,8 +140,8 @@ func TestScoreArm_StrengthBonus(t *testing.T) {
|
||||
}
|
||||
task := Task{Type: TaskSecurityReview, EstimatedTokens: 5000, RequiresTools: true, Priority: PriorityNormal}
|
||||
|
||||
a := scoreArm(nil, withoutStrength, task)
|
||||
b := scoreArm(nil, withStrength, task)
|
||||
a := scoreArm(nil, BanditParams{}, withoutStrength, task)
|
||||
b := scoreArm(nil, BanditParams{}, withStrength, task)
|
||||
if !(b > a) {
|
||||
t.Errorf("strength-tagged arm score (%v) should exceed plain arm score (%v)", b, a)
|
||||
}
|
||||
@@ -160,8 +160,8 @@ func TestScoreArm_StrengthBonusDoesNotApplyToOtherTasks(t *testing.T) {
|
||||
}
|
||||
task := Task{Type: TaskDebug, EstimatedTokens: 5000, RequiresTools: true, Priority: PriorityNormal}
|
||||
|
||||
a := scoreArm(nil, plain, task)
|
||||
b := scoreArm(nil, tagged, task)
|
||||
a := scoreArm(nil, BanditParams{}, plain, task)
|
||||
b := scoreArm(nil, BanditParams{}, tagged, task)
|
||||
if math.Abs(a-b) > 1e-9 {
|
||||
t.Errorf("non-matching task should ignore Strengths: plain=%v tagged=%v", a, b)
|
||||
}
|
||||
@@ -184,7 +184,7 @@ func TestSelectBest_StrengthPromotedArmBeatsCLIAgent(t *testing.T) {
|
||||
}
|
||||
|
||||
task := Task{Type: TaskSecurityReview, EstimatedTokens: 5000, RequiresTools: true, Priority: PriorityNormal}
|
||||
got := selectBest(nil, []*Arm{cliAgent, opus}, task, PreferAuto)
|
||||
got := selectBest(nil, BanditParams{}, []*Arm{cliAgent, opus}, task, PreferAuto)
|
||||
if got == nil {
|
||||
t.Fatal("selectBest returned nil")
|
||||
}
|
||||
@@ -208,7 +208,7 @@ func TestSelectBest_EmptyStrengthsPreservesTierOrder(t *testing.T) {
|
||||
}
|
||||
|
||||
task := Task{Type: TaskSecurityReview, EstimatedTokens: 5000, RequiresTools: true, Priority: PriorityNormal}
|
||||
got := selectBest(nil, []*Arm{cliAgent, opus}, task, PreferAuto)
|
||||
got := selectBest(nil, BanditParams{}, []*Arm{cliAgent, opus}, task, PreferAuto)
|
||||
if got.ID != cliAgent.ID {
|
||||
t.Errorf("without Strengths, CLI-agent tier-1 should win; got %s", got.ID)
|
||||
}
|
||||
@@ -327,7 +327,7 @@ func TestSelectBest_MultiplePromotedArmsBestQualityWins(t *testing.T) {
|
||||
Strengths: []TaskType{TaskSecurityReview},
|
||||
}
|
||||
|
||||
qt := NewQualityTracker()
|
||||
qt := NewQualityTracker(0, 0)
|
||||
// armB has consistently succeeded — minObservations=3 is enough to flip
|
||||
// the score blend.
|
||||
for i := 0; i < 5; i++ {
|
||||
@@ -339,7 +339,7 @@ func TestSelectBest_MultiplePromotedArmsBestQualityWins(t *testing.T) {
|
||||
}
|
||||
|
||||
task := Task{Type: TaskSecurityReview, EstimatedTokens: 5000, RequiresTools: true, Priority: PriorityNormal}
|
||||
got := selectBest(qt, []*Arm{armA, armB}, task, PreferAuto)
|
||||
got := selectBest(qt, BanditParams{}, []*Arm{armA, armB}, task, PreferAuto)
|
||||
if got == nil {
|
||||
t.Fatal("selectBest returned nil")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user