247 lines
6.0 KiB
Go
247 lines
6.0 KiB
Go
package server_test
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
"testing"
|
|
"time"
|
|
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/credentials/insecure"
|
|
|
|
"somegit.dev/vikingowl/reddit-reader/internal/domain"
|
|
grpcserver "somegit.dev/vikingowl/reddit-reader/internal/grpc/server"
|
|
"somegit.dev/vikingowl/reddit-reader/internal/store"
|
|
pb "somegit.dev/vikingowl/reddit-reader/proto/redditreader"
|
|
)
|
|
|
|
func setupTestServer(t *testing.T) (pb.RedditReaderClient, *store.Store) {
|
|
t.Helper()
|
|
st, err := store.Open(":memory:")
|
|
if err != nil {
|
|
t.Fatalf("store.Open: %v", err)
|
|
}
|
|
t.Cleanup(func() { st.Close() })
|
|
|
|
lis, err := net.Listen("tcp", "localhost:0")
|
|
if err != nil {
|
|
t.Fatalf("net.Listen: %v", err)
|
|
}
|
|
srv := grpc.NewServer()
|
|
grpcserver.Register(srv, st, time.Now())
|
|
go srv.Serve(lis)
|
|
t.Cleanup(func() { srv.GracefulStop() })
|
|
|
|
conn, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()))
|
|
if err != nil {
|
|
t.Fatalf("grpc.NewClient: %v", err)
|
|
}
|
|
t.Cleanup(func() { conn.Close() })
|
|
|
|
return pb.NewRedditReaderClient(conn), st
|
|
}
|
|
|
|
func TestListPostsEmpty(t *testing.T) {
|
|
client, _ := setupTestServer(t)
|
|
resp, err := client.ListPosts(context.Background(), &pb.ListRequest{})
|
|
if err != nil {
|
|
t.Fatalf("ListPosts: %v", err)
|
|
}
|
|
if len(resp.Posts) != 0 {
|
|
t.Errorf("expected 0 posts, got %d", len(resp.Posts))
|
|
}
|
|
}
|
|
|
|
func TestListPostsWithData(t *testing.T) {
|
|
client, st := setupTestServer(t)
|
|
rel := 0.8
|
|
if err := st.InsertPost(domain.Post{
|
|
ID: "t3_a",
|
|
Subreddit: "golang",
|
|
Title: "Test",
|
|
CreatedUTC: time.Now(),
|
|
Relevance: &rel,
|
|
}); err != nil {
|
|
t.Fatalf("InsertPost: %v", err)
|
|
}
|
|
|
|
resp, err := client.ListPosts(context.Background(), &pb.ListRequest{})
|
|
if err != nil {
|
|
t.Fatalf("ListPosts: %v", err)
|
|
}
|
|
if len(resp.Posts) != 1 {
|
|
t.Fatalf("expected 1 post, got %d", len(resp.Posts))
|
|
}
|
|
if resp.Posts[0].Title != "Test" {
|
|
t.Errorf("Title = %q, want Test", resp.Posts[0].Title)
|
|
}
|
|
}
|
|
|
|
func TestUpdatePost(t *testing.T) {
|
|
client, st := setupTestServer(t)
|
|
if err := st.InsertPost(domain.Post{
|
|
ID: "t3_a",
|
|
Subreddit: "test",
|
|
Title: "Test",
|
|
CreatedUTC: time.Now(),
|
|
}); err != nil {
|
|
t.Fatalf("InsertPost: %v", err)
|
|
}
|
|
|
|
starred := true
|
|
resp, err := client.UpdatePost(context.Background(), &pb.UpdateRequest{
|
|
Id: "t3_a",
|
|
Starred: &starred,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("UpdatePost: %v", err)
|
|
}
|
|
if !resp.Starred {
|
|
t.Error("expected starred")
|
|
}
|
|
}
|
|
|
|
func TestSubmitFeedback(t *testing.T) {
|
|
client, st := setupTestServer(t)
|
|
if err := st.InsertPost(domain.Post{
|
|
ID: "t3_a",
|
|
Subreddit: "test",
|
|
Title: "Test",
|
|
CreatedUTC: time.Now(),
|
|
}); err != nil {
|
|
t.Fatalf("InsertPost: %v", err)
|
|
}
|
|
|
|
_, err := client.SubmitFeedback(context.Background(), &pb.FeedbackRequest{
|
|
PostId: "t3_a",
|
|
Vote: 1,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("SubmitFeedback: %v", err)
|
|
}
|
|
|
|
fb, err := st.RecentFeedback(10)
|
|
if err != nil {
|
|
t.Fatalf("RecentFeedback: %v", err)
|
|
}
|
|
if len(fb) != 1 || fb[0].Vote != 1 {
|
|
t.Errorf("feedback = %v", fb)
|
|
}
|
|
}
|
|
|
|
func TestSubredditCRUD(t *testing.T) {
|
|
client, _ := setupTestServer(t)
|
|
|
|
_, err := client.AddSubreddit(context.Background(), &pb.AddSubredditRequest{
|
|
Name: "golang",
|
|
PollSort: "new",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("AddSubreddit: %v", err)
|
|
}
|
|
|
|
list, err := client.ListSubreddits(context.Background(), &pb.Empty{})
|
|
if err != nil {
|
|
t.Fatalf("ListSubreddits: %v", err)
|
|
}
|
|
if len(list.Subreddits) != 1 {
|
|
t.Fatalf("expected 1 sub, got %d", len(list.Subreddits))
|
|
}
|
|
|
|
_, err = client.RemoveSubreddit(context.Background(), &pb.RemoveRequest{Name: "golang"})
|
|
if err != nil {
|
|
t.Fatalf("RemoveSubreddit: %v", err)
|
|
}
|
|
|
|
list, err = client.ListSubreddits(context.Background(), &pb.Empty{})
|
|
if err != nil {
|
|
t.Fatalf("ListSubreddits after remove: %v", err)
|
|
}
|
|
if len(list.Subreddits) != 0 {
|
|
t.Errorf("expected 0 after remove, got %d", len(list.Subreddits))
|
|
}
|
|
}
|
|
|
|
func TestStatus(t *testing.T) {
|
|
client, _ := setupTestServer(t)
|
|
resp, err := client.Status(context.Background(), &pb.Empty{})
|
|
if err != nil {
|
|
t.Fatalf("Status: %v", err)
|
|
}
|
|
if resp.UptimeSeconds < 0 {
|
|
t.Error("uptime should be >= 0")
|
|
}
|
|
}
|
|
|
|
func TestUpdateFilters(t *testing.T) {
|
|
client, st := setupTestServer(t)
|
|
|
|
// Need to add the subreddit first (foreign key constraint).
|
|
if err := st.AddSubreddit(domain.Subreddit{Name: "golang", PollSort: "new"}); err != nil {
|
|
t.Fatalf("AddSubreddit: %v", err)
|
|
}
|
|
|
|
resp, err := client.UpdateFilters(context.Background(), &pb.FilterRequest{
|
|
Subreddit: "golang",
|
|
Filters: []*pb.FilterMsg{
|
|
{Pattern: "hiring", IsRegex: false},
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("UpdateFilters: %v", err)
|
|
}
|
|
if len(resp.Filters) != 1 {
|
|
t.Fatalf("expected 1 filter, got %d", len(resp.Filters))
|
|
}
|
|
if resp.Filters[0].Pattern != "hiring" {
|
|
t.Errorf("Pattern = %q, want hiring", resp.Filters[0].Pattern)
|
|
}
|
|
}
|
|
|
|
func TestStreamNotify(t *testing.T) {
|
|
st, err := store.Open(":memory:")
|
|
if err != nil {
|
|
t.Fatalf("store.Open: %v", err)
|
|
}
|
|
t.Cleanup(func() { st.Close() })
|
|
|
|
lis, err := net.Listen("tcp", "localhost:0")
|
|
if err != nil {
|
|
t.Fatalf("net.Listen: %v", err)
|
|
}
|
|
srv := grpc.NewServer()
|
|
s := grpcserver.Register(srv, st, time.Now())
|
|
go srv.Serve(lis)
|
|
t.Cleanup(func() { srv.GracefulStop() })
|
|
|
|
conn, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()))
|
|
if err != nil {
|
|
t.Fatalf("grpc.NewClient: %v", err)
|
|
}
|
|
t.Cleanup(func() { conn.Close() })
|
|
|
|
client := pb.NewRedditReaderClient(conn)
|
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
|
defer cancel()
|
|
|
|
stream, err := client.StreamPosts(ctx, &pb.StreamRequest{})
|
|
if err != nil {
|
|
t.Fatalf("StreamPosts: %v", err)
|
|
}
|
|
|
|
// Give the stream a moment to register the subscriber.
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
s.Notify([]domain.Post{
|
|
{ID: "t3_stream", Subreddit: "test", Title: "Streamed", CreatedUTC: time.Now()},
|
|
})
|
|
|
|
post, err := stream.Recv()
|
|
if err != nil {
|
|
t.Fatalf("Recv: %v", err)
|
|
}
|
|
if post.Title != "Streamed" {
|
|
t.Errorf("Title = %q, want Streamed", post.Title)
|
|
}
|
|
}
|