Files
vikingowl eea26a262e feat(router): surface bandit knobs as [router.bandit] config
Four hardcoded constants in the selector and feedback tracker are now
user-tunable via [router.bandit]:

- quality_alpha    (EMA smoothing, default 0.3)
- min_observations (samples before observed overrides heuristic, default 3)
- observed_weight  (observed/heuristic blend ratio, default 0.7)
- strength_bonus   (quality bonus for Strengths-tagged arms, default 0.15)

Each field treats 0 as 'use default', so an empty TOML block is
byte-identical to pre-config behaviour. BanditParams is plumbed via
router.Config{Bandit: ...} and resolveBanditParams() centralises the
fallback so every call site shares the same defaults.

QualityTracker, scoreArm, bestScored, and selectBest signatures now
take the configured values directly rather than reaching for package-
level constants. Tests updated to pass BanditParams{} (defaults) or
explicit overrides where they validate the new tuning paths.

Tracks item #3 from the 'Bandit selector — design decisions deferred'
TODO entry — ships independently of the EMA vs SLM strategic decision.
2026-05-24 22:42:34 +02:00

114 lines
3.3 KiB
Go

package router
import "sync"
// Built-in defaults for the bandit knobs. Surfaced via
// [router.bandit] config keys; see BanditParams in router.go. Kept
// here so the QualityTracker has a sensible fallback when constructed
// without explicit parameters (tests, ad-hoc callers).
const (
defaultQualityAlpha = 0.3 // EMA smoothing factor (~3-sample memory)
defaultMinObservations = 3 // min samples before observed score overrides heuristic
defaultObservedWeight = 0.7 // weight of observed score in observed/heuristic blend
defaultStrengthBonus = 0.15
)
// EMAScore tracks an exponential moving average quality score.
type EMAScore struct {
Value float64
Count int
}
// QualityTracker records per-arm, per-task-type EMA quality scores from elf
// outcomes and per-classifier-source counts used by Phase 4 routing decisions.
type QualityTracker struct {
mu sync.RWMutex
scores map[ArmID]map[TaskType]*EMAScore
classifierCount map[ClassifierSource]int
// Configurable knobs — set via NewQualityTracker. Pass 0 for any
// argument to keep the built-in default.
alpha float64
minObservations int
}
// NewQualityTracker returns an empty QualityTracker. Pass 0 for any
// argument to keep the built-in default (alpha=0.3, minObs=3).
func NewQualityTracker(alpha float64, minObs int) *QualityTracker {
if alpha == 0 {
alpha = defaultQualityAlpha
}
if minObs == 0 {
minObs = defaultMinObservations
}
return &QualityTracker{
scores: make(map[ArmID]map[TaskType]*EMAScore),
classifierCount: make(map[ClassifierSource]int),
alpha: alpha,
minObservations: minObs,
}
}
// RecordClassifier increments the count for a classifier source. Used to
// answer "how often did the SLM actually classify vs fall back?" — Phase 4
// trust signal.
func (qt *QualityTracker) RecordClassifier(src ClassifierSource) {
if src == ClassifierUnknown {
return // pre-classification / forced; don't pollute counters
}
qt.mu.Lock()
defer qt.mu.Unlock()
qt.classifierCount[src]++
}
// ClassifierCounts returns a copy of the per-source observation counts.
func (qt *QualityTracker) ClassifierCounts() map[ClassifierSource]int {
qt.mu.RLock()
defer qt.mu.RUnlock()
out := make(map[ClassifierSource]int, len(qt.classifierCount))
for k, v := range qt.classifierCount {
out[k] = v
}
return out
}
// Record updates the EMA score for the given arm and task type.
func (qt *QualityTracker) Record(armID ArmID, taskType TaskType, success bool) {
observation := 0.0
if success {
observation = 1.0
}
qt.mu.Lock()
defer qt.mu.Unlock()
if qt.scores[armID] == nil {
qt.scores[armID] = make(map[TaskType]*EMAScore)
}
s := qt.scores[armID][taskType]
if s == nil {
s = &EMAScore{}
qt.scores[armID][taskType] = s
}
if s.Count == 0 {
s.Value = observation
} else {
s.Value = qt.alpha*observation + (1-qt.alpha)*s.Value
}
s.Count++
}
// Quality returns the observed EMA score for an arm+task combination.
// Returns (0, false) when fewer than minObservations have been recorded.
func (qt *QualityTracker) Quality(armID ArmID, taskType TaskType) (score float64, hasData bool) {
qt.mu.RLock()
defer qt.mu.RUnlock()
m, ok := qt.scores[armID]
if !ok {
return 0, false
}
s, ok := m[taskType]
if !ok || s.Count < qt.minObservations {
return 0, false
}
return s.Value, true
}