feat: engine hook integration — PreToolUse, PostToolUse, Stop
This commit is contained in:
@@ -12,6 +12,7 @@ import (
|
||||
"somegit.dev/Owlibou/gnoma/internal/router"
|
||||
"somegit.dev/Owlibou/gnoma/internal/security"
|
||||
"somegit.dev/Owlibou/gnoma/internal/tool"
|
||||
"somegit.dev/Owlibou/gnoma/internal/hook"
|
||||
"somegit.dev/Owlibou/gnoma/internal/tool/persist"
|
||||
)
|
||||
|
||||
@@ -27,6 +28,7 @@ type Config struct {
|
||||
Model string // override model (empty = provider default)
|
||||
MaxTurns int // safety limit on tool loops (0 = unlimited)
|
||||
Store *persist.Store // nil = no result persistence
|
||||
Hooks *hook.Dispatcher // nil = no hooks
|
||||
Logger *slog.Logger
|
||||
}
|
||||
|
||||
|
||||
346
internal/engine/hook_integration_test.go
Normal file
346
internal/engine/hook_integration_test.go
Normal file
@@ -0,0 +1,346 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/hook"
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
"somegit.dev/Owlibou/gnoma/internal/stream"
|
||||
"somegit.dev/Owlibou/gnoma/internal/tool"
|
||||
)
|
||||
|
||||
// --- test executors ---
|
||||
|
||||
type blockingExecutor struct{}
|
||||
|
||||
func (b *blockingExecutor) Execute(_ context.Context, _ []byte) (hook.HookResult, error) {
|
||||
return hook.HookResult{Action: hook.Deny}, nil
|
||||
}
|
||||
|
||||
type allowingExecutor struct{}
|
||||
|
||||
func (a *allowingExecutor) Execute(_ context.Context, _ []byte) (hook.HookResult, error) {
|
||||
return hook.HookResult{Action: hook.Allow}, nil
|
||||
}
|
||||
|
||||
// argTransformExecutor replaces the "args" field in the payload.
|
||||
type argTransformExecutor struct{ newArgs json.RawMessage }
|
||||
|
||||
func (t *argTransformExecutor) Execute(_ context.Context, payload []byte) (hook.HookResult, error) {
|
||||
out, _ := json.Marshal(map[string]any{
|
||||
"tool": hook.ExtractToolName(payload),
|
||||
"args": t.newArgs,
|
||||
})
|
||||
return hook.HookResult{Action: hook.Allow, Output: out}, nil
|
||||
}
|
||||
|
||||
// resultTransformExecutor replaces the tool output.
|
||||
type resultTransformExecutor struct{ newOutput string }
|
||||
|
||||
func (r *resultTransformExecutor) Execute(_ context.Context, _ []byte) (hook.HookResult, error) {
|
||||
out, _ := json.Marshal(map[string]any{"output": r.newOutput})
|
||||
return hook.HookResult{Action: hook.Allow, Output: out}, nil
|
||||
}
|
||||
|
||||
// recordingExecutor records whether it was called and the payload.
|
||||
type recordingExecutor struct {
|
||||
called bool
|
||||
payload []byte
|
||||
}
|
||||
|
||||
func (r *recordingExecutor) Execute(_ context.Context, payload []byte) (hook.HookResult, error) {
|
||||
r.called = true
|
||||
r.payload = append([]byte(nil), payload...)
|
||||
return hook.HookResult{Action: hook.Allow}, nil
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
func hookDispatcher(event hook.EventType, ex hook.Executor) *hook.Dispatcher {
|
||||
def := hook.HookDef{Name: "test", Event: event, Command: hook.CommandTypeShell, Exec: "x"}
|
||||
d := &hook.Dispatcher{}
|
||||
d.SetChain(event, []hook.Handler{hook.NewHandler(def, ex)})
|
||||
return d
|
||||
}
|
||||
|
||||
// toolCallStream builds a stream that emits a single tool call then stops.
|
||||
func toolCallStream(callID, toolName, args string, stopReason message.StopReason, model string) stream.Stream {
|
||||
events := []stream.Event{
|
||||
{Type: stream.EventToolCallDone, ToolCallID: callID, ToolCallName: toolName, Args: json.RawMessage(args)},
|
||||
{Type: stream.EventTextDelta, StopReason: stopReason, Model: model},
|
||||
}
|
||||
return &eventStream{events: events}
|
||||
}
|
||||
|
||||
// --- tests ---
|
||||
|
||||
func TestHook_NilDispatcher_NoChange(t *testing.T) {
|
||||
mp := &mockProvider{
|
||||
streams: []stream.Stream{
|
||||
newEventStream(message.StopEndTurn, "m",
|
||||
stream.Event{Type: stream.EventTextDelta, Text: "hello"},
|
||||
),
|
||||
},
|
||||
}
|
||||
eng, err := New(Config{Provider: mp, Tools: tool.NewRegistry()})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
turn, err := eng.Submit(context.Background(), "hi", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if turn.Rounds != 1 {
|
||||
t.Errorf("rounds = %d, want 1", turn.Rounds)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHook_PreToolUse_Deny(t *testing.T) {
|
||||
executed := false
|
||||
reg := tool.NewRegistry()
|
||||
reg.Register(&mockTool{
|
||||
name: "bash",
|
||||
execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) {
|
||||
executed = true
|
||||
return tool.Result{Output: "should not run"}, nil
|
||||
},
|
||||
})
|
||||
|
||||
mp := &mockProvider{
|
||||
streams: []stream.Stream{
|
||||
toolCallStream("c1", "bash", `{"command":"rm -rf /"}`, message.StopToolUse, "m"),
|
||||
newEventStream(message.StopEndTurn, "m",
|
||||
stream.Event{Type: stream.EventTextDelta, Text: "ok"},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
eng, _ := New(Config{
|
||||
Provider: mp,
|
||||
Tools: reg,
|
||||
Hooks: hookDispatcher(hook.PreToolUse, &blockingExecutor{}),
|
||||
})
|
||||
eng.Submit(context.Background(), "run", nil)
|
||||
|
||||
if executed {
|
||||
t.Error("tool was executed despite PreToolUse deny")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHook_PreToolUse_Allow(t *testing.T) {
|
||||
executed := false
|
||||
reg := tool.NewRegistry()
|
||||
reg.Register(&mockTool{
|
||||
name: "bash",
|
||||
execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) {
|
||||
executed = true
|
||||
return tool.Result{Output: "ran"}, nil
|
||||
},
|
||||
})
|
||||
|
||||
mp := &mockProvider{
|
||||
streams: []stream.Stream{
|
||||
toolCallStream("c1", "bash", `{}`, message.StopToolUse, "m"),
|
||||
newEventStream(message.StopEndTurn, "m",
|
||||
stream.Event{Type: stream.EventTextDelta, Text: "ok"},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
eng, _ := New(Config{
|
||||
Provider: mp,
|
||||
Tools: reg,
|
||||
Hooks: hookDispatcher(hook.PreToolUse, &allowingExecutor{}),
|
||||
})
|
||||
eng.Submit(context.Background(), "run", nil)
|
||||
|
||||
if !executed {
|
||||
t.Error("tool was not executed despite PreToolUse allow")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHook_PreToolUse_DenyMessage(t *testing.T) {
|
||||
reg := tool.NewRegistry()
|
||||
reg.Register(&mockTool{
|
||||
name: "bash",
|
||||
execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) {
|
||||
return tool.Result{Output: "should not run"}, nil
|
||||
},
|
||||
})
|
||||
|
||||
mp := &mockProvider{
|
||||
streams: []stream.Stream{
|
||||
toolCallStream("c1", "bash", `{}`, message.StopToolUse, "m"),
|
||||
newEventStream(message.StopEndTurn, "m",
|
||||
stream.Event{Type: stream.EventTextDelta, Text: "ok"},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
eng, _ := New(Config{
|
||||
Provider: mp,
|
||||
Tools: reg,
|
||||
Hooks: hookDispatcher(hook.PreToolUse, &blockingExecutor{}),
|
||||
})
|
||||
eng.Submit(context.Background(), "run", nil)
|
||||
|
||||
for _, msg := range eng.History() {
|
||||
for _, c := range msg.Content {
|
||||
if c.Type == message.ContentToolResult && c.ToolResult != nil {
|
||||
if !strings.HasPrefix(c.ToolResult.Content, "denied by hook") {
|
||||
t.Errorf("denied result = %q, want prefix 'denied by hook'", c.ToolResult.Content)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
t.Error("no tool result found in history")
|
||||
}
|
||||
|
||||
func TestHook_PreToolUse_Transform(t *testing.T) {
|
||||
var receivedArgs json.RawMessage
|
||||
reg := tool.NewRegistry()
|
||||
reg.Register(&mockTool{
|
||||
name: "bash",
|
||||
execFn: func(_ context.Context, args json.RawMessage) (tool.Result, error) {
|
||||
receivedArgs = args
|
||||
return tool.Result{Output: "ok"}, nil
|
||||
},
|
||||
})
|
||||
|
||||
mp := &mockProvider{
|
||||
streams: []stream.Stream{
|
||||
toolCallStream("c1", "bash", `{"command":"original"}`, message.StopToolUse, "m"),
|
||||
newEventStream(message.StopEndTurn, "m",
|
||||
stream.Event{Type: stream.EventTextDelta, Text: "done"},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
eng, _ := New(Config{
|
||||
Provider: mp,
|
||||
Tools: reg,
|
||||
Hooks: hookDispatcher(hook.PreToolUse,
|
||||
&argTransformExecutor{newArgs: json.RawMessage(`{"command":"safe-replacement"}`)}),
|
||||
})
|
||||
eng.Submit(context.Background(), "run", nil)
|
||||
|
||||
var got map[string]string
|
||||
json.Unmarshal(receivedArgs, &got)
|
||||
if got["command"] != "safe-replacement" {
|
||||
t.Errorf("tool args = %s, want safe-replacement", receivedArgs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHook_PostToolUse_Transform(t *testing.T) {
|
||||
reg := tool.NewRegistry()
|
||||
reg.Register(&mockTool{
|
||||
name: "bash",
|
||||
execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) {
|
||||
return tool.Result{Output: "original output"}, nil
|
||||
},
|
||||
})
|
||||
|
||||
mp := &mockProvider{
|
||||
streams: []stream.Stream{
|
||||
toolCallStream("c1", "bash", `{}`, message.StopToolUse, "m"),
|
||||
newEventStream(message.StopEndTurn, "m",
|
||||
stream.Event{Type: stream.EventTextDelta, Text: "done"},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
eng, _ := New(Config{
|
||||
Provider: mp,
|
||||
Tools: reg,
|
||||
Hooks: hookDispatcher(hook.PostToolUse,
|
||||
&resultTransformExecutor{newOutput: "transformed output"}),
|
||||
})
|
||||
eng.Submit(context.Background(), "run", nil)
|
||||
|
||||
for _, msg := range eng.History() {
|
||||
for _, c := range msg.Content {
|
||||
if c.Type == message.ContentToolResult && c.ToolResult != nil {
|
||||
if c.ToolResult.Content != "transformed output" {
|
||||
t.Errorf("tool result = %q, want 'transformed output'", c.ToolResult.Content)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
t.Error("no tool result found in history")
|
||||
}
|
||||
|
||||
func TestHook_PostToolUse_DenyTreatedAsSkip(t *testing.T) {
|
||||
reg := tool.NewRegistry()
|
||||
reg.Register(&mockTool{
|
||||
name: "bash",
|
||||
execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) {
|
||||
return tool.Result{Output: "tool ran"}, nil
|
||||
},
|
||||
})
|
||||
|
||||
mp := &mockProvider{
|
||||
streams: []stream.Stream{
|
||||
toolCallStream("c1", "bash", `{}`, message.StopToolUse, "m"),
|
||||
newEventStream(message.StopEndTurn, "m",
|
||||
stream.Event{Type: stream.EventTextDelta, Text: "done"},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
eng, _ := New(Config{
|
||||
Provider: mp,
|
||||
Tools: reg,
|
||||
Hooks: hookDispatcher(hook.PostToolUse, &blockingExecutor{}),
|
||||
})
|
||||
turn, err := eng.Submit(context.Background(), "run", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// 2 rounds = tool call + end turn, confirming the result reached the LLM.
|
||||
if turn.Rounds != 2 {
|
||||
t.Errorf("rounds = %d, want 2 (result reached LLM despite PostToolUse deny)", turn.Rounds)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHook_Stop_MaxTurns(t *testing.T) {
|
||||
// Stop hook fires when MaxTurns is exceeded.
|
||||
stopRecorder := &recordingExecutor{}
|
||||
reg := tool.NewRegistry()
|
||||
reg.Register(&mockTool{
|
||||
name: "bash",
|
||||
execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) {
|
||||
return tool.Result{Output: "ok"}, nil
|
||||
},
|
||||
})
|
||||
|
||||
mp := &mockProvider{
|
||||
streams: []stream.Stream{
|
||||
// Round 1: tool call → will loop to round 2
|
||||
toolCallStream("c1", "bash", `{}`, message.StopToolUse, "m"),
|
||||
// Round 2: MaxTurns=1 check triggers before this, so it's never consumed
|
||||
},
|
||||
}
|
||||
|
||||
d := &hook.Dispatcher{}
|
||||
d.SetChain(hook.Stop, []hook.Handler{
|
||||
hook.NewHandler(
|
||||
hook.HookDef{Name: "stop-rec", Event: hook.Stop, Command: hook.CommandTypeShell, Exec: "x"},
|
||||
stopRecorder,
|
||||
),
|
||||
})
|
||||
|
||||
eng, _ := New(Config{Provider: mp, Tools: reg, Hooks: d, MaxTurns: 1})
|
||||
_, err := eng.Submit(context.Background(), "run", nil)
|
||||
// MaxTurns exceeded returns an error
|
||||
if err == nil {
|
||||
t.Fatal("expected error for MaxTurns exceeded")
|
||||
}
|
||||
if !stopRecorder.called {
|
||||
t.Error("Stop hook was not fired on MaxTurns exceeded")
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
gnomactx "somegit.dev/Owlibou/gnoma/internal/context"
|
||||
"somegit.dev/Owlibou/gnoma/internal/hook"
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
"somegit.dev/Owlibou/gnoma/internal/permission"
|
||||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||||
@@ -55,6 +56,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
|
||||
for {
|
||||
turn.Rounds++
|
||||
if e.cfg.MaxTurns > 0 && turn.Rounds > e.cfg.MaxTurns {
|
||||
e.cfg.Hooks.Fire(hook.Stop, hook.MarshalStopPayload("max_turns")) //nolint:errcheck
|
||||
return turn, fmt.Errorf("safety limit: %d rounds exceeded", e.cfg.MaxTurns)
|
||||
}
|
||||
|
||||
@@ -227,6 +229,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
|
||||
// Decide next action
|
||||
switch resp.StopReason {
|
||||
case message.StopEndTurn, message.StopSequence:
|
||||
e.cfg.Hooks.Fire(hook.Stop, hook.MarshalStopPayload("end_turn")) //nolint:errcheck
|
||||
return turn, nil
|
||||
|
||||
case message.StopMaxTokens:
|
||||
@@ -254,6 +257,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
|
||||
|
||||
default:
|
||||
// Unknown stop reason or empty — treat as end of turn
|
||||
e.cfg.Hooks.Fire(hook.Stop, hook.MarshalStopPayload("unknown")) //nolint:errcheck
|
||||
return turn, nil
|
||||
}
|
||||
}
|
||||
@@ -411,9 +415,26 @@ func (e *Engine) executeSingleTool(ctx context.Context, call message.ToolCall, t
|
||||
}
|
||||
}
|
||||
|
||||
// PreToolUse hook: can deny execution or transform args.
|
||||
args := call.Arguments
|
||||
if e.cfg.Hooks != nil {
|
||||
payload := hook.MarshalPreToolPayload(call.Name, args)
|
||||
transformed, action, _ := e.cfg.Hooks.Fire(hook.PreToolUse, payload)
|
||||
if action == hook.Deny {
|
||||
return message.ToolResult{
|
||||
ToolCallID: call.ID,
|
||||
Content: "denied by hook",
|
||||
IsError: true,
|
||||
}
|
||||
}
|
||||
if newArgs := hook.ExtractTransformedArgs(transformed); newArgs != nil {
|
||||
args = newArgs
|
||||
}
|
||||
}
|
||||
|
||||
e.logger.Debug("executing tool", "name", call.Name, "id", call.ID)
|
||||
|
||||
result, err := t.Execute(ctx, call.Arguments)
|
||||
result, err := t.Execute(ctx, args)
|
||||
if err != nil {
|
||||
e.logger.Error("tool execution failed", "name", call.Name, "error", err)
|
||||
return message.ToolResult{
|
||||
@@ -423,8 +444,17 @@ func (e *Engine) executeSingleTool(ctx context.Context, call message.ToolCall, t
|
||||
}
|
||||
}
|
||||
|
||||
// Scan tool result through firewall
|
||||
// PostToolUse hook: can transform result (Deny treated as Skip).
|
||||
output := result.Output
|
||||
if e.cfg.Hooks != nil {
|
||||
payload := hook.MarshalPostToolPayload(call.Name, args, output, result.Metadata)
|
||||
transformed, _, _ := e.cfg.Hooks.Fire(hook.PostToolUse, payload)
|
||||
if s := hook.ExtractTransformedOutput(transformed); s != "" {
|
||||
output = s
|
||||
}
|
||||
}
|
||||
|
||||
// Scan tool result through firewall
|
||||
if e.cfg.Firewall != nil {
|
||||
output = e.cfg.Firewall.ScanToolResult(output)
|
||||
}
|
||||
|
||||
@@ -4,25 +4,21 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/elf"
|
||||
"somegit.dev/Owlibou/gnoma/internal/router"
|
||||
)
|
||||
|
||||
// ElfSpawner is the minimal interface AgentExecutor needs from elf.Manager.
|
||||
type ElfSpawner interface {
|
||||
Spawn(ctx context.Context, taskType router.TaskType, prompt, systemPrompt string, maxTurns int) (elf.Elf, error)
|
||||
}
|
||||
// ElfSpawnFn spawns an elf with the given prompt and returns its output text.
|
||||
// This is satisfied by a closure wrapping elf.Manager.Spawn in main.go.
|
||||
type ElfSpawnFn func(ctx context.Context, prompt string) (output string, err error)
|
||||
|
||||
// AgentExecutor spawns an elf and parses ALLOW/DENY from its output.
|
||||
type AgentExecutor struct {
|
||||
def HookDef
|
||||
spawner ElfSpawner
|
||||
spawnFn ElfSpawnFn
|
||||
}
|
||||
|
||||
// NewAgentExecutor constructs an AgentExecutor.
|
||||
func NewAgentExecutor(def HookDef, spawner ElfSpawner) *AgentExecutor {
|
||||
return &AgentExecutor{def: def, spawner: spawner}
|
||||
func NewAgentExecutor(def HookDef, spawnFn ElfSpawnFn) *AgentExecutor {
|
||||
return &AgentExecutor{def: def, spawnFn: spawnFn}
|
||||
}
|
||||
|
||||
// Execute renders the hook template, spawns an elf, waits for its result,
|
||||
@@ -35,19 +31,13 @@ func (a *AgentExecutor) Execute(ctx context.Context, payload []byte) (HookResult
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
e, err := a.spawner.Spawn(ctx, router.TaskReview, prompt, "", 5)
|
||||
if err != nil {
|
||||
return HookResult{}, fmt.Errorf("hook %q: spawn elf: %w", a.def.Name, err)
|
||||
}
|
||||
|
||||
result := e.Wait()
|
||||
output, err := a.spawnFn(ctx, prompt)
|
||||
duration := time.Since(start)
|
||||
|
||||
if result.Error != nil {
|
||||
return HookResult{Duration: duration}, fmt.Errorf("hook %q: elf failed: %w", a.def.Name, result.Error)
|
||||
if err != nil {
|
||||
return HookResult{Duration: duration}, fmt.Errorf("hook %q: elf failed: %w", a.def.Name, err)
|
||||
}
|
||||
|
||||
action := parseDecision(result.Output)
|
||||
action := parseDecision(output)
|
||||
return HookResult{
|
||||
Action: action,
|
||||
Duration: duration,
|
||||
|
||||
@@ -5,45 +5,32 @@ import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/elf"
|
||||
"somegit.dev/Owlibou/gnoma/internal/router"
|
||||
"somegit.dev/Owlibou/gnoma/internal/stream"
|
||||
)
|
||||
|
||||
// mockElfSpawner satisfies ElfSpawner. Records calls and returns configurable results.
|
||||
type mockElfSpawner struct {
|
||||
result elf.Result
|
||||
err error
|
||||
// Captures
|
||||
lastPrompt string
|
||||
lastTask router.TaskType
|
||||
func spawnFnOK(output string) ElfSpawnFn {
|
||||
return func(_ context.Context, _ string) (string, error) {
|
||||
return output, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockElfSpawner) Spawn(ctx context.Context, taskType router.TaskType, prompt, systemPrompt string, maxTurns int) (elf.Elf, error) {
|
||||
m.lastPrompt = prompt
|
||||
m.lastTask = taskType
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
func spawnFnErr(err error) ElfSpawnFn {
|
||||
return func(_ context.Context, _ string) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
return &immediateElf{result: m.result}, nil
|
||||
}
|
||||
|
||||
// immediateElf returns a pre-computed result immediately.
|
||||
type immediateElf struct {
|
||||
result elf.Result
|
||||
func capturingSpawnFn(output string) (ElfSpawnFn, *string) {
|
||||
captured := new(string)
|
||||
fn := func(_ context.Context, prompt string) (string, error) {
|
||||
*captured = prompt
|
||||
return output, nil
|
||||
}
|
||||
return fn, captured
|
||||
}
|
||||
|
||||
func (e *immediateElf) ID() string { return "test-elf" }
|
||||
func (e *immediateElf) Status() elf.Status { return e.result.Status }
|
||||
func (e *immediateElf) Events() <-chan stream.Event { return nil }
|
||||
func (e *immediateElf) Wait() elf.Result { return e.result }
|
||||
func (e *immediateElf) Cancel() {}
|
||||
|
||||
func TestAgentExecutor_OutputALLOW(t *testing.T) {
|
||||
def := HookDef{Name: "test", Event: PreToolUse, Command: CommandTypeAgent, Exec: "Review this tool call."}
|
||||
spawner := &mockElfSpawner{result: elf.Result{Output: "After analysis, ALLOW this.", Status: elf.StatusCompleted}}
|
||||
ex := NewAgentExecutor(def, spawner)
|
||||
ex := NewAgentExecutor(def, spawnFnOK("After analysis, ALLOW this."))
|
||||
result, err := ex.Execute(context.Background(), MarshalPreToolPayload("bash", nil))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -55,8 +42,7 @@ func TestAgentExecutor_OutputALLOW(t *testing.T) {
|
||||
|
||||
func TestAgentExecutor_OutputDENY(t *testing.T) {
|
||||
def := HookDef{Name: "test", Event: PreToolUse, Command: CommandTypeAgent, Exec: "Review this."}
|
||||
spawner := &mockElfSpawner{result: elf.Result{Output: "This is dangerous. DENY.", Status: elf.StatusCompleted}}
|
||||
ex := NewAgentExecutor(def, spawner)
|
||||
ex := NewAgentExecutor(def, spawnFnOK("This is dangerous. DENY."))
|
||||
result, err := ex.Execute(context.Background(), MarshalPreToolPayload("bash", nil))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -68,8 +54,7 @@ func TestAgentExecutor_OutputDENY(t *testing.T) {
|
||||
|
||||
func TestAgentExecutor_OutputNoMatch_Skip(t *testing.T) {
|
||||
def := HookDef{Name: "test", Event: PreToolUse, Command: CommandTypeAgent, Exec: "Review this."}
|
||||
spawner := &mockElfSpawner{result: elf.Result{Output: "I'm unsure.", Status: elf.StatusCompleted}}
|
||||
ex := NewAgentExecutor(def, spawner)
|
||||
ex := NewAgentExecutor(def, spawnFnOK("I'm unsure."))
|
||||
result, err := ex.Execute(context.Background(), MarshalPreToolPayload("bash", nil))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -79,24 +64,9 @@ func TestAgentExecutor_OutputNoMatch_Skip(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentExecutor_ElfFailure_Error(t *testing.T) {
|
||||
def := HookDef{Name: "test", Event: PreToolUse, Command: CommandTypeAgent, Exec: "Review."}
|
||||
spawner := &mockElfSpawner{result: elf.Result{
|
||||
Output: "",
|
||||
Status: elf.StatusFailed,
|
||||
Error: errors.New("elf crashed"),
|
||||
}}
|
||||
ex := NewAgentExecutor(def, spawner)
|
||||
_, err := ex.Execute(context.Background(), MarshalPreToolPayload("bash", nil))
|
||||
if err == nil {
|
||||
t.Error("expected error for failed elf")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentExecutor_SpawnError(t *testing.T) {
|
||||
def := HookDef{Name: "test", Event: PreToolUse, Command: CommandTypeAgent, Exec: "Review."}
|
||||
spawner := &mockElfSpawner{err: errors.New("no arms available")}
|
||||
ex := NewAgentExecutor(def, spawner)
|
||||
ex := NewAgentExecutor(def, spawnFnErr(errors.New("no arms available")))
|
||||
_, err := ex.Execute(context.Background(), MarshalPreToolPayload("bash", nil))
|
||||
if err == nil {
|
||||
t.Error("expected error when spawn fails")
|
||||
@@ -110,30 +80,23 @@ func TestAgentExecutor_TemplateRendered(t *testing.T) {
|
||||
Command: CommandTypeAgent,
|
||||
Exec: "Tool={{.Tool}} Event={{.Event}}",
|
||||
}
|
||||
spawner := &mockElfSpawner{result: elf.Result{Output: "ALLOW", Status: elf.StatusCompleted}}
|
||||
ex := NewAgentExecutor(def, spawner)
|
||||
fn, captured := capturingSpawnFn("ALLOW")
|
||||
ex := NewAgentExecutor(def, fn)
|
||||
ex.Execute(context.Background(), MarshalPreToolPayload("bash", nil))
|
||||
if spawner.lastPrompt != "Tool=bash Event=pre_tool_use" {
|
||||
t.Errorf("prompt = %q", spawner.lastPrompt)
|
||||
if *captured != "Tool=bash Event=pre_tool_use" {
|
||||
t.Errorf("prompt = %q", *captured)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentExecutor_Duration(t *testing.T) {
|
||||
def := HookDef{Name: "test", Event: PreToolUse, Command: CommandTypeAgent, Exec: "Review."}
|
||||
spawner := &mockElfSpawner{result: elf.Result{Output: "ALLOW", Status: elf.StatusCompleted, Duration: 100 * time.Millisecond}}
|
||||
ex := NewAgentExecutor(def, spawner)
|
||||
fn := func(_ context.Context, _ string) (string, error) {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
return "ALLOW", nil
|
||||
}
|
||||
ex := NewAgentExecutor(def, fn)
|
||||
result, _ := ex.Execute(context.Background(), MarshalPreToolPayload("bash", nil))
|
||||
if result.Duration <= 0 {
|
||||
t.Error("expected Duration > 0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentExecutor_TaskTypeIsReview(t *testing.T) {
|
||||
def := HookDef{Name: "test", Event: PreToolUse, Command: CommandTypeAgent, Exec: "Review."}
|
||||
spawner := &mockElfSpawner{result: elf.Result{Output: "ALLOW", Status: elf.StatusCompleted}}
|
||||
ex := NewAgentExecutor(def, spawner)
|
||||
ex.Execute(context.Background(), MarshalPreToolPayload("bash", nil))
|
||||
if spawner.lastTask != router.TaskReview {
|
||||
t.Errorf("task type = %v, want TaskReview", spawner.lastTask)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,9 +13,23 @@ type Dispatcher struct {
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// SetChain replaces the handler chain for an event. Primarily for testing.
|
||||
func (d *Dispatcher) SetChain(event EventType, handlers []Handler) {
|
||||
if d.chains == nil {
|
||||
d.chains = make(map[EventType][]Handler)
|
||||
}
|
||||
d.chains[event] = handlers
|
||||
}
|
||||
|
||||
// NewHandler constructs a Handler from a definition and executor.
|
||||
func NewHandler(def HookDef, ex Executor) Handler {
|
||||
return Handler{def: def, executor: ex}
|
||||
}
|
||||
|
||||
// NewDispatcher validates defs, constructs the appropriate executor per
|
||||
// CommandType, and groups handlers by EventType.
|
||||
func NewDispatcher(defs []HookDef, logger *slog.Logger, executorFn func(HookDef) (Executor, error)) (*Dispatcher, error) {
|
||||
// streamer and spawnFn may be nil if no prompt/agent hooks are configured.
|
||||
func NewDispatcher(defs []HookDef, streamer Streamer, spawnFn ElfSpawnFn, logger *slog.Logger) (*Dispatcher, error) {
|
||||
if logger == nil {
|
||||
logger = slog.Default()
|
||||
}
|
||||
@@ -27,7 +41,7 @@ func NewDispatcher(defs []HookDef, logger *slog.Logger, executorFn func(HookDef)
|
||||
if err := def.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("hook.NewDispatcher: %w", err)
|
||||
}
|
||||
ex, err := executorFn(def)
|
||||
ex, err := buildExecutor(def, streamer, spawnFn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hook.NewDispatcher: building executor for %q: %w", def.Name, err)
|
||||
}
|
||||
@@ -36,6 +50,26 @@ func NewDispatcher(defs []HookDef, logger *slog.Logger, executorFn func(HookDef)
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// buildExecutor constructs the right Executor for a HookDef.
|
||||
func buildExecutor(def HookDef, streamer Streamer, spawnFn ElfSpawnFn) (Executor, error) {
|
||||
switch def.Command {
|
||||
case CommandTypeShell:
|
||||
return NewCommandExecutor(def), nil
|
||||
case CommandTypePrompt:
|
||||
if streamer == nil {
|
||||
return nil, fmt.Errorf("prompt hook %q requires a Streamer (no router configured)", def.Name)
|
||||
}
|
||||
return NewPromptExecutor(def, streamer), nil
|
||||
case CommandTypeAgent:
|
||||
if spawnFn == nil {
|
||||
return nil, fmt.Errorf("agent hook %q requires an ElfSpawnFn (no elf manager configured)", def.Name)
|
||||
}
|
||||
return NewAgentExecutor(def, spawnFn), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown command type %v", def.Command)
|
||||
}
|
||||
}
|
||||
|
||||
// Fire runs all handlers registered for event, in order.
|
||||
// Returns the (possibly transformed) payload, the aggregate Action, and the first error.
|
||||
// Safe to call on a nil *Dispatcher — returns (payload, Allow, nil).
|
||||
|
||||
@@ -136,6 +136,21 @@ func parseActionString(s string) (Action, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// ExtractTransformedArgs extracts the "args" field from a transformed PreToolUse payload.
|
||||
// Returns nil if the field is absent or the payload is malformed.
|
||||
func ExtractTransformedArgs(payload []byte) json.RawMessage {
|
||||
if payload == nil {
|
||||
return nil
|
||||
}
|
||||
var v struct {
|
||||
Args json.RawMessage `json:"args"`
|
||||
}
|
||||
if err := json.Unmarshal(payload, &v); err != nil {
|
||||
return nil
|
||||
}
|
||||
return v.Args
|
||||
}
|
||||
|
||||
// ExtractTransformedOutput extracts the "output" string from a PostToolUse
|
||||
// transformed payload. Returns "" if the payload is nil or malformed.
|
||||
func ExtractTransformedOutput(transformed json.RawMessage) string {
|
||||
|
||||
Reference in New Issue
Block a user