feat: PromptExecutor — LLM-based hook evaluation via router
This commit is contained in:
135
internal/hook/prompt.go
Normal file
135
internal/hook/prompt.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package hook
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
"somegit.dev/Owlibou/gnoma/internal/stream"
|
||||
)
|
||||
|
||||
// TemplateData holds the variables available in hook prompt templates.
|
||||
type TemplateData struct {
|
||||
Event string
|
||||
Tool string
|
||||
Args string
|
||||
Result string
|
||||
}
|
||||
|
||||
// renderTemplate executes a text/template with the given data.
|
||||
func renderTemplate(tmpl string, data TemplateData) (string, error) {
|
||||
t, err := template.New("hook").Parse(tmpl)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("hook: template parse error: %w", err)
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
if err := t.Execute(&buf, data); err != nil {
|
||||
return "", fmt.Errorf("hook: template execute error: %w", err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// parseDecision scans text for the first case-insensitive occurrence of
|
||||
// "ALLOW" or "DENY". Returns Skip if neither is found.
|
||||
func parseDecision(text string) Action {
|
||||
upper := strings.ToUpper(text)
|
||||
ai := strings.Index(upper, "ALLOW")
|
||||
di := strings.Index(upper, "DENY")
|
||||
switch {
|
||||
case ai >= 0 && (di < 0 || ai < di):
|
||||
return Allow
|
||||
case di >= 0:
|
||||
return Deny
|
||||
default:
|
||||
return Skip
|
||||
}
|
||||
}
|
||||
|
||||
// Streamer is the minimal interface PromptExecutor needs from the router.
|
||||
// *router.Router satisfies this interface via an adapter in main.go.
|
||||
type Streamer interface {
|
||||
Stream(ctx context.Context, prompt string) (stream.Stream, error)
|
||||
}
|
||||
|
||||
// PromptExecutor sends a templated prompt to an LLM and parses ALLOW/DENY
|
||||
// from the response.
|
||||
type PromptExecutor struct {
|
||||
def HookDef
|
||||
streamer Streamer
|
||||
}
|
||||
|
||||
// NewPromptExecutor constructs a PromptExecutor.
|
||||
func NewPromptExecutor(def HookDef, streamer Streamer) *PromptExecutor {
|
||||
return &PromptExecutor{def: def, streamer: streamer}
|
||||
}
|
||||
|
||||
// Execute renders the template, sends the prompt, and parses the response.
|
||||
func (p *PromptExecutor) Execute(ctx context.Context, payload []byte) (HookResult, error) {
|
||||
data := templateDataFromPayload(payload, p.def.Event)
|
||||
prompt, err := renderTemplate(p.def.Exec, data)
|
||||
if err != nil {
|
||||
return HookResult{}, fmt.Errorf("hook %q: %w", p.def.Name, err)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
s, err := p.streamer.Stream(ctx, prompt)
|
||||
if err != nil {
|
||||
return HookResult{}, fmt.Errorf("hook %q: stream error: %w", p.def.Name, err)
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
acc := stream.NewAccumulator()
|
||||
var stopReason message.StopReason
|
||||
var model string
|
||||
for s.Next() {
|
||||
evt := s.Current()
|
||||
acc.Apply(evt)
|
||||
if evt.StopReason != "" {
|
||||
stopReason = evt.StopReason
|
||||
model = evt.Model
|
||||
}
|
||||
}
|
||||
if err := s.Err(); err != nil {
|
||||
return HookResult{}, fmt.Errorf("hook %q: stream error: %w", p.def.Name, err)
|
||||
}
|
||||
|
||||
resp := acc.Response(stopReason, model)
|
||||
text := resp.Message.TextContent()
|
||||
|
||||
action := parseDecision(text)
|
||||
return HookResult{
|
||||
Action: action,
|
||||
Output: []byte(text),
|
||||
Duration: time.Since(start),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// templateDataFromPayload builds TemplateData from a hook payload.
|
||||
func templateDataFromPayload(payload []byte, event EventType) TemplateData {
|
||||
data := TemplateData{Event: event.String()}
|
||||
if event == PreToolUse || event == PostToolUse {
|
||||
data.Tool = ExtractToolName(payload)
|
||||
data.Args = extractRawField(payload, "args")
|
||||
data.Result = extractRawField(payload, "result")
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// extractRawField returns the JSON-encoded value of a top-level field.
|
||||
// Returns "" if absent or on error.
|
||||
func extractRawField(payload []byte, field string) string {
|
||||
var v map[string]json.RawMessage
|
||||
if err := json.Unmarshal(payload, &v); err != nil {
|
||||
return ""
|
||||
}
|
||||
raw, ok := v[field]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return string(raw)
|
||||
}
|
||||
223
internal/hook/prompt_test.go
Normal file
223
internal/hook/prompt_test.go
Normal file
@@ -0,0 +1,223 @@
|
||||
package hook
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
"somegit.dev/Owlibou/gnoma/internal/stream"
|
||||
)
|
||||
|
||||
// mockStreamer implements Streamer by returning a pre-built stream or an error.
|
||||
type mockStreamer struct {
|
||||
s stream.Stream
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockStreamer) Stream(_ context.Context, prompt string) (stream.Stream, error) {
|
||||
return m.s, m.err
|
||||
}
|
||||
|
||||
// textStream returns a single-event stream with the given text.
|
||||
func textStream(text string) stream.Stream {
|
||||
events := []stream.Event{
|
||||
{Type: stream.EventTextDelta, Text: text},
|
||||
{Type: stream.EventTextDelta, StopReason: message.StopEndTurn},
|
||||
}
|
||||
return &sliceStream{events: events}
|
||||
}
|
||||
|
||||
type sliceStream struct {
|
||||
events []stream.Event
|
||||
idx int
|
||||
}
|
||||
|
||||
func (s *sliceStream) Next() bool { s.idx++; return s.idx <= len(s.events) }
|
||||
func (s *sliceStream) Current() stream.Event { return s.events[s.idx-1] }
|
||||
func (s *sliceStream) Err() error { return nil }
|
||||
func (s *sliceStream) Close() error { return nil }
|
||||
|
||||
// --- Template rendering tests ---
|
||||
|
||||
func TestRenderTemplate_EventVar(t *testing.T) {
|
||||
got, err := renderTemplate("event is {{.Event}}", TemplateData{Event: "pre_tool_use"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "event is pre_tool_use" {
|
||||
t.Errorf("got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderTemplate_AllVars(t *testing.T) {
|
||||
tmpl := "{{.Event}} {{.Tool}} {{.Args}} {{.Result}}"
|
||||
data := TemplateData{Event: "pre_tool_use", Tool: "bash", Args: `{"cmd":"ls"}`, Result: ""}
|
||||
got, err := renderTemplate(tmpl, data)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != `pre_tool_use bash {"cmd":"ls"} ` {
|
||||
t.Errorf("got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderTemplate_NonToolEvent_EmptyToolFields(t *testing.T) {
|
||||
tmpl := "[{{.Tool}}][{{.Args}}][{{.Result}}]"
|
||||
data := TemplateData{Event: "session_start"} // Tool/Args/Result are zero values
|
||||
got, err := renderTemplate(tmpl, data)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "[][][]" {
|
||||
t.Errorf("got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderTemplate_InvalidTemplate(t *testing.T) {
|
||||
_, err := renderTemplate("{{.Unknown field", TemplateData{})
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid template")
|
||||
}
|
||||
}
|
||||
|
||||
// --- parseDecision tests ---
|
||||
|
||||
func TestParseDecision_ALLOW(t *testing.T) {
|
||||
if got := parseDecision("The action is ALLOW."); got != Allow {
|
||||
t.Errorf("got %v, want Allow", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDecision_DENY(t *testing.T) {
|
||||
if got := parseDecision("I must DENY this request."); got != Deny {
|
||||
t.Errorf("got %v, want Deny", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDecision_NoMatch(t *testing.T) {
|
||||
if got := parseDecision("I don't know."); got != Skip {
|
||||
t.Errorf("got %v, want Skip", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDecision_CaseInsensitive(t *testing.T) {
|
||||
cases := []struct {
|
||||
text string
|
||||
want Action
|
||||
}{
|
||||
{"allow", Allow},
|
||||
{"Allow", Allow},
|
||||
{"ALLOW", Allow},
|
||||
{"deny", Deny},
|
||||
{"Deny", Deny},
|
||||
{"DENY", Deny},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
if got := parseDecision(tt.text); got != tt.want {
|
||||
t.Errorf("parseDecision(%q) = %v, want %v", tt.text, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDecision_FirstMatchWins(t *testing.T) {
|
||||
// "DENY" appears before "ALLOW" → Deny
|
||||
if got := parseDecision("I will DENY this, not ALLOW."); got != Deny {
|
||||
t.Errorf("got %v, want Deny (first match)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// --- PromptExecutor tests ---
|
||||
|
||||
func TestPromptExecutor_ResponseALLOW(t *testing.T) {
|
||||
def := HookDef{Name: "test", Event: PreToolUse, Command: CommandTypePrompt, Exec: "Is this safe? ALLOW or DENY."}
|
||||
ex := NewPromptExecutor(def, &mockStreamer{s: textStream("This is safe. ALLOW.")})
|
||||
result, err := ex.Execute(context.Background(), MarshalPreToolPayload("bash", nil))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.Action != Allow {
|
||||
t.Errorf("action = %v, want Allow", result.Action)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPromptExecutor_ResponseDENY(t *testing.T) {
|
||||
def := HookDef{Name: "test", Event: PreToolUse, Command: CommandTypePrompt, Exec: "Is this safe? ALLOW or DENY."}
|
||||
ex := NewPromptExecutor(def, &mockStreamer{s: textStream("This is dangerous. DENY.")})
|
||||
result, err := ex.Execute(context.Background(), MarshalPreToolPayload("bash", nil))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.Action != Deny {
|
||||
t.Errorf("action = %v, want Deny", result.Action)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPromptExecutor_ResponseNoMatch_Skip(t *testing.T) {
|
||||
def := HookDef{Name: "test", Event: PreToolUse, Command: CommandTypePrompt, Exec: "Review this."}
|
||||
ex := NewPromptExecutor(def, &mockStreamer{s: textStream("I'm not sure what to do.")})
|
||||
result, err := ex.Execute(context.Background(), MarshalPreToolPayload("bash", nil))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.Action != Skip {
|
||||
t.Errorf("action = %v, want Skip", result.Action)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPromptExecutor_TemplateRendered(t *testing.T) {
|
||||
// Verify template vars are substituted — use a streamer that captures the prompt.
|
||||
var capturedPrompt string
|
||||
capturingStreamer := &capturingStreamer{response: "ALLOW"}
|
||||
def := HookDef{
|
||||
Name: "test",
|
||||
Event: PreToolUse,
|
||||
Command: CommandTypePrompt,
|
||||
Exec: "Tool={{.Tool}} Event={{.Event}}",
|
||||
}
|
||||
ex := NewPromptExecutor(def, capturingStreamer)
|
||||
ex.Execute(context.Background(), MarshalPreToolPayload("bash", nil))
|
||||
capturedPrompt = capturingStreamer.prompt
|
||||
if capturedPrompt == "" {
|
||||
t.Fatal("prompt not captured")
|
||||
}
|
||||
if capturedPrompt != "Tool=bash Event=pre_tool_use" {
|
||||
t.Errorf("prompt = %q", capturedPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPromptExecutor_OutputIsFullResponse(t *testing.T) {
|
||||
def := HookDef{Name: "test", Event: PreToolUse, Command: CommandTypePrompt, Exec: "Review."}
|
||||
response := "After analysis, ALLOW this operation."
|
||||
ex := NewPromptExecutor(def, &mockStreamer{s: textStream(response)})
|
||||
result, _ := ex.Execute(context.Background(), MarshalPreToolPayload("bash", nil))
|
||||
if result.Error != nil {
|
||||
t.Fatalf("unexpected error: %v", result.Error)
|
||||
}
|
||||
// Output field carries the full LLM response text (for observability)
|
||||
if string(result.Output) != response {
|
||||
t.Errorf("Output = %q, want %q", result.Output, response)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPromptExecutor_StreamerError(t *testing.T) {
|
||||
def := HookDef{Name: "test", Event: PreToolUse, Command: CommandTypePrompt, Exec: "Review."}
|
||||
ex := NewPromptExecutor(def, &mockStreamer{err: errors.New("provider unavailable")})
|
||||
result, err := ex.Execute(context.Background(), MarshalPreToolPayload("bash", nil))
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
// fail_open=false (default) → Deny on error; but error is returned, caller (Dispatcher) applies policy
|
||||
_ = result
|
||||
}
|
||||
|
||||
// capturingStreamer records the prompt it was called with.
|
||||
type capturingStreamer struct {
|
||||
prompt string
|
||||
response string
|
||||
}
|
||||
|
||||
func (c *capturingStreamer) Stream(_ context.Context, prompt string) (stream.Stream, error) {
|
||||
c.prompt = prompt
|
||||
return textStream(c.response), nil
|
||||
}
|
||||
Reference in New Issue
Block a user