435 lines
11 KiB
Go
435 lines
11 KiB
Go
package store
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"somegit.dev/vikingowl/reddit-reader/internal/domain"
|
|
_ "modernc.org/sqlite"
|
|
)
|
|
|
|
const schema = `
|
|
CREATE TABLE IF NOT EXISTS subreddits (
|
|
name TEXT PRIMARY KEY,
|
|
enabled INTEGER DEFAULT 1,
|
|
poll_sort TEXT DEFAULT 'new',
|
|
added_at TEXT DEFAULT (datetime('now'))
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS filters (
|
|
id INTEGER PRIMARY KEY,
|
|
subreddit TEXT REFERENCES subreddits(name) ON DELETE CASCADE,
|
|
pattern TEXT NOT NULL,
|
|
is_regex INTEGER DEFAULT 0
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS posts (
|
|
id TEXT PRIMARY KEY,
|
|
subreddit TEXT NOT NULL,
|
|
title TEXT NOT NULL,
|
|
author TEXT,
|
|
url TEXT,
|
|
selftext TEXT,
|
|
score INTEGER,
|
|
created_utc TEXT,
|
|
fetched_at TEXT DEFAULT (datetime('now')),
|
|
relevance REAL,
|
|
summary TEXT,
|
|
read INTEGER DEFAULT 0,
|
|
starred INTEGER DEFAULT 0,
|
|
dismissed INTEGER DEFAULT 0
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS feedback (
|
|
id INTEGER PRIMARY KEY,
|
|
post_id TEXT REFERENCES posts(id),
|
|
vote INTEGER NOT NULL,
|
|
created_at TEXT DEFAULT (datetime('now'))
|
|
);
|
|
`
|
|
|
|
// Store is the SQLite-backed persistence layer.
|
|
type Store struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
// ListFilter controls which posts ListPosts returns.
|
|
type ListFilter struct {
|
|
Subreddit string
|
|
Unread *bool
|
|
Starred *bool
|
|
Dismissed *bool
|
|
Limit int
|
|
}
|
|
|
|
// PostUpdate carries the fields to update on a post; nil fields are skipped.
|
|
type PostUpdate struct {
|
|
Read *bool
|
|
Starred *bool
|
|
Dismissed *bool
|
|
Relevance *float64
|
|
Summary *string
|
|
}
|
|
|
|
// Open opens (or creates) the SQLite database at dsn and runs migrations.
|
|
func Open(dsn string) (*Store, error) {
|
|
db, err := sql.Open("sqlite", dsn)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("store.Open: %w", err)
|
|
}
|
|
db.SetMaxOpenConns(1)
|
|
|
|
if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil {
|
|
db.Close()
|
|
return nil, fmt.Errorf("store.Open WAL: %w", err)
|
|
}
|
|
if _, err := db.Exec("PRAGMA foreign_keys=ON"); err != nil {
|
|
db.Close()
|
|
return nil, fmt.Errorf("store.Open foreign_keys: %w", err)
|
|
}
|
|
if _, err := db.Exec(schema); err != nil {
|
|
db.Close()
|
|
return nil, fmt.Errorf("store.Open schema: %w", err)
|
|
}
|
|
|
|
return &Store{db: db}, nil
|
|
}
|
|
|
|
// Close closes the underlying database connection.
|
|
func (s *Store) Close() error {
|
|
return s.db.Close()
|
|
}
|
|
|
|
// InsertPost inserts a post; silently ignores duplicates (INSERT OR IGNORE).
|
|
func (s *Store) InsertPost(p domain.Post) error {
|
|
const q = `
|
|
INSERT OR IGNORE INTO posts
|
|
(id, subreddit, title, author, url, selftext, score, created_utc, relevance, summary)
|
|
VALUES
|
|
(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`
|
|
|
|
_, err := s.db.Exec(q,
|
|
p.ID, p.Subreddit, p.Title, p.Author, p.URL, p.SelfText,
|
|
p.Score, p.CreatedUTC.UTC().Format(time.RFC3339),
|
|
p.Relevance, p.Summary,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("store.InsertPost: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetPost retrieves a single post by ID.
|
|
func (s *Store) GetPost(id string) (domain.Post, error) {
|
|
const q = `
|
|
SELECT id, subreddit, title, author, url, selftext, score,
|
|
created_utc, fetched_at, relevance, summary, read, starred, dismissed
|
|
FROM posts WHERE id = ?`
|
|
|
|
row := s.db.QueryRow(q, id)
|
|
return scanPost(row)
|
|
}
|
|
|
|
// PostExists reports whether a post with the given ID exists.
|
|
func (s *Store) PostExists(id string) (bool, error) {
|
|
var n int
|
|
err := s.db.QueryRow(`SELECT COUNT(1) FROM posts WHERE id = ?`, id).Scan(&n)
|
|
if err != nil {
|
|
return false, fmt.Errorf("store.PostExists: %w", err)
|
|
}
|
|
return n > 0, nil
|
|
}
|
|
|
|
// ListPosts returns posts ordered by relevance DESC, fetched_at DESC.
|
|
func (s *Store) ListPosts(f ListFilter) ([]domain.Post, error) {
|
|
where, args := buildPostWhere(f)
|
|
limit := ""
|
|
if f.Limit > 0 {
|
|
limit = fmt.Sprintf(" LIMIT %d", f.Limit)
|
|
}
|
|
q := `SELECT id, subreddit, title, author, url, selftext, score,
|
|
created_utc, fetched_at, relevance, summary, read, starred, dismissed
|
|
FROM posts` + where + ` ORDER BY COALESCE(relevance, 0) DESC, fetched_at DESC` + limit
|
|
|
|
rows, err := s.db.Query(q, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("store.ListPosts: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
return collectPosts(rows)
|
|
}
|
|
|
|
// UpdatePost updates the non-nil fields in u for the post with the given ID.
|
|
func (s *Store) UpdatePost(id string, u PostUpdate) error {
|
|
setClauses, args := buildPostSetClauses(u)
|
|
if len(setClauses) == 0 {
|
|
return nil
|
|
}
|
|
args = append(args, id)
|
|
q := `UPDATE posts SET ` + strings.Join(setClauses, ", ") + ` WHERE id = ?`
|
|
if _, err := s.db.Exec(q, args...); err != nil {
|
|
return fmt.Errorf("store.UpdatePost: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UnsummarizedPosts returns posts where summary IS NULL.
|
|
func (s *Store) UnsummarizedPosts() ([]domain.Post, error) {
|
|
const q = `SELECT id, subreddit, title, author, url, selftext, score,
|
|
created_utc, fetched_at, relevance, summary, read, starred, dismissed
|
|
FROM posts WHERE summary IS NULL`
|
|
|
|
rows, err := s.db.Query(q)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("store.UnsummarizedPosts: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
return collectPosts(rows)
|
|
}
|
|
|
|
// AddSubreddit inserts a subreddit (INSERT OR IGNORE).
|
|
func (s *Store) AddSubreddit(sub domain.Subreddit) error {
|
|
_, err := s.db.Exec(
|
|
`INSERT OR IGNORE INTO subreddits (name, poll_sort) VALUES (?, ?)`,
|
|
sub.Name, sub.PollSort,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("store.AddSubreddit: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// RemoveSubreddit deletes a subreddit by name.
|
|
func (s *Store) RemoveSubreddit(name string) error {
|
|
if _, err := s.db.Exec(`DELETE FROM subreddits WHERE name = ?`, name); err != nil {
|
|
return fmt.Errorf("store.RemoveSubreddit: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ListSubreddits returns all subreddits.
|
|
func (s *Store) ListSubreddits() ([]domain.Subreddit, error) {
|
|
rows, err := s.db.Query(`SELECT name, enabled, poll_sort, added_at FROM subreddits ORDER BY name`)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("store.ListSubreddits: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var subs []domain.Subreddit
|
|
for rows.Next() {
|
|
var sub domain.Subreddit
|
|
var enabled int
|
|
var addedAt string
|
|
if err := rows.Scan(&sub.Name, &enabled, &sub.PollSort, &addedAt); err != nil {
|
|
return nil, fmt.Errorf("store.ListSubreddits scan: %w", err)
|
|
}
|
|
sub.Enabled = enabled != 0
|
|
sub.AddedAt, _ = time.Parse(time.RFC3339, addedAt)
|
|
subs = append(subs, sub)
|
|
}
|
|
return subs, rows.Err()
|
|
}
|
|
|
|
// AddFilter inserts a filter and returns its ID.
|
|
func (s *Store) AddFilter(f domain.Filter) (int64, error) {
|
|
isRegex := 0
|
|
if f.IsRegex {
|
|
isRegex = 1
|
|
}
|
|
res, err := s.db.Exec(
|
|
`INSERT INTO filters (subreddit, pattern, is_regex) VALUES (?, ?, ?)`,
|
|
f.Subreddit, f.Pattern, isRegex,
|
|
)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("store.AddFilter: %w", err)
|
|
}
|
|
id, err := res.LastInsertId()
|
|
if err != nil {
|
|
return 0, fmt.Errorf("store.AddFilter LastInsertId: %w", err)
|
|
}
|
|
return id, nil
|
|
}
|
|
|
|
// ListFilters returns all filters for a subreddit.
|
|
func (s *Store) ListFilters(subreddit string) ([]domain.Filter, error) {
|
|
rows, err := s.db.Query(
|
|
`SELECT id, subreddit, pattern, is_regex FROM filters WHERE subreddit = ?`,
|
|
subreddit,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("store.ListFilters: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var filters []domain.Filter
|
|
for rows.Next() {
|
|
var f domain.Filter
|
|
var isRegex int
|
|
if err := rows.Scan(&f.ID, &f.Subreddit, &f.Pattern, &isRegex); err != nil {
|
|
return nil, fmt.Errorf("store.ListFilters scan: %w", err)
|
|
}
|
|
f.IsRegex = isRegex != 0
|
|
filters = append(filters, f)
|
|
}
|
|
return filters, rows.Err()
|
|
}
|
|
|
|
// RemoveFilter deletes a filter by ID.
|
|
func (s *Store) RemoveFilter(id int64) error {
|
|
if _, err := s.db.Exec(`DELETE FROM filters WHERE id = ?`, id); err != nil {
|
|
return fmt.Errorf("store.RemoveFilter: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// AddFeedback records a vote for a post.
|
|
func (s *Store) AddFeedback(postID string, vote int) error {
|
|
if _, err := s.db.Exec(
|
|
`INSERT INTO feedback (post_id, vote) VALUES (?, ?)`, postID, vote,
|
|
); err != nil {
|
|
return fmt.Errorf("store.AddFeedback: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// RecentFeedback returns up to limit feedback rows ordered by created_at DESC.
|
|
func (s *Store) RecentFeedback(limit int) ([]domain.Feedback, error) {
|
|
rows, err := s.db.Query(
|
|
`SELECT id, post_id, vote, created_at FROM feedback ORDER BY created_at DESC LIMIT ?`,
|
|
limit,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("store.RecentFeedback: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var out []domain.Feedback
|
|
for rows.Next() {
|
|
var fb domain.Feedback
|
|
var createdAt string
|
|
if err := rows.Scan(&fb.ID, &fb.PostID, &fb.Vote, &createdAt); err != nil {
|
|
return nil, fmt.Errorf("store.RecentFeedback scan: %w", err)
|
|
}
|
|
fb.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
|
|
out = append(out, fb)
|
|
}
|
|
return out, rows.Err()
|
|
}
|
|
|
|
// --- helpers ---
|
|
|
|
type rowScanner interface {
|
|
Scan(dest ...any) error
|
|
}
|
|
|
|
func scanPost(row rowScanner) (domain.Post, error) {
|
|
var p domain.Post
|
|
var createdUTC, fetchedAt string
|
|
var readInt, starredInt, dismissedInt int
|
|
err := row.Scan(
|
|
&p.ID, &p.Subreddit, &p.Title, &p.Author, &p.URL, &p.SelfText,
|
|
&p.Score, &createdUTC, &fetchedAt,
|
|
&p.Relevance, &p.Summary,
|
|
&readInt, &starredInt, &dismissedInt,
|
|
)
|
|
if err != nil {
|
|
return domain.Post{}, fmt.Errorf("store scanPost: %w", err)
|
|
}
|
|
p.CreatedUTC, _ = time.Parse(time.RFC3339, createdUTC)
|
|
p.FetchedAt, _ = time.Parse(time.RFC3339, fetchedAt)
|
|
p.Read = readInt != 0
|
|
p.Starred = starredInt != 0
|
|
p.Dismissed = dismissedInt != 0
|
|
return p, nil
|
|
}
|
|
|
|
func collectPosts(rows *sql.Rows) ([]domain.Post, error) {
|
|
var posts []domain.Post
|
|
for rows.Next() {
|
|
p, err := scanPost(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
posts = append(posts, p)
|
|
}
|
|
return posts, rows.Err()
|
|
}
|
|
|
|
func buildPostWhere(f ListFilter) (string, []any) {
|
|
var clauses []string
|
|
var args []any
|
|
|
|
if f.Subreddit != "" {
|
|
clauses = append(clauses, "subreddit = ?")
|
|
args = append(args, f.Subreddit)
|
|
}
|
|
if f.Unread != nil {
|
|
if *f.Unread {
|
|
clauses = append(clauses, "read = 0")
|
|
} else {
|
|
clauses = append(clauses, "read = 1")
|
|
}
|
|
}
|
|
if f.Starred != nil {
|
|
if *f.Starred {
|
|
clauses = append(clauses, "starred = 1")
|
|
} else {
|
|
clauses = append(clauses, "starred = 0")
|
|
}
|
|
}
|
|
if f.Dismissed != nil {
|
|
if *f.Dismissed {
|
|
clauses = append(clauses, "dismissed = 1")
|
|
} else {
|
|
clauses = append(clauses, "dismissed = 0")
|
|
}
|
|
}
|
|
|
|
if len(clauses) == 0 {
|
|
return "", args
|
|
}
|
|
return " WHERE " + strings.Join(clauses, " AND "), args
|
|
}
|
|
|
|
func buildPostSetClauses(u PostUpdate) ([]string, []any) {
|
|
var clauses []string
|
|
var args []any
|
|
|
|
if u.Read != nil {
|
|
clauses = append(clauses, "read = ?")
|
|
v := 0
|
|
if *u.Read {
|
|
v = 1
|
|
}
|
|
args = append(args, v)
|
|
}
|
|
if u.Starred != nil {
|
|
clauses = append(clauses, "starred = ?")
|
|
v := 0
|
|
if *u.Starred {
|
|
v = 1
|
|
}
|
|
args = append(args, v)
|
|
}
|
|
if u.Dismissed != nil {
|
|
clauses = append(clauses, "dismissed = ?")
|
|
v := 0
|
|
if *u.Dismissed {
|
|
v = 1
|
|
}
|
|
args = append(args, v)
|
|
}
|
|
if u.Relevance != nil {
|
|
clauses = append(clauses, "relevance = ?")
|
|
args = append(args, *u.Relevance)
|
|
}
|
|
if u.Summary != nil {
|
|
clauses = append(clauses, "summary = ?")
|
|
args = append(args, *u.Summary)
|
|
}
|
|
|
|
return clauses, args
|
|
}
|