feat: add OpenAI provider adapter
Streaming, tool use (index-based delta accumulation), tool name sanitization (fs.read → fs_read), StreamOptions.IncludeUsage for token tracking. Hardcoded model list (gpt-4o, gpt-4o-mini, o3, o3-mini). Wired into CLI with OPENAI_API_KEY env support. Live verified: text streaming + tool calling with gpt-4o.
This commit is contained in:
103
internal/provider/openai/provider.go
Normal file
103
internal/provider/openai/provider.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||||
"somegit.dev/Owlibou/gnoma/internal/stream"
|
||||
|
||||
oai "github.com/openai/openai-go"
|
||||
"github.com/openai/openai-go/option"
|
||||
)
|
||||
|
||||
const defaultModel = "gpt-4o"
|
||||
|
||||
// Provider implements provider.Provider for the OpenAI API.
|
||||
type Provider struct {
|
||||
client *oai.Client
|
||||
name string
|
||||
model string
|
||||
}
|
||||
|
||||
// New creates an OpenAI provider from config.
|
||||
func New(cfg provider.ProviderConfig) (provider.Provider, error) {
|
||||
if cfg.APIKey == "" {
|
||||
return nil, fmt.Errorf("openai: api key required")
|
||||
}
|
||||
|
||||
opts := []option.RequestOption{
|
||||
option.WithAPIKey(cfg.APIKey),
|
||||
}
|
||||
if cfg.BaseURL != "" {
|
||||
opts = append(opts, option.WithBaseURL(cfg.BaseURL))
|
||||
}
|
||||
|
||||
client := oai.NewClient(opts...)
|
||||
|
||||
model := cfg.Model
|
||||
if model == "" {
|
||||
model = defaultModel
|
||||
}
|
||||
|
||||
return &Provider{
|
||||
client: &client,
|
||||
name: "openai",
|
||||
model: model,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Stream initiates a streaming chat completion request.
|
||||
func (p *Provider) Stream(ctx context.Context, req provider.Request) (stream.Stream, error) {
|
||||
model := req.Model
|
||||
if model == "" {
|
||||
model = p.model
|
||||
}
|
||||
|
||||
params := translateRequest(req)
|
||||
params.Model = model
|
||||
|
||||
raw := p.client.Chat.Completions.NewStreaming(ctx, params)
|
||||
|
||||
return newOpenAIStream(raw), nil
|
||||
}
|
||||
|
||||
// Name returns "openai".
|
||||
func (p *Provider) Name() string { return p.name }
|
||||
|
||||
// DefaultModel returns the configured default model.
|
||||
func (p *Provider) DefaultModel() string { return p.model }
|
||||
|
||||
// Models returns known OpenAI models with capabilities.
|
||||
func (p *Provider) Models(_ context.Context) ([]provider.ModelInfo, error) {
|
||||
return []provider.ModelInfo{
|
||||
{
|
||||
ID: "gpt-4o", Name: "GPT-4o", Provider: p.name,
|
||||
Capabilities: provider.Capabilities{
|
||||
ToolUse: true, JSONOutput: true, Vision: true,
|
||||
ContextWindow: 128000, MaxOutput: 16384,
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "gpt-4o-mini", Name: "GPT-4o Mini", Provider: p.name,
|
||||
Capabilities: provider.Capabilities{
|
||||
ToolUse: true, JSONOutput: true, Vision: true,
|
||||
ContextWindow: 128000, MaxOutput: 16384,
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "o3", Name: "o3", Provider: p.name,
|
||||
Capabilities: provider.Capabilities{
|
||||
ToolUse: true, JSONOutput: true, Thinking: true,
|
||||
ContextWindow: 200000, MaxOutput: 100000,
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "o3-mini", Name: "o3 Mini", Provider: p.name,
|
||||
Capabilities: provider.Capabilities{
|
||||
ToolUse: true, JSONOutput: true, Thinking: true,
|
||||
ContextWindow: 200000, MaxOutput: 100000,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
152
internal/provider/openai/stream.go
Normal file
152
internal/provider/openai/stream.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
"somegit.dev/Owlibou/gnoma/internal/stream"
|
||||
|
||||
oai "github.com/openai/openai-go"
|
||||
"github.com/openai/openai-go/packages/ssestream"
|
||||
)
|
||||
|
||||
// openaiStream adapts OpenAI's ssestream to gnoma's stream.Stream.
|
||||
type openaiStream struct {
|
||||
raw *ssestream.Stream[oai.ChatCompletionChunk]
|
||||
cur stream.Event
|
||||
err error
|
||||
model string
|
||||
stopReason message.StopReason
|
||||
emittedStop bool
|
||||
|
||||
// Tool call tracking (OpenAI uses index-based accumulation)
|
||||
toolCalls map[int64]*toolCallState
|
||||
hadToolCalls bool
|
||||
}
|
||||
|
||||
type toolCallState struct {
|
||||
id string
|
||||
name string
|
||||
args string
|
||||
}
|
||||
|
||||
func newOpenAIStream(raw *ssestream.Stream[oai.ChatCompletionChunk]) *openaiStream {
|
||||
return &openaiStream{
|
||||
raw: raw,
|
||||
toolCalls: make(map[int64]*toolCallState),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *openaiStream) Next() bool {
|
||||
for s.raw.Next() {
|
||||
chunk := s.raw.Current()
|
||||
|
||||
if s.model == "" && chunk.Model != "" {
|
||||
s.model = chunk.Model
|
||||
}
|
||||
|
||||
// Usage (only present when StreamOptions.IncludeUsage is true)
|
||||
if chunk.Usage.PromptTokens > 0 || chunk.Usage.CompletionTokens > 0 {
|
||||
usage := translateUsage(chunk.Usage)
|
||||
s.cur = stream.Event{
|
||||
Type: stream.EventUsage,
|
||||
Usage: usage,
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
if len(chunk.Choices) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
choice := chunk.Choices[0]
|
||||
delta := choice.Delta
|
||||
|
||||
// Finish reason
|
||||
if choice.FinishReason != "" {
|
||||
s.stopReason = translateFinishReason(string(choice.FinishReason))
|
||||
}
|
||||
|
||||
// Tool calls (index-based)
|
||||
if len(delta.ToolCalls) > 0 {
|
||||
for _, tc := range delta.ToolCalls {
|
||||
existing, ok := s.toolCalls[tc.Index]
|
||||
if !ok {
|
||||
// New tool call
|
||||
existing = &toolCallState{
|
||||
id: tc.ID,
|
||||
name: tc.Function.Name,
|
||||
}
|
||||
s.toolCalls[tc.Index] = existing
|
||||
s.hadToolCalls = true
|
||||
|
||||
if tc.Function.Name != "" {
|
||||
s.cur = stream.Event{
|
||||
Type: stream.EventToolCallStart,
|
||||
ToolCallID: tc.ID,
|
||||
ToolCallName: unsanitizeToolName(tc.Function.Name),
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Accumulate arguments
|
||||
if tc.Function.Arguments != "" {
|
||||
existing.args += tc.Function.Arguments
|
||||
s.cur = stream.Event{
|
||||
Type: stream.EventToolCallDelta,
|
||||
ToolCallID: existing.id,
|
||||
ArgDelta: tc.Function.Arguments,
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Text content
|
||||
if delta.Content != "" {
|
||||
s.cur = stream.Event{
|
||||
Type: stream.EventTextDelta,
|
||||
Text: delta.Content,
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Stream ended — flush tool call Done events, then emit stop
|
||||
for idx, tc := range s.toolCalls {
|
||||
s.cur = stream.Event{
|
||||
Type: stream.EventToolCallDone,
|
||||
ToolCallID: tc.id,
|
||||
ToolCallName: unsanitizeToolName(tc.name),
|
||||
Args: json.RawMessage(tc.args),
|
||||
}
|
||||
delete(s.toolCalls, idx)
|
||||
return true
|
||||
}
|
||||
|
||||
if !s.emittedStop {
|
||||
s.emittedStop = true
|
||||
if s.stopReason == "" {
|
||||
if s.hadToolCalls {
|
||||
s.stopReason = message.StopToolUse
|
||||
} else {
|
||||
s.stopReason = message.StopEndTurn
|
||||
}
|
||||
}
|
||||
s.cur = stream.Event{
|
||||
Type: stream.EventTextDelta,
|
||||
StopReason: s.stopReason,
|
||||
Model: s.model,
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
s.err = s.raw.Err()
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *openaiStream) Current() stream.Event { return s.cur }
|
||||
func (s *openaiStream) Err() error { return s.err }
|
||||
func (s *openaiStream) Close() error { return s.raw.Close() }
|
||||
153
internal/provider/openai/translate.go
Normal file
153
internal/provider/openai/translate.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||||
|
||||
oai "github.com/openai/openai-go"
|
||||
"github.com/openai/openai-go/packages/param"
|
||||
"github.com/openai/openai-go/shared"
|
||||
)
|
||||
|
||||
func sanitizeToolName(name string) string {
|
||||
return strings.ReplaceAll(name, ".", "_")
|
||||
}
|
||||
|
||||
func unsanitizeToolName(name string) string {
|
||||
if strings.HasPrefix(name, "fs_") {
|
||||
return "fs." + name[3:]
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
// --- gnoma → OpenAI ---
|
||||
|
||||
func translateMessages(msgs []message.Message) []oai.ChatCompletionMessageParamUnion {
|
||||
out := make([]oai.ChatCompletionMessageParamUnion, 0, len(msgs))
|
||||
for _, m := range msgs {
|
||||
out = append(out, translateMessage(m)...)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func translateMessage(m message.Message) []oai.ChatCompletionMessageParamUnion {
|
||||
switch m.Role {
|
||||
case message.RoleSystem:
|
||||
return []oai.ChatCompletionMessageParamUnion{
|
||||
oai.SystemMessage(m.TextContent()),
|
||||
}
|
||||
|
||||
case message.RoleUser:
|
||||
// Tool results → individual ToolMessages
|
||||
if len(m.Content) > 0 && m.Content[0].Type == message.ContentToolResult {
|
||||
var msgs []oai.ChatCompletionMessageParamUnion
|
||||
for _, c := range m.Content {
|
||||
if c.Type == message.ContentToolResult && c.ToolResult != nil {
|
||||
msgs = append(msgs, oai.ToolMessage(c.ToolResult.Content, c.ToolResult.ToolCallID))
|
||||
}
|
||||
}
|
||||
return msgs
|
||||
}
|
||||
return []oai.ChatCompletionMessageParamUnion{
|
||||
oai.UserMessage(m.TextContent()),
|
||||
}
|
||||
|
||||
case message.RoleAssistant:
|
||||
msg := oai.ChatCompletionMessageParamUnion{
|
||||
OfAssistant: &oai.ChatCompletionAssistantMessageParam{
|
||||
Content: oai.ChatCompletionAssistantMessageParamContentUnion{
|
||||
OfString: param.NewOpt(m.TextContent()),
|
||||
},
|
||||
},
|
||||
}
|
||||
// Add tool calls
|
||||
for _, tc := range m.ToolCalls() {
|
||||
msg.OfAssistant.ToolCalls = append(msg.OfAssistant.ToolCalls, oai.ChatCompletionMessageToolCallParam{
|
||||
ID: tc.ID,
|
||||
Function: oai.ChatCompletionMessageToolCallFunctionParam{
|
||||
Name: tc.Name,
|
||||
Arguments: string(tc.Arguments),
|
||||
},
|
||||
})
|
||||
}
|
||||
return []oai.ChatCompletionMessageParamUnion{msg}
|
||||
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func translateTools(defs []provider.ToolDefinition) []oai.ChatCompletionToolParam {
|
||||
if len(defs) == 0 {
|
||||
return nil
|
||||
}
|
||||
tools := make([]oai.ChatCompletionToolParam, len(defs))
|
||||
for i, d := range defs {
|
||||
var params shared.FunctionParameters
|
||||
if d.Parameters != nil {
|
||||
_ = json.Unmarshal(d.Parameters, ¶ms)
|
||||
}
|
||||
tools[i] = oai.ChatCompletionToolParam{
|
||||
Function: shared.FunctionDefinitionParam{
|
||||
Name: sanitizeToolName(d.Name),
|
||||
Description: param.NewOpt(d.Description),
|
||||
Parameters: params,
|
||||
},
|
||||
}
|
||||
}
|
||||
return tools
|
||||
}
|
||||
|
||||
func translateRequest(req provider.Request) oai.ChatCompletionNewParams {
|
||||
params := oai.ChatCompletionNewParams{
|
||||
Model: req.Model,
|
||||
Messages: translateMessages(req.Messages),
|
||||
Tools: translateTools(req.Tools),
|
||||
}
|
||||
|
||||
if req.MaxTokens > 0 {
|
||||
params.MaxCompletionTokens = param.NewOpt(req.MaxTokens)
|
||||
}
|
||||
if req.Temperature != nil {
|
||||
params.Temperature = param.NewOpt(*req.Temperature)
|
||||
}
|
||||
if req.TopP != nil {
|
||||
params.TopP = param.NewOpt(*req.TopP)
|
||||
}
|
||||
if len(req.StopSequences) > 0 {
|
||||
params.Stop = oai.ChatCompletionNewParamsStopUnion{
|
||||
OfStringArray: req.StopSequences,
|
||||
}
|
||||
}
|
||||
// Enable usage in streaming
|
||||
params.StreamOptions = oai.ChatCompletionStreamOptionsParam{
|
||||
IncludeUsage: param.NewOpt(true),
|
||||
}
|
||||
|
||||
return params
|
||||
}
|
||||
|
||||
// --- OpenAI → gnoma ---
|
||||
|
||||
func translateFinishReason(fr string) message.StopReason {
|
||||
switch fr {
|
||||
case "stop":
|
||||
return message.StopEndTurn
|
||||
case "tool_calls":
|
||||
return message.StopToolUse
|
||||
case "length":
|
||||
return message.StopMaxTokens
|
||||
default:
|
||||
return message.StopEndTurn
|
||||
}
|
||||
}
|
||||
|
||||
func translateUsage(u oai.CompletionUsage) *message.Usage {
|
||||
return &message.Usage{
|
||||
InputTokens: u.PromptTokens,
|
||||
OutputTokens: u.CompletionTokens,
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user