feat: hook payload marshal/unmarshal helpers
This commit is contained in:
152
internal/hook/payload.go
Normal file
152
internal/hook/payload.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package hook
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// MarshalPreToolPayload builds the stdin payload for a PreToolUse hook.
|
||||
func MarshalPreToolPayload(tool string, args json.RawMessage) []byte {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"event": "pre_tool_use",
|
||||
"tool": tool,
|
||||
"args": args,
|
||||
})
|
||||
return b
|
||||
}
|
||||
|
||||
// MarshalPostToolPayload builds the stdin payload for a PostToolUse hook.
|
||||
func MarshalPostToolPayload(tool string, args json.RawMessage, output string, metadata map[string]any) []byte {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"event": "post_tool_use",
|
||||
"tool": tool,
|
||||
"args": args,
|
||||
"result": map[string]any{
|
||||
"output": output,
|
||||
"metadata": metadata,
|
||||
},
|
||||
})
|
||||
return b
|
||||
}
|
||||
|
||||
// MarshalSessionStartPayload builds the stdin payload for a SessionStart hook.
|
||||
func MarshalSessionStartPayload(sessionID, mode string) []byte {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"event": "session_start",
|
||||
"session_id": sessionID,
|
||||
"mode": mode,
|
||||
})
|
||||
return b
|
||||
}
|
||||
|
||||
// MarshalSessionEndPayload builds the stdin payload for a SessionEnd hook.
|
||||
func MarshalSessionEndPayload(sessionID string, turns int) []byte {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"event": "session_end",
|
||||
"session_id": sessionID,
|
||||
"turns": turns,
|
||||
})
|
||||
return b
|
||||
}
|
||||
|
||||
// MarshalPreCompactPayload builds the stdin payload for a PreCompact hook.
|
||||
func MarshalPreCompactPayload(messageCount, tokenEstimate int) []byte {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"event": "pre_compact",
|
||||
"message_count": messageCount,
|
||||
"token_estimate": tokenEstimate,
|
||||
})
|
||||
return b
|
||||
}
|
||||
|
||||
// MarshalStopPayload builds the stdin payload for a Stop hook.
|
||||
func MarshalStopPayload(reason string) []byte {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"event": "stop",
|
||||
"reason": reason,
|
||||
})
|
||||
return b
|
||||
}
|
||||
|
||||
// ExtractToolName extracts the "tool" field from a hook payload.
|
||||
// Returns "" for non-tool events or malformed payloads.
|
||||
func ExtractToolName(payload []byte) string {
|
||||
var v struct {
|
||||
Tool string `json:"tool"`
|
||||
}
|
||||
if err := json.Unmarshal(payload, &v); err != nil {
|
||||
return ""
|
||||
}
|
||||
return v.Tool
|
||||
}
|
||||
|
||||
// hookOutput is the JSON structure a hook may write to stdout.
|
||||
type hookOutput struct {
|
||||
Action string `json:"action"`
|
||||
Transformed json.RawMessage `json:"transformed"`
|
||||
}
|
||||
|
||||
// ParseHookOutput parses hook stdout and exit code into an Action and optional
|
||||
// transformed payload. JSON "action" field overrides the exit code when present.
|
||||
// Empty stdout falls back to exit code alone.
|
||||
func ParseHookOutput(stdout []byte, exitCode int) (Action, json.RawMessage, error) {
|
||||
if len(stdout) == 0 {
|
||||
action, err := ParseAction(exitCode)
|
||||
return action, nil, err
|
||||
}
|
||||
|
||||
var out hookOutput
|
||||
if err := json.Unmarshal(stdout, &out); err != nil {
|
||||
return 0, nil, fmt.Errorf("hook: invalid stdout JSON: %w", err)
|
||||
}
|
||||
|
||||
var action Action
|
||||
if out.Action != "" {
|
||||
var err error
|
||||
action, err = parseActionString(out.Action)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
} else {
|
||||
var err error
|
||||
action, err = ParseAction(exitCode)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var transformed json.RawMessage
|
||||
if len(out.Transformed) > 0 {
|
||||
transformed = out.Transformed
|
||||
}
|
||||
return action, transformed, nil
|
||||
}
|
||||
|
||||
// parseActionString maps a JSON "action" string to an Action.
|
||||
func parseActionString(s string) (Action, error) {
|
||||
switch s {
|
||||
case "allow":
|
||||
return Allow, nil
|
||||
case "deny":
|
||||
return Deny, nil
|
||||
case "skip":
|
||||
return Skip, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("hook: unknown action string %q", s)
|
||||
}
|
||||
}
|
||||
|
||||
// ExtractTransformedOutput extracts the "output" string from a PostToolUse
|
||||
// transformed payload. Returns "" if the payload is nil or malformed.
|
||||
func ExtractTransformedOutput(transformed json.RawMessage) string {
|
||||
if transformed == nil {
|
||||
return ""
|
||||
}
|
||||
var v struct {
|
||||
Output string `json:"output"`
|
||||
}
|
||||
if err := json.Unmarshal(transformed, &v); err != nil {
|
||||
return ""
|
||||
}
|
||||
return v.Output
|
||||
}
|
||||
231
internal/hook/payload_test.go
Normal file
231
internal/hook/payload_test.go
Normal file
@@ -0,0 +1,231 @@
|
||||
package hook
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMarshalPreToolPayload(t *testing.T) {
|
||||
args := json.RawMessage(`{"command":"ls -la"}`)
|
||||
payload := MarshalPreToolPayload("bash", args)
|
||||
|
||||
var got map[string]any
|
||||
if err := json.Unmarshal(payload, &got); err != nil {
|
||||
t.Fatalf("invalid JSON: %v", err)
|
||||
}
|
||||
if got["event"] != "pre_tool_use" {
|
||||
t.Errorf("event = %q, want %q", got["event"], "pre_tool_use")
|
||||
}
|
||||
if got["tool"] != "bash" {
|
||||
t.Errorf("tool = %q, want %q", got["tool"], "bash")
|
||||
}
|
||||
if got["args"] == nil {
|
||||
t.Error("args field missing")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalPostToolPayload(t *testing.T) {
|
||||
args := json.RawMessage(`{"command":"ls"}`)
|
||||
payload := MarshalPostToolPayload("bash", args, "file1\nfile2", nil)
|
||||
|
||||
var got map[string]any
|
||||
if err := json.Unmarshal(payload, &got); err != nil {
|
||||
t.Fatalf("invalid JSON: %v", err)
|
||||
}
|
||||
if got["event"] != "post_tool_use" {
|
||||
t.Errorf("event = %q, want %q", got["event"], "post_tool_use")
|
||||
}
|
||||
if got["tool"] != "bash" {
|
||||
t.Errorf("tool = %q, want %q", got["tool"], "bash")
|
||||
}
|
||||
result, ok := got["result"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("result field missing or wrong type")
|
||||
}
|
||||
if result["output"] != "file1\nfile2" {
|
||||
t.Errorf("result.output = %q, want %q", result["output"], "file1\nfile2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalSessionStartPayload(t *testing.T) {
|
||||
payload := MarshalSessionStartPayload("abc-123", "tui")
|
||||
|
||||
var got map[string]any
|
||||
if err := json.Unmarshal(payload, &got); err != nil {
|
||||
t.Fatalf("invalid JSON: %v", err)
|
||||
}
|
||||
if got["event"] != "session_start" {
|
||||
t.Errorf("event = %q, want %q", got["event"], "session_start")
|
||||
}
|
||||
if got["session_id"] != "abc-123" {
|
||||
t.Errorf("session_id = %q, want %q", got["session_id"], "abc-123")
|
||||
}
|
||||
if got["mode"] != "tui" {
|
||||
t.Errorf("mode = %q, want %q", got["mode"], "tui")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalSessionEndPayload(t *testing.T) {
|
||||
payload := MarshalSessionEndPayload("abc-123", 42)
|
||||
|
||||
var got map[string]any
|
||||
if err := json.Unmarshal(payload, &got); err != nil {
|
||||
t.Fatalf("invalid JSON: %v", err)
|
||||
}
|
||||
if got["event"] != "session_end" {
|
||||
t.Errorf("event = %q, want %q", got["event"], "session_end")
|
||||
}
|
||||
if got["session_id"] != "abc-123" {
|
||||
t.Errorf("session_id = %q, want %q", got["session_id"], "abc-123")
|
||||
}
|
||||
if int(got["turns"].(float64)) != 42 {
|
||||
t.Errorf("turns = %v, want 42", got["turns"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalPreCompactPayload(t *testing.T) {
|
||||
payload := MarshalPreCompactPayload(87, 120000)
|
||||
|
||||
var got map[string]any
|
||||
if err := json.Unmarshal(payload, &got); err != nil {
|
||||
t.Fatalf("invalid JSON: %v", err)
|
||||
}
|
||||
if got["event"] != "pre_compact" {
|
||||
t.Errorf("event = %q, want %q", got["event"], "pre_compact")
|
||||
}
|
||||
if int(got["message_count"].(float64)) != 87 {
|
||||
t.Errorf("message_count = %v, want 87", got["message_count"])
|
||||
}
|
||||
if int(got["token_estimate"].(float64)) != 120000 {
|
||||
t.Errorf("token_estimate = %v, want 120000", got["token_estimate"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalStopPayload(t *testing.T) {
|
||||
payload := MarshalStopPayload("max_turns")
|
||||
|
||||
var got map[string]any
|
||||
if err := json.Unmarshal(payload, &got); err != nil {
|
||||
t.Fatalf("invalid JSON: %v", err)
|
||||
}
|
||||
if got["event"] != "stop" {
|
||||
t.Errorf("event = %q, want %q", got["event"], "stop")
|
||||
}
|
||||
if got["reason"] != "max_turns" {
|
||||
t.Errorf("reason = %q, want %q", got["reason"], "max_turns")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractToolName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
payload []byte
|
||||
want string
|
||||
}{
|
||||
{
|
||||
"pre_tool_use payload",
|
||||
[]byte(`{"event":"pre_tool_use","tool":"bash","args":{}}`),
|
||||
"bash",
|
||||
},
|
||||
{
|
||||
"post_tool_use payload",
|
||||
[]byte(`{"event":"post_tool_use","tool":"fs.read","args":{}}`),
|
||||
"fs.read",
|
||||
},
|
||||
{
|
||||
"session_start has no tool",
|
||||
[]byte(`{"event":"session_start","session_id":"x","mode":"tui"}`),
|
||||
"",
|
||||
},
|
||||
{
|
||||
"empty payload",
|
||||
[]byte(`{}`),
|
||||
"",
|
||||
},
|
||||
{
|
||||
"malformed JSON",
|
||||
[]byte(`not json`),
|
||||
"",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := ExtractToolName(tt.payload); got != tt.want {
|
||||
t.Errorf("ExtractToolName() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseHookOutput_JSONActionOverridesExitCode(t *testing.T) {
|
||||
// stdout says deny, exit code 0 — JSON wins
|
||||
stdout := []byte(`{"action":"deny","transformed":{"command":"safe"}}`)
|
||||
action, transformed, err := ParseHookOutput(stdout, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if action != Deny {
|
||||
t.Errorf("action = %v, want Deny", action)
|
||||
}
|
||||
if transformed == nil {
|
||||
t.Error("transformed should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseHookOutput_EmptyStdoutFallsBackToExitCode(t *testing.T) {
|
||||
tests := []struct {
|
||||
exitCode int
|
||||
want Action
|
||||
}{
|
||||
{0, Allow},
|
||||
{1, Skip},
|
||||
{2, Deny},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
action, transformed, err := ParseHookOutput(nil, tt.exitCode)
|
||||
if err != nil {
|
||||
t.Errorf("exit %d: unexpected error: %v", tt.exitCode, err)
|
||||
continue
|
||||
}
|
||||
if action != tt.want {
|
||||
t.Errorf("exit %d: action = %v, want %v", tt.exitCode, action, tt.want)
|
||||
}
|
||||
if transformed != nil {
|
||||
t.Errorf("exit %d: expected nil transformed", tt.exitCode)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseHookOutput_MalformedJSON(t *testing.T) {
|
||||
// non-empty stdout that isn't valid JSON falls back to exit code
|
||||
_, _, err := ParseHookOutput([]byte("not json"), 0)
|
||||
if err == nil {
|
||||
t.Error("expected error for malformed JSON stdout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseHookOutput_AllowString(t *testing.T) {
|
||||
stdout := []byte(`{"action":"allow"}`)
|
||||
action, _, err := ParseHookOutput(stdout, 2) // exit 2 but JSON says allow
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if action != Allow {
|
||||
t.Errorf("action = %v, want Allow", action)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractTransformedOutput(t *testing.T) {
|
||||
transformed := json.RawMessage(`{"output":"rewritten result","metadata":{"key":"val"}}`)
|
||||
got := ExtractTransformedOutput(transformed)
|
||||
if got != "rewritten result" {
|
||||
t.Errorf("ExtractTransformedOutput() = %q, want %q", got, "rewritten result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractTransformedOutput_Empty(t *testing.T) {
|
||||
got := ExtractTransformedOutput(nil)
|
||||
if got != "" {
|
||||
t.Errorf("ExtractTransformedOutput(nil) = %q, want %q", got, "")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user