99 lines
2.9 KiB
Go
99 lines
2.9 KiB
Go
package llm
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strconv"
|
|
"strings"
|
|
|
|
mistral "github.com/VikingOwl91/mistral-go-sdk"
|
|
"github.com/VikingOwl91/mistral-go-sdk/chat"
|
|
|
|
"somegit.dev/vikingowl/reddit-reader/internal/domain"
|
|
)
|
|
|
|
// MistralClient wraps the Mistral AI SDK and implements the Summarizer interface.
|
|
type MistralClient struct {
|
|
client *mistral.Client
|
|
model string
|
|
}
|
|
|
|
// MistralOption configures a MistralClient.
|
|
type MistralOption func(*mistralOpts)
|
|
|
|
type mistralOpts struct {
|
|
baseURL string
|
|
}
|
|
|
|
// WithMistralBaseURL overrides the Mistral API base URL (useful for tests).
|
|
func WithMistralBaseURL(url string) MistralOption {
|
|
return func(o *mistralOpts) { o.baseURL = url }
|
|
}
|
|
|
|
// NewMistralClient creates a MistralClient using the given API key and model.
|
|
func NewMistralClient(apiKey, model string, opts ...MistralOption) *MistralClient {
|
|
var mo mistralOpts
|
|
for _, o := range opts {
|
|
o(&mo)
|
|
}
|
|
|
|
var clientOpts []mistral.Option
|
|
if mo.baseURL != "" {
|
|
clientOpts = append(clientOpts, mistral.WithBaseURL(mo.baseURL))
|
|
}
|
|
|
|
return &MistralClient{
|
|
client: mistral.NewClient(apiKey, clientOpts...),
|
|
model: model,
|
|
}
|
|
}
|
|
|
|
// Score returns a relevance score in [0.0, 1.0] for the given post against the user's interests.
|
|
func (m *MistralClient) Score(ctx context.Context, post domain.Post, interests domain.Interests) (float64, error) {
|
|
systemPrompt := buildScorePrompt(interests)
|
|
userPrompt := fmt.Sprintf("Title: %s\n\nContent: %s", post.Title, truncate(post.SelfText, 500))
|
|
|
|
resp, err := m.client.ChatComplete(ctx, &chat.CompletionRequest{
|
|
Model: m.model,
|
|
Messages: []chat.Message{
|
|
&chat.SystemMessage{Content: chat.TextContent(systemPrompt)},
|
|
&chat.UserMessage{Content: chat.TextContent(userPrompt)},
|
|
},
|
|
})
|
|
if err != nil {
|
|
return 0, fmt.Errorf("mistral score: %w", err)
|
|
}
|
|
if len(resp.Choices) == 0 {
|
|
return 0, fmt.Errorf("mistral score: no choices")
|
|
}
|
|
|
|
text := strings.TrimSpace(resp.Choices[0].Message.Content.String())
|
|
score, err := strconv.ParseFloat(text, 64)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("mistral parse score %q: %w", text, err)
|
|
}
|
|
return score, nil
|
|
}
|
|
|
|
// Summarize produces a 5-bullet summary of the given post.
|
|
func (m *MistralClient) Summarize(ctx context.Context, post domain.Post) (string, error) {
|
|
systemPrompt := "You are a concise summarizer. Given a Reddit post, produce exactly 5 bullet points summarizing the key information. Each bullet starts with '- '. No other text."
|
|
userPrompt := fmt.Sprintf("Title: %s\n\nContent: %s", post.Title, post.SelfText)
|
|
|
|
resp, err := m.client.ChatComplete(ctx, &chat.CompletionRequest{
|
|
Model: m.model,
|
|
Messages: []chat.Message{
|
|
&chat.SystemMessage{Content: chat.TextContent(systemPrompt)},
|
|
&chat.UserMessage{Content: chat.TextContent(userPrompt)},
|
|
},
|
|
})
|
|
if err != nil {
|
|
return "", fmt.Errorf("mistral summarize: %w", err)
|
|
}
|
|
if len(resp.Choices) == 0 {
|
|
return "", fmt.Errorf("mistral summarize: no choices")
|
|
}
|
|
|
|
return strings.TrimSpace(resp.Choices[0].Message.Content.String()), nil
|
|
}
|