From fc465e5f294b44a67316e3ea341abd5dcc69abd6 Mon Sep 17 00:00:00 2001 From: vikingowl Date: Thu, 7 May 2026 15:22:22 +0200 Subject: [PATCH] =?UTF-8?q?feat(engine):=20M8=20cleanup=20=E2=80=94=20Wave?= =?UTF-8?q?=20A=20wiring=20gaps?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Remove stale TODO(P0c) comment from main.go (resolved by P0c tier routing) - Wire config.Provider.Temperature → engine.Config.Temperature → provider.Request - Add WithMaxFileSize option to fs.write; wire cfg.Tools.MaxFileSize in main.go - Wire router.ReportOutcome after each runLoop return (success = err == nil) - Fix nil-callback guard on EventRouting dispatch (pre-existing bug exposed by new test) --- cmd/gnoma/main.go | 7 ++-- internal/engine/buildrequest_test.go | 35 ++++++++++++++++++++ internal/engine/engine.go | 25 +++++++------- internal/engine/engine_test.go | 49 ++++++++++++++++++++++++++++ internal/engine/loop.go | 39 +++++++++++++++++++--- internal/tool/fs/fs_test.go | 45 +++++++++++++++++++++++++ internal/tool/fs/write.go | 23 +++++++++++-- 7 files changed, 201 insertions(+), 22 deletions(-) diff --git a/cmd/gnoma/main.go b/cmd/gnoma/main.go index 9933434..bf309fe 100644 --- a/cmd/gnoma/main.go +++ b/cmd/gnoma/main.go @@ -176,6 +176,9 @@ func main() { // Create tool registry reg := buildToolRegistry() + if cfg.Tools.MaxFileSize > 0 { + reg.Register(fs.NewWriteTool(fs.WithMaxFileSize(cfg.Tools.MaxFileSize))) + } // Harvest shell aliases aliases, err := bash.HarvestAliases(context.Background()) @@ -290,9 +293,6 @@ func main() { } // Discover CLI agents (claude, gemini, vibe) and register as arms. - // TODO(P0c): CLI arms have cost=0 and ToolUse=true, so the router currently - // always prefers them over API arms. Tier-based routing (subprocess > local > API) - // needs to be explicit in the selector, not implicit through cost. cliAgents := subprocprov.DiscoverCLIAgents(context.Background()) for _, agent := range cliAgents { cliArmID := router.NewArmID("subprocess", agent.Name) @@ -561,6 +561,7 @@ func main() { Context: ctxWindow, System: systemPrompt, Model: *model, + Temperature: cfg.Provider.Temperature, MaxTurns: *maxTurns, Store: store, Hooks: dispatcher, diff --git a/internal/engine/buildrequest_test.go b/internal/engine/buildrequest_test.go index ba7e9c2..b56418c 100644 --- a/internal/engine/buildrequest_test.go +++ b/internal/engine/buildrequest_test.go @@ -155,6 +155,41 @@ func TestBuildRequest_AllowedToolsFilter(t *testing.T) { } } +func TestBuildRequest_Temperature(t *testing.T) { + temp := 0.7 + e, err := New(Config{ + Provider: &mockProvider{name: "test"}, + Tools: tool.NewRegistry(), + Temperature: &temp, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + req := e.buildRequest(context.Background()) + if req.Temperature == nil { + t.Fatal("expected Temperature in request, got nil") + } + if *req.Temperature != temp { + t.Errorf("Temperature = %v, want %v", *req.Temperature, temp) + } +} + +func TestBuildRequest_TemperatureNilWhenNotSet(t *testing.T) { + e, err := New(Config{ + Provider: &mockProvider{name: "test"}, + Tools: tool.NewRegistry(), + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + req := e.buildRequest(context.Background()) + if req.Temperature != nil { + t.Errorf("expected nil Temperature, got %v", *req.Temperature) + } +} + func TestBuildRequest_MultiArmRouting_IncludesTools(t *testing.T) { rtr := router.New(router.Config{}) rtr.RegisterArm(&router.Arm{ diff --git a/internal/engine/engine.go b/internal/engine/engine.go index a76a38b..38ca6fe 100644 --- a/internal/engine/engine.go +++ b/internal/engine/engine.go @@ -18,18 +18,19 @@ import ( // Config holds engine configuration. type Config struct { - Provider provider.Provider // direct provider (used if Router is nil) - Router *router.Router // nil = use Provider directly - Tools *tool.Registry - Firewall *security.Firewall // nil = no scanning - Permissions *permission.Checker // nil = allow all - Context *gnomactx.Window // nil = no compaction - System string // system prompt - Model string // override model (empty = provider default) - MaxTurns int // safety limit on tool loops (0 = unlimited) - Store *persist.Store // nil = no result persistence - Hooks *hook.Dispatcher // nil = no hooks - Logger *slog.Logger + Provider provider.Provider // direct provider (used if Router is nil) + Router *router.Router // nil = use Provider directly + Tools *tool.Registry + Firewall *security.Firewall // nil = no scanning + Permissions *permission.Checker // nil = allow all + Context *gnomactx.Window // nil = no compaction + System string // system prompt + Model string // override model (empty = provider default) + Temperature *float64 // nil = provider default + MaxTurns int // safety limit on tool loops (0 = unlimited) + Store *persist.Store // nil = no result persistence + Hooks *hook.Dispatcher // nil = no hooks + Logger *slog.Logger } func (c Config) validate() error { diff --git a/internal/engine/engine_test.go b/internal/engine/engine_test.go index c440213..62e3dad 100644 --- a/internal/engine/engine_test.go +++ b/internal/engine/engine_test.go @@ -10,6 +10,7 @@ import ( gnomactx "somegit.dev/Owlibou/gnoma/internal/context" "somegit.dev/Owlibou/gnoma/internal/message" "somegit.dev/Owlibou/gnoma/internal/provider" + "somegit.dev/Owlibou/gnoma/internal/router" "somegit.dev/Owlibou/gnoma/internal/stream" "somegit.dev/Owlibou/gnoma/internal/tool" ) @@ -577,3 +578,51 @@ func TestSubmit_CumulativeUsage(t *testing.T) { t.Errorf("cumulative OutputTokens = %d, want 130", e.Usage().OutputTokens) } } + +func TestSubmit_ReportsOutcomeToRouter(t *testing.T) { + rtr := router.New(router.Config{}) + armID := router.NewArmID("test", "mock-model") + + makeStream := func() stream.Stream { + return newEventStream(message.StopEndTurn, "mock-model", + stream.Event{Type: stream.EventTextDelta, Text: "hi"}, + stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 10, OutputTokens: 5}}, + ) + } + mp := &mockProvider{ + name: "test", + streams: []stream.Stream{makeStream(), makeStream(), makeStream()}, + } + rtr.RegisterArm(&router.Arm{ + ID: armID, + Provider: mp, + ModelName: "mock-model", + Capabilities: provider.Capabilities{ToolUse: true}, + }) + rtr.ForceArm(armID) + + e, err := New(Config{ + Provider: mp, + Router: rtr, + Tools: tool.NewRegistry(), + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + ctx := context.Background() + for i := 0; i < 3; i++ { + if _, err := e.Submit(ctx, "hello", nil); err != nil { + t.Fatalf("Submit %d: %v", i, err) + } + } + + taskType := router.ClassifyTask("hello").Type + score, hasData := rtr.QualityTracker().Quality(armID, taskType) + if !hasData { + t.Fatal("expected quality data after 3 successful turns — ReportOutcome may not be wired") + } + if score < 0.9 { + t.Errorf("quality score = %f, want ≥0.9 for all successful turns", score) + } +} diff --git a/internal/engine/loop.go b/internal/engine/loop.go index 55b3157..37d6a9e 100644 --- a/internal/engine/loop.go +++ b/internal/engine/loop.go @@ -54,12 +54,30 @@ func (e *Engine) SubmitMessages(ctx context.Context, msgs []message.Message, cb func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) { turn := &Turn{} + loopStart := time.Now() + var lastArmID router.ArmID + var lastTaskType router.TaskType + + reportOutcome := func(err error) { + if e.cfg.Router == nil || lastArmID == "" { + return + } + e.cfg.Router.ReportOutcome(router.Outcome{ + ArmID: lastArmID, + TaskType: lastTaskType, + Success: err == nil, + Tokens: int(turn.Usage.InputTokens + turn.Usage.OutputTokens), + Duration: time.Since(loopStart), + }) + } for { turn.Rounds++ if e.cfg.MaxTurns > 0 && turn.Rounds > e.cfg.MaxTurns { e.cfg.Hooks.Fire(hook.Stop, hook.MarshalStopPayload("max_turns")) //nolint:errcheck - return turn, fmt.Errorf("safety limit: %d rounds exceeded", e.cfg.MaxTurns) + err := fmt.Errorf("safety limit: %d rounds exceeded", e.cfg.MaxTurns) + reportOutcome(err) + return turn, err } // Build provider request (gates tools on model capabilities) @@ -94,6 +112,8 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) { s, decision, err = e.cfg.Router.Stream(ctx, task, req) if decision.Arm != nil { + lastArmID = decision.Arm.ID + lastTaskType = task.Type e.logger.Debug("streaming request", "provider", decision.Arm.Provider.Name(), "model", decision.Arm.ModelName, @@ -102,7 +122,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) { "tools", len(req.Tools), "round", turn.Rounds, ) - if turn.Rounds == 1 { + if turn.Rounds == 1 && cb != nil { cb(stream.Event{ Type: stream.EventRouting, RoutingModel: string(decision.Arm.ID), @@ -149,7 +169,9 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) { s, err = e.handleRequestTooLarge(ctx, err, req) if err != nil { decision.Rollback() - return nil, fmt.Errorf("provider stream: %w", err) + streamErr := fmt.Errorf("provider stream: %w", err) + reportOutcome(streamErr) + return nil, streamErr } } } @@ -192,7 +214,9 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) { ) s.Close() decision.Rollback() - return nil, e.annotateStreamError(err, len(req.Tools)) + streamErr := e.annotateStreamError(err, len(req.Tools)) + reportOutcome(streamErr) + return nil, streamErr } s.Close() @@ -243,6 +267,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) { switch resp.StopReason { case message.StopEndTurn, message.StopSequence: e.cfg.Hooks.Fire(hook.Stop, hook.MarshalStopPayload("end_turn")) //nolint:errcheck + reportOutcome(nil) return turn, nil case message.StopMaxTokens: @@ -258,7 +283,9 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) { case message.StopToolUse: results, err := e.executeTools(ctx, resp.Message.ToolCalls(), cb) if err != nil { - return nil, fmt.Errorf("tool execution: %w", err) + toolErr := fmt.Errorf("tool execution: %w", err) + reportOutcome(toolErr) + return nil, toolErr } toolMsg := message.NewToolResults(results...) turn.Messages = append(turn.Messages, toolMsg) @@ -271,6 +298,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) { default: // Unknown stop reason or empty — treat as end of turn e.cfg.Hooks.Fire(hook.Stop, hook.MarshalStopPayload("unknown")) //nolint:errcheck + reportOutcome(nil) return turn, nil } } @@ -298,6 +326,7 @@ func (e *Engine) buildRequest(ctx context.Context) provider.Request { SystemPrompt: systemPrompt, Messages: messages, ToolChoice: e.turnOpts.ToolChoice, + Temperature: e.cfg.Temperature, } // Only include tools if the model supports them. diff --git a/internal/tool/fs/fs_test.go b/internal/tool/fs/fs_test.go index 9b9dd72..dccf02c 100644 --- a/internal/tool/fs/fs_test.go +++ b/internal/tool/fs/fs_test.go @@ -167,6 +167,51 @@ func TestWriteTool_OverwriteExisting(t *testing.T) { } } +func TestWriteTool_MaxFileSize_Rejected(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.txt") + w := NewWriteTool(WithMaxFileSize(5)) + + result, err := w.Execute(context.Background(), mustJSON(t, writeArgs{Path: path, Content: "hello world"})) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if !strings.Contains(result.Output, "too large") { + t.Errorf("Output = %q, want rejection message containing 'too large'", result.Output) + } + if _, statErr := os.Stat(path); !os.IsNotExist(statErr) { + t.Error("file should not be created when content exceeds max size") + } +} + +func TestWriteTool_MaxFileSize_Accepted(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.txt") + w := NewWriteTool(WithMaxFileSize(100)) + + _, err := w.Execute(context.Background(), mustJSON(t, writeArgs{Path: path, Content: "hello"})) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if _, statErr := os.Stat(path); os.IsNotExist(statErr) { + t.Error("file should be created when content is within limit") + } +} + +func TestWriteTool_MaxFileSize_ZeroMeansNoLimit(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.txt") + w := NewWriteTool(WithMaxFileSize(0)) + + _, err := w.Execute(context.Background(), mustJSON(t, writeArgs{Path: path, Content: "hello world"})) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if _, statErr := os.Stat(path); os.IsNotExist(statErr) { + t.Error("file should be created when max size is 0 (no limit)") + } +} + // --- Edit --- func TestEditTool_Interface(t *testing.T) { diff --git a/internal/tool/fs/write.go b/internal/tool/fs/write.go index 25ba916..9a4c6a1 100644 --- a/internal/tool/fs/write.go +++ b/internal/tool/fs/write.go @@ -27,9 +27,24 @@ var writeParams = json.RawMessage(`{ "required": ["path", "content"] }`) -type WriteTool struct{} +type WriteOption func(*WriteTool) -func NewWriteTool() *WriteTool { return &WriteTool{} } +// WithMaxFileSize rejects writes where the content exceeds n bytes. 0 means no limit. +func WithMaxFileSize(n int64) WriteOption { + return func(t *WriteTool) { t.maxFileSize = n } +} + +type WriteTool struct { + maxFileSize int64 +} + +func NewWriteTool(opts ...WriteOption) *WriteTool { + t := &WriteTool{} + for _, opt := range opts { + opt(t) + } + return t +} func (t *WriteTool) Name() string { return writeToolName } func (t *WriteTool) Description() string { return "Write content to a file, creating parent directories as needed" } @@ -51,6 +66,10 @@ func (t *WriteTool) Execute(_ context.Context, args json.RawMessage) (tool.Resul return tool.Result{}, fmt.Errorf("fs.write: path required") } + if t.maxFileSize > 0 && int64(len(a.Content)) > t.maxFileSize { + return tool.Result{Output: fmt.Sprintf("Error: content too large (%d bytes, limit %d bytes)", len(a.Content), t.maxFileSize)}, nil + } + // Create parent directories dir := filepath.Dir(a.Path) if err := os.MkdirAll(dir, 0o755); err != nil {