feat: add foundation types, streaming, and provider interface

internal/message/ — Content discriminated union, Message, Usage,
StopReason, Response. 22 tests.

internal/stream/ — Stream pull-based iterator interface, Event types,
Accumulator (assembles Response from events). 8 tests.

internal/provider/ — Provider interface, Request, ToolDefinition,
Registry with factory pattern, ProviderError with HTTP status
classification. errors.AsType[E] for Go 1.26. 13 tests.

43 tests total, all passing.
This commit is contained in:
2026-04-03 10:57:54 +02:00
parent d3990214a5
commit 85c643fdca
17 changed files with 1569 additions and 0 deletions

View File

@@ -0,0 +1,78 @@
package message
import (
"encoding/json"
"fmt"
)
// ContentType discriminates the content block union.
type ContentType int
const (
ContentText ContentType = iota + 1
ContentToolCall
ContentToolResult
ContentThinking
)
func (ct ContentType) String() string {
switch ct {
case ContentText:
return "text"
case ContentToolCall:
return "tool_call"
case ContentToolResult:
return "tool_result"
case ContentThinking:
return "thinking"
default:
return fmt.Sprintf("unknown(%d)", ct)
}
}
// Content is a discriminated union. Exactly one payload field is set per Type.
type Content struct {
Type ContentType
Text string // ContentText
ToolCall *ToolCall // ContentToolCall
ToolResult *ToolResult // ContentToolResult
Thinking *Thinking // ContentThinking
}
// ToolCall represents the model's request to invoke a tool.
type ToolCall struct {
ID string `json:"id"`
Name string `json:"name"`
Arguments json.RawMessage `json:"arguments"`
}
// ToolResult is the output of executing a tool, correlated by ToolCallID.
type ToolResult struct {
ToolCallID string `json:"tool_call_id"`
Content string `json:"content"`
IsError bool `json:"is_error"`
}
// Thinking represents a reasoning/thinking trace.
// Signature must round-trip unchanged (Anthropic requirement).
type Thinking struct {
Text string `json:"text,omitempty"`
Signature string `json:"signature,omitempty"`
Redacted bool `json:"redacted,omitempty"`
}
func NewTextContent(text string) Content {
return Content{Type: ContentText, Text: text}
}
func NewToolCallContent(tc ToolCall) Content {
return Content{Type: ContentToolCall, ToolCall: &tc}
}
func NewToolResultContent(tr ToolResult) Content {
return Content{Type: ContentToolResult, ToolResult: &tr}
}
func NewThinkingContent(th Thinking) Content {
return Content{Type: ContentThinking, Thinking: &th}
}

View File

@@ -0,0 +1,174 @@
package message
import (
"encoding/json"
"testing"
)
func TestNewTextContent(t *testing.T) {
c := NewTextContent("hello world")
if c.Type != ContentText {
t.Errorf("Type = %v, want %v", c.Type, ContentText)
}
if c.Text != "hello world" {
t.Errorf("Text = %q, want %q", c.Text, "hello world")
}
if c.ToolCall != nil {
t.Error("ToolCall should be nil for text content")
}
if c.ToolResult != nil {
t.Error("ToolResult should be nil for text content")
}
if c.Thinking != nil {
t.Error("Thinking should be nil for text content")
}
}
func TestNewToolCallContent(t *testing.T) {
args := json.RawMessage(`{"command":"ls -la"}`)
tc := ToolCall{
ID: "tc_001",
Name: "bash",
Arguments: args,
}
c := NewToolCallContent(tc)
if c.Type != ContentToolCall {
t.Errorf("Type = %v, want %v", c.Type, ContentToolCall)
}
if c.ToolCall == nil {
t.Fatal("ToolCall should not be nil")
}
if c.ToolCall.ID != "tc_001" {
t.Errorf("ToolCall.ID = %q, want %q", c.ToolCall.ID, "tc_001")
}
if c.ToolCall.Name != "bash" {
t.Errorf("ToolCall.Name = %q, want %q", c.ToolCall.Name, "bash")
}
if string(c.ToolCall.Arguments) != `{"command":"ls -la"}` {
t.Errorf("ToolCall.Arguments = %s, want %s", c.ToolCall.Arguments, args)
}
if c.Text != "" {
t.Error("Text should be empty for tool call content")
}
}
func TestNewToolResultContent(t *testing.T) {
tr := ToolResult{
ToolCallID: "tc_001",
Content: "file1.go\nfile2.go",
IsError: false,
}
c := NewToolResultContent(tr)
if c.Type != ContentToolResult {
t.Errorf("Type = %v, want %v", c.Type, ContentToolResult)
}
if c.ToolResult == nil {
t.Fatal("ToolResult should not be nil")
}
if c.ToolResult.ToolCallID != "tc_001" {
t.Errorf("ToolResult.ToolCallID = %q, want %q", c.ToolResult.ToolCallID, "tc_001")
}
if c.ToolResult.IsError {
t.Error("ToolResult.IsError should be false")
}
}
func TestNewToolResultContent_Error(t *testing.T) {
tr := ToolResult{
ToolCallID: "tc_002",
Content: "permission denied",
IsError: true,
}
c := NewToolResultContent(tr)
if !c.ToolResult.IsError {
t.Error("ToolResult.IsError should be true")
}
if c.ToolResult.Content != "permission denied" {
t.Errorf("ToolResult.Content = %q, want %q", c.ToolResult.Content, "permission denied")
}
}
func TestNewThinkingContent(t *testing.T) {
th := Thinking{
Text: "Let me think about this...",
Signature: "sig_abc123",
}
c := NewThinkingContent(th)
if c.Type != ContentThinking {
t.Errorf("Type = %v, want %v", c.Type, ContentThinking)
}
if c.Thinking == nil {
t.Fatal("Thinking should not be nil")
}
if c.Thinking.Text != "Let me think about this..." {
t.Errorf("Thinking.Text = %q", c.Thinking.Text)
}
if c.Thinking.Signature != "sig_abc123" {
t.Errorf("Thinking.Signature = %q", c.Thinking.Signature)
}
if c.Thinking.Redacted {
t.Error("Thinking.Redacted should be false")
}
}
func TestNewRedactedThinkingContent(t *testing.T) {
th := Thinking{
Redacted: true,
}
c := NewThinkingContent(th)
if !c.Thinking.Redacted {
t.Error("Thinking.Redacted should be true")
}
}
func TestContentType_String(t *testing.T) {
tests := []struct {
ct ContentType
want string
}{
{ContentText, "text"},
{ContentToolCall, "tool_call"},
{ContentToolResult, "tool_result"},
{ContentThinking, "thinking"},
{ContentType(99), "unknown(99)"},
}
for _, tt := range tests {
if got := tt.ct.String(); got != tt.want {
t.Errorf("ContentType(%d).String() = %q, want %q", tt.ct, got, tt.want)
}
}
}
func TestToolCall_JSON_RoundTrip(t *testing.T) {
original := ToolCall{
ID: "tc_100",
Name: "fs.read",
Arguments: json.RawMessage(`{"path":"/tmp/test.go","offset":0}`),
}
data, err := json.Marshal(original)
if err != nil {
t.Fatalf("Marshal: %v", err)
}
var decoded ToolCall
if err := json.Unmarshal(data, &decoded); err != nil {
t.Fatalf("Unmarshal: %v", err)
}
if decoded.ID != original.ID {
t.Errorf("ID = %q, want %q", decoded.ID, original.ID)
}
if decoded.Name != original.Name {
t.Errorf("Name = %q, want %q", decoded.Name, original.Name)
}
if string(decoded.Arguments) != string(original.Arguments) {
t.Errorf("Arguments = %s, want %s", decoded.Arguments, original.Arguments)
}
}

