16 Commits

Author SHA1 Message Date
3167966b98 chore: move module path to github.com/VikingOwl91/mistral-go-sdk
Public discoverability on pkg.go.dev. Also fixes stream tool call
test fixture to match real Mistral API responses (finish_reason, usage).
2026-04-03 12:01:11 +02:00
e22732aa7c docs: update README with Workflows API, correct method/test counts 2026-04-02 16:58:16 +02:00
6928b9f1c9 chore: bump version to v1.2.0, update changelog and docs 2026-04-02 16:52:26 +02:00
0ab8064a06 feat(batch): add DeleteBatchJob method 2026-04-02 16:51:51 +02:00
c5b0011e30 feat: add workflow runs, schedules, and workers service methods 2026-04-02 16:49:35 +02:00
dc30e09c77 feat: add workflow events, deployments, and metrics service methods 2026-04-02 16:47:34 +02:00
3b0530a409 feat: add workflow execution service methods and WorkflowEventStream
14 execution service methods (Get, History, Stream, Signal, Query,
Update, Terminate, Cancel, Reset, BatchCancel, BatchTerminate, TraceOTel,
TraceSummary, TraceEvents), WorkflowEventStream with envelope unwrapping,
and ExecuteWorkflowAndWait with isTerminal polling. Extends StreamPayload
with WorkflowContext and BrokerSequence fields.
2026-04-02 16:42:35 +02:00
29aa8e0de1 feat: add workflows CRUD and registration service methods 2026-04-02 16:36:29 +02:00
910970f45e feat(workflow): add deployment, metrics, run, schedule, and registration types 2026-04-02 16:31:01 +02:00
a699495fc2 feat(workflow): add sealed Event interface with 17 types and UnmarshalEvent
Implements EventType/EventSource/Scope enums, eventBase struct, 17 concrete
event types with typed attributes for Started/Completed/Failed, UnknownEvent
fallback, UnmarshalEvent dispatcher, and SSE envelope types.
2026-04-02 16:27:39 +02:00
a41bf39325 feat(workflow): add package scaffold, core CRUD types, and execution types 2026-04-02 16:24:08 +02:00
58712f8364 docs: add workflows API implementation plan
10-task TDD plan covering workflow/ types package, 8 service files,
WorkflowEventStream, sealed Event interface, DeleteBatchJob, version
bump, and changelog update.
2026-04-02 16:15:57 +02:00
b2a1f141e0 docs: add workflows API integration design spec
Design spec for integrating Mistral Python SDK v2.2.0 changes into the
Go SDK v1.2.0. Covers Workflows API (37 methods across 8 sub-resources),
DeleteBatchJob addition, sealed event interface with 17 variants, and
SSE streaming with StreamPayload envelope.
2026-04-02 16:05:35 +02:00
aa5c53c407 feat: v1.1.0 — sync with upstream Python SDK v2.1.3
Add Connectors, Audio Speech/Voices, Audio Realtime types,
and Observability (beta). 41 new service methods, 116 total.

Breaking: ListModels and UploadFile signatures changed
(pass nil for previous behavior).
2026-03-24 09:07:03 +01:00
b1f0fc4907 feat: v1.0.0 stable release — tracks upstream Python SDK v2.0.4 2026-03-17 11:34:09 +01:00
94a938b733 docs: add pkg.go.dev badge to README 2026-03-17 11:31:19 +01:00
101 changed files with 9124 additions and 75 deletions

View File

@@ -1,3 +1,92 @@
## v1.2.1 — 2026-04-03
Move module path to `github.com/VikingOwl91/mistral-go-sdk` for public
discoverability on pkg.go.dev.
### Changed
- Module path changed from `somegit.dev/vikingowl/mistral-go-sdk` to
`github.com/VikingOwl91/mistral-go-sdk`.
### Fixed
- `TestChatCompleteStream_WithToolCalls` fixture now includes `finish_reason`
and `usage` to match real Mistral API responses.
## v1.2.0 — 2026-04-02
Upstream sync with Python SDK v2.2.0. Adds Workflows API and DeleteBatchJob.
### Added
- **Workflows API** (new `workflow/` package) — complete workflow orchestration
support with 37 service methods across 8 sub-resources:
- **Workflows CRUD** — `ListWorkflows`, `GetWorkflow`, `UpdateWorkflow`,
`ArchiveWorkflow`, `UnarchiveWorkflow`, `ExecuteWorkflow`,
`ExecuteWorkflowAndWait`.
- **Registrations** — `ListWorkflowRegistrations`, `GetWorkflowRegistration`,
`ExecuteWorkflowRegistration` (deprecated).
- **Executions** — `GetWorkflowExecution`, `GetWorkflowExecutionHistory`,
`StreamWorkflowExecution`, `SignalWorkflowExecution`,
`QueryWorkflowExecution`, `UpdateWorkflowExecution`,
`TerminateWorkflowExecution`, `CancelWorkflowExecution`,
`ResetWorkflowExecution`, `BatchCancelWorkflowExecutions`,
`BatchTerminateWorkflowExecutions`.
- **Trace** — `GetWorkflowExecutionTraceOTel`,
`GetWorkflowExecutionTraceSummary`, `GetWorkflowExecutionTraceEvents`.
- **Events** — `StreamWorkflowEvents`, `ListWorkflowEvents`.
- **Deployments** — `ListWorkflowDeployments`, `GetWorkflowDeployment`.
- **Metrics** — `GetWorkflowMetrics`.
- **Runs** — `ListWorkflowRuns`, `GetWorkflowRun`, `GetWorkflowRunHistory`.
- **Schedules** — `ListWorkflowSchedules`, `ScheduleWorkflow`,
`UnscheduleWorkflow`.
- **Workers** — `GetWorkflowWorkerInfo`.
- **`WorkflowEventStream`** — typed SSE stream wrapper with `StreamPayload`
envelope, sealed `Event` interface (17 concrete types + `UnknownEvent`).
- **`DeleteBatchJob`** — delete a batch job by ID.
## v1.1.0 — 2026-03-24
Upstream sync with Python SDK v2.1.3. Adds Connectors, Audio Speech/Voices, and Observability (beta).
### Breaking Changes
- **`ListModels`** signature changed from `(ctx)` to `(ctx, *model.ListParams)`.
Pass `nil` for previous behavior. The new `ListParams` supports `Provider` and
`Model` query filters.
- **`UploadFile`** signature changed from `(ctx, filename, reader, purpose)` to
`(ctx, filename, reader, *file.UploadParams)`. The new `UploadParams` struct
holds `Purpose`, `Expiry`, and `Visibility` fields.
### Added
- **`ReasoningEffort`** field on `chat.CompletionRequest` and
`agents.CompletionRequest` — controls reasoning effort (`"none"`, `"high"`).
- **Connectors API** (new `connector/` package) — `CreateConnector`,
`ListConnectors`, `GetConnector`, `UpdateConnector`, `DeleteConnector`,
`GetConnectorAuthURL`, `ListConnectorTools`, `CallConnectorTool`.
- **Audio Speech (TTS)** — `Speech`, `SpeechStream` with `SpeechStream` typed
wrapper, `SpeechOutputFormat` enum (pcm/wav/mp3/flac/opus).
- **Audio Voices** — `ListVoices`, `CreateVoice`, `GetVoice`, `UpdateVoice`,
`DeleteVoice`, `GetVoiceSampleAudio`.
- **Audio Realtime types** — `AudioEncoding`, `AudioFormat`, `RealtimeSession`,
and WebSocket message types in `audio/realtime.go`. No WebSocket client yet
(would require adding a dependency).
- **Observability API** (new `observability/` package, beta) — campaigns,
chat completion events, judges, datasets, records, and import tasks.
33 service methods total.
- **`file.Visibility`** enum — `shared_global`, `shared_org`,
`shared_workspace`, `private`.
- **`model.ListParams`** — filter models by `Provider` and `Model`.
## v1.0.0 — 2026-03-17
Stable release. Tracks upstream Python SDK v2.0.4.
No API changes from v0.2.0. This release signals that the SDK surface is
stable and follows Go module semver conventions — breaking changes will only
occur in future major versions.
## v0.2.0 — 2026-03-17
Sync with upstream Python SDK v2.0.4. Upstream reference changed from OpenAPI

89
CLAUDE.md Normal file
View File

