191 lines
4.7 KiB
Go
191 lines
4.7 KiB
Go
package client
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/credentials/insecure"
|
|
"somegit.dev/vikingowl/reddit-reader/internal/domain"
|
|
pb "somegit.dev/vikingowl/reddit-reader/proto/redditreader"
|
|
)
|
|
|
|
// Client wraps a gRPC connection to the RedditReader service.
|
|
type Client struct {
|
|
conn *grpc.ClientConn
|
|
client pb.RedditReaderClient
|
|
}
|
|
|
|
// Dial connects to the gRPC server via a Unix socket.
|
|
func Dial(socketPath string) (*Client, error) {
|
|
conn, err := grpc.NewClient(
|
|
"unix://"+socketPath,
|
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("grpc dial: %w", err)
|
|
}
|
|
return &Client{
|
|
conn: conn,
|
|
client: pb.NewRedditReaderClient(conn),
|
|
}, nil
|
|
}
|
|
|
|
// Close closes the underlying gRPC connection.
|
|
func (c *Client) Close() error {
|
|
return c.conn.Close()
|
|
}
|
|
|
|
// ListPosts retrieves posts matching the given subreddit and limit.
|
|
func (c *Client) ListPosts(ctx context.Context, subreddit string, limit int) ([]domain.Post, error) {
|
|
resp, err := c.client.ListPosts(ctx, &pb.ListRequest{
|
|
Subreddit: subreddit,
|
|
Limit: int32(limit),
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list posts: %w", err)
|
|
}
|
|
posts := make([]domain.Post, len(resp.GetPosts()))
|
|
for i, p := range resp.GetPosts() {
|
|
posts[i] = protoToDomain(p)
|
|
}
|
|
return posts, nil
|
|
}
|
|
|
|
// UpdatePost updates flags on a post and returns the updated version.
|
|
func (c *Client) UpdatePost(ctx context.Context, id string, read, starred, dismissed *bool) (domain.Post, error) {
|
|
req := &pb.UpdateRequest{Id: id}
|
|
if read != nil {
|
|
req.Read = read
|
|
}
|
|
if starred != nil {
|
|
req.Starred = starred
|
|
}
|
|
if dismissed != nil {
|
|
req.Dismissed = dismissed
|
|
}
|
|
resp, err := c.client.UpdatePost(ctx, req)
|
|
if err != nil {
|
|
return domain.Post{}, fmt.Errorf("update post: %w", err)
|
|
}
|
|
return protoToDomain(resp), nil
|
|
}
|
|
|
|
// SubmitFeedback records a vote for a post.
|
|
func (c *Client) SubmitFeedback(ctx context.Context, postID string, vote int) error {
|
|
_, err := c.client.SubmitFeedback(ctx, &pb.FeedbackRequest{
|
|
PostId: postID,
|
|
Vote: int32(vote),
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("submit feedback: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ListSubreddits returns all configured subreddits.
|
|
func (c *Client) ListSubreddits(ctx context.Context) ([]domain.Subreddit, error) {
|
|
resp, err := c.client.ListSubreddits(ctx, &pb.Empty{})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list subreddits: %w", err)
|
|
}
|
|
subs := make([]domain.Subreddit, len(resp.GetSubreddits()))
|
|
for i, s := range resp.GetSubreddits() {
|
|
subs[i] = domain.Subreddit{
|
|
Name: s.GetName(),
|
|
Enabled: s.GetEnabled(),
|
|
PollSort: s.GetPollSort(),
|
|
}
|
|
}
|
|
return subs, nil
|
|
}
|
|
|
|
// AddSubreddit adds a subreddit with the given sort order.
|
|
func (c *Client) AddSubreddit(ctx context.Context, name, sort string) error {
|
|
_, err := c.client.AddSubreddit(ctx, &pb.AddSubredditRequest{
|
|
Name: name,
|
|
PollSort: sort,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("add subreddit: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// RemoveSubreddit deletes a subreddit by name.
|
|
func (c *Client) RemoveSubreddit(ctx context.Context, name string) error {
|
|
_, err := c.client.RemoveSubreddit(ctx, &pb.RemoveRequest{Name: name})
|
|
if err != nil {
|
|
return fmt.Errorf("remove subreddit: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// StreamPosts returns a channel that receives posts as they are pushed by the server.
|
|
func (c *Client) StreamPosts(ctx context.Context) (<-chan domain.Post, error) {
|
|
stream, err := c.client.StreamPosts(ctx, &pb.StreamRequest{})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("stream posts: %w", err)
|
|
}
|
|
ch := make(chan domain.Post, 64)
|
|
go func() {
|
|
defer close(ch)
|
|
for {
|
|
p, err := stream.Recv()
|
|
if err == io.EOF {
|
|
return
|
|
}
|
|
if err != nil {
|
|
return
|
|
}
|
|
select {
|
|
case ch <- protoToDomain(p):
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
return ch, nil
|
|
}
|
|
|
|
// Status returns the server's status information.
|
|
func (c *Client) Status(ctx context.Context) (*pb.StatusResponse, error) {
|
|
resp, err := c.client.Status(ctx, &pb.Empty{})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("status: %w", err)
|
|
}
|
|
return resp, nil
|
|
}
|
|
|
|
func protoToDomain(p *pb.Post) domain.Post {
|
|
post := domain.Post{
|
|
ID: p.GetId(),
|
|
Subreddit: p.GetSubreddit(),
|
|
Title: p.GetTitle(),
|
|
Author: p.GetAuthor(),
|
|
URL: p.GetUrl(),
|
|
SelfText: p.GetSelfText(),
|
|
Score: int(p.GetScore()),
|
|
Read: p.GetRead(),
|
|
Starred: p.GetStarred(),
|
|
Dismissed: p.GetDismissed(),
|
|
}
|
|
if ts := p.GetCreatedUtc(); ts != nil {
|
|
post.CreatedUTC = ts.AsTime()
|
|
}
|
|
if ts := p.GetFetchedAt(); ts != nil {
|
|
post.FetchedAt = ts.AsTime()
|
|
}
|
|
if p.Relevance != nil {
|
|
v := p.GetRelevance()
|
|
post.Relevance = &v
|
|
}
|
|
if p.Summary != nil {
|
|
v := p.GetSummary()
|
|
post.Summary = &v
|
|
}
|
|
return post
|
|
}
|
|
|