Files
gnoma/internal/hook/payload_test.go

232 lines
6.0 KiB
Go

package hook
import (
"encoding/json"
"testing"
)
func TestMarshalPreToolPayload(t *testing.T) {
args := json.RawMessage(`{"command":"ls -la"}`)
payload := MarshalPreToolPayload("bash", args)
var got map[string]any
if err := json.Unmarshal(payload, &got); err != nil {
t.Fatalf("invalid JSON: %v", err)
}
if got["event"] != "pre_tool_use" {
t.Errorf("event = %q, want %q", got["event"], "pre_tool_use")
}
if got["tool"] != "bash" {
t.Errorf("tool = %q, want %q", got["tool"], "bash")
}
if got["args"] == nil {
t.Error("args field missing")
}
}
func TestMarshalPostToolPayload(t *testing.T) {
args := json.RawMessage(`{"command":"ls"}`)
payload := MarshalPostToolPayload("bash", args, "file1\nfile2", nil)
var got map[string]any
if err := json.Unmarshal(payload, &got); err != nil {
t.Fatalf("invalid JSON: %v", err)
}
if got["event"] != "post_tool_use" {
t.Errorf("event = %q, want %q", got["event"], "post_tool_use")
}
if got["tool"] != "bash" {
t.Errorf("tool = %q, want %q", got["tool"], "bash")
}
result, ok := got["result"].(map[string]any)
if !ok {
t.Fatal("result field missing or wrong type")
}
if result["output"] != "file1\nfile2" {
t.Errorf("result.output = %q, want %q", result["output"], "file1\nfile2")
}
}
func TestMarshalSessionStartPayload(t *testing.T) {
payload := MarshalSessionStartPayload("abc-123", "tui")
var got map[string]any
if err := json.Unmarshal(payload, &got); err != nil {
t.Fatalf("invalid JSON: %v", err)
}
if got["event"] != "session_start" {
t.Errorf("event = %q, want %q", got["event"], "session_start")
}
if got["session_id"] != "abc-123" {
t.Errorf("session_id = %q, want %q", got["session_id"], "abc-123")
}
if got["mode"] != "tui" {
t.Errorf("mode = %q, want %q", got["mode"], "tui")
}
}
func TestMarshalSessionEndPayload(t *testing.T) {
payload := MarshalSessionEndPayload("abc-123", 42)
var got map[string]any
if err := json.Unmarshal(payload, &got); err != nil {
t.Fatalf("invalid JSON: %v", err)
}
if got["event"] != "session_end" {
t.Errorf("event = %q, want %q", got["event"], "session_end")
}
if got["session_id"] != "abc-123" {
t.Errorf("session_id = %q, want %q", got["session_id"], "abc-123")
}
if int(got["turns"].(float64)) != 42 {
t.Errorf("turns = %v, want 42", got["turns"])
}
}
func TestMarshalPreCompactPayload(t *testing.T) {
payload := MarshalPreCompactPayload(87, 120000)
var got map[string]any
if err := json.Unmarshal(payload, &got); err != nil {
t.Fatalf("invalid JSON: %v", err)
}
if got["event"] != "pre_compact" {
t.Errorf("event = %q, want %q", got["event"], "pre_compact")
}
if int(got["message_count"].(float64)) != 87 {
t.Errorf("message_count = %v, want 87", got["message_count"])
}
if int(got["token_estimate"].(float64)) != 120000 {
t.Errorf("token_estimate = %v, want 120000", got["token_estimate"])
}
}
func TestMarshalStopPayload(t *testing.T) {
payload := MarshalStopPayload("max_turns")
var got map[string]any
if err := json.Unmarshal(payload, &got); err != nil {
t.Fatalf("invalid JSON: %v", err)
}
if got["event"] != "stop" {
t.Errorf("event = %q, want %q", got["event"], "stop")
}
if got["reason"] != "max_turns" {
t.Errorf("reason = %q, want %q", got["reason"], "max_turns")
}
}
func TestExtractToolName(t *testing.T) {
tests := []struct {
name string
payload []byte
want string
}{
{
"pre_tool_use payload",
[]byte(`{"event":"pre_tool_use","tool":"bash","args":{}}`),
"bash",
},
{
"post_tool_use payload",
[]byte(`{"event":"post_tool_use","tool":"fs.read","args":{}}`),
"fs.read",
},
{
"session_start has no tool",
[]byte(`{"event":"session_start","session_id":"x","mode":"tui"}`),
"",
},
{
"empty payload",
[]byte(`{}`),
"",
},
{
"malformed JSON",
[]byte(`not json`),
"",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := ExtractToolName(tt.payload); got != tt.want {
t.Errorf("ExtractToolName() = %q, want %q", got, tt.want)
}
})
}
}
func TestParseHookOutput_JSONActionOverridesExitCode(t *testing.T) {
// stdout says deny, exit code 0 — JSON wins
stdout := []byte(`{"action":"deny","transformed":{"command":"safe"}}`)
action, transformed, err := ParseHookOutput(stdout, 0)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if action != Deny {
t.Errorf("action = %v, want Deny", action)
}
if transformed == nil {
t.Error("transformed should not be nil")
}
}
func TestParseHookOutput_EmptyStdoutFallsBackToExitCode(t *testing.T) {
tests := []struct {
exitCode int
want Action
}{
{0, Allow},
{1, Skip},
{2, Deny},
}
for _, tt := range tests {
action, transformed, err := ParseHookOutput(nil, tt.exitCode)
if err != nil {
t.Errorf("exit %d: unexpected error: %v", tt.exitCode, err)
continue
}
if action != tt.want {
t.Errorf("exit %d: action = %v, want %v", tt.exitCode, action, tt.want)
}
if transformed != nil {
t.Errorf("exit %d: expected nil transformed", tt.exitCode)
}
}
}
func TestParseHookOutput_MalformedJSON(t *testing.T) {
// non-empty stdout that isn't valid JSON falls back to exit code
_, _, err := ParseHookOutput([]byte("not json"), 0)
if err == nil {
t.Error("expected error for malformed JSON stdout")
}
}
func TestParseHookOutput_AllowString(t *testing.T) {
stdout := []byte(`{"action":"allow"}`)
action, _, err := ParseHookOutput(stdout, 2) // exit 2 but JSON says allow
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if action != Allow {
t.Errorf("action = %v, want Allow", action)
}
}
func TestExtractTransformedOutput(t *testing.T) {
transformed := json.RawMessage(`{"output":"rewritten result","metadata":{"key":"val"}}`)
got := ExtractTransformedOutput(transformed)
if got != "rewritten result" {
t.Errorf("ExtractTransformedOutput() = %q, want %q", got, "rewritten result")
}
}
func TestExtractTransformedOutput_Empty(t *testing.T) {
got := ExtractTransformedOutput(nil)
if got != "" {
t.Errorf("ExtractTransformedOutput(nil) = %q, want %q", got, "")
}
}