Files

435 lines
11 KiB
Go

package store
import (
"database/sql"
"fmt"
"strings"
"time"
"somegit.dev/vikingowl/reddit-reader/internal/domain"
_ "modernc.org/sqlite"
)
const schema = `
CREATE TABLE IF NOT EXISTS subreddits (
name TEXT PRIMARY KEY,
enabled INTEGER DEFAULT 1,
poll_sort TEXT DEFAULT 'new',
added_at TEXT DEFAULT (datetime('now'))
);
CREATE TABLE IF NOT EXISTS filters (
id INTEGER PRIMARY KEY,
subreddit TEXT REFERENCES subreddits(name) ON DELETE CASCADE,
pattern TEXT NOT NULL,
is_regex INTEGER DEFAULT 0
);
CREATE TABLE IF NOT EXISTS posts (
id TEXT PRIMARY KEY,
subreddit TEXT NOT NULL,
title TEXT NOT NULL,
author TEXT,
url TEXT,
selftext TEXT,
score INTEGER,
created_utc TEXT,
fetched_at TEXT DEFAULT (datetime('now')),
relevance REAL,
summary TEXT,
read INTEGER DEFAULT 0,
starred INTEGER DEFAULT 0,
dismissed INTEGER DEFAULT 0
);
CREATE TABLE IF NOT EXISTS feedback (
id INTEGER PRIMARY KEY,
post_id TEXT REFERENCES posts(id),
vote INTEGER NOT NULL,
created_at TEXT DEFAULT (datetime('now'))
);
`
// Store is the SQLite-backed persistence layer.
type Store struct {
db *sql.DB
}
// ListFilter controls which posts ListPosts returns.
type ListFilter struct {
Subreddit string
Unread *bool
Starred *bool
Dismissed *bool
Limit int
}
// PostUpdate carries the fields to update on a post; nil fields are skipped.
type PostUpdate struct {
Read *bool
Starred *bool
Dismissed *bool
Relevance *float64
Summary *string
}
// Open opens (or creates) the SQLite database at dsn and runs migrations.
func Open(dsn string) (*Store, error) {
db, err := sql.Open("sqlite", dsn)
if err != nil {
return nil, fmt.Errorf("store.Open: %w", err)
}
db.SetMaxOpenConns(1)
if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil {
db.Close()
return nil, fmt.Errorf("store.Open WAL: %w", err)
}
if _, err := db.Exec("PRAGMA foreign_keys=ON"); err != nil {
db.Close()
return nil, fmt.Errorf("store.Open foreign_keys: %w", err)
}
if _, err := db.Exec(schema); err != nil {
db.Close()
return nil, fmt.Errorf("store.Open schema: %w", err)
}
return &Store{db: db}, nil
}
// Close closes the underlying database connection.
func (s *Store) Close() error {
return s.db.Close()
}
// InsertPost inserts a post; silently ignores duplicates (INSERT OR IGNORE).
func (s *Store) InsertPost(p domain.Post) error {
const q = `
INSERT OR IGNORE INTO posts
(id, subreddit, title, author, url, selftext, score, created_utc, relevance, summary)
VALUES
(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`
_, err := s.db.Exec(q,
p.ID, p.Subreddit, p.Title, p.Author, p.URL, p.SelfText,
p.Score, p.CreatedUTC.UTC().Format(time.RFC3339),
p.Relevance, p.Summary,
)
if err != nil {
return fmt.Errorf("store.InsertPost: %w", err)
}
return nil
}
// GetPost retrieves a single post by ID.
func (s *Store) GetPost(id string) (domain.Post, error) {
const q = `
SELECT id, subreddit, title, author, url, selftext, score,
created_utc, fetched_at, relevance, summary, read, starred, dismissed
FROM posts WHERE id = ?`
row := s.db.QueryRow(q, id)
return scanPost(row)
}
// PostExists reports whether a post with the given ID exists.
func (s *Store) PostExists(id string) (bool, error) {
var n int
err := s.db.QueryRow(`SELECT COUNT(1) FROM posts WHERE id = ?`, id).Scan(&n)
if err != nil {
return false, fmt.Errorf("store.PostExists: %w", err)
}
return n > 0, nil
}
// ListPosts returns posts ordered by relevance DESC, fetched_at DESC.
func (s *Store) ListPosts(f ListFilter) ([]domain.Post, error) {
where, args := buildPostWhere(f)
limit := ""
if f.Limit > 0 {
limit = fmt.Sprintf(" LIMIT %d", f.Limit)
}
q := `SELECT id, subreddit, title, author, url, selftext, score,
created_utc, fetched_at, relevance, summary, read, starred, dismissed
FROM posts` + where + ` ORDER BY COALESCE(relevance, 0) DESC, fetched_at DESC` + limit
rows, err := s.db.Query(q, args...)
if err != nil {
return nil, fmt.Errorf("store.ListPosts: %w", err)
}
defer rows.Close()
return collectPosts(rows)
}
// UpdatePost updates the non-nil fields in u for the post with the given ID.
func (s *Store) UpdatePost(id string, u PostUpdate) error {
setClauses, args := buildPostSetClauses(u)
if len(setClauses) == 0 {
return nil
}
args = append(args, id)
q := `UPDATE posts SET ` + strings.Join(setClauses, ", ") + ` WHERE id = ?`
if _, err := s.db.Exec(q, args...); err != nil {
return fmt.Errorf("store.UpdatePost: %w", err)
}
return nil
}
// UnsummarizedPosts returns posts where summary IS NULL.
func (s *Store) UnsummarizedPosts() ([]domain.Post, error) {
const q = `SELECT id, subreddit, title, author, url, selftext, score,
created_utc, fetched_at, relevance, summary, read, starred, dismissed
FROM posts WHERE summary IS NULL`
rows, err := s.db.Query(q)
if err != nil {
return nil, fmt.Errorf("store.UnsummarizedPosts: %w", err)
}
defer rows.Close()
return collectPosts(rows)
}
// AddSubreddit inserts a subreddit (INSERT OR IGNORE).
func (s *Store) AddSubreddit(sub domain.Subreddit) error {
_, err := s.db.Exec(
`INSERT OR IGNORE INTO subreddits (name, poll_sort) VALUES (?, ?)`,
sub.Name, sub.PollSort,
)
if err != nil {
return fmt.Errorf("store.AddSubreddit: %w", err)
}
return nil
}
// RemoveSubreddit deletes a subreddit by name.
func (s *Store) RemoveSubreddit(name string) error {
if _, err := s.db.Exec(`DELETE FROM subreddits WHERE name = ?`, name); err != nil {
return fmt.Errorf("store.RemoveSubreddit: %w", err)
}
return nil
}
// ListSubreddits returns all subreddits.
func (s *Store) ListSubreddits() ([]domain.Subreddit, error) {
rows, err := s.db.Query(`SELECT name, enabled, poll_sort, added_at FROM subreddits ORDER BY name`)
if err != nil {
return nil, fmt.Errorf("store.ListSubreddits: %w", err)
}
defer rows.Close()
var subs []domain.Subreddit
for rows.Next() {
var sub domain.Subreddit
var enabled int
var addedAt string
if err := rows.Scan(&sub.Name, &enabled, &sub.PollSort, &addedAt); err != nil {
return nil, fmt.Errorf("store.ListSubreddits scan: %w", err)
}
sub.Enabled = enabled != 0
sub.AddedAt, _ = time.Parse(time.RFC3339, addedAt)
subs = append(subs, sub)
}
return subs, rows.Err()
}
// AddFilter inserts a filter and returns its ID.
func (s *Store) AddFilter(f domain.Filter) (int64, error) {
isRegex := 0
if f.IsRegex {
isRegex = 1
}
res, err := s.db.Exec(
`INSERT INTO filters (subreddit, pattern, is_regex) VALUES (?, ?, ?)`,
f.Subreddit, f.Pattern, isRegex,
)
if err != nil {
return 0, fmt.Errorf("store.AddFilter: %w", err)
}
id, err := res.LastInsertId()
if err != nil {
return 0, fmt.Errorf("store.AddFilter LastInsertId: %w", err)
}
return id, nil
}
// ListFilters returns all filters for a subreddit.
func (s *Store) ListFilters(subreddit string) ([]domain.Filter, error) {
rows, err := s.db.Query(
`SELECT id, subreddit, pattern, is_regex FROM filters WHERE subreddit = ?`,
subreddit,
)
if err != nil {
return nil, fmt.Errorf("store.ListFilters: %w", err)
}
defer rows.Close()
var filters []domain.Filter
for rows.Next() {
var f domain.Filter
var isRegex int
if err := rows.Scan(&f.ID, &f.Subreddit, &f.Pattern, &isRegex); err != nil {
return nil, fmt.Errorf("store.ListFilters scan: %w", err)
}
f.IsRegex = isRegex != 0
filters = append(filters, f)
}
return filters, rows.Err()
}
// RemoveFilter deletes a filter by ID.
func (s *Store) RemoveFilter(id int64) error {
if _, err := s.db.Exec(`DELETE FROM filters WHERE id = ?`, id); err != nil {
return fmt.Errorf("store.RemoveFilter: %w", err)
}
return nil
}
// AddFeedback records a vote for a post.
func (s *Store) AddFeedback(postID string, vote int) error {
if _, err := s.db.Exec(
`INSERT INTO feedback (post_id, vote) VALUES (?, ?)`, postID, vote,
); err != nil {
return fmt.Errorf("store.AddFeedback: %w", err)
}
return nil
}
// RecentFeedback returns up to limit feedback rows ordered by created_at DESC.
func (s *Store) RecentFeedback(limit int) ([]domain.Feedback, error) {
rows, err := s.db.Query(
`SELECT id, post_id, vote, created_at FROM feedback ORDER BY created_at DESC LIMIT ?`,
limit,
)
if err != nil {
return nil, fmt.Errorf("store.RecentFeedback: %w", err)
}
defer rows.Close()
var out []domain.Feedback
for rows.Next() {
var fb domain.Feedback
var createdAt string
if err := rows.Scan(&fb.ID, &fb.PostID, &fb.Vote, &createdAt); err != nil {
return nil, fmt.Errorf("store.RecentFeedback scan: %w", err)
}
fb.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
out = append(out, fb)
}
return out, rows.Err()
}
// --- helpers ---
type rowScanner interface {
Scan(dest ...any) error
}
func scanPost(row rowScanner) (domain.Post, error) {
var p domain.Post
var createdUTC, fetchedAt string
var readInt, starredInt, dismissedInt int
err := row.Scan(
&p.ID, &p.Subreddit, &p.Title, &p.Author, &p.URL, &p.SelfText,
&p.Score, &createdUTC, &fetchedAt,
&p.Relevance, &p.Summary,
&readInt, &starredInt, &dismissedInt,
)
if err != nil {
return domain.Post{}, fmt.Errorf("store scanPost: %w", err)
}
p.CreatedUTC, _ = time.Parse(time.RFC3339, createdUTC)
p.FetchedAt, _ = time.Parse(time.RFC3339, fetchedAt)
p.Read = readInt != 0
p.Starred = starredInt != 0
p.Dismissed = dismissedInt != 0
return p, nil
}
func collectPosts(rows *sql.Rows) ([]domain.Post, error) {
var posts []domain.Post
for rows.Next() {
p, err := scanPost(rows)
if err != nil {
return nil, err
}
posts = append(posts, p)
}
return posts, rows.Err()
}
func buildPostWhere(f ListFilter) (string, []any) {
var clauses []string
var args []any
if f.Subreddit != "" {
clauses = append(clauses, "subreddit = ?")
args = append(args, f.Subreddit)
}
if f.Unread != nil {
if *f.Unread {
clauses = append(clauses, "read = 0")
} else {
clauses = append(clauses, "read = 1")
}
}
if f.Starred != nil {
if *f.Starred {
clauses = append(clauses, "starred = 1")
} else {
clauses = append(clauses, "starred = 0")
}
}
if f.Dismissed != nil {
if *f.Dismissed {
clauses = append(clauses, "dismissed = 1")
} else {
clauses = append(clauses, "dismissed = 0")
}
}
if len(clauses) == 0 {
return "", args
}
return " WHERE " + strings.Join(clauses, " AND "), args
}
func buildPostSetClauses(u PostUpdate) ([]string, []any) {
var clauses []string
var args []any
if u.Read != nil {
clauses = append(clauses, "read = ?")
v := 0
if *u.Read {
v = 1
}
args = append(args, v)
}
if u.Starred != nil {
clauses = append(clauses, "starred = ?")
v := 0
if *u.Starred {
v = 1
}
args = append(args, v)
}
if u.Dismissed != nil {
clauses = append(clauses, "dismissed = ?")
v := 0
if *u.Dismissed {
v = 1
}
args = append(args, v)
}
if u.Relevance != nil {
clauses = append(clauses, "relevance = ?")
args = append(args, *u.Relevance)
}
if u.Summary != nil {
clauses = append(clauses, "summary = ?")
args = append(args, *u.Summary)
}
return clauses, args
}