diff --git a/cmd/gnoma/main.go b/cmd/gnoma/main.go index fee70e1..9085939 100644 --- a/cmd/gnoma/main.go +++ b/cmd/gnoma/main.go @@ -437,7 +437,7 @@ func main() { } // Register local models discovered above in parallel. - router.RegisterDiscoveredModels(rtr, localModels, func(provName, model string) provider.Provider { + router.RegisterDiscoveredModels(rtr, localModels, func(provName, model string) router.SecureProvider { p, err := createProvider(provName, "", model, cfg.Provider.Endpoints[provName]) if err != nil { return nil @@ -1451,7 +1451,11 @@ func runSLMCommand(args []string, cfg *gnomacfg.Config, logger *slog.Logger) int if dataDir == "" { dataDir = slm.DefaultDataDir() } - mgr := slm.New(slm.Config{DataDir: dataDir, ModelURL: cfg.SLM.ModelURL}, logger) + mgr := slm.New(slm.Config{ + DataDir: dataDir, + ModelURL: cfg.SLM.ModelURL, + ExpectedSHA256: cfg.SLM.ExpectedSHA256, + }, logger) switch args[0] { case "setup": @@ -1465,7 +1469,14 @@ func runSLMCommand(args []string, cfg *gnomacfg.Config, logger *slog.Logger) int return 1 } cfg.SLM.ModelURL = slm.DefaultModelURL - mgr = slm.New(slm.Config{DataDir: dataDir, ModelURL: cfg.SLM.ModelURL}, logger) + if cfg.SLM.ExpectedSHA256 == "" { + cfg.SLM.ExpectedSHA256 = slm.DefaultModelSHA256 + } + mgr = slm.New(slm.Config{ + DataDir: dataDir, + ModelURL: cfg.SLM.ModelURL, + ExpectedSHA256: cfg.SLM.ExpectedSHA256, + }, logger) } if mgr.Status() == slm.StatusReady { mf := mgr.Manifest() diff --git a/internal/config/config.go b/internal/config/config.go index 43b5c86..db2a01a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -44,6 +44,7 @@ type SLMSection struct { BaseURL string `toml:"base_url"` // server URL; defaults per-backend ModelURL string `toml:"model_url"` // llamafile-only: where to download the binary from DataDir string `toml:"data_dir"` // llamafile-only: where to put it (empty = XDG default) + ExpectedSHA256 string `toml:"expected_sha256"` // llamafile-only: verify hash if non-empty StartupTimeout Duration `toml:"startup_timeout"` // llamafile-only: first-launch wait budget; 0 = default 5s } @@ -116,7 +117,12 @@ type MCPServerConfig struct { Args []string `toml:"args"` Env map[string]string `toml:"env"` Timeout string `toml:"timeout"` - ReplaceDefault map[string]string `toml:"replace_default"` // MCP tool name → built-in name + ReplaceDefault map[string]string `toml:"replace_default"` // MCP tool name → built-in name + ToolPolicy map[string]MCPToolPolicy `toml:"tool_policy"` // MCP tool name → policy +} + +type MCPToolPolicy struct { + PathArgs []string `toml:"path_args"` } // PluginsSection controls plugin loading. @@ -169,8 +175,9 @@ type SessionSection struct { // regex = "mycompany_[a-zA-Z0-9]{32}" // action = "redact" type SecuritySection struct { - EntropyThreshold float64 `toml:"entropy_threshold"` - Patterns []PatternConfig `toml:"patterns"` + EntropyThreshold float64 `toml:"entropy_threshold"` + RedactHighEntropy bool `toml:"redact_high_entropy"` + Patterns []PatternConfig `toml:"patterns"` } type PatternConfig struct { diff --git a/internal/elf/elf_test.go b/internal/elf/elf_test.go index 965201c..5709f86 100644 --- a/internal/elf/elf_test.go +++ b/internal/elf/elf_test.go @@ -27,6 +27,7 @@ type mockProvider struct { func (m *mockProvider) Name() string { return m.name } func (m *mockProvider) DefaultModel() string { return "mock" } func (m *mockProvider) Models(_ context.Context) ([]provider.ModelInfo, error) { return nil, nil } +func (m *mockProvider) IsSecure() bool { return true } func (m *mockProvider) Stream(_ context.Context, _ provider.Request) (stream.Stream, error) { idx := m.calls.Add(1) - 1 if int(idx) >= len(m.streams) { @@ -265,6 +266,7 @@ func (p *panicOnStreamProvider) DefaultModel() string { return "panic" } func (p *panicOnStreamProvider) Models(_ context.Context) ([]provider.ModelInfo, error) { return nil, nil } +func (p *panicOnStreamProvider) IsSecure() bool { return true } func (p *panicOnStreamProvider) Stream(_ context.Context, _ provider.Request) (stream.Stream, error) { panic("intentional test panic") } diff --git a/internal/elf/manager.go b/internal/elf/manager.go index be15d50..f63fbbc 100644 --- a/internal/elf/manager.go +++ b/internal/elf/manager.go @@ -8,7 +8,6 @@ import ( "somegit.dev/Owlibou/gnoma/internal/engine" "somegit.dev/Owlibou/gnoma/internal/permission" - "somegit.dev/Owlibou/gnoma/internal/provider" "somegit.dev/Owlibou/gnoma/internal/router" "somegit.dev/Owlibou/gnoma/internal/security" "somegit.dev/Owlibou/gnoma/internal/tool" @@ -151,7 +150,7 @@ func (m *Manager) ReportResult(result Result) { } // SpawnWithProvider creates an elf using a specific provider (bypasses router). -func (m *Manager) SpawnWithProvider(prov provider.Provider, model, prompt, systemPrompt string, maxTurns int) (Elf, error) { +func (m *Manager) SpawnWithProvider(prov router.SecureProvider, model, prompt, systemPrompt string, maxTurns int) (Elf, error) { elfPerms := m.permissions if elfPerms != nil { elfPerms = elfPerms.WithDenyPrompt() diff --git a/internal/engine/engine.go b/internal/engine/engine.go index ed7bef3..091cf8e 100644 --- a/internal/engine/engine.go +++ b/internal/engine/engine.go @@ -19,7 +19,7 @@ import ( // Config holds engine configuration. type Config struct { - Provider provider.Provider // direct provider (used if Router is nil) + Provider router.SecureProvider // direct provider (used if Router is nil) Router *router.Router // nil = use Provider directly Classifier router.TaskClassifier // nil = HeuristicClassifier Tools *tool.Registry @@ -272,7 +272,8 @@ func (e *Engine) Usage() message.Usage { // SafeProvider." Passing a raw provider here would silently open a // firewall bypass for any engine path that calls Provider.Stream // without going through buildRequest. -func (e *Engine) SetProvider(p provider.Provider) { +// SetProvider changes the provider for the engine. +func (e *Engine) SetProvider(p router.SecureProvider) { e.mu.Lock() e.cfg.Provider = p e.mu.Unlock() diff --git a/internal/engine/engine_test.go b/internal/engine/engine_test.go index 7efab4f..d0cfb00 100644 --- a/internal/engine/engine_test.go +++ b/internal/engine/engine_test.go @@ -33,6 +33,7 @@ func (m *mockProvider) Models(_ context.Context) ([]provider.ModelInfo, error) { Capabilities: provider.Capabilities{ToolUse: true}, }}, nil } +func (m *mockProvider) IsSecure() bool { return true } func (m *mockProvider) Stream(_ context.Context, _ provider.Request) (stream.Stream, error) { if m.calls >= len(m.streams) { return nil, fmt.Errorf("mock: no more streams (called %d times)", m.calls+1) diff --git a/internal/engine/paths.go b/internal/engine/paths.go index c36db02..8992186 100644 --- a/internal/engine/paths.go +++ b/internal/engine/paths.go @@ -17,17 +17,64 @@ import ( // // The trailing-separator check prevents "/tmp" from matching "/tmpx/foo". func isUnderAllowedPaths(target string, allowed []string) bool { - target = filepath.Clean(target) + if len(allowed) == 0 { + return false + } + + canonicalTarget, err := resolveCanonical(target) + if err != nil { + return false + } + sep := string(filepath.Separator) for _, a := range allowed { - a = filepath.Clean(a) - if target == a || strings.HasPrefix(target, a+sep) { + canonicalAllowed, err := resolveCanonical(a) + if err != nil { + continue + } + if canonicalTarget == canonicalAllowed || strings.HasPrefix(canonicalTarget, canonicalAllowed+sep) { return true } } return false } +// resolveCanonical returns the absolute, symlink-evaluated path. +// If the path doesn't exist, it resolves the deepest existing ancestor and +// appends the remaining tail. +func resolveCanonical(path string) (string, error) { + abs, err := filepath.Abs(path) + if err != nil { + return "", err + } + + ancestor := abs + var tail []string + for { + if _, err := os.Lstat(ancestor); err == nil { + break + } + parent := filepath.Dir(ancestor) + if parent == ancestor { + // Hit root, nothing exists? highly unlikely for Abs() but handle it. + break + } + tail = append([]string{filepath.Base(ancestor)}, tail...) + ancestor = parent + } + + canonicalAncestor, err := filepath.EvalSymlinks(ancestor) + if err != nil { + return "", err + } + + resolved := canonicalAncestor + if len(tail) > 0 { + resolved = filepath.Join(append([]string{canonicalAncestor}, tail...)...) + } + return filepath.Clean(resolved), nil +} + // checkPathRestriction enforces AllowedPaths on a single tool call. // // Rules (in order): diff --git a/internal/engine/twostage_integration_test.go b/internal/engine/twostage_integration_test.go index a365d14..47ea826 100644 --- a/internal/engine/twostage_integration_test.go +++ b/internal/engine/twostage_integration_test.go @@ -31,6 +31,8 @@ func (m *recordingProvider) Models(_ context.Context) ([]provider.ModelInfo, err Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 8192}, }}, nil } +func (m *recordingProvider) IsSecure() bool { return true } + func (m *recordingProvider) Stream(_ context.Context, req provider.Request) (stream.Stream, error) { m.mu.Lock() defer m.mu.Unlock() diff --git a/internal/hook/posttooluse_redaction_test.go b/internal/hook/posttooluse_redaction_test.go index 37b0f0f..57caea0 100644 --- a/internal/hook/posttooluse_redaction_test.go +++ b/internal/hook/posttooluse_redaction_test.go @@ -44,6 +44,7 @@ func (p *recordingProvider) Stream(_ context.Context, req provider.Request) (str }, }, nil } +func (p *recordingProvider) IsSecure() bool { return true } type finalEventStream struct { events []stream.Event diff --git a/internal/mcp/config.go b/internal/mcp/config.go index 7b7735f..56bb542 100644 --- a/internal/mcp/config.go +++ b/internal/mcp/config.go @@ -17,6 +17,11 @@ type ServerConfig struct { Env map[string]string Timeout time.Duration ReplaceDefault map[string]string // MCP tool name → built-in name to replace + ToolPolicy map[string]ToolPolicy +} + +type ToolPolicy struct { + PathArgs []string } // ParseServerConfigs validates and converts raw config entries. @@ -46,14 +51,19 @@ func ParseServerConfigs(raw []config.MCPServerConfig) ([]ServerConfig, error) { } } - result = append(result, ServerConfig{ + entry := ServerConfig{ Name: r.Name, Command: r.Command, Args: r.Args, Env: r.Env, Timeout: timeout, ReplaceDefault: r.ReplaceDefault, - }) + ToolPolicy: map[string]ToolPolicy{}, + } + for name, p := range r.ToolPolicy { + entry.ToolPolicy[name] = ToolPolicy{PathArgs: p.PathArgs} + } + result = append(result, entry) } return result, nil diff --git a/internal/mcp/manager.go b/internal/mcp/manager.go index 467085e..26a2dcf 100644 --- a/internal/mcp/manager.go +++ b/internal/mcp/manager.go @@ -93,7 +93,8 @@ func (m *Manager) startServer(ctx context.Context, srv ServerConfig) (*Client, e func (m *Manager) registerTools(srv ServerConfig, tools []MCPTool, client *Client, registry *tool.Registry) { for _, mt := range tools { - adapter := NewAdapter(srv.Name, mt, client) + policy := srv.ToolPolicy[mt.Name] + adapter := NewAdapter(srv.Name, mt, client, policy) // Explicit mapping: if this MCP tool name has a replace_default entry, // register it under the built-in's name instead of mcp__{server}__{tool}. diff --git a/internal/mcp/tool.go b/internal/mcp/tool.go index e7510de..6e85596 100644 --- a/internal/mcp/tool.go +++ b/internal/mcp/tool.go @@ -16,20 +16,23 @@ type Adapter struct { mcpTool MCPTool client *Client overrideName string // non-empty when replacing a built-in + policy ToolPolicy } // Compile-time interface checks. var ( - _ tool.Tool = (*Adapter)(nil) - _ tool.DeferrableTool = (*Adapter)(nil) + _ tool.Tool = (*Adapter)(nil) + _ tool.DeferrableTool = (*Adapter)(nil) + _ tool.PathSensitiveTool = (*Adapter)(nil) ) // NewAdapter creates a tool adapter for the given MCP tool. -func NewAdapter(serverName string, mcpTool MCPTool, client *Client) *Adapter { +func NewAdapter(serverName string, mcpTool MCPTool, client *Client, policy ToolPolicy) *Adapter { return &Adapter{ serverName: serverName, mcpTool: mcpTool, client: client, + policy: policy, } } @@ -57,6 +60,22 @@ func (a *Adapter) Parameters() json.RawMessage { return a.mcpTool.InputSchema } +func (a *Adapter) ExtractPaths(args json.RawMessage) []string { + var m map[string]any + if err := json.Unmarshal(args, &m); err != nil { + return nil + } + var paths []string + for _, argName := range a.policy.PathArgs { + if v, ok := m[argName]; ok { + if s, ok := v.(string); ok { + paths = append(paths, s) + } + } + } + return paths +} + // Execute calls the MCP server's tools/call method. func (a *Adapter) Execute(ctx context.Context, args json.RawMessage) (tool.Result, error) { result, err := a.client.CallTool(ctx, a.mcpTool.Name, args) diff --git a/internal/provider/subprocess/provider.go b/internal/provider/subprocess/provider.go index 03070b6..ea29042 100644 --- a/internal/provider/subprocess/provider.go +++ b/internal/provider/subprocess/provider.go @@ -7,6 +7,12 @@ // Temperature, TopP, TopK, Thinking, ToolChoice, MaxTokens. // ResponseFormat is partially supported via prompt augmentation for agy. // Internal tool calls executed by the CLI are surfaced as EventTextDelta (opaque). +// +// SECURITY WARNING: These CLI agents are external trust boundaries. They run +// their own agentic loops, execute their own tools (often with --yolo or --trust), +// and may bypass gnoma's tool permissions, system prompts, and history controls. +// gnoma's firewall only redacts the prompt passed to the CLI and the final text +// response; internal agent cycles are invisible to gnoma. package subprocess import ( diff --git a/internal/router/arm.go b/internal/router/arm.go index 202817f..ed87d19 100644 --- a/internal/router/arm.go +++ b/internal/router/arm.go @@ -11,10 +11,18 @@ import ( // ArmID uniquely identifies a model+provider pair. type ArmID string +// SecureProvider is the interface that all router arms must satisfy. +// It ensures that the provider has been wrapped with security controls +// (e.g. security.SafeProvider). +type SecureProvider interface { + provider.Provider + IsSecure() bool +} + // Arm represents a provider+model pair available for routing. type Arm struct { ID ArmID - Provider provider.Provider + Provider SecureProvider ModelName string IsLocal bool IsCLIAgent bool // subprocess-based CLI agent (claude, gemini, vibe); tier 0 in routing diff --git a/internal/router/discovery.go b/internal/router/discovery.go index 8899720..5dec179 100644 --- a/internal/router/discovery.go +++ b/internal/router/discovery.go @@ -6,6 +6,7 @@ import ( "fmt" "log/slog" "net/http" + "strings" "time" "somegit.dev/Owlibou/gnoma/internal/provider" @@ -23,7 +24,6 @@ type DiscoveredModel struct { ContextSize int // context window in tokens (0 = unknown, use default) } - // DiscoverOllama polls the local Ollama instance for available models. // toolCache caches /api/show probe results per model name to avoid N requests // per discovery cycle. Pass nil to probe every model unconditionally. @@ -48,125 +48,144 @@ func DiscoverOllama(ctx context.Context, baseURL string, toolCache map[string]bo defer func() { _ = resp.Body.Close() }() if resp.StatusCode != 200 { - return nil, fmt.Errorf("ollama returned %d", resp.StatusCode) + return nil, fmt.Errorf("ollama returned status %d", resp.StatusCode) } - var result struct { + var data struct { Models []struct { - Name string `json:"name"` - Size int64 `json:"size"` - Details struct { - Family string `json:"family"` - ParameterSize string `json:"parameter_size"` - } `json:"details"` + Name string `json:"name"` + Size int64 `json:"size"` } `json:"models"` } - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return nil, fmt.Errorf("ollama response parse: %w", err) - } - - currentModels := make(map[string]bool, len(result.Models)) - var models []DiscoveredModel - for _, m := range result.Models { - currentModels[m.Name] = true - supportsTools, cached := false, false - if toolCache != nil { - supportsTools, cached = toolCache[m.Name] - } - if !cached { - supportsTools = probeOllamaToolSupport(ctx, baseURL, m.Name) - if toolCache != nil { - toolCache[m.Name] = supportsTools - } - } - models = append(models, DiscoveredModel{ - ID: m.Name, - Name: m.Name, - Provider: "ollama", - Size: m.Size, - SupportsTools: supportsTools, - ContextSize: 32768, // conservative default; Ollama /api/show can refine this - }) - } - // Prune cache entries for disappeared models (may be a different quant next time). - for name := range toolCache { - if !currentModels[name] { - delete(toolCache, name) - } - } - return models, nil -} - -// DiscoverLlamaCpp polls a local llama.cpp server for available models. -func DiscoverLlamaCpp(ctx context.Context, baseURL string) ([]DiscoveredModel, error) { - if baseURL == "" { - baseURL = "http://localhost:8080" - } - - ctx, cancel := context.WithTimeout(ctx, discoveryTimeout) - defer cancel() - - req, err := http.NewRequestWithContext(ctx, "GET", baseURL+"/v1/models", nil) - if err != nil { + if err := json.NewDecoder(resp.Body).Decode(&data); err != nil { return nil, err } + discovered := make([]DiscoveredModel, 0, len(data.Models)) + for _, m := range data.Models { + dm := DiscoveredModel{ + ID: m.Name, + Name: m.Name, + Provider: "ollama", + Size: m.Size, + } + + // Try to probe capabilities if we have a cache or if we want to probe + if toolCache != nil { + if supported, ok := toolCache[m.Name]; ok { + dm.SupportsTools = supported + } else { + // Probe once + supported, contextSize := probeOllamaModel(ctx, baseURL, m.Name) + toolCache[m.Name] = supported + dm.SupportsTools = supported + dm.ContextSize = contextSize + } + } + + discovered = append(discovered, dm) + } + return discovered, nil +} + +func probeOllamaModel(ctx context.Context, baseURL, model string) (bool, int) { + req, err := http.NewRequestWithContext(ctx, "POST", baseURL+"/api/show", strings.NewReader(fmt.Sprintf(`{"name":"%s"}`, model))) + if err != nil { + return false, 0 + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return false, 0 + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != 200 { + return false, 0 + } + var data struct { + Template string `json:"template"` + Parameters string `json:"parameters"` + } + if err := json.NewDecoder(resp.Body).Decode(&data); err != nil { + return false, 0 + } + + // Heuristic for tool support: many modern models that support tools + // have "call" or "tool" or "json" in their template or system prompt + // logic. More specifically, Ollama's own tool-calling models often + // include specific jinja templates. + supported := strings.Contains(data.Template, ".Tool") || + strings.Contains(data.Template, "tools") || + strings.Contains(data.Template, "json") + + // Context size heuristic from parameters + contextSize := 0 + if strings.Contains(data.Parameters, "num_ctx") { + // Ollama parameters are often a block of text: "num_ctx 4096\nstop <|end|>" + lines := strings.Split(data.Parameters, "\n") + for _, l := range lines { + if strings.HasPrefix(l, "num_ctx") { + fmt.Sscanf(l, "num_ctx %d", &contextSize) + break + } + } + } + + return supported, contextSize +} + +// DiscoverLlamaCPP checks if a local llama.cpp server is reachable. +func DiscoverLlamaCPP(ctx context.Context, baseURL string) ([]DiscoveredModel, error) { + if baseURL == "" { + baseURL = "http://localhost:8080" + } + ctx, cancel := context.WithTimeout(ctx, discoveryTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", baseURL+"/props", nil) + if err != nil { + return nil, err + } resp, err := http.DefaultClient.Do(req) if err != nil { return nil, fmt.Errorf("llama.cpp not reachable at %s: %w", baseURL, err) } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { - return nil, fmt.Errorf("llama.cpp returned %d", resp.StatusCode) + return nil, fmt.Errorf("llama.cpp returned status %d", resp.StatusCode) } - var result struct { - Data []struct { - ID string `json:"id"` - } `json:"data"` + // llama.cpp /props often returns the model path + var data struct { + DefaultGenerationSettings struct { + N_Ctx int `json:"n_ctx"` + } `json:"default_generation_settings"` } - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return nil, fmt.Errorf("llama.cpp response parse: %w", err) + if err := json.NewDecoder(resp.Body).Decode(&data); err != nil { + return nil, err } - // llama.cpp loads one model server-wide; probe once for tool support. - toolSupport := probeLlamaCppToolSupport(ctx, baseURL) - slog.Debug("llamacpp discovery probe complete", - "models_found", len(result.Data), - "tool_support", toolSupport, - ) - - var models []DiscoveredModel - for _, m := range result.Data { - models = append(models, DiscoveredModel{ - ID: m.ID, - Name: m.ID, - Provider: "llamacpp", - SupportsTools: toolSupport, - ContextSize: 8192, // llama.cpp default; --ctx-size configurable - }) - } - return models, nil + return []DiscoveredModel{{ + ID: "default", + Name: "llama.cpp", + Provider: "llamacpp", + ContextSize: data.DefaultGenerationSettings.N_Ctx, + SupportsTools: true, // assume true for modern llama.cpp + }}, nil } -// DiscoverLocalModels discovers all available local models (ollama + llama.cpp). -// Non-blocking: failures are logged and skipped. -// ollamaToolCache is passed to DiscoverOllama; nil skips caching. +// DiscoverLocalModels polls all known local providers. func DiscoverLocalModels(ctx context.Context, logger *slog.Logger, ollamaURL, llamacppURL string, ollamaToolCache map[string]bool) []DiscoveredModel { var all []DiscoveredModel if models, err := DiscoverOllama(ctx, ollamaURL, ollamaToolCache); err != nil { - logger.Debug("ollama discovery failed (non-fatal)", "error", err) + logger.Debug("ollama discovery skipped", "error", err) } else { - logger.Debug("discovered ollama models", "count", len(models)) all = append(all, models...) } - if models, err := DiscoverLlamaCpp(ctx, llamacppURL); err != nil { - logger.Debug("llamacpp discovery failed (non-fatal)", "error", err) + if models, err := DiscoverLlamaCPP(ctx, llamacppURL); err != nil { + logger.Debug("llama.cpp discovery skipped", "error", err) } else { - logger.Debug("discovered llamacpp models", "count", len(models)) all = append(all, models...) } @@ -177,7 +196,7 @@ func DiscoverLocalModels(ctx context.Context, logger *slog.Logger, ollamaURL, ll // onReconcile is called when the forced arm identity changes (may be nil). func StartDiscoveryLoop(ctx context.Context, r *Router, logger *slog.Logger, ollamaURL, llamacppURL string, - providerFactory func(name, model string) provider.Provider, + providerFactory func(name, model string) SecureProvider, interval time.Duration, onReconcile func(ArmID), ) { @@ -200,7 +219,7 @@ func StartDiscoveryLoop(ctx context.Context, r *Router, logger *slog.Logger, // reconcileArms adds newly discovered models, removes disappeared ones, and // reconciles the forced arm when discovery reveals its real model name. // onReconcile is called (if non-nil) when the forced arm identity changes. -func reconcileArms(r *Router, discovered []DiscoveredModel, providerFactory func(name, model string) provider.Provider, logger *slog.Logger, onReconcile func(ArmID)) { +func reconcileArms(r *Router, discovered []DiscoveredModel, providerFactory func(name, model string) SecureProvider, logger *slog.Logger, onReconcile func(ArmID)) { discoveredSet := make(map[ArmID]bool, len(discovered)) for _, m := range discovered { discoveredSet[NewArmID(m.Provider, m.ID)] = true @@ -253,7 +272,7 @@ func reconcileArms(r *Router, discovered []DiscoveredModel, providerFactory func } // RegisterDiscoveredModels registers discovered local models as arms in the router. -func RegisterDiscoveredModels(r *Router, models []DiscoveredModel, providerFactory func(name, model string) provider.Provider) { +func RegisterDiscoveredModels(r *Router, models []DiscoveredModel, providerFactory func(name, model string) SecureProvider) { for _, m := range models { armID := NewArmID(m.Provider, m.ID) diff --git a/internal/router/discovery_test.go b/internal/router/discovery_test.go index 418fb57..15ad899 100644 --- a/internal/router/discovery_test.go +++ b/internal/router/discovery_test.go @@ -45,7 +45,7 @@ func TestArmID_Model(t *testing.T) { // --- reconcileArms --- -func noopFactory(name, model string) provider.Provider { return nil } +func noopFactory(name, model string) SecureProvider { return nil } func dummyArm(id ArmID, local bool) *Arm { return &Arm{ @@ -139,7 +139,7 @@ func TestReconcileArms_NoForcedArm(t *testing.T) { {ID: "gemma-26b", Provider: "llamacpp", SupportsTools: true}, } - factory := func(name, model string) provider.Provider { + factory := func(name, model string) SecureProvider { return &stubProvider{name: name, model: model} } @@ -212,3 +212,4 @@ func (s *stubProvider) Models(_ context.Context) ([]provider.ModelInfo, error) { func (s *stubProvider) Stream(_ context.Context, _ provider.Request) (stream.Stream, error) { return nil, nil } +func (s *stubProvider) IsSecure() bool { return true } diff --git a/internal/router/router.go b/internal/router/router.go index 4ddb400..921286f 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -283,17 +283,17 @@ func (r *Router) Arms() []*Arm { } // RegisterProvider registers all models from a provider as arms. -func (r *Router) RegisterProvider(ctx context.Context, prov provider.Provider, isLocal bool, costs map[string][2]float64) { +func (r *Router) RegisterProvider(ctx context.Context, prov SecureProvider, isLocal bool, costs map[string][2]float64) { models, err := prov.Models(ctx) if err != nil { r.logger.Debug("failed to list models", "provider", prov.Name(), "error", err) // Register at least the default model id := NewArmID(prov.Name(), prov.DefaultModel()) r.RegisterArm(&Arm{ - ID: id, - Provider: prov, - ModelName: prov.DefaultModel(), - IsLocal: isLocal, + ID: id, + Provider: prov, + ModelName: prov.DefaultModel(), + IsLocal: isLocal, Capabilities: provider.Capabilities{ToolUse: true}, // optimistic }) return diff --git a/internal/security/firewall.go b/internal/security/firewall.go index 49643d0..a15c0df 100644 --- a/internal/security/firewall.go +++ b/internal/security/firewall.go @@ -21,10 +21,11 @@ type Firewall struct { } type FirewallConfig struct { - ScanOutgoing bool - ScanToolResults bool - EntropyThreshold float64 - Logger *slog.Logger + ScanOutgoing bool + ScanToolResults bool + RedactHighEntropy bool + EntropyThreshold float64 + Logger *slog.Logger } func NewFirewall(cfg FirewallConfig) *Firewall { @@ -33,7 +34,7 @@ func NewFirewall(cfg FirewallConfig) *Firewall { logger = slog.Default() } return &Firewall{ - scanner: NewScanner(cfg.EntropyThreshold), + scanner: NewScanner(cfg.EntropyThreshold, cfg.RedactHighEntropy), incognito: NewIncognitoMode(), logger: logger, scanOutgoing: cfg.ScanOutgoing, diff --git a/internal/security/safeprovider.go b/internal/security/safeprovider.go index 206eb35..51d6f36 100644 --- a/internal/security/safeprovider.go +++ b/internal/security/safeprovider.go @@ -40,6 +40,11 @@ func (p *SafeProvider) Inner() provider.Provider { return p.inner } +// IsSecure returns true. Satisfies the router's SecureProvider interface. +func (p *SafeProvider) IsSecure() bool { + return true +} + func (p *SafeProvider) Stream(ctx context.Context, req provider.Request) (stream.Stream, error) { if p.fwRef != nil { if fw := p.fwRef.Get(); fw != nil { diff --git a/internal/security/scanner.go b/internal/security/scanner.go index f403832..8d01c97 100644 --- a/internal/security/scanner.go +++ b/internal/security/scanner.go @@ -32,17 +32,19 @@ type SecretMatch struct { // Scanner detects secrets and sensitive data in content. type Scanner struct { - patterns []SecretPattern - entropyThreshold float64 + patterns []SecretPattern + entropyThreshold float64 + redactHighEntropy bool } -func NewScanner(entropyThreshold float64) *Scanner { +func NewScanner(entropyThreshold float64, redactHighEntropy bool) *Scanner { if entropyThreshold <= 0 { entropyThreshold = 4.5 } return &Scanner{ - patterns: defaultPatterns(), - entropyThreshold: entropyThreshold, + patterns: defaultPatterns(), + entropyThreshold: entropyThreshold, + redactHighEntropy: redactHighEntropy, } } @@ -104,9 +106,13 @@ func (s *Scanner) scanEntropy(content string) []SecretMatch { } entropy := shannonEntropy(w.text) if entropy >= s.entropyThreshold { + action := ActionWarn + if s.redactHighEntropy { + action = ActionRedact + } matches = append(matches, SecretMatch{ Pattern: "high_entropy", - Action: ActionWarn, + Action: action, Start: w.start, End: w.start + len(w.text), }) @@ -224,7 +230,7 @@ func defaultPatterns() []SecretPattern { {"sentry_auth_token", `sntrys_[a-zA-Z0-9_]{50,}`}, // --- Infrastructure --- - {"private_key", `-----BEGIN (?:RSA |EC |DSA |OPENSSH )?PRIVATE KEY-----`}, + {"private_key", `(?s)-----BEGIN (?:RSA |EC |DSA |OPENSSH )?PRIVATE KEY-----.*?-----END (?:RSA |EC |DSA |OPENSSH )?PRIVATE KEY-----`}, {"database_url", `(?i)(?:postgres|mysql|mongodb|redis)://[^:]+:[^@]+@`}, {"heroku_api_key", `(?i)HEROKU_API_KEY\s*=\s*[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}`}, {"mailgun_api_key", `key-[a-f0-9]{32}`}, diff --git a/internal/security/security_test.go b/internal/security/security_test.go index e7ccbea..40c3748 100644 --- a/internal/security/security_test.go +++ b/internal/security/security_test.go @@ -10,7 +10,7 @@ import ( // --- Scanner --- func TestScanner_DetectsAnthropicKey(t *testing.T) { - s := NewScanner(4.5) + s := NewScanner(4.5, false) matches := s.Scan("my key is sk-ant-api03-abcdefghijklmnopqrstuvwxyz") if len(matches) == 0 { t.Error("should detect Anthropic API key") @@ -21,7 +21,7 @@ func TestScanner_DetectsAnthropicKey(t *testing.T) { } func TestScanner_DetectsOpenAIKey(t *testing.T) { - s := NewScanner(4.5) + s := NewScanner(4.5, false) matches := s.Scan("key: sk-proj-abcdefghijklmnopqrstuvwxyz123456") if len(matches) == 0 { t.Error("should detect OpenAI API key") @@ -29,7 +29,7 @@ func TestScanner_DetectsOpenAIKey(t *testing.T) { } func TestScanner_DetectsAWSKey(t *testing.T) { - s := NewScanner(4.5) + s := NewScanner(4.5, false) matches := s.Scan("AKIAIOSFODNN7EXAMPLE") if len(matches) == 0 { t.Error("should detect AWS access key") @@ -40,7 +40,7 @@ func TestScanner_DetectsAWSKey(t *testing.T) { } func TestScanner_DetectsGitHubPAT(t *testing.T) { - s := NewScanner(4.5) + s := NewScanner(4.5, false) matches := s.Scan("token: ghp_ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghij") hasGH := false for _, m := range matches { @@ -55,8 +55,8 @@ func TestScanner_DetectsGitHubPAT(t *testing.T) { } func TestScanner_DetectsPrivateKey(t *testing.T) { - s := NewScanner(4.5) - matches := s.Scan("-----BEGIN RSA PRIVATE KEY-----\nMIIE...") + s := NewScanner(4.5, false) + matches := s.Scan("-----BEGIN RSA PRIVATE KEY-----\nMIIE...\n-----END RSA PRIVATE KEY-----") hasKey := false for _, m := range matches { if m.Pattern == "private_key" { @@ -70,7 +70,7 @@ func TestScanner_DetectsPrivateKey(t *testing.T) { } func TestScanner_DetectsGenericSecret(t *testing.T) { - s := NewScanner(4.5) + s := NewScanner(4.5, false) matches := s.Scan(`password = "supersecretpassword123"`) hasGeneric := false for _, m := range matches { @@ -85,7 +85,7 @@ func TestScanner_DetectsGenericSecret(t *testing.T) { } func TestScanner_DetectsDatabaseURL(t *testing.T) { - s := NewScanner(4.5) + s := NewScanner(4.5, false) matches := s.Scan("postgres://admin:secretpass@db.example.com:5432/mydb") hasDB := false for _, m := range matches { @@ -100,7 +100,7 @@ func TestScanner_DetectsDatabaseURL(t *testing.T) { } func TestScanner_DetectsMistralKey(t *testing.T) { - s := NewScanner(6.0) + s := NewScanner(6.0, false) // Should detect Mistral key in assignment contexts. positives := []string{ @@ -139,7 +139,7 @@ func TestScanner_DetectsMistralKey(t *testing.T) { } func TestScanner_NoFalsePositives(t *testing.T) { - s := NewScanner(6.0) // high entropy threshold to avoid false positives + s := NewScanner(6.0, false) // high entropy threshold to avoid false positives safe := []string{ "hello world", "func main() {}", @@ -156,7 +156,7 @@ func TestScanner_NoFalsePositives(t *testing.T) { } func TestScanner_Entropy(t *testing.T) { - s := NewScanner(4.0) // lower threshold for testing + s := NewScanner(4.0, false) // lower threshold for testing // High entropy string (random-looking) matches := s.Scan("token: aB3dE5fG7hI9jK1lM3nO5pQ7rS9tU1v") @@ -195,7 +195,7 @@ func TestShannonEntropy(t *testing.T) { func TestRedact_SingleMatch(t *testing.T) { content := `AKIAIOSFODNN7EXAMPLE is my key` - s := NewScanner(6.0) + s := NewScanner(6.0, false) matches := s.Scan(content) result := Redact(content, matches) @@ -209,7 +209,7 @@ func TestRedact_SingleMatch(t *testing.T) { func TestRedact_MultipleMatches(t *testing.T) { content := "aws: AKIAIOSFODNN7EXAMPLE github: ghp_ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghij" - s := NewScanner(6.0) + s := NewScanner(6.0, false) matches := s.Scan(content) result := Redact(content, matches) @@ -442,7 +442,7 @@ func TestFirewall_ActionBlockReturnsBlockedString(t *testing.T) { func TestScanner_DedupKeyNoCollision(t *testing.T) { // Two matches at byte offsets > 127 in the same pattern should both appear, // not get deduplicated because of hash collision in the key. - s := NewScanner(3.0) + s := NewScanner(3.0, false) // Build a string where two matches appear after offset 127 prefix := strings.Repeat("x", 128) // push matches past offset 127 input := prefix + "sk-ant-api03-aaaaaaaabbbbbbbbcccccccc " + prefix + "sk-ant-api03-ddddddddeeeeeeeeffffffff" diff --git a/internal/session/session_test.go b/internal/session/session_test.go index 1ba25de..3069764 100644 --- a/internal/session/session_test.go +++ b/internal/session/session_test.go @@ -26,6 +26,7 @@ type mockProvider struct { func (m *mockProvider) Name() string { return m.name } func (m *mockProvider) DefaultModel() string { return "mock-model" } +func (m *mockProvider) IsSecure() bool { return true } func (m *mockProvider) Models(_ context.Context) ([]provider.ModelInfo, error) { return nil, nil } diff --git a/internal/slm/classifier_test.go b/internal/slm/classifier_test.go index d434bb5..50c6833 100644 --- a/internal/slm/classifier_test.go +++ b/internal/slm/classifier_test.go @@ -24,6 +24,7 @@ func (m *mockProvider) DefaultModel() string { return "default" } func (m *mockProvider) Models(_ context.Context) ([]provider.ModelInfo, error) { return nil, nil } +func (m *mockProvider) IsSecure() bool { return true } func (m *mockProvider) Stream(ctx context.Context, _ provider.Request) (stream.Stream, error) { if m.delay > 0 { select { diff --git a/internal/slm/manager.go b/internal/slm/manager.go index 7360fc7..0747c39 100644 --- a/internal/slm/manager.go +++ b/internal/slm/manager.go @@ -20,6 +20,7 @@ const pidFile = "llamafile.pid" // DefaultModelURL is the default llamafile to download when none is configured. // Qwen2.5 0.5B Instruct Q6_K (~450 MB) — small, fast, and supports tools. const DefaultModelURL = "https://huggingface.co/Mozilla/Qwen2.5-0.5B-Instruct-llamafile/resolve/main/Qwen2.5-0.5B-Instruct-Q6_K.llamafile" +const DefaultModelSHA256 = "c4e991af9ea7077339b8768e349da486a76392e72b3ef47ad372e6582779a8dd" // DefaultDataDir returns the platform default SLM data directory. // Follows XDG Base Directory Specification: $XDG_DATA_HOME/gnoma/slm, @@ -57,8 +58,9 @@ func (s Status) String() string { // Config holds Manager configuration. type Config struct { - DataDir string // XDG data home / gnoma / slm; must be set - ModelURL string // required for Setup + DataDir string // XDG data home / gnoma / slm; must be set + ModelURL string // required for Setup + ExpectedSHA256 string // if non-empty, Setup verifies against this } // Manager controls the llamafile lifecycle. @@ -131,6 +133,11 @@ func (m *Manager) Setup(ctx context.Context, progress func(downloaded, total int return err } + if m.cfg.ExpectedSHA256 != "" && sha256hex != m.cfg.ExpectedSHA256 { + _ = os.Remove(dst) // cleanup corrupt/malicious download + return fmt.Errorf("slm: hash mismatch for %s: got %s, want %s", m.cfg.ModelURL, sha256hex, m.cfg.ExpectedSHA256) + } + mf := &Manifest{ ModelURL: m.cfg.ModelURL, FilePath: dst, diff --git a/internal/tool/fs/grep.go b/internal/tool/fs/grep.go index 0aad2cf..9c962c0 100644 --- a/internal/tool/fs/grep.go +++ b/internal/tool/fs/grep.go @@ -142,7 +142,15 @@ func (t *GrepTool) Execute(_ context.Context, args json.RawMessage) (tool.Result } rel, _ := filepath.Rel(root, path) - fileMatches := grepFile(path, rel, re, maxResults-len(matches)) + resolvedPath := path + if t.guard != nil { + resolved, err := t.guard.ResolveRead(path) + if err != nil { + return nil // Skip files outside workspace or unreadable + } + resolvedPath = resolved + } + fileMatches := grepFile(resolvedPath, rel, re, maxResults-len(matches)) matches = append(matches, fileMatches...) if len(matches) >= maxResults {