View File

@@ -0,0 +1,89 @@
package message
import "strings"
// Role identifies the sender of a message.
type Role string
const (
RoleUser Role = "user"
RoleAssistant Role = "assistant"
RoleSystem Role = "system"
)
// Message represents a single turn in the conversation.
type Message struct {
Role Role
Content []Content
}
func NewUserText(text string) Message {
return Message{
Role: RoleUser,
Content: []Content{NewTextContent(text)},
}
}
func NewAssistantText(text string) Message {
return Message{
Role: RoleAssistant,
Content: []Content{NewTextContent(text)},
}
}
func NewAssistantContent(blocks ...Content) Message {
return Message{
Role: RoleAssistant,
Content: blocks,
}
}
func NewSystemText(text string) Message {
return Message{
Role: RoleSystem,
Content: []Content{NewTextContent(text)},
}
}
func NewToolResults(results ...ToolResult) Message {
content := make([]Content, len(results))
for i, r := range results {
content[i] = NewToolResultContent(r)
}
return Message{
Role: RoleUser,
Content: content,
}
}
// HasToolCalls returns true if any content block is a tool call.
func (m Message) HasToolCalls() bool {
for _, c := range m.Content {
if c.Type == ContentToolCall {
return true
}
}
return false
}
// ToolCalls extracts all tool call blocks.
func (m Message) ToolCalls() []ToolCall {
var calls []ToolCall
for _, c := range m.Content {
if c.Type == ContentToolCall && c.ToolCall != nil {
calls = append(calls, *c.ToolCall)
}
}
return calls
}
// TextContent concatenates all text blocks.
func (m Message) TextContent() string {
var b strings.Builder
for _, c := range m.Content {
if c.Type == ContentText {
b.WriteString(c.Text)
}
}
return b.String()
}

View File

