diff --git a/internal/llm/llm.go b/internal/llm/llm.go new file mode 100644 index 0000000..49b6ff9 --- /dev/null +++ b/internal/llm/llm.go @@ -0,0 +1,11 @@ +package llm + +import ( + "context" + "somegit.dev/vikingowl/reddit-reader/internal/domain" +) + +type Summarizer interface { + Score(ctx context.Context, post domain.Post, interests domain.Interests) (float64, error) + Summarize(ctx context.Context, post domain.Post) (string, error) +} diff --git a/internal/llm/openai.go b/internal/llm/openai.go new file mode 100644 index 0000000..9e1db16 --- /dev/null +++ b/internal/llm/openai.go @@ -0,0 +1,145 @@ +package llm + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "strconv" + "strings" + + "somegit.dev/vikingowl/reddit-reader/internal/domain" +) + +type chatRequest struct { + Model string `json:"model"` + Messages []chatMessage `json:"messages"` + Temperature float64 `json:"temperature"` +} + +type chatMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type chatResponse struct { + Choices []struct { + Message struct { + Content string `json:"content"` + } `json:"message"` + } `json:"choices"` +} + +// OpenAIClient speaks the OpenAI-compatible /v1/chat/completions API. +// Works with Ollama, llama.cpp, and any other compatible backend. +type OpenAIClient struct { + baseURL string + model string + client *http.Client +} + +func NewOpenAIClient(baseURL, model string) *OpenAIClient { + return &OpenAIClient{ + baseURL: baseURL, + model: model, + client: &http.Client{}, + } +} + +func (c *OpenAIClient) Score(ctx context.Context, post domain.Post, interests domain.Interests) (float64, error) { + system := buildScorePrompt(interests) + user := fmt.Sprintf("Title: %s\n\n%s", post.Title, truncate(post.SelfText, 500)) + + content, err := c.complete(ctx, system, user, 0.1) + if err != nil { + return 0, err + } + + score, err := strconv.ParseFloat(strings.TrimSpace(content), 64) + if err != nil { + return 0, fmt.Errorf("llm: non-numeric score response %q: %w", content, err) + } + return score, nil +} + +func (c *OpenAIClient) Summarize(ctx context.Context, post domain.Post) (string, error) { + system := "You are a concise summarizer. Produce exactly 5 bullet points summarizing the post. Each bullet starts with '- '." + user := fmt.Sprintf("Title: %s\n\n%s", post.Title, truncate(post.SelfText, 500)) + + return c.complete(ctx, system, user, 0.3) +} + +func (c *OpenAIClient) complete(ctx context.Context, system, user string, temperature float64) (string, error) { + req := chatRequest{ + Model: c.model, + Messages: []chatMessage{ + {Role: "system", Content: system}, + {Role: "user", Content: user}, + }, + Temperature: temperature, + } + + body, err := json.Marshal(req) + if err != nil { + return "", fmt.Errorf("llm: marshal request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/chat/completions", bytes.NewReader(body)) + if err != nil { + return "", fmt.Errorf("llm: build request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := c.client.Do(httpReq) + if err != nil { + return "", fmt.Errorf("llm: http request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("llm: unexpected status %d", resp.StatusCode) + } + + var chatResp chatResponse + if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil { + return "", fmt.Errorf("llm: decode response: %w", err) + } + if len(chatResp.Choices) == 0 { + return "", fmt.Errorf("llm: no choices in response") + } + + return chatResp.Choices[0].Message.Content, nil +} + +// buildScorePrompt constructs the system prompt for relevance scoring. +// Package-level so mistral.go can reuse it. +func buildScorePrompt(interests domain.Interests) string { + var sb strings.Builder + sb.WriteString("You are a relevance scorer. Given a Reddit post and user interests, respond with a single float between 0.0 and 1.0 indicating relevance. No explanation, just the number.\n\n") + sb.WriteString("User interests: ") + sb.WriteString(interests.Description) + + if len(interests.Examples) > 0 { + sb.WriteString("\n\nFeedback examples (post IDs the user rated):\n") + for _, ex := range interests.Examples { + vote := "interesting" + if ex.Vote < 0 { + vote = "not interesting" + } + fmt.Fprintf(&sb, "- %s: %s\n", ex.PostID, vote) + } + } + + return sb.String() +} + +// truncate shortens s to at most maxLen runes. +// Package-level so mistral.go can reuse it. +func truncate(s string, maxLen int) string { + runes := []rune(s) + if len(runes) <= maxLen { + return s + } + return string(runes[:maxLen]) +} diff --git a/internal/llm/openai_test.go b/internal/llm/openai_test.go new file mode 100644 index 0000000..4793527 --- /dev/null +++ b/internal/llm/openai_test.go @@ -0,0 +1,103 @@ +package llm_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "somegit.dev/vikingowl/reddit-reader/internal/domain" + "somegit.dev/vikingowl/reddit-reader/internal/llm" +) + +func TestOpenAIScore(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/chat/completions" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + json.NewEncoder(w).Encode(map[string]any{ + "choices": []map[string]any{ + {"message": map[string]any{"role": "assistant", "content": "0.85"}}, + }, + }) + })) + defer srv.Close() + + client := llm.NewOpenAIClient(srv.URL, "test-model") + post := domain.Post{Title: "Go iterators", SelfText: "range over func patterns"} + interests := domain.Interests{Description: "Go programming"} + + score, err := client.Score(context.Background(), post, interests) + if err != nil { + t.Fatalf("Score: %v", err) + } + if score != 0.85 { + t.Errorf("score = %f, want 0.85", score) + } +} + +func TestOpenAISummarize(t *testing.T) { + want := "- point one\n- point two\n- point three\n- point four\n- point five" + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]any{ + "choices": []map[string]any{ + {"message": map[string]any{"role": "assistant", "content": want}}, + }, + }) + })) + defer srv.Close() + + client := llm.NewOpenAIClient(srv.URL, "test-model") + got, err := client.Summarize(context.Background(), domain.Post{Title: "Test", SelfText: "Some content"}) + if err != nil { + t.Fatalf("Summarize: %v", err) + } + if got != want { + t.Errorf("summary = %q, want %q", got, want) + } +} + +func TestOpenAIScoreInvalidResponse(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]any{ + "choices": []map[string]any{ + {"message": map[string]any{"role": "assistant", "content": "not a number"}}, + }, + }) + })) + defer srv.Close() + + client := llm.NewOpenAIClient(srv.URL, "test-model") + _, err := client.Score(context.Background(), domain.Post{Title: "Test"}, domain.Interests{}) + if err == nil { + t.Error("expected error for non-numeric score response") + } +} + +func TestOpenAIScorePromptIncludesFeedback(t *testing.T) { + var receivedBody map[string]any + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewDecoder(r.Body).Decode(&receivedBody) + json.NewEncoder(w).Encode(map[string]any{ + "choices": []map[string]any{ + {"message": map[string]any{"role": "assistant", "content": "0.5"}}, + }, + }) + })) + defer srv.Close() + + client := llm.NewOpenAIClient(srv.URL, "test-model") + interests := domain.Interests{ + Description: "Go", + Examples: []domain.Feedback{{PostID: "t3_good", Vote: 1}}, + } + client.Score(context.Background(), domain.Post{Title: "Test"}, interests) + + msgs := receivedBody["messages"].([]any) + systemMsg := msgs[0].(map[string]any)["content"].(string) + if !strings.Contains(systemMsg, "t3_good") { + t.Error("system prompt should contain feedback post IDs") + } +}