feat: Engine.SetHistory/SetUsage/SetActivatedTools for session restore

This commit is contained in:
2026-04-05 23:39:38 +02:00
parent bbd7791428
commit 20fb045cba
2 changed files with 272 additions and 0 deletions

View File

@@ -154,6 +154,26 @@ func (e *Engine) SetModel(model string) {
e.cfg.Model = model
}
// SetHistory replaces the conversation history (for session restore).
// Also syncs the context window and re-estimates the tracker's token count.
func (e *Engine) SetHistory(msgs []message.Message) {
e.history = msgs
if e.cfg.Context != nil {
e.cfg.Context.SetMessages(msgs)
e.cfg.Context.Tracker().Set(e.cfg.Context.Tracker().CountMessages(msgs))
}
}
// SetUsage sets cumulative token usage (for session restore).
func (e *Engine) SetUsage(u message.Usage) {
e.usage = u
}
// SetActivatedTools restores the set of activated deferred tools (for session restore).
func (e *Engine) SetActivatedTools(tools map[string]bool) {
e.activatedTools = tools
}
// Reset clears conversation history and usage.
func (e *Engine) Reset() {
e.history = nil

View File

@@ -0,0 +1,252 @@
package engine
import (
"context"
"encoding/json"
"testing"
gnomactx "somegit.dev/Owlibou/gnoma/internal/context"
"somegit.dev/Owlibou/gnoma/internal/message"
"somegit.dev/Owlibou/gnoma/internal/stream"
"somegit.dev/Owlibou/gnoma/internal/tool"
)
// deferredMockTool implements tool.Tool and tool.DeferrableTool.
type deferredMockTool struct {
name string
}
func (d *deferredMockTool) Name() string { return d.name }
func (d *deferredMockTool) Description() string { return "deferred mock" }
func (d *deferredMockTool) Parameters() json.RawMessage { return json.RawMessage(`{"type":"object"}`) }
func (d *deferredMockTool) IsReadOnly() bool { return true }
func (d *deferredMockTool) IsDestructive() bool { return false }
func (d *deferredMockTool) ShouldDefer() bool { return true }
func (d *deferredMockTool) Execute(_ context.Context, _ json.RawMessage) (tool.Result, error) {
return tool.Result{Output: "deferred output"}, nil
}
func TestSetHistory_ReplacesHistory(t *testing.T) {
e, _ := New(Config{
Provider: &mockProvider{name: "test"},
Tools: tool.NewRegistry(),
})
msgs := []message.Message{
message.NewUserText("hello"),
message.NewAssistantText("hi there"),
}
e.SetHistory(msgs)
got := e.History()
if len(got) != 2 {
t.Fatalf("History() len = %d, want 2", len(got))
}
if got[0].Role != message.RoleUser {
t.Errorf("History()[0].Role = %q, want user", got[0].Role)
}
if got[1].Role != message.RoleAssistant {
t.Errorf("History()[1].Role = %q, want assistant", got[1].Role)
}
}
func TestSetHistory_OverwritesPreviousHistory(t *testing.T) {
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
newEventStream(message.StopEndTurn, "",
stream.Event{Type: stream.EventTextDelta, Text: "original"},
),
},
}
e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry()})
e.Submit(context.Background(), "first message", nil)
if len(e.History()) == 0 {
t.Fatal("history should not be empty after Submit")
}
replacement := []message.Message{
message.NewUserText("restored message"),
}
e.SetHistory(replacement)
got := e.History()
if len(got) != 1 {
t.Fatalf("History() len = %d, want 1 after restore", len(got))
}
if got[0].TextContent() != "restored message" {
t.Errorf("History()[0].TextContent() = %q, want %q", got[0].TextContent(), "restored message")
}
}
func TestSetHistory_SyncsContextWindow(t *testing.T) {
ctxWindow := gnomactx.NewWindow(gnomactx.WindowConfig{MaxTokens: 200_000})
e, _ := New(Config{
Provider: &mockProvider{name: "test"},
Tools: tool.NewRegistry(),
Context: ctxWindow,
})
msgs := []message.Message{
message.NewUserText("user turn"),
message.NewAssistantText("assistant turn"),
}
e.SetHistory(msgs)
all := e.ContextWindow().AllMessages()
if len(all) != 2 {
t.Fatalf("ContextWindow().AllMessages() len = %d, want 2", len(all))
}
if all[0].TextContent() != "user turn" {
t.Errorf("AllMessages()[0].TextContent() = %q, want %q", all[0].TextContent(), "user turn")
}
}
func TestSetHistory_SyncsTrackerTokenCount(t *testing.T) {
ctxWindow := gnomactx.NewWindow(gnomactx.WindowConfig{MaxTokens: 200_000})
e, _ := New(Config{
Provider: &mockProvider{name: "test"},
Tools: tool.NewRegistry(),
Context: ctxWindow,
})
// Start with zero tracker usage.
if ctxWindow.Tracker().Used() != 0 {
t.Fatal("tracker should start at zero")
}
msgs := []message.Message{
message.NewUserText("hello world"),
}
e.SetHistory(msgs)
// After SetHistory, tracker should reflect a non-zero estimate.
used := ctxWindow.Tracker().Used()
if used == 0 {
t.Error("tracker should be non-zero after SetHistory with messages")
}
}
func TestSetHistory_NilContextWindow_NoPanic(t *testing.T) {
e, _ := New(Config{
Provider: &mockProvider{name: "test"},
Tools: tool.NewRegistry(),
// Context intentionally nil
})
msgs := []message.Message{message.NewUserText("safe")}
// Should not panic when no context window is configured.
e.SetHistory(msgs)
if len(e.History()) != 1 {
t.Errorf("History() len = %d, want 1", len(e.History()))
}
}
func TestSetUsage_ReplacesUsage(t *testing.T) {
e, _ := New(Config{
Provider: &mockProvider{name: "test"},
Tools: tool.NewRegistry(),
})
u := message.Usage{InputTokens: 500, OutputTokens: 250}
e.SetUsage(u)
got := e.Usage()
if got.InputTokens != 500 {
t.Errorf("Usage().InputTokens = %d, want 500", got.InputTokens)
}
if got.OutputTokens != 250 {
t.Errorf("Usage().OutputTokens = %d, want 250", got.OutputTokens)
}
}
func TestSetUsage_OverwritesPreviousUsage(t *testing.T) {
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
newEventStream(message.StopEndTurn, "",
stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 100, OutputTokens: 50}},
stream.Event{Type: stream.EventTextDelta, Text: "hi"},
),
},
}
e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry()})
e.Submit(context.Background(), "hello", nil)
if e.Usage().InputTokens == 0 {
t.Fatal("usage should be non-zero after Submit")
}
restored := message.Usage{InputTokens: 999, OutputTokens: 111}
e.SetUsage(restored)
got := e.Usage()
if got.InputTokens != 999 {
t.Errorf("Usage().InputTokens = %d, want 999", got.InputTokens)
}
if got.OutputTokens != 111 {
t.Errorf("Usage().OutputTokens = %d, want 111", got.OutputTokens)
}
}
func TestSetActivatedTools_DeferredToolIncludedInRequest(t *testing.T) {
reg := tool.NewRegistry()
reg.Register(&deferredMockTool{name: "bash"})
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
newEventStream(message.StopEndTurn, "mock-model",
stream.Event{Type: stream.EventTextDelta, Text: "done"},
),
},
}
e, _ := New(Config{Provider: mp, Tools: reg})
// Before activation: buildRequest should omit "bash" (deferred).
reqBefore := e.buildRequest(context.Background())
for _, td := range reqBefore.Tools {
if td.Name == "bash" {
t.Fatal("deferred tool 'bash' should not appear in request before activation")
}
}
// Restore activated tools.
e.SetActivatedTools(map[string]bool{"bash": true})
// After activation: buildRequest should include "bash".
reqAfter := e.buildRequest(context.Background())
found := false
for _, td := range reqAfter.Tools {
if td.Name == "bash" {
found = true
break
}
}
if !found {
t.Error("deferred tool 'bash' should appear in request after SetActivatedTools")
}
}
func TestSetActivatedTools_EmptyMap_DeactivatesAll(t *testing.T) {
reg := tool.NewRegistry()
reg.Register(&deferredMockTool{name: "bash"})
mp := &mockProvider{name: "test"}
e, _ := New(Config{Provider: mp, Tools: reg})
// Manually activate, then restore to empty.
e.activatedTools["bash"] = true
e.SetActivatedTools(map[string]bool{})
req := e.buildRequest(context.Background())
for _, td := range req.Tools {
if td.Name == "bash" {
t.Error("deferred tool 'bash' should not appear after SetActivatedTools(empty)")
}
}
}