feat(llm): Summarizer interface and OpenAI-compatible backend
This commit is contained in:
11
internal/llm/llm.go
Normal file
11
internal/llm/llm.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"somegit.dev/vikingowl/reddit-reader/internal/domain"
|
||||
)
|
||||
|
||||
type Summarizer interface {
|
||||
Score(ctx context.Context, post domain.Post, interests domain.Interests) (float64, error)
|
||||
Summarize(ctx context.Context, post domain.Post) (string, error)
|
||||
}
|
||||
145
internal/llm/openai.go
Normal file
145
internal/llm/openai.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"somegit.dev/vikingowl/reddit-reader/internal/domain"
|
||||
)
|
||||
|
||||
type chatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []chatMessage `json:"messages"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
}
|
||||
|
||||
type chatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type chatResponse struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
// OpenAIClient speaks the OpenAI-compatible /v1/chat/completions API.
|
||||
// Works with Ollama, llama.cpp, and any other compatible backend.
|
||||
type OpenAIClient struct {
|
||||
baseURL string
|
||||
model string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func NewOpenAIClient(baseURL, model string) *OpenAIClient {
|
||||
return &OpenAIClient{
|
||||
baseURL: baseURL,
|
||||
model: model,
|
||||
client: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) Score(ctx context.Context, post domain.Post, interests domain.Interests) (float64, error) {
|
||||
system := buildScorePrompt(interests)
|
||||
user := fmt.Sprintf("Title: %s\n\n%s", post.Title, truncate(post.SelfText, 500))
|
||||
|
||||
content, err := c.complete(ctx, system, user, 0.1)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
score, err := strconv.ParseFloat(strings.TrimSpace(content), 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("llm: non-numeric score response %q: %w", content, err)
|
||||
}
|
||||
return score, nil
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) Summarize(ctx context.Context, post domain.Post) (string, error) {
|
||||
system := "You are a concise summarizer. Produce exactly 5 bullet points summarizing the post. Each bullet starts with '- '."
|
||||
user := fmt.Sprintf("Title: %s\n\n%s", post.Title, truncate(post.SelfText, 500))
|
||||
|
||||
return c.complete(ctx, system, user, 0.3)
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) complete(ctx context.Context, system, user string, temperature float64) (string, error) {
|
||||
req := chatRequest{
|
||||
Model: c.model,
|
||||
Messages: []chatMessage{
|
||||
{Role: "system", Content: system},
|
||||
{Role: "user", Content: user},
|
||||
},
|
||||
Temperature: temperature,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("llm: marshal request: %w", err)
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/chat/completions", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("llm: build request: %w", err)
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("llm: http request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("llm: unexpected status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var chatResp chatResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
|
||||
return "", fmt.Errorf("llm: decode response: %w", err)
|
||||
}
|
||||
if len(chatResp.Choices) == 0 {
|
||||
return "", fmt.Errorf("llm: no choices in response")
|
||||
}
|
||||
|
||||
return chatResp.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
// buildScorePrompt constructs the system prompt for relevance scoring.
|
||||
// Package-level so mistral.go can reuse it.
|
||||
func buildScorePrompt(interests domain.Interests) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("You are a relevance scorer. Given a Reddit post and user interests, respond with a single float between 0.0 and 1.0 indicating relevance. No explanation, just the number.\n\n")
|
||||
sb.WriteString("User interests: ")
|
||||
sb.WriteString(interests.Description)
|
||||
|
||||
if len(interests.Examples) > 0 {
|
||||
sb.WriteString("\n\nFeedback examples (post IDs the user rated):\n")
|
||||
for _, ex := range interests.Examples {
|
||||
vote := "interesting"
|
||||
if ex.Vote < 0 {
|
||||
vote = "not interesting"
|
||||
}
|
||||
fmt.Fprintf(&sb, "- %s: %s\n", ex.PostID, vote)
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// truncate shortens s to at most maxLen runes.
|
||||
// Package-level so mistral.go can reuse it.
|
||||
func truncate(s string, maxLen int) string {
|
||||
runes := []rune(s)
|
||||
if len(runes) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return string(runes[:maxLen])
|
||||
}
|
||||
103
internal/llm/openai_test.go
Normal file
103
internal/llm/openai_test.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package llm_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/vikingowl/reddit-reader/internal/domain"
|
||||
"somegit.dev/vikingowl/reddit-reader/internal/llm"
|
||||
)
|
||||
|
||||
func TestOpenAIScore(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/chat/completions" {
|
||||
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{"message": map[string]any{"role": "assistant", "content": "0.85"}},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := llm.NewOpenAIClient(srv.URL, "test-model")
|
||||
post := domain.Post{Title: "Go iterators", SelfText: "range over func patterns"}
|
||||
interests := domain.Interests{Description: "Go programming"}
|
||||
|
||||
score, err := client.Score(context.Background(), post, interests)
|
||||
if err != nil {
|
||||
t.Fatalf("Score: %v", err)
|
||||
}
|
||||
if score != 0.85 {
|
||||
t.Errorf("score = %f, want 0.85", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISummarize(t *testing.T) {
|
||||
want := "- point one\n- point two\n- point three\n- point four\n- point five"
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{"message": map[string]any{"role": "assistant", "content": want}},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := llm.NewOpenAIClient(srv.URL, "test-model")
|
||||
got, err := client.Summarize(context.Background(), domain.Post{Title: "Test", SelfText: "Some content"})
|
||||
if err != nil {
|
||||
t.Fatalf("Summarize: %v", err)
|
||||
}
|
||||
if got != want {
|
||||
t.Errorf("summary = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIScoreInvalidResponse(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{"message": map[string]any{"role": "assistant", "content": "not a number"}},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := llm.NewOpenAIClient(srv.URL, "test-model")
|
||||
_, err := client.Score(context.Background(), domain.Post{Title: "Test"}, domain.Interests{})
|
||||
if err == nil {
|
||||
t.Error("expected error for non-numeric score response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIScorePromptIncludesFeedback(t *testing.T) {
|
||||
var receivedBody map[string]any
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewDecoder(r.Body).Decode(&receivedBody)
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{"message": map[string]any{"role": "assistant", "content": "0.5"}},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := llm.NewOpenAIClient(srv.URL, "test-model")
|
||||
interests := domain.Interests{
|
||||
Description: "Go",
|
||||
Examples: []domain.Feedback{{PostID: "t3_good", Vote: 1}},
|
||||
}
|
||||
client.Score(context.Background(), domain.Post{Title: "Test"}, interests)
|
||||
|
||||
msgs := receivedBody["messages"].([]any)
|
||||
systemMsg := msgs[0].(map[string]any)["content"].(string)
|
||||
if !strings.Contains(systemMsg, "t3_good") {
|
||||
t.Error("system prompt should contain feedback post IDs")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user