Files
reddit-reader/internal/llm/mistral.go

99 lines
2.9 KiB
Go

package llm
import (
"context"
"fmt"
"strconv"
"strings"
mistral "github.com/VikingOwl91/mistral-go-sdk"
"github.com/VikingOwl91/mistral-go-sdk/chat"
"somegit.dev/vikingowl/reddit-reader/internal/domain"
)
// MistralClient wraps the Mistral AI SDK and implements the Summarizer interface.
type MistralClient struct {
client *mistral.Client
model string
}
// MistralOption configures a MistralClient.
type MistralOption func(*mistralOpts)
type mistralOpts struct {
baseURL string
}
// WithMistralBaseURL overrides the Mistral API base URL (useful for tests).
func WithMistralBaseURL(url string) MistralOption {
return func(o *mistralOpts) { o.baseURL = url }
}
// NewMistralClient creates a MistralClient using the given API key and model.
func NewMistralClient(apiKey, model string, opts ...MistralOption) *MistralClient {
var mo mistralOpts
for _, o := range opts {
o(&mo)
}
var clientOpts []mistral.Option
if mo.baseURL != "" {
clientOpts = append(clientOpts, mistral.WithBaseURL(mo.baseURL))
}
return &MistralClient{
client: mistral.NewClient(apiKey, clientOpts...),
model: model,
}
}
// Score returns a relevance score in [0.0, 1.0] for the given post against the user's interests.
func (m *MistralClient) Score(ctx context.Context, post domain.Post, interests domain.Interests) (float64, error) {
systemPrompt := buildScorePrompt(interests)
userPrompt := fmt.Sprintf("Title: %s\n\nContent: %s", post.Title, truncate(post.SelfText, 500))
resp, err := m.client.ChatComplete(ctx, &chat.CompletionRequest{
Model: m.model,
Messages: []chat.Message{
&chat.SystemMessage{Content: chat.TextContent(systemPrompt)},
&chat.UserMessage{Content: chat.TextContent(userPrompt)},
},
})
if err != nil {
return 0, fmt.Errorf("mistral score: %w", err)
}
if len(resp.Choices) == 0 {
return 0, fmt.Errorf("mistral score: no choices")
}
text := strings.TrimSpace(resp.Choices[0].Message.Content.String())
score, err := strconv.ParseFloat(text, 64)
if err != nil {
return 0, fmt.Errorf("mistral parse score %q: %w", text, err)
}
return score, nil
}
// Summarize produces a 5-bullet summary of the given post.
func (m *MistralClient) Summarize(ctx context.Context, post domain.Post) (string, error) {
systemPrompt := "You are a concise summarizer. Given a Reddit post, produce exactly 5 bullet points summarizing the key information. Each bullet starts with '- '. No other text."
userPrompt := fmt.Sprintf("Title: %s\n\nContent: %s", post.Title, post.SelfText)
resp, err := m.client.ChatComplete(ctx, &chat.CompletionRequest{
Model: m.model,
Messages: []chat.Message{
&chat.SystemMessage{Content: chat.TextContent(systemPrompt)},
&chat.UserMessage{Content: chat.TextContent(userPrompt)},
},
})
if err != nil {
return "", fmt.Errorf("mistral summarize: %w", err)
}
if len(resp.Choices) == 0 {
return "", fmt.Errorf("mistral summarize: no choices")
}
return strings.TrimSpace(resp.Choices[0].Message.Content.String()), nil
}