266 lines
6.7 KiB
Go
266 lines
6.7 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"sync"
|
|
"time"
|
|
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
"google.golang.org/protobuf/types/known/timestamppb"
|
|
|
|
"somegit.dev/vikingowl/reddit-reader/internal/domain"
|
|
"somegit.dev/vikingowl/reddit-reader/internal/store"
|
|
pb "somegit.dev/vikingowl/reddit-reader/proto/redditreader"
|
|
)
|
|
|
|
// Server implements the RedditReader gRPC service.
|
|
type Server struct {
|
|
pb.UnimplementedRedditReaderServer
|
|
store *store.Store
|
|
startedAt time.Time
|
|
mu sync.RWMutex
|
|
subscribers map[chan *pb.Post]struct{}
|
|
}
|
|
|
|
// Register creates a Server and registers it with the gRPC server.
|
|
func Register(srv *grpc.Server, st *store.Store, startedAt time.Time) *Server {
|
|
s := &Server{
|
|
store: st,
|
|
startedAt: startedAt,
|
|
subscribers: make(map[chan *pb.Post]struct{}),
|
|
}
|
|
pb.RegisterRedditReaderServer(srv, s)
|
|
return s
|
|
}
|
|
|
|
// Notify pushes new posts to all connected stream subscribers.
|
|
func (s *Server) Notify(posts []domain.Post) {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
for ch := range s.subscribers {
|
|
for i := range posts {
|
|
p := domainToProto(posts[i])
|
|
select {
|
|
case ch <- p:
|
|
default:
|
|
// subscriber too slow, drop
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// StreamPosts adds a subscriber channel and sends posts as they arrive.
|
|
func (s *Server) StreamPosts(_ *pb.StreamRequest, stream grpc.ServerStreamingServer[pb.Post]) error {
|
|
ch := make(chan *pb.Post, 64)
|
|
|
|
s.mu.Lock()
|
|
s.subscribers[ch] = struct{}{}
|
|
s.mu.Unlock()
|
|
|
|
defer func() {
|
|
s.mu.Lock()
|
|
delete(s.subscribers, ch)
|
|
s.mu.Unlock()
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case <-stream.Context().Done():
|
|
return stream.Context().Err()
|
|
case post := <-ch:
|
|
if err := stream.Send(post); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// ListPosts returns posts matching the filter criteria.
|
|
func (s *Server) ListPosts(_ context.Context, req *pb.ListRequest) (*pb.ListResponse, error) {
|
|
f := store.ListFilter{
|
|
Subreddit: req.GetSubreddit(),
|
|
Limit: int(req.GetLimit()),
|
|
}
|
|
if req.Unread != nil {
|
|
v := req.GetUnread()
|
|
f.Unread = &v
|
|
}
|
|
if req.Starred != nil {
|
|
v := req.GetStarred()
|
|
f.Starred = &v
|
|
}
|
|
if req.Dismissed != nil {
|
|
v := req.GetDismissed()
|
|
f.Dismissed = &v
|
|
}
|
|
|
|
posts, err := s.store.ListPosts(f)
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.Internal, "list posts: %v", err)
|
|
}
|
|
|
|
resp := &pb.ListResponse{Posts: make([]*pb.Post, len(posts))}
|
|
for i := range posts {
|
|
resp.Posts[i] = domainToProto(posts[i])
|
|
}
|
|
return resp, nil
|
|
}
|
|
|
|
// UpdatePost updates flags on a post and returns the updated version.
|
|
func (s *Server) UpdatePost(_ context.Context, req *pb.UpdateRequest) (*pb.Post, error) {
|
|
if req.GetId() == "" {
|
|
return nil, status.Error(codes.InvalidArgument, "id is required")
|
|
}
|
|
|
|
u := store.PostUpdate{}
|
|
if req.Read != nil {
|
|
v := req.GetRead()
|
|
u.Read = &v
|
|
}
|
|
if req.Starred != nil {
|
|
v := req.GetStarred()
|
|
u.Starred = &v
|
|
}
|
|
if req.Dismissed != nil {
|
|
v := req.GetDismissed()
|
|
u.Dismissed = &v
|
|
}
|
|
|
|
if err := s.store.UpdatePost(req.GetId(), u); err != nil {
|
|
return nil, status.Errorf(codes.Internal, "update post: %v", err)
|
|
}
|
|
|
|
post, err := s.store.GetPost(req.GetId())
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.NotFound, "post %q not found", req.GetId())
|
|
}
|
|
return domainToProto(post), nil
|
|
}
|
|
|
|
// SubmitFeedback records a vote for a post.
|
|
func (s *Server) SubmitFeedback(_ context.Context, req *pb.FeedbackRequest) (*pb.FeedbackResponse, error) {
|
|
if err := s.store.AddFeedback(req.GetPostId(), int(req.GetVote())); err != nil {
|
|
return nil, status.Errorf(codes.Internal, "add feedback: %v", err)
|
|
}
|
|
return &pb.FeedbackResponse{}, nil
|
|
}
|
|
|
|
// ListSubreddits returns all configured subreddits.
|
|
func (s *Server) ListSubreddits(_ context.Context, _ *pb.Empty) (*pb.SubredditList, error) {
|
|
subs, err := s.store.ListSubreddits()
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.Internal, "list subreddits: %v", err)
|
|
}
|
|
|
|
resp := &pb.SubredditList{Subreddits: make([]*pb.SubredditMsg, len(subs))}
|
|
for i, sub := range subs {
|
|
resp.Subreddits[i] = &pb.SubredditMsg{
|
|
Name: sub.Name,
|
|
Enabled: sub.Enabled,
|
|
PollSort: sub.PollSort,
|
|
}
|
|
}
|
|
return resp, nil
|
|
}
|
|
|
|
// AddSubreddit adds a subreddit and returns it.
|
|
func (s *Server) AddSubreddit(_ context.Context, req *pb.AddSubredditRequest) (*pb.SubredditMsg, error) {
|
|
sub := domain.Subreddit{
|
|
Name: req.GetName(),
|
|
PollSort: req.GetPollSort(),
|
|
}
|
|
if err := s.store.AddSubreddit(sub); err != nil {
|
|
return nil, status.Errorf(codes.Internal, "add subreddit: %v", err)
|
|
}
|
|
return &pb.SubredditMsg{
|
|
Name: sub.Name,
|
|
Enabled: true,
|
|
PollSort: sub.PollSort,
|
|
}, nil
|
|
}
|
|
|
|
// RemoveSubreddit deletes a subreddit by name.
|
|
func (s *Server) RemoveSubreddit(_ context.Context, req *pb.RemoveRequest) (*pb.Empty, error) {
|
|
if err := s.store.RemoveSubreddit(req.GetName()); err != nil {
|
|
return nil, status.Errorf(codes.Internal, "remove subreddit: %v", err)
|
|
}
|
|
return &pb.Empty{}, nil
|
|
}
|
|
|
|
// UpdateFilters adds the provided filters for a subreddit and returns all filters.
|
|
func (s *Server) UpdateFilters(_ context.Context, req *pb.FilterRequest) (*pb.FilterResponse, error) {
|
|
for _, f := range req.GetFilters() {
|
|
_, err := s.store.AddFilter(domain.Filter{
|
|
Subreddit: req.GetSubreddit(),
|
|
Pattern: f.GetPattern(),
|
|
IsRegex: f.GetIsRegex(),
|
|
})
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.Internal, "add filter: %v", err)
|
|
}
|
|
}
|
|
|
|
filters, err := s.store.ListFilters(req.GetSubreddit())
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.Internal, "list filters: %v", err)
|
|
}
|
|
|
|
resp := &pb.FilterResponse{Filters: make([]*pb.FilterMsg, len(filters))}
|
|
for i, f := range filters {
|
|
resp.Filters[i] = &pb.FilterMsg{
|
|
Id: f.ID,
|
|
Subreddit: f.Subreddit,
|
|
Pattern: f.Pattern,
|
|
IsRegex: f.IsRegex,
|
|
}
|
|
}
|
|
return resp, nil
|
|
}
|
|
|
|
// Status returns uptime and post counts.
|
|
func (s *Server) Status(_ context.Context, _ *pb.Empty) (*pb.StatusResponse, error) {
|
|
allPosts, err := s.store.ListPosts(store.ListFilter{})
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.Internal, "count posts: %v", err)
|
|
}
|
|
|
|
unread := 0
|
|
for _, p := range allPosts {
|
|
if !p.Read {
|
|
unread++
|
|
}
|
|
}
|
|
|
|
return &pb.StatusResponse{
|
|
UptimeSeconds: int64(time.Since(s.startedAt).Seconds()),
|
|
TotalPosts: int32(len(allPosts)),
|
|
UnreadPosts: int32(unread),
|
|
}, nil
|
|
}
|
|
|
|
func domainToProto(p domain.Post) *pb.Post {
|
|
out := &pb.Post{
|
|
Id: p.ID,
|
|
Subreddit: p.Subreddit,
|
|
Title: p.Title,
|
|
Author: p.Author,
|
|
Url: p.URL,
|
|
SelfText: p.SelfText,
|
|
Score: int32(p.Score),
|
|
CreatedUtc: timestamppb.New(p.CreatedUTC),
|
|
FetchedAt: timestamppb.New(p.FetchedAt),
|
|
Read: p.Read,
|
|
Starred: p.Starred,
|
|
Dismissed: p.Dismissed,
|
|
}
|
|
if p.Relevance != nil {
|
|
out.Relevance = p.Relevance
|
|
}
|
|
if p.Summary != nil {
|
|
out.Summary = p.Summary
|
|
}
|
|
return out
|
|
}
|