58 lines
2.0 KiB
Go
58 lines
2.0 KiB
Go
package llm_test
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"somegit.dev/vikingowl/reddit-reader/internal/domain"
|
|
"somegit.dev/vikingowl/reddit-reader/internal/llm"
|
|
)
|
|
|
|
func TestMistralScore(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
json.NewEncoder(w).Encode(map[string]any{
|
|
"id": "test", "object": "chat.completion", "model": "mistral-small-latest", "created": 1234567890,
|
|
"choices": []map[string]any{
|
|
{"index": 0, "finish_reason": "stop", "message": map[string]any{"role": "assistant", "content": "0.72"}},
|
|
},
|
|
"usage": map[string]any{"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
|
})
|
|
}))
|
|
defer srv.Close()
|
|
|
|
client := llm.NewMistralClient("test-key", "mistral-small-latest", llm.WithMistralBaseURL(srv.URL))
|
|
score, err := client.Score(context.Background(), domain.Post{Title: "Test"}, domain.Interests{Description: "Go"})
|
|
if err != nil {
|
|
t.Fatalf("Score: %v", err)
|
|
}
|
|
if score != 0.72 {
|
|
t.Errorf("score = %f, want 0.72", score)
|
|
}
|
|
}
|
|
|
|
func TestMistralSummarize(t *testing.T) {
|
|
want := "- bullet one\n- bullet two\n- bullet three\n- bullet four\n- bullet five"
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
json.NewEncoder(w).Encode(map[string]any{
|
|
"id": "test", "object": "chat.completion", "model": "mistral-small-latest", "created": 1234567890,
|
|
"choices": []map[string]any{
|
|
{"index": 0, "finish_reason": "stop", "message": map[string]any{"role": "assistant", "content": want}},
|
|
},
|
|
"usage": map[string]any{"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
|
})
|
|
}))
|
|
defer srv.Close()
|
|
|
|
client := llm.NewMistralClient("test-key", "mistral-small-latest", llm.WithMistralBaseURL(srv.URL))
|
|
got, err := client.Summarize(context.Background(), domain.Post{Title: "Test", SelfText: "content"})
|
|
if err != nil {
|
|
t.Fatalf("Summarize: %v", err)
|
|
}
|
|
if got != want {
|
|
t.Errorf("summary = %q, want %q", got, want)
|
|
}
|
|
}
|