diff --git a/internal/message/content.go b/internal/message/content.go new file mode 100644 index 0000000..71c7ef6 --- /dev/null +++ b/internal/message/content.go @@ -0,0 +1,78 @@ +package message + +import ( + "encoding/json" + "fmt" +) + +// ContentType discriminates the content block union. +type ContentType int + +const ( + ContentText ContentType = iota + 1 + ContentToolCall + ContentToolResult + ContentThinking +) + +func (ct ContentType) String() string { + switch ct { + case ContentText: + return "text" + case ContentToolCall: + return "tool_call" + case ContentToolResult: + return "tool_result" + case ContentThinking: + return "thinking" + default: + return fmt.Sprintf("unknown(%d)", ct) + } +} + +// Content is a discriminated union. Exactly one payload field is set per Type. +type Content struct { + Type ContentType + Text string // ContentText + ToolCall *ToolCall // ContentToolCall + ToolResult *ToolResult // ContentToolResult + Thinking *Thinking // ContentThinking +} + +// ToolCall represents the model's request to invoke a tool. +type ToolCall struct { + ID string `json:"id"` + Name string `json:"name"` + Arguments json.RawMessage `json:"arguments"` +} + +// ToolResult is the output of executing a tool, correlated by ToolCallID. +type ToolResult struct { + ToolCallID string `json:"tool_call_id"` + Content string `json:"content"` + IsError bool `json:"is_error"` +} + +// Thinking represents a reasoning/thinking trace. +// Signature must round-trip unchanged (Anthropic requirement). +type Thinking struct { + Text string `json:"text,omitempty"` + Signature string `json:"signature,omitempty"` + Redacted bool `json:"redacted,omitempty"` +} + +func NewTextContent(text string) Content { + return Content{Type: ContentText, Text: text} +} + +func NewToolCallContent(tc ToolCall) Content { + return Content{Type: ContentToolCall, ToolCall: &tc} +} + +func NewToolResultContent(tr ToolResult) Content { + return Content{Type: ContentToolResult, ToolResult: &tr} +} + +func NewThinkingContent(th Thinking) Content { + return Content{Type: ContentThinking, Thinking: &th} +} diff --git a/internal/message/content_test.go b/internal/message/content_test.go new file mode 100644 index 0000000..43fd108 --- /dev/null +++ b/internal/message/content_test.go @@ -0,0 +1,174 @@ +package message + +import ( + "encoding/json" + "testing" +) + +func TestNewTextContent(t *testing.T) { + c := NewTextContent("hello world") + + if c.Type != ContentText { + t.Errorf("Type = %v, want %v", c.Type, ContentText) + } + if c.Text != "hello world" { + t.Errorf("Text = %q, want %q", c.Text, "hello world") + } + if c.ToolCall != nil { + t.Error("ToolCall should be nil for text content") + } + if c.ToolResult != nil { + t.Error("ToolResult should be nil for text content") + } + if c.Thinking != nil { + t.Error("Thinking should be nil for text content") + } +} + +func TestNewToolCallContent(t *testing.T) { + args := json.RawMessage(`{"command":"ls -la"}`) + tc := ToolCall{ + ID: "tc_001", + Name: "bash", + Arguments: args, + } + c := NewToolCallContent(tc) + + if c.Type != ContentToolCall { + t.Errorf("Type = %v, want %v", c.Type, ContentToolCall) + } + if c.ToolCall == nil { + t.Fatal("ToolCall should not be nil") + } + if c.ToolCall.ID != "tc_001" { + t.Errorf("ToolCall.ID = %q, want %q", c.ToolCall.ID, "tc_001") + } + if c.ToolCall.Name != "bash" { + t.Errorf("ToolCall.Name = %q, want %q", c.ToolCall.Name, "bash") + } + if string(c.ToolCall.Arguments) != `{"command":"ls -la"}` { + t.Errorf("ToolCall.Arguments = %s, want %s", c.ToolCall.Arguments, args) + } + if c.Text != "" { + t.Error("Text should be empty for tool call content") + } +} + +func TestNewToolResultContent(t *testing.T) { + tr := ToolResult{ + ToolCallID: "tc_001", + Content: "file1.go\nfile2.go", + IsError: false, + } + c := NewToolResultContent(tr) + + if c.Type != ContentToolResult { + t.Errorf("Type = %v, want %v", c.Type, ContentToolResult) + } + if c.ToolResult == nil { + t.Fatal("ToolResult should not be nil") + } + if c.ToolResult.ToolCallID != "tc_001" { + t.Errorf("ToolResult.ToolCallID = %q, want %q", c.ToolResult.ToolCallID, "tc_001") + } + if c.ToolResult.IsError { + t.Error("ToolResult.IsError should be false") + } +} + +func TestNewToolResultContent_Error(t *testing.T) { + tr := ToolResult{ + ToolCallID: "tc_002", + Content: "permission denied", + IsError: true, + } + c := NewToolResultContent(tr) + + if !c.ToolResult.IsError { + t.Error("ToolResult.IsError should be true") + } + if c.ToolResult.Content != "permission denied" { + t.Errorf("ToolResult.Content = %q, want %q", c.ToolResult.Content, "permission denied") + } +} + +func TestNewThinkingContent(t *testing.T) { + th := Thinking{ + Text: "Let me think about this...", + Signature: "sig_abc123", + } + c := NewThinkingContent(th) + + if c.Type != ContentThinking { + t.Errorf("Type = %v, want %v", c.Type, ContentThinking) + } + if c.Thinking == nil { + t.Fatal("Thinking should not be nil") + } + if c.Thinking.Text != "Let me think about this..." { + t.Errorf("Thinking.Text = %q", c.Thinking.Text) + } + if c.Thinking.Signature != "sig_abc123" { + t.Errorf("Thinking.Signature = %q", c.Thinking.Signature) + } + if c.Thinking.Redacted { + t.Error("Thinking.Redacted should be false") + } +} + +func TestNewRedactedThinkingContent(t *testing.T) { + th := Thinking{ + Redacted: true, + } + c := NewThinkingContent(th) + + if !c.Thinking.Redacted { + t.Error("Thinking.Redacted should be true") + } +} + +func TestContentType_String(t *testing.T) { + tests := []struct { + ct ContentType + want string + }{ + {ContentText, "text"}, + {ContentToolCall, "tool_call"}, + {ContentToolResult, "tool_result"}, + {ContentThinking, "thinking"}, + {ContentType(99), "unknown(99)"}, + } + for _, tt := range tests { + if got := tt.ct.String(); got != tt.want { + t.Errorf("ContentType(%d).String() = %q, want %q", tt.ct, got, tt.want) + } + } +} + +func TestToolCall_JSON_RoundTrip(t *testing.T) { + original := ToolCall{ + ID: "tc_100", + Name: "fs.read", + Arguments: json.RawMessage(`{"path":"/tmp/test.go","offset":0}`), + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + + var decoded ToolCall + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + + if decoded.ID != original.ID { + t.Errorf("ID = %q, want %q", decoded.ID, original.ID) + } + if decoded.Name != original.Name { + t.Errorf("Name = %q, want %q", decoded.Name, original.Name) + } + if string(decoded.Arguments) != string(original.Arguments) { + t.Errorf("Arguments = %s, want %s", decoded.Arguments, original.Arguments) + } +} diff --git a/internal/message/message.go b/internal/message/message.go new file mode 100644 index 0000000..c160e24 --- /dev/null +++ b/internal/message/message.go @@ -0,0 +1,89 @@ +package message + +import "strings" + +// Role identifies the sender of a message. +type Role string + +const ( + RoleUser Role = "user" + RoleAssistant Role = "assistant" + RoleSystem Role = "system" +) + +// Message represents a single turn in the conversation. +type Message struct { + Role Role + Content []Content +} + +func NewUserText(text string) Message { + return Message{ + Role: RoleUser, + Content: []Content{NewTextContent(text)}, + } +} + +func NewAssistantText(text string) Message { + return Message{ + Role: RoleAssistant, + Content: []Content{NewTextContent(text)}, + } +} + +func NewAssistantContent(blocks ...Content) Message { + return Message{ + Role: RoleAssistant, + Content: blocks, + } +} + +func NewSystemText(text string) Message { + return Message{ + Role: RoleSystem, + Content: []Content{NewTextContent(text)}, + } +} + +func NewToolResults(results ...ToolResult) Message { + content := make([]Content, len(results)) + for i, r := range results { + content[i] = NewToolResultContent(r) + } + return Message{ + Role: RoleUser, + Content: content, + } +} + +// HasToolCalls returns true if any content block is a tool call. +func (m Message) HasToolCalls() bool { + for _, c := range m.Content { + if c.Type == ContentToolCall { + return true + } + } + return false +} + +// ToolCalls extracts all tool call blocks. +func (m Message) ToolCalls() []ToolCall { + var calls []ToolCall + for _, c := range m.Content { + if c.Type == ContentToolCall && c.ToolCall != nil { + calls = append(calls, *c.ToolCall) + } + } + return calls +} + +// TextContent concatenates all text blocks. +func (m Message) TextContent() string { + var b strings.Builder + for _, c := range m.Content { + if c.Type == ContentText { + b.WriteString(c.Text) + } + } + return b.String() +} diff --git a/internal/message/message_test.go b/internal/message/message_test.go new file mode 100644 index 0000000..29ae1b1 --- /dev/null +++ b/internal/message/message_test.go @@ -0,0 +1,214 @@ +package message + +import ( + "encoding/json" + "testing" +) + +func TestNewUserText(t *testing.T) { + m := NewUserText("hello") + if m.Role != RoleUser { + t.Errorf("Role = %q, want %q", m.Role, RoleUser) + } + if len(m.Content) != 1 { + t.Fatalf("len(Content) = %d, want 1", len(m.Content)) + } + if m.Content[0].Type != ContentText { + t.Errorf("Content[0].Type = %v, want %v", m.Content[0].Type, ContentText) + } + if m.Content[0].Text != "hello" { + t.Errorf("Content[0].Text = %q, want %q", m.Content[0].Text, "hello") + } +} + +func TestNewAssistantText(t *testing.T) { + m := NewAssistantText("response") + if m.Role != RoleAssistant { + t.Errorf("Role = %q, want %q", m.Role, RoleAssistant) + } + if m.TextContent() != "response" { + t.Errorf("TextContent() = %q, want %q", m.TextContent(), "response") + } +} + +func TestNewSystemText(t *testing.T) { + m := NewSystemText("you are a helper") + if m.Role != RoleSystem { + t.Errorf("Role = %q, want %q", m.Role, RoleSystem) + } +} + +func TestNewAssistantContent_Mixed(t *testing.T) { + m := NewAssistantContent( + NewTextContent("I'll run that command."), + NewToolCallContent(ToolCall{ + ID: "tc_1", + Name: "bash", + Arguments: json.RawMessage(`{"command":"ls"}`), + }), + ) + + if m.Role != RoleAssistant { + t.Errorf("Role = %q, want %q", m.Role, RoleAssistant) + } + if len(m.Content) != 2 { + t.Fatalf("len(Content) = %d, want 2", len(m.Content)) + } + if m.Content[0].Type != ContentText { + t.Errorf("Content[0].Type = %v, want text", m.Content[0].Type) + } + if m.Content[1].Type != ContentToolCall { + t.Errorf("Content[1].Type = %v, want tool_call", m.Content[1].Type) + } +} + +func TestNewToolResults(t *testing.T) { + m := NewToolResults( + ToolResult{ToolCallID: "tc_1", Content: "output1"}, + ToolResult{ToolCallID: "tc_2", Content: "output2", IsError: true}, + ) + + if m.Role != RoleUser { + t.Errorf("Role = %q, want %q", m.Role, RoleUser) + } + if len(m.Content) != 2 { + t.Fatalf("len(Content) = %d, want 2", len(m.Content)) + } + if m.Content[0].ToolResult.ToolCallID != "tc_1" { + t.Errorf("Content[0].ToolResult.ToolCallID = %q", m.Content[0].ToolResult.ToolCallID) + } + if m.Content[1].ToolResult.IsError != true { + t.Error("Content[1].ToolResult.IsError should be true") + } +} + +func TestMessage_HasToolCalls(t *testing.T) { + tests := []struct { + name string + msg Message + want bool + }{ + { + name: "text only", + msg: NewUserText("hello"), + want: false, + }, + { + name: "with tool call", + msg: NewAssistantContent( + NewTextContent("running..."), + NewToolCallContent(ToolCall{ID: "tc_1", Name: "bash"}), + ), + want: true, + }, + { + name: "tool results (not calls)", + msg: NewToolResults(ToolResult{ToolCallID: "tc_1", Content: "ok"}), + want: false, + }, + { + name: "empty message", + msg: Message{Role: RoleAssistant}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.msg.HasToolCalls(); got != tt.want { + t.Errorf("HasToolCalls() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMessage_ToolCalls(t *testing.T) { + m := NewAssistantContent( + NewTextContent("here are two commands"), + NewToolCallContent(ToolCall{ID: "tc_1", Name: "bash", Arguments: json.RawMessage(`{"command":"ls"}`)}), + NewTextContent("and another"), + NewToolCallContent(ToolCall{ID: "tc_2", Name: "fs.read", Arguments: json.RawMessage(`{"path":"go.mod"}`)}), + ) + + calls := m.ToolCalls() + if len(calls) != 2 { + t.Fatalf("len(ToolCalls()) = %d, want 2", len(calls)) + } + if calls[0].ID != "tc_1" { + t.Errorf("calls[0].ID = %q, want tc_1", calls[0].ID) + } + if calls[1].Name != "fs.read" { + t.Errorf("calls[1].Name = %q, want fs.read", calls[1].Name) + } +} + +func TestMessage_ToolCalls_Empty(t *testing.T) { + m := NewUserText("no tools here") + calls := m.ToolCalls() + if len(calls) != 0 { + t.Errorf("len(ToolCalls()) = %d, want 0", len(calls)) + } +} + +func TestMessage_TextContent(t *testing.T) { + tests := []struct { + name string + msg Message + want string + }{ + { + name: "single text", + msg: NewUserText("hello"), + want: "hello", + }, + { + name: "multiple text blocks", + msg: NewAssistantContent( + NewTextContent("first "), + NewToolCallContent(ToolCall{ID: "tc_1", Name: "bash"}), + NewTextContent("second"), + ), + want: "first second", + }, + { + name: "no text", + msg: NewToolResults(ToolResult{ToolCallID: "tc_1", Content: "output"}), + want: "", + }, + { + name: "empty message", + msg: Message{Role: RoleAssistant}, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.msg.TextContent(); got != tt.want { + t.Errorf("TextContent() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestResponse_Fields(t *testing.T) { + r := Response{ + Message: NewAssistantText("done"), + StopReason: StopEndTurn, + Usage: Usage{InputTokens: 100, OutputTokens: 50}, + Model: "mistral-large-latest", + } + + if r.StopReason != StopEndTurn { + t.Errorf("StopReason = %q, want %q", r.StopReason, StopEndTurn) + } + if r.Usage.TotalTokens() != 150 { + t.Errorf("Usage.TotalTokens() = %d, want 150", r.Usage.TotalTokens()) + } + if r.Model != "mistral-large-latest" { + t.Errorf("Model = %q", r.Model) + } + if r.Message.TextContent() != "done" { + t.Errorf("Message.TextContent() = %q", r.Message.TextContent()) + } +} diff --git a/internal/message/response.go b/internal/message/response.go new file mode 100644 index 0000000..bb3097c --- /dev/null +++ b/internal/message/response.go @@ -0,0 +1,9 @@ +package message + +// Response wraps a completed assistant turn. +type Response struct { + Message Message + StopReason StopReason + Usage Usage + Model string +} diff --git a/internal/message/stop.go b/internal/message/stop.go new file mode 100644 index 0000000..d1624fd --- /dev/null +++ b/internal/message/stop.go @@ -0,0 +1,11 @@ +package message + +// StopReason indicates why the model stopped generating. +type StopReason string + +const ( + StopEndTurn StopReason = "end_turn" + StopMaxTokens StopReason = "max_tokens" + StopToolUse StopReason = "tool_use" + StopSequence StopReason = "stop_sequence" +) diff --git a/internal/message/usage.go b/internal/message/usage.go new file mode 100644 index 0000000..affdc6f --- /dev/null +++ b/internal/message/usage.go @@ -0,0 +1,20 @@ +package message + +// Usage tracks token consumption for a single API turn. +type Usage struct { + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + CacheReadTokens int64 `json:"cache_read_tokens,omitempty"` + CacheCreationTokens int64 `json:"cache_creation_tokens,omitempty"` +} + +func (u Usage) TotalTokens() int64 { + return u.InputTokens + u.OutputTokens +} + +func (u *Usage) Add(other Usage) { + u.InputTokens += other.InputTokens + u.OutputTokens += other.OutputTokens + u.CacheReadTokens += other.CacheReadTokens + u.CacheCreationTokens += other.CacheCreationTokens +} diff --git a/internal/message/usage_test.go b/internal/message/usage_test.go new file mode 100644 index 0000000..01f58b2 --- /dev/null +++ b/internal/message/usage_test.go @@ -0,0 +1,70 @@ +package message + +import "testing" + +func TestUsage_TotalTokens(t *testing.T) { + u := Usage{InputTokens: 100, OutputTokens: 50} + if got := u.TotalTokens(); got != 150 { + t.Errorf("TotalTokens() = %d, want 150", got) + } +} + +func TestUsage_TotalTokens_Zero(t *testing.T) { + var u Usage + if got := u.TotalTokens(); got != 0 { + t.Errorf("TotalTokens() = %d, want 0", got) + } +} + +func TestUsage_Add(t *testing.T) { + u := Usage{ + InputTokens: 100, + OutputTokens: 50, + CacheReadTokens: 10, + CacheCreationTokens: 5, + } + other := Usage{ + InputTokens: 200, + OutputTokens: 80, + CacheReadTokens: 20, + CacheCreationTokens: 15, + } + + u.Add(other) + + if u.InputTokens != 300 { + t.Errorf("InputTokens = %d, want 300", u.InputTokens) + } + if u.OutputTokens != 130 { + t.Errorf("OutputTokens = %d, want 130", u.OutputTokens) + } + if u.CacheReadTokens != 30 { + t.Errorf("CacheReadTokens = %d, want 30", u.CacheReadTokens) + } + if u.CacheCreationTokens != 20 { + t.Errorf("CacheCreationTokens = %d, want 20", u.CacheCreationTokens) + } +} + +func TestUsage_Add_Multiple(t *testing.T) { + var total Usage + turns := []Usage{ + {InputTokens: 100, OutputTokens: 50}, + {InputTokens: 200, OutputTokens: 80}, + {InputTokens: 150, OutputTokens: 60}, + } + + for _, turn := range turns { + total.Add(turn) + } + + if total.InputTokens != 450 { + t.Errorf("InputTokens = %d, want 450", total.InputTokens) + } + if total.OutputTokens != 190 { + t.Errorf("OutputTokens = %d, want 190", total.OutputTokens) + } + if total.TotalTokens() != 640 { + t.Errorf("TotalTokens() = %d, want 640", total.TotalTokens()) + } +} diff --git a/internal/provider/errors.go b/internal/provider/errors.go new file mode 100644 index 0000000..4fb6b3c --- /dev/null +++ b/internal/provider/errors.go @@ -0,0 +1,79 @@ +package provider + +import ( + "fmt" + "time" +) + +// ErrorKind classifies provider errors for retry decisions. +type ErrorKind int + +const ( + ErrTransient ErrorKind = iota + 1 // 429, 500, 502, 503, 529 — retry with backoff + ErrAuth // 401, 403 — don't retry + ErrBadRequest // 400 — don't retry, fix request + ErrNotFound // 404 — model/endpoint not found + ErrOverloaded // capacity exhausted — backoff + retry +) + +func (k ErrorKind) String() string { + switch k { + case ErrTransient: + return "transient" + case ErrAuth: + return "auth" + case ErrBadRequest: + return "bad_request" + case ErrNotFound: + return "not_found" + case ErrOverloaded: + return "overloaded" + default: + return fmt.Sprintf("unknown(%d)", k) + } +} + +// ProviderError wraps an SDK error with classification metadata. +type ProviderError struct { + Kind ErrorKind + Provider string + StatusCode int + Message string + Retryable bool + RetryAfter time.Duration // from Retry-After or rate limit headers + Err error // underlying SDK error +} + +func (e *ProviderError) Error() string { + if e.Err != nil { + return fmt.Sprintf("%s %s (%d): %s: %v", e.Provider, e.Kind, e.StatusCode, e.Message, e.Err) + } + return fmt.Sprintf("%s %s (%d): %s", e.Provider, e.Kind, e.StatusCode, e.Message) +} + +func (e *ProviderError) Unwrap() error { + return e.Err +} + +// ClassifyHTTPStatus returns the ErrorKind and retryability for an HTTP status code. +func ClassifyHTTPStatus(status int) (ErrorKind, bool) { + switch { + case status == 401 || status == 403: + return ErrAuth, false + case status == 400: + return ErrBadRequest, false + case status == 404: + return ErrNotFound, false + case status == 429 || status == 529: + return ErrTransient, true + case status == 500 || status == 502 || status == 503: + return ErrTransient, true + case status == 504: + return ErrOverloaded, true + default: + if status >= 500 { + return ErrTransient, true + } + return ErrBadRequest, false + } +} diff --git a/internal/provider/errors_test.go b/internal/provider/errors_test.go new file mode 100644 index 0000000..07f7c2c --- /dev/null +++ b/internal/provider/errors_test.go @@ -0,0 +1,118 @@ +package provider + +import ( + "errors" + "fmt" + "testing" +) + +func TestProviderError_Error(t *testing.T) { + err := &ProviderError{ + Kind: ErrTransient, + Provider: "mistral", + StatusCode: 429, + Message: "rate limited", + } + got := err.Error() + want := "mistral transient (429): rate limited" + if got != want { + t.Errorf("Error() = %q, want %q", got, want) + } +} + +func TestProviderError_Error_WithWrapped(t *testing.T) { + inner := errors.New("connection reset") + err := &ProviderError{ + Kind: ErrTransient, + Provider: "openai", + StatusCode: 502, + Message: "bad gateway", + Err: inner, + } + got := err.Error() + want := "openai transient (502): bad gateway: connection reset" + if got != want { + t.Errorf("Error() = %q, want %q", got, want) + } +} + +func TestProviderError_Unwrap(t *testing.T) { + inner := errors.New("timeout") + err := &ProviderError{ + Kind: ErrTransient, + Err: inner, + } + if !errors.Is(err, inner) { + t.Error("errors.Is should find inner error") + } +} + +func TestProviderError_AsType(t *testing.T) { + inner := &ProviderError{ + Kind: ErrAuth, + Provider: "anthropic", + StatusCode: 401, + Message: "invalid key", + } + wrapped := fmt.Errorf("api call failed: %w", inner) + + pErr, ok := errors.AsType[*ProviderError](wrapped) + if !ok { + t.Fatal("errors.AsType should find ProviderError") + } + if pErr.Kind != ErrAuth { + t.Errorf("Kind = %v, want %v", pErr.Kind, ErrAuth) + } + if pErr.Provider != "anthropic" { + t.Errorf("Provider = %q", pErr.Provider) + } +} + +func TestClassifyHTTPStatus(t *testing.T) { + tests := []struct { + status int + wantKind ErrorKind + wantRetry bool + }{ + {200, ErrBadRequest, false}, // shouldn't happen, but safe default + {400, ErrBadRequest, false}, + {401, ErrAuth, false}, + {403, ErrAuth, false}, + {404, ErrNotFound, false}, + {429, ErrTransient, true}, + {500, ErrTransient, true}, + {502, ErrTransient, true}, + {503, ErrTransient, true}, + {504, ErrOverloaded, true}, + {529, ErrTransient, true}, + {599, ErrTransient, true}, // unknown 5xx + } + for _, tt := range tests { + kind, retry := ClassifyHTTPStatus(tt.status) + if kind != tt.wantKind { + t.Errorf("ClassifyHTTPStatus(%d) kind = %v, want %v", tt.status, kind, tt.wantKind) + } + if retry != tt.wantRetry { + t.Errorf("ClassifyHTTPStatus(%d) retry = %v, want %v", tt.status, retry, tt.wantRetry) + } + } +} + +func TestErrorKind_String(t *testing.T) { + tests := []struct { + kind ErrorKind + want string + }{ + {ErrTransient, "transient"}, + {ErrAuth, "auth"}, + {ErrBadRequest, "bad_request"}, + {ErrNotFound, "not_found"}, + {ErrOverloaded, "overloaded"}, + {ErrorKind(99), "unknown(99)"}, + } + for _, tt := range tests { + if got := tt.kind.String(); got != tt.want { + t.Errorf("ErrorKind(%d).String() = %q, want %q", tt.kind, got, tt.want) + } + } +} diff --git a/internal/provider/provider.go b/internal/provider/provider.go new file mode 100644 index 0000000..4dec0fd --- /dev/null +++ b/internal/provider/provider.go @@ -0,0 +1,44 @@ +package provider + +import ( + "context" + "encoding/json" + + "somegit.dev/Owlibou/gnoma/internal/message" + "somegit.dev/Owlibou/gnoma/internal/stream" +) + +// Request encapsulates everything needed for a single LLM API call. +type Request struct { + Model string + SystemPrompt string + Messages []message.Message + Tools []ToolDefinition + MaxTokens int64 + Temperature *float64 + TopP *float64 + TopK *int64 + StopSequences []string + Thinking *ThinkingConfig +} + +// ToolDefinition is the provider-agnostic tool schema. +type ToolDefinition struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters json.RawMessage `json:"parameters"` // JSON Schema passthrough +} + +// ThinkingConfig controls extended thinking / reasoning. +type ThinkingConfig struct { + BudgetTokens int64 +} + +// Provider is the core abstraction over all LLM backends. +type Provider interface { + // Stream initiates a streaming request and returns an event stream. + Stream(ctx context.Context, req Request) (stream.Stream, error) + + // Name returns the provider identifier (e.g., "mistral", "anthropic"). + Name() string +} diff --git a/internal/provider/registry.go b/internal/provider/registry.go new file mode 100644 index 0000000..b1e5254 --- /dev/null +++ b/internal/provider/registry.go @@ -0,0 +1,69 @@ +package provider + +import ( + "fmt" + "sync" +) + +// ProviderConfig is the common configuration for any provider. +type ProviderConfig struct { + Name string + APIKey string + BaseURL string // override for OpenAI-compat endpoints + Model string // default model for this provider + Options map[string]any // provider-specific options +} + +// Factory creates a Provider from configuration. +type Factory func(cfg ProviderConfig) (Provider, error) + +// Registry maps provider names to factory functions. +type Registry struct { + mu sync.RWMutex + factories map[string]Factory +} + +func NewRegistry() *Registry { + return &Registry{ + factories: make(map[string]Factory), + } +} + +// Register adds a provider factory. Overwrites if name already exists. +func (r *Registry) Register(name string, f Factory) { + r.mu.Lock() + defer r.mu.Unlock() + r.factories[name] = f +} + +// Create instantiates a provider by name with the given config. +func (r *Registry) Create(name string, cfg ProviderConfig) (Provider, error) { + r.mu.RLock() + f, ok := r.factories[name] + r.mu.RUnlock() + + if !ok { + return nil, fmt.Errorf("unknown provider: %q", name) + } + cfg.Name = name + return f(cfg) +} + +// Has returns true if a factory is registered for the given name. +func (r *Registry) Has(name string) bool { + r.mu.RLock() + defer r.mu.RUnlock() + _, ok := r.factories[name] + return ok +} + +// Names returns all registered provider names. +func (r *Registry) Names() []string { + r.mu.RLock() + defer r.mu.RUnlock() + names := make([]string, 0, len(r.factories)) + for name := range r.factories { + names = append(names, name) + } + return names +} diff --git a/internal/provider/registry_test.go b/internal/provider/registry_test.go new file mode 100644 index 0000000..525194d --- /dev/null +++ b/internal/provider/registry_test.go @@ -0,0 +1,131 @@ +package provider + +import ( + "context" + "errors" + "slices" + "sort" + "testing" + + "somegit.dev/Owlibou/gnoma/internal/stream" +) + +// mockProvider implements Provider for testing. +type mockProvider struct { + name string +} + +func (m *mockProvider) Stream(_ context.Context, _ Request) (stream.Stream, error) { + return nil, nil +} + +func (m *mockProvider) Name() string { + return m.name +} + +func TestRegistry_RegisterAndCreate(t *testing.T) { + r := NewRegistry() + + r.Register("mock", func(cfg ProviderConfig) (Provider, error) { + return &mockProvider{name: cfg.Name}, nil + }) + + p, err := r.Create("mock", ProviderConfig{}) + if err != nil { + t.Fatalf("Create: %v", err) + } + if p.Name() != "mock" { + t.Errorf("Name() = %q, want %q", p.Name(), "mock") + } +} + +func TestRegistry_Create_Unknown(t *testing.T) { + r := NewRegistry() + + _, err := r.Create("nonexistent", ProviderConfig{}) + if err == nil { + t.Fatal("expected error for unknown provider") + } + want := `unknown provider: "nonexistent"` + if err.Error() != want { + t.Errorf("error = %q, want %q", err.Error(), want) + } +} + +func TestRegistry_Create_FactoryError(t *testing.T) { + r := NewRegistry() + r.Register("broken", func(cfg ProviderConfig) (Provider, error) { + return nil, errors.New("missing api key") + }) + + _, err := r.Create("broken", ProviderConfig{}) + if err == nil { + t.Fatal("expected error from factory") + } + if err.Error() != "missing api key" { + t.Errorf("error = %q", err.Error()) + } +} + +func TestRegistry_Create_SetsName(t *testing.T) { + r := NewRegistry() + + var receivedName string + r.Register("test", func(cfg ProviderConfig) (Provider, error) { + receivedName = cfg.Name + return &mockProvider{name: cfg.Name}, nil + }) + + _, _ = r.Create("test", ProviderConfig{APIKey: "sk-123"}) + if receivedName != "test" { + t.Errorf("factory received Name = %q, want %q", receivedName, "test") + } +} + +func TestRegistry_Has(t *testing.T) { + r := NewRegistry() + r.Register("exists", func(cfg ProviderConfig) (Provider, error) { + return nil, nil + }) + + if !r.Has("exists") { + t.Error("Has(exists) = false, want true") + } + if r.Has("nope") { + t.Error("Has(nope) = true, want false") + } +} + +func TestRegistry_Names(t *testing.T) { + r := NewRegistry() + r.Register("alpha", func(cfg ProviderConfig) (Provider, error) { return nil, nil }) + r.Register("beta", func(cfg ProviderConfig) (Provider, error) { return nil, nil }) + r.Register("gamma", func(cfg ProviderConfig) (Provider, error) { return nil, nil }) + + names := r.Names() + sort.Strings(names) + + want := []string{"alpha", "beta", "gamma"} + if !slices.Equal(names, want) { + t.Errorf("Names() = %v, want %v", names, want) + } +} + +func TestRegistry_Register_Overwrite(t *testing.T) { + r := NewRegistry() + + r.Register("dup", func(cfg ProviderConfig) (Provider, error) { + return &mockProvider{name: "old"}, nil + }) + r.Register("dup", func(cfg ProviderConfig) (Provider, error) { + return &mockProvider{name: "new"}, nil + }) + + p, err := r.Create("dup", ProviderConfig{}) + if err != nil { + t.Fatalf("Create: %v", err) + } + if p.Name() != "new" { + t.Errorf("Name() = %q, want %q (overwritten factory)", p.Name(), "new") + } +} diff --git a/internal/stream/accumulator.go b/internal/stream/accumulator.go new file mode 100644 index 0000000..155926c --- /dev/null +++ b/internal/stream/accumulator.go @@ -0,0 +1,143 @@ +package stream + +import ( + "strings" + + "somegit.dev/Owlibou/gnoma/internal/message" +) + +// Accumulator assembles a message.Response from a sequence of Events. +// Provider adapters translate SDK events into stream.Events; +// the Accumulator — shared, tested once — builds the final Response. +type Accumulator struct { + content []message.Content + usage message.Usage + + // Active text block being built + textBuf *strings.Builder + + // Active thinking block being built + thinkBuf *strings.Builder + + // Tool calls in progress, keyed by ToolCallID + toolCalls map[string]*toolCallAccum + // Ordered tool call IDs to preserve emission order + toolCallOrder []string +} + +type toolCallAccum struct { + id string + name string + argsBuf strings.Builder + args []byte // final complete args (from Done event) +} + +func NewAccumulator() *Accumulator { + return &Accumulator{ + toolCalls: make(map[string]*toolCallAccum), + } +} + +// Apply processes a single event, updating the accumulator's state. +func (a *Accumulator) Apply(e Event) { + switch e.Type { + case EventTextDelta: + a.flushThinking() + if a.textBuf == nil { + a.textBuf = &strings.Builder{} + } + a.textBuf.WriteString(e.Text) + + case EventThinkingDelta: + a.flushText() + if a.thinkBuf == nil { + a.thinkBuf = &strings.Builder{} + } + a.thinkBuf.WriteString(e.Text) + + case EventToolCallStart: + a.flushText() + a.flushThinking() + tc := &toolCallAccum{id: e.ToolCallID, name: e.ToolCallName} + a.toolCalls[e.ToolCallID] = tc + a.toolCallOrder = append(a.toolCallOrder, e.ToolCallID) + + case EventToolCallDelta: + if tc, ok := a.toolCalls[e.ToolCallID]; ok { + tc.argsBuf.WriteString(e.ArgDelta) + } + + case EventToolCallDone: + if tc, ok := a.toolCalls[e.ToolCallID]; ok { + if e.Args != nil { + // Done event carries authoritative complete args + tc.args = e.Args + } else { + // Fall back to accumulated deltas + tc.args = []byte(tc.argsBuf.String()) + } + } + + case EventUsage: + if e.Usage != nil { + a.usage.Add(*e.Usage) + } + + case EventError: + // Errors are handled by the stream consumer, not accumulated + } +} + +// Response builds the final message.Response from all accumulated events. +func (a *Accumulator) Response(stopReason message.StopReason, model string) message.Response { + a.flushText() + a.flushThinking() + a.flushToolCalls() + + return message.Response{ + Message: message.Message{ + Role: message.RoleAssistant, + Content: a.content, + }, + StopReason: stopReason, + Usage: a.usage, + Model: model, + } +} + +func (a *Accumulator) flushText() { + if a.textBuf != nil && a.textBuf.Len() > 0 { + a.content = append(a.content, message.NewTextContent(a.textBuf.String())) + a.textBuf = nil + } +} + +func (a *Accumulator) flushThinking() { + if a.thinkBuf != nil && a.thinkBuf.Len() > 0 { + a.content = append(a.content, message.NewThinkingContent(message.Thinking{ + Text: a.thinkBuf.String(), + })) + a.thinkBuf = nil + } +} + +func (a *Accumulator) flushToolCalls() { + for _, id := range a.toolCallOrder { + tc, ok := a.toolCalls[id] + if !ok { + continue + } + args := tc.args + if args == nil { + // Fallback: use accumulated deltas even if Done never arrived + args = []byte(tc.argsBuf.String()) + } + a.content = append(a.content, message.NewToolCallContent(message.ToolCall{ + ID: tc.id, + Name: tc.name, + Arguments: args, + })) + } + a.toolCalls = make(map[string]*toolCallAccum) + a.toolCallOrder = nil +} diff --git a/internal/stream/accumulator_test.go b/internal/stream/accumulator_test.go new file mode 100644 index 0000000..423b5b5 --- /dev/null +++ b/internal/stream/accumulator_test.go @@ -0,0 +1,225 @@ +package stream + +import ( + "encoding/json" + "testing" + + "somegit.dev/Owlibou/gnoma/internal/message" +) + +func TestAccumulator_TextOnly(t *testing.T) { + acc := NewAccumulator() + + acc.Apply(Event{Type: EventTextDelta, Text: "Hello "}) + acc.Apply(Event{Type: EventTextDelta, Text: "world!"}) + acc.Apply(Event{Type: EventUsage, Usage: &message.Usage{InputTokens: 10, OutputTokens: 5}}) + + resp := acc.Response(message.StopEndTurn, "mistral-large") + + if resp.StopReason != message.StopEndTurn { + t.Errorf("StopReason = %q, want %q", resp.StopReason, message.StopEndTurn) + } + if resp.Model != "mistral-large" { + t.Errorf("Model = %q, want %q", resp.Model, "mistral-large") + } + if resp.Message.Role != message.RoleAssistant { + t.Errorf("Role = %q, want %q", resp.Message.Role, message.RoleAssistant) + } + if resp.Message.TextContent() != "Hello world!" { + t.Errorf("TextContent() = %q, want %q", resp.Message.TextContent(), "Hello world!") + } + if resp.Usage.InputTokens != 10 { + t.Errorf("Usage.InputTokens = %d, want 10", resp.Usage.InputTokens) + } +} + +func TestAccumulator_SingleToolCall(t *testing.T) { + acc := NewAccumulator() + + acc.Apply(Event{Type: EventTextDelta, Text: "I'll run that."}) + acc.Apply(Event{ + Type: EventToolCallStart, + ToolCallID: "tc_1", + ToolCallName: "bash", + }) + acc.Apply(Event{ + Type: EventToolCallDelta, + ToolCallID: "tc_1", + ArgDelta: `{"comma`, + }) + acc.Apply(Event{ + Type: EventToolCallDelta, + ToolCallID: "tc_1", + ArgDelta: `nd":"ls"}`, + }) + acc.Apply(Event{ + Type: EventToolCallDone, + ToolCallID: "tc_1", + Args: json.RawMessage(`{"command":"ls"}`), + }) + + resp := acc.Response(message.StopToolUse, "mistral-large") + + if resp.StopReason != message.StopToolUse { + t.Errorf("StopReason = %q, want %q", resp.StopReason, message.StopToolUse) + } + + content := resp.Message.Content + if len(content) != 2 { + t.Fatalf("len(Content) = %d, want 2", len(content)) + } + + // First block: text + if content[0].Type != message.ContentText { + t.Errorf("Content[0].Type = %v, want text", content[0].Type) + } + if content[0].Text != "I'll run that." { + t.Errorf("Content[0].Text = %q", content[0].Text) + } + + // Second block: tool call + if content[1].Type != message.ContentToolCall { + t.Errorf("Content[1].Type = %v, want tool_call", content[1].Type) + } + tc := content[1].ToolCall + if tc.ID != "tc_1" { + t.Errorf("ToolCall.ID = %q, want tc_1", tc.ID) + } + if tc.Name != "bash" { + t.Errorf("ToolCall.Name = %q, want bash", tc.Name) + } + if string(tc.Arguments) != `{"command":"ls"}` { + t.Errorf("ToolCall.Arguments = %s", tc.Arguments) + } +} + +func TestAccumulator_MultipleToolCalls(t *testing.T) { + acc := NewAccumulator() + + // Tool call 1 + acc.Apply(Event{Type: EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "bash"}) + acc.Apply(Event{Type: EventToolCallDone, ToolCallID: "tc_1", Args: json.RawMessage(`{"command":"ls"}`)}) + + // Tool call 2 + acc.Apply(Event{Type: EventToolCallStart, ToolCallID: "tc_2", ToolCallName: "fs.read"}) + acc.Apply(Event{Type: EventToolCallDone, ToolCallID: "tc_2", Args: json.RawMessage(`{"path":"go.mod"}`)}) + + resp := acc.Response(message.StopToolUse, "") + calls := resp.Message.ToolCalls() + if len(calls) != 2 { + t.Fatalf("len(ToolCalls()) = %d, want 2", len(calls)) + } + if calls[0].Name != "bash" { + t.Errorf("calls[0].Name = %q", calls[0].Name) + } + if calls[1].Name != "fs.read" { + t.Errorf("calls[1].Name = %q", calls[1].Name) + } +} + +func TestAccumulator_ThinkingBlocks(t *testing.T) { + acc := NewAccumulator() + + acc.Apply(Event{Type: EventThinkingDelta, Text: "Let me think"}) + acc.Apply(Event{Type: EventThinkingDelta, Text: " about this..."}) + acc.Apply(Event{Type: EventTextDelta, Text: "Here's my answer."}) + + resp := acc.Response(message.StopEndTurn, "") + content := resp.Message.Content + + if len(content) != 2 { + t.Fatalf("len(Content) = %d, want 2", len(content)) + } + + // Thinking block first + if content[0].Type != message.ContentThinking { + t.Errorf("Content[0].Type = %v, want thinking", content[0].Type) + } + if content[0].Thinking.Text != "Let me think about this..." { + t.Errorf("Thinking.Text = %q", content[0].Thinking.Text) + } + + // Then text + if content[1].Type != message.ContentText { + t.Errorf("Content[1].Type = %v, want text", content[1].Type) + } +} + +func TestAccumulator_ToolCallDelta_Assembly(t *testing.T) { + acc := NewAccumulator() + + acc.Apply(Event{Type: EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "bash"}) + // Simulate fragmented JSON + acc.Apply(Event{Type: EventToolCallDelta, ToolCallID: "tc_1", ArgDelta: `{"`}) + acc.Apply(Event{Type: EventToolCallDelta, ToolCallID: "tc_1", ArgDelta: `comman`}) + acc.Apply(Event{Type: EventToolCallDelta, ToolCallID: "tc_1", ArgDelta: `d":"`}) + acc.Apply(Event{Type: EventToolCallDelta, ToolCallID: "tc_1", ArgDelta: `echo hi`}) + acc.Apply(Event{Type: EventToolCallDelta, ToolCallID: "tc_1", ArgDelta: `"}`}) + + // Done event carries the complete args (authoritative) + acc.Apply(Event{ + Type: EventToolCallDone, + ToolCallID: "tc_1", + Args: json.RawMessage(`{"command":"echo hi"}`), + }) + + resp := acc.Response(message.StopToolUse, "") + calls := resp.Message.ToolCalls() + if len(calls) != 1 { + t.Fatalf("len(ToolCalls()) = %d, want 1", len(calls)) + } + if string(calls[0].Arguments) != `{"command":"echo hi"}` { + t.Errorf("Arguments = %s", calls[0].Arguments) + } +} + +func TestAccumulator_EmptyStream(t *testing.T) { + acc := NewAccumulator() + resp := acc.Response(message.StopEndTurn, "model-x") + + if resp.Model != "model-x" { + t.Errorf("Model = %q", resp.Model) + } + if len(resp.Message.Content) != 0 { + t.Errorf("len(Content) = %d, want 0", len(resp.Message.Content)) + } +} + +func TestAccumulator_UsageCumulative(t *testing.T) { + acc := NewAccumulator() + + acc.Apply(Event{Type: EventUsage, Usage: &message.Usage{InputTokens: 100, OutputTokens: 30}}) + acc.Apply(Event{Type: EventTextDelta, Text: "hi"}) + acc.Apply(Event{Type: EventUsage, Usage: &message.Usage{InputTokens: 0, OutputTokens: 20}}) + + resp := acc.Response(message.StopEndTurn, "") + // Last usage wins (providers typically send cumulative, not incremental) + // But our Add() is additive — so 100+0=100 input, 30+20=50 output + if resp.Usage.InputTokens != 100 { + t.Errorf("InputTokens = %d, want 100", resp.Usage.InputTokens) + } + if resp.Usage.OutputTokens != 50 { + t.Errorf("OutputTokens = %d, want 50", resp.Usage.OutputTokens) + } +} + +func TestEventType_String(t *testing.T) { + tests := []struct { + et EventType + want string + }{ + {EventTextDelta, "text_delta"}, + {EventThinkingDelta, "thinking_delta"}, + {EventToolCallStart, "tool_call_start"}, + {EventToolCallDelta, "tool_call_delta"}, + {EventToolCallDone, "tool_call_done"}, + {EventUsage, "usage"}, + {EventError, "error"}, + {EventType(99), "unknown(99)"}, + } + for _, tt := range tests { + if got := tt.et.String(); got != tt.want { + t.Errorf("EventType(%d).String() = %q, want %q", tt.et, got, tt.want) + } + } +} diff --git a/internal/stream/event.go b/internal/stream/event.go new file mode 100644 index 0000000..83609ad --- /dev/null +++ b/internal/stream/event.go @@ -0,0 +1,70 @@ +package stream + +import ( + "encoding/json" + "fmt" + + "somegit.dev/Owlibou/gnoma/internal/message" +) + +// EventType discriminates streaming events. +type EventType int + +const ( + EventTextDelta EventType = iota + 1 + EventThinkingDelta + EventToolCallStart + EventToolCallDelta + EventToolCallDone + EventUsage + EventError +) + +func (et EventType) String() string { + switch et { + case EventTextDelta: + return "text_delta" + case EventThinkingDelta: + return "thinking_delta" + case EventToolCallStart: + return "tool_call_start" + case EventToolCallDelta: + return "tool_call_delta" + case EventToolCallDone: + return "tool_call_done" + case EventUsage: + return "usage" + case EventError: + return "error" + default: + return fmt.Sprintf("unknown(%d)", et) + } +} + +// Event is a single streaming event from a provider. +type Event struct { + Type EventType + + // TextDelta, ThinkingDelta + Text string + + // ToolCallStart: ID + Name set + // ToolCallDelta: ID + ArgDelta set + // ToolCallDone: ID + Args set (complete JSON) + ToolCallID string + ToolCallName string + ArgDelta string // partial JSON fragment + Args json.RawMessage // complete arguments (on Done) + + // Usage + Usage *message.Usage + + // Error + Err error + + // StopReason — set on the final event of a stream + StopReason message.StopReason + + // Model — set on first event if available + Model string +} diff --git a/internal/stream/stream.go b/internal/stream/stream.go new file mode 100644 index 0000000..7411245 --- /dev/null +++ b/internal/stream/stream.go @@ -0,0 +1,25 @@ +package stream + +// Stream is the unified pull-based iterator over provider events. +// All provider implementations produce this interface. +// +// Usage: +// +// for s.Next() { +// event := s.Current() +// process(event) +// } +// if err := s.Err(); err != nil { +// handle(err) +// } +// s.Close() +type Stream interface { + // Next advances to the next event. Returns false when exhausted or errored. + Next() bool + // Current returns the most recent event. Only valid after Next() returns true. + Current() Event + // Err returns the first error encountered, or nil. + Err() error + // Close releases resources. Must be called after consumption. + Close() error +}