c0c2e4bff5
Two structural fixes for the SLM classifier's 100% failure rate: (1) Pass ResponseFormat=json_object + Temperature=0 + TopP=1 + MaxTokens=128 in the classifier Request. The provider type already supports these but callSLM was leaving them unset, which meant ollama (and any other backend) ran with default sampling and free-form text output. format=json mode in particular makes ollama emit only valid JSON at decoding time — eliminates the majority of parse failures. (2) Harden extractJSON to strip common thinking-block tags before hunting for the brace. Seen in the wild: <think>…</think> (Qwen3 distillations) and <Thought Process>…</Thought Process> (tiny3.5). Defensive list also covers <reasoning>, <thoughts>. Unterminated thinking blocks fall back to brace-search so we still have a shot. Table-driven tests cover all variants plus the no-tag and fenced-json paths to confirm no regression. Even with format=json on a capable provider, the extractor is the safety net for backends that don't enforce format strictly — same defence-in-depth shape as the existing fence stripping. Doesn't fix the deeper architecture question (encoder + bandit preferred over decoder-SLM as classifier — see plan doc landing in the same PR); fixes the immediate bug.
260 lines
8.7 KiB
Go
260 lines
8.7 KiB
Go
package slm
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"testing"
|
|
"time"
|
|
|
|
"somegit.dev/Owlibou/gnoma/internal/message"
|
|
"somegit.dev/Owlibou/gnoma/internal/provider"
|
|
"somegit.dev/Owlibou/gnoma/internal/router"
|
|
"somegit.dev/Owlibou/gnoma/internal/stream"
|
|
)
|
|
|
|
// mockProvider implements provider.Provider for classifier tests.
|
|
type mockProvider struct {
|
|
text string
|
|
delay time.Duration
|
|
err error
|
|
}
|
|
|
|
func (m *mockProvider) Name() string { return "mock" }
|
|
func (m *mockProvider) DefaultModel() string { return "default" }
|
|
func (m *mockProvider) Models(_ context.Context) ([]provider.ModelInfo, error) {
|
|
return nil, nil
|
|
}
|
|
func (m *mockProvider) Stream(ctx context.Context, _ provider.Request) (stream.Stream, error) {
|
|
if m.delay > 0 {
|
|
select {
|
|
case <-time.After(m.delay):
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
}
|
|
}
|
|
if m.err != nil {
|
|
return nil, m.err
|
|
}
|
|
return &mockStream{events: []stream.Event{
|
|
{Type: stream.EventTextDelta, Text: m.text},
|
|
}}, nil
|
|
}
|
|
|
|
type mockStream struct {
|
|
events []stream.Event
|
|
idx int
|
|
}
|
|
|
|
func (s *mockStream) Next() bool { s.idx++; return s.idx <= len(s.events) }
|
|
func (s *mockStream) Current() stream.Event { return s.events[s.idx-1] }
|
|
func (s *mockStream) Err() error { return nil }
|
|
func (s *mockStream) Close() error { return nil }
|
|
|
|
func TestClassifier_HappyPath(t *testing.T) {
|
|
// SLM complexity 0.55 stays above the Debug floor (0.4), so the SLM
|
|
// value is preserved verbatim.
|
|
p := &mockProvider{text: `{"task_type":"Debug","complexity":0.55,"requires_tools":false}`}
|
|
cls := NewClassifier(p, "default", 0, nil)
|
|
|
|
task, err := cls.Classify(context.Background(), "fix the failing test", nil)
|
|
if err != nil {
|
|
t.Fatalf("Classify: %v", err)
|
|
}
|
|
if task.Type != router.TaskDebug {
|
|
t.Errorf("Type = %s, want Debug", task.Type)
|
|
}
|
|
if task.ComplexityScore != 0.55 {
|
|
t.Errorf("ComplexityScore = %v, want 0.55 (SLM value preserved above floor)", task.ComplexityScore)
|
|
}
|
|
if task.RequiresTools != false {
|
|
t.Errorf("RequiresTools = true, want false")
|
|
}
|
|
}
|
|
|
|
func TestClassifier_AppliesTaskTypeFloor(t *testing.T) {
|
|
// Debug floor is 0.4; SLM under-reports at 0.25. The classifier should
|
|
// bump ComplexityScore up to the floor so the SLM arm can't be picked
|
|
// for its own kind of misclassification.
|
|
p := &mockProvider{text: `{"task_type":"Debug","complexity":0.25,"requires_tools":false}`}
|
|
cls := NewClassifier(p, "default", 0, nil)
|
|
|
|
task, err := cls.Classify(context.Background(), "fix the failing test", nil)
|
|
if err != nil {
|
|
t.Fatalf("Classify: %v", err)
|
|
}
|
|
floor := router.MinComplexityForType(router.TaskDebug)
|
|
if task.ComplexityScore != floor {
|
|
t.Errorf("ComplexityScore = %v, want floor %v", task.ComplexityScore, floor)
|
|
}
|
|
}
|
|
|
|
func TestClassifier_BlendHeuristic(t *testing.T) {
|
|
// SLM returns one type; other Task fields should come from heuristic.
|
|
p := &mockProvider{text: `{"task_type":"Boilerplate","complexity":0.1,"requires_tools":false}`}
|
|
cls := NewClassifier(p, "default", 0, nil)
|
|
|
|
task, err := cls.Classify(context.Background(), "scaffold a new HTTP handler", nil)
|
|
if err != nil {
|
|
t.Fatalf("Classify: %v", err)
|
|
}
|
|
if task.Type != router.TaskBoilerplate {
|
|
t.Errorf("Type = %s, want Boilerplate", task.Type)
|
|
}
|
|
// Priority must come from the heuristic baseline (PriorityNormal = 1, not zero).
|
|
if task.Priority < router.PriorityNormal {
|
|
t.Errorf("Priority = %v, want at least PriorityNormal from heuristic baseline", task.Priority)
|
|
}
|
|
}
|
|
|
|
func TestClassifier_FallbackOnBadJSON(t *testing.T) {
|
|
p := &mockProvider{text: "I cannot classify that."}
|
|
cls := NewClassifier(p, "default", 0, nil)
|
|
|
|
// Should not error — falls back to heuristic.
|
|
task, err := cls.Classify(context.Background(), "write unit tests for the parser", nil)
|
|
if err != nil {
|
|
t.Fatalf("Classify should not error on bad JSON: %v", err)
|
|
}
|
|
// Heuristic would return UnitTest for "write unit tests".
|
|
if task.Type != router.TaskUnitTest {
|
|
t.Errorf("heuristic fallback: Type = %s, want UnitTest", task.Type)
|
|
}
|
|
}
|
|
|
|
func TestClassifier_FallbackOnProviderError(t *testing.T) {
|
|
p := &mockProvider{err: errors.New("connection refused")}
|
|
cls := NewClassifier(p, "default", 0, nil)
|
|
|
|
task, err := cls.Classify(context.Background(), "explain how generics work", nil)
|
|
if err != nil {
|
|
t.Fatalf("Classify should not error on provider error: %v", err)
|
|
}
|
|
// Heuristic fallback: "explain" → TaskExplain
|
|
if task.Type != router.TaskExplain {
|
|
t.Errorf("heuristic fallback: Type = %s, want Explain", task.Type)
|
|
}
|
|
}
|
|
|
|
func TestClassifier_FallbackOnTimeout(t *testing.T) {
|
|
p := &mockProvider{delay: 500 * time.Millisecond}
|
|
cls := NewClassifier(p, "default", 0, nil)
|
|
cls.timeout = 50 * time.Millisecond // force timeout
|
|
|
|
task, err := cls.Classify(context.Background(), "debug the failing test", nil)
|
|
if err != nil {
|
|
t.Fatalf("Classify should not error on timeout: %v", err)
|
|
}
|
|
// Falls back to heuristic: "debug" → TaskDebug
|
|
if task.Type != router.TaskDebug {
|
|
t.Errorf("heuristic fallback: Type = %s, want Debug", task.Type)
|
|
}
|
|
}
|
|
|
|
func TestClassifier_FenceStripping(t *testing.T) {
|
|
fenced := "```json\n{\"task_type\":\"Refactor\",\"complexity\":0.5,\"requires_tools\":true}\n```"
|
|
p := &mockProvider{text: fenced}
|
|
cls := NewClassifier(p, "default", 0, nil)
|
|
|
|
task, err := cls.Classify(context.Background(), "refactor the auth middleware", nil)
|
|
if err != nil {
|
|
t.Fatalf("Classify: %v", err)
|
|
}
|
|
if task.Type != router.TaskRefactor {
|
|
t.Errorf("Type = %s, want Refactor", task.Type)
|
|
}
|
|
}
|
|
|
|
func TestClassifier_UnknownTaskType_FallsBackToHeuristic(t *testing.T) {
|
|
p := &mockProvider{text: `{"task_type":"FooBar","complexity":0.3,"requires_tools":false}`}
|
|
cls := NewClassifier(p, "default", 0, nil)
|
|
|
|
task, err := cls.Classify(context.Background(), "implement a binary search function", nil)
|
|
if err != nil {
|
|
t.Fatalf("Classify: %v", err)
|
|
}
|
|
// "implement" → heuristic should give Generation or Boilerplate; SLM gave FooBar → Generation fallback
|
|
_ = task // just verify no panic and no error
|
|
}
|
|
|
|
func TestClassifier_SetsClassifierSource_OnSuccess(t *testing.T) {
|
|
p := &mockProvider{text: `{"task_type":"Debug","complexity":0.3,"requires_tools":true}`}
|
|
cls := NewClassifier(p, "default", 0, nil)
|
|
task, err := cls.Classify(context.Background(), "fix the failing test", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if task.ClassifierSource != router.ClassifierSLM {
|
|
t.Errorf("ClassifierSource = %v, want ClassifierSLM", task.ClassifierSource)
|
|
}
|
|
}
|
|
|
|
func TestClassifier_SetsClassifierSource_OnFallback(t *testing.T) {
|
|
p := &mockProvider{err: errors.New("backend unreachable")}
|
|
cls := NewClassifier(p, "default", 0, nil)
|
|
task, err := cls.Classify(context.Background(), "fix the failing test", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if task.ClassifierSource != router.ClassifierSLMFallback {
|
|
t.Errorf("ClassifierSource = %v, want ClassifierSLMFallback", task.ClassifierSource)
|
|
}
|
|
}
|
|
|
|
func TestClassifier_ContextPassedToHistory(t *testing.T) {
|
|
p := &mockProvider{text: `{"task_type":"Explain","complexity":0.2,"requires_tools":false}`}
|
|
cls := NewClassifier(p, "default", 0, nil)
|
|
|
|
history := []message.Message{
|
|
{Role: message.RoleUser, Content: []message.Content{{Type: message.ContentText, Text: "prior"}}},
|
|
}
|
|
task, err := cls.Classify(context.Background(), "explain this code", history)
|
|
if err != nil {
|
|
t.Fatalf("Classify: %v", err)
|
|
}
|
|
if task.Type != router.TaskExplain {
|
|
t.Errorf("Type = %s, want Explain", task.Type)
|
|
}
|
|
}
|
|
|
|
func TestExtractJSON_StripsThinkingTags(t *testing.T) {
|
|
cases := []struct {
|
|
name string
|
|
in string
|
|
want string
|
|
}{
|
|
{
|
|
name: "qwen-think-block",
|
|
in: `<think>Let me decide</think>{"task_type":"Debug","complexity":0.5,"requires_tools":true}`,
|
|
want: `{"task_type":"Debug","complexity":0.5,"requires_tools":true}`,
|
|
},
|
|
{
|
|
name: "tiny3.5-thought-process",
|
|
in: "<Thought Process>\nUser wants debugging help.\n</Thought Process>\n{\"task_type\":\"Debug\",\"complexity\":0.4,\"requires_tools\":true}",
|
|
want: `{"task_type":"Debug","complexity":0.4,"requires_tools":true}`,
|
|
},
|
|
{
|
|
name: "unterminated-think-falls-back-to-brace",
|
|
in: `<think>incomplete reasoning {"task_type":"Explain","complexity":0.2,"requires_tools":false}`,
|
|
want: `{"task_type":"Explain","complexity":0.2,"requires_tools":false}`,
|
|
},
|
|
{
|
|
name: "no-tags-still-works",
|
|
in: `{"task_type":"Generation","complexity":0.6,"requires_tools":false}`,
|
|
want: `{"task_type":"Generation","complexity":0.6,"requires_tools":false}`,
|
|
},
|
|
{
|
|
name: "fenced-json-still-works",
|
|
in: "```json\n{\"task_type\":\"Refactor\",\"complexity\":0.5,\"requires_tools\":true}\n```",
|
|
want: `{"task_type":"Refactor","complexity":0.5,"requires_tools":true}`,
|
|
},
|
|
}
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
got := extractJSON(tc.in)
|
|
if got != tc.want {
|
|
t.Errorf("extractJSON(...)\n got: %q\n want: %q", got, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|