104 lines
3.1 KiB
Go
104 lines
3.1 KiB
Go
package llm_test
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
|
|
"somegit.dev/vikingowl/reddit-reader/internal/domain"
|
|
"somegit.dev/vikingowl/reddit-reader/internal/llm"
|
|
)
|
|
|
|
func TestOpenAIScore(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/v1/chat/completions" {
|
|
t.Errorf("unexpected path: %s", r.URL.Path)
|
|
}
|
|
json.NewEncoder(w).Encode(map[string]any{
|
|
"choices": []map[string]any{
|
|
{"message": map[string]any{"role": "assistant", "content": "0.85"}},
|
|
},
|
|
})
|
|
}))
|
|
defer srv.Close()
|
|
|
|
client := llm.NewOpenAIClient(srv.URL, "test-model")
|
|
post := domain.Post{Title: "Go iterators", SelfText: "range over func patterns"}
|
|
interests := domain.Interests{Description: "Go programming"}
|
|
|
|
score, err := client.Score(context.Background(), post, interests)
|
|
if err != nil {
|
|
t.Fatalf("Score: %v", err)
|
|
}
|
|
if score != 0.85 {
|
|
t.Errorf("score = %f, want 0.85", score)
|
|
}
|
|
}
|
|
|
|
func TestOpenAISummarize(t *testing.T) {
|
|
want := "- point one\n- point two\n- point three\n- point four\n- point five"
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
json.NewEncoder(w).Encode(map[string]any{
|
|
"choices": []map[string]any{
|
|
{"message": map[string]any{"role": "assistant", "content": want}},
|
|
},
|
|
})
|
|
}))
|
|
defer srv.Close()
|
|
|
|
client := llm.NewOpenAIClient(srv.URL, "test-model")
|
|
got, err := client.Summarize(context.Background(), domain.Post{Title: "Test", SelfText: "Some content"})
|
|
if err != nil {
|
|
t.Fatalf("Summarize: %v", err)
|
|
}
|
|
if got != want {
|
|
t.Errorf("summary = %q, want %q", got, want)
|
|
}
|
|
}
|
|
|
|
func TestOpenAIScoreInvalidResponse(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
json.NewEncoder(w).Encode(map[string]any{
|
|
"choices": []map[string]any{
|
|
{"message": map[string]any{"role": "assistant", "content": "not a number"}},
|
|
},
|
|
})
|
|
}))
|
|
defer srv.Close()
|
|
|
|
client := llm.NewOpenAIClient(srv.URL, "test-model")
|
|
_, err := client.Score(context.Background(), domain.Post{Title: "Test"}, domain.Interests{})
|
|
if err == nil {
|
|
t.Error("expected error for non-numeric score response")
|
|
}
|
|
}
|
|
|
|
func TestOpenAIScorePromptIncludesFeedback(t *testing.T) {
|
|
var receivedBody map[string]any
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
json.NewDecoder(r.Body).Decode(&receivedBody)
|
|
json.NewEncoder(w).Encode(map[string]any{
|
|
"choices": []map[string]any{
|
|
{"message": map[string]any{"role": "assistant", "content": "0.5"}},
|
|
},
|
|
})
|
|
}))
|
|
defer srv.Close()
|
|
|
|
client := llm.NewOpenAIClient(srv.URL, "test-model")
|
|
interests := domain.Interests{
|
|
Description: "Go",
|
|
Examples: []domain.Feedback{{PostID: "t3_good", Vote: 1}},
|
|
}
|
|
client.Score(context.Background(), domain.Post{Title: "Test"}, interests)
|
|
|
|
msgs := receivedBody["messages"].([]any)
|
|
systemMsg := msgs[0].(map[string]any)["content"].(string)
|
|
if !strings.Contains(systemMsg, "t3_good") {
|
|
t.Error("system prompt should contain feedback post IDs")
|
|
}
|
|
}
|