feat: add router foundation with task classification and arm selection
internal/router/ — core routing layer: - Task classification: 10 types (boilerplate, generation, refactor, review, unit_test, planning, orchestration, security_review, debug, explain) with keyword heuristics and complexity scoring - Arm registry: provider+model pairs with capabilities and cost - Limit pools: shared resource budgets with scarcity multipliers, optimistic reservation, use-it-or-lose-it discounting - Heuristic selector: score = (quality × value) / effective_cost Prefers tools, thinking for planning, penalizes small models on complex tasks - Router: Select() picks best feasible arm, ForceArm() for CLI override Engine now routes through router.Select() when configured. Wired into CLI — arm registered per --provider/--model flags. 20 router tests. 173 tests total across 13 packages.
This commit is contained in:
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/engine"
|
||||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||||
"somegit.dev/Owlibou/gnoma/internal/router"
|
||||
"somegit.dev/Owlibou/gnoma/internal/security"
|
||||
anthropicprov "somegit.dev/Owlibou/gnoma/internal/provider/anthropic"
|
||||
"somegit.dev/Owlibou/gnoma/internal/provider/mistral"
|
||||
@@ -81,6 +82,23 @@ func main() {
|
||||
// Re-register bash tool with aliases
|
||||
reg.Register(bash.New(bash.WithAliases(aliases)))
|
||||
|
||||
// Create router and register the provider as a single arm
|
||||
// (M4 foundation: one provider from CLI. Multi-provider routing comes with config.)
|
||||
rtr := router.New(router.Config{Logger: logger})
|
||||
armModel := *model
|
||||
if armModel == "" {
|
||||
armModel = prov.DefaultModel()
|
||||
}
|
||||
armID := router.NewArmID(*providerName, armModel)
|
||||
rtr.RegisterArm(&router.Arm{
|
||||
ID: armID,
|
||||
Provider: prov,
|
||||
ModelName: armModel,
|
||||
IsLocal: localProviders[*providerName],
|
||||
Capabilities: provider.Capabilities{ToolUse: true}, // trust CLI provider
|
||||
})
|
||||
rtr.ForceArm(armID)
|
||||
|
||||
// Create firewall
|
||||
fw := security.NewFirewall(security.FirewallConfig{
|
||||
ScanOutgoing: true,
|
||||
@@ -92,6 +110,7 @@ func main() {
|
||||
// Create engine
|
||||
eng, err := engine.New(engine.Config{
|
||||
Provider: prov,
|
||||
Router: rtr,
|
||||
Tools: reg,
|
||||
Firewall: fw,
|
||||
System: *system,
|
||||
|
||||
@@ -7,13 +7,15 @@ import (
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||||
"somegit.dev/Owlibou/gnoma/internal/router"
|
||||
"somegit.dev/Owlibou/gnoma/internal/security"
|
||||
"somegit.dev/Owlibou/gnoma/internal/tool"
|
||||
)
|
||||
|
||||
// Config holds engine configuration.
|
||||
type Config struct {
|
||||
Provider provider.Provider
|
||||
Provider provider.Provider // direct provider (used if Router is nil)
|
||||
Router *router.Router // nil = use Provider directly
|
||||
Tools *tool.Registry
|
||||
Firewall *security.Firewall // nil = no scanning
|
||||
System string // system prompt
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||||
"somegit.dev/Owlibou/gnoma/internal/router"
|
||||
"somegit.dev/Owlibou/gnoma/internal/stream"
|
||||
)
|
||||
|
||||
@@ -38,16 +39,50 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
|
||||
// Build provider request (gates tools on model capabilities)
|
||||
req := e.buildRequest(ctx)
|
||||
|
||||
e.logger.Debug("streaming request",
|
||||
"provider", e.cfg.Provider.Name(),
|
||||
"model", req.Model,
|
||||
"messages", len(req.Messages),
|
||||
"tools", len(req.Tools),
|
||||
"round", turn.Rounds,
|
||||
)
|
||||
// Route and stream
|
||||
var s stream.Stream
|
||||
var err error
|
||||
|
||||
// Stream from provider
|
||||
s, err := e.cfg.Provider.Stream(ctx, req)
|
||||
if e.cfg.Router != nil {
|
||||
// Classify task from the latest user message
|
||||
prompt := ""
|
||||
for i := len(e.history) - 1; i >= 0; i-- {
|
||||
if e.history[i].Role == message.RoleUser {
|
||||
prompt = e.history[i].TextContent()
|
||||
break
|
||||
}
|
||||
}
|
||||
task := router.ClassifyTask(prompt)
|
||||
task.EstimatedTokens = 4000 // rough default
|
||||
|
||||
e.logger.Debug("routing request",
|
||||
"task_type", task.Type,
|
||||
"complexity", task.ComplexityScore,
|
||||
"round", turn.Rounds,
|
||||
)
|
||||
|
||||
var arm *router.Arm
|
||||
s, arm, err = e.cfg.Router.Stream(ctx, task, req)
|
||||
if arm != nil {
|
||||
e.logger.Debug("streaming request",
|
||||
"provider", arm.Provider.Name(),
|
||||
"model", arm.ModelName,
|
||||
"arm", arm.ID,
|
||||
"messages", len(req.Messages),
|
||||
"tools", len(req.Tools),
|
||||
"round", turn.Rounds,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
e.logger.Debug("streaming request",
|
||||
"provider", e.cfg.Provider.Name(),
|
||||
"model", req.Model,
|
||||
"messages", len(req.Messages),
|
||||
"tools", len(req.Tools),
|
||||
"round", turn.Rounds,
|
||||
)
|
||||
s, err = e.cfg.Provider.Stream(ctx, req)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("provider stream: %w", err)
|
||||
}
|
||||
|
||||
47
internal/router/arm.go
Normal file
47
internal/router/arm.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||||
)
|
||||
|
||||
// ArmID uniquely identifies a model+provider pair.
|
||||
type ArmID string
|
||||
|
||||
// Arm represents a provider+model pair available for routing.
|
||||
type Arm struct {
|
||||
ID ArmID
|
||||
Provider provider.Provider
|
||||
ModelName string
|
||||
IsLocal bool
|
||||
Capabilities provider.Capabilities
|
||||
Pools []*LimitPool
|
||||
|
||||
// Cost per 1k tokens (EUR, estimated)
|
||||
CostPer1kInput float64
|
||||
CostPer1kOutput float64
|
||||
}
|
||||
|
||||
// NewArmID creates an arm ID from provider name and model.
|
||||
func NewArmID(providerName, model string) ArmID {
|
||||
return ArmID(providerName + "/" + model)
|
||||
}
|
||||
|
||||
// EstimateCost returns estimated cost in EUR for a task.
|
||||
func (a *Arm) EstimateCost(estimatedTokens int) float64 {
|
||||
// Rough estimate: 60% input, 40% output
|
||||
inputTokens := float64(estimatedTokens) * 0.6
|
||||
outputTokens := float64(estimatedTokens) * 0.4
|
||||
return (inputTokens/1000)*a.CostPer1kInput + (outputTokens/1000)*a.CostPer1kOutput
|
||||
}
|
||||
|
||||
// SupportsTools returns true if this arm's model supports function calling.
|
||||
func (a *Arm) SupportsTools() bool {
|
||||
return a.Capabilities.ToolUse
|
||||
}
|
||||
|
||||
// ArmPerf holds live performance metrics for an arm.
|
||||
type ArmPerf struct {
|
||||
TTFT_P50_ms float64 // time to first token, p50
|
||||
TTFT_P95_ms float64 // time to first token, p95
|
||||
ToksPerSec float64 // tokens per second throughput
|
||||
}
|
||||
170
internal/router/pool.go
Normal file
170
internal/router/pool.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// PoolKind identifies the type of resource a pool tracks.
|
||||
type PoolKind int
|
||||
|
||||
const (
|
||||
PoolRPM PoolKind = iota // requests per minute
|
||||
PoolRPD // requests per day
|
||||
PoolTPD // tokens per day
|
||||
PoolCostEUR // monetary cost cap
|
||||
PoolCustom // arbitrary units
|
||||
)
|
||||
|
||||
// LimitPool tracks a shared resource budget that arms draw from.
|
||||
type LimitPool struct {
|
||||
mu sync.Mutex
|
||||
|
||||
ID string
|
||||
Kind PoolKind
|
||||
TotalLimit float64
|
||||
Used float64
|
||||
Reserved float64 // optimistically reserved for in-flight requests
|
||||
ResetPeriod time.Duration
|
||||
ResetAt time.Time
|
||||
|
||||
// Per-arm consumption rates (units per 1k tokens or per request)
|
||||
ArmRates map[ArmID]float64
|
||||
|
||||
// Scarcity curve aggressiveness. k=2 gentle, k=4 aggressive hoarding.
|
||||
ScarcityK float64
|
||||
}
|
||||
|
||||
// RemainingFraction returns the fraction of budget still available.
|
||||
func (p *LimitPool) RemainingFraction() float64 {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p.TotalLimit <= 0 {
|
||||
return 0
|
||||
}
|
||||
return 1.0 - (p.Used+p.Reserved)/p.TotalLimit
|
||||
}
|
||||
|
||||
// ScarcityMultiplier returns a cost inflation factor based on remaining budget.
|
||||
// As resources deplete, the multiplier increases, making the arm more expensive.
|
||||
func (p *LimitPool) ScarcityMultiplier() float64 {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return p.scarcityMultiplierLocked()
|
||||
}
|
||||
|
||||
func (p *LimitPool) scarcityMultiplierLocked() float64 {
|
||||
if p.TotalLimit <= 0 {
|
||||
return math.Inf(1)
|
||||
}
|
||||
|
||||
f := 1.0 - (p.Used+p.Reserved)/p.TotalLimit
|
||||
if f <= 0 {
|
||||
return math.Inf(1) // exhausted
|
||||
}
|
||||
|
||||
// Use-it-or-lose-it: if reset is imminent and headroom exists, discount
|
||||
hoursToReset := time.Until(p.ResetAt).Hours()
|
||||
if !p.ResetAt.IsZero() && hoursToReset > 0 && hoursToReset < 1.0 && f > 0.3 {
|
||||
return 0.5
|
||||
}
|
||||
|
||||
k := p.ScarcityK
|
||||
if k <= 0 {
|
||||
k = 2.0 // gentle default
|
||||
}
|
||||
return 1.0 / math.Pow(f, k)
|
||||
}
|
||||
|
||||
// Exhausted returns true if the pool has no remaining capacity.
|
||||
func (p *LimitPool) Exhausted() bool {
|
||||
return p.RemainingFraction() <= 0
|
||||
}
|
||||
|
||||
// CanAfford returns true if the pool can cover the projected consumption.
|
||||
func (p *LimitPool) CanAfford(armID ArmID, estimatedTokens int) bool {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
rate := p.ArmRates[armID]
|
||||
if rate == 0 {
|
||||
return true // no rate defined = no limit
|
||||
}
|
||||
projected := rate * float64(estimatedTokens) / 1000.0
|
||||
available := p.TotalLimit - p.Used - p.Reserved
|
||||
return projected <= available
|
||||
}
|
||||
|
||||
// Reservation represents an optimistic resource reservation.
|
||||
type Reservation struct {
|
||||
pool *LimitPool
|
||||
armID ArmID
|
||||
projected float64
|
||||
committed bool
|
||||
}
|
||||
|
||||
// Reserve creates an optimistic reservation. Call Commit() with actual usage
|
||||
// on completion, or Rollback() on failure.
|
||||
func (p *LimitPool) Reserve(armID ArmID, estimatedTokens int) (*Reservation, bool) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
rate := p.ArmRates[armID]
|
||||
if rate == 0 {
|
||||
return &Reservation{pool: p}, true // no limit
|
||||
}
|
||||
|
||||
projected := rate * float64(estimatedTokens) / 1000.0
|
||||
available := p.TotalLimit - p.Used - p.Reserved
|
||||
if projected > available {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
p.Reserved += projected
|
||||
return &Reservation{
|
||||
pool: p,
|
||||
armID: armID,
|
||||
projected: projected,
|
||||
}, true
|
||||
}
|
||||
|
||||
// Commit finalizes the reservation with actual consumption.
|
||||
func (r *Reservation) Commit(actualTokens int) {
|
||||
if r.committed || r.pool == nil {
|
||||
return
|
||||
}
|
||||
r.committed = true
|
||||
r.pool.mu.Lock()
|
||||
defer r.pool.mu.Unlock()
|
||||
|
||||
rate := r.pool.ArmRates[r.armID]
|
||||
actual := rate * float64(actualTokens) / 1000.0
|
||||
|
||||
r.pool.Reserved -= r.projected
|
||||
r.pool.Used += actual
|
||||
}
|
||||
|
||||
// Rollback releases the reservation without consumption.
|
||||
func (r *Reservation) Rollback() {
|
||||
if r.committed || r.pool == nil || r.projected == 0 {
|
||||
return
|
||||
}
|
||||
r.committed = true
|
||||
r.pool.mu.Lock()
|
||||
defer r.pool.mu.Unlock()
|
||||
|
||||
r.pool.Reserved -= r.projected
|
||||
}
|
||||
|
||||
// CheckReset resets usage if the reset period has elapsed.
|
||||
func (p *LimitPool) CheckReset() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if !p.ResetAt.IsZero() && time.Now().After(p.ResetAt) {
|
||||
p.Used = 0
|
||||
p.Reserved = 0
|
||||
p.ResetAt = p.ResetAt.Add(p.ResetPeriod)
|
||||
}
|
||||
}
|
||||
160
internal/router/router.go
Normal file
160
internal/router/router.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||||
"somegit.dev/Owlibou/gnoma/internal/stream"
|
||||
)
|
||||
|
||||
// Router selects the best arm for a given task.
|
||||
// M4: heuristic selection. M9: bandit learning.
|
||||
type Router struct {
|
||||
mu sync.RWMutex
|
||||
arms map[ArmID]*Arm
|
||||
logger *slog.Logger
|
||||
|
||||
// Optional: force a specific arm (--provider flag override)
|
||||
forcedArm ArmID
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Logger *slog.Logger
|
||||
}
|
||||
|
||||
func New(cfg Config) *Router {
|
||||
logger := cfg.Logger
|
||||
if logger == nil {
|
||||
logger = slog.Default()
|
||||
}
|
||||
return &Router{
|
||||
arms: make(map[ArmID]*Arm),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterArm adds an arm to the router.
|
||||
func (r *Router) RegisterArm(arm *Arm) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.arms[arm.ID] = arm
|
||||
r.logger.Debug("arm registered", "id", arm.ID, "local", arm.IsLocal, "tools", arm.SupportsTools())
|
||||
}
|
||||
|
||||
// ForceArm overrides routing to always select a specific arm.
|
||||
// Used for --provider CLI flag.
|
||||
func (r *Router) ForceArm(id ArmID) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.forcedArm = id
|
||||
}
|
||||
|
||||
// Select picks the best arm for the given task.
|
||||
func (r *Router) Select(task Task) RoutingDecision {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
// If an arm is forced, use it directly
|
||||
if r.forcedArm != "" {
|
||||
arm, ok := r.arms[r.forcedArm]
|
||||
if !ok {
|
||||
return RoutingDecision{Error: fmt.Errorf("forced arm %q not found", r.forcedArm)}
|
||||
}
|
||||
return RoutingDecision{Strategy: StrategySingleArm, Arm: arm}
|
||||
}
|
||||
|
||||
// Collect all arms
|
||||
allArms := make([]*Arm, 0, len(r.arms))
|
||||
for _, arm := range r.arms {
|
||||
allArms = append(allArms, arm)
|
||||
}
|
||||
|
||||
if len(allArms) == 0 {
|
||||
return RoutingDecision{Error: fmt.Errorf("no arms registered")}
|
||||
}
|
||||
|
||||
// Filter to feasible arms
|
||||
feasible := filterFeasible(allArms, task)
|
||||
if len(feasible) == 0 {
|
||||
return RoutingDecision{Error: fmt.Errorf("no feasible arm for task type %s", task.Type)}
|
||||
}
|
||||
|
||||
// Select best
|
||||
best := selectBest(feasible, task)
|
||||
if best == nil {
|
||||
return RoutingDecision{Error: fmt.Errorf("selection failed")}
|
||||
}
|
||||
|
||||
r.logger.Debug("arm selected",
|
||||
"arm", best.ID,
|
||||
"task_type", task.Type,
|
||||
"complexity", task.ComplexityScore,
|
||||
)
|
||||
|
||||
return RoutingDecision{Strategy: StrategySingleArm, Arm: best}
|
||||
}
|
||||
|
||||
// Arms returns all registered arms.
|
||||
func (r *Router) Arms() []*Arm {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
arms := make([]*Arm, 0, len(r.arms))
|
||||
for _, a := range r.arms {
|
||||
arms = append(arms, a)
|
||||
}
|
||||
return arms
|
||||
}
|
||||
|
||||
// RegisterProvider registers all models from a provider as arms.
|
||||
func (r *Router) RegisterProvider(ctx context.Context, prov provider.Provider, isLocal bool, costs map[string][2]float64) {
|
||||
models, err := prov.Models(ctx)
|
||||
if err != nil {
|
||||
r.logger.Debug("failed to list models", "provider", prov.Name(), "error", err)
|
||||
// Register at least the default model
|
||||
id := NewArmID(prov.Name(), prov.DefaultModel())
|
||||
r.RegisterArm(&Arm{
|
||||
ID: id,
|
||||
Provider: prov,
|
||||
ModelName: prov.DefaultModel(),
|
||||
IsLocal: isLocal,
|
||||
Capabilities: provider.Capabilities{ToolUse: true}, // optimistic
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
for _, m := range models {
|
||||
id := NewArmID(prov.Name(), m.ID)
|
||||
arm := &Arm{
|
||||
ID: id,
|
||||
Provider: prov,
|
||||
ModelName: m.ID,
|
||||
IsLocal: isLocal,
|
||||
Capabilities: m.Capabilities,
|
||||
}
|
||||
if c, ok := costs[m.ID]; ok {
|
||||
arm.CostPer1kInput = c[0]
|
||||
arm.CostPer1kOutput = c[1]
|
||||
}
|
||||
r.RegisterArm(arm)
|
||||
}
|
||||
}
|
||||
|
||||
// Stream is a convenience that selects an arm and streams from it.
|
||||
func (r *Router) Stream(ctx context.Context, task Task, req provider.Request) (stream.Stream, *Arm, error) {
|
||||
decision := r.Select(task)
|
||||
if decision.Error != nil {
|
||||
return nil, nil, decision.Error
|
||||
}
|
||||
|
||||
arm := decision.Arm
|
||||
req.Model = arm.ModelName
|
||||
|
||||
s, err := arm.Provider.Stream(ctx, req)
|
||||
if err != nil {
|
||||
return nil, arm, err
|
||||
}
|
||||
return s, arm, nil
|
||||
}
|
||||
305
internal/router/router_test.go
Normal file
305
internal/router/router_test.go
Normal file
@@ -0,0 +1,305 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||||
)
|
||||
|
||||
// --- Task Classification ---
|
||||
|
||||
func TestClassifyTask(t *testing.T) {
|
||||
tests := []struct {
|
||||
prompt string
|
||||
want TaskType
|
||||
}{
|
||||
{"fix the bug in auth module", TaskDebug},
|
||||
{"review this pull request", TaskReview},
|
||||
{"refactor the database layer", TaskRefactor},
|
||||
{"write unit tests for the handler", TaskUnitTest},
|
||||
{"explain how the router works", TaskExplain},
|
||||
{"create a new REST endpoint", TaskGeneration},
|
||||
{"plan the migration strategy", TaskPlanning},
|
||||
{"audit security of the API", TaskSecurityReview},
|
||||
{"scaffold a new service", TaskBoilerplate},
|
||||
{"coordinate the deployment pipeline", TaskOrchestration},
|
||||
{"hello", TaskGeneration}, // default
|
||||
}
|
||||
for _, tt := range tests {
|
||||
task := ClassifyTask(tt.prompt)
|
||||
if task.Type != tt.want {
|
||||
t.Errorf("ClassifyTask(%q).Type = %s, want %s", tt.prompt, task.Type, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyTask_RequiresTools(t *testing.T) {
|
||||
// Explain tasks don't require tools
|
||||
task := ClassifyTask("explain how generics work")
|
||||
if task.RequiresTools {
|
||||
t.Error("explain task should not require tools")
|
||||
}
|
||||
|
||||
// Debug tasks require tools
|
||||
task = ClassifyTask("debug the failing test")
|
||||
if !task.RequiresTools {
|
||||
t.Error("debug task should require tools")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskValueScore(t *testing.T) {
|
||||
low := Task{Type: TaskBoilerplate, Priority: PriorityLow}
|
||||
high := Task{Type: TaskSecurityReview, Priority: PriorityCritical}
|
||||
|
||||
if low.ValueScore() >= high.ValueScore() {
|
||||
t.Errorf("low priority boilerplate (%f) should score less than critical security (%f)",
|
||||
low.ValueScore(), high.ValueScore())
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateComplexity(t *testing.T) {
|
||||
simple := estimateComplexity("rename the variable")
|
||||
complex := estimateComplexity("design and implement a distributed caching system with migration support and integration testing across multiple environments")
|
||||
|
||||
if simple >= complex {
|
||||
t.Errorf("simple (%f) should be less than complex (%f)", simple, complex)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Arm ---
|
||||
|
||||
func TestArmEstimateCost(t *testing.T) {
|
||||
arm := &Arm{
|
||||
CostPer1kInput: 0.003,
|
||||
CostPer1kOutput: 0.015,
|
||||
}
|
||||
cost := arm.EstimateCost(10000)
|
||||
// 6000 input tokens * 0.003/1k + 4000 output tokens * 0.015/1k
|
||||
// = 0.018 + 0.060 = 0.078
|
||||
if cost < 0.07 || cost > 0.09 {
|
||||
t.Errorf("EstimateCost(10000) = %f, want ~0.078", cost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestArmEstimateCost_Free(t *testing.T) {
|
||||
arm := &Arm{} // local model, zero cost
|
||||
cost := arm.EstimateCost(10000)
|
||||
if cost != 0 {
|
||||
t.Errorf("free model should have zero cost, got %f", cost)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Pool ---
|
||||
|
||||
func TestLimitPool_RemainingFraction(t *testing.T) {
|
||||
p := &LimitPool{TotalLimit: 100, Used: 30, Reserved: 20}
|
||||
f := p.RemainingFraction()
|
||||
if f != 0.5 {
|
||||
t.Errorf("RemainingFraction = %f, want 0.5", f)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimitPool_ScarcityMultiplier(t *testing.T) {
|
||||
// Half remaining, k=2: 1/0.5^2 = 4
|
||||
p := &LimitPool{TotalLimit: 100, Used: 50, ScarcityK: 2}
|
||||
m := p.ScarcityMultiplier()
|
||||
if m < 3.9 || m > 4.1 {
|
||||
t.Errorf("ScarcityMultiplier = %f, want ~4.0", m)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimitPool_ScarcityMultiplier_Exhausted(t *testing.T) {
|
||||
p := &LimitPool{TotalLimit: 100, Used: 100}
|
||||
m := p.ScarcityMultiplier()
|
||||
if !math.IsInf(m, 1) {
|
||||
t.Errorf("exhausted pool should return +Inf, got %f", m)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimitPool_ScarcityMultiplier_UseItOrLoseIt(t *testing.T) {
|
||||
p := &LimitPool{
|
||||
TotalLimit: 100, Used: 30, // 70% remaining
|
||||
ScarcityK: 2,
|
||||
ResetAt: time.Now().Add(30 * time.Minute), // reset in 30 min
|
||||
}
|
||||
m := p.ScarcityMultiplier()
|
||||
if m != 0.5 {
|
||||
t.Errorf("use-it-or-lose-it discount: ScarcityMultiplier = %f, want 0.5", m)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimitPool_ReserveAndCommit(t *testing.T) {
|
||||
p := &LimitPool{
|
||||
TotalLimit: 1000,
|
||||
ArmRates: map[ArmID]float64{"test/model": 5.0}, // 5 units per 1k tokens
|
||||
ScarcityK: 2,
|
||||
}
|
||||
|
||||
res, ok := p.Reserve("test/model", 10000) // 5 * 10 = 50 units
|
||||
if !ok {
|
||||
t.Fatal("reservation should succeed")
|
||||
}
|
||||
if p.Reserved != 50 {
|
||||
t.Errorf("Reserved = %f, want 50", p.Reserved)
|
||||
}
|
||||
|
||||
res.Commit(8000) // actual: 5 * 8 = 40 units
|
||||
if p.Used != 40 {
|
||||
t.Errorf("Used = %f, want 40", p.Used)
|
||||
}
|
||||
if p.Reserved != 0 {
|
||||
t.Errorf("Reserved = %f, want 0 after commit", p.Reserved)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimitPool_ReserveExhausted(t *testing.T) {
|
||||
p := &LimitPool{
|
||||
TotalLimit: 100,
|
||||
Used: 90,
|
||||
ArmRates: map[ArmID]float64{"test/model": 5.0},
|
||||
ScarcityK: 2,
|
||||
}
|
||||
|
||||
_, ok := p.Reserve("test/model", 10000) // needs 50, only 10 available
|
||||
if ok {
|
||||
t.Error("reservation should fail when exhausted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimitPool_Rollback(t *testing.T) {
|
||||
p := &LimitPool{
|
||||
TotalLimit: 1000,
|
||||
ArmRates: map[ArmID]float64{"test/model": 5.0},
|
||||
ScarcityK: 2,
|
||||
}
|
||||
|
||||
res, _ := p.Reserve("test/model", 10000)
|
||||
if p.Reserved != 50 {
|
||||
t.Fatalf("Reserved = %f, want 50", p.Reserved)
|
||||
}
|
||||
|
||||
res.Rollback()
|
||||
if p.Reserved != 0 {
|
||||
t.Errorf("Reserved = %f, want 0 after rollback", p.Reserved)
|
||||
}
|
||||
if p.Used != 0 {
|
||||
t.Errorf("Used = %f, want 0 after rollback", p.Used)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimitPool_CheckReset(t *testing.T) {
|
||||
p := &LimitPool{
|
||||
TotalLimit: 1000,
|
||||
Used: 500,
|
||||
Reserved: 100,
|
||||
ResetPeriod: time.Hour,
|
||||
ResetAt: time.Now().Add(-time.Minute), // already passed
|
||||
ScarcityK: 2,
|
||||
}
|
||||
|
||||
p.CheckReset()
|
||||
if p.Used != 0 {
|
||||
t.Errorf("Used = %f after reset, want 0", p.Used)
|
||||
}
|
||||
if p.Reserved != 0 {
|
||||
t.Errorf("Reserved = %f after reset, want 0", p.Reserved)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Selector ---
|
||||
|
||||
func TestSelectBest_PrefersToolSupport(t *testing.T) {
|
||||
withTools := &Arm{
|
||||
ID: "a/with-tools", ModelName: "with-tools",
|
||||
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 128000},
|
||||
}
|
||||
withoutTools := &Arm{
|
||||
ID: "b/no-tools", ModelName: "no-tools",
|
||||
Capabilities: provider.Capabilities{ToolUse: false, ContextWindow: 128000},
|
||||
}
|
||||
|
||||
task := Task{Type: TaskGeneration, RequiresTools: true, Priority: PriorityNormal}
|
||||
best := selectBest([]*Arm{withoutTools, withTools}, task)
|
||||
|
||||
if best.ID != "a/with-tools" {
|
||||
t.Errorf("should prefer arm with tool support, got %s", best.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectBest_PrefersThinkingForPlanning(t *testing.T) {
|
||||
thinking := &Arm{
|
||||
ID: "a/thinking", ModelName: "thinking",
|
||||
Capabilities: provider.Capabilities{ToolUse: true, Thinking: true, ContextWindow: 200000},
|
||||
CostPer1kInput: 0.01, CostPer1kOutput: 0.05,
|
||||
}
|
||||
noThinking := &Arm{
|
||||
ID: "b/basic", ModelName: "basic",
|
||||
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 128000},
|
||||
CostPer1kInput: 0.01, CostPer1kOutput: 0.05,
|
||||
}
|
||||
|
||||
task := Task{Type: TaskPlanning, RequiresTools: true, Priority: PriorityNormal, EstimatedTokens: 5000}
|
||||
best := selectBest([]*Arm{noThinking, thinking}, task)
|
||||
|
||||
if best.ID != "a/thinking" {
|
||||
t.Errorf("should prefer thinking model for planning, got %s", best.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterFeasible_ExcludesExhausted(t *testing.T) {
|
||||
pool := &LimitPool{
|
||||
TotalLimit: 100,
|
||||
Used: 100, // exhausted
|
||||
ArmRates: map[ArmID]float64{"a/model": 1.0},
|
||||
ScarcityK: 2,
|
||||
}
|
||||
arm := &Arm{
|
||||
ID: "a/model", ModelName: "model",
|
||||
Capabilities: provider.Capabilities{ToolUse: true},
|
||||
Pools: []*LimitPool{pool},
|
||||
}
|
||||
|
||||
task := Task{Type: TaskGeneration, RequiresTools: true, EstimatedTokens: 1000}
|
||||
feasible := filterFeasible([]*Arm{arm}, task)
|
||||
|
||||
if len(feasible) != 0 {
|
||||
t.Error("exhausted arm should not be feasible")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Router ---
|
||||
|
||||
func TestRouter_SelectForced(t *testing.T) {
|
||||
r := New(Config{})
|
||||
r.RegisterArm(&Arm{ID: "a/model1", Capabilities: provider.Capabilities{ToolUse: true}})
|
||||
r.RegisterArm(&Arm{ID: "b/model2", Capabilities: provider.Capabilities{ToolUse: true}})
|
||||
|
||||
r.ForceArm("b/model2")
|
||||
|
||||
decision := r.Select(Task{Type: TaskGeneration})
|
||||
if decision.Error != nil {
|
||||
t.Fatalf("Select: %v", decision.Error)
|
||||
}
|
||||
if decision.Arm.ID != "b/model2" {
|
||||
t.Errorf("forced arm should be selected, got %s", decision.Arm.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectNoArms(t *testing.T) {
|
||||
r := New(Config{})
|
||||
decision := r.Select(Task{Type: TaskGeneration})
|
||||
if decision.Error == nil {
|
||||
t.Error("should error with no arms")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectForcedNotFound(t *testing.T) {
|
||||
r := New(Config{})
|
||||
r.ForceArm("nonexistent/model")
|
||||
decision := r.Select(Task{Type: TaskGeneration})
|
||||
if decision.Error == nil {
|
||||
t.Error("should error when forced arm not found")
|
||||
}
|
||||
}
|
||||
167
internal/router/selector.go
Normal file
167
internal/router/selector.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"math"
|
||||
)
|
||||
|
||||
// Strategy identifies how a task should be executed.
|
||||
type Strategy int
|
||||
|
||||
const (
|
||||
StrategySingleArm Strategy = iota
|
||||
// Future (M9): StrategyCascade, StrategyParallelEnsemble, StrategyMultiRound
|
||||
)
|
||||
|
||||
// RoutingDecision is the result of arm selection.
|
||||
type RoutingDecision struct {
|
||||
Strategy Strategy
|
||||
Arm *Arm // primary arm
|
||||
Error error
|
||||
}
|
||||
|
||||
// selectBest picks the highest-scoring feasible arm using heuristic scoring.
|
||||
// No bandit learning — that's M9. Just smart defaults based on model size,
|
||||
// locality, task type, cost, and pool scarcity.
|
||||
func selectBest(arms []*Arm, task Task) *Arm {
|
||||
if len(arms) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var best *Arm
|
||||
bestScore := math.Inf(-1)
|
||||
|
||||
for _, arm := range arms {
|
||||
score := scoreArm(arm, task)
|
||||
if score > bestScore {
|
||||
bestScore = score
|
||||
best = arm
|
||||
}
|
||||
}
|
||||
|
||||
return best
|
||||
}
|
||||
|
||||
// scoreArm computes a heuristic quality/cost score for an arm.
|
||||
// Score = (quality × value) / effective_cost
|
||||
func scoreArm(arm *Arm, task Task) float64 {
|
||||
quality := heuristicQuality(arm, task)
|
||||
value := task.ValueScore()
|
||||
cost := effectiveCost(arm, task)
|
||||
|
||||
if cost <= 0 {
|
||||
cost = 0.001 // prevent division by zero for free local models
|
||||
}
|
||||
|
||||
return (quality * value) / cost
|
||||
}
|
||||
|
||||
// heuristicQuality estimates arm quality without historical data.
|
||||
func heuristicQuality(arm *Arm, task Task) float64 {
|
||||
score := 0.5 // base
|
||||
|
||||
// Larger context window = better for complex tasks
|
||||
if arm.Capabilities.ContextWindow >= 100000 {
|
||||
score += 0.1
|
||||
}
|
||||
if arm.Capabilities.ContextWindow >= 200000 {
|
||||
score += 0.05
|
||||
}
|
||||
|
||||
// Thinking capability valuable for planning/orchestration/security
|
||||
if arm.Capabilities.Thinking {
|
||||
switch task.Type {
|
||||
case TaskPlanning, TaskOrchestration, TaskSecurityReview:
|
||||
score += 0.2
|
||||
case TaskDebug, TaskRefactor:
|
||||
score += 0.1
|
||||
}
|
||||
}
|
||||
|
||||
// Tool support required — arm without tools gets heavy penalty
|
||||
if task.RequiresTools && !arm.SupportsTools() {
|
||||
score *= 0.1
|
||||
}
|
||||
|
||||
// Local models get a small boost (no network latency, privacy)
|
||||
if arm.IsLocal {
|
||||
score += 0.05
|
||||
}
|
||||
|
||||
// Complexity adjustment — complex tasks penalize small/local models
|
||||
if task.ComplexityScore > 0.7 && arm.IsLocal {
|
||||
score *= 0.7
|
||||
}
|
||||
|
||||
// Clamp
|
||||
if score > 1.0 {
|
||||
score = 1.0
|
||||
}
|
||||
if score < 0.0 {
|
||||
score = 0.0
|
||||
}
|
||||
return score
|
||||
}
|
||||
|
||||
// effectiveCost returns the base cost inflated by pool scarcity.
|
||||
func effectiveCost(arm *Arm, task Task) float64 {
|
||||
base := arm.EstimateCost(task.EstimatedTokens)
|
||||
if base <= 0 {
|
||||
base = 0.001 // local models are ~free but not zero for scoring
|
||||
}
|
||||
|
||||
// Apply maximum scarcity multiplier across all pools
|
||||
maxMultiplier := 1.0
|
||||
for _, pool := range arm.Pools {
|
||||
m := pool.ScarcityMultiplier()
|
||||
if m > maxMultiplier {
|
||||
maxMultiplier = m
|
||||
}
|
||||
}
|
||||
|
||||
return base * maxMultiplier
|
||||
}
|
||||
|
||||
// filterFeasible returns arms that can handle the task (tools, pool capacity).
|
||||
func filterFeasible(arms []*Arm, task Task) []*Arm {
|
||||
var feasible []*Arm
|
||||
for _, arm := range arms {
|
||||
// Must support tools if task requires them
|
||||
if task.RequiresTools && !arm.SupportsTools() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check all pools have capacity
|
||||
poolsOK := true
|
||||
for _, pool := range arm.Pools {
|
||||
pool.CheckReset()
|
||||
if !pool.CanAfford(arm.ID, task.EstimatedTokens) {
|
||||
poolsOK = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if !poolsOK {
|
||||
continue
|
||||
}
|
||||
|
||||
feasible = append(feasible, arm)
|
||||
}
|
||||
|
||||
// If no arm with tools is feasible but task requires them,
|
||||
// fall back to any available arm (tool-less is better than nothing)
|
||||
if len(feasible) == 0 && task.RequiresTools {
|
||||
for _, arm := range arms {
|
||||
poolsOK := true
|
||||
for _, pool := range arm.Pools {
|
||||
if !pool.CanAfford(arm.ID, task.EstimatedTokens) {
|
||||
poolsOK = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if poolsOK {
|
||||
feasible = append(feasible, arm)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return feasible
|
||||
}
|
||||
199
internal/router/task.go
Normal file
199
internal/router/task.go
Normal file
@@ -0,0 +1,199 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// TaskType classifies a task for routing purposes.
|
||||
type TaskType int
|
||||
|
||||
const (
|
||||
TaskBoilerplate TaskType = iota // simple scaffolding, templates
|
||||
TaskGeneration // new code creation
|
||||
TaskRefactor // restructuring existing code
|
||||
TaskReview // code review, analysis
|
||||
TaskUnitTest // writing tests
|
||||
TaskPlanning // architecture, design
|
||||
TaskOrchestration // multi-step coordination
|
||||
TaskSecurityReview // security-focused analysis
|
||||
TaskDebug // finding and fixing bugs
|
||||
TaskExplain // explaining code or concepts
|
||||
)
|
||||
|
||||
func (t TaskType) String() string {
|
||||
switch t {
|
||||
case TaskBoilerplate:
|
||||
return "boilerplate"
|
||||
case TaskGeneration:
|
||||
return "generation"
|
||||
case TaskRefactor:
|
||||
return "refactor"
|
||||
case TaskReview:
|
||||
return "review"
|
||||
case TaskUnitTest:
|
||||
return "unit_test"
|
||||
case TaskPlanning:
|
||||
return "planning"
|
||||
case TaskOrchestration:
|
||||
return "orchestration"
|
||||
case TaskSecurityReview:
|
||||
return "security_review"
|
||||
case TaskDebug:
|
||||
return "debug"
|
||||
case TaskExplain:
|
||||
return "explain"
|
||||
default:
|
||||
return fmt.Sprintf("unknown(%d)", t)
|
||||
}
|
||||
}
|
||||
|
||||
// Priority indicates task importance for routing decisions.
|
||||
type Priority int
|
||||
|
||||
const (
|
||||
PriorityLow Priority = iota
|
||||
PriorityNormal
|
||||
PriorityHigh
|
||||
PriorityCritical
|
||||
)
|
||||
|
||||
// Task represents a classified unit of work for routing.
|
||||
type Task struct {
|
||||
Type TaskType
|
||||
Priority Priority
|
||||
EstimatedTokens int
|
||||
RequiresTools bool
|
||||
ComplexityScore float64 // 0-1
|
||||
}
|
||||
|
||||
// ValueScore computes a routing value based on priority and type.
|
||||
func (t Task) ValueScore() float64 {
|
||||
base := map[Priority]float64{
|
||||
PriorityLow: 0.5,
|
||||
PriorityNormal: 1.0,
|
||||
PriorityHigh: 2.0,
|
||||
PriorityCritical: 5.0,
|
||||
}[t.Priority]
|
||||
|
||||
return base * taskTypeMultiplier[t.Type]
|
||||
}
|
||||
|
||||
var taskTypeMultiplier = map[TaskType]float64{
|
||||
TaskBoilerplate: 0.6,
|
||||
TaskGeneration: 1.0,
|
||||
TaskRefactor: 0.9,
|
||||
TaskReview: 1.1,
|
||||
TaskUnitTest: 0.8,
|
||||
TaskPlanning: 1.4,
|
||||
TaskOrchestration: 1.5,
|
||||
TaskSecurityReview: 2.0,
|
||||
TaskDebug: 1.2,
|
||||
TaskExplain: 0.7,
|
||||
}
|
||||
|
||||
// QualityThreshold defines minimum acceptable quality for a task type.
|
||||
type QualityThreshold struct {
|
||||
Minimum float64 // below → output is harmful, never accept
|
||||
Acceptable float64 // good enough
|
||||
Target float64 // ideal
|
||||
}
|
||||
|
||||
var DefaultThresholds = map[TaskType]QualityThreshold{
|
||||
TaskBoilerplate: {0.50, 0.70, 0.80},
|
||||
TaskGeneration: {0.60, 0.75, 0.88},
|
||||
TaskRefactor: {0.65, 0.78, 0.90},
|
||||
TaskReview: {0.70, 0.82, 0.92},
|
||||
TaskUnitTest: {0.60, 0.75, 0.85},
|
||||
TaskPlanning: {0.75, 0.88, 0.95},
|
||||
TaskOrchestration: {0.80, 0.90, 0.96},
|
||||
TaskSecurityReview: {0.88, 0.94, 0.99},
|
||||
TaskDebug: {0.65, 0.80, 0.90},
|
||||
TaskExplain: {0.55, 0.72, 0.85},
|
||||
}
|
||||
|
||||
// ClassifyTask infers a TaskType from the user's prompt using keyword heuristics.
|
||||
func ClassifyTask(prompt string) Task {
|
||||
lower := strings.ToLower(prompt)
|
||||
|
||||
task := Task{
|
||||
Priority: PriorityNormal,
|
||||
RequiresTools: true, // assume tools needed by default
|
||||
}
|
||||
|
||||
// Check for task type keywords (order matters — more specific first)
|
||||
switch {
|
||||
case containsAny(lower, "security", "vulnerability", "cve", "owasp", "xss", "injection", "audit security"):
|
||||
task.Type = TaskSecurityReview
|
||||
task.Priority = PriorityHigh
|
||||
case containsAny(lower, "plan", "architect", "design", "strategy", "roadmap"):
|
||||
task.Type = TaskPlanning
|
||||
case containsAny(lower, "orchestrat", "coordinate", "dispatch", "pipeline"):
|
||||
task.Type = TaskOrchestration
|
||||
task.Priority = PriorityHigh
|
||||
case containsAny(lower, "debug", "fix", "troubleshoot", "not working", "error", "crash", "failing", "bug"):
|
||||
task.Type = TaskDebug
|
||||
case containsAny(lower, "review", "check", "analyze", "audit", "inspect"):
|
||||
task.Type = TaskReview
|
||||
case containsAny(lower, "refactor", "restructure", "reorganize", "clean up", "simplify"):
|
||||
task.Type = TaskRefactor
|
||||
case containsAny(lower, "test", "spec", "coverage", "assert"):
|
||||
task.Type = TaskUnitTest
|
||||
case containsAny(lower, "explain", "what is", "how does", "describe", "tell me about"):
|
||||
task.Type = TaskExplain
|
||||
task.RequiresTools = false
|
||||
case containsAny(lower, "create", "implement", "build", "add", "write", "generate", "make"):
|
||||
task.Type = TaskGeneration
|
||||
case containsAny(lower, "scaffold", "boilerplate", "template", "stub", "skeleton"):
|
||||
task.Type = TaskBoilerplate
|
||||
default:
|
||||
task.Type = TaskGeneration // default
|
||||
}
|
||||
|
||||
// Estimate complexity from prompt length and keywords
|
||||
task.ComplexityScore = estimateComplexity(lower)
|
||||
|
||||
return task
|
||||
}
|
||||
|
||||
func containsAny(s string, keywords ...string) bool {
|
||||
for _, kw := range keywords {
|
||||
if strings.Contains(s, kw) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func estimateComplexity(prompt string) float64 {
|
||||
score := 0.0
|
||||
|
||||
// Length contributes to complexity
|
||||
words := len(strings.Fields(prompt))
|
||||
score += float64(words) / 200.0 // normalize: 200 words = 1.0
|
||||
|
||||
// Complexity keywords
|
||||
complexKeywords := []string{"implement", "design", "architect", "system", "integration", "migrate", "optimize"}
|
||||
for _, kw := range complexKeywords {
|
||||
if strings.Contains(prompt, kw) {
|
||||
score += 0.15
|
||||
}
|
||||
}
|
||||
|
||||
// Simple keywords reduce complexity
|
||||
simpleKeywords := []string{"rename", "format", "add field", "change name", "typo", "simple"}
|
||||
for _, kw := range simpleKeywords {
|
||||
if strings.Contains(prompt, kw) {
|
||||
score -= 0.15
|
||||
}
|
||||
}
|
||||
|
||||
// Clamp to [0, 1]
|
||||
if score < 0 {
|
||||
score = 0
|
||||
}
|
||||
if score > 1 {
|
||||
score = 1
|
||||
}
|
||||
return score
|
||||
}
|
||||
Reference in New Issue
Block a user