@@ -0,0 +1,214 @@
package message
import (
"encoding/json"
"testing"
)
func TestNewUserText(t *testing.T) {
m := NewUserText("hello")
if m.Role != RoleUser {
t.Errorf("Role = %q, want %q", m.Role, RoleUser)
}
if len(m.Content) != 1 {
t.Fatalf("len(Content) = %d, want 1", len(m.Content))
}
if m.Content[0].Type != ContentText {
t.Errorf("Content[0].Type = %v, want %v", m.Content[0].Type, ContentText)
}
if m.Content[0].Text != "hello" {
t.Errorf("Content[0].Text = %q, want %q", m.Content[0].Text, "hello")
}
}
func TestNewAssistantText(t *testing.T) {
m := NewAssistantText("response")
if m.Role != RoleAssistant {
t.Errorf("Role = %q, want %q", m.Role, RoleAssistant)
}
if m.TextContent() != "response" {
t.Errorf("TextContent() = %q, want %q", m.TextContent(), "response")
}
}
func TestNewSystemText(t *testing.T) {
m := NewSystemText("you are a helper")
if m.Role != RoleSystem {
t.Errorf("Role = %q, want %q", m.Role, RoleSystem)
}
}
func TestNewAssistantContent_Mixed(t *testing.T) {
m := NewAssistantContent(
NewTextContent("I'll run that command."),
NewToolCallContent(ToolCall{
ID: "tc_1",
Name: "bash",
Arguments: json.RawMessage(`{"command":"ls"}`),
}),
)
if m.Role != RoleAssistant {
t.Errorf("Role = %q, want %q", m.Role, RoleAssistant)
}
if len(m.Content) != 2 {
t.Fatalf("len(Content) = %d, want 2", len(m.Content))
}
if m.Content[0].Type != ContentText {
t.Errorf("Content[0].Type = %v, want text", m.Content[0].Type)
}
if m.Content[1].Type != ContentToolCall {
t.Errorf("Content[1].Type = %v, want tool_call", m.Content[1].Type)
}
}
func TestNewToolResults(t *testing.T) {
m := NewToolResults(
ToolResult{ToolCallID: "tc_1", Content: "output1"},
ToolResult{ToolCallID: "tc_2", Content: "output2", IsError: true},
)
if m.Role != RoleUser {
t.Errorf("Role = %q, want %q", m.Role, RoleUser)
}
if len(m.Content) != 2 {
t.Fatalf("len(Content) = %d, want 2", len(m.Content))
}
if m.Content[0].ToolResult.ToolCallID != "tc_1" {
t.Errorf("Content[0].ToolResult.ToolCallID = %q", m.Content[0].ToolResult.ToolCallID)
}
if m.Content[1].ToolResult.IsError != true {
t.Error("Content[1].ToolResult.IsError should be true")
}
}
func TestMessage_HasToolCalls(t *testing.T) {
tests := []struct {
name string
msg Message
want bool
}{
{
name: "text only",
msg: NewUserText("hello"),
want: false,
},
{
name: "with tool call",
msg: NewAssistantContent(
NewTextContent("running..."),
NewToolCallContent(ToolCall{ID: "tc_1", Name: "bash"}),
),
want: true,
},
{
name: "tool results (not calls)",
msg: NewToolResults(ToolResult{ToolCallID: "tc_1", Content: "ok"}),
want: false,
},
{
name: "empty message",
msg: Message{Role: RoleAssistant},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.msg.HasToolCalls(); got != tt.want {
t.Errorf("HasToolCalls() = %v, want %v", got, tt.want)
}
})
}
}
func TestMessage_ToolCalls(t *testing.T) {
m := NewAssistantContent(
NewTextContent("here are two commands"),
NewToolCallContent(ToolCall{ID: "tc_1", Name: "bash", Arguments: json.RawMessage(`{"command":"ls"}`)}),
NewTextContent("and another"),
NewToolCallContent(ToolCall{ID: "tc_2", Name: "fs.read", Arguments: json.RawMessage(`{"path":"go.mod"}`)}),
)
calls := m.ToolCalls()
if len(calls) != 2 {
t.Fatalf("len(ToolCalls()) = %d, want 2", len(calls))
}
if calls[0].ID != "tc_1" {
t.Errorf("calls[0].ID = %q, want tc_1", calls[0].ID)
}
if calls[1].Name != "fs.read" {
t.Errorf("calls[1].Name = %q, want fs.read", calls[1].Name)
}
}
func TestMessage_ToolCalls_Empty(t *testing.T) {
m := NewUserText("no tools here")
calls := m.ToolCalls()
if len(calls) != 0 {
t.Errorf("len(ToolCalls()) = %d, want 0", len(calls))
}
}
func TestMessage_TextContent(t *testing.T) {
tests := []struct {
name string
msg Message
want string
}{
{
name: "single text",
msg: NewUserText("hello"),
want: "hello",
},
{
name: "multiple text blocks",
msg: NewAssistantContent(
NewTextContent("first "),
NewToolCallContent(ToolCall{ID: "tc_1", Name: "bash"}),
NewTextContent("second"),
),
want: "first second",
},
{
name: "no text",
msg: NewToolResults(ToolResult{ToolCallID: "tc_1", Content: "output"}),
want: "",
},
{
name: "empty message",
msg: Message{Role: RoleAssistant},
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.msg.TextContent(); got != tt.want {
t.Errorf("TextContent() = %q, want %q", got, tt.want)
}
})
}
}
func TestResponse_Fields(t *testing.T) {
r := Response{
Message: NewAssistantText("done"),
StopReason: StopEndTurn,
Usage: Usage{InputTokens: 100, OutputTokens: 50},
Model: "mistral-large-latest",
}
if r.StopReason != StopEndTurn {
t.Errorf("StopReason = %q, want %q", r.StopReason, StopEndTurn)
}
if r.Usage.TotalTokens() != 150 {
t.Errorf("Usage.TotalTokens() = %d, want 150", r.Usage.TotalTokens())
}
if r.Model != "mistral-large-latest" {
t.Errorf("Model = %q", r.Model)
}
if r.Message.TextContent() != "done" {
t.Errorf("Message.TextContent() = %q", r.Message.TextContent())
}
}

View File

@@ -0,0 +1,9 @@
package message
// Response wraps a completed assistant turn.
type Response struct {
Message Message
StopReason StopReason
Usage Usage
Model string
}

11
internal/message/stop.go Normal file
View File

@@ -0,0 +1,11 @@
package message
// StopReason indicates why the model stopped generating.
type StopReason string
const (
StopEndTurn StopReason = "end_turn"
StopMaxTokens StopReason = "max_tokens"
StopToolUse StopReason = "tool_use"
StopSequence StopReason = "stop_sequence"
)

20
internal/message/usage.go Normal file
View File

@@ -0,0 +1,20 @@
package message
// Usage tracks token consumption for a single API turn.
type Usage struct {
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
CacheReadTokens int64 `json:"cache_read_tokens,omitempty"`
CacheCreationTokens int64 `json:"cache_creation_tokens,omitempty"`
}
func (u Usage) TotalTokens() int64 {
return u.InputTokens + u.OutputTokens
}
func (u *Usage) Add(other Usage) {
u.InputTokens += other.InputTokens
u.OutputTokens += other.OutputTokens
u.CacheReadTokens += other.CacheReadTokens
u.CacheCreationTokens += other.CacheCreationTokens
}

