diff --git a/internal/engine/loop.go b/internal/engine/loop.go index 1f3cb06..7be5db3 100644 --- a/internal/engine/loop.go +++ b/internal/engine/loop.go @@ -3,8 +3,10 @@ package engine import ( "context" "encoding/json" + "errors" "fmt" "sync" + "time" gnomactx "somegit.dev/Owlibou/gnoma/internal/context" "somegit.dev/Owlibou/gnoma/internal/message" @@ -88,7 +90,26 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) { s, err = e.cfg.Provider.Stream(ctx, req) } if err != nil { - return nil, fmt.Errorf("provider stream: %w", err) + // Retry on transient errors (429, 5xx) with exponential backoff + s, err = e.retryOnTransient(ctx, err, func() (stream.Stream, error) { + if e.cfg.Router != nil { + prompt := "" + for i := len(e.history) - 1; i >= 0; i-- { + if e.history[i].Role == message.RoleUser { + prompt = e.history[i].TextContent() + break + } + } + task := router.ClassifyTask(prompt) + task.EstimatedTokens = 4000 + s, _, retryErr := e.cfg.Router.Stream(ctx, task, req) + return s, retryErr + } + return e.cfg.Provider.Stream(ctx, req) + }) + if err != nil { + return nil, fmt.Errorf("provider stream: %w", err) + } } // Consume stream, forwarding events to callback @@ -320,6 +341,53 @@ func truncate(s string, maxLen int) string { return s[:maxLen] + "..." } +// retryOnTransient retries the stream call on 429/5xx with exponential backoff. +// Returns the original error if not retryable or all retries exhausted. +func (e *Engine) retryOnTransient(ctx context.Context, firstErr error, fn func() (stream.Stream, error)) (stream.Stream, error) { + var provErr *provider.ProviderError + if !errors.As(firstErr, &provErr) || !provErr.Retryable { + return nil, firstErr + } + + const maxRetries = 4 + delays := [maxRetries]time.Duration{ + 1 * time.Second, + 2 * time.Second, + 4 * time.Second, + 8 * time.Second, + } + + // Use Retry-After if the provider told us + if provErr.RetryAfter > 0 && provErr.RetryAfter < 30*time.Second { + delays[0] = provErr.RetryAfter + } + + for attempt := range maxRetries { + e.logger.Debug("retrying after transient error", + "attempt", attempt+1, + "delay", delays[attempt], + "status", provErr.StatusCode, + ) + + select { + case <-time.After(delays[attempt]): + case <-ctx.Done(): + return nil, ctx.Err() + } + + s, err := fn() + if err == nil { + return s, nil + } + + if !errors.As(err, &provErr) || !provErr.Retryable { + return nil, err + } + } + + return nil, firstErr +} + // toolDefFromTool converts a tool.Tool to provider.ToolDefinition. // Unused currently but kept for reference when building tool definitions dynamically. func toolDefFromJSON(name, description string, params json.RawMessage) provider.ToolDefinition { diff --git a/internal/tool/agent/batch.go b/internal/tool/agent/batch.go index d5ea687..c0f385c 100644 --- a/internal/tool/agent/batch.go +++ b/internal/tool/agent/batch.go @@ -94,7 +94,7 @@ func (t *BatchTool) Execute(ctx context.Context, args json.RawMessage) (tool.Res systemPrompt := "You are an elf — a focused sub-agent of gnoma. Complete the given task thoroughly and concisely. Use tools as needed." - // Spawn all elfs + // Spawn all elfs with slight stagger to avoid rate limit bursts type elfEntry struct { elf elf.Elf desc string @@ -102,11 +102,22 @@ func (t *BatchTool) Execute(ctx context.Context, args json.RawMessage) (tool.Res } var elfs []elfEntry - for _, task := range a.Tasks { + for i, task := range a.Tasks { + // Stagger spawns to avoid hitting rate limits (e.g., Mistral's 1 req/s) + if i > 0 { + select { + case <-time.After(300 * time.Millisecond): + case <-ctx.Done(): + for _, entry := range elfs { + entry.elf.Cancel() + } + return tool.Result{Output: "cancelled during spawn"}, nil + } + } + taskType := parseTaskType(task.TaskType) e, err := t.manager.Spawn(ctx, taskType, task.Prompt, systemPrompt, maxTurns) if err != nil { - // Clean up already-spawned elfs for _, entry := range elfs { entry.elf.Cancel() } @@ -120,7 +131,6 @@ func (t *BatchTool) Execute(ctx context.Context, args json.RawMessage) (tool.Res elfs = append(elfs, elfEntry{elf: e, desc: desc, task: task}) - // Send initial progress t.sendProgress(elf.Progress{ ElfID: e.ID(), Description: desc,