Files
reddit-reader/internal/llm/mistral_test.go

58 lines
2.0 KiB
Go

package llm_test
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"somegit.dev/vikingowl/reddit-reader/internal/domain"
"somegit.dev/vikingowl/reddit-reader/internal/llm"
)
func TestMistralScore(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]any{
"id": "test", "object": "chat.completion", "model": "mistral-small-latest", "created": 1234567890,
"choices": []map[string]any{
{"index": 0, "finish_reason": "stop", "message": map[string]any{"role": "assistant", "content": "0.72"}},
},
"usage": map[string]any{"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
})
}))
defer srv.Close()
client := llm.NewMistralClient("test-key", "mistral-small-latest", llm.WithMistralBaseURL(srv.URL))
score, err := client.Score(context.Background(), domain.Post{Title: "Test"}, domain.Interests{Description: "Go"})
if err != nil {
t.Fatalf("Score: %v", err)
}
if score != 0.72 {
t.Errorf("score = %f, want 0.72", score)
}
}
func TestMistralSummarize(t *testing.T) {
want := "- bullet one\n- bullet two\n- bullet three\n- bullet four\n- bullet five"
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]any{
"id": "test", "object": "chat.completion", "model": "mistral-small-latest", "created": 1234567890,
"choices": []map[string]any{
{"index": 0, "finish_reason": "stop", "message": map[string]any{"role": "assistant", "content": want}},
},
"usage": map[string]any{"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
})
}))
defer srv.Close()
client := llm.NewMistralClient("test-key", "mistral-small-latest", llm.WithMistralBaseURL(srv.URL))
got, err := client.Summarize(context.Background(), domain.Post{Title: "Test", SelfText: "content"})
if err != nil {
t.Fatalf("Summarize: %v", err)
}
if got != want {
t.Errorf("summary = %q, want %q", got, want)
}
}