From 261c19f90f5d6f21118ff953ff3607b160933f03 Mon Sep 17 00:00:00 2001 From: vikingowl <26+vikingowl@noreply.somegit.dev> Date: Fri, 3 Apr 2026 13:33:55 +0200 Subject: [PATCH] feat: add OpenAI provider adapter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Streaming, tool use (index-based delta accumulation), tool name sanitization (fs.read → fs_read), StreamOptions.IncludeUsage for token tracking. Hardcoded model list (gpt-4o, gpt-4o-mini, o3, o3-mini). Wired into CLI with OPENAI_API_KEY env support. Live verified: text streaming + tool calling with gpt-4o. --- cmd/gnoma/main.go | 5 +- go.mod | 1 + go.sum | 2 + internal/provider/openai/provider.go | 103 +++++++++++++++++ internal/provider/openai/stream.go | 152 +++++++++++++++++++++++++ internal/provider/openai/translate.go | 153 ++++++++++++++++++++++++++ 6 files changed, 415 insertions(+), 1 deletion(-) create mode 100644 internal/provider/openai/provider.go create mode 100644 internal/provider/openai/stream.go create mode 100644 internal/provider/openai/translate.go diff --git a/cmd/gnoma/main.go b/cmd/gnoma/main.go index c4e1b18..6464dad 100644 --- a/cmd/gnoma/main.go +++ b/cmd/gnoma/main.go @@ -14,6 +14,7 @@ import ( "somegit.dev/Owlibou/gnoma/internal/provider" anthropicprov "somegit.dev/Owlibou/gnoma/internal/provider/anthropic" "somegit.dev/Owlibou/gnoma/internal/provider/mistral" + oaiprov "somegit.dev/Owlibou/gnoma/internal/provider/openai" "somegit.dev/Owlibou/gnoma/internal/stream" "somegit.dev/Owlibou/gnoma/internal/tool" "somegit.dev/Owlibou/gnoma/internal/tool/bash" @@ -191,8 +192,10 @@ func createProvider(name, apiKey, model string) (provider.Provider, error) { return mistral.New(cfg) case "anthropic": return anthropicprov.New(cfg) + case "openai": + return oaiprov.New(cfg) default: - return nil, fmt.Errorf("unknown provider %q (supports: mistral, anthropic)", name) + return nil, fmt.Errorf("unknown provider %q (supports: mistral, anthropic, openai)", name) } } diff --git a/go.mod b/go.mod index 4f34c7a..b27b00d 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.26.1 require ( github.com/VikingOwl91/mistral-go-sdk v1.2.1 github.com/anthropics/anthropic-sdk-go v1.29.0 + github.com/openai/openai-go v1.12.0 ) require ( diff --git a/go.sum b/go.sum index fc4adb7..8528421 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI= github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= +github.com/openai/openai-go v1.12.0 h1:NBQCnXzqOTv5wsgNC36PrFEiskGfO5wccfCWDo9S1U0= +github.com/openai/openai-go v1.12.0/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= diff --git a/internal/provider/openai/provider.go b/internal/provider/openai/provider.go new file mode 100644 index 0000000..514e810 --- /dev/null +++ b/internal/provider/openai/provider.go @@ -0,0 +1,103 @@ +package openai + +import ( + "context" + "fmt" + + "somegit.dev/Owlibou/gnoma/internal/provider" + "somegit.dev/Owlibou/gnoma/internal/stream" + + oai "github.com/openai/openai-go" + "github.com/openai/openai-go/option" +) + +const defaultModel = "gpt-4o" + +// Provider implements provider.Provider for the OpenAI API. +type Provider struct { + client *oai.Client + name string + model string +} + +// New creates an OpenAI provider from config. +func New(cfg provider.ProviderConfig) (provider.Provider, error) { + if cfg.APIKey == "" { + return nil, fmt.Errorf("openai: api key required") + } + + opts := []option.RequestOption{ + option.WithAPIKey(cfg.APIKey), + } + if cfg.BaseURL != "" { + opts = append(opts, option.WithBaseURL(cfg.BaseURL)) + } + + client := oai.NewClient(opts...) + + model := cfg.Model + if model == "" { + model = defaultModel + } + + return &Provider{ + client: &client, + name: "openai", + model: model, + }, nil +} + +// Stream initiates a streaming chat completion request. +func (p *Provider) Stream(ctx context.Context, req provider.Request) (stream.Stream, error) { + model := req.Model + if model == "" { + model = p.model + } + + params := translateRequest(req) + params.Model = model + + raw := p.client.Chat.Completions.NewStreaming(ctx, params) + + return newOpenAIStream(raw), nil +} + +// Name returns "openai". +func (p *Provider) Name() string { return p.name } + +// DefaultModel returns the configured default model. +func (p *Provider) DefaultModel() string { return p.model } + +// Models returns known OpenAI models with capabilities. +func (p *Provider) Models(_ context.Context) ([]provider.ModelInfo, error) { + return []provider.ModelInfo{ + { + ID: "gpt-4o", Name: "GPT-4o", Provider: p.name, + Capabilities: provider.Capabilities{ + ToolUse: true, JSONOutput: true, Vision: true, + ContextWindow: 128000, MaxOutput: 16384, + }, + }, + { + ID: "gpt-4o-mini", Name: "GPT-4o Mini", Provider: p.name, + Capabilities: provider.Capabilities{ + ToolUse: true, JSONOutput: true, Vision: true, + ContextWindow: 128000, MaxOutput: 16384, + }, + }, + { + ID: "o3", Name: "o3", Provider: p.name, + Capabilities: provider.Capabilities{ + ToolUse: true, JSONOutput: true, Thinking: true, + ContextWindow: 200000, MaxOutput: 100000, + }, + }, + { + ID: "o3-mini", Name: "o3 Mini", Provider: p.name, + Capabilities: provider.Capabilities{ + ToolUse: true, JSONOutput: true, Thinking: true, + ContextWindow: 200000, MaxOutput: 100000, + }, + }, + }, nil +} diff --git a/internal/provider/openai/stream.go b/internal/provider/openai/stream.go new file mode 100644 index 0000000..7d2b565 --- /dev/null +++ b/internal/provider/openai/stream.go @@ -0,0 +1,152 @@ +package openai + +import ( + "encoding/json" + + "somegit.dev/Owlibou/gnoma/internal/message" + "somegit.dev/Owlibou/gnoma/internal/stream" + + oai "github.com/openai/openai-go" + "github.com/openai/openai-go/packages/ssestream" +) + +// openaiStream adapts OpenAI's ssestream to gnoma's stream.Stream. +type openaiStream struct { + raw *ssestream.Stream[oai.ChatCompletionChunk] + cur stream.Event + err error + model string + stopReason message.StopReason + emittedStop bool + + // Tool call tracking (OpenAI uses index-based accumulation) + toolCalls map[int64]*toolCallState + hadToolCalls bool +} + +type toolCallState struct { + id string + name string + args string +} + +func newOpenAIStream(raw *ssestream.Stream[oai.ChatCompletionChunk]) *openaiStream { + return &openaiStream{ + raw: raw, + toolCalls: make(map[int64]*toolCallState), + } +} + +func (s *openaiStream) Next() bool { + for s.raw.Next() { + chunk := s.raw.Current() + + if s.model == "" && chunk.Model != "" { + s.model = chunk.Model + } + + // Usage (only present when StreamOptions.IncludeUsage is true) + if chunk.Usage.PromptTokens > 0 || chunk.Usage.CompletionTokens > 0 { + usage := translateUsage(chunk.Usage) + s.cur = stream.Event{ + Type: stream.EventUsage, + Usage: usage, + } + return true + } + + if len(chunk.Choices) == 0 { + continue + } + + choice := chunk.Choices[0] + delta := choice.Delta + + // Finish reason + if choice.FinishReason != "" { + s.stopReason = translateFinishReason(string(choice.FinishReason)) + } + + // Tool calls (index-based) + if len(delta.ToolCalls) > 0 { + for _, tc := range delta.ToolCalls { + existing, ok := s.toolCalls[tc.Index] + if !ok { + // New tool call + existing = &toolCallState{ + id: tc.ID, + name: tc.Function.Name, + } + s.toolCalls[tc.Index] = existing + s.hadToolCalls = true + + if tc.Function.Name != "" { + s.cur = stream.Event{ + Type: stream.EventToolCallStart, + ToolCallID: tc.ID, + ToolCallName: unsanitizeToolName(tc.Function.Name), + } + return true + } + } + + // Accumulate arguments + if tc.Function.Arguments != "" { + existing.args += tc.Function.Arguments + s.cur = stream.Event{ + Type: stream.EventToolCallDelta, + ToolCallID: existing.id, + ArgDelta: tc.Function.Arguments, + } + return true + } + } + continue + } + + // Text content + if delta.Content != "" { + s.cur = stream.Event{ + Type: stream.EventTextDelta, + Text: delta.Content, + } + return true + } + } + + // Stream ended — flush tool call Done events, then emit stop + for idx, tc := range s.toolCalls { + s.cur = stream.Event{ + Type: stream.EventToolCallDone, + ToolCallID: tc.id, + ToolCallName: unsanitizeToolName(tc.name), + Args: json.RawMessage(tc.args), + } + delete(s.toolCalls, idx) + return true + } + + if !s.emittedStop { + s.emittedStop = true + if s.stopReason == "" { + if s.hadToolCalls { + s.stopReason = message.StopToolUse + } else { + s.stopReason = message.StopEndTurn + } + } + s.cur = stream.Event{ + Type: stream.EventTextDelta, + StopReason: s.stopReason, + Model: s.model, + } + return true + } + + s.err = s.raw.Err() + return false +} + +func (s *openaiStream) Current() stream.Event { return s.cur } +func (s *openaiStream) Err() error { return s.err } +func (s *openaiStream) Close() error { return s.raw.Close() } diff --git a/internal/provider/openai/translate.go b/internal/provider/openai/translate.go new file mode 100644 index 0000000..1d9a19f --- /dev/null +++ b/internal/provider/openai/translate.go @@ -0,0 +1,153 @@ +package openai + +import ( + "encoding/json" + "strings" + + "somegit.dev/Owlibou/gnoma/internal/message" + "somegit.dev/Owlibou/gnoma/internal/provider" + + oai "github.com/openai/openai-go" + "github.com/openai/openai-go/packages/param" + "github.com/openai/openai-go/shared" +) + +func sanitizeToolName(name string) string { + return strings.ReplaceAll(name, ".", "_") +} + +func unsanitizeToolName(name string) string { + if strings.HasPrefix(name, "fs_") { + return "fs." + name[3:] + } + return name +} + +// --- gnoma → OpenAI --- + +func translateMessages(msgs []message.Message) []oai.ChatCompletionMessageParamUnion { + out := make([]oai.ChatCompletionMessageParamUnion, 0, len(msgs)) + for _, m := range msgs { + out = append(out, translateMessage(m)...) + } + return out +} + +func translateMessage(m message.Message) []oai.ChatCompletionMessageParamUnion { + switch m.Role { + case message.RoleSystem: + return []oai.ChatCompletionMessageParamUnion{ + oai.SystemMessage(m.TextContent()), + } + + case message.RoleUser: + // Tool results → individual ToolMessages + if len(m.Content) > 0 && m.Content[0].Type == message.ContentToolResult { + var msgs []oai.ChatCompletionMessageParamUnion + for _, c := range m.Content { + if c.Type == message.ContentToolResult && c.ToolResult != nil { + msgs = append(msgs, oai.ToolMessage(c.ToolResult.Content, c.ToolResult.ToolCallID)) + } + } + return msgs + } + return []oai.ChatCompletionMessageParamUnion{ + oai.UserMessage(m.TextContent()), + } + + case message.RoleAssistant: + msg := oai.ChatCompletionMessageParamUnion{ + OfAssistant: &oai.ChatCompletionAssistantMessageParam{ + Content: oai.ChatCompletionAssistantMessageParamContentUnion{ + OfString: param.NewOpt(m.TextContent()), + }, + }, + } + // Add tool calls + for _, tc := range m.ToolCalls() { + msg.OfAssistant.ToolCalls = append(msg.OfAssistant.ToolCalls, oai.ChatCompletionMessageToolCallParam{ + ID: tc.ID, + Function: oai.ChatCompletionMessageToolCallFunctionParam{ + Name: tc.Name, + Arguments: string(tc.Arguments), + }, + }) + } + return []oai.ChatCompletionMessageParamUnion{msg} + + default: + return nil + } +} + +func translateTools(defs []provider.ToolDefinition) []oai.ChatCompletionToolParam { + if len(defs) == 0 { + return nil + } + tools := make([]oai.ChatCompletionToolParam, len(defs)) + for i, d := range defs { + var params shared.FunctionParameters + if d.Parameters != nil { + _ = json.Unmarshal(d.Parameters, ¶ms) + } + tools[i] = oai.ChatCompletionToolParam{ + Function: shared.FunctionDefinitionParam{ + Name: sanitizeToolName(d.Name), + Description: param.NewOpt(d.Description), + Parameters: params, + }, + } + } + return tools +} + +func translateRequest(req provider.Request) oai.ChatCompletionNewParams { + params := oai.ChatCompletionNewParams{ + Model: req.Model, + Messages: translateMessages(req.Messages), + Tools: translateTools(req.Tools), + } + + if req.MaxTokens > 0 { + params.MaxCompletionTokens = param.NewOpt(req.MaxTokens) + } + if req.Temperature != nil { + params.Temperature = param.NewOpt(*req.Temperature) + } + if req.TopP != nil { + params.TopP = param.NewOpt(*req.TopP) + } + if len(req.StopSequences) > 0 { + params.Stop = oai.ChatCompletionNewParamsStopUnion{ + OfStringArray: req.StopSequences, + } + } + // Enable usage in streaming + params.StreamOptions = oai.ChatCompletionStreamOptionsParam{ + IncludeUsage: param.NewOpt(true), + } + + return params +} + +// --- OpenAI → gnoma --- + +func translateFinishReason(fr string) message.StopReason { + switch fr { + case "stop": + return message.StopEndTurn + case "tool_calls": + return message.StopToolUse + case "length": + return message.StopMaxTokens + default: + return message.StopEndTurn + } +} + +func translateUsage(u oai.CompletionUsage) *message.Usage { + return &message.Usage{ + InputTokens: u.PromptTokens, + OutputTokens: u.CompletionTokens, + } +}