Files
reddit-reader/internal/llm/openai.go

146 lines
3.9 KiB
Go

package llm
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"strconv"
"strings"
"somegit.dev/vikingowl/reddit-reader/internal/domain"
)
type chatRequest struct {
Model string `json:"model"`
Messages []chatMessage `json:"messages"`
Temperature float64 `json:"temperature"`
}
type chatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type chatResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
// OpenAIClient speaks the OpenAI-compatible /v1/chat/completions API.
// Works with Ollama, llama.cpp, and any other compatible backend.
type OpenAIClient struct {
baseURL string
model string
client *http.Client
}
func NewOpenAIClient(baseURL, model string) *OpenAIClient {
return &OpenAIClient{
baseURL: baseURL,
model: model,
client: &http.Client{},
}
}
func (c *OpenAIClient) Score(ctx context.Context, post domain.Post, interests domain.Interests) (float64, error) {
system := buildScorePrompt(interests)
user := fmt.Sprintf("Title: %s\n\n%s", post.Title, truncate(post.SelfText, 500))
content, err := c.complete(ctx, system, user, 0.1)
if err != nil {
return 0, err
}
score, err := strconv.ParseFloat(strings.TrimSpace(content), 64)
if err != nil {
return 0, fmt.Errorf("llm: non-numeric score response %q: %w", content, err)
}
return score, nil
}
func (c *OpenAIClient) Summarize(ctx context.Context, post domain.Post) (string, error) {
system := "You are a concise summarizer. Produce exactly 5 bullet points summarizing the post. Each bullet starts with '- '."
user := fmt.Sprintf("Title: %s\n\n%s", post.Title, truncate(post.SelfText, 500))
return c.complete(ctx, system, user, 0.3)
}
func (c *OpenAIClient) complete(ctx context.Context, system, user string, temperature float64) (string, error) {
req := chatRequest{
Model: c.model,
Messages: []chatMessage{
{Role: "system", Content: system},
{Role: "user", Content: user},
},
Temperature: temperature,
}
body, err := json.Marshal(req)
if err != nil {
return "", fmt.Errorf("llm: marshal request: %w", err)
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/chat/completions", bytes.NewReader(body))
if err != nil {
return "", fmt.Errorf("llm: build request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := c.client.Do(httpReq)
if err != nil {
return "", fmt.Errorf("llm: http request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("llm: unexpected status %d", resp.StatusCode)
}
var chatResp chatResponse
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
return "", fmt.Errorf("llm: decode response: %w", err)
}
if len(chatResp.Choices) == 0 {
return "", fmt.Errorf("llm: no choices in response")
}
return chatResp.Choices[0].Message.Content, nil
}
// buildScorePrompt constructs the system prompt for relevance scoring.
// Package-level so mistral.go can reuse it.
func buildScorePrompt(interests domain.Interests) string {
var sb strings.Builder
sb.WriteString("You are a relevance scorer. Given a Reddit post and user interests, respond with a single float between 0.0 and 1.0 indicating relevance. No explanation, just the number.\n\n")
sb.WriteString("User interests: ")
sb.WriteString(interests.Description)
if len(interests.Examples) > 0 {
sb.WriteString("\n\nFeedback examples (post IDs the user rated):\n")
for _, ex := range interests.Examples {
vote := "interesting"
if ex.Vote < 0 {
vote = "not interesting"
}
fmt.Fprintf(&sb, "- %s: %s\n", ex.PostID, vote)
}
}
return sb.String()
}
// truncate shortens s to at most maxLen runes.
// Package-level so mistral.go can reuse it.
func truncate(s string, maxLen int) string {
runes := []rune(s)
if len(runes) <= maxLen {
return s
}
return string(runes[:maxLen])
}