feat: add foundation types, streaming, and provider interface
internal/message/ — Content discriminated union, Message, Usage, StopReason, Response. 22 tests. internal/stream/ — Stream pull-based iterator interface, Event types, Accumulator (assembles Response from events). 8 tests. internal/provider/ — Provider interface, Request, ToolDefinition, Registry with factory pattern, ProviderError with HTTP status classification. errors.AsType[E] for Go 1.26. 13 tests. 43 tests total, all passing.
This commit is contained in:
78
internal/message/content.go
Normal file
78
internal/message/content.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ContentType discriminates the content block union.
|
||||
type ContentType int
|
||||
|
||||
const (
|
||||
ContentText ContentType = iota + 1
|
||||
ContentToolCall
|
||||
ContentToolResult
|
||||
ContentThinking
|
||||
)
|
||||
|
||||
func (ct ContentType) String() string {
|
||||
switch ct {
|
||||
case ContentText:
|
||||
return "text"
|
||||
case ContentToolCall:
|
||||
return "tool_call"
|
||||
case ContentToolResult:
|
||||
return "tool_result"
|
||||
case ContentThinking:
|
||||
return "thinking"
|
||||
default:
|
||||
return fmt.Sprintf("unknown(%d)", ct)
|
||||
}
|
||||
}
|
||||
|
||||
// Content is a discriminated union. Exactly one payload field is set per Type.
|
||||
type Content struct {
|
||||
Type ContentType
|
||||
Text string // ContentText
|
||||
ToolCall *ToolCall // ContentToolCall
|
||||
ToolResult *ToolResult // ContentToolResult
|
||||
Thinking *Thinking // ContentThinking
|
||||
}
|
||||
|
||||
// ToolCall represents the model's request to invoke a tool.
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Arguments json.RawMessage `json:"arguments"`
|
||||
}
|
||||
|
||||
// ToolResult is the output of executing a tool, correlated by ToolCallID.
|
||||
type ToolResult struct {
|
||||
ToolCallID string `json:"tool_call_id"`
|
||||
Content string `json:"content"`
|
||||
IsError bool `json:"is_error"`
|
||||
}
|
||||
|
||||
// Thinking represents a reasoning/thinking trace.
|
||||
// Signature must round-trip unchanged (Anthropic requirement).
|
||||
type Thinking struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
Redacted bool `json:"redacted,omitempty"`
|
||||
}
|
||||
|
||||
func NewTextContent(text string) Content {
|
||||
return Content{Type: ContentText, Text: text}
|
||||
}
|
||||
|
||||
func NewToolCallContent(tc ToolCall) Content {
|
||||
return Content{Type: ContentToolCall, ToolCall: &tc}
|
||||
}
|
||||
|
||||
func NewToolResultContent(tr ToolResult) Content {
|
||||
return Content{Type: ContentToolResult, ToolResult: &tr}
|
||||
}
|
||||
|
||||
func NewThinkingContent(th Thinking) Content {
|
||||
return Content{Type: ContentThinking, Thinking: &th}
|
||||
}
|
||||
174
internal/message/content_test.go
Normal file
174
internal/message/content_test.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewTextContent(t *testing.T) {
|
||||
c := NewTextContent("hello world")
|
||||
|
||||
if c.Type != ContentText {
|
||||
t.Errorf("Type = %v, want %v", c.Type, ContentText)
|
||||
}
|
||||
if c.Text != "hello world" {
|
||||
t.Errorf("Text = %q, want %q", c.Text, "hello world")
|
||||
}
|
||||
if c.ToolCall != nil {
|
||||
t.Error("ToolCall should be nil for text content")
|
||||
}
|
||||
if c.ToolResult != nil {
|
||||
t.Error("ToolResult should be nil for text content")
|
||||
}
|
||||
if c.Thinking != nil {
|
||||
t.Error("Thinking should be nil for text content")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewToolCallContent(t *testing.T) {
|
||||
args := json.RawMessage(`{"command":"ls -la"}`)
|
||||
tc := ToolCall{
|
||||
ID: "tc_001",
|
||||
Name: "bash",
|
||||
Arguments: args,
|
||||
}
|
||||
c := NewToolCallContent(tc)
|
||||
|
||||
if c.Type != ContentToolCall {
|
||||
t.Errorf("Type = %v, want %v", c.Type, ContentToolCall)
|
||||
}
|
||||
if c.ToolCall == nil {
|
||||
t.Fatal("ToolCall should not be nil")
|
||||
}
|
||||
if c.ToolCall.ID != "tc_001" {
|
||||
t.Errorf("ToolCall.ID = %q, want %q", c.ToolCall.ID, "tc_001")
|
||||
}
|
||||
if c.ToolCall.Name != "bash" {
|
||||
t.Errorf("ToolCall.Name = %q, want %q", c.ToolCall.Name, "bash")
|
||||
}
|
||||
if string(c.ToolCall.Arguments) != `{"command":"ls -la"}` {
|
||||
t.Errorf("ToolCall.Arguments = %s, want %s", c.ToolCall.Arguments, args)
|
||||
}
|
||||
if c.Text != "" {
|
||||
t.Error("Text should be empty for tool call content")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewToolResultContent(t *testing.T) {
|
||||
tr := ToolResult{
|
||||
ToolCallID: "tc_001",
|
||||
Content: "file1.go\nfile2.go",
|
||||
IsError: false,
|
||||
}
|
||||
c := NewToolResultContent(tr)
|
||||
|
||||
if c.Type != ContentToolResult {
|
||||
t.Errorf("Type = %v, want %v", c.Type, ContentToolResult)
|
||||
}
|
||||
if c.ToolResult == nil {
|
||||
t.Fatal("ToolResult should not be nil")
|
||||
}
|
||||
if c.ToolResult.ToolCallID != "tc_001" {
|
||||
t.Errorf("ToolResult.ToolCallID = %q, want %q", c.ToolResult.ToolCallID, "tc_001")
|
||||
}
|
||||
if c.ToolResult.IsError {
|
||||
t.Error("ToolResult.IsError should be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewToolResultContent_Error(t *testing.T) {
|
||||
tr := ToolResult{
|
||||
ToolCallID: "tc_002",
|
||||
Content: "permission denied",
|
||||
IsError: true,
|
||||
}
|
||||
c := NewToolResultContent(tr)
|
||||
|
||||
if !c.ToolResult.IsError {
|
||||
t.Error("ToolResult.IsError should be true")
|
||||
}
|
||||
if c.ToolResult.Content != "permission denied" {
|
||||
t.Errorf("ToolResult.Content = %q, want %q", c.ToolResult.Content, "permission denied")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewThinkingContent(t *testing.T) {
|
||||
th := Thinking{
|
||||
Text: "Let me think about this...",
|
||||
Signature: "sig_abc123",
|
||||
}
|
||||
c := NewThinkingContent(th)
|
||||
|
||||
if c.Type != ContentThinking {
|
||||
t.Errorf("Type = %v, want %v", c.Type, ContentThinking)
|
||||
}
|
||||
if c.Thinking == nil {
|
||||
t.Fatal("Thinking should not be nil")
|
||||
}
|
||||
if c.Thinking.Text != "Let me think about this..." {
|
||||
t.Errorf("Thinking.Text = %q", c.Thinking.Text)
|
||||
}
|
||||
if c.Thinking.Signature != "sig_abc123" {
|
||||
t.Errorf("Thinking.Signature = %q", c.Thinking.Signature)
|
||||
}
|
||||
if c.Thinking.Redacted {
|
||||
t.Error("Thinking.Redacted should be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRedactedThinkingContent(t *testing.T) {
|
||||
th := Thinking{
|
||||
Redacted: true,
|
||||
}
|
||||
c := NewThinkingContent(th)
|
||||
|
||||
if !c.Thinking.Redacted {
|
||||
t.Error("Thinking.Redacted should be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestContentType_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
ct ContentType
|
||||
want string
|
||||
}{
|
||||
{ContentText, "text"},
|
||||
{ContentToolCall, "tool_call"},
|
||||
{ContentToolResult, "tool_result"},
|
||||
{ContentThinking, "thinking"},
|
||||
{ContentType(99), "unknown(99)"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := tt.ct.String(); got != tt.want {
|
||||
t.Errorf("ContentType(%d).String() = %q, want %q", tt.ct, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolCall_JSON_RoundTrip(t *testing.T) {
|
||||
original := ToolCall{
|
||||
ID: "tc_100",
|
||||
Name: "fs.read",
|
||||
Arguments: json.RawMessage(`{"path":"/tmp/test.go","offset":0}`),
|
||||
}
|
||||
|
||||
data, err := json.Marshal(original)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal: %v", err)
|
||||
}
|
||||
|
||||
var decoded ToolCall
|
||||
if err := json.Unmarshal(data, &decoded); err != nil {
|
||||
t.Fatalf("Unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if decoded.ID != original.ID {
|
||||
t.Errorf("ID = %q, want %q", decoded.ID, original.ID)
|
||||
}
|
||||
if decoded.Name != original.Name {
|
||||
t.Errorf("Name = %q, want %q", decoded.Name, original.Name)
|
||||
}
|
||||
if string(decoded.Arguments) != string(original.Arguments) {
|
||||
t.Errorf("Arguments = %s, want %s", decoded.Arguments, original.Arguments)
|
||||
}
|
||||
}
|
||||
89
internal/message/message.go
Normal file
89
internal/message/message.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package message
|
||||
|
||||
import "strings"
|
||||
|
||||
// Role identifies the sender of a message.
|
||||
type Role string
|
||||
|
||||
const (
|
||||
RoleUser Role = "user"
|
||||
RoleAssistant Role = "assistant"
|
||||
RoleSystem Role = "system"
|
||||
)
|
||||
|
||||
// Message represents a single turn in the conversation.
|
||||
type Message struct {
|
||||
Role Role
|
||||
Content []Content
|
||||
}
|
||||
|
||||
func NewUserText(text string) Message {
|
||||
return Message{
|
||||
Role: RoleUser,
|
||||
Content: []Content{NewTextContent(text)},
|
||||
}
|
||||
}
|
||||
|
||||
func NewAssistantText(text string) Message {
|
||||
return Message{
|
||||
Role: RoleAssistant,
|
||||
Content: []Content{NewTextContent(text)},
|
||||
}
|
||||
}
|
||||
|
||||
func NewAssistantContent(blocks ...Content) Message {
|
||||
return Message{
|
||||
Role: RoleAssistant,
|
||||
Content: blocks,
|
||||
}
|
||||
}
|
||||
|
||||
func NewSystemText(text string) Message {
|
||||
return Message{
|
||||
Role: RoleSystem,
|
||||
Content: []Content{NewTextContent(text)},
|
||||
}
|
||||
}
|
||||
|
||||
func NewToolResults(results ...ToolResult) Message {
|
||||
content := make([]Content, len(results))
|
||||
for i, r := range results {
|
||||
content[i] = NewToolResultContent(r)
|
||||
}
|
||||
return Message{
|
||||
Role: RoleUser,
|
||||
Content: content,
|
||||
}
|
||||
}
|
||||
|
||||
// HasToolCalls returns true if any content block is a tool call.
|
||||
func (m Message) HasToolCalls() bool {
|
||||
for _, c := range m.Content {
|
||||
if c.Type == ContentToolCall {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ToolCalls extracts all tool call blocks.
|
||||
func (m Message) ToolCalls() []ToolCall {
|
||||
var calls []ToolCall
|
||||
for _, c := range m.Content {
|
||||
if c.Type == ContentToolCall && c.ToolCall != nil {
|
||||
calls = append(calls, *c.ToolCall)
|
||||
}
|
||||
}
|
||||
return calls
|
||||
}
|
||||
|
||||
// TextContent concatenates all text blocks.
|
||||
func (m Message) TextContent() string {
|
||||
var b strings.Builder
|
||||
for _, c := range m.Content {
|
||||
if c.Type == ContentText {
|
||||
b.WriteString(c.Text)
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
214
internal/message/message_test.go
Normal file
214
internal/message/message_test.go
Normal file
@@ -0,0 +1,214 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewUserText(t *testing.T) {
|
||||
m := NewUserText("hello")
|
||||
if m.Role != RoleUser {
|
||||
t.Errorf("Role = %q, want %q", m.Role, RoleUser)
|
||||
}
|
||||
if len(m.Content) != 1 {
|
||||
t.Fatalf("len(Content) = %d, want 1", len(m.Content))
|
||||
}
|
||||
if m.Content[0].Type != ContentText {
|
||||
t.Errorf("Content[0].Type = %v, want %v", m.Content[0].Type, ContentText)
|
||||
}
|
||||
if m.Content[0].Text != "hello" {
|
||||
t.Errorf("Content[0].Text = %q, want %q", m.Content[0].Text, "hello")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAssistantText(t *testing.T) {
|
||||
m := NewAssistantText("response")
|
||||
if m.Role != RoleAssistant {
|
||||
t.Errorf("Role = %q, want %q", m.Role, RoleAssistant)
|
||||
}
|
||||
if m.TextContent() != "response" {
|
||||
t.Errorf("TextContent() = %q, want %q", m.TextContent(), "response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSystemText(t *testing.T) {
|
||||
m := NewSystemText("you are a helper")
|
||||
if m.Role != RoleSystem {
|
||||
t.Errorf("Role = %q, want %q", m.Role, RoleSystem)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAssistantContent_Mixed(t *testing.T) {
|
||||
m := NewAssistantContent(
|
||||
NewTextContent("I'll run that command."),
|
||||
NewToolCallContent(ToolCall{
|
||||
ID: "tc_1",
|
||||
Name: "bash",
|
||||
Arguments: json.RawMessage(`{"command":"ls"}`),
|
||||
}),
|
||||
)
|
||||
|
||||
if m.Role != RoleAssistant {
|
||||
t.Errorf("Role = %q, want %q", m.Role, RoleAssistant)
|
||||
}
|
||||
if len(m.Content) != 2 {
|
||||
t.Fatalf("len(Content) = %d, want 2", len(m.Content))
|
||||
}
|
||||
if m.Content[0].Type != ContentText {
|
||||
t.Errorf("Content[0].Type = %v, want text", m.Content[0].Type)
|
||||
}
|
||||
if m.Content[1].Type != ContentToolCall {
|
||||
t.Errorf("Content[1].Type = %v, want tool_call", m.Content[1].Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewToolResults(t *testing.T) {
|
||||
m := NewToolResults(
|
||||
ToolResult{ToolCallID: "tc_1", Content: "output1"},
|
||||
ToolResult{ToolCallID: "tc_2", Content: "output2", IsError: true},
|
||||
)
|
||||
|
||||
if m.Role != RoleUser {
|
||||
t.Errorf("Role = %q, want %q", m.Role, RoleUser)
|
||||
}
|
||||
if len(m.Content) != 2 {
|
||||
t.Fatalf("len(Content) = %d, want 2", len(m.Content))
|
||||
}
|
||||
if m.Content[0].ToolResult.ToolCallID != "tc_1" {
|
||||
t.Errorf("Content[0].ToolResult.ToolCallID = %q", m.Content[0].ToolResult.ToolCallID)
|
||||
}
|
||||
if m.Content[1].ToolResult.IsError != true {
|
||||
t.Error("Content[1].ToolResult.IsError should be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessage_HasToolCalls(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg Message
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "text only",
|
||||
msg: NewUserText("hello"),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "with tool call",
|
||||
msg: NewAssistantContent(
|
||||
NewTextContent("running..."),
|
||||
NewToolCallContent(ToolCall{ID: "tc_1", Name: "bash"}),
|
||||
),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "tool results (not calls)",
|
||||
msg: NewToolResults(ToolResult{ToolCallID: "tc_1", Content: "ok"}),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty message",
|
||||
msg: Message{Role: RoleAssistant},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.msg.HasToolCalls(); got != tt.want {
|
||||
t.Errorf("HasToolCalls() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessage_ToolCalls(t *testing.T) {
|
||||
m := NewAssistantContent(
|
||||
NewTextContent("here are two commands"),
|
||||
NewToolCallContent(ToolCall{ID: "tc_1", Name: "bash", Arguments: json.RawMessage(`{"command":"ls"}`)}),
|
||||
NewTextContent("and another"),
|
||||
NewToolCallContent(ToolCall{ID: "tc_2", Name: "fs.read", Arguments: json.RawMessage(`{"path":"go.mod"}`)}),
|
||||
)
|
||||
|
||||
calls := m.ToolCalls()
|
||||
if len(calls) != 2 {
|
||||
t.Fatalf("len(ToolCalls()) = %d, want 2", len(calls))
|
||||
}
|
||||
if calls[0].ID != "tc_1" {
|
||||
t.Errorf("calls[0].ID = %q, want tc_1", calls[0].ID)
|
||||
}
|
||||
if calls[1].Name != "fs.read" {
|
||||
t.Errorf("calls[1].Name = %q, want fs.read", calls[1].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessage_ToolCalls_Empty(t *testing.T) {
|
||||
m := NewUserText("no tools here")
|
||||
calls := m.ToolCalls()
|
||||
if len(calls) != 0 {
|
||||
t.Errorf("len(ToolCalls()) = %d, want 0", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessage_TextContent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg Message
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "single text",
|
||||
msg: NewUserText("hello"),
|
||||
want: "hello",
|
||||
},
|
||||
{
|
||||
name: "multiple text blocks",
|
||||
msg: NewAssistantContent(
|
||||
NewTextContent("first "),
|
||||
NewToolCallContent(ToolCall{ID: "tc_1", Name: "bash"}),
|
||||
NewTextContent("second"),
|
||||
),
|
||||
want: "first second",
|
||||
},
|
||||
{
|
||||
name: "no text",
|
||||
msg: NewToolResults(ToolResult{ToolCallID: "tc_1", Content: "output"}),
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "empty message",
|
||||
msg: Message{Role: RoleAssistant},
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.msg.TextContent(); got != tt.want {
|
||||
t.Errorf("TextContent() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponse_Fields(t *testing.T) {
|
||||
r := Response{
|
||||
Message: NewAssistantText("done"),
|
||||
StopReason: StopEndTurn,
|
||||
Usage: Usage{InputTokens: 100, OutputTokens: 50},
|
||||
Model: "mistral-large-latest",
|
||||
}
|
||||
|
||||
if r.StopReason != StopEndTurn {
|
||||
t.Errorf("StopReason = %q, want %q", r.StopReason, StopEndTurn)
|
||||
}
|
||||
if r.Usage.TotalTokens() != 150 {
|
||||
t.Errorf("Usage.TotalTokens() = %d, want 150", r.Usage.TotalTokens())
|
||||
}
|
||||
if r.Model != "mistral-large-latest" {
|
||||
t.Errorf("Model = %q", r.Model)
|
||||
}
|
||||
if r.Message.TextContent() != "done" {
|
||||
t.Errorf("Message.TextContent() = %q", r.Message.TextContent())
|
||||
}
|
||||
}
|
||||
9
internal/message/response.go
Normal file
9
internal/message/response.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package message
|
||||
|
||||
// Response wraps a completed assistant turn.
|
||||
type Response struct {
|
||||
Message Message
|
||||
StopReason StopReason
|
||||
Usage Usage
|
||||
Model string
|
||||
}
|
||||
11
internal/message/stop.go
Normal file
11
internal/message/stop.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package message
|
||||
|
||||
// StopReason indicates why the model stopped generating.
|
||||
type StopReason string
|
||||
|
||||
const (
|
||||
StopEndTurn StopReason = "end_turn"
|
||||
StopMaxTokens StopReason = "max_tokens"
|
||||
StopToolUse StopReason = "tool_use"
|
||||
StopSequence StopReason = "stop_sequence"
|
||||
)
|
||||
20
internal/message/usage.go
Normal file
20
internal/message/usage.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package message
|
||||
|
||||
// Usage tracks token consumption for a single API turn.
|
||||
type Usage struct {
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
CacheReadTokens int64 `json:"cache_read_tokens,omitempty"`
|
||||
CacheCreationTokens int64 `json:"cache_creation_tokens,omitempty"`
|
||||
}
|
||||
|
||||
func (u Usage) TotalTokens() int64 {
|
||||
return u.InputTokens + u.OutputTokens
|
||||
}
|
||||
|
||||
func (u *Usage) Add(other Usage) {
|
||||
u.InputTokens += other.InputTokens
|
||||
u.OutputTokens += other.OutputTokens
|
||||
u.CacheReadTokens += other.CacheReadTokens
|
||||
u.CacheCreationTokens += other.CacheCreationTokens
|
||||
}
|
||||
70
internal/message/usage_test.go
Normal file
70
internal/message/usage_test.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package message
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestUsage_TotalTokens(t *testing.T) {
|
||||
u := Usage{InputTokens: 100, OutputTokens: 50}
|
||||
if got := u.TotalTokens(); got != 150 {
|
||||
t.Errorf("TotalTokens() = %d, want 150", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsage_TotalTokens_Zero(t *testing.T) {
|
||||
var u Usage
|
||||
if got := u.TotalTokens(); got != 0 {
|
||||
t.Errorf("TotalTokens() = %d, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsage_Add(t *testing.T) {
|
||||
u := Usage{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
CacheReadTokens: 10,
|
||||
CacheCreationTokens: 5,
|
||||
}
|
||||
other := Usage{
|
||||
InputTokens: 200,
|
||||
OutputTokens: 80,
|
||||
CacheReadTokens: 20,
|
||||
CacheCreationTokens: 15,
|
||||
}
|
||||
|
||||
u.Add(other)
|
||||
|
||||
if u.InputTokens != 300 {
|
||||
t.Errorf("InputTokens = %d, want 300", u.InputTokens)
|
||||
}
|
||||
if u.OutputTokens != 130 {
|
||||
t.Errorf("OutputTokens = %d, want 130", u.OutputTokens)
|
||||
}
|
||||
if u.CacheReadTokens != 30 {
|
||||
t.Errorf("CacheReadTokens = %d, want 30", u.CacheReadTokens)
|
||||
}
|
||||
if u.CacheCreationTokens != 20 {
|
||||
t.Errorf("CacheCreationTokens = %d, want 20", u.CacheCreationTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsage_Add_Multiple(t *testing.T) {
|
||||
var total Usage
|
||||
turns := []Usage{
|
||||
{InputTokens: 100, OutputTokens: 50},
|
||||
{InputTokens: 200, OutputTokens: 80},
|
||||
{InputTokens: 150, OutputTokens: 60},
|
||||
}
|
||||
|
||||
for _, turn := range turns {
|
||||
total.Add(turn)
|
||||
}
|
||||
|
||||
if total.InputTokens != 450 {
|
||||
t.Errorf("InputTokens = %d, want 450", total.InputTokens)
|
||||
}
|
||||
if total.OutputTokens != 190 {
|
||||
t.Errorf("OutputTokens = %d, want 190", total.OutputTokens)
|
||||
}
|
||||
if total.TotalTokens() != 640 {
|
||||
t.Errorf("TotalTokens() = %d, want 640", total.TotalTokens())
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user