diff --git a/cmd/gnoma/main.go b/cmd/gnoma/main.go index b783e0d..3c4bd9c 100644 --- a/cmd/gnoma/main.go +++ b/cmd/gnoma/main.go @@ -432,7 +432,11 @@ func main() { if armModel == "" { armModel = prov.DefaultModel() } - sess := session.NewLocal(eng, *providerName, armModel) + sess := session.NewLocal(session.LocalConfig{ + Engine: eng, + Provider: *providerName, + Model: armModel, + }) defer sess.Close() m := tui.New(sess, tui.Config{ diff --git a/internal/session/local.go b/internal/session/local.go index d9b6f11..91edd4e 100644 --- a/internal/session/local.go +++ b/internal/session/local.go @@ -3,12 +3,26 @@ package session import ( "context" "fmt" + "log/slog" "sync" + "time" "somegit.dev/Owlibou/gnoma/internal/engine" + "somegit.dev/Owlibou/gnoma/internal/security" "somegit.dev/Owlibou/gnoma/internal/stream" ) +// LocalConfig holds all configuration for a Local session. +type LocalConfig struct { + Engine *engine.Engine + Provider string + Model string + SessionID string // identifies this session on disk + Store *SessionStore // nil = no persistence + Incognito *security.IncognitoMode // nil = always persist + Logger *slog.Logger // nil = slog.Default() +} + // Local implements Session using goroutines and channels within the same process. type Local struct { mu sync.Mutex @@ -26,16 +40,37 @@ type Local struct { provider string model string turnCount int + + // Persistence + sessionID string + store *SessionStore + incognito *security.IncognitoMode + createdAt time.Time + logger *slog.Logger } // NewLocal creates a channel-based in-process session. -func NewLocal(eng *engine.Engine, providerName, model string) *Local { - return &Local{ - eng: eng, - state: StateIdle, - provider: providerName, - model: model, +func NewLocal(cfg LocalConfig) *Local { + logger := cfg.Logger + if logger == nil { + logger = slog.Default() } + return &Local{ + eng: cfg.Engine, + state: StateIdle, + provider: cfg.Provider, + model: cfg.Model, + sessionID: cfg.SessionID, + store: cfg.Store, + incognito: cfg.Incognito, + createdAt: time.Now(), + logger: logger, + } +} + +// SessionID returns the persistent identifier for this session. +func (s *Local) SessionID() string { + return s.sessionID } func (s *Local) Send(input string) error { @@ -74,15 +109,40 @@ func (s *Local) SendWithOptions(input string, opts engine.TurnOptions) error { s.mu.Lock() s.turn = turn s.err = err + var finalState SessionState if err != nil && ctx.Err() != nil { s.state = StateCancelled + finalState = StateCancelled } else if err != nil { s.state = StateError + finalState = StateError } else { s.state = StateIdle + finalState = StateIdle } s.mu.Unlock() + // Auto-save after successful turn (outside lock to avoid holding it during I/O) + if finalState == StateIdle && s.store != nil && (s.incognito == nil || s.incognito.ShouldPersist()) { + snap := Snapshot{ + ID: s.sessionID, + Metadata: Metadata{ + ID: s.sessionID, + Provider: s.provider, + Model: s.model, + TurnCount: s.turnCount, + Usage: s.eng.Usage(), + CreatedAt: s.createdAt, + UpdatedAt: time.Now(), + MessageCount: len(s.eng.History()), + }, + Messages: s.eng.History(), + } + if saveErr := s.store.Save(snap); saveErr != nil { + s.logger.Warn("session auto-save failed", "error", saveErr) + } + } + close(s.events) }() diff --git a/internal/session/session.go b/internal/session/session.go index 41a59a5..198fb8f 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -66,4 +66,6 @@ type Session interface { Close() error // Status returns current session state. Status() Status + // SessionID returns the persistent identifier for this session. + SessionID() string } diff --git a/internal/session/session_test.go b/internal/session/session_test.go index da0a61c..d75ad47 100644 --- a/internal/session/session_test.go +++ b/internal/session/session_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "log/slog" "testing" "time" @@ -65,7 +66,7 @@ func TestLocal_SendAndReceive(t *testing.T) { } eng, _ := engine.New(engine.Config{Provider: mp, Tools: tool.NewRegistry()}) - sess := NewLocal(eng, "test", "mock-model") + sess := NewLocal(LocalConfig{Engine: eng, Provider: "test", Model: "mock-model"}) // Initial state status := sess.Status() @@ -120,7 +121,7 @@ func TestLocal_SendWhileBusy(t *testing.T) { } eng, _ := engine.New(engine.Config{Provider: mp, Tools: tool.NewRegistry()}) - sess := NewLocal(eng, "test", "model") + sess := NewLocal(LocalConfig{Engine: eng, Provider: "test", Model: "model"}) sess.Send("first") @@ -147,7 +148,7 @@ func TestLocal_Cancel(t *testing.T) { } eng, _ := engine.New(engine.Config{Provider: mp, Tools: tool.NewRegistry()}) - sess := NewLocal(eng, "test", "model") + sess := NewLocal(LocalConfig{Engine: eng, Provider: "test", Model: "model"}) sess.Send("slow task") @@ -170,7 +171,7 @@ func TestLocal_Cancel(t *testing.T) { func TestLocal_Close(t *testing.T) { mp := &mockProvider{name: "test"} eng, _ := engine.New(engine.Config{Provider: mp, Tools: tool.NewRegistry()}) - sess := NewLocal(eng, "test", "model") + sess := NewLocal(LocalConfig{Engine: eng, Provider: "test", Model: "model"}) if err := sess.Close(); err != nil { t.Fatalf("Close: %v", err) @@ -198,7 +199,7 @@ func TestLocal_StatusTracking(t *testing.T) { } eng, _ := engine.New(engine.Config{Provider: mp, Tools: tool.NewRegistry()}) - sess := NewLocal(eng, "test", "mock-model") + sess := NewLocal(LocalConfig{Engine: eng, Provider: "test", Model: "mock-model"}) // Turn 1 sess.Send("one") @@ -246,5 +247,84 @@ func (s *slowStream) Close() error { return nil } // Ensure Local implements Session interface var _ Session = (*Local)(nil) +func TestLocal_AutoSave(t *testing.T) { + mp := &mockProvider{ + name: "test", + streams: []stream.Stream{ + newEventStream(message.StopEndTurn, + stream.Event{Type: stream.EventTextDelta, Text: "saved!"}, + ), + }, + } + + eng, _ := engine.New(engine.Config{Provider: mp, Tools: tool.NewRegistry()}) + store := NewSessionStore(t.TempDir(), 10, slog.Default()) + sess := NewLocal(LocalConfig{ + Engine: eng, + Provider: "test", + Model: "mock-model", + SessionID: "test-session-001", + Store: store, + }) + + if err := sess.Send("hello"); err != nil { + t.Fatalf("Send: %v", err) + } + for range sess.Events() { + } + + snap, err := store.Load("test-session-001") + if err != nil { + t.Fatalf("Load: %v", err) + } + if snap.ID != "test-session-001" { + t.Errorf("snap.ID = %q, want %q", snap.ID, "test-session-001") + } + if snap.Metadata.Provider != "test" { + t.Errorf("snap.Metadata.Provider = %q, want %q", snap.Metadata.Provider, "test") + } + if snap.Metadata.TurnCount != 1 { + t.Errorf("snap.Metadata.TurnCount = %d, want 1", snap.Metadata.TurnCount) + } + if len(snap.Messages) == 0 { + t.Error("snap.Messages should not be empty after a turn") + } +} + +func TestLocal_AutoSave_SkipsWhenNoStore(t *testing.T) { + mp := &mockProvider{ + name: "test", + streams: []stream.Stream{ + newEventStream(message.StopEndTurn, + stream.Event{Type: stream.EventTextDelta, Text: "ok"}, + ), + }, + } + + eng, _ := engine.New(engine.Config{Provider: mp, Tools: tool.NewRegistry()}) + // No store — must not panic + sess := NewLocal(LocalConfig{Engine: eng, Provider: "test", Model: "mock-model"}) + + if err := sess.Send("hello"); err != nil { + t.Fatalf("Send: %v", err) + } + for range sess.Events() { + } + + status := sess.Status() + if status.State != StateIdle { + t.Errorf("state = %s, want idle", status.State) + } +} + +func TestLocal_SessionID(t *testing.T) { + mp := &mockProvider{name: "test"} + eng, _ := engine.New(engine.Config{Provider: mp, Tools: tool.NewRegistry()}) + sess := NewLocal(LocalConfig{Engine: eng, Provider: "test", Model: "m", SessionID: "my-id"}) + if sess.SessionID() != "my-id" { + t.Errorf("SessionID() = %q, want %q", sess.SessionID(), "my-id") + } +} + // Suppress unused import var _ = json.Marshal