feat(config): TOML config parsing with env var overrides
This commit is contained in:
1
go.mod
1
go.mod
@@ -4,6 +4,7 @@ go 1.26
|
||||
|
||||
require (
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.3.0 // indirect
|
||||
github.com/spf13/cobra v1.10.2 // indirect
|
||||
github.com/spf13/pflag v1.0.9 // indirect
|
||||
)
|
||||
|
||||
2
go.sum
2
go.sum
@@ -1,6 +1,8 @@
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/pelletier/go-toml/v2 v2.3.0 h1:k59bC/lIZREW0/iVaQR8nDHxVq8OVlIzYCOJf421CaM=
|
||||
github.com/pelletier/go-toml/v2 v2.3.0/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
|
||||
github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4=
|
||||
|
||||
126
internal/config/config.go
Normal file
126
internal/config/config.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
toml "github.com/pelletier/go-toml/v2"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Reddit RedditConfig `toml:"reddit"`
|
||||
LLM LLMConfig `toml:"llm"`
|
||||
Interests InterestsConfig `toml:"interests"`
|
||||
Monitor MonitorConfig `toml:"monitor"`
|
||||
GRPC GRPCConfig `toml:"grpc"`
|
||||
}
|
||||
|
||||
type RedditConfig struct {
|
||||
ClientID string `toml:"client_id"`
|
||||
ClientSecret string `toml:"client_secret"`
|
||||
Username string `toml:"username"`
|
||||
Password string `toml:"password"`
|
||||
}
|
||||
|
||||
type LLMConfig struct {
|
||||
Backend string `toml:"backend"`
|
||||
Endpoint string `toml:"endpoint"`
|
||||
Model string `toml:"model"`
|
||||
APIKey string `toml:"api_key"`
|
||||
RelevanceThreshold float64 `toml:"relevance_threshold"`
|
||||
}
|
||||
|
||||
type InterestsConfig struct {
|
||||
Description string `toml:"description"`
|
||||
}
|
||||
|
||||
type MonitorConfig struct {
|
||||
PollInterval Duration `toml:"poll_interval"`
|
||||
MaxPostsPerPoll int `toml:"max_posts_per_poll"`
|
||||
}
|
||||
|
||||
type GRPCConfig struct {
|
||||
Socket string `toml:"socket"`
|
||||
}
|
||||
|
||||
// Duration wraps time.Duration for TOML string parsing.
|
||||
type Duration struct {
|
||||
time.Duration
|
||||
}
|
||||
|
||||
func (d *Duration) UnmarshalText(text []byte) error {
|
||||
var err error
|
||||
d.Duration, err = time.ParseDuration(string(text))
|
||||
return err
|
||||
}
|
||||
|
||||
func (d Duration) MarshalText() ([]byte, error) {
|
||||
return []byte(d.Duration.String()), nil
|
||||
}
|
||||
|
||||
func LoadFromFile(path string) (*Config, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read config: %w", err)
|
||||
}
|
||||
var cfg Config
|
||||
if err := toml.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func (c *Config) ApplyEnvOverrides() {
|
||||
if v := os.Getenv("REDDIT_READER_REDDIT_CLIENT_ID"); v != "" {
|
||||
c.Reddit.ClientID = v
|
||||
}
|
||||
if v := os.Getenv("REDDIT_READER_REDDIT_CLIENT_SECRET"); v != "" {
|
||||
c.Reddit.ClientSecret = v
|
||||
}
|
||||
if v := os.Getenv("REDDIT_READER_REDDIT_USERNAME"); v != "" {
|
||||
c.Reddit.Username = v
|
||||
}
|
||||
if v := os.Getenv("REDDIT_READER_REDDIT_PASSWORD"); v != "" {
|
||||
c.Reddit.Password = v
|
||||
}
|
||||
if v := os.Getenv("REDDIT_READER_LLM_API_KEY"); v != "" {
|
||||
c.LLM.APIKey = v
|
||||
}
|
||||
if v := os.Getenv("REDDIT_READER_LLM_BACKEND"); v != "" {
|
||||
c.LLM.Backend = v
|
||||
}
|
||||
if v := os.Getenv("REDDIT_READER_LLM_ENDPOINT"); v != "" {
|
||||
c.LLM.Endpoint = v
|
||||
}
|
||||
if v := os.Getenv("REDDIT_READER_LLM_MODEL"); v != "" {
|
||||
c.LLM.Model = v
|
||||
}
|
||||
}
|
||||
|
||||
func DefaultPath() string {
|
||||
dir, err := os.UserConfigDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return filepath.Join(dir, "reddit-reader", "config.toml")
|
||||
}
|
||||
|
||||
func (c *Config) SaveToFile(path string) error {
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
return fmt.Errorf("create config dir: %w", err)
|
||||
}
|
||||
data, err := toml.Marshal(c)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal config: %w", err)
|
||||
}
|
||||
return os.WriteFile(path, data, 0o600)
|
||||
}
|
||||
|
||||
func DefaultSocket() string {
|
||||
if dir := os.Getenv("XDG_RUNTIME_DIR"); dir != "" {
|
||||
return filepath.Join(dir, "reddit-reader.sock")
|
||||
}
|
||||
return "/tmp/reddit-reader.sock"
|
||||
}
|
||||
121
internal/config/config_test.go
Normal file
121
internal/config/config_test.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package config_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/vikingowl/reddit-reader/internal/config"
|
||||
)
|
||||
|
||||
func TestLoadFromFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "config.toml")
|
||||
err := os.WriteFile(path, []byte(`
|
||||
[reddit]
|
||||
client_id = "test_id"
|
||||
client_secret = "test_secret"
|
||||
username = "test_user"
|
||||
password = "test_pass"
|
||||
|
||||
[llm]
|
||||
backend = "ollama"
|
||||
endpoint = "localhost:11434"
|
||||
model = "mistral-small"
|
||||
relevance_threshold = 0.7
|
||||
|
||||
[interests]
|
||||
description = "Go programming, Linux"
|
||||
|
||||
[monitor]
|
||||
poll_interval = "5m"
|
||||
max_posts_per_poll = 10
|
||||
|
||||
[grpc]
|
||||
socket = "/tmp/test.sock"
|
||||
`), 0o644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg, err := config.LoadFromFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadFromFile: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Reddit.ClientID != "test_id" {
|
||||
t.Errorf("ClientID = %q, want %q", cfg.Reddit.ClientID, "test_id")
|
||||
}
|
||||
if cfg.LLM.Backend != "ollama" {
|
||||
t.Errorf("Backend = %q, want %q", cfg.LLM.Backend, "ollama")
|
||||
}
|
||||
if cfg.LLM.RelevanceThreshold != 0.7 {
|
||||
t.Errorf("RelevanceThreshold = %f, want 0.7", cfg.LLM.RelevanceThreshold)
|
||||
}
|
||||
if cfg.Interests.Description != "Go programming, Linux" {
|
||||
t.Errorf("Description = %q, want %q", cfg.Interests.Description, "Go programming, Linux")
|
||||
}
|
||||
if cfg.Monitor.PollInterval.String() != "5m0s" {
|
||||
t.Errorf("PollInterval = %v, want 5m", cfg.Monitor.PollInterval)
|
||||
}
|
||||
if cfg.Monitor.MaxPostsPerPoll != 10 {
|
||||
t.Errorf("MaxPostsPerPoll = %d, want 10", cfg.Monitor.MaxPostsPerPoll)
|
||||
}
|
||||
if cfg.GRPC.Socket != "/tmp/test.sock" {
|
||||
t.Errorf("Socket = %q, want %q", cfg.GRPC.Socket, "/tmp/test.sock")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvVarOverride(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "config.toml")
|
||||
err := os.WriteFile(path, []byte(`
|
||||
[reddit]
|
||||
client_id = "file_id"
|
||||
client_secret = "file_secret"
|
||||
username = "file_user"
|
||||
password = "file_pass"
|
||||
|
||||
[llm]
|
||||
backend = "ollama"
|
||||
endpoint = "localhost:11434"
|
||||
model = "mistral-small"
|
||||
relevance_threshold = 0.6
|
||||
|
||||
[interests]
|
||||
description = ""
|
||||
|
||||
[monitor]
|
||||
poll_interval = "2m"
|
||||
max_posts_per_poll = 25
|
||||
|
||||
[grpc]
|
||||
socket = "/tmp/test.sock"
|
||||
`), 0o644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Setenv("REDDIT_READER_REDDIT_CLIENT_ID", "env_id")
|
||||
t.Setenv("REDDIT_READER_LLM_API_KEY", "env_key")
|
||||
|
||||
cfg, err := config.LoadFromFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadFromFile: %v", err)
|
||||
}
|
||||
cfg.ApplyEnvOverrides()
|
||||
|
||||
if cfg.Reddit.ClientID != "env_id" {
|
||||
t.Errorf("ClientID = %q, want %q (env override)", cfg.Reddit.ClientID, "env_id")
|
||||
}
|
||||
if cfg.LLM.APIKey != "env_key" {
|
||||
t.Errorf("APIKey = %q, want %q (env override)", cfg.LLM.APIKey, "env_key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfigPath(t *testing.T) {
|
||||
path := config.DefaultPath()
|
||||
if path == "" {
|
||||
t.Error("DefaultPath returned empty string")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user