View File

@@ -0,0 +1,70 @@
package message
import "testing"
func TestUsage_TotalTokens(t *testing.T) {
u := Usage{InputTokens: 100, OutputTokens: 50}
if got := u.TotalTokens(); got != 150 {
t.Errorf("TotalTokens() = %d, want 150", got)
}
}
func TestUsage_TotalTokens_Zero(t *testing.T) {
var u Usage
if got := u.TotalTokens(); got != 0 {
t.Errorf("TotalTokens() = %d, want 0", got)
}
}
func TestUsage_Add(t *testing.T) {
u := Usage{
InputTokens: 100,
OutputTokens: 50,
CacheReadTokens: 10,
CacheCreationTokens: 5,
}
other := Usage{
InputTokens: 200,
OutputTokens: 80,
CacheReadTokens: 20,
CacheCreationTokens: 15,
}
u.Add(other)
if u.InputTokens != 300 {
t.Errorf("InputTokens = %d, want 300", u.InputTokens)
}
if u.OutputTokens != 130 {
t.Errorf("OutputTokens = %d, want 130", u.OutputTokens)
}
if u.CacheReadTokens != 30 {
t.Errorf("CacheReadTokens = %d, want 30", u.CacheReadTokens)
}
if u.CacheCreationTokens != 20 {
t.Errorf("CacheCreationTokens = %d, want 20", u.CacheCreationTokens)
}
}
func TestUsage_Add_Multiple(t *testing.T) {
var total Usage
turns := []Usage{
{InputTokens: 100, OutputTokens: 50},
{InputTokens: 200, OutputTokens: 80},
{InputTokens: 150, OutputTokens: 60},
}
for _, turn := range turns {
total.Add(turn)
}
if total.InputTokens != 450 {
t.Errorf("InputTokens = %d, want 450", total.InputTokens)
}
if total.OutputTokens != 190 {
t.Errorf("OutputTokens = %d, want 190", total.OutputTokens)
}
if total.TotalTokens() != 640 {
t.Errorf("TotalTokens() = %d, want 640", total.TotalTokens())
}
}

View File

@@ -0,0 +1,79 @@
package provider
import (
"fmt"
"time"
)
// ErrorKind classifies provider errors for retry decisions.
type ErrorKind int
const (
ErrTransient ErrorKind = iota + 1 // 429, 500, 502, 503, 529 — retry with backoff
ErrAuth // 401, 403 — don't retry
ErrBadRequest // 400 — don't retry, fix request
ErrNotFound // 404 — model/endpoint not found
ErrOverloaded // capacity exhausted — backoff + retry
)
func (k ErrorKind) String() string {
switch k {
case ErrTransient:
return "transient"
case ErrAuth:
return "auth"
case ErrBadRequest:
return "bad_request"
case ErrNotFound:
return "not_found"
case ErrOverloaded:
return "overloaded"
default:
return fmt.Sprintf("unknown(%d)", k)
}
}
// ProviderError wraps an SDK error with classification metadata.
type ProviderError struct {
Kind ErrorKind
Provider string
StatusCode int
Message string
Retryable bool
RetryAfter time.Duration // from Retry-After or rate limit headers
Err error // underlying SDK error
}
func (e *ProviderError) Error() string {
if e.Err != nil {
return fmt.Sprintf("%s %s (%d): %s: %v", e.Provider, e.Kind, e.StatusCode, e.Message, e.Err)
}
return fmt.Sprintf("%s %s (%d): %s", e.Provider, e.Kind, e.StatusCode, e.Message)
}
func (e *ProviderError) Unwrap() error {
return e.Err
}
// ClassifyHTTPStatus returns the ErrorKind and retryability for an HTTP status code.
func ClassifyHTTPStatus(status int) (ErrorKind, bool) {
switch {
case status == 401 || status == 403:
return ErrAuth, false
case status == 400:
return ErrBadRequest, false
case status == 404:
return ErrNotFound, false
case status == 429 || status == 529:
return ErrTransient, true
case status == 500 || status == 502 || status == 503:
return ErrTransient, true
case status == 504:
return ErrOverloaded, true
default:
if status >= 500 {
return ErrTransient, true
}
return ErrBadRequest, false
}
}

View File

