Files
vikingowl eb0583f606 fix(router): unpin config-default provider + complexity floor by task type
Two routing bugs were keeping the SLM out of every real prompt and,
once it was eligible, pulling complex tasks into it as well.

Bug 1: ForceArm was called unconditionally when a primary provider was
configured (cmd/gnoma/main.go:378). That short-circuited the entire
router — every prompt went straight to whatever was set as
[provider].default, regardless of tier, score, or feasibility. The SLM
arm appeared in `gnoma router stats` registration logs but had zero
observations after dozens of prompts.

Fix: only pin when the user passed --provider on the command line.
Config defaults register the arm but don't force it; the router picks
freely. Verified end-to-end — trivial prompts now reach slm/ollama
via the tier-0 priority.

Bug 2: A short prompt like "refactor the SLM module" classifies as
TaskRefactor with complexity 0.015 — well under the SLM arm's 0.3
ceiling. The arm became eligible despite the task being inherently
non-trivial. Once eligible, tier-0 priority then pulled it in over
the CLI agents.

Fix: add MinComplexityForType, applied in both ClassifyTask
(heuristic path) and slm.Classifier.Classify (SLM-overlay path). The
floor is per-task-type:

  - TaskSecurityReview, TaskOrchestration  → 0.60
  - TaskRefactor, TaskPlanning, TaskDebug  → 0.40
  - TaskUnitTest, TaskReview               → 0.35

Tasks like Explain/Generation/Boilerplate keep their organic
complexity score so trivial knowledge prompts (≤0.15) still fall to
the SLM. Tasks that imply existing code or multi-step reasoning are
clamped above the SLM's MaxComplexity, naturally routing them to a
bigger arm.

After both fixes, observed routing in a clean run:

  What is 2+2?              → slm/ollama (complexity 0.015)
  Define a closure          → slm/ollama (complexity 0.015)
  What is HTTP?             → slm/ollama (complexity 0.015)
  Refactor the SLM module   → subprocess/gemini (complexity 0.40)
  Audit for race conditions → subprocess/gemini (complexity 0.35)
  Plan a migration          → subprocess/gemini (complexity 0.40)
2026-05-19 19:22:16 +02:00

163 lines
4.8 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package slm
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"strings"
"time"
"somegit.dev/Owlibou/gnoma/internal/message"
"somegit.dev/Owlibou/gnoma/internal/provider"
"somegit.dev/Owlibou/gnoma/internal/router"
"somegit.dev/Owlibou/gnoma/internal/stream"
)
// defaultClassifyTimeout — 5 s accommodates thinking-mode models like
// Qwen3 distillations (Tiny3.5) that emit reasoning tokens before output.
// Non-thinking models complete in well under 1 s.
const defaultClassifyTimeout = 5 * time.Second
const classifySystemPrompt = `Classify the following coding request. /no_think
Respond with JSON only, no other text, no reasoning, no thinking tags.
Format: {"task_type": "<type>", "complexity": <0.0-1.0>, "requires_tools": <true|false>}
Task types: Debug, Explain, Generation, Refactor, UnitTest, Boilerplate, Planning, Orchestration, SecurityReview, Review
Complexity guide:
0.00.3: boilerplate, trivial edits, simple lookups, short explanations
0.40.6: new functions, refactors, unit tests, moderate analysis
0.71.0: architectural changes, multi-file edits, security review, planning`
type classifyResponse struct {
TaskType string `json:"task_type"`
Complexity float64 `json:"complexity"`
RequiresTools bool `json:"requires_tools"`
}
// Classifier implements router.TaskClassifier using a llamafile-hosted SLM.
// On timeout or parse failure it falls back to router.HeuristicClassifier.
type Classifier struct {
provider provider.Provider
model string
timeout time.Duration
logger *slog.Logger
}
// NewClassifier creates a Classifier. model is the model name passed to the provider
// (llamafile ignores it but openaicompat requires a non-empty value).
func NewClassifier(p provider.Provider, model string, logger *slog.Logger) *Classifier {
if logger == nil {
logger = slog.Default()
}
return &Classifier{
provider: p,
model: model,
timeout: defaultClassifyTimeout,
logger: logger,
}
}
// Classify calls the SLM and overlays the three SLM-authoritative fields
// (Type, ComplexityScore, RequiresTools) onto a heuristic baseline Task.
// This ensures Priority, EstimatedTokens, and RequiredEffort are always set.
func (c *Classifier) Classify(ctx context.Context, prompt string, history []message.Message) (router.Task, error) {
tctx, cancel := context.WithTimeout(ctx, c.timeout)
defer cancel()
resp, err := c.callSLM(tctx, prompt)
if err != nil {
c.logger.Debug("slm classify fallback", "error", err)
t, ferr := router.HeuristicClassifier{}.Classify(ctx, prompt, history)
t.ClassifierSource = router.ClassifierSLMFallback
return t, ferr
}
// Start from the heuristic baseline so Priority/EstimatedTokens/RequiredEffort are set.
task := router.ClassifyTask(prompt)
task.Type = router.ParseTaskType(resp.TaskType)
task.ComplexityScore = resp.Complexity
task.RequiresTools = resp.RequiresTools
task.ClassifierSource = router.ClassifierSLM
// Re-apply the per-task-type complexity floor after the SLM overlay.
// The SLM may have under-reported complexity for a Refactor-style
// task; the floor protects the SLM arm from being picked for its own
// kind of misclassification.
if floor := router.MinComplexityForType(task.Type); task.ComplexityScore < floor {
task.ComplexityScore = floor
}
return task, nil
}
func (c *Classifier) callSLM(ctx context.Context, prompt string) (*classifyResponse, error) {
req := provider.Request{
Model: c.model,
SystemPrompt: classifySystemPrompt,
Messages: []message.Message{
{
Role: message.RoleUser,
Content: []message.Content{{Type: message.ContentText, Text: prompt}},
},
},
}
strm, err := c.provider.Stream(ctx, req)
if err != nil {
return nil, fmt.Errorf("stream: %w", err)
}
defer func() { _ = strm.Close() }()
var sb strings.Builder
for strm.Next() {
ev := strm.Current()
if ev.Type == stream.EventTextDelta {
sb.WriteString(ev.Text)
}
}
if err := strm.Err(); err != nil {
return nil, fmt.Errorf("stream error: %w", err)
}
text := extractJSON(sb.String())
var resp classifyResponse
if err := json.Unmarshal([]byte(text), &resp); err != nil {
return nil, fmt.Errorf("parse %q: %w", text, err)
}
return &resp, nil
}
// extractJSON pulls the first {...} substring from s, stripping markdown fences if present.
func extractJSON(s string) string {
s = strings.TrimSpace(s)
// Strip ```json ... ``` fences.
if strings.HasPrefix(s, "```") {
end := strings.LastIndex(s, "```")
if end > 3 {
inner := s[3:end]
inner = strings.TrimPrefix(inner, "json")
s = strings.TrimSpace(inner)
}
}
// Extract first balanced {...} block.
start := strings.IndexByte(s, '{')
if start < 0 {
return s
}
depth := 0
for i := start; i < len(s); i++ {
switch s[i] {
case '{':
depth++
case '}':
depth--
if depth == 0 {
return s[start : i+1]
}
}
}
return s[start:]
}