Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 49d80cf847 | |||
| ea1a5361e2 | |||
| 246997c4be | |||
| 0975bf7118 | |||
| afc31b0af4 | |||
| 1717f9f567 | |||
| f83ace7ad6 | |||
| 7491a36bb7 | |||
| bd41d76e32 | |||
| c5cc98ed8a | |||
| bc137182d4 | |||
| a2b7f8eb3f | |||
| d37cc2dad3 | |||
| e38cce5f1f | |||
| 12a6b83cc9 | |||
| 244ecd97e5 | |||
| 7d0e35b0f4 | |||
| 8d6e66533b | |||
| 69fda263f3 |
@@ -33,7 +33,14 @@ Thumbs.db
|
||||
# Session data
|
||||
.gnoma/sessions/
|
||||
|
||||
# Pasted-image artifacts. New images go to the user cache dir
|
||||
# (~/.cache/gnoma/pasted-images/); the pattern covers legacy
|
||||
# files written into .gnoma/ before that change.
|
||||
.gnoma/pasted_image_*
|
||||
|
||||
# Debug
|
||||
__debug_bin*
|
||||
.env
|
||||
.claude/
|
||||
log.txt
|
||||
codex_out.jsonl
|
||||
|
||||
@@ -95,7 +95,7 @@ learning); `/help` lists slash commands; `Esc` cancels an in-flight turn.
|
||||
| Mistral | `MISTRAL_API_KEY` | `mistral-large-latest` (Mistral Large 3) | `mistral-medium-3.5`, `magistral-medium-2509` |
|
||||
| Ollama (local) | — | `qwen3:8b` (override with `--model`) | any model on your Ollama instance |
|
||||
| llama.cpp (local) | — | reported by `/v1/models` | n/a |
|
||||
| Subprocess (`claude`, `gemini`, `agy` CLIs) | provider-specific | binary name | configurable via `[cli_agents]` |
|
||||
| Subprocess (`claude`, `gemini`, `agy`, `codex`, `vibe` CLIs) | provider-specific | binary name | configurable via `[cli_agents]` |
|
||||
|
||||
Override per-invocation:
|
||||
|
||||
|
||||
@@ -4,13 +4,37 @@ Active work, newest first.
|
||||
|
||||
## In flight
|
||||
|
||||
- **Distribution** — `.goreleaser.yml` is configured for
|
||||
`linux`/`darwin`/`windows` × `amd64`/`arm64`. Still pending: first
|
||||
tag + release pipeline trigger, optional Homebrew tap and Docker
|
||||
image, mirror release publishing to GitHub.
|
||||
- **Entropy FP reduction (post-SLM Phase F)** — F-1 (format-aware
|
||||
pre-extractor) shipped 2026-05-22: `[security].entropy_safelist`
|
||||
with `uuid`, `sha_hex`, `iso8601`, `url`; default empty so
|
||||
pre-F-1 behaviour is unchanged. F-2 (SLM-assisted classifier for
|
||||
ambiguous entropy hits) remains gated on F-1 FP-rate telemetry
|
||||
from real workloads plus ≥50 SLM observations. Surfaced from the
|
||||
r/ollama launch thread (2026-05-20); external validation from
|
||||
alterlab.io on the same tiered approach. See
|
||||
[`docs/superpowers/plans/2026-05-19-post-slm-unlock.md`](docs/superpowers/plans/2026-05-19-post-slm-unlock.md).
|
||||
- **Compound tools (post-SLM Phase E)** — held until ≥50 SLM
|
||||
observations inform which primitives are worth adding. See
|
||||
[`docs/superpowers/plans/2026-05-19-post-slm-unlock.md`](docs/superpowers/plans/2026-05-19-post-slm-unlock.md).
|
||||
- **Sensitive-content handling — unified policy.** Three input paths
|
||||
can introduce sensitive content into the context: pasted images
|
||||
(screenshots may contain secrets, API keys, PII), pasted text (often
|
||||
copied straight from a terminal with credentials), and tool-read
|
||||
files (`.env`, key files, etc.). Today these are handled
|
||||
inconsistently: incognito gates persistence but content still flows
|
||||
to providers; outgoing-scan firewall covers some patterns but is
|
||||
format-aware only for text. Need a single policy/UI: at-paste
|
||||
warning when the content matches sensitive heuristics, a
|
||||
consent-gated review step, and consistent treatment across the
|
||||
three paths. Cross-cuts with Phase F entropy work and the
|
||||
outgoing-scan firewall.
|
||||
- **Distribution — follow-ups.** v0.1.0 shipped (archives on
|
||||
github.com/VikingOwl91/gnoma/releases, multi-arch images on
|
||||
ghcr.io/vikingowl91/gnoma). Still optional: Homebrew tap,
|
||||
`curl | sh` installer script, signed checksums (cosign/sigstore),
|
||||
release note automation, Windows process-tree kill via
|
||||
golang.org/x/sys/windows job objects (currently `os.Process.Kill`
|
||||
only — see `internal/mcp/transport_windows.go`).
|
||||
|
||||
## Stable backlog (not in active phases)
|
||||
|
||||
@@ -30,6 +54,12 @@ Active work, newest first.
|
||||
|
||||
Completed initiatives, kept here as pointers to their plan files:
|
||||
|
||||
- **v0.1.0 release** — 2026-05-20. First tagged release. GoReleaser
|
||||
pipeline produces six static archives (linux/darwin/windows ×
|
||||
amd64/arm64) on the GitHub mirror plus multi-arch Docker images on
|
||||
GHCR. History was rewritten on the same day to migrate authorship to
|
||||
a noreply identity and strip co-author attribution.
|
||||
|
||||
- **Post-audit security hardening** — complete 2026-05-19. Three waves
|
||||
+ one ADR closed all 14 findings from the external review:
|
||||
- [Wave 1 — SafeProvider boundary](docs/superpowers/plans/2026-05-19-security-wave1-safeprovider.md)
|
||||
|
||||
@@ -529,6 +529,7 @@ func main() {
|
||||
ScanOutgoing: true,
|
||||
ScanToolResults: true,
|
||||
EntropyThreshold: entropyThreshold,
|
||||
EntropySafelist: cfg.Security.EntropySafelist,
|
||||
Logger: logger,
|
||||
})
|
||||
// Install into the ref so every SafeProvider wrapper sees scanning
|
||||
@@ -1018,6 +1019,7 @@ func main() {
|
||||
var switchTarget string
|
||||
|
||||
m := tui.New(sess, tui.Config{
|
||||
AppConfig: cfg,
|
||||
Firewall: fw,
|
||||
Engine: eng,
|
||||
Permissions: permChecker,
|
||||
|
||||
@@ -399,6 +399,136 @@ No tasks scoped until that trigger fires.
|
||||
|
||||
---
|
||||
|
||||
## Phase F: Entropy False-Positive Reduction
|
||||
|
||||
Surfaced from the r/ollama launch thread (2026-05-20). Commenter
|
||||
`SharpRule4025` suggested two layered improvements to the firewall's
|
||||
entropy detector; both compose with the existing scanner in
|
||||
`internal/security/scanner.go` without changing its model.
|
||||
|
||||
Empirically the current default already keeps known safe formats well
|
||||
under the 4.5 threshold (UUID4 measured at 3.54–3.72, SHA-256 hex at
|
||||
3.94, SHA-1 at 3.57–3.79), so this is FP-rate *refinement* rather
|
||||
than a correctness fix. The wins are for strict configs that lower
|
||||
the threshold, log-noise reduction in normal use, and a credible
|
||||
story for "we thought about the long tail."
|
||||
|
||||
Public commitment: see the OP reply on r/ollama (2026-05-20). The
|
||||
sequencing committed there is F-1 first (deterministic), F-2 second
|
||||
(SLM-assisted, design work needed on prompt-injection).
|
||||
|
||||
**External validation (2026-05-20).** `SharpRule4025` followed up
|
||||
with production experience from alterlab.io running a similar
|
||||
tiered approach on web-page extraction: deterministic parsers first
|
||||
to strip envelope structure, then targeted smaller models for the
|
||||
residual unstructured text. Reported token-usage reduction in their
|
||||
pipeline: **80–95%**. This isn't a benchmark on gnoma's specific
|
||||
entropy path, but it corroborates the F-1 → F-2 architecture
|
||||
(deterministic first, classifier second) at scale outside this
|
||||
project. Their framing of the SLM step —
|
||||
*"a smart regex that handles the ambiguity without risking a leak
|
||||
to the upstream provider"* — captures the design intent concisely;
|
||||
worth preserving for downstream docs and release notes.
|
||||
|
||||
### F-1: Format-aware pre-extractor (deterministic, low risk)
|
||||
|
||||
**Problem.** `Scanner.scanEntropy()` tokenises by character class
|
||||
(`entropyTokenize`, alphabet `[a-zA-Z0-9_-/]`) but doesn't recognise
|
||||
specific known-safe shapes. Under default thresholds this is fine;
|
||||
under `redactHighEntropy = true` or a lowered threshold it can produce
|
||||
noise on payloads that are mostly structured data.
|
||||
|
||||
**Approach.** Before entropy calculation, extract tokens matching a
|
||||
small allow-list of known-safe patterns (UUID4/5, SHA-1/256 hex,
|
||||
ISO-8601 timestamps, RFC-3986 URLs). Entropy is then computed only
|
||||
on the remaining unstructured residue.
|
||||
|
||||
#### Tasks (F-1)
|
||||
|
||||
- [x] `internal/security/safelist.go` — compiled regex list for the
|
||||
known-safe shapes (`uuid`, `sha_hex`, `iso8601`, `url`) with
|
||||
per-pattern naming so the trace path matches the existing `pattern`
|
||||
log field.
|
||||
- [x] `Scanner.scanEntropy()` consults the safelist first; tokens
|
||||
contained in any safelist span are skipped (not scored).
|
||||
- [x] Config knob `[security].entropy_safelist = ["uuid", "sha_hex",
|
||||
"iso8601", "url"]` so users can curate which formats are auto-skipped.
|
||||
Empty / unset preserves current behaviour exactly. (TOML key lives
|
||||
under `[security]` to match the existing `entropy_threshold` and
|
||||
`redact_high_entropy` knobs, not under a new `[firewall.entropy]`
|
||||
table.)
|
||||
- [x] Tests: UUID skipped, SHA-1/256 skipped, mixed payload with secret
|
||||
preserved, secret-adjacent-to-UUID regression guard, empty safelist
|
||||
preserves pre-F-1 behaviour, unknown name silently dropped.
|
||||
- [ ] Measurement of FP-rate delta on a synthetic corpus — deferred
|
||||
until telemetry from a real workload is available (the synthetic
|
||||
corpus would just measure the unit tests).
|
||||
|
||||
**Effort estimate:** ~150 LOC + tests.
|
||||
|
||||
**Status:** shipped 2026-05-22. Default config remains empty; users
|
||||
opt in by adding `entropy_safelist` to `[security]`. F-2 gating still
|
||||
requires real-world FP-rate observations.
|
||||
|
||||
### F-2: SLM-assisted classifier for ambiguous entropy hits
|
||||
|
||||
**Problem.** After the F-1 deterministic layer, the remaining
|
||||
entropy-flagged tokens are genuinely ambiguous — secrets and
|
||||
application-specific structured strings both look similar to a
|
||||
regex + entropy scorer.
|
||||
|
||||
**Approach.** When the SLM tier is enabled (`[slm] enabled = true`),
|
||||
optionally feed each entropy-flagged token to the existing SLM arm
|
||||
for a binary classification ("credential" / "benign") before
|
||||
deciding whether to redact. The same model that already handles
|
||||
prompt routing in `internal/slm/classifier.go` does double duty as
|
||||
a security-judge.
|
||||
|
||||
**Trust-boundary caveat.** Putting an LLM inside the security
|
||||
decision path adds a prompt-injection surface that doesn't exist
|
||||
today: an entropy-flagged token may contain attacker-controlled bytes
|
||||
(from a tool result), and a sufficiently crafted payload could
|
||||
manipulate the classifier's verdict. Two modes shake out:
|
||||
|
||||
- **Strict** — SLM disabled, or SLM enabled with
|
||||
`block_ambiguous = true`. Treat ambiguous entropy hits as redacts;
|
||||
no model consultation. This must remain the default.
|
||||
- **Assisted** — SLM enabled with `ask_slm = true`. Feed the flagged
|
||||
token (plus minimal anchoring context) to the SLM, accept its
|
||||
verdict above a confidence floor, log every classification for
|
||||
audit.
|
||||
|
||||
#### Tasks (F-2)
|
||||
|
||||
- [ ] `internal/slm/security_classifier.go` — wraps the existing SLM
|
||||
Provider with a credential-classification prompt. Output:
|
||||
`{verdict: "credential" | "benign", confidence: 0..1}`.
|
||||
- [ ] `Firewall.ScanWithSLM()` consults the classifier on ambiguous
|
||||
hits; falls back to the strict path if SLM is disabled, errors,
|
||||
or returns below the confidence floor.
|
||||
- [ ] Audit log for every classifier call — input token *hashed*,
|
||||
not raw; verdict; confidence; source boundary.
|
||||
- [ ] Config: `[firewall.entropy].slm_assist = false` (default),
|
||||
`slm_confidence_floor = 0.7`.
|
||||
- [ ] Adversarial test: prompt-injection payload crafted to flip
|
||||
the verdict must still be redacted at strict / floor settings.
|
||||
|
||||
**Hold this until:**
|
||||
|
||||
- F-1 has shipped and produced FP-rate measurements that quantify
|
||||
how large the residual ambiguous set actually is. If F-1 already
|
||||
closes the gap on real workloads, F-2 may not be worth the new
|
||||
trust boundary.
|
||||
- The SLM arm has ≥50 observations (same telemetry bar as Phase E)
|
||||
so its behaviour under arbitrary input is understood.
|
||||
|
||||
**Effort estimate:** ~300 LOC + tests + adversarial suite. Revise
|
||||
after F-1 telemetry lands.
|
||||
|
||||
**Status:** scoped, blocked on F-1 and SLM telemetry.
|
||||
|
||||
---
|
||||
|
||||
## Out of scope
|
||||
|
||||
Items previously considered and explicitly dropped:
|
||||
@@ -432,6 +562,12 @@ Items previously considered and explicitly dropped:
|
||||
profiles can express per-task arm preferences).
|
||||
5. **Phase E (compound tools)** — re-evaluate once the SLM arm has
|
||||
produced enough telemetry to justify specific primitives.
|
||||
6. **Phase F-1 (format-aware entropy pre-extractor)** — deterministic,
|
||||
no new trust boundary, can ship independently of the SLM-telemetry
|
||||
gating that holds E and F-2. Concrete next-up item if a small
|
||||
self-contained piece of work is needed.
|
||||
7. **Phase F-2 (SLM-assisted entropy classifier)** — blocked on F-1
|
||||
shipping plus the same ≥50-SLM-observation bar as E.
|
||||
|
||||
Or pause and let SLM data accumulate before committing to any of the
|
||||
larger phases (D, C).
|
||||
@@ -442,3 +578,9 @@ larger phases (D, C).
|
||||
|
||||
- 2026-05-19: Initial. Captures outstanding work after the SLM
|
||||
unlock session.
|
||||
- 2026-05-20: Added Phase F (entropy false-positive reduction).
|
||||
Surfaced from the r/ollama launch thread — `SharpRule4025`
|
||||
proposed a format-aware pre-extractor (F-1, deterministic,
|
||||
shippable) and an SLM-assisted classifier for ambiguous hits
|
||||
(F-2, blocked on F-1 + SLM telemetry). Sequencing matches the
|
||||
public OP reply.
|
||||
|
||||
@@ -22,6 +22,7 @@ type Config struct {
|
||||
Hooks []HookConfig `toml:"hooks"`
|
||||
MCPServers []MCPServerConfig `toml:"mcp_servers"`
|
||||
Plugins PluginsSection `toml:"plugins"`
|
||||
TUI TUISection `toml:"tui"`
|
||||
}
|
||||
|
||||
// SLMSection configures the optional small language model used for task
|
||||
@@ -169,14 +170,19 @@ type SessionSection struct {
|
||||
//
|
||||
// [security]
|
||||
// entropy_threshold = 4.5
|
||||
// entropy_safelist = ["uuid", "sha_hex", "iso8601", "url"]
|
||||
//
|
||||
// [[security.patterns]]
|
||||
// name = "internal_token"
|
||||
// regex = "mycompany_[a-zA-Z0-9]{32}"
|
||||
// action = "redact"
|
||||
//
|
||||
// entropy_safelist names known-safe shapes that bypass the entropy scorer
|
||||
// (Phase F-1 FP reduction). Empty / unset preserves pre-F-1 behavior.
|
||||
type SecuritySection struct {
|
||||
EntropyThreshold float64 `toml:"entropy_threshold"`
|
||||
RedactHighEntropy bool `toml:"redact_high_entropy"`
|
||||
EntropySafelist []string `toml:"entropy_safelist"`
|
||||
Patterns []PatternConfig `toml:"patterns"`
|
||||
}
|
||||
|
||||
@@ -201,14 +207,14 @@ type ProviderSection struct {
|
||||
Default string `toml:"default"`
|
||||
Model string `toml:"model"`
|
||||
MaxTokens int64 `toml:"max_tokens"`
|
||||
Temperature *float64 `toml:"temperature"` // TODO(M8): wire to provider.Request.Temperature
|
||||
Temperature *float64 `toml:"temperature"`
|
||||
APIKeys map[string]string `toml:"api_keys"`
|
||||
Endpoints map[string]string `toml:"endpoints"`
|
||||
}
|
||||
|
||||
type ToolsSection struct {
|
||||
BashTimeout Duration `toml:"bash_timeout"`
|
||||
MaxFileSize int64 `toml:"max_file_size"` // TODO(M8): wire to fs tool WithMaxFileSize option
|
||||
MaxFileSize int64 `toml:"max_file_size"`
|
||||
}
|
||||
|
||||
// RateLimitSection allows overriding default rate limits per provider.
|
||||
@@ -254,3 +260,8 @@ func (d *Duration) UnmarshalText(text []byte) error {
|
||||
func (d Duration) Duration() time.Duration {
|
||||
return time.Duration(d)
|
||||
}
|
||||
|
||||
type TUISection struct {
|
||||
Theme string `toml:"theme"`
|
||||
Vim bool `toml:"vim"`
|
||||
}
|
||||
|
||||
@@ -22,5 +22,9 @@ func Defaults() Config {
|
||||
SLM: SLMSection{
|
||||
StartupTimeout: Duration(5 * time.Second),
|
||||
},
|
||||
TUI: TUISection{
|
||||
Theme: "catppuccin",
|
||||
Vim: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,6 +29,8 @@ func setConfig(path, key, value string) error {
|
||||
"slm.model_url": true,
|
||||
"slm.enabled": true,
|
||||
"slm.data_dir": true,
|
||||
"tui.theme": true,
|
||||
"tui.vim": true,
|
||||
}
|
||||
if !allowed[key] {
|
||||
return fmt.Errorf("unknown config key %q (supported: %s)", key, strings.Join(allowedKeys(), ", "))
|
||||
@@ -60,6 +62,10 @@ func setConfig(path, key, value string) error {
|
||||
cfg.SLM.Enabled = value == "true"
|
||||
case "slm.data_dir":
|
||||
cfg.SLM.DataDir = value
|
||||
case "tui.theme":
|
||||
cfg.TUI.Theme = value
|
||||
case "tui.vim":
|
||||
cfg.TUI.Vim = value == "true"
|
||||
}
|
||||
|
||||
// Ensure directory exists
|
||||
@@ -88,5 +94,6 @@ func allowedKeys() []string {
|
||||
return []string{
|
||||
"provider.default", "provider.model", "permission.mode",
|
||||
"slm.model_url", "slm.enabled", "slm.data_dir",
|
||||
"tui.theme", "tui.vim",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -343,6 +343,20 @@ func (e *Engine) latestUserPrompt() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// latestUserHasImages reports whether the most recent user message carries
|
||||
// any inline image content. Used by the routing path to enforce vision
|
||||
// capability when selecting an arm.
|
||||
func (e *Engine) latestUserHasImages() bool {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
for i := len(e.history) - 1; i >= 0; i-- {
|
||||
if e.history[i].Role == message.RoleUser {
|
||||
return e.history[i].HasImages()
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// historySnapshot returns a copy of the current history slice.
|
||||
func (e *Engine) historySnapshot() []message.Message {
|
||||
e.mu.Lock()
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
)
|
||||
|
||||
// imageMarkerRe matches the `[Image: /absolute/path/to/file.ext]` form that
|
||||
// the TUI emits when expanding pasted image placeholders.
|
||||
var imageMarkerRe = regexp.MustCompile(`\[Image:\s*([^\]]+?)\]`)
|
||||
|
||||
// imageMaxBytes caps how big an inline image is allowed to be. Larger files
|
||||
// are skipped (the marker stays as plain text). 10 MiB roughly matches what
|
||||
// vision providers accept inline; bigger payloads almost always indicate a
|
||||
// misclick (e.g. a screen recording) rather than an actual screenshot.
|
||||
const imageMaxBytes = 10 << 20
|
||||
|
||||
// parseImageMarkers splits a user input string into a sequence of content
|
||||
// blocks. Each `[Image: /path]` marker is replaced by an ImageContent block
|
||||
// carrying the file bytes; the surrounding text is preserved as ContentText
|
||||
// blocks. If a marker references a file that can't be read or whose bytes
|
||||
// exceed imageMaxBytes, the marker is left as literal text and a warning
|
||||
// is appended to warnings — the turn still proceeds.
|
||||
//
|
||||
// When no markers are present, the result is a single text block matching
|
||||
// the legacy NewUserText behavior.
|
||||
func parseImageMarkers(input string) (content []message.Content, warnings []string) {
|
||||
indices := imageMarkerRe.FindAllStringSubmatchIndex(input, -1)
|
||||
if len(indices) == 0 {
|
||||
return []message.Content{message.NewTextContent(input)}, nil
|
||||
}
|
||||
|
||||
var blocks []message.Content
|
||||
cursor := 0
|
||||
for _, idx := range indices {
|
||||
matchStart, matchEnd := idx[0], idx[1]
|
||||
pathStart, pathEnd := idx[2], idx[3]
|
||||
path := strings.TrimSpace(input[pathStart:pathEnd])
|
||||
|
||||
// Emit any preceding text as a text block.
|
||||
if matchStart > cursor {
|
||||
if pre := input[cursor:matchStart]; pre != "" {
|
||||
blocks = append(blocks, message.NewTextContent(pre))
|
||||
}
|
||||
}
|
||||
|
||||
img, warn := loadImage(path)
|
||||
if warn != "" {
|
||||
warnings = append(warnings, warn)
|
||||
// Fall back to literal text so the model still sees the reference.
|
||||
blocks = append(blocks, message.NewTextContent(input[matchStart:matchEnd]))
|
||||
} else {
|
||||
blocks = append(blocks, message.NewImageContent(img))
|
||||
}
|
||||
cursor = matchEnd
|
||||
}
|
||||
if cursor < len(input) {
|
||||
if tail := input[cursor:]; tail != "" {
|
||||
blocks = append(blocks, message.NewTextContent(tail))
|
||||
}
|
||||
}
|
||||
if len(blocks) == 0 {
|
||||
blocks = []message.Content{message.NewTextContent("")}
|
||||
}
|
||||
return blocks, warnings
|
||||
}
|
||||
|
||||
func loadImage(path string) (message.Image, string) {
|
||||
if path == "" {
|
||||
return message.Image{}, "image marker had empty path"
|
||||
}
|
||||
if !filepath.IsAbs(path) {
|
||||
return message.Image{}, fmt.Sprintf("image path %q must be absolute; skipping", path)
|
||||
}
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return message.Image{}, fmt.Sprintf("image %q: %v", path, err)
|
||||
}
|
||||
if info.IsDir() {
|
||||
return message.Image{}, fmt.Sprintf("image %q is a directory", path)
|
||||
}
|
||||
if info.Size() > imageMaxBytes {
|
||||
return message.Image{}, fmt.Sprintf("image %q is %d bytes, exceeds %d limit", path, info.Size(), imageMaxBytes)
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return message.Image{}, fmt.Sprintf("image %q read failed: %v", path, err)
|
||||
}
|
||||
mediaType := http.DetectContentType(data)
|
||||
if !strings.HasPrefix(mediaType, "image/") {
|
||||
return message.Image{}, fmt.Sprintf("image %q has unsupported media type %q", path, mediaType)
|
||||
}
|
||||
return message.Image{Data: data, MediaType: mediaType, Path: path}, ""
|
||||
}
|
||||
@@ -0,0 +1,155 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
)
|
||||
|
||||
// pngOnePixel is the minimum valid 1x1 PNG. Used so http.DetectContentType
|
||||
// returns "image/png" and the parser accepts the file.
|
||||
var pngOnePixel = []byte{
|
||||
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A,
|
||||
0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52,
|
||||
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01,
|
||||
0x08, 0x02, 0x00, 0x00, 0x00, 0x90, 0x77, 0x53,
|
||||
0xDE, 0x00, 0x00, 0x00, 0x0C, 0x49, 0x44, 0x41,
|
||||
0x54, 0x08, 0x99, 0x63, 0xF8, 0xCF, 0xC0, 0x00,
|
||||
0x00, 0x00, 0x03, 0x00, 0x01, 0x5B, 0x3E, 0xBA,
|
||||
0xD6, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4E,
|
||||
0x44, 0xAE, 0x42, 0x60, 0x82,
|
||||
}
|
||||
|
||||
func writeTempPNG(t *testing.T) string {
|
||||
t.Helper()
|
||||
p := filepath.Join(t.TempDir(), "test.png")
|
||||
if err := os.WriteFile(p, pngOnePixel, 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func TestParseImageMarkers_NoMarkers(t *testing.T) {
|
||||
got, warns := parseImageMarkers("just plain text")
|
||||
if len(got) != 1 || got[0].Type != message.ContentText || got[0].Text != "just plain text" {
|
||||
t.Errorf("got %+v, want single text block", got)
|
||||
}
|
||||
if len(warns) != 0 {
|
||||
t.Errorf("unexpected warnings: %v", warns)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseImageMarkers_SingleImage(t *testing.T) {
|
||||
path := writeTempPNG(t)
|
||||
got, warns := parseImageMarkers("[Image: " + path + "] what is this?")
|
||||
if len(warns) != 0 {
|
||||
t.Fatalf("unexpected warnings: %v", warns)
|
||||
}
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("got %d blocks, want 2", len(got))
|
||||
}
|
||||
if got[0].Type != message.ContentImage {
|
||||
t.Errorf("block 0 type = %v, want ContentImage", got[0].Type)
|
||||
}
|
||||
if got[0].Image == nil || !bytes.Equal(got[0].Image.Data, pngOnePixel) {
|
||||
t.Error("image bytes not captured into Content.Image.Data")
|
||||
}
|
||||
if got[0].Image.MediaType != "image/png" {
|
||||
t.Errorf("MediaType = %q, want image/png", got[0].Image.MediaType)
|
||||
}
|
||||
if got[1].Type != message.ContentText || got[1].Text != " what is this?" {
|
||||
t.Errorf("block 1 = %+v, want trailing text", got[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseImageMarkers_MissingFileWarnsAndFallsBackToText(t *testing.T) {
|
||||
got, warns := parseImageMarkers("see [Image: /nonexistent/path.png] please")
|
||||
if len(warns) != 1 {
|
||||
t.Fatalf("got %d warnings, want 1", len(warns))
|
||||
}
|
||||
if !strings.Contains(warns[0], "/nonexistent/path.png") {
|
||||
t.Errorf("warning %q should mention path", warns[0])
|
||||
}
|
||||
// Marker stays as literal text so subprocess CLIs that auto-ingest paths still work.
|
||||
var joined string
|
||||
for _, c := range got {
|
||||
if c.Type == message.ContentText {
|
||||
joined += c.Text
|
||||
}
|
||||
if c.Type == message.ContentImage {
|
||||
t.Error("missing file should not produce image content")
|
||||
}
|
||||
}
|
||||
if !strings.Contains(joined, "[Image: /nonexistent/path.png]") {
|
||||
t.Errorf("joined text %q should keep literal marker", joined)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseImageMarkers_RelativePathRejected(t *testing.T) {
|
||||
_, warns := parseImageMarkers("[Image: relative/path.png]")
|
||||
if len(warns) != 1 {
|
||||
t.Fatalf("got %d warnings, want 1", len(warns))
|
||||
}
|
||||
if !strings.Contains(warns[0], "absolute") {
|
||||
t.Errorf("warning %q should explain absolute-path requirement", warns[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseImageMarkers_OversizedRejected(t *testing.T) {
|
||||
p := filepath.Join(t.TempDir(), "big.png")
|
||||
// Write a >10MiB file (header still says PNG so media type detect passes).
|
||||
big := make([]byte, imageMaxBytes+1)
|
||||
copy(big, pngOnePixel)
|
||||
if err := os.WriteFile(p, big, 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, warns := parseImageMarkers("[Image: " + p + "]")
|
||||
if len(warns) != 1 {
|
||||
t.Fatalf("got %d warnings, want 1", len(warns))
|
||||
}
|
||||
if !strings.Contains(warns[0], "exceeds") {
|
||||
t.Errorf("warning %q should explain size limit", warns[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseImageMarkers_NonImageFileRejected(t *testing.T) {
|
||||
p := filepath.Join(t.TempDir(), "not_an_image.txt")
|
||||
if err := os.WriteFile(p, []byte("plain text, not an image"), 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, warns := parseImageMarkers("[Image: " + p + "]")
|
||||
if len(warns) != 1 {
|
||||
t.Fatalf("got %d warnings, want 1", len(warns))
|
||||
}
|
||||
if !strings.Contains(warns[0], "unsupported media type") {
|
||||
t.Errorf("warning %q should mention media type", warns[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseImageMarkers_MultipleImagesAndText(t *testing.T) {
|
||||
p1 := writeTempPNG(t)
|
||||
p2 := writeTempPNG(t)
|
||||
input := "before [Image: " + p1 + "] between [Image: " + p2 + "] after"
|
||||
got, warns := parseImageMarkers(input)
|
||||
if len(warns) != 0 {
|
||||
t.Fatalf("unexpected warnings: %v", warns)
|
||||
}
|
||||
// Expected order: text, image, text, image, text
|
||||
wantTypes := []message.ContentType{
|
||||
message.ContentText, message.ContentImage,
|
||||
message.ContentText, message.ContentImage,
|
||||
message.ContentText,
|
||||
}
|
||||
if len(got) != len(wantTypes) {
|
||||
t.Fatalf("got %d blocks, want %d", len(got), len(wantTypes))
|
||||
}
|
||||
for i, want := range wantTypes {
|
||||
if got[i].Type != want {
|
||||
t.Errorf("block %d type = %v, want %v", i, got[i].Type, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
+27
-1
@@ -29,9 +29,10 @@ func (e *Engine) Submit(ctx context.Context, input string, cb Callback) (*Turn,
|
||||
|
||||
// SubmitWithOptions is like Submit but applies per-turn overrides (e.g. ToolChoice).
|
||||
func (e *Engine) SubmitWithOptions(ctx context.Context, input string, opts TurnOptions, cb Callback) (*Turn, error) {
|
||||
userMsg := e.buildUserMessage(ctx, input, cb)
|
||||
|
||||
e.mu.Lock()
|
||||
e.turnOpts = opts
|
||||
userMsg := message.NewUserText(input)
|
||||
e.history = append(e.history, userMsg)
|
||||
e.mu.Unlock()
|
||||
defer func() {
|
||||
@@ -47,6 +48,29 @@ func (e *Engine) SubmitWithOptions(ctx context.Context, input string, opts TurnO
|
||||
return e.runLoop(ctx, cb)
|
||||
}
|
||||
|
||||
// buildUserMessage wraps the raw input into a message.Message. When the
|
||||
// active model advertises Vision capability and the input contains
|
||||
// `[Image: /path]` markers, the markers are inlined as ImageContent blocks
|
||||
// carrying the file bytes; otherwise the input is wrapped as a single
|
||||
// text block (legacy behavior). Marker-parse warnings are forwarded to cb
|
||||
// as system events so the user sees why a paste fell back to text.
|
||||
func (e *Engine) buildUserMessage(ctx context.Context, input string, cb Callback) message.Message {
|
||||
if !imageMarkerRe.MatchString(input) {
|
||||
return message.NewUserText(input)
|
||||
}
|
||||
caps := e.resolveCapabilities(ctx)
|
||||
if caps == nil || !caps.Vision {
|
||||
// Active model can't see images; leave markers as text so any
|
||||
// downstream subprocess CLI that auto-ingests paths still works.
|
||||
return message.NewUserText(input)
|
||||
}
|
||||
content, warnings := parseImageMarkers(input)
|
||||
for _, w := range warnings {
|
||||
e.logger.Warn("image marker parse", "warning", w)
|
||||
}
|
||||
return message.Message{Role: message.RoleUser, Content: content}
|
||||
}
|
||||
|
||||
// SubmitMessages is like Submit but accepts pre-built messages.
|
||||
func (e *Engine) SubmitMessages(ctx context.Context, msgs []message.Message, cb Callback) (*Turn, error) {
|
||||
e.mu.Lock()
|
||||
@@ -142,6 +166,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
|
||||
task.EstimatedTokens = int(gnomactx.EstimateTokens(prompt))
|
||||
}
|
||||
task.ExcludedArms = failedArms
|
||||
task.RequiresVision = e.latestUserHasImages()
|
||||
|
||||
e.logger.Debug("routing request",
|
||||
"task_type", task.Type,
|
||||
@@ -212,6 +237,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
|
||||
}
|
||||
|
||||
task.ExcludedArms = failedArms
|
||||
task.RequiresVision = e.latestUserHasImages()
|
||||
var retryDecision router.RoutingDecision
|
||||
s, retryDecision, err = e.cfg.Router.Stream(ctx, task, req)
|
||||
if err == nil {
|
||||
|
||||
@@ -13,6 +13,7 @@ const (
|
||||
ContentToolCall
|
||||
ContentToolResult
|
||||
ContentThinking
|
||||
ContentImage
|
||||
)
|
||||
|
||||
func (ct ContentType) String() string {
|
||||
@@ -25,6 +26,8 @@ func (ct ContentType) String() string {
|
||||
return "tool_result"
|
||||
case ContentThinking:
|
||||
return "thinking"
|
||||
case ContentImage:
|
||||
return "image"
|
||||
default:
|
||||
return fmt.Sprintf("unknown(%d)", ct)
|
||||
}
|
||||
@@ -37,6 +40,7 @@ type Content struct {
|
||||
ToolCall *ToolCall // ContentToolCall
|
||||
ToolResult *ToolResult // ContentToolResult
|
||||
Thinking *Thinking // ContentThinking
|
||||
Image *Image // ContentImage
|
||||
}
|
||||
|
||||
// ToolCall represents the model's request to invoke a tool.
|
||||
@@ -61,6 +65,17 @@ type Thinking struct {
|
||||
Redacted bool `json:"redacted,omitempty"`
|
||||
}
|
||||
|
||||
// Image carries inline image bytes for vision-capable models. Data is the
|
||||
// raw image bytes captured at user-input time so the message snapshot is
|
||||
// self-contained (file deletion or rename after the turn does not break
|
||||
// translation). MediaType is the IANA media type (e.g. "image/png").
|
||||
// Path is retained for human-readable display and logging only.
|
||||
type Image struct {
|
||||
Data []byte `json:"data"`
|
||||
MediaType string `json:"media_type"`
|
||||
Path string `json:"path,omitempty"`
|
||||
}
|
||||
|
||||
func NewTextContent(text string) Content {
|
||||
return Content{Type: ContentText, Text: text}
|
||||
}
|
||||
@@ -76,3 +91,7 @@ func NewToolResultContent(tr ToolResult) Content {
|
||||
func NewThinkingContent(th Thinking) Content {
|
||||
return Content{Type: ContentThinking, Thinking: &th}
|
||||
}
|
||||
|
||||
func NewImageContent(img Image) Content {
|
||||
return Content{Type: ContentImage, Image: &img}
|
||||
}
|
||||
|
||||
@@ -87,3 +87,15 @@ func (m Message) TextContent() string {
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// HasImages reports whether any content block in the message is an inline
|
||||
// image. Providers that don't support vision can use this to decide whether
|
||||
// to fall back to a text-only representation.
|
||||
func (m Message) HasImages() bool {
|
||||
for _, c := range m.Content {
|
||||
if c.Type == ContentImage {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -2,14 +2,29 @@ package google
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||||
"somegit.dev/Owlibou/gnoma/internal/stream"
|
||||
|
||||
"cloud.google.com/go/auth"
|
||||
"cloud.google.com/go/auth/credentials"
|
||||
"google.golang.org/genai"
|
||||
)
|
||||
|
||||
// cloudPlatformScope is the standard OAuth scope used for Vertex AI and
|
||||
// the Gemini API on Google Cloud. credentials.DetectDefault REQUIRES at
|
||||
// least Scopes or Audience to be set — calling it with nil options
|
||||
// returns "credentials: options must be provided" and the ADC branch
|
||||
// becomes dead code.
|
||||
const cloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform"
|
||||
|
||||
const defaultModel = "gemini-3.5-flash"
|
||||
|
||||
// Provider implements provider.Provider for Google's Gemini API.
|
||||
@@ -19,18 +34,284 @@ type Provider struct {
|
||||
model string
|
||||
}
|
||||
|
||||
// New creates a Google GenAI provider from config.
|
||||
func New(cfg provider.ProviderConfig) (provider.Provider, error) {
|
||||
if cfg.APIKey == "" {
|
||||
return nil, fmt.Errorf("google: api key required")
|
||||
type oauthCreds struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
AccessToken2 string `json:"accessToken"`
|
||||
ExpiryDate int64 `json:"expiry_date"`
|
||||
ExpiresAt int64 `json:"expiresAt"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
RefreshToken2 string `json:"refreshToken"`
|
||||
TokenType string `json:"token_type"`
|
||||
TokenType2 string `json:"tokenType"`
|
||||
}
|
||||
|
||||
func (c *oauthCreds) Token() string {
|
||||
if c.AccessToken != "" {
|
||||
return c.AccessToken
|
||||
}
|
||||
return c.AccessToken2
|
||||
}
|
||||
|
||||
func (c *oauthCreds) Expiry() time.Time {
|
||||
val := c.ExpiryDate
|
||||
if val == 0 {
|
||||
val = c.ExpiresAt
|
||||
}
|
||||
if val > 0 {
|
||||
if val > 9999999999 {
|
||||
return time.UnixMilli(val)
|
||||
}
|
||||
return time.Unix(val, 0)
|
||||
}
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
type fileTokenProvider struct {
|
||||
filePath string
|
||||
}
|
||||
|
||||
func (tp *fileTokenProvider) Token(ctx context.Context) (*auth.Token, error) {
|
||||
data, err := os.ReadFile(tp.filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read oauth credentials: %w", err)
|
||||
}
|
||||
|
||||
client, err := genai.NewClient(context.Background(), &genai.ClientConfig{
|
||||
APIKey: cfg.APIKey,
|
||||
Backend: genai.BackendGeminiAPI,
|
||||
})
|
||||
var creds oauthCreds
|
||||
if err := json.Unmarshal(data, &creds); err != nil {
|
||||
return nil, fmt.Errorf("parse oauth credentials: %w", err)
|
||||
}
|
||||
|
||||
tokVal := creds.Token()
|
||||
if tokVal == "" {
|
||||
return nil, fmt.Errorf("no access token in credentials file")
|
||||
}
|
||||
|
||||
// We don't perform an OAuth refresh exchange ourselves; the upstream
|
||||
// CLI (gemini / antigravity) refreshes the file out-of-band. If we're
|
||||
// asked for a token after expiry and the file hasn't been refreshed,
|
||||
// fail loudly with an actionable message instead of sending a known-
|
||||
// dead bearer that the API would reject with a confusing 401.
|
||||
expiry := creds.Expiry()
|
||||
if !expiry.IsZero() && time.Now().After(expiry) {
|
||||
return nil, fmt.Errorf("oauth token at %s is expired (re-run the upstream CLI to refresh)", tp.filePath)
|
||||
}
|
||||
|
||||
tokenType := creds.TokenType
|
||||
if tokenType == "" {
|
||||
tokenType = creds.TokenType2
|
||||
}
|
||||
if tokenType == "" {
|
||||
tokenType = "Bearer"
|
||||
}
|
||||
|
||||
return &auth.Token{
|
||||
Value: tokVal,
|
||||
Type: tokenType,
|
||||
Expiry: expiry,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func expandHome(path string) string {
|
||||
if len(path) == 0 || path[0] != '~' {
|
||||
return path
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("google: create client: %w", err)
|
||||
return path
|
||||
}
|
||||
if len(path) == 1 {
|
||||
return home
|
||||
}
|
||||
if path[1] == '/' || path[1] == '\\' {
|
||||
return filepath.Join(home, path[2:])
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
// errCredentialMissing wraps os.ErrNotExist for the precedence walker so
|
||||
// the "file isn't there" case is silent while permission / parse / empty-
|
||||
// token failures get a slog.Warn (they typically indicate a misconfigured
|
||||
// install — chmod 0600 on the wrong file, half-written JSON, etc.).
|
||||
var errCredentialMissing = errors.New("credential file not present")
|
||||
|
||||
func tryLoadOAuthCredentials(filePath string) (*auth.Credentials, error) {
|
||||
expanded := expandHome(filePath)
|
||||
if _, err := os.Stat(expanded); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, errCredentialMissing
|
||||
}
|
||||
slog.Warn("google oauth: stat failed", "path", expanded, "err", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(expanded)
|
||||
if err != nil {
|
||||
slog.Warn("google oauth: read failed", "path", expanded, "err", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var creds oauthCreds
|
||||
if err := json.Unmarshal(data, &creds); err != nil {
|
||||
slog.Warn("google oauth: parse failed", "path", expanded, "err", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tokVal := creds.Token()
|
||||
if tokVal == "" {
|
||||
slog.Warn("google oauth: empty access token", "path", expanded)
|
||||
return nil, fmt.Errorf("empty access token in %s", expanded)
|
||||
}
|
||||
|
||||
expiry := creds.Expiry()
|
||||
if !expiry.IsZero() && time.Now().After(expiry) {
|
||||
slog.Warn("google oauth: token expired", "path", expanded, "expired_at", expiry)
|
||||
return nil, fmt.Errorf("token in %s expired at %s", expanded, expiry.Format(time.RFC3339))
|
||||
}
|
||||
|
||||
tp := &fileTokenProvider{filePath: expanded}
|
||||
return auth.NewCredentials(&auth.CredentialsOptions{
|
||||
TokenProvider: tp,
|
||||
}), nil
|
||||
}
|
||||
|
||||
// CredentialSource labels the origin of the auth credential returned by
|
||||
// selectOAuthCredentials. Used by tests and diagnostics.
|
||||
type CredentialSource string
|
||||
|
||||
const (
|
||||
CredentialSourceNone CredentialSource = ""
|
||||
CredentialSourceAgy CredentialSource = "agy"
|
||||
CredentialSourceGemini CredentialSource = "gemini"
|
||||
CredentialSourceADC CredentialSource = "adc"
|
||||
)
|
||||
|
||||
// agyCredentialPaths lists the OAuth credential file locations that the
|
||||
// agy / antigravity CLIs are known to write to. First match wins.
|
||||
var agyCredentialPaths = []string{
|
||||
"~/.config/google-antigravity/session.json",
|
||||
"~/.config/google-antigravity/oauth_creds.json",
|
||||
"~/.config/antigravity/session.json",
|
||||
"~/.config/antigravity/oauth_creds.json",
|
||||
"~/.config/antigravity-cli/session.json",
|
||||
"~/.config/antigravity-cli/oauth_creds.json",
|
||||
"~/.gemini/antigravity-cli/oauth_creds.json",
|
||||
}
|
||||
|
||||
// geminiCredentialPaths lists the locations the official gemini CLI uses.
|
||||
var geminiCredentialPaths = []string{
|
||||
"~/.gemini/oauth_creds.json",
|
||||
"~/.config/gemini-cli/oauth_creds.json",
|
||||
}
|
||||
|
||||
// selectOAuthCredentials walks the precedence chain (agy → gemini → ADC)
|
||||
// and returns the first usable credential plus a tag identifying which
|
||||
// source it came from. Tests use the tag to verify precedence; the New()
|
||||
// builder discards it.
|
||||
func selectOAuthCredentials() (*auth.Credentials, CredentialSource, error) {
|
||||
for _, path := range agyCredentialPaths {
|
||||
if c, err := tryLoadOAuthCredentials(path); err == nil {
|
||||
return c, CredentialSourceAgy, nil
|
||||
}
|
||||
}
|
||||
for _, path := range geminiCredentialPaths {
|
||||
if c, err := tryLoadOAuthCredentials(path); err == nil {
|
||||
return c, CredentialSourceGemini, nil
|
||||
}
|
||||
}
|
||||
// Application Default Credentials. DetectDefault REQUIRES scopes —
|
||||
// passing nil makes the call always error, leaving ADC unreachable.
|
||||
c, err := credentials.DetectDefault(&credentials.DetectOptions{
|
||||
Scopes: []string{cloudPlatformScope},
|
||||
})
|
||||
if err == nil {
|
||||
return c, CredentialSourceADC, nil
|
||||
}
|
||||
slog.Debug("google adc: DetectDefault failed", "err", err)
|
||||
return nil, CredentialSourceNone, fmt.Errorf("no google credentials found (tried agy session, gemini session, and ADC)")
|
||||
}
|
||||
|
||||
// New creates a Google GenAI provider from config.
|
||||
func New(cfg provider.ProviderConfig) (provider.Provider, error) {
|
||||
var client *genai.Client
|
||||
var err error
|
||||
|
||||
if cfg.APIKey != "" {
|
||||
client, err = genai.NewClient(context.Background(), &genai.ClientConfig{
|
||||
APIKey: cfg.APIKey,
|
||||
Backend: genai.BackendGeminiAPI,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("google: create client (Gemini API): %w", err)
|
||||
}
|
||||
} else {
|
||||
creds, source, selErr := selectOAuthCredentials()
|
||||
if selErr != nil {
|
||||
return nil, fmt.Errorf("google: %w", selErr)
|
||||
}
|
||||
slog.Debug("google auth: credential selected", "source", source)
|
||||
|
||||
// Resolve Project ID
|
||||
var projectID string
|
||||
if projectVal, ok := cfg.Options["project"]; ok {
|
||||
if s, ok := projectVal.(string); ok {
|
||||
projectID = s
|
||||
}
|
||||
}
|
||||
if projectID == "" {
|
||||
if projectIDVal, ok := cfg.Options["project_id"]; ok {
|
||||
if s, ok := projectIDVal.(string); ok {
|
||||
projectID = s
|
||||
}
|
||||
}
|
||||
}
|
||||
if projectID == "" && creds != nil {
|
||||
if pid, err := creds.ProjectID(context.Background()); err == nil && pid != "" {
|
||||
projectID = pid
|
||||
}
|
||||
}
|
||||
if projectID == "" {
|
||||
projectID = os.Getenv("GOOGLE_CLOUD_PROJECT")
|
||||
}
|
||||
if projectID == "" {
|
||||
projectID = os.Getenv("GOOGLE_PROJECT")
|
||||
}
|
||||
if projectID == "" {
|
||||
return nil, fmt.Errorf("google: project id is required for Vertex AI backend")
|
||||
}
|
||||
|
||||
// Resolve Location
|
||||
var location string
|
||||
if locVal, ok := cfg.Options["location"]; ok {
|
||||
if s, ok := locVal.(string); ok {
|
||||
location = s
|
||||
}
|
||||
}
|
||||
if location == "" {
|
||||
if regVal, ok := cfg.Options["region"]; ok {
|
||||
if s, ok := regVal.(string); ok {
|
||||
location = s
|
||||
}
|
||||
}
|
||||
}
|
||||
if location == "" {
|
||||
location = os.Getenv("GOOGLE_CLOUD_LOCATION")
|
||||
}
|
||||
if location == "" {
|
||||
location = os.Getenv("GOOGLE_CLOUD_REGION")
|
||||
}
|
||||
if location == "" {
|
||||
location = "us-central1"
|
||||
}
|
||||
|
||||
client, err = genai.NewClient(context.Background(), &genai.ClientConfig{
|
||||
Backend: genai.BackendVertexAI,
|
||||
Credentials: creds,
|
||||
Project: projectID,
|
||||
Location: location,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("google: create client (Vertex AI): %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
model := cfg.Model
|
||||
|
||||
@@ -0,0 +1,228 @@
|
||||
package google
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"cloud.google.com/go/auth"
|
||||
|
||||
_ "somegit.dev/Owlibou/gnoma/internal/provider"
|
||||
)
|
||||
|
||||
func TestTryLoadOAuthCredentials_Formats(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
data interface{}
|
||||
expectError bool
|
||||
checkToken string
|
||||
checkExpiry time.Time
|
||||
}{
|
||||
{
|
||||
name: "snake_case and seconds expiry",
|
||||
data: oauthCreds{
|
||||
AccessToken: "token-snake",
|
||||
ExpiryDate: time.Now().Add(1 * time.Hour).Unix(),
|
||||
TokenType: "Bearer",
|
||||
},
|
||||
expectError: false,
|
||||
checkToken: "token-snake",
|
||||
},
|
||||
{
|
||||
name: "camelCase and milliseconds expiry",
|
||||
data: oauthCreds{
|
||||
AccessToken2: "token-camel",
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour).UnixNano() / 1e6,
|
||||
TokenType2: "Bearer",
|
||||
},
|
||||
expectError: false,
|
||||
checkToken: "token-camel",
|
||||
},
|
||||
{
|
||||
name: "expired token",
|
||||
data: oauthCreds{
|
||||
AccessToken: "token-expired",
|
||||
ExpiryDate: time.Now().Add(-1 * time.Hour).Unix(),
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "missing access token",
|
||||
data: oauthCreds{
|
||||
ExpiryDate: time.Now().Add(1 * time.Hour).Unix(),
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
filePath := filepath.Join(tmpDir, "creds.json")
|
||||
bz, err := json.Marshal(tc.data)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal failed: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filePath, bz, 0644); err != nil {
|
||||
t.Fatalf("write file failed: %v", err)
|
||||
}
|
||||
|
||||
creds, err := tryLoadOAuthCredentials(filePath)
|
||||
if tc.expectError {
|
||||
if err == nil {
|
||||
t.Fatalf("expected error but got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
tok, err := creds.Token(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get token: %v", err)
|
||||
}
|
||||
|
||||
if tok.Value != tc.checkToken {
|
||||
t.Errorf("expected token %q, got %q", tc.checkToken, tok.Value)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectOAuthCredentials_Precedence(t *testing.T) {
|
||||
// Override HOME so expandHome() resolves into a sandbox dir.
|
||||
tmpHome := t.TempDir()
|
||||
t.Setenv("HOME", tmpHome)
|
||||
|
||||
writeCreds := func(relPath, tokenVal string) {
|
||||
absPath := filepath.Join(tmpHome, relPath)
|
||||
if err := os.MkdirAll(filepath.Dir(absPath), 0755); err != nil {
|
||||
t.Fatalf("mkdir: %v", err)
|
||||
}
|
||||
data := oauthCreds{
|
||||
AccessToken: tokenVal,
|
||||
ExpiryDate: time.Now().Add(1 * time.Hour).Unix(),
|
||||
}
|
||||
bz, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(absPath, bz, 0600); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
tokenOf := func(c *auth.Credentials) string {
|
||||
t.Helper()
|
||||
tok, err := c.Token(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Token: %v", err)
|
||||
}
|
||||
return tok.Value
|
||||
}
|
||||
|
||||
t.Run("agy beats gemini when both present", func(t *testing.T) {
|
||||
// Fresh sandbox per subtest to avoid leftover files.
|
||||
sub := t.TempDir()
|
||||
t.Setenv("HOME", sub)
|
||||
// Use the first agy path and the first gemini path.
|
||||
writeAt := func(rel, tok string) {
|
||||
abs := filepath.Join(sub, rel)
|
||||
if err := os.MkdirAll(filepath.Dir(abs), 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
bz, _ := json.Marshal(oauthCreds{
|
||||
AccessToken: tok,
|
||||
ExpiryDate: time.Now().Add(time.Hour).Unix(),
|
||||
})
|
||||
if err := os.WriteFile(abs, bz, 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
writeAt(filepath.Join(".config", "google-antigravity", "session.json"), "token-agy")
|
||||
writeAt(filepath.Join(".gemini", "oauth_creds.json"), "token-gemini")
|
||||
|
||||
creds, source, err := selectOAuthCredentials()
|
||||
if err != nil {
|
||||
t.Fatalf("selectOAuthCredentials: %v", err)
|
||||
}
|
||||
if source != CredentialSourceAgy {
|
||||
t.Errorf("source = %q, want %q", source, CredentialSourceAgy)
|
||||
}
|
||||
if got := tokenOf(creds); got != "token-agy" {
|
||||
t.Errorf("loaded token = %q, want token-agy (agy precedence violated)", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("falls back to gemini when agy missing", func(t *testing.T) {
|
||||
sub := t.TempDir()
|
||||
t.Setenv("HOME", sub)
|
||||
// Only gemini file present.
|
||||
geminiPath := filepath.Join(sub, ".gemini", "oauth_creds.json")
|
||||
if err := os.MkdirAll(filepath.Dir(geminiPath), 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
bz, _ := json.Marshal(oauthCreds{
|
||||
AccessToken: "token-gemini-only",
|
||||
ExpiryDate: time.Now().Add(time.Hour).Unix(),
|
||||
})
|
||||
if err := os.WriteFile(geminiPath, bz, 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
creds, source, err := selectOAuthCredentials()
|
||||
if err != nil {
|
||||
t.Fatalf("selectOAuthCredentials: %v", err)
|
||||
}
|
||||
if source != CredentialSourceGemini {
|
||||
t.Errorf("source = %q, want %q", source, CredentialSourceGemini)
|
||||
}
|
||||
if got := tokenOf(creds); got != "token-gemini-only" {
|
||||
t.Errorf("loaded token = %q, want token-gemini-only", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing files are not warning-worthy", func(t *testing.T) {
|
||||
// Sanity check: empty home directory walks the chain without
|
||||
// failing in unexpected ways (only ADC would remain, which we
|
||||
// don't assert on here because the test host may or may not have
|
||||
// gcloud configured).
|
||||
sub := t.TempDir()
|
||||
t.Setenv("HOME", sub)
|
||||
_, _, err := selectOAuthCredentials()
|
||||
// Either ADC works on this host (no error) or no creds anywhere
|
||||
// (returns our specific "no google credentials" error). Both are
|
||||
// fine; the point is we don't panic or report a misconfiguration.
|
||||
if err != nil && !strings.Contains(err.Error(), "no google credentials") {
|
||||
t.Errorf("unexpected error shape: %v", err)
|
||||
}
|
||||
})
|
||||
_ = writeCreds // keep helper available if extended in future
|
||||
}
|
||||
|
||||
func TestFileTokenProvider_RejectsExpired(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "creds.json")
|
||||
bz, _ := json.Marshal(oauthCreds{
|
||||
AccessToken: "stale",
|
||||
ExpiryDate: time.Now().Add(-time.Hour).Unix(),
|
||||
})
|
||||
if err := os.WriteFile(path, bz, 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tp := &fileTokenProvider{filePath: path}
|
||||
tok, err := tp.Token(context.Background())
|
||||
if err == nil {
|
||||
t.Errorf("expected error for expired token, got token %+v", tok)
|
||||
}
|
||||
if err != nil && !strings.Contains(err.Error(), "expired") {
|
||||
t.Errorf("error %q should mention expiry", err)
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,9 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
@@ -39,6 +41,37 @@ func unsanitizeToolName(name string) string {
|
||||
return name
|
||||
}
|
||||
|
||||
// buildUserContentParts converts a heterogeneous user-content slice into
|
||||
// OpenAI content-parts. Adjacent text blocks are concatenated. Each Image
|
||||
// block is emitted as an image_url part carrying a base64 data URL.
|
||||
func buildUserContentParts(blocks []message.Content) []oai.ChatCompletionContentPartUnionParam {
|
||||
parts := make([]oai.ChatCompletionContentPartUnionParam, 0, len(blocks))
|
||||
var textBuf strings.Builder
|
||||
flushText := func() {
|
||||
if textBuf.Len() > 0 {
|
||||
parts = append(parts, oai.TextContentPart(textBuf.String()))
|
||||
textBuf.Reset()
|
||||
}
|
||||
}
|
||||
for _, c := range blocks {
|
||||
switch c.Type {
|
||||
case message.ContentText:
|
||||
textBuf.WriteString(c.Text)
|
||||
case message.ContentImage:
|
||||
if c.Image == nil || len(c.Image.Data) == 0 {
|
||||
continue
|
||||
}
|
||||
flushText()
|
||||
dataURL := fmt.Sprintf("data:%s;base64,%s", c.Image.MediaType, base64.StdEncoding.EncodeToString(c.Image.Data))
|
||||
parts = append(parts, oai.ImageContentPart(oai.ChatCompletionContentPartImageImageURLParam{
|
||||
URL: dataURL,
|
||||
}))
|
||||
}
|
||||
}
|
||||
flushText()
|
||||
return parts
|
||||
}
|
||||
|
||||
// --- gnoma → OpenAI ---
|
||||
|
||||
func translateMessages(msgs []message.Message) []oai.ChatCompletionMessageParamUnion {
|
||||
@@ -67,6 +100,12 @@ func translateMessage(m message.Message) []oai.ChatCompletionMessageParamUnion {
|
||||
}
|
||||
return msgs
|
||||
}
|
||||
// Inline images → content parts array; pure text → plain string.
|
||||
if m.HasImages() {
|
||||
return []oai.ChatCompletionMessageParamUnion{
|
||||
oai.UserMessage(buildUserContentParts(m.Content)),
|
||||
}
|
||||
}
|
||||
return []oai.ChatCompletionMessageParamUnion{
|
||||
oai.UserMessage(m.TextContent()),
|
||||
}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
@@ -10,6 +12,85 @@ import (
|
||||
"github.com/openai/openai-go/packages/param"
|
||||
)
|
||||
|
||||
func TestTranslateMessage_UserTextOnly_UsesStringContent(t *testing.T) {
|
||||
m := message.NewUserText("hello")
|
||||
out := translateMessage(m)
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("got %d messages, want 1", len(out))
|
||||
}
|
||||
user := out[0].OfUser
|
||||
if user == nil {
|
||||
t.Fatal("expected OfUser to be set")
|
||||
}
|
||||
if user.Content.OfString.Value != "hello" {
|
||||
t.Errorf("OfString = %q, want %q", user.Content.OfString.Value, "hello")
|
||||
}
|
||||
if len(user.Content.OfArrayOfContentParts) != 0 {
|
||||
t.Errorf("OfArrayOfContentParts should be empty when no image, got %d parts", len(user.Content.OfArrayOfContentParts))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateMessage_UserWithImage_EmitsContentParts(t *testing.T) {
|
||||
pngBytes := []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}
|
||||
m := message.Message{
|
||||
Role: message.RoleUser,
|
||||
Content: []message.Content{
|
||||
message.NewTextContent("what is this?"),
|
||||
message.NewImageContent(message.Image{
|
||||
Data: pngBytes,
|
||||
MediaType: "image/png",
|
||||
Path: "/tmp/x.png",
|
||||
}),
|
||||
},
|
||||
}
|
||||
out := translateMessage(m)
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("got %d messages, want 1", len(out))
|
||||
}
|
||||
user := out[0].OfUser
|
||||
if user == nil {
|
||||
t.Fatal("expected OfUser to be set")
|
||||
}
|
||||
parts := user.Content.OfArrayOfContentParts
|
||||
if len(parts) != 2 {
|
||||
t.Fatalf("got %d content parts, want 2 (text + image)", len(parts))
|
||||
}
|
||||
gotText := parts[0].GetText()
|
||||
if gotText == nil || *gotText != "what is this?" {
|
||||
t.Errorf("first part should be text %q, got %v", "what is this?", gotText)
|
||||
}
|
||||
gotImg := parts[1].GetImageURL()
|
||||
if gotImg == nil {
|
||||
t.Fatal("second part should be image")
|
||||
}
|
||||
wantPrefix := "data:image/png;base64,"
|
||||
if !strings.HasPrefix(gotImg.URL, wantPrefix) {
|
||||
t.Errorf("image URL %q should start with %q", gotImg.URL, wantPrefix)
|
||||
}
|
||||
decoded, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(gotImg.URL, wantPrefix))
|
||||
if err != nil {
|
||||
t.Fatalf("base64 decode: %v", err)
|
||||
}
|
||||
if string(decoded) != string(pngBytes) {
|
||||
t.Error("decoded image bytes do not match original")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildUserContentParts_DropsEmptyImage(t *testing.T) {
|
||||
blocks := []message.Content{
|
||||
message.NewTextContent("a"),
|
||||
{Type: message.ContentImage, Image: nil},
|
||||
message.NewTextContent("b"),
|
||||
}
|
||||
parts := buildUserContentParts(blocks)
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("got %d parts, want 1 (adjacent text concatenated, nil image dropped)", len(parts))
|
||||
}
|
||||
if got := parts[0].GetText(); got == nil || *got != "ab" {
|
||||
t.Errorf("merged text = %v, want %q", got, "ab")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateMessage_AssistantToolCallNames_Sanitized(t *testing.T) {
|
||||
msg := message.Message{
|
||||
Role: message.RoleAssistant,
|
||||
|
||||
@@ -4,8 +4,10 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -25,6 +27,7 @@ const (
|
||||
FormatGeminiStreamJSON StreamFormat = "gemini-stream-json"
|
||||
FormatVibeStreaming StreamFormat = "vibe-streaming"
|
||||
FormatAgyText StreamFormat = "agy-text"
|
||||
FormatCodexStreamJSON StreamFormat = "codex-stream-json"
|
||||
)
|
||||
|
||||
// CLIAgent describes a known CLI agent binary.
|
||||
@@ -100,13 +103,8 @@ var knownAgents = []CLIAgent{
|
||||
Name: "agy",
|
||||
DisplayName: "Antigravity",
|
||||
ProbeArgs: []string{"--version"},
|
||||
PromptArgs: func(p string) []string {
|
||||
// --dangerously-skip-permissions parallels gemini's --yolo and
|
||||
// vibe's --trust: required for non-interactive runs since stdin
|
||||
// is closed and we cannot answer permission prompts.
|
||||
return []string{"--print", p, "--dangerously-skip-permissions"}
|
||||
},
|
||||
Format: FormatAgyText,
|
||||
PromptArgs: agyPromptArgs,
|
||||
Format: FormatAgyText,
|
||||
// JSONOutput / Vision left false: agy v1.0.0 has no native
|
||||
// structured-output flag and no image-input mechanism. JSON support
|
||||
// is faked via PromptResponseFormat (best-effort, model-dependent);
|
||||
@@ -117,6 +115,66 @@ var knownAgents = []CLIAgent{
|
||||
},
|
||||
PromptResponseFormat: true,
|
||||
},
|
||||
{
|
||||
Name: "codex",
|
||||
DisplayName: "Codex CLI",
|
||||
ProbeArgs: []string{"--version"},
|
||||
PromptArgs: codexPromptArgs,
|
||||
Format: FormatCodexStreamJSON,
|
||||
Capabilities: provider.Capabilities{
|
||||
ToolUse: true,
|
||||
ContextWindow: 200000,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// agySandboxBypassEnv toggles the --dangerously-skip-permissions flag passed
|
||||
// to agy. Defaults to "on" because agy's stdin is closed in our
|
||||
// non-interactive invocation; without the flag the CLI blocks on permission
|
||||
// prompts that nobody can answer. Mirrors the codex env in shape and
|
||||
// default for consistency.
|
||||
const agySandboxBypassEnv = "GNOMA_AGY_BYPASS_PERMISSIONS"
|
||||
|
||||
func agyBypassPermissions() bool {
|
||||
switch strings.ToLower(strings.TrimSpace(os.Getenv(agySandboxBypassEnv))) {
|
||||
case "0", "false", "no", "off":
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func agyPromptArgs(p string) []string {
|
||||
args := []string{"--print", p}
|
||||
if agyBypassPermissions() {
|
||||
args = append(args, "--dangerously-skip-permissions")
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
// codexSandboxBypassEnv toggles the --dangerously-bypass-approvals-and-sandbox
|
||||
// flag passed to codex. Defaults to "on" because codex's stdin is closed in
|
||||
// the non-interactive `exec` mode we use; without the bypass the CLI blocks
|
||||
// waiting for an approval prompt that nobody can answer and the turn hangs.
|
||||
// Operators who pre-approve via codex's own config (e.g. a workspace-level
|
||||
// trust file) can set this to "0", "false", or "no" to drop the flag.
|
||||
const codexSandboxBypassEnv = "GNOMA_CODEX_BYPASS_SANDBOX"
|
||||
|
||||
func codexBypassSandbox() bool {
|
||||
switch strings.ToLower(strings.TrimSpace(os.Getenv(codexSandboxBypassEnv))) {
|
||||
case "0", "false", "no", "off":
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func codexPromptArgs(p string) []string {
|
||||
args := []string{"exec", p, "--json"}
|
||||
if codexBypassSandbox() {
|
||||
args = append(args, "--dangerously-bypass-approvals-and-sandbox")
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
// newParser returns a FormatParser for the given format.
|
||||
@@ -130,6 +188,8 @@ func newParser(f StreamFormat, rf *provider.ResponseFormat) FormatParser {
|
||||
return newVibeParser()
|
||||
case FormatAgyText:
|
||||
return newAgyParser(rf)
|
||||
case FormatCodexStreamJSON:
|
||||
return newCodexParser()
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -54,6 +54,7 @@ func TestKnownAgents_ValidFormats(t *testing.T) {
|
||||
FormatGeminiStreamJSON: true,
|
||||
FormatVibeStreaming: true,
|
||||
FormatAgyText: true,
|
||||
FormatCodexStreamJSON: true,
|
||||
}
|
||||
for _, a := range knownAgents {
|
||||
if !valid[a.Format] {
|
||||
@@ -84,7 +85,7 @@ func TestNewParser_ReturnsParserForKnownFormats(t *testing.T) {
|
||||
FormatClaudeStreamJSON,
|
||||
FormatGeminiStreamJSON,
|
||||
FormatVibeStreaming,
|
||||
FormatAgyText,
|
||||
FormatCodexStreamJSON,
|
||||
}
|
||||
for _, f := range formats {
|
||||
p := newParser(f, nil)
|
||||
|
||||
@@ -0,0 +1,224 @@
|
||||
package subprocess
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
"somegit.dev/Owlibou/gnoma/internal/stream"
|
||||
)
|
||||
|
||||
func TestCodexPromptArgs_BypassDefaultsOn(t *testing.T) {
|
||||
t.Setenv("GNOMA_CODEX_BYPASS_SANDBOX", "")
|
||||
args := codexPromptArgs("hi")
|
||||
if !slices.Contains(args, "--dangerously-bypass-approvals-and-sandbox") {
|
||||
t.Errorf("default args should include sandbox bypass; got %v", args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexPromptArgs_BypassOptOut(t *testing.T) {
|
||||
for _, val := range []string{"0", "false", "no", "off", "FALSE"} {
|
||||
t.Run(val, func(t *testing.T) {
|
||||
t.Setenv("GNOMA_CODEX_BYPASS_SANDBOX", val)
|
||||
args := codexPromptArgs("hi")
|
||||
if slices.Contains(args, "--dangerously-bypass-approvals-and-sandbox") {
|
||||
t.Errorf("env=%q should drop bypass flag; got %v", val, args)
|
||||
}
|
||||
if !slices.Contains(args, "exec") || !slices.Contains(args, "--json") {
|
||||
t.Errorf("required base args missing; got %v", args)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexPromptArgs_UnknownValueDefaultsOn(t *testing.T) {
|
||||
t.Setenv("GNOMA_CODEX_BYPASS_SANDBOX", "maybe")
|
||||
args := codexPromptArgs("hi")
|
||||
if !slices.Contains(args, "--dangerously-bypass-approvals-and-sandbox") {
|
||||
t.Errorf("non-falsy value should keep bypass on; got %v", args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexParser_ExtractsTextDelta(t *testing.T) {
|
||||
p := newCodexParser()
|
||||
line := []byte(`{"type":"item.completed","item":{"type":"agent_message","text":"hello world"}}`)
|
||||
|
||||
evts, err := p.ParseLine(line)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(evts) == 0 {
|
||||
t.Fatal("expected at least one event")
|
||||
}
|
||||
if evts[0].Type != stream.EventTextDelta {
|
||||
t.Errorf("got type %v, want EventTextDelta", evts[0].Type)
|
||||
}
|
||||
if evts[0].Text != "hello world" {
|
||||
t.Errorf("got text %q, want %q", evts[0].Text, "hello world")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexParser_ExtractsUsageFromTurnCompleted(t *testing.T) {
|
||||
p := newCodexParser()
|
||||
line := []byte(`{"type":"turn.completed","usage":{"input_tokens":123,"output_tokens":45}}`)
|
||||
|
||||
evts, err := p.ParseLine(line)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var usageEvt *stream.Event
|
||||
for i := range evts {
|
||||
if evts[i].Type == stream.EventUsage {
|
||||
usageEvt = &evts[i]
|
||||
}
|
||||
}
|
||||
if usageEvt == nil {
|
||||
t.Fatal("no EventUsage emitted")
|
||||
}
|
||||
if usageEvt.Usage.InputTokens != 123 {
|
||||
t.Errorf("input_tokens: got %d, want 123", usageEvt.Usage.InputTokens)
|
||||
}
|
||||
if usageEvt.Usage.OutputTokens != 45 {
|
||||
t.Errorf("output_tokens: got %d, want 45", usageEvt.Usage.OutputTokens)
|
||||
}
|
||||
if usageEvt.StopReason != message.StopEndTurn {
|
||||
t.Errorf("stop_reason: got %v, want StopEndTurn", usageEvt.StopReason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexParser_ExtractsUsageFromPromptCompletionTokens(t *testing.T) {
|
||||
p := newCodexParser()
|
||||
line := []byte(`{"type":"turn.completed","usage":{"prompt_tokens":123,"completion_tokens":45}}`)
|
||||
|
||||
evts, err := p.ParseLine(line)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var usageEvt *stream.Event
|
||||
for i := range evts {
|
||||
if evts[i].Type == stream.EventUsage {
|
||||
usageEvt = &evts[i]
|
||||
}
|
||||
}
|
||||
if usageEvt == nil {
|
||||
t.Fatal("no EventUsage emitted")
|
||||
}
|
||||
if usageEvt.Usage.InputTokens != 123 {
|
||||
t.Errorf("input_tokens: got %d, want 123", usageEvt.Usage.InputTokens)
|
||||
}
|
||||
if usageEvt.Usage.OutputTokens != 45 {
|
||||
t.Errorf("output_tokens: got %d, want 45", usageEvt.Usage.OutputTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexParser_IgnoresOtherItemsAndTypes(t *testing.T) {
|
||||
p := newCodexParser()
|
||||
lines := [][]byte{
|
||||
[]byte(`{"type":"item.completed","item":{"type":"tool_call","text":"something"}}`),
|
||||
[]byte(`{"type":"other_type"}`),
|
||||
}
|
||||
|
||||
for _, line := range lines {
|
||||
evts, err := p.ParseLine(line)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if len(evts) != 0 {
|
||||
t.Errorf("expected 0 events, got %d", len(evts))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexParser_SkipsNonJSONBanners(t *testing.T) {
|
||||
p := newCodexParser()
|
||||
// Real codex output interleaves banner lines, blank lines, and
|
||||
// human-readable warnings with the JSON event stream. None of
|
||||
// these may abort the turn — only the JSON events matter.
|
||||
lines := [][]byte{
|
||||
[]byte(""),
|
||||
[]byte(" "),
|
||||
[]byte("codex v1.2.3 starting"),
|
||||
[]byte(`WARNING: sandbox bypass enabled`),
|
||||
[]byte(`{"type":"item.completed","item":{"type":"agent_message","text":"ok"}}`),
|
||||
[]byte("trailing diagnostics: 42ms"),
|
||||
}
|
||||
var sawText bool
|
||||
for _, line := range lines {
|
||||
evts, err := p.ParseLine(line)
|
||||
if err != nil {
|
||||
t.Errorf("non-JSON line %q caused error: %v", string(line), err)
|
||||
continue
|
||||
}
|
||||
for _, e := range evts {
|
||||
if e.Type == stream.EventTextDelta {
|
||||
sawText = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if !sawText {
|
||||
t.Error("legitimate JSON line was swallowed by banner-skip logic")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexParser_MalformedJSONSkippedNotFatal(t *testing.T) {
|
||||
p := newCodexParser()
|
||||
// Starts with `{` so the banner-skip heuristic doesn't filter it,
|
||||
// but is not valid JSON — must skip silently, not return an error.
|
||||
bad := []byte(`{"type":"item.completed",`)
|
||||
evts, err := p.ParseLine(bad)
|
||||
if err != nil {
|
||||
t.Errorf("malformed JSON should be skipped, got error: %v", err)
|
||||
}
|
||||
if len(evts) != 0 {
|
||||
t.Errorf("expected 0 events from malformed JSON, got %d", len(evts))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexParser_UsageMaxOfPaths(t *testing.T) {
|
||||
// Both input_tokens and prompt_tokens present with different values
|
||||
// — accounting must not silently undercount by always preferring
|
||||
// one field.
|
||||
p := newCodexParser()
|
||||
line := []byte(`{"type":"turn.completed","usage":{"input_tokens":100,"prompt_tokens":120,"output_tokens":30,"completion_tokens":35}}`)
|
||||
evts, err := p.ParseLine(line)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(evts) != 1 || evts[0].Type != stream.EventUsage {
|
||||
t.Fatalf("expected single EventUsage, got %+v", evts)
|
||||
}
|
||||
if evts[0].Usage.InputTokens != 120 {
|
||||
t.Errorf("input tokens = %d, want max(100, 120) = 120", evts[0].Usage.InputTokens)
|
||||
}
|
||||
if evts[0].Usage.OutputTokens != 35 {
|
||||
t.Errorf("output tokens = %d, want max(30, 35) = 35", evts[0].Usage.OutputTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexParser_FixtureFile(t *testing.T) {
|
||||
lines := loadFixture(t, "codex")
|
||||
p := newCodexParser()
|
||||
evts := collectEvents(t, p, lines)
|
||||
|
||||
var textEvts, usageEvts int
|
||||
for _, e := range evts {
|
||||
switch e.Type {
|
||||
case stream.EventTextDelta:
|
||||
textEvts++
|
||||
if e.Text != "hello" {
|
||||
t.Errorf("expected text 'hello', got %q", e.Text)
|
||||
}
|
||||
case stream.EventUsage:
|
||||
usageEvts++
|
||||
if e.Usage.InputTokens != 10 || e.Usage.OutputTokens != 5 {
|
||||
t.Errorf("expected 10/5 tokens, got %d/%d", e.Usage.InputTokens, e.Usage.OutputTokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
if textEvts != 1 {
|
||||
t.Errorf("expected 1 EventTextDelta, got %d", textEvts)
|
||||
}
|
||||
if usageEvts != 1 {
|
||||
t.Errorf("expected 1 EventUsage, got %d", usageEvts)
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,10 @@
|
||||
package subprocess
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||||
@@ -249,3 +251,87 @@ func (p *agyParser) ParseLine(line []byte) ([]stream.Event, error) {
|
||||
}
|
||||
|
||||
func (p *agyParser) Done() []stream.Event { return nil }
|
||||
|
||||
// --- codex-stream-json ---
|
||||
// Format emitted by: codex exec "..." --json --dangerously-bypass-approvals-and-sandbox
|
||||
//
|
||||
// Relevant event types:
|
||||
// type=item.completed, item.type=agent_message → EventTextDelta (using item.text)
|
||||
// type=turn.completed → EventUsage (using usage)
|
||||
|
||||
type codexParser struct{}
|
||||
|
||||
func newCodexParser() FormatParser { return &codexParser{} }
|
||||
|
||||
type codexEvent struct {
|
||||
Type string `json:"type"`
|
||||
Item *codexItem `json:"item,omitempty"`
|
||||
Usage *codexUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type codexItem struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type codexUsage struct {
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
PromptTokens int64 `json:"prompt_tokens"`
|
||||
CompletionTokens int64 `json:"completion_tokens"`
|
||||
}
|
||||
|
||||
func (p *codexParser) ParseLine(line []byte) ([]stream.Event, error) {
|
||||
// Codex emits banner/debug lines to stdout interleaved with the JSON
|
||||
// event stream (version notes, sandbox warnings, "starting turn" log
|
||||
// lines, etc.). Skip anything that isn't a JSON object so a stray
|
||||
// banner can't abort the turn — subprocessStream.Next treats a
|
||||
// parser error as terminal.
|
||||
trimmed := bytes.TrimSpace(line)
|
||||
if len(trimmed) == 0 || trimmed[0] != '{' {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var ev codexEvent
|
||||
if err := json.Unmarshal(trimmed, &ev); err != nil {
|
||||
// Looks like JSON but won't parse — log and skip rather than
|
||||
// killing the stream; codex JSON-line output is the only path
|
||||
// we have to recover from a malformed line.
|
||||
slog.Debug("codex: skipping unparseable JSON line", "err", err, "line", string(trimmed))
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
switch ev.Type {
|
||||
case "item.completed":
|
||||
if ev.Item != nil && ev.Item.Type == "agent_message" && ev.Item.Text != "" {
|
||||
return []stream.Event{{Type: stream.EventTextDelta, Text: ev.Item.Text}}, nil
|
||||
}
|
||||
case "turn.completed":
|
||||
if ev.Usage != nil {
|
||||
// Some codex builds emit input_tokens, others (older) emit
|
||||
// prompt_tokens; new builds occasionally include both with
|
||||
// slightly different values. max() prevents silent
|
||||
// undercounting when both are non-zero.
|
||||
input := ev.Usage.InputTokens
|
||||
if ev.Usage.PromptTokens > input {
|
||||
input = ev.Usage.PromptTokens
|
||||
}
|
||||
output := ev.Usage.OutputTokens
|
||||
if ev.Usage.CompletionTokens > output {
|
||||
output = ev.Usage.CompletionTokens
|
||||
}
|
||||
return []stream.Event{{
|
||||
Type: stream.EventUsage,
|
||||
Usage: &message.Usage{
|
||||
InputTokens: input,
|
||||
OutputTokens: output,
|
||||
},
|
||||
StopReason: message.StopEndTurn,
|
||||
}}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (p *codexParser) Done() []stream.Event { return nil }
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
// Package subprocess provides a provider.Provider that delegates to CLI agents
|
||||
// (claude, gemini, vibe, agy) by spawning them as subprocesses.
|
||||
// (claude, gemini, vibe, codex) by spawning them as subprocesses.
|
||||
//
|
||||
// Impedance mismatch: these CLI agents are full agentic loops, not LLM endpoints.
|
||||
// Only the latest user message is passed as a prompt. The following provider.Request
|
||||
// fields are intentionally ignored: Tools, SystemPrompt, Messages (history),
|
||||
// Temperature, TopP, TopK, Thinking, ToolChoice, MaxTokens.
|
||||
// ResponseFormat is partially supported via prompt augmentation for agy.
|
||||
// Internal tool calls executed by the CLI are surfaced as EventTextDelta (opaque).
|
||||
//
|
||||
// SECURITY WARNING: These CLI agents are external trust boundaries. They run
|
||||
@@ -38,7 +37,7 @@ func New(agent DiscoveredAgent) *Provider {
|
||||
// Name returns "subprocess" — all CLI agents share this provider namespace.
|
||||
func (p *Provider) Name() string { return "subprocess" }
|
||||
|
||||
// DefaultModel returns the CLI binary name (e.g., "claude", "gemini", "vibe", "agy").
|
||||
// DefaultModel returns the CLI binary name (e.g., "claude", "gemini", "vibe", "codex").
|
||||
func (p *Provider) DefaultModel() string { return p.agent.Name }
|
||||
|
||||
// Models returns a single ModelInfo describing this CLI agent.
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
{"type":"item.completed", "item":{"type":"agent_message", "text":"hello"}}
|
||||
{"type":"item.completed", "item":{"type":"tool_call", "text":"ignored"}}
|
||||
{"type":"turn.completed", "usage":{"input_tokens": 10, "output_tokens": 5}}
|
||||
+112
-37
@@ -25,19 +25,31 @@ const (
|
||||
|
||||
// DiscoveredModel represents a model found via discovery.
|
||||
type DiscoveredModel struct {
|
||||
ID string
|
||||
Name string
|
||||
Provider string // "ollama" or "llamacpp"
|
||||
Size int64 // bytes, if available
|
||||
SupportsTools bool // whether the model supports function/tool calling
|
||||
ContextSize int // context window in tokens (always populated; provider-specific default if probe was inconclusive)
|
||||
ID string
|
||||
Name string
|
||||
Provider string // "ollama" or "llamacpp"
|
||||
Size int64 // bytes, if available
|
||||
SupportsTools bool // whether the model supports function/tool calling
|
||||
SupportsVision bool // whether the model accepts image inputs (multimodal)
|
||||
ContextSize int // context window in tokens (always populated; provider-specific default if probe was inconclusive)
|
||||
}
|
||||
|
||||
// OllamaProbeResult bundles the capabilities probed from a single
|
||||
// /api/show call. Cached per model name so discovery cycles don't re-probe
|
||||
// every model. SupportsVision was added alongside SupportsTools; older
|
||||
// callers using `map[string]bool` should migrate to `map[string]OllamaProbeResult`.
|
||||
type OllamaProbeResult struct {
|
||||
SupportsTools bool
|
||||
SupportsVision bool
|
||||
ContextSize int
|
||||
}
|
||||
|
||||
// DiscoverOllama polls the local Ollama instance for available models.
|
||||
// toolCache caches /api/show probe results per model name to avoid N requests
|
||||
// per discovery cycle. Pass nil to probe every model unconditionally.
|
||||
// The caller owns the cache and should pass the same map across cycles.
|
||||
func DiscoverOllama(ctx context.Context, baseURL string, toolCache map[string]bool) ([]DiscoveredModel, error) {
|
||||
// probeCache caches /api/show probe results per model name to avoid N
|
||||
// requests per discovery cycle. Pass nil to probe every model
|
||||
// unconditionally. The caller owns the cache and should pass the same
|
||||
// map across cycles.
|
||||
func DiscoverOllama(ctx context.Context, baseURL string, probeCache map[string]OllamaProbeResult) ([]DiscoveredModel, error) {
|
||||
if baseURL == "" {
|
||||
baseURL = "http://localhost:11434"
|
||||
}
|
||||
@@ -81,17 +93,15 @@ func DiscoverOllama(ctx context.Context, baseURL string, toolCache map[string]bo
|
||||
Size: m.Size,
|
||||
}
|
||||
|
||||
// Try to probe capabilities if we have a cache or if we want to probe
|
||||
if toolCache != nil {
|
||||
if supported, ok := toolCache[m.Name]; ok {
|
||||
dm.SupportsTools = supported
|
||||
} else {
|
||||
// Probe once
|
||||
supported, contextSize := probeOllamaModel(ctx, baseURL, m.Name)
|
||||
toolCache[m.Name] = supported
|
||||
dm.SupportsTools = supported
|
||||
dm.ContextSize = contextSize
|
||||
if probeCache != nil {
|
||||
result, ok := probeCache[m.Name]
|
||||
if !ok {
|
||||
result = probeOllamaModel(ctx, baseURL, m.Name)
|
||||
probeCache[m.Name] = result
|
||||
}
|
||||
dm.SupportsTools = result.SupportsTools
|
||||
dm.SupportsVision = result.SupportsVision
|
||||
dm.ContextSize = result.ContextSize
|
||||
}
|
||||
|
||||
if dm.ContextSize == 0 {
|
||||
@@ -103,43 +113,75 @@ func DiscoverOllama(ctx context.Context, baseURL string, toolCache map[string]bo
|
||||
|
||||
// Prune cache entries for models that have disappeared since the last
|
||||
// poll. Without this, the cache grows unbounded and stale entries linger
|
||||
// (a reappearing model would replay an out-of-date tool-support verdict).
|
||||
for name := range toolCache {
|
||||
// (a reappearing model would replay an out-of-date probe verdict).
|
||||
for name := range probeCache {
|
||||
if !currentModels[name] {
|
||||
delete(toolCache, name)
|
||||
delete(probeCache, name)
|
||||
}
|
||||
}
|
||||
return discovered, nil
|
||||
}
|
||||
|
||||
func probeOllamaModel(ctx context.Context, baseURL, model string) (bool, int) {
|
||||
func probeOllamaModel(ctx context.Context, baseURL, model string) OllamaProbeResult {
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", baseURL+"/api/show", strings.NewReader(fmt.Sprintf(`{"name":"%s"}`, model)))
|
||||
if err != nil {
|
||||
return false, 0
|
||||
return OllamaProbeResult{}
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return false, 0
|
||||
return OllamaProbeResult{}
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != 200 {
|
||||
return false, 0
|
||||
return OllamaProbeResult{}
|
||||
}
|
||||
var data struct {
|
||||
Template string `json:"template"`
|
||||
Parameters string `json:"parameters"`
|
||||
Details struct {
|
||||
Families []string `json:"families"`
|
||||
Family string `json:"family"`
|
||||
} `json:"details"`
|
||||
Capabilities []string `json:"capabilities"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
|
||||
return false, 0
|
||||
return OllamaProbeResult{}
|
||||
}
|
||||
|
||||
// Heuristic for tool support: many modern models that support tools
|
||||
// have "call" or "tool" or "json" in their template or system prompt
|
||||
// logic. More specifically, Ollama's own tool-calling models often
|
||||
// include specific jinja templates.
|
||||
supported := strings.Contains(data.Template, ".Tool") ||
|
||||
// include specific jinja templates. Newer Ollama versions also
|
||||
// advertise capabilities via the "capabilities" field.
|
||||
supportsTools := strings.Contains(data.Template, ".Tool") ||
|
||||
strings.Contains(data.Template, "tools") ||
|
||||
strings.Contains(data.Template, "json")
|
||||
for _, cap := range data.Capabilities {
|
||||
if cap == "tools" {
|
||||
supportsTools = true
|
||||
}
|
||||
}
|
||||
|
||||
// Vision detection: CLIP/vision encoder families show up in
|
||||
// details.families (e.g. "clip", "mllama"); newer Ollama also lists
|
||||
// "vision" in the capabilities array. Fall back to a name-pattern
|
||||
// match for releases that predate the capabilities field.
|
||||
supportsVision := false
|
||||
for _, fam := range data.Details.Families {
|
||||
f := strings.ToLower(fam)
|
||||
if f == "clip" || f == "mllama" || strings.HasSuffix(f, "vl") {
|
||||
supportsVision = true
|
||||
break
|
||||
}
|
||||
}
|
||||
for _, cap := range data.Capabilities {
|
||||
if cap == "vision" {
|
||||
supportsVision = true
|
||||
}
|
||||
}
|
||||
if !supportsVision && isKnownVisionModelName(model) {
|
||||
supportsVision = true
|
||||
}
|
||||
|
||||
// Context size heuristic from parameters
|
||||
contextSize := 0
|
||||
@@ -154,7 +196,39 @@ func probeOllamaModel(ctx context.Context, baseURL, model string) (bool, int) {
|
||||
}
|
||||
}
|
||||
|
||||
return supported, contextSize
|
||||
return OllamaProbeResult{
|
||||
SupportsTools: supportsTools,
|
||||
SupportsVision: supportsVision,
|
||||
ContextSize: contextSize,
|
||||
}
|
||||
}
|
||||
|
||||
// knownVisionModelPrefixes lists Ollama model name prefixes that ship as
|
||||
// multimodal models. Used as a fallback when the /api/show response is
|
||||
// missing details.families or the capabilities array (older Ollama).
|
||||
var knownVisionModelPrefixes = []string{
|
||||
"llava",
|
||||
"bakllava",
|
||||
"moondream",
|
||||
"qwen2-vl",
|
||||
"qwen2.5-vl",
|
||||
"qwen3-vl",
|
||||
"llama3.2-vision",
|
||||
"llama4-vision",
|
||||
"minicpm-v",
|
||||
"cogvlm",
|
||||
"pixtral",
|
||||
"gemma3", // gemma3 multimodal variants
|
||||
}
|
||||
|
||||
func isKnownVisionModelName(model string) bool {
|
||||
low := strings.ToLower(model)
|
||||
for _, p := range knownVisionModelPrefixes {
|
||||
if strings.HasPrefix(low, p) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// DiscoverLlamaCPP enumerates models served by a llama.cpp server.
|
||||
@@ -261,10 +335,10 @@ func fetchLlamaCppContextSize(ctx context.Context, baseURL string) int {
|
||||
}
|
||||
|
||||
// DiscoverLocalModels polls all known local providers.
|
||||
func DiscoverLocalModels(ctx context.Context, logger *slog.Logger, ollamaURL, llamacppURL string, ollamaToolCache map[string]bool) []DiscoveredModel {
|
||||
func DiscoverLocalModels(ctx context.Context, logger *slog.Logger, ollamaURL, llamacppURL string, ollamaProbeCache map[string]OllamaProbeResult) []DiscoveredModel {
|
||||
var all []DiscoveredModel
|
||||
|
||||
if models, err := DiscoverOllama(ctx, ollamaURL, ollamaToolCache); err != nil {
|
||||
if models, err := DiscoverOllama(ctx, ollamaURL, ollamaProbeCache); err != nil {
|
||||
logger.Debug("ollama discovery skipped", "error", err)
|
||||
} else {
|
||||
all = append(all, models...)
|
||||
@@ -288,7 +362,7 @@ func StartDiscoveryLoop(ctx context.Context, r *Router, logger *slog.Logger,
|
||||
onReconcile func(ArmID),
|
||||
) {
|
||||
go func() {
|
||||
ollamaToolCache := make(map[string]bool)
|
||||
ollamaProbeCache := make(map[string]OllamaProbeResult)
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
@@ -296,7 +370,7 @@ func StartDiscoveryLoop(ctx context.Context, r *Router, logger *slog.Logger,
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
models := DiscoverLocalModels(ctx, logger, ollamaURL, llamacppURL, ollamaToolCache)
|
||||
models := DiscoverLocalModels(ctx, logger, ollamaURL, llamacppURL, ollamaProbeCache)
|
||||
reconcileArms(r, models, providerFactory, logger, onReconcile)
|
||||
}
|
||||
}
|
||||
@@ -390,9 +464,10 @@ func RegisterDiscoveredModels(r *Router, models []DiscoveredModel, providerFacto
|
||||
// Many small local models (phi, etc.) don't support
|
||||
// function calling and will produce confused output if selected
|
||||
// for tool-requiring tasks. Larger known models (mistral, llama3,
|
||||
// qwen2.5-coder, tiny3.5) support tools. Callers can update the arm's
|
||||
// Capabilities after probing the model template.
|
||||
// qwen2.5-coder, tiny3.5) support tools. Vision is set from the
|
||||
// /api/show probe (capabilities/families/name fallback).
|
||||
ToolUse: m.SupportsTools,
|
||||
Vision: m.SupportsVision,
|
||||
ContextWindow: m.ContextSize,
|
||||
},
|
||||
})
|
||||
|
||||
@@ -270,7 +270,7 @@ func TestDiscoverOllama_AppliesDefaultContextSize(t *testing.T) {
|
||||
srv := stub.server()
|
||||
defer srv.Close()
|
||||
|
||||
cache := map[string]bool{}
|
||||
cache := map[string]OllamaProbeResult{}
|
||||
models, err := DiscoverOllama(context.Background(), srv.URL, cache)
|
||||
if err != nil {
|
||||
t.Fatalf("DiscoverOllama: %v", err)
|
||||
@@ -296,10 +296,10 @@ func TestDiscoverOllama_PrunesCacheOnDisappearance(t *testing.T) {
|
||||
srv := stub.server()
|
||||
defer srv.Close()
|
||||
|
||||
cache := map[string]bool{
|
||||
"alive:latest": true,
|
||||
"ghost:latest": true, // not in tags response — must be pruned
|
||||
"another-ghost": false,
|
||||
cache := map[string]OllamaProbeResult{
|
||||
"alive:latest": {SupportsTools: true},
|
||||
"ghost:latest": {SupportsTools: true}, // not in tags response — must be pruned
|
||||
"another-ghost": {},
|
||||
}
|
||||
if _, err := DiscoverOllama(context.Background(), srv.URL, cache); err != nil {
|
||||
t.Fatalf("DiscoverOllama: %v", err)
|
||||
|
||||
@@ -236,6 +236,14 @@ func filterFeasible(arms []*Arm, task Task) []*Arm {
|
||||
continue
|
||||
}
|
||||
|
||||
// Must support vision if task carries inline image content.
|
||||
// No tools/quality fallback for vision: a non-vision arm physically
|
||||
// cannot consume the image bytes, so degrading to it would silently
|
||||
// drop the image and confuse the model.
|
||||
if task.RequiresVision && !arm.Capabilities.Vision {
|
||||
continue
|
||||
}
|
||||
|
||||
// Must support the required effort level (EffortAuto always passes)
|
||||
if !arm.Capabilities.SupportsEffort(task.RequiredEffort) {
|
||||
continue
|
||||
@@ -274,6 +282,12 @@ func filterFeasible(arms []*Arm, task Task) []*Arm {
|
||||
if !arm.Capabilities.ToolUse {
|
||||
continue
|
||||
}
|
||||
// Vision requirement is hard: a non-vision arm cannot
|
||||
// consume image bytes, so even the last-resort fallback
|
||||
// must respect it.
|
||||
if task.RequiresVision && !arm.Capabilities.Vision {
|
||||
continue
|
||||
}
|
||||
poolsOK := true
|
||||
for _, pool := range arm.Pools {
|
||||
if !pool.CanAfford(arm.ID, task.EstimatedTokens) {
|
||||
|
||||
@@ -91,6 +91,7 @@ type Task struct {
|
||||
Priority Priority
|
||||
EstimatedTokens int
|
||||
RequiresTools bool
|
||||
RequiresVision bool // input includes inline image content; arm must advertise Capabilities.Vision
|
||||
ComplexityScore float64 // 0-1
|
||||
RequiredEffort provider.EffortLevel // EffortAuto = no constraint on thinking
|
||||
ExcludedArms []ArmID // Arms to avoid (e.g. due to recent 429 errors)
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||||
)
|
||||
|
||||
func TestFilterFeasible_RequiresVision_FiltersNonVisionArms(t *testing.T) {
|
||||
textOnly := &Arm{
|
||||
ID: NewArmID("ollama", "qwen2.5-coder:7b"),
|
||||
Capabilities: provider.Capabilities{
|
||||
ToolUse: true,
|
||||
Vision: false,
|
||||
ContextWindow: 32768,
|
||||
},
|
||||
}
|
||||
visionArm := &Arm{
|
||||
ID: NewArmID("ollama", "llava:7b"),
|
||||
Capabilities: provider.Capabilities{
|
||||
ToolUse: true,
|
||||
Vision: true,
|
||||
ContextWindow: 4096,
|
||||
},
|
||||
}
|
||||
arms := []*Arm{textOnly, visionArm}
|
||||
|
||||
t.Run("no image: both arms feasible", func(t *testing.T) {
|
||||
task := Task{Type: TaskGeneration, RequiresTools: true, RequiresVision: false}
|
||||
got := filterFeasible(arms, task)
|
||||
if len(got) != 2 {
|
||||
t.Errorf("got %d arms, want 2", len(got))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("image present: only vision arm feasible", func(t *testing.T) {
|
||||
task := Task{Type: TaskGeneration, RequiresTools: true, RequiresVision: true}
|
||||
got := filterFeasible(arms, task)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("got %d arms, want 1", len(got))
|
||||
}
|
||||
if got[0].ID != visionArm.ID {
|
||||
t.Errorf("selected arm = %s, want %s", got[0].ID, visionArm.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFilterFeasible_RequiresVision_FallbackAlsoFilters(t *testing.T) {
|
||||
// All arms unavailable for normal quality path; fallback path must
|
||||
// still respect RequiresVision (can't degrade to a text-only arm
|
||||
// when the model literally cannot see the image).
|
||||
textOnly := &Arm{
|
||||
ID: NewArmID("ollama", "qwen2.5:0.5b"), // tiny → low quality
|
||||
Capabilities: provider.Capabilities{
|
||||
ToolUse: true,
|
||||
Vision: false,
|
||||
ContextWindow: 4096,
|
||||
},
|
||||
}
|
||||
arms := []*Arm{textOnly}
|
||||
|
||||
task := Task{
|
||||
Type: TaskGeneration,
|
||||
RequiresTools: true,
|
||||
RequiresVision: true,
|
||||
}
|
||||
got := filterFeasible(arms, task)
|
||||
if len(got) != 0 {
|
||||
t.Errorf("got %d arms, want 0 — non-vision arm must not be selected even as fallback", len(got))
|
||||
}
|
||||
}
|
||||
@@ -25,6 +25,7 @@ type FirewallConfig struct {
|
||||
ScanToolResults bool
|
||||
RedactHighEntropy bool
|
||||
EntropyThreshold float64
|
||||
EntropySafelist []string
|
||||
Logger *slog.Logger
|
||||
}
|
||||
|
||||
@@ -33,8 +34,20 @@ func NewFirewall(cfg FirewallConfig) *Firewall {
|
||||
if logger == nil {
|
||||
logger = slog.Default()
|
||||
}
|
||||
scanner := NewScanner(cfg.EntropyThreshold, cfg.RedactHighEntropy)
|
||||
scanner.SetLogger(logger)
|
||||
// Validate safelist names at the config boundary so a typo surfaces
|
||||
// loudly instead of silently disabling FP reduction.
|
||||
entries, unknown := splitSafelistNames(cfg.EntropySafelist)
|
||||
for _, name := range unknown {
|
||||
logger.Warn("ignoring unknown entropy safelist name",
|
||||
"name", name,
|
||||
"hint", "valid names: uuid, sha_hex, iso8601, url",
|
||||
)
|
||||
}
|
||||
scanner.safelist = entries
|
||||
return &Firewall{
|
||||
scanner: NewScanner(cfg.EntropyThreshold, cfg.RedactHighEntropy),
|
||||
scanner: scanner,
|
||||
incognito: NewIncognitoMode(),
|
||||
logger: logger,
|
||||
scanOutgoing: cfg.ScanOutgoing,
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
package security
|
||||
|
||||
import "regexp"
|
||||
|
||||
// safelistEntry pairs a user-facing pattern name (the TOML knob value) with
|
||||
// its compiled regex. The name flows through to log fields so operators can
|
||||
// measure per-pattern FP-rate deltas — the data F-2's go/no-go decision
|
||||
// depends on.
|
||||
type safelistEntry struct {
|
||||
name string
|
||||
re *regexp.Regexp
|
||||
}
|
||||
|
||||
// safelistSpan is a half-open byte range [start, end) in the scanned content
|
||||
// that the user has declared as a known-safe shape (UUID, hash, URL, timestamp).
|
||||
// Tokens contained inside any span are skipped by scanEntropy — they never
|
||||
// reach the entropy scorer, so they cannot produce false positives under
|
||||
// lowered thresholds or redact_high_entropy = true.
|
||||
type safelistSpan struct {
|
||||
start int
|
||||
end int
|
||||
name string
|
||||
}
|
||||
|
||||
// defaultSafelistPatterns returns the curated allow-list of known-safe shapes,
|
||||
// keyed by the user-facing name accepted in [security].entropy_safelist.
|
||||
//
|
||||
// Adding a key here exposes a new opt-in name to user configs. Removing or
|
||||
// renaming a key is a breaking change.
|
||||
func defaultSafelistPatterns() map[string]*regexp.Regexp {
|
||||
return map[string]*regexp.Regexp{
|
||||
// UUID v1–5: 8-4-4-4-12 hex with hyphens. Case-insensitive.
|
||||
"uuid": regexp.MustCompile(`(?i)\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b`),
|
||||
|
||||
// SHA-1 / SHA-256 / SHA-384 / SHA-512 hex digests.
|
||||
"sha_hex": regexp.MustCompile(`(?i)\b(?:[0-9a-f]{40}|[0-9a-f]{64}|[0-9a-f]{96}|[0-9a-f]{128})\b`),
|
||||
|
||||
// ISO-8601 timestamp (date + time, optional fractional seconds, optional zone).
|
||||
"iso8601": regexp.MustCompile(`\b\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(?:\.\d+)?(?:Z|[+-]\d{2}:?\d{2})?\b`),
|
||||
|
||||
// RFC-3986-ish HTTP(S) URL. Greedy up to whitespace or quoting.
|
||||
"url": regexp.MustCompile(`\bhttps?://[^\s'"<>` + "`" + `]+`),
|
||||
}
|
||||
}
|
||||
|
||||
// splitSafelistNames partitions user-supplied names into resolved entries and
|
||||
// the list of unknown names. Callers (NewFirewall) surface unknowns so a typo
|
||||
// like "uid" instead of "uuid" doesn't silently disable the safelist.
|
||||
func splitSafelistNames(names []string) (entries []safelistEntry, unknown []string) {
|
||||
if len(names) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
defaults := defaultSafelistPatterns()
|
||||
for _, name := range names {
|
||||
if re, ok := defaults[name]; ok {
|
||||
entries = append(entries, safelistEntry{name: name, re: re})
|
||||
} else {
|
||||
unknown = append(unknown, name)
|
||||
}
|
||||
}
|
||||
return entries, unknown
|
||||
}
|
||||
|
||||
// buildSafelist resolves names to entries, dropping unknowns silently. Used
|
||||
// where the caller doesn't need to report typos (e.g. test setup).
|
||||
func buildSafelist(names []string) []safelistEntry {
|
||||
entries, _ := splitSafelistNames(names)
|
||||
return entries
|
||||
}
|
||||
|
||||
// safelistSpansFor returns every safelist match in content, tagged with the
|
||||
// pattern name that produced it. Spans may overlap; containment is checked
|
||||
// per-token in scanEntropy.
|
||||
func safelistSpansFor(content string, entries []safelistEntry) []safelistSpan {
|
||||
if len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
var spans []safelistSpan
|
||||
for _, e := range entries {
|
||||
for _, loc := range e.re.FindAllStringIndex(content, -1) {
|
||||
spans = append(spans, safelistSpan{start: loc[0], end: loc[1], name: e.name})
|
||||
}
|
||||
}
|
||||
return spans
|
||||
}
|
||||
|
||||
// inAnySpan reports whether [start, end) lies fully inside any safelist span.
|
||||
// Returns the matching pattern name so the skip can be logged for FP-rate
|
||||
// telemetry — the data F-2 gates on.
|
||||
func inAnySpan(spans []safelistSpan, start, end int) (string, bool) {
|
||||
for _, s := range spans {
|
||||
if start >= s.start && end <= s.end {
|
||||
return s.name, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
@@ -0,0 +1,294 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// A real high-entropy token (random base64-ish) used as the "secret"
|
||||
// in mixed-payload tests. Confirmed to score >= 4.5 with the default
|
||||
// alphabet and to be long enough (>=20 chars) to enter scanEntropy.
|
||||
const secretToken = "x9KqLm2pNvBz3RtYwH7Xj4QsDc8Fa6Vu"
|
||||
|
||||
// loweredThreshold sits below typical UUID/hash entropy (UUID v4 ≈ 3.4,
|
||||
// SHA hex ≈ 3.9). The plan flags this regime — lowered threshold or
|
||||
// redact_high_entropy = true — as where FPs bite. F-1 must remove them.
|
||||
const loweredThreshold = 3.0
|
||||
|
||||
func TestSafelist_UUIDIsSkipped(t *testing.T) {
|
||||
s := NewScanner(loweredThreshold, true)
|
||||
s.SetSafelist([]string{"uuid"})
|
||||
|
||||
matches := s.Scan("trace_id=550e8400-e29b-41d4-a716-446655440000 done")
|
||||
for _, m := range matches {
|
||||
if m.Pattern == "high_entropy" {
|
||||
t.Errorf("UUID should not be flagged as high_entropy: %+v", m)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafelist_SHA256IsSkipped(t *testing.T) {
|
||||
s := NewScanner(4.5, true)
|
||||
s.SetSafelist([]string{"sha_hex"})
|
||||
|
||||
sha256 := "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
|
||||
matches := s.Scan("commit " + sha256)
|
||||
for _, m := range matches {
|
||||
if m.Pattern == "high_entropy" {
|
||||
t.Errorf("SHA-256 should not be flagged as high_entropy: %+v", m)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafelist_SHA1IsSkipped(t *testing.T) {
|
||||
s := NewScanner(4.5, true)
|
||||
s.SetSafelist([]string{"sha_hex"})
|
||||
|
||||
sha1 := "356a192b7913b04c54574d18c28d46e6395428ab"
|
||||
matches := s.Scan("blob " + sha1)
|
||||
for _, m := range matches {
|
||||
if m.Pattern == "high_entropy" {
|
||||
t.Errorf("SHA-1 should not be flagged as high_entropy: %+v", m)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafelist_MixedPayload_SecretStillCaught(t *testing.T) {
|
||||
s := NewScanner(loweredThreshold, true)
|
||||
s.SetSafelist([]string{"uuid", "sha_hex"})
|
||||
|
||||
uuid := "550e8400-e29b-41d4-a716-446655440000"
|
||||
content := "id=" + uuid + " secret=" + secretToken
|
||||
|
||||
matches := s.Scan(content)
|
||||
|
||||
var entropyHits []SecretMatch
|
||||
for _, m := range matches {
|
||||
if m.Pattern == "high_entropy" {
|
||||
entropyHits = append(entropyHits, m)
|
||||
}
|
||||
}
|
||||
if len(entropyHits) != 1 {
|
||||
t.Fatalf("want 1 entropy hit (the actual secret), got %d: %+v", len(entropyHits), entropyHits)
|
||||
}
|
||||
// Confirm the hit covers the secret, not the UUID.
|
||||
hit := content[entropyHits[0].Start:entropyHits[0].End]
|
||||
if hit != secretToken {
|
||||
t.Errorf("entropy hit covered %q, want %q", hit, secretToken)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafelist_EmptyPreservesCurrentBehavior(t *testing.T) {
|
||||
// No safelist configured — under a lowered threshold the UUID trips
|
||||
// entropy. This is the pre-F-1 false positive the safelist removes;
|
||||
// here we lock in that pre-F-1 behaviour is unchanged when no safelist
|
||||
// is supplied.
|
||||
s := NewScanner(loweredThreshold, true) // SetSafelist intentionally not called
|
||||
|
||||
uuid := "550e8400-e29b-41d4-a716-446655440000"
|
||||
matches := s.Scan(uuid)
|
||||
|
||||
var entropyHits int
|
||||
for _, m := range matches {
|
||||
if m.Pattern == "high_entropy" {
|
||||
entropyHits++
|
||||
}
|
||||
}
|
||||
if entropyHits == 0 {
|
||||
t.Error("with no safelist + lowered threshold, UUID should still trigger entropy (pre-F-1 baseline)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafelist_UnknownNameIgnored(t *testing.T) {
|
||||
s := NewScanner(loweredThreshold, true)
|
||||
// "made_up" is not a known pattern — must be silently dropped, not panic.
|
||||
s.SetSafelist([]string{"uuid", "made_up", "sha_hex"})
|
||||
|
||||
uuid := "550e8400-e29b-41d4-a716-446655440000"
|
||||
matches := s.Scan(uuid)
|
||||
for _, m := range matches {
|
||||
if m.Pattern == "high_entropy" {
|
||||
t.Errorf("uuid should still be skipped despite unknown name in list: %+v", m)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafelist_URLPathNotFlagged(t *testing.T) {
|
||||
s := NewScanner(4.5, true)
|
||||
s.SetSafelist([]string{"url"})
|
||||
|
||||
// A high-entropy URL path — a real-world false positive shape.
|
||||
url := "https://example.com/" + secretToken
|
||||
matches := s.Scan(url)
|
||||
for _, m := range matches {
|
||||
if m.Pattern == "high_entropy" {
|
||||
hit := url[m.Start:m.End]
|
||||
t.Errorf("URL substring %q should be covered by url safelist", hit)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafelist_ISO8601Span(t *testing.T) {
|
||||
// ISO-8601 timestamps don't survive entropy tokenization as a single
|
||||
// 20+-char token (':' splits them), so this is mostly a sanity check
|
||||
// that declaring iso8601 doesn't break anything.
|
||||
s := NewScanner(4.5, true)
|
||||
s.SetSafelist([]string{"iso8601"})
|
||||
|
||||
ts := "2026-05-22T10:30:00.123Z"
|
||||
matches := s.Scan(ts)
|
||||
for _, m := range matches {
|
||||
if m.Pattern == "high_entropy" {
|
||||
t.Errorf("ISO-8601 timestamp should not trip entropy: %+v", m)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafelist_SecretAdjacentToUUIDStillRedacted(t *testing.T) {
|
||||
// Regression guard: a real secret that happens to abut a UUID must
|
||||
// not be swallowed by the UUID's safelist span.
|
||||
s := NewScanner(loweredThreshold, true)
|
||||
s.SetSafelist([]string{"uuid"})
|
||||
|
||||
uuid := "550e8400-e29b-41d4-a716-446655440000"
|
||||
content := uuid + " " + secretToken
|
||||
|
||||
matches := s.Scan(content)
|
||||
var foundSecret bool
|
||||
for _, m := range matches {
|
||||
if m.Pattern == "high_entropy" && content[m.Start:m.End] == secretToken {
|
||||
foundSecret = true
|
||||
}
|
||||
}
|
||||
if !foundSecret {
|
||||
t.Errorf("secret adjacent to UUID was not detected; matches=%+v", matches)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafelist_KnownPatternNamesMatchPlan(t *testing.T) {
|
||||
// Plan-locked names that the user-facing TOML knob accepts.
|
||||
// Changing these breaks user configs — bump with care.
|
||||
want := []string{"uuid", "sha_hex", "iso8601", "url"}
|
||||
got := defaultSafelistPatterns()
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("default safelist size = %d, want %d", len(got), len(want))
|
||||
}
|
||||
for _, name := range want {
|
||||
if _, ok := got[name]; !ok {
|
||||
t.Errorf("missing safelist pattern %q (have %v)", name, safelistKeys(got))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func safelistKeys[V any](m map[string]V) []string {
|
||||
out := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
out = append(out, k)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func TestFirewall_EntropySafelistEndToEnd(t *testing.T) {
|
||||
// End-to-end: FirewallConfig.EntropySafelist must flow through to
|
||||
// the scanner's runtime behavior. A SHA-256 in tool output should
|
||||
// survive an entropy-redacting firewall when sha_hex is safelisted.
|
||||
sha256 := "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
|
||||
content := "commit " + sha256 + " landed"
|
||||
|
||||
withSafelist := NewFirewall(FirewallConfig{
|
||||
ScanToolResults: true,
|
||||
RedactHighEntropy: true,
|
||||
EntropyThreshold: loweredThreshold,
|
||||
EntropySafelist: []string{"sha_hex"},
|
||||
})
|
||||
if got := withSafelist.ScanToolResult(content); !strings.Contains(got, sha256) {
|
||||
t.Errorf("safelisted SHA-256 should pass through, got %q", got)
|
||||
}
|
||||
|
||||
withoutSafelist := NewFirewall(FirewallConfig{
|
||||
ScanToolResults: true,
|
||||
RedactHighEntropy: true,
|
||||
EntropyThreshold: loweredThreshold,
|
||||
})
|
||||
if got := withoutSafelist.ScanToolResult(content); strings.Contains(got, sha256) {
|
||||
t.Errorf("without safelist the SHA-256 should be redacted at threshold %.1f, got %q", loweredThreshold, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirewall_UnknownSafelistNameWarns(t *testing.T) {
|
||||
// A typo like "uid" instead of "uuid" must surface as a Warn so the
|
||||
// operator notices, rather than silently disabling FP reduction.
|
||||
var buf bytes.Buffer
|
||||
logger := slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelWarn}))
|
||||
|
||||
_ = NewFirewall(FirewallConfig{
|
||||
EntropySafelist: []string{"uuid", "uid"}, // "uid" is the typo
|
||||
Logger: logger,
|
||||
})
|
||||
|
||||
logs := buf.String()
|
||||
if !strings.Contains(logs, "unknown entropy safelist name") {
|
||||
t.Errorf("expected warning about unknown name, got logs: %q", logs)
|
||||
}
|
||||
if !strings.Contains(logs, "uid") {
|
||||
t.Errorf("warning should name the unknown entry, got logs: %q", logs)
|
||||
}
|
||||
if strings.Contains(logs, "name=uuid ") || strings.Contains(logs, "name=uuid\n") {
|
||||
t.Errorf("known name 'uuid' should not be warned about, got logs: %q", logs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirewall_AllKnownSafelistNamesQuiet(t *testing.T) {
|
||||
// No warnings for any of the canonical names — guards against a
|
||||
// future code change that accidentally renames a default pattern.
|
||||
var buf bytes.Buffer
|
||||
logger := slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelWarn}))
|
||||
|
||||
_ = NewFirewall(FirewallConfig{
|
||||
EntropySafelist: []string{"uuid", "sha_hex", "iso8601", "url"},
|
||||
Logger: logger,
|
||||
})
|
||||
|
||||
if logs := buf.String(); logs != "" {
|
||||
t.Errorf("known safelist names should not warn, got: %q", logs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafelist_SkipIsLogged(t *testing.T) {
|
||||
// Per-pattern telemetry is the data F-2's go/no-go gate depends on.
|
||||
// Verify a skip emits a Debug log carrying the pattern name.
|
||||
var buf bytes.Buffer
|
||||
logger := slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
|
||||
s := NewScanner(loweredThreshold, true)
|
||||
s.SetLogger(logger)
|
||||
s.SetSafelist([]string{"uuid"})
|
||||
|
||||
uuid := "550e8400-e29b-41d4-a716-446655440000"
|
||||
_ = s.Scan(uuid)
|
||||
|
||||
logs := buf.String()
|
||||
if !strings.Contains(logs, "entropy candidate skipped by safelist") {
|
||||
t.Errorf("expected debug log on skip, got: %q", logs)
|
||||
}
|
||||
if !strings.Contains(logs, "pattern=uuid") {
|
||||
t.Errorf("debug log should carry pattern name, got: %q", logs)
|
||||
}
|
||||
}
|
||||
|
||||
// Sanity check the helper that powers other tests: the secret token
|
||||
// we use really is high-entropy and long enough for the scanner.
|
||||
func TestSafelist_SecretTokenIsHighEntropy(t *testing.T) {
|
||||
if len(secretToken) < 20 {
|
||||
t.Fatalf("secretToken too short: %d", len(secretToken))
|
||||
}
|
||||
if e := shannonEntropy(secretToken); e < 4.5 {
|
||||
t.Fatalf("secretToken entropy = %.2f, want >= 4.5 (test corpus drift)", e)
|
||||
}
|
||||
// And confirm it's stripped of any characters that would split the token.
|
||||
if strings.ContainsAny(secretToken, " .:") {
|
||||
t.Fatalf("secretToken contains a tokenizer split char")
|
||||
}
|
||||
}
|
||||
@@ -44,6 +44,10 @@ func shouldStrip(r rune) bool {
|
||||
if unicode.Is(unicode.Co, r) {
|
||||
return true
|
||||
}
|
||||
// Strip unassigned characters (Cn) — unregistered characters
|
||||
if unicode.Is(unicode.Cn, r) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Strip specific dangerous ranges
|
||||
switch {
|
||||
|
||||
@@ -2,6 +2,7 @@ package security
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"regexp"
|
||||
)
|
||||
@@ -35,6 +36,8 @@ type Scanner struct {
|
||||
patterns []SecretPattern
|
||||
entropyThreshold float64
|
||||
redactHighEntropy bool
|
||||
safelist []safelistEntry
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func NewScanner(entropyThreshold float64, redactHighEntropy bool) *Scanner {
|
||||
@@ -48,6 +51,30 @@ func NewScanner(entropyThreshold float64, redactHighEntropy bool) *Scanner {
|
||||
}
|
||||
}
|
||||
|
||||
// SetSafelist configures the format-aware entropy pre-extractor (Phase F-1).
|
||||
// Names are looked up in defaultSafelistPatterns; unknown names are silently
|
||||
// dropped (callers that want to surface typos should use splitSafelistNames
|
||||
// directly — NewFirewall does this). Calling with an empty or nil slice
|
||||
// clears the safelist and restores pre-F-1 behavior (every long token is
|
||||
// entropy-scored).
|
||||
func (s *Scanner) SetSafelist(names []string) {
|
||||
s.safelist = buildSafelist(names)
|
||||
}
|
||||
|
||||
// SetLogger swaps the logger used for safelist-skip telemetry. The Scanner
|
||||
// otherwise logs nothing; if unset it falls back to slog.Default() so tests
|
||||
// stay quiet.
|
||||
func (s *Scanner) SetLogger(logger *slog.Logger) {
|
||||
s.logger = logger
|
||||
}
|
||||
|
||||
func (s *Scanner) log() *slog.Logger {
|
||||
if s.logger != nil {
|
||||
return s.logger
|
||||
}
|
||||
return slog.Default()
|
||||
}
|
||||
|
||||
// AddPattern adds a custom detection pattern.
|
||||
func (s *Scanner) AddPattern(name, regex string, action ScanAction) error {
|
||||
re, err := regexp.Compile(regex)
|
||||
@@ -98,12 +125,23 @@ func (s *Scanner) HasSecrets(content string) bool {
|
||||
// scanEntropy detects high-entropy strings that might be secrets.
|
||||
func (s *Scanner) scanEntropy(content string) []SecretMatch {
|
||||
var matches []SecretMatch
|
||||
safeSpans := safelistSpansFor(content, s.safelist)
|
||||
// Check each word-like token that's long enough to be a secret
|
||||
words := entropyTokenize(content)
|
||||
for _, w := range words {
|
||||
if len(w.text) < 20 { // secrets are typically 20+ chars
|
||||
continue
|
||||
}
|
||||
if name, ok := inAnySpan(safeSpans, w.start, w.start+len(w.text)); ok {
|
||||
// Per-pattern telemetry for FP-rate measurement. Token bytes
|
||||
// stay out of the log — only length + the safelist name that
|
||||
// covered it. F-2's go/no-go hinges on this data.
|
||||
s.log().Debug("entropy candidate skipped by safelist",
|
||||
"pattern", name,
|
||||
"token_len", len(w.text),
|
||||
)
|
||||
continue
|
||||
}
|
||||
entropy := shannonEntropy(w.text)
|
||||
if entropy >= s.entropyThreshold {
|
||||
action := ActionWarn
|
||||
|
||||
@@ -360,6 +360,15 @@ func TestSanitizeUnicode_PreservesEmoji(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeUnicode_StripsUnassigned(t *testing.T) {
|
||||
// Unassigned character (Cn) e.g., U+0378
|
||||
unassigned := "Hello\u0378world"
|
||||
result := SanitizeUnicode(unassigned)
|
||||
if result != "Helloworld" {
|
||||
t.Errorf("should strip unassigned characters, got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Incognito ---
|
||||
|
||||
func TestIncognito_DefaultOff(t *testing.T) {
|
||||
|
||||
@@ -201,6 +201,13 @@ func (s *Local) SetModel(model string) {
|
||||
s.model = model
|
||||
}
|
||||
|
||||
// SetProvider updates the displayed provider name.
|
||||
func (s *Local) SetProvider(provider string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.provider = provider
|
||||
}
|
||||
|
||||
func (s *Local) Status() Status {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
@@ -2,8 +2,11 @@ package bash
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"mvdan.cc/sh/v3/syntax"
|
||||
)
|
||||
|
||||
// SecurityCheck identifies a specific validation check.
|
||||
@@ -251,7 +254,7 @@ func checkStandaloneSemicolon(cmd string) *SecurityViolation {
|
||||
}
|
||||
|
||||
// checkSensitiveRedirection blocks output redirection to sensitive paths.
|
||||
// Detects: >, >>, fd redirects (2>), and no-space variants (>/etc/passwd).
|
||||
// Uses a POSIX shell parser to reliably identify all output redirections.
|
||||
func checkSensitiveRedirection(cmd string) *SecurityViolation {
|
||||
sensitiveTargets := []string{
|
||||
"/etc/passwd", "/etc/shadow", "/etc/sudoers",
|
||||
@@ -260,22 +263,90 @@ func checkSensitiveRedirection(cmd string) *SecurityViolation {
|
||||
".env",
|
||||
}
|
||||
|
||||
for _, target := range sensitiveTargets {
|
||||
// Match any form: >, >>, 2>, 2>>, &> followed by optional whitespace then target
|
||||
idx := strings.Index(cmd, target)
|
||||
if idx <= 0 {
|
||||
continue
|
||||
}
|
||||
// Check what precedes the target (skip whitespace backwards)
|
||||
pre := strings.TrimRight(cmd[:idx], " \t")
|
||||
if len(pre) > 0 && (pre[len(pre)-1] == '>' || strings.HasSuffix(pre, ">>")) {
|
||||
return &SecurityViolation{
|
||||
Check: CheckRedirection,
|
||||
Message: fmt.Sprintf("redirection to sensitive path: %s", target),
|
||||
}
|
||||
reader := strings.NewReader(cmd)
|
||||
parser := syntax.NewParser()
|
||||
file, err := parser.Parse(reader, "")
|
||||
if err != nil {
|
||||
return &SecurityViolation{
|
||||
Check: CheckIncomplete,
|
||||
Message: fmt.Sprintf("invalid command syntax: %v", err),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
var violation *SecurityViolation
|
||||
printer := syntax.NewPrinter()
|
||||
|
||||
syntax.Walk(file, func(node syntax.Node) bool {
|
||||
if violation != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if stmt, ok := node.(*syntax.Stmt); ok {
|
||||
for _, redir := range stmt.Redirs {
|
||||
op := redir.Op
|
||||
// Check all redirection operators that write or modify files:
|
||||
// Skip read-only/heredoc operators: RdrIn (<), DplIn (<&), Hdoc (<<), DashHdoc (<<-), WordHdoc (<<<)
|
||||
if op == syntax.RdrIn || op == syntax.DplIn || op == syntax.Hdoc || op == syntax.DashHdoc || op == syntax.WordHdoc {
|
||||
continue
|
||||
}
|
||||
|
||||
if redir.Word == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
_ = printer.Print(&b, redir.Word)
|
||||
targetPath := b.String()
|
||||
|
||||
// Strip single/double quotes around the target word if present
|
||||
targetPath = strings.TrimSpace(targetPath)
|
||||
if (strings.HasPrefix(targetPath, "\"") && strings.HasSuffix(targetPath, "\"")) ||
|
||||
(strings.HasPrefix(targetPath, "'") && strings.HasSuffix(targetPath, "'")) {
|
||||
if len(targetPath) >= 2 {
|
||||
targetPath = targetPath[1 : len(targetPath)-1]
|
||||
}
|
||||
}
|
||||
|
||||
cleaned := filepath.Clean(targetPath)
|
||||
|
||||
for _, target := range sensitiveTargets {
|
||||
if strings.HasPrefix(target, "/") {
|
||||
// Absolute targets: exact match
|
||||
if cleaned == target {
|
||||
violation = &SecurityViolation{
|
||||
Check: CheckRedirection,
|
||||
Message: fmt.Sprintf("redirection to sensitive path: %s", target),
|
||||
}
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
// Relative targets: suffix/base match
|
||||
if target == ".env" || target == ".bashrc" || target == ".zshrc" || target == ".profile" || target == ".bash_profile" {
|
||||
if filepath.Base(cleaned) == target {
|
||||
violation = &SecurityViolation{
|
||||
Check: CheckRedirection,
|
||||
Message: fmt.Sprintf("redirection to sensitive path: %s", target),
|
||||
}
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
// Relative paths with directory components (e.g. .ssh/config)
|
||||
if strings.HasSuffix(cleaned, "/"+target) || cleaned == target {
|
||||
violation = &SecurityViolation{
|
||||
Check: CheckRedirection,
|
||||
Message: fmt.Sprintf("redirection to sensitive path: %s", target),
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
return violation
|
||||
}
|
||||
|
||||
// checkJQInjection detects jq commands with embedded shell metacharacters in the filter.
|
||||
|
||||
@@ -229,6 +229,12 @@ func TestCheckSensitiveRedirection_Blocked(t *testing.T) {
|
||||
"echo evil > /etc/passwd",
|
||||
"echo evil>>/etc/shadow",
|
||||
"echo evil >> /etc/shadow",
|
||||
"echo evil >\\\n.env",
|
||||
"echo evil > \".env\"",
|
||||
"echo evil > '.env'",
|
||||
"echo evil > ./.env",
|
||||
"echo evil > sub/.env",
|
||||
"echo evil > /home/user/workspace/.env",
|
||||
}
|
||||
for _, cmd := range blocked {
|
||||
t.Run(cmd, func(t *testing.T) {
|
||||
@@ -240,6 +246,17 @@ func TestCheckSensitiveRedirection_Blocked(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckSensitiveRedirection_SyntaxError(t *testing.T) {
|
||||
v := ValidateCommand("echo hello > \"unclosed quote")
|
||||
if v == nil {
|
||||
t.Error("expected violation for invalid syntax")
|
||||
return
|
||||
}
|
||||
if v.Check != CheckIncomplete {
|
||||
t.Errorf("expected CheckIncomplete, got %d", v.Check)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckProcessSubstitution_Allowed(t *testing.T) {
|
||||
// Process substitution <() and >() should NOT be blocked
|
||||
allowed := []string{
|
||||
|
||||
@@ -79,7 +79,7 @@ func (t *EditTool) Execute(_ context.Context, args json.RawMessage) (tool.Result
|
||||
|
||||
path := a.Path
|
||||
if t.guard != nil {
|
||||
resolved, err := t.guard.ResolveRead(path)
|
||||
resolved, err := t.guard.ResolveWrite(path)
|
||||
if err != nil {
|
||||
return tool.Result{Output: fmt.Sprintf("Error: %v", err)}, nil
|
||||
}
|
||||
|
||||
+845
-175
File diff suppressed because it is too large
Load Diff
@@ -18,6 +18,7 @@ var builtinCommands = []cmdEntry{
|
||||
{"/clear", "clear conversation history"},
|
||||
{"/compact", "summarize and compact conversation context"},
|
||||
{"/config", "open settings panel"},
|
||||
{"/copy", "copy the latest assistant response to the clipboard"},
|
||||
{"/exit", "exit gnoma"},
|
||||
{"/help", "show available commands and shortcuts"},
|
||||
{"/incognito", "toggle incognito mode (no persistence, local-only routing)"},
|
||||
@@ -34,8 +35,10 @@ var builtinCommands = []cmdEntry{
|
||||
{"/replay", "replay last assistant response"},
|
||||
{"/resume", "browse and resume a saved session"},
|
||||
{"/shell", "open interactive shell"},
|
||||
{"/theme", "list themes or set active theme"},
|
||||
{"/skills", "list available skills"},
|
||||
{"/usage", "show token usage for this session"},
|
||||
{"/vim", "toggle Vim keybindings in the input composer"},
|
||||
}
|
||||
|
||||
// permissionModes lists valid modes for /permission completion.
|
||||
@@ -81,14 +84,14 @@ func matchSuggestions(input string, commands []cmdEntry) []cmdEntry {
|
||||
|
||||
// matchCompletion returns the unique ghost-text completion, or "".
|
||||
// Used for Tab acceptance of a single unambiguous match. profileNames
|
||||
// is the dynamic completion source for `/profile <name>` — pass nil
|
||||
// when none are known.
|
||||
func matchCompletion(input string, commands []cmdEntry, profileNames []string) string {
|
||||
// is the dynamic completion source for `/profile <name>`, and providerNames
|
||||
// is for `/provider <name>` — pass nil when none are known.
|
||||
func matchCompletion(input string, commands []cmdEntry, profileNames []string, providerNames []string) string {
|
||||
if !strings.HasPrefix(input, "/") || len(input) < 2 {
|
||||
return ""
|
||||
}
|
||||
if strings.Contains(input, " ") {
|
||||
return matchArgCompletion(input, profileNames)
|
||||
return matchArgCompletion(input, profileNames, providerNames)
|
||||
}
|
||||
suggestions := matchSuggestions(input, commands)
|
||||
if len(suggestions) == 1 && suggestions[0].name != input {
|
||||
@@ -126,9 +129,9 @@ func fuzzyMatchCommands(query string, commands []cmdEntry) []cmdEntry {
|
||||
}
|
||||
|
||||
// matchArgCompletion handles second-level completion for commands with args.
|
||||
// profileNames is the dynamic source for `/profile <name>`; pass nil when
|
||||
// profile mode isn't engaged.
|
||||
func matchArgCompletion(input string, profileNames []string) string {
|
||||
// profileNames is the dynamic source for `/profile <name>`, and providerNames
|
||||
// is for `/provider <name>`; pass nil when not available.
|
||||
func matchArgCompletion(input string, profileNames []string, providerNames []string) string {
|
||||
parts := strings.SplitN(input, " ", 2)
|
||||
if len(parts) != 2 {
|
||||
return ""
|
||||
@@ -157,6 +160,16 @@ func matchArgCompletion(input string, profileNames []string) string {
|
||||
return cmd + " " + name
|
||||
}
|
||||
}
|
||||
case "/provider":
|
||||
if arg == "" || len(providerNames) == 0 {
|
||||
return ""
|
||||
}
|
||||
lower := strings.ToLower(arg)
|
||||
for _, name := range providerNames {
|
||||
if strings.HasPrefix(strings.ToLower(name), lower) && name != arg {
|
||||
return cmd + " " + name
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -34,7 +34,7 @@ func TestMatchCompletion(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := matchCompletion(tt.input, cmds, nil)
|
||||
got := matchCompletion(tt.input, cmds, nil, nil)
|
||||
if got != tt.want {
|
||||
t.Errorf("matchCompletion(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
@@ -113,7 +113,7 @@ func TestMatchArgCompletion(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := matchArgCompletion(tt.input, nil)
|
||||
got := matchArgCompletion(tt.input, nil, nil)
|
||||
if got != tt.want {
|
||||
t.Errorf("matchArgCompletion(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
@@ -134,7 +134,7 @@ func TestMatchArgCompletion_Profile(t *testing.T) {
|
||||
{"/profile ", ""}, // empty arg — wait for input
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := matchArgCompletion(tt.input, profiles)
|
||||
got := matchArgCompletion(tt.input, profiles, nil)
|
||||
if got != tt.want {
|
||||
t.Errorf("matchArgCompletion(%q, profiles) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
@@ -145,7 +145,7 @@ func TestMatchCompletion_DispatchesToProfileArgCompletion(t *testing.T) {
|
||||
// End-to-end: matchCompletion sees "/profile w", forwards to
|
||||
// matchArgCompletion with profileNames, gets back "/profile work".
|
||||
cmds := []cmdEntry{{"/profile", "profiles"}}
|
||||
got := matchCompletion("/profile w", cmds, []string{"work", "private"})
|
||||
got := matchCompletion("/profile w", cmds, []string{"work", "private"}, nil)
|
||||
if got != "/profile work" {
|
||||
t.Errorf("matchCompletion(/profile w) = %q, want /profile work", got)
|
||||
}
|
||||
@@ -154,8 +154,37 @@ func TestMatchCompletion_DispatchesToProfileArgCompletion(t *testing.T) {
|
||||
func TestMatchArgCompletion_ProfileNoNamesAvailable(t *testing.T) {
|
||||
// When profile mode isn't engaged, profileNames is nil/empty and the
|
||||
// completer must not try to suggest anything.
|
||||
got := matchArgCompletion("/profile w", nil)
|
||||
got := matchArgCompletion("/profile w", nil, nil)
|
||||
if got != "" {
|
||||
t.Errorf("matchArgCompletion(profile, nil) = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchArgCompletion_Provider(t *testing.T) {
|
||||
providers := []string{"anthropic", "openai", "google"}
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"/provider a", "/provider anthropic"},
|
||||
{"/provider o", "/provider openai"},
|
||||
{"/provider openai", ""}, // already complete
|
||||
{"/provider g", "/provider google"},
|
||||
{"/provider z", ""}, // no match
|
||||
{"/provider ", ""}, // empty arg — wait for input
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := matchArgCompletion(tt.input, nil, providers)
|
||||
if got != tt.want {
|
||||
t.Errorf("matchArgCompletion(%q, providers) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchCompletion_DispatchesToProviderArgCompletion(t *testing.T) {
|
||||
cmds := []cmdEntry{{"/provider", "providers"}}
|
||||
got := matchCompletion("/provider a", cmds, nil, []string{"anthropic", "openai"})
|
||||
if got != "/provider anthropic" {
|
||||
t.Errorf("matchCompletion(/provider a) = %q, want /provider anthropic", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,125 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// stageHistoryDir redirects GlobalConfigDir() to t.TempDir() by overriding
|
||||
// XDG_CONFIG_HOME. Returns the resolved ~/.config/gnoma path.
|
||||
func stageHistoryDir(t *testing.T) string {
|
||||
t.Helper()
|
||||
root := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", root)
|
||||
return filepath.Join(root, "gnoma")
|
||||
}
|
||||
|
||||
func TestSavePromptHistory_WritesFileWithRestrictivePerms(t *testing.T) {
|
||||
dir := stageHistoryDir(t)
|
||||
|
||||
savePromptHistory("first prompt")
|
||||
|
||||
path := filepath.Join(dir, "history.txt")
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
t.Fatalf("history file not created: %v", err)
|
||||
}
|
||||
if mode := info.Mode().Perm(); mode != 0o600 {
|
||||
t.Errorf("history file mode = %o, want 0600", mode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSavePromptHistory_RewritesExistingFileTo0600(t *testing.T) {
|
||||
dir := stageHistoryDir(t)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
path := filepath.Join(dir, "history.txt")
|
||||
if err := os.WriteFile(path, []byte("old entry\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
savePromptHistory("new entry")
|
||||
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
t.Fatalf("stat failed: %v", err)
|
||||
}
|
||||
if mode := info.Mode().Perm(); mode != 0o600 {
|
||||
t.Errorf("history file mode = %o, want 0600 after rewrite", mode)
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !strings.Contains(string(data), "old entry") {
|
||||
t.Error("rewrite dropped previously stored entry")
|
||||
}
|
||||
if !strings.Contains(string(data), "new entry") {
|
||||
t.Error("rewrite missing newly appended entry")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSavePromptHistory_TruncatesToLast500Entries(t *testing.T) {
|
||||
dir := stageHistoryDir(t)
|
||||
|
||||
// Save 600 entries.
|
||||
for i := 0; i < 600; i++ {
|
||||
savePromptHistory(fmt.Sprintf("entry-%d", i))
|
||||
}
|
||||
|
||||
// On-disk file must also be capped (not just the loaded view).
|
||||
data, err := os.ReadFile(filepath.Join(dir, "history.txt"))
|
||||
if err != nil {
|
||||
t.Fatalf("read failed: %v", err)
|
||||
}
|
||||
onDiskLines := strings.Count(strings.TrimRight(string(data), "\n"), "\n") + 1
|
||||
if onDiskLines > 500 {
|
||||
t.Errorf("on-disk history has %d lines, want ≤500", onDiskLines)
|
||||
}
|
||||
|
||||
got := loadPromptHistory()
|
||||
if len(got) > 500 {
|
||||
t.Errorf("history length = %d, want ≤500 after 600 writes", len(got))
|
||||
}
|
||||
if len(got) == 0 {
|
||||
t.Fatal("history unexpectedly empty")
|
||||
}
|
||||
// Most recent entry should be the last one written.
|
||||
if got[len(got)-1] != "entry-599" {
|
||||
t.Errorf("last entry = %q, want entry-599", got[len(got)-1])
|
||||
}
|
||||
// Oldest retained entry should be entry-100 (600-500).
|
||||
if got[0] != "entry-100" {
|
||||
t.Errorf("first entry = %q, want entry-100", got[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSavePromptHistory_IgnoresBlankInput(t *testing.T) {
|
||||
dir := stageHistoryDir(t)
|
||||
|
||||
savePromptHistory("")
|
||||
savePromptHistory(" \n\t ")
|
||||
|
||||
path := filepath.Join(dir, "history.txt")
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
t.Error("blank input should not create history file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSavePromptHistory_NewlinesFlattenedToSpace(t *testing.T) {
|
||||
stageHistoryDir(t)
|
||||
|
||||
savePromptHistory("line one\nline two")
|
||||
|
||||
got := loadPromptHistory()
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("history length = %d, want 1", len(got))
|
||||
}
|
||||
if got[0] != "line one line two" {
|
||||
t.Errorf("got %q, want 'line one line two'", got[0])
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// stagePastedImageCache redirects os.UserCacheDir() to a temp dir by
|
||||
// overriding XDG_CACHE_HOME. Returns the resolved cache root.
|
||||
func stagePastedImageCache(t *testing.T) string {
|
||||
t.Helper()
|
||||
root := t.TempDir()
|
||||
t.Setenv("XDG_CACHE_HOME", root)
|
||||
return filepath.Join(root, "gnoma", "pasted-images")
|
||||
}
|
||||
|
||||
func TestStorePastedImage_WritesToUserCacheWithRestrictivePerms(t *testing.T) {
|
||||
cacheDir := stagePastedImageCache(t)
|
||||
|
||||
path, err := storePastedImage([]byte("png-bytes"), ".png")
|
||||
if err != nil {
|
||||
t.Fatalf("storePastedImage: %v", err)
|
||||
}
|
||||
if filepath.Dir(path) != cacheDir {
|
||||
t.Errorf("path dir = %q, want %q", filepath.Dir(path), cacheDir)
|
||||
}
|
||||
if filepath.Ext(path) != ".png" {
|
||||
t.Errorf("path ext = %q, want .png", filepath.Ext(path))
|
||||
}
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if mode := info.Mode().Perm(); mode != 0o600 {
|
||||
t.Errorf("file mode = %o, want 0600", mode)
|
||||
}
|
||||
if dirInfo, _ := os.Stat(cacheDir); dirInfo != nil {
|
||||
if mode := dirInfo.Mode().Perm(); mode != 0o700 {
|
||||
t.Errorf("dir mode = %o, want 0700", mode)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorePastedImage_DoesNotPolluteProjectRoot(t *testing.T) {
|
||||
// Make sure the cache dir lookup doesn't fall back to cwd / the
|
||||
// project root for any reason. Stage XDG_CACHE_HOME and verify
|
||||
// the returned path is under it, not under cwd.
|
||||
cacheRoot := t.TempDir()
|
||||
t.Setenv("XDG_CACHE_HOME", cacheRoot)
|
||||
|
||||
cwd, _ := os.Getwd()
|
||||
path, err := storePastedImage([]byte("x"), ".png")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rel, err := filepath.Rel(cwd, path)
|
||||
if err == nil && !filepath.IsAbs(rel) && rel[0] != '.' {
|
||||
// path is inside cwd — that would mean we polluted the workdir
|
||||
t.Errorf("storePastedImage wrote under cwd at %q", path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPruneStalePastedImages_RemovesOldKeepsFresh(t *testing.T) {
|
||||
cacheDir := stagePastedImageCache(t)
|
||||
|
||||
// Manually create one stale + one fresh file (mtime via os.Chtimes).
|
||||
stale := filepath.Join(cacheDir, "pasted_image_stale.png")
|
||||
fresh := filepath.Join(cacheDir, "pasted_image_fresh.png")
|
||||
if err := os.MkdirAll(cacheDir, 0o700); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(stale, []byte("old"), 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(fresh, []byte("new"), 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
old := time.Now().Add(-pastedImageStaleAfter - time.Minute)
|
||||
if err := os.Chtimes(stale, old, old); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
pruneStalePastedImages(cacheDir)
|
||||
|
||||
if _, err := os.Stat(stale); !os.IsNotExist(err) {
|
||||
t.Errorf("stale file should be pruned, stat err = %v", err)
|
||||
}
|
||||
if _, err := os.Stat(fresh); err != nil {
|
||||
t.Errorf("fresh file should survive, stat err = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPruneStalePastedImages_MissingDirIsNoOp(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("prune panicked on missing dir: %v", r)
|
||||
}
|
||||
}()
|
||||
pruneStalePastedImages(filepath.Join(t.TempDir(), "does", "not", "exist"))
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExpandPlaceholders_BracketFormExpandsToStoredText(t *testing.T) {
|
||||
m := Model{
|
||||
pastedTexts: map[string]string{"#p1": "hello world"},
|
||||
}
|
||||
got := m.expandPlaceholders("see [Pasted text #p1 +0 lines] end")
|
||||
want := "see hello world end"
|
||||
if got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandPlaceholders_RawFormExpandsToStoredText(t *testing.T) {
|
||||
m := Model{
|
||||
pastedTexts: map[string]string{"#p1": "hello"},
|
||||
}
|
||||
got := m.expandPlaceholders("ref #p1 here")
|
||||
want := "ref hello here"
|
||||
if got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandPlaceholders_UnknownIDsAreLeftAlone(t *testing.T) {
|
||||
m := Model{
|
||||
pastedTexts: map[string]string{"#p1": "hello"},
|
||||
}
|
||||
got := m.expandPlaceholders("ref #p9 here")
|
||||
if got != "ref #p9 here" {
|
||||
t.Errorf("unknown id should be left intact, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// Regression: the bug was that after the bracket form was inlined, a second
|
||||
// pass scanned the resulting string for raw `#p\d+`. If the pasted content
|
||||
// itself contained `#p2`, that token was silently corrupted into whatever
|
||||
// `pastedTexts["#p2"]` mapped to (or stripped if absent).
|
||||
func TestExpandPlaceholders_PastedContentContainingPlaceholderSyntaxSurvives(t *testing.T) {
|
||||
m := Model{
|
||||
pastedTexts: map[string]string{
|
||||
"#p1": "look at #p2 in this snippet",
|
||||
"#p2": "SHOULD_NOT_APPEAR",
|
||||
},
|
||||
}
|
||||
got := m.expandPlaceholders("here: [Pasted text #p1 +0 lines]")
|
||||
want := "here: look at #p2 in this snippet"
|
||||
if got != want {
|
||||
t.Errorf("pasted content was re-expanded:\n got %q\n want %q", got, want)
|
||||
}
|
||||
if strings.Contains(got, "SHOULD_NOT_APPEAR") {
|
||||
t.Error("nested #p2 inside pasted content was wrongly expanded")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandPlaceholders_ImageBracketFormExpandsToPath(t *testing.T) {
|
||||
m := Model{
|
||||
pastedImages: map[string]string{"#img1": "/tmp/x.png"},
|
||||
}
|
||||
got := m.expandPlaceholders("see [Pasted image #img1] end")
|
||||
want := "see [Image: /tmp/x.png] end"
|
||||
if got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandPlaceholders_MultiplePlaceholdersInOneInput(t *testing.T) {
|
||||
m := Model{
|
||||
pastedTexts: map[string]string{"#p1": "AAA", "#p2": "BBB"},
|
||||
pastedImages: map[string]string{"#img1": "/tmp/x.png"},
|
||||
}
|
||||
got := m.expandPlaceholders("[Pasted text #p1 +0 lines] then #p2 then [Pasted image #img1]")
|
||||
want := "AAA then BBB then [Image: /tmp/x.png]"
|
||||
if got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,326 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/engine"
|
||||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||||
"somegit.dev/Owlibou/gnoma/internal/router"
|
||||
"somegit.dev/Owlibou/gnoma/internal/security"
|
||||
"somegit.dev/Owlibou/gnoma/internal/session"
|
||||
"somegit.dev/Owlibou/gnoma/internal/stream"
|
||||
"somegit.dev/Owlibou/gnoma/internal/tool"
|
||||
)
|
||||
|
||||
type mockProvider struct {
|
||||
name string
|
||||
defaultModel string
|
||||
}
|
||||
|
||||
func (m *mockProvider) Stream(ctx context.Context, req provider.Request) (stream.Stream, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockProvider) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *mockProvider) Models(ctx context.Context) ([]provider.ModelInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockProvider) DefaultModel() string {
|
||||
return m.defaultModel
|
||||
}
|
||||
|
||||
func newTestRouterAndEngine() (*router.Router, *engine.Engine, router.SecureProvider, router.SecureProvider) {
|
||||
rtr := router.New(router.Config{})
|
||||
p1 := security.WrapProvider(&mockProvider{name: "anthropic", defaultModel: "claude-3-5-sonnet"}, nil)
|
||||
p2 := security.WrapProvider(&mockProvider{name: "openai", defaultModel: "gpt-4o"}, nil)
|
||||
|
||||
rtr.RegisterArm(&router.Arm{
|
||||
ID: router.NewArmID("anthropic", "claude-3-5-sonnet"),
|
||||
Provider: p1,
|
||||
ModelName: "claude-3-5-sonnet",
|
||||
Capabilities: provider.Capabilities{ToolUse: true},
|
||||
})
|
||||
rtr.RegisterArm(&router.Arm{
|
||||
ID: router.NewArmID("openai", "gpt-4o"),
|
||||
Provider: p2,
|
||||
ModelName: "gpt-4o",
|
||||
Capabilities: provider.Capabilities{ToolUse: true},
|
||||
})
|
||||
rtr.RegisterArm(&router.Arm{
|
||||
ID: router.NewArmID("openai", "gpt-3.5-turbo"),
|
||||
Provider: p2,
|
||||
ModelName: "gpt-3.5-turbo",
|
||||
Capabilities: provider.Capabilities{ToolUse: true},
|
||||
})
|
||||
|
||||
eng, err := engine.New(engine.Config{
|
||||
Provider: p1,
|
||||
Model: "claude-3-5-sonnet",
|
||||
Tools: tool.NewRegistry(),
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return rtr, eng, p1, p2
|
||||
}
|
||||
|
||||
func TestGetAvailableProviders(t *testing.T) {
|
||||
rtr, _, _, _ := newTestRouterAndEngine()
|
||||
m := Model{
|
||||
config: Config{
|
||||
Router: rtr,
|
||||
},
|
||||
}
|
||||
|
||||
provs := m.getAvailableProviders()
|
||||
if len(provs) != 2 {
|
||||
t.Fatalf("expected 2 providers, got %d", len(provs))
|
||||
}
|
||||
if provs[0] != "anthropic" || provs[1] != "openai" {
|
||||
t.Errorf("expected [anthropic, openai], got %v", provs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindBestArmForProvider(t *testing.T) {
|
||||
rtr, _, _, _ := newTestRouterAndEngine()
|
||||
m := Model{
|
||||
config: Config{
|
||||
Router: rtr,
|
||||
},
|
||||
}
|
||||
|
||||
// Should match the default model
|
||||
arm1 := m.findBestArmForProvider("openai")
|
||||
if arm1 == nil {
|
||||
t.Fatal("expected arm for openai")
|
||||
}
|
||||
if arm1.ModelName != "gpt-4o" {
|
||||
t.Errorf("expected gpt-4o, got %s", arm1.ModelName)
|
||||
}
|
||||
|
||||
// Should fallback to first arm if default model not found
|
||||
rtr.RegisterArm(&router.Arm{
|
||||
ID: router.NewArmID("unknown", "weird-model"),
|
||||
Provider: security.WrapProvider(&mockProvider{name: "unknown", defaultModel: "missing"}, nil),
|
||||
ModelName: "weird-model",
|
||||
})
|
||||
arm2 := m.findBestArmForProvider("unknown")
|
||||
if arm2 == nil {
|
||||
t.Fatal("expected arm for unknown")
|
||||
}
|
||||
if arm2.ModelName != "weird-model" {
|
||||
t.Errorf("expected weird-model, got %s", arm2.ModelName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseAllPickersResetsProvider(t *testing.T) {
|
||||
m := Model{providerPickerOpen: true}
|
||||
m = m.closeAllPickers()
|
||||
if m.providerPickerOpen {
|
||||
t.Error("providerPickerOpen should be false after closeAllPickers")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPickerItemCount_Provider(t *testing.T) {
|
||||
rtr, _, _, _ := newTestRouterAndEngine()
|
||||
m := Model{
|
||||
providerPickerOpen: true,
|
||||
config: Config{
|
||||
Router: rtr,
|
||||
},
|
||||
}
|
||||
count := m.getPickerItemCount()
|
||||
if count != 2 {
|
||||
t.Errorf("expected picker item count 2, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleProviderCommand_ArgsEmptyOpensPicker(t *testing.T) {
|
||||
rtr, eng, _, _ := newTestRouterAndEngine()
|
||||
sess := session.NewLocal(session.LocalConfig{
|
||||
Engine: eng,
|
||||
Provider: "anthropic",
|
||||
Model: "claude-3-5-sonnet",
|
||||
})
|
||||
m := Model{
|
||||
session: sess,
|
||||
config: Config{
|
||||
Router: rtr,
|
||||
Engine: eng,
|
||||
},
|
||||
}
|
||||
|
||||
res, err := m.handleCommand("/provider")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
newM, ok := res.(Model)
|
||||
if !ok {
|
||||
t.Fatalf("expected Model type, got %T", res)
|
||||
}
|
||||
if !newM.providerPickerOpen {
|
||||
t.Error("expected provider picker to be open")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleProviderCommand_ArgsNotEmptySwitchesProvider(t *testing.T) {
|
||||
rtr, eng, _, _ := newTestRouterAndEngine()
|
||||
sess := session.NewLocal(session.LocalConfig{
|
||||
Engine: eng,
|
||||
Provider: "anthropic",
|
||||
Model: "claude-3-5-sonnet",
|
||||
})
|
||||
m := Model{
|
||||
session: sess,
|
||||
config: Config{
|
||||
Router: rtr,
|
||||
Engine: eng,
|
||||
},
|
||||
}
|
||||
|
||||
res, err := m.handleCommand("/provider openai")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
newM, ok := res.(Model)
|
||||
if !ok {
|
||||
t.Fatalf("expected Model type, got %T", res)
|
||||
}
|
||||
if newM.providerPickerOpen {
|
||||
t.Error("expected provider picker to be closed")
|
||||
}
|
||||
|
||||
status := newM.session.Status()
|
||||
if status.Provider != "openai" {
|
||||
t.Errorf("expected provider to switch to openai, got %s", status.Provider)
|
||||
}
|
||||
if status.Model != "gpt-4o" {
|
||||
t.Errorf("expected model to switch to gpt-4o, got %s", status.Model)
|
||||
}
|
||||
|
||||
// Check messages contain switch system log
|
||||
found := false
|
||||
for _, msg := range newM.messages {
|
||||
if msg.role == "system" && strings.Contains(msg.content, "provider switched to: openai") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("expected switch system message in history")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigPanelTransitions(t *testing.T) {
|
||||
rtr, eng, _, _ := newTestRouterAndEngine()
|
||||
sess := session.NewLocal(session.LocalConfig{
|
||||
Engine: eng,
|
||||
Provider: "anthropic",
|
||||
Model: "claude-3-5-sonnet",
|
||||
})
|
||||
m := Model{
|
||||
session: sess,
|
||||
configPanelOpen: true,
|
||||
config: Config{
|
||||
Router: rtr,
|
||||
Engine: eng,
|
||||
},
|
||||
}
|
||||
|
||||
// 1. Select Provider (index 0)
|
||||
m.configSelected = 0
|
||||
m = m.applyConfigSetting()
|
||||
if m.configPanelOpen {
|
||||
t.Error("expected config panel to close when opening provider picker")
|
||||
}
|
||||
if !m.providerPickerOpen {
|
||||
t.Error("expected provider picker to open")
|
||||
}
|
||||
|
||||
// Reset state
|
||||
m.configPanelOpen = true
|
||||
m.providerPickerOpen = false
|
||||
|
||||
// 2. Select Model (index 1)
|
||||
m.configSelected = 1
|
||||
m = m.applyConfigSetting()
|
||||
if m.configPanelOpen {
|
||||
t.Error("expected config panel to close when opening model picker")
|
||||
}
|
||||
if !m.modelPickerOpen {
|
||||
t.Error("expected model picker to open")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigPanelTransitionsWithSLM(t *testing.T) {
|
||||
rtr, eng, _, _ := newTestRouterAndEngine()
|
||||
sess := session.NewLocal(session.LocalConfig{
|
||||
Engine: eng,
|
||||
Provider: "anthropic",
|
||||
Model: "claude-3-5-sonnet",
|
||||
})
|
||||
m := Model{
|
||||
session: sess,
|
||||
configPanelOpen: true,
|
||||
config: Config{
|
||||
Router: rtr,
|
||||
Engine: eng,
|
||||
SLM: SLMInfo{
|
||||
Active: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 1. Verify getActiveSettings only has permission and incognito
|
||||
settings := m.getActiveSettings()
|
||||
if len(settings) != 2 {
|
||||
t.Fatalf("expected 2 settings when SLM is active, got %d", len(settings))
|
||||
}
|
||||
if settings[0] != "permission" || settings[1] != "incognito" {
|
||||
t.Errorf("expected settings to be [permission, incognito], got %v", settings)
|
||||
}
|
||||
|
||||
// 2. Try handling /model slash command — it should add a system message and not open picker
|
||||
retM, _ := m.handleCommand("/model")
|
||||
m2 := retM.(Model)
|
||||
if m2.modelPickerOpen {
|
||||
t.Error("expected model picker not to open when SLM is active")
|
||||
}
|
||||
if len(m2.messages) == 0 || m2.messages[len(m2.messages)-1].role != "system" {
|
||||
t.Error("expected system warning message for blocked model switch")
|
||||
}
|
||||
|
||||
// 3. Try handling /provider slash command — it should add a system message and not open picker
|
||||
retP, _ := m.handleCommand("/provider")
|
||||
m3 := retP.(Model)
|
||||
if m3.providerPickerOpen {
|
||||
t.Error("expected provider picker not to open when SLM is active")
|
||||
}
|
||||
if len(m3.messages) == 0 || m3.messages[len(m3.messages)-1].role != "system" {
|
||||
t.Error("expected system warning message for blocked provider switch")
|
||||
}
|
||||
|
||||
// 4. Verify rendering output mentions "router" instead of anthropic/claude-3-5-sonnet
|
||||
statusStr := m.renderStatus()
|
||||
if !strings.Contains(statusStr, "router") {
|
||||
t.Errorf("expected status bar to contain 'router' when SLM is active, got: %q", statusStr)
|
||||
}
|
||||
if strings.Contains(statusStr, "anthropic") {
|
||||
t.Errorf("expected status bar to hide 'anthropic' when SLM is active, got: %q", statusStr)
|
||||
}
|
||||
|
||||
chatStr := m.renderChat(80)
|
||||
if !strings.Contains(chatStr, "router (slm:") {
|
||||
t.Errorf("expected header to contain 'router (slm:' when SLM is active, got: %q", chatStr)
|
||||
}
|
||||
if strings.Contains(chatStr, "anthropic") {
|
||||
t.Errorf("expected header to hide 'anthropic' when SLM is active, got: %q", chatStr)
|
||||
}
|
||||
}
|
||||
+675
-144
File diff suppressed because it is too large
Load Diff
+225
-106
@@ -2,126 +2,245 @@ package tui
|
||||
|
||||
import (
|
||||
"image/color"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"charm.land/lipgloss/v2"
|
||||
"somegit.dev/Owlibou/gnoma/internal/permission"
|
||||
)
|
||||
|
||||
// Color palette — catppuccin mocha inspired
|
||||
var (
|
||||
cPurple = lipgloss.Color("#CBA6F7") // mauve
|
||||
cBlue = lipgloss.Color("#89B4FA") // blue
|
||||
cGreen = lipgloss.Color("#A6E3A1") // green
|
||||
cRed = lipgloss.Color("#F38BA8") // red
|
||||
cYellow = lipgloss.Color("#F9E2AF") // yellow
|
||||
cPeach = lipgloss.Color("#FAB387") // peach
|
||||
cTeal = lipgloss.Color("#94E2D5") // teal
|
||||
cText = lipgloss.Color("#CDD6F4") // text
|
||||
cSubtext = lipgloss.Color("#A6ADC8") // subtext0
|
||||
cOverlay = lipgloss.Color("#6C7086") // overlay0
|
||||
cSurface = lipgloss.Color("#313244") // surface0
|
||||
cMantle = lipgloss.Color("#181825") // mantle
|
||||
)
|
||||
|
||||
// Permission mode colors — each mode has a distinct color
|
||||
var modeColors = map[permission.Mode]color.Color{
|
||||
permission.ModeBypass: cGreen, // green = all allowed
|
||||
permission.ModeDefault: cBlue, // blue = prompting
|
||||
permission.ModePlan: cTeal, // teal = read-only
|
||||
permission.ModeAcceptEdits: cPurple, // purple = edits ok
|
||||
permission.ModeAuto: cPeach, // peach = smart
|
||||
permission.ModeDeny: cRed, // red = locked down
|
||||
// Theme represents a custom color palette for the TUI.
|
||||
type Theme struct {
|
||||
Name string
|
||||
Purple color.Color
|
||||
Blue color.Color
|
||||
Green color.Color
|
||||
Red color.Color
|
||||
Yellow color.Color
|
||||
Peach color.Color
|
||||
Teal color.Color
|
||||
Text color.Color
|
||||
Subtext color.Color
|
||||
Overlay color.Color
|
||||
Surface color.Color
|
||||
Mantle color.Color
|
||||
}
|
||||
|
||||
// ModeColor returns the color for a permission mode.
|
||||
// Predefined themes
|
||||
var Themes = []Theme{
|
||||
{
|
||||
Name: "catppuccin",
|
||||
Purple: lipgloss.Color("#CBA6F7"),
|
||||
Blue: lipgloss.Color("#89B4FA"),
|
||||
Green: lipgloss.Color("#A6E3A1"),
|
||||
Red: lipgloss.Color("#F38BA8"),
|
||||
Yellow: lipgloss.Color("#F9E2AF"),
|
||||
Peach: lipgloss.Color("#FAB387"),
|
||||
Teal: lipgloss.Color("#94E2D5"),
|
||||
Text: lipgloss.Color("#CDD6F4"),
|
||||
Subtext: lipgloss.Color("#A6ADC8"),
|
||||
Overlay: lipgloss.Color("#6C7086"),
|
||||
Surface: lipgloss.Color("#313244"),
|
||||
Mantle: lipgloss.Color("#181825"),
|
||||
},
|
||||
{
|
||||
Name: "nord",
|
||||
Purple: lipgloss.Color("#B48EAD"),
|
||||
Blue: lipgloss.Color("#81A1C1"),
|
||||
Green: lipgloss.Color("#A3BE8C"),
|
||||
Red: lipgloss.Color("#BF616A"),
|
||||
Yellow: lipgloss.Color("#EBCB8B"),
|
||||
Peach: lipgloss.Color("#D08770"),
|
||||
Teal: lipgloss.Color("#88C0D0"),
|
||||
Text: lipgloss.Color("#D8DEE9"),
|
||||
Subtext: lipgloss.Color("#E5E9F0"),
|
||||
Overlay: lipgloss.Color("#4C566A"),
|
||||
Surface: lipgloss.Color("#3B4252"),
|
||||
Mantle: lipgloss.Color("#2E3440"),
|
||||
},
|
||||
{
|
||||
Name: "gruvbox",
|
||||
Purple: lipgloss.Color("#d3869b"),
|
||||
Blue: lipgloss.Color("#83a598"),
|
||||
Green: lipgloss.Color("#b8bb26"),
|
||||
Red: lipgloss.Color("#fb4934"),
|
||||
Yellow: lipgloss.Color("#fabd2f"),
|
||||
Peach: lipgloss.Color("#fe8019"),
|
||||
Teal: lipgloss.Color("#8ec07c"),
|
||||
Text: lipgloss.Color("#ebdbb2"),
|
||||
Subtext: lipgloss.Color("#a89984"),
|
||||
Overlay: lipgloss.Color("#928374"),
|
||||
Surface: lipgloss.Color("#3c3836"),
|
||||
Mantle: lipgloss.Color("#282828"),
|
||||
},
|
||||
{
|
||||
Name: "monokai",
|
||||
Purple: lipgloss.Color("#ae81ff"),
|
||||
Blue: lipgloss.Color("#66d9ef"),
|
||||
Green: lipgloss.Color("#a6e22e"),
|
||||
Red: lipgloss.Color("#f92672"),
|
||||
Yellow: lipgloss.Color("#e6db74"),
|
||||
Peach: lipgloss.Color("#fd971f"),
|
||||
Teal: lipgloss.Color("#a1efe4"),
|
||||
Text: lipgloss.Color("#f8f8f2"),
|
||||
Subtext: lipgloss.Color("#cfcfc2"),
|
||||
Overlay: lipgloss.Color("#75715e"),
|
||||
Surface: lipgloss.Color("#272822"),
|
||||
Mantle: lipgloss.Color("#1e1f1c"),
|
||||
},
|
||||
{
|
||||
Name: "solarized_light",
|
||||
Purple: lipgloss.Color("#6c71c4"),
|
||||
Blue: lipgloss.Color("#268bd2"),
|
||||
Green: lipgloss.Color("#859900"),
|
||||
Red: lipgloss.Color("#dc322f"),
|
||||
Yellow: lipgloss.Color("#b58900"),
|
||||
Peach: lipgloss.Color("#cb4b16"),
|
||||
Teal: lipgloss.Color("#2aa198"),
|
||||
Text: lipgloss.Color("#586e75"),
|
||||
Subtext: lipgloss.Color("#657b83"),
|
||||
Overlay: lipgloss.Color("#93a1a1"),
|
||||
Surface: lipgloss.Color("#eee8d5"),
|
||||
Mantle: lipgloss.Color("#fdf6e3"),
|
||||
},
|
||||
}
|
||||
|
||||
// themeStyles is the immutable snapshot of the active palette and the
|
||||
// pre-built lipgloss styles derived from it. ApplyTheme builds a fresh
|
||||
// snapshot and stores it atomically; readers Load() the pointer once and
|
||||
// see a coherent view, so no mutex is needed even if rendering ever moves
|
||||
// off the bubbletea event-loop goroutine.
|
||||
type themeStyles struct {
|
||||
name string
|
||||
|
||||
cPurple color.Color
|
||||
cBlue color.Color
|
||||
cGreen color.Color
|
||||
cRed color.Color
|
||||
cYellow color.Color
|
||||
cPeach color.Color
|
||||
cTeal color.Color
|
||||
cText color.Color
|
||||
cSubtext color.Color
|
||||
cOverlay color.Color
|
||||
cSurface color.Color
|
||||
cMantle color.Color
|
||||
|
||||
modeColors map[permission.Mode]color.Color
|
||||
|
||||
sHeaderBrand lipgloss.Style
|
||||
sHeaderModel lipgloss.Style
|
||||
sHeaderDim lipgloss.Style
|
||||
sUserLabel lipgloss.Style
|
||||
styleAssistantLabel lipgloss.Style
|
||||
sToolOutput lipgloss.Style
|
||||
sToolResult lipgloss.Style
|
||||
sSystem lipgloss.Style
|
||||
sError lipgloss.Style
|
||||
sHint lipgloss.Style
|
||||
sCursor lipgloss.Style
|
||||
sDiffAdd lipgloss.Style
|
||||
sDiffRemove lipgloss.Style
|
||||
sText lipgloss.Style
|
||||
sThinkingLabel lipgloss.Style
|
||||
sThinkingBody lipgloss.Style
|
||||
sStatusBar lipgloss.Style
|
||||
sStatusHighlight lipgloss.Style
|
||||
sStatusDim lipgloss.Style
|
||||
sStatusStreaming lipgloss.Style
|
||||
sStatusBranch lipgloss.Style
|
||||
sStatusIncognito lipgloss.Style
|
||||
}
|
||||
|
||||
var activeStyles atomic.Pointer[themeStyles]
|
||||
|
||||
// theme returns the currently-active style snapshot. The returned pointer
|
||||
// must be treated as read-only; ApplyTheme never mutates an existing
|
||||
// snapshot in place.
|
||||
func theme() *themeStyles {
|
||||
return activeStyles.Load()
|
||||
}
|
||||
|
||||
// ModeColor returns the color for a permission mode under the active theme.
|
||||
func ModeColor(mode permission.Mode) color.Color {
|
||||
if c, ok := modeColors[mode]; ok {
|
||||
t := theme()
|
||||
if c, ok := t.modeColors[mode]; ok {
|
||||
return c
|
||||
}
|
||||
return cOverlay
|
||||
return t.cOverlay
|
||||
}
|
||||
|
||||
// Header
|
||||
var (
|
||||
sHeaderBrand = lipgloss.NewStyle().
|
||||
Background(cPurple).
|
||||
Foreground(cMantle).
|
||||
Bold(true).
|
||||
Padding(0, 1)
|
||||
// Initialize with catppuccin on package load.
|
||||
func init() {
|
||||
ApplyTheme("catppuccin")
|
||||
}
|
||||
|
||||
sHeaderModel = lipgloss.NewStyle().
|
||||
Foreground(cGreen).
|
||||
Bold(true)
|
||||
// ApplyTheme builds a fresh themeStyles snapshot for the named theme and
|
||||
// atomically swaps it in as the active one. Concurrent reads via theme()
|
||||
// see either the previous snapshot or the new one — never a half-built
|
||||
// state. Returns false if name does not match a known theme.
|
||||
func ApplyTheme(name string) bool {
|
||||
var src *Theme
|
||||
for i := range Themes {
|
||||
tName := strings.ReplaceAll(strings.ToLower(Themes[i].Name), "_", "-")
|
||||
sName := strings.ReplaceAll(strings.ToLower(name), "_", "-")
|
||||
if tName == sName {
|
||||
src = &Themes[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
if src == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
sHeaderDim = lipgloss.NewStyle().
|
||||
Foreground(cOverlay)
|
||||
)
|
||||
t := &themeStyles{
|
||||
name: src.Name,
|
||||
cPurple: src.Purple,
|
||||
cBlue: src.Blue,
|
||||
cGreen: src.Green,
|
||||
cRed: src.Red,
|
||||
cYellow: src.Yellow,
|
||||
cPeach: src.Peach,
|
||||
cTeal: src.Teal,
|
||||
cText: src.Text,
|
||||
cSubtext: src.Subtext,
|
||||
cOverlay: src.Overlay,
|
||||
cSurface: src.Surface,
|
||||
cMantle: src.Mantle,
|
||||
}
|
||||
|
||||
// Chat
|
||||
var (
|
||||
sUserLabel = lipgloss.NewStyle().
|
||||
Foreground(cBlue).
|
||||
Bold(true)
|
||||
t.modeColors = map[permission.Mode]color.Color{
|
||||
permission.ModeBypass: t.cGreen,
|
||||
permission.ModeDefault: t.cBlue,
|
||||
permission.ModePlan: t.cTeal,
|
||||
permission.ModeAcceptEdits: t.cPurple,
|
||||
permission.ModeAuto: t.cPeach,
|
||||
permission.ModeDeny: t.cRed,
|
||||
}
|
||||
|
||||
styleAssistantLabel = lipgloss.NewStyle().
|
||||
Foreground(cPurple).
|
||||
Bold(true)
|
||||
t.sHeaderBrand = lipgloss.NewStyle().Background(t.cPurple).Foreground(t.cMantle).Bold(true).Padding(0, 1)
|
||||
t.sHeaderModel = lipgloss.NewStyle().Foreground(t.cGreen).Bold(true)
|
||||
t.sHeaderDim = lipgloss.NewStyle().Foreground(t.cOverlay)
|
||||
t.sUserLabel = lipgloss.NewStyle().Foreground(t.cBlue).Bold(true)
|
||||
t.styleAssistantLabel = lipgloss.NewStyle().Foreground(t.cPurple).Bold(true)
|
||||
t.sToolOutput = lipgloss.NewStyle().Foreground(t.cGreen)
|
||||
t.sToolResult = lipgloss.NewStyle().Foreground(t.cOverlay)
|
||||
t.sSystem = lipgloss.NewStyle().Foreground(t.cYellow)
|
||||
t.sError = lipgloss.NewStyle().Foreground(t.cRed)
|
||||
t.sHint = lipgloss.NewStyle().Foreground(t.cOverlay)
|
||||
t.sCursor = lipgloss.NewStyle().Foreground(t.cPurple)
|
||||
t.sDiffAdd = lipgloss.NewStyle().Foreground(t.cGreen)
|
||||
t.sDiffRemove = lipgloss.NewStyle().Foreground(t.cRed)
|
||||
t.sText = lipgloss.NewStyle().Foreground(t.cText)
|
||||
t.sThinkingLabel = lipgloss.NewStyle().Foreground(t.cOverlay).Italic(true)
|
||||
t.sThinkingBody = lipgloss.NewStyle().Foreground(t.cOverlay).Italic(true)
|
||||
t.sStatusBar = lipgloss.NewStyle().Foreground(t.cSubtext)
|
||||
t.sStatusHighlight = lipgloss.NewStyle().Foreground(t.cPurple).Bold(true)
|
||||
t.sStatusDim = lipgloss.NewStyle().Foreground(t.cOverlay)
|
||||
t.sStatusStreaming = lipgloss.NewStyle().Foreground(t.cYellow).Bold(true)
|
||||
t.sStatusBranch = lipgloss.NewStyle().Foreground(t.cGreen)
|
||||
t.sStatusIncognito = lipgloss.NewStyle().Foreground(t.cYellow)
|
||||
|
||||
sToolOutput = lipgloss.NewStyle().
|
||||
Foreground(cGreen)
|
||||
|
||||
sToolResult = lipgloss.NewStyle().
|
||||
Foreground(cOverlay)
|
||||
|
||||
sSystem = lipgloss.NewStyle().
|
||||
Foreground(cYellow)
|
||||
|
||||
sError = lipgloss.NewStyle().
|
||||
Foreground(cRed)
|
||||
|
||||
sHint = lipgloss.NewStyle().
|
||||
Foreground(cOverlay)
|
||||
|
||||
sCursor = lipgloss.NewStyle().
|
||||
Foreground(cPurple)
|
||||
|
||||
sDiffAdd = lipgloss.NewStyle().
|
||||
Foreground(cGreen)
|
||||
|
||||
sDiffRemove = lipgloss.NewStyle().
|
||||
Foreground(cRed)
|
||||
|
||||
sText = lipgloss.NewStyle().
|
||||
Foreground(cText)
|
||||
|
||||
sThinkingLabel = lipgloss.NewStyle().
|
||||
Foreground(cOverlay).
|
||||
Italic(true)
|
||||
|
||||
sThinkingBody = lipgloss.NewStyle().
|
||||
Foreground(cOverlay).
|
||||
Italic(true)
|
||||
)
|
||||
|
||||
// Status bar
|
||||
var (
|
||||
sStatusBar = lipgloss.NewStyle().
|
||||
Foreground(cSubtext)
|
||||
|
||||
sStatusHighlight = lipgloss.NewStyle().
|
||||
Foreground(cPurple).
|
||||
Bold(true)
|
||||
|
||||
sStatusDim = lipgloss.NewStyle().
|
||||
Foreground(cOverlay)
|
||||
|
||||
sStatusStreaming = lipgloss.NewStyle().
|
||||
Foreground(cYellow).
|
||||
Bold(true)
|
||||
|
||||
sStatusBranch = lipgloss.NewStyle().
|
||||
Foreground(cGreen)
|
||||
|
||||
sStatusIncognito = lipgloss.NewStyle().
|
||||
Foreground(cYellow)
|
||||
)
|
||||
activeStyles.Store(t)
|
||||
return true
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user