@@ -0,0 +1,118 @@
package provider
import (
"errors"
"fmt"
"testing"
)
func TestProviderError_Error(t *testing.T) {
err := &ProviderError{
Kind: ErrTransient,
Provider: "mistral",
StatusCode: 429,
Message: "rate limited",
}
got := err.Error()
want := "mistral transient (429): rate limited"
if got != want {
t.Errorf("Error() = %q, want %q", got, want)
}
}
func TestProviderError_Error_WithWrapped(t *testing.T) {
inner := errors.New("connection reset")
err := &ProviderError{
Kind: ErrTransient,
Provider: "openai",
StatusCode: 502,
Message: "bad gateway",
Err: inner,
}
got := err.Error()
want := "openai transient (502): bad gateway: connection reset"
if got != want {
t.Errorf("Error() = %q, want %q", got, want)
}
}
func TestProviderError_Unwrap(t *testing.T) {
inner := errors.New("timeout")
err := &ProviderError{
Kind: ErrTransient,
Err: inner,
}
if !errors.Is(err, inner) {
t.Error("errors.Is should find inner error")
}
}
func TestProviderError_AsType(t *testing.T) {
inner := &ProviderError{
Kind: ErrAuth,
Provider: "anthropic",
StatusCode: 401,
Message: "invalid key",
}
wrapped := fmt.Errorf("api call failed: %w", inner)
pErr, ok := errors.AsType[*ProviderError](wrapped)
if !ok {
t.Fatal("errors.AsType should find ProviderError")
}
if pErr.Kind != ErrAuth {
t.Errorf("Kind = %v, want %v", pErr.Kind, ErrAuth)
}
if pErr.Provider != "anthropic" {
t.Errorf("Provider = %q", pErr.Provider)
}
}
func TestClassifyHTTPStatus(t *testing.T) {
tests := []struct {
status int
wantKind ErrorKind
wantRetry bool
}{
{200, ErrBadRequest, false}, // shouldn't happen, but safe default
{400, ErrBadRequest, false},
{401, ErrAuth, false},
{403, ErrAuth, false},
{404, ErrNotFound, false},
{429, ErrTransient, true},
{500, ErrTransient, true},
{502, ErrTransient, true},
{503, ErrTransient, true},
{504, ErrOverloaded, true},
{529, ErrTransient, true},
{599, ErrTransient, true}, // unknown 5xx
}
for _, tt := range tests {
kind, retry := ClassifyHTTPStatus(tt.status)
if kind != tt.wantKind {
t.Errorf("ClassifyHTTPStatus(%d) kind = %v, want %v", tt.status, kind, tt.wantKind)
}
if retry != tt.wantRetry {
t.Errorf("ClassifyHTTPStatus(%d) retry = %v, want %v", tt.status, retry, tt.wantRetry)
}
}
}
func TestErrorKind_String(t *testing.T) {
tests := []struct {
kind ErrorKind
want string
}{
{ErrTransient, "transient"},
{ErrAuth, "auth"},
{ErrBadRequest, "bad_request"},
{ErrNotFound, "not_found"},
{ErrOverloaded, "overloaded"},
{ErrorKind(99), "unknown(99)"},
}
for _, tt := range tests {
if got := tt.kind.String(); got != tt.want {
t.Errorf("ErrorKind(%d).String() = %q, want %q", tt.kind, got, tt.want)
}
}
}

View File

@@ -0,0 +1,44 @@
package provider
import (
"context"
"encoding/json"
"somegit.dev/Owlibou/gnoma/internal/message"
"somegit.dev/Owlibou/gnoma/internal/stream"
)
// Request encapsulates everything needed for a single LLM API call.
type Request struct {
Model string
SystemPrompt string
Messages []message.Message
Tools []ToolDefinition
MaxTokens int64
Temperature *float64
TopP *float64
TopK *int64
StopSequences []string
Thinking *ThinkingConfig
}
// ToolDefinition is the provider-agnostic tool schema.
type ToolDefinition struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters json.RawMessage `json:"parameters"` // JSON Schema passthrough
}
// ThinkingConfig controls extended thinking / reasoning.
type ThinkingConfig struct {
BudgetTokens int64
}
// Provider is the core abstraction over all LLM backends.
type Provider interface {
// Stream initiates a streaming request and returns an event stream.
Stream(ctx context.Context, req Request) (stream.Stream, error)
// Name returns the provider identifier (e.g., "mistral", "anthropic").
Name() string
}

View File

@@ -0,0 +1,69 @@
package provider
import (
"fmt"
"sync"
)
// ProviderConfig is the common configuration for any provider.
type ProviderConfig struct {
Name string
APIKey string
BaseURL string // override for OpenAI-compat endpoints
Model string // default model for this provider
Options map[string]any // provider-specific options
}
// Factory creates a Provider from configuration.
type Factory func(cfg ProviderConfig) (Provider, error)
// Registry maps provider names to factory functions.
type Registry struct {
mu sync.RWMutex
factories map[string]Factory
}
func NewRegistry() *Registry {
return &Registry{
factories: make(map[string]Factory),
}
}
// Register adds a provider factory. Overwrites if name already exists.
func (r *Registry) Register(name string, f Factory) {
r.mu.Lock()
defer r.mu.Unlock()
r.factories[name] = f
}
// Create instantiates a provider by name with the given config.
func (r *Registry) Create(name string, cfg ProviderConfig) (Provider, error) {
r.mu.RLock()
f, ok := r.factories[name]
r.mu.RUnlock()
if !ok {
return nil, fmt.Errorf("unknown provider: %q", name)
}
cfg.Name = name
return f(cfg)
}
// Has returns true if a factory is registered for the given name.
func (r *Registry) Has(name string) bool {
r.mu.RLock()
defer r.mu.RUnlock()
_, ok := r.factories[name]
return ok
}
// Names returns all registered provider names.
func (r *Registry) Names() []string {
r.mu.RLock()
defer r.mu.RUnlock()
names := make([]string, 0, len(r.factories))
for name := range r.factories {
names = append(names, name)
}
return names
}

View File

