feat(engine): M8 cleanup — Wave A wiring gaps
- Remove stale TODO(P0c) comment from main.go (resolved by P0c tier routing) - Wire config.Provider.Temperature → engine.Config.Temperature → provider.Request - Add WithMaxFileSize option to fs.write; wire cfg.Tools.MaxFileSize in main.go - Wire router.ReportOutcome after each runLoop return (success = err == nil) - Fix nil-callback guard on EventRouting dispatch (pre-existing bug exposed by new test)
This commit is contained in:
@@ -176,6 +176,9 @@ func main() {
|
||||
|
||||
// Create tool registry
|
||||
reg := buildToolRegistry()
|
||||
if cfg.Tools.MaxFileSize > 0 {
|
||||
reg.Register(fs.NewWriteTool(fs.WithMaxFileSize(cfg.Tools.MaxFileSize)))
|
||||
}
|
||||
|
||||
// Harvest shell aliases
|
||||
aliases, err := bash.HarvestAliases(context.Background())
|
||||
@@ -290,9 +293,6 @@ func main() {
|
||||
}
|
||||
|
||||
// Discover CLI agents (claude, gemini, vibe) and register as arms.
|
||||
// TODO(P0c): CLI arms have cost=0 and ToolUse=true, so the router currently
|
||||
// always prefers them over API arms. Tier-based routing (subprocess > local > API)
|
||||
// needs to be explicit in the selector, not implicit through cost.
|
||||
cliAgents := subprocprov.DiscoverCLIAgents(context.Background())
|
||||
for _, agent := range cliAgents {
|
||||
cliArmID := router.NewArmID("subprocess", agent.Name)
|
||||
@@ -561,6 +561,7 @@ func main() {
|
||||
Context: ctxWindow,
|
||||
System: systemPrompt,
|
||||
Model: *model,
|
||||
Temperature: cfg.Provider.Temperature,
|
||||
MaxTurns: *maxTurns,
|
||||
Store: store,
|
||||
Hooks: dispatcher,
|
||||
|
||||
@@ -155,6 +155,41 @@ func TestBuildRequest_AllowedToolsFilter(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRequest_Temperature(t *testing.T) {
|
||||
temp := 0.7
|
||||
e, err := New(Config{
|
||||
Provider: &mockProvider{name: "test"},
|
||||
Tools: tool.NewRegistry(),
|
||||
Temperature: &temp,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
|
||||
req := e.buildRequest(context.Background())
|
||||
if req.Temperature == nil {
|
||||
t.Fatal("expected Temperature in request, got nil")
|
||||
}
|
||||
if *req.Temperature != temp {
|
||||
t.Errorf("Temperature = %v, want %v", *req.Temperature, temp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRequest_TemperatureNilWhenNotSet(t *testing.T) {
|
||||
e, err := New(Config{
|
||||
Provider: &mockProvider{name: "test"},
|
||||
Tools: tool.NewRegistry(),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
|
||||
req := e.buildRequest(context.Background())
|
||||
if req.Temperature != nil {
|
||||
t.Errorf("expected nil Temperature, got %v", *req.Temperature)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRequest_MultiArmRouting_IncludesTools(t *testing.T) {
|
||||
rtr := router.New(router.Config{})
|
||||
rtr.RegisterArm(&router.Arm{
|
||||
|
||||
@@ -18,18 +18,19 @@ import (
|
||||
|
||||
// Config holds engine configuration.
|
||||
type Config struct {
|
||||
Provider provider.Provider // direct provider (used if Router is nil)
|
||||
Router *router.Router // nil = use Provider directly
|
||||
Tools *tool.Registry
|
||||
Firewall *security.Firewall // nil = no scanning
|
||||
Permissions *permission.Checker // nil = allow all
|
||||
Context *gnomactx.Window // nil = no compaction
|
||||
System string // system prompt
|
||||
Model string // override model (empty = provider default)
|
||||
MaxTurns int // safety limit on tool loops (0 = unlimited)
|
||||
Store *persist.Store // nil = no result persistence
|
||||
Hooks *hook.Dispatcher // nil = no hooks
|
||||
Logger *slog.Logger
|
||||
Provider provider.Provider // direct provider (used if Router is nil)
|
||||
Router *router.Router // nil = use Provider directly
|
||||
Tools *tool.Registry
|
||||
Firewall *security.Firewall // nil = no scanning
|
||||
Permissions *permission.Checker // nil = allow all
|
||||
Context *gnomactx.Window // nil = no compaction
|
||||
System string // system prompt
|
||||
Model string // override model (empty = provider default)
|
||||
Temperature *float64 // nil = provider default
|
||||
MaxTurns int // safety limit on tool loops (0 = unlimited)
|
||||
Store *persist.Store // nil = no result persistence
|
||||
Hooks *hook.Dispatcher // nil = no hooks
|
||||
Logger *slog.Logger
|
||||
}
|
||||
|
||||
func (c Config) validate() error {
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
gnomactx "somegit.dev/Owlibou/gnoma/internal/context"
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||||
"somegit.dev/Owlibou/gnoma/internal/router"
|
||||
"somegit.dev/Owlibou/gnoma/internal/stream"
|
||||
"somegit.dev/Owlibou/gnoma/internal/tool"
|
||||
)
|
||||
@@ -577,3 +578,51 @@ func TestSubmit_CumulativeUsage(t *testing.T) {
|
||||
t.Errorf("cumulative OutputTokens = %d, want 130", e.Usage().OutputTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubmit_ReportsOutcomeToRouter(t *testing.T) {
|
||||
rtr := router.New(router.Config{})
|
||||
armID := router.NewArmID("test", "mock-model")
|
||||
|
||||
makeStream := func() stream.Stream {
|
||||
return newEventStream(message.StopEndTurn, "mock-model",
|
||||
stream.Event{Type: stream.EventTextDelta, Text: "hi"},
|
||||
stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 10, OutputTokens: 5}},
|
||||
)
|
||||
}
|
||||
mp := &mockProvider{
|
||||
name: "test",
|
||||
streams: []stream.Stream{makeStream(), makeStream(), makeStream()},
|
||||
}
|
||||
rtr.RegisterArm(&router.Arm{
|
||||
ID: armID,
|
||||
Provider: mp,
|
||||
ModelName: "mock-model",
|
||||
Capabilities: provider.Capabilities{ToolUse: true},
|
||||
})
|
||||
rtr.ForceArm(armID)
|
||||
|
||||
e, err := New(Config{
|
||||
Provider: mp,
|
||||
Router: rtr,
|
||||
Tools: tool.NewRegistry(),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
for i := 0; i < 3; i++ {
|
||||
if _, err := e.Submit(ctx, "hello", nil); err != nil {
|
||||
t.Fatalf("Submit %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
taskType := router.ClassifyTask("hello").Type
|
||||
score, hasData := rtr.QualityTracker().Quality(armID, taskType)
|
||||
if !hasData {
|
||||
t.Fatal("expected quality data after 3 successful turns — ReportOutcome may not be wired")
|
||||
}
|
||||
if score < 0.9 {
|
||||
t.Errorf("quality score = %f, want ≥0.9 for all successful turns", score)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,12 +54,30 @@ func (e *Engine) SubmitMessages(ctx context.Context, msgs []message.Message, cb
|
||||
|
||||
func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
|
||||
turn := &Turn{}
|
||||
loopStart := time.Now()
|
||||
var lastArmID router.ArmID
|
||||
var lastTaskType router.TaskType
|
||||
|
||||
reportOutcome := func(err error) {
|
||||
if e.cfg.Router == nil || lastArmID == "" {
|
||||
return
|
||||
}
|
||||
e.cfg.Router.ReportOutcome(router.Outcome{
|
||||
ArmID: lastArmID,
|
||||
TaskType: lastTaskType,
|
||||
Success: err == nil,
|
||||
Tokens: int(turn.Usage.InputTokens + turn.Usage.OutputTokens),
|
||||
Duration: time.Since(loopStart),
|
||||
})
|
||||
}
|
||||
|
||||
for {
|
||||
turn.Rounds++
|
||||
if e.cfg.MaxTurns > 0 && turn.Rounds > e.cfg.MaxTurns {
|
||||
e.cfg.Hooks.Fire(hook.Stop, hook.MarshalStopPayload("max_turns")) //nolint:errcheck
|
||||
return turn, fmt.Errorf("safety limit: %d rounds exceeded", e.cfg.MaxTurns)
|
||||
err := fmt.Errorf("safety limit: %d rounds exceeded", e.cfg.MaxTurns)
|
||||
reportOutcome(err)
|
||||
return turn, err
|
||||
}
|
||||
|
||||
// Build provider request (gates tools on model capabilities)
|
||||
@@ -94,6 +112,8 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
|
||||
|
||||
s, decision, err = e.cfg.Router.Stream(ctx, task, req)
|
||||
if decision.Arm != nil {
|
||||
lastArmID = decision.Arm.ID
|
||||
lastTaskType = task.Type
|
||||
e.logger.Debug("streaming request",
|
||||
"provider", decision.Arm.Provider.Name(),
|
||||
"model", decision.Arm.ModelName,
|
||||
@@ -102,7 +122,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
|
||||
"tools", len(req.Tools),
|
||||
"round", turn.Rounds,
|
||||
)
|
||||
if turn.Rounds == 1 {
|
||||
if turn.Rounds == 1 && cb != nil {
|
||||
cb(stream.Event{
|
||||
Type: stream.EventRouting,
|
||||
RoutingModel: string(decision.Arm.ID),
|
||||
@@ -149,7 +169,9 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
|
||||
s, err = e.handleRequestTooLarge(ctx, err, req)
|
||||
if err != nil {
|
||||
decision.Rollback()
|
||||
return nil, fmt.Errorf("provider stream: %w", err)
|
||||
streamErr := fmt.Errorf("provider stream: %w", err)
|
||||
reportOutcome(streamErr)
|
||||
return nil, streamErr
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -192,7 +214,9 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
|
||||
)
|
||||
s.Close()
|
||||
decision.Rollback()
|
||||
return nil, e.annotateStreamError(err, len(req.Tools))
|
||||
streamErr := e.annotateStreamError(err, len(req.Tools))
|
||||
reportOutcome(streamErr)
|
||||
return nil, streamErr
|
||||
}
|
||||
s.Close()
|
||||
|
||||
@@ -243,6 +267,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
|
||||
switch resp.StopReason {
|
||||
case message.StopEndTurn, message.StopSequence:
|
||||
e.cfg.Hooks.Fire(hook.Stop, hook.MarshalStopPayload("end_turn")) //nolint:errcheck
|
||||
reportOutcome(nil)
|
||||
return turn, nil
|
||||
|
||||
case message.StopMaxTokens:
|
||||
@@ -258,7 +283,9 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
|
||||
case message.StopToolUse:
|
||||
results, err := e.executeTools(ctx, resp.Message.ToolCalls(), cb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tool execution: %w", err)
|
||||
toolErr := fmt.Errorf("tool execution: %w", err)
|
||||
reportOutcome(toolErr)
|
||||
return nil, toolErr
|
||||
}
|
||||
toolMsg := message.NewToolResults(results...)
|
||||
turn.Messages = append(turn.Messages, toolMsg)
|
||||
@@ -271,6 +298,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
|
||||
default:
|
||||
// Unknown stop reason or empty — treat as end of turn
|
||||
e.cfg.Hooks.Fire(hook.Stop, hook.MarshalStopPayload("unknown")) //nolint:errcheck
|
||||
reportOutcome(nil)
|
||||
return turn, nil
|
||||
}
|
||||
}
|
||||
@@ -298,6 +326,7 @@ func (e *Engine) buildRequest(ctx context.Context) provider.Request {
|
||||
SystemPrompt: systemPrompt,
|
||||
Messages: messages,
|
||||
ToolChoice: e.turnOpts.ToolChoice,
|
||||
Temperature: e.cfg.Temperature,
|
||||
}
|
||||
|
||||
// Only include tools if the model supports them.
|
||||
|
||||
@@ -167,6 +167,51 @@ func TestWriteTool_OverwriteExisting(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteTool_MaxFileSize_Rejected(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.txt")
|
||||
w := NewWriteTool(WithMaxFileSize(5))
|
||||
|
||||
result, err := w.Execute(context.Background(), mustJSON(t, writeArgs{Path: path, Content: "hello world"}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if !strings.Contains(result.Output, "too large") {
|
||||
t.Errorf("Output = %q, want rejection message containing 'too large'", result.Output)
|
||||
}
|
||||
if _, statErr := os.Stat(path); !os.IsNotExist(statErr) {
|
||||
t.Error("file should not be created when content exceeds max size")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteTool_MaxFileSize_Accepted(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.txt")
|
||||
w := NewWriteTool(WithMaxFileSize(100))
|
||||
|
||||
_, err := w.Execute(context.Background(), mustJSON(t, writeArgs{Path: path, Content: "hello"}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if _, statErr := os.Stat(path); os.IsNotExist(statErr) {
|
||||
t.Error("file should be created when content is within limit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteTool_MaxFileSize_ZeroMeansNoLimit(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.txt")
|
||||
w := NewWriteTool(WithMaxFileSize(0))
|
||||
|
||||
_, err := w.Execute(context.Background(), mustJSON(t, writeArgs{Path: path, Content: "hello world"}))
|
||||
if err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
if _, statErr := os.Stat(path); os.IsNotExist(statErr) {
|
||||
t.Error("file should be created when max size is 0 (no limit)")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Edit ---
|
||||
|
||||
func TestEditTool_Interface(t *testing.T) {
|
||||
|
||||
@@ -27,9 +27,24 @@ var writeParams = json.RawMessage(`{
|
||||
"required": ["path", "content"]
|
||||
}`)
|
||||
|
||||
type WriteTool struct{}
|
||||
type WriteOption func(*WriteTool)
|
||||
|
||||
func NewWriteTool() *WriteTool { return &WriteTool{} }
|
||||
// WithMaxFileSize rejects writes where the content exceeds n bytes. 0 means no limit.
|
||||
func WithMaxFileSize(n int64) WriteOption {
|
||||
return func(t *WriteTool) { t.maxFileSize = n }
|
||||
}
|
||||
|
||||
type WriteTool struct {
|
||||
maxFileSize int64
|
||||
}
|
||||
|
||||
func NewWriteTool(opts ...WriteOption) *WriteTool {
|
||||
t := &WriteTool{}
|
||||
for _, opt := range opts {
|
||||
opt(t)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *WriteTool) Name() string { return writeToolName }
|
||||
func (t *WriteTool) Description() string { return "Write content to a file, creating parent directories as needed" }
|
||||
@@ -51,6 +66,10 @@ func (t *WriteTool) Execute(_ context.Context, args json.RawMessage) (tool.Resul
|
||||
return tool.Result{}, fmt.Errorf("fs.write: path required")
|
||||
}
|
||||
|
||||
if t.maxFileSize > 0 && int64(len(a.Content)) > t.maxFileSize {
|
||||
return tool.Result{Output: fmt.Sprintf("Error: content too large (%d bytes, limit %d bytes)", len(a.Content), t.maxFileSize)}, nil
|
||||
}
|
||||
|
||||
// Create parent directories
|
||||
dir := filepath.Dir(a.Path)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
|
||||
Reference in New Issue
Block a user