From 20fb045cba1a8e3bbffed9a8d21027865665c3f6 Mon Sep 17 00:00:00 2001 From: vikingowl Date: Sun, 5 Apr 2026 23:39:38 +0200 Subject: [PATCH] feat: Engine.SetHistory/SetUsage/SetActivatedTools for session restore --- internal/engine/engine.go | 20 +++ internal/engine/restore_test.go | 252 ++++++++++++++++++++++++++++++++ 2 files changed, 272 insertions(+) create mode 100644 internal/engine/restore_test.go diff --git a/internal/engine/engine.go b/internal/engine/engine.go index 18cb80b..b4560e6 100644 --- a/internal/engine/engine.go +++ b/internal/engine/engine.go @@ -154,6 +154,26 @@ func (e *Engine) SetModel(model string) { e.cfg.Model = model } +// SetHistory replaces the conversation history (for session restore). +// Also syncs the context window and re-estimates the tracker's token count. +func (e *Engine) SetHistory(msgs []message.Message) { + e.history = msgs + if e.cfg.Context != nil { + e.cfg.Context.SetMessages(msgs) + e.cfg.Context.Tracker().Set(e.cfg.Context.Tracker().CountMessages(msgs)) + } +} + +// SetUsage sets cumulative token usage (for session restore). +func (e *Engine) SetUsage(u message.Usage) { + e.usage = u +} + +// SetActivatedTools restores the set of activated deferred tools (for session restore). +func (e *Engine) SetActivatedTools(tools map[string]bool) { + e.activatedTools = tools +} + // Reset clears conversation history and usage. func (e *Engine) Reset() { e.history = nil diff --git a/internal/engine/restore_test.go b/internal/engine/restore_test.go new file mode 100644 index 0000000..9413edd --- /dev/null +++ b/internal/engine/restore_test.go @@ -0,0 +1,252 @@ +package engine + +import ( + "context" + "encoding/json" + "testing" + + gnomactx "somegit.dev/Owlibou/gnoma/internal/context" + "somegit.dev/Owlibou/gnoma/internal/message" + "somegit.dev/Owlibou/gnoma/internal/stream" + "somegit.dev/Owlibou/gnoma/internal/tool" +) + +// deferredMockTool implements tool.Tool and tool.DeferrableTool. +type deferredMockTool struct { + name string +} + +func (d *deferredMockTool) Name() string { return d.name } +func (d *deferredMockTool) Description() string { return "deferred mock" } +func (d *deferredMockTool) Parameters() json.RawMessage { return json.RawMessage(`{"type":"object"}`) } +func (d *deferredMockTool) IsReadOnly() bool { return true } +func (d *deferredMockTool) IsDestructive() bool { return false } +func (d *deferredMockTool) ShouldDefer() bool { return true } +func (d *deferredMockTool) Execute(_ context.Context, _ json.RawMessage) (tool.Result, error) { + return tool.Result{Output: "deferred output"}, nil +} + +func TestSetHistory_ReplacesHistory(t *testing.T) { + e, _ := New(Config{ + Provider: &mockProvider{name: "test"}, + Tools: tool.NewRegistry(), + }) + + msgs := []message.Message{ + message.NewUserText("hello"), + message.NewAssistantText("hi there"), + } + e.SetHistory(msgs) + + got := e.History() + if len(got) != 2 { + t.Fatalf("History() len = %d, want 2", len(got)) + } + if got[0].Role != message.RoleUser { + t.Errorf("History()[0].Role = %q, want user", got[0].Role) + } + if got[1].Role != message.RoleAssistant { + t.Errorf("History()[1].Role = %q, want assistant", got[1].Role) + } +} + +func TestSetHistory_OverwritesPreviousHistory(t *testing.T) { + mp := &mockProvider{ + name: "test", + streams: []stream.Stream{ + newEventStream(message.StopEndTurn, "", + stream.Event{Type: stream.EventTextDelta, Text: "original"}, + ), + }, + } + e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry()}) + e.Submit(context.Background(), "first message", nil) + + if len(e.History()) == 0 { + t.Fatal("history should not be empty after Submit") + } + + replacement := []message.Message{ + message.NewUserText("restored message"), + } + e.SetHistory(replacement) + + got := e.History() + if len(got) != 1 { + t.Fatalf("History() len = %d, want 1 after restore", len(got)) + } + if got[0].TextContent() != "restored message" { + t.Errorf("History()[0].TextContent() = %q, want %q", got[0].TextContent(), "restored message") + } +} + +func TestSetHistory_SyncsContextWindow(t *testing.T) { + ctxWindow := gnomactx.NewWindow(gnomactx.WindowConfig{MaxTokens: 200_000}) + e, _ := New(Config{ + Provider: &mockProvider{name: "test"}, + Tools: tool.NewRegistry(), + Context: ctxWindow, + }) + + msgs := []message.Message{ + message.NewUserText("user turn"), + message.NewAssistantText("assistant turn"), + } + e.SetHistory(msgs) + + all := e.ContextWindow().AllMessages() + if len(all) != 2 { + t.Fatalf("ContextWindow().AllMessages() len = %d, want 2", len(all)) + } + if all[0].TextContent() != "user turn" { + t.Errorf("AllMessages()[0].TextContent() = %q, want %q", all[0].TextContent(), "user turn") + } +} + +func TestSetHistory_SyncsTrackerTokenCount(t *testing.T) { + ctxWindow := gnomactx.NewWindow(gnomactx.WindowConfig{MaxTokens: 200_000}) + e, _ := New(Config{ + Provider: &mockProvider{name: "test"}, + Tools: tool.NewRegistry(), + Context: ctxWindow, + }) + + // Start with zero tracker usage. + if ctxWindow.Tracker().Used() != 0 { + t.Fatal("tracker should start at zero") + } + + msgs := []message.Message{ + message.NewUserText("hello world"), + } + e.SetHistory(msgs) + + // After SetHistory, tracker should reflect a non-zero estimate. + used := ctxWindow.Tracker().Used() + if used == 0 { + t.Error("tracker should be non-zero after SetHistory with messages") + } +} + +func TestSetHistory_NilContextWindow_NoPanic(t *testing.T) { + e, _ := New(Config{ + Provider: &mockProvider{name: "test"}, + Tools: tool.NewRegistry(), + // Context intentionally nil + }) + + msgs := []message.Message{message.NewUserText("safe")} + + // Should not panic when no context window is configured. + e.SetHistory(msgs) + + if len(e.History()) != 1 { + t.Errorf("History() len = %d, want 1", len(e.History())) + } +} + +func TestSetUsage_ReplacesUsage(t *testing.T) { + e, _ := New(Config{ + Provider: &mockProvider{name: "test"}, + Tools: tool.NewRegistry(), + }) + + u := message.Usage{InputTokens: 500, OutputTokens: 250} + e.SetUsage(u) + + got := e.Usage() + if got.InputTokens != 500 { + t.Errorf("Usage().InputTokens = %d, want 500", got.InputTokens) + } + if got.OutputTokens != 250 { + t.Errorf("Usage().OutputTokens = %d, want 250", got.OutputTokens) + } +} + +func TestSetUsage_OverwritesPreviousUsage(t *testing.T) { + mp := &mockProvider{ + name: "test", + streams: []stream.Stream{ + newEventStream(message.StopEndTurn, "", + stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 100, OutputTokens: 50}}, + stream.Event{Type: stream.EventTextDelta, Text: "hi"}, + ), + }, + } + e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry()}) + e.Submit(context.Background(), "hello", nil) + + if e.Usage().InputTokens == 0 { + t.Fatal("usage should be non-zero after Submit") + } + + restored := message.Usage{InputTokens: 999, OutputTokens: 111} + e.SetUsage(restored) + + got := e.Usage() + if got.InputTokens != 999 { + t.Errorf("Usage().InputTokens = %d, want 999", got.InputTokens) + } + if got.OutputTokens != 111 { + t.Errorf("Usage().OutputTokens = %d, want 111", got.OutputTokens) + } +} + +func TestSetActivatedTools_DeferredToolIncludedInRequest(t *testing.T) { + reg := tool.NewRegistry() + reg.Register(&deferredMockTool{name: "bash"}) + + mp := &mockProvider{ + name: "test", + streams: []stream.Stream{ + newEventStream(message.StopEndTurn, "mock-model", + stream.Event{Type: stream.EventTextDelta, Text: "done"}, + ), + }, + } + + e, _ := New(Config{Provider: mp, Tools: reg}) + + // Before activation: buildRequest should omit "bash" (deferred). + reqBefore := e.buildRequest(context.Background()) + for _, td := range reqBefore.Tools { + if td.Name == "bash" { + t.Fatal("deferred tool 'bash' should not appear in request before activation") + } + } + + // Restore activated tools. + e.SetActivatedTools(map[string]bool{"bash": true}) + + // After activation: buildRequest should include "bash". + reqAfter := e.buildRequest(context.Background()) + found := false + for _, td := range reqAfter.Tools { + if td.Name == "bash" { + found = true + break + } + } + if !found { + t.Error("deferred tool 'bash' should appear in request after SetActivatedTools") + } +} + +func TestSetActivatedTools_EmptyMap_DeactivatesAll(t *testing.T) { + reg := tool.NewRegistry() + reg.Register(&deferredMockTool{name: "bash"}) + + mp := &mockProvider{name: "test"} + e, _ := New(Config{Provider: mp, Tools: reg}) + + // Manually activate, then restore to empty. + e.activatedTools["bash"] = true + e.SetActivatedTools(map[string]bool{}) + + req := e.buildRequest(context.Background()) + for _, td := range req.Tools { + if td.Name == "bash" { + t.Error("deferred tool 'bash' should not appear after SetActivatedTools(empty)") + } + } +}