@@ -0,0 +1,131 @@
package provider
import (
"context"
"errors"
"slices"
"sort"
"testing"
"somegit.dev/Owlibou/gnoma/internal/stream"
)
// mockProvider implements Provider for testing.
type mockProvider struct {
name string
}
func (m *mockProvider) Stream(_ context.Context, _ Request) (stream.Stream, error) {
return nil, nil
}
func (m *mockProvider) Name() string {
return m.name
}
func TestRegistry_RegisterAndCreate(t *testing.T) {
r := NewRegistry()
r.Register("mock", func(cfg ProviderConfig) (Provider, error) {
return &mockProvider{name: cfg.Name}, nil
})
p, err := r.Create("mock", ProviderConfig{})
if err != nil {
t.Fatalf("Create: %v", err)
}
if p.Name() != "mock" {
t.Errorf("Name() = %q, want %q", p.Name(), "mock")
}
}
func TestRegistry_Create_Unknown(t *testing.T) {
r := NewRegistry()
_, err := r.Create("nonexistent", ProviderConfig{})
if err == nil {
t.Fatal("expected error for unknown provider")
}
want := `unknown provider: "nonexistent"`
if err.Error() != want {
t.Errorf("error = %q, want %q", err.Error(), want)
}
}
func TestRegistry_Create_FactoryError(t *testing.T) {
r := NewRegistry()
r.Register("broken", func(cfg ProviderConfig) (Provider, error) {
return nil, errors.New("missing api key")
})
_, err := r.Create("broken", ProviderConfig{})
if err == nil {
t.Fatal("expected error from factory")
}
if err.Error() != "missing api key" {
t.Errorf("error = %q", err.Error())
}
}
func TestRegistry_Create_SetsName(t *testing.T) {
r := NewRegistry()
var receivedName string
r.Register("test", func(cfg ProviderConfig) (Provider, error) {
receivedName = cfg.Name
return &mockProvider{name: cfg.Name}, nil
})
_, _ = r.Create("test", ProviderConfig{APIKey: "sk-123"})
if receivedName != "test" {
t.Errorf("factory received Name = %q, want %q", receivedName, "test")
}
}
func TestRegistry_Has(t *testing.T) {
r := NewRegistry()
r.Register("exists", func(cfg ProviderConfig) (Provider, error) {
return nil, nil
})
if !r.Has("exists") {
t.Error("Has(exists) = false, want true")
}
if r.Has("nope") {
t.Error("Has(nope) = true, want false")
}
}
func TestRegistry_Names(t *testing.T) {
r := NewRegistry()
r.Register("alpha", func(cfg ProviderConfig) (Provider, error) { return nil, nil })
r.Register("beta", func(cfg ProviderConfig) (Provider, error) { return nil, nil })
r.Register("gamma", func(cfg ProviderConfig) (Provider, error) { return nil, nil })
names := r.Names()
sort.Strings(names)
want := []string{"alpha", "beta", "gamma"}
if !slices.Equal(names, want) {
t.Errorf("Names() = %v, want %v", names, want)
}
}
func TestRegistry_Register_Overwrite(t *testing.T) {
r := NewRegistry()
r.Register("dup", func(cfg ProviderConfig) (Provider, error) {
return &mockProvider{name: "old"}, nil
})
r.Register("dup", func(cfg ProviderConfig) (Provider, error) {
return &mockProvider{name: "new"}, nil
})
p, err := r.Create("dup", ProviderConfig{})
if err != nil {
t.Fatalf("Create: %v", err)
}
if p.Name() != "new" {
t.Errorf("Name() = %q, want %q (overwritten factory)", p.Name(), "new")
}
}

View File

@@ -0,0 +1,143 @@
package stream
import (
"strings"
"somegit.dev/Owlibou/gnoma/internal/message"
)
// Accumulator assembles a message.Response from a sequence of Events.
// Provider adapters translate SDK events into stream.Events;
// the Accumulator — shared, tested once — builds the final Response.
type Accumulator struct {
content []message.Content
usage message.Usage
// Active text block being built
textBuf *strings.Builder
// Active thinking block being built
thinkBuf *strings.Builder
// Tool calls in progress, keyed by ToolCallID
toolCalls map[string]*toolCallAccum
// Ordered tool call IDs to preserve emission order
toolCallOrder []string
}
type toolCallAccum struct {
id string
name string
argsBuf strings.Builder
args []byte // final complete args (from Done event)
}
func NewAccumulator() *Accumulator {
return &Accumulator{
toolCalls: make(map[string]*toolCallAccum),
}
}
// Apply processes a single event, updating the accumulator's state.
func (a *Accumulator) Apply(e Event) {
switch e.Type {
case EventTextDelta:
a.flushThinking()
if a.textBuf == nil {
a.textBuf = &strings.Builder{}
}
a.textBuf.WriteString(e.Text)
case EventThinkingDelta:
a.flushText()
if a.thinkBuf == nil {
a.thinkBuf = &strings.Builder{}
}
a.thinkBuf.WriteString(e.Text)
case EventToolCallStart:
a.flushText()
a.flushThinking()
tc := &toolCallAccum{id: e.ToolCallID, name: e.ToolCallName}
a.toolCalls[e.ToolCallID] = tc
a.toolCallOrder = append(a.toolCallOrder, e.ToolCallID)
case EventToolCallDelta:
if tc, ok := a.toolCalls[e.ToolCallID]; ok {
tc.argsBuf.WriteString(e.ArgDelta)
}
case EventToolCallDone:
if tc, ok := a.toolCalls[e.ToolCallID]; ok {
if e.Args != nil {
// Done event carries authoritative complete args
tc.args = e.Args
} else {
// Fall back to accumulated deltas
tc.args = []byte(tc.argsBuf.String())
}
}
case EventUsage:
if e.Usage != nil {
a.usage.Add(*e.Usage)
}
case EventError:
// Errors are handled by the stream consumer, not accumulated
}
}
// Response builds the final message.Response from all accumulated events.
func (a *Accumulator) Response(stopReason message.StopReason, model string) message.Response {
a.flushText()
a.flushThinking()
a.flushToolCalls()
return message.Response{
Message: message.Message{
Role: message.RoleAssistant,
Content: a.content,
},
StopReason: stopReason,
Usage: a.usage,
Model: model,
}
}
func (a *Accumulator) flushText() {
if a.textBuf != nil && a.textBuf.Len() > 0 {
a.content = append(a.content, message.NewTextContent(a.textBuf.String()))
a.textBuf = nil
}
}
func (a *Accumulator) flushThinking() {
if a.thinkBuf != nil && a.thinkBuf.Len() > 0 {
a.content = append(a.content, message.NewThinkingContent(message.Thinking{
Text: a.thinkBuf.String(),
}))
a.thinkBuf = nil
}
}
func (a *Accumulator) flushToolCalls() {
for _, id := range a.toolCallOrder {
tc, ok := a.toolCalls[id]
if !ok {
continue
}
args := tc.args
if args == nil {
// Fallback: use accumulated deltas even if Done never arrived
args = []byte(tc.argsBuf.String())
}
a.content = append(a.content, message.NewToolCallContent(message.ToolCall{
ID: tc.id,
Name: tc.name,
Arguments: args,
}))
}
a.toolCalls = make(map[string]*toolCallAccum)
a.toolCallOrder = nil
}

