From ef2d058cc075415c6b96270c6daa408b66b03a20 Mon Sep 17 00:00:00 2001 From: vikingowl Date: Sun, 5 Apr 2026 23:31:25 +0200 Subject: [PATCH] feat: JSON serialization for Message and Content (session persistence blocker) Add custom MarshalJSON/UnmarshalJSON on Content using string type discriminant ("text", "tool_call", "tool_result", "thinking"). Add json tags to Message. --- internal/message/json.go | 88 ++++++++++++++++++ internal/message/json_test.go | 167 ++++++++++++++++++++++++++++++++++ internal/message/message.go | 4 +- 3 files changed, 257 insertions(+), 2 deletions(-) create mode 100644 internal/message/json.go create mode 100644 internal/message/json_test.go diff --git a/internal/message/json.go b/internal/message/json.go new file mode 100644 index 0000000..7dc2dfc --- /dev/null +++ b/internal/message/json.go @@ -0,0 +1,88 @@ +package message + +import ( + "encoding/json" + "fmt" +) + +// MarshalJSON encodes Content as {"type":"","":...}. +func (c Content) MarshalJSON() ([]byte, error) { + switch c.Type { + case ContentText: + return json.Marshal(struct { + Type string `json:"type"` + Text string `json:"text"` + }{Type: "text", Text: c.Text}) + case ContentToolCall: + return json.Marshal(struct { + Type string `json:"type"` + ToolCall *ToolCall `json:"tool_call"` + }{Type: "tool_call", ToolCall: c.ToolCall}) + case ContentToolResult: + return json.Marshal(struct { + Type string `json:"type"` + ToolResult *ToolResult `json:"tool_result"` + }{Type: "tool_result", ToolResult: c.ToolResult}) + case ContentThinking: + return json.Marshal(struct { + Type string `json:"type"` + Thinking *Thinking `json:"thinking"` + }{Type: "thinking", Thinking: c.Thinking}) + default: + return nil, fmt.Errorf("message: unknown ContentType %d", c.Type) + } +} + +// UnmarshalJSON decodes Content from {"type":"","":...}. +func (c *Content) UnmarshalJSON(data []byte) error { + // First pass: extract the type discriminant. + var disc struct { + Type string `json:"type"` + } + if err := json.Unmarshal(data, &disc); err != nil { + return fmt.Errorf("message: unmarshal content type: %w", err) + } + + // Second pass: decode the payload for the known type. + switch disc.Type { + case "text": + var v struct { + Text string `json:"text"` + } + if err := json.Unmarshal(data, &v); err != nil { + return err + } + c.Type = ContentText + c.Text = v.Text + case "tool_call": + var v struct { + ToolCall *ToolCall `json:"tool_call"` + } + if err := json.Unmarshal(data, &v); err != nil { + return err + } + c.Type = ContentToolCall + c.ToolCall = v.ToolCall + case "tool_result": + var v struct { + ToolResult *ToolResult `json:"tool_result"` + } + if err := json.Unmarshal(data, &v); err != nil { + return err + } + c.Type = ContentToolResult + c.ToolResult = v.ToolResult + case "thinking": + var v struct { + Thinking *Thinking `json:"thinking"` + } + if err := json.Unmarshal(data, &v); err != nil { + return err + } + c.Type = ContentThinking + c.Thinking = v.Thinking + default: + return fmt.Errorf("message: unknown content type %q", disc.Type) + } + return nil +} diff --git a/internal/message/json_test.go b/internal/message/json_test.go new file mode 100644 index 0000000..4e67634 --- /dev/null +++ b/internal/message/json_test.go @@ -0,0 +1,167 @@ +package message_test + +import ( + "encoding/json" + "testing" + + "somegit.dev/Owlibou/gnoma/internal/message" +) + +func TestContent_MarshalJSON_Text(t *testing.T) { + c := message.NewTextContent("hello world") + data, err := json.Marshal(c) + if err != nil { + t.Fatal(err) + } + var got message.Content + if err := json.Unmarshal(data, &got); err != nil { + t.Fatal(err) + } + if got.Type != message.ContentText || got.Text != "hello world" { + t.Errorf("round-trip failed: %+v", got) + } +} + +func TestContent_MarshalJSON_ToolCall(t *testing.T) { + c := message.NewToolCallContent(message.ToolCall{ + ID: "tc_1", + Name: "bash", + Arguments: json.RawMessage(`{"command":"ls"}`), + }) + data, err := json.Marshal(c) + if err != nil { + t.Fatal(err) + } + var got message.Content + if err := json.Unmarshal(data, &got); err != nil { + t.Fatal(err) + } + if got.Type != message.ContentToolCall || got.ToolCall == nil { + t.Fatalf("round-trip failed: %+v", got) + } + if got.ToolCall.ID != "tc_1" || got.ToolCall.Name != "bash" { + t.Errorf("tool call fields wrong: %+v", got.ToolCall) + } + if string(got.ToolCall.Arguments) != `{"command":"ls"}` { + t.Errorf("arguments wrong: %s", got.ToolCall.Arguments) + } +} + +func TestContent_MarshalJSON_ToolResult(t *testing.T) { + c := message.NewToolResultContent(message.ToolResult{ + ToolCallID: "tc_1", + Content: "output", + IsError: false, + }) + data, err := json.Marshal(c) + if err != nil { + t.Fatal(err) + } + var got message.Content + if err := json.Unmarshal(data, &got); err != nil { + t.Fatal(err) + } + if got.Type != message.ContentToolResult || got.ToolResult == nil { + t.Fatalf("round-trip failed: %+v", got) + } + if got.ToolResult.ToolCallID != "tc_1" || got.ToolResult.Content != "output" { + t.Errorf("tool result fields wrong: %+v", got.ToolResult) + } +} + +func TestContent_MarshalJSON_ToolResult_IsError(t *testing.T) { + c := message.NewToolResultContent(message.ToolResult{ + ToolCallID: "tc_2", + Content: "error msg", + IsError: true, + }) + data, err := json.Marshal(c) + if err != nil { + t.Fatal(err) + } + var got message.Content + if err := json.Unmarshal(data, &got); err != nil { + t.Fatal(err) + } + if !got.ToolResult.IsError { + t.Errorf("IsError not preserved") + } +} + +func TestContent_MarshalJSON_Thinking(t *testing.T) { + c := message.NewThinkingContent(message.Thinking{ + Text: "let me think", + Signature: "sig_abc", + Redacted: false, + }) + data, err := json.Marshal(c) + if err != nil { + t.Fatal(err) + } + var got message.Content + if err := json.Unmarshal(data, &got); err != nil { + t.Fatal(err) + } + if got.Type != message.ContentThinking || got.Thinking == nil { + t.Fatalf("round-trip failed: %+v", got) + } + if got.Thinking.Text != "let me think" || got.Thinking.Signature != "sig_abc" { + t.Errorf("thinking fields wrong: %+v", got.Thinking) + } +} + +func TestContent_UnmarshalJSON_UnknownType(t *testing.T) { + data := []byte(`{"type":"unknown_xyz","text":"hi"}`) + var got message.Content + err := json.Unmarshal(data, &got) + if err == nil { + t.Error("expected error for unknown type, got nil") + } +} + +func TestMessage_RoundTrip(t *testing.T) { + msg := message.NewUserText("hello") + data, err := json.Marshal(msg) + if err != nil { + t.Fatal(err) + } + var got message.Message + if err := json.Unmarshal(data, &got); err != nil { + t.Fatal(err) + } + if got.Role != message.RoleUser || len(got.Content) != 1 || got.Content[0].Text != "hello" { + t.Errorf("message round-trip failed: %+v", got) + } +} + +func TestMessages_Slice_RoundTrip(t *testing.T) { + msgs := []message.Message{ + message.NewUserText("question"), + message.NewAssistantContent( + message.NewTextContent("answer"), + message.NewToolCallContent(message.ToolCall{ + ID: "tc_1", + Name: "bash", + Arguments: json.RawMessage(`{}`), + }), + ), + message.NewToolResults(message.ToolResult{ + ToolCallID: "tc_1", + Content: "result", + }), + } + data, err := json.Marshal(msgs) + if err != nil { + t.Fatal(err) + } + var got []message.Message + if err := json.Unmarshal(data, &got); err != nil { + t.Fatal(err) + } + if len(got) != 3 { + t.Fatalf("expected 3 messages, got %d", len(got)) + } + if got[1].Content[1].Type != message.ContentToolCall { + t.Errorf("tool call content type wrong: %v", got[1].Content[1].Type) + } +} diff --git a/internal/message/message.go b/internal/message/message.go index c160e24..94459f0 100644 --- a/internal/message/message.go +++ b/internal/message/message.go @@ -13,8 +13,8 @@ const ( // Message represents a single turn in the conversation. type Message struct { - Role Role - Content []Content + Role Role `json:"role"` + Content []Content `json:"content"` } func NewUserText(text string) Message {