Files

191 lines
4.7 KiB
Go

package client
import (
"context"
"fmt"
"io"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"somegit.dev/vikingowl/reddit-reader/internal/domain"
pb "somegit.dev/vikingowl/reddit-reader/proto/redditreader"
)
// Client wraps a gRPC connection to the RedditReader service.
type Client struct {
conn *grpc.ClientConn
client pb.RedditReaderClient
}
// Dial connects to the gRPC server via a Unix socket.
func Dial(socketPath string) (*Client, error) {
conn, err := grpc.NewClient(
"unix://"+socketPath,
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
if err != nil {
return nil, fmt.Errorf("grpc dial: %w", err)
}
return &Client{
conn: conn,
client: pb.NewRedditReaderClient(conn),
}, nil
}
// Close closes the underlying gRPC connection.
func (c *Client) Close() error {
return c.conn.Close()
}
// ListPosts retrieves posts matching the given subreddit and limit.
func (c *Client) ListPosts(ctx context.Context, subreddit string, limit int) ([]domain.Post, error) {
resp, err := c.client.ListPosts(ctx, &pb.ListRequest{
Subreddit: subreddit,
Limit: int32(limit),
})
if err != nil {
return nil, fmt.Errorf("list posts: %w", err)
}
posts := make([]domain.Post, len(resp.GetPosts()))
for i, p := range resp.GetPosts() {
posts[i] = protoToDomain(p)
}
return posts, nil
}
// UpdatePost updates flags on a post and returns the updated version.
func (c *Client) UpdatePost(ctx context.Context, id string, read, starred, dismissed *bool) (domain.Post, error) {
req := &pb.UpdateRequest{Id: id}
if read != nil {
req.Read = read
}
if starred != nil {
req.Starred = starred
}
if dismissed != nil {
req.Dismissed = dismissed
}
resp, err := c.client.UpdatePost(ctx, req)
if err != nil {
return domain.Post{}, fmt.Errorf("update post: %w", err)
}
return protoToDomain(resp), nil
}
// SubmitFeedback records a vote for a post.
func (c *Client) SubmitFeedback(ctx context.Context, postID string, vote int) error {
_, err := c.client.SubmitFeedback(ctx, &pb.FeedbackRequest{
PostId: postID,
Vote: int32(vote),
})
if err != nil {
return fmt.Errorf("submit feedback: %w", err)
}
return nil
}
// ListSubreddits returns all configured subreddits.
func (c *Client) ListSubreddits(ctx context.Context) ([]domain.Subreddit, error) {
resp, err := c.client.ListSubreddits(ctx, &pb.Empty{})
if err != nil {
return nil, fmt.Errorf("list subreddits: %w", err)
}
subs := make([]domain.Subreddit, len(resp.GetSubreddits()))
for i, s := range resp.GetSubreddits() {
subs[i] = domain.Subreddit{
Name: s.GetName(),
Enabled: s.GetEnabled(),
PollSort: s.GetPollSort(),
}
}
return subs, nil
}
// AddSubreddit adds a subreddit with the given sort order.
func (c *Client) AddSubreddit(ctx context.Context, name, sort string) error {
_, err := c.client.AddSubreddit(ctx, &pb.AddSubredditRequest{
Name: name,
PollSort: sort,
})
if err != nil {
return fmt.Errorf("add subreddit: %w", err)
}
return nil
}
// RemoveSubreddit deletes a subreddit by name.
func (c *Client) RemoveSubreddit(ctx context.Context, name string) error {
_, err := c.client.RemoveSubreddit(ctx, &pb.RemoveRequest{Name: name})
if err != nil {
return fmt.Errorf("remove subreddit: %w", err)
}
return nil
}
// StreamPosts returns a channel that receives posts as they are pushed by the server.
func (c *Client) StreamPosts(ctx context.Context) (<-chan domain.Post, error) {
stream, err := c.client.StreamPosts(ctx, &pb.StreamRequest{})
if err != nil {
return nil, fmt.Errorf("stream posts: %w", err)
}
ch := make(chan domain.Post, 64)
go func() {
defer close(ch)
for {
p, err := stream.Recv()
if err == io.EOF {
return
}
if err != nil {
return
}
select {
case ch <- protoToDomain(p):
case <-ctx.Done():
return
}
}
}()
return ch, nil
}
// Status returns the server's status information.
func (c *Client) Status(ctx context.Context) (*pb.StatusResponse, error) {
resp, err := c.client.Status(ctx, &pb.Empty{})
if err != nil {
return nil, fmt.Errorf("status: %w", err)
}
return resp, nil
}
func protoToDomain(p *pb.Post) domain.Post {
post := domain.Post{
ID: p.GetId(),
Subreddit: p.GetSubreddit(),
Title: p.GetTitle(),
Author: p.GetAuthor(),
URL: p.GetUrl(),
SelfText: p.GetSelfText(),
Score: int(p.GetScore()),
Read: p.GetRead(),
Starred: p.GetStarred(),
Dismissed: p.GetDismissed(),
}
if ts := p.GetCreatedUtc(); ts != nil {
post.CreatedUTC = ts.AsTime()
}
if ts := p.GetFetchedAt(); ts != nil {
post.FetchedAt = ts.AsTime()
}
if p.Relevance != nil {
v := p.GetRelevance()
post.Relevance = &v
}
if p.Summary != nil {
v := p.GetSummary()
post.Summary = &v
}
return post
}