package server import ( "context" "sync" "time" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" "somegit.dev/vikingowl/reddit-reader/internal/domain" "somegit.dev/vikingowl/reddit-reader/internal/store" pb "somegit.dev/vikingowl/reddit-reader/proto/redditreader" ) // Server implements the RedditReader gRPC service. type Server struct { pb.UnimplementedRedditReaderServer store *store.Store startedAt time.Time mu sync.RWMutex subscribers map[chan *pb.Post]struct{} } // Register creates a Server and registers it with the gRPC server. func Register(srv *grpc.Server, st *store.Store, startedAt time.Time) *Server { s := &Server{ store: st, startedAt: startedAt, subscribers: make(map[chan *pb.Post]struct{}), } pb.RegisterRedditReaderServer(srv, s) return s } // Notify pushes new posts to all connected stream subscribers. func (s *Server) Notify(posts []domain.Post) { s.mu.RLock() defer s.mu.RUnlock() for ch := range s.subscribers { for i := range posts { p := domainToProto(posts[i]) select { case ch <- p: default: // subscriber too slow, drop } } } } // StreamPosts adds a subscriber channel and sends posts as they arrive. func (s *Server) StreamPosts(_ *pb.StreamRequest, stream grpc.ServerStreamingServer[pb.Post]) error { ch := make(chan *pb.Post, 64) s.mu.Lock() s.subscribers[ch] = struct{}{} s.mu.Unlock() defer func() { s.mu.Lock() delete(s.subscribers, ch) s.mu.Unlock() }() for { select { case <-stream.Context().Done(): return stream.Context().Err() case post := <-ch: if err := stream.Send(post); err != nil { return err } } } } // ListPosts returns posts matching the filter criteria. func (s *Server) ListPosts(_ context.Context, req *pb.ListRequest) (*pb.ListResponse, error) { f := store.ListFilter{ Subreddit: req.GetSubreddit(), Limit: int(req.GetLimit()), } if req.Unread != nil { v := req.GetUnread() f.Unread = &v } if req.Starred != nil { v := req.GetStarred() f.Starred = &v } if req.Dismissed != nil { v := req.GetDismissed() f.Dismissed = &v } posts, err := s.store.ListPosts(f) if err != nil { return nil, status.Errorf(codes.Internal, "list posts: %v", err) } resp := &pb.ListResponse{Posts: make([]*pb.Post, len(posts))} for i := range posts { resp.Posts[i] = domainToProto(posts[i]) } return resp, nil } // UpdatePost updates flags on a post and returns the updated version. func (s *Server) UpdatePost(_ context.Context, req *pb.UpdateRequest) (*pb.Post, error) { if req.GetId() == "" { return nil, status.Error(codes.InvalidArgument, "id is required") } u := store.PostUpdate{} if req.Read != nil { v := req.GetRead() u.Read = &v } if req.Starred != nil { v := req.GetStarred() u.Starred = &v } if req.Dismissed != nil { v := req.GetDismissed() u.Dismissed = &v } if err := s.store.UpdatePost(req.GetId(), u); err != nil { return nil, status.Errorf(codes.Internal, "update post: %v", err) } post, err := s.store.GetPost(req.GetId()) if err != nil { return nil, status.Errorf(codes.NotFound, "post %q not found", req.GetId()) } return domainToProto(post), nil } // SubmitFeedback records a vote for a post. func (s *Server) SubmitFeedback(_ context.Context, req *pb.FeedbackRequest) (*pb.FeedbackResponse, error) { if err := s.store.AddFeedback(req.GetPostId(), int(req.GetVote())); err != nil { return nil, status.Errorf(codes.Internal, "add feedback: %v", err) } return &pb.FeedbackResponse{}, nil } // ListSubreddits returns all configured subreddits. func (s *Server) ListSubreddits(_ context.Context, _ *pb.Empty) (*pb.SubredditList, error) { subs, err := s.store.ListSubreddits() if err != nil { return nil, status.Errorf(codes.Internal, "list subreddits: %v", err) } resp := &pb.SubredditList{Subreddits: make([]*pb.SubredditMsg, len(subs))} for i, sub := range subs { resp.Subreddits[i] = &pb.SubredditMsg{ Name: sub.Name, Enabled: sub.Enabled, PollSort: sub.PollSort, } } return resp, nil } // AddSubreddit adds a subreddit and returns it. func (s *Server) AddSubreddit(_ context.Context, req *pb.AddSubredditRequest) (*pb.SubredditMsg, error) { sub := domain.Subreddit{ Name: req.GetName(), PollSort: req.GetPollSort(), } if err := s.store.AddSubreddit(sub); err != nil { return nil, status.Errorf(codes.Internal, "add subreddit: %v", err) } return &pb.SubredditMsg{ Name: sub.Name, Enabled: true, PollSort: sub.PollSort, }, nil } // RemoveSubreddit deletes a subreddit by name. func (s *Server) RemoveSubreddit(_ context.Context, req *pb.RemoveRequest) (*pb.Empty, error) { if err := s.store.RemoveSubreddit(req.GetName()); err != nil { return nil, status.Errorf(codes.Internal, "remove subreddit: %v", err) } return &pb.Empty{}, nil } // UpdateFilters adds the provided filters for a subreddit and returns all filters. func (s *Server) UpdateFilters(_ context.Context, req *pb.FilterRequest) (*pb.FilterResponse, error) { for _, f := range req.GetFilters() { _, err := s.store.AddFilter(domain.Filter{ Subreddit: req.GetSubreddit(), Pattern: f.GetPattern(), IsRegex: f.GetIsRegex(), }) if err != nil { return nil, status.Errorf(codes.Internal, "add filter: %v", err) } } filters, err := s.store.ListFilters(req.GetSubreddit()) if err != nil { return nil, status.Errorf(codes.Internal, "list filters: %v", err) } resp := &pb.FilterResponse{Filters: make([]*pb.FilterMsg, len(filters))} for i, f := range filters { resp.Filters[i] = &pb.FilterMsg{ Id: f.ID, Subreddit: f.Subreddit, Pattern: f.Pattern, IsRegex: f.IsRegex, } } return resp, nil } // Status returns uptime and post counts. func (s *Server) Status(_ context.Context, _ *pb.Empty) (*pb.StatusResponse, error) { allPosts, err := s.store.ListPosts(store.ListFilter{}) if err != nil { return nil, status.Errorf(codes.Internal, "count posts: %v", err) } unread := 0 for _, p := range allPosts { if !p.Read { unread++ } } return &pb.StatusResponse{ UptimeSeconds: int64(time.Since(s.startedAt).Seconds()), TotalPosts: int32(len(allPosts)), UnreadPosts: int32(unread), }, nil } func domainToProto(p domain.Post) *pb.Post { out := &pb.Post{ Id: p.ID, Subreddit: p.Subreddit, Title: p.Title, Author: p.Author, Url: p.URL, SelfText: p.SelfText, Score: int32(p.Score), CreatedUtc: timestamppb.New(p.CreatedUTC), FetchedAt: timestamppb.New(p.FetchedAt), Read: p.Read, Starred: p.Starred, Dismissed: p.Dismissed, } if p.Relevance != nil { out.Relevance = p.Relevance } if p.Summary != nil { out.Summary = p.Summary } return out }