feat(classifier): Wave A — TaskClassifier interface + HeuristicClassifier
- internal/router/classifier.go: TaskClassifier interface with Classify(ctx, prompt, history) signature. HeuristicClassifier wraps the existing ClassifyTask() with zero behavior change. - engine.Config.Classifier: injectable TaskClassifier; nil defaults to HeuristicClassifier. Engine.classify() helper handles nil + error fallback transparently. - loop.go: all four router.ClassifyTask() call sites replaced with e.classify(ctx, prompt). SLMClassifier slots in without further changes to the engine.
This commit is contained in:
+28
-12
@@ -6,30 +6,31 @@ import (
|
||||
"log/slog"
|
||||
|
||||
gnomactx "somegit.dev/Owlibou/gnoma/internal/context"
|
||||
"somegit.dev/Owlibou/gnoma/internal/hook"
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
"somegit.dev/Owlibou/gnoma/internal/permission"
|
||||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||||
"somegit.dev/Owlibou/gnoma/internal/router"
|
||||
"somegit.dev/Owlibou/gnoma/internal/security"
|
||||
"somegit.dev/Owlibou/gnoma/internal/tool"
|
||||
"somegit.dev/Owlibou/gnoma/internal/hook"
|
||||
"somegit.dev/Owlibou/gnoma/internal/tool/persist"
|
||||
)
|
||||
|
||||
// Config holds engine configuration.
|
||||
type Config struct {
|
||||
Provider provider.Provider // direct provider (used if Router is nil)
|
||||
Router *router.Router // nil = use Provider directly
|
||||
Provider provider.Provider // direct provider (used if Router is nil)
|
||||
Router *router.Router // nil = use Provider directly
|
||||
Classifier router.TaskClassifier // nil = HeuristicClassifier
|
||||
Tools *tool.Registry
|
||||
Firewall *security.Firewall // nil = no scanning
|
||||
Permissions *permission.Checker // nil = allow all
|
||||
Context *gnomactx.Window // nil = no compaction
|
||||
System string // system prompt
|
||||
Model string // override model (empty = provider default)
|
||||
Temperature *float64 // nil = provider default
|
||||
MaxTurns int // safety limit on tool loops (0 = unlimited)
|
||||
Store *persist.Store // nil = no result persistence
|
||||
Hooks *hook.Dispatcher // nil = no hooks
|
||||
Firewall *security.Firewall // nil = no scanning
|
||||
Permissions *permission.Checker // nil = allow all
|
||||
Context *gnomactx.Window // nil = no compaction
|
||||
System string // system prompt
|
||||
Model string // override model (empty = provider default)
|
||||
Temperature *float64 // nil = provider default
|
||||
MaxTurns int // safety limit on tool loops (0 = unlimited)
|
||||
Store *persist.Store // nil = no result persistence
|
||||
Hooks *hook.Dispatcher // nil = no hooks
|
||||
Logger *slog.Logger
|
||||
}
|
||||
|
||||
@@ -228,6 +229,21 @@ func (e *Engine) SetActivatedTools(tools map[string]bool) {
|
||||
e.activatedTools = tools
|
||||
}
|
||||
|
||||
// classify returns a Task for the given prompt using the configured classifier.
|
||||
// Falls back to HeuristicClassifier if none is configured or if classification fails.
|
||||
func (e *Engine) classify(ctx context.Context, prompt string) router.Task {
|
||||
cls := e.cfg.Classifier
|
||||
if cls == nil {
|
||||
cls = router.HeuristicClassifier{}
|
||||
}
|
||||
task, err := cls.Classify(ctx, prompt, e.history)
|
||||
if err != nil {
|
||||
e.logger.Debug("classifier error, falling back to heuristic", "error", err)
|
||||
return router.ClassifyTask(prompt)
|
||||
}
|
||||
return task
|
||||
}
|
||||
|
||||
// Reset clears conversation history and usage.
|
||||
func (e *Engine) Reset() {
|
||||
e.history = nil
|
||||
|
||||
@@ -579,6 +579,94 @@ func TestSubmit_CumulativeUsage(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// spyClassifier records calls and delegates to HeuristicClassifier.
|
||||
type spyClassifier struct {
|
||||
calls int
|
||||
result *router.Task // when non-nil, return this instead of heuristic result
|
||||
}
|
||||
|
||||
func (s *spyClassifier) Classify(ctx context.Context, prompt string, history []message.Message) (router.Task, error) {
|
||||
s.calls++
|
||||
if s.result != nil {
|
||||
return *s.result, nil
|
||||
}
|
||||
return router.HeuristicClassifier{}.Classify(ctx, prompt, history)
|
||||
}
|
||||
|
||||
func TestSubmit_UsesInjectedClassifier(t *testing.T) {
|
||||
rtr := router.New(router.Config{})
|
||||
armID := router.NewArmID("test", "mock-model")
|
||||
mp := &mockProvider{
|
||||
name: "test",
|
||||
streams: []stream.Stream{
|
||||
newEventStream(message.StopEndTurn, "mock-model",
|
||||
stream.Event{Type: stream.EventTextDelta, Text: "ok"},
|
||||
),
|
||||
},
|
||||
}
|
||||
rtr.RegisterArm(&router.Arm{
|
||||
ID: armID,
|
||||
Provider: mp,
|
||||
ModelName: "mock-model",
|
||||
Capabilities: provider.Capabilities{ToolUse: true},
|
||||
})
|
||||
rtr.ForceArm(armID)
|
||||
|
||||
spy := &spyClassifier{}
|
||||
e, err := New(Config{
|
||||
Provider: mp,
|
||||
Router: rtr,
|
||||
Tools: tool.NewRegistry(),
|
||||
Classifier: spy,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
|
||||
if _, err := e.Submit(context.Background(), "implement a parser", nil); err != nil {
|
||||
t.Fatalf("Submit: %v", err)
|
||||
}
|
||||
|
||||
if spy.calls == 0 {
|
||||
t.Error("expected Classify to be called at least once, got 0 calls")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubmit_NilClassifierFallsBackToHeuristic(t *testing.T) {
|
||||
rtr := router.New(router.Config{})
|
||||
armID := router.NewArmID("test", "mock-model")
|
||||
mp := &mockProvider{
|
||||
name: "test",
|
||||
streams: []stream.Stream{
|
||||
newEventStream(message.StopEndTurn, "mock-model",
|
||||
stream.Event{Type: stream.EventTextDelta, Text: "ok"},
|
||||
),
|
||||
},
|
||||
}
|
||||
rtr.RegisterArm(&router.Arm{
|
||||
ID: armID,
|
||||
Provider: mp,
|
||||
ModelName: "mock-model",
|
||||
Capabilities: provider.Capabilities{ToolUse: true},
|
||||
})
|
||||
rtr.ForceArm(armID)
|
||||
|
||||
// No Classifier set — should not panic, should use heuristic
|
||||
e, err := New(Config{
|
||||
Provider: mp,
|
||||
Router: rtr,
|
||||
Tools: tool.NewRegistry(),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
|
||||
_, err = e.Submit(context.Background(), "debug the server crash", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Submit with nil Classifier: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubmit_ReportsOutcomeToRouter(t *testing.T) {
|
||||
rtr := router.New(router.Config{})
|
||||
armID := router.NewArmID("test", "mock-model")
|
||||
|
||||
@@ -97,7 +97,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
|
||||
break
|
||||
}
|
||||
}
|
||||
task := router.ClassifyTask(prompt)
|
||||
task := e.classify(ctx, prompt)
|
||||
if e.cfg.Context != nil {
|
||||
task.EstimatedTokens = int(e.cfg.Context.Tracker().CountTokens(prompt))
|
||||
} else {
|
||||
@@ -151,7 +151,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
|
||||
break
|
||||
}
|
||||
}
|
||||
task := router.ClassifyTask(prompt)
|
||||
task := e.classify(ctx, prompt)
|
||||
if e.cfg.Context != nil {
|
||||
task.EstimatedTokens = int(e.cfg.Context.Tracker().CountTokens(prompt))
|
||||
} else {
|
||||
@@ -376,7 +376,7 @@ func (e *Engine) buildRequest(ctx context.Context) provider.Request {
|
||||
break
|
||||
}
|
||||
}
|
||||
if router.ClassifyTask(prompt).Type == router.TaskOrchestration {
|
||||
if e.classify(ctx, prompt).Type == router.TaskOrchestration {
|
||||
req.SystemPrompt = coordinatorPrompt() + "\n\n" + req.SystemPrompt
|
||||
}
|
||||
}
|
||||
@@ -596,7 +596,7 @@ func (e *Engine) handleRequestTooLarge(ctx context.Context, origErr error, req p
|
||||
break
|
||||
}
|
||||
}
|
||||
task := router.ClassifyTask(prompt)
|
||||
task := e.classify(ctx, prompt)
|
||||
if e.cfg.Context != nil {
|
||||
task.EstimatedTokens = int(e.cfg.Context.Tracker().CountTokens(prompt))
|
||||
} else {
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
)
|
||||
|
||||
// TaskClassifier classifies a user prompt into a Task for routing decisions.
|
||||
// The history slice provides prior conversation context; implementations may
|
||||
// ignore it (HeuristicClassifier) or use it for richer inference (SLMClassifier).
|
||||
type TaskClassifier interface {
|
||||
Classify(ctx context.Context, prompt string, history []message.Message) (Task, error)
|
||||
}
|
||||
|
||||
// HeuristicClassifier is the default classifier. It wraps the keyword-based
|
||||
// ClassifyTask function and ignores conversation history.
|
||||
type HeuristicClassifier struct{}
|
||||
|
||||
func (HeuristicClassifier) Classify(_ context.Context, prompt string, _ []message.Message) (Task, error) {
|
||||
return ClassifyTask(prompt), nil
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
)
|
||||
|
||||
// TestHeuristicClassifier_ParityWithClassifyTask verifies that
|
||||
// HeuristicClassifier.Classify produces identical results to ClassifyTask.
|
||||
func TestHeuristicClassifier_ParityWithClassifyTask(t *testing.T) {
|
||||
prompts := []string{
|
||||
"debug the failing test",
|
||||
"explain how generics work",
|
||||
"implement a new HTTP handler",
|
||||
"refactor the auth middleware",
|
||||
"security audit the login flow",
|
||||
"write unit tests for the parser",
|
||||
"scaffold a new service",
|
||||
"plan the migration strategy",
|
||||
"orchestrate the deployment pipeline",
|
||||
"review the pull request",
|
||||
}
|
||||
|
||||
cls := HeuristicClassifier{}
|
||||
ctx := context.Background()
|
||||
var noHistory []message.Message
|
||||
|
||||
for _, p := range prompts {
|
||||
want := ClassifyTask(p)
|
||||
got, err := cls.Classify(ctx, p, noHistory)
|
||||
if err != nil {
|
||||
t.Errorf("Classify(%q) unexpected error: %v", p, err)
|
||||
continue
|
||||
}
|
||||
if got.Type != want.Type {
|
||||
t.Errorf("Classify(%q).Type = %s, want %s", p, got.Type, want.Type)
|
||||
}
|
||||
if got.ComplexityScore != want.ComplexityScore {
|
||||
t.Errorf("Classify(%q).ComplexityScore = %v, want %v", p, got.ComplexityScore, want.ComplexityScore)
|
||||
}
|
||||
if got.RequiresTools != want.RequiresTools {
|
||||
t.Errorf("Classify(%q).RequiresTools = %v, want %v", p, got.RequiresTools, want.RequiresTools)
|
||||
}
|
||||
if got.Priority != want.Priority {
|
||||
t.Errorf("Classify(%q).Priority = %v, want %v", p, got.Priority, want.Priority)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestHeuristicClassifier_IgnoresHistory verifies that history has no effect
|
||||
// on the heuristic classifier (it operates only on the prompt).
|
||||
func TestHeuristicClassifier_IgnoresHistory(t *testing.T) {
|
||||
cls := HeuristicClassifier{}
|
||||
ctx := context.Background()
|
||||
prompt := "implement a binary search function"
|
||||
|
||||
withoutHistory, _ := cls.Classify(ctx, prompt, nil)
|
||||
withHistory, _ := cls.Classify(ctx, prompt, []message.Message{
|
||||
{Role: message.RoleUser, Content: []message.Content{{Type: message.ContentText, Text: "previous message"}}},
|
||||
})
|
||||
|
||||
if withoutHistory.Type != withHistory.Type {
|
||||
t.Errorf("history should not affect HeuristicClassifier: got %s vs %s",
|
||||
withoutHistory.Type, withHistory.Type)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user