View File

@@ -0,0 +1,225 @@
package stream
import (
"encoding/json"
"testing"
"somegit.dev/Owlibou/gnoma/internal/message"
)
func TestAccumulator_TextOnly(t *testing.T) {
acc := NewAccumulator()
acc.Apply(Event{Type: EventTextDelta, Text: "Hello "})
acc.Apply(Event{Type: EventTextDelta, Text: "world!"})
acc.Apply(Event{Type: EventUsage, Usage: &message.Usage{InputTokens: 10, OutputTokens: 5}})
resp := acc.Response(message.StopEndTurn, "mistral-large")
if resp.StopReason != message.StopEndTurn {
t.Errorf("StopReason = %q, want %q", resp.StopReason, message.StopEndTurn)
}
if resp.Model != "mistral-large" {
t.Errorf("Model = %q, want %q", resp.Model, "mistral-large")
}
if resp.Message.Role != message.RoleAssistant {
t.Errorf("Role = %q, want %q", resp.Message.Role, message.RoleAssistant)
}
if resp.Message.TextContent() != "Hello world!" {
t.Errorf("TextContent() = %q, want %q", resp.Message.TextContent(), "Hello world!")
}
if resp.Usage.InputTokens != 10 {
t.Errorf("Usage.InputTokens = %d, want 10", resp.Usage.InputTokens)
}
}
func TestAccumulator_SingleToolCall(t *testing.T) {
acc := NewAccumulator()
acc.Apply(Event{Type: EventTextDelta, Text: "I'll run that."})
acc.Apply(Event{
Type: EventToolCallStart,
ToolCallID: "tc_1",
ToolCallName: "bash",
})
acc.Apply(Event{
Type: EventToolCallDelta,
ToolCallID: "tc_1",
ArgDelta: `{"comma`,
})
acc.Apply(Event{
Type: EventToolCallDelta,
ToolCallID: "tc_1",
ArgDelta: `nd":"ls"}`,
})
acc.Apply(Event{
Type: EventToolCallDone,
ToolCallID: "tc_1",
Args: json.RawMessage(`{"command":"ls"}`),
})
resp := acc.Response(message.StopToolUse, "mistral-large")
if resp.StopReason != message.StopToolUse {
t.Errorf("StopReason = %q, want %q", resp.StopReason, message.StopToolUse)
}
content := resp.Message.Content
if len(content) != 2 {
t.Fatalf("len(Content) = %d, want 2", len(content))
}
// First block: text
if content[0].Type != message.ContentText {
t.Errorf("Content[0].Type = %v, want text", content[0].Type)
}
if content[0].Text != "I'll run that." {
t.Errorf("Content[0].Text = %q", content[0].Text)
}
// Second block: tool call
if content[1].Type != message.ContentToolCall {
t.Errorf("Content[1].Type = %v, want tool_call", content[1].Type)
}
tc := content[1].ToolCall
if tc.ID != "tc_1" {
t.Errorf("ToolCall.ID = %q, want tc_1", tc.ID)
}
if tc.Name != "bash" {
t.Errorf("ToolCall.Name = %q, want bash", tc.Name)
}
if string(tc.Arguments) != `{"command":"ls"}` {
t.Errorf("ToolCall.Arguments = %s", tc.Arguments)
}
}
func TestAccumulator_MultipleToolCalls(t *testing.T) {
acc := NewAccumulator()
// Tool call 1
acc.Apply(Event{Type: EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "bash"})
acc.Apply(Event{Type: EventToolCallDone, ToolCallID: "tc_1", Args: json.RawMessage(`{"command":"ls"}`)})
// Tool call 2
acc.Apply(Event{Type: EventToolCallStart, ToolCallID: "tc_2", ToolCallName: "fs.read"})
acc.Apply(Event{Type: EventToolCallDone, ToolCallID: "tc_2", Args: json.RawMessage(`{"path":"go.mod"}`)})
resp := acc.Response(message.StopToolUse, "")
calls := resp.Message.ToolCalls()
if len(calls) != 2 {
t.Fatalf("len(ToolCalls()) = %d, want 2", len(calls))
}
if calls[0].Name != "bash" {
t.Errorf("calls[0].Name = %q", calls[0].Name)
}
if calls[1].Name != "fs.read" {
t.Errorf("calls[1].Name = %q", calls[1].Name)
}
}
func TestAccumulator_ThinkingBlocks(t *testing.T) {
acc := NewAccumulator()
acc.Apply(Event{Type: EventThinkingDelta, Text: "Let me think"})
acc.Apply(Event{Type: EventThinkingDelta, Text: " about this..."})
acc.Apply(Event{Type: EventTextDelta, Text: "Here's my answer."})
resp := acc.Response(message.StopEndTurn, "")
content := resp.Message.Content
if len(content) != 2 {
t.Fatalf("len(Content) = %d, want 2", len(content))
}
// Thinking block first
if content[0].Type != message.ContentThinking {
t.Errorf("Content[0].Type = %v, want thinking", content[0].Type)
}
if content[0].Thinking.Text != "Let me think about this..." {
t.Errorf("Thinking.Text = %q", content[0].Thinking.Text)
}
// Then text
if content[1].Type != message.ContentText {
t.Errorf("Content[1].Type = %v, want text", content[1].Type)
}
}
func TestAccumulator_ToolCallDelta_Assembly(t *testing.T) {
acc := NewAccumulator()
acc.Apply(Event{Type: EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "bash"})
// Simulate fragmented JSON
acc.Apply(Event{Type: EventToolCallDelta, ToolCallID: "tc_1", ArgDelta: `{"`})
acc.Apply(Event{Type: EventToolCallDelta, ToolCallID: "tc_1", ArgDelta: `comman`})
acc.Apply(Event{Type: EventToolCallDelta, ToolCallID: "tc_1", ArgDelta: `d":"`})
acc.Apply(Event{Type: EventToolCallDelta, ToolCallID: "tc_1", ArgDelta: `echo hi`})
acc.Apply(Event{Type: EventToolCallDelta, ToolCallID: "tc_1", ArgDelta: `"}`})
// Done event carries the complete args (authoritative)
acc.Apply(Event{
Type: EventToolCallDone,
ToolCallID: "tc_1",
Args: json.RawMessage(`{"command":"echo hi"}`),
})
resp := acc.Response(message.StopToolUse, "")
calls := resp.Message.ToolCalls()
if len(calls) != 1 {
t.Fatalf("len(ToolCalls()) = %d, want 1", len(calls))
}
if string(calls[0].Arguments) != `{"command":"echo hi"}` {
t.Errorf("Arguments = %s", calls[0].Arguments)
}
}
func TestAccumulator_EmptyStream(t *testing.T) {
acc := NewAccumulator()
resp := acc.Response(message.StopEndTurn, "model-x")
if resp.Model != "model-x" {
t.Errorf("Model = %q", resp.Model)
}
if len(resp.Message.Content) != 0 {
t.Errorf("len(Content) = %d, want 0", len(resp.Message.Content))
}
}
func TestAccumulator_UsageCumulative(t *testing.T) {
acc := NewAccumulator()
acc.Apply(Event{Type: EventUsage, Usage: &message.Usage{InputTokens: 100, OutputTokens: 30}})
acc.Apply(Event{Type: EventTextDelta, Text: "hi"})
acc.Apply(Event{Type: EventUsage, Usage: &message.Usage{InputTokens: 0, OutputTokens: 20}})
resp := acc.Response(message.StopEndTurn, "")
// Last usage wins (providers typically send cumulative, not incremental)
// But our Add() is additive — so 100+0=100 input, 30+20=50 output
if resp.Usage.InputTokens != 100 {
t.Errorf("InputTokens = %d, want 100", resp.Usage.InputTokens)
}
if resp.Usage.OutputTokens != 50 {
t.Errorf("OutputTokens = %d, want 50", resp.Usage.OutputTokens)
}
}
func TestEventType_String(t *testing.T) {
tests := []struct {
et EventType
want string
}{
{EventTextDelta, "text_delta"},
{EventThinkingDelta, "thinking_delta"},
{EventToolCallStart, "tool_call_start"},
{EventToolCallDelta, "tool_call_delta"},
{EventToolCallDone, "tool_call_done"},
{EventUsage, "usage"},
{EventError, "error"},
{EventType(99), "unknown(99)"},
}
for _, tt := range tests {
if got := tt.et.String(); got != tt.want {
t.Errorf("EventType(%d).String() = %q, want %q", tt.et, got, tt.want)
}
}
}

70
internal/stream/event.go Normal file
View File

@@ -0,0 +1,70 @@
package stream
import (
"encoding/json"
"fmt"
"somegit.dev/Owlibou/gnoma/internal/message"
)
// EventType discriminates streaming events.
type EventType int
const (
EventTextDelta EventType = iota + 1
EventThinkingDelta
EventToolCallStart
EventToolCallDelta
EventToolCallDone
EventUsage
EventError
)
func (et EventType) String() string {
switch et {
case EventTextDelta:
return "text_delta"
case EventThinkingDelta:
return "thinking_delta"
case EventToolCallStart:
return "tool_call_start"
case EventToolCallDelta:
return "tool_call_delta"
case EventToolCallDone:
return "tool_call_done"
case EventUsage:
return "usage"
case EventError:
return "error"
default:
return fmt.Sprintf("unknown(%d)", et)
}
}
// Event is a single streaming event from a provider.
type Event struct {
Type EventType
// TextDelta, ThinkingDelta
Text string
// ToolCallStart: ID + Name set
// ToolCallDelta: ID + ArgDelta set
// ToolCallDone: ID + Args set (complete JSON)
ToolCallID string
ToolCallName string
ArgDelta string // partial JSON fragment
Args json.RawMessage // complete arguments (on Done)
// Usage
Usage *message.Usage
// Error
Err error
// StopReason — set on the final event of a stream
StopReason message.StopReason
// Model — set on first event if available
Model string
}

25
internal/stream/stream.go Normal file
View File

@@ -0,0 +1,25 @@
package stream
// Stream is the unified pull-based iterator over provider events.
// All provider implementations produce this interface.
//
// Usage:
//
// for s.Next() {
// event := s.Current()
// process(event)
// }
// if err := s.Err(); err != nil {
// handle(err)
// }
// s.Close()
type Stream interface {
// Next advances to the next event. Returns false when exhausted or errored.
Next() bool
// Current returns the most recent event. Only valid after Next() returns true.
Current() Event
// Err returns the first error encountered, or nil.
Err() error
// Close releases resources. Must be called after consumption.
Close() error
}