feat(slm): Wave C — SLM classifier, MaxComplexity routing, CLI subcommands, TUI status
- slm.Classifier: openaicompat → llamafile, 2s timeout + heuristic fallback, heuristic baseline blended so Priority/RequiredEffort are never zeroed, extractJSON strips markdown fences from small-model responses - router.ParseTaskType: case-insensitive string → TaskType, unknown → TaskGeneration - router.Arm.MaxComplexity: zero = no ceiling (preserves existing arm behavior); filterFeasible excludes arms when task.ComplexityScore > MaxComplexity - config.SLMSection: [slm] enabled / model_url / data_dir - openaicompat.NewLlamafile: no API key, model = "default", no retries - slm.Manager: DefaultDataDir() (XDG), Manifest() accessor - cmd/gnoma: `gnoma slm setup` / `gnoma slm status` subcommands; SLM arm registered with MaxComplexity=0.3 when enabled + set up - tui: /config shows slm status (ready/missing/not set up + base URL if running) - docs: roadmap updated to reflect llamafile pivot from Ollama
This commit is contained in:
+129
-1
@@ -18,6 +18,7 @@ import (
|
||||
"somegit.dev/Owlibou/gnoma/internal/engine"
|
||||
"somegit.dev/Owlibou/gnoma/internal/hook"
|
||||
"somegit.dev/Owlibou/gnoma/internal/skill"
|
||||
"somegit.dev/Owlibou/gnoma/internal/slm"
|
||||
"somegit.dev/Owlibou/gnoma/internal/tool/persist"
|
||||
gnomacfg "somegit.dev/Owlibou/gnoma/internal/config"
|
||||
gnomactx "somegit.dev/Owlibou/gnoma/internal/context"
|
||||
@@ -131,6 +132,11 @@ func main() {
|
||||
*permMode = cfg.Permission.Mode
|
||||
}
|
||||
|
||||
// SLM subcommands: `gnoma slm setup` / `gnoma slm status`
|
||||
if cliArgs := flag.Args(); len(cliArgs) > 0 && cliArgs[0] == "slm" {
|
||||
os.Exit(runSLMCommand(cliArgs[1:], cfg, logger))
|
||||
}
|
||||
|
||||
// Resolve API key: CLI flag → config → env vars
|
||||
knownProviders := map[string]bool{
|
||||
"mistral": true, "anthropic": true, "openai": true,
|
||||
@@ -551,10 +557,48 @@ func main() {
|
||||
logger.Debug("prefix token baseline set", "tokens", prefixTokens)
|
||||
}
|
||||
|
||||
// Wire SLM: start llamafile, register arm, inject classifier (opt-in).
|
||||
var slmMgr *slm.Manager
|
||||
var engineClassifier router.TaskClassifier
|
||||
if cfg.SLM.Enabled {
|
||||
slmDataDir := cfg.SLM.DataDir
|
||||
if slmDataDir == "" {
|
||||
slmDataDir = slm.DefaultDataDir()
|
||||
}
|
||||
slmMgr = slm.New(slm.Config{DataDir: slmDataDir, ModelURL: cfg.SLM.ModelURL}, logger)
|
||||
if slmMgr.IsSetUp() {
|
||||
slmBaseURL, startErr := slmMgr.Start(context.Background())
|
||||
if startErr != nil {
|
||||
logger.Warn("failed to start SLM; falling back to heuristic classifier", "error", startErr)
|
||||
} else {
|
||||
slmProv, provErr := openaicompat.NewLlamafile(provider.ProviderConfig{
|
||||
BaseURL: slmBaseURL + "/v1",
|
||||
})
|
||||
if provErr != nil {
|
||||
logger.Warn("failed to create SLM provider", "error", provErr)
|
||||
} else {
|
||||
engineClassifier = slm.NewClassifier(slmProv, "default", logger)
|
||||
rtr.RegisterArm(&router.Arm{
|
||||
ID: "slm/llamafile",
|
||||
Provider: slmProv,
|
||||
ModelName: "default",
|
||||
IsLocal: true,
|
||||
MaxComplexity: 0.3,
|
||||
Capabilities: provider.Capabilities{ToolUse: false},
|
||||
})
|
||||
logger.Info("SLM ready", "url", slmBaseURL)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logger.Warn("SLM enabled but not set up; run: gnoma slm setup")
|
||||
}
|
||||
}
|
||||
|
||||
// Create engine
|
||||
eng, err := engine.New(engine.Config{
|
||||
Provider: prov,
|
||||
Router: rtr,
|
||||
Classifier: engineClassifier,
|
||||
Tools: reg,
|
||||
Firewall: fw,
|
||||
Permissions: permChecker,
|
||||
@@ -729,13 +773,14 @@ func main() {
|
||||
Permissions: permChecker,
|
||||
Router: rtr,
|
||||
ElfManager: elfMgr,
|
||||
SLMManager: slmMgr,
|
||||
PermCh: permCh,
|
||||
PermReqCh: permReqCh,
|
||||
ElfProgress: elfProgressCh,
|
||||
SessionStore: sessStore,
|
||||
StartWithResumePicker: openResumePicker,
|
||||
Skills: skillReg,
|
||||
PluginInfos: buildPluginInfos(discoveredPlugins, enabledSet),
|
||||
PluginInfos: buildPluginInfos(discoveredPlugins, enabledSet),
|
||||
Version: buildVersion,
|
||||
ModelUpdateCh: modelUpdateCh,
|
||||
})
|
||||
@@ -972,6 +1017,89 @@ func buildPluginInfos(plugins []plugin.Plugin, enabledSet map[string]bool) []tui
|
||||
return infos
|
||||
}
|
||||
|
||||
// runSLMCommand handles `gnoma slm <subcommand>`.
|
||||
// Returns an exit code.
|
||||
func runSLMCommand(args []string, cfg *gnomacfg.Config, logger *slog.Logger) int {
|
||||
if len(args) == 0 {
|
||||
fmt.Fprintln(os.Stderr, "usage: gnoma slm <command>")
|
||||
fmt.Fprintln(os.Stderr, "commands:")
|
||||
fmt.Fprintln(os.Stderr, " setup download and verify the llamafile model")
|
||||
fmt.Fprintln(os.Stderr, " status show current setup state")
|
||||
return 1
|
||||
}
|
||||
|
||||
dataDir := cfg.SLM.DataDir
|
||||
if dataDir == "" {
|
||||
dataDir = slm.DefaultDataDir()
|
||||
}
|
||||
mgr := slm.New(slm.Config{DataDir: dataDir, ModelURL: cfg.SLM.ModelURL}, logger)
|
||||
|
||||
switch args[0] {
|
||||
case "setup":
|
||||
if cfg.SLM.ModelURL == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: [slm] model_url must be set in config before running setup")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
fmt.Fprintln(os.Stderr, "Example (~/.config/gnoma/config.toml):")
|
||||
fmt.Fprintln(os.Stderr, " [slm]")
|
||||
fmt.Fprintln(os.Stderr, ` model_url = "https://huggingface.co/mozilla-ai/TinyLlama-1.1B-Chat-v1.0-llamafile/resolve/main/TinyLlama-1.1B-Chat-v1.0.Q5_K_M.llamafile"`)
|
||||
return 1
|
||||
}
|
||||
fmt.Printf("downloading llamafile from %s\n", cfg.SLM.ModelURL)
|
||||
err := mgr.Setup(context.Background(), func(downloaded, total int64) {
|
||||
if total > 0 {
|
||||
pct := float64(downloaded) / float64(total) * 100
|
||||
fmt.Printf("\r %.1f%% (%s / %s) ", pct, humanBytes(downloaded), humanBytes(total))
|
||||
} else {
|
||||
fmt.Printf("\r %s downloaded ", humanBytes(downloaded))
|
||||
}
|
||||
})
|
||||
fmt.Println()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: %v\n", err)
|
||||
return 1
|
||||
}
|
||||
fmt.Printf("SLM ready at: %s\n", dataDir)
|
||||
fmt.Println("Enable in config:")
|
||||
fmt.Println(" [slm]")
|
||||
fmt.Println(" enabled = true")
|
||||
return 0
|
||||
|
||||
case "status":
|
||||
status := mgr.Status()
|
||||
fmt.Printf("slm status: %s\n", status)
|
||||
if mf := mgr.Manifest(); mf != nil {
|
||||
fmt.Printf(" file: %s\n", mf.FilePath)
|
||||
fmt.Printf(" size: %s\n", humanBytes(mf.Size))
|
||||
fmt.Printf(" sha256: %s\n", mf.SHA256[:16]+"...")
|
||||
fmt.Printf(" setup: %s\n", mf.SetupAt.Format("2006-01-02 15:04 UTC"))
|
||||
}
|
||||
if status == slm.StatusNotSetUp {
|
||||
fmt.Println(" run: gnoma slm setup")
|
||||
} else if status == slm.StatusMissing {
|
||||
fmt.Println(" file is missing; run: gnoma slm setup")
|
||||
}
|
||||
return 0
|
||||
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "unknown slm command: %s\n", args[0])
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
// humanBytes formats a byte count as a human-readable string.
|
||||
func humanBytes(n int64) string {
|
||||
const unit = 1024
|
||||
if n < unit {
|
||||
return fmt.Sprintf("%d B", n)
|
||||
}
|
||||
div, exp := int64(unit), 0
|
||||
for n2 := n / unit; n2 >= unit; n2 /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %ciB", float64(n)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
|
||||
// resolveEnabledPlugins determines which plugins are enabled based on config.
|
||||
// If Enabled is empty, all plugins are enabled by default (opt-out via Disabled).
|
||||
// If Enabled is non-empty, only listed plugins are enabled (opt-in).
|
||||
|
||||
@@ -51,67 +51,44 @@ Bash tool flags `passwd foo` and offers takeover.
|
||||
|
||||
## Phase 3: SLM Task Classifier
|
||||
|
||||
Add an optional SLM-driven task classifier behind the existing `TaskClassifier` interface. The SLM
|
||||
calls Ollama HTTP via the existing `openaicompat` provider — zero new dependencies, no CGO, no
|
||||
daemon management.
|
||||
Add an optional SLM-driven task classifier and low-complexity executor behind the `TaskClassifier`
|
||||
interface. Uses llamafile (single-file download, OpenAI-compatible HTTP) instead of Ollama.
|
||||
Zero new Go dependencies; the model binary is downloaded separately on opt-in.
|
||||
|
||||
**Context:** `gemma-integration-analysis.md` describes how gemini-cli implements this using
|
||||
LiteRT-LM (a Node.js daemon + PID files). Those specifics do not apply here. The Go approach is
|
||||
simpler: Ollama HTTP + structured JSON output + hard timeout + heuristic fallback.
|
||||
**Implementation note (diverges from original plan):** Original plan used Ollama HTTP with
|
||||
`router.slm_model` config key. Pivoted to llamafile after discussion: user downloads a specific
|
||||
model file once (`gnoma slm setup`), gnoma manages the subprocess lifetime. Requires no external
|
||||
daemon or package manager. Config section is `[slm]` not `[router]`.
|
||||
|
||||
### Interface
|
||||
### Architecture
|
||||
|
||||
```go
|
||||
// internal/router/classifier.go
|
||||
type TaskClassifier interface {
|
||||
Classify(ctx context.Context, input string, history []message.Message) (Task, error)
|
||||
}
|
||||
|
||||
type HeuristicClassifier struct{} // default — wraps existing ClassifyTask()
|
||||
type SLMClassifier struct {
|
||||
provider provider.Provider // openaicompat pointing at Ollama
|
||||
model string
|
||||
timeout time.Duration // default 2s
|
||||
}
|
||||
```
|
||||
|
||||
`SLMClassifier.Classify` sends a structured prompt with the Complexity Rubric (adapted from
|
||||
gemma-integration-analysis.md) and expects a JSON response:
|
||||
|
||||
```json
|
||||
{"task_type": "Generation", "complexity": 0.4, "requires_tools": true}
|
||||
```
|
||||
|
||||
On timeout or parse failure, it falls back to `HeuristicClassifier`.
|
||||
|
||||
### Complexity Rubric (prompt fragment)
|
||||
|
||||
```
|
||||
Classify this coding request. Respond with JSON only.
|
||||
Complexity 0.0–0.3: boilerplate, trivial edits, simple lookups
|
||||
Complexity 0.4–0.6: moderate — new functions, refactors, unit tests
|
||||
Complexity 0.7–1.0: architectural, multi-file, security review, planning
|
||||
```
|
||||
- `internal/slm/` — Manager (download, subprocess lifecycle, health check), Classifier
|
||||
- `internal/router/` — `TaskClassifier` interface, `HeuristicClassifier`, `ParseTaskType`
|
||||
- `Arm.MaxComplexity` — SLM arm capped at 0.3; excluded from complex tasks by `filterFeasible`
|
||||
|
||||
### Config
|
||||
|
||||
```toml
|
||||
[router]
|
||||
slm_model = "" # empty = disabled (HeuristicClassifier used)
|
||||
# e.g. "gemma3:1b" — must be available in Ollama
|
||||
[slm]
|
||||
enabled = true
|
||||
model_url = "https://huggingface.co/mozilla-ai/TinyLlama-1.1B-Chat-v1.0-llamafile/resolve/main/TinyLlama-1.1B-Chat-v1.0.Q5_K_M.llamafile"
|
||||
data_dir = "" # empty = ~/.local/share/gnoma/slm
|
||||
```
|
||||
|
||||
### Tasks
|
||||
|
||||
- [ ] `TaskClassifier` interface in `internal/router/classifier.go`
|
||||
- [ ] `HeuristicClassifier` wraps existing `ClassifyTask()` (zero behavior change)
|
||||
- [ ] `SLMClassifier`: Ollama HTTP via openaicompat, JSON parse, 2s timeout + fallback
|
||||
- [ ] Complexity Rubric prompt (task type + complexity float + requires_tools bool)
|
||||
- [ ] Config key `router.slm_model`; wire into router construction in `cmd/gnoma/main.go`
|
||||
- [ ] Tests: `HeuristicClassifier` behavior unchanged; `SLMClassifier` fallback on timeout;
|
||||
`SLMClassifier` correct parse on valid JSON response
|
||||
- [x] `TaskClassifier` interface in `internal/router/classifier.go`
|
||||
- [x] `HeuristicClassifier` wraps existing `ClassifyTask()` (zero behavior change)
|
||||
- [x] `internal/slm/` — Manager, Manifest, download, subprocess lifecycle (Wave B)
|
||||
- [x] `slm.Classifier`: openaicompat pointing at llamafile, JSON parse, 2s timeout + fallback
|
||||
- [x] `ParseTaskType` in `internal/router/task.go`
|
||||
- [x] `Arm.MaxComplexity` + `filterFeasible` ceiling
|
||||
- [x] `[slm]` config section in `internal/config/`
|
||||
- [x] `gnoma slm setup` / `gnoma slm status` CLI subcommands
|
||||
- [x] SLM arm registered with `MaxComplexity = 0.3` in `cmd/gnoma/main.go`
|
||||
- [x] TUI `/config` shows SLM status
|
||||
|
||||
**Dependencies:** existing `internal/provider/openaicompat` — no new deps.
|
||||
**Dependencies:** existing `internal/provider/openaicompat` — no new Go deps.
|
||||
|
||||
**Exit criteria:** `gnoma` with `slm_model = "gemma3:1b"` routes using SLM classification.
|
||||
Without config key, behavior is identical to today.
|
||||
@@ -218,8 +195,8 @@ arches: amd64/arm64). The binary is a single static executable with zero runtime
|
||||
| Item | Reason |
|
||||
|------|--------|
|
||||
| `.gnoma/tmp/` local temp directory | `persist.Store` already uses `/tmp/gnoma-<sessionID>/`; adding `.gnoma/tmp/` adds complexity (cleanup, gitignore, collision avoidance) for no benefit |
|
||||
| LiteRT-LM / CGO SLM runtime | `CGO_ENABLED=0` (goreleaser constraint). Go approach: Ollama HTTP via existing openaicompat |
|
||||
| Daemon/PID file management for SLM | Node.js-specific pattern from gemma-integration-analysis.md; not applicable to this Go binary |
|
||||
| LiteRT-LM / CGO SLM runtime | `CGO_ENABLED=0` (goreleaser constraint). Go approach: llamafile subprocess via existing openaicompat |
|
||||
| Ollama-based SLM classifier | Pivoted to llamafile: single-file download, no external daemon, user-controlled opt-in |
|
||||
| PTY via `go-pty` library | Requires CGO. Replaced by `tea.ExecProcess` (already in go.mod, no CGO) |
|
||||
|
||||
---
|
||||
|
||||
@@ -10,11 +10,28 @@ type Config struct {
|
||||
RateLimits RateLimitSection `toml:"rate_limits"`
|
||||
Security SecuritySection `toml:"security"`
|
||||
Session SessionSection `toml:"session"`
|
||||
SLM SLMSection `toml:"slm"`
|
||||
Hooks []HookConfig `toml:"hooks"`
|
||||
MCPServers []MCPServerConfig `toml:"mcp_servers"`
|
||||
Plugins PluginsSection `toml:"plugins"`
|
||||
}
|
||||
|
||||
// SLMSection configures the optional small language model for task classification
|
||||
// and low-complexity task execution.
|
||||
//
|
||||
// Example config:
|
||||
//
|
||||
// [slm]
|
||||
// enabled = true
|
||||
// model_url = "https://huggingface.co/mozilla-ai/TinyLlama-1.1B-Chat-v1.0-llamafile/resolve/main/TinyLlama-1.1B-Chat-v1.0.Q5_K_M.llamafile"
|
||||
//
|
||||
// Run `gnoma slm setup` to download and verify the model before enabling.
|
||||
type SLMSection struct {
|
||||
Enabled bool `toml:"enabled"`
|
||||
ModelURL string `toml:"model_url"`
|
||||
DataDir string `toml:"data_dir"` // empty = XDG default (~/.local/share/gnoma/slm)
|
||||
}
|
||||
|
||||
// MCPServerConfig defines an MCP server to start and connect to.
|
||||
//
|
||||
// Example:
|
||||
|
||||
@@ -47,3 +47,18 @@ func NewLlamaCpp(cfg provider.ProviderConfig) (provider.Provider, error) {
|
||||
}
|
||||
return oaiprov.New(cfg)
|
||||
}
|
||||
|
||||
// NewLlamafile creates a provider for a llamafile process.
|
||||
// BaseURL must include /v1, e.g. "http://127.0.0.1:8080/v1".
|
||||
func NewLlamafile(cfg provider.ProviderConfig) (provider.Provider, error) {
|
||||
if cfg.APIKey == "" {
|
||||
cfg.APIKey = "llamafile" // llamafile doesn't require a real key
|
||||
}
|
||||
if cfg.Model == "" {
|
||||
cfg.Model = "default" // llamafile ignores the model field
|
||||
}
|
||||
if cfg.MaxRetries == nil {
|
||||
cfg.MaxRetries = intPtr(0)
|
||||
}
|
||||
return oaiprov.New(cfg)
|
||||
}
|
||||
|
||||
@@ -22,6 +22,10 @@ type Arm struct {
|
||||
Capabilities provider.Capabilities
|
||||
Pools []*LimitPool
|
||||
|
||||
// MaxComplexity is a hard ceiling on task complexity this arm will accept.
|
||||
// Zero means no ceiling (default for all existing arms).
|
||||
MaxComplexity float64
|
||||
|
||||
// Cost per 1k tokens (EUR, estimated)
|
||||
CostPer1kInput float64
|
||||
CostPer1kOutput float64
|
||||
|
||||
@@ -657,3 +657,52 @@ func TestRouter_AllDisabled_ReturnsError(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterFeasible_MaxComplexity(t *testing.T) {
|
||||
slmArm := &Arm{
|
||||
ID: "slm/tiny",
|
||||
IsLocal: true,
|
||||
MaxComplexity: 0.3,
|
||||
Capabilities: provider.Capabilities{ToolUse: false},
|
||||
}
|
||||
apiArm := &Arm{
|
||||
ID: "api/big",
|
||||
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000},
|
||||
}
|
||||
|
||||
// Low-complexity task: SLM arm passes the ceiling.
|
||||
lowTask := Task{Type: TaskBoilerplate, ComplexityScore: 0.2}
|
||||
got := filterFeasible([]*Arm{slmArm, apiArm}, lowTask)
|
||||
found := false
|
||||
for _, a := range got {
|
||||
if a.ID == "slm/tiny" {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("slm arm should pass filterFeasible for low-complexity task")
|
||||
}
|
||||
|
||||
// High-complexity task: SLM arm must be excluded.
|
||||
highTask := Task{Type: TaskPlanning, ComplexityScore: 0.8, RequiresTools: false}
|
||||
got = filterFeasible([]*Arm{slmArm, apiArm}, highTask)
|
||||
for _, a := range got {
|
||||
if a.ID == "slm/tiny" {
|
||||
t.Error("slm arm should be excluded for high-complexity task")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterFeasible_MaxComplexity_Zero_MeansNoLimit(t *testing.T) {
|
||||
// MaxComplexity == 0 means "no ceiling" — existing arms are unaffected.
|
||||
arm := &Arm{
|
||||
ID: "api/arm",
|
||||
MaxComplexity: 0, // zero = no ceiling
|
||||
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000},
|
||||
}
|
||||
task := Task{Type: TaskOrchestration, ComplexityScore: 0.99}
|
||||
got := filterFeasible([]*Arm{arm}, task)
|
||||
if len(got) == 0 {
|
||||
t.Error("arm with MaxComplexity=0 should never be excluded by complexity ceiling")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -179,6 +179,11 @@ func filterFeasible(arms []*Arm, task Task) []*Arm {
|
||||
var belowQuality []*Arm // passed tool+pool but scored below minimum quality
|
||||
|
||||
for _, arm := range arms {
|
||||
// Complexity ceiling: zero means no ceiling (preserves behavior for all existing arms).
|
||||
if arm.MaxComplexity > 0 && task.ComplexityScore > arm.MaxComplexity {
|
||||
continue
|
||||
}
|
||||
|
||||
// Must support tools if task requires them
|
||||
if task.RequiresTools && !arm.SupportsTools() {
|
||||
continue
|
||||
|
||||
@@ -241,3 +241,32 @@ func estimateComplexity(prompt string) float64 {
|
||||
}
|
||||
return score
|
||||
}
|
||||
|
||||
// ParseTaskType converts a string from an SLM JSON response to a TaskType.
|
||||
// Matching is case-insensitive. Unknown strings fall back to TaskGeneration.
|
||||
func ParseTaskType(s string) TaskType {
|
||||
switch strings.ToLower(strings.ReplaceAll(s, "_", "")) {
|
||||
case "debug":
|
||||
return TaskDebug
|
||||
case "explain":
|
||||
return TaskExplain
|
||||
case "generation":
|
||||
return TaskGeneration
|
||||
case "refactor":
|
||||
return TaskRefactor
|
||||
case "unittest":
|
||||
return TaskUnitTest
|
||||
case "boilerplate":
|
||||
return TaskBoilerplate
|
||||
case "planning":
|
||||
return TaskPlanning
|
||||
case "orchestration":
|
||||
return TaskOrchestration
|
||||
case "securityreview":
|
||||
return TaskSecurityReview
|
||||
case "review":
|
||||
return TaskReview
|
||||
default:
|
||||
return TaskGeneration
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
package router
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseTaskType(t *testing.T) {
|
||||
cases := []struct {
|
||||
input string
|
||||
want TaskType
|
||||
}{
|
||||
{"Debug", TaskDebug},
|
||||
{"debug", TaskDebug},
|
||||
{"DEBUG", TaskDebug},
|
||||
{"Explain", TaskExplain},
|
||||
{"explain", TaskExplain},
|
||||
{"Generation", TaskGeneration},
|
||||
{"generation", TaskGeneration},
|
||||
{"Refactor", TaskRefactor},
|
||||
{"refactor", TaskRefactor},
|
||||
{"UnitTest", TaskUnitTest},
|
||||
{"unit_test", TaskUnitTest},
|
||||
{"unitTest", TaskUnitTest},
|
||||
{"Boilerplate", TaskBoilerplate},
|
||||
{"boilerplate", TaskBoilerplate},
|
||||
{"Planning", TaskPlanning},
|
||||
{"planning", TaskPlanning},
|
||||
{"Orchestration", TaskOrchestration},
|
||||
{"orchestration", TaskOrchestration},
|
||||
{"SecurityReview", TaskSecurityReview},
|
||||
{"security_review", TaskSecurityReview},
|
||||
{"Review", TaskReview},
|
||||
{"review", TaskReview},
|
||||
// unknown falls back to TaskGeneration
|
||||
{"", TaskGeneration},
|
||||
{"unknown", TaskGeneration},
|
||||
{"gibberish", TaskGeneration},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
got := ParseTaskType(tc.input)
|
||||
if got != tc.want {
|
||||
t.Errorf("ParseTaskType(%q) = %s, want %s", tc.input, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,148 @@
|
||||
package slm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||||
"somegit.dev/Owlibou/gnoma/internal/router"
|
||||
"somegit.dev/Owlibou/gnoma/internal/stream"
|
||||
)
|
||||
|
||||
const defaultClassifyTimeout = 2 * time.Second
|
||||
|
||||
const classifySystemPrompt = `Classify the following coding request. Respond with JSON only, no other text.
|
||||
Format: {"task_type": "<type>", "complexity": <0.0-1.0>, "requires_tools": <true|false>}
|
||||
|
||||
Task types: Debug, Explain, Generation, Refactor, UnitTest, Boilerplate, Planning, Orchestration, SecurityReview, Review
|
||||
|
||||
Complexity guide:
|
||||
0.0–0.3: boilerplate, trivial edits, simple lookups, short explanations
|
||||
0.4–0.6: new functions, refactors, unit tests, moderate analysis
|
||||
0.7–1.0: architectural changes, multi-file edits, security review, planning`
|
||||
|
||||
type classifyResponse struct {
|
||||
TaskType string `json:"task_type"`
|
||||
Complexity float64 `json:"complexity"`
|
||||
RequiresTools bool `json:"requires_tools"`
|
||||
}
|
||||
|
||||
// Classifier implements router.TaskClassifier using a llamafile-hosted SLM.
|
||||
// On timeout or parse failure it falls back to router.HeuristicClassifier.
|
||||
type Classifier struct {
|
||||
provider provider.Provider
|
||||
model string
|
||||
timeout time.Duration
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewClassifier creates a Classifier. model is the model name passed to the provider
|
||||
// (llamafile ignores it but openaicompat requires a non-empty value).
|
||||
func NewClassifier(p provider.Provider, model string, logger *slog.Logger) *Classifier {
|
||||
if logger == nil {
|
||||
logger = slog.Default()
|
||||
}
|
||||
return &Classifier{
|
||||
provider: p,
|
||||
model: model,
|
||||
timeout: defaultClassifyTimeout,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Classify calls the SLM and overlays the three SLM-authoritative fields
|
||||
// (Type, ComplexityScore, RequiresTools) onto a heuristic baseline Task.
|
||||
// This ensures Priority, EstimatedTokens, and RequiredEffort are always set.
|
||||
func (c *Classifier) Classify(ctx context.Context, prompt string, history []message.Message) (router.Task, error) {
|
||||
tctx, cancel := context.WithTimeout(ctx, c.timeout)
|
||||
defer cancel()
|
||||
|
||||
resp, err := c.callSLM(tctx, prompt)
|
||||
if err != nil {
|
||||
c.logger.Debug("slm classify fallback", "error", err)
|
||||
return router.HeuristicClassifier{}.Classify(ctx, prompt, history)
|
||||
}
|
||||
|
||||
// Start from the heuristic baseline so Priority/EstimatedTokens/RequiredEffort are set.
|
||||
task := router.ClassifyTask(prompt)
|
||||
task.Type = router.ParseTaskType(resp.TaskType)
|
||||
task.ComplexityScore = resp.Complexity
|
||||
task.RequiresTools = resp.RequiresTools
|
||||
return task, nil
|
||||
}
|
||||
|
||||
func (c *Classifier) callSLM(ctx context.Context, prompt string) (*classifyResponse, error) {
|
||||
req := provider.Request{
|
||||
Model: c.model,
|
||||
SystemPrompt: classifySystemPrompt,
|
||||
Messages: []message.Message{
|
||||
{
|
||||
Role: message.RoleUser,
|
||||
Content: []message.Content{{Type: message.ContentText, Text: prompt}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
strm, err := c.provider.Stream(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stream: %w", err)
|
||||
}
|
||||
defer strm.Close()
|
||||
|
||||
var sb strings.Builder
|
||||
for strm.Next() {
|
||||
ev := strm.Current()
|
||||
if ev.Type == stream.EventTextDelta {
|
||||
sb.WriteString(ev.Text)
|
||||
}
|
||||
}
|
||||
if err := strm.Err(); err != nil {
|
||||
return nil, fmt.Errorf("stream error: %w", err)
|
||||
}
|
||||
|
||||
text := extractJSON(sb.String())
|
||||
var resp classifyResponse
|
||||
if err := json.Unmarshal([]byte(text), &resp); err != nil {
|
||||
return nil, fmt.Errorf("parse %q: %w", text, err)
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// extractJSON pulls the first {...} substring from s, stripping markdown fences if present.
|
||||
func extractJSON(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
|
||||
// Strip ```json ... ``` fences.
|
||||
if strings.HasPrefix(s, "```") {
|
||||
end := strings.LastIndex(s, "```")
|
||||
if end > 3 {
|
||||
inner := s[3:end]
|
||||
inner = strings.TrimPrefix(inner, "json")
|
||||
s = strings.TrimSpace(inner)
|
||||
}
|
||||
}
|
||||
|
||||
// Extract first balanced {...} block.
|
||||
start := strings.IndexByte(s, '{')
|
||||
if start < 0 {
|
||||
return s
|
||||
}
|
||||
depth := 0
|
||||
for i := start; i < len(s); i++ {
|
||||
switch s[i] {
|
||||
case '{':
|
||||
depth++
|
||||
case '}':
|
||||
depth--
|
||||
if depth == 0 {
|
||||
return s[start : i+1]
|
||||
}
|
||||
}
|
||||
}
|
||||
return s[start:]
|
||||
}
|
||||
@@ -0,0 +1,174 @@
|
||||
package slm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||||
"somegit.dev/Owlibou/gnoma/internal/router"
|
||||
"somegit.dev/Owlibou/gnoma/internal/stream"
|
||||
)
|
||||
|
||||
// mockProvider implements provider.Provider for classifier tests.
|
||||
type mockProvider struct {
|
||||
text string
|
||||
delay time.Duration
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockProvider) Name() string { return "mock" }
|
||||
func (m *mockProvider) DefaultModel() string { return "default" }
|
||||
func (m *mockProvider) Models(_ context.Context) ([]provider.ModelInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockProvider) Stream(ctx context.Context, _ provider.Request) (stream.Stream, error) {
|
||||
if m.delay > 0 {
|
||||
select {
|
||||
case <-time.After(m.delay):
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return &mockStream{events: []stream.Event{
|
||||
{Type: stream.EventTextDelta, Text: m.text},
|
||||
}}, nil
|
||||
}
|
||||
|
||||
type mockStream struct {
|
||||
events []stream.Event
|
||||
idx int
|
||||
}
|
||||
|
||||
func (s *mockStream) Next() bool { s.idx++; return s.idx <= len(s.events) }
|
||||
func (s *mockStream) Current() stream.Event { return s.events[s.idx-1] }
|
||||
func (s *mockStream) Err() error { return nil }
|
||||
func (s *mockStream) Close() error { return nil }
|
||||
|
||||
func TestClassifier_HappyPath(t *testing.T) {
|
||||
p := &mockProvider{text: `{"task_type":"Debug","complexity":0.25,"requires_tools":false}`}
|
||||
cls := NewClassifier(p, "default", nil)
|
||||
|
||||
task, err := cls.Classify(context.Background(), "fix the failing test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Classify: %v", err)
|
||||
}
|
||||
if task.Type != router.TaskDebug {
|
||||
t.Errorf("Type = %s, want Debug", task.Type)
|
||||
}
|
||||
if task.ComplexityScore != 0.25 {
|
||||
t.Errorf("ComplexityScore = %v, want 0.25", task.ComplexityScore)
|
||||
}
|
||||
if task.RequiresTools != false {
|
||||
t.Errorf("RequiresTools = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifier_BlendHeuristic(t *testing.T) {
|
||||
// SLM returns one type; other Task fields should come from heuristic.
|
||||
p := &mockProvider{text: `{"task_type":"Boilerplate","complexity":0.1,"requires_tools":false}`}
|
||||
cls := NewClassifier(p, "default", nil)
|
||||
|
||||
task, err := cls.Classify(context.Background(), "scaffold a new HTTP handler", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Classify: %v", err)
|
||||
}
|
||||
if task.Type != router.TaskBoilerplate {
|
||||
t.Errorf("Type = %s, want Boilerplate", task.Type)
|
||||
}
|
||||
// Priority must come from the heuristic baseline (PriorityNormal = 1, not zero).
|
||||
if task.Priority < router.PriorityNormal {
|
||||
t.Errorf("Priority = %v, want at least PriorityNormal from heuristic baseline", task.Priority)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifier_FallbackOnBadJSON(t *testing.T) {
|
||||
p := &mockProvider{text: "I cannot classify that."}
|
||||
cls := NewClassifier(p, "default", nil)
|
||||
|
||||
// Should not error — falls back to heuristic.
|
||||
task, err := cls.Classify(context.Background(), "write unit tests for the parser", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Classify should not error on bad JSON: %v", err)
|
||||
}
|
||||
// Heuristic would return UnitTest for "write unit tests".
|
||||
if task.Type != router.TaskUnitTest {
|
||||
t.Errorf("heuristic fallback: Type = %s, want UnitTest", task.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifier_FallbackOnProviderError(t *testing.T) {
|
||||
p := &mockProvider{err: errors.New("connection refused")}
|
||||
cls := NewClassifier(p, "default", nil)
|
||||
|
||||
task, err := cls.Classify(context.Background(), "explain how generics work", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Classify should not error on provider error: %v", err)
|
||||
}
|
||||
// Heuristic fallback: "explain" → TaskExplain
|
||||
if task.Type != router.TaskExplain {
|
||||
t.Errorf("heuristic fallback: Type = %s, want Explain", task.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifier_FallbackOnTimeout(t *testing.T) {
|
||||
p := &mockProvider{delay: 500 * time.Millisecond}
|
||||
cls := NewClassifier(p, "default", nil)
|
||||
cls.timeout = 50 * time.Millisecond // force timeout
|
||||
|
||||
task, err := cls.Classify(context.Background(), "debug the failing test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Classify should not error on timeout: %v", err)
|
||||
}
|
||||
// Falls back to heuristic: "debug" → TaskDebug
|
||||
if task.Type != router.TaskDebug {
|
||||
t.Errorf("heuristic fallback: Type = %s, want Debug", task.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifier_FenceStripping(t *testing.T) {
|
||||
fenced := "```json\n{\"task_type\":\"Refactor\",\"complexity\":0.5,\"requires_tools\":true}\n```"
|
||||
p := &mockProvider{text: fenced}
|
||||
cls := NewClassifier(p, "default", nil)
|
||||
|
||||
task, err := cls.Classify(context.Background(), "refactor the auth middleware", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Classify: %v", err)
|
||||
}
|
||||
if task.Type != router.TaskRefactor {
|
||||
t.Errorf("Type = %s, want Refactor", task.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifier_UnknownTaskType_FallsBackToHeuristic(t *testing.T) {
|
||||
p := &mockProvider{text: `{"task_type":"FooBar","complexity":0.3,"requires_tools":false}`}
|
||||
cls := NewClassifier(p, "default", nil)
|
||||
|
||||
task, err := cls.Classify(context.Background(), "implement a binary search function", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Classify: %v", err)
|
||||
}
|
||||
// "implement" → heuristic should give Generation or Boilerplate; SLM gave FooBar → Generation fallback
|
||||
_ = task // just verify no panic and no error
|
||||
}
|
||||
|
||||
func TestClassifier_ContextPassedToHistory(t *testing.T) {
|
||||
p := &mockProvider{text: `{"task_type":"Explain","complexity":0.2,"requires_tools":false}`}
|
||||
cls := NewClassifier(p, "default", nil)
|
||||
|
||||
history := []message.Message{
|
||||
{Role: message.RoleUser, Content: []message.Content{{Type: message.ContentText, Text: "prior"}}},
|
||||
}
|
||||
task, err := cls.Classify(context.Background(), "explain this code", history)
|
||||
if err != nil {
|
||||
t.Fatalf("Classify: %v", err)
|
||||
}
|
||||
if task.Type != router.TaskExplain {
|
||||
t.Errorf("Type = %s, want Explain", task.Type)
|
||||
}
|
||||
}
|
||||
@@ -17,6 +17,18 @@ import (
|
||||
|
||||
const pidFile = "llamafile.pid"
|
||||
|
||||
// DefaultDataDir returns the platform default SLM data directory.
|
||||
// Follows XDG Base Directory Specification: $XDG_DATA_HOME/gnoma/slm,
|
||||
// falling back to ~/.local/share/gnoma/slm.
|
||||
func DefaultDataDir() string {
|
||||
dir := os.Getenv("XDG_DATA_HOME")
|
||||
if dir == "" {
|
||||
home, _ := os.UserHomeDir()
|
||||
dir = filepath.Join(home, ".local", "share")
|
||||
}
|
||||
return filepath.Join(dir, "gnoma", "slm")
|
||||
}
|
||||
|
||||
// Status describes the setup state of the SLM.
|
||||
type Status int
|
||||
|
||||
@@ -180,6 +192,15 @@ func (m *Manager) BaseURL() string {
|
||||
return fmt.Sprintf("http://127.0.0.1:%d", m.port)
|
||||
}
|
||||
|
||||
// Manifest returns the on-disk manifest if present, or nil.
|
||||
func (m *Manager) Manifest() *Manifest {
|
||||
mf, err := readManifest(m.cfg.DataDir)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return mf
|
||||
}
|
||||
|
||||
func (m *Manager) pidPath() string {
|
||||
return filepath.Join(m.cfg.DataDir, pidFile)
|
||||
}
|
||||
|
||||
+22
-4
@@ -19,6 +19,7 @@ import (
|
||||
gnomacfg "somegit.dev/Owlibou/gnoma/internal/config"
|
||||
"somegit.dev/Owlibou/gnoma/internal/elf"
|
||||
"somegit.dev/Owlibou/gnoma/internal/skill"
|
||||
"somegit.dev/Owlibou/gnoma/internal/slm"
|
||||
"somegit.dev/Owlibou/gnoma/internal/engine"
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
"somegit.dev/Owlibou/gnoma/internal/permission"
|
||||
@@ -61,6 +62,7 @@ type Config struct {
|
||||
Permissions *permission.Checker // for mode switching
|
||||
Router *router.Router // for model listing
|
||||
ElfManager *elf.Manager // for CancelAll on escape/quit
|
||||
SLMManager *slm.Manager // nil = SLM not configured
|
||||
PermCh chan bool // TUI → engine: y/n response
|
||||
PermReqCh <-chan PermReqMsg // engine → TUI: tool requesting approval
|
||||
ElfProgress <-chan elf.Progress // elf → TUI: structured progress updates
|
||||
@@ -877,16 +879,32 @@ func (m Model) handleCommand(cmd string) (tea.Model, tea.Cmd) {
|
||||
status := m.session.Status()
|
||||
var b strings.Builder
|
||||
b.WriteString("Current configuration:\n")
|
||||
fmt.Fprintf(&b, " provider: %s\n", status.Provider)
|
||||
fmt.Fprintf(&b, " model: %s\n", status.Model)
|
||||
fmt.Fprintf(&b, " provider: %s\n", status.Provider)
|
||||
fmt.Fprintf(&b, " model: %s\n", status.Model)
|
||||
if m.config.Permissions != nil {
|
||||
fmt.Fprintf(&b, " permission: %s\n", m.config.Permissions.Mode())
|
||||
}
|
||||
fmt.Fprintf(&b, " incognito: %v\n", m.incognito)
|
||||
fmt.Fprintf(&b, " cwd: %s\n", m.cwd)
|
||||
fmt.Fprintf(&b, " incognito: %v\n", m.incognito)
|
||||
fmt.Fprintf(&b, " cwd: %s\n", m.cwd)
|
||||
if m.gitBranch != "" {
|
||||
fmt.Fprintf(&b, " git branch: %s\n", m.gitBranch)
|
||||
}
|
||||
if m.config.SLMManager != nil {
|
||||
slmStat := m.config.SLMManager.Status()
|
||||
switch slmStat {
|
||||
case slm.StatusReady:
|
||||
url := m.config.SLMManager.BaseURL()
|
||||
if url != "" {
|
||||
fmt.Fprintf(&b, " slm: ready (running at %s)\n", url)
|
||||
} else {
|
||||
b.WriteString(" slm: ready (not started)\n")
|
||||
}
|
||||
case slm.StatusMissing:
|
||||
b.WriteString(" slm: file missing — run: gnoma slm setup\n")
|
||||
default:
|
||||
b.WriteString(" slm: not set up — run: gnoma slm setup\n")
|
||||
}
|
||||
}
|
||||
b.WriteString("\nConfig files: ~/.config/gnoma/config.toml, .gnoma/config.toml")
|
||||
b.WriteString("\nEdit: /config set <key> <value>")
|
||||
m.messages = append(m.messages, chatMessage{role: "system", content: b.String()})
|
||||
|
||||
Reference in New Issue
Block a user