eb0583f606
Two routing bugs were keeping the SLM out of every real prompt and, once it was eligible, pulling complex tasks into it as well. Bug 1: ForceArm was called unconditionally when a primary provider was configured (cmd/gnoma/main.go:378). That short-circuited the entire router — every prompt went straight to whatever was set as [provider].default, regardless of tier, score, or feasibility. The SLM arm appeared in `gnoma router stats` registration logs but had zero observations after dozens of prompts. Fix: only pin when the user passed --provider on the command line. Config defaults register the arm but don't force it; the router picks freely. Verified end-to-end — trivial prompts now reach slm/ollama via the tier-0 priority. Bug 2: A short prompt like "refactor the SLM module" classifies as TaskRefactor with complexity 0.015 — well under the SLM arm's 0.3 ceiling. The arm became eligible despite the task being inherently non-trivial. Once eligible, tier-0 priority then pulled it in over the CLI agents. Fix: add MinComplexityForType, applied in both ClassifyTask (heuristic path) and slm.Classifier.Classify (SLM-overlay path). The floor is per-task-type: - TaskSecurityReview, TaskOrchestration → 0.60 - TaskRefactor, TaskPlanning, TaskDebug → 0.40 - TaskUnitTest, TaskReview → 0.35 Tasks like Explain/Generation/Boilerplate keep their organic complexity score so trivial knowledge prompts (≤0.15) still fall to the SLM. Tasks that imply existing code or multi-step reasoning are clamped above the SLM's MaxComplexity, naturally routing them to a bigger arm. After both fixes, observed routing in a clean run: What is 2+2? → slm/ollama (complexity 0.015) Define a closure → slm/ollama (complexity 0.015) What is HTTP? → slm/ollama (complexity 0.015) Refactor the SLM module → subprocess/gemini (complexity 0.40) Audit for race conditions → subprocess/gemini (complexity 0.35) Plan a migration → subprocess/gemini (complexity 0.40)
163 lines
4.8 KiB
Go
163 lines
4.8 KiB
Go
package slm
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"log/slog"
|
||
"strings"
|
||
"time"
|
||
|
||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||
"somegit.dev/Owlibou/gnoma/internal/router"
|
||
"somegit.dev/Owlibou/gnoma/internal/stream"
|
||
)
|
||
|
||
// defaultClassifyTimeout — 5 s accommodates thinking-mode models like
|
||
// Qwen3 distillations (Tiny3.5) that emit reasoning tokens before output.
|
||
// Non-thinking models complete in well under 1 s.
|
||
const defaultClassifyTimeout = 5 * time.Second
|
||
|
||
const classifySystemPrompt = `Classify the following coding request. /no_think
|
||
Respond with JSON only, no other text, no reasoning, no thinking tags.
|
||
Format: {"task_type": "<type>", "complexity": <0.0-1.0>, "requires_tools": <true|false>}
|
||
|
||
Task types: Debug, Explain, Generation, Refactor, UnitTest, Boilerplate, Planning, Orchestration, SecurityReview, Review
|
||
|
||
Complexity guide:
|
||
0.0–0.3: boilerplate, trivial edits, simple lookups, short explanations
|
||
0.4–0.6: new functions, refactors, unit tests, moderate analysis
|
||
0.7–1.0: architectural changes, multi-file edits, security review, planning`
|
||
|
||
type classifyResponse struct {
|
||
TaskType string `json:"task_type"`
|
||
Complexity float64 `json:"complexity"`
|
||
RequiresTools bool `json:"requires_tools"`
|
||
}
|
||
|
||
// Classifier implements router.TaskClassifier using a llamafile-hosted SLM.
|
||
// On timeout or parse failure it falls back to router.HeuristicClassifier.
|
||
type Classifier struct {
|
||
provider provider.Provider
|
||
model string
|
||
timeout time.Duration
|
||
logger *slog.Logger
|
||
}
|
||
|
||
// NewClassifier creates a Classifier. model is the model name passed to the provider
|
||
// (llamafile ignores it but openaicompat requires a non-empty value).
|
||
func NewClassifier(p provider.Provider, model string, logger *slog.Logger) *Classifier {
|
||
if logger == nil {
|
||
logger = slog.Default()
|
||
}
|
||
return &Classifier{
|
||
provider: p,
|
||
model: model,
|
||
timeout: defaultClassifyTimeout,
|
||
logger: logger,
|
||
}
|
||
}
|
||
|
||
// Classify calls the SLM and overlays the three SLM-authoritative fields
|
||
// (Type, ComplexityScore, RequiresTools) onto a heuristic baseline Task.
|
||
// This ensures Priority, EstimatedTokens, and RequiredEffort are always set.
|
||
func (c *Classifier) Classify(ctx context.Context, prompt string, history []message.Message) (router.Task, error) {
|
||
tctx, cancel := context.WithTimeout(ctx, c.timeout)
|
||
defer cancel()
|
||
|
||
resp, err := c.callSLM(tctx, prompt)
|
||
if err != nil {
|
||
c.logger.Debug("slm classify fallback", "error", err)
|
||
t, ferr := router.HeuristicClassifier{}.Classify(ctx, prompt, history)
|
||
t.ClassifierSource = router.ClassifierSLMFallback
|
||
return t, ferr
|
||
}
|
||
|
||
// Start from the heuristic baseline so Priority/EstimatedTokens/RequiredEffort are set.
|
||
task := router.ClassifyTask(prompt)
|
||
task.Type = router.ParseTaskType(resp.TaskType)
|
||
task.ComplexityScore = resp.Complexity
|
||
task.RequiresTools = resp.RequiresTools
|
||
task.ClassifierSource = router.ClassifierSLM
|
||
// Re-apply the per-task-type complexity floor after the SLM overlay.
|
||
// The SLM may have under-reported complexity for a Refactor-style
|
||
// task; the floor protects the SLM arm from being picked for its own
|
||
// kind of misclassification.
|
||
if floor := router.MinComplexityForType(task.Type); task.ComplexityScore < floor {
|
||
task.ComplexityScore = floor
|
||
}
|
||
return task, nil
|
||
}
|
||
|
||
func (c *Classifier) callSLM(ctx context.Context, prompt string) (*classifyResponse, error) {
|
||
req := provider.Request{
|
||
Model: c.model,
|
||
SystemPrompt: classifySystemPrompt,
|
||
Messages: []message.Message{
|
||
{
|
||
Role: message.RoleUser,
|
||
Content: []message.Content{{Type: message.ContentText, Text: prompt}},
|
||
},
|
||
},
|
||
}
|
||
|
||
strm, err := c.provider.Stream(ctx, req)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("stream: %w", err)
|
||
}
|
||
defer func() { _ = strm.Close() }()
|
||
|
||
var sb strings.Builder
|
||
for strm.Next() {
|
||
ev := strm.Current()
|
||
if ev.Type == stream.EventTextDelta {
|
||
sb.WriteString(ev.Text)
|
||
}
|
||
}
|
||
if err := strm.Err(); err != nil {
|
||
return nil, fmt.Errorf("stream error: %w", err)
|
||
}
|
||
|
||
text := extractJSON(sb.String())
|
||
var resp classifyResponse
|
||
if err := json.Unmarshal([]byte(text), &resp); err != nil {
|
||
return nil, fmt.Errorf("parse %q: %w", text, err)
|
||
}
|
||
return &resp, nil
|
||
}
|
||
|
||
// extractJSON pulls the first {...} substring from s, stripping markdown fences if present.
|
||
func extractJSON(s string) string {
|
||
s = strings.TrimSpace(s)
|
||
|
||
// Strip ```json ... ``` fences.
|
||
if strings.HasPrefix(s, "```") {
|
||
end := strings.LastIndex(s, "```")
|
||
if end > 3 {
|
||
inner := s[3:end]
|
||
inner = strings.TrimPrefix(inner, "json")
|
||
s = strings.TrimSpace(inner)
|
||
}
|
||
}
|
||
|
||
// Extract first balanced {...} block.
|
||
start := strings.IndexByte(s, '{')
|
||
if start < 0 {
|
||
return s
|
||
}
|
||
depth := 0
|
||
for i := start; i < len(s); i++ {
|
||
switch s[i] {
|
||
case '{':
|
||
depth++
|
||
case '}':
|
||
depth--
|
||
if depth == 0 {
|
||
return s[start : i+1]
|
||
}
|
||
}
|
||
}
|
||
return s[start:]
|
||
}
|