diff --git a/internal/engine/engine.go b/internal/engine/engine.go index 06dd8c0..55e292d 100644 --- a/internal/engine/engine.go +++ b/internal/engine/engine.go @@ -6,30 +6,31 @@ import ( "log/slog" gnomactx "somegit.dev/Owlibou/gnoma/internal/context" + "somegit.dev/Owlibou/gnoma/internal/hook" "somegit.dev/Owlibou/gnoma/internal/message" "somegit.dev/Owlibou/gnoma/internal/permission" "somegit.dev/Owlibou/gnoma/internal/provider" "somegit.dev/Owlibou/gnoma/internal/router" "somegit.dev/Owlibou/gnoma/internal/security" "somegit.dev/Owlibou/gnoma/internal/tool" - "somegit.dev/Owlibou/gnoma/internal/hook" "somegit.dev/Owlibou/gnoma/internal/tool/persist" ) // Config holds engine configuration. type Config struct { - Provider provider.Provider // direct provider (used if Router is nil) - Router *router.Router // nil = use Provider directly + Provider provider.Provider // direct provider (used if Router is nil) + Router *router.Router // nil = use Provider directly + Classifier router.TaskClassifier // nil = HeuristicClassifier Tools *tool.Registry - Firewall *security.Firewall // nil = no scanning - Permissions *permission.Checker // nil = allow all - Context *gnomactx.Window // nil = no compaction - System string // system prompt - Model string // override model (empty = provider default) - Temperature *float64 // nil = provider default - MaxTurns int // safety limit on tool loops (0 = unlimited) - Store *persist.Store // nil = no result persistence - Hooks *hook.Dispatcher // nil = no hooks + Firewall *security.Firewall // nil = no scanning + Permissions *permission.Checker // nil = allow all + Context *gnomactx.Window // nil = no compaction + System string // system prompt + Model string // override model (empty = provider default) + Temperature *float64 // nil = provider default + MaxTurns int // safety limit on tool loops (0 = unlimited) + Store *persist.Store // nil = no result persistence + Hooks *hook.Dispatcher // nil = no hooks Logger *slog.Logger } @@ -228,6 +229,21 @@ func (e *Engine) SetActivatedTools(tools map[string]bool) { e.activatedTools = tools } +// classify returns a Task for the given prompt using the configured classifier. +// Falls back to HeuristicClassifier if none is configured or if classification fails. +func (e *Engine) classify(ctx context.Context, prompt string) router.Task { + cls := e.cfg.Classifier + if cls == nil { + cls = router.HeuristicClassifier{} + } + task, err := cls.Classify(ctx, prompt, e.history) + if err != nil { + e.logger.Debug("classifier error, falling back to heuristic", "error", err) + return router.ClassifyTask(prompt) + } + return task +} + // Reset clears conversation history and usage. func (e *Engine) Reset() { e.history = nil diff --git a/internal/engine/engine_test.go b/internal/engine/engine_test.go index 62e3dad..f33ce63 100644 --- a/internal/engine/engine_test.go +++ b/internal/engine/engine_test.go @@ -579,6 +579,94 @@ func TestSubmit_CumulativeUsage(t *testing.T) { } } +// spyClassifier records calls and delegates to HeuristicClassifier. +type spyClassifier struct { + calls int + result *router.Task // when non-nil, return this instead of heuristic result +} + +func (s *spyClassifier) Classify(ctx context.Context, prompt string, history []message.Message) (router.Task, error) { + s.calls++ + if s.result != nil { + return *s.result, nil + } + return router.HeuristicClassifier{}.Classify(ctx, prompt, history) +} + +func TestSubmit_UsesInjectedClassifier(t *testing.T) { + rtr := router.New(router.Config{}) + armID := router.NewArmID("test", "mock-model") + mp := &mockProvider{ + name: "test", + streams: []stream.Stream{ + newEventStream(message.StopEndTurn, "mock-model", + stream.Event{Type: stream.EventTextDelta, Text: "ok"}, + ), + }, + } + rtr.RegisterArm(&router.Arm{ + ID: armID, + Provider: mp, + ModelName: "mock-model", + Capabilities: provider.Capabilities{ToolUse: true}, + }) + rtr.ForceArm(armID) + + spy := &spyClassifier{} + e, err := New(Config{ + Provider: mp, + Router: rtr, + Tools: tool.NewRegistry(), + Classifier: spy, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + if _, err := e.Submit(context.Background(), "implement a parser", nil); err != nil { + t.Fatalf("Submit: %v", err) + } + + if spy.calls == 0 { + t.Error("expected Classify to be called at least once, got 0 calls") + } +} + +func TestSubmit_NilClassifierFallsBackToHeuristic(t *testing.T) { + rtr := router.New(router.Config{}) + armID := router.NewArmID("test", "mock-model") + mp := &mockProvider{ + name: "test", + streams: []stream.Stream{ + newEventStream(message.StopEndTurn, "mock-model", + stream.Event{Type: stream.EventTextDelta, Text: "ok"}, + ), + }, + } + rtr.RegisterArm(&router.Arm{ + ID: armID, + Provider: mp, + ModelName: "mock-model", + Capabilities: provider.Capabilities{ToolUse: true}, + }) + rtr.ForceArm(armID) + + // No Classifier set — should not panic, should use heuristic + e, err := New(Config{ + Provider: mp, + Router: rtr, + Tools: tool.NewRegistry(), + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + _, err = e.Submit(context.Background(), "debug the server crash", nil) + if err != nil { + t.Fatalf("Submit with nil Classifier: %v", err) + } +} + func TestSubmit_ReportsOutcomeToRouter(t *testing.T) { rtr := router.New(router.Config{}) armID := router.NewArmID("test", "mock-model") diff --git a/internal/engine/loop.go b/internal/engine/loop.go index 3f400c3..d4c114d 100644 --- a/internal/engine/loop.go +++ b/internal/engine/loop.go @@ -97,7 +97,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) { break } } - task := router.ClassifyTask(prompt) + task := e.classify(ctx, prompt) if e.cfg.Context != nil { task.EstimatedTokens = int(e.cfg.Context.Tracker().CountTokens(prompt)) } else { @@ -151,7 +151,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) { break } } - task := router.ClassifyTask(prompt) + task := e.classify(ctx, prompt) if e.cfg.Context != nil { task.EstimatedTokens = int(e.cfg.Context.Tracker().CountTokens(prompt)) } else { @@ -376,7 +376,7 @@ func (e *Engine) buildRequest(ctx context.Context) provider.Request { break } } - if router.ClassifyTask(prompt).Type == router.TaskOrchestration { + if e.classify(ctx, prompt).Type == router.TaskOrchestration { req.SystemPrompt = coordinatorPrompt() + "\n\n" + req.SystemPrompt } } @@ -596,7 +596,7 @@ func (e *Engine) handleRequestTooLarge(ctx context.Context, origErr error, req p break } } - task := router.ClassifyTask(prompt) + task := e.classify(ctx, prompt) if e.cfg.Context != nil { task.EstimatedTokens = int(e.cfg.Context.Tracker().CountTokens(prompt)) } else { diff --git a/internal/router/classifier.go b/internal/router/classifier.go new file mode 100644 index 0000000..46a196d --- /dev/null +++ b/internal/router/classifier.go @@ -0,0 +1,22 @@ +package router + +import ( + "context" + + "somegit.dev/Owlibou/gnoma/internal/message" +) + +// TaskClassifier classifies a user prompt into a Task for routing decisions. +// The history slice provides prior conversation context; implementations may +// ignore it (HeuristicClassifier) or use it for richer inference (SLMClassifier). +type TaskClassifier interface { + Classify(ctx context.Context, prompt string, history []message.Message) (Task, error) +} + +// HeuristicClassifier is the default classifier. It wraps the keyword-based +// ClassifyTask function and ignores conversation history. +type HeuristicClassifier struct{} + +func (HeuristicClassifier) Classify(_ context.Context, prompt string, _ []message.Message) (Task, error) { + return ClassifyTask(prompt), nil +} diff --git a/internal/router/classifier_test.go b/internal/router/classifier_test.go new file mode 100644 index 0000000..ed79b7d --- /dev/null +++ b/internal/router/classifier_test.go @@ -0,0 +1,68 @@ +package router + +import ( + "context" + "testing" + + "somegit.dev/Owlibou/gnoma/internal/message" +) + +// TestHeuristicClassifier_ParityWithClassifyTask verifies that +// HeuristicClassifier.Classify produces identical results to ClassifyTask. +func TestHeuristicClassifier_ParityWithClassifyTask(t *testing.T) { + prompts := []string{ + "debug the failing test", + "explain how generics work", + "implement a new HTTP handler", + "refactor the auth middleware", + "security audit the login flow", + "write unit tests for the parser", + "scaffold a new service", + "plan the migration strategy", + "orchestrate the deployment pipeline", + "review the pull request", + } + + cls := HeuristicClassifier{} + ctx := context.Background() + var noHistory []message.Message + + for _, p := range prompts { + want := ClassifyTask(p) + got, err := cls.Classify(ctx, p, noHistory) + if err != nil { + t.Errorf("Classify(%q) unexpected error: %v", p, err) + continue + } + if got.Type != want.Type { + t.Errorf("Classify(%q).Type = %s, want %s", p, got.Type, want.Type) + } + if got.ComplexityScore != want.ComplexityScore { + t.Errorf("Classify(%q).ComplexityScore = %v, want %v", p, got.ComplexityScore, want.ComplexityScore) + } + if got.RequiresTools != want.RequiresTools { + t.Errorf("Classify(%q).RequiresTools = %v, want %v", p, got.RequiresTools, want.RequiresTools) + } + if got.Priority != want.Priority { + t.Errorf("Classify(%q).Priority = %v, want %v", p, got.Priority, want.Priority) + } + } +} + +// TestHeuristicClassifier_IgnoresHistory verifies that history has no effect +// on the heuristic classifier (it operates only on the prompt). +func TestHeuristicClassifier_IgnoresHistory(t *testing.T) { + cls := HeuristicClassifier{} + ctx := context.Background() + prompt := "implement a binary search function" + + withoutHistory, _ := cls.Classify(ctx, prompt, nil) + withHistory, _ := cls.Classify(ctx, prompt, []message.Message{ + {Role: message.RoleUser, Content: []message.Content{{Type: message.ContentText, Text: "previous message"}}}, + }) + + if withoutHistory.Type != withHistory.Type { + t.Errorf("history should not affect HeuristicClassifier: got %s vs %s", + withoutHistory.Type, withHistory.Type) + } +}