feat(security): implement multi-wave audit remediation and agy provider support
Implemented full security remediation following Universal Security Pilot protocol: - W1: Enforced SecureProvider at router and engine boundaries to prevent bypasses. - W1: Implemented path-sensitive policy for MCP tools. - W2: Added SHA256 hash verification for SLM downloads (llamafile). - W3: Enhanced secret redaction for private keys (full body) and high-entropy strings. - W4: Fixed symlink-based filesystem sandbox escapes in paths and grep. - W4: Documented CLI agent trust boundaries. Also added 'agy' (Antigravity) as a subprocess CLI provider with plain-text JSON schema support.
This commit is contained in:
+14
-3
@@ -437,7 +437,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Register local models discovered above in parallel.
|
// Register local models discovered above in parallel.
|
||||||
router.RegisterDiscoveredModels(rtr, localModels, func(provName, model string) provider.Provider {
|
router.RegisterDiscoveredModels(rtr, localModels, func(provName, model string) router.SecureProvider {
|
||||||
p, err := createProvider(provName, "", model, cfg.Provider.Endpoints[provName])
|
p, err := createProvider(provName, "", model, cfg.Provider.Endpoints[provName])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
@@ -1451,7 +1451,11 @@ func runSLMCommand(args []string, cfg *gnomacfg.Config, logger *slog.Logger) int
|
|||||||
if dataDir == "" {
|
if dataDir == "" {
|
||||||
dataDir = slm.DefaultDataDir()
|
dataDir = slm.DefaultDataDir()
|
||||||
}
|
}
|
||||||
mgr := slm.New(slm.Config{DataDir: dataDir, ModelURL: cfg.SLM.ModelURL}, logger)
|
mgr := slm.New(slm.Config{
|
||||||
|
DataDir: dataDir,
|
||||||
|
ModelURL: cfg.SLM.ModelURL,
|
||||||
|
ExpectedSHA256: cfg.SLM.ExpectedSHA256,
|
||||||
|
}, logger)
|
||||||
|
|
||||||
switch args[0] {
|
switch args[0] {
|
||||||
case "setup":
|
case "setup":
|
||||||
@@ -1465,7 +1469,14 @@ func runSLMCommand(args []string, cfg *gnomacfg.Config, logger *slog.Logger) int
|
|||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
cfg.SLM.ModelURL = slm.DefaultModelURL
|
cfg.SLM.ModelURL = slm.DefaultModelURL
|
||||||
mgr = slm.New(slm.Config{DataDir: dataDir, ModelURL: cfg.SLM.ModelURL}, logger)
|
if cfg.SLM.ExpectedSHA256 == "" {
|
||||||
|
cfg.SLM.ExpectedSHA256 = slm.DefaultModelSHA256
|
||||||
|
}
|
||||||
|
mgr = slm.New(slm.Config{
|
||||||
|
DataDir: dataDir,
|
||||||
|
ModelURL: cfg.SLM.ModelURL,
|
||||||
|
ExpectedSHA256: cfg.SLM.ExpectedSHA256,
|
||||||
|
}, logger)
|
||||||
}
|
}
|
||||||
if mgr.Status() == slm.StatusReady {
|
if mgr.Status() == slm.StatusReady {
|
||||||
mf := mgr.Manifest()
|
mf := mgr.Manifest()
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ type SLMSection struct {
|
|||||||
BaseURL string `toml:"base_url"` // server URL; defaults per-backend
|
BaseURL string `toml:"base_url"` // server URL; defaults per-backend
|
||||||
ModelURL string `toml:"model_url"` // llamafile-only: where to download the binary from
|
ModelURL string `toml:"model_url"` // llamafile-only: where to download the binary from
|
||||||
DataDir string `toml:"data_dir"` // llamafile-only: where to put it (empty = XDG default)
|
DataDir string `toml:"data_dir"` // llamafile-only: where to put it (empty = XDG default)
|
||||||
|
ExpectedSHA256 string `toml:"expected_sha256"` // llamafile-only: verify hash if non-empty
|
||||||
StartupTimeout Duration `toml:"startup_timeout"` // llamafile-only: first-launch wait budget; 0 = default 5s
|
StartupTimeout Duration `toml:"startup_timeout"` // llamafile-only: first-launch wait budget; 0 = default 5s
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -116,7 +117,12 @@ type MCPServerConfig struct {
|
|||||||
Args []string `toml:"args"`
|
Args []string `toml:"args"`
|
||||||
Env map[string]string `toml:"env"`
|
Env map[string]string `toml:"env"`
|
||||||
Timeout string `toml:"timeout"`
|
Timeout string `toml:"timeout"`
|
||||||
ReplaceDefault map[string]string `toml:"replace_default"` // MCP tool name → built-in name
|
ReplaceDefault map[string]string `toml:"replace_default"` // MCP tool name → built-in name
|
||||||
|
ToolPolicy map[string]MCPToolPolicy `toml:"tool_policy"` // MCP tool name → policy
|
||||||
|
}
|
||||||
|
|
||||||
|
type MCPToolPolicy struct {
|
||||||
|
PathArgs []string `toml:"path_args"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// PluginsSection controls plugin loading.
|
// PluginsSection controls plugin loading.
|
||||||
@@ -169,8 +175,9 @@ type SessionSection struct {
|
|||||||
// regex = "mycompany_[a-zA-Z0-9]{32}"
|
// regex = "mycompany_[a-zA-Z0-9]{32}"
|
||||||
// action = "redact"
|
// action = "redact"
|
||||||
type SecuritySection struct {
|
type SecuritySection struct {
|
||||||
EntropyThreshold float64 `toml:"entropy_threshold"`
|
EntropyThreshold float64 `toml:"entropy_threshold"`
|
||||||
Patterns []PatternConfig `toml:"patterns"`
|
RedactHighEntropy bool `toml:"redact_high_entropy"`
|
||||||
|
Patterns []PatternConfig `toml:"patterns"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type PatternConfig struct {
|
type PatternConfig struct {
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ type mockProvider struct {
|
|||||||
func (m *mockProvider) Name() string { return m.name }
|
func (m *mockProvider) Name() string { return m.name }
|
||||||
func (m *mockProvider) DefaultModel() string { return "mock" }
|
func (m *mockProvider) DefaultModel() string { return "mock" }
|
||||||
func (m *mockProvider) Models(_ context.Context) ([]provider.ModelInfo, error) { return nil, nil }
|
func (m *mockProvider) Models(_ context.Context) ([]provider.ModelInfo, error) { return nil, nil }
|
||||||
|
func (m *mockProvider) IsSecure() bool { return true }
|
||||||
func (m *mockProvider) Stream(_ context.Context, _ provider.Request) (stream.Stream, error) {
|
func (m *mockProvider) Stream(_ context.Context, _ provider.Request) (stream.Stream, error) {
|
||||||
idx := m.calls.Add(1) - 1
|
idx := m.calls.Add(1) - 1
|
||||||
if int(idx) >= len(m.streams) {
|
if int(idx) >= len(m.streams) {
|
||||||
@@ -265,6 +266,7 @@ func (p *panicOnStreamProvider) DefaultModel() string { return "panic" }
|
|||||||
func (p *panicOnStreamProvider) Models(_ context.Context) ([]provider.ModelInfo, error) {
|
func (p *panicOnStreamProvider) Models(_ context.Context) ([]provider.ModelInfo, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
func (p *panicOnStreamProvider) IsSecure() bool { return true }
|
||||||
func (p *panicOnStreamProvider) Stream(_ context.Context, _ provider.Request) (stream.Stream, error) {
|
func (p *panicOnStreamProvider) Stream(_ context.Context, _ provider.Request) (stream.Stream, error) {
|
||||||
panic("intentional test panic")
|
panic("intentional test panic")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
|
|
||||||
"somegit.dev/Owlibou/gnoma/internal/engine"
|
"somegit.dev/Owlibou/gnoma/internal/engine"
|
||||||
"somegit.dev/Owlibou/gnoma/internal/permission"
|
"somegit.dev/Owlibou/gnoma/internal/permission"
|
||||||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
|
||||||
"somegit.dev/Owlibou/gnoma/internal/router"
|
"somegit.dev/Owlibou/gnoma/internal/router"
|
||||||
"somegit.dev/Owlibou/gnoma/internal/security"
|
"somegit.dev/Owlibou/gnoma/internal/security"
|
||||||
"somegit.dev/Owlibou/gnoma/internal/tool"
|
"somegit.dev/Owlibou/gnoma/internal/tool"
|
||||||
@@ -151,7 +150,7 @@ func (m *Manager) ReportResult(result Result) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SpawnWithProvider creates an elf using a specific provider (bypasses router).
|
// SpawnWithProvider creates an elf using a specific provider (bypasses router).
|
||||||
func (m *Manager) SpawnWithProvider(prov provider.Provider, model, prompt, systemPrompt string, maxTurns int) (Elf, error) {
|
func (m *Manager) SpawnWithProvider(prov router.SecureProvider, model, prompt, systemPrompt string, maxTurns int) (Elf, error) {
|
||||||
elfPerms := m.permissions
|
elfPerms := m.permissions
|
||||||
if elfPerms != nil {
|
if elfPerms != nil {
|
||||||
elfPerms = elfPerms.WithDenyPrompt()
|
elfPerms = elfPerms.WithDenyPrompt()
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import (
|
|||||||
|
|
||||||
// Config holds engine configuration.
|
// Config holds engine configuration.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Provider provider.Provider // direct provider (used if Router is nil)
|
Provider router.SecureProvider // direct provider (used if Router is nil)
|
||||||
Router *router.Router // nil = use Provider directly
|
Router *router.Router // nil = use Provider directly
|
||||||
Classifier router.TaskClassifier // nil = HeuristicClassifier
|
Classifier router.TaskClassifier // nil = HeuristicClassifier
|
||||||
Tools *tool.Registry
|
Tools *tool.Registry
|
||||||
@@ -272,7 +272,8 @@ func (e *Engine) Usage() message.Usage {
|
|||||||
// SafeProvider." Passing a raw provider here would silently open a
|
// SafeProvider." Passing a raw provider here would silently open a
|
||||||
// firewall bypass for any engine path that calls Provider.Stream
|
// firewall bypass for any engine path that calls Provider.Stream
|
||||||
// without going through buildRequest.
|
// without going through buildRequest.
|
||||||
func (e *Engine) SetProvider(p provider.Provider) {
|
// SetProvider changes the provider for the engine.
|
||||||
|
func (e *Engine) SetProvider(p router.SecureProvider) {
|
||||||
e.mu.Lock()
|
e.mu.Lock()
|
||||||
e.cfg.Provider = p
|
e.cfg.Provider = p
|
||||||
e.mu.Unlock()
|
e.mu.Unlock()
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ func (m *mockProvider) Models(_ context.Context) ([]provider.ModelInfo, error) {
|
|||||||
Capabilities: provider.Capabilities{ToolUse: true},
|
Capabilities: provider.Capabilities{ToolUse: true},
|
||||||
}}, nil
|
}}, nil
|
||||||
}
|
}
|
||||||
|
func (m *mockProvider) IsSecure() bool { return true }
|
||||||
func (m *mockProvider) Stream(_ context.Context, _ provider.Request) (stream.Stream, error) {
|
func (m *mockProvider) Stream(_ context.Context, _ provider.Request) (stream.Stream, error) {
|
||||||
if m.calls >= len(m.streams) {
|
if m.calls >= len(m.streams) {
|
||||||
return nil, fmt.Errorf("mock: no more streams (called %d times)", m.calls+1)
|
return nil, fmt.Errorf("mock: no more streams (called %d times)", m.calls+1)
|
||||||
|
|||||||
@@ -17,17 +17,64 @@ import (
|
|||||||
//
|
//
|
||||||
// The trailing-separator check prevents "/tmp" from matching "/tmpx/foo".
|
// The trailing-separator check prevents "/tmp" from matching "/tmpx/foo".
|
||||||
func isUnderAllowedPaths(target string, allowed []string) bool {
|
func isUnderAllowedPaths(target string, allowed []string) bool {
|
||||||
target = filepath.Clean(target)
|
if len(allowed) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
canonicalTarget, err := resolveCanonical(target)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
sep := string(filepath.Separator)
|
sep := string(filepath.Separator)
|
||||||
for _, a := range allowed {
|
for _, a := range allowed {
|
||||||
a = filepath.Clean(a)
|
canonicalAllowed, err := resolveCanonical(a)
|
||||||
if target == a || strings.HasPrefix(target, a+sep) {
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if canonicalTarget == canonicalAllowed || strings.HasPrefix(canonicalTarget, canonicalAllowed+sep) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// resolveCanonical returns the absolute, symlink-evaluated path.
|
||||||
|
// If the path doesn't exist, it resolves the deepest existing ancestor and
|
||||||
|
// appends the remaining tail.
|
||||||
|
func resolveCanonical(path string) (string, error) {
|
||||||
|
abs, err := filepath.Abs(path)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
ancestor := abs
|
||||||
|
var tail []string
|
||||||
|
for {
|
||||||
|
if _, err := os.Lstat(ancestor); err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
parent := filepath.Dir(ancestor)
|
||||||
|
if parent == ancestor {
|
||||||
|
// Hit root, nothing exists? highly unlikely for Abs() but handle it.
|
||||||
|
break
|
||||||
|
}
|
||||||
|
tail = append([]string{filepath.Base(ancestor)}, tail...)
|
||||||
|
ancestor = parent
|
||||||
|
}
|
||||||
|
|
||||||
|
canonicalAncestor, err := filepath.EvalSymlinks(ancestor)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
resolved := canonicalAncestor
|
||||||
|
if len(tail) > 0 {
|
||||||
|
resolved = filepath.Join(append([]string{canonicalAncestor}, tail...)...)
|
||||||
|
}
|
||||||
|
return filepath.Clean(resolved), nil
|
||||||
|
}
|
||||||
|
|
||||||
// checkPathRestriction enforces AllowedPaths on a single tool call.
|
// checkPathRestriction enforces AllowedPaths on a single tool call.
|
||||||
//
|
//
|
||||||
// Rules (in order):
|
// Rules (in order):
|
||||||
|
|||||||
@@ -31,6 +31,8 @@ func (m *recordingProvider) Models(_ context.Context) ([]provider.ModelInfo, err
|
|||||||
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 8192},
|
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 8192},
|
||||||
}}, nil
|
}}, nil
|
||||||
}
|
}
|
||||||
|
func (m *recordingProvider) IsSecure() bool { return true }
|
||||||
|
|
||||||
func (m *recordingProvider) Stream(_ context.Context, req provider.Request) (stream.Stream, error) {
|
func (m *recordingProvider) Stream(_ context.Context, req provider.Request) (stream.Stream, error) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ func (p *recordingProvider) Stream(_ context.Context, req provider.Request) (str
|
|||||||
},
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
func (p *recordingProvider) IsSecure() bool { return true }
|
||||||
|
|
||||||
type finalEventStream struct {
|
type finalEventStream struct {
|
||||||
events []stream.Event
|
events []stream.Event
|
||||||
|
|||||||
+12
-2
@@ -17,6 +17,11 @@ type ServerConfig struct {
|
|||||||
Env map[string]string
|
Env map[string]string
|
||||||
Timeout time.Duration
|
Timeout time.Duration
|
||||||
ReplaceDefault map[string]string // MCP tool name → built-in name to replace
|
ReplaceDefault map[string]string // MCP tool name → built-in name to replace
|
||||||
|
ToolPolicy map[string]ToolPolicy
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolPolicy struct {
|
||||||
|
PathArgs []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseServerConfigs validates and converts raw config entries.
|
// ParseServerConfigs validates and converts raw config entries.
|
||||||
@@ -46,14 +51,19 @@ func ParseServerConfigs(raw []config.MCPServerConfig) ([]ServerConfig, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
result = append(result, ServerConfig{
|
entry := ServerConfig{
|
||||||
Name: r.Name,
|
Name: r.Name,
|
||||||
Command: r.Command,
|
Command: r.Command,
|
||||||
Args: r.Args,
|
Args: r.Args,
|
||||||
Env: r.Env,
|
Env: r.Env,
|
||||||
Timeout: timeout,
|
Timeout: timeout,
|
||||||
ReplaceDefault: r.ReplaceDefault,
|
ReplaceDefault: r.ReplaceDefault,
|
||||||
})
|
ToolPolicy: map[string]ToolPolicy{},
|
||||||
|
}
|
||||||
|
for name, p := range r.ToolPolicy {
|
||||||
|
entry.ToolPolicy[name] = ToolPolicy{PathArgs: p.PathArgs}
|
||||||
|
}
|
||||||
|
result = append(result, entry)
|
||||||
}
|
}
|
||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
|
|||||||
@@ -93,7 +93,8 @@ func (m *Manager) startServer(ctx context.Context, srv ServerConfig) (*Client, e
|
|||||||
|
|
||||||
func (m *Manager) registerTools(srv ServerConfig, tools []MCPTool, client *Client, registry *tool.Registry) {
|
func (m *Manager) registerTools(srv ServerConfig, tools []MCPTool, client *Client, registry *tool.Registry) {
|
||||||
for _, mt := range tools {
|
for _, mt := range tools {
|
||||||
adapter := NewAdapter(srv.Name, mt, client)
|
policy := srv.ToolPolicy[mt.Name]
|
||||||
|
adapter := NewAdapter(srv.Name, mt, client, policy)
|
||||||
|
|
||||||
// Explicit mapping: if this MCP tool name has a replace_default entry,
|
// Explicit mapping: if this MCP tool name has a replace_default entry,
|
||||||
// register it under the built-in's name instead of mcp__{server}__{tool}.
|
// register it under the built-in's name instead of mcp__{server}__{tool}.
|
||||||
|
|||||||
+22
-3
@@ -16,20 +16,23 @@ type Adapter struct {
|
|||||||
mcpTool MCPTool
|
mcpTool MCPTool
|
||||||
client *Client
|
client *Client
|
||||||
overrideName string // non-empty when replacing a built-in
|
overrideName string // non-empty when replacing a built-in
|
||||||
|
policy ToolPolicy
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compile-time interface checks.
|
// Compile-time interface checks.
|
||||||
var (
|
var (
|
||||||
_ tool.Tool = (*Adapter)(nil)
|
_ tool.Tool = (*Adapter)(nil)
|
||||||
_ tool.DeferrableTool = (*Adapter)(nil)
|
_ tool.DeferrableTool = (*Adapter)(nil)
|
||||||
|
_ tool.PathSensitiveTool = (*Adapter)(nil)
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewAdapter creates a tool adapter for the given MCP tool.
|
// NewAdapter creates a tool adapter for the given MCP tool.
|
||||||
func NewAdapter(serverName string, mcpTool MCPTool, client *Client) *Adapter {
|
func NewAdapter(serverName string, mcpTool MCPTool, client *Client, policy ToolPolicy) *Adapter {
|
||||||
return &Adapter{
|
return &Adapter{
|
||||||
serverName: serverName,
|
serverName: serverName,
|
||||||
mcpTool: mcpTool,
|
mcpTool: mcpTool,
|
||||||
client: client,
|
client: client,
|
||||||
|
policy: policy,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -57,6 +60,22 @@ func (a *Adapter) Parameters() json.RawMessage {
|
|||||||
return a.mcpTool.InputSchema
|
return a.mcpTool.InputSchema
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adapter) ExtractPaths(args json.RawMessage) []string {
|
||||||
|
var m map[string]any
|
||||||
|
if err := json.Unmarshal(args, &m); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var paths []string
|
||||||
|
for _, argName := range a.policy.PathArgs {
|
||||||
|
if v, ok := m[argName]; ok {
|
||||||
|
if s, ok := v.(string); ok {
|
||||||
|
paths = append(paths, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return paths
|
||||||
|
}
|
||||||
|
|
||||||
// Execute calls the MCP server's tools/call method.
|
// Execute calls the MCP server's tools/call method.
|
||||||
func (a *Adapter) Execute(ctx context.Context, args json.RawMessage) (tool.Result, error) {
|
func (a *Adapter) Execute(ctx context.Context, args json.RawMessage) (tool.Result, error) {
|
||||||
result, err := a.client.CallTool(ctx, a.mcpTool.Name, args)
|
result, err := a.client.CallTool(ctx, a.mcpTool.Name, args)
|
||||||
|
|||||||
@@ -7,6 +7,12 @@
|
|||||||
// Temperature, TopP, TopK, Thinking, ToolChoice, MaxTokens.
|
// Temperature, TopP, TopK, Thinking, ToolChoice, MaxTokens.
|
||||||
// ResponseFormat is partially supported via prompt augmentation for agy.
|
// ResponseFormat is partially supported via prompt augmentation for agy.
|
||||||
// Internal tool calls executed by the CLI are surfaced as EventTextDelta (opaque).
|
// Internal tool calls executed by the CLI are surfaced as EventTextDelta (opaque).
|
||||||
|
//
|
||||||
|
// SECURITY WARNING: These CLI agents are external trust boundaries. They run
|
||||||
|
// their own agentic loops, execute their own tools (often with --yolo or --trust),
|
||||||
|
// and may bypass gnoma's tool permissions, system prompts, and history controls.
|
||||||
|
// gnoma's firewall only redacts the prompt passed to the CLI and the final text
|
||||||
|
// response; internal agent cycles are invisible to gnoma.
|
||||||
package subprocess
|
package subprocess
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -11,10 +11,18 @@ import (
|
|||||||
// ArmID uniquely identifies a model+provider pair.
|
// ArmID uniquely identifies a model+provider pair.
|
||||||
type ArmID string
|
type ArmID string
|
||||||
|
|
||||||
|
// SecureProvider is the interface that all router arms must satisfy.
|
||||||
|
// It ensures that the provider has been wrapped with security controls
|
||||||
|
// (e.g. security.SafeProvider).
|
||||||
|
type SecureProvider interface {
|
||||||
|
provider.Provider
|
||||||
|
IsSecure() bool
|
||||||
|
}
|
||||||
|
|
||||||
// Arm represents a provider+model pair available for routing.
|
// Arm represents a provider+model pair available for routing.
|
||||||
type Arm struct {
|
type Arm struct {
|
||||||
ID ArmID
|
ID ArmID
|
||||||
Provider provider.Provider
|
Provider SecureProvider
|
||||||
ModelName string
|
ModelName string
|
||||||
IsLocal bool
|
IsLocal bool
|
||||||
IsCLIAgent bool // subprocess-based CLI agent (claude, gemini, vibe); tier 0 in routing
|
IsCLIAgent bool // subprocess-based CLI agent (claude, gemini, vibe); tier 0 in routing
|
||||||
|
|||||||
+112
-93
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||||||
@@ -23,7 +24,6 @@ type DiscoveredModel struct {
|
|||||||
ContextSize int // context window in tokens (0 = unknown, use default)
|
ContextSize int // context window in tokens (0 = unknown, use default)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// DiscoverOllama polls the local Ollama instance for available models.
|
// DiscoverOllama polls the local Ollama instance for available models.
|
||||||
// toolCache caches /api/show probe results per model name to avoid N requests
|
// toolCache caches /api/show probe results per model name to avoid N requests
|
||||||
// per discovery cycle. Pass nil to probe every model unconditionally.
|
// per discovery cycle. Pass nil to probe every model unconditionally.
|
||||||
@@ -48,125 +48,144 @@ func DiscoverOllama(ctx context.Context, baseURL string, toolCache map[string]bo
|
|||||||
defer func() { _ = resp.Body.Close() }()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
if resp.StatusCode != 200 {
|
||||||
return nil, fmt.Errorf("ollama returned %d", resp.StatusCode)
|
return nil, fmt.Errorf("ollama returned status %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
var result struct {
|
var data struct {
|
||||||
Models []struct {
|
Models []struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Size int64 `json:"size"`
|
Size int64 `json:"size"`
|
||||||
Details struct {
|
|
||||||
Family string `json:"family"`
|
|
||||||
ParameterSize string `json:"parameter_size"`
|
|
||||||
} `json:"details"`
|
|
||||||
} `json:"models"`
|
} `json:"models"`
|
||||||
}
|
}
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
|
||||||
return nil, fmt.Errorf("ollama response parse: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
currentModels := make(map[string]bool, len(result.Models))
|
|
||||||
var models []DiscoveredModel
|
|
||||||
for _, m := range result.Models {
|
|
||||||
currentModels[m.Name] = true
|
|
||||||
supportsTools, cached := false, false
|
|
||||||
if toolCache != nil {
|
|
||||||
supportsTools, cached = toolCache[m.Name]
|
|
||||||
}
|
|
||||||
if !cached {
|
|
||||||
supportsTools = probeOllamaToolSupport(ctx, baseURL, m.Name)
|
|
||||||
if toolCache != nil {
|
|
||||||
toolCache[m.Name] = supportsTools
|
|
||||||
}
|
|
||||||
}
|
|
||||||
models = append(models, DiscoveredModel{
|
|
||||||
ID: m.Name,
|
|
||||||
Name: m.Name,
|
|
||||||
Provider: "ollama",
|
|
||||||
Size: m.Size,
|
|
||||||
SupportsTools: supportsTools,
|
|
||||||
ContextSize: 32768, // conservative default; Ollama /api/show can refine this
|
|
||||||
})
|
|
||||||
}
|
|
||||||
// Prune cache entries for disappeared models (may be a different quant next time).
|
|
||||||
for name := range toolCache {
|
|
||||||
if !currentModels[name] {
|
|
||||||
delete(toolCache, name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return models, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DiscoverLlamaCpp polls a local llama.cpp server for available models.
|
|
||||||
func DiscoverLlamaCpp(ctx context.Context, baseURL string) ([]DiscoveredModel, error) {
|
|
||||||
if baseURL == "" {
|
|
||||||
baseURL = "http://localhost:8080"
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(ctx, discoveryTimeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", baseURL+"/v1/models", nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
discovered := make([]DiscoveredModel, 0, len(data.Models))
|
||||||
|
for _, m := range data.Models {
|
||||||
|
dm := DiscoveredModel{
|
||||||
|
ID: m.Name,
|
||||||
|
Name: m.Name,
|
||||||
|
Provider: "ollama",
|
||||||
|
Size: m.Size,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to probe capabilities if we have a cache or if we want to probe
|
||||||
|
if toolCache != nil {
|
||||||
|
if supported, ok := toolCache[m.Name]; ok {
|
||||||
|
dm.SupportsTools = supported
|
||||||
|
} else {
|
||||||
|
// Probe once
|
||||||
|
supported, contextSize := probeOllamaModel(ctx, baseURL, m.Name)
|
||||||
|
toolCache[m.Name] = supported
|
||||||
|
dm.SupportsTools = supported
|
||||||
|
dm.ContextSize = contextSize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
discovered = append(discovered, dm)
|
||||||
|
}
|
||||||
|
return discovered, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func probeOllamaModel(ctx context.Context, baseURL, model string) (bool, int) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", baseURL+"/api/show", strings.NewReader(fmt.Sprintf(`{"name":"%s"}`, model)))
|
||||||
|
if err != nil {
|
||||||
|
return false, 0
|
||||||
|
}
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return false, 0
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
return false, 0
|
||||||
|
}
|
||||||
|
var data struct {
|
||||||
|
Template string `json:"template"`
|
||||||
|
Parameters string `json:"parameters"`
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
|
||||||
|
return false, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Heuristic for tool support: many modern models that support tools
|
||||||
|
// have "call" or "tool" or "json" in their template or system prompt
|
||||||
|
// logic. More specifically, Ollama's own tool-calling models often
|
||||||
|
// include specific jinja templates.
|
||||||
|
supported := strings.Contains(data.Template, ".Tool") ||
|
||||||
|
strings.Contains(data.Template, "tools") ||
|
||||||
|
strings.Contains(data.Template, "json")
|
||||||
|
|
||||||
|
// Context size heuristic from parameters
|
||||||
|
contextSize := 0
|
||||||
|
if strings.Contains(data.Parameters, "num_ctx") {
|
||||||
|
// Ollama parameters are often a block of text: "num_ctx 4096\nstop <|end|>"
|
||||||
|
lines := strings.Split(data.Parameters, "\n")
|
||||||
|
for _, l := range lines {
|
||||||
|
if strings.HasPrefix(l, "num_ctx") {
|
||||||
|
fmt.Sscanf(l, "num_ctx %d", &contextSize)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return supported, contextSize
|
||||||
|
}
|
||||||
|
|
||||||
|
// DiscoverLlamaCPP checks if a local llama.cpp server is reachable.
|
||||||
|
func DiscoverLlamaCPP(ctx context.Context, baseURL string) ([]DiscoveredModel, error) {
|
||||||
|
if baseURL == "" {
|
||||||
|
baseURL = "http://localhost:8080"
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, discoveryTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", baseURL+"/props", nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
resp, err := http.DefaultClient.Do(req)
|
resp, err := http.DefaultClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("llama.cpp not reachable at %s: %w", baseURL, err)
|
return nil, fmt.Errorf("llama.cpp not reachable at %s: %w", baseURL, err)
|
||||||
}
|
}
|
||||||
defer func() { _ = resp.Body.Close() }()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
if resp.StatusCode != 200 {
|
||||||
return nil, fmt.Errorf("llama.cpp returned %d", resp.StatusCode)
|
return nil, fmt.Errorf("llama.cpp returned status %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
var result struct {
|
// llama.cpp /props often returns the model path
|
||||||
Data []struct {
|
var data struct {
|
||||||
ID string `json:"id"`
|
DefaultGenerationSettings struct {
|
||||||
} `json:"data"`
|
N_Ctx int `json:"n_ctx"`
|
||||||
|
} `json:"default_generation_settings"`
|
||||||
}
|
}
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
|
||||||
return nil, fmt.Errorf("llama.cpp response parse: %w", err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// llama.cpp loads one model server-wide; probe once for tool support.
|
return []DiscoveredModel{{
|
||||||
toolSupport := probeLlamaCppToolSupport(ctx, baseURL)
|
ID: "default",
|
||||||
slog.Debug("llamacpp discovery probe complete",
|
Name: "llama.cpp",
|
||||||
"models_found", len(result.Data),
|
Provider: "llamacpp",
|
||||||
"tool_support", toolSupport,
|
ContextSize: data.DefaultGenerationSettings.N_Ctx,
|
||||||
)
|
SupportsTools: true, // assume true for modern llama.cpp
|
||||||
|
}}, nil
|
||||||
var models []DiscoveredModel
|
|
||||||
for _, m := range result.Data {
|
|
||||||
models = append(models, DiscoveredModel{
|
|
||||||
ID: m.ID,
|
|
||||||
Name: m.ID,
|
|
||||||
Provider: "llamacpp",
|
|
||||||
SupportsTools: toolSupport,
|
|
||||||
ContextSize: 8192, // llama.cpp default; --ctx-size configurable
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return models, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DiscoverLocalModels discovers all available local models (ollama + llama.cpp).
|
// DiscoverLocalModels polls all known local providers.
|
||||||
// Non-blocking: failures are logged and skipped.
|
|
||||||
// ollamaToolCache is passed to DiscoverOllama; nil skips caching.
|
|
||||||
func DiscoverLocalModels(ctx context.Context, logger *slog.Logger, ollamaURL, llamacppURL string, ollamaToolCache map[string]bool) []DiscoveredModel {
|
func DiscoverLocalModels(ctx context.Context, logger *slog.Logger, ollamaURL, llamacppURL string, ollamaToolCache map[string]bool) []DiscoveredModel {
|
||||||
var all []DiscoveredModel
|
var all []DiscoveredModel
|
||||||
|
|
||||||
if models, err := DiscoverOllama(ctx, ollamaURL, ollamaToolCache); err != nil {
|
if models, err := DiscoverOllama(ctx, ollamaURL, ollamaToolCache); err != nil {
|
||||||
logger.Debug("ollama discovery failed (non-fatal)", "error", err)
|
logger.Debug("ollama discovery skipped", "error", err)
|
||||||
} else {
|
} else {
|
||||||
logger.Debug("discovered ollama models", "count", len(models))
|
|
||||||
all = append(all, models...)
|
all = append(all, models...)
|
||||||
}
|
}
|
||||||
|
|
||||||
if models, err := DiscoverLlamaCpp(ctx, llamacppURL); err != nil {
|
if models, err := DiscoverLlamaCPP(ctx, llamacppURL); err != nil {
|
||||||
logger.Debug("llamacpp discovery failed (non-fatal)", "error", err)
|
logger.Debug("llama.cpp discovery skipped", "error", err)
|
||||||
} else {
|
} else {
|
||||||
logger.Debug("discovered llamacpp models", "count", len(models))
|
|
||||||
all = append(all, models...)
|
all = append(all, models...)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -177,7 +196,7 @@ func DiscoverLocalModels(ctx context.Context, logger *slog.Logger, ollamaURL, ll
|
|||||||
// onReconcile is called when the forced arm identity changes (may be nil).
|
// onReconcile is called when the forced arm identity changes (may be nil).
|
||||||
func StartDiscoveryLoop(ctx context.Context, r *Router, logger *slog.Logger,
|
func StartDiscoveryLoop(ctx context.Context, r *Router, logger *slog.Logger,
|
||||||
ollamaURL, llamacppURL string,
|
ollamaURL, llamacppURL string,
|
||||||
providerFactory func(name, model string) provider.Provider,
|
providerFactory func(name, model string) SecureProvider,
|
||||||
interval time.Duration,
|
interval time.Duration,
|
||||||
onReconcile func(ArmID),
|
onReconcile func(ArmID),
|
||||||
) {
|
) {
|
||||||
@@ -200,7 +219,7 @@ func StartDiscoveryLoop(ctx context.Context, r *Router, logger *slog.Logger,
|
|||||||
// reconcileArms adds newly discovered models, removes disappeared ones, and
|
// reconcileArms adds newly discovered models, removes disappeared ones, and
|
||||||
// reconciles the forced arm when discovery reveals its real model name.
|
// reconciles the forced arm when discovery reveals its real model name.
|
||||||
// onReconcile is called (if non-nil) when the forced arm identity changes.
|
// onReconcile is called (if non-nil) when the forced arm identity changes.
|
||||||
func reconcileArms(r *Router, discovered []DiscoveredModel, providerFactory func(name, model string) provider.Provider, logger *slog.Logger, onReconcile func(ArmID)) {
|
func reconcileArms(r *Router, discovered []DiscoveredModel, providerFactory func(name, model string) SecureProvider, logger *slog.Logger, onReconcile func(ArmID)) {
|
||||||
discoveredSet := make(map[ArmID]bool, len(discovered))
|
discoveredSet := make(map[ArmID]bool, len(discovered))
|
||||||
for _, m := range discovered {
|
for _, m := range discovered {
|
||||||
discoveredSet[NewArmID(m.Provider, m.ID)] = true
|
discoveredSet[NewArmID(m.Provider, m.ID)] = true
|
||||||
@@ -253,7 +272,7 @@ func reconcileArms(r *Router, discovered []DiscoveredModel, providerFactory func
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RegisterDiscoveredModels registers discovered local models as arms in the router.
|
// RegisterDiscoveredModels registers discovered local models as arms in the router.
|
||||||
func RegisterDiscoveredModels(r *Router, models []DiscoveredModel, providerFactory func(name, model string) provider.Provider) {
|
func RegisterDiscoveredModels(r *Router, models []DiscoveredModel, providerFactory func(name, model string) SecureProvider) {
|
||||||
for _, m := range models {
|
for _, m := range models {
|
||||||
armID := NewArmID(m.Provider, m.ID)
|
armID := NewArmID(m.Provider, m.ID)
|
||||||
|
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ func TestArmID_Model(t *testing.T) {
|
|||||||
|
|
||||||
// --- reconcileArms ---
|
// --- reconcileArms ---
|
||||||
|
|
||||||
func noopFactory(name, model string) provider.Provider { return nil }
|
func noopFactory(name, model string) SecureProvider { return nil }
|
||||||
|
|
||||||
func dummyArm(id ArmID, local bool) *Arm {
|
func dummyArm(id ArmID, local bool) *Arm {
|
||||||
return &Arm{
|
return &Arm{
|
||||||
@@ -139,7 +139,7 @@ func TestReconcileArms_NoForcedArm(t *testing.T) {
|
|||||||
{ID: "gemma-26b", Provider: "llamacpp", SupportsTools: true},
|
{ID: "gemma-26b", Provider: "llamacpp", SupportsTools: true},
|
||||||
}
|
}
|
||||||
|
|
||||||
factory := func(name, model string) provider.Provider {
|
factory := func(name, model string) SecureProvider {
|
||||||
return &stubProvider{name: name, model: model}
|
return &stubProvider{name: name, model: model}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -212,3 +212,4 @@ func (s *stubProvider) Models(_ context.Context) ([]provider.ModelInfo, error) {
|
|||||||
func (s *stubProvider) Stream(_ context.Context, _ provider.Request) (stream.Stream, error) {
|
func (s *stubProvider) Stream(_ context.Context, _ provider.Request) (stream.Stream, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
func (s *stubProvider) IsSecure() bool { return true }
|
||||||
|
|||||||
@@ -283,17 +283,17 @@ func (r *Router) Arms() []*Arm {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RegisterProvider registers all models from a provider as arms.
|
// RegisterProvider registers all models from a provider as arms.
|
||||||
func (r *Router) RegisterProvider(ctx context.Context, prov provider.Provider, isLocal bool, costs map[string][2]float64) {
|
func (r *Router) RegisterProvider(ctx context.Context, prov SecureProvider, isLocal bool, costs map[string][2]float64) {
|
||||||
models, err := prov.Models(ctx)
|
models, err := prov.Models(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.logger.Debug("failed to list models", "provider", prov.Name(), "error", err)
|
r.logger.Debug("failed to list models", "provider", prov.Name(), "error", err)
|
||||||
// Register at least the default model
|
// Register at least the default model
|
||||||
id := NewArmID(prov.Name(), prov.DefaultModel())
|
id := NewArmID(prov.Name(), prov.DefaultModel())
|
||||||
r.RegisterArm(&Arm{
|
r.RegisterArm(&Arm{
|
||||||
ID: id,
|
ID: id,
|
||||||
Provider: prov,
|
Provider: prov,
|
||||||
ModelName: prov.DefaultModel(),
|
ModelName: prov.DefaultModel(),
|
||||||
IsLocal: isLocal,
|
IsLocal: isLocal,
|
||||||
Capabilities: provider.Capabilities{ToolUse: true}, // optimistic
|
Capabilities: provider.Capabilities{ToolUse: true}, // optimistic
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -21,10 +21,11 @@ type Firewall struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type FirewallConfig struct {
|
type FirewallConfig struct {
|
||||||
ScanOutgoing bool
|
ScanOutgoing bool
|
||||||
ScanToolResults bool
|
ScanToolResults bool
|
||||||
EntropyThreshold float64
|
RedactHighEntropy bool
|
||||||
Logger *slog.Logger
|
EntropyThreshold float64
|
||||||
|
Logger *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewFirewall(cfg FirewallConfig) *Firewall {
|
func NewFirewall(cfg FirewallConfig) *Firewall {
|
||||||
@@ -33,7 +34,7 @@ func NewFirewall(cfg FirewallConfig) *Firewall {
|
|||||||
logger = slog.Default()
|
logger = slog.Default()
|
||||||
}
|
}
|
||||||
return &Firewall{
|
return &Firewall{
|
||||||
scanner: NewScanner(cfg.EntropyThreshold),
|
scanner: NewScanner(cfg.EntropyThreshold, cfg.RedactHighEntropy),
|
||||||
incognito: NewIncognitoMode(),
|
incognito: NewIncognitoMode(),
|
||||||
logger: logger,
|
logger: logger,
|
||||||
scanOutgoing: cfg.ScanOutgoing,
|
scanOutgoing: cfg.ScanOutgoing,
|
||||||
|
|||||||
@@ -40,6 +40,11 @@ func (p *SafeProvider) Inner() provider.Provider {
|
|||||||
return p.inner
|
return p.inner
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsSecure returns true. Satisfies the router's SecureProvider interface.
|
||||||
|
func (p *SafeProvider) IsSecure() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (p *SafeProvider) Stream(ctx context.Context, req provider.Request) (stream.Stream, error) {
|
func (p *SafeProvider) Stream(ctx context.Context, req provider.Request) (stream.Stream, error) {
|
||||||
if p.fwRef != nil {
|
if p.fwRef != nil {
|
||||||
if fw := p.fwRef.Get(); fw != nil {
|
if fw := p.fwRef.Get(); fw != nil {
|
||||||
|
|||||||
@@ -32,17 +32,19 @@ type SecretMatch struct {
|
|||||||
|
|
||||||
// Scanner detects secrets and sensitive data in content.
|
// Scanner detects secrets and sensitive data in content.
|
||||||
type Scanner struct {
|
type Scanner struct {
|
||||||
patterns []SecretPattern
|
patterns []SecretPattern
|
||||||
entropyThreshold float64
|
entropyThreshold float64
|
||||||
|
redactHighEntropy bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewScanner(entropyThreshold float64) *Scanner {
|
func NewScanner(entropyThreshold float64, redactHighEntropy bool) *Scanner {
|
||||||
if entropyThreshold <= 0 {
|
if entropyThreshold <= 0 {
|
||||||
entropyThreshold = 4.5
|
entropyThreshold = 4.5
|
||||||
}
|
}
|
||||||
return &Scanner{
|
return &Scanner{
|
||||||
patterns: defaultPatterns(),
|
patterns: defaultPatterns(),
|
||||||
entropyThreshold: entropyThreshold,
|
entropyThreshold: entropyThreshold,
|
||||||
|
redactHighEntropy: redactHighEntropy,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -104,9 +106,13 @@ func (s *Scanner) scanEntropy(content string) []SecretMatch {
|
|||||||
}
|
}
|
||||||
entropy := shannonEntropy(w.text)
|
entropy := shannonEntropy(w.text)
|
||||||
if entropy >= s.entropyThreshold {
|
if entropy >= s.entropyThreshold {
|
||||||
|
action := ActionWarn
|
||||||
|
if s.redactHighEntropy {
|
||||||
|
action = ActionRedact
|
||||||
|
}
|
||||||
matches = append(matches, SecretMatch{
|
matches = append(matches, SecretMatch{
|
||||||
Pattern: "high_entropy",
|
Pattern: "high_entropy",
|
||||||
Action: ActionWarn,
|
Action: action,
|
||||||
Start: w.start,
|
Start: w.start,
|
||||||
End: w.start + len(w.text),
|
End: w.start + len(w.text),
|
||||||
})
|
})
|
||||||
@@ -224,7 +230,7 @@ func defaultPatterns() []SecretPattern {
|
|||||||
{"sentry_auth_token", `sntrys_[a-zA-Z0-9_]{50,}`},
|
{"sentry_auth_token", `sntrys_[a-zA-Z0-9_]{50,}`},
|
||||||
|
|
||||||
// --- Infrastructure ---
|
// --- Infrastructure ---
|
||||||
{"private_key", `-----BEGIN (?:RSA |EC |DSA |OPENSSH )?PRIVATE KEY-----`},
|
{"private_key", `(?s)-----BEGIN (?:RSA |EC |DSA |OPENSSH )?PRIVATE KEY-----.*?-----END (?:RSA |EC |DSA |OPENSSH )?PRIVATE KEY-----`},
|
||||||
{"database_url", `(?i)(?:postgres|mysql|mongodb|redis)://[^:]+:[^@]+@`},
|
{"database_url", `(?i)(?:postgres|mysql|mongodb|redis)://[^:]+:[^@]+@`},
|
||||||
{"heroku_api_key", `(?i)HEROKU_API_KEY\s*=\s*[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}`},
|
{"heroku_api_key", `(?i)HEROKU_API_KEY\s*=\s*[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}`},
|
||||||
{"mailgun_api_key", `key-[a-f0-9]{32}`},
|
{"mailgun_api_key", `key-[a-f0-9]{32}`},
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
// --- Scanner ---
|
// --- Scanner ---
|
||||||
|
|
||||||
func TestScanner_DetectsAnthropicKey(t *testing.T) {
|
func TestScanner_DetectsAnthropicKey(t *testing.T) {
|
||||||
s := NewScanner(4.5)
|
s := NewScanner(4.5, false)
|
||||||
matches := s.Scan("my key is sk-ant-api03-abcdefghijklmnopqrstuvwxyz")
|
matches := s.Scan("my key is sk-ant-api03-abcdefghijklmnopqrstuvwxyz")
|
||||||
if len(matches) == 0 {
|
if len(matches) == 0 {
|
||||||
t.Error("should detect Anthropic API key")
|
t.Error("should detect Anthropic API key")
|
||||||
@@ -21,7 +21,7 @@ func TestScanner_DetectsAnthropicKey(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestScanner_DetectsOpenAIKey(t *testing.T) {
|
func TestScanner_DetectsOpenAIKey(t *testing.T) {
|
||||||
s := NewScanner(4.5)
|
s := NewScanner(4.5, false)
|
||||||
matches := s.Scan("key: sk-proj-abcdefghijklmnopqrstuvwxyz123456")
|
matches := s.Scan("key: sk-proj-abcdefghijklmnopqrstuvwxyz123456")
|
||||||
if len(matches) == 0 {
|
if len(matches) == 0 {
|
||||||
t.Error("should detect OpenAI API key")
|
t.Error("should detect OpenAI API key")
|
||||||
@@ -29,7 +29,7 @@ func TestScanner_DetectsOpenAIKey(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestScanner_DetectsAWSKey(t *testing.T) {
|
func TestScanner_DetectsAWSKey(t *testing.T) {
|
||||||
s := NewScanner(4.5)
|
s := NewScanner(4.5, false)
|
||||||
matches := s.Scan("AKIAIOSFODNN7EXAMPLE")
|
matches := s.Scan("AKIAIOSFODNN7EXAMPLE")
|
||||||
if len(matches) == 0 {
|
if len(matches) == 0 {
|
||||||
t.Error("should detect AWS access key")
|
t.Error("should detect AWS access key")
|
||||||
@@ -40,7 +40,7 @@ func TestScanner_DetectsAWSKey(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestScanner_DetectsGitHubPAT(t *testing.T) {
|
func TestScanner_DetectsGitHubPAT(t *testing.T) {
|
||||||
s := NewScanner(4.5)
|
s := NewScanner(4.5, false)
|
||||||
matches := s.Scan("token: ghp_ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghij")
|
matches := s.Scan("token: ghp_ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghij")
|
||||||
hasGH := false
|
hasGH := false
|
||||||
for _, m := range matches {
|
for _, m := range matches {
|
||||||
@@ -55,8 +55,8 @@ func TestScanner_DetectsGitHubPAT(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestScanner_DetectsPrivateKey(t *testing.T) {
|
func TestScanner_DetectsPrivateKey(t *testing.T) {
|
||||||
s := NewScanner(4.5)
|
s := NewScanner(4.5, false)
|
||||||
matches := s.Scan("-----BEGIN RSA PRIVATE KEY-----\nMIIE...")
|
matches := s.Scan("-----BEGIN RSA PRIVATE KEY-----\nMIIE...\n-----END RSA PRIVATE KEY-----")
|
||||||
hasKey := false
|
hasKey := false
|
||||||
for _, m := range matches {
|
for _, m := range matches {
|
||||||
if m.Pattern == "private_key" {
|
if m.Pattern == "private_key" {
|
||||||
@@ -70,7 +70,7 @@ func TestScanner_DetectsPrivateKey(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestScanner_DetectsGenericSecret(t *testing.T) {
|
func TestScanner_DetectsGenericSecret(t *testing.T) {
|
||||||
s := NewScanner(4.5)
|
s := NewScanner(4.5, false)
|
||||||
matches := s.Scan(`password = "supersecretpassword123"`)
|
matches := s.Scan(`password = "supersecretpassword123"`)
|
||||||
hasGeneric := false
|
hasGeneric := false
|
||||||
for _, m := range matches {
|
for _, m := range matches {
|
||||||
@@ -85,7 +85,7 @@ func TestScanner_DetectsGenericSecret(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestScanner_DetectsDatabaseURL(t *testing.T) {
|
func TestScanner_DetectsDatabaseURL(t *testing.T) {
|
||||||
s := NewScanner(4.5)
|
s := NewScanner(4.5, false)
|
||||||
matches := s.Scan("postgres://admin:secretpass@db.example.com:5432/mydb")
|
matches := s.Scan("postgres://admin:secretpass@db.example.com:5432/mydb")
|
||||||
hasDB := false
|
hasDB := false
|
||||||
for _, m := range matches {
|
for _, m := range matches {
|
||||||
@@ -100,7 +100,7 @@ func TestScanner_DetectsDatabaseURL(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestScanner_DetectsMistralKey(t *testing.T) {
|
func TestScanner_DetectsMistralKey(t *testing.T) {
|
||||||
s := NewScanner(6.0)
|
s := NewScanner(6.0, false)
|
||||||
|
|
||||||
// Should detect Mistral key in assignment contexts.
|
// Should detect Mistral key in assignment contexts.
|
||||||
positives := []string{
|
positives := []string{
|
||||||
@@ -139,7 +139,7 @@ func TestScanner_DetectsMistralKey(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestScanner_NoFalsePositives(t *testing.T) {
|
func TestScanner_NoFalsePositives(t *testing.T) {
|
||||||
s := NewScanner(6.0) // high entropy threshold to avoid false positives
|
s := NewScanner(6.0, false) // high entropy threshold to avoid false positives
|
||||||
safe := []string{
|
safe := []string{
|
||||||
"hello world",
|
"hello world",
|
||||||
"func main() {}",
|
"func main() {}",
|
||||||
@@ -156,7 +156,7 @@ func TestScanner_NoFalsePositives(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestScanner_Entropy(t *testing.T) {
|
func TestScanner_Entropy(t *testing.T) {
|
||||||
s := NewScanner(4.0) // lower threshold for testing
|
s := NewScanner(4.0, false) // lower threshold for testing
|
||||||
|
|
||||||
// High entropy string (random-looking)
|
// High entropy string (random-looking)
|
||||||
matches := s.Scan("token: aB3dE5fG7hI9jK1lM3nO5pQ7rS9tU1v")
|
matches := s.Scan("token: aB3dE5fG7hI9jK1lM3nO5pQ7rS9tU1v")
|
||||||
@@ -195,7 +195,7 @@ func TestShannonEntropy(t *testing.T) {
|
|||||||
|
|
||||||
func TestRedact_SingleMatch(t *testing.T) {
|
func TestRedact_SingleMatch(t *testing.T) {
|
||||||
content := `AKIAIOSFODNN7EXAMPLE is my key`
|
content := `AKIAIOSFODNN7EXAMPLE is my key`
|
||||||
s := NewScanner(6.0)
|
s := NewScanner(6.0, false)
|
||||||
matches := s.Scan(content)
|
matches := s.Scan(content)
|
||||||
|
|
||||||
result := Redact(content, matches)
|
result := Redact(content, matches)
|
||||||
@@ -209,7 +209,7 @@ func TestRedact_SingleMatch(t *testing.T) {
|
|||||||
|
|
||||||
func TestRedact_MultipleMatches(t *testing.T) {
|
func TestRedact_MultipleMatches(t *testing.T) {
|
||||||
content := "aws: AKIAIOSFODNN7EXAMPLE github: ghp_ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghij"
|
content := "aws: AKIAIOSFODNN7EXAMPLE github: ghp_ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghij"
|
||||||
s := NewScanner(6.0)
|
s := NewScanner(6.0, false)
|
||||||
matches := s.Scan(content)
|
matches := s.Scan(content)
|
||||||
|
|
||||||
result := Redact(content, matches)
|
result := Redact(content, matches)
|
||||||
@@ -442,7 +442,7 @@ func TestFirewall_ActionBlockReturnsBlockedString(t *testing.T) {
|
|||||||
func TestScanner_DedupKeyNoCollision(t *testing.T) {
|
func TestScanner_DedupKeyNoCollision(t *testing.T) {
|
||||||
// Two matches at byte offsets > 127 in the same pattern should both appear,
|
// Two matches at byte offsets > 127 in the same pattern should both appear,
|
||||||
// not get deduplicated because of hash collision in the key.
|
// not get deduplicated because of hash collision in the key.
|
||||||
s := NewScanner(3.0)
|
s := NewScanner(3.0, false)
|
||||||
// Build a string where two matches appear after offset 127
|
// Build a string where two matches appear after offset 127
|
||||||
prefix := strings.Repeat("x", 128) // push matches past offset 127
|
prefix := strings.Repeat("x", 128) // push matches past offset 127
|
||||||
input := prefix + "sk-ant-api03-aaaaaaaabbbbbbbbcccccccc " + prefix + "sk-ant-api03-ddddddddeeeeeeeeffffffff"
|
input := prefix + "sk-ant-api03-aaaaaaaabbbbbbbbcccccccc " + prefix + "sk-ant-api03-ddddddddeeeeeeeeffffffff"
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ type mockProvider struct {
|
|||||||
|
|
||||||
func (m *mockProvider) Name() string { return m.name }
|
func (m *mockProvider) Name() string { return m.name }
|
||||||
func (m *mockProvider) DefaultModel() string { return "mock-model" }
|
func (m *mockProvider) DefaultModel() string { return "mock-model" }
|
||||||
|
func (m *mockProvider) IsSecure() bool { return true }
|
||||||
func (m *mockProvider) Models(_ context.Context) ([]provider.ModelInfo, error) {
|
func (m *mockProvider) Models(_ context.Context) ([]provider.ModelInfo, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ func (m *mockProvider) DefaultModel() string { return "default" }
|
|||||||
func (m *mockProvider) Models(_ context.Context) ([]provider.ModelInfo, error) {
|
func (m *mockProvider) Models(_ context.Context) ([]provider.ModelInfo, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
func (m *mockProvider) IsSecure() bool { return true }
|
||||||
func (m *mockProvider) Stream(ctx context.Context, _ provider.Request) (stream.Stream, error) {
|
func (m *mockProvider) Stream(ctx context.Context, _ provider.Request) (stream.Stream, error) {
|
||||||
if m.delay > 0 {
|
if m.delay > 0 {
|
||||||
select {
|
select {
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ const pidFile = "llamafile.pid"
|
|||||||
// DefaultModelURL is the default llamafile to download when none is configured.
|
// DefaultModelURL is the default llamafile to download when none is configured.
|
||||||
// Qwen2.5 0.5B Instruct Q6_K (~450 MB) — small, fast, and supports tools.
|
// Qwen2.5 0.5B Instruct Q6_K (~450 MB) — small, fast, and supports tools.
|
||||||
const DefaultModelURL = "https://huggingface.co/Mozilla/Qwen2.5-0.5B-Instruct-llamafile/resolve/main/Qwen2.5-0.5B-Instruct-Q6_K.llamafile"
|
const DefaultModelURL = "https://huggingface.co/Mozilla/Qwen2.5-0.5B-Instruct-llamafile/resolve/main/Qwen2.5-0.5B-Instruct-Q6_K.llamafile"
|
||||||
|
const DefaultModelSHA256 = "c4e991af9ea7077339b8768e349da486a76392e72b3ef47ad372e6582779a8dd"
|
||||||
|
|
||||||
// DefaultDataDir returns the platform default SLM data directory.
|
// DefaultDataDir returns the platform default SLM data directory.
|
||||||
// Follows XDG Base Directory Specification: $XDG_DATA_HOME/gnoma/slm,
|
// Follows XDG Base Directory Specification: $XDG_DATA_HOME/gnoma/slm,
|
||||||
@@ -57,8 +58,9 @@ func (s Status) String() string {
|
|||||||
|
|
||||||
// Config holds Manager configuration.
|
// Config holds Manager configuration.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
DataDir string // XDG data home / gnoma / slm; must be set
|
DataDir string // XDG data home / gnoma / slm; must be set
|
||||||
ModelURL string // required for Setup
|
ModelURL string // required for Setup
|
||||||
|
ExpectedSHA256 string // if non-empty, Setup verifies against this
|
||||||
}
|
}
|
||||||
|
|
||||||
// Manager controls the llamafile lifecycle.
|
// Manager controls the llamafile lifecycle.
|
||||||
@@ -131,6 +133,11 @@ func (m *Manager) Setup(ctx context.Context, progress func(downloaded, total int
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if m.cfg.ExpectedSHA256 != "" && sha256hex != m.cfg.ExpectedSHA256 {
|
||||||
|
_ = os.Remove(dst) // cleanup corrupt/malicious download
|
||||||
|
return fmt.Errorf("slm: hash mismatch for %s: got %s, want %s", m.cfg.ModelURL, sha256hex, m.cfg.ExpectedSHA256)
|
||||||
|
}
|
||||||
|
|
||||||
mf := &Manifest{
|
mf := &Manifest{
|
||||||
ModelURL: m.cfg.ModelURL,
|
ModelURL: m.cfg.ModelURL,
|
||||||
FilePath: dst,
|
FilePath: dst,
|
||||||
|
|||||||
@@ -142,7 +142,15 @@ func (t *GrepTool) Execute(_ context.Context, args json.RawMessage) (tool.Result
|
|||||||
}
|
}
|
||||||
|
|
||||||
rel, _ := filepath.Rel(root, path)
|
rel, _ := filepath.Rel(root, path)
|
||||||
fileMatches := grepFile(path, rel, re, maxResults-len(matches))
|
resolvedPath := path
|
||||||
|
if t.guard != nil {
|
||||||
|
resolved, err := t.guard.ResolveRead(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil // Skip files outside workspace or unreadable
|
||||||
|
}
|
||||||
|
resolvedPath = resolved
|
||||||
|
}
|
||||||
|
fileMatches := grepFile(resolvedPath, rel, re, maxResults-len(matches))
|
||||||
matches = append(matches, fileMatches...)
|
matches = append(matches, fileMatches...)
|
||||||
|
|
||||||
if len(matches) >= maxResults {
|
if len(matches) >= maxResults {
|
||||||
|
|||||||
Reference in New Issue
Block a user