diff --git a/internal/hook/payload.go b/internal/hook/payload.go new file mode 100644 index 0000000..27a40a4 --- /dev/null +++ b/internal/hook/payload.go @@ -0,0 +1,152 @@ +package hook + +import ( + "encoding/json" + "fmt" +) + +// MarshalPreToolPayload builds the stdin payload for a PreToolUse hook. +func MarshalPreToolPayload(tool string, args json.RawMessage) []byte { + b, _ := json.Marshal(map[string]any{ + "event": "pre_tool_use", + "tool": tool, + "args": args, + }) + return b +} + +// MarshalPostToolPayload builds the stdin payload for a PostToolUse hook. +func MarshalPostToolPayload(tool string, args json.RawMessage, output string, metadata map[string]any) []byte { + b, _ := json.Marshal(map[string]any{ + "event": "post_tool_use", + "tool": tool, + "args": args, + "result": map[string]any{ + "output": output, + "metadata": metadata, + }, + }) + return b +} + +// MarshalSessionStartPayload builds the stdin payload for a SessionStart hook. +func MarshalSessionStartPayload(sessionID, mode string) []byte { + b, _ := json.Marshal(map[string]any{ + "event": "session_start", + "session_id": sessionID, + "mode": mode, + }) + return b +} + +// MarshalSessionEndPayload builds the stdin payload for a SessionEnd hook. +func MarshalSessionEndPayload(sessionID string, turns int) []byte { + b, _ := json.Marshal(map[string]any{ + "event": "session_end", + "session_id": sessionID, + "turns": turns, + }) + return b +} + +// MarshalPreCompactPayload builds the stdin payload for a PreCompact hook. +func MarshalPreCompactPayload(messageCount, tokenEstimate int) []byte { + b, _ := json.Marshal(map[string]any{ + "event": "pre_compact", + "message_count": messageCount, + "token_estimate": tokenEstimate, + }) + return b +} + +// MarshalStopPayload builds the stdin payload for a Stop hook. +func MarshalStopPayload(reason string) []byte { + b, _ := json.Marshal(map[string]any{ + "event": "stop", + "reason": reason, + }) + return b +} + +// ExtractToolName extracts the "tool" field from a hook payload. +// Returns "" for non-tool events or malformed payloads. +func ExtractToolName(payload []byte) string { + var v struct { + Tool string `json:"tool"` + } + if err := json.Unmarshal(payload, &v); err != nil { + return "" + } + return v.Tool +} + +// hookOutput is the JSON structure a hook may write to stdout. +type hookOutput struct { + Action string `json:"action"` + Transformed json.RawMessage `json:"transformed"` +} + +// ParseHookOutput parses hook stdout and exit code into an Action and optional +// transformed payload. JSON "action" field overrides the exit code when present. +// Empty stdout falls back to exit code alone. +func ParseHookOutput(stdout []byte, exitCode int) (Action, json.RawMessage, error) { + if len(stdout) == 0 { + action, err := ParseAction(exitCode) + return action, nil, err + } + + var out hookOutput + if err := json.Unmarshal(stdout, &out); err != nil { + return 0, nil, fmt.Errorf("hook: invalid stdout JSON: %w", err) + } + + var action Action + if out.Action != "" { + var err error + action, err = parseActionString(out.Action) + if err != nil { + return 0, nil, err + } + } else { + var err error + action, err = ParseAction(exitCode) + if err != nil { + return 0, nil, err + } + } + + var transformed json.RawMessage + if len(out.Transformed) > 0 { + transformed = out.Transformed + } + return action, transformed, nil +} + +// parseActionString maps a JSON "action" string to an Action. +func parseActionString(s string) (Action, error) { + switch s { + case "allow": + return Allow, nil + case "deny": + return Deny, nil + case "skip": + return Skip, nil + default: + return 0, fmt.Errorf("hook: unknown action string %q", s) + } +} + +// ExtractTransformedOutput extracts the "output" string from a PostToolUse +// transformed payload. Returns "" if the payload is nil or malformed. +func ExtractTransformedOutput(transformed json.RawMessage) string { + if transformed == nil { + return "" + } + var v struct { + Output string `json:"output"` + } + if err := json.Unmarshal(transformed, &v); err != nil { + return "" + } + return v.Output +} diff --git a/internal/hook/payload_test.go b/internal/hook/payload_test.go new file mode 100644 index 0000000..845da88 --- /dev/null +++ b/internal/hook/payload_test.go @@ -0,0 +1,231 @@ +package hook + +import ( + "encoding/json" + "testing" +) + +func TestMarshalPreToolPayload(t *testing.T) { + args := json.RawMessage(`{"command":"ls -la"}`) + payload := MarshalPreToolPayload("bash", args) + + var got map[string]any + if err := json.Unmarshal(payload, &got); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + if got["event"] != "pre_tool_use" { + t.Errorf("event = %q, want %q", got["event"], "pre_tool_use") + } + if got["tool"] != "bash" { + t.Errorf("tool = %q, want %q", got["tool"], "bash") + } + if got["args"] == nil { + t.Error("args field missing") + } +} + +func TestMarshalPostToolPayload(t *testing.T) { + args := json.RawMessage(`{"command":"ls"}`) + payload := MarshalPostToolPayload("bash", args, "file1\nfile2", nil) + + var got map[string]any + if err := json.Unmarshal(payload, &got); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + if got["event"] != "post_tool_use" { + t.Errorf("event = %q, want %q", got["event"], "post_tool_use") + } + if got["tool"] != "bash" { + t.Errorf("tool = %q, want %q", got["tool"], "bash") + } + result, ok := got["result"].(map[string]any) + if !ok { + t.Fatal("result field missing or wrong type") + } + if result["output"] != "file1\nfile2" { + t.Errorf("result.output = %q, want %q", result["output"], "file1\nfile2") + } +} + +func TestMarshalSessionStartPayload(t *testing.T) { + payload := MarshalSessionStartPayload("abc-123", "tui") + + var got map[string]any + if err := json.Unmarshal(payload, &got); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + if got["event"] != "session_start" { + t.Errorf("event = %q, want %q", got["event"], "session_start") + } + if got["session_id"] != "abc-123" { + t.Errorf("session_id = %q, want %q", got["session_id"], "abc-123") + } + if got["mode"] != "tui" { + t.Errorf("mode = %q, want %q", got["mode"], "tui") + } +} + +func TestMarshalSessionEndPayload(t *testing.T) { + payload := MarshalSessionEndPayload("abc-123", 42) + + var got map[string]any + if err := json.Unmarshal(payload, &got); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + if got["event"] != "session_end" { + t.Errorf("event = %q, want %q", got["event"], "session_end") + } + if got["session_id"] != "abc-123" { + t.Errorf("session_id = %q, want %q", got["session_id"], "abc-123") + } + if int(got["turns"].(float64)) != 42 { + t.Errorf("turns = %v, want 42", got["turns"]) + } +} + +func TestMarshalPreCompactPayload(t *testing.T) { + payload := MarshalPreCompactPayload(87, 120000) + + var got map[string]any + if err := json.Unmarshal(payload, &got); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + if got["event"] != "pre_compact" { + t.Errorf("event = %q, want %q", got["event"], "pre_compact") + } + if int(got["message_count"].(float64)) != 87 { + t.Errorf("message_count = %v, want 87", got["message_count"]) + } + if int(got["token_estimate"].(float64)) != 120000 { + t.Errorf("token_estimate = %v, want 120000", got["token_estimate"]) + } +} + +func TestMarshalStopPayload(t *testing.T) { + payload := MarshalStopPayload("max_turns") + + var got map[string]any + if err := json.Unmarshal(payload, &got); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + if got["event"] != "stop" { + t.Errorf("event = %q, want %q", got["event"], "stop") + } + if got["reason"] != "max_turns" { + t.Errorf("reason = %q, want %q", got["reason"], "max_turns") + } +} + +func TestExtractToolName(t *testing.T) { + tests := []struct { + name string + payload []byte + want string + }{ + { + "pre_tool_use payload", + []byte(`{"event":"pre_tool_use","tool":"bash","args":{}}`), + "bash", + }, + { + "post_tool_use payload", + []byte(`{"event":"post_tool_use","tool":"fs.read","args":{}}`), + "fs.read", + }, + { + "session_start has no tool", + []byte(`{"event":"session_start","session_id":"x","mode":"tui"}`), + "", + }, + { + "empty payload", + []byte(`{}`), + "", + }, + { + "malformed JSON", + []byte(`not json`), + "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ExtractToolName(tt.payload); got != tt.want { + t.Errorf("ExtractToolName() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestParseHookOutput_JSONActionOverridesExitCode(t *testing.T) { + // stdout says deny, exit code 0 — JSON wins + stdout := []byte(`{"action":"deny","transformed":{"command":"safe"}}`) + action, transformed, err := ParseHookOutput(stdout, 0) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if action != Deny { + t.Errorf("action = %v, want Deny", action) + } + if transformed == nil { + t.Error("transformed should not be nil") + } +} + +func TestParseHookOutput_EmptyStdoutFallsBackToExitCode(t *testing.T) { + tests := []struct { + exitCode int + want Action + }{ + {0, Allow}, + {1, Skip}, + {2, Deny}, + } + for _, tt := range tests { + action, transformed, err := ParseHookOutput(nil, tt.exitCode) + if err != nil { + t.Errorf("exit %d: unexpected error: %v", tt.exitCode, err) + continue + } + if action != tt.want { + t.Errorf("exit %d: action = %v, want %v", tt.exitCode, action, tt.want) + } + if transformed != nil { + t.Errorf("exit %d: expected nil transformed", tt.exitCode) + } + } +} + +func TestParseHookOutput_MalformedJSON(t *testing.T) { + // non-empty stdout that isn't valid JSON falls back to exit code + _, _, err := ParseHookOutput([]byte("not json"), 0) + if err == nil { + t.Error("expected error for malformed JSON stdout") + } +} + +func TestParseHookOutput_AllowString(t *testing.T) { + stdout := []byte(`{"action":"allow"}`) + action, _, err := ParseHookOutput(stdout, 2) // exit 2 but JSON says allow + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if action != Allow { + t.Errorf("action = %v, want Allow", action) + } +} + +func TestExtractTransformedOutput(t *testing.T) { + transformed := json.RawMessage(`{"output":"rewritten result","metadata":{"key":"val"}}`) + got := ExtractTransformedOutput(transformed) + if got != "rewritten result" { + t.Errorf("ExtractTransformedOutput() = %q, want %q", got, "rewritten result") + } +} + +func TestExtractTransformedOutput_Empty(t *testing.T) { + got := ExtractTransformedOutput(nil) + if got != "" { + t.Errorf("ExtractTransformedOutput(nil) = %q, want %q", got, "") + } +}