feat: add router foundation with task classification and arm selection

internal/router/ — core routing layer:
- Task classification: 10 types (boilerplate, generation, refactor,
  review, unit_test, planning, orchestration, security_review, debug,
  explain) with keyword heuristics and complexity scoring
- Arm registry: provider+model pairs with capabilities and cost
- Limit pools: shared resource budgets with scarcity multipliers,
  optimistic reservation, use-it-or-lose-it discounting
- Heuristic selector: score = (quality × value) / effective_cost
  Prefers tools, thinking for planning, penalizes small models on
  complex tasks
- Router: Select() picks best feasible arm, ForceArm() for CLI override

Engine now routes through router.Select() when configured.
Wired into CLI — arm registered per --provider/--model flags.

20 router tests. 173 tests total across 13 packages.
This commit is contained in:
2026-04-03 14:23:15 +02:00
parent 33dec722b8
commit b9faa30ea8
9 changed files with 1114 additions and 10 deletions

View File

@@ -12,6 +12,7 @@ import (
"somegit.dev/Owlibou/gnoma/internal/engine"
"somegit.dev/Owlibou/gnoma/internal/provider"
"somegit.dev/Owlibou/gnoma/internal/router"
"somegit.dev/Owlibou/gnoma/internal/security"
anthropicprov "somegit.dev/Owlibou/gnoma/internal/provider/anthropic"
"somegit.dev/Owlibou/gnoma/internal/provider/mistral"
@@ -81,6 +82,23 @@ func main() {
// Re-register bash tool with aliases
reg.Register(bash.New(bash.WithAliases(aliases)))
// Create router and register the provider as a single arm
// (M4 foundation: one provider from CLI. Multi-provider routing comes with config.)
rtr := router.New(router.Config{Logger: logger})
armModel := *model
if armModel == "" {
armModel = prov.DefaultModel()
}
armID := router.NewArmID(*providerName, armModel)
rtr.RegisterArm(&router.Arm{
ID: armID,
Provider: prov,
ModelName: armModel,
IsLocal: localProviders[*providerName],
Capabilities: provider.Capabilities{ToolUse: true}, // trust CLI provider
})
rtr.ForceArm(armID)
// Create firewall
fw := security.NewFirewall(security.FirewallConfig{
ScanOutgoing: true,
@@ -92,6 +110,7 @@ func main() {
// Create engine
eng, err := engine.New(engine.Config{
Provider: prov,
Router: rtr,
Tools: reg,
Firewall: fw,
System: *system,

View File

@@ -7,13 +7,15 @@ import (
"somegit.dev/Owlibou/gnoma/internal/message"
"somegit.dev/Owlibou/gnoma/internal/provider"
"somegit.dev/Owlibou/gnoma/internal/router"
"somegit.dev/Owlibou/gnoma/internal/security"
"somegit.dev/Owlibou/gnoma/internal/tool"
)
// Config holds engine configuration.
type Config struct {
Provider provider.Provider
Provider provider.Provider // direct provider (used if Router is nil)
Router *router.Router // nil = use Provider directly
Tools *tool.Registry
Firewall *security.Firewall // nil = no scanning
System string // system prompt

View File

@@ -7,6 +7,7 @@ import (
"somegit.dev/Owlibou/gnoma/internal/message"
"somegit.dev/Owlibou/gnoma/internal/provider"
"somegit.dev/Owlibou/gnoma/internal/router"
"somegit.dev/Owlibou/gnoma/internal/stream"
)
@@ -38,16 +39,50 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
// Build provider request (gates tools on model capabilities)
req := e.buildRequest(ctx)
e.logger.Debug("streaming request",
"provider", e.cfg.Provider.Name(),
"model", req.Model,
"messages", len(req.Messages),
"tools", len(req.Tools),
"round", turn.Rounds,
)
// Route and stream
var s stream.Stream
var err error
// Stream from provider
s, err := e.cfg.Provider.Stream(ctx, req)
if e.cfg.Router != nil {
// Classify task from the latest user message
prompt := ""
for i := len(e.history) - 1; i >= 0; i-- {
if e.history[i].Role == message.RoleUser {
prompt = e.history[i].TextContent()
break
}
}
task := router.ClassifyTask(prompt)
task.EstimatedTokens = 4000 // rough default
e.logger.Debug("routing request",
"task_type", task.Type,
"complexity", task.ComplexityScore,
"round", turn.Rounds,
)
var arm *router.Arm
s, arm, err = e.cfg.Router.Stream(ctx, task, req)
if arm != nil {
e.logger.Debug("streaming request",
"provider", arm.Provider.Name(),
"model", arm.ModelName,
"arm", arm.ID,
"messages", len(req.Messages),
"tools", len(req.Tools),
"round", turn.Rounds,
)
}
} else {
e.logger.Debug("streaming request",
"provider", e.cfg.Provider.Name(),
"model", req.Model,
"messages", len(req.Messages),
"tools", len(req.Tools),
"round", turn.Rounds,
)
s, err = e.cfg.Provider.Stream(ctx, req)
}
if err != nil {
return nil, fmt.Errorf("provider stream: %w", err)
}

47
internal/router/arm.go Normal file
View File

@@ -0,0 +1,47 @@
package router
import (
"somegit.dev/Owlibou/gnoma/internal/provider"
)
// ArmID uniquely identifies a model+provider pair.
type ArmID string
// Arm represents a provider+model pair available for routing.
type Arm struct {
ID ArmID
Provider provider.Provider
ModelName string
IsLocal bool
Capabilities provider.Capabilities
Pools []*LimitPool
// Cost per 1k tokens (EUR, estimated)
CostPer1kInput float64
CostPer1kOutput float64
}
// NewArmID creates an arm ID from provider name and model.
func NewArmID(providerName, model string) ArmID {
return ArmID(providerName + "/" + model)
}
// EstimateCost returns estimated cost in EUR for a task.
func (a *Arm) EstimateCost(estimatedTokens int) float64 {
// Rough estimate: 60% input, 40% output
inputTokens := float64(estimatedTokens) * 0.6
outputTokens := float64(estimatedTokens) * 0.4
return (inputTokens/1000)*a.CostPer1kInput + (outputTokens/1000)*a.CostPer1kOutput
}
// SupportsTools returns true if this arm's model supports function calling.
func (a *Arm) SupportsTools() bool {
return a.Capabilities.ToolUse
}
// ArmPerf holds live performance metrics for an arm.
type ArmPerf struct {
TTFT_P50_ms float64 // time to first token, p50
TTFT_P95_ms float64 // time to first token, p95
ToksPerSec float64 // tokens per second throughput
}

170
internal/router/pool.go Normal file
View File

@@ -0,0 +1,170 @@
package router
import (
"math"
"sync"
"time"
)
// PoolKind identifies the type of resource a pool tracks.
type PoolKind int
const (
PoolRPM PoolKind = iota // requests per minute
PoolRPD // requests per day
PoolTPD // tokens per day
PoolCostEUR // monetary cost cap
PoolCustom // arbitrary units
)
// LimitPool tracks a shared resource budget that arms draw from.
type LimitPool struct {
mu sync.Mutex
ID string
Kind PoolKind
TotalLimit float64
Used float64
Reserved float64 // optimistically reserved for in-flight requests
ResetPeriod time.Duration
ResetAt time.Time
// Per-arm consumption rates (units per 1k tokens or per request)
ArmRates map[ArmID]float64
// Scarcity curve aggressiveness. k=2 gentle, k=4 aggressive hoarding.
ScarcityK float64
}
// RemainingFraction returns the fraction of budget still available.
func (p *LimitPool) RemainingFraction() float64 {
p.mu.Lock()
defer p.mu.Unlock()
if p.TotalLimit <= 0 {
return 0
}
return 1.0 - (p.Used+p.Reserved)/p.TotalLimit
}
// ScarcityMultiplier returns a cost inflation factor based on remaining budget.
// As resources deplete, the multiplier increases, making the arm more expensive.
func (p *LimitPool) ScarcityMultiplier() float64 {
p.mu.Lock()
defer p.mu.Unlock()
return p.scarcityMultiplierLocked()
}
func (p *LimitPool) scarcityMultiplierLocked() float64 {
if p.TotalLimit <= 0 {
return math.Inf(1)
}
f := 1.0 - (p.Used+p.Reserved)/p.TotalLimit
if f <= 0 {
return math.Inf(1) // exhausted
}
// Use-it-or-lose-it: if reset is imminent and headroom exists, discount
hoursToReset := time.Until(p.ResetAt).Hours()
if !p.ResetAt.IsZero() && hoursToReset > 0 && hoursToReset < 1.0 && f > 0.3 {
return 0.5
}
k := p.ScarcityK
if k <= 0 {
k = 2.0 // gentle default
}
return 1.0 / math.Pow(f, k)
}
// Exhausted returns true if the pool has no remaining capacity.
func (p *LimitPool) Exhausted() bool {
return p.RemainingFraction() <= 0
}
// CanAfford returns true if the pool can cover the projected consumption.
func (p *LimitPool) CanAfford(armID ArmID, estimatedTokens int) bool {
p.mu.Lock()
defer p.mu.Unlock()
rate := p.ArmRates[armID]
if rate == 0 {
return true // no rate defined = no limit
}
projected := rate * float64(estimatedTokens) / 1000.0
available := p.TotalLimit - p.Used - p.Reserved
return projected <= available
}
// Reservation represents an optimistic resource reservation.
type Reservation struct {
pool *LimitPool
armID ArmID
projected float64
committed bool
}
// Reserve creates an optimistic reservation. Call Commit() with actual usage
// on completion, or Rollback() on failure.
func (p *LimitPool) Reserve(armID ArmID, estimatedTokens int) (*Reservation, bool) {
p.mu.Lock()
defer p.mu.Unlock()
rate := p.ArmRates[armID]
if rate == 0 {
return &Reservation{pool: p}, true // no limit
}
projected := rate * float64(estimatedTokens) / 1000.0
available := p.TotalLimit - p.Used - p.Reserved
if projected > available {
return nil, false
}
p.Reserved += projected
return &Reservation{
pool: p,
armID: armID,
projected: projected,
}, true
}
// Commit finalizes the reservation with actual consumption.
func (r *Reservation) Commit(actualTokens int) {
if r.committed || r.pool == nil {
return
}
r.committed = true
r.pool.mu.Lock()
defer r.pool.mu.Unlock()
rate := r.pool.ArmRates[r.armID]
actual := rate * float64(actualTokens) / 1000.0
r.pool.Reserved -= r.projected
r.pool.Used += actual
}
// Rollback releases the reservation without consumption.
func (r *Reservation) Rollback() {
if r.committed || r.pool == nil || r.projected == 0 {
return
}
r.committed = true
r.pool.mu.Lock()
defer r.pool.mu.Unlock()
r.pool.Reserved -= r.projected
}
// CheckReset resets usage if the reset period has elapsed.
func (p *LimitPool) CheckReset() {
p.mu.Lock()
defer p.mu.Unlock()
if !p.ResetAt.IsZero() && time.Now().After(p.ResetAt) {
p.Used = 0
p.Reserved = 0
p.ResetAt = p.ResetAt.Add(p.ResetPeriod)
}
}

160
internal/router/router.go Normal file
View File

@@ -0,0 +1,160 @@
package router
import (
"context"
"fmt"
"log/slog"
"sync"
"somegit.dev/Owlibou/gnoma/internal/provider"
"somegit.dev/Owlibou/gnoma/internal/stream"
)
// Router selects the best arm for a given task.
// M4: heuristic selection. M9: bandit learning.
type Router struct {
mu sync.RWMutex
arms map[ArmID]*Arm
logger *slog.Logger
// Optional: force a specific arm (--provider flag override)
forcedArm ArmID
}
type Config struct {
Logger *slog.Logger
}
func New(cfg Config) *Router {
logger := cfg.Logger
if logger == nil {
logger = slog.Default()
}
return &Router{
arms: make(map[ArmID]*Arm),
logger: logger,
}
}
// RegisterArm adds an arm to the router.
func (r *Router) RegisterArm(arm *Arm) {
r.mu.Lock()
defer r.mu.Unlock()
r.arms[arm.ID] = arm
r.logger.Debug("arm registered", "id", arm.ID, "local", arm.IsLocal, "tools", arm.SupportsTools())
}
// ForceArm overrides routing to always select a specific arm.
// Used for --provider CLI flag.
func (r *Router) ForceArm(id ArmID) {
r.mu.Lock()
defer r.mu.Unlock()
r.forcedArm = id
}
// Select picks the best arm for the given task.
func (r *Router) Select(task Task) RoutingDecision {
r.mu.RLock()
defer r.mu.RUnlock()
// If an arm is forced, use it directly
if r.forcedArm != "" {
arm, ok := r.arms[r.forcedArm]
if !ok {
return RoutingDecision{Error: fmt.Errorf("forced arm %q not found", r.forcedArm)}
}
return RoutingDecision{Strategy: StrategySingleArm, Arm: arm}
}
// Collect all arms
allArms := make([]*Arm, 0, len(r.arms))
for _, arm := range r.arms {
allArms = append(allArms, arm)
}
if len(allArms) == 0 {
return RoutingDecision{Error: fmt.Errorf("no arms registered")}
}
// Filter to feasible arms
feasible := filterFeasible(allArms, task)
if len(feasible) == 0 {
return RoutingDecision{Error: fmt.Errorf("no feasible arm for task type %s", task.Type)}
}
// Select best
best := selectBest(feasible, task)
if best == nil {
return RoutingDecision{Error: fmt.Errorf("selection failed")}
}
r.logger.Debug("arm selected",
"arm", best.ID,
"task_type", task.Type,
"complexity", task.ComplexityScore,
)
return RoutingDecision{Strategy: StrategySingleArm, Arm: best}
}
// Arms returns all registered arms.
func (r *Router) Arms() []*Arm {
r.mu.RLock()
defer r.mu.RUnlock()
arms := make([]*Arm, 0, len(r.arms))
for _, a := range r.arms {
arms = append(arms, a)
}
return arms
}
// RegisterProvider registers all models from a provider as arms.
func (r *Router) RegisterProvider(ctx context.Context, prov provider.Provider, isLocal bool, costs map[string][2]float64) {
models, err := prov.Models(ctx)
if err != nil {
r.logger.Debug("failed to list models", "provider", prov.Name(), "error", err)
// Register at least the default model
id := NewArmID(prov.Name(), prov.DefaultModel())
r.RegisterArm(&Arm{
ID: id,
Provider: prov,
ModelName: prov.DefaultModel(),
IsLocal: isLocal,
Capabilities: provider.Capabilities{ToolUse: true}, // optimistic
})
return
}
for _, m := range models {
id := NewArmID(prov.Name(), m.ID)
arm := &Arm{
ID: id,
Provider: prov,
ModelName: m.ID,
IsLocal: isLocal,
Capabilities: m.Capabilities,
}
if c, ok := costs[m.ID]; ok {
arm.CostPer1kInput = c[0]
arm.CostPer1kOutput = c[1]
}
r.RegisterArm(arm)
}
}
// Stream is a convenience that selects an arm and streams from it.
func (r *Router) Stream(ctx context.Context, task Task, req provider.Request) (stream.Stream, *Arm, error) {
decision := r.Select(task)
if decision.Error != nil {
return nil, nil, decision.Error
}
arm := decision.Arm
req.Model = arm.ModelName
s, err := arm.Provider.Stream(ctx, req)
if err != nil {
return nil, arm, err
}
return s, arm, nil
}

View File

@@ -0,0 +1,305 @@
package router
import (
"math"
"testing"
"time"
"somegit.dev/Owlibou/gnoma/internal/provider"
)
// --- Task Classification ---
func TestClassifyTask(t *testing.T) {
tests := []struct {
prompt string
want TaskType
}{
{"fix the bug in auth module", TaskDebug},
{"review this pull request", TaskReview},
{"refactor the database layer", TaskRefactor},
{"write unit tests for the handler", TaskUnitTest},
{"explain how the router works", TaskExplain},
{"create a new REST endpoint", TaskGeneration},
{"plan the migration strategy", TaskPlanning},
{"audit security of the API", TaskSecurityReview},
{"scaffold a new service", TaskBoilerplate},
{"coordinate the deployment pipeline", TaskOrchestration},
{"hello", TaskGeneration}, // default
}
for _, tt := range tests {
task := ClassifyTask(tt.prompt)
if task.Type != tt.want {
t.Errorf("ClassifyTask(%q).Type = %s, want %s", tt.prompt, task.Type, tt.want)
}
}
}
func TestClassifyTask_RequiresTools(t *testing.T) {
// Explain tasks don't require tools
task := ClassifyTask("explain how generics work")
if task.RequiresTools {
t.Error("explain task should not require tools")
}
// Debug tasks require tools
task = ClassifyTask("debug the failing test")
if !task.RequiresTools {
t.Error("debug task should require tools")
}
}
func TestTaskValueScore(t *testing.T) {
low := Task{Type: TaskBoilerplate, Priority: PriorityLow}
high := Task{Type: TaskSecurityReview, Priority: PriorityCritical}
if low.ValueScore() >= high.ValueScore() {
t.Errorf("low priority boilerplate (%f) should score less than critical security (%f)",
low.ValueScore(), high.ValueScore())
}
}
func TestEstimateComplexity(t *testing.T) {
simple := estimateComplexity("rename the variable")
complex := estimateComplexity("design and implement a distributed caching system with migration support and integration testing across multiple environments")
if simple >= complex {
t.Errorf("simple (%f) should be less than complex (%f)", simple, complex)
}
}
// --- Arm ---
func TestArmEstimateCost(t *testing.T) {
arm := &Arm{
CostPer1kInput: 0.003,
CostPer1kOutput: 0.015,
}
cost := arm.EstimateCost(10000)
// 6000 input tokens * 0.003/1k + 4000 output tokens * 0.015/1k
// = 0.018 + 0.060 = 0.078
if cost < 0.07 || cost > 0.09 {
t.Errorf("EstimateCost(10000) = %f, want ~0.078", cost)
}
}
func TestArmEstimateCost_Free(t *testing.T) {
arm := &Arm{} // local model, zero cost
cost := arm.EstimateCost(10000)
if cost != 0 {
t.Errorf("free model should have zero cost, got %f", cost)
}
}
// --- Pool ---
func TestLimitPool_RemainingFraction(t *testing.T) {
p := &LimitPool{TotalLimit: 100, Used: 30, Reserved: 20}
f := p.RemainingFraction()
if f != 0.5 {
t.Errorf("RemainingFraction = %f, want 0.5", f)
}
}
func TestLimitPool_ScarcityMultiplier(t *testing.T) {
// Half remaining, k=2: 1/0.5^2 = 4
p := &LimitPool{TotalLimit: 100, Used: 50, ScarcityK: 2}
m := p.ScarcityMultiplier()
if m < 3.9 || m > 4.1 {
t.Errorf("ScarcityMultiplier = %f, want ~4.0", m)
}
}
func TestLimitPool_ScarcityMultiplier_Exhausted(t *testing.T) {
p := &LimitPool{TotalLimit: 100, Used: 100}
m := p.ScarcityMultiplier()
if !math.IsInf(m, 1) {
t.Errorf("exhausted pool should return +Inf, got %f", m)
}
}
func TestLimitPool_ScarcityMultiplier_UseItOrLoseIt(t *testing.T) {
p := &LimitPool{
TotalLimit: 100, Used: 30, // 70% remaining
ScarcityK: 2,
ResetAt: time.Now().Add(30 * time.Minute), // reset in 30 min
}
m := p.ScarcityMultiplier()
if m != 0.5 {
t.Errorf("use-it-or-lose-it discount: ScarcityMultiplier = %f, want 0.5", m)
}
}
func TestLimitPool_ReserveAndCommit(t *testing.T) {
p := &LimitPool{
TotalLimit: 1000,
ArmRates: map[ArmID]float64{"test/model": 5.0}, // 5 units per 1k tokens
ScarcityK: 2,
}
res, ok := p.Reserve("test/model", 10000) // 5 * 10 = 50 units
if !ok {
t.Fatal("reservation should succeed")
}
if p.Reserved != 50 {
t.Errorf("Reserved = %f, want 50", p.Reserved)
}
res.Commit(8000) // actual: 5 * 8 = 40 units
if p.Used != 40 {
t.Errorf("Used = %f, want 40", p.Used)
}
if p.Reserved != 0 {
t.Errorf("Reserved = %f, want 0 after commit", p.Reserved)
}
}
func TestLimitPool_ReserveExhausted(t *testing.T) {
p := &LimitPool{
TotalLimit: 100,
Used: 90,
ArmRates: map[ArmID]float64{"test/model": 5.0},
ScarcityK: 2,
}
_, ok := p.Reserve("test/model", 10000) // needs 50, only 10 available
if ok {
t.Error("reservation should fail when exhausted")
}
}
func TestLimitPool_Rollback(t *testing.T) {
p := &LimitPool{
TotalLimit: 1000,
ArmRates: map[ArmID]float64{"test/model": 5.0},
ScarcityK: 2,
}
res, _ := p.Reserve("test/model", 10000)
if p.Reserved != 50 {
t.Fatalf("Reserved = %f, want 50", p.Reserved)
}
res.Rollback()
if p.Reserved != 0 {
t.Errorf("Reserved = %f, want 0 after rollback", p.Reserved)
}
if p.Used != 0 {
t.Errorf("Used = %f, want 0 after rollback", p.Used)
}
}
func TestLimitPool_CheckReset(t *testing.T) {
p := &LimitPool{
TotalLimit: 1000,
Used: 500,
Reserved: 100,
ResetPeriod: time.Hour,
ResetAt: time.Now().Add(-time.Minute), // already passed
ScarcityK: 2,
}
p.CheckReset()
if p.Used != 0 {
t.Errorf("Used = %f after reset, want 0", p.Used)
}
if p.Reserved != 0 {
t.Errorf("Reserved = %f after reset, want 0", p.Reserved)
}
}
// --- Selector ---
func TestSelectBest_PrefersToolSupport(t *testing.T) {
withTools := &Arm{
ID: "a/with-tools", ModelName: "with-tools",
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 128000},
}
withoutTools := &Arm{
ID: "b/no-tools", ModelName: "no-tools",
Capabilities: provider.Capabilities{ToolUse: false, ContextWindow: 128000},
}
task := Task{Type: TaskGeneration, RequiresTools: true, Priority: PriorityNormal}
best := selectBest([]*Arm{withoutTools, withTools}, task)
if best.ID != "a/with-tools" {
t.Errorf("should prefer arm with tool support, got %s", best.ID)
}
}
func TestSelectBest_PrefersThinkingForPlanning(t *testing.T) {
thinking := &Arm{
ID: "a/thinking", ModelName: "thinking",
Capabilities: provider.Capabilities{ToolUse: true, Thinking: true, ContextWindow: 200000},
CostPer1kInput: 0.01, CostPer1kOutput: 0.05,
}
noThinking := &Arm{
ID: "b/basic", ModelName: "basic",
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 128000},
CostPer1kInput: 0.01, CostPer1kOutput: 0.05,
}
task := Task{Type: TaskPlanning, RequiresTools: true, Priority: PriorityNormal, EstimatedTokens: 5000}
best := selectBest([]*Arm{noThinking, thinking}, task)
if best.ID != "a/thinking" {
t.Errorf("should prefer thinking model for planning, got %s", best.ID)
}
}
func TestFilterFeasible_ExcludesExhausted(t *testing.T) {
pool := &LimitPool{
TotalLimit: 100,
Used: 100, // exhausted
ArmRates: map[ArmID]float64{"a/model": 1.0},
ScarcityK: 2,
}
arm := &Arm{
ID: "a/model", ModelName: "model",
Capabilities: provider.Capabilities{ToolUse: true},
Pools: []*LimitPool{pool},
}
task := Task{Type: TaskGeneration, RequiresTools: true, EstimatedTokens: 1000}
feasible := filterFeasible([]*Arm{arm}, task)
if len(feasible) != 0 {
t.Error("exhausted arm should not be feasible")
}
}
// --- Router ---
func TestRouter_SelectForced(t *testing.T) {
r := New(Config{})
r.RegisterArm(&Arm{ID: "a/model1", Capabilities: provider.Capabilities{ToolUse: true}})
r.RegisterArm(&Arm{ID: "b/model2", Capabilities: provider.Capabilities{ToolUse: true}})
r.ForceArm("b/model2")
decision := r.Select(Task{Type: TaskGeneration})
if decision.Error != nil {
t.Fatalf("Select: %v", decision.Error)
}
if decision.Arm.ID != "b/model2" {
t.Errorf("forced arm should be selected, got %s", decision.Arm.ID)
}
}
func TestRouter_SelectNoArms(t *testing.T) {
r := New(Config{})
decision := r.Select(Task{Type: TaskGeneration})
if decision.Error == nil {
t.Error("should error with no arms")
}
}
func TestRouter_SelectForcedNotFound(t *testing.T) {
r := New(Config{})
r.ForceArm("nonexistent/model")
decision := r.Select(Task{Type: TaskGeneration})
if decision.Error == nil {
t.Error("should error when forced arm not found")
}
}

167
internal/router/selector.go Normal file
View File

@@ -0,0 +1,167 @@
package router
import (
"math"
)
// Strategy identifies how a task should be executed.
type Strategy int
const (
StrategySingleArm Strategy = iota
// Future (M9): StrategyCascade, StrategyParallelEnsemble, StrategyMultiRound
)
// RoutingDecision is the result of arm selection.
type RoutingDecision struct {
Strategy Strategy
Arm *Arm // primary arm
Error error
}
// selectBest picks the highest-scoring feasible arm using heuristic scoring.
// No bandit learning — that's M9. Just smart defaults based on model size,
// locality, task type, cost, and pool scarcity.
func selectBest(arms []*Arm, task Task) *Arm {
if len(arms) == 0 {
return nil
}
var best *Arm
bestScore := math.Inf(-1)
for _, arm := range arms {
score := scoreArm(arm, task)
if score > bestScore {
bestScore = score
best = arm
}
}
return best
}
// scoreArm computes a heuristic quality/cost score for an arm.
// Score = (quality × value) / effective_cost
func scoreArm(arm *Arm, task Task) float64 {
quality := heuristicQuality(arm, task)
value := task.ValueScore()
cost := effectiveCost(arm, task)
if cost <= 0 {
cost = 0.001 // prevent division by zero for free local models
}
return (quality * value) / cost
}
// heuristicQuality estimates arm quality without historical data.
func heuristicQuality(arm *Arm, task Task) float64 {
score := 0.5 // base
// Larger context window = better for complex tasks
if arm.Capabilities.ContextWindow >= 100000 {
score += 0.1
}
if arm.Capabilities.ContextWindow >= 200000 {
score += 0.05
}
// Thinking capability valuable for planning/orchestration/security
if arm.Capabilities.Thinking {
switch task.Type {
case TaskPlanning, TaskOrchestration, TaskSecurityReview:
score += 0.2
case TaskDebug, TaskRefactor:
score += 0.1
}
}
// Tool support required — arm without tools gets heavy penalty
if task.RequiresTools && !arm.SupportsTools() {
score *= 0.1
}
// Local models get a small boost (no network latency, privacy)
if arm.IsLocal {
score += 0.05
}
// Complexity adjustment — complex tasks penalize small/local models
if task.ComplexityScore > 0.7 && arm.IsLocal {
score *= 0.7
}
// Clamp
if score > 1.0 {
score = 1.0
}
if score < 0.0 {
score = 0.0
}
return score
}
// effectiveCost returns the base cost inflated by pool scarcity.
func effectiveCost(arm *Arm, task Task) float64 {
base := arm.EstimateCost(task.EstimatedTokens)
if base <= 0 {
base = 0.001 // local models are ~free but not zero for scoring
}
// Apply maximum scarcity multiplier across all pools
maxMultiplier := 1.0
for _, pool := range arm.Pools {
m := pool.ScarcityMultiplier()
if m > maxMultiplier {
maxMultiplier = m
}
}
return base * maxMultiplier
}
// filterFeasible returns arms that can handle the task (tools, pool capacity).
func filterFeasible(arms []*Arm, task Task) []*Arm {
var feasible []*Arm
for _, arm := range arms {
// Must support tools if task requires them
if task.RequiresTools && !arm.SupportsTools() {
continue
}
// Check all pools have capacity
poolsOK := true
for _, pool := range arm.Pools {
pool.CheckReset()
if !pool.CanAfford(arm.ID, task.EstimatedTokens) {
poolsOK = false
break
}
}
if !poolsOK {
continue
}
feasible = append(feasible, arm)
}
// If no arm with tools is feasible but task requires them,
// fall back to any available arm (tool-less is better than nothing)
if len(feasible) == 0 && task.RequiresTools {
for _, arm := range arms {
poolsOK := true
for _, pool := range arm.Pools {
if !pool.CanAfford(arm.ID, task.EstimatedTokens) {
poolsOK = false
break
}
}
if poolsOK {
feasible = append(feasible, arm)
}
}
}
return feasible
}

199
internal/router/task.go Normal file
View File

@@ -0,0 +1,199 @@
package router
import (
"fmt"
"strings"
)
// TaskType classifies a task for routing purposes.
type TaskType int
const (
TaskBoilerplate TaskType = iota // simple scaffolding, templates
TaskGeneration // new code creation
TaskRefactor // restructuring existing code
TaskReview // code review, analysis
TaskUnitTest // writing tests
TaskPlanning // architecture, design
TaskOrchestration // multi-step coordination
TaskSecurityReview // security-focused analysis
TaskDebug // finding and fixing bugs
TaskExplain // explaining code or concepts
)
func (t TaskType) String() string {
switch t {
case TaskBoilerplate:
return "boilerplate"
case TaskGeneration:
return "generation"
case TaskRefactor:
return "refactor"
case TaskReview:
return "review"
case TaskUnitTest:
return "unit_test"
case TaskPlanning:
return "planning"
case TaskOrchestration:
return "orchestration"
case TaskSecurityReview:
return "security_review"
case TaskDebug:
return "debug"
case TaskExplain:
return "explain"
default:
return fmt.Sprintf("unknown(%d)", t)
}
}
// Priority indicates task importance for routing decisions.
type Priority int
const (
PriorityLow Priority = iota
PriorityNormal
PriorityHigh
PriorityCritical
)
// Task represents a classified unit of work for routing.
type Task struct {
Type TaskType
Priority Priority
EstimatedTokens int
RequiresTools bool
ComplexityScore float64 // 0-1
}
// ValueScore computes a routing value based on priority and type.
func (t Task) ValueScore() float64 {
base := map[Priority]float64{
PriorityLow: 0.5,
PriorityNormal: 1.0,
PriorityHigh: 2.0,
PriorityCritical: 5.0,
}[t.Priority]
return base * taskTypeMultiplier[t.Type]
}
var taskTypeMultiplier = map[TaskType]float64{
TaskBoilerplate: 0.6,
TaskGeneration: 1.0,
TaskRefactor: 0.9,
TaskReview: 1.1,
TaskUnitTest: 0.8,
TaskPlanning: 1.4,
TaskOrchestration: 1.5,
TaskSecurityReview: 2.0,
TaskDebug: 1.2,
TaskExplain: 0.7,
}
// QualityThreshold defines minimum acceptable quality for a task type.
type QualityThreshold struct {
Minimum float64 // below → output is harmful, never accept
Acceptable float64 // good enough
Target float64 // ideal
}
var DefaultThresholds = map[TaskType]QualityThreshold{
TaskBoilerplate: {0.50, 0.70, 0.80},
TaskGeneration: {0.60, 0.75, 0.88},
TaskRefactor: {0.65, 0.78, 0.90},
TaskReview: {0.70, 0.82, 0.92},
TaskUnitTest: {0.60, 0.75, 0.85},
TaskPlanning: {0.75, 0.88, 0.95},
TaskOrchestration: {0.80, 0.90, 0.96},
TaskSecurityReview: {0.88, 0.94, 0.99},
TaskDebug: {0.65, 0.80, 0.90},
TaskExplain: {0.55, 0.72, 0.85},
}
// ClassifyTask infers a TaskType from the user's prompt using keyword heuristics.
func ClassifyTask(prompt string) Task {
lower := strings.ToLower(prompt)
task := Task{
Priority: PriorityNormal,
RequiresTools: true, // assume tools needed by default
}
// Check for task type keywords (order matters — more specific first)
switch {
case containsAny(lower, "security", "vulnerability", "cve", "owasp", "xss", "injection", "audit security"):
task.Type = TaskSecurityReview
task.Priority = PriorityHigh
case containsAny(lower, "plan", "architect", "design", "strategy", "roadmap"):
task.Type = TaskPlanning
case containsAny(lower, "orchestrat", "coordinate", "dispatch", "pipeline"):
task.Type = TaskOrchestration
task.Priority = PriorityHigh
case containsAny(lower, "debug", "fix", "troubleshoot", "not working", "error", "crash", "failing", "bug"):
task.Type = TaskDebug
case containsAny(lower, "review", "check", "analyze", "audit", "inspect"):
task.Type = TaskReview
case containsAny(lower, "refactor", "restructure", "reorganize", "clean up", "simplify"):
task.Type = TaskRefactor
case containsAny(lower, "test", "spec", "coverage", "assert"):
task.Type = TaskUnitTest
case containsAny(lower, "explain", "what is", "how does", "describe", "tell me about"):
task.Type = TaskExplain
task.RequiresTools = false
case containsAny(lower, "create", "implement", "build", "add", "write", "generate", "make"):
task.Type = TaskGeneration
case containsAny(lower, "scaffold", "boilerplate", "template", "stub", "skeleton"):
task.Type = TaskBoilerplate
default:
task.Type = TaskGeneration // default
}
// Estimate complexity from prompt length and keywords
task.ComplexityScore = estimateComplexity(lower)
return task
}
func containsAny(s string, keywords ...string) bool {
for _, kw := range keywords {
if strings.Contains(s, kw) {
return true
}
}
return false
}
func estimateComplexity(prompt string) float64 {
score := 0.0
// Length contributes to complexity
words := len(strings.Fields(prompt))
score += float64(words) / 200.0 // normalize: 200 words = 1.0
// Complexity keywords
complexKeywords := []string{"implement", "design", "architect", "system", "integration", "migrate", "optimize"}
for _, kw := range complexKeywords {
if strings.Contains(prompt, kw) {
score += 0.15
}
}
// Simple keywords reduce complexity
simpleKeywords := []string{"rename", "format", "add field", "change name", "typo", "simple"}
for _, kw := range simpleKeywords {
if strings.Contains(prompt, kw) {
score -= 0.15
}
}
// Clamp to [0, 1]
if score < 0 {
score = 0
}
if score > 1 {
score = 1
}
return score
}