c0c2e4bff5
Two structural fixes for the SLM classifier's 100% failure rate: (1) Pass ResponseFormat=json_object + Temperature=0 + TopP=1 + MaxTokens=128 in the classifier Request. The provider type already supports these but callSLM was leaving them unset, which meant ollama (and any other backend) ran with default sampling and free-form text output. format=json mode in particular makes ollama emit only valid JSON at decoding time — eliminates the majority of parse failures. (2) Harden extractJSON to strip common thinking-block tags before hunting for the brace. Seen in the wild: <think>…</think> (Qwen3 distillations) and <Thought Process>…</Thought Process> (tiny3.5). Defensive list also covers <reasoning>, <thoughts>. Unterminated thinking blocks fall back to brace-search so we still have a shot. Table-driven tests cover all variants plus the no-tag and fenced-json paths to confirm no regression. Even with format=json on a capable provider, the extractor is the safety net for backends that don't enforce format strictly — same defence-in-depth shape as the existing fence stripping. Doesn't fix the deeper architecture question (encoder + bandit preferred over decoder-SLM as classifier — see plan doc landing in the same PR); fixes the immediate bug.
227 lines
7.6 KiB
Go
227 lines
7.6 KiB
Go
package slm
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"log/slog"
|
||
"strings"
|
||
"time"
|
||
|
||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||
"somegit.dev/Owlibou/gnoma/internal/router"
|
||
"somegit.dev/Owlibou/gnoma/internal/stream"
|
||
)
|
||
|
||
// defaultClassifyTimeout — 15 s accommodates cold-start model loads
|
||
// (ollama lazily loads on first call, ~2-8s for a 1.5B model on SSD)
|
||
// combined with thinking-mode first-token latency (Qwen3 distillations
|
||
// like Tiny3.5 sometimes emit <think> tokens before the JSON output
|
||
// even with /no_think). Non-thinking warm models complete in well
|
||
// under 1 s. Tune via [slm].classify_timeout in config.
|
||
const defaultClassifyTimeout = 15 * time.Second
|
||
|
||
const classifySystemPrompt = `Classify the following coding request. /no_think
|
||
Respond with JSON only, no other text, no reasoning, no thinking tags.
|
||
Format: {"task_type": "<type>", "complexity": <0.0-1.0>, "requires_tools": <true|false>}
|
||
|
||
Task types: Debug, Explain, Generation, Refactor, UnitTest, Boilerplate, Planning, Orchestration, SecurityReview, Review
|
||
|
||
Complexity guide:
|
||
0.0–0.3: boilerplate, trivial edits, simple lookups, short explanations
|
||
0.4–0.6: new functions, refactors, unit tests, moderate analysis
|
||
0.7–1.0: architectural changes, multi-file edits, security review, planning`
|
||
|
||
type classifyResponse struct {
|
||
TaskType string `json:"task_type"`
|
||
Complexity float64 `json:"complexity"`
|
||
RequiresTools bool `json:"requires_tools"`
|
||
}
|
||
|
||
// Classifier implements router.TaskClassifier using a llamafile-hosted SLM.
|
||
// On timeout or parse failure it falls back to router.HeuristicClassifier.
|
||
type Classifier struct {
|
||
provider provider.Provider
|
||
model string
|
||
timeout time.Duration
|
||
logger *slog.Logger
|
||
}
|
||
|
||
// NewClassifier creates a Classifier. model is the model name passed to the provider
|
||
// (llamafile ignores it but openaicompat requires a non-empty value).
|
||
// Pass timeout=0 to use the built-in default (defaultClassifyTimeout).
|
||
func NewClassifier(p provider.Provider, model string, timeout time.Duration, logger *slog.Logger) *Classifier {
|
||
if logger == nil {
|
||
logger = slog.Default()
|
||
}
|
||
if timeout <= 0 {
|
||
timeout = defaultClassifyTimeout
|
||
}
|
||
return &Classifier{
|
||
provider: p,
|
||
model: model,
|
||
timeout: timeout,
|
||
logger: logger,
|
||
}
|
||
}
|
||
|
||
// Classify calls the SLM and overlays the three SLM-authoritative fields
|
||
// (Type, ComplexityScore, RequiresTools) onto a heuristic baseline Task.
|
||
// This ensures Priority, EstimatedTokens, and RequiredEffort are always set.
|
||
func (c *Classifier) Classify(ctx context.Context, prompt string, history []message.Message) (router.Task, error) {
|
||
tctx, cancel := context.WithTimeout(ctx, c.timeout)
|
||
defer cancel()
|
||
|
||
resp, err := c.callSLM(tctx, prompt)
|
||
if err != nil {
|
||
// Warn-level so a first-time misconfiguration (timeout too tight,
|
||
// wrong endpoint, malformed JSON from the model) surfaces without
|
||
// requiring --verbose. The fallback path itself is benign; the
|
||
// signal is that the SLM isn't doing the work it was supposed to.
|
||
c.logger.Warn("slm classify fallback", "error", err, "timeout", c.timeout)
|
||
t, ferr := router.HeuristicClassifier{}.Classify(ctx, prompt, history)
|
||
t.ClassifierSource = router.ClassifierSLMFallback
|
||
return t, ferr
|
||
}
|
||
|
||
// Start from the heuristic baseline so Priority/EstimatedTokens/RequiredEffort are set.
|
||
task := router.ClassifyTask(prompt)
|
||
task.Type = router.ParseTaskType(resp.TaskType)
|
||
task.ComplexityScore = resp.Complexity
|
||
task.RequiresTools = resp.RequiresTools
|
||
task.ClassifierSource = router.ClassifierSLM
|
||
// Re-apply the per-task-type complexity floor after the SLM overlay.
|
||
// The SLM may have under-reported complexity for a Refactor-style
|
||
// task; the floor protects the SLM arm from being picked for its own
|
||
// kind of misclassification.
|
||
if floor := router.MinComplexityForType(task.Type); task.ComplexityScore < floor {
|
||
task.ComplexityScore = floor
|
||
}
|
||
return task, nil
|
||
}
|
||
|
||
func (c *Classifier) callSLM(ctx context.Context, prompt string) (*classifyResponse, error) {
|
||
// Constrain the model toward valid, deterministic JSON output. Without
|
||
// these settings small models routinely ignore the JSON-only system
|
||
// prompt, emit reasoning blocks (<think>, <Thought Process>) or just
|
||
// answer the user's prompt in prose. ResponseFormat=json_object asks
|
||
// the provider to enforce JSON at decoding time where supported
|
||
// (ollama 'format=json', llama.cpp grammar, OpenAI json_object). Even
|
||
// when the provider can't enforce, the explicit signal nudges the
|
||
// adapter to set the right backend flag.
|
||
temp := 0.0
|
||
topP := 1.0
|
||
req := provider.Request{
|
||
Model: c.model,
|
||
SystemPrompt: classifySystemPrompt,
|
||
Temperature: &temp,
|
||
TopP: &topP,
|
||
MaxTokens: 128, // classification output is ~50 tokens; cap to prevent runaway reasoning
|
||
ResponseFormat: &provider.ResponseFormat{
|
||
Type: provider.ResponseJSON,
|
||
},
|
||
Messages: []message.Message{
|
||
{
|
||
Role: message.RoleUser,
|
||
Content: []message.Content{{Type: message.ContentText, Text: prompt}},
|
||
},
|
||
},
|
||
}
|
||
|
||
strm, err := c.provider.Stream(ctx, req)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("stream: %w", err)
|
||
}
|
||
defer func() { _ = strm.Close() }()
|
||
|
||
var sb strings.Builder
|
||
for strm.Next() {
|
||
ev := strm.Current()
|
||
if ev.Type == stream.EventTextDelta {
|
||
sb.WriteString(ev.Text)
|
||
}
|
||
}
|
||
if err := strm.Err(); err != nil {
|
||
return nil, fmt.Errorf("stream error: %w", err)
|
||
}
|
||
|
||
text := extractJSON(sb.String())
|
||
var resp classifyResponse
|
||
if err := json.Unmarshal([]byte(text), &resp); err != nil {
|
||
return nil, fmt.Errorf("parse %q: %w", text, err)
|
||
}
|
||
return &resp, nil
|
||
}
|
||
|
||
// extractJSON pulls the first {...} substring from s, stripping markdown
|
||
// fences and known thinking-block tags. Small models routinely violate
|
||
// the JSON-only system prompt by emitting reasoning tokens first, so
|
||
// the extractor must tolerate prefixes the model wasn't asked to emit.
|
||
func extractJSON(s string) string {
|
||
s = strings.TrimSpace(s)
|
||
|
||
// Strip known thinking-block tags. Order matters: longer/more-
|
||
// specific names first so a partial match doesn't shadow a real
|
||
// one. Seen in the wild on Qwen3 (<think>) and tiny3.5
|
||
// (<Thought Process>); the others are defensive against similar
|
||
// fine-tunes.
|
||
for _, tag := range []string{"Thought Process", "thinking", "reasoning", "thoughts", "think"} {
|
||
s = stripTagBlock(s, tag)
|
||
}
|
||
|
||
// Strip ```json ... ``` fences.
|
||
if strings.HasPrefix(s, "```") {
|
||
end := strings.LastIndex(s, "```")
|
||
if end > 3 {
|
||
inner := s[3:end]
|
||
inner = strings.TrimPrefix(inner, "json")
|
||
s = strings.TrimSpace(inner)
|
||
}
|
||
}
|
||
|
||
// Extract first balanced {...} block.
|
||
start := strings.IndexByte(s, '{')
|
||
if start < 0 {
|
||
return s
|
||
}
|
||
depth := 0
|
||
for i := start; i < len(s); i++ {
|
||
switch s[i] {
|
||
case '{':
|
||
depth++
|
||
case '}':
|
||
depth--
|
||
if depth == 0 {
|
||
return s[start : i+1]
|
||
}
|
||
}
|
||
}
|
||
return s[start:]
|
||
}
|
||
|
||
// stripTagBlock removes <tag>...</tag> blocks (case-insensitive on the
|
||
// tag name) from the start of s. Returns the original string if the tag
|
||
// is not at the start. Idempotent; safe to call repeatedly.
|
||
func stripTagBlock(s, tag string) string {
|
||
trimmed := strings.TrimSpace(s)
|
||
open := "<" + tag
|
||
lower := strings.ToLower(trimmed)
|
||
if !strings.HasPrefix(lower, strings.ToLower(open)) {
|
||
return s
|
||
}
|
||
// Find the matching closing tag, case-insensitive.
|
||
close := "</" + tag + ">"
|
||
closeIdx := strings.Index(strings.ToLower(trimmed), strings.ToLower(close))
|
||
if closeIdx < 0 {
|
||
// Unterminated thinking block — strip up to the first '{'
|
||
// so we still have a shot at extracting JSON that follows.
|
||
braceIdx := strings.IndexByte(trimmed, '{')
|
||
if braceIdx > 0 {
|
||
return strings.TrimSpace(trimmed[braceIdx:])
|
||
}
|
||
return s
|
||
}
|
||
return strings.TrimSpace(trimmed[closeIdx+len(close):])
|
||
}
|