Files
gnoma/internal/hook/prompt.go

136 lines
3.5 KiB
Go

package hook
import (
"bytes"
"context"
"encoding/json"
"fmt"
"strings"
"text/template"
"time"
"somegit.dev/Owlibou/gnoma/internal/message"
"somegit.dev/Owlibou/gnoma/internal/stream"
)
// TemplateData holds the variables available in hook prompt templates.
type TemplateData struct {
Event string
Tool string
Args string
Result string
}
// renderTemplate executes a text/template with the given data.
func renderTemplate(tmpl string, data TemplateData) (string, error) {
t, err := template.New("hook").Parse(tmpl)
if err != nil {
return "", fmt.Errorf("hook: template parse error: %w", err)
}
var buf bytes.Buffer
if err := t.Execute(&buf, data); err != nil {
return "", fmt.Errorf("hook: template execute error: %w", err)
}
return buf.String(), nil
}
// parseDecision scans text for the first case-insensitive occurrence of
// "ALLOW" or "DENY". Returns Skip if neither is found.
func parseDecision(text string) Action {
upper := strings.ToUpper(text)
ai := strings.Index(upper, "ALLOW")
di := strings.Index(upper, "DENY")
switch {
case ai >= 0 && (di < 0 || ai < di):
return Allow
case di >= 0:
return Deny
default:
return Skip
}
}
// Streamer is the minimal interface PromptExecutor needs from the router.
// *router.Router satisfies this interface via an adapter in main.go.
type Streamer interface {
Stream(ctx context.Context, prompt string) (stream.Stream, error)
}
// PromptExecutor sends a templated prompt to an LLM and parses ALLOW/DENY
// from the response.
type PromptExecutor struct {
def HookDef
streamer Streamer
}
// NewPromptExecutor constructs a PromptExecutor.
func NewPromptExecutor(def HookDef, streamer Streamer) *PromptExecutor {
return &PromptExecutor{def: def, streamer: streamer}
}
// Execute renders the template, sends the prompt, and parses the response.
func (p *PromptExecutor) Execute(ctx context.Context, payload []byte) (HookResult, error) {
data := templateDataFromPayload(payload, p.def.Event)
prompt, err := renderTemplate(p.def.Exec, data)
if err != nil {
return HookResult{}, fmt.Errorf("hook %q: %w", p.def.Name, err)
}
start := time.Now()
s, err := p.streamer.Stream(ctx, prompt)
if err != nil {
return HookResult{}, fmt.Errorf("hook %q: stream error: %w", p.def.Name, err)
}
defer s.Close()
acc := stream.NewAccumulator()
var stopReason message.StopReason
var model string
for s.Next() {
evt := s.Current()
acc.Apply(evt)
if evt.StopReason != "" {
stopReason = evt.StopReason
model = evt.Model
}
}
if err := s.Err(); err != nil {
return HookResult{}, fmt.Errorf("hook %q: stream error: %w", p.def.Name, err)
}
resp := acc.Response(stopReason, model)
text := resp.Message.TextContent()
action := parseDecision(text)
return HookResult{
Action: action,
Output: []byte(text),
Duration: time.Since(start),
}, nil
}
// templateDataFromPayload builds TemplateData from a hook payload.
func templateDataFromPayload(payload []byte, event EventType) TemplateData {
data := TemplateData{Event: event.String()}
if event == PreToolUse || event == PostToolUse {
data.Tool = ExtractToolName(payload)
data.Args = extractRawField(payload, "args")
data.Result = extractRawField(payload, "result")
}
return data
}
// extractRawField returns the JSON-encoded value of a top-level field.
// Returns "" if absent or on error.
func extractRawField(payload []byte, field string) string {
var v map[string]json.RawMessage
if err := json.Unmarshal(payload, &v); err != nil {
return ""
}
raw, ok := v[field]
if !ok {
return ""
}
return string(raw)
}