feat: add foundation types, streaming, and provider interface
internal/message/ — Content discriminated union, Message, Usage, StopReason, Response. 22 tests. internal/stream/ — Stream pull-based iterator interface, Event types, Accumulator (assembles Response from events). 8 tests. internal/provider/ — Provider interface, Request, ToolDefinition, Registry with factory pattern, ProviderError with HTTP status classification. errors.AsType[E] for Go 1.26. 13 tests. 43 tests total, all passing.
This commit is contained in:
78
internal/message/content.go
Normal file
78
internal/message/content.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ContentType discriminates the content block union.
|
||||
type ContentType int
|
||||
|
||||
const (
|
||||
ContentText ContentType = iota + 1
|
||||
ContentToolCall
|
||||
ContentToolResult
|
||||
ContentThinking
|
||||
)
|
||||
|
||||
func (ct ContentType) String() string {
|
||||
switch ct {
|
||||
case ContentText:
|
||||
return "text"
|
||||
case ContentToolCall:
|
||||
return "tool_call"
|
||||
case ContentToolResult:
|
||||
return "tool_result"
|
||||
case ContentThinking:
|
||||
return "thinking"
|
||||
default:
|
||||
return fmt.Sprintf("unknown(%d)", ct)
|
||||
}
|
||||
}
|
||||
|
||||
// Content is a discriminated union. Exactly one payload field is set per Type.
|
||||
type Content struct {
|
||||
Type ContentType
|
||||
Text string // ContentText
|
||||
ToolCall *ToolCall // ContentToolCall
|
||||
ToolResult *ToolResult // ContentToolResult
|
||||
Thinking *Thinking // ContentThinking
|
||||
}
|
||||
|
||||
// ToolCall represents the model's request to invoke a tool.
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Arguments json.RawMessage `json:"arguments"`
|
||||
}
|
||||
|
||||
// ToolResult is the output of executing a tool, correlated by ToolCallID.
|
||||
type ToolResult struct {
|
||||
ToolCallID string `json:"tool_call_id"`
|
||||
Content string `json:"content"`
|
||||
IsError bool `json:"is_error"`
|
||||
}
|
||||
|
||||
// Thinking represents a reasoning/thinking trace.
|
||||
// Signature must round-trip unchanged (Anthropic requirement).
|
||||
type Thinking struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
Redacted bool `json:"redacted,omitempty"`
|
||||
}
|
||||
|
||||
func NewTextContent(text string) Content {
|
||||
return Content{Type: ContentText, Text: text}
|
||||
}
|
||||
|
||||
func NewToolCallContent(tc ToolCall) Content {
|
||||
return Content{Type: ContentToolCall, ToolCall: &tc}
|
||||
}
|
||||
|
||||
func NewToolResultContent(tr ToolResult) Content {
|
||||
return Content{Type: ContentToolResult, ToolResult: &tr}
|
||||
}
|
||||
|
||||
func NewThinkingContent(th Thinking) Content {
|
||||
return Content{Type: ContentThinking, Thinking: &th}
|
||||
}
|
||||
174
internal/message/content_test.go
Normal file
174
internal/message/content_test.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewTextContent(t *testing.T) {
|
||||
c := NewTextContent("hello world")
|
||||
|
||||
if c.Type != ContentText {
|
||||
t.Errorf("Type = %v, want %v", c.Type, ContentText)
|
||||
}
|
||||
if c.Text != "hello world" {
|
||||
t.Errorf("Text = %q, want %q", c.Text, "hello world")
|
||||
}
|
||||
if c.ToolCall != nil {
|
||||
t.Error("ToolCall should be nil for text content")
|
||||
}
|
||||
if c.ToolResult != nil {
|
||||
t.Error("ToolResult should be nil for text content")
|
||||
}
|
||||
if c.Thinking != nil {
|
||||
t.Error("Thinking should be nil for text content")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewToolCallContent(t *testing.T) {
|
||||
args := json.RawMessage(`{"command":"ls -la"}`)
|
||||
tc := ToolCall{
|
||||
ID: "tc_001",
|
||||
Name: "bash",
|
||||
Arguments: args,
|
||||
}
|
||||
c := NewToolCallContent(tc)
|
||||
|
||||
if c.Type != ContentToolCall {
|
||||
t.Errorf("Type = %v, want %v", c.Type, ContentToolCall)
|
||||
}
|
||||
if c.ToolCall == nil {
|
||||
t.Fatal("ToolCall should not be nil")
|
||||
}
|
||||
if c.ToolCall.ID != "tc_001" {
|
||||
t.Errorf("ToolCall.ID = %q, want %q", c.ToolCall.ID, "tc_001")
|
||||
}
|
||||
if c.ToolCall.Name != "bash" {
|
||||
t.Errorf("ToolCall.Name = %q, want %q", c.ToolCall.Name, "bash")
|
||||
}
|
||||
if string(c.ToolCall.Arguments) != `{"command":"ls -la"}` {
|
||||
t.Errorf("ToolCall.Arguments = %s, want %s", c.ToolCall.Arguments, args)
|
||||
}
|
||||
if c.Text != "" {
|
||||
t.Error("Text should be empty for tool call content")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewToolResultContent(t *testing.T) {
|
||||
tr := ToolResult{
|
||||
ToolCallID: "tc_001",
|
||||
Content: "file1.go\nfile2.go",
|
||||
IsError: false,
|
||||
}
|
||||
c := NewToolResultContent(tr)
|
||||
|
||||
if c.Type != ContentToolResult {
|
||||
t.Errorf("Type = %v, want %v", c.Type, ContentToolResult)
|
||||
}
|
||||
if c.ToolResult == nil {
|
||||
t.Fatal("ToolResult should not be nil")
|
||||
}
|
||||
if c.ToolResult.ToolCallID != "tc_001" {
|
||||
t.Errorf("ToolResult.ToolCallID = %q, want %q", c.ToolResult.ToolCallID, "tc_001")
|
||||
}
|
||||
if c.ToolResult.IsError {
|
||||
t.Error("ToolResult.IsError should be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewToolResultContent_Error(t *testing.T) {
|
||||
tr := ToolResult{
|
||||
ToolCallID: "tc_002",
|
||||
Content: "permission denied",
|
||||
IsError: true,
|
||||
}
|
||||
c := NewToolResultContent(tr)
|
||||
|
||||
if !c.ToolResult.IsError {
|
||||
t.Error("ToolResult.IsError should be true")
|
||||
}
|
||||
if c.ToolResult.Content != "permission denied" {
|
||||
t.Errorf("ToolResult.Content = %q, want %q", c.ToolResult.Content, "permission denied")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewThinkingContent(t *testing.T) {
|
||||
th := Thinking{
|
||||
Text: "Let me think about this...",
|
||||
Signature: "sig_abc123",
|
||||
}
|
||||
c := NewThinkingContent(th)
|
||||
|
||||
if c.Type != ContentThinking {
|
||||
t.Errorf("Type = %v, want %v", c.Type, ContentThinking)
|
||||
}
|
||||
if c.Thinking == nil {
|
||||
t.Fatal("Thinking should not be nil")
|
||||
}
|
||||
if c.Thinking.Text != "Let me think about this..." {
|
||||
t.Errorf("Thinking.Text = %q", c.Thinking.Text)
|
||||
}
|
||||
if c.Thinking.Signature != "sig_abc123" {
|
||||
t.Errorf("Thinking.Signature = %q", c.Thinking.Signature)
|
||||
}
|
||||
if c.Thinking.Redacted {
|
||||
t.Error("Thinking.Redacted should be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRedactedThinkingContent(t *testing.T) {
|
||||
th := Thinking{
|
||||
Redacted: true,
|
||||
}
|
||||
c := NewThinkingContent(th)
|
||||
|
||||
if !c.Thinking.Redacted {
|
||||
t.Error("Thinking.Redacted should be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestContentType_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
ct ContentType
|
||||
want string
|
||||
}{
|
||||
{ContentText, "text"},
|
||||
{ContentToolCall, "tool_call"},
|
||||
{ContentToolResult, "tool_result"},
|
||||
{ContentThinking, "thinking"},
|
||||
{ContentType(99), "unknown(99)"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := tt.ct.String(); got != tt.want {
|
||||
t.Errorf("ContentType(%d).String() = %q, want %q", tt.ct, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolCall_JSON_RoundTrip(t *testing.T) {
|
||||
original := ToolCall{
|
||||
ID: "tc_100",
|
||||
Name: "fs.read",
|
||||
Arguments: json.RawMessage(`{"path":"/tmp/test.go","offset":0}`),
|
||||
}
|
||||
|
||||
data, err := json.Marshal(original)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal: %v", err)
|
||||
}
|
||||
|
||||
var decoded ToolCall
|
||||
if err := json.Unmarshal(data, &decoded); err != nil {
|
||||
t.Fatalf("Unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if decoded.ID != original.ID {
|
||||
t.Errorf("ID = %q, want %q", decoded.ID, original.ID)
|
||||
}
|
||||
if decoded.Name != original.Name {
|
||||
t.Errorf("Name = %q, want %q", decoded.Name, original.Name)
|
||||
}
|
||||
if string(decoded.Arguments) != string(original.Arguments) {
|
||||
t.Errorf("Arguments = %s, want %s", decoded.Arguments, original.Arguments)
|
||||
}
|
||||
}
|
||||
89
internal/message/message.go
Normal file
89
internal/message/message.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package message
|
||||
|
||||
import "strings"
|
||||
|
||||
// Role identifies the sender of a message.
|
||||
type Role string
|
||||
|
||||
const (
|
||||
RoleUser Role = "user"
|
||||
RoleAssistant Role = "assistant"
|
||||
RoleSystem Role = "system"
|
||||
)
|
||||
|
||||
// Message represents a single turn in the conversation.
|
||||
type Message struct {
|
||||
Role Role
|
||||
Content []Content
|
||||
}
|
||||
|
||||
func NewUserText(text string) Message {
|
||||
return Message{
|
||||
Role: RoleUser,
|
||||
Content: []Content{NewTextContent(text)},
|
||||
}
|
||||
}
|
||||
|
||||
func NewAssistantText(text string) Message {
|
||||
return Message{
|
||||
Role: RoleAssistant,
|
||||
Content: []Content{NewTextContent(text)},
|
||||
}
|
||||
}
|
||||
|
||||
func NewAssistantContent(blocks ...Content) Message {
|
||||
return Message{
|
||||
Role: RoleAssistant,
|
||||
Content: blocks,
|
||||
}
|
||||
}
|
||||
|
||||
func NewSystemText(text string) Message {
|
||||
return Message{
|
||||
Role: RoleSystem,
|
||||
Content: []Content{NewTextContent(text)},
|
||||
}
|
||||
}
|
||||
|
||||
func NewToolResults(results ...ToolResult) Message {
|
||||
content := make([]Content, len(results))
|
||||
for i, r := range results {
|
||||
content[i] = NewToolResultContent(r)
|
||||
}
|
||||
return Message{
|
||||
Role: RoleUser,
|
||||
Content: content,
|
||||
}
|
||||
}
|
||||
|
||||
// HasToolCalls returns true if any content block is a tool call.
|
||||
func (m Message) HasToolCalls() bool {
|
||||
for _, c := range m.Content {
|
||||
if c.Type == ContentToolCall {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ToolCalls extracts all tool call blocks.
|
||||
func (m Message) ToolCalls() []ToolCall {
|
||||
var calls []ToolCall
|
||||
for _, c := range m.Content {
|
||||
if c.Type == ContentToolCall && c.ToolCall != nil {
|
||||
calls = append(calls, *c.ToolCall)
|
||||
}
|
||||
}
|
||||
return calls
|
||||
}
|
||||
|
||||
// TextContent concatenates all text blocks.
|
||||
func (m Message) TextContent() string {
|
||||
var b strings.Builder
|
||||
for _, c := range m.Content {
|
||||
if c.Type == ContentText {
|
||||
b.WriteString(c.Text)
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
214
internal/message/message_test.go
Normal file
214
internal/message/message_test.go
Normal file
@@ -0,0 +1,214 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewUserText(t *testing.T) {
|
||||
m := NewUserText("hello")
|
||||
if m.Role != RoleUser {
|
||||
t.Errorf("Role = %q, want %q", m.Role, RoleUser)
|
||||
}
|
||||
if len(m.Content) != 1 {
|
||||
t.Fatalf("len(Content) = %d, want 1", len(m.Content))
|
||||
}
|
||||
if m.Content[0].Type != ContentText {
|
||||
t.Errorf("Content[0].Type = %v, want %v", m.Content[0].Type, ContentText)
|
||||
}
|
||||
if m.Content[0].Text != "hello" {
|
||||
t.Errorf("Content[0].Text = %q, want %q", m.Content[0].Text, "hello")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAssistantText(t *testing.T) {
|
||||
m := NewAssistantText("response")
|
||||
if m.Role != RoleAssistant {
|
||||
t.Errorf("Role = %q, want %q", m.Role, RoleAssistant)
|
||||
}
|
||||
if m.TextContent() != "response" {
|
||||
t.Errorf("TextContent() = %q, want %q", m.TextContent(), "response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSystemText(t *testing.T) {
|
||||
m := NewSystemText("you are a helper")
|
||||
if m.Role != RoleSystem {
|
||||
t.Errorf("Role = %q, want %q", m.Role, RoleSystem)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAssistantContent_Mixed(t *testing.T) {
|
||||
m := NewAssistantContent(
|
||||
NewTextContent("I'll run that command."),
|
||||
NewToolCallContent(ToolCall{
|
||||
ID: "tc_1",
|
||||
Name: "bash",
|
||||
Arguments: json.RawMessage(`{"command":"ls"}`),
|
||||
}),
|
||||
)
|
||||
|
||||
if m.Role != RoleAssistant {
|
||||
t.Errorf("Role = %q, want %q", m.Role, RoleAssistant)
|
||||
}
|
||||
if len(m.Content) != 2 {
|
||||
t.Fatalf("len(Content) = %d, want 2", len(m.Content))
|
||||
}
|
||||
if m.Content[0].Type != ContentText {
|
||||
t.Errorf("Content[0].Type = %v, want text", m.Content[0].Type)
|
||||
}
|
||||
if m.Content[1].Type != ContentToolCall {
|
||||
t.Errorf("Content[1].Type = %v, want tool_call", m.Content[1].Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewToolResults(t *testing.T) {
|
||||
m := NewToolResults(
|
||||
ToolResult{ToolCallID: "tc_1", Content: "output1"},
|
||||
ToolResult{ToolCallID: "tc_2", Content: "output2", IsError: true},
|
||||
)
|
||||
|
||||
if m.Role != RoleUser {
|
||||
t.Errorf("Role = %q, want %q", m.Role, RoleUser)
|
||||
}
|
||||
if len(m.Content) != 2 {
|
||||
t.Fatalf("len(Content) = %d, want 2", len(m.Content))
|
||||
}
|
||||
if m.Content[0].ToolResult.ToolCallID != "tc_1" {
|
||||
t.Errorf("Content[0].ToolResult.ToolCallID = %q", m.Content[0].ToolResult.ToolCallID)
|
||||
}
|
||||
if m.Content[1].ToolResult.IsError != true {
|
||||
t.Error("Content[1].ToolResult.IsError should be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessage_HasToolCalls(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg Message
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "text only",
|
||||
msg: NewUserText("hello"),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "with tool call",
|
||||
msg: NewAssistantContent(
|
||||
NewTextContent("running..."),
|
||||
NewToolCallContent(ToolCall{ID: "tc_1", Name: "bash"}),
|
||||
),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "tool results (not calls)",
|
||||
msg: NewToolResults(ToolResult{ToolCallID: "tc_1", Content: "ok"}),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty message",
|
||||
msg: Message{Role: RoleAssistant},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.msg.HasToolCalls(); got != tt.want {
|
||||
t.Errorf("HasToolCalls() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessage_ToolCalls(t *testing.T) {
|
||||
m := NewAssistantContent(
|
||||
NewTextContent("here are two commands"),
|
||||
NewToolCallContent(ToolCall{ID: "tc_1", Name: "bash", Arguments: json.RawMessage(`{"command":"ls"}`)}),
|
||||
NewTextContent("and another"),
|
||||
NewToolCallContent(ToolCall{ID: "tc_2", Name: "fs.read", Arguments: json.RawMessage(`{"path":"go.mod"}`)}),
|
||||
)
|
||||
|
||||
calls := m.ToolCalls()
|
||||
if len(calls) != 2 {
|
||||
t.Fatalf("len(ToolCalls()) = %d, want 2", len(calls))
|
||||
}
|
||||
if calls[0].ID != "tc_1" {
|
||||
t.Errorf("calls[0].ID = %q, want tc_1", calls[0].ID)
|
||||
}
|
||||
if calls[1].Name != "fs.read" {
|
||||
t.Errorf("calls[1].Name = %q, want fs.read", calls[1].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessage_ToolCalls_Empty(t *testing.T) {
|
||||
m := NewUserText("no tools here")
|
||||
calls := m.ToolCalls()
|
||||
if len(calls) != 0 {
|
||||
t.Errorf("len(ToolCalls()) = %d, want 0", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessage_TextContent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg Message
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "single text",
|
||||
msg: NewUserText("hello"),
|
||||
want: "hello",
|
||||
},
|
||||
{
|
||||
name: "multiple text blocks",
|
||||
msg: NewAssistantContent(
|
||||
NewTextContent("first "),
|
||||
NewToolCallContent(ToolCall{ID: "tc_1", Name: "bash"}),
|
||||
NewTextContent("second"),
|
||||
),
|
||||
want: "first second",
|
||||
},
|
||||
{
|
||||
name: "no text",
|
||||
msg: NewToolResults(ToolResult{ToolCallID: "tc_1", Content: "output"}),
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "empty message",
|
||||
msg: Message{Role: RoleAssistant},
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.msg.TextContent(); got != tt.want {
|
||||
t.Errorf("TextContent() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponse_Fields(t *testing.T) {
|
||||
r := Response{
|
||||
Message: NewAssistantText("done"),
|
||||
StopReason: StopEndTurn,
|
||||
Usage: Usage{InputTokens: 100, OutputTokens: 50},
|
||||
Model: "mistral-large-latest",
|
||||
}
|
||||
|
||||
if r.StopReason != StopEndTurn {
|
||||
t.Errorf("StopReason = %q, want %q", r.StopReason, StopEndTurn)
|
||||
}
|
||||
if r.Usage.TotalTokens() != 150 {
|
||||
t.Errorf("Usage.TotalTokens() = %d, want 150", r.Usage.TotalTokens())
|
||||
}
|
||||
if r.Model != "mistral-large-latest" {
|
||||
t.Errorf("Model = %q", r.Model)
|
||||
}
|
||||
if r.Message.TextContent() != "done" {
|
||||
t.Errorf("Message.TextContent() = %q", r.Message.TextContent())
|
||||
}
|
||||
}
|
||||
9
internal/message/response.go
Normal file
9
internal/message/response.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package message
|
||||
|
||||
// Response wraps a completed assistant turn.
|
||||
type Response struct {
|
||||
Message Message
|
||||
StopReason StopReason
|
||||
Usage Usage
|
||||
Model string
|
||||
}
|
||||
11
internal/message/stop.go
Normal file
11
internal/message/stop.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package message
|
||||
|
||||
// StopReason indicates why the model stopped generating.
|
||||
type StopReason string
|
||||
|
||||
const (
|
||||
StopEndTurn StopReason = "end_turn"
|
||||
StopMaxTokens StopReason = "max_tokens"
|
||||
StopToolUse StopReason = "tool_use"
|
||||
StopSequence StopReason = "stop_sequence"
|
||||
)
|
||||
20
internal/message/usage.go
Normal file
20
internal/message/usage.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package message
|
||||
|
||||
// Usage tracks token consumption for a single API turn.
|
||||
type Usage struct {
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
CacheReadTokens int64 `json:"cache_read_tokens,omitempty"`
|
||||
CacheCreationTokens int64 `json:"cache_creation_tokens,omitempty"`
|
||||
}
|
||||
|
||||
func (u Usage) TotalTokens() int64 {
|
||||
return u.InputTokens + u.OutputTokens
|
||||
}
|
||||
|
||||
func (u *Usage) Add(other Usage) {
|
||||
u.InputTokens += other.InputTokens
|
||||
u.OutputTokens += other.OutputTokens
|
||||
u.CacheReadTokens += other.CacheReadTokens
|
||||
u.CacheCreationTokens += other.CacheCreationTokens
|
||||
}
|
||||
70
internal/message/usage_test.go
Normal file
70
internal/message/usage_test.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package message
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestUsage_TotalTokens(t *testing.T) {
|
||||
u := Usage{InputTokens: 100, OutputTokens: 50}
|
||||
if got := u.TotalTokens(); got != 150 {
|
||||
t.Errorf("TotalTokens() = %d, want 150", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsage_TotalTokens_Zero(t *testing.T) {
|
||||
var u Usage
|
||||
if got := u.TotalTokens(); got != 0 {
|
||||
t.Errorf("TotalTokens() = %d, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsage_Add(t *testing.T) {
|
||||
u := Usage{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
CacheReadTokens: 10,
|
||||
CacheCreationTokens: 5,
|
||||
}
|
||||
other := Usage{
|
||||
InputTokens: 200,
|
||||
OutputTokens: 80,
|
||||
CacheReadTokens: 20,
|
||||
CacheCreationTokens: 15,
|
||||
}
|
||||
|
||||
u.Add(other)
|
||||
|
||||
if u.InputTokens != 300 {
|
||||
t.Errorf("InputTokens = %d, want 300", u.InputTokens)
|
||||
}
|
||||
if u.OutputTokens != 130 {
|
||||
t.Errorf("OutputTokens = %d, want 130", u.OutputTokens)
|
||||
}
|
||||
if u.CacheReadTokens != 30 {
|
||||
t.Errorf("CacheReadTokens = %d, want 30", u.CacheReadTokens)
|
||||
}
|
||||
if u.CacheCreationTokens != 20 {
|
||||
t.Errorf("CacheCreationTokens = %d, want 20", u.CacheCreationTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsage_Add_Multiple(t *testing.T) {
|
||||
var total Usage
|
||||
turns := []Usage{
|
||||
{InputTokens: 100, OutputTokens: 50},
|
||||
{InputTokens: 200, OutputTokens: 80},
|
||||
{InputTokens: 150, OutputTokens: 60},
|
||||
}
|
||||
|
||||
for _, turn := range turns {
|
||||
total.Add(turn)
|
||||
}
|
||||
|
||||
if total.InputTokens != 450 {
|
||||
t.Errorf("InputTokens = %d, want 450", total.InputTokens)
|
||||
}
|
||||
if total.OutputTokens != 190 {
|
||||
t.Errorf("OutputTokens = %d, want 190", total.OutputTokens)
|
||||
}
|
||||
if total.TotalTokens() != 640 {
|
||||
t.Errorf("TotalTokens() = %d, want 640", total.TotalTokens())
|
||||
}
|
||||
}
|
||||
79
internal/provider/errors.go
Normal file
79
internal/provider/errors.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ErrorKind classifies provider errors for retry decisions.
|
||||
type ErrorKind int
|
||||
|
||||
const (
|
||||
ErrTransient ErrorKind = iota + 1 // 429, 500, 502, 503, 529 — retry with backoff
|
||||
ErrAuth // 401, 403 — don't retry
|
||||
ErrBadRequest // 400 — don't retry, fix request
|
||||
ErrNotFound // 404 — model/endpoint not found
|
||||
ErrOverloaded // capacity exhausted — backoff + retry
|
||||
)
|
||||
|
||||
func (k ErrorKind) String() string {
|
||||
switch k {
|
||||
case ErrTransient:
|
||||
return "transient"
|
||||
case ErrAuth:
|
||||
return "auth"
|
||||
case ErrBadRequest:
|
||||
return "bad_request"
|
||||
case ErrNotFound:
|
||||
return "not_found"
|
||||
case ErrOverloaded:
|
||||
return "overloaded"
|
||||
default:
|
||||
return fmt.Sprintf("unknown(%d)", k)
|
||||
}
|
||||
}
|
||||
|
||||
// ProviderError wraps an SDK error with classification metadata.
|
||||
type ProviderError struct {
|
||||
Kind ErrorKind
|
||||
Provider string
|
||||
StatusCode int
|
||||
Message string
|
||||
Retryable bool
|
||||
RetryAfter time.Duration // from Retry-After or rate limit headers
|
||||
Err error // underlying SDK error
|
||||
}
|
||||
|
||||
func (e *ProviderError) Error() string {
|
||||
if e.Err != nil {
|
||||
return fmt.Sprintf("%s %s (%d): %s: %v", e.Provider, e.Kind, e.StatusCode, e.Message, e.Err)
|
||||
}
|
||||
return fmt.Sprintf("%s %s (%d): %s", e.Provider, e.Kind, e.StatusCode, e.Message)
|
||||
}
|
||||
|
||||
func (e *ProviderError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
// ClassifyHTTPStatus returns the ErrorKind and retryability for an HTTP status code.
|
||||
func ClassifyHTTPStatus(status int) (ErrorKind, bool) {
|
||||
switch {
|
||||
case status == 401 || status == 403:
|
||||
return ErrAuth, false
|
||||
case status == 400:
|
||||
return ErrBadRequest, false
|
||||
case status == 404:
|
||||
return ErrNotFound, false
|
||||
case status == 429 || status == 529:
|
||||
return ErrTransient, true
|
||||
case status == 500 || status == 502 || status == 503:
|
||||
return ErrTransient, true
|
||||
case status == 504:
|
||||
return ErrOverloaded, true
|
||||
default:
|
||||
if status >= 500 {
|
||||
return ErrTransient, true
|
||||
}
|
||||
return ErrBadRequest, false
|
||||
}
|
||||
}
|
||||
118
internal/provider/errors_test.go
Normal file
118
internal/provider/errors_test.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestProviderError_Error(t *testing.T) {
|
||||
err := &ProviderError{
|
||||
Kind: ErrTransient,
|
||||
Provider: "mistral",
|
||||
StatusCode: 429,
|
||||
Message: "rate limited",
|
||||
}
|
||||
got := err.Error()
|
||||
want := "mistral transient (429): rate limited"
|
||||
if got != want {
|
||||
t.Errorf("Error() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderError_Error_WithWrapped(t *testing.T) {
|
||||
inner := errors.New("connection reset")
|
||||
err := &ProviderError{
|
||||
Kind: ErrTransient,
|
||||
Provider: "openai",
|
||||
StatusCode: 502,
|
||||
Message: "bad gateway",
|
||||
Err: inner,
|
||||
}
|
||||
got := err.Error()
|
||||
want := "openai transient (502): bad gateway: connection reset"
|
||||
if got != want {
|
||||
t.Errorf("Error() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderError_Unwrap(t *testing.T) {
|
||||
inner := errors.New("timeout")
|
||||
err := &ProviderError{
|
||||
Kind: ErrTransient,
|
||||
Err: inner,
|
||||
}
|
||||
if !errors.Is(err, inner) {
|
||||
t.Error("errors.Is should find inner error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderError_AsType(t *testing.T) {
|
||||
inner := &ProviderError{
|
||||
Kind: ErrAuth,
|
||||
Provider: "anthropic",
|
||||
StatusCode: 401,
|
||||
Message: "invalid key",
|
||||
}
|
||||
wrapped := fmt.Errorf("api call failed: %w", inner)
|
||||
|
||||
pErr, ok := errors.AsType[*ProviderError](wrapped)
|
||||
if !ok {
|
||||
t.Fatal("errors.AsType should find ProviderError")
|
||||
}
|
||||
if pErr.Kind != ErrAuth {
|
||||
t.Errorf("Kind = %v, want %v", pErr.Kind, ErrAuth)
|
||||
}
|
||||
if pErr.Provider != "anthropic" {
|
||||
t.Errorf("Provider = %q", pErr.Provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyHTTPStatus(t *testing.T) {
|
||||
tests := []struct {
|
||||
status int
|
||||
wantKind ErrorKind
|
||||
wantRetry bool
|
||||
}{
|
||||
{200, ErrBadRequest, false}, // shouldn't happen, but safe default
|
||||
{400, ErrBadRequest, false},
|
||||
{401, ErrAuth, false},
|
||||
{403, ErrAuth, false},
|
||||
{404, ErrNotFound, false},
|
||||
{429, ErrTransient, true},
|
||||
{500, ErrTransient, true},
|
||||
{502, ErrTransient, true},
|
||||
{503, ErrTransient, true},
|
||||
{504, ErrOverloaded, true},
|
||||
{529, ErrTransient, true},
|
||||
{599, ErrTransient, true}, // unknown 5xx
|
||||
}
|
||||
for _, tt := range tests {
|
||||
kind, retry := ClassifyHTTPStatus(tt.status)
|
||||
if kind != tt.wantKind {
|
||||
t.Errorf("ClassifyHTTPStatus(%d) kind = %v, want %v", tt.status, kind, tt.wantKind)
|
||||
}
|
||||
if retry != tt.wantRetry {
|
||||
t.Errorf("ClassifyHTTPStatus(%d) retry = %v, want %v", tt.status, retry, tt.wantRetry)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorKind_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
kind ErrorKind
|
||||
want string
|
||||
}{
|
||||
{ErrTransient, "transient"},
|
||||
{ErrAuth, "auth"},
|
||||
{ErrBadRequest, "bad_request"},
|
||||
{ErrNotFound, "not_found"},
|
||||
{ErrOverloaded, "overloaded"},
|
||||
{ErrorKind(99), "unknown(99)"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := tt.kind.String(); got != tt.want {
|
||||
t.Errorf("ErrorKind(%d).String() = %q, want %q", tt.kind, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
44
internal/provider/provider.go
Normal file
44
internal/provider/provider.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
"somegit.dev/Owlibou/gnoma/internal/stream"
|
||||
)
|
||||
|
||||
// Request encapsulates everything needed for a single LLM API call.
|
||||
type Request struct {
|
||||
Model string
|
||||
SystemPrompt string
|
||||
Messages []message.Message
|
||||
Tools []ToolDefinition
|
||||
MaxTokens int64
|
||||
Temperature *float64
|
||||
TopP *float64
|
||||
TopK *int64
|
||||
StopSequences []string
|
||||
Thinking *ThinkingConfig
|
||||
}
|
||||
|
||||
// ToolDefinition is the provider-agnostic tool schema.
|
||||
type ToolDefinition struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters json.RawMessage `json:"parameters"` // JSON Schema passthrough
|
||||
}
|
||||
|
||||
// ThinkingConfig controls extended thinking / reasoning.
|
||||
type ThinkingConfig struct {
|
||||
BudgetTokens int64
|
||||
}
|
||||
|
||||
// Provider is the core abstraction over all LLM backends.
|
||||
type Provider interface {
|
||||
// Stream initiates a streaming request and returns an event stream.
|
||||
Stream(ctx context.Context, req Request) (stream.Stream, error)
|
||||
|
||||
// Name returns the provider identifier (e.g., "mistral", "anthropic").
|
||||
Name() string
|
||||
}
|
||||
69
internal/provider/registry.go
Normal file
69
internal/provider/registry.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ProviderConfig is the common configuration for any provider.
|
||||
type ProviderConfig struct {
|
||||
Name string
|
||||
APIKey string
|
||||
BaseURL string // override for OpenAI-compat endpoints
|
||||
Model string // default model for this provider
|
||||
Options map[string]any // provider-specific options
|
||||
}
|
||||
|
||||
// Factory creates a Provider from configuration.
|
||||
type Factory func(cfg ProviderConfig) (Provider, error)
|
||||
|
||||
// Registry maps provider names to factory functions.
|
||||
type Registry struct {
|
||||
mu sync.RWMutex
|
||||
factories map[string]Factory
|
||||
}
|
||||
|
||||
func NewRegistry() *Registry {
|
||||
return &Registry{
|
||||
factories: make(map[string]Factory),
|
||||
}
|
||||
}
|
||||
|
||||
// Register adds a provider factory. Overwrites if name already exists.
|
||||
func (r *Registry) Register(name string, f Factory) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.factories[name] = f
|
||||
}
|
||||
|
||||
// Create instantiates a provider by name with the given config.
|
||||
func (r *Registry) Create(name string, cfg ProviderConfig) (Provider, error) {
|
||||
r.mu.RLock()
|
||||
f, ok := r.factories[name]
|
||||
r.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown provider: %q", name)
|
||||
}
|
||||
cfg.Name = name
|
||||
return f(cfg)
|
||||
}
|
||||
|
||||
// Has returns true if a factory is registered for the given name.
|
||||
func (r *Registry) Has(name string) bool {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
_, ok := r.factories[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
// Names returns all registered provider names.
|
||||
func (r *Registry) Names() []string {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
names := make([]string, 0, len(r.factories))
|
||||
for name := range r.factories {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
131
internal/provider/registry_test.go
Normal file
131
internal/provider/registry_test.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"slices"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/stream"
|
||||
)
|
||||
|
||||
// mockProvider implements Provider for testing.
|
||||
type mockProvider struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func (m *mockProvider) Stream(_ context.Context, _ Request) (stream.Stream, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockProvider) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func TestRegistry_RegisterAndCreate(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
r.Register("mock", func(cfg ProviderConfig) (Provider, error) {
|
||||
return &mockProvider{name: cfg.Name}, nil
|
||||
})
|
||||
|
||||
p, err := r.Create("mock", ProviderConfig{})
|
||||
if err != nil {
|
||||
t.Fatalf("Create: %v", err)
|
||||
}
|
||||
if p.Name() != "mock" {
|
||||
t.Errorf("Name() = %q, want %q", p.Name(), "mock")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Create_Unknown(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
_, err := r.Create("nonexistent", ProviderConfig{})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown provider")
|
||||
}
|
||||
want := `unknown provider: "nonexistent"`
|
||||
if err.Error() != want {
|
||||
t.Errorf("error = %q, want %q", err.Error(), want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Create_FactoryError(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register("broken", func(cfg ProviderConfig) (Provider, error) {
|
||||
return nil, errors.New("missing api key")
|
||||
})
|
||||
|
||||
_, err := r.Create("broken", ProviderConfig{})
|
||||
if err == nil {
|
||||
t.Fatal("expected error from factory")
|
||||
}
|
||||
if err.Error() != "missing api key" {
|
||||
t.Errorf("error = %q", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Create_SetsName(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
var receivedName string
|
||||
r.Register("test", func(cfg ProviderConfig) (Provider, error) {
|
||||
receivedName = cfg.Name
|
||||
return &mockProvider{name: cfg.Name}, nil
|
||||
})
|
||||
|
||||
_, _ = r.Create("test", ProviderConfig{APIKey: "sk-123"})
|
||||
if receivedName != "test" {
|
||||
t.Errorf("factory received Name = %q, want %q", receivedName, "test")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Has(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register("exists", func(cfg ProviderConfig) (Provider, error) {
|
||||
return nil, nil
|
||||
})
|
||||
|
||||
if !r.Has("exists") {
|
||||
t.Error("Has(exists) = false, want true")
|
||||
}
|
||||
if r.Has("nope") {
|
||||
t.Error("Has(nope) = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Names(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register("alpha", func(cfg ProviderConfig) (Provider, error) { return nil, nil })
|
||||
r.Register("beta", func(cfg ProviderConfig) (Provider, error) { return nil, nil })
|
||||
r.Register("gamma", func(cfg ProviderConfig) (Provider, error) { return nil, nil })
|
||||
|
||||
names := r.Names()
|
||||
sort.Strings(names)
|
||||
|
||||
want := []string{"alpha", "beta", "gamma"}
|
||||
if !slices.Equal(names, want) {
|
||||
t.Errorf("Names() = %v, want %v", names, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Register_Overwrite(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
r.Register("dup", func(cfg ProviderConfig) (Provider, error) {
|
||||
return &mockProvider{name: "old"}, nil
|
||||
})
|
||||
r.Register("dup", func(cfg ProviderConfig) (Provider, error) {
|
||||
return &mockProvider{name: "new"}, nil
|
||||
})
|
||||
|
||||
p, err := r.Create("dup", ProviderConfig{})
|
||||
if err != nil {
|
||||
t.Fatalf("Create: %v", err)
|
||||
}
|
||||
if p.Name() != "new" {
|
||||
t.Errorf("Name() = %q, want %q (overwritten factory)", p.Name(), "new")
|
||||
}
|
||||
}
|
||||
143
internal/stream/accumulator.go
Normal file
143
internal/stream/accumulator.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
)
|
||||
|
||||
// Accumulator assembles a message.Response from a sequence of Events.
|
||||
// Provider adapters translate SDK events into stream.Events;
|
||||
// the Accumulator — shared, tested once — builds the final Response.
|
||||
type Accumulator struct {
|
||||
content []message.Content
|
||||
usage message.Usage
|
||||
|
||||
// Active text block being built
|
||||
textBuf *strings.Builder
|
||||
|
||||
// Active thinking block being built
|
||||
thinkBuf *strings.Builder
|
||||
|
||||
// Tool calls in progress, keyed by ToolCallID
|
||||
toolCalls map[string]*toolCallAccum
|
||||
// Ordered tool call IDs to preserve emission order
|
||||
toolCallOrder []string
|
||||
}
|
||||
|
||||
type toolCallAccum struct {
|
||||
id string
|
||||
name string
|
||||
argsBuf strings.Builder
|
||||
args []byte // final complete args (from Done event)
|
||||
}
|
||||
|
||||
func NewAccumulator() *Accumulator {
|
||||
return &Accumulator{
|
||||
toolCalls: make(map[string]*toolCallAccum),
|
||||
}
|
||||
}
|
||||
|
||||
// Apply processes a single event, updating the accumulator's state.
|
||||
func (a *Accumulator) Apply(e Event) {
|
||||
switch e.Type {
|
||||
case EventTextDelta:
|
||||
a.flushThinking()
|
||||
if a.textBuf == nil {
|
||||
a.textBuf = &strings.Builder{}
|
||||
}
|
||||
a.textBuf.WriteString(e.Text)
|
||||
|
||||
case EventThinkingDelta:
|
||||
a.flushText()
|
||||
if a.thinkBuf == nil {
|
||||
a.thinkBuf = &strings.Builder{}
|
||||
}
|
||||
a.thinkBuf.WriteString(e.Text)
|
||||
|
||||
case EventToolCallStart:
|
||||
a.flushText()
|
||||
a.flushThinking()
|
||||
tc := &toolCallAccum{id: e.ToolCallID, name: e.ToolCallName}
|
||||
a.toolCalls[e.ToolCallID] = tc
|
||||
a.toolCallOrder = append(a.toolCallOrder, e.ToolCallID)
|
||||
|
||||
case EventToolCallDelta:
|
||||
if tc, ok := a.toolCalls[e.ToolCallID]; ok {
|
||||
tc.argsBuf.WriteString(e.ArgDelta)
|
||||
}
|
||||
|
||||
case EventToolCallDone:
|
||||
if tc, ok := a.toolCalls[e.ToolCallID]; ok {
|
||||
if e.Args != nil {
|
||||
// Done event carries authoritative complete args
|
||||
tc.args = e.Args
|
||||
} else {
|
||||
// Fall back to accumulated deltas
|
||||
tc.args = []byte(tc.argsBuf.String())
|
||||
}
|
||||
}
|
||||
|
||||
case EventUsage:
|
||||
if e.Usage != nil {
|
||||
a.usage.Add(*e.Usage)
|
||||
}
|
||||
|
||||
case EventError:
|
||||
// Errors are handled by the stream consumer, not accumulated
|
||||
}
|
||||
}
|
||||
|
||||
// Response builds the final message.Response from all accumulated events.
|
||||
func (a *Accumulator) Response(stopReason message.StopReason, model string) message.Response {
|
||||
a.flushText()
|
||||
a.flushThinking()
|
||||
a.flushToolCalls()
|
||||
|
||||
return message.Response{
|
||||
Message: message.Message{
|
||||
Role: message.RoleAssistant,
|
||||
Content: a.content,
|
||||
},
|
||||
StopReason: stopReason,
|
||||
Usage: a.usage,
|
||||
Model: model,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Accumulator) flushText() {
|
||||
if a.textBuf != nil && a.textBuf.Len() > 0 {
|
||||
a.content = append(a.content, message.NewTextContent(a.textBuf.String()))
|
||||
a.textBuf = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Accumulator) flushThinking() {
|
||||
if a.thinkBuf != nil && a.thinkBuf.Len() > 0 {
|
||||
a.content = append(a.content, message.NewThinkingContent(message.Thinking{
|
||||
Text: a.thinkBuf.String(),
|
||||
}))
|
||||
a.thinkBuf = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Accumulator) flushToolCalls() {
|
||||
for _, id := range a.toolCallOrder {
|
||||
tc, ok := a.toolCalls[id]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
args := tc.args
|
||||
if args == nil {
|
||||
// Fallback: use accumulated deltas even if Done never arrived
|
||||
args = []byte(tc.argsBuf.String())
|
||||
}
|
||||
a.content = append(a.content, message.NewToolCallContent(message.ToolCall{
|
||||
ID: tc.id,
|
||||
Name: tc.name,
|
||||
Arguments: args,
|
||||
}))
|
||||
}
|
||||
a.toolCalls = make(map[string]*toolCallAccum)
|
||||
a.toolCallOrder = nil
|
||||
}
|
||||
225
internal/stream/accumulator_test.go
Normal file
225
internal/stream/accumulator_test.go
Normal file
@@ -0,0 +1,225 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
)
|
||||
|
||||
func TestAccumulator_TextOnly(t *testing.T) {
|
||||
acc := NewAccumulator()
|
||||
|
||||
acc.Apply(Event{Type: EventTextDelta, Text: "Hello "})
|
||||
acc.Apply(Event{Type: EventTextDelta, Text: "world!"})
|
||||
acc.Apply(Event{Type: EventUsage, Usage: &message.Usage{InputTokens: 10, OutputTokens: 5}})
|
||||
|
||||
resp := acc.Response(message.StopEndTurn, "mistral-large")
|
||||
|
||||
if resp.StopReason != message.StopEndTurn {
|
||||
t.Errorf("StopReason = %q, want %q", resp.StopReason, message.StopEndTurn)
|
||||
}
|
||||
if resp.Model != "mistral-large" {
|
||||
t.Errorf("Model = %q, want %q", resp.Model, "mistral-large")
|
||||
}
|
||||
if resp.Message.Role != message.RoleAssistant {
|
||||
t.Errorf("Role = %q, want %q", resp.Message.Role, message.RoleAssistant)
|
||||
}
|
||||
if resp.Message.TextContent() != "Hello world!" {
|
||||
t.Errorf("TextContent() = %q, want %q", resp.Message.TextContent(), "Hello world!")
|
||||
}
|
||||
if resp.Usage.InputTokens != 10 {
|
||||
t.Errorf("Usage.InputTokens = %d, want 10", resp.Usage.InputTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccumulator_SingleToolCall(t *testing.T) {
|
||||
acc := NewAccumulator()
|
||||
|
||||
acc.Apply(Event{Type: EventTextDelta, Text: "I'll run that."})
|
||||
acc.Apply(Event{
|
||||
Type: EventToolCallStart,
|
||||
ToolCallID: "tc_1",
|
||||
ToolCallName: "bash",
|
||||
})
|
||||
acc.Apply(Event{
|
||||
Type: EventToolCallDelta,
|
||||
ToolCallID: "tc_1",
|
||||
ArgDelta: `{"comma`,
|
||||
})
|
||||
acc.Apply(Event{
|
||||
Type: EventToolCallDelta,
|
||||
ToolCallID: "tc_1",
|
||||
ArgDelta: `nd":"ls"}`,
|
||||
})
|
||||
acc.Apply(Event{
|
||||
Type: EventToolCallDone,
|
||||
ToolCallID: "tc_1",
|
||||
Args: json.RawMessage(`{"command":"ls"}`),
|
||||
})
|
||||
|
||||
resp := acc.Response(message.StopToolUse, "mistral-large")
|
||||
|
||||
if resp.StopReason != message.StopToolUse {
|
||||
t.Errorf("StopReason = %q, want %q", resp.StopReason, message.StopToolUse)
|
||||
}
|
||||
|
||||
content := resp.Message.Content
|
||||
if len(content) != 2 {
|
||||
t.Fatalf("len(Content) = %d, want 2", len(content))
|
||||
}
|
||||
|
||||
// First block: text
|
||||
if content[0].Type != message.ContentText {
|
||||
t.Errorf("Content[0].Type = %v, want text", content[0].Type)
|
||||
}
|
||||
if content[0].Text != "I'll run that." {
|
||||
t.Errorf("Content[0].Text = %q", content[0].Text)
|
||||
}
|
||||
|
||||
// Second block: tool call
|
||||
if content[1].Type != message.ContentToolCall {
|
||||
t.Errorf("Content[1].Type = %v, want tool_call", content[1].Type)
|
||||
}
|
||||
tc := content[1].ToolCall
|
||||
if tc.ID != "tc_1" {
|
||||
t.Errorf("ToolCall.ID = %q, want tc_1", tc.ID)
|
||||
}
|
||||
if tc.Name != "bash" {
|
||||
t.Errorf("ToolCall.Name = %q, want bash", tc.Name)
|
||||
}
|
||||
if string(tc.Arguments) != `{"command":"ls"}` {
|
||||
t.Errorf("ToolCall.Arguments = %s", tc.Arguments)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccumulator_MultipleToolCalls(t *testing.T) {
|
||||
acc := NewAccumulator()
|
||||
|
||||
// Tool call 1
|
||||
acc.Apply(Event{Type: EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "bash"})
|
||||
acc.Apply(Event{Type: EventToolCallDone, ToolCallID: "tc_1", Args: json.RawMessage(`{"command":"ls"}`)})
|
||||
|
||||
// Tool call 2
|
||||
acc.Apply(Event{Type: EventToolCallStart, ToolCallID: "tc_2", ToolCallName: "fs.read"})
|
||||
acc.Apply(Event{Type: EventToolCallDone, ToolCallID: "tc_2", Args: json.RawMessage(`{"path":"go.mod"}`)})
|
||||
|
||||
resp := acc.Response(message.StopToolUse, "")
|
||||
calls := resp.Message.ToolCalls()
|
||||
if len(calls) != 2 {
|
||||
t.Fatalf("len(ToolCalls()) = %d, want 2", len(calls))
|
||||
}
|
||||
if calls[0].Name != "bash" {
|
||||
t.Errorf("calls[0].Name = %q", calls[0].Name)
|
||||
}
|
||||
if calls[1].Name != "fs.read" {
|
||||
t.Errorf("calls[1].Name = %q", calls[1].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccumulator_ThinkingBlocks(t *testing.T) {
|
||||
acc := NewAccumulator()
|
||||
|
||||
acc.Apply(Event{Type: EventThinkingDelta, Text: "Let me think"})
|
||||
acc.Apply(Event{Type: EventThinkingDelta, Text: " about this..."})
|
||||
acc.Apply(Event{Type: EventTextDelta, Text: "Here's my answer."})
|
||||
|
||||
resp := acc.Response(message.StopEndTurn, "")
|
||||
content := resp.Message.Content
|
||||
|
||||
if len(content) != 2 {
|
||||
t.Fatalf("len(Content) = %d, want 2", len(content))
|
||||
}
|
||||
|
||||
// Thinking block first
|
||||
if content[0].Type != message.ContentThinking {
|
||||
t.Errorf("Content[0].Type = %v, want thinking", content[0].Type)
|
||||
}
|
||||
if content[0].Thinking.Text != "Let me think about this..." {
|
||||
t.Errorf("Thinking.Text = %q", content[0].Thinking.Text)
|
||||
}
|
||||
|
||||
// Then text
|
||||
if content[1].Type != message.ContentText {
|
||||
t.Errorf("Content[1].Type = %v, want text", content[1].Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccumulator_ToolCallDelta_Assembly(t *testing.T) {
|
||||
acc := NewAccumulator()
|
||||
|
||||
acc.Apply(Event{Type: EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "bash"})
|
||||
// Simulate fragmented JSON
|
||||
acc.Apply(Event{Type: EventToolCallDelta, ToolCallID: "tc_1", ArgDelta: `{"`})
|
||||
acc.Apply(Event{Type: EventToolCallDelta, ToolCallID: "tc_1", ArgDelta: `comman`})
|
||||
acc.Apply(Event{Type: EventToolCallDelta, ToolCallID: "tc_1", ArgDelta: `d":"`})
|
||||
acc.Apply(Event{Type: EventToolCallDelta, ToolCallID: "tc_1", ArgDelta: `echo hi`})
|
||||
acc.Apply(Event{Type: EventToolCallDelta, ToolCallID: "tc_1", ArgDelta: `"}`})
|
||||
|
||||
// Done event carries the complete args (authoritative)
|
||||
acc.Apply(Event{
|
||||
Type: EventToolCallDone,
|
||||
ToolCallID: "tc_1",
|
||||
Args: json.RawMessage(`{"command":"echo hi"}`),
|
||||
})
|
||||
|
||||
resp := acc.Response(message.StopToolUse, "")
|
||||
calls := resp.Message.ToolCalls()
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("len(ToolCalls()) = %d, want 1", len(calls))
|
||||
}
|
||||
if string(calls[0].Arguments) != `{"command":"echo hi"}` {
|
||||
t.Errorf("Arguments = %s", calls[0].Arguments)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccumulator_EmptyStream(t *testing.T) {
|
||||
acc := NewAccumulator()
|
||||
resp := acc.Response(message.StopEndTurn, "model-x")
|
||||
|
||||
if resp.Model != "model-x" {
|
||||
t.Errorf("Model = %q", resp.Model)
|
||||
}
|
||||
if len(resp.Message.Content) != 0 {
|
||||
t.Errorf("len(Content) = %d, want 0", len(resp.Message.Content))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccumulator_UsageCumulative(t *testing.T) {
|
||||
acc := NewAccumulator()
|
||||
|
||||
acc.Apply(Event{Type: EventUsage, Usage: &message.Usage{InputTokens: 100, OutputTokens: 30}})
|
||||
acc.Apply(Event{Type: EventTextDelta, Text: "hi"})
|
||||
acc.Apply(Event{Type: EventUsage, Usage: &message.Usage{InputTokens: 0, OutputTokens: 20}})
|
||||
|
||||
resp := acc.Response(message.StopEndTurn, "")
|
||||
// Last usage wins (providers typically send cumulative, not incremental)
|
||||
// But our Add() is additive — so 100+0=100 input, 30+20=50 output
|
||||
if resp.Usage.InputTokens != 100 {
|
||||
t.Errorf("InputTokens = %d, want 100", resp.Usage.InputTokens)
|
||||
}
|
||||
if resp.Usage.OutputTokens != 50 {
|
||||
t.Errorf("OutputTokens = %d, want 50", resp.Usage.OutputTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventType_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
et EventType
|
||||
want string
|
||||
}{
|
||||
{EventTextDelta, "text_delta"},
|
||||
{EventThinkingDelta, "thinking_delta"},
|
||||
{EventToolCallStart, "tool_call_start"},
|
||||
{EventToolCallDelta, "tool_call_delta"},
|
||||
{EventToolCallDone, "tool_call_done"},
|
||||
{EventUsage, "usage"},
|
||||
{EventError, "error"},
|
||||
{EventType(99), "unknown(99)"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := tt.et.String(); got != tt.want {
|
||||
t.Errorf("EventType(%d).String() = %q, want %q", tt.et, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
70
internal/stream/event.go
Normal file
70
internal/stream/event.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
)
|
||||
|
||||
// EventType discriminates streaming events.
|
||||
type EventType int
|
||||
|
||||
const (
|
||||
EventTextDelta EventType = iota + 1
|
||||
EventThinkingDelta
|
||||
EventToolCallStart
|
||||
EventToolCallDelta
|
||||
EventToolCallDone
|
||||
EventUsage
|
||||
EventError
|
||||
)
|
||||
|
||||
func (et EventType) String() string {
|
||||
switch et {
|
||||
case EventTextDelta:
|
||||
return "text_delta"
|
||||
case EventThinkingDelta:
|
||||
return "thinking_delta"
|
||||
case EventToolCallStart:
|
||||
return "tool_call_start"
|
||||
case EventToolCallDelta:
|
||||
return "tool_call_delta"
|
||||
case EventToolCallDone:
|
||||
return "tool_call_done"
|
||||
case EventUsage:
|
||||
return "usage"
|
||||
case EventError:
|
||||
return "error"
|
||||
default:
|
||||
return fmt.Sprintf("unknown(%d)", et)
|
||||
}
|
||||
}
|
||||
|
||||
// Event is a single streaming event from a provider.
|
||||
type Event struct {
|
||||
Type EventType
|
||||
|
||||
// TextDelta, ThinkingDelta
|
||||
Text string
|
||||
|
||||
// ToolCallStart: ID + Name set
|
||||
// ToolCallDelta: ID + ArgDelta set
|
||||
// ToolCallDone: ID + Args set (complete JSON)
|
||||
ToolCallID string
|
||||
ToolCallName string
|
||||
ArgDelta string // partial JSON fragment
|
||||
Args json.RawMessage // complete arguments (on Done)
|
||||
|
||||
// Usage
|
||||
Usage *message.Usage
|
||||
|
||||
// Error
|
||||
Err error
|
||||
|
||||
// StopReason — set on the final event of a stream
|
||||
StopReason message.StopReason
|
||||
|
||||
// Model — set on first event if available
|
||||
Model string
|
||||
}
|
||||
25
internal/stream/stream.go
Normal file
25
internal/stream/stream.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package stream
|
||||
|
||||
// Stream is the unified pull-based iterator over provider events.
|
||||
// All provider implementations produce this interface.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// for s.Next() {
|
||||
// event := s.Current()
|
||||
// process(event)
|
||||
// }
|
||||
// if err := s.Err(); err != nil {
|
||||
// handle(err)
|
||||
// }
|
||||
// s.Close()
|
||||
type Stream interface {
|
||||
// Next advances to the next event. Returns false when exhausted or errored.
|
||||
Next() bool
|
||||
// Current returns the most recent event. Only valid after Next() returns true.
|
||||
Current() Event
|
||||
// Err returns the first error encountered, or nil.
|
||||
Err() error
|
||||
// Close releases resources. Must be called after consumption.
|
||||
Close() error
|
||||
}
|
||||
Reference in New Issue
Block a user