feat: PromptExecutor — LLM-based hook evaluation via router

This commit is contained in:
2026-04-07 00:53:53 +02:00
parent 7d0b9c222f
commit 1aa1d83e9e
2 changed files with 358 additions and 0 deletions

135
internal/hook/prompt.go Normal file
View File

@@ -0,0 +1,135 @@
package hook
import (
"bytes"
"context"
"encoding/json"
"fmt"
"strings"
"text/template"
"time"
"somegit.dev/Owlibou/gnoma/internal/message"
"somegit.dev/Owlibou/gnoma/internal/stream"
)
// TemplateData holds the variables available in hook prompt templates.
type TemplateData struct {
Event string
Tool string
Args string
Result string
}
// renderTemplate executes a text/template with the given data.
func renderTemplate(tmpl string, data TemplateData) (string, error) {
t, err := template.New("hook").Parse(tmpl)
if err != nil {
return "", fmt.Errorf("hook: template parse error: %w", err)
}
var buf bytes.Buffer
if err := t.Execute(&buf, data); err != nil {
return "", fmt.Errorf("hook: template execute error: %w", err)
}
return buf.String(), nil
}
// parseDecision scans text for the first case-insensitive occurrence of
// "ALLOW" or "DENY". Returns Skip if neither is found.
func parseDecision(text string) Action {
upper := strings.ToUpper(text)
ai := strings.Index(upper, "ALLOW")
di := strings.Index(upper, "DENY")
switch {
case ai >= 0 && (di < 0 || ai < di):
return Allow
case di >= 0:
return Deny
default:
return Skip
}
}
// Streamer is the minimal interface PromptExecutor needs from the router.
// *router.Router satisfies this interface via an adapter in main.go.
type Streamer interface {
Stream(ctx context.Context, prompt string) (stream.Stream, error)
}
// PromptExecutor sends a templated prompt to an LLM and parses ALLOW/DENY
// from the response.
type PromptExecutor struct {
def HookDef
streamer Streamer
}
// NewPromptExecutor constructs a PromptExecutor.
func NewPromptExecutor(def HookDef, streamer Streamer) *PromptExecutor {
return &PromptExecutor{def: def, streamer: streamer}
}
// Execute renders the template, sends the prompt, and parses the response.
func (p *PromptExecutor) Execute(ctx context.Context, payload []byte) (HookResult, error) {
data := templateDataFromPayload(payload, p.def.Event)
prompt, err := renderTemplate(p.def.Exec, data)
if err != nil {
return HookResult{}, fmt.Errorf("hook %q: %w", p.def.Name, err)
}
start := time.Now()
s, err := p.streamer.Stream(ctx, prompt)
if err != nil {
return HookResult{}, fmt.Errorf("hook %q: stream error: %w", p.def.Name, err)
}
defer s.Close()
acc := stream.NewAccumulator()
var stopReason message.StopReason
var model string
for s.Next() {
evt := s.Current()
acc.Apply(evt)
if evt.StopReason != "" {
stopReason = evt.StopReason
model = evt.Model
}
}
if err := s.Err(); err != nil {
return HookResult{}, fmt.Errorf("hook %q: stream error: %w", p.def.Name, err)
}
resp := acc.Response(stopReason, model)
text := resp.Message.TextContent()
action := parseDecision(text)
return HookResult{
Action: action,
Output: []byte(text),
Duration: time.Since(start),
}, nil
}
// templateDataFromPayload builds TemplateData from a hook payload.
func templateDataFromPayload(payload []byte, event EventType) TemplateData {
data := TemplateData{Event: event.String()}
if event == PreToolUse || event == PostToolUse {
data.Tool = ExtractToolName(payload)
data.Args = extractRawField(payload, "args")
data.Result = extractRawField(payload, "result")
}
return data
}
// extractRawField returns the JSON-encoded value of a top-level field.
// Returns "" if absent or on error.
func extractRawField(payload []byte, field string) string {
var v map[string]json.RawMessage
if err := json.Unmarshal(payload, &v); err != nil {
return ""
}
raw, ok := v[field]
if !ok {
return ""
}
return string(raw)
}

View File

@@ -0,0 +1,223 @@
package hook
import (
"context"
"errors"
"testing"
"somegit.dev/Owlibou/gnoma/internal/message"
"somegit.dev/Owlibou/gnoma/internal/stream"
)
// mockStreamer implements Streamer by returning a pre-built stream or an error.
type mockStreamer struct {
s stream.Stream
err error
}
func (m *mockStreamer) Stream(_ context.Context, prompt string) (stream.Stream, error) {
return m.s, m.err
}
// textStream returns a single-event stream with the given text.
func textStream(text string) stream.Stream {
events := []stream.Event{
{Type: stream.EventTextDelta, Text: text},
{Type: stream.EventTextDelta, StopReason: message.StopEndTurn},
}
return &sliceStream{events: events}
}
type sliceStream struct {
events []stream.Event
idx int
}
func (s *sliceStream) Next() bool { s.idx++; return s.idx <= len(s.events) }
func (s *sliceStream) Current() stream.Event { return s.events[s.idx-1] }
func (s *sliceStream) Err() error { return nil }
func (s *sliceStream) Close() error { return nil }
// --- Template rendering tests ---
func TestRenderTemplate_EventVar(t *testing.T) {
got, err := renderTemplate("event is {{.Event}}", TemplateData{Event: "pre_tool_use"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "event is pre_tool_use" {
t.Errorf("got %q", got)
}
}
func TestRenderTemplate_AllVars(t *testing.T) {
tmpl := "{{.Event}} {{.Tool}} {{.Args}} {{.Result}}"
data := TemplateData{Event: "pre_tool_use", Tool: "bash", Args: `{"cmd":"ls"}`, Result: ""}
got, err := renderTemplate(tmpl, data)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != `pre_tool_use bash {"cmd":"ls"} ` {
t.Errorf("got %q", got)
}
}
func TestRenderTemplate_NonToolEvent_EmptyToolFields(t *testing.T) {
tmpl := "[{{.Tool}}][{{.Args}}][{{.Result}}]"
data := TemplateData{Event: "session_start"} // Tool/Args/Result are zero values
got, err := renderTemplate(tmpl, data)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "[][][]" {
t.Errorf("got %q", got)
}
}
func TestRenderTemplate_InvalidTemplate(t *testing.T) {
_, err := renderTemplate("{{.Unknown field", TemplateData{})
if err == nil {
t.Error("expected error for invalid template")
}
}
// --- parseDecision tests ---
func TestParseDecision_ALLOW(t *testing.T) {
if got := parseDecision("The action is ALLOW."); got != Allow {
t.Errorf("got %v, want Allow", got)
}
}
func TestParseDecision_DENY(t *testing.T) {
if got := parseDecision("I must DENY this request."); got != Deny {
t.Errorf("got %v, want Deny", got)
}
}
func TestParseDecision_NoMatch(t *testing.T) {
if got := parseDecision("I don't know."); got != Skip {
t.Errorf("got %v, want Skip", got)
}
}
func TestParseDecision_CaseInsensitive(t *testing.T) {
cases := []struct {
text string
want Action
}{
{"allow", Allow},
{"Allow", Allow},
{"ALLOW", Allow},
{"deny", Deny},
{"Deny", Deny},
{"DENY", Deny},
}
for _, tt := range cases {
if got := parseDecision(tt.text); got != tt.want {
t.Errorf("parseDecision(%q) = %v, want %v", tt.text, got, tt.want)
}
}
}
func TestParseDecision_FirstMatchWins(t *testing.T) {
// "DENY" appears before "ALLOW" → Deny
if got := parseDecision("I will DENY this, not ALLOW."); got != Deny {
t.Errorf("got %v, want Deny (first match)", got)
}
}
// --- PromptExecutor tests ---
func TestPromptExecutor_ResponseALLOW(t *testing.T) {
def := HookDef{Name: "test", Event: PreToolUse, Command: CommandTypePrompt, Exec: "Is this safe? ALLOW or DENY."}
ex := NewPromptExecutor(def, &mockStreamer{s: textStream("This is safe. ALLOW.")})
result, err := ex.Execute(context.Background(), MarshalPreToolPayload("bash", nil))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Action != Allow {
t.Errorf("action = %v, want Allow", result.Action)
}
}
func TestPromptExecutor_ResponseDENY(t *testing.T) {
def := HookDef{Name: "test", Event: PreToolUse, Command: CommandTypePrompt, Exec: "Is this safe? ALLOW or DENY."}
ex := NewPromptExecutor(def, &mockStreamer{s: textStream("This is dangerous. DENY.")})
result, err := ex.Execute(context.Background(), MarshalPreToolPayload("bash", nil))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Action != Deny {
t.Errorf("action = %v, want Deny", result.Action)
}
}
func TestPromptExecutor_ResponseNoMatch_Skip(t *testing.T) {
def := HookDef{Name: "test", Event: PreToolUse, Command: CommandTypePrompt, Exec: "Review this."}
ex := NewPromptExecutor(def, &mockStreamer{s: textStream("I'm not sure what to do.")})
result, err := ex.Execute(context.Background(), MarshalPreToolPayload("bash", nil))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Action != Skip {
t.Errorf("action = %v, want Skip", result.Action)
}
}
func TestPromptExecutor_TemplateRendered(t *testing.T) {
// Verify template vars are substituted — use a streamer that captures the prompt.
var capturedPrompt string
capturingStreamer := &capturingStreamer{response: "ALLOW"}
def := HookDef{
Name: "test",
Event: PreToolUse,
Command: CommandTypePrompt,
Exec: "Tool={{.Tool}} Event={{.Event}}",
}
ex := NewPromptExecutor(def, capturingStreamer)
ex.Execute(context.Background(), MarshalPreToolPayload("bash", nil))
capturedPrompt = capturingStreamer.prompt
if capturedPrompt == "" {
t.Fatal("prompt not captured")
}
if capturedPrompt != "Tool=bash Event=pre_tool_use" {
t.Errorf("prompt = %q", capturedPrompt)
}
}
func TestPromptExecutor_OutputIsFullResponse(t *testing.T) {
def := HookDef{Name: "test", Event: PreToolUse, Command: CommandTypePrompt, Exec: "Review."}
response := "After analysis, ALLOW this operation."
ex := NewPromptExecutor(def, &mockStreamer{s: textStream(response)})
result, _ := ex.Execute(context.Background(), MarshalPreToolPayload("bash", nil))
if result.Error != nil {
t.Fatalf("unexpected error: %v", result.Error)
}
// Output field carries the full LLM response text (for observability)
if string(result.Output) != response {
t.Errorf("Output = %q, want %q", result.Output, response)
}
}
func TestPromptExecutor_StreamerError(t *testing.T) {
def := HookDef{Name: "test", Event: PreToolUse, Command: CommandTypePrompt, Exec: "Review."}
ex := NewPromptExecutor(def, &mockStreamer{err: errors.New("provider unavailable")})
result, err := ex.Execute(context.Background(), MarshalPreToolPayload("bash", nil))
if err == nil {
t.Fatal("expected error")
}
// fail_open=false (default) → Deny on error; but error is returned, caller (Dispatcher) applies policy
_ = result
}
// capturingStreamer records the prompt it was called with.
type capturingStreamer struct {
prompt string
response string
}
func (c *capturingStreamer) Stream(_ context.Context, prompt string) (stream.Stream, error) {
c.prompt = prompt
return textStream(c.response), nil
}