Files

266 lines
6.7 KiB
Go

package server
import (
"context"
"sync"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
"somegit.dev/vikingowl/reddit-reader/internal/domain"
"somegit.dev/vikingowl/reddit-reader/internal/store"
pb "somegit.dev/vikingowl/reddit-reader/proto/redditreader"
)
// Server implements the RedditReader gRPC service.
type Server struct {
pb.UnimplementedRedditReaderServer
store *store.Store
startedAt time.Time
mu sync.RWMutex
subscribers map[chan *pb.Post]struct{}
}
// Register creates a Server and registers it with the gRPC server.
func Register(srv *grpc.Server, st *store.Store, startedAt time.Time) *Server {
s := &Server{
store: st,
startedAt: startedAt,
subscribers: make(map[chan *pb.Post]struct{}),
}
pb.RegisterRedditReaderServer(srv, s)
return s
}
// Notify pushes new posts to all connected stream subscribers.
func (s *Server) Notify(posts []domain.Post) {
s.mu.RLock()
defer s.mu.RUnlock()
for ch := range s.subscribers {
for i := range posts {
p := domainToProto(posts[i])
select {
case ch <- p:
default:
// subscriber too slow, drop
}
}
}
}
// StreamPosts adds a subscriber channel and sends posts as they arrive.
func (s *Server) StreamPosts(_ *pb.StreamRequest, stream grpc.ServerStreamingServer[pb.Post]) error {
ch := make(chan *pb.Post, 64)
s.mu.Lock()
s.subscribers[ch] = struct{}{}
s.mu.Unlock()
defer func() {
s.mu.Lock()
delete(s.subscribers, ch)
s.mu.Unlock()
}()
for {
select {
case <-stream.Context().Done():
return stream.Context().Err()
case post := <-ch:
if err := stream.Send(post); err != nil {
return err
}
}
}
}
// ListPosts returns posts matching the filter criteria.
func (s *Server) ListPosts(_ context.Context, req *pb.ListRequest) (*pb.ListResponse, error) {
f := store.ListFilter{
Subreddit: req.GetSubreddit(),
Limit: int(req.GetLimit()),
}
if req.Unread != nil {
v := req.GetUnread()
f.Unread = &v
}
if req.Starred != nil {
v := req.GetStarred()
f.Starred = &v
}
if req.Dismissed != nil {
v := req.GetDismissed()
f.Dismissed = &v
}
posts, err := s.store.ListPosts(f)
if err != nil {
return nil, status.Errorf(codes.Internal, "list posts: %v", err)
}
resp := &pb.ListResponse{Posts: make([]*pb.Post, len(posts))}
for i := range posts {
resp.Posts[i] = domainToProto(posts[i])
}
return resp, nil
}
// UpdatePost updates flags on a post and returns the updated version.
func (s *Server) UpdatePost(_ context.Context, req *pb.UpdateRequest) (*pb.Post, error) {
if req.GetId() == "" {
return nil, status.Error(codes.InvalidArgument, "id is required")
}
u := store.PostUpdate{}
if req.Read != nil {
v := req.GetRead()
u.Read = &v
}
if req.Starred != nil {
v := req.GetStarred()
u.Starred = &v
}
if req.Dismissed != nil {
v := req.GetDismissed()
u.Dismissed = &v
}
if err := s.store.UpdatePost(req.GetId(), u); err != nil {
return nil, status.Errorf(codes.Internal, "update post: %v", err)
}
post, err := s.store.GetPost(req.GetId())
if err != nil {
return nil, status.Errorf(codes.NotFound, "post %q not found", req.GetId())
}
return domainToProto(post), nil
}
// SubmitFeedback records a vote for a post.
func (s *Server) SubmitFeedback(_ context.Context, req *pb.FeedbackRequest) (*pb.FeedbackResponse, error) {
if err := s.store.AddFeedback(req.GetPostId(), int(req.GetVote())); err != nil {
return nil, status.Errorf(codes.Internal, "add feedback: %v", err)
}
return &pb.FeedbackResponse{}, nil
}
// ListSubreddits returns all configured subreddits.
func (s *Server) ListSubreddits(_ context.Context, _ *pb.Empty) (*pb.SubredditList, error) {
subs, err := s.store.ListSubreddits()
if err != nil {
return nil, status.Errorf(codes.Internal, "list subreddits: %v", err)
}
resp := &pb.SubredditList{Subreddits: make([]*pb.SubredditMsg, len(subs))}
for i, sub := range subs {
resp.Subreddits[i] = &pb.SubredditMsg{
Name: sub.Name,
Enabled: sub.Enabled,
PollSort: sub.PollSort,
}
}
return resp, nil
}
// AddSubreddit adds a subreddit and returns it.
func (s *Server) AddSubreddit(_ context.Context, req *pb.AddSubredditRequest) (*pb.SubredditMsg, error) {
sub := domain.Subreddit{
Name: req.GetName(),
PollSort: req.GetPollSort(),
}
if err := s.store.AddSubreddit(sub); err != nil {
return nil, status.Errorf(codes.Internal, "add subreddit: %v", err)
}
return &pb.SubredditMsg{
Name: sub.Name,
Enabled: true,
PollSort: sub.PollSort,
}, nil
}
// RemoveSubreddit deletes a subreddit by name.
func (s *Server) RemoveSubreddit(_ context.Context, req *pb.RemoveRequest) (*pb.Empty, error) {
if err := s.store.RemoveSubreddit(req.GetName()); err != nil {
return nil, status.Errorf(codes.Internal, "remove subreddit: %v", err)
}
return &pb.Empty{}, nil
}
// UpdateFilters adds the provided filters for a subreddit and returns all filters.
func (s *Server) UpdateFilters(_ context.Context, req *pb.FilterRequest) (*pb.FilterResponse, error) {
for _, f := range req.GetFilters() {
_, err := s.store.AddFilter(domain.Filter{
Subreddit: req.GetSubreddit(),
Pattern: f.GetPattern(),
IsRegex: f.GetIsRegex(),
})
if err != nil {
return nil, status.Errorf(codes.Internal, "add filter: %v", err)
}
}
filters, err := s.store.ListFilters(req.GetSubreddit())
if err != nil {
return nil, status.Errorf(codes.Internal, "list filters: %v", err)
}
resp := &pb.FilterResponse{Filters: make([]*pb.FilterMsg, len(filters))}
for i, f := range filters {
resp.Filters[i] = &pb.FilterMsg{
Id: f.ID,
Subreddit: f.Subreddit,
Pattern: f.Pattern,
IsRegex: f.IsRegex,
}
}
return resp, nil
}
// Status returns uptime and post counts.
func (s *Server) Status(_ context.Context, _ *pb.Empty) (*pb.StatusResponse, error) {
allPosts, err := s.store.ListPosts(store.ListFilter{})
if err != nil {
return nil, status.Errorf(codes.Internal, "count posts: %v", err)
}
unread := 0
for _, p := range allPosts {
if !p.Read {
unread++
}
}
return &pb.StatusResponse{
UptimeSeconds: int64(time.Since(s.startedAt).Seconds()),
TotalPosts: int32(len(allPosts)),
UnreadPosts: int32(unread),
}, nil
}
func domainToProto(p domain.Post) *pb.Post {
out := &pb.Post{
Id: p.ID,
Subreddit: p.Subreddit,
Title: p.Title,
Author: p.Author,
Url: p.URL,
SelfText: p.SelfText,
Score: int32(p.Score),
CreatedUtc: timestamppb.New(p.CreatedUTC),
FetchedAt: timestamppb.New(p.FetchedAt),
Read: p.Read,
Starred: p.Starred,
Dismissed: p.Dismissed,
}
if p.Relevance != nil {
out.Relevance = p.Relevance
}
if p.Summary != nil {
out.Summary = p.Summary
}
return out
}