Files
gnoma/internal/slm/classifier.go
T
vikingowl c0c2e4bff5 fix(slm): enforce JSON output + strip thinking-block prefixes
Two structural fixes for the SLM classifier's 100% failure rate:

(1) Pass ResponseFormat=json_object + Temperature=0 + TopP=1 +
MaxTokens=128 in the classifier Request. The provider type already
supports these but callSLM was leaving them unset, which meant ollama
(and any other backend) ran with default sampling and free-form text
output. format=json mode in particular makes ollama emit only valid
JSON at decoding time — eliminates the majority of parse failures.

(2) Harden extractJSON to strip common thinking-block tags before
hunting for the brace. Seen in the wild: <think>…</think> (Qwen3
distillations) and <Thought Process>…</Thought Process> (tiny3.5).
Defensive list also covers <reasoning>, <thoughts>. Unterminated
thinking blocks fall back to brace-search so we still have a shot.
Table-driven tests cover all variants plus the no-tag and
fenced-json paths to confirm no regression.

Even with format=json on a capable provider, the extractor is the
safety net for backends that don't enforce format strictly — same
defence-in-depth shape as the existing fence stripping.

Doesn't fix the deeper architecture question (encoder + bandit
preferred over decoder-SLM as classifier — see plan doc landing in
the same PR); fixes the immediate bug.
2026-05-25 01:19:51 +02:00

227 lines
7.6 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package slm
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"strings"
"time"
"somegit.dev/Owlibou/gnoma/internal/message"
"somegit.dev/Owlibou/gnoma/internal/provider"
"somegit.dev/Owlibou/gnoma/internal/router"
"somegit.dev/Owlibou/gnoma/internal/stream"
)
// defaultClassifyTimeout — 15 s accommodates cold-start model loads
// (ollama lazily loads on first call, ~2-8s for a 1.5B model on SSD)
// combined with thinking-mode first-token latency (Qwen3 distillations
// like Tiny3.5 sometimes emit <think> tokens before the JSON output
// even with /no_think). Non-thinking warm models complete in well
// under 1 s. Tune via [slm].classify_timeout in config.
const defaultClassifyTimeout = 15 * time.Second
const classifySystemPrompt = `Classify the following coding request. /no_think
Respond with JSON only, no other text, no reasoning, no thinking tags.
Format: {"task_type": "<type>", "complexity": <0.0-1.0>, "requires_tools": <true|false>}
Task types: Debug, Explain, Generation, Refactor, UnitTest, Boilerplate, Planning, Orchestration, SecurityReview, Review
Complexity guide:
0.00.3: boilerplate, trivial edits, simple lookups, short explanations
0.40.6: new functions, refactors, unit tests, moderate analysis
0.71.0: architectural changes, multi-file edits, security review, planning`
type classifyResponse struct {
TaskType string `json:"task_type"`
Complexity float64 `json:"complexity"`
RequiresTools bool `json:"requires_tools"`
}
// Classifier implements router.TaskClassifier using a llamafile-hosted SLM.
// On timeout or parse failure it falls back to router.HeuristicClassifier.
type Classifier struct {
provider provider.Provider
model string
timeout time.Duration
logger *slog.Logger
}
// NewClassifier creates a Classifier. model is the model name passed to the provider
// (llamafile ignores it but openaicompat requires a non-empty value).
// Pass timeout=0 to use the built-in default (defaultClassifyTimeout).
func NewClassifier(p provider.Provider, model string, timeout time.Duration, logger *slog.Logger) *Classifier {
if logger == nil {
logger = slog.Default()
}
if timeout <= 0 {
timeout = defaultClassifyTimeout
}
return &Classifier{
provider: p,
model: model,
timeout: timeout,
logger: logger,
}
}
// Classify calls the SLM and overlays the three SLM-authoritative fields
// (Type, ComplexityScore, RequiresTools) onto a heuristic baseline Task.
// This ensures Priority, EstimatedTokens, and RequiredEffort are always set.
func (c *Classifier) Classify(ctx context.Context, prompt string, history []message.Message) (router.Task, error) {
tctx, cancel := context.WithTimeout(ctx, c.timeout)
defer cancel()
resp, err := c.callSLM(tctx, prompt)
if err != nil {
// Warn-level so a first-time misconfiguration (timeout too tight,
// wrong endpoint, malformed JSON from the model) surfaces without
// requiring --verbose. The fallback path itself is benign; the
// signal is that the SLM isn't doing the work it was supposed to.
c.logger.Warn("slm classify fallback", "error", err, "timeout", c.timeout)
t, ferr := router.HeuristicClassifier{}.Classify(ctx, prompt, history)
t.ClassifierSource = router.ClassifierSLMFallback
return t, ferr
}
// Start from the heuristic baseline so Priority/EstimatedTokens/RequiredEffort are set.
task := router.ClassifyTask(prompt)
task.Type = router.ParseTaskType(resp.TaskType)
task.ComplexityScore = resp.Complexity
task.RequiresTools = resp.RequiresTools
task.ClassifierSource = router.ClassifierSLM
// Re-apply the per-task-type complexity floor after the SLM overlay.
// The SLM may have under-reported complexity for a Refactor-style
// task; the floor protects the SLM arm from being picked for its own
// kind of misclassification.
if floor := router.MinComplexityForType(task.Type); task.ComplexityScore < floor {
task.ComplexityScore = floor
}
return task, nil
}
func (c *Classifier) callSLM(ctx context.Context, prompt string) (*classifyResponse, error) {
// Constrain the model toward valid, deterministic JSON output. Without
// these settings small models routinely ignore the JSON-only system
// prompt, emit reasoning blocks (<think>, <Thought Process>) or just
// answer the user's prompt in prose. ResponseFormat=json_object asks
// the provider to enforce JSON at decoding time where supported
// (ollama 'format=json', llama.cpp grammar, OpenAI json_object). Even
// when the provider can't enforce, the explicit signal nudges the
// adapter to set the right backend flag.
temp := 0.0
topP := 1.0
req := provider.Request{
Model: c.model,
SystemPrompt: classifySystemPrompt,
Temperature: &temp,
TopP: &topP,
MaxTokens: 128, // classification output is ~50 tokens; cap to prevent runaway reasoning
ResponseFormat: &provider.ResponseFormat{
Type: provider.ResponseJSON,
},
Messages: []message.Message{
{
Role: message.RoleUser,
Content: []message.Content{{Type: message.ContentText, Text: prompt}},
},
},
}
strm, err := c.provider.Stream(ctx, req)
if err != nil {
return nil, fmt.Errorf("stream: %w", err)
}
defer func() { _ = strm.Close() }()
var sb strings.Builder
for strm.Next() {
ev := strm.Current()
if ev.Type == stream.EventTextDelta {
sb.WriteString(ev.Text)
}
}
if err := strm.Err(); err != nil {
return nil, fmt.Errorf("stream error: %w", err)
}
text := extractJSON(sb.String())
var resp classifyResponse
if err := json.Unmarshal([]byte(text), &resp); err != nil {
return nil, fmt.Errorf("parse %q: %w", text, err)
}
return &resp, nil
}
// extractJSON pulls the first {...} substring from s, stripping markdown
// fences and known thinking-block tags. Small models routinely violate
// the JSON-only system prompt by emitting reasoning tokens first, so
// the extractor must tolerate prefixes the model wasn't asked to emit.
func extractJSON(s string) string {
s = strings.TrimSpace(s)
// Strip known thinking-block tags. Order matters: longer/more-
// specific names first so a partial match doesn't shadow a real
// one. Seen in the wild on Qwen3 (<think>) and tiny3.5
// (<Thought Process>); the others are defensive against similar
// fine-tunes.
for _, tag := range []string{"Thought Process", "thinking", "reasoning", "thoughts", "think"} {
s = stripTagBlock(s, tag)
}
// Strip ```json ... ``` fences.
if strings.HasPrefix(s, "```") {
end := strings.LastIndex(s, "```")
if end > 3 {
inner := s[3:end]
inner = strings.TrimPrefix(inner, "json")
s = strings.TrimSpace(inner)
}
}
// Extract first balanced {...} block.
start := strings.IndexByte(s, '{')
if start < 0 {
return s
}
depth := 0
for i := start; i < len(s); i++ {
switch s[i] {
case '{':
depth++
case '}':
depth--
if depth == 0 {
return s[start : i+1]
}
}
}
return s[start:]
}
// stripTagBlock removes <tag>...</tag> blocks (case-insensitive on the
// tag name) from the start of s. Returns the original string if the tag
// is not at the start. Idempotent; safe to call repeatedly.
func stripTagBlock(s, tag string) string {
trimmed := strings.TrimSpace(s)
open := "<" + tag
lower := strings.ToLower(trimmed)
if !strings.HasPrefix(lower, strings.ToLower(open)) {
return s
}
// Find the matching closing tag, case-insensitive.
close := "</" + tag + ">"
closeIdx := strings.Index(strings.ToLower(trimmed), strings.ToLower(close))
if closeIdx < 0 {
// Unterminated thinking block — strip up to the first '{'
// so we still have a shot at extracting JSON that follows.
braceIdx := strings.IndexByte(trimmed, '{')
if braceIdx > 0 {
return strings.TrimSpace(trimmed[braceIdx:])
}
return s
}
return strings.TrimSpace(trimmed[closeIdx+len(close):])
}