feat: JSON serialization for Message and Content (session persistence blocker)
Add custom MarshalJSON/UnmarshalJSON on Content using string type discriminant
("text", "tool_call", "tool_result", "thinking"). Add json tags to Message.
This commit is contained in:
88
internal/message/json.go
Normal file
88
internal/message/json.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// MarshalJSON encodes Content as {"type":"<name>","<field>":...}.
|
||||
func (c Content) MarshalJSON() ([]byte, error) {
|
||||
switch c.Type {
|
||||
case ContentText:
|
||||
return json.Marshal(struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
}{Type: "text", Text: c.Text})
|
||||
case ContentToolCall:
|
||||
return json.Marshal(struct {
|
||||
Type string `json:"type"`
|
||||
ToolCall *ToolCall `json:"tool_call"`
|
||||
}{Type: "tool_call", ToolCall: c.ToolCall})
|
||||
case ContentToolResult:
|
||||
return json.Marshal(struct {
|
||||
Type string `json:"type"`
|
||||
ToolResult *ToolResult `json:"tool_result"`
|
||||
}{Type: "tool_result", ToolResult: c.ToolResult})
|
||||
case ContentThinking:
|
||||
return json.Marshal(struct {
|
||||
Type string `json:"type"`
|
||||
Thinking *Thinking `json:"thinking"`
|
||||
}{Type: "thinking", Thinking: c.Thinking})
|
||||
default:
|
||||
return nil, fmt.Errorf("message: unknown ContentType %d", c.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// UnmarshalJSON decodes Content from {"type":"<name>","<field>":...}.
|
||||
func (c *Content) UnmarshalJSON(data []byte) error {
|
||||
// First pass: extract the type discriminant.
|
||||
var disc struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &disc); err != nil {
|
||||
return fmt.Errorf("message: unmarshal content type: %w", err)
|
||||
}
|
||||
|
||||
// Second pass: decode the payload for the known type.
|
||||
switch disc.Type {
|
||||
case "text":
|
||||
var v struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &v); err != nil {
|
||||
return err
|
||||
}
|
||||
c.Type = ContentText
|
||||
c.Text = v.Text
|
||||
case "tool_call":
|
||||
var v struct {
|
||||
ToolCall *ToolCall `json:"tool_call"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &v); err != nil {
|
||||
return err
|
||||
}
|
||||
c.Type = ContentToolCall
|
||||
c.ToolCall = v.ToolCall
|
||||
case "tool_result":
|
||||
var v struct {
|
||||
ToolResult *ToolResult `json:"tool_result"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &v); err != nil {
|
||||
return err
|
||||
}
|
||||
c.Type = ContentToolResult
|
||||
c.ToolResult = v.ToolResult
|
||||
case "thinking":
|
||||
var v struct {
|
||||
Thinking *Thinking `json:"thinking"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &v); err != nil {
|
||||
return err
|
||||
}
|
||||
c.Type = ContentThinking
|
||||
c.Thinking = v.Thinking
|
||||
default:
|
||||
return fmt.Errorf("message: unknown content type %q", disc.Type)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
167
internal/message/json_test.go
Normal file
167
internal/message/json_test.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package message_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
)
|
||||
|
||||
func TestContent_MarshalJSON_Text(t *testing.T) {
|
||||
c := message.NewTextContent("hello world")
|
||||
data, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var got message.Content
|
||||
if err := json.Unmarshal(data, &got); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got.Type != message.ContentText || got.Text != "hello world" {
|
||||
t.Errorf("round-trip failed: %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContent_MarshalJSON_ToolCall(t *testing.T) {
|
||||
c := message.NewToolCallContent(message.ToolCall{
|
||||
ID: "tc_1",
|
||||
Name: "bash",
|
||||
Arguments: json.RawMessage(`{"command":"ls"}`),
|
||||
})
|
||||
data, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var got message.Content
|
||||
if err := json.Unmarshal(data, &got); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got.Type != message.ContentToolCall || got.ToolCall == nil {
|
||||
t.Fatalf("round-trip failed: %+v", got)
|
||||
}
|
||||
if got.ToolCall.ID != "tc_1" || got.ToolCall.Name != "bash" {
|
||||
t.Errorf("tool call fields wrong: %+v", got.ToolCall)
|
||||
}
|
||||
if string(got.ToolCall.Arguments) != `{"command":"ls"}` {
|
||||
t.Errorf("arguments wrong: %s", got.ToolCall.Arguments)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContent_MarshalJSON_ToolResult(t *testing.T) {
|
||||
c := message.NewToolResultContent(message.ToolResult{
|
||||
ToolCallID: "tc_1",
|
||||
Content: "output",
|
||||
IsError: false,
|
||||
})
|
||||
data, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var got message.Content
|
||||
if err := json.Unmarshal(data, &got); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got.Type != message.ContentToolResult || got.ToolResult == nil {
|
||||
t.Fatalf("round-trip failed: %+v", got)
|
||||
}
|
||||
if got.ToolResult.ToolCallID != "tc_1" || got.ToolResult.Content != "output" {
|
||||
t.Errorf("tool result fields wrong: %+v", got.ToolResult)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContent_MarshalJSON_ToolResult_IsError(t *testing.T) {
|
||||
c := message.NewToolResultContent(message.ToolResult{
|
||||
ToolCallID: "tc_2",
|
||||
Content: "error msg",
|
||||
IsError: true,
|
||||
})
|
||||
data, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var got message.Content
|
||||
if err := json.Unmarshal(data, &got); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !got.ToolResult.IsError {
|
||||
t.Errorf("IsError not preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestContent_MarshalJSON_Thinking(t *testing.T) {
|
||||
c := message.NewThinkingContent(message.Thinking{
|
||||
Text: "let me think",
|
||||
Signature: "sig_abc",
|
||||
Redacted: false,
|
||||
})
|
||||
data, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var got message.Content
|
||||
if err := json.Unmarshal(data, &got); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got.Type != message.ContentThinking || got.Thinking == nil {
|
||||
t.Fatalf("round-trip failed: %+v", got)
|
||||
}
|
||||
if got.Thinking.Text != "let me think" || got.Thinking.Signature != "sig_abc" {
|
||||
t.Errorf("thinking fields wrong: %+v", got.Thinking)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContent_UnmarshalJSON_UnknownType(t *testing.T) {
|
||||
data := []byte(`{"type":"unknown_xyz","text":"hi"}`)
|
||||
var got message.Content
|
||||
err := json.Unmarshal(data, &got)
|
||||
if err == nil {
|
||||
t.Error("expected error for unknown type, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessage_RoundTrip(t *testing.T) {
|
||||
msg := message.NewUserText("hello")
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var got message.Message
|
||||
if err := json.Unmarshal(data, &got); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got.Role != message.RoleUser || len(got.Content) != 1 || got.Content[0].Text != "hello" {
|
||||
t.Errorf("message round-trip failed: %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessages_Slice_RoundTrip(t *testing.T) {
|
||||
msgs := []message.Message{
|
||||
message.NewUserText("question"),
|
||||
message.NewAssistantContent(
|
||||
message.NewTextContent("answer"),
|
||||
message.NewToolCallContent(message.ToolCall{
|
||||
ID: "tc_1",
|
||||
Name: "bash",
|
||||
Arguments: json.RawMessage(`{}`),
|
||||
}),
|
||||
),
|
||||
message.NewToolResults(message.ToolResult{
|
||||
ToolCallID: "tc_1",
|
||||
Content: "result",
|
||||
}),
|
||||
}
|
||||
data, err := json.Marshal(msgs)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var got []message.Message
|
||||
if err := json.Unmarshal(data, &got); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(got) != 3 {
|
||||
t.Fatalf("expected 3 messages, got %d", len(got))
|
||||
}
|
||||
if got[1].Content[1].Type != message.ContentToolCall {
|
||||
t.Errorf("tool call content type wrong: %v", got[1].Content[1].Type)
|
||||
}
|
||||
}
|
||||
@@ -13,8 +13,8 @@ const (
|
||||
|
||||
// Message represents a single turn in the conversation.
|
||||
type Message struct {
|
||||
Role Role
|
||||
Content []Content
|
||||
Role Role `json:"role"`
|
||||
Content []Content `json:"content"`
|
||||
}
|
||||
|
||||
func NewUserText(text string) Message {
|
||||
|
||||
Reference in New Issue
Block a user