package client import ( "context" "fmt" "io" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "somegit.dev/vikingowl/reddit-reader/internal/domain" pb "somegit.dev/vikingowl/reddit-reader/proto/redditreader" ) // Client wraps a gRPC connection to the RedditReader service. type Client struct { conn *grpc.ClientConn client pb.RedditReaderClient } // Dial connects to the gRPC server via a Unix socket. func Dial(socketPath string) (*Client, error) { conn, err := grpc.NewClient( "unix://"+socketPath, grpc.WithTransportCredentials(insecure.NewCredentials()), ) if err != nil { return nil, fmt.Errorf("grpc dial: %w", err) } return &Client{ conn: conn, client: pb.NewRedditReaderClient(conn), }, nil } // Close closes the underlying gRPC connection. func (c *Client) Close() error { return c.conn.Close() } // ListPosts retrieves posts matching the given subreddit and limit. func (c *Client) ListPosts(ctx context.Context, subreddit string, limit int) ([]domain.Post, error) { resp, err := c.client.ListPosts(ctx, &pb.ListRequest{ Subreddit: subreddit, Limit: int32(limit), }) if err != nil { return nil, fmt.Errorf("list posts: %w", err) } posts := make([]domain.Post, len(resp.GetPosts())) for i, p := range resp.GetPosts() { posts[i] = protoToDomain(p) } return posts, nil } // UpdatePost updates flags on a post and returns the updated version. func (c *Client) UpdatePost(ctx context.Context, id string, read, starred, dismissed *bool) (domain.Post, error) { req := &pb.UpdateRequest{Id: id} if read != nil { req.Read = read } if starred != nil { req.Starred = starred } if dismissed != nil { req.Dismissed = dismissed } resp, err := c.client.UpdatePost(ctx, req) if err != nil { return domain.Post{}, fmt.Errorf("update post: %w", err) } return protoToDomain(resp), nil } // SubmitFeedback records a vote for a post. func (c *Client) SubmitFeedback(ctx context.Context, postID string, vote int) error { _, err := c.client.SubmitFeedback(ctx, &pb.FeedbackRequest{ PostId: postID, Vote: int32(vote), }) if err != nil { return fmt.Errorf("submit feedback: %w", err) } return nil } // ListSubreddits returns all configured subreddits. func (c *Client) ListSubreddits(ctx context.Context) ([]domain.Subreddit, error) { resp, err := c.client.ListSubreddits(ctx, &pb.Empty{}) if err != nil { return nil, fmt.Errorf("list subreddits: %w", err) } subs := make([]domain.Subreddit, len(resp.GetSubreddits())) for i, s := range resp.GetSubreddits() { subs[i] = domain.Subreddit{ Name: s.GetName(), Enabled: s.GetEnabled(), PollSort: s.GetPollSort(), } } return subs, nil } // AddSubreddit adds a subreddit with the given sort order. func (c *Client) AddSubreddit(ctx context.Context, name, sort string) error { _, err := c.client.AddSubreddit(ctx, &pb.AddSubredditRequest{ Name: name, PollSort: sort, }) if err != nil { return fmt.Errorf("add subreddit: %w", err) } return nil } // RemoveSubreddit deletes a subreddit by name. func (c *Client) RemoveSubreddit(ctx context.Context, name string) error { _, err := c.client.RemoveSubreddit(ctx, &pb.RemoveRequest{Name: name}) if err != nil { return fmt.Errorf("remove subreddit: %w", err) } return nil } // StreamPosts returns a channel that receives posts as they are pushed by the server. func (c *Client) StreamPosts(ctx context.Context) (<-chan domain.Post, error) { stream, err := c.client.StreamPosts(ctx, &pb.StreamRequest{}) if err != nil { return nil, fmt.Errorf("stream posts: %w", err) } ch := make(chan domain.Post, 64) go func() { defer close(ch) for { p, err := stream.Recv() if err == io.EOF { return } if err != nil { return } select { case ch <- protoToDomain(p): case <-ctx.Done(): return } } }() return ch, nil } // Status returns the server's status information. func (c *Client) Status(ctx context.Context) (*pb.StatusResponse, error) { resp, err := c.client.Status(ctx, &pb.Empty{}) if err != nil { return nil, fmt.Errorf("status: %w", err) } return resp, nil } func protoToDomain(p *pb.Post) domain.Post { post := domain.Post{ ID: p.GetId(), Subreddit: p.GetSubreddit(), Title: p.GetTitle(), Author: p.GetAuthor(), URL: p.GetUrl(), SelfText: p.GetSelfText(), Score: int(p.GetScore()), Read: p.GetRead(), Starred: p.GetStarred(), Dismissed: p.GetDismissed(), } if ts := p.GetCreatedUtc(); ts != nil { post.CreatedUTC = ts.AsTime() } if ts := p.GetFetchedAt(); ts != nil { post.FetchedAt = ts.AsTime() } if p.Relevance != nil { v := p.GetRelevance() post.Relevance = &v } if p.Summary != nil { v := p.GetSummary() post.Summary = &v } return post }