feat: LocalConfig + auto-save hook in session.Local
Refactor NewLocal to accept LocalConfig (matching engine/router patterns), add persistence fields (SessionID, Store, Incognito, Logger), capture finalState before releasing the lock to avoid data races, and auto-save a Snapshot after each successful turn when a store is configured. Add SessionID() to the Session interface and three new tests covering auto-save, no-store no-panic, and SessionID accessors.
This commit is contained in:
@@ -432,7 +432,11 @@ func main() {
|
|||||||
if armModel == "" {
|
if armModel == "" {
|
||||||
armModel = prov.DefaultModel()
|
armModel = prov.DefaultModel()
|
||||||
}
|
}
|
||||||
sess := session.NewLocal(eng, *providerName, armModel)
|
sess := session.NewLocal(session.LocalConfig{
|
||||||
|
Engine: eng,
|
||||||
|
Provider: *providerName,
|
||||||
|
Model: armModel,
|
||||||
|
})
|
||||||
defer sess.Close()
|
defer sess.Close()
|
||||||
|
|
||||||
m := tui.New(sess, tui.Config{
|
m := tui.New(sess, tui.Config{
|
||||||
|
|||||||
@@ -3,12 +3,26 @@ package session
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"somegit.dev/Owlibou/gnoma/internal/engine"
|
"somegit.dev/Owlibou/gnoma/internal/engine"
|
||||||
|
"somegit.dev/Owlibou/gnoma/internal/security"
|
||||||
"somegit.dev/Owlibou/gnoma/internal/stream"
|
"somegit.dev/Owlibou/gnoma/internal/stream"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// LocalConfig holds all configuration for a Local session.
|
||||||
|
type LocalConfig struct {
|
||||||
|
Engine *engine.Engine
|
||||||
|
Provider string
|
||||||
|
Model string
|
||||||
|
SessionID string // identifies this session on disk
|
||||||
|
Store *SessionStore // nil = no persistence
|
||||||
|
Incognito *security.IncognitoMode // nil = always persist
|
||||||
|
Logger *slog.Logger // nil = slog.Default()
|
||||||
|
}
|
||||||
|
|
||||||
// Local implements Session using goroutines and channels within the same process.
|
// Local implements Session using goroutines and channels within the same process.
|
||||||
type Local struct {
|
type Local struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
@@ -26,16 +40,37 @@ type Local struct {
|
|||||||
provider string
|
provider string
|
||||||
model string
|
model string
|
||||||
turnCount int
|
turnCount int
|
||||||
|
|
||||||
|
// Persistence
|
||||||
|
sessionID string
|
||||||
|
store *SessionStore
|
||||||
|
incognito *security.IncognitoMode
|
||||||
|
createdAt time.Time
|
||||||
|
logger *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewLocal creates a channel-based in-process session.
|
// NewLocal creates a channel-based in-process session.
|
||||||
func NewLocal(eng *engine.Engine, providerName, model string) *Local {
|
func NewLocal(cfg LocalConfig) *Local {
|
||||||
return &Local{
|
logger := cfg.Logger
|
||||||
eng: eng,
|
if logger == nil {
|
||||||
state: StateIdle,
|
logger = slog.Default()
|
||||||
provider: providerName,
|
|
||||||
model: model,
|
|
||||||
}
|
}
|
||||||
|
return &Local{
|
||||||
|
eng: cfg.Engine,
|
||||||
|
state: StateIdle,
|
||||||
|
provider: cfg.Provider,
|
||||||
|
model: cfg.Model,
|
||||||
|
sessionID: cfg.SessionID,
|
||||||
|
store: cfg.Store,
|
||||||
|
incognito: cfg.Incognito,
|
||||||
|
createdAt: time.Now(),
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SessionID returns the persistent identifier for this session.
|
||||||
|
func (s *Local) SessionID() string {
|
||||||
|
return s.sessionID
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Local) Send(input string) error {
|
func (s *Local) Send(input string) error {
|
||||||
@@ -74,15 +109,40 @@ func (s *Local) SendWithOptions(input string, opts engine.TurnOptions) error {
|
|||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
s.turn = turn
|
s.turn = turn
|
||||||
s.err = err
|
s.err = err
|
||||||
|
var finalState SessionState
|
||||||
if err != nil && ctx.Err() != nil {
|
if err != nil && ctx.Err() != nil {
|
||||||
s.state = StateCancelled
|
s.state = StateCancelled
|
||||||
|
finalState = StateCancelled
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
s.state = StateError
|
s.state = StateError
|
||||||
|
finalState = StateError
|
||||||
} else {
|
} else {
|
||||||
s.state = StateIdle
|
s.state = StateIdle
|
||||||
|
finalState = StateIdle
|
||||||
}
|
}
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
// Auto-save after successful turn (outside lock to avoid holding it during I/O)
|
||||||
|
if finalState == StateIdle && s.store != nil && (s.incognito == nil || s.incognito.ShouldPersist()) {
|
||||||
|
snap := Snapshot{
|
||||||
|
ID: s.sessionID,
|
||||||
|
Metadata: Metadata{
|
||||||
|
ID: s.sessionID,
|
||||||
|
Provider: s.provider,
|
||||||
|
Model: s.model,
|
||||||
|
TurnCount: s.turnCount,
|
||||||
|
Usage: s.eng.Usage(),
|
||||||
|
CreatedAt: s.createdAt,
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
|
MessageCount: len(s.eng.History()),
|
||||||
|
},
|
||||||
|
Messages: s.eng.History(),
|
||||||
|
}
|
||||||
|
if saveErr := s.store.Save(snap); saveErr != nil {
|
||||||
|
s.logger.Warn("session auto-save failed", "error", saveErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
close(s.events)
|
close(s.events)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|||||||
@@ -66,4 +66,6 @@ type Session interface {
|
|||||||
Close() error
|
Close() error
|
||||||
// Status returns current session state.
|
// Status returns current session state.
|
||||||
Status() Status
|
Status() Status
|
||||||
|
// SessionID returns the persistent identifier for this session.
|
||||||
|
SessionID() string
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -65,7 +66,7 @@ func TestLocal_SendAndReceive(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
eng, _ := engine.New(engine.Config{Provider: mp, Tools: tool.NewRegistry()})
|
eng, _ := engine.New(engine.Config{Provider: mp, Tools: tool.NewRegistry()})
|
||||||
sess := NewLocal(eng, "test", "mock-model")
|
sess := NewLocal(LocalConfig{Engine: eng, Provider: "test", Model: "mock-model"})
|
||||||
|
|
||||||
// Initial state
|
// Initial state
|
||||||
status := sess.Status()
|
status := sess.Status()
|
||||||
@@ -120,7 +121,7 @@ func TestLocal_SendWhileBusy(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
eng, _ := engine.New(engine.Config{Provider: mp, Tools: tool.NewRegistry()})
|
eng, _ := engine.New(engine.Config{Provider: mp, Tools: tool.NewRegistry()})
|
||||||
sess := NewLocal(eng, "test", "model")
|
sess := NewLocal(LocalConfig{Engine: eng, Provider: "test", Model: "model"})
|
||||||
|
|
||||||
sess.Send("first")
|
sess.Send("first")
|
||||||
|
|
||||||
@@ -147,7 +148,7 @@ func TestLocal_Cancel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
eng, _ := engine.New(engine.Config{Provider: mp, Tools: tool.NewRegistry()})
|
eng, _ := engine.New(engine.Config{Provider: mp, Tools: tool.NewRegistry()})
|
||||||
sess := NewLocal(eng, "test", "model")
|
sess := NewLocal(LocalConfig{Engine: eng, Provider: "test", Model: "model"})
|
||||||
|
|
||||||
sess.Send("slow task")
|
sess.Send("slow task")
|
||||||
|
|
||||||
@@ -170,7 +171,7 @@ func TestLocal_Cancel(t *testing.T) {
|
|||||||
func TestLocal_Close(t *testing.T) {
|
func TestLocal_Close(t *testing.T) {
|
||||||
mp := &mockProvider{name: "test"}
|
mp := &mockProvider{name: "test"}
|
||||||
eng, _ := engine.New(engine.Config{Provider: mp, Tools: tool.NewRegistry()})
|
eng, _ := engine.New(engine.Config{Provider: mp, Tools: tool.NewRegistry()})
|
||||||
sess := NewLocal(eng, "test", "model")
|
sess := NewLocal(LocalConfig{Engine: eng, Provider: "test", Model: "model"})
|
||||||
|
|
||||||
if err := sess.Close(); err != nil {
|
if err := sess.Close(); err != nil {
|
||||||
t.Fatalf("Close: %v", err)
|
t.Fatalf("Close: %v", err)
|
||||||
@@ -198,7 +199,7 @@ func TestLocal_StatusTracking(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
eng, _ := engine.New(engine.Config{Provider: mp, Tools: tool.NewRegistry()})
|
eng, _ := engine.New(engine.Config{Provider: mp, Tools: tool.NewRegistry()})
|
||||||
sess := NewLocal(eng, "test", "mock-model")
|
sess := NewLocal(LocalConfig{Engine: eng, Provider: "test", Model: "mock-model"})
|
||||||
|
|
||||||
// Turn 1
|
// Turn 1
|
||||||
sess.Send("one")
|
sess.Send("one")
|
||||||
@@ -246,5 +247,84 @@ func (s *slowStream) Close() error { return nil }
|
|||||||
// Ensure Local implements Session interface
|
// Ensure Local implements Session interface
|
||||||
var _ Session = (*Local)(nil)
|
var _ Session = (*Local)(nil)
|
||||||
|
|
||||||
|
func TestLocal_AutoSave(t *testing.T) {
|
||||||
|
mp := &mockProvider{
|
||||||
|
name: "test",
|
||||||
|
streams: []stream.Stream{
|
||||||
|
newEventStream(message.StopEndTurn,
|
||||||
|
stream.Event{Type: stream.EventTextDelta, Text: "saved!"},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
eng, _ := engine.New(engine.Config{Provider: mp, Tools: tool.NewRegistry()})
|
||||||
|
store := NewSessionStore(t.TempDir(), 10, slog.Default())
|
||||||
|
sess := NewLocal(LocalConfig{
|
||||||
|
Engine: eng,
|
||||||
|
Provider: "test",
|
||||||
|
Model: "mock-model",
|
||||||
|
SessionID: "test-session-001",
|
||||||
|
Store: store,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := sess.Send("hello"); err != nil {
|
||||||
|
t.Fatalf("Send: %v", err)
|
||||||
|
}
|
||||||
|
for range sess.Events() {
|
||||||
|
}
|
||||||
|
|
||||||
|
snap, err := store.Load("test-session-001")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load: %v", err)
|
||||||
|
}
|
||||||
|
if snap.ID != "test-session-001" {
|
||||||
|
t.Errorf("snap.ID = %q, want %q", snap.ID, "test-session-001")
|
||||||
|
}
|
||||||
|
if snap.Metadata.Provider != "test" {
|
||||||
|
t.Errorf("snap.Metadata.Provider = %q, want %q", snap.Metadata.Provider, "test")
|
||||||
|
}
|
||||||
|
if snap.Metadata.TurnCount != 1 {
|
||||||
|
t.Errorf("snap.Metadata.TurnCount = %d, want 1", snap.Metadata.TurnCount)
|
||||||
|
}
|
||||||
|
if len(snap.Messages) == 0 {
|
||||||
|
t.Error("snap.Messages should not be empty after a turn")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLocal_AutoSave_SkipsWhenNoStore(t *testing.T) {
|
||||||
|
mp := &mockProvider{
|
||||||
|
name: "test",
|
||||||
|
streams: []stream.Stream{
|
||||||
|
newEventStream(message.StopEndTurn,
|
||||||
|
stream.Event{Type: stream.EventTextDelta, Text: "ok"},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
eng, _ := engine.New(engine.Config{Provider: mp, Tools: tool.NewRegistry()})
|
||||||
|
// No store — must not panic
|
||||||
|
sess := NewLocal(LocalConfig{Engine: eng, Provider: "test", Model: "mock-model"})
|
||||||
|
|
||||||
|
if err := sess.Send("hello"); err != nil {
|
||||||
|
t.Fatalf("Send: %v", err)
|
||||||
|
}
|
||||||
|
for range sess.Events() {
|
||||||
|
}
|
||||||
|
|
||||||
|
status := sess.Status()
|
||||||
|
if status.State != StateIdle {
|
||||||
|
t.Errorf("state = %s, want idle", status.State)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLocal_SessionID(t *testing.T) {
|
||||||
|
mp := &mockProvider{name: "test"}
|
||||||
|
eng, _ := engine.New(engine.Config{Provider: mp, Tools: tool.NewRegistry()})
|
||||||
|
sess := NewLocal(LocalConfig{Engine: eng, Provider: "test", Model: "m", SessionID: "my-id"})
|
||||||
|
if sess.SessionID() != "my-id" {
|
||||||
|
t.Errorf("SessionID() = %q, want %q", sess.SessionID(), "my-id")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Suppress unused import
|
// Suppress unused import
|
||||||
var _ = json.Marshal
|
var _ = json.Marshal
|
||||||
|
|||||||
Reference in New Issue
Block a user