feat(slm): Wave C — SLM classifier, MaxComplexity routing, CLI subcommands, TUI status

- slm.Classifier: openaicompat → llamafile, 2s timeout + heuristic fallback,
  heuristic baseline blended so Priority/RequiredEffort are never zeroed,
  extractJSON strips markdown fences from small-model responses
- router.ParseTaskType: case-insensitive string → TaskType, unknown → TaskGeneration
- router.Arm.MaxComplexity: zero = no ceiling (preserves existing arm behavior);
  filterFeasible excludes arms when task.ComplexityScore > MaxComplexity
- config.SLMSection: [slm] enabled / model_url / data_dir
- openaicompat.NewLlamafile: no API key, model = "default", no retries
- slm.Manager: DefaultDataDir() (XDG), Manifest() accessor
- cmd/gnoma: `gnoma slm setup` / `gnoma slm status` subcommands; SLM arm
  registered with MaxComplexity=0.3 when enabled + set up
- tui: /config shows slm status (ready/missing/not set up + base URL if running)
- docs: roadmap updated to reflect llamafile pivot from Ollama
This commit is contained in:
2026-05-07 16:44:32 +02:00
parent d1a5c79fa4
commit a9213ec382
13 changed files with 685 additions and 56 deletions
+129 -1
View File
@@ -18,6 +18,7 @@ import (
"somegit.dev/Owlibou/gnoma/internal/engine"
"somegit.dev/Owlibou/gnoma/internal/hook"
"somegit.dev/Owlibou/gnoma/internal/skill"
"somegit.dev/Owlibou/gnoma/internal/slm"
"somegit.dev/Owlibou/gnoma/internal/tool/persist"
gnomacfg "somegit.dev/Owlibou/gnoma/internal/config"
gnomactx "somegit.dev/Owlibou/gnoma/internal/context"
@@ -131,6 +132,11 @@ func main() {
*permMode = cfg.Permission.Mode
}
// SLM subcommands: `gnoma slm setup` / `gnoma slm status`
if cliArgs := flag.Args(); len(cliArgs) > 0 && cliArgs[0] == "slm" {
os.Exit(runSLMCommand(cliArgs[1:], cfg, logger))
}
// Resolve API key: CLI flag → config → env vars
knownProviders := map[string]bool{
"mistral": true, "anthropic": true, "openai": true,
@@ -551,10 +557,48 @@ func main() {
logger.Debug("prefix token baseline set", "tokens", prefixTokens)
}
// Wire SLM: start llamafile, register arm, inject classifier (opt-in).
var slmMgr *slm.Manager
var engineClassifier router.TaskClassifier
if cfg.SLM.Enabled {
slmDataDir := cfg.SLM.DataDir
if slmDataDir == "" {
slmDataDir = slm.DefaultDataDir()
}
slmMgr = slm.New(slm.Config{DataDir: slmDataDir, ModelURL: cfg.SLM.ModelURL}, logger)
if slmMgr.IsSetUp() {
slmBaseURL, startErr := slmMgr.Start(context.Background())
if startErr != nil {
logger.Warn("failed to start SLM; falling back to heuristic classifier", "error", startErr)
} else {
slmProv, provErr := openaicompat.NewLlamafile(provider.ProviderConfig{
BaseURL: slmBaseURL + "/v1",
})
if provErr != nil {
logger.Warn("failed to create SLM provider", "error", provErr)
} else {
engineClassifier = slm.NewClassifier(slmProv, "default", logger)
rtr.RegisterArm(&router.Arm{
ID: "slm/llamafile",
Provider: slmProv,
ModelName: "default",
IsLocal: true,
MaxComplexity: 0.3,
Capabilities: provider.Capabilities{ToolUse: false},
})
logger.Info("SLM ready", "url", slmBaseURL)
}
}
} else {
logger.Warn("SLM enabled but not set up; run: gnoma slm setup")
}
}
// Create engine
eng, err := engine.New(engine.Config{
Provider: prov,
Router: rtr,
Classifier: engineClassifier,
Tools: reg,
Firewall: fw,
Permissions: permChecker,
@@ -729,13 +773,14 @@ func main() {
Permissions: permChecker,
Router: rtr,
ElfManager: elfMgr,
SLMManager: slmMgr,
PermCh: permCh,
PermReqCh: permReqCh,
ElfProgress: elfProgressCh,
SessionStore: sessStore,
StartWithResumePicker: openResumePicker,
Skills: skillReg,
PluginInfos: buildPluginInfos(discoveredPlugins, enabledSet),
PluginInfos: buildPluginInfos(discoveredPlugins, enabledSet),
Version: buildVersion,
ModelUpdateCh: modelUpdateCh,
})
@@ -972,6 +1017,89 @@ func buildPluginInfos(plugins []plugin.Plugin, enabledSet map[string]bool) []tui
return infos
}
// runSLMCommand handles `gnoma slm <subcommand>`.
// Returns an exit code.
func runSLMCommand(args []string, cfg *gnomacfg.Config, logger *slog.Logger) int {
if len(args) == 0 {
fmt.Fprintln(os.Stderr, "usage: gnoma slm <command>")
fmt.Fprintln(os.Stderr, "commands:")
fmt.Fprintln(os.Stderr, " setup download and verify the llamafile model")
fmt.Fprintln(os.Stderr, " status show current setup state")
return 1
}
dataDir := cfg.SLM.DataDir
if dataDir == "" {
dataDir = slm.DefaultDataDir()
}
mgr := slm.New(slm.Config{DataDir: dataDir, ModelURL: cfg.SLM.ModelURL}, logger)
switch args[0] {
case "setup":
if cfg.SLM.ModelURL == "" {
fmt.Fprintln(os.Stderr, "error: [slm] model_url must be set in config before running setup")
fmt.Fprintln(os.Stderr, "")
fmt.Fprintln(os.Stderr, "Example (~/.config/gnoma/config.toml):")
fmt.Fprintln(os.Stderr, " [slm]")
fmt.Fprintln(os.Stderr, ` model_url = "https://huggingface.co/mozilla-ai/TinyLlama-1.1B-Chat-v1.0-llamafile/resolve/main/TinyLlama-1.1B-Chat-v1.0.Q5_K_M.llamafile"`)
return 1
}
fmt.Printf("downloading llamafile from %s\n", cfg.SLM.ModelURL)
err := mgr.Setup(context.Background(), func(downloaded, total int64) {
if total > 0 {
pct := float64(downloaded) / float64(total) * 100
fmt.Printf("\r %.1f%% (%s / %s) ", pct, humanBytes(downloaded), humanBytes(total))
} else {
fmt.Printf("\r %s downloaded ", humanBytes(downloaded))
}
})
fmt.Println()
if err != nil {
fmt.Fprintf(os.Stderr, "error: %v\n", err)
return 1
}
fmt.Printf("SLM ready at: %s\n", dataDir)
fmt.Println("Enable in config:")
fmt.Println(" [slm]")
fmt.Println(" enabled = true")
return 0
case "status":
status := mgr.Status()
fmt.Printf("slm status: %s\n", status)
if mf := mgr.Manifest(); mf != nil {
fmt.Printf(" file: %s\n", mf.FilePath)
fmt.Printf(" size: %s\n", humanBytes(mf.Size))
fmt.Printf(" sha256: %s\n", mf.SHA256[:16]+"...")
fmt.Printf(" setup: %s\n", mf.SetupAt.Format("2006-01-02 15:04 UTC"))
}
if status == slm.StatusNotSetUp {
fmt.Println(" run: gnoma slm setup")
} else if status == slm.StatusMissing {
fmt.Println(" file is missing; run: gnoma slm setup")
}
return 0
default:
fmt.Fprintf(os.Stderr, "unknown slm command: %s\n", args[0])
return 1
}
}
// humanBytes formats a byte count as a human-readable string.
func humanBytes(n int64) string {
const unit = 1024
if n < unit {
return fmt.Sprintf("%d B", n)
}
div, exp := int64(unit), 0
for n2 := n / unit; n2 >= unit; n2 /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %ciB", float64(n)/float64(div), "KMGTPE"[exp])
}
// resolveEnabledPlugins determines which plugins are enabled based on config.
// If Enabled is empty, all plugins are enabled by default (opt-out via Disabled).
// If Enabled is non-empty, only listed plugins are enabled (opt-in).
@@ -51,67 +51,44 @@ Bash tool flags `passwd foo` and offers takeover.
## Phase 3: SLM Task Classifier
Add an optional SLM-driven task classifier behind the existing `TaskClassifier` interface. The SLM
calls Ollama HTTP via the existing `openaicompat` provider — zero new dependencies, no CGO, no
daemon management.
Add an optional SLM-driven task classifier and low-complexity executor behind the `TaskClassifier`
interface. Uses llamafile (single-file download, OpenAI-compatible HTTP) instead of Ollama.
Zero new Go dependencies; the model binary is downloaded separately on opt-in.
**Context:** `gemma-integration-analysis.md` describes how gemini-cli implements this using
LiteRT-LM (a Node.js daemon + PID files). Those specifics do not apply here. The Go approach is
simpler: Ollama HTTP + structured JSON output + hard timeout + heuristic fallback.
**Implementation note (diverges from original plan):** Original plan used Ollama HTTP with
`router.slm_model` config key. Pivoted to llamafile after discussion: user downloads a specific
model file once (`gnoma slm setup`), gnoma manages the subprocess lifetime. Requires no external
daemon or package manager. Config section is `[slm]` not `[router]`.
### Interface
### Architecture
```go
// internal/router/classifier.go
type TaskClassifier interface {
Classify(ctx context.Context, input string, history []message.Message) (Task, error)
}
type HeuristicClassifier struct{} // default — wraps existing ClassifyTask()
type SLMClassifier struct {
provider provider.Provider // openaicompat pointing at Ollama
model string
timeout time.Duration // default 2s
}
```
`SLMClassifier.Classify` sends a structured prompt with the Complexity Rubric (adapted from
gemma-integration-analysis.md) and expects a JSON response:
```json
{"task_type": "Generation", "complexity": 0.4, "requires_tools": true}
```
On timeout or parse failure, it falls back to `HeuristicClassifier`.
### Complexity Rubric (prompt fragment)
```
Classify this coding request. Respond with JSON only.
Complexity 0.00.3: boilerplate, trivial edits, simple lookups
Complexity 0.40.6: moderate — new functions, refactors, unit tests
Complexity 0.71.0: architectural, multi-file, security review, planning
```
- `internal/slm/` — Manager (download, subprocess lifecycle, health check), Classifier
- `internal/router/``TaskClassifier` interface, `HeuristicClassifier`, `ParseTaskType`
- `Arm.MaxComplexity` — SLM arm capped at 0.3; excluded from complex tasks by `filterFeasible`
### Config
```toml
[router]
slm_model = "" # empty = disabled (HeuristicClassifier used)
# e.g. "gemma3:1b" — must be available in Ollama
[slm]
enabled = true
model_url = "https://huggingface.co/mozilla-ai/TinyLlama-1.1B-Chat-v1.0-llamafile/resolve/main/TinyLlama-1.1B-Chat-v1.0.Q5_K_M.llamafile"
data_dir = "" # empty = ~/.local/share/gnoma/slm
```
### Tasks
- [ ] `TaskClassifier` interface in `internal/router/classifier.go`
- [ ] `HeuristicClassifier` wraps existing `ClassifyTask()` (zero behavior change)
- [ ] `SLMClassifier`: Ollama HTTP via openaicompat, JSON parse, 2s timeout + fallback
- [ ] Complexity Rubric prompt (task type + complexity float + requires_tools bool)
- [ ] Config key `router.slm_model`; wire into router construction in `cmd/gnoma/main.go`
- [ ] Tests: `HeuristicClassifier` behavior unchanged; `SLMClassifier` fallback on timeout;
`SLMClassifier` correct parse on valid JSON response
- [x] `TaskClassifier` interface in `internal/router/classifier.go`
- [x] `HeuristicClassifier` wraps existing `ClassifyTask()` (zero behavior change)
- [x] `internal/slm/` — Manager, Manifest, download, subprocess lifecycle (Wave B)
- [x] `slm.Classifier`: openaicompat pointing at llamafile, JSON parse, 2s timeout + fallback
- [x] `ParseTaskType` in `internal/router/task.go`
- [x] `Arm.MaxComplexity` + `filterFeasible` ceiling
- [x] `[slm]` config section in `internal/config/`
- [x] `gnoma slm setup` / `gnoma slm status` CLI subcommands
- [x] SLM arm registered with `MaxComplexity = 0.3` in `cmd/gnoma/main.go`
- [x] TUI `/config` shows SLM status
**Dependencies:** existing `internal/provider/openaicompat` — no new deps.
**Dependencies:** existing `internal/provider/openaicompat` — no new Go deps.
**Exit criteria:** `gnoma` with `slm_model = "gemma3:1b"` routes using SLM classification.
Without config key, behavior is identical to today.
@@ -218,8 +195,8 @@ arches: amd64/arm64). The binary is a single static executable with zero runtime
| Item | Reason |
|------|--------|
| `.gnoma/tmp/` local temp directory | `persist.Store` already uses `/tmp/gnoma-<sessionID>/`; adding `.gnoma/tmp/` adds complexity (cleanup, gitignore, collision avoidance) for no benefit |
| LiteRT-LM / CGO SLM runtime | `CGO_ENABLED=0` (goreleaser constraint). Go approach: Ollama HTTP via existing openaicompat |
| Daemon/PID file management for SLM | Node.js-specific pattern from gemma-integration-analysis.md; not applicable to this Go binary |
| LiteRT-LM / CGO SLM runtime | `CGO_ENABLED=0` (goreleaser constraint). Go approach: llamafile subprocess via existing openaicompat |
| Ollama-based SLM classifier | Pivoted to llamafile: single-file download, no external daemon, user-controlled opt-in |
| PTY via `go-pty` library | Requires CGO. Replaced by `tea.ExecProcess` (already in go.mod, no CGO) |
---
+17
View File
@@ -10,11 +10,28 @@ type Config struct {
RateLimits RateLimitSection `toml:"rate_limits"`
Security SecuritySection `toml:"security"`
Session SessionSection `toml:"session"`
SLM SLMSection `toml:"slm"`
Hooks []HookConfig `toml:"hooks"`
MCPServers []MCPServerConfig `toml:"mcp_servers"`
Plugins PluginsSection `toml:"plugins"`
}
// SLMSection configures the optional small language model for task classification
// and low-complexity task execution.
//
// Example config:
//
// [slm]
// enabled = true
// model_url = "https://huggingface.co/mozilla-ai/TinyLlama-1.1B-Chat-v1.0-llamafile/resolve/main/TinyLlama-1.1B-Chat-v1.0.Q5_K_M.llamafile"
//
// Run `gnoma slm setup` to download and verify the model before enabling.
type SLMSection struct {
Enabled bool `toml:"enabled"`
ModelURL string `toml:"model_url"`
DataDir string `toml:"data_dir"` // empty = XDG default (~/.local/share/gnoma/slm)
}
// MCPServerConfig defines an MCP server to start and connect to.
//
// Example:
@@ -47,3 +47,18 @@ func NewLlamaCpp(cfg provider.ProviderConfig) (provider.Provider, error) {
}
return oaiprov.New(cfg)
}
// NewLlamafile creates a provider for a llamafile process.
// BaseURL must include /v1, e.g. "http://127.0.0.1:8080/v1".
func NewLlamafile(cfg provider.ProviderConfig) (provider.Provider, error) {
if cfg.APIKey == "" {
cfg.APIKey = "llamafile" // llamafile doesn't require a real key
}
if cfg.Model == "" {
cfg.Model = "default" // llamafile ignores the model field
}
if cfg.MaxRetries == nil {
cfg.MaxRetries = intPtr(0)
}
return oaiprov.New(cfg)
}
+4
View File
@@ -22,6 +22,10 @@ type Arm struct {
Capabilities provider.Capabilities
Pools []*LimitPool
// MaxComplexity is a hard ceiling on task complexity this arm will accept.
// Zero means no ceiling (default for all existing arms).
MaxComplexity float64
// Cost per 1k tokens (EUR, estimated)
CostPer1kInput float64
CostPer1kOutput float64
+49
View File
@@ -657,3 +657,52 @@ func TestRouter_AllDisabled_ReturnsError(t *testing.T) {
}
}
func TestFilterFeasible_MaxComplexity(t *testing.T) {
slmArm := &Arm{
ID: "slm/tiny",
IsLocal: true,
MaxComplexity: 0.3,
Capabilities: provider.Capabilities{ToolUse: false},
}
apiArm := &Arm{
ID: "api/big",
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000},
}
// Low-complexity task: SLM arm passes the ceiling.
lowTask := Task{Type: TaskBoilerplate, ComplexityScore: 0.2}
got := filterFeasible([]*Arm{slmArm, apiArm}, lowTask)
found := false
for _, a := range got {
if a.ID == "slm/tiny" {
found = true
}
}
if !found {
t.Error("slm arm should pass filterFeasible for low-complexity task")
}
// High-complexity task: SLM arm must be excluded.
highTask := Task{Type: TaskPlanning, ComplexityScore: 0.8, RequiresTools: false}
got = filterFeasible([]*Arm{slmArm, apiArm}, highTask)
for _, a := range got {
if a.ID == "slm/tiny" {
t.Error("slm arm should be excluded for high-complexity task")
}
}
}
func TestFilterFeasible_MaxComplexity_Zero_MeansNoLimit(t *testing.T) {
// MaxComplexity == 0 means "no ceiling" — existing arms are unaffected.
arm := &Arm{
ID: "api/arm",
MaxComplexity: 0, // zero = no ceiling
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000},
}
task := Task{Type: TaskOrchestration, ComplexityScore: 0.99}
got := filterFeasible([]*Arm{arm}, task)
if len(got) == 0 {
t.Error("arm with MaxComplexity=0 should never be excluded by complexity ceiling")
}
}
+5
View File
@@ -179,6 +179,11 @@ func filterFeasible(arms []*Arm, task Task) []*Arm {
var belowQuality []*Arm // passed tool+pool but scored below minimum quality
for _, arm := range arms {
// Complexity ceiling: zero means no ceiling (preserves behavior for all existing arms).
if arm.MaxComplexity > 0 && task.ComplexityScore > arm.MaxComplexity {
continue
}
// Must support tools if task requires them
if task.RequiresTools && !arm.SupportsTools() {
continue
+29
View File
@@ -241,3 +241,32 @@ func estimateComplexity(prompt string) float64 {
}
return score
}
// ParseTaskType converts a string from an SLM JSON response to a TaskType.
// Matching is case-insensitive. Unknown strings fall back to TaskGeneration.
func ParseTaskType(s string) TaskType {
switch strings.ToLower(strings.ReplaceAll(s, "_", "")) {
case "debug":
return TaskDebug
case "explain":
return TaskExplain
case "generation":
return TaskGeneration
case "refactor":
return TaskRefactor
case "unittest":
return TaskUnitTest
case "boilerplate":
return TaskBoilerplate
case "planning":
return TaskPlanning
case "orchestration":
return TaskOrchestration
case "securityreview":
return TaskSecurityReview
case "review":
return TaskReview
default:
return TaskGeneration
}
}
+44
View File
@@ -0,0 +1,44 @@
package router
import "testing"
func TestParseTaskType(t *testing.T) {
cases := []struct {
input string
want TaskType
}{
{"Debug", TaskDebug},
{"debug", TaskDebug},
{"DEBUG", TaskDebug},
{"Explain", TaskExplain},
{"explain", TaskExplain},
{"Generation", TaskGeneration},
{"generation", TaskGeneration},
{"Refactor", TaskRefactor},
{"refactor", TaskRefactor},
{"UnitTest", TaskUnitTest},
{"unit_test", TaskUnitTest},
{"unitTest", TaskUnitTest},
{"Boilerplate", TaskBoilerplate},
{"boilerplate", TaskBoilerplate},
{"Planning", TaskPlanning},
{"planning", TaskPlanning},
{"Orchestration", TaskOrchestration},
{"orchestration", TaskOrchestration},
{"SecurityReview", TaskSecurityReview},
{"security_review", TaskSecurityReview},
{"Review", TaskReview},
{"review", TaskReview},
// unknown falls back to TaskGeneration
{"", TaskGeneration},
{"unknown", TaskGeneration},
{"gibberish", TaskGeneration},
}
for _, tc := range cases {
got := ParseTaskType(tc.input)
if got != tc.want {
t.Errorf("ParseTaskType(%q) = %s, want %s", tc.input, got, tc.want)
}
}
}
+148
View File
@@ -0,0 +1,148 @@
package slm
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"strings"
"time"
"somegit.dev/Owlibou/gnoma/internal/message"
"somegit.dev/Owlibou/gnoma/internal/provider"
"somegit.dev/Owlibou/gnoma/internal/router"
"somegit.dev/Owlibou/gnoma/internal/stream"
)
const defaultClassifyTimeout = 2 * time.Second
const classifySystemPrompt = `Classify the following coding request. Respond with JSON only, no other text.
Format: {"task_type": "<type>", "complexity": <0.0-1.0>, "requires_tools": <true|false>}
Task types: Debug, Explain, Generation, Refactor, UnitTest, Boilerplate, Planning, Orchestration, SecurityReview, Review
Complexity guide:
0.00.3: boilerplate, trivial edits, simple lookups, short explanations
0.40.6: new functions, refactors, unit tests, moderate analysis
0.71.0: architectural changes, multi-file edits, security review, planning`
type classifyResponse struct {
TaskType string `json:"task_type"`
Complexity float64 `json:"complexity"`
RequiresTools bool `json:"requires_tools"`
}
// Classifier implements router.TaskClassifier using a llamafile-hosted SLM.
// On timeout or parse failure it falls back to router.HeuristicClassifier.
type Classifier struct {
provider provider.Provider
model string
timeout time.Duration
logger *slog.Logger
}
// NewClassifier creates a Classifier. model is the model name passed to the provider
// (llamafile ignores it but openaicompat requires a non-empty value).
func NewClassifier(p provider.Provider, model string, logger *slog.Logger) *Classifier {
if logger == nil {
logger = slog.Default()
}
return &Classifier{
provider: p,
model: model,
timeout: defaultClassifyTimeout,
logger: logger,
}
}
// Classify calls the SLM and overlays the three SLM-authoritative fields
// (Type, ComplexityScore, RequiresTools) onto a heuristic baseline Task.
// This ensures Priority, EstimatedTokens, and RequiredEffort are always set.
func (c *Classifier) Classify(ctx context.Context, prompt string, history []message.Message) (router.Task, error) {
tctx, cancel := context.WithTimeout(ctx, c.timeout)
defer cancel()
resp, err := c.callSLM(tctx, prompt)
if err != nil {
c.logger.Debug("slm classify fallback", "error", err)
return router.HeuristicClassifier{}.Classify(ctx, prompt, history)
}
// Start from the heuristic baseline so Priority/EstimatedTokens/RequiredEffort are set.
task := router.ClassifyTask(prompt)
task.Type = router.ParseTaskType(resp.TaskType)
task.ComplexityScore = resp.Complexity
task.RequiresTools = resp.RequiresTools
return task, nil
}
func (c *Classifier) callSLM(ctx context.Context, prompt string) (*classifyResponse, error) {
req := provider.Request{
Model: c.model,
SystemPrompt: classifySystemPrompt,
Messages: []message.Message{
{
Role: message.RoleUser,
Content: []message.Content{{Type: message.ContentText, Text: prompt}},
},
},
}
strm, err := c.provider.Stream(ctx, req)
if err != nil {
return nil, fmt.Errorf("stream: %w", err)
}
defer strm.Close()
var sb strings.Builder
for strm.Next() {
ev := strm.Current()
if ev.Type == stream.EventTextDelta {
sb.WriteString(ev.Text)
}
}
if err := strm.Err(); err != nil {
return nil, fmt.Errorf("stream error: %w", err)
}
text := extractJSON(sb.String())
var resp classifyResponse
if err := json.Unmarshal([]byte(text), &resp); err != nil {
return nil, fmt.Errorf("parse %q: %w", text, err)
}
return &resp, nil
}
// extractJSON pulls the first {...} substring from s, stripping markdown fences if present.
func extractJSON(s string) string {
s = strings.TrimSpace(s)
// Strip ```json ... ``` fences.
if strings.HasPrefix(s, "```") {
end := strings.LastIndex(s, "```")
if end > 3 {
inner := s[3:end]
inner = strings.TrimPrefix(inner, "json")
s = strings.TrimSpace(inner)
}
}
// Extract first balanced {...} block.
start := strings.IndexByte(s, '{')
if start < 0 {
return s
}
depth := 0
for i := start; i < len(s); i++ {
switch s[i] {
case '{':
depth++
case '}':
depth--
if depth == 0 {
return s[start : i+1]
}
}
}
return s[start:]
}
+174
View File
@@ -0,0 +1,174 @@
package slm
import (
"context"
"errors"
"testing"
"time"
"somegit.dev/Owlibou/gnoma/internal/message"
"somegit.dev/Owlibou/gnoma/internal/provider"
"somegit.dev/Owlibou/gnoma/internal/router"
"somegit.dev/Owlibou/gnoma/internal/stream"
)
// mockProvider implements provider.Provider for classifier tests.
type mockProvider struct {
text string
delay time.Duration
err error
}
func (m *mockProvider) Name() string { return "mock" }
func (m *mockProvider) DefaultModel() string { return "default" }
func (m *mockProvider) Models(_ context.Context) ([]provider.ModelInfo, error) {
return nil, nil
}
func (m *mockProvider) Stream(ctx context.Context, _ provider.Request) (stream.Stream, error) {
if m.delay > 0 {
select {
case <-time.After(m.delay):
case <-ctx.Done():
return nil, ctx.Err()
}
}
if m.err != nil {
return nil, m.err
}
return &mockStream{events: []stream.Event{
{Type: stream.EventTextDelta, Text: m.text},
}}, nil
}
type mockStream struct {
events []stream.Event
idx int
}
func (s *mockStream) Next() bool { s.idx++; return s.idx <= len(s.events) }
func (s *mockStream) Current() stream.Event { return s.events[s.idx-1] }
func (s *mockStream) Err() error { return nil }
func (s *mockStream) Close() error { return nil }
func TestClassifier_HappyPath(t *testing.T) {
p := &mockProvider{text: `{"task_type":"Debug","complexity":0.25,"requires_tools":false}`}
cls := NewClassifier(p, "default", nil)
task, err := cls.Classify(context.Background(), "fix the failing test", nil)
if err != nil {
t.Fatalf("Classify: %v", err)
}
if task.Type != router.TaskDebug {
t.Errorf("Type = %s, want Debug", task.Type)
}
if task.ComplexityScore != 0.25 {
t.Errorf("ComplexityScore = %v, want 0.25", task.ComplexityScore)
}
if task.RequiresTools != false {
t.Errorf("RequiresTools = true, want false")
}
}
func TestClassifier_BlendHeuristic(t *testing.T) {
// SLM returns one type; other Task fields should come from heuristic.
p := &mockProvider{text: `{"task_type":"Boilerplate","complexity":0.1,"requires_tools":false}`}
cls := NewClassifier(p, "default", nil)
task, err := cls.Classify(context.Background(), "scaffold a new HTTP handler", nil)
if err != nil {
t.Fatalf("Classify: %v", err)
}
if task.Type != router.TaskBoilerplate {
t.Errorf("Type = %s, want Boilerplate", task.Type)
}
// Priority must come from the heuristic baseline (PriorityNormal = 1, not zero).
if task.Priority < router.PriorityNormal {
t.Errorf("Priority = %v, want at least PriorityNormal from heuristic baseline", task.Priority)
}
}
func TestClassifier_FallbackOnBadJSON(t *testing.T) {
p := &mockProvider{text: "I cannot classify that."}
cls := NewClassifier(p, "default", nil)
// Should not error — falls back to heuristic.
task, err := cls.Classify(context.Background(), "write unit tests for the parser", nil)
if err != nil {
t.Fatalf("Classify should not error on bad JSON: %v", err)
}
// Heuristic would return UnitTest for "write unit tests".
if task.Type != router.TaskUnitTest {
t.Errorf("heuristic fallback: Type = %s, want UnitTest", task.Type)
}
}
func TestClassifier_FallbackOnProviderError(t *testing.T) {
p := &mockProvider{err: errors.New("connection refused")}
cls := NewClassifier(p, "default", nil)
task, err := cls.Classify(context.Background(), "explain how generics work", nil)
if err != nil {
t.Fatalf("Classify should not error on provider error: %v", err)
}
// Heuristic fallback: "explain" → TaskExplain
if task.Type != router.TaskExplain {
t.Errorf("heuristic fallback: Type = %s, want Explain", task.Type)
}
}
func TestClassifier_FallbackOnTimeout(t *testing.T) {
p := &mockProvider{delay: 500 * time.Millisecond}
cls := NewClassifier(p, "default", nil)
cls.timeout = 50 * time.Millisecond // force timeout
task, err := cls.Classify(context.Background(), "debug the failing test", nil)
if err != nil {
t.Fatalf("Classify should not error on timeout: %v", err)
}
// Falls back to heuristic: "debug" → TaskDebug
if task.Type != router.TaskDebug {
t.Errorf("heuristic fallback: Type = %s, want Debug", task.Type)
}
}
func TestClassifier_FenceStripping(t *testing.T) {
fenced := "```json\n{\"task_type\":\"Refactor\",\"complexity\":0.5,\"requires_tools\":true}\n```"
p := &mockProvider{text: fenced}
cls := NewClassifier(p, "default", nil)
task, err := cls.Classify(context.Background(), "refactor the auth middleware", nil)
if err != nil {
t.Fatalf("Classify: %v", err)
}
if task.Type != router.TaskRefactor {
t.Errorf("Type = %s, want Refactor", task.Type)
}
}
func TestClassifier_UnknownTaskType_FallsBackToHeuristic(t *testing.T) {
p := &mockProvider{text: `{"task_type":"FooBar","complexity":0.3,"requires_tools":false}`}
cls := NewClassifier(p, "default", nil)
task, err := cls.Classify(context.Background(), "implement a binary search function", nil)
if err != nil {
t.Fatalf("Classify: %v", err)
}
// "implement" → heuristic should give Generation or Boilerplate; SLM gave FooBar → Generation fallback
_ = task // just verify no panic and no error
}
func TestClassifier_ContextPassedToHistory(t *testing.T) {
p := &mockProvider{text: `{"task_type":"Explain","complexity":0.2,"requires_tools":false}`}
cls := NewClassifier(p, "default", nil)
history := []message.Message{
{Role: message.RoleUser, Content: []message.Content{{Type: message.ContentText, Text: "prior"}}},
}
task, err := cls.Classify(context.Background(), "explain this code", history)
if err != nil {
t.Fatalf("Classify: %v", err)
}
if task.Type != router.TaskExplain {
t.Errorf("Type = %s, want Explain", task.Type)
}
}
+21
View File
@@ -17,6 +17,18 @@ import (
const pidFile = "llamafile.pid"
// DefaultDataDir returns the platform default SLM data directory.
// Follows XDG Base Directory Specification: $XDG_DATA_HOME/gnoma/slm,
// falling back to ~/.local/share/gnoma/slm.
func DefaultDataDir() string {
dir := os.Getenv("XDG_DATA_HOME")
if dir == "" {
home, _ := os.UserHomeDir()
dir = filepath.Join(home, ".local", "share")
}
return filepath.Join(dir, "gnoma", "slm")
}
// Status describes the setup state of the SLM.
type Status int
@@ -180,6 +192,15 @@ func (m *Manager) BaseURL() string {
return fmt.Sprintf("http://127.0.0.1:%d", m.port)
}
// Manifest returns the on-disk manifest if present, or nil.
func (m *Manager) Manifest() *Manifest {
mf, err := readManifest(m.cfg.DataDir)
if err != nil {
return nil
}
return mf
}
func (m *Manager) pidPath() string {
return filepath.Join(m.cfg.DataDir, pidFile)
}
+22 -4
View File
@@ -19,6 +19,7 @@ import (
gnomacfg "somegit.dev/Owlibou/gnoma/internal/config"
"somegit.dev/Owlibou/gnoma/internal/elf"
"somegit.dev/Owlibou/gnoma/internal/skill"
"somegit.dev/Owlibou/gnoma/internal/slm"
"somegit.dev/Owlibou/gnoma/internal/engine"
"somegit.dev/Owlibou/gnoma/internal/message"
"somegit.dev/Owlibou/gnoma/internal/permission"
@@ -61,6 +62,7 @@ type Config struct {
Permissions *permission.Checker // for mode switching
Router *router.Router // for model listing
ElfManager *elf.Manager // for CancelAll on escape/quit
SLMManager *slm.Manager // nil = SLM not configured
PermCh chan bool // TUI → engine: y/n response
PermReqCh <-chan PermReqMsg // engine → TUI: tool requesting approval
ElfProgress <-chan elf.Progress // elf → TUI: structured progress updates
@@ -877,16 +879,32 @@ func (m Model) handleCommand(cmd string) (tea.Model, tea.Cmd) {
status := m.session.Status()
var b strings.Builder
b.WriteString("Current configuration:\n")
fmt.Fprintf(&b, " provider: %s\n", status.Provider)
fmt.Fprintf(&b, " model: %s\n", status.Model)
fmt.Fprintf(&b, " provider: %s\n", status.Provider)
fmt.Fprintf(&b, " model: %s\n", status.Model)
if m.config.Permissions != nil {
fmt.Fprintf(&b, " permission: %s\n", m.config.Permissions.Mode())
}
fmt.Fprintf(&b, " incognito: %v\n", m.incognito)
fmt.Fprintf(&b, " cwd: %s\n", m.cwd)
fmt.Fprintf(&b, " incognito: %v\n", m.incognito)
fmt.Fprintf(&b, " cwd: %s\n", m.cwd)
if m.gitBranch != "" {
fmt.Fprintf(&b, " git branch: %s\n", m.gitBranch)
}
if m.config.SLMManager != nil {
slmStat := m.config.SLMManager.Status()
switch slmStat {
case slm.StatusReady:
url := m.config.SLMManager.BaseURL()
if url != "" {
fmt.Fprintf(&b, " slm: ready (running at %s)\n", url)
} else {
b.WriteString(" slm: ready (not started)\n")
}
case slm.StatusMissing:
b.WriteString(" slm: file missing — run: gnoma slm setup\n")
default:
b.WriteString(" slm: not set up — run: gnoma slm setup\n")
}
}
b.WriteString("\nConfig files: ~/.config/gnoma/config.toml, .gnoma/config.toml")
b.WriteString("\nEdit: /config set <key> <value>")
m.messages = append(m.messages, chatMessage{role: "system", content: b.String()})