@@ -0,0 +1,89 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Project
Idiomatic Go SDK for the Mistral AI API. Module path: `github.com/VikingOwl91/mistral-go-sdk`. Requires Go 1.26+. Zero external dependencies (stdlib only). Tracks the upstream [Mistral Python SDK](https://github.com/mistralai/client-python) as reference for API surface and type definitions.
## Repository layout
- **Working directory**: `mistral-go-sdk/` — the Go SDK source. All development happens here.
- **`../client-python/`**: Clone of the upstream Mistral Python SDK. Read-only reference — pull/update it when checking for upstream API changes, but never modify it.
## Commands
```bash
# Run all unit tests
go test ./...
# Run a single test
go test -run TestChatComplete_Success
# Run integration tests (requires MISTRAL_API_KEY env var)
go test -tags=integration ./...
# Vet and build
go vet ./...
go build ./...
```
No Makefile, linter config, or code generation tooling — standard `go test` / `go vet` / `go build`.
## Architecture
### Two-layer design: types in sub-packages, methods on `*Client`
Sub-packages (`chat/`, `agents/`, `conversation/`, `embedding/`, `model/`, `file/`, `finetune/`, `batch/`, `ocr/`, `audio/`, `library/`, `moderation/`, `classification/`, `fim/`, `connector/`, `observability/`, `workflow/`) are **types-only** — they define request/response structs and enums but contain no HTTP logic. All service methods live on `*Client` in the root package, prefix-namespaced by domain (e.g. `ChatComplete`, `AgentsComplete`, `CreateFineTuningJob`, `UploadFile`).
### HTTP internals (request.go)
All HTTP flows route through a small set of unexported helpers on `*Client`:
- `do()` — raw HTTP with auth headers + retry
- `doJSON()` — JSON marshal request → `do()` → unmarshal response
- `doStream()` — JSON request → raw `*http.Response` for SSE
- `doMultipart()` / `doMultipartStream()` — multipart file upload variants
- `doRetry()` — retry loop with exponential backoff + jitter + `Retry-After` parsing
### Streaming
Generic `Stream[T]` type wraps SSE (`sseReader`) with `Next()`/`Current()`/`Err()`/`Close()` iterator pattern. Typed wrappers `EventStream` (conversations) and `AudioStream` (transcription) unmarshal `json.RawMessage` into domain-specific event types.
### Sealed interfaces for discriminated unions
Polymorphic API types use **sealed interfaces** with unexported marker methods:
- `chat.Message` (marker: `isMessage()`) — `SystemMessage`, `UserMessage`, `AssistantMessage`, `ToolMessage`
- `chat.ContentChunk` (marker: `contentChunk()`) — `TextChunk`, `ImageURLChunk`, `DocumentURLChunk`, `FileChunk`, `ReferenceChunk`, `ThinkChunk`, `AudioChunk`, `ToolReferenceChunk`, `ToolFileChunk`
- `agents.AgentTool` (marker: `agentToolType()`) — `FunctionTool`, `WebSearchTool`, `CodeInterpreterTool`, `ConnectorTool`, etc.
- `conversation.Event` — conversation streaming events
Each has an `Unknown*` variant so the SDK doesn't break on new API types. Each has a `Unmarshal*` dispatch function that probes a `type`/`role` discriminator field.
### Custom JSON patterns
Several types require non-trivial marshal/unmarshal:
- **Type alias trick** — `type alias T` inside `MarshalJSON` to avoid infinite recursion when injecting a `type`/`role` discriminator field.
- **`json:"-"` + custom MarshalJSON** — `CompletionRequest.Messages` (and `stream`) are excluded from default marshaling and injected via custom `MarshalJSON`.
- **Union types** — `Content` handles `string | null | []ContentChunk`; `ToolChoice` handles `string | object`; `ImageURL` handles `string | object`; `FunctionCall.Arguments` handles `string | object`; `ReferenceID` handles `int | string` with type preservation.
- **Probe struct pattern** — `Unmarshal*` functions decode only the discriminator field first, then dispatch to the concrete type.
### Shared types in `chat/`
`GuardrailConfig`, `ModerationLLMV1Config`, `ModerationLLMV2Config` live in `chat/` because it's the base types package imported by both `agents/` and `conversation/`. This avoids import cycles.
### Error handling
`APIError` in `error.go` with sentinel checkers: `IsNotFound()`, `IsRateLimit()`, `IsAuth()`. All use `errors.As` for unwrapping.
## Testing patterns
- Unit tests use `httptest.NewServer` with inline handlers to mock the Mistral API. Client is pointed at the test server via `WithBaseURL(server.URL)`.
- Integration tests are behind `//go:build integration` build tag and require `MISTRAL_API_KEY`.
- Tests use stdlib `testing` only — no third-party test frameworks.
## Adding a new API endpoint
1. Define request/response types in the appropriate sub-package (or create a new one with a `doc.go`).
2. Add a method on `*Client` in the root package. Use `doJSON` for standard request/response, `doStream` for SSE, `doMultipart` for file uploads.
3. Add unit tests with `httptest.NewServer`.
4. If the endpoint supports streaming, return `*Stream[T]` and call `EnableStream()` on the request before sending.

View File

@@ -3,6 +3,7 @@
The most complete Go client for the [Mistral AI API](https://docs.mistral.ai/).
<!-- Badges -->
[![Go Reference](https://pkg.go.dev/badge/github.com/VikingOwl91/mistral-go-sdk.svg)](https://pkg.go.dev/github.com/VikingOwl91/mistral-go-sdk)
![Go Version](https://img.shields.io/badge/go-1.26-blue)
![License](https://img.shields.io/badge/license-MIT-green)
@@ -10,20 +11,20 @@ The most complete Go client for the [Mistral AI API](https://docs.mistral.ai/).
**Zero dependencies.** The entire SDK — including tests — uses only the Go standard library. No `go.sum`, no transitive dependency tree to audit, no version conflicts, no supply chain risk.
**Full API coverage.** 75 methods across every Mistral endpoint — including Conversations, Agents CRUD, Libraries, OCR, Audio, Fine-tuning, and Batch Jobs. No other Go SDK covers Conversations or Agents.
**Full API coverage.** 166 methods across every Mistral endpoint — including Workflows, Connectors, Audio Speech/Voices, Conversations, Agents CRUD, Libraries, OCR, Observability, Fine-tuning, and Batch Jobs. No other Go SDK covers Workflows, Conversations, Connectors, or Observability.
**Typed streaming.** A generic pull-based `Stream[T]` iterator — no channels, no goroutines, no leaks. Just `Next()` / `Current()` / `Err()` / `Close()`.
**Forward-compatible.** Unknown types (`UnknownEntry`, `UnknownEvent`, `UnknownMessage`, `UnknownChunk`, `UnknownAgentTool`) capture raw JSON instead of returning errors. When Mistral ships a new message role or event type, your code keeps running — it doesn't panic.
**Forward-compatible.** Unknown types (`UnknownEntry`, `UnknownEvent`, `UnknownMessage`, `UnknownChunk`, `UnknownAgentTool`, workflow `UnknownEvent`) capture raw JSON instead of returning errors. When Mistral ships a new message role or event type, your code keeps running — it doesn't panic.
**Hand-written, not generated.** Idiomatic Go with sealed interfaces, discriminated unions, and functional options — not a Speakeasy/OpenAPI auto-gen dump with `any` everywhere.
**Test-driven.** 193 tests with race detection clean. Every endpoint tested against mock servers; integration tests against the real API.
**Test-driven.** 284 tests with race detection clean. Every endpoint tested against mock servers; integration tests against the real API.
## Install
```sh
go get somegit.dev/vikingowl/mistral-go-sdk
go get github.com/VikingOwl91/mistral-go-sdk
```
## Quick Start
@@ -38,8 +39,8 @@ import (
"fmt"
"log"
mistral "somegit.dev/vikingowl/mistral-go-sdk"
"somegit.dev/vikingowl/mistral-go-sdk/chat"
mistral "github.com/VikingOwl91/mistral-go-sdk"
"github.com/VikingOwl91/mistral-go-sdk/chat"
)
func main() {
@@ -111,7 +112,7 @@ resp, err := client.ChatComplete(ctx, &chat.CompletionRequest{
### Conversations
```go
import "somegit.dev/vikingowl/mistral-go-sdk/conversation"
import "github.com/VikingOwl91/mistral-go-sdk/conversation"
resp, err := client.StartConversation(ctx, &conversation.StartRequest{
AgentID: "ag-your-agent-id",
@@ -131,7 +132,7 @@ for stream.Next() {
## API Coverage
75 public methods on `Client`, grouped by domain:
166 public methods on `Client`, grouped by domain:
| Domain | Methods |
|--------|---------|
@@ -139,17 +140,34 @@ for stream.Next() {
| **FIM** | `FIMComplete`, `FIMCompleteStream` |
| **Agents (completions)** | `AgentsComplete`, `AgentsCompleteStream` |
| **Agents (CRUD)** | `CreateAgent`, `ListAgents`, `GetAgent`, `UpdateAgent`, `DeleteAgent`, `UpdateAgentVersion`, `ListAgentVersions`, `GetAgentVersion`, `SetAgentAlias`, `ListAgentAliases`, `DeleteAgentAlias` |
| **Connectors** | `CreateConnector`, `ListConnectors`, `GetConnector`, `UpdateConnector`, `DeleteConnector`, `GetConnectorAuthURL`, `ListConnectorTools`, `CallConnectorTool` |
| **Conversations** | `StartConversation`, `StartConversationStream`, `AppendConversation`, `AppendConversationStream`, `RestartConversation`, `RestartConversationStream`, `GetConversation`, `ListConversations`, `DeleteConversation`, `GetConversationHistory`, `GetConversationMessages` |
| **Models** | `ListModels`, `GetModel`, `DeleteModel` |
| **Files** | `UploadFile`, `ListFiles`, `GetFile`, `DeleteFile`, `GetFileContent`, `GetFileURL` |
| **Embeddings** | `CreateEmbeddings` |
| **Fine-tuning** | `CreateFineTuningJob`, `ListFineTuningJobs`, `GetFineTuningJob`, `CancelFineTuningJob`, `StartFineTuningJob`, `UpdateFineTunedModel`, `ArchiveFineTunedModel`, `UnarchiveFineTunedModel` |
| **Batch** | `CreateBatchJob`, `ListBatchJobs`, `GetBatchJob`, `CancelBatchJob` |
| **Batch** | `CreateBatchJob`, `ListBatchJobs`, `GetBatchJob`, `CancelBatchJob`, `DeleteBatchJob` |
| **OCR** | `OCR` |
| **Audio** | `Transcribe`, `TranscribeStream` |
| **Audio (transcription)** | `Transcribe`, `TranscribeStream` |
| **Audio (speech)** | `Speech`, `SpeechStream` |
| **Audio (voices)** | `ListVoices`, `CreateVoice`, `GetVoice`, `UpdateVoice`, `DeleteVoice`, `GetVoiceSampleAudio` |
| **Libraries** | `CreateLibrary`, `ListLibraries`, `GetLibrary`, `UpdateLibrary`, `DeleteLibrary`, `UploadDocument`, `ListDocuments`, `GetDocument`, `UpdateDocument`, `DeleteDocument`, `GetDocumentTextContent`, `GetDocumentStatus`, `GetDocumentSignedURL`, `GetDocumentExtractedTextSignedURL`, `ReprocessDocument`, `ListLibrarySharing`, `ShareLibrary`, `UnshareLibrary` |
| **Moderation** | `Moderate`, `ModerateChat` |
| **Classification** | `Classify`, `ClassifyChat` |
| **Observability (campaigns)** | `CreateCampaign`, `ListCampaigns`, `GetCampaign`, `DeleteCampaign`, `GetCampaignStatus`, `ListCampaignEvents` |
| **Observability (events)** | `SearchChatCompletionEvents`, `SearchChatCompletionEventIDs`, `GetChatCompletionEvent`, `GetSimilarChatCompletionEvents`, `JudgeChatCompletionEvent` |
| **Observability (judges)** | `CreateJudge`, `ListJudges`, `GetJudge`, `UpdateJudge`, `DeleteJudge`, `JudgeConversation` |
| **Observability (datasets)** | `CreateDataset`, `ListDatasets`, `GetDataset`, `UpdateDataset`, `DeleteDataset`, `ExportDatasetToJSONL`, `ListDatasetRecords`, `CreateDatasetRecord`, `GetDatasetRecord`, `UpdateDatasetRecordPayload`, `UpdateDatasetRecordProperties`, `DeleteDatasetRecord`, `BulkDeleteDatasetRecords`, `JudgeDatasetRecord`, `ImportDatasetFromCampaign`, `ImportDatasetFromExplorer`, `ImportDatasetFromFile`, `ImportDatasetFromPlayground`, `ImportDatasetFromDataset`, `ListDatasetTasks`, `GetDatasetTask` |
| **Workflows (CRUD)** | `ListWorkflows`, `GetWorkflow`, `UpdateWorkflow`, `ArchiveWorkflow`, `UnarchiveWorkflow`, `ExecuteWorkflow`, `ExecuteWorkflowAndWait` |
| **Workflows (registrations)** | `ListWorkflowRegistrations`, `GetWorkflowRegistration`, `ExecuteWorkflowRegistration` |
| **Workflows (executions)** | `GetWorkflowExecution`, `GetWorkflowExecutionHistory`, `StreamWorkflowExecution`, `SignalWorkflowExecution`, `QueryWorkflowExecution`, `UpdateWorkflowExecution`, `TerminateWorkflowExecution`, `CancelWorkflowExecution`, `ResetWorkflowExecution`, `BatchCancelWorkflowExecutions`, `BatchTerminateWorkflowExecutions` |
| **Workflows (trace)** | `GetWorkflowExecutionTraceOTel`, `GetWorkflowExecutionTraceSummary`, `GetWorkflowExecutionTraceEvents` |
| **Workflows (events)** | `StreamWorkflowEvents`, `ListWorkflowEvents` |
| **Workflows (deployments)** | `ListWorkflowDeployments`, `GetWorkflowDeployment` |
| **Workflows (metrics)** | `GetWorkflowMetrics` |
| **Workflows (runs)** | `ListWorkflowRuns`, `GetWorkflowRun`, `GetWorkflowRunHistory` |
| **Workflows (schedules)** | `ListWorkflowSchedules`, `ScheduleWorkflow`, `UnscheduleWorkflow` |
| **Workflows (workers)** | `GetWorkflowWorkerInfo` |
## Comparison
@@ -162,11 +180,14 @@ There is no official Go SDK from Mistral AI (only Python and TypeScript). The ma
| Embeddings | Yes | Yes | Yes | Yes |
| Tool calling | Yes | No | No | No |
| Agents (completions + CRUD) | Yes | No | No | No |
| Connectors (MCP) | Yes | No | No | No |
| Conversations API | Yes | No | No | No |
| Libraries / Documents | Yes | No | No | No |
| Fine-tuning / Batch | Yes | No | No | No |
| OCR | Yes | No | No | Yes |
| Audio transcription | Yes | No | No | No |
| Audio (transcription + TTS + voices) | Yes | No | No | No |
| Workflows API | Yes | No | No | No |
| Observability (beta) | Yes | No | No | No |
| Moderation / Classification | Yes | No | No | No |
| Vision (multimodal) | Yes | No | No | Yes |
| Zero dependencies | Yes | test-only (testify) | test-only (testify) | test-only (testify) |
@@ -218,6 +239,12 @@ if err != nil {
This SDK tracks the [official Mistral Python SDK](https://github.com/mistralai/client-python)
as its upstream reference for API surface and type definitions.
| SDK Version | Upstream Python SDK |
|-------------|---------------------|
| v1.2.0 | v2.2.0 |
| v1.1.0 | v2.1.3 |
| v1.0.0 | v2.0.4 |
## License
[MIT](LICENSE)

View File

@@ -4,7 +4,7 @@ import (
"encoding/json"
"fmt"
"somegit.dev/vikingowl/mistral-go-sdk/chat"
"github.com/VikingOwl91/mistral-go-sdk/chat"
)
// AgentTool is a sealed interface for agent tool types.

View File

@@ -3,7 +3,7 @@ package agents
import (
"encoding/json"
"somegit.dev/vikingowl/mistral-go-sdk/chat"
"github.com/VikingOwl91/mistral-go-sdk/chat"
)
// CompletionRequest represents an agents completion request.
@@ -26,6 +26,7 @@ type CompletionRequest struct {
Prediction *chat.Prediction `json:"prediction,omitempty"`
PromptMode *chat.PromptMode `json:"prompt_mode,omitempty"`
Guardrails []chat.GuardrailConfig `json:"guardrails,omitempty"`
ReasoningEffort *chat.ReasoningEffort `json:"reasoning_effort,omitempty"`
stream bool
}

View File

@@ -3,8 +3,8 @@ package mistral
import (
"context"
"somegit.dev/vikingowl/mistral-go-sdk/agents"
"somegit.dev/vikingowl/mistral-go-sdk/chat"
"github.com/VikingOwl91/mistral-go-sdk/agents"
"github.com/VikingOwl91/mistral-go-sdk/chat"
)
// AgentsComplete sends an agents completion request.

View File

@@ -8,8 +8,8 @@ import (
"net/http/httptest"
"testing"
"somegit.dev/vikingowl/mistral-go-sdk/agents"
"somegit.dev/vikingowl/mistral-go-sdk/chat"
"github.com/VikingOwl91/mistral-go-sdk/agents"
"github.com/VikingOwl91/mistral-go-sdk/chat"
)
func TestAgentsComplete_Success(t *testing.T) {
@@ -114,6 +114,39 @@ func TestAgentsComplete_WithTools(t *testing.T) {
}
}
func TestAgentsComplete_ReasoningEffort(t *testing.T) {
effort := chat.ReasoningEffortHigh
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
if body["reasoning_effort"] != "high" {
t.Errorf("expected reasoning_effort=high, got %v", body["reasoning_effort"])
}
json.NewEncoder(w).Encode(map[string]any{
"id": "a-re", "object": "chat.completion",
"model": "m", "created": 0,
"choices": []map[string]any{{
"index": 0, "message": map[string]any{"role": "assistant", "content": "ok"},
"finish_reason": "stop",
}},
"usage": map[string]any{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
_, err := client.AgentsComplete(context.Background(), &agents.CompletionRequest{
AgentID: "agent-1",
Messages: []chat.Message{&chat.UserMessage{Content: chat.TextContent("Hi")}},
ReasoningEffort: &effort,
})
if err != nil {
t.Fatal(err)
}
}
func TestAgentsCompleteStream_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var body map[string]any

View File

@@ -6,7 +6,7 @@ import (
"net/url"
"strconv"
"somegit.dev/vikingowl/mistral-go-sdk/agents"
"github.com/VikingOwl91/mistral-go-sdk/agents"
)
// CreateAgent creates a new agent.

View File

@@ -7,7 +7,7 @@ import (
"net/http/httptest"
"testing"
"somegit.dev/vikingowl/mistral-go-sdk/agents"
"github.com/VikingOwl91/mistral-go-sdk/agents"
)
func TestCreateAgent_Success(t *testing.T) {

View File

@@ -1,5 +1,25 @@
// Package audio provides types for the Mistral audio transcription API.
// Package audio provides types for the Mistral audio APIs.
//
// # Transcription
//
// [TranscriptionRequest] and [TranscriptionResponse] handle speech-to-text.
// Streaming transcription returns typed [StreamEvent] values via a sealed
// interface dispatched by the "type" field.
//
// # Speech (TTS)
//
// [SpeechRequest] and [SpeechResponse] handle text-to-speech.
// Streaming speech returns typed [SpeechStreamEvent] values
// ([SpeechAudioDelta] and [SpeechDone]).
//
// # Voices
//
// [VoiceResponse], [VoiceCreateRequest], and [VoiceUpdateRequest] manage
// custom voices for speech synthesis.
//
// # Realtime
//
// Realtime transcription types ([AudioEncoding], [AudioFormat],
// [RealtimeSession], and WebSocket message types) are defined here.
// The WebSocket client is not yet implemented.
package audio

75
audio/realtime.go Normal file
View File

@@ -0,0 +1,75 @@
package audio
// AudioEncoding is the encoding format for realtime audio streams.
type AudioEncoding string
const (
EncodingPCMS16LE AudioEncoding = "pcm_s16le"
EncodingPCMS32LE AudioEncoding = "pcm_s32le"
EncodingPCMF16LE AudioEncoding = "pcm_f16le"
EncodingPCMF32LE AudioEncoding = "pcm_f32le"
EncodingPCMMulaw AudioEncoding = "pcm_mulaw"
EncodingPCMAlaw AudioEncoding = "pcm_alaw"
)
// AudioFormat describes the encoding and sample rate for realtime audio.
type AudioFormat struct {
Encoding AudioEncoding `json:"encoding"`
SampleRate int `json:"sample_rate"`
}
// RealtimeSession describes a realtime transcription session.
type RealtimeSession struct {
RequestID string `json:"request_id"`
Model string `json:"model"`
AudioFormat AudioFormat `json:"audio_format"`
TargetStreamingDelayMs *int `json:"target_streaming_delay_ms,omitempty"`
}
// RealtimeSessionUpdate is sent to update session parameters.
// Parameters can only be changed before audio transmission starts.
type RealtimeSessionUpdate struct {
AudioFormat *AudioFormat `json:"audio_format,omitempty"`
TargetStreamingDelayMs *int `json:"target_streaming_delay_ms,omitempty"`
}
// InputAudioAppend sends a chunk of audio data.
// Audio is base64-encoded (max 262144 bytes decoded).
type InputAudioAppend struct {
Type string `json:"type"` // "input_audio.append"
Audio string `json:"audio"`
}
// InputAudioFlush flushes the audio buffer.
type InputAudioFlush struct {
Type string `json:"type"` // "input_audio.flush"
}
// InputAudioEnd signals the end of audio input.
type InputAudioEnd struct {
Type string `json:"type"` // "input_audio.end"
}
// RealtimeSessionCreated is received when a session is created.
type RealtimeSessionCreated struct {
Type string `json:"type"` // "session.created"
Session RealtimeSession `json:"session"`
}
// RealtimeSessionUpdated is received when a session is updated.
type RealtimeSessionUpdated struct {
Type string `json:"type"` // "session.updated"
Session RealtimeSession `json:"session"`
}
// RealtimeErrorDetail describes a realtime error.
type RealtimeErrorDetail struct {
Message string `json:"message"`
Code int `json:"code"`
}
// RealtimeError is received on error.
type RealtimeError struct {
Type string `json:"type"` // "error"
Error RealtimeErrorDetail `json:"error"`
}

88
audio/speech.go Normal file
View File

@@ -0,0 +1,88 @@
package audio
import (
"encoding/json"
"fmt"
)
// SpeechOutputFormat is the output audio format for speech synthesis.
type SpeechOutputFormat string
const (
SpeechFormatPCM SpeechOutputFormat = "pcm"
SpeechFormatWAV SpeechOutputFormat = "wav"
SpeechFormatMP3 SpeechOutputFormat = "mp3"
SpeechFormatFLAC SpeechOutputFormat = "flac"
SpeechFormatOpus SpeechOutputFormat = "opus"
)
// SpeechRequest represents a text-to-speech request.
type SpeechRequest struct {
Input string `json:"input"`
Model string `json:"model"`
Metadata map[string]any `json:"metadata,omitempty"`
VoiceID *string `json:"voice_id,omitempty"`
RefAudio *string `json:"ref_audio,omitempty"`
ResponseFormat *SpeechOutputFormat `json:"response_format,omitempty"`
stream bool
}
// EnableStream is used internally to enable streaming.
func (r *SpeechRequest) EnableStream() { r.stream = true }
func (r *SpeechRequest) MarshalJSON() ([]byte, error) {
type Alias SpeechRequest
return json.Marshal(&struct {
Stream bool `json:"stream"`
*Alias
}{
Stream: r.stream,
Alias: (*Alias)(r),
})
}
// SpeechResponse is the response from a non-streaming speech request.
type SpeechResponse struct {
AudioData string `json:"audio_data"`
}
// SpeechStreamEvent is a sealed interface for speech streaming events.
type SpeechStreamEvent interface {
speechStreamEvent()
}
// SpeechAudioDelta contains a chunk of audio data during streaming.
type SpeechAudioDelta struct {
Type string `json:"type"`
AudioData string `json:"audio_data"`
}
func (*SpeechAudioDelta) speechStreamEvent() {}
// SpeechDone is emitted when speech synthesis is complete.
type SpeechDone struct {
Type string `json:"type"`
Usage UsageInfo `json:"usage"`
}
func (*SpeechDone) speechStreamEvent() {}
// UnmarshalSpeechStreamEvent dispatches a raw JSON event to the correct type.
func UnmarshalSpeechStreamEvent(data []byte) (SpeechStreamEvent, error) {
var probe struct {
Type string `json:"type"`
}
if err := json.Unmarshal(data, &probe); err != nil {
return nil, err
}
switch probe.Type {
case "speech.audio.delta":
var e SpeechAudioDelta
return &e, json.Unmarshal(data, &e)
case "speech.audio.done":
var e SpeechDone
return &e, json.Unmarshal(data, &e)
default:
return nil, fmt.Errorf("unknown speech stream event type: %q", probe.Type)
}
}

48
audio/voice.go Normal file
View File

@@ -0,0 +1,48 @@
package audio
// VoiceResponse represents a voice entity.
type VoiceResponse struct {
Name string `json:"name"`
ID string `json:"id"`
CreatedAt string `json:"created_at"`
UserID *string `json:"user_id,omitempty"`
Slug *string `json:"slug,omitempty"`
Languages []string `json:"languages,omitempty"`
Gender *string `json:"gender,omitempty"`
Age *int `json:"age,omitempty"`
Tags []string `json:"tags,omitempty"`
Color *string `json:"color,omitempty"`
RetentionNotice *int `json:"retention_notice,omitempty"`
}
// VoiceCreateRequest creates a custom voice.
type VoiceCreateRequest struct {
Name string `json:"name"`
SampleAudio string `json:"sample_audio"`
Slug *string `json:"slug,omitempty"`
Languages []string `json:"languages,omitempty"`
Gender *string `json:"gender,omitempty"`
Age *int `json:"age,omitempty"`
Tags []string `json:"tags,omitempty"`
Color *string `json:"color,omitempty"`
RetentionNotice *int `json:"retention_notice,omitempty"`
SampleFilename *string `json:"sample_filename,omitempty"`
}
// VoiceUpdateRequest updates a voice.
type VoiceUpdateRequest struct {
Name *string `json:"name,omitempty"`
Languages []string `json:"languages,omitempty"`
Gender *string `json:"gender,omitempty"`
Age *int `json:"age,omitempty"`
Tags []string `json:"tags,omitempty"`
}
// VoiceListResponse is the response from listing voices.
type VoiceListResponse struct {
Items []VoiceResponse `json:"items"`
Total int `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
TotalPages int `json:"total_pages"`
}

View File

@@ -3,9 +3,11 @@ package mistral
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"somegit.dev/vikingowl/mistral-go-sdk/audio"
"github.com/VikingOwl91/mistral-go-sdk/audio"
)
// Transcribe sends an audio file for transcription.
@@ -95,3 +97,125 @@ func (s *AudioStream) Err() error { return s.err }
// Close releases the underlying connection.
func (s *AudioStream) Close() error { return s.stream.Close() }
// Speech sends a text-to-speech request and returns the full response.
func (c *Client) Speech(ctx context.Context, req *audio.SpeechRequest) (*audio.SpeechResponse, error) {
var resp audio.SpeechResponse
if err := c.doJSON(ctx, "POST", "/v1/audio/speech", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// SpeechStream sends a text-to-speech request and returns a stream of audio events.
func (c *Client) SpeechStream(ctx context.Context, req *audio.SpeechRequest) (*SpeechStream, error) {
req.EnableStream()
resp, err := c.doStream(ctx, "POST", "/v1/audio/speech", req)
if err != nil {
return nil, err
}
return newSpeechStream(resp.Body), nil
}
// SpeechStream wraps the generic Stream for speech streaming events.
type SpeechStream struct {
stream *Stream[json.RawMessage]
event audio.SpeechStreamEvent
err error
}
func newSpeechStream(body readCloser) *SpeechStream {
return &SpeechStream{
stream: newStream[json.RawMessage](body),
}
}
// Next advances to the next event. Returns false when done or on error.
func (s *SpeechStream) Next() bool {
if s.err != nil {
return false
}
if !s.stream.Next() {
s.err = s.stream.Err()
return false
}
event, err := audio.UnmarshalSpeechStreamEvent(s.stream.Current())
if err != nil {
s.err = err
return false
}
s.event = event
return true
}
// Current returns the most recently read event.
func (s *SpeechStream) Current() audio.SpeechStreamEvent { return s.event }
// Err returns any error encountered during streaming.
func (s *SpeechStream) Err() error { return s.err }
// Close releases the underlying connection.
func (s *SpeechStream) Close() error { return s.stream.Close() }
// ListVoices returns available voices.
func (c *Client) ListVoices(ctx context.Context) (*audio.VoiceListResponse, error) {
var resp audio.VoiceListResponse
if err := c.doJSON(ctx, "GET", "/v1/audio/voices", nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// CreateVoice creates a custom voice.
func (c *Client) CreateVoice(ctx context.Context, req *audio.VoiceCreateRequest) (*audio.VoiceResponse, error) {
var resp audio.VoiceResponse
if err := c.doJSON(ctx, "POST", "/v1/audio/voices", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// GetVoice retrieves a voice by ID.
func (c *Client) GetVoice(ctx context.Context, voiceID string) (*audio.VoiceResponse, error) {
var resp audio.VoiceResponse
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/audio/voices/%s", voiceID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// UpdateVoice updates a voice.
func (c *Client) UpdateVoice(ctx context.Context, voiceID string, req *audio.VoiceUpdateRequest) (*audio.VoiceResponse, error) {
var resp audio.VoiceResponse
if err := c.doJSON(ctx, "PATCH", fmt.Sprintf("/v1/audio/voices/%s", voiceID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// DeleteVoice deletes a voice.
func (c *Client) DeleteVoice(ctx context.Context, voiceID string) error {
resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/audio/voices/%s", voiceID), nil)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return parseAPIError(resp)
}
return nil
}
// GetVoiceSampleAudio retrieves the sample audio for a voice.
// Returns the raw HTTP response; the caller must close the body.
func (c *Client) GetVoiceSampleAudio(ctx context.Context, voiceID string) (*http.Response, error) {
resp, err := c.do(ctx, "GET", fmt.Sprintf("/v1/audio/voices/%s/sample", voiceID), nil)
if err != nil {
return nil, err
}
if resp.StatusCode >= 400 {
defer resp.Body.Close()
return nil, parseAPIError(resp)
}
return resp, nil
}

217
audio_speech_test.go Normal file
View File

@@ -0,0 +1,217 @@
package mistral
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/VikingOwl91/mistral-go-sdk/audio"
)
func TestSpeech_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("expected POST, got %s", r.Method)
}
if r.URL.Path != "/v1/audio/speech" {
t.Errorf("got path %s", r.URL.Path)
}
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
if body["input"] != "Hello world" {
t.Errorf("got input %v", body["input"])
}
if body["stream"] != false {
t.Errorf("expected stream=false")
}
json.NewEncoder(w).Encode(map[string]any{
"audio_data": "base64audiodata==",
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.Speech(context.Background(), &audio.SpeechRequest{
Input: "Hello world",
Model: "mistral-speech",
})
if err != nil {
t.Fatal(err)
}
if resp.AudioData != "base64audiodata==" {
t.Errorf("got audio_data %q", resp.AudioData)
}
}
func TestSpeechStream_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
if body["stream"] != true {
t.Errorf("expected stream=true")
}
w.Header().Set("Content-Type", "text/event-stream")
flusher, _ := w.(http.Flusher)
delta, _ := json.Marshal(map[string]any{
"type": "speech.audio.delta", "audio_data": "chunk1==",
})
fmt.Fprintf(w, "data: %s\n\n", delta)
flusher.Flush()
done, _ := json.Marshal(map[string]any{
"type": "speech.audio.done",
"usage": map[string]any{
"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15,
},
})
fmt.Fprintf(w, "data: %s\n\n", done)
flusher.Flush()
fmt.Fprint(w, "data: [DONE]\n\n")
flusher.Flush()
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
stream, err := client.SpeechStream(context.Background(), &audio.SpeechRequest{
Input: "Hi",
Model: "mistral-speech",
})
if err != nil {
t.Fatal(err)
}
defer stream.Close()
var events int
for stream.Next() {
events++
}
if err := stream.Err(); err != nil {
t.Fatal(err)
}
if events != 2 {
t.Errorf("got %d events, want 2", events)
}
}
func TestListVoices_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/audio/voices" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"items": []map[string]any{
{"id": "v1", "name": "Default", "created_at": "2025-01-01"},
},
"total": 1, "page": 1, "page_size": 10, "total_pages": 1,
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.ListVoices(context.Background())
if err != nil {
t.Fatal(err)
}
if len(resp.Items) != 1 {
t.Fatalf("got %d voices", len(resp.Items))
}
if resp.Items[0].ID != "v1" {
t.Errorf("got id %q", resp.Items[0].ID)
}
}
func TestCreateVoice_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("expected POST, got %s", r.Method)
}
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
if body["name"] != "MyVoice" {
t.Errorf("got name %v", body["name"])
}
json.NewEncoder(w).Encode(map[string]any{
"id": "v2", "name": "MyVoice", "created_at": "2025-01-01",
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.CreateVoice(context.Background(), &audio.VoiceCreateRequest{
Name: "MyVoice",
SampleAudio: "base64audio==",
})
if err != nil {
t.Fatal(err)
}
if resp.ID != "v2" {
t.Errorf("got id %q", resp.ID)
}
}
func TestGetVoice_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/audio/voices/v1" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"id": "v1", "name": "Default", "created_at": "2025-01-01",
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.GetVoice(context.Background(), "v1")
if err != nil {
t.Fatal(err)
}
if resp.Name != "Default" {
t.Errorf("got name %q", resp.Name)
}
}
func TestUpdateVoice_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "PATCH" {
t.Errorf("expected PATCH, got %s", r.Method)
}
json.NewEncoder(w).Encode(map[string]any{
"id": "v1", "name": "Renamed", "created_at": "2025-01-01",
})
}))
defer server.Close()
name := "Renamed"
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.UpdateVoice(context.Background(), "v1", &audio.VoiceUpdateRequest{
Name: &name,
})
if err != nil {
t.Fatal(err)
}
if resp.Name != "Renamed" {
t.Errorf("got name %q", resp.Name)
}
}
func TestDeleteVoice_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "DELETE" {
t.Errorf("expected DELETE, got %s", r.Method)
}
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
err := client.DeleteVoice(context.Background(), "v1")
if err != nil {
t.Fatal(err)
}
}

View File

@@ -8,7 +8,7 @@ import (
"strings"
"testing"
"somegit.dev/vikingowl/mistral-go-sdk/audio"
"github.com/VikingOwl91/mistral-go-sdk/audio"
)
func TestTranscribe_WithFileURL(t *testing.T) {

View File

@@ -60,3 +60,10 @@ type ListParams struct {
Status []string
OrderBy *string
}
// DeleteResponse is the response from deleting a batch job.
type DeleteResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Deleted bool `json:"deleted"`
}

View File

@@ -7,7 +7,7 @@ import (
"strconv"
"strings"
"somegit.dev/vikingowl/mistral-go-sdk/batch"
"github.com/VikingOwl91/mistral-go-sdk/batch"
)
// CreateBatchJob creates a new batch inference job.
@@ -76,3 +76,12 @@ func (c *Client) CancelBatchJob(ctx context.Context, jobID string) (*batch.JobOu
}
return &resp, nil
}
// DeleteBatchJob deletes a batch job.
func (c *Client) DeleteBatchJob(ctx context.Context, jobID string) (*batch.DeleteResponse, error) {
var resp batch.DeleteResponse
if err := c.doJSON(ctx, "DELETE", fmt.Sprintf("/v1/batch/jobs/%s", jobID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}

View File

@@ -7,7 +7,7 @@ import (
"net/http/httptest"
"testing"
"somegit.dev/vikingowl/mistral-go-sdk/batch"
"github.com/VikingOwl91/mistral-go-sdk/batch"
)
func TestCreateBatchJob_Success(t *testing.T) {
@@ -121,3 +121,30 @@ func TestCancelBatchJob_Success(t *testing.T) {
t.Errorf("got status %q", job.Status)
}
}
func TestDeleteBatchJob_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "DELETE" {
t.Errorf("expected DELETE, got %s", r.Method)
}
if r.URL.Path != "/v1/batch/jobs/batch-123" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"id": "batch-123", "object": "batch", "deleted": true,
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.DeleteBatchJob(context.Background(), "batch-123")
if err != nil {
t.Fatal(err)
}
if resp.ID != "batch-123" {
t.Errorf("got id %q", resp.ID)
}
if !resp.Deleted {
t.Error("expected deleted=true")
}
}

View File

@@ -9,6 +9,14 @@ const (
PromptModeReasoning PromptMode = "reasoning"
)
// ReasoningEffort controls the amount of reasoning effort the model uses.
type ReasoningEffort string
const (
ReasoningEffortNone ReasoningEffort = "none"
ReasoningEffortHigh ReasoningEffort = "high"
)
// Prediction provides expected completion content for optimization.
type Prediction struct {
Type string `json:"type"`
@@ -36,6 +44,7 @@ type CompletionRequest struct {
Prediction *Prediction `json:"prediction,omitempty"`
PromptMode *PromptMode `json:"prompt_mode,omitempty"`
Guardrails []GuardrailConfig `json:"guardrails,omitempty"`
ReasoningEffort *ReasoningEffort `json:"reasoning_effort,omitempty"`
stream bool
}

View File

@@ -3,7 +3,7 @@ package mistral
import (
"context"
"somegit.dev/vikingowl/mistral-go-sdk/chat"
"github.com/VikingOwl91/mistral-go-sdk/chat"
)
// ChatComplete sends a chat completion request and returns the full response.

View File

@@ -8,7 +8,7 @@ import (
"net/http/httptest"
"testing"
"somegit.dev/vikingowl/mistral-go-sdk/chat"
"github.com/VikingOwl91/mistral-go-sdk/chat"
)
func TestChatComplete_Success(t *testing.T) {
@@ -350,6 +350,41 @@ func TestChatComplete_RequestBody(t *testing.T) {
}
}
func TestChatComplete_ReasoningEffort(t *testing.T) {
effort := chat.ReasoningEffortHigh
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
bodyBytes, _ := io.ReadAll(r.Body)
var body map[string]any
json.Unmarshal(bodyBytes, &body)
if body["reasoning_effort"] != "high" {
t.Errorf("expected reasoning_effort=high, got %v", body["reasoning_effort"])
}
json.NewEncoder(w).Encode(map[string]any{
"id": "chat-re", "object": "chat.completion",
"model": "m", "created": 0,
"choices": []map[string]any{{
"index": 0, "message": map[string]any{"role": "assistant", "content": "ok"},
"finish_reason": "stop",
}},
"usage": map[string]any{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
_, err := client.ChatComplete(context.Background(), &chat.CompletionRequest{
Model: "m",
Messages: []chat.Message{&chat.UserMessage{Content: chat.TextContent("Hi")}},
ReasoningEffort: &effort,
})
if err != nil {
t.Fatal(err)
}
}
func TestChatComplete_ContextCanceled(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Never responds — context should cancel first

View File

@@ -9,7 +9,7 @@ import (
"strings"
"testing"
"somegit.dev/vikingowl/mistral-go-sdk/chat"
"github.com/VikingOwl91/mistral-go-sdk/chat"
)
func TestChatCompleteStream_Success(t *testing.T) {
@@ -165,6 +165,7 @@ func TestChatCompleteStream_WithToolCalls(t *testing.T) {
w.Header().Set("Content-Type", "text/event-stream")
flusher, _ := w.(http.Flusher)
toolCalls := chat.FinishReasonToolCalls
chunk := chat.CompletionChunk{
ID: "c",
Model: "m",
@@ -177,7 +178,9 @@ func TestChatCompleteStream_WithToolCalls(t *testing.T) {
Function: chat.FunctionCall{Name: "get_weather", Arguments: `{"city":"Paris"}`},
}},
},
FinishReason: &toolCalls,
}},
Usage: &chat.UsageInfo{PromptTokens: 10, CompletionTokens: 5, TotalTokens: 15},
}
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", data)
@@ -199,10 +202,17 @@ func TestChatCompleteStream_WithToolCalls(t *testing.T) {
if !stream.Next() {
t.Fatalf("expected chunk, err: %v", stream.Err())
}
tc := stream.Current().Choices[0].Delta.ToolCalls
cur := stream.Current()
tc := cur.Choices[0].Delta.ToolCalls
if len(tc) != 1 || tc[0].Function.Name != "get_weather" {
t.Errorf("got tool calls %+v", tc)
}
if cur.Choices[0].FinishReason == nil || *cur.Choices[0].FinishReason != chat.FinishReasonToolCalls {
t.Errorf("expected finish_reason tool_calls, got %v", cur.Choices[0].FinishReason)
}
if cur.Usage == nil || cur.Usage.TotalTokens != 15 {
t.Errorf("expected usage with total_tokens=15, got %+v", cur.Usage)
}
}
func TestChatCompleteStream_APIError(t *testing.T) {

View File

@@ -1,6 +1,6 @@
package classification
import "somegit.dev/vikingowl/mistral-go-sdk/chat"
import "github.com/VikingOwl91/mistral-go-sdk/chat"
// Request represents a text classification request (/v1/classifications).
type Request struct {

100
connector/connector.go Normal file
View File

@@ -0,0 +1,100 @@
package connector
import "encoding/json"
// Visibility controls who can see a connector or tool.
type Visibility string
const (
VisibilitySharedGlobal Visibility = "shared_global"
VisibilitySharedOrg Visibility = "shared_org"
VisibilitySharedWorkspace Visibility = "shared_workspace"
VisibilityPrivate Visibility = "private"
)
// AuthData holds OAuth2 client credentials for a connector.
type AuthData struct {
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
}
// Connector represents a registered MCP connector.
type Connector struct {
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
CreatedAt string `json:"created_at"`
ModifiedAt string `json:"modified_at"`
Server *string `json:"server,omitempty"`
AuthType *string `json:"auth_type,omitempty"`
}
// CreateRequest creates a new connector.
type CreateRequest struct {
Name string `json:"name"`
Description string `json:"description"`
Server string `json:"server"`
IconURL *string `json:"icon_url,omitempty"`
Visibility *Visibility `json:"visibility,omitempty"`
Headers map[string]string `json:"headers,omitempty"`
AuthData *AuthData `json:"auth_data,omitempty"`
SystemPrompt *string `json:"system_prompt,omitempty"`
}
// UpdateRequest updates an existing connector.
type UpdateRequest struct {
Name *string `json:"name,omitempty"`
Description *string `json:"description,omitempty"`
IconURL *string `json:"icon_url,omitempty"`
SystemPrompt *string `json:"system_prompt,omitempty"`
Server *string `json:"server,omitempty"`
Headers map[string]string `json:"headers,omitempty"`
AuthData *AuthData `json:"auth_data,omitempty"`
ConnectionConfig map[string]any `json:"connection_config,omitempty"`
ConnectionSecrets map[string]any `json:"connection_secrets,omitempty"`
}
// AuthURLResponse is the response from getting a connector's OAuth URL.
type AuthURLResponse struct {
AuthURL string `json:"auth_url"`
TTL int `json:"ttl"`
}
// CallToolRequest is the request body for calling a connector tool.
type CallToolRequest struct {
Arguments map[string]any `json:"arguments,omitempty"`
}
// CallToolResponse is the response from calling a connector tool.
// Content is left as raw JSON because the upstream API returns a union
// of 5 content types (text, image, audio, resource link, embedded resource).
type CallToolResponse struct {
Content json.RawMessage `json:"content"`
Metadata map[string]any `json:"metadata,omitempty"`
}
// Tool represents a tool exposed by a connector.
type Tool struct {
ID string `json:"id"`
Name string `json:"name"`
Description *string `json:"description,omitempty"`
Visibility Visibility `json:"visibility,omitempty"`
CreatedAt string `json:"created_at,omitempty"`
ModifiedAt string `json:"modified_at,omitempty"`
SystemPrompt *string `json:"system_prompt,omitempty"`
JsonSchema map[string]any `json:"jsonschema,omitempty"`
Active *bool `json:"active,omitempty"`
}
// ListParams holds query parameters for listing connectors.
type ListParams struct {
Page *int
PageSize *int
}
// ListToolsParams holds query parameters for listing connector tools.
type ListToolsParams struct {
Page *int
PageSize *int
Refresh *bool
}

6
connector/doc.go Normal file
View File

@@ -0,0 +1,6 @@
// Package connector provides types for the Mistral connectors API.
//
// Connectors represent MCP (Model Context Protocol) server integrations.
// Use [CreateRequest] to register a new connector, then use tools
// discovered via the list-tools endpoint in chat or agent completions.
package connector

119
connectors.go Normal file
View File

@@ -0,0 +1,119 @@
package mistral
import (
"context"
"fmt"
"net/url"
"strconv"
"github.com/VikingOwl91/mistral-go-sdk/connector"
)
// CreateConnector registers a new MCP connector.
func (c *Client) CreateConnector(ctx context.Context, req *connector.CreateRequest) (*connector.Connector, error) {
var resp connector.Connector
if err := c.doJSON(ctx, "POST", "/v1/connectors", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ListConnectors returns all connectors.
func (c *Client) ListConnectors(ctx context.Context, params *connector.ListParams) ([]connector.Connector, error) {
path := "/v1/connectors"
if params != nil {
q := url.Values{}
if params.Page != nil {
q.Set("page", strconv.Itoa(*params.Page))
}
if params.PageSize != nil {
q.Set("page_size", strconv.Itoa(*params.PageSize))
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp []connector.Connector
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return resp, nil
}
// GetConnector retrieves a connector by ID or name.
func (c *Client) GetConnector(ctx context.Context, idOrName string) (*connector.Connector, error) {
var resp connector.Connector
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/connectors/%s", idOrName), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// UpdateConnector updates an existing connector.
func (c *Client) UpdateConnector(ctx context.Context, idOrName string, req *connector.UpdateRequest) (*connector.Connector, error) {
var resp connector.Connector
if err := c.doJSON(ctx, "PATCH", fmt.Sprintf("/v1/connectors/%s", idOrName), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// DeleteConnector deletes a connector.
func (c *Client) DeleteConnector(ctx context.Context, idOrName string) error {
resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/connectors/%s", idOrName), nil)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return parseAPIError(resp)
}
return nil
}
// GetConnectorAuthURL returns the OAuth2 authorization URL for a connector.
func (c *Client) GetConnectorAuthURL(ctx context.Context, idOrName string, appReturnURL *string) (*connector.AuthURLResponse, error) {
path := fmt.Sprintf("/v1/connectors/%s/auth_url", idOrName)
if appReturnURL != nil {
path += "?app_return_url=" + url.QueryEscape(*appReturnURL)
}
var resp connector.AuthURLResponse
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ListConnectorTools lists tools exposed by a connector.
func (c *Client) ListConnectorTools(ctx context.Context, idOrName string, params *connector.ListToolsParams) ([]connector.Tool, error) {
path := fmt.Sprintf("/v1/connectors/%s/tools", idOrName)
if params != nil {
q := url.Values{}
if params.Page != nil {
q.Set("page", strconv.Itoa(*params.Page))
}
if params.PageSize != nil {
q.Set("page_size", strconv.Itoa(*params.PageSize))
}
if params.Refresh != nil {
q.Set("refresh", strconv.FormatBool(*params.Refresh))
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp []connector.Tool
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return resp, nil
}
// CallConnectorTool invokes a tool on a connector.
func (c *Client) CallConnectorTool(ctx context.Context, idOrName, toolName string, req *connector.CallToolRequest) (*connector.CallToolResponse, error) {
var resp connector.CallToolResponse
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/connectors/%s/tools/%s/call", idOrName, toolName), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}

217
connectors_test.go Normal file
View File

@@ -0,0 +1,217 @@
package mistral
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/VikingOwl91/mistral-go-sdk/connector"
)
func TestCreateConnector_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("expected POST, got %s", r.Method)
}
if r.URL.Path != "/v1/connectors" {
t.Errorf("got path %s", r.URL.Path)
}
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
if body["name"] != "my_connector" {
t.Errorf("got name %v", body["name"])
}
if body["server"] != "https://mcp.example.com" {
t.Errorf("got server %v", body["server"])
}
json.NewEncoder(w).Encode(map[string]any{
"id": "conn-1", "name": "my_connector",
"description": "test", "created_at": "2025-01-01",
"modified_at": "2025-01-01", "server": "https://mcp.example.com",
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.CreateConnector(context.Background(), &connector.CreateRequest{
Name: "my_connector",
Description: "test",
Server: "https://mcp.example.com",
})
if err != nil {
t.Fatal(err)
}
if resp.ID != "conn-1" {
t.Errorf("got id %q", resp.ID)
}
}
func TestListConnectors_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
t.Errorf("expected GET, got %s", r.Method)
}
json.NewEncoder(w).Encode([]map[string]any{
{"id": "c1", "name": "conn1", "description": "d1", "created_at": "t", "modified_at": "t"},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
list, err := client.ListConnectors(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
if len(list) != 1 {
t.Fatalf("got %d connectors", len(list))
}
if list[0].ID != "c1" {
t.Errorf("got id %q", list[0].ID)
}
}
func TestGetConnector_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/connectors/my_conn" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"id": "c1", "name": "my_conn", "description": "d",
"created_at": "t", "modified_at": "t",
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
c, err := client.GetConnector(context.Background(), "my_conn")
if err != nil {
t.Fatal(err)
}
if c.Name != "my_conn" {
t.Errorf("got name %q", c.Name)
}
}
func TestUpdateConnector_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "PATCH" {
t.Errorf("expected PATCH, got %s", r.Method)
}
json.NewEncoder(w).Encode(map[string]any{
"id": "c1", "name": "updated", "description": "new desc",
"created_at": "t", "modified_at": "t",
})
}))
defer server.Close()
name := "updated"
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.UpdateConnector(context.Background(), "c1", &connector.UpdateRequest{
Name: &name,
})
if err != nil {
t.Fatal(err)
}
if resp.Name != "updated" {
t.Errorf("got name %q", resp.Name)
}
}
func TestDeleteConnector_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "DELETE" {
t.Errorf("expected DELETE, got %s", r.Method)
}
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
err := client.DeleteConnector(context.Background(), "c1")
if err != nil {
t.Fatal(err)
}
}
func TestGetConnectorAuthURL_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/connectors/c1/auth_url" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"auth_url": "https://oauth.example.com/authorize",
"ttl": 3600,
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.GetConnectorAuthURL(context.Background(), "c1", nil)
if err != nil {
t.Fatal(err)
}
if resp.AuthURL != "https://oauth.example.com/authorize" {
t.Errorf("got auth_url %q", resp.AuthURL)
}
if resp.TTL != 3600 {
t.Errorf("got ttl %d", resp.TTL)
}
}
func TestListConnectorTools_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/connectors/c1/tools" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode([]map[string]any{
{"id": "t1", "name": "search", "description": "search the web"},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
tools, err := client.ListConnectorTools(context.Background(), "c1", nil)
if err != nil {
t.Fatal(err)
}
if len(tools) != 1 {
t.Fatalf("got %d tools", len(tools))
}
if tools[0].Name != "search" {
t.Errorf("got name %q", tools[0].Name)
}
}
func TestCallConnectorTool_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("expected POST, got %s", r.Method)
}
if r.URL.Path != "/v1/connectors/c1/tools/search/call" {
t.Errorf("got path %s", r.URL.Path)
}
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
args := body["arguments"].(map[string]any)
if args["query"] != "hello" {
t.Errorf("got query %v", args["query"])
}
json.NewEncoder(w).Encode(map[string]any{
"content": []map[string]any{{"type": "text", "text": "result"}},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.CallConnectorTool(context.Background(), "c1", "search", &connector.CallToolRequest{
Arguments: map[string]any{"query": "hello"},
})
if err != nil {
t.Fatal(err)
}
if resp.Content == nil {
t.Error("expected non-nil content")
}
}

View File

@@ -4,7 +4,7 @@ import (
"encoding/json"
"fmt"
"somegit.dev/vikingowl/mistral-go-sdk/chat"
"github.com/VikingOwl91/mistral-go-sdk/chat"
)
// HandoffExecution controls tool call execution.

View File

@@ -4,7 +4,7 @@ import (
"encoding/json"
"fmt"
"somegit.dev/vikingowl/mistral-go-sdk/chat"
"github.com/VikingOwl91/mistral-go-sdk/chat"
)
// Entry is a sealed interface for conversation history entries.

View File

@@ -3,7 +3,7 @@ package conversation
import (
"encoding/json"
"somegit.dev/vikingowl/mistral-go-sdk/chat"
"github.com/VikingOwl91/mistral-go-sdk/chat"
)
// StartRequest starts a new conversation.

View File

@@ -7,7 +7,7 @@ import (
"net/url"
"strconv"
"somegit.dev/vikingowl/mistral-go-sdk/conversation"
"github.com/VikingOwl91/mistral-go-sdk/conversation"
)
// StartConversation creates and starts a new conversation.

View File

@@ -8,7 +8,7 @@ import (
"net/http/httptest"
"testing"
"somegit.dev/vikingowl/mistral-go-sdk/conversation"
"github.com/VikingOwl91/mistral-go-sdk/conversation"
)
func TestStartConversation_Success(t *testing.T) {

6
doc.go
View File

@@ -39,9 +39,9 @@
// # Sub-packages
//
// Types are organized into sub-packages by domain: [chat], [agents],
// [conversation], [embedding], [model], [file], [finetune], [batch],
// [ocr], [audio], [library], [moderation], [classification], and [fim].
// All service methods live directly on [Client].
// [connector], [conversation], [embedding], [model], [file], [finetune],
// [batch], [ocr], [audio], [library], [moderation], [classification],
// [fim], and [observability]. All service methods live directly on [Client].
//
// # Reference
//

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,650 @@
# Workflows API Integration — Design Spec
**Date:** 2026-04-02
**Upstream:** Mistral Python SDK v2.2.0 (released 2026-03-31)
**SDK version:** v1.2.0
**Scope:** Full parity with Python SDK v2.2.0 changes since v2.1.3
## Summary
Add the Workflows API (37 new methods) and `DeleteBatchJob` (1 method) to the Go SDK.
This is purely additive — no breaking changes to existing API surface.
## New Package: `workflow/`
Types-only package following the two-layer architecture. 8 type files + `doc.go`.
### `workflow/doc.go`
Package documentation.
### `workflow/workflow.go` — Core CRUD types
```go
type Workflow struct {
ID string `json:"id"`
Name string `json:"name"`
DisplayName *string `json:"display_name,omitempty"`
Description *string `json:"description,omitempty"`
OwnerID string `json:"owner_id"`
WorkspaceID string `json:"workspace_id"`
AvailableInChatAssistant bool `json:"available_in_chat_assistant"`
Archived bool `json:"archived"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
type WorkflowUpdateRequest struct {
DisplayName *string `json:"display_name,omitempty"`
Description *string `json:"description,omitempty"`
AvailableInChatAssistant *bool `json:"available_in_chat_assistant,omitempty"`
}
type WorkflowListResponse struct {
Workflows []Workflow `json:"workflows"`
NextCursor *string `json:"next_cursor,omitempty"`
}
type WorkflowListParams struct {
ActiveOnly *bool
IncludeShared *bool
AvailableInChatAssistant *bool
Archived *bool
Cursor *string
Limit *int
}
type WorkflowArchiveResponse struct {
ID string `json:"id"`
Archived bool `json:"archived"`
}
```
### `workflow/execution.go` — Execution types
```go
type ExecutionStatus string
const (
ExecutionRunning ExecutionStatus = "RUNNING"
ExecutionCompleted ExecutionStatus = "COMPLETED"
ExecutionFailed ExecutionStatus = "FAILED"
ExecutionCanceled ExecutionStatus = "CANCELED"
ExecutionTerminated ExecutionStatus = "TERMINATED"
ExecutionContinuedAsNew ExecutionStatus = "CONTINUED_AS_NEW"
ExecutionTimedOut ExecutionStatus = "TIMED_OUT"
ExecutionRetryingAfterErr ExecutionStatus = "RETRYING_AFTER_ERROR"
)
type ExecutionRequest struct {
ExecutionID *string `json:"execution_id,omitempty"`
Input map[string]any `json:"input,omitempty"`
EncodedInput *NetworkEncodedInput `json:"encoded_input,omitempty"`
WaitForResult bool `json:"wait_for_result,omitempty"`
TimeoutSeconds *float64 `json:"timeout_seconds,omitempty"`
CustomTracingAttributes map[string]string `json:"custom_tracing_attributes,omitempty"`
DeploymentName *string `json:"deployment_name,omitempty"`
}
type ExecutionResponse struct {
WorkflowName string `json:"workflow_name"`
ExecutionID string `json:"execution_id"`
RootExecutionID string `json:"root_execution_id"`
Status ExecutionStatus `json:"status"`
StartTime string `json:"start_time"`
EndTime *string `json:"end_time,omitempty"`
Result any `json:"result,omitempty"`
ParentExecutionID *string `json:"parent_execution_id,omitempty"`
TotalDurationMs *int `json:"total_duration_ms,omitempty"`
}
type NetworkEncodedInput struct {
B64Payload string `json:"b64payload"`
EncodingOptions []string `json:"encoding_options,omitempty"`
Empty bool `json:"empty,omitempty"`
}
type SignalInvocationBody struct {
Name string `json:"name"`
Input any `json:"input"`
}
type SignalResponse struct {
Message string `json:"message"` // default: "Signal accepted"
}
type QueryInvocationBody struct {
Name string `json:"name"`
Input any `json:"input,omitempty"`
}
type QueryResponse struct {
QueryName string `json:"query_name"`
Result any `json:"result"`
}
type UpdateInvocationBody struct {
Name string `json:"name"`
Input any `json:"input,omitempty"`
}
type UpdateResponse struct {
UpdateName string `json:"update_name"`
Result any `json:"result"`
}
// Trace response types
type TraceOTelResponse struct {
WorkflowName string `json:"workflow_name"`
ExecutionID string `json:"execution_id"`
RootExecutionID string `json:"root_execution_id"`
Status *ExecutionStatus `json:"status"`
StartTime string `json:"start_time"`
EndTime *string `json:"end_time,omitempty"`
Result any `json:"result"`
DataSource string `json:"data_source"`
ParentExecutionID *string `json:"parent_execution_id,omitempty"`
TotalDurationMs *int `json:"total_duration_ms,omitempty"`
OTelTraceID *string `json:"otel_trace_id,omitempty"`
OTelTraceData any `json:"otel_trace_data,omitempty"`
}
type TraceSummaryResponse struct {
WorkflowName string `json:"workflow_name"`
ExecutionID string `json:"execution_id"`
RootExecutionID string `json:"root_execution_id"`
Status *ExecutionStatus `json:"status"`
StartTime string `json:"start_time"`
EndTime *string `json:"end_time,omitempty"`
Result any `json:"result"`
ParentExecutionID *string `json:"parent_execution_id,omitempty"`
TotalDurationMs *int `json:"total_duration_ms,omitempty"`
SpanTree any `json:"span_tree,omitempty"`
}
type TraceEventsResponse struct {
WorkflowName string `json:"workflow_name"`
ExecutionID string `json:"execution_id"`
RootExecutionID string `json:"root_execution_id"`
Status *ExecutionStatus `json:"status"`
StartTime string `json:"start_time"`
EndTime *string `json:"end_time,omitempty"`
Result any `json:"result"`
ParentExecutionID *string `json:"parent_execution_id,omitempty"`
TotalDurationMs *int `json:"total_duration_ms,omitempty"`
Events []json.RawMessage `json:"events,omitempty"`
}
type TraceEventsParams struct {
MergeSameIDEvents *bool
IncludeInternalEvents *bool
}
type ResetInvocationBody struct {
EventID int `json:"event_id"`
Reason *string `json:"reason,omitempty"`
ExcludeSignals bool `json:"exclude_signals,omitempty"`
ExcludeUpdates bool `json:"exclude_updates,omitempty"`
}
type BatchExecutionBody struct {
ExecutionIDs []string `json:"execution_ids"`
}
type BatchExecutionResponse struct {
Results map[string]BatchExecutionResult `json:"results,omitempty"`
}
type BatchExecutionResult struct {
Status string `json:"status"`
Error *string `json:"error,omitempty"`
}
type StreamParams struct {
EventSource *EventSource
LastEventID *string
}
```
### `workflow/event.go` — Sealed event interface + 17 variants
Discriminator field: `event_type`
```go
type Event interface {
workflowEvent()
EventType() EventType
}
type EventType string
const (
EventWorkflowStarted EventType = "WORKFLOW_EXECUTION_STARTED"
EventWorkflowCompleted EventType = "WORKFLOW_EXECUTION_COMPLETED"
EventWorkflowFailed EventType = "WORKFLOW_EXECUTION_FAILED"
EventWorkflowCanceled EventType = "WORKFLOW_EXECUTION_CANCELED"
EventWorkflowContinuedAsNew EventType = "WORKFLOW_EXECUTION_CONTINUED_AS_NEW"
EventWorkflowTaskTimedOut EventType = "WORKFLOW_TASK_TIMED_OUT"
EventWorkflowTaskFailed EventType = "WORKFLOW_TASK_FAILED"
EventCustomTaskStarted EventType = "CUSTOM_TASK_STARTED"
EventCustomTaskInProgress EventType = "CUSTOM_TASK_IN_PROGRESS"
EventCustomTaskCompleted EventType = "CUSTOM_TASK_COMPLETED"
EventCustomTaskFailed EventType = "CUSTOM_TASK_FAILED"
EventCustomTaskTimedOut EventType = "CUSTOM_TASK_TIMED_OUT"
EventCustomTaskCanceled EventType = "CUSTOM_TASK_CANCELED"
EventActivityTaskStarted EventType = "ACTIVITY_TASK_STARTED"
EventActivityTaskCompleted EventType = "ACTIVITY_TASK_COMPLETED"
EventActivityTaskRetrying EventType = "ACTIVITY_TASK_RETRYING"
EventActivityTaskFailed EventType = "ACTIVITY_TASK_FAILED"
)
type EventSource string
const (
EventSourceDatabase EventSource = "DATABASE"
EventSourceLive EventSource = "LIVE"
)
type Scope string
const (
ScopeActivity Scope = "activity"
ScopeWorkflow Scope = "workflow"
ScopeAll Scope = "*"
)
```
Each concrete event type has common fields + type-specific attributes:
```go
// Common fields embedded in all event types
type eventBase struct {
ID string `json:"event_id"`
Timestamp int64 `json:"event_timestamp"`
RootWorkflowExecID string `json:"root_workflow_exec_id"`
ParentWorkflowExecID *string `json:"parent_workflow_exec_id"`
WorkflowExecID string `json:"workflow_exec_id"`
WorkflowRunID string `json:"workflow_run_id"`
WorkflowName string `json:"workflow_name"`
}
// Example concrete types:
type WorkflowExecutionStartedEvent struct {
eventBase
Attributes WorkflowStartedAttributes `json:"attributes"`
}
func (WorkflowExecutionStartedEvent) workflowEvent() {}
func (WorkflowExecutionStartedEvent) EventType() EventType { return EventWorkflowStarted }
type WorkflowExecutionCompletedEvent struct {
eventBase
Attributes WorkflowCompletedAttributes `json:"attributes"`
}
// ... pattern repeats for all 17 types
type UnknownEvent struct {
eventBase
RawType string
Raw json.RawMessage
}
```
SSE envelope types:
```go
type StreamPayload struct {
Stream string `json:"stream"`
Data json.RawMessage `json:"data"`
WorkflowContext StreamWorkflowContext `json:"workflow_context"`
BrokerSequence int `json:"broker_sequence"`
Timestamp *string `json:"timestamp,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
}
type StreamWorkflowContext struct {
Namespace string `json:"namespace"`
WorkflowName string `json:"workflow_name"`
WorkflowExecID string `json:"workflow_exec_id"`
ParentWorkflowExecID *string `json:"parent_workflow_exec_id,omitempty"`
RootWorkflowExecID *string `json:"root_workflow_exec_id,omitempty"`
}
func UnmarshalEvent(data json.RawMessage) (Event, error)
// Probes event_type discriminator, dispatches to concrete type.
// Unknown event_type returns UnknownEvent.
```
### `workflow/deployment.go`
```go
type Deployment struct {
ID string `json:"id"`
Name string `json:"name"`
IsActive bool `json:"is_active"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
type DeploymentListResponse struct {
Deployments []Deployment `json:"deployments"`
}
type DeploymentListParams struct {
ActiveOnly *bool
WorkflowName *string
}
```
### `workflow/metrics.go`
```go
type Metrics struct {
ExecutionCount ScalarMetric `json:"execution_count"`
SuccessCount ScalarMetric `json:"success_count"`
ErrorCount ScalarMetric `json:"error_count"`
AverageLatencyMs ScalarMetric `json:"average_latency_ms"`
LatencyOverTime TimeSeriesMetric `json:"latency_over_time"`
RetryRate ScalarMetric `json:"retry_rate"`
}
type ScalarMetric struct {
Value float64 `json:"value"`
}
type TimeSeriesMetric struct {
Value [][]float64 `json:"value"`
}
type MetricsParams struct {
StartTime *string
EndTime *string
}
```
### `workflow/run.go`
```go
type Run struct {
ID string `json:"id"`
WorkflowName string `json:"workflow_name"`
ExecutionID string `json:"execution_id"`
Status ExecutionStatus `json:"status"`
StartTime string `json:"start_time"`
EndTime *string `json:"end_time,omitempty"`
}
type ListRunsResponse struct {
Runs []Run `json:"runs"`
NextPageToken *string `json:"next_page_token,omitempty"`
}
type RunListParams struct {
WorkflowIdentifier *string
Search *string
Status *string
PageSize *int
NextPageToken *string
}
```
### `workflow/schedule.go`
```go
type ScheduleRequest struct {
Schedule ScheduleDefinition `json:"schedule"`
WorkflowRegistrationID *string `json:"workflow_registration_id,omitempty"`
WorkflowIdentifier *string `json:"workflow_identifier,omitempty"`
ScheduleID *string `json:"schedule_id,omitempty"`
DeploymentName *string `json:"deployment_name,omitempty"`
}
type ScheduleDefinition struct {
Input any `json:"input"`
Calendars []ScheduleCalendar `json:"calendars,omitempty"`
Intervals []ScheduleInterval `json:"intervals,omitempty"`
CronExpressions []string `json:"cron_expressions,omitempty"`
Skip []ScheduleCalendar `json:"skip,omitempty"`
StartAt *string `json:"start_at,omitempty"`
EndAt *string `json:"end_at,omitempty"`
Jitter *string `json:"jitter,omitempty"`
TimeZoneName *string `json:"time_zone_name,omitempty"`
Policy *SchedulePolicy `json:"policy,omitempty"`
}
type ScheduleCalendar struct {
Second []ScheduleRange `json:"second,omitempty"`
Minute []ScheduleRange `json:"minute,omitempty"`
Hour []ScheduleRange `json:"hour,omitempty"`
DayOfMonth []ScheduleRange `json:"day_of_month,omitempty"`
Month []ScheduleRange `json:"month,omitempty"`
Year []ScheduleRange `json:"year,omitempty"`
DayOfWeek []ScheduleRange `json:"day_of_week,omitempty"`
Comment *string `json:"comment,omitempty"`
}
type ScheduleRange struct {
Start int `json:"start"`
End int `json:"end,omitempty"`
Step int `json:"step,omitempty"`
}
type ScheduleInterval struct {
Every string `json:"every"`
Offset *string `json:"offset,omitempty"`
}
type SchedulePolicy struct {
CatchupWindowSeconds int `json:"catchup_window_seconds,omitempty"`
Overlap *int `json:"overlap,omitempty"`
PauseOnFailure bool `json:"pause_on_failure,omitempty"`
}
type ScheduleResponse struct {
ScheduleID string `json:"schedule_id"`
}
type ScheduleListResponse struct {
Schedules []Schedule `json:"schedules"`
}
type Schedule struct {
ScheduleID string `json:"schedule_id"`
Definition ScheduleDefinition `json:"definition"`
WorkflowName string `json:"workflow_name"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
```
### `workflow/registration.go`
```go
type Registration struct {
ID string `json:"id"`
WorkflowID string `json:"workflow_id"`
TaskQueue string `json:"task_queue"`
Workflow *Workflow `json:"workflow,omitempty"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
type RegistrationListResponse struct {
Registrations []Registration `json:"registrations"`
NextCursor *string `json:"next_cursor,omitempty"`
}
type RegistrationListParams struct {
WorkflowID *string
TaskQueue *string
ActiveOnly *bool
IncludeShared *bool
WorkflowSearch *string
Archived *bool
WithWorkflow *bool
AvailableInChatAssistant *bool
Limit *int
Cursor *string
}
type RegistrationGetParams struct {
WithWorkflow *bool
IncludeShared *bool
}
type WorkerInfo struct {
SchedulerURL string `json:"scheduler_url"`
Namespace string `json:"namespace"`
TLS bool `json:"tls"`
}
```
## Service Methods (root package)
### `workflows.go` — 10 methods
| Method | HTTP | Path |
|--------|------|------|
| `ListWorkflows` | GET | `/v1/workflows` |
| `GetWorkflow` | GET | `/v1/workflows/{id}` |
| `UpdateWorkflow` | PUT | `/v1/workflows/{id}` |
| `ArchiveWorkflow` | PUT | `/v1/workflows/{id}/archive` |
| `UnarchiveWorkflow` | PUT | `/v1/workflows/{id}/unarchive` |
| `ExecuteWorkflow` | POST | `/v1/workflows/{id}/execute` |
| `ListWorkflowRegistrations` | GET | `/v1/workflows/registrations` |
| `GetWorkflowRegistration` | GET | `/v1/workflows/registrations/{id}` |
| `ExecuteWorkflowRegistration` | POST | `/v1/workflows/registrations/{id}/execute` |
| `ExecuteWorkflowAndWait` | (composite) | execute + poll |
`ExecuteWorkflowRegistration` is deprecated (doc comment only).
`ExecuteWorkflowAndWait` calls `ExecuteWorkflow`, then polls `GetWorkflowExecution`
in a loop until status is terminal or context is canceled.
### `workflows_executions.go` — 14 methods
| Method | HTTP | Path |
|--------|------|------|
| `GetWorkflowExecution` | GET | `/v1/workflows/executions/{id}` |
| `GetWorkflowExecutionHistory` | GET | `/v1/workflows/executions/{id}/history` |
| `StreamWorkflowExecution` | GET (SSE) | `/v1/workflows/executions/{id}/stream` |
| `SignalWorkflowExecution` | POST | `/v1/workflows/executions/{id}/signals` |
| `QueryWorkflowExecution` | POST | `/v1/workflows/executions/{id}/queries` |
| `UpdateWorkflowExecution` | POST | `/v1/workflows/executions/{id}/updates` |
| `TerminateWorkflowExecution` | POST (204) | `/v1/workflows/executions/{id}/terminate` |
| `CancelWorkflowExecution` | POST (204) | `/v1/workflows/executions/{id}/cancel` |
| `ResetWorkflowExecution` | POST (204) | `/v1/workflows/executions/{id}/reset` |
| `BatchCancelWorkflowExecutions` | POST | `/v1/workflows/executions/cancel` |
| `BatchTerminateWorkflowExecutions` | POST | `/v1/workflows/executions/terminate` |
| `GetWorkflowExecutionTraceOTel` | GET | `/v1/workflows/executions/{id}/trace/otel` |
| `GetWorkflowExecutionTraceSummary` | GET | `/v1/workflows/executions/{id}/trace/summary` |
| `GetWorkflowExecutionTraceEvents` | GET | `/v1/workflows/executions/{id}/trace/events` |
Also contains `WorkflowEventStream` type (wraps `Stream[json.RawMessage]`,
dispatches via `workflow.UnmarshalEvent`).
### `workflows_events.go` — 2 methods
| Method | HTTP | Path |
|--------|------|------|
| `StreamWorkflowEvents` | GET (SSE) | `/v1/workflows/events/stream` |
| `ListWorkflowEvents` | GET | `/v1/workflows/events/list` |
### `workflows_deployments.go` — 2 methods
| Method | HTTP | Path |
|--------|------|------|
| `ListWorkflowDeployments` | GET | `/v1/workflows/deployments` |
| `GetWorkflowDeployment` | GET | `/v1/workflows/deployments/{id}` |
### `workflows_metrics.go` — 1 method
| Method | HTTP | Path |
|--------|------|------|
| `GetWorkflowMetrics` | GET | `/v1/workflows/{name}/metrics` |
### `workflows_runs.go` — 3 methods
| Method | HTTP | Path |
|--------|------|------|
| `ListWorkflowRuns` | GET | `/v1/workflows/runs` |
| `GetWorkflowRun` | GET | `/v1/workflows/runs/{id}` |
| `GetWorkflowRunHistory` | GET | `/v1/workflows/runs/{id}/history` |
### `workflows_schedules.go` — 3 methods
| Method | HTTP | Path |
|--------|------|------|
| `ListWorkflowSchedules` | GET | `/v1/workflows/schedules` |
| `ScheduleWorkflow` | POST | `/v1/workflows/schedules` |
| `UnscheduleWorkflow` | DELETE | `/v1/workflows/schedules/{id}` |
### `workflows_workers.go` — 1 method
| Method | HTTP | Path |
|--------|------|------|
| `GetWorkflowWorkerInfo` | GET | `/v1/workflows/workers/whoami` |
### `batch_api.go` — 1 new method
| Method | HTTP | Path |
|--------|------|------|
| `DeleteBatchJob` | DELETE | `/v1/batch/jobs/{id}` |
New type in `batch/`: `DeleteResponse { ID, Object, Deleted }`.
## Streaming Design
`WorkflowEventStream` wraps `Stream[json.RawMessage]` like `EventStream` does for conversations.
SSE data arrives as `StreamPayload` envelope:
```json
{
"stream": "...",
"data": { "event_type": "WORKFLOW_EXECUTION_COMPLETED", ... },
"workflow_context": { ... },
"broker_sequence": 42,
"timestamp": "...",
"metadata": {}
}
```
`WorkflowEventStream.Next()`:
1. Read next SSE `data:` line via inner `Stream[json.RawMessage]`
2. Unmarshal into `workflow.StreamPayload`
3. Dispatch `payload.Data` via `workflow.UnmarshalEvent` (probes `event_type`)
4. Expose both `Current() workflow.Event` and `CurrentPayload() *workflow.StreamPayload`
Both `StreamWorkflowExecution` and `StreamWorkflowEvents` use GET (not POST)
with SSE response. They use `doStream` without a request body — the stream method
needs to support GET + query params (verify `doStream` handles nil body for GET).
## Testing
One test file per service file. `httptest.NewServer` with inline handlers. Stdlib only.
Key scenarios:
- Query param encoding for list/filter endpoints
- PUT body marshaling for update/archive
- 204 no-body responses for terminate/cancel/reset
- 202 response for signal
- SSE streaming with StreamPayload envelope + event type dispatch
- UnknownEvent forward compatibility
- ExecuteWorkflowAndWait polling loop (mock multiple get-execution responses)
- Sealed interface UnmarshalEvent for all 17 event types + unknown
- Batch operations with map response
## Version & Docs
- Bump version constant in `mistral.go` to `1.2.0`
- Update `CLAUDE.md` sub-packages list to include `workflow/`
- Update `CHANGELOG.md` with v1.2.0 entry
- Upstream sync reference: Python SDK v2.2.0
## Non-Goals
- No pagination helpers (cursor chaining) — callers manage pagination manually, same as existing endpoints
- No traceparent injection hook — Go callers manage their own tracing headers
- No `execute_workflow_and_wait_async` — Go has context cancellation instead
- No WebSocket/realtime workflow support (not in Python SDK either)

View File

@@ -1,6 +1,6 @@
package embedding
import "somegit.dev/vikingowl/mistral-go-sdk/chat"
import "github.com/VikingOwl91/mistral-go-sdk/chat"
// Dtype specifies the data type of output embeddings.
type Dtype string

View File

@@ -3,7 +3,7 @@ package mistral
import (
"context"
"somegit.dev/vikingowl/mistral-go-sdk/embedding"
"github.com/VikingOwl91/mistral-go-sdk/embedding"
)
// CreateEmbeddings sends an embedding request and returns the response.

View File

@@ -7,7 +7,7 @@ import (
"net/http/httptest"
"testing"
"somegit.dev/vikingowl/mistral-go-sdk/embedding"
"github.com/VikingOwl91/mistral-go-sdk/embedding"
)
func TestCreateEmbeddings_Success(t *testing.T) {

View File

@@ -5,9 +5,9 @@ import (
"fmt"
"log"
mistral "somegit.dev/vikingowl/mistral-go-sdk"
"somegit.dev/vikingowl/mistral-go-sdk/chat"
"somegit.dev/vikingowl/mistral-go-sdk/embedding"
mistral "github.com/VikingOwl91/mistral-go-sdk"
"github.com/VikingOwl91/mistral-go-sdk/chat"
"github.com/VikingOwl91/mistral-go-sdk/embedding"
)
func ExampleNewClient() {

View File

@@ -64,6 +64,23 @@ type SignedURL struct {
URL string `json:"url"`
}
// Visibility controls who can see a file.
type Visibility string
const (
VisibilitySharedGlobal Visibility = "shared_global"
VisibilitySharedOrg Visibility = "shared_org"
VisibilitySharedWorkspace Visibility = "shared_workspace"
VisibilityPrivate Visibility = "private"
)
// UploadParams holds parameters for uploading a file.
type UploadParams struct {
Purpose Purpose
Expiry *int
Visibility *Visibility
}
// ListParams holds optional parameters for listing files.
type ListParams struct {
Page *int

View File

@@ -8,14 +8,22 @@ import (
"net/url"
"strconv"
"somegit.dev/vikingowl/mistral-go-sdk/file"
"github.com/VikingOwl91/mistral-go-sdk/file"
)
// UploadFile uploads a file for use with fine-tuning, batch, or OCR.
func (c *Client) UploadFile(ctx context.Context, filename string, r io.Reader, purpose file.Purpose) (*file.File, error) {
func (c *Client) UploadFile(ctx context.Context, filename string, r io.Reader, params *file.UploadParams) (*file.File, error) {
fields := map[string]string{}
if purpose != "" {
fields["purpose"] = string(purpose)
if params != nil {
if params.Purpose != "" {
fields["purpose"] = string(params.Purpose)
}
if params.Expiry != nil {
fields["expiry"] = strconv.Itoa(*params.Expiry)
}
if params.Visibility != nil {
fields["visibility"] = string(*params.Visibility)
}
}
var resp file.File
if err := c.doMultipart(ctx, "/v1/files", filename, r, fields, &resp); err != nil {

View File

@@ -9,7 +9,7 @@ import (
"strings"
"testing"
"somegit.dev/vikingowl/mistral-go-sdk/file"
"github.com/VikingOwl91/mistral-go-sdk/file"
)
func TestUploadFile_Success(t *testing.T) {
@@ -58,7 +58,7 @@ func TestUploadFile_Success(t *testing.T) {
context.Background(),
"train.jsonl",
strings.NewReader(`{"text":"hello"}`),
file.PurposeFineTune,
&file.UploadParams{Purpose: file.PurposeFineTune},
)
if err != nil {
t.Fatal(err)
@@ -74,6 +74,43 @@ func TestUploadFile_Success(t *testing.T) {
}
}
func TestUploadFile_WithExpiryAndVisibility(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseMultipartForm(10 << 20); err != nil {
t.Fatal(err)
}
if r.FormValue("purpose") != "fine-tune" {
t.Errorf("got purpose %q", r.FormValue("purpose"))
}
if r.FormValue("expiry") != "48" {
t.Errorf("expected expiry=48, got %q", r.FormValue("expiry"))
}
if r.FormValue("visibility") != "private" {
t.Errorf("expected visibility=private, got %q", r.FormValue("visibility"))
}
json.NewEncoder(w).Encode(map[string]any{
"id": "file-ev", "object": "file", "bytes": 10,
"created_at": 1, "filename": "data.jsonl",
"purpose": "fine-tune", "sample_type": "instruct",
"source": "upload",
})
}))
defer server.Close()
expiry := 48
vis := file.VisibilityPrivate
client := NewClient("key", WithBaseURL(server.URL))
_, err := client.UploadFile(context.Background(), "data.jsonl", strings.NewReader("{}"), &file.UploadParams{
Purpose: file.PurposeFineTune,
Expiry: &expiry,
Visibility: &vis,
})
if err != nil {
t.Fatal(err)
}
}
func TestListFiles_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
@@ -260,7 +297,7 @@ func TestUploadFile_Error(t *testing.T) {
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
_, err := client.UploadFile(context.Background(), "bad.txt", strings.NewReader(""), file.PurposeFineTune)
_, err := client.UploadFile(context.Background(), "bad.txt", strings.NewReader(""), &file.UploadParams{Purpose: file.PurposeFineTune})
if err == nil {
t.Fatal("expected error")
}

View File

@@ -3,8 +3,8 @@ package mistral
import (
"context"
"somegit.dev/vikingowl/mistral-go-sdk/chat"
"somegit.dev/vikingowl/mistral-go-sdk/fim"
"github.com/VikingOwl91/mistral-go-sdk/chat"
"github.com/VikingOwl91/mistral-go-sdk/fim"
)
// FIMComplete sends a Fill-In-the-Middle completion request.

View File

@@ -8,8 +8,8 @@ import (
"net/http/httptest"
"testing"
"somegit.dev/vikingowl/mistral-go-sdk/chat"
"somegit.dev/vikingowl/mistral-go-sdk/fim"
"github.com/VikingOwl91/mistral-go-sdk/chat"
"github.com/VikingOwl91/mistral-go-sdk/fim"
)
func TestFIMComplete_Success(t *testing.T) {

View File

@@ -7,7 +7,7 @@ import (
"net/url"
"strconv"
"somegit.dev/vikingowl/mistral-go-sdk/finetune"
"github.com/VikingOwl91/mistral-go-sdk/finetune"
)
// CreateFineTuningJob creates a new fine-tuning job.

View File

@@ -7,7 +7,7 @@ import (
"net/http/httptest"
"testing"
"somegit.dev/vikingowl/mistral-go-sdk/finetune"
"github.com/VikingOwl91/mistral-go-sdk/finetune"
)
func TestCreateFineTuningJob_Success(t *testing.T) {

2
go.mod
View File

@@ -1,3 +1,3 @@
module somegit.dev/vikingowl/mistral-go-sdk
module github.com/VikingOwl91/mistral-go-sdk
go 1.26

View File

@@ -8,8 +8,8 @@ import (
"strings"
"testing"
"somegit.dev/vikingowl/mistral-go-sdk/chat"
"somegit.dev/vikingowl/mistral-go-sdk/embedding"
"github.com/VikingOwl91/mistral-go-sdk/chat"
"github.com/VikingOwl91/mistral-go-sdk/embedding"
)
func integrationClient(t *testing.T) *Client {
@@ -24,7 +24,7 @@ func integrationClient(t *testing.T) *Client {
func TestIntegration_ListModels(t *testing.T) {
client := integrationClient(t)
resp, err := client.ListModels(context.Background())
resp, err := client.ListModels(context.Background(), nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -7,7 +7,7 @@ import (
"net/url"
"strconv"
"somegit.dev/vikingowl/mistral-go-sdk/library"
"github.com/VikingOwl91/mistral-go-sdk/library"
)
// CreateLibrary creates a new document library.

View File

@@ -8,7 +8,7 @@ import (
"strings"
"testing"
"somegit.dev/vikingowl/mistral-go-sdk/library"
"github.com/VikingOwl91/mistral-go-sdk/library"
)
func newLibraryJSON() map[string]any {

View File

@@ -6,7 +6,7 @@ import (
)
// Version is the SDK version string.
const Version = "0.2.0"
const Version = "1.2.1"
const (
defaultBaseURL = "https://api.mistral.ai"

View File

@@ -52,3 +52,9 @@ type DeleteModelOut struct {
Object string `json:"object"`
Deleted bool `json:"deleted"`
}
// ListParams holds optional parameters for listing models.
type ListParams struct {
Provider *string
Model *string
}

View File

@@ -2,14 +2,28 @@ package mistral
import (
"context"
"net/url"
"somegit.dev/vikingowl/mistral-go-sdk/model"
"github.com/VikingOwl91/mistral-go-sdk/model"
)
// ListModels returns a list of available models.
func (c *Client) ListModels(ctx context.Context) (*model.ModelList, error) {
func (c *Client) ListModels(ctx context.Context, params *model.ListParams) (*model.ModelList, error) {
path := "/v1/models"
if params != nil {
q := url.Values{}
if params.Provider != nil {
q.Set("provider", *params.Provider)
}
if params.Model != nil {
q.Set("model", *params.Model)
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp model.ModelList
if err := c.doJSON(ctx, "GET", "/v1/models", nil, &resp); err != nil {
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil

View File

@@ -6,6 +6,8 @@ import (
"net/http"
"net/http/httptest"
"testing"
"github.com/VikingOwl91/mistral-go-sdk/model"
)
func TestListModels_Success(t *testing.T) {
@@ -45,7 +47,7 @@ func TestListModels_Success(t *testing.T) {
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
list, err := client.ListModels(context.Background())
list, err := client.ListModels(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
@@ -82,6 +84,33 @@ func TestListModels_Success(t *testing.T) {
}
}
func TestListModels_WithParams(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Query().Get("provider") != "mistralai" {
t.Errorf("expected provider=mistralai, got %q", r.URL.Query().Get("provider"))
}
if r.URL.Query().Get("model") != "mistral-small" {
t.Errorf("expected model=mistral-small, got %q", r.URL.Query().Get("model"))
}
json.NewEncoder(w).Encode(map[string]any{
"object": "list",
"data": []map[string]any{},
})
}))
defer server.Close()
provider := "mistralai"
modelName := "mistral-small"
client := NewClient("key", WithBaseURL(server.URL))
_, err := client.ListModels(context.Background(), &model.ListParams{
Provider: &provider,
Model: &modelName,
})
if err != nil {
t.Fatal(err)
}
}
func TestGetModel_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/models/mistral-small-latest" {

View File

@@ -1,6 +1,6 @@
package moderation
import "somegit.dev/vikingowl/mistral-go-sdk/chat"
import "github.com/VikingOwl91/mistral-go-sdk/chat"
// Request represents a text moderation request (/v1/moderations).
type Request struct {

View File

@@ -3,8 +3,8 @@ package mistral
import (
"context"
"somegit.dev/vikingowl/mistral-go-sdk/classification"
"somegit.dev/vikingowl/mistral-go-sdk/moderation"
"github.com/VikingOwl91/mistral-go-sdk/classification"
"github.com/VikingOwl91/mistral-go-sdk/moderation"
)
// Moderate sends a text moderation request.

View File

@@ -7,8 +7,8 @@ import (
"net/http/httptest"
"testing"
"somegit.dev/vikingowl/mistral-go-sdk/classification"
"somegit.dev/vikingowl/mistral-go-sdk/moderation"
"github.com/VikingOwl91/mistral-go-sdk/classification"
"github.com/VikingOwl91/mistral-go-sdk/moderation"
)
func TestModerate_Success(t *testing.T) {

46
observability/campaign.go Normal file
View File

@@ -0,0 +1,46 @@
package observability
// Campaign represents an observability campaign.
type Campaign struct {
ID string `json:"id"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
DeletedAt *string `json:"deleted_at,omitempty"`
Name string `json:"name"`
OwnerID string `json:"owner_id"`
WorkspaceID string `json:"workspace_id"`
Description string `json:"description"`
MaxNbEvents int `json:"max_nb_events"`
SearchParams FilterPayload `json:"search_params"`
Judge Judge `json:"judge"`
}
// CreateCampaignRequest creates a new campaign.
type CreateCampaignRequest struct {
SearchParams FilterPayload `json:"search_params"`
JudgeID string `json:"judge_id"`
Name string `json:"name"`
Description string `json:"description"`
MaxNbEvents int `json:"max_nb_events"`
}
// CampaignStatusResponse is the response for campaign status.
type CampaignStatusResponse struct {
Status TaskStatus `json:"status"`
}
// ListCampaignsResponse is the response from listing campaigns.
type ListCampaignsResponse struct {
Count int `json:"count"`
Results []Campaign `json:"results,omitempty"`
Next *string `json:"next,omitempty"`
Previous *string `json:"previous,omitempty"`
}
// ListCampaignEventsResponse is the response from listing campaign events.
type ListCampaignEventsResponse struct {
Count int `json:"count"`
Results []ChatCompletionEventPreview `json:"results,omitempty"`
Next *string `json:"next,omitempty"`
Previous *string `json:"previous,omitempty"`
}

156
observability/dataset.go Normal file
View File

@@ -0,0 +1,156 @@
package observability
import "encoding/json"
// ConversationSource indicates how a dataset record was created.
type ConversationSource string
const (
SourceExplorer ConversationSource = "EXPLORER"
SourceUploadedFile ConversationSource = "UPLOADED_FILE"
SourceDirectInput ConversationSource = "DIRECT_INPUT"
SourcePlayground ConversationSource = "PLAYGROUND"
)
// Dataset represents a dataset entity.
type Dataset struct {
ID string `json:"id"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
DeletedAt *string `json:"deleted_at,omitempty"`
Name string `json:"name"`
Description string `json:"description"`
OwnerID string `json:"owner_id"`
WorkspaceID string `json:"workspace_id"`
}
// CreateDatasetRequest creates a new dataset.
type CreateDatasetRequest struct {
Name string `json:"name"`
Description string `json:"description"`
}
// UpdateDatasetRequest updates a dataset.
type UpdateDatasetRequest struct {
Name *string `json:"name,omitempty"`
Description *string `json:"description,omitempty"`
}
// DatasetRecord is a single record in a dataset.
type DatasetRecord struct {
ID string `json:"id"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
DeletedAt *string `json:"deleted_at,omitempty"`
DatasetID string `json:"dataset_id"`
Payload ConversationPayload `json:"payload"`
Properties map[string]any `json:"properties,omitempty"`
Source ConversationSource `json:"source"`
}
// ConversationPayload holds the messages for a dataset record.
type ConversationPayload struct {
Messages []map[string]any `json:"messages"`
}
// CreateRecordRequest creates a new dataset record.
type CreateRecordRequest struct {
Payload ConversationPayload `json:"payload"`
Properties map[string]any `json:"properties"`
}
// UpdateRecordPayloadRequest updates a record's payload.
type UpdateRecordPayloadRequest struct {
Payload ConversationPayload `json:"payload"`
}
// UpdateRecordPropertiesRequest updates a record's properties.
type UpdateRecordPropertiesRequest struct {
Properties map[string]any `json:"properties"`
}
// BulkDeleteRecordsRequest deletes multiple records.
type BulkDeleteRecordsRequest struct {
DatasetRecordIDs []string `json:"dataset_record_ids"`
}
// JudgeRecordRequest judges a dataset record.
type JudgeRecordRequest struct {
JudgeDefinition CreateJudgeRequest `json:"judge_definition"`
}
// DatasetImportTask tracks an async import operation.
type DatasetImportTask struct {
ID string `json:"id"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
DeletedAt *string `json:"deleted_at,omitempty"`
CreatorID string `json:"creator_id"`
DatasetID string `json:"dataset_id"`
WorkspaceID string `json:"workspace_id"`
Status TaskStatus `json:"status"`
Progress *int `json:"progress,omitempty"`
Message *string `json:"message,omitempty"`
}
// ExportDatasetResponse is the response from exporting a dataset.
type ExportDatasetResponse struct {
FileURL string `json:"file_url"`
}
// Import request types.
// ImportFromCampaignRequest imports records from a campaign.
type ImportFromCampaignRequest struct {
CampaignID string `json:"campaign_id"`
}
// ImportFromExplorerRequest imports records from explorer events.
type ImportFromExplorerRequest struct {
CompletionEventIDs []string `json:"completion_event_ids"`
}
// ImportFromFileRequest imports records from a file.
type ImportFromFileRequest struct {
FileID string `json:"file_id"`
}
// ImportFromPlaygroundRequest imports records from playground conversations.
type ImportFromPlaygroundRequest struct {
ConversationIDs []string `json:"conversation_ids"`
}
// ImportFromDatasetRequest imports records from another dataset.
type ImportFromDatasetRequest struct {
DatasetRecordIDs []string `json:"dataset_record_ids"`
}
// List response types.
// ListDatasetsResponse is the response from listing datasets.
type ListDatasetsResponse struct {
Count int `json:"count"`
Results []Dataset `json:"results,omitempty"`
Next *string `json:"next,omitempty"`
Previous *string `json:"previous,omitempty"`
}
// ListRecordsResponse is the response from listing dataset records.
type ListRecordsResponse struct {
Count int `json:"count"`
Results []DatasetRecord `json:"results,omitempty"`
Next *string `json:"next,omitempty"`
Previous *string `json:"previous,omitempty"`
}
// ListTasksResponse is the response from listing import tasks.
type ListTasksResponse struct {
Count int `json:"count"`
Results []DatasetImportTask `json:"results,omitempty"`
Next *string `json:"next,omitempty"`
Previous *string `json:"previous,omitempty"`
}
// JudgeResultResponse is the raw response from judging operations.
// The shape depends on the judge type (classification or regression).
type JudgeResultResponse json.RawMessage

5
observability/doc.go Normal file
View File

@@ -0,0 +1,5 @@
// Package observability provides types for the Mistral observability API (beta).
//
// This includes campaigns, chat completion events, judges, and datasets
// for monitoring and evaluating model behavior.
package observability

70
observability/event.go Normal file
View File

@@ -0,0 +1,70 @@
package observability
// ChatCompletionEvent is a full chat completion event.
type ChatCompletionEvent struct {
EventID string `json:"event_id"`
CorrelationID string `json:"correlation_id"`
CreatedAt string `json:"created_at"`
ExtraFields map[string]any `json:"extra_fields,omitempty"`
NbInputTokens int `json:"nb_input_tokens"`
NbOutputTokens int `json:"nb_output_tokens"`
EnabledTools []map[string]any `json:"enabled_tools,omitempty"`
RequestMessages []map[string]any `json:"request_messages,omitempty"`
ResponseMessages []map[string]any `json:"response_messages,omitempty"`
NbMessages int `json:"nb_messages"`
ChatTranscriptionEvents []ChatTranscriptionEvent `json:"chat_transcription_events,omitempty"`
}
// ChatCompletionEventPreview is a summary of a chat completion event.
type ChatCompletionEventPreview struct {
EventID string `json:"event_id"`
CorrelationID string `json:"correlation_id"`
CreatedAt string `json:"created_at"`
ExtraFields map[string]any `json:"extra_fields,omitempty"`
NbInputTokens int `json:"nb_input_tokens"`
NbOutputTokens int `json:"nb_output_tokens"`
}
// ChatTranscriptionEvent is an audio transcription within a chat event.
type ChatTranscriptionEvent struct {
AudioURL string `json:"audio_url"`
Model string `json:"model"`
ResponseMessage map[string]any `json:"response_message"`
}
// SearchEventsRequest is the request body for searching chat completion events.
type SearchEventsRequest struct {
SearchParams FilterPayload `json:"search_params"`
ExtraFields []string `json:"extra_fields,omitempty"`
}
// SearchEventsResponse is the response from searching events.
type SearchEventsResponse struct {
Results []ChatCompletionEventPreview `json:"results,omitempty"`
Next *string `json:"next,omitempty"`
Cursor *string `json:"cursor,omitempty"`
}
// SearchEventIDsRequest is the request body for searching event IDs.
type SearchEventIDsRequest struct {
SearchParams FilterPayload `json:"search_params"`
ExtraFields []string `json:"extra_fields,omitempty"`
}
// SearchEventIDsResponse is the response from searching event IDs.
type SearchEventIDsResponse struct {
CompletionEventIDs []string `json:"completion_event_ids"`
}
// JudgeEventRequest is the request body for judging a chat completion event.
type JudgeEventRequest struct {
JudgeDefinition CreateJudgeRequest `json:"judge_definition"`
}
// SimilarEventsResponse is the response from fetching similar events.
type SimilarEventsResponse struct {
Count int `json:"count"`
Results []ChatCompletionEventPreview `json:"results,omitempty"`
Next *string `json:"next,omitempty"`
Previous *string `json:"previous,omitempty"`
}

75
observability/filter.go Normal file
View File

@@ -0,0 +1,75 @@
package observability
import "encoding/json"
// Op is a filter comparison operator.
type Op string
const (
OpLt Op = "lt"
OpLte Op = "lte"
OpGt Op = "gt"
OpGte Op = "gte"
OpEq Op = "eq"
OpNeq Op = "neq"
OpIsNull Op = "isnull"
OpStartsWith Op = "startswith"
OpIStartsWith Op = "istartswith"
OpEndsWith Op = "endswith"
OpIEndsWith Op = "iendswith"
OpContains Op = "contains"
OpIContains Op = "icontains"
OpMatches Op = "matches"
OpNotContains Op = "notcontains"
OpINotContains Op = "inotcontains"
OpIncludes Op = "includes"
OpExcludes Op = "excludes"
OpLenEq Op = "len_eq"
)
// FilterCondition is a single filter comparison.
type FilterCondition struct {
Field string `json:"field"`
Op Op `json:"op"`
Value any `json:"value"`
}
// FilterGroup combines filters with AND/OR logic.
// The JSON keys are uppercase "AND" / "OR".
type FilterGroup struct {
AND []json.RawMessage `json:"AND,omitempty"`
OR []json.RawMessage `json:"OR,omitempty"`
}
// FilterPayload wraps the top-level filter for search operations.
// Filters can be a FilterGroup or a FilterCondition.
type FilterPayload struct {
Filters json.RawMessage `json:"filters,omitempty"`
}
// TaskStatus is the status of an async task.
type TaskStatus string
const (
TaskStatusRunning TaskStatus = "RUNNING"
TaskStatusCompleted TaskStatus = "COMPLETED"
TaskStatusFailed TaskStatus = "FAILED"
TaskStatusCanceled TaskStatus = "CANCELED"
TaskStatusTerminated TaskStatus = "TERMINATED"
TaskStatusContinuedAsNew TaskStatus = "CONTINUED_AS_NEW"
TaskStatusTimedOut TaskStatus = "TIMED_OUT"
TaskStatusUnknown TaskStatus = "UNKNOWN"
)
// PaginationParams holds common pagination query parameters.
type PaginationParams struct {
Page *int
PageSize *int
}
// SearchParams holds common search query parameters.
type SearchParams struct {
Page *int
PageSize *int
Q *string
}

114
observability/judge.go Normal file
View File

@@ -0,0 +1,114 @@
package observability
import (
"encoding/json"
"fmt"
)
// JudgeOutputType identifies the kind of judge output.
type JudgeOutputType string
const (
JudgeOutputClassification JudgeOutputType = "CLASSIFICATION"
JudgeOutputRegression JudgeOutputType = "REGRESSION"
)
// JudgeOutput is a sealed interface for judge output configurations.
type JudgeOutput interface {
judgeOutputType() JudgeOutputType
}
// ClassificationOutput configures a classification judge.
type ClassificationOutput struct {
Type JudgeOutputType `json:"type"`
Options []ClassificationOption `json:"options"`
}
func (*ClassificationOutput) judgeOutputType() JudgeOutputType { return JudgeOutputClassification }
// ClassificationOption is a single option for a classification judge.
type ClassificationOption struct {
Value string `json:"value"`
Description string `json:"description"`
}
// RegressionOutput configures a regression judge.
type RegressionOutput struct {
Type JudgeOutputType `json:"type"`
MinDescription string `json:"min_description"`
MaxDescription string `json:"max_description"`
Min *float64 `json:"min,omitempty"`
Max *float64 `json:"max,omitempty"`
}
func (*RegressionOutput) judgeOutputType() JudgeOutputType { return JudgeOutputRegression }
// UnmarshalJudgeOutput dispatches to the concrete JudgeOutput type.
func UnmarshalJudgeOutput(data []byte) (JudgeOutput, error) {
var probe struct {
Type string `json:"type"`
}
if err := json.Unmarshal(data, &probe); err != nil {
return nil, fmt.Errorf("unmarshal judge output: %w", err)
}
switch JudgeOutputType(probe.Type) {
case JudgeOutputClassification:
var o ClassificationOutput
return &o, json.Unmarshal(data, &o)
case JudgeOutputRegression:
var o RegressionOutput
return &o, json.Unmarshal(data, &o)
default:
return nil, fmt.Errorf("unknown judge output type: %q", probe.Type)
}
}
// Judge represents a judge entity.
type Judge struct {
ID string `json:"id"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
DeletedAt *string `json:"deleted_at,omitempty"`
OwnerID string `json:"owner_id"`
WorkspaceID string `json:"workspace_id"`
Name string `json:"name"`
Description string `json:"description"`
ModelName string `json:"model_name"`
Output json.RawMessage `json:"output"`
Instructions string `json:"instructions"`
Tools []string `json:"tools,omitempty"`
}
// CreateJudgeRequest creates a new judge.
type CreateJudgeRequest struct {
Name string `json:"name"`
Description string `json:"description"`
ModelName string `json:"model_name"`
Output json.RawMessage `json:"output"`
Instructions string `json:"instructions"`
Tools []string `json:"tools"`
}
// UpdateJudgeRequest updates a judge.
type UpdateJudgeRequest struct {
Name string `json:"name"`
Description string `json:"description"`
ModelName string `json:"model_name"`
Output json.RawMessage `json:"output"`
Instructions string `json:"instructions"`
Tools []string `json:"tools"`
}
// JudgeConversationRequest is the request for live-judging a conversation.
type JudgeConversationRequest struct {
Messages []map[string]any `json:"messages"`
Properties map[string]any `json:"properties,omitempty"`
}
// ListJudgesResponse is the response from listing judges.
type ListJudgesResponse struct {
Count int `json:"count"`
Results []Judge `json:"results,omitempty"`
Next *string `json:"next,omitempty"`
Previous *string `json:"previous,omitempty"`
}

View File

@@ -0,0 +1,97 @@
package mistral
import (
"context"
"fmt"
"net/url"
"strconv"
"github.com/VikingOwl91/mistral-go-sdk/observability"
)
// CreateCampaign creates a new observability campaign.
func (c *Client) CreateCampaign(ctx context.Context, req *observability.CreateCampaignRequest) (*observability.Campaign, error) {
var resp observability.Campaign
if err := c.doJSON(ctx, "POST", "/v1/observability/campaigns", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ListCampaigns lists observability campaigns.
func (c *Client) ListCampaigns(ctx context.Context, params *observability.SearchParams) (*observability.ListCampaignsResponse, error) {
path := "/v1/observability/campaigns"
if params != nil {
q := url.Values{}
if params.PageSize != nil {
q.Set("page_size", strconv.Itoa(*params.PageSize))
}
if params.Page != nil {
q.Set("page", strconv.Itoa(*params.Page))
}
if params.Q != nil {
q.Set("q", *params.Q)
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp observability.ListCampaignsResponse
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// GetCampaign retrieves a campaign by ID.
func (c *Client) GetCampaign(ctx context.Context, campaignID string) (*observability.Campaign, error) {
var resp observability.Campaign
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/campaigns/%s", campaignID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// DeleteCampaign deletes a campaign.
func (c *Client) DeleteCampaign(ctx context.Context, campaignID string) error {
resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/observability/campaigns/%s", campaignID), nil)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return parseAPIError(resp)
}
return nil
}
// GetCampaignStatus retrieves the status of a campaign.
func (c *Client) GetCampaignStatus(ctx context.Context, campaignID string) (*observability.CampaignStatusResponse, error) {
var resp observability.CampaignStatusResponse
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/campaigns/%s/status", campaignID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ListCampaignEvents lists events selected by a campaign.
func (c *Client) ListCampaignEvents(ctx context.Context, campaignID string, params *observability.PaginationParams) (*observability.ListCampaignEventsResponse, error) {
path := fmt.Sprintf("/v1/observability/campaigns/%s/selected-events", campaignID)
if params != nil {
q := url.Values{}
if params.PageSize != nil {
q.Set("page_size", strconv.Itoa(*params.PageSize))
}
if params.Page != nil {
q.Set("page", strconv.Itoa(*params.Page))
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp observability.ListCampaignEventsResponse
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}

View File

@@ -0,0 +1,148 @@
package mistral
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/VikingOwl91/mistral-go-sdk/observability"
)
func TestCreateCampaign_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" || r.URL.Path != "/v1/observability/campaigns" {
t.Errorf("unexpected %s %s", r.Method, r.URL.Path)
}
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
if body["name"] != "test-campaign" {
t.Errorf("got name %v", body["name"])
}
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]any{
"id": "camp-1", "name": "test-campaign", "description": "d",
"created_at": "t", "updated_at": "t", "owner_id": "o",
"workspace_id": "w", "max_nb_events": 100,
"search_params": map[string]any{}, "judge": map[string]any{"id": "j1"},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.CreateCampaign(context.Background(), &observability.CreateCampaignRequest{
Name: "test-campaign",
Description: "d",
JudgeID: "j1",
MaxNbEvents: 100,
})
if err != nil {
t.Fatal(err)
}
if resp.ID != "camp-1" {
t.Errorf("got id %q", resp.ID)
}
}
func TestListCampaigns_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/campaigns" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"count": 1,
"results": []map[string]any{{"id": "c1", "name": "c", "description": "d", "created_at": "t", "updated_at": "t", "owner_id": "o", "workspace_id": "w", "max_nb_events": 10, "search_params": map[string]any{}, "judge": map[string]any{"id": "j"}}},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.ListCampaigns(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
if resp.Count != 1 {
t.Errorf("got count %d", resp.Count)
}
}
func TestGetCampaign_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/campaigns/camp-1" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"id": "camp-1", "name": "c", "description": "d",
"created_at": "t", "updated_at": "t", "owner_id": "o",
"workspace_id": "w", "max_nb_events": 10,
"search_params": map[string]any{}, "judge": map[string]any{"id": "j"},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.GetCampaign(context.Background(), "camp-1")
if err != nil {
t.Fatal(err)
}
if resp.ID != "camp-1" {
t.Errorf("got id %q", resp.ID)
}
}
func TestDeleteCampaign_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "DELETE" {
t.Errorf("expected DELETE")
}
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
if err := client.DeleteCampaign(context.Background(), "camp-1"); err != nil {
t.Fatal(err)
}
}
func TestGetCampaignStatus_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/campaigns/camp-1/status" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{"status": "COMPLETED"})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.GetCampaignStatus(context.Background(), "camp-1")
if err != nil {
t.Fatal(err)
}
if resp.Status != observability.TaskStatusCompleted {
t.Errorf("got status %q", resp.Status)
}
}
func TestListCampaignEvents_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/campaigns/camp-1/selected-events" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"count": 0,
"results": []any{},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.ListCampaignEvents(context.Background(), "camp-1", nil)
if err != nil {
t.Fatal(err)
}
if resp.Count != 0 {
t.Errorf("got count %d", resp.Count)
}
}

252
observability_datasets.go Normal file
View File

@@ -0,0 +1,252 @@
package mistral
import (
"context"
"encoding/json"
"fmt"
"net/url"
"strconv"
"github.com/VikingOwl91/mistral-go-sdk/observability"
)
// CreateDataset creates a new observability dataset.
func (c *Client) CreateDataset(ctx context.Context, req *observability.CreateDatasetRequest) (*observability.Dataset, error) {
var resp observability.Dataset
if err := c.doJSON(ctx, "POST", "/v1/observability/datasets", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ListDatasets lists observability datasets.
func (c *Client) ListDatasets(ctx context.Context, params *observability.SearchParams) (*observability.ListDatasetsResponse, error) {
path := "/v1/observability/datasets"
if params != nil {
q := url.Values{}
if params.PageSize != nil {
q.Set("page_size", strconv.Itoa(*params.PageSize))
}
if params.Page != nil {
q.Set("page", strconv.Itoa(*params.Page))
}
if params.Q != nil {
q.Set("q", *params.Q)
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp observability.ListDatasetsResponse
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// GetDataset retrieves a dataset by ID.
func (c *Client) GetDataset(ctx context.Context, datasetID string) (*observability.Dataset, error) {
var resp observability.Dataset
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/datasets/%s", datasetID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// UpdateDataset updates a dataset.
func (c *Client) UpdateDataset(ctx context.Context, datasetID string, req *observability.UpdateDatasetRequest) (*observability.Dataset, error) {
var resp observability.Dataset
if err := c.doJSON(ctx, "PATCH", fmt.Sprintf("/v1/observability/datasets/%s", datasetID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// DeleteDataset deletes a dataset.
func (c *Client) DeleteDataset(ctx context.Context, datasetID string) error {
resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/observability/datasets/%s", datasetID), nil)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return parseAPIError(resp)
}
return nil
}
// ExportDatasetToJSONL exports a dataset to JSONL format.
func (c *Client) ExportDatasetToJSONL(ctx context.Context, datasetID string) (*observability.ExportDatasetResponse, error) {
var resp observability.ExportDatasetResponse
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/datasets/%s/exports/to-jsonl", datasetID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// Dataset records
// ListDatasetRecords lists records in a dataset.
func (c *Client) ListDatasetRecords(ctx context.Context, datasetID string, params *observability.PaginationParams) (*observability.ListRecordsResponse, error) {
path := fmt.Sprintf("/v1/observability/datasets/%s/records", datasetID)
if params != nil {
q := url.Values{}
if params.PageSize != nil {
q.Set("page_size", strconv.Itoa(*params.PageSize))
}
if params.Page != nil {
q.Set("page", strconv.Itoa(*params.Page))
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp observability.ListRecordsResponse
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// CreateDatasetRecord creates a record in a dataset.
func (c *Client) CreateDatasetRecord(ctx context.Context, datasetID string, req *observability.CreateRecordRequest) (*observability.DatasetRecord, error) {
var resp observability.DatasetRecord
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/records", datasetID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// GetDatasetRecord retrieves a dataset record by ID.
func (c *Client) GetDatasetRecord(ctx context.Context, recordID string) (*observability.DatasetRecord, error) {
var resp observability.DatasetRecord
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/dataset-records/%s", recordID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// UpdateDatasetRecordPayload updates a record's payload.
func (c *Client) UpdateDatasetRecordPayload(ctx context.Context, recordID string, req *observability.UpdateRecordPayloadRequest) (*observability.DatasetRecord, error) {
var resp observability.DatasetRecord
if err := c.doJSON(ctx, "PUT", fmt.Sprintf("/v1/observability/dataset-records/%s/payload", recordID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// UpdateDatasetRecordProperties updates a record's properties.
func (c *Client) UpdateDatasetRecordProperties(ctx context.Context, recordID string, req *observability.UpdateRecordPropertiesRequest) (*observability.DatasetRecord, error) {
var resp observability.DatasetRecord
if err := c.doJSON(ctx, "PUT", fmt.Sprintf("/v1/observability/dataset-records/%s/properties", recordID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// DeleteDatasetRecord deletes a dataset record.
func (c *Client) DeleteDatasetRecord(ctx context.Context, recordID string) error {
resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/observability/dataset-records/%s", recordID), nil)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return parseAPIError(resp)
}
return nil
}
// BulkDeleteDatasetRecords deletes multiple dataset records.
func (c *Client) BulkDeleteDatasetRecords(ctx context.Context, req *observability.BulkDeleteRecordsRequest) error {
return c.doJSON(ctx, "POST", "/v1/observability/dataset-records/bulk-delete", req, nil)
}
// JudgeDatasetRecord judges a dataset record.
func (c *Client) JudgeDatasetRecord(ctx context.Context, recordID string, req *observability.JudgeRecordRequest) (json.RawMessage, error) {
var resp json.RawMessage
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/dataset-records/%s/live-judging", recordID), req, &resp); err != nil {
return nil, err
}
return resp, nil
}
// Import operations
// ImportDatasetFromCampaign imports records from a campaign.
func (c *Client) ImportDatasetFromCampaign(ctx context.Context, datasetID string, req *observability.ImportFromCampaignRequest) (*observability.DatasetImportTask, error) {
var resp observability.DatasetImportTask
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/imports/from-campaign", datasetID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ImportDatasetFromExplorer imports records from explorer events.
func (c *Client) ImportDatasetFromExplorer(ctx context.Context, datasetID string, req *observability.ImportFromExplorerRequest) (*observability.DatasetImportTask, error) {
var resp observability.DatasetImportTask
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/imports/from-explorer", datasetID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ImportDatasetFromFile imports records from a file.
func (c *Client) ImportDatasetFromFile(ctx context.Context, datasetID string, req *observability.ImportFromFileRequest) (*observability.DatasetImportTask, error) {
var resp observability.DatasetImportTask
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/imports/from-file", datasetID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ImportDatasetFromPlayground imports records from playground conversations.
func (c *Client) ImportDatasetFromPlayground(ctx context.Context, datasetID string, req *observability.ImportFromPlaygroundRequest) (*observability.DatasetImportTask, error) {
var resp observability.DatasetImportTask
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/imports/from-playground", datasetID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ImportDatasetFromDataset imports records from another dataset.
func (c *Client) ImportDatasetFromDataset(ctx context.Context, datasetID string, req *observability.ImportFromDatasetRequest) (*observability.DatasetImportTask, error) {
var resp observability.DatasetImportTask
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/imports/from-dataset", datasetID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// Tasks
// ListDatasetTasks lists import tasks for a dataset.
func (c *Client) ListDatasetTasks(ctx context.Context, datasetID string, params *observability.PaginationParams) (*observability.ListTasksResponse, error) {
path := fmt.Sprintf("/v1/observability/datasets/%s/tasks", datasetID)
if params != nil {
q := url.Values{}
if params.PageSize != nil {
q.Set("page_size", strconv.Itoa(*params.PageSize))
}
if params.Page != nil {
q.Set("page", strconv.Itoa(*params.Page))
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp observability.ListTasksResponse
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// GetDatasetTask retrieves an import task by ID.
func (c *Client) GetDatasetTask(ctx context.Context, datasetID, taskID string) (*observability.DatasetImportTask, error) {
var resp observability.DatasetImportTask
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/datasets/%s/tasks/%s", datasetID, taskID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}

View File

@@ -0,0 +1,211 @@
package mistral
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/VikingOwl91/mistral-go-sdk/observability"
)
func datasetJSON() map[string]any {
return map[string]any{
"id": "ds-1", "created_at": "t", "updated_at": "t",
"name": "test-ds", "description": "d",
"owner_id": "o", "workspace_id": "w",
}
}
func TestCreateDataset_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" || r.URL.Path != "/v1/observability/datasets" {
t.Errorf("unexpected %s %s", r.Method, r.URL.Path)
}
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(datasetJSON())
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.CreateDataset(context.Background(), &observability.CreateDatasetRequest{
Name: "test-ds",
Description: "d",
})
if err != nil {
t.Fatal(err)
}
if resp.ID != "ds-1" {
t.Errorf("got id %q", resp.ID)
}
}
func TestListDatasets_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]any{
"count": 1,
"results": []any{datasetJSON()},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.ListDatasets(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
if resp.Count != 1 {
t.Errorf("got count %d", resp.Count)
}
}
func TestGetDataset_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/datasets/ds-1" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(datasetJSON())
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.GetDataset(context.Background(), "ds-1")
if err != nil {
t.Fatal(err)
}
if resp.Name != "test-ds" {
t.Errorf("got name %q", resp.Name)
}
}
func TestDeleteDataset_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
if err := client.DeleteDataset(context.Background(), "ds-1"); err != nil {
t.Fatal(err)
}
}
func TestCreateDatasetRecord_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" || r.URL.Path != "/v1/observability/datasets/ds-1/records" {
t.Errorf("unexpected %s %s", r.Method, r.URL.Path)
}
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]any{
"id": "rec-1", "created_at": "t", "updated_at": "t",
"dataset_id": "ds-1", "source": "DIRECT_INPUT",
"payload": map[string]any{"messages": []any{}},
"properties": map[string]any{},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.CreateDatasetRecord(context.Background(), "ds-1", &observability.CreateRecordRequest{
Payload: observability.ConversationPayload{Messages: []map[string]any{}},
Properties: map[string]any{},
})
if err != nil {
t.Fatal(err)
}
if resp.ID != "rec-1" {
t.Errorf("got id %q", resp.ID)
}
}
func TestListDatasetRecords_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/datasets/ds-1/records" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"count": 0, "results": []any{},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.ListDatasetRecords(context.Background(), "ds-1", nil)
if err != nil {
t.Fatal(err)
}
if resp.Count != 0 {
t.Errorf("got count %d", resp.Count)
}
}
func TestImportDatasetFromCampaign_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/datasets/ds-1/imports/from-campaign" {
t.Errorf("got path %s", r.URL.Path)
}
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]any{
"id": "task-1", "created_at": "t", "updated_at": "t",
"creator_id": "u", "dataset_id": "ds-1", "workspace_id": "w",
"status": "RUNNING",
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.ImportDatasetFromCampaign(context.Background(), "ds-1", &observability.ImportFromCampaignRequest{
CampaignID: "camp-1",
})
if err != nil {
t.Fatal(err)
}
if resp.Status != observability.TaskStatusRunning {
t.Errorf("got status %q", resp.Status)
}
}
func TestExportDatasetToJSONL_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/datasets/ds-1/exports/to-jsonl" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"file_url": "https://storage.example.com/export.jsonl",
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.ExportDatasetToJSONL(context.Background(), "ds-1")
if err != nil {
t.Fatal(err)
}
if resp.FileURL != "https://storage.example.com/export.jsonl" {
t.Errorf("got file_url %q", resp.FileURL)
}
}
func TestGetDatasetTask_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/datasets/ds-1/tasks/task-1" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"id": "task-1", "created_at": "t", "updated_at": "t",
"creator_id": "u", "dataset_id": "ds-1", "workspace_id": "w",
"status": "COMPLETED",
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.GetDatasetTask(context.Background(), "ds-1", "task-1")
if err != nil {
t.Fatal(err)
}
if resp.Status != observability.TaskStatusCompleted {
t.Errorf("got status %q", resp.Status)
}
}

69
observability_events.go Normal file
View File

@@ -0,0 +1,69 @@
package mistral
import (
"context"
"encoding/json"
"fmt"
"net/url"
"strconv"
"github.com/VikingOwl91/mistral-go-sdk/observability"
)
// SearchChatCompletionEvents searches for chat completion events.
func (c *Client) SearchChatCompletionEvents(ctx context.Context, req *observability.SearchEventsRequest) (*observability.SearchEventsResponse, error) {
var resp observability.SearchEventsResponse
if err := c.doJSON(ctx, "POST", "/v1/observability/chat-completion-events/search", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// SearchChatCompletionEventIDs searches for chat completion event IDs.
func (c *Client) SearchChatCompletionEventIDs(ctx context.Context, req *observability.SearchEventIDsRequest) (*observability.SearchEventIDsResponse, error) {
var resp observability.SearchEventIDsResponse
if err := c.doJSON(ctx, "POST", "/v1/observability/chat-completion-events/search-ids", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// GetChatCompletionEvent retrieves a chat completion event by ID.
func (c *Client) GetChatCompletionEvent(ctx context.Context, eventID string) (*observability.ChatCompletionEvent, error) {
var resp observability.ChatCompletionEvent
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/chat-completion-events/%s", eventID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// GetSimilarChatCompletionEvents retrieves events similar to a given event.
func (c *Client) GetSimilarChatCompletionEvents(ctx context.Context, eventID string, params *observability.PaginationParams) (*observability.SimilarEventsResponse, error) {
path := fmt.Sprintf("/v1/observability/chat-completion-events/%s/similar-events", eventID)
if params != nil {
q := url.Values{}
if params.PageSize != nil {
q.Set("page_size", strconv.Itoa(*params.PageSize))
}
if params.Page != nil {
q.Set("page", strconv.Itoa(*params.Page))
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp observability.SimilarEventsResponse
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// JudgeChatCompletionEvent judges a chat completion event.
func (c *Client) JudgeChatCompletionEvent(ctx context.Context, eventID string, req *observability.JudgeEventRequest) (json.RawMessage, error) {
var resp json.RawMessage
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/chat-completion-events/%s/live-judging", eventID), req, &resp); err != nil {
return nil, err
}
return resp, nil
}

View File

@@ -0,0 +1,101 @@
package mistral
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/VikingOwl91/mistral-go-sdk/observability"
)
func TestSearchChatCompletionEvents_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" || r.URL.Path != "/v1/observability/chat-completion-events/search" {
t.Errorf("unexpected %s %s", r.Method, r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"results": []map[string]any{
{"event_id": "ev-1", "correlation_id": "c1", "created_at": "t", "nb_input_tokens": 10, "nb_output_tokens": 5},
},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.SearchChatCompletionEvents(context.Background(), &observability.SearchEventsRequest{})
if err != nil {
t.Fatal(err)
}
if len(resp.Results) != 1 {
t.Fatalf("got %d results", len(resp.Results))
}
if resp.Results[0].EventID != "ev-1" {
t.Errorf("got event_id %q", resp.Results[0].EventID)
}
}
func TestSearchChatCompletionEventIDs_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/chat-completion-events/search-ids" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"completion_event_ids": []string{"ev-1", "ev-2"},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.SearchChatCompletionEventIDs(context.Background(), &observability.SearchEventIDsRequest{})
if err != nil {
t.Fatal(err)
}
if len(resp.CompletionEventIDs) != 2 {
t.Errorf("got %d ids", len(resp.CompletionEventIDs))
}
}
func TestGetChatCompletionEvent_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/chat-completion-events/ev-1" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"event_id": "ev-1", "correlation_id": "c1", "created_at": "t",
"nb_input_tokens": 10, "nb_output_tokens": 5, "nb_messages": 2,
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.GetChatCompletionEvent(context.Background(), "ev-1")
if err != nil {
t.Fatal(err)
}
if resp.EventID != "ev-1" {
t.Errorf("got event_id %q", resp.EventID)
}
}
func TestGetSimilarChatCompletionEvents_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/chat-completion-events/ev-1/similar-events" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"count": 0, "results": []any{},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.GetSimilarChatCompletionEvents(context.Background(), "ev-1", nil)
if err != nil {
t.Fatal(err)
}
if resp.Count != 0 {
t.Errorf("got count %d", resp.Count)
}
}

85
observability_judges.go Normal file
View File

@@ -0,0 +1,85 @@
package mistral
import (
"context"
"encoding/json"
"fmt"
"net/url"
"strconv"
"github.com/VikingOwl91/mistral-go-sdk/observability"
)
// CreateJudge creates a new observability judge.
func (c *Client) CreateJudge(ctx context.Context, req *observability.CreateJudgeRequest) (*observability.Judge, error) {
var resp observability.Judge
if err := c.doJSON(ctx, "POST", "/v1/observability/judges", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ListJudges lists observability judges.
func (c *Client) ListJudges(ctx context.Context, params *observability.SearchParams) (*observability.ListJudgesResponse, error) {
path := "/v1/observability/judges"
if params != nil {
q := url.Values{}
if params.PageSize != nil {
q.Set("page_size", strconv.Itoa(*params.PageSize))
}
if params.Page != nil {
q.Set("page", strconv.Itoa(*params.Page))
}
if params.Q != nil {
q.Set("q", *params.Q)
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp observability.ListJudgesResponse
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// GetJudge retrieves a judge by ID.
func (c *Client) GetJudge(ctx context.Context, judgeID string) (*observability.Judge, error) {
var resp observability.Judge
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/judges/%s", judgeID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// UpdateJudge updates a judge.
func (c *Client) UpdateJudge(ctx context.Context, judgeID string, req *observability.UpdateJudgeRequest) (*observability.Judge, error) {
var resp observability.Judge
if err := c.doJSON(ctx, "PUT", fmt.Sprintf("/v1/observability/judges/%s", judgeID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// DeleteJudge deletes a judge.
func (c *Client) DeleteJudge(ctx context.Context, judgeID string) error {
resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/observability/judges/%s", judgeID), nil)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return parseAPIError(resp)
}
return nil
}
// JudgeConversation performs live judging on a conversation.
func (c *Client) JudgeConversation(ctx context.Context, judgeID string, req *observability.JudgeConversationRequest) (json.RawMessage, error) {
var resp json.RawMessage
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/judges/%s/live-judging", judgeID), req, &resp); err != nil {
return nil, err
}
return resp, nil
}

View File

@@ -0,0 +1,123 @@
package mistral
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/VikingOwl91/mistral-go-sdk/observability"
)
func judgeJSON() map[string]any {
return map[string]any{
"id": "j1", "created_at": "t", "updated_at": "t",
"owner_id": "o", "workspace_id": "w", "name": "quality",
"description": "d", "model_name": "m", "instructions": "i",
"output": map[string]any{"type": "CLASSIFICATION", "options": []any{}},
}
}
func TestCreateJudge_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" || r.URL.Path != "/v1/observability/judges" {
t.Errorf("unexpected %s %s", r.Method, r.URL.Path)
}
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(judgeJSON())
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.CreateJudge(context.Background(), &observability.CreateJudgeRequest{
Name: "quality",
Description: "d",
ModelName: "m",
Instructions: "i",
Tools: []string{},
Output: json.RawMessage(`{"type":"CLASSIFICATION","options":[]}`),
})
if err != nil {
t.Fatal(err)
}
if resp.ID != "j1" {
t.Errorf("got id %q", resp.ID)
}
}
func TestListJudges_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]any{
"count": 1,
"results": []any{judgeJSON()},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.ListJudges(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
if resp.Count != 1 {
t.Errorf("got count %d", resp.Count)
}
}
func TestGetJudge_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/judges/j1" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(judgeJSON())
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.GetJudge(context.Background(), "j1")
if err != nil {
t.Fatal(err)
}
if resp.Name != "quality" {
t.Errorf("got name %q", resp.Name)
}
}
func TestUpdateJudge_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "PUT" {
t.Errorf("expected PUT, got %s", r.Method)
}
json.NewEncoder(w).Encode(judgeJSON())
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
_, err := client.UpdateJudge(context.Background(), "j1", &observability.UpdateJudgeRequest{
Name: "quality",
Description: "d",
ModelName: "m",
Instructions: "i",
Tools: []string{},
Output: json.RawMessage(`{"type":"CLASSIFICATION","options":[]}`),
})
if err != nil {
t.Fatal(err)
}
}
func TestDeleteJudge_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "DELETE" {
t.Errorf("expected DELETE")
}
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
if err := client.DeleteJudge(context.Background(), "j1"); err != nil {
t.Fatal(err)
}
}

View File

@@ -3,7 +3,7 @@ package mistral
import (
"context"
"somegit.dev/vikingowl/mistral-go-sdk/ocr"
"github.com/VikingOwl91/mistral-go-sdk/ocr"
)
// OCR performs optical character recognition on a document.

View File

@@ -7,7 +7,7 @@ import (
"net/http/httptest"
"testing"
"somegit.dev/vikingowl/mistral-go-sdk/ocr"
"github.com/VikingOwl91/mistral-go-sdk/ocr"
)
func TestOCR_Success(t *testing.T) {

View File

@@ -9,7 +9,7 @@ import (
"testing"
"time"
"somegit.dev/vikingowl/mistral-go-sdk/chat"
"github.com/VikingOwl91/mistral-go-sdk/chat"
)
func TestRetry_429ThenSuccess(t *testing.T) {

21
workflow/deployment.go Normal file
View File

@@ -0,0 +1,21 @@
package workflow
// Deployment represents a workflow deployment.
type Deployment struct {
ID string `json:"id"`
Name string `json:"name"`
IsActive bool `json:"is_active"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
// DeploymentListResponse is the response from listing deployments.
type DeploymentListResponse struct {
Deployments []Deployment `json:"deployments"`
}
// DeploymentListParams holds query parameters for listing deployments.
type DeploymentListParams struct {
ActiveOnly *bool
WorkflowName *string
}

5
workflow/doc.go Normal file
View File

@@ -0,0 +1,5 @@
// Package workflow provides types for the Mistral workflows API.
//
// Workflows support orchestrating multi-step processes with execution
// management, scheduling, event streaming, and observability.
package workflow

369
workflow/event.go Normal file
View File

@@ -0,0 +1,369 @@
package workflow
import (
"encoding/json"
"fmt"
)
// EventType identifies the kind of workflow event.
type EventType string
const (
EventWorkflowStarted EventType = "WORKFLOW_EXECUTION_STARTED"
EventWorkflowCompleted EventType = "WORKFLOW_EXECUTION_COMPLETED"
EventWorkflowFailed EventType = "WORKFLOW_EXECUTION_FAILED"
EventWorkflowCanceled EventType = "WORKFLOW_EXECUTION_CANCELED"
EventWorkflowContinuedAsNew EventType = "WORKFLOW_EXECUTION_CONTINUED_AS_NEW"
EventWorkflowTaskTimedOut EventType = "WORKFLOW_TASK_TIMED_OUT"
EventWorkflowTaskFailed EventType = "WORKFLOW_TASK_FAILED"
EventCustomTaskStarted EventType = "CUSTOM_TASK_STARTED"
EventCustomTaskInProgress EventType = "CUSTOM_TASK_IN_PROGRESS"
EventCustomTaskCompleted EventType = "CUSTOM_TASK_COMPLETED"
EventCustomTaskFailed EventType = "CUSTOM_TASK_FAILED"
EventCustomTaskTimedOut EventType = "CUSTOM_TASK_TIMED_OUT"
EventCustomTaskCanceled EventType = "CUSTOM_TASK_CANCELED"
EventActivityTaskStarted EventType = "ACTIVITY_TASK_STARTED"
EventActivityTaskCompleted EventType = "ACTIVITY_TASK_COMPLETED"
EventActivityTaskRetrying EventType = "ACTIVITY_TASK_RETRYING"
EventActivityTaskFailed EventType = "ACTIVITY_TASK_FAILED"
)
// EventSource identifies where an event originated.
type EventSource string
const (
EventSourceDatabase EventSource = "DATABASE"
EventSourceLive EventSource = "LIVE"
)
// Scope identifies the scope of an event subscription.
type Scope string
const (
ScopeActivity Scope = "activity"
ScopeWorkflow Scope = "workflow"
ScopeAll Scope = "*"
)
// Event is a sealed interface for workflow execution events.
type Event interface {
workflowEvent()
EventType() EventType
}
// eventBase holds fields common to all workflow events.
type eventBase struct {
ID string `json:"event_id"`
Timestamp int64 `json:"event_timestamp"`
RootWorkflowExecID string `json:"root_workflow_exec_id"`
ParentWorkflowExecID *string `json:"parent_workflow_exec_id"`
WorkflowExecID string `json:"workflow_exec_id"`
WorkflowRunID string `json:"workflow_run_id"`
WorkflowName string `json:"workflow_name"`
}
// WorkflowStartedAttributes holds typed attributes for workflow started events.
type WorkflowStartedAttributes struct {
TaskID string `json:"task_id"`
}
// JSONPayload holds a typed JSON value.
type JSONPayload struct {
Value any `json:"value"`
Type string `json:"type"`
}
// WorkflowCompletedAttributes holds typed attributes for workflow completed events.
type WorkflowCompletedAttributes struct {
TaskID string `json:"task_id"`
Result JSONPayload `json:"result"`
}
// WorkflowFailedAttributes holds typed attributes for workflow failed events.
type WorkflowFailedAttributes struct {
TaskID string `json:"task_id"`
Failure any `json:"failure"`
}
// WorkflowExecutionStartedEvent signals that a workflow execution has started.
type WorkflowExecutionStartedEvent struct {
eventBase
Attributes WorkflowStartedAttributes `json:"attributes"`
}
func (*WorkflowExecutionStartedEvent) workflowEvent() {}
func (*WorkflowExecutionStartedEvent) EventType() EventType { return EventWorkflowStarted }
// WorkflowExecutionCompletedEvent signals that a workflow execution has completed.
type WorkflowExecutionCompletedEvent struct {
eventBase
Attributes WorkflowCompletedAttributes `json:"attributes"`
}
func (*WorkflowExecutionCompletedEvent) workflowEvent() {}
func (*WorkflowExecutionCompletedEvent) EventType() EventType { return EventWorkflowCompleted }
// WorkflowExecutionFailedEvent signals that a workflow execution has failed.
type WorkflowExecutionFailedEvent struct {
eventBase
Attributes WorkflowFailedAttributes `json:"attributes"`
}
func (*WorkflowExecutionFailedEvent) workflowEvent() {}
func (*WorkflowExecutionFailedEvent) EventType() EventType { return EventWorkflowFailed }
// WorkflowExecutionCanceledEvent signals that a workflow execution was canceled.
type WorkflowExecutionCanceledEvent struct {
eventBase
Attributes json.RawMessage `json:"attributes"`
}
func (*WorkflowExecutionCanceledEvent) workflowEvent() {}
func (*WorkflowExecutionCanceledEvent) EventType() EventType { return EventWorkflowCanceled }
// WorkflowExecutionContinuedAsNewEvent signals that a workflow continued as a new execution.
type WorkflowExecutionContinuedAsNewEvent struct {
eventBase
Attributes json.RawMessage `json:"attributes"`
}
func (*WorkflowExecutionContinuedAsNewEvent) workflowEvent() {}
func (*WorkflowExecutionContinuedAsNewEvent) EventType() EventType { return EventWorkflowContinuedAsNew }
// WorkflowTaskTimedOutEvent signals that a workflow task timed out.
type WorkflowTaskTimedOutEvent struct {
eventBase
Attributes json.RawMessage `json:"attributes"`
}
func (*WorkflowTaskTimedOutEvent) workflowEvent() {}
func (*WorkflowTaskTimedOutEvent) EventType() EventType { return EventWorkflowTaskTimedOut }
// WorkflowTaskFailedEvent signals that a workflow task failed.
type WorkflowTaskFailedEvent struct {
eventBase
Attributes json.RawMessage `json:"attributes"`
}
func (*WorkflowTaskFailedEvent) workflowEvent() {}
func (*WorkflowTaskFailedEvent) EventType() EventType { return EventWorkflowTaskFailed }
// CustomTaskStartedEvent signals that a custom task has started.
type CustomTaskStartedEvent struct {
eventBase
Attributes json.RawMessage `json:"attributes"`
}
func (*CustomTaskStartedEvent) workflowEvent() {}
func (*CustomTaskStartedEvent) EventType() EventType { return EventCustomTaskStarted }
// CustomTaskInProgressEvent signals that a custom task is in progress.
type CustomTaskInProgressEvent struct {
eventBase
Attributes json.RawMessage `json:"attributes"`
}
func (*CustomTaskInProgressEvent) workflowEvent() {}
func (*CustomTaskInProgressEvent) EventType() EventType { return EventCustomTaskInProgress }
// CustomTaskCompletedEvent signals that a custom task has completed.
type CustomTaskCompletedEvent struct {
eventBase
Attributes json.RawMessage `json:"attributes"`
}
func (*CustomTaskCompletedEvent) workflowEvent() {}
func (*CustomTaskCompletedEvent) EventType() EventType { return EventCustomTaskCompleted }
// CustomTaskFailedEvent signals that a custom task has failed.
type CustomTaskFailedEvent struct {
eventBase
Attributes json.RawMessage `json:"attributes"`
}
func (*CustomTaskFailedEvent) workflowEvent() {}
func (*CustomTaskFailedEvent) EventType() EventType { return EventCustomTaskFailed }
// CustomTaskTimedOutEvent signals that a custom task timed out.
type CustomTaskTimedOutEvent struct {
eventBase
Attributes json.RawMessage `json:"attributes"`
}
func (*CustomTaskTimedOutEvent) workflowEvent() {}
func (*CustomTaskTimedOutEvent) EventType() EventType { return EventCustomTaskTimedOut }
// CustomTaskCanceledEvent signals that a custom task was canceled.
type CustomTaskCanceledEvent struct {
eventBase
Attributes json.RawMessage `json:"attributes"`
}
func (*CustomTaskCanceledEvent) workflowEvent() {}
func (*CustomTaskCanceledEvent) EventType() EventType { return EventCustomTaskCanceled }
// ActivityTaskStartedEvent signals that an activity task has started.
type ActivityTaskStartedEvent struct {
eventBase
Attributes json.RawMessage `json:"attributes"`
}
func (*ActivityTaskStartedEvent) workflowEvent() {}
func (*ActivityTaskStartedEvent) EventType() EventType { return EventActivityTaskStarted }
// ActivityTaskCompletedEvent signals that an activity task has completed.
type ActivityTaskCompletedEvent struct {
eventBase
Attributes json.RawMessage `json:"attributes"`
}
func (*ActivityTaskCompletedEvent) workflowEvent() {}
func (*ActivityTaskCompletedEvent) EventType() EventType { return EventActivityTaskCompleted }
// ActivityTaskRetryingEvent signals that an activity task is being retried.
type ActivityTaskRetryingEvent struct {
eventBase
Attributes json.RawMessage `json:"attributes"`
}
func (*ActivityTaskRetryingEvent) workflowEvent() {}
func (*ActivityTaskRetryingEvent) EventType() EventType { return EventActivityTaskRetrying }
// ActivityTaskFailedEvent signals that an activity task has failed.
type ActivityTaskFailedEvent struct {
eventBase
Attributes json.RawMessage `json:"attributes"`
}
func (*ActivityTaskFailedEvent) workflowEvent() {}
func (*ActivityTaskFailedEvent) EventType() EventType { return EventActivityTaskFailed }
// UnknownEvent holds an event with an unrecognized event_type.
// This prevents the SDK from breaking when new event types are added.
type UnknownEvent struct {
eventBase
RawType string
Raw json.RawMessage
}
func (*UnknownEvent) workflowEvent() {}
func (e *UnknownEvent) EventType() EventType { return EventType(e.RawType) }
// UnmarshalEvent dispatches JSON to the concrete Event type
// based on the "event_type" discriminator field.
func UnmarshalEvent(data []byte) (Event, error) {
var probe struct {
Type string `json:"event_type"`
}
if err := json.Unmarshal(data, &probe); err != nil {
return nil, fmt.Errorf("mistral: unmarshal workflow event: %w", err)
}
switch probe.Type {
case string(EventWorkflowStarted):
var e WorkflowExecutionStartedEvent
return &e, json.Unmarshal(data, &e)
case string(EventWorkflowCompleted):
var e WorkflowExecutionCompletedEvent
return &e, json.Unmarshal(data, &e)
case string(EventWorkflowFailed):
var e WorkflowExecutionFailedEvent
return &e, json.Unmarshal(data, &e)
case string(EventWorkflowCanceled):
var e WorkflowExecutionCanceledEvent
return &e, json.Unmarshal(data, &e)
case string(EventWorkflowContinuedAsNew):
var e WorkflowExecutionContinuedAsNewEvent
return &e, json.Unmarshal(data, &e)
case string(EventWorkflowTaskTimedOut):
var e WorkflowTaskTimedOutEvent
return &e, json.Unmarshal(data, &e)
case string(EventWorkflowTaskFailed):
var e WorkflowTaskFailedEvent
return &e, json.Unmarshal(data, &e)
case string(EventCustomTaskStarted):
var e CustomTaskStartedEvent
return &e, json.Unmarshal(data, &e)
case string(EventCustomTaskInProgress):
var e CustomTaskInProgressEvent
return &e, json.Unmarshal(data, &e)
case string(EventCustomTaskCompleted):
var e CustomTaskCompletedEvent
return &e, json.Unmarshal(data, &e)
case string(EventCustomTaskFailed):
var e CustomTaskFailedEvent
return &e, json.Unmarshal(data, &e)
case string(EventCustomTaskTimedOut):
var e CustomTaskTimedOutEvent
return &e, json.Unmarshal(data, &e)
case string(EventCustomTaskCanceled):
var e CustomTaskCanceledEvent
return &e, json.Unmarshal(data, &e)
case string(EventActivityTaskStarted):
var e ActivityTaskStartedEvent
return &e, json.Unmarshal(data, &e)
case string(EventActivityTaskCompleted):
var e ActivityTaskCompletedEvent
return &e, json.Unmarshal(data, &e)
case string(EventActivityTaskRetrying):
var e ActivityTaskRetryingEvent
return &e, json.Unmarshal(data, &e)
case string(EventActivityTaskFailed):
var e ActivityTaskFailedEvent
return &e, json.Unmarshal(data, &e)
default:
var base eventBase
if err := json.Unmarshal(data, &base); err != nil {
return nil, fmt.Errorf("mistral: unmarshal workflow event base: %w", err)
}
return &UnknownEvent{
eventBase: base,
RawType: probe.Type,
Raw: json.RawMessage(data),
}, nil
}
}
// StreamPayload is a single SSE payload from the workflow event stream.
type StreamPayload struct {
Stream string `json:"stream"`
Data json.RawMessage `json:"data"`
WorkflowContext StreamWorkflowContext `json:"workflow_context"`
BrokerSequence int64 `json:"broker_sequence"`
}
// StreamWorkflowContext holds context for a workflow event stream.
type StreamWorkflowContext struct {
WorkflowName string `json:"workflow_name"`
ExecutionID string `json:"execution_id"`
}
// EventStreamParams holds query parameters for streaming workflow events.
type EventStreamParams struct {
Scope *Scope
ActivityName *string
ActivityID *string
WorkflowName *string
WorkflowExecID *string
RootWorkflowExecID *string
ParentWorkflowExecID *string
Stream *string
StartSeq *int
MetadataFilters map[string]any
WorkflowEventTypes []EventType
LastEventID *string
}
// EventListParams holds query parameters for listing workflow events.
type EventListParams struct {
RootWorkflowExecID *string
WorkflowExecID *string
WorkflowRunID *string
Limit *int
Cursor *string
}
// EventListResponse is the response from listing workflow events.
type EventListResponse struct {
Events []json.RawMessage `json:"events"`
NextCursor *string `json:"next_cursor,omitempty"`
}

158
workflow/event_test.go Normal file
View File

@@ -0,0 +1,158 @@
package workflow
import (
"encoding/json"
"fmt"
"testing"
)
func TestUnmarshalEvent_WorkflowExecutionCompleted(t *testing.T) {
data := []byte(`{
"event_id": "evt-1",
"event_timestamp": 1711929600000000000,
"root_workflow_exec_id": "root-1",
"parent_workflow_exec_id": null,
"workflow_exec_id": "exec-1",
"workflow_run_id": "run-1",
"workflow_name": "my-workflow",
"event_type": "WORKFLOW_EXECUTION_COMPLETED",
"attributes": {"task_id": "t1", "result": {"value": {"answer": 42}, "type": "json"}}
}`)
event, err := UnmarshalEvent(data)
if err != nil {
t.Fatal(err)
}
e, ok := event.(*WorkflowExecutionCompletedEvent)
if !ok {
t.Fatalf("expected *WorkflowExecutionCompletedEvent, got %T", event)
}
if e.ID != "evt-1" {
t.Errorf("got ID %q", e.ID)
}
if e.WorkflowName != "my-workflow" {
t.Errorf("got WorkflowName %q", e.WorkflowName)
}
if e.EventType() != EventWorkflowCompleted {
t.Errorf("got EventType %q", e.EventType())
}
}
func TestUnmarshalEvent_CustomTaskStarted(t *testing.T) {
data := []byte(`{
"event_id": "evt-2",
"event_timestamp": 1711929600000000000,
"root_workflow_exec_id": "root-1",
"parent_workflow_exec_id": "parent-1",
"workflow_exec_id": "exec-1",
"workflow_run_id": "run-1",
"workflow_name": "my-workflow",
"event_type": "CUSTOM_TASK_STARTED",
"attributes": {}
}`)
event, err := UnmarshalEvent(data)
if err != nil {
t.Fatal(err)
}
e, ok := event.(*CustomTaskStartedEvent)
if !ok {
t.Fatalf("expected *CustomTaskStartedEvent, got %T", event)
}
parent := "parent-1"
if e.ParentWorkflowExecID == nil || *e.ParentWorkflowExecID != parent {
t.Errorf("expected parent %q, got %v", parent, e.ParentWorkflowExecID)
}
}
func TestUnmarshalEvent_ActivityTaskRetrying(t *testing.T) {
data := []byte(`{
"event_id": "evt-3",
"event_timestamp": 1711929600000000000,
"root_workflow_exec_id": "root-1",
"parent_workflow_exec_id": null,
"workflow_exec_id": "exec-1",
"workflow_run_id": "run-1",
"workflow_name": "my-workflow",
"event_type": "ACTIVITY_TASK_RETRYING",
"attributes": {}
}`)
event, err := UnmarshalEvent(data)
if err != nil {
t.Fatal(err)
}
if _, ok := event.(*ActivityTaskRetryingEvent); !ok {
t.Fatalf("expected *ActivityTaskRetryingEvent, got %T", event)
}
}
func TestUnmarshalEvent_UnknownType(t *testing.T) {
data := []byte(`{
"event_id": "evt-4",
"event_timestamp": 1711929600000000000,
"root_workflow_exec_id": "root-1",
"parent_workflow_exec_id": null,
"workflow_exec_id": "exec-1",
"workflow_run_id": "run-1",
"workflow_name": "my-workflow",
"event_type": "FUTURE_EVENT_TYPE",
"attributes": {}
}`)
event, err := UnmarshalEvent(data)
if err != nil {
t.Fatal(err)
}
unk, ok := event.(*UnknownEvent)
if !ok {
t.Fatalf("expected *UnknownEvent, got %T", event)
}
if unk.RawType != "FUTURE_EVENT_TYPE" {
t.Errorf("got RawType %q", unk.RawType)
}
}
func TestUnmarshalEvent_AllTypes(t *testing.T) {
types := []struct {
eventType string
wantType string
}{
{"WORKFLOW_EXECUTION_STARTED", "*workflow.WorkflowExecutionStartedEvent"},
{"WORKFLOW_EXECUTION_COMPLETED", "*workflow.WorkflowExecutionCompletedEvent"},
{"WORKFLOW_EXECUTION_FAILED", "*workflow.WorkflowExecutionFailedEvent"},
{"WORKFLOW_EXECUTION_CANCELED", "*workflow.WorkflowExecutionCanceledEvent"},
{"WORKFLOW_EXECUTION_CONTINUED_AS_NEW", "*workflow.WorkflowExecutionContinuedAsNewEvent"},
{"WORKFLOW_TASK_TIMED_OUT", "*workflow.WorkflowTaskTimedOutEvent"},
{"WORKFLOW_TASK_FAILED", "*workflow.WorkflowTaskFailedEvent"},
{"CUSTOM_TASK_STARTED", "*workflow.CustomTaskStartedEvent"},
{"CUSTOM_TASK_IN_PROGRESS", "*workflow.CustomTaskInProgressEvent"},
{"CUSTOM_TASK_COMPLETED", "*workflow.CustomTaskCompletedEvent"},
{"CUSTOM_TASK_FAILED", "*workflow.CustomTaskFailedEvent"},
{"CUSTOM_TASK_TIMED_OUT", "*workflow.CustomTaskTimedOutEvent"},
{"CUSTOM_TASK_CANCELED", "*workflow.CustomTaskCanceledEvent"},
{"ACTIVITY_TASK_STARTED", "*workflow.ActivityTaskStartedEvent"},
{"ACTIVITY_TASK_COMPLETED", "*workflow.ActivityTaskCompletedEvent"},
{"ACTIVITY_TASK_RETRYING", "*workflow.ActivityTaskRetryingEvent"},
{"ACTIVITY_TASK_FAILED", "*workflow.ActivityTaskFailedEvent"},
}
for _, tt := range types {
t.Run(tt.eventType, func(t *testing.T) {
data, _ := json.Marshal(map[string]any{
"event_id": "evt",
"event_timestamp": 1711929600000000000,
"root_workflow_exec_id": "root",
"parent_workflow_exec_id": nil,
"workflow_exec_id": "exec",
"workflow_run_id": "run",
"workflow_name": "wf",
"event_type": tt.eventType,
"attributes": map[string]any{},
})
event, err := UnmarshalEvent(data)
if err != nil {
t.Fatal(err)
}
got := fmt.Sprintf("%T", event)
if got != tt.wantType {
t.Errorf("event_type %q: got %s, want %s", tt.eventType, got, tt.wantType)
}
})
}
}

163
workflow/execution.go Normal file
View File

@@ -0,0 +1,163 @@
package workflow
import "encoding/json"
// ExecutionStatus is the status of a workflow execution.
type ExecutionStatus string
const (
ExecutionRunning ExecutionStatus = "RUNNING"
ExecutionCompleted ExecutionStatus = "COMPLETED"
ExecutionFailed ExecutionStatus = "FAILED"
ExecutionCanceled ExecutionStatus = "CANCELED"
ExecutionTerminated ExecutionStatus = "TERMINATED"
ExecutionContinuedAsNew ExecutionStatus = "CONTINUED_AS_NEW"
ExecutionTimedOut ExecutionStatus = "TIMED_OUT"
ExecutionRetryingAfterErr ExecutionStatus = "RETRYING_AFTER_ERROR"
)
// ExecutionRequest is the request body for executing a workflow.
type ExecutionRequest struct {
ExecutionID *string `json:"execution_id,omitempty"`
Input map[string]any `json:"input,omitempty"`
EncodedInput *NetworkEncodedInput `json:"encoded_input,omitempty"`
WaitForResult bool `json:"wait_for_result,omitempty"`
TimeoutSeconds *float64 `json:"timeout_seconds,omitempty"`
CustomTracingAttributes map[string]string `json:"custom_tracing_attributes,omitempty"`
DeploymentName *string `json:"deployment_name,omitempty"`
}
// ExecutionResponse is the response from a workflow execution.
type ExecutionResponse struct {
WorkflowName string `json:"workflow_name"`
ExecutionID string `json:"execution_id"`
RootExecutionID string `json:"root_execution_id"`
Status ExecutionStatus `json:"status"`
StartTime string `json:"start_time"`
EndTime *string `json:"end_time,omitempty"`
Result any `json:"result,omitempty"`
ParentExecutionID *string `json:"parent_execution_id,omitempty"`
TotalDurationMs *int `json:"total_duration_ms,omitempty"`
}
// NetworkEncodedInput holds a base64-encoded payload for workflow input.
type NetworkEncodedInput struct {
B64Payload string `json:"b64payload"`
EncodingOptions []string `json:"encoding_options,omitempty"`
Empty bool `json:"empty,omitempty"`
}
// SignalInvocationBody is the request body for signaling a workflow execution.
type SignalInvocationBody struct {
Name string `json:"name"`
Input any `json:"input"`
}
// SignalResponse is the response from signaling a workflow execution.
type SignalResponse struct {
Message string `json:"message"`
}
// QueryInvocationBody is the request body for querying a workflow execution.
type QueryInvocationBody struct {
Name string `json:"name"`
Input any `json:"input,omitempty"`
}
// QueryResponse is the response from querying a workflow execution.
type QueryResponse struct {
QueryName string `json:"query_name"`
Result any `json:"result"`
}
// UpdateInvocationBody is the request body for updating a workflow execution.
type UpdateInvocationBody struct {
Name string `json:"name"`
Input any `json:"input,omitempty"`
}
// UpdateResponse is the response from updating a workflow execution.
type UpdateResponse struct {
UpdateName string `json:"update_name"`
Result any `json:"result"`
}
// ResetInvocationBody is the request body for resetting a workflow execution.
type ResetInvocationBody struct {
EventID int `json:"event_id"`
Reason *string `json:"reason,omitempty"`
ExcludeSignals bool `json:"exclude_signals,omitempty"`
ExcludeUpdates bool `json:"exclude_updates,omitempty"`
}
// BatchExecutionBody is the request body for batch execution operations.
type BatchExecutionBody struct {
ExecutionIDs []string `json:"execution_ids"`
}
// BatchExecutionResponse is the response from batch execution operations.
type BatchExecutionResponse struct {
Results map[string]BatchExecutionResult `json:"results,omitempty"`
}
// BatchExecutionResult is the result of a single batch operation.
type BatchExecutionResult struct {
Status string `json:"status"`
Error *string `json:"error,omitempty"`
}
// StreamParams holds query parameters for streaming workflow executions.
type StreamParams struct {
EventSource *EventSource
LastEventID *string
}
// TraceOTelResponse is the response from the OTel trace endpoint.
type TraceOTelResponse struct {
WorkflowName string `json:"workflow_name"`
ExecutionID string `json:"execution_id"`
RootExecutionID string `json:"root_execution_id"`
Status *ExecutionStatus `json:"status"`
StartTime string `json:"start_time"`
EndTime *string `json:"end_time,omitempty"`
Result any `json:"result"`
DataSource string `json:"data_source"`
ParentExecutionID *string `json:"parent_execution_id,omitempty"`
TotalDurationMs *int `json:"total_duration_ms,omitempty"`
OTelTraceID *string `json:"otel_trace_id,omitempty"`
OTelTraceData any `json:"otel_trace_data,omitempty"`
}
// TraceSummaryResponse is the response from the trace summary endpoint.
type TraceSummaryResponse struct {
WorkflowName string `json:"workflow_name"`
ExecutionID string `json:"execution_id"`
RootExecutionID string `json:"root_execution_id"`
Status *ExecutionStatus `json:"status"`
StartTime string `json:"start_time"`
EndTime *string `json:"end_time,omitempty"`
Result any `json:"result"`
ParentExecutionID *string `json:"parent_execution_id,omitempty"`
TotalDurationMs *int `json:"total_duration_ms,omitempty"`
SpanTree any `json:"span_tree,omitempty"`
}
// TraceEventsResponse is the response from the trace events endpoint.
type TraceEventsResponse struct {
WorkflowName string `json:"workflow_name"`
ExecutionID string `json:"execution_id"`
RootExecutionID string `json:"root_execution_id"`
Status *ExecutionStatus `json:"status"`
StartTime string `json:"start_time"`
EndTime *string `json:"end_time,omitempty"`
Result any `json:"result"`
ParentExecutionID *string `json:"parent_execution_id,omitempty"`
TotalDurationMs *int `json:"total_duration_ms,omitempty"`
Events []json.RawMessage `json:"events,omitempty"`
}
// TraceEventsParams holds query parameters for the trace events endpoint.
type TraceEventsParams struct {
MergeSameIDEvents *bool
IncludeInternalEvents *bool
}

27
workflow/metrics.go Normal file
View File

@@ -0,0 +1,27 @@
package workflow
// Metrics holds workflow performance metrics.
type Metrics struct {
ExecutionCount ScalarMetric `json:"execution_count"`
SuccessCount ScalarMetric `json:"success_count"`
ErrorCount ScalarMetric `json:"error_count"`
AverageLatencyMs ScalarMetric `json:"average_latency_ms"`
LatencyOverTime TimeSeriesMetric `json:"latency_over_time"`
RetryRate ScalarMetric `json:"retry_rate"`
}
// ScalarMetric holds a single numeric metric value.
type ScalarMetric struct {
Value float64 `json:"value"`
}
// TimeSeriesMetric holds a time series of [timestamp, value] pairs.
type TimeSeriesMetric struct {
Value [][]float64 `json:"value"`
}
// MetricsParams holds query parameters for workflow metrics.
type MetricsParams struct {
StartTime *string
EndTime *string
}

44
workflow/registration.go Normal file
View File

@@ -0,0 +1,44 @@
package workflow
// Registration represents a workflow registration.
type Registration struct {
ID string `json:"id"`
WorkflowID string `json:"workflow_id"`
TaskQueue string `json:"task_queue"`
Workflow *Workflow `json:"workflow,omitempty"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
// RegistrationListResponse is the response from listing workflow registrations.
type RegistrationListResponse struct {
Registrations []Registration `json:"registrations"`
NextCursor *string `json:"next_cursor,omitempty"`
}
// RegistrationListParams holds query parameters for listing registrations.
type RegistrationListParams struct {
WorkflowID *string
TaskQueue *string
ActiveOnly *bool
IncludeShared *bool
WorkflowSearch *string
Archived *bool
WithWorkflow *bool
AvailableInChatAssistant *bool
Limit *int
Cursor *string
}
// RegistrationGetParams holds query parameters for getting a registration.
type RegistrationGetParams struct {
WithWorkflow *bool
IncludeShared *bool
}
// WorkerInfo holds information about the current worker.
type WorkerInfo struct {
SchedulerURL string `json:"scheduler_url"`
Namespace string `json:"namespace"`
TLS bool `json:"tls"`
}

26
workflow/run.go Normal file
View File

@@ -0,0 +1,26 @@
package workflow
// Run represents a workflow run.
type Run struct {
ID string `json:"id"`
WorkflowName string `json:"workflow_name"`
ExecutionID string `json:"execution_id"`
Status ExecutionStatus `json:"status"`
StartTime string `json:"start_time"`
EndTime *string `json:"end_time,omitempty"`
}
// ListRunsResponse is the response from listing workflow runs.
type ListRunsResponse struct {
Runs []Run `json:"runs"`
NextPageToken *string `json:"next_page_token,omitempty"`
}
// RunListParams holds query parameters for listing workflow runs.
type RunListParams struct {
WorkflowIdentifier *string
Search *string
Status *string
PageSize *int
NextPageToken *string
}

75
workflow/schedule.go Normal file
View File

@@ -0,0 +1,75 @@
package workflow
// ScheduleRequest is the request body for scheduling a workflow.
type ScheduleRequest struct {
Schedule ScheduleDefinition `json:"schedule"`
WorkflowRegistrationID *string `json:"workflow_registration_id,omitempty"`
WorkflowIdentifier *string `json:"workflow_identifier,omitempty"`
ScheduleID *string `json:"schedule_id,omitempty"`
DeploymentName *string `json:"deployment_name,omitempty"`
}
// ScheduleDefinition describes when and how a workflow should be scheduled.
type ScheduleDefinition struct {
Input any `json:"input"`
Calendars []ScheduleCalendar `json:"calendars,omitempty"`
Intervals []ScheduleInterval `json:"intervals,omitempty"`
CronExpressions []string `json:"cron_expressions,omitempty"`
Skip []ScheduleCalendar `json:"skip,omitempty"`
StartAt *string `json:"start_at,omitempty"`
EndAt *string `json:"end_at,omitempty"`
Jitter *string `json:"jitter,omitempty"`
TimeZoneName *string `json:"time_zone_name,omitempty"`
Policy *SchedulePolicy `json:"policy,omitempty"`
}
// ScheduleCalendar defines calendar-based schedule entries.
type ScheduleCalendar struct {
Second []ScheduleRange `json:"second,omitempty"`
Minute []ScheduleRange `json:"minute,omitempty"`
Hour []ScheduleRange `json:"hour,omitempty"`
DayOfMonth []ScheduleRange `json:"day_of_month,omitempty"`
Month []ScheduleRange `json:"month,omitempty"`
Year []ScheduleRange `json:"year,omitempty"`
DayOfWeek []ScheduleRange `json:"day_of_week,omitempty"`
Comment *string `json:"comment,omitempty"`
}
// ScheduleRange defines a numeric range for calendar schedules.
type ScheduleRange struct {
Start int `json:"start"`
End int `json:"end,omitempty"`
Step int `json:"step,omitempty"`
}
// ScheduleInterval defines an interval-based schedule.
type ScheduleInterval struct {
Every string `json:"every"`
Offset *string `json:"offset,omitempty"`
}
// SchedulePolicy controls schedule overlap and failure behavior.
type SchedulePolicy struct {
CatchupWindowSeconds int `json:"catchup_window_seconds,omitempty"`
Overlap *int `json:"overlap,omitempty"`
PauseOnFailure bool `json:"pause_on_failure,omitempty"`
}
// ScheduleResponse is the response from creating a workflow schedule.
type ScheduleResponse struct {
ScheduleID string `json:"schedule_id"`
}
// ScheduleListResponse is the response from listing workflow schedules.
type ScheduleListResponse struct {
Schedules []Schedule `json:"schedules"`
}
// Schedule represents a workflow schedule.
type Schedule struct {
ScheduleID string `json:"schedule_id"`
Definition ScheduleDefinition `json:"definition"`
WorkflowName string `json:"workflow_name"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}

44
workflow/workflow.go Normal file
View File

@@ -0,0 +1,44 @@
package workflow
// Workflow represents a workflow definition.
type Workflow struct {
ID string `json:"id"`
Name string `json:"name"`
DisplayName *string `json:"display_name,omitempty"`
Description *string `json:"description,omitempty"`
OwnerID string `json:"owner_id"`
WorkspaceID string `json:"workspace_id"`
AvailableInChatAssistant bool `json:"available_in_chat_assistant"`
Archived bool `json:"archived"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
// WorkflowUpdateRequest is the request body for updating a workflow.
type WorkflowUpdateRequest struct {
DisplayName *string `json:"display_name,omitempty"`
Description *string `json:"description,omitempty"`
AvailableInChatAssistant *bool `json:"available_in_chat_assistant,omitempty"`
}
// WorkflowListResponse is the response from listing workflows.
type WorkflowListResponse struct {
Workflows []Workflow `json:"workflows"`
NextCursor *string `json:"next_cursor,omitempty"`
}
// WorkflowListParams holds query parameters for listing workflows.
type WorkflowListParams struct {
ActiveOnly *bool
IncludeShared *bool
AvailableInChatAssistant *bool
Archived *bool
Cursor *string
Limit *int
}
// WorkflowArchiveResponse is the response from archiving/unarchiving a workflow.
type WorkflowArchiveResponse struct {
ID string `json:"id"`
Archived bool `json:"archived"`
}

168
workflows.go Normal file
View File

@@ -0,0 +1,168 @@
package mistral
import (
"context"
"fmt"
"net/url"
"strconv"
"github.com/VikingOwl91/mistral-go-sdk/workflow"
)
// ListWorkflows lists workflows.
func (c *Client) ListWorkflows(ctx context.Context, params *workflow.WorkflowListParams) (*workflow.WorkflowListResponse, error) {
path := "/v1/workflows"
if params != nil {
q := url.Values{}
if params.ActiveOnly != nil {
q.Set("active_only", strconv.FormatBool(*params.ActiveOnly))
}
if params.IncludeShared != nil {
q.Set("include_shared", strconv.FormatBool(*params.IncludeShared))
}
if params.AvailableInChatAssistant != nil {
q.Set("available_in_chat_assistant", strconv.FormatBool(*params.AvailableInChatAssistant))
}
if params.Archived != nil {
q.Set("archived", strconv.FormatBool(*params.Archived))
}
if params.Cursor != nil {
q.Set("cursor", *params.Cursor)
}
if params.Limit != nil {
q.Set("limit", strconv.Itoa(*params.Limit))
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp workflow.WorkflowListResponse
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// GetWorkflow retrieves a workflow by identifier.
func (c *Client) GetWorkflow(ctx context.Context, workflowIdentifier string) (*workflow.Workflow, error) {
var resp workflow.Workflow
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/workflows/%s", workflowIdentifier), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// UpdateWorkflow updates a workflow.
func (c *Client) UpdateWorkflow(ctx context.Context, workflowIdentifier string, req *workflow.WorkflowUpdateRequest) (*workflow.Workflow, error) {
var resp workflow.Workflow
if err := c.doJSON(ctx, "PUT", fmt.Sprintf("/v1/workflows/%s", workflowIdentifier), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ArchiveWorkflow archives a workflow.
func (c *Client) ArchiveWorkflow(ctx context.Context, workflowIdentifier string) (*workflow.WorkflowArchiveResponse, error) {
var resp workflow.WorkflowArchiveResponse
if err := c.doJSON(ctx, "PUT", fmt.Sprintf("/v1/workflows/%s/archive", workflowIdentifier), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// UnarchiveWorkflow unarchives a workflow.
func (c *Client) UnarchiveWorkflow(ctx context.Context, workflowIdentifier string) (*workflow.WorkflowArchiveResponse, error) {
var resp workflow.WorkflowArchiveResponse
if err := c.doJSON(ctx, "PUT", fmt.Sprintf("/v1/workflows/%s/unarchive", workflowIdentifier), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ExecuteWorkflow executes a workflow.
func (c *Client) ExecuteWorkflow(ctx context.Context, workflowIdentifier string, req *workflow.ExecutionRequest) (*workflow.ExecutionResponse, error) {
var resp workflow.ExecutionResponse
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/workflows/%s/execute", workflowIdentifier), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ListWorkflowRegistrations lists workflow registrations.
func (c *Client) ListWorkflowRegistrations(ctx context.Context, params *workflow.RegistrationListParams) (*workflow.RegistrationListResponse, error) {
path := "/v1/workflows/registrations"
if params != nil {
q := url.Values{}
if params.WorkflowID != nil {
q.Set("workflow_id", *params.WorkflowID)
}
if params.TaskQueue != nil {
q.Set("task_queue", *params.TaskQueue)
}
if params.ActiveOnly != nil {
q.Set("active_only", strconv.FormatBool(*params.ActiveOnly))
}
if params.IncludeShared != nil {
q.Set("include_shared", strconv.FormatBool(*params.IncludeShared))
}
if params.WorkflowSearch != nil {
q.Set("workflow_search", *params.WorkflowSearch)
}
if params.Archived != nil {
q.Set("archived", strconv.FormatBool(*params.Archived))
}
if params.WithWorkflow != nil {
q.Set("with_workflow", strconv.FormatBool(*params.WithWorkflow))
}
if params.AvailableInChatAssistant != nil {
q.Set("available_in_chat_assistant", strconv.FormatBool(*params.AvailableInChatAssistant))
}
if params.Limit != nil {
q.Set("limit", strconv.Itoa(*params.Limit))
}
if params.Cursor != nil {
q.Set("cursor", *params.Cursor)
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp workflow.RegistrationListResponse
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// GetWorkflowRegistration retrieves a workflow registration by ID.
func (c *Client) GetWorkflowRegistration(ctx context.Context, registrationID string, params *workflow.RegistrationGetParams) (*workflow.Registration, error) {
path := fmt.Sprintf("/v1/workflows/registrations/%s", registrationID)
if params != nil {
q := url.Values{}
if params.WithWorkflow != nil {
q.Set("with_workflow", strconv.FormatBool(*params.WithWorkflow))
}
if params.IncludeShared != nil {
q.Set("include_shared", strconv.FormatBool(*params.IncludeShared))
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp workflow.Registration
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ExecuteWorkflowRegistration executes a workflow via its registration.
//
// Deprecated: Use ExecuteWorkflow instead. This method will be removed in a future release.
func (c *Client) ExecuteWorkflowRegistration(ctx context.Context, registrationID string, req *workflow.ExecutionRequest) (*workflow.ExecutionResponse, error) {
var resp workflow.ExecutionResponse
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/workflows/registrations/%s/execute", registrationID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}

41
workflows_deployments.go Normal file
View File

@@ -0,0 +1,41 @@
package mistral
import (
"context"
"fmt"
"net/url"
"strconv"
"github.com/VikingOwl91/mistral-go-sdk/workflow"
)
// ListWorkflowDeployments lists workflow deployments.
func (c *Client) ListWorkflowDeployments(ctx context.Context, params *workflow.DeploymentListParams) (*workflow.DeploymentListResponse, error) {
path := "/v1/workflows/deployments"
if params != nil {
q := url.Values{}
if params.ActiveOnly != nil {
q.Set("active_only", strconv.FormatBool(*params.ActiveOnly))
}
if params.WorkflowName != nil {
q.Set("workflow_name", *params.WorkflowName)
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp workflow.DeploymentListResponse
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// GetWorkflowDeployment retrieves a workflow deployment by ID.
func (c *Client) GetWorkflowDeployment(ctx context.Context, deploymentID string) (*workflow.Deployment, error) {
var resp workflow.Deployment
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/workflows/deployments/%s", deploymentID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}

View File

@@ -0,0 +1,57 @@
package mistral
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
)
func TestListWorkflowDeployments_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/workflows/deployments" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"deployments": []map[string]any{
{"id": "dep-1", "name": "prod", "is_active": true, "created_at": "2026-01-01", "updated_at": "2026-01-01"},
},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.ListWorkflowDeployments(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
if len(resp.Deployments) != 1 {
t.Fatalf("got %d deployments", len(resp.Deployments))
}
if resp.Deployments[0].Name != "prod" {
t.Errorf("got name %q", resp.Deployments[0].Name)
}
}
func TestGetWorkflowDeployment_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/workflows/deployments/dep-1" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"id": "dep-1", "name": "prod", "is_active": true,
"created_at": "2026-01-01", "updated_at": "2026-01-01",
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
dep, err := client.GetWorkflowDeployment(context.Background(), "dep-1")
if err != nil {
t.Fatal(err)
}
if dep.ID != "dep-1" {
t.Errorf("got id %q", dep.ID)
}
}

99
workflows_events.go Normal file
View File

@@ -0,0 +1,99 @@
package mistral
import (
"context"
"encoding/json"
"net/url"
"strconv"
"strings"
"github.com/VikingOwl91/mistral-go-sdk/workflow"
)
// StreamWorkflowEvents streams workflow events via SSE.
func (c *Client) StreamWorkflowEvents(ctx context.Context, params *workflow.EventStreamParams) (*WorkflowEventStream, error) {
path := "/v1/workflows/events/stream"
if params != nil {
q := url.Values{}
if params.Scope != nil {
q.Set("scope", string(*params.Scope))
}
if params.ActivityName != nil {
q.Set("activity_name", *params.ActivityName)
}
if params.ActivityID != nil {
q.Set("activity_id", *params.ActivityID)
}
if params.WorkflowName != nil {
q.Set("workflow_name", *params.WorkflowName)
}
if params.WorkflowExecID != nil {
q.Set("workflow_exec_id", *params.WorkflowExecID)
}
if params.RootWorkflowExecID != nil {
q.Set("root_workflow_exec_id", *params.RootWorkflowExecID)
}
if params.ParentWorkflowExecID != nil {
q.Set("parent_workflow_exec_id", *params.ParentWorkflowExecID)
}
if params.Stream != nil {
q.Set("stream", *params.Stream)
}
if params.StartSeq != nil {
q.Set("start_seq", strconv.Itoa(*params.StartSeq))
}
if params.MetadataFilters != nil {
data, _ := json.Marshal(params.MetadataFilters)
q.Set("metadata_filters", string(data))
}
if len(params.WorkflowEventTypes) > 0 {
types := make([]string, len(params.WorkflowEventTypes))
for i, et := range params.WorkflowEventTypes {
types[i] = string(et)
}
q.Set("workflow_event_types", strings.Join(types, ","))
}
if params.LastEventID != nil {
q.Set("last_event_id", *params.LastEventID)
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
resp, err := c.doStream(ctx, "GET", path, nil)
if err != nil {
return nil, err
}
return newWorkflowEventStream(resp.Body), nil
}
// ListWorkflowEvents lists workflow events.
func (c *Client) ListWorkflowEvents(ctx context.Context, params *workflow.EventListParams) (*workflow.EventListResponse, error) {
path := "/v1/workflows/events/list"
if params != nil {
q := url.Values{}
if params.RootWorkflowExecID != nil {
q.Set("root_workflow_exec_id", *params.RootWorkflowExecID)
}
if params.WorkflowExecID != nil {
q.Set("workflow_exec_id", *params.WorkflowExecID)
}
if params.WorkflowRunID != nil {
q.Set("workflow_run_id", *params.WorkflowRunID)
}
if params.Limit != nil {
q.Set("limit", strconv.Itoa(*params.Limit))
}
if params.Cursor != nil {
q.Set("cursor", *params.Cursor)
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp workflow.EventListResponse
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}

40
workflows_events_test.go Normal file
View File

@@ -0,0 +1,40 @@
package mistral
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/VikingOwl91/mistral-go-sdk/workflow"
)
func TestListWorkflowEvents_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/workflows/events/list" {
t.Errorf("got path %s", r.URL.Path)
}
if r.URL.Query().Get("limit") != "50" {
t.Errorf("got limit %q", r.URL.Query().Get("limit"))
}
json.NewEncoder(w).Encode(map[string]any{
"events": []map[string]any{{"event_type": "WORKFLOW_EXECUTION_STARTED"}},
"next_cursor": "cur-1",
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
limit := 50
resp, err := client.ListWorkflowEvents(context.Background(), &workflow.EventListParams{Limit: &limit})
if err != nil {
t.Fatal(err)
}
if len(resp.Events) != 1 {
t.Fatalf("got %d events", len(resp.Events))
}
if resp.NextCursor == nil || *resp.NextCursor != "cur-1" {
t.Errorf("got cursor %v", resp.NextCursor)
}
}

255
workflows_executions.go Normal file
View File

@@ -0,0 +1,255 @@
package mistral
import (
"context"
"encoding/json"
"fmt"
"net/url"
"strconv"
"time"
"github.com/VikingOwl91/mistral-go-sdk/workflow"
)
// GetWorkflowExecution retrieves a workflow execution by ID.
func (c *Client) GetWorkflowExecution(ctx context.Context, executionID string) (*workflow.ExecutionResponse, error) {
var resp workflow.ExecutionResponse
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/workflows/executions/%s", executionID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// GetWorkflowExecutionHistory retrieves the history of a workflow execution.
func (c *Client) GetWorkflowExecutionHistory(ctx context.Context, executionID string, decodePayloads *bool) (json.RawMessage, error) {
path := fmt.Sprintf("/v1/workflows/executions/%s/history", executionID)
if decodePayloads != nil {
path += "?decode_payloads=" + strconv.FormatBool(*decodePayloads)
}
var resp json.RawMessage
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return resp, nil
}
// StreamWorkflowExecution streams events for a workflow execution via SSE.
func (c *Client) StreamWorkflowExecution(ctx context.Context, executionID string, params *workflow.StreamParams) (*WorkflowEventStream, error) {
path := fmt.Sprintf("/v1/workflows/executions/%s/stream", executionID)
if params != nil {
q := url.Values{}
if params.EventSource != nil {
q.Set("event_source", string(*params.EventSource))
}
if params.LastEventID != nil {
q.Set("last_event_id", *params.LastEventID)
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
resp, err := c.doStream(ctx, "GET", path, nil)
if err != nil {
return nil, err
}
return newWorkflowEventStream(resp.Body), nil
}
// SignalWorkflowExecution sends a signal to a workflow execution.
func (c *Client) SignalWorkflowExecution(ctx context.Context, executionID string, req *workflow.SignalInvocationBody) (*workflow.SignalResponse, error) {
var resp workflow.SignalResponse
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/workflows/executions/%s/signals", executionID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// QueryWorkflowExecution queries a workflow execution.
func (c *Client) QueryWorkflowExecution(ctx context.Context, executionID string, req *workflow.QueryInvocationBody) (*workflow.QueryResponse, error) {
var resp workflow.QueryResponse
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/workflows/executions/%s/queries", executionID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// UpdateWorkflowExecution sends an update to a workflow execution.
func (c *Client) UpdateWorkflowExecution(ctx context.Context, executionID string, req *workflow.UpdateInvocationBody) (*workflow.UpdateResponse, error) {
var resp workflow.UpdateResponse
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/workflows/executions/%s/updates", executionID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// TerminateWorkflowExecution terminates a workflow execution.
func (c *Client) TerminateWorkflowExecution(ctx context.Context, executionID string) error {
resp, err := c.do(ctx, "POST", fmt.Sprintf("/v1/workflows/executions/%s/terminate", executionID), nil)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return parseAPIError(resp)
}
return nil
}
// CancelWorkflowExecution cancels a workflow execution.
func (c *Client) CancelWorkflowExecution(ctx context.Context, executionID string) error {
resp, err := c.do(ctx, "POST", fmt.Sprintf("/v1/workflows/executions/%s/cancel", executionID), nil)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return parseAPIError(resp)
}
return nil
}
// ResetWorkflowExecution resets a workflow execution to a specific event.
func (c *Client) ResetWorkflowExecution(ctx context.Context, executionID string, req *workflow.ResetInvocationBody) error {
return c.doJSON(ctx, "POST", fmt.Sprintf("/v1/workflows/executions/%s/reset", executionID), req, nil)
}
// BatchCancelWorkflowExecutions cancels multiple workflow executions.
func (c *Client) BatchCancelWorkflowExecutions(ctx context.Context, req *workflow.BatchExecutionBody) (*workflow.BatchExecutionResponse, error) {
var resp workflow.BatchExecutionResponse
if err := c.doJSON(ctx, "POST", "/v1/workflows/executions/cancel", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// BatchTerminateWorkflowExecutions terminates multiple workflow executions.
func (c *Client) BatchTerminateWorkflowExecutions(ctx context.Context, req *workflow.BatchExecutionBody) (*workflow.BatchExecutionResponse, error) {
var resp workflow.BatchExecutionResponse
if err := c.doJSON(ctx, "POST", "/v1/workflows/executions/terminate", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// GetWorkflowExecutionTraceOTel retrieves the OpenTelemetry trace for a workflow execution.
func (c *Client) GetWorkflowExecutionTraceOTel(ctx context.Context, executionID string) (*workflow.TraceOTelResponse, error) {
var resp workflow.TraceOTelResponse
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/workflows/executions/%s/trace/otel", executionID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// GetWorkflowExecutionTraceSummary retrieves the trace summary for a workflow execution.
func (c *Client) GetWorkflowExecutionTraceSummary(ctx context.Context, executionID string) (*workflow.TraceSummaryResponse, error) {
var resp workflow.TraceSummaryResponse
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/workflows/executions/%s/trace/summary", executionID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// GetWorkflowExecutionTraceEvents retrieves the trace events for a workflow execution.
func (c *Client) GetWorkflowExecutionTraceEvents(ctx context.Context, executionID string, params *workflow.TraceEventsParams) (*workflow.TraceEventsResponse, error) {
path := fmt.Sprintf("/v1/workflows/executions/%s/trace/events", executionID)
if params != nil {
q := url.Values{}
if params.MergeSameIDEvents != nil {
q.Set("merge_same_id_events", strconv.FormatBool(*params.MergeSameIDEvents))
}
if params.IncludeInternalEvents != nil {
q.Set("include_internal_events", strconv.FormatBool(*params.IncludeInternalEvents))
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp workflow.TraceEventsResponse
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// WorkflowEventStream wraps the generic Stream to provide typed workflow events
// with StreamPayload envelope metadata.
type WorkflowEventStream struct {
stream *Stream[json.RawMessage]
event workflow.Event
payload *workflow.StreamPayload
err error
}
func newWorkflowEventStream(body readCloser) *WorkflowEventStream {
return &WorkflowEventStream{
stream: newStream[json.RawMessage](body),
}
}
// Next advances to the next event. Returns false when done or on error.
func (s *WorkflowEventStream) Next() bool {
if s.err != nil {
return false
}
if !s.stream.Next() {
s.err = s.stream.Err()
return false
}
var payload workflow.StreamPayload
if err := json.Unmarshal(s.stream.Current(), &payload); err != nil {
s.err = fmt.Errorf("mistral: decode workflow stream payload: %w", err)
return false
}
event, err := workflow.UnmarshalEvent(payload.Data)
if err != nil {
s.err = err
return false
}
s.event = event
s.payload = &payload
return true
}
// Current returns the most recently read workflow event.
func (s *WorkflowEventStream) Current() workflow.Event { return s.event }
// CurrentPayload returns the full StreamPayload envelope of the current event.
func (s *WorkflowEventStream) CurrentPayload() *workflow.StreamPayload { return s.payload }
// Err returns any error encountered during streaming.
func (s *WorkflowEventStream) Err() error { return s.err }
// Close releases the underlying connection.
func (s *WorkflowEventStream) Close() error { return s.stream.Close() }
// ExecuteWorkflowAndWait executes a workflow and polls until completion.
func (c *Client) ExecuteWorkflowAndWait(ctx context.Context, workflowIdentifier string, req *workflow.ExecutionRequest) (*workflow.ExecutionResponse, error) {
execResp, err := c.ExecuteWorkflow(ctx, workflowIdentifier, req)
if err != nil {
return nil, err
}
for {
if isTerminal(execResp.Status) {
return execResp, nil
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(500 * time.Millisecond):
}
execResp, err = c.GetWorkflowExecution(ctx, execResp.ExecutionID)
if err != nil {
return nil, err
}
}
}
func isTerminal(s workflow.ExecutionStatus) bool {
switch s {
case workflow.ExecutionCompleted, workflow.ExecutionFailed,
workflow.ExecutionCanceled, workflow.ExecutionTerminated,
workflow.ExecutionTimedOut:
return true
}
return false
}

View File

@@ -0,0 +1,266 @@
package mistral
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/VikingOwl91/mistral-go-sdk/workflow"
)
func TestGetWorkflowExecution_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/workflows/executions/exec-1" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"workflow_name": "my-flow", "execution_id": "exec-1",
"root_execution_id": "exec-1", "status": "COMPLETED",
"start_time": "2026-01-01T00:00:00Z",
"end_time": "2026-01-01T00:01:00Z",
"result": map[string]any{"answer": 42},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.GetWorkflowExecution(context.Background(), "exec-1")
if err != nil {
t.Fatal(err)
}
if resp.Status != workflow.ExecutionCompleted {
t.Errorf("got status %q", resp.Status)
}
}
func TestSignalWorkflowExecution_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("got method %s", r.Method)
}
if r.URL.Path != "/v1/workflows/executions/exec-1/signals" {
t.Errorf("got path %s", r.URL.Path)
}
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
if body["name"] != "approval" {
t.Errorf("got name %v", body["name"])
}
w.WriteHeader(http.StatusAccepted)
json.NewEncoder(w).Encode(map[string]any{"message": "Signal accepted"})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.SignalWorkflowExecution(context.Background(), "exec-1", &workflow.SignalInvocationBody{
Name: "approval",
Input: map[string]any{"approved": true},
})
if err != nil {
t.Fatal(err)
}
if resp.Message != "Signal accepted" {
t.Errorf("got message %q", resp.Message)
}
}
func TestTerminateWorkflowExecution_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("got method %s", r.Method)
}
if r.URL.Path != "/v1/workflows/executions/exec-1/terminate" {
t.Errorf("got path %s", r.URL.Path)
}
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
err := client.TerminateWorkflowExecution(context.Background(), "exec-1")
if err != nil {
t.Fatal(err)
}
}
func TestBatchCancelWorkflowExecutions_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("got method %s", r.Method)
}
if r.URL.Path != "/v1/workflows/executions/cancel" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"results": map[string]any{
"exec-1": map[string]any{"status": "success"},
"exec-2": map[string]any{"status": "failure", "error": "not found"},
},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.BatchCancelWorkflowExecutions(context.Background(), &workflow.BatchExecutionBody{
ExecutionIDs: []string{"exec-1", "exec-2"},
})
if err != nil {
t.Fatal(err)
}
if resp.Results["exec-1"].Status != "success" {
t.Errorf("got exec-1 status %q", resp.Results["exec-1"].Status)
}
if resp.Results["exec-2"].Error == nil || *resp.Results["exec-2"].Error != "not found" {
t.Errorf("expected exec-2 error")
}
}
func TestStreamWorkflowExecution_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
t.Errorf("got method %s", r.Method)
}
if r.URL.Path != "/v1/workflows/executions/exec-1/stream" {
t.Errorf("got path %s", r.URL.Path)
}
w.Header().Set("Content-Type", "text/event-stream")
flusher, _ := w.(http.Flusher)
payloads := []map[string]any{
{
"stream": "events",
"data": map[string]any{
"event_id": "evt-1", "event_timestamp": 1711929600000000000,
"root_workflow_exec_id": "exec-1", "parent_workflow_exec_id": nil,
"workflow_exec_id": "exec-1", "workflow_run_id": "run-1",
"workflow_name": "my-flow", "event_type": "WORKFLOW_EXECUTION_STARTED",
"attributes": map[string]any{},
},
"workflow_context": map[string]any{
"namespace": "default", "workflow_name": "my-flow", "workflow_exec_id": "exec-1",
},
"broker_sequence": 1,
},
{
"stream": "events",
"data": map[string]any{
"event_id": "evt-2", "event_timestamp": 1711929601000000000,
"root_workflow_exec_id": "exec-1", "parent_workflow_exec_id": nil,
"workflow_exec_id": "exec-1", "workflow_run_id": "run-1",
"workflow_name": "my-flow", "event_type": "WORKFLOW_EXECUTION_COMPLETED",
"attributes": map[string]any{"result": map[string]any{"value": 42, "type": "json"}},
},
"workflow_context": map[string]any{
"namespace": "default", "workflow_name": "my-flow", "workflow_exec_id": "exec-1",
},
"broker_sequence": 2,
},
}
for _, p := range payloads {
data, _ := json.Marshal(p)
fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()
}
fmt.Fprint(w, "data: [DONE]\n\n")
flusher.Flush()
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
stream, err := client.StreamWorkflowExecution(context.Background(), "exec-1", nil)
if err != nil {
t.Fatal(err)
}
defer stream.Close()
var events []workflow.Event
var lastPayload *workflow.StreamPayload
for stream.Next() {
events = append(events, stream.Current())
lastPayload = stream.CurrentPayload()
}
if stream.Err() != nil {
t.Fatal(stream.Err())
}
if len(events) != 2 {
t.Fatalf("got %d events, want 2", len(events))
}
if _, ok := events[0].(*workflow.WorkflowExecutionStartedEvent); !ok {
t.Errorf("expected *WorkflowExecutionStartedEvent, got %T", events[0])
}
if _, ok := events[1].(*workflow.WorkflowExecutionCompletedEvent); !ok {
t.Errorf("expected *WorkflowExecutionCompletedEvent, got %T", events[1])
}
if lastPayload.WorkflowContext.WorkflowName != "my-flow" {
t.Errorf("got workflow context name %q", lastPayload.WorkflowContext.WorkflowName)
}
}
func TestGetWorkflowExecutionTraceOTel_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/workflows/executions/exec-1/trace/otel" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"workflow_name": "my-flow", "execution_id": "exec-1",
"root_execution_id": "exec-1", "status": "COMPLETED",
"start_time": "2026-01-01T00:00:00Z", "data_source": "temporal",
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.GetWorkflowExecutionTraceOTel(context.Background(), "exec-1")
if err != nil {
t.Fatal(err)
}
if resp.DataSource != "temporal" {
t.Errorf("got data_source %q", resp.DataSource)
}
}
func TestExecuteWorkflowAndWait_Success(t *testing.T) {
calls := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == "POST" && r.URL.Path == "/v1/workflows/wf-1/execute":
json.NewEncoder(w).Encode(map[string]any{
"workflow_name": "my-flow", "execution_id": "exec-1",
"root_execution_id": "exec-1", "status": "RUNNING",
"start_time": "2026-01-01T00:00:00Z",
})
case r.Method == "GET" && r.URL.Path == "/v1/workflows/executions/exec-1":
calls++
status := "RUNNING"
if calls >= 2 {
status = "COMPLETED"
}
resp := map[string]any{
"workflow_name": "my-flow", "execution_id": "exec-1",
"root_execution_id": "exec-1", "status": status,
"start_time": "2026-01-01T00:00:00Z",
}
if status == "COMPLETED" {
resp["result"] = map[string]any{"answer": 42}
}
json.NewEncoder(w).Encode(resp)
default:
t.Errorf("unexpected %s %s", r.Method, r.URL.Path)
}
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.ExecuteWorkflowAndWait(context.Background(), "wf-1", &workflow.ExecutionRequest{
Input: map[string]any{"prompt": "hello"},
})
if err != nil {
t.Fatal(err)
}
if resp.Status != workflow.ExecutionCompleted {
t.Errorf("got status %q", resp.Status)
}
}

31
workflows_metrics.go Normal file
View File

@@ -0,0 +1,31 @@
package mistral
import (
"context"
"fmt"
"net/url"
"github.com/VikingOwl91/mistral-go-sdk/workflow"
)
// GetWorkflowMetrics retrieves performance metrics for a workflow.
func (c *Client) GetWorkflowMetrics(ctx context.Context, workflowName string, params *workflow.MetricsParams) (*workflow.Metrics, error) {
path := fmt.Sprintf("/v1/workflows/%s/metrics", workflowName)
if params != nil {
q := url.Values{}
if params.StartTime != nil {
q.Set("start_time", *params.StartTime)
}
if params.EndTime != nil {
q.Set("end_time", *params.EndTime)
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp workflow.Metrics
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}

44
workflows_metrics_test.go Normal file
View File

@@ -0,0 +1,44 @@
package mistral
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/VikingOwl91/mistral-go-sdk/workflow"
)
func TestGetWorkflowMetrics_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/workflows/my-flow/metrics" {
t.Errorf("got path %s", r.URL.Path)
}
if r.URL.Query().Get("start_time") != "2026-01-01T00:00:00Z" {
t.Errorf("got start_time %q", r.URL.Query().Get("start_time"))
}
json.NewEncoder(w).Encode(map[string]any{
"execution_count": map[string]any{"value": 100},
"success_count": map[string]any{"value": 95},
"error_count": map[string]any{"value": 5},
"average_latency_ms": map[string]any{"value": 1234.5},
"latency_over_time": map[string]any{"value": [][]float64{{1711929600, 1200}, {1711929660, 1300}}},
"retry_rate": map[string]any{"value": 0.02},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
start := "2026-01-01T00:00:00Z"
resp, err := client.GetWorkflowMetrics(context.Background(), "my-flow", &workflow.MetricsParams{StartTime: &start})
if err != nil {
t.Fatal(err)
}
if resp.ExecutionCount.Value != 100 {
t.Errorf("got execution_count %v", resp.ExecutionCount.Value)
}
if resp.AverageLatencyMs.Value != 1234.5 {
t.Errorf("got average_latency_ms %v", resp.AverageLatencyMs.Value)
}
}

60
workflows_runs.go Normal file
View File

@@ -0,0 +1,60 @@
package mistral
import (
"context"
"encoding/json"
"fmt"
"net/url"
"strconv"
"github.com/VikingOwl91/mistral-go-sdk/workflow"
)
// ListWorkflowRuns lists workflow runs.
func (c *Client) ListWorkflowRuns(ctx context.Context, params *workflow.RunListParams) (*workflow.ListRunsResponse, error) {
path := "/v1/workflows/runs"
if params != nil {
q := url.Values{}
if params.WorkflowIdentifier != nil {
q.Set("workflow_identifier", *params.WorkflowIdentifier)
}
if params.Search != nil {
q.Set("search", *params.Search)
}
if params.Status != nil {
q.Set("status", *params.Status)
}
if params.PageSize != nil {
q.Set("page_size", strconv.Itoa(*params.PageSize))
}
if params.NextPageToken != nil {
q.Set("next_page_token", *params.NextPageToken)
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp workflow.ListRunsResponse
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// GetWorkflowRun retrieves a workflow run by ID.
func (c *Client) GetWorkflowRun(ctx context.Context, runID string) (*workflow.Run, error) {
var resp workflow.Run
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/workflows/runs/%s", runID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// GetWorkflowRunHistory retrieves the history of a workflow run.
func (c *Client) GetWorkflowRunHistory(ctx context.Context, runID string) (json.RawMessage, error) {
var resp json.RawMessage
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/workflows/runs/%s/history", runID), nil, &resp); err != nil {
return nil, err
}
return resp, nil
}

58
workflows_runs_test.go Normal file
View File

@@ -0,0 +1,58 @@
package mistral
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
)
func TestListWorkflowRuns_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/workflows/runs" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"runs": []map[string]any{
{"id": "run-1", "workflow_name": "my-flow", "execution_id": "exec-1", "status": "COMPLETED", "start_time": "2026-01-01"},
},
"next_page_token": "tok-1",
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.ListWorkflowRuns(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
if len(resp.Runs) != 1 {
t.Fatalf("got %d runs", len(resp.Runs))
}
if resp.NextPageToken == nil || *resp.NextPageToken != "tok-1" {
t.Errorf("got token %v", resp.NextPageToken)
}
}
func TestGetWorkflowRun_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/workflows/runs/run-1" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"id": "run-1", "workflow_name": "my-flow", "execution_id": "exec-1",
"status": "COMPLETED", "start_time": "2026-01-01",
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
run, err := client.GetWorkflowRun(context.Background(), "run-1")
if err != nil {
t.Fatal(err)
}
if run.ID != "run-1" {
t.Errorf("got id %q", run.ID)
}
}

39
workflows_schedules.go Normal file
View File

@@ -0,0 +1,39 @@
package mistral
import (
"context"
"fmt"
"github.com/VikingOwl91/mistral-go-sdk/workflow"
)
// ListWorkflowSchedules lists workflow schedules.
func (c *Client) ListWorkflowSchedules(ctx context.Context) (*workflow.ScheduleListResponse, error) {
var resp workflow.ScheduleListResponse
if err := c.doJSON(ctx, "GET", "/v1/workflows/schedules", nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ScheduleWorkflow creates a workflow schedule.
func (c *Client) ScheduleWorkflow(ctx context.Context, req *workflow.ScheduleRequest) (*workflow.ScheduleResponse, error) {
var resp workflow.ScheduleResponse
if err := c.doJSON(ctx, "POST", "/v1/workflows/schedules", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// UnscheduleWorkflow removes a workflow schedule.
func (c *Client) UnscheduleWorkflow(ctx context.Context, scheduleID string) error {
resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/workflows/schedules/%s", scheduleID), nil)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return parseAPIError(resp)
}
return nil
}

View File

@@ -0,0 +1,89 @@
package mistral
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/VikingOwl91/mistral-go-sdk/workflow"
)
func TestScheduleWorkflow_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("got method %s", r.Method)
}
if r.URL.Path != "/v1/workflows/schedules" {
t.Errorf("got path %s", r.URL.Path)
}
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
schedule, _ := body["schedule"].(map[string]any)
cronExprs, _ := schedule["cron_expressions"].([]any)
if len(cronExprs) != 1 || cronExprs[0] != "0 9 * * MON-FRI" {
t.Errorf("got cron_expressions %v", cronExprs)
}
json.NewEncoder(w).Encode(map[string]any{"schedule_id": "sched-1"})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
wfID := "wf-1"
resp, err := client.ScheduleWorkflow(context.Background(), &workflow.ScheduleRequest{
WorkflowIdentifier: &wfID,
Schedule: workflow.ScheduleDefinition{
Input: map[string]any{"prompt": "daily report"},
CronExpressions: []string{"0 9 * * MON-FRI"},
},
})
if err != nil {
t.Fatal(err)
}
if resp.ScheduleID != "sched-1" {
t.Errorf("got schedule_id %q", resp.ScheduleID)
}
}
func TestUnscheduleWorkflow_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "DELETE" {
t.Errorf("got method %s", r.Method)
}
if r.URL.Path != "/v1/workflows/schedules/sched-1" {
t.Errorf("got path %s", r.URL.Path)
}
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
err := client.UnscheduleWorkflow(context.Background(), "sched-1")
if err != nil {
t.Fatal(err)
}
}
func TestListWorkflowSchedules_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/workflows/schedules" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"schedules": []map[string]any{
{"schedule_id": "sched-1", "workflow_name": "my-flow", "created_at": "2026-01-01", "updated_at": "2026-01-01"},
},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.ListWorkflowSchedules(context.Background())
if err != nil {
t.Fatal(err)
}
if len(resp.Schedules) != 1 {
t.Fatalf("got %d schedules", len(resp.Schedules))
}
}

190
workflows_test.go Normal file
View File

@@ -0,0 +1,190 @@
package mistral
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/VikingOwl91/mistral-go-sdk/workflow"
)
func TestListWorkflows_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
t.Errorf("got method %s", r.Method)
}
if r.URL.Path != "/v1/workflows" {
t.Errorf("got path %s", r.URL.Path)
}
if r.URL.Query().Get("limit") != "10" {
t.Errorf("got limit %q", r.URL.Query().Get("limit"))
}
if r.URL.Query().Get("active_only") != "true" {
t.Errorf("got active_only %q", r.URL.Query().Get("active_only"))
}
json.NewEncoder(w).Encode(map[string]any{
"workflows": []map[string]any{
{"id": "wf-1", "name": "my-flow", "owner_id": "u1", "workspace_id": "ws1", "created_at": "2026-01-01", "updated_at": "2026-01-01"},
},
"next_cursor": "cur-abc",
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
active := true
limit := 10
resp, err := client.ListWorkflows(context.Background(), &workflow.WorkflowListParams{
ActiveOnly: &active,
Limit: &limit,
})
if err != nil {
t.Fatal(err)
}
if len(resp.Workflows) != 1 {
t.Fatalf("got %d workflows", len(resp.Workflows))
}
if resp.Workflows[0].ID != "wf-1" {
t.Errorf("got id %q", resp.Workflows[0].ID)
}
if resp.NextCursor == nil || *resp.NextCursor != "cur-abc" {
t.Errorf("got cursor %v", resp.NextCursor)
}
}
func TestGetWorkflow_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/workflows/wf-1" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"id": "wf-1", "name": "my-flow", "owner_id": "u1", "workspace_id": "ws1",
"created_at": "2026-01-01", "updated_at": "2026-01-01",
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
wf, err := client.GetWorkflow(context.Background(), "wf-1")
if err != nil {
t.Fatal(err)
}
if wf.Name != "my-flow" {
t.Errorf("got name %q", wf.Name)
}
}
func TestUpdateWorkflow_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "PUT" {
t.Errorf("got method %s", r.Method)
}
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
if body["display_name"] != "New Name" {
t.Errorf("got display_name %v", body["display_name"])
}
json.NewEncoder(w).Encode(map[string]any{
"id": "wf-1", "name": "my-flow", "display_name": "New Name",
"owner_id": "u1", "workspace_id": "ws1",
"created_at": "2026-01-01", "updated_at": "2026-01-02",
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
name := "New Name"
wf, err := client.UpdateWorkflow(context.Background(), "wf-1", &workflow.WorkflowUpdateRequest{
DisplayName: &name,
})
if err != nil {
t.Fatal(err)
}
if wf.DisplayName == nil || *wf.DisplayName != "New Name" {
t.Errorf("got display_name %v", wf.DisplayName)
}
}
func TestArchiveWorkflow_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "PUT" {
t.Errorf("got method %s", r.Method)
}
if r.URL.Path != "/v1/workflows/wf-1/archive" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{"id": "wf-1", "archived": true})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.ArchiveWorkflow(context.Background(), "wf-1")
if err != nil {
t.Fatal(err)
}
if !resp.Archived {
t.Error("expected archived=true")
}
}
func TestExecuteWorkflow_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("got method %s", r.Method)
}
if r.URL.Path != "/v1/workflows/wf-1/execute" {
t.Errorf("got path %s", r.URL.Path)
}
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
input, _ := body["input"].(map[string]any)
if input["prompt"] != "hello" {
t.Errorf("got input %v", body["input"])
}
json.NewEncoder(w).Encode(map[string]any{
"workflow_name": "my-flow", "execution_id": "exec-1",
"root_execution_id": "exec-1", "status": "RUNNING",
"start_time": "2026-01-01T00:00:00Z",
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.ExecuteWorkflow(context.Background(), "wf-1", &workflow.ExecutionRequest{
Input: map[string]any{"prompt": "hello"},
})
if err != nil {
t.Fatal(err)
}
if resp.ExecutionID != "exec-1" {
t.Errorf("got execution_id %q", resp.ExecutionID)
}
if resp.Status != workflow.ExecutionRunning {
t.Errorf("got status %q", resp.Status)
}
}
func TestListWorkflowRegistrations_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/workflows/registrations" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"registrations": []map[string]any{
{"id": "reg-1", "workflow_id": "wf-1", "task_queue": "default", "created_at": "2026-01-01", "updated_at": "2026-01-01"},
},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.ListWorkflowRegistrations(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
if len(resp.Registrations) != 1 {
t.Fatalf("got %d registrations", len(resp.Registrations))
}
}

16
workflows_workers.go Normal file
View File

@@ -0,0 +1,16 @@
package mistral
import (
"context"
"github.com/VikingOwl91/mistral-go-sdk/workflow"
)
// GetWorkflowWorkerInfo retrieves information about the current worker.
func (c *Client) GetWorkflowWorkerInfo(ctx context.Context) (*workflow.WorkerInfo, error) {
var resp workflow.WorkerInfo
if err := c.doJSON(ctx, "GET", "/v1/workflows/workers/whoami", nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}

Some files were not shown because too many files have changed in this diff Show More