feat(fs): enforce workspace boundary on fs tools
Adds a Guard that resolves every path against an allowlist of absolute roots (default: cwd) and rejects anything escaping via relative segments, absolute paths outside the root, or symlinks (including symlinked parents on writes). Closes audit finding C1: fs.read/fs.write/fs.edit/fs.glob/fs.grep/fs.ls previously accepted any absolute path; the only protection was a substring denylist (.env, .ssh/, ...) which missed /etc/shadow, kube configs, IDE secrets, and anything reachable via symlink.
This commit is contained in:
+41
-9
@@ -229,10 +229,23 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
cwd, cwdErr := os.Getwd()
|
||||
if cwdErr != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: cannot resolve working directory: %v\n", cwdErr)
|
||||
os.Exit(1)
|
||||
}
|
||||
fsGuard, err := fs.NewGuard(cwd)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: workspace guard: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Create tool registry
|
||||
reg := buildToolRegistry()
|
||||
reg := buildToolRegistry(fsGuard)
|
||||
if cfg.Tools.MaxFileSize > 0 {
|
||||
reg.Register(fs.NewWriteTool(fs.WithMaxFileSize(cfg.Tools.MaxFileSize)))
|
||||
w := fs.NewWriteTool(fs.WithMaxFileSize(cfg.Tools.MaxFileSize))
|
||||
w.SetGuard(fsGuard)
|
||||
reg.Register(w)
|
||||
}
|
||||
|
||||
// Harvest aliases, inventory, CLI agents, and local models in parallel.
|
||||
@@ -991,15 +1004,34 @@ func createProvider(name, apiKey, model, baseURL string) (provider.Provider, err
|
||||
}
|
||||
}
|
||||
|
||||
func buildToolRegistry() *tool.Registry {
|
||||
func buildToolRegistry(guard *fs.Guard) *tool.Registry {
|
||||
reg := tool.NewRegistry()
|
||||
reg.Register(bash.New())
|
||||
reg.Register(fs.NewReadTool())
|
||||
reg.Register(fs.NewWriteTool())
|
||||
reg.Register(fs.NewEditTool())
|
||||
reg.Register(fs.NewGlobTool())
|
||||
reg.Register(fs.NewGrepTool())
|
||||
reg.Register(fs.NewLSTool())
|
||||
|
||||
read := fs.NewReadTool()
|
||||
read.SetGuard(guard)
|
||||
reg.Register(read)
|
||||
|
||||
write := fs.NewWriteTool()
|
||||
write.SetGuard(guard)
|
||||
reg.Register(write)
|
||||
|
||||
edit := fs.NewEditTool()
|
||||
edit.SetGuard(guard)
|
||||
reg.Register(edit)
|
||||
|
||||
glob := fs.NewGlobTool()
|
||||
glob.SetGuard(guard)
|
||||
reg.Register(glob)
|
||||
|
||||
grep := fs.NewGrepTool()
|
||||
grep.SetGuard(guard)
|
||||
reg.Register(grep)
|
||||
|
||||
ls := fs.NewLSTool()
|
||||
ls.SetGuard(guard)
|
||||
reg.Register(ls)
|
||||
|
||||
return reg
|
||||
}
|
||||
|
||||
|
||||
@@ -35,10 +35,14 @@ var editParams = json.RawMessage(`{
|
||||
"required": ["path", "old_string", "new_string"]
|
||||
}`)
|
||||
|
||||
type EditTool struct{}
|
||||
type EditTool struct {
|
||||
guard *Guard
|
||||
}
|
||||
|
||||
func NewEditTool() *EditTool { return &EditTool{} }
|
||||
|
||||
func (t *EditTool) SetGuard(g *Guard) { t.guard = g }
|
||||
|
||||
func (t *EditTool) Name() string { return editToolName }
|
||||
func (t *EditTool) Description() string { return "Perform exact string replacement in a file" }
|
||||
func (t *EditTool) Parameters() json.RawMessage { return editParams }
|
||||
@@ -72,7 +76,16 @@ func (t *EditTool) Execute(_ context.Context, args json.RawMessage) (tool.Result
|
||||
return tool.Result{}, fmt.Errorf("fs.edit: old_string and new_string must differ")
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(a.Path)
|
||||
path := a.Path
|
||||
if t.guard != nil {
|
||||
resolved, err := t.guard.ResolveRead(path)
|
||||
if err != nil {
|
||||
return tool.Result{Output: fmt.Sprintf("Error: %v", err)}, nil
|
||||
}
|
||||
path = resolved
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return tool.Result{Output: fmt.Sprintf("Error: %v", err)}, nil
|
||||
}
|
||||
@@ -101,7 +114,7 @@ func (t *EditTool) Execute(_ context.Context, args json.RawMessage) (tool.Result
|
||||
newContent = strings.Replace(content, a.OldString, a.NewString, 1)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(a.Path, []byte(newContent), 0o644); err != nil {
|
||||
if err := os.WriteFile(path, []byte(newContent), 0o644); err != nil {
|
||||
return tool.Result{Output: fmt.Sprintf("Error writing file: %v", err)}, nil
|
||||
}
|
||||
|
||||
@@ -111,11 +124,11 @@ func (t *EditTool) Execute(_ context.Context, args json.RawMessage) (tool.Result
|
||||
}
|
||||
|
||||
// Generate diff-style output with context
|
||||
diff := buildEditDiff(content, a.OldString, a.NewString, a.Path, replacements)
|
||||
diff := buildEditDiff(content, a.OldString, a.NewString, path, replacements)
|
||||
|
||||
return tool.Result{
|
||||
Output: diff,
|
||||
Metadata: map[string]any{"replacements": replacements, "path": a.Path},
|
||||
Metadata: map[string]any{"replacements": replacements, "path": path},
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -30,10 +30,14 @@ var globParams = json.RawMessage(`{
|
||||
"required": ["pattern"]
|
||||
}`)
|
||||
|
||||
type GlobTool struct{}
|
||||
type GlobTool struct {
|
||||
guard *Guard
|
||||
}
|
||||
|
||||
func NewGlobTool() *GlobTool { return &GlobTool{} }
|
||||
|
||||
func (t *GlobTool) SetGuard(g *Guard) { t.guard = g }
|
||||
|
||||
func (t *GlobTool) Name() string { return globToolName }
|
||||
func (t *GlobTool) Description() string { return "Find files matching a glob pattern, sorted by modification time" }
|
||||
func (t *GlobTool) Parameters() json.RawMessage { return globParams }
|
||||
@@ -64,12 +68,23 @@ func (t *GlobTool) Execute(_ context.Context, args json.RawMessage) (tool.Result
|
||||
|
||||
root := a.Path
|
||||
if root == "" {
|
||||
var err error
|
||||
root, err = os.Getwd()
|
||||
if err != nil {
|
||||
return tool.Result{}, fmt.Errorf("fs.glob: %w", err)
|
||||
if t.guard != nil {
|
||||
root = t.guard.Roots()[0]
|
||||
} else {
|
||||
var err error
|
||||
root, err = os.Getwd()
|
||||
if err != nil {
|
||||
return tool.Result{}, fmt.Errorf("fs.glob: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
if t.guard != nil {
|
||||
resolved, err := t.guard.ResolveRead(root)
|
||||
if err != nil {
|
||||
return tool.Result{Output: fmt.Sprintf("Error: %v", err)}, nil
|
||||
}
|
||||
root = resolved
|
||||
}
|
||||
|
||||
var matches []string
|
||||
err := filepath.WalkDir(root, func(path string, d os.DirEntry, err error) error {
|
||||
|
||||
@@ -41,10 +41,14 @@ var grepParams = json.RawMessage(`{
|
||||
"required": ["pattern"]
|
||||
}`)
|
||||
|
||||
type GrepTool struct{}
|
||||
type GrepTool struct {
|
||||
guard *Guard
|
||||
}
|
||||
|
||||
func NewGrepTool() *GrepTool { return &GrepTool{} }
|
||||
|
||||
func (t *GrepTool) SetGuard(g *Guard) { t.guard = g }
|
||||
|
||||
func (t *GrepTool) Name() string { return grepToolName }
|
||||
func (t *GrepTool) Description() string { return "Search file contents using a regular expression" }
|
||||
func (t *GrepTool) Parameters() json.RawMessage { return grepParams }
|
||||
@@ -93,11 +97,22 @@ func (t *GrepTool) Execute(_ context.Context, args json.RawMessage) (tool.Result
|
||||
|
||||
root := a.Path
|
||||
if root == "" {
|
||||
root, err = os.Getwd()
|
||||
if err != nil {
|
||||
return tool.Result{}, fmt.Errorf("fs.grep: %w", err)
|
||||
if t.guard != nil {
|
||||
root = t.guard.Roots()[0]
|
||||
} else {
|
||||
root, err = os.Getwd()
|
||||
if err != nil {
|
||||
return tool.Result{}, fmt.Errorf("fs.grep: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
if t.guard != nil {
|
||||
resolved, err := t.guard.ResolveRead(root)
|
||||
if err != nil {
|
||||
return tool.Result{Output: fmt.Sprintf("Error: %v", err)}, nil
|
||||
}
|
||||
root = resolved
|
||||
}
|
||||
|
||||
info, err := os.Stat(root)
|
||||
if err != nil {
|
||||
|
||||
@@ -0,0 +1,123 @@
|
||||
package fs
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var ErrOutsideWorkspace = errors.New("path outside workspace")
|
||||
|
||||
type Guard struct {
|
||||
roots []string
|
||||
}
|
||||
|
||||
func NewGuard(roots ...string) (*Guard, error) {
|
||||
if len(roots) == 0 {
|
||||
return nil, errors.New("guard requires at least one root")
|
||||
}
|
||||
resolved := make([]string, 0, len(roots))
|
||||
for _, r := range roots {
|
||||
if !filepath.IsAbs(r) {
|
||||
return nil, fmt.Errorf("guard root %q must be absolute", r)
|
||||
}
|
||||
canonical, err := filepath.EvalSymlinks(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("guard root %q: %w", r, err)
|
||||
}
|
||||
info, err := os.Stat(canonical)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("guard root %q: %w", r, err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return nil, fmt.Errorf("guard root %q is not a directory", r)
|
||||
}
|
||||
resolved = append(resolved, filepath.Clean(canonical))
|
||||
}
|
||||
return &Guard{roots: resolved}, nil
|
||||
}
|
||||
|
||||
func (g *Guard) Roots() []string {
|
||||
out := make([]string, len(g.roots))
|
||||
copy(out, g.roots)
|
||||
return out
|
||||
}
|
||||
|
||||
func (g *Guard) ResolveRead(path string) (string, error) {
|
||||
abs, err := g.absolutise(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
canonical, err := filepath.EvalSymlinks(abs)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("resolve %q: %w", path, err)
|
||||
}
|
||||
if !g.contains(canonical) {
|
||||
return "", fmt.Errorf("%w: %s", ErrOutsideWorkspace, path)
|
||||
}
|
||||
return canonical, nil
|
||||
}
|
||||
|
||||
// ResolveWrite canonicalises the deepest existing ancestor so a symlinked
|
||||
// parent escaping the workspace is rejected even when the leaf doesn't exist.
|
||||
func (g *Guard) ResolveWrite(path string) (string, error) {
|
||||
abs, err := g.absolutise(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
ancestor := abs
|
||||
tail := ""
|
||||
for {
|
||||
if _, err := os.Lstat(ancestor); err == nil {
|
||||
break
|
||||
}
|
||||
parent := filepath.Dir(ancestor)
|
||||
if parent == ancestor {
|
||||
return "", fmt.Errorf("resolve %q: no existing ancestor", path)
|
||||
}
|
||||
tail = filepath.Join(filepath.Base(ancestor), tail)
|
||||
ancestor = parent
|
||||
}
|
||||
|
||||
canonicalAncestor, err := filepath.EvalSymlinks(ancestor)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("resolve ancestor of %q: %w", path, err)
|
||||
}
|
||||
resolved := canonicalAncestor
|
||||
if tail != "" {
|
||||
resolved = filepath.Join(canonicalAncestor, tail)
|
||||
}
|
||||
if !g.contains(resolved) {
|
||||
return "", fmt.Errorf("%w: %s", ErrOutsideWorkspace, path)
|
||||
}
|
||||
return resolved, nil
|
||||
}
|
||||
|
||||
// absolutise anchors relative paths against the first root rather than process
|
||||
// cwd, which may drift over the lifetime of the agent.
|
||||
func (g *Guard) absolutise(path string) (string, error) {
|
||||
if path == "" {
|
||||
return "", errors.New("empty path")
|
||||
}
|
||||
if filepath.IsAbs(path) {
|
||||
return filepath.Clean(path), nil
|
||||
}
|
||||
return filepath.Clean(filepath.Join(g.roots[0], path)), nil
|
||||
}
|
||||
|
||||
// contains uses a separator boundary so "/ws-evil" is not considered inside "/ws".
|
||||
func (g *Guard) contains(canonical string) bool {
|
||||
for _, root := range g.roots {
|
||||
if canonical == root {
|
||||
return true
|
||||
}
|
||||
prefix := root + string(filepath.Separator)
|
||||
if strings.HasPrefix(canonical, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,232 @@
|
||||
package fs
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewGuard_RejectsEmptyRoots(t *testing.T) {
|
||||
if _, err := NewGuard(); err == nil {
|
||||
t.Fatal("NewGuard() with no roots should error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewGuard_RejectsRelativeRoot(t *testing.T) {
|
||||
if _, err := NewGuard("relative/path"); err == nil {
|
||||
t.Fatal("NewGuard with relative root should error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewGuard_RejectsNonexistentRoot(t *testing.T) {
|
||||
if _, err := NewGuard("/definitely/does/not/exist/anywhere"); err == nil {
|
||||
t.Fatal("NewGuard with nonexistent root should error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGuard_ResolveInsideRoot(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
g := mustGuard(t, root)
|
||||
|
||||
path := filepath.Join(root, "file.txt")
|
||||
if err := os.WriteFile(path, []byte("x"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got, err := g.ResolveRead(path)
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveRead inside root: %v", err)
|
||||
}
|
||||
if got != path {
|
||||
t.Errorf("got %q, want %q", got, path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGuard_ResolveReadOutsideRootDenied(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
outside := t.TempDir()
|
||||
g := mustGuard(t, root)
|
||||
|
||||
outsidePath := filepath.Join(outside, "secret")
|
||||
if err := os.WriteFile(outsidePath, []byte("x"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err := g.ResolveRead(outsidePath)
|
||||
if !errors.Is(err, ErrOutsideWorkspace) {
|
||||
t.Fatalf("want ErrOutsideWorkspace, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGuard_ResolveReadRelativeEscape(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
g := mustGuard(t, root)
|
||||
|
||||
// Relative path with ../../../ should resolve relative to first root and
|
||||
// escape it; guard must deny.
|
||||
_, err := g.ResolveRead("../../../etc/passwd")
|
||||
if !errors.Is(err, ErrOutsideWorkspace) {
|
||||
t.Fatalf("want ErrOutsideWorkspace, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGuard_ResolveReadSymlinkEscapeDenied(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("symlink semantics differ on Windows")
|
||||
}
|
||||
root := t.TempDir()
|
||||
outside := t.TempDir()
|
||||
g := mustGuard(t, root)
|
||||
|
||||
target := filepath.Join(outside, "target")
|
||||
if err := os.WriteFile(target, []byte("secret"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
link := filepath.Join(root, "link")
|
||||
if err := os.Symlink(target, link); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err := g.ResolveRead(link)
|
||||
if !errors.Is(err, ErrOutsideWorkspace) {
|
||||
t.Fatalf("symlink escaping root should be denied; got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGuard_ResolveReadSymlinkWithinRootAllowed(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("symlink semantics differ on Windows")
|
||||
}
|
||||
root := t.TempDir()
|
||||
g := mustGuard(t, root)
|
||||
|
||||
target := filepath.Join(root, "target")
|
||||
if err := os.WriteFile(target, []byte("x"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
link := filepath.Join(root, "link")
|
||||
if err := os.Symlink(target, link); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got, err := g.ResolveRead(link)
|
||||
if err != nil {
|
||||
t.Fatalf("symlink inside root: %v", err)
|
||||
}
|
||||
// Canonical form should be the target, not the link.
|
||||
if !strings.HasPrefix(got, root) {
|
||||
t.Errorf("canonical %q should be inside root %q", got, root)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGuard_ResolveWriteNewFileAllowed(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
g := mustGuard(t, root)
|
||||
|
||||
newFile := filepath.Join(root, "newdir", "newfile.txt")
|
||||
got, err := g.ResolveWrite(newFile)
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveWrite to new path inside root: %v", err)
|
||||
}
|
||||
if !strings.HasPrefix(got, root) {
|
||||
t.Errorf("canonical %q should be inside root %q", got, root)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGuard_ResolveWriteOutsideRootDenied(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
outside := t.TempDir()
|
||||
g := mustGuard(t, root)
|
||||
|
||||
_, err := g.ResolveWrite(filepath.Join(outside, "evil.txt"))
|
||||
if !errors.Is(err, ErrOutsideWorkspace) {
|
||||
t.Fatalf("want ErrOutsideWorkspace, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGuard_ResolveWriteViaSymlinkedParentDenied(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("symlink semantics differ on Windows")
|
||||
}
|
||||
root := t.TempDir()
|
||||
outside := t.TempDir()
|
||||
g := mustGuard(t, root)
|
||||
|
||||
// Create a symlink inside root whose target is outside.
|
||||
linkedDir := filepath.Join(root, "escape")
|
||||
if err := os.Symlink(outside, linkedDir); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Writing under the symlinked dir lands outside root.
|
||||
_, err := g.ResolveWrite(filepath.Join(linkedDir, "evil.txt"))
|
||||
if !errors.Is(err, ErrOutsideWorkspace) {
|
||||
t.Fatalf("write via symlinked parent should be denied; got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGuard_MultipleRoots(t *testing.T) {
|
||||
rootA := t.TempDir()
|
||||
rootB := t.TempDir()
|
||||
g := mustGuard(t, rootA, rootB)
|
||||
|
||||
a := filepath.Join(rootA, "file")
|
||||
if err := os.WriteFile(a, []byte("x"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
b := filepath.Join(rootB, "file")
|
||||
if err := os.WriteFile(b, []byte("x"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := g.ResolveRead(a); err != nil {
|
||||
t.Errorf("rootA: %v", err)
|
||||
}
|
||||
if _, err := g.ResolveRead(b); err != nil {
|
||||
t.Errorf("rootB: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGuard_RootBoundaryNotPrefixMatch(t *testing.T) {
|
||||
// Catch the classic bug: /foo/bar must NOT be considered inside /foo/ba.
|
||||
parent := t.TempDir()
|
||||
rootShort := filepath.Join(parent, "ws")
|
||||
rootLongName := filepath.Join(parent, "ws-evil")
|
||||
for _, d := range []string{rootShort, rootLongName} {
|
||||
if err := os.MkdirAll(d, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
g := mustGuard(t, rootShort)
|
||||
|
||||
evil := filepath.Join(rootLongName, "secret")
|
||||
if err := os.WriteFile(evil, []byte("x"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err := g.ResolveRead(evil)
|
||||
if !errors.Is(err, ErrOutsideWorkspace) {
|
||||
t.Fatalf("ws-evil/secret should not be considered inside ws; got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGuard_RootItselfAllowed(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
g := mustGuard(t, root)
|
||||
|
||||
if _, err := g.ResolveRead(root); err != nil {
|
||||
t.Errorf("root path itself should be allowed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func mustGuard(t *testing.T, roots ...string) *Guard {
|
||||
t.Helper()
|
||||
g, err := NewGuard(roots...)
|
||||
if err != nil {
|
||||
t.Fatalf("NewGuard: %v", err)
|
||||
}
|
||||
return g
|
||||
}
|
||||
@@ -0,0 +1,184 @@
|
||||
package fs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// These tests exercise each fs tool with a Guard installed, verifying that
|
||||
// paths outside the workspace are rejected at the tool boundary.
|
||||
|
||||
func TestReadTool_GuardDeniesOutsideRoot(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
outside := t.TempDir()
|
||||
outsideFile := filepath.Join(outside, "secret")
|
||||
if err := os.WriteFile(outsideFile, []byte("secret"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
r := NewReadTool()
|
||||
r.SetGuard(mustGuard(t, root))
|
||||
|
||||
res, err := r.Execute(context.Background(), mustJSON(t, readArgs{Path: outsideFile}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if !strings.Contains(res.Output, "path outside workspace") {
|
||||
t.Errorf("expected workspace error, got %q", res.Output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadTool_GuardAllowsInsideRoot(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
inside := filepath.Join(root, "ok.txt")
|
||||
if err := os.WriteFile(inside, []byte("hi"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
r := NewReadTool()
|
||||
r.SetGuard(mustGuard(t, root))
|
||||
|
||||
res, err := r.Execute(context.Background(), mustJSON(t, readArgs{Path: inside}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if !strings.Contains(res.Output, "hi") {
|
||||
t.Errorf("expected file content, got %q", res.Output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteTool_GuardDeniesOutsideRoot(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
outside := t.TempDir()
|
||||
target := filepath.Join(outside, "evil.txt")
|
||||
|
||||
w := NewWriteTool()
|
||||
w.SetGuard(mustGuard(t, root))
|
||||
|
||||
res, err := w.Execute(context.Background(), mustJSON(t, writeArgs{Path: target, Content: "x"}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if !strings.Contains(res.Output, "path outside workspace") {
|
||||
t.Errorf("expected workspace error, got %q", res.Output)
|
||||
}
|
||||
if _, err := os.Stat(target); err == nil {
|
||||
t.Errorf("file was written despite guard: %s", target)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteTool_GuardAllowsInsideRoot(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
target := filepath.Join(root, "sub", "ok.txt")
|
||||
|
||||
w := NewWriteTool()
|
||||
w.SetGuard(mustGuard(t, root))
|
||||
|
||||
res, err := w.Execute(context.Background(), mustJSON(t, writeArgs{Path: target, Content: "hi"}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if !strings.Contains(res.Output, "Wrote") {
|
||||
t.Errorf("expected write confirmation, got %q", res.Output)
|
||||
}
|
||||
if _, err := os.Stat(target); err != nil {
|
||||
t.Errorf("file missing after write: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEditTool_GuardDeniesOutsideRoot(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
outside := t.TempDir()
|
||||
target := filepath.Join(outside, "f.txt")
|
||||
if err := os.WriteFile(target, []byte("hello"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
e := NewEditTool()
|
||||
e.SetGuard(mustGuard(t, root))
|
||||
|
||||
res, err := e.Execute(context.Background(), mustJSON(t, editArgs{Path: target, OldString: "hello", NewString: "hi"}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if !strings.Contains(res.Output, "path outside workspace") {
|
||||
t.Errorf("expected workspace error, got %q", res.Output)
|
||||
}
|
||||
// File must remain unchanged.
|
||||
data, _ := os.ReadFile(target)
|
||||
if string(data) != "hello" {
|
||||
t.Errorf("file mutated despite guard: %q", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLSTool_GuardDeniesOutsideRoot(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
outside := t.TempDir()
|
||||
|
||||
l := NewLSTool()
|
||||
l.SetGuard(mustGuard(t, root))
|
||||
|
||||
res, err := l.Execute(context.Background(), mustJSON(t, lsArgs{Path: outside}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if !strings.Contains(res.Output, "path outside workspace") {
|
||||
t.Errorf("expected workspace error, got %q", res.Output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLSTool_GuardEmptyPathDefaultsToRoot(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
if err := os.WriteFile(filepath.Join(root, "marker.txt"), []byte("x"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
l := NewLSTool()
|
||||
l.SetGuard(mustGuard(t, root))
|
||||
|
||||
res, err := l.Execute(context.Background(), mustJSON(t, lsArgs{}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if !strings.Contains(res.Output, "marker.txt") {
|
||||
t.Errorf("expected to list root contents, got %q", res.Output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobTool_GuardDeniesOutsideRoot(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
outside := t.TempDir()
|
||||
|
||||
g := NewGlobTool()
|
||||
g.SetGuard(mustGuard(t, root))
|
||||
|
||||
res, err := g.Execute(context.Background(), mustJSON(t, globArgs{Pattern: "*", Path: outside}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if !strings.Contains(res.Output, "path outside workspace") {
|
||||
t.Errorf("expected workspace error, got %q", res.Output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGrepTool_GuardDeniesOutsideRoot(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
outside := t.TempDir()
|
||||
if err := os.WriteFile(filepath.Join(outside, "f.txt"), []byte("needle"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
g := NewGrepTool()
|
||||
g.SetGuard(mustGuard(t, root))
|
||||
|
||||
res, err := g.Execute(context.Background(), mustJSON(t, grepArgs{Pattern: "needle", Path: outside}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if !strings.Contains(res.Output, "path outside workspace") {
|
||||
t.Errorf("expected workspace error, got %q", res.Output)
|
||||
}
|
||||
}
|
||||
+20
-5
@@ -24,10 +24,14 @@ var lsParams = json.RawMessage(`{
|
||||
}
|
||||
}`)
|
||||
|
||||
type LSTool struct{}
|
||||
type LSTool struct {
|
||||
guard *Guard
|
||||
}
|
||||
|
||||
func NewLSTool() *LSTool { return &LSTool{} }
|
||||
|
||||
func (t *LSTool) SetGuard(g *Guard) { t.guard = g }
|
||||
|
||||
func (t *LSTool) Name() string { return lsToolName }
|
||||
func (t *LSTool) Description() string { return "List directory contents with file types and sizes" }
|
||||
func (t *LSTool) Parameters() json.RawMessage { return lsParams }
|
||||
@@ -54,12 +58,23 @@ func (t *LSTool) Execute(_ context.Context, args json.RawMessage) (tool.Result,
|
||||
|
||||
dir := a.Path
|
||||
if dir == "" {
|
||||
var err error
|
||||
dir, err = os.Getwd()
|
||||
if err != nil {
|
||||
return tool.Result{}, fmt.Errorf("fs.ls: %w", err)
|
||||
if t.guard != nil {
|
||||
dir = t.guard.Roots()[0]
|
||||
} else {
|
||||
var err error
|
||||
dir, err = os.Getwd()
|
||||
if err != nil {
|
||||
return tool.Result{}, fmt.Errorf("fs.ls: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
if t.guard != nil {
|
||||
resolved, err := t.guard.ResolveRead(dir)
|
||||
if err != nil {
|
||||
return tool.Result{Output: fmt.Sprintf("Error: %v", err)}, nil
|
||||
}
|
||||
dir = resolved
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
|
||||
@@ -36,8 +36,11 @@ var readParams = json.RawMessage(`{
|
||||
|
||||
type ReadTool struct {
|
||||
maxLines int
|
||||
guard *Guard
|
||||
}
|
||||
|
||||
func (t *ReadTool) SetGuard(g *Guard) { t.guard = g }
|
||||
|
||||
type ReadOption func(*ReadTool)
|
||||
|
||||
func WithMaxLines(n int) ReadOption {
|
||||
@@ -81,7 +84,16 @@ func (t *ReadTool) Execute(_ context.Context, args json.RawMessage) (tool.Result
|
||||
return tool.Result{}, fmt.Errorf("fs.read: path required")
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(a.Path)
|
||||
path := a.Path
|
||||
if t.guard != nil {
|
||||
resolved, err := t.guard.ResolveRead(path)
|
||||
if err != nil {
|
||||
return tool.Result{Output: fmt.Sprintf("Error: %v", err)}, nil
|
||||
}
|
||||
path = resolved
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return tool.Result{Output: fmt.Sprintf("Error: %v", err)}, nil
|
||||
}
|
||||
|
||||
@@ -36,8 +36,11 @@ func WithMaxFileSize(n int64) WriteOption {
|
||||
|
||||
type WriteTool struct {
|
||||
maxFileSize int64
|
||||
guard *Guard
|
||||
}
|
||||
|
||||
func (t *WriteTool) SetGuard(g *Guard) { t.guard = g }
|
||||
|
||||
func NewWriteTool(opts ...WriteOption) *WriteTool {
|
||||
t := &WriteTool{}
|
||||
for _, opt := range opts {
|
||||
@@ -78,18 +81,27 @@ func (t *WriteTool) Execute(_ context.Context, args json.RawMessage) (tool.Resul
|
||||
return tool.Result{Output: fmt.Sprintf("Error: content too large (%d bytes, limit %d bytes)", len(a.Content), t.maxFileSize)}, nil
|
||||
}
|
||||
|
||||
path := a.Path
|
||||
if t.guard != nil {
|
||||
resolved, err := t.guard.ResolveWrite(path)
|
||||
if err != nil {
|
||||
return tool.Result{Output: fmt.Sprintf("Error: %v", err)}, nil
|
||||
}
|
||||
path = resolved
|
||||
}
|
||||
|
||||
// Create parent directories
|
||||
dir := filepath.Dir(a.Path)
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return tool.Result{Output: fmt.Sprintf("Error creating directory: %v", err)}, nil
|
||||
}
|
||||
|
||||
if err := os.WriteFile(a.Path, []byte(a.Content), 0o644); err != nil {
|
||||
if err := os.WriteFile(path, []byte(a.Content), 0o644); err != nil {
|
||||
return tool.Result{Output: fmt.Sprintf("Error writing file: %v", err)}, nil
|
||||
}
|
||||
|
||||
return tool.Result{
|
||||
Output: fmt.Sprintf("Wrote %d bytes to %s", len(a.Content), a.Path),
|
||||
Metadata: map[string]any{"bytes_written": len(a.Content), "path": a.Path},
|
||||
Output: fmt.Sprintf("Wrote %d bytes to %s", len(a.Content), path),
|
||||
Metadata: map[string]any{"bytes_written": len(a.Content), "path": path},
|
||||
}, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user