241 lines
6.0 KiB
Go
241 lines
6.0 KiB
Go
package setup
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
|
|
"somegit.dev/vikingowl/reddit-reader/internal/config"
|
|
"somegit.dev/vikingowl/reddit-reader/internal/store"
|
|
)
|
|
|
|
type Wizard struct {
|
|
scanner *bufio.Scanner
|
|
cfg *config.Config
|
|
}
|
|
|
|
func Run() error {
|
|
w := &Wizard{scanner: bufio.NewScanner(os.Stdin), cfg: &config.Config{}}
|
|
return w.run()
|
|
}
|
|
|
|
func (w *Wizard) run() error {
|
|
fmt.Println("=== Reddit Reader Setup ===")
|
|
fmt.Println()
|
|
|
|
if err := w.setupReddit(); err != nil {
|
|
return fmt.Errorf("setup reddit: %w", err)
|
|
}
|
|
|
|
if err := w.setupLLM(); err != nil {
|
|
return fmt.Errorf("setup llm: %w", err)
|
|
}
|
|
|
|
if err := w.setupSubreddits(); err != nil {
|
|
return fmt.Errorf("setup subreddits: %w", err)
|
|
}
|
|
|
|
if err := w.setupInterests(); err != nil {
|
|
return fmt.Errorf("setup interests: %w", err)
|
|
}
|
|
|
|
w.cfg.Monitor.PollInterval = config.Duration{Duration: 2 * time.Minute}
|
|
w.cfg.Monitor.MaxPostsPerPoll = 25
|
|
w.cfg.GRPC.Socket = config.DefaultSocket()
|
|
|
|
cfgPath := config.DefaultPath()
|
|
if err := w.cfg.SaveToFile(cfgPath); err != nil {
|
|
return fmt.Errorf("save config: %w", err)
|
|
}
|
|
fmt.Printf("Config saved to %s\n", cfgPath)
|
|
|
|
dbPath := filepath.Join(filepath.Dir(cfgPath), "reddit-reader.db")
|
|
db, err := store.Open(dbPath)
|
|
if err != nil {
|
|
return fmt.Errorf("create database: %w", err)
|
|
}
|
|
db.Close()
|
|
fmt.Printf("Database initialized at %s\n", dbPath)
|
|
|
|
w.offerSystemd()
|
|
|
|
fmt.Println()
|
|
fmt.Println("Setup complete. Run 'reddit-reader serve' to start.")
|
|
return nil
|
|
}
|
|
|
|
func (w *Wizard) setupReddit() error {
|
|
fmt.Println("--- Reddit credentials ---")
|
|
fmt.Println("Create an app at https://www.reddit.com/prefs/apps (script type).")
|
|
fmt.Println()
|
|
|
|
w.cfg.Reddit.ClientID = w.prompt("Client ID")
|
|
w.cfg.Reddit.ClientSecret = w.prompt("Client Secret")
|
|
w.cfg.Reddit.Username = w.prompt("Reddit Username")
|
|
w.cfg.Reddit.Password = w.prompt("Reddit Password")
|
|
return nil
|
|
}
|
|
|
|
func (w *Wizard) setupLLM() error {
|
|
fmt.Println()
|
|
fmt.Println("--- LLM backend ---")
|
|
|
|
ollamaAvailable := w.probeOllama()
|
|
if ollamaAvailable {
|
|
fmt.Println("Ollama detected at localhost:11434.")
|
|
} else {
|
|
fmt.Println("Ollama not detected at localhost:11434.")
|
|
}
|
|
|
|
fmt.Println("Backend options: ollama, openai")
|
|
defaultBackend := "ollama"
|
|
if !ollamaAvailable {
|
|
defaultBackend = "openai"
|
|
}
|
|
w.cfg.LLM.Backend = w.promptDefault("Backend", defaultBackend)
|
|
|
|
switch w.cfg.LLM.Backend {
|
|
case "ollama":
|
|
w.cfg.LLM.Endpoint = w.promptDefault("Ollama endpoint", "http://localhost:11434")
|
|
w.cfg.LLM.Model = w.promptDefault("Model", "llama3")
|
|
case "openai":
|
|
w.cfg.LLM.Endpoint = w.promptDefault("OpenAI endpoint", "https://api.openai.com/v1")
|
|
w.cfg.LLM.Model = w.promptDefault("Model", "gpt-4o-mini")
|
|
w.cfg.LLM.APIKey = w.prompt("API Key")
|
|
default:
|
|
w.cfg.LLM.Endpoint = w.prompt("Endpoint")
|
|
w.cfg.LLM.Model = w.prompt("Model")
|
|
}
|
|
|
|
w.cfg.LLM.RelevanceThreshold = 0.5
|
|
return nil
|
|
}
|
|
|
|
func (w *Wizard) setupSubreddits() error {
|
|
fmt.Println()
|
|
fmt.Println("--- Subreddits ---")
|
|
fmt.Println("Enter subreddit names (comma-separated, without r/).")
|
|
fmt.Println("Keyword filters can be configured via the TUI later.")
|
|
fmt.Println()
|
|
|
|
raw := w.prompt("Subreddits")
|
|
parts := strings.Split(raw, ",")
|
|
for _, p := range parts {
|
|
name := strings.TrimSpace(p)
|
|
if name != "" {
|
|
fmt.Printf(" + %s\n", name)
|
|
}
|
|
}
|
|
// Subreddits are stored in the DB, not the config; we print them here so
|
|
// the user sees they were accepted. The monitor will read from the DB.
|
|
// We don't add them to the DB here because we haven't opened it yet;
|
|
// they can be added via the TUI after first run.
|
|
_ = parts
|
|
return nil
|
|
}
|
|
|
|
func (w *Wizard) setupInterests() error {
|
|
fmt.Println()
|
|
fmt.Println("--- Interests ---")
|
|
fmt.Println("Describe your interests in plain text. This drives relevance scoring.")
|
|
fmt.Println()
|
|
|
|
w.cfg.Interests.Description = w.prompt("Interests")
|
|
return nil
|
|
}
|
|
|
|
func (w *Wizard) prompt(label string) string {
|
|
fmt.Printf("%s: ", label)
|
|
if w.scanner.Scan() {
|
|
return strings.TrimSpace(w.scanner.Text())
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func (w *Wizard) promptDefault(label, def string) string {
|
|
fmt.Printf("%s [%s]: ", label, def)
|
|
if w.scanner.Scan() {
|
|
v := strings.TrimSpace(w.scanner.Text())
|
|
if v == "" {
|
|
return def
|
|
}
|
|
return v
|
|
}
|
|
return def
|
|
}
|
|
|
|
func (w *Wizard) probeOllama() bool {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
defer cancel()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:11434/api/tags", nil)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
resp.Body.Close()
|
|
return resp.StatusCode == http.StatusOK
|
|
}
|
|
|
|
func (w *Wizard) offerSystemd() {
|
|
fmt.Println()
|
|
answer := w.promptDefault("Install systemd user units? (y/N)", "N")
|
|
if !strings.EqualFold(answer, "y") {
|
|
return
|
|
}
|
|
|
|
systemdDir := filepath.Join(os.Getenv("HOME"), ".config", "systemd", "user")
|
|
if err := os.MkdirAll(systemdDir, 0o755); err != nil {
|
|
fmt.Fprintf(os.Stderr, "warning: could not create systemd dir: %v\n", err)
|
|
return
|
|
}
|
|
|
|
serviceUnit := `[Unit]
|
|
Description=Reddit Reader Monitor
|
|
After=network-online.target
|
|
|
|
[Service]
|
|
Type=simple
|
|
ExecStart=%h/.local/bin/reddit-reader serve
|
|
Restart=on-failure
|
|
RestartSec=5
|
|
|
|
[Install]
|
|
WantedBy=default.target
|
|
`
|
|
|
|
socketUnit := `[Unit]
|
|
Description=Reddit Reader Socket
|
|
|
|
[Socket]
|
|
ListenStream=%t/reddit-reader.sock
|
|
|
|
[Install]
|
|
WantedBy=sockets.target
|
|
`
|
|
|
|
servicePath := filepath.Join(systemdDir, "reddit-reader.service")
|
|
if err := os.WriteFile(servicePath, []byte(serviceUnit), 0o644); err != nil {
|
|
fmt.Fprintf(os.Stderr, "warning: could not write service unit: %v\n", err)
|
|
return
|
|
}
|
|
|
|
socketPath := filepath.Join(systemdDir, "reddit-reader.socket")
|
|
if err := os.WriteFile(socketPath, []byte(socketUnit), 0o644); err != nil {
|
|
fmt.Fprintf(os.Stderr, "warning: could not write socket unit: %v\n", err)
|
|
return
|
|
}
|
|
|
|
fmt.Printf("Wrote %s\n", servicePath)
|
|
fmt.Printf("Wrote %s\n", socketPath)
|
|
fmt.Println("Run: systemctl --user daemon-reload && systemctl --user enable --now reddit-reader.socket")
|
|
}
|