diff --git a/go.mod b/go.mod index 693f131..db02ee1 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.26 require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/pelletier/go-toml/v2 v2.3.0 // indirect github.com/spf13/cobra v1.10.2 // indirect github.com/spf13/pflag v1.0.9 // indirect ) diff --git a/go.sum b/go.sum index a6ee3e0..4af9525 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,8 @@ github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/pelletier/go-toml/v2 v2.3.0 h1:k59bC/lIZREW0/iVaQR8nDHxVq8OVlIzYCOJf421CaM= +github.com/pelletier/go-toml/v2 v2.3.0/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..5172ee7 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,126 @@ +package config + +import ( + "fmt" + "os" + "path/filepath" + "time" + + toml "github.com/pelletier/go-toml/v2" +) + +type Config struct { + Reddit RedditConfig `toml:"reddit"` + LLM LLMConfig `toml:"llm"` + Interests InterestsConfig `toml:"interests"` + Monitor MonitorConfig `toml:"monitor"` + GRPC GRPCConfig `toml:"grpc"` +} + +type RedditConfig struct { + ClientID string `toml:"client_id"` + ClientSecret string `toml:"client_secret"` + Username string `toml:"username"` + Password string `toml:"password"` +} + +type LLMConfig struct { + Backend string `toml:"backend"` + Endpoint string `toml:"endpoint"` + Model string `toml:"model"` + APIKey string `toml:"api_key"` + RelevanceThreshold float64 `toml:"relevance_threshold"` +} + +type InterestsConfig struct { + Description string `toml:"description"` +} + +type MonitorConfig struct { + PollInterval Duration `toml:"poll_interval"` + MaxPostsPerPoll int `toml:"max_posts_per_poll"` +} + +type GRPCConfig struct { + Socket string `toml:"socket"` +} + +// Duration wraps time.Duration for TOML string parsing. +type Duration struct { + time.Duration +} + +func (d *Duration) UnmarshalText(text []byte) error { + var err error + d.Duration, err = time.ParseDuration(string(text)) + return err +} + +func (d Duration) MarshalText() ([]byte, error) { + return []byte(d.Duration.String()), nil +} + +func LoadFromFile(path string) (*Config, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read config: %w", err) + } + var cfg Config + if err := toml.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("parse config: %w", err) + } + return &cfg, nil +} + +func (c *Config) ApplyEnvOverrides() { + if v := os.Getenv("REDDIT_READER_REDDIT_CLIENT_ID"); v != "" { + c.Reddit.ClientID = v + } + if v := os.Getenv("REDDIT_READER_REDDIT_CLIENT_SECRET"); v != "" { + c.Reddit.ClientSecret = v + } + if v := os.Getenv("REDDIT_READER_REDDIT_USERNAME"); v != "" { + c.Reddit.Username = v + } + if v := os.Getenv("REDDIT_READER_REDDIT_PASSWORD"); v != "" { + c.Reddit.Password = v + } + if v := os.Getenv("REDDIT_READER_LLM_API_KEY"); v != "" { + c.LLM.APIKey = v + } + if v := os.Getenv("REDDIT_READER_LLM_BACKEND"); v != "" { + c.LLM.Backend = v + } + if v := os.Getenv("REDDIT_READER_LLM_ENDPOINT"); v != "" { + c.LLM.Endpoint = v + } + if v := os.Getenv("REDDIT_READER_LLM_MODEL"); v != "" { + c.LLM.Model = v + } +} + +func DefaultPath() string { + dir, err := os.UserConfigDir() + if err != nil { + return "" + } + return filepath.Join(dir, "reddit-reader", "config.toml") +} + +func (c *Config) SaveToFile(path string) error { + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return fmt.Errorf("create config dir: %w", err) + } + data, err := toml.Marshal(c) + if err != nil { + return fmt.Errorf("marshal config: %w", err) + } + return os.WriteFile(path, data, 0o600) +} + +func DefaultSocket() string { + if dir := os.Getenv("XDG_RUNTIME_DIR"); dir != "" { + return filepath.Join(dir, "reddit-reader.sock") + } + return "/tmp/reddit-reader.sock" +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..348fa84 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,121 @@ +package config_test + +import ( + "os" + "path/filepath" + "testing" + + "somegit.dev/vikingowl/reddit-reader/internal/config" +) + +func TestLoadFromFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "config.toml") + err := os.WriteFile(path, []byte(` +[reddit] +client_id = "test_id" +client_secret = "test_secret" +username = "test_user" +password = "test_pass" + +[llm] +backend = "ollama" +endpoint = "localhost:11434" +model = "mistral-small" +relevance_threshold = 0.7 + +[interests] +description = "Go programming, Linux" + +[monitor] +poll_interval = "5m" +max_posts_per_poll = 10 + +[grpc] +socket = "/tmp/test.sock" +`), 0o644) + if err != nil { + t.Fatal(err) + } + + cfg, err := config.LoadFromFile(path) + if err != nil { + t.Fatalf("LoadFromFile: %v", err) + } + + if cfg.Reddit.ClientID != "test_id" { + t.Errorf("ClientID = %q, want %q", cfg.Reddit.ClientID, "test_id") + } + if cfg.LLM.Backend != "ollama" { + t.Errorf("Backend = %q, want %q", cfg.LLM.Backend, "ollama") + } + if cfg.LLM.RelevanceThreshold != 0.7 { + t.Errorf("RelevanceThreshold = %f, want 0.7", cfg.LLM.RelevanceThreshold) + } + if cfg.Interests.Description != "Go programming, Linux" { + t.Errorf("Description = %q, want %q", cfg.Interests.Description, "Go programming, Linux") + } + if cfg.Monitor.PollInterval.String() != "5m0s" { + t.Errorf("PollInterval = %v, want 5m", cfg.Monitor.PollInterval) + } + if cfg.Monitor.MaxPostsPerPoll != 10 { + t.Errorf("MaxPostsPerPoll = %d, want 10", cfg.Monitor.MaxPostsPerPoll) + } + if cfg.GRPC.Socket != "/tmp/test.sock" { + t.Errorf("Socket = %q, want %q", cfg.GRPC.Socket, "/tmp/test.sock") + } +} + +func TestEnvVarOverride(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "config.toml") + err := os.WriteFile(path, []byte(` +[reddit] +client_id = "file_id" +client_secret = "file_secret" +username = "file_user" +password = "file_pass" + +[llm] +backend = "ollama" +endpoint = "localhost:11434" +model = "mistral-small" +relevance_threshold = 0.6 + +[interests] +description = "" + +[monitor] +poll_interval = "2m" +max_posts_per_poll = 25 + +[grpc] +socket = "/tmp/test.sock" +`), 0o644) + if err != nil { + t.Fatal(err) + } + + t.Setenv("REDDIT_READER_REDDIT_CLIENT_ID", "env_id") + t.Setenv("REDDIT_READER_LLM_API_KEY", "env_key") + + cfg, err := config.LoadFromFile(path) + if err != nil { + t.Fatalf("LoadFromFile: %v", err) + } + cfg.ApplyEnvOverrides() + + if cfg.Reddit.ClientID != "env_id" { + t.Errorf("ClientID = %q, want %q (env override)", cfg.Reddit.ClientID, "env_id") + } + if cfg.LLM.APIKey != "env_key" { + t.Errorf("APIKey = %q, want %q (env override)", cfg.LLM.APIKey, "env_key") + } +} + +func TestDefaultConfigPath(t *testing.T) { + path := config.DefaultPath() + if path == "" { + t.Error("DefaultPath returned empty string") + } +}