Compare commits
17 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 3167966b98 | |||
| e22732aa7c | |||
| 6928b9f1c9 | |||
| 0ab8064a06 | |||
| c5b0011e30 | |||
| dc30e09c77 | |||
| 3b0530a409 | |||
| 29aa8e0de1 | |||
| 910970f45e | |||
| a699495fc2 | |||
| a41bf39325 | |||
| 58712f8364 | |||
| b2a1f141e0 | |||
| aa5c53c407 | |||
| b1f0fc4907 | |||
| 94a938b733 | |||
| 52231df17b |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -29,3 +29,4 @@ vendor/
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
client-python/
|
||||
|
||||
126
CHANGELOG.md
126
CHANGELOG.md
@@ -1,3 +1,129 @@
|
||||
## v1.2.1 — 2026-04-03
|
||||
|
||||
Move module path to `github.com/VikingOwl91/mistral-go-sdk` for public
|
||||
discoverability on pkg.go.dev.
|
||||
|
||||
### Changed
|
||||
|
||||
- Module path changed from `somegit.dev/vikingowl/mistral-go-sdk` to
|
||||
`github.com/VikingOwl91/mistral-go-sdk`.
|
||||
|
||||
### Fixed
|
||||
|
||||
- `TestChatCompleteStream_WithToolCalls` fixture now includes `finish_reason`
|
||||
and `usage` to match real Mistral API responses.
|
||||
|
||||
## v1.2.0 — 2026-04-02
|
||||
|
||||
Upstream sync with Python SDK v2.2.0. Adds Workflows API and DeleteBatchJob.
|
||||
|
||||
### Added
|
||||
|
||||
- **Workflows API** (new `workflow/` package) — complete workflow orchestration
|
||||
support with 37 service methods across 8 sub-resources:
|
||||
- **Workflows CRUD** — `ListWorkflows`, `GetWorkflow`, `UpdateWorkflow`,
|
||||
`ArchiveWorkflow`, `UnarchiveWorkflow`, `ExecuteWorkflow`,
|
||||
`ExecuteWorkflowAndWait`.
|
||||
- **Registrations** — `ListWorkflowRegistrations`, `GetWorkflowRegistration`,
|
||||
`ExecuteWorkflowRegistration` (deprecated).
|
||||
- **Executions** — `GetWorkflowExecution`, `GetWorkflowExecutionHistory`,
|
||||
`StreamWorkflowExecution`, `SignalWorkflowExecution`,
|
||||
`QueryWorkflowExecution`, `UpdateWorkflowExecution`,
|
||||
`TerminateWorkflowExecution`, `CancelWorkflowExecution`,
|
||||
`ResetWorkflowExecution`, `BatchCancelWorkflowExecutions`,
|
||||
`BatchTerminateWorkflowExecutions`.
|
||||
- **Trace** — `GetWorkflowExecutionTraceOTel`,
|
||||
`GetWorkflowExecutionTraceSummary`, `GetWorkflowExecutionTraceEvents`.
|
||||
- **Events** — `StreamWorkflowEvents`, `ListWorkflowEvents`.
|
||||
- **Deployments** — `ListWorkflowDeployments`, `GetWorkflowDeployment`.
|
||||
- **Metrics** — `GetWorkflowMetrics`.
|
||||
- **Runs** — `ListWorkflowRuns`, `GetWorkflowRun`, `GetWorkflowRunHistory`.
|
||||
- **Schedules** — `ListWorkflowSchedules`, `ScheduleWorkflow`,
|
||||
`UnscheduleWorkflow`.
|
||||
- **Workers** — `GetWorkflowWorkerInfo`.
|
||||
- **`WorkflowEventStream`** — typed SSE stream wrapper with `StreamPayload`
|
||||
envelope, sealed `Event` interface (17 concrete types + `UnknownEvent`).
|
||||
- **`DeleteBatchJob`** — delete a batch job by ID.
|
||||
|
||||
## v1.1.0 — 2026-03-24
|
||||
|
||||
Upstream sync with Python SDK v2.1.3. Adds Connectors, Audio Speech/Voices, and Observability (beta).
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
- **`ListModels`** signature changed from `(ctx)` to `(ctx, *model.ListParams)`.
|
||||
Pass `nil` for previous behavior. The new `ListParams` supports `Provider` and
|
||||
`Model` query filters.
|
||||
- **`UploadFile`** signature changed from `(ctx, filename, reader, purpose)` to
|
||||
`(ctx, filename, reader, *file.UploadParams)`. The new `UploadParams` struct
|
||||
holds `Purpose`, `Expiry`, and `Visibility` fields.
|
||||
|
||||
### Added
|
||||
|
||||
- **`ReasoningEffort`** field on `chat.CompletionRequest` and
|
||||
`agents.CompletionRequest` — controls reasoning effort (`"none"`, `"high"`).
|
||||
- **Connectors API** (new `connector/` package) — `CreateConnector`,
|
||||
`ListConnectors`, `GetConnector`, `UpdateConnector`, `DeleteConnector`,
|
||||
`GetConnectorAuthURL`, `ListConnectorTools`, `CallConnectorTool`.
|
||||
- **Audio Speech (TTS)** — `Speech`, `SpeechStream` with `SpeechStream` typed
|
||||
wrapper, `SpeechOutputFormat` enum (pcm/wav/mp3/flac/opus).
|
||||
- **Audio Voices** — `ListVoices`, `CreateVoice`, `GetVoice`, `UpdateVoice`,
|
||||
`DeleteVoice`, `GetVoiceSampleAudio`.
|
||||
- **Audio Realtime types** — `AudioEncoding`, `AudioFormat`, `RealtimeSession`,
|
||||
and WebSocket message types in `audio/realtime.go`. No WebSocket client yet
|
||||
(would require adding a dependency).
|
||||
- **Observability API** (new `observability/` package, beta) — campaigns,
|
||||
chat completion events, judges, datasets, records, and import tasks.
|
||||
33 service methods total.
|
||||
- **`file.Visibility`** enum — `shared_global`, `shared_org`,
|
||||
`shared_workspace`, `private`.
|
||||
- **`model.ListParams`** — filter models by `Provider` and `Model`.
|
||||
|
||||
## v1.0.0 — 2026-03-17
|
||||
|
||||
Stable release. Tracks upstream Python SDK v2.0.4.
|
||||
|
||||
No API changes from v0.2.0. This release signals that the SDK surface is
|
||||
stable and follows Go module semver conventions — breaking changes will only
|
||||
occur in future major versions.
|
||||
|
||||
## v0.2.0 — 2026-03-17
|
||||
|
||||
Sync with upstream Python SDK v2.0.4. Upstream reference changed from OpenAPI
|
||||
spec to official Python SDK (https://github.com/mistralai/client-python).
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
- **`ReferenceChunk.ReferenceIDs`** changed from `[]int` to `[]ReferenceID`.
|
||||
The API now returns mixed integer and string identifiers. Use `IntRef(n)` and
|
||||
`StringRef(s)` constructors; read back with `.Int()` and `.IsString()`.
|
||||
- **`agents.GuardrailConfig` and `agents.ModerationLLMV1Config`** moved to
|
||||
`chat.GuardrailConfig` and `chat.ModerationLLMV1Config`. The types are now
|
||||
shared across chat, agents, and conversation packages.
|
||||
|
||||
### Added
|
||||
|
||||
- **`ToolReferenceChunk`** — new content chunk type for tool references
|
||||
returned by built-in connectors (web search, code interpreter, etc.).
|
||||
- **`ToolFileChunk`** — new content chunk type for tool-generated files.
|
||||
- **`BuiltInConnector`** constants — `ConnectorWebSearch`,
|
||||
`ConnectorWebSearchPremium`, `ConnectorCodeInterpreter`,
|
||||
`ConnectorImageGeneration`, `ConnectorDocumentLibrary`.
|
||||
- **`ModerationLLMV2Config`** — v2 moderation guardrail with split
|
||||
`dangerous`/`criminal` categories and new `jailbreaking` category.
|
||||
- **`GuardrailConfig`** on `chat.CompletionRequest`,
|
||||
`agents.CompletionRequest`, `conversation.StartRequest`, and
|
||||
`conversation.RestartRequest`.
|
||||
- **`ConnectorTool`** — new agent tool type for custom connectors with
|
||||
`ConnectorAuth` (api-key / oauth2-token authorization).
|
||||
- **`ModelCapabilities`** — added `AudioTranscriptionRealtime` and
|
||||
`AudioSpeech` fields.
|
||||
|
||||
### Removed
|
||||
|
||||
- Bundled `docs/openapi.yaml`. The SDK now tracks the upstream Python SDK
|
||||
directly as its reference implementation.
|
||||
|
||||
## v0.1.0 — 2026-03-05
|
||||
|
||||
Initial release.
|
||||
|
||||
89
CLAUDE.md
Normal file
89
CLAUDE.md
Normal file
@@ -0,0 +1,89 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Project
|
||||
|
||||
Idiomatic Go SDK for the Mistral AI API. Module path: `github.com/VikingOwl91/mistral-go-sdk`. Requires Go 1.26+. Zero external dependencies (stdlib only). Tracks the upstream [Mistral Python SDK](https://github.com/mistralai/client-python) as reference for API surface and type definitions.
|
||||
|
||||
## Repository layout
|
||||
|
||||
- **Working directory**: `mistral-go-sdk/` — the Go SDK source. All development happens here.
|
||||
- **`../client-python/`**: Clone of the upstream Mistral Python SDK. Read-only reference — pull/update it when checking for upstream API changes, but never modify it.
|
||||
|
||||
## Commands
|
||||
|
||||
```bash
|
||||
# Run all unit tests
|
||||
go test ./...
|
||||
|
||||
# Run a single test
|
||||
go test -run TestChatComplete_Success
|
||||
|
||||
# Run integration tests (requires MISTRAL_API_KEY env var)
|
||||
go test -tags=integration ./...
|
||||
|
||||
# Vet and build
|
||||
go vet ./...
|
||||
go build ./...
|
||||
```
|
||||
|
||||
No Makefile, linter config, or code generation tooling — standard `go test` / `go vet` / `go build`.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Two-layer design: types in sub-packages, methods on `*Client`
|
||||
|
||||
Sub-packages (`chat/`, `agents/`, `conversation/`, `embedding/`, `model/`, `file/`, `finetune/`, `batch/`, `ocr/`, `audio/`, `library/`, `moderation/`, `classification/`, `fim/`, `connector/`, `observability/`, `workflow/`) are **types-only** — they define request/response structs and enums but contain no HTTP logic. All service methods live on `*Client` in the root package, prefix-namespaced by domain (e.g. `ChatComplete`, `AgentsComplete`, `CreateFineTuningJob`, `UploadFile`).
|
||||
|
||||
### HTTP internals (request.go)
|
||||
|
||||
All HTTP flows route through a small set of unexported helpers on `*Client`:
|
||||
- `do()` — raw HTTP with auth headers + retry
|
||||
- `doJSON()` — JSON marshal request → `do()` → unmarshal response
|
||||
- `doStream()` — JSON request → raw `*http.Response` for SSE
|
||||
- `doMultipart()` / `doMultipartStream()` — multipart file upload variants
|
||||
- `doRetry()` — retry loop with exponential backoff + jitter + `Retry-After` parsing
|
||||
|
||||
### Streaming
|
||||
|
||||
Generic `Stream[T]` type wraps SSE (`sseReader`) with `Next()`/`Current()`/`Err()`/`Close()` iterator pattern. Typed wrappers `EventStream` (conversations) and `AudioStream` (transcription) unmarshal `json.RawMessage` into domain-specific event types.
|
||||
|
||||
### Sealed interfaces for discriminated unions
|
||||
|
||||
Polymorphic API types use **sealed interfaces** with unexported marker methods:
|
||||
- `chat.Message` (marker: `isMessage()`) — `SystemMessage`, `UserMessage`, `AssistantMessage`, `ToolMessage`
|
||||
- `chat.ContentChunk` (marker: `contentChunk()`) — `TextChunk`, `ImageURLChunk`, `DocumentURLChunk`, `FileChunk`, `ReferenceChunk`, `ThinkChunk`, `AudioChunk`, `ToolReferenceChunk`, `ToolFileChunk`
|
||||
- `agents.AgentTool` (marker: `agentToolType()`) — `FunctionTool`, `WebSearchTool`, `CodeInterpreterTool`, `ConnectorTool`, etc.
|
||||
- `conversation.Event` — conversation streaming events
|
||||
|
||||
Each has an `Unknown*` variant so the SDK doesn't break on new API types. Each has a `Unmarshal*` dispatch function that probes a `type`/`role` discriminator field.
|
||||
|
||||
### Custom JSON patterns
|
||||
|
||||
Several types require non-trivial marshal/unmarshal:
|
||||
- **Type alias trick** — `type alias T` inside `MarshalJSON` to avoid infinite recursion when injecting a `type`/`role` discriminator field.
|
||||
- **`json:"-"` + custom MarshalJSON** — `CompletionRequest.Messages` (and `stream`) are excluded from default marshaling and injected via custom `MarshalJSON`.
|
||||
- **Union types** — `Content` handles `string | null | []ContentChunk`; `ToolChoice` handles `string | object`; `ImageURL` handles `string | object`; `FunctionCall.Arguments` handles `string | object`; `ReferenceID` handles `int | string` with type preservation.
|
||||
- **Probe struct pattern** — `Unmarshal*` functions decode only the discriminator field first, then dispatch to the concrete type.
|
||||
|
||||
### Shared types in `chat/`
|
||||
|
||||
`GuardrailConfig`, `ModerationLLMV1Config`, `ModerationLLMV2Config` live in `chat/` because it's the base types package imported by both `agents/` and `conversation/`. This avoids import cycles.
|
||||
|
||||
### Error handling
|
||||
|
||||
`APIError` in `error.go` with sentinel checkers: `IsNotFound()`, `IsRateLimit()`, `IsAuth()`. All use `errors.As` for unwrapping.
|
||||
|
||||
## Testing patterns
|
||||
|
||||
- Unit tests use `httptest.NewServer` with inline handlers to mock the Mistral API. Client is pointed at the test server via `WithBaseURL(server.URL)`.
|
||||
- Integration tests are behind `//go:build integration` build tag and require `MISTRAL_API_KEY`.
|
||||
- Tests use stdlib `testing` only — no third-party test frameworks.
|
||||
|
||||
## Adding a new API endpoint
|
||||
|
||||
1. Define request/response types in the appropriate sub-package (or create a new one with a `doc.go`).
|
||||
2. Add a method on `*Client` in the root package. Use `doJSON` for standard request/response, `doStream` for SSE, `doMultipart` for file uploads.
|
||||
3. Add unit tests with `httptest.NewServer`.
|
||||
4. If the endpoint supports streaming, return `*Stream[T]` and call `EnableStream()` on the request before sending.
|
||||
54
README.md
54
README.md
@@ -3,6 +3,7 @@
|
||||
The most complete Go client for the [Mistral AI API](https://docs.mistral.ai/).
|
||||
|
||||
<!-- Badges -->
|
||||
[](https://pkg.go.dev/github.com/VikingOwl91/mistral-go-sdk)
|
||||

|
||||

|
||||
|
||||
@@ -10,20 +11,20 @@ The most complete Go client for the [Mistral AI API](https://docs.mistral.ai/).
|
||||
|
||||
**Zero dependencies.** The entire SDK — including tests — uses only the Go standard library. No `go.sum`, no transitive dependency tree to audit, no version conflicts, no supply chain risk.
|
||||
|
||||
**Full API coverage.** 75 methods across every Mistral endpoint — including Conversations, Agents CRUD, Libraries, OCR, Audio, Fine-tuning, and Batch Jobs. No other Go SDK covers Conversations or Agents.
|
||||
**Full API coverage.** 166 methods across every Mistral endpoint — including Workflows, Connectors, Audio Speech/Voices, Conversations, Agents CRUD, Libraries, OCR, Observability, Fine-tuning, and Batch Jobs. No other Go SDK covers Workflows, Conversations, Connectors, or Observability.
|
||||
|
||||
**Typed streaming.** A generic pull-based `Stream[T]` iterator — no channels, no goroutines, no leaks. Just `Next()` / `Current()` / `Err()` / `Close()`.
|
||||
|
||||
**Forward-compatible.** Unknown types (`UnknownEntry`, `UnknownEvent`, `UnknownMessage`, `UnknownChunk`, `UnknownAgentTool`) capture raw JSON instead of returning errors. When Mistral ships a new message role or event type, your code keeps running — it doesn't panic.
|
||||
**Forward-compatible.** Unknown types (`UnknownEntry`, `UnknownEvent`, `UnknownMessage`, `UnknownChunk`, `UnknownAgentTool`, workflow `UnknownEvent`) capture raw JSON instead of returning errors. When Mistral ships a new message role or event type, your code keeps running — it doesn't panic.
|
||||
|
||||
**Hand-written, not generated.** Idiomatic Go with sealed interfaces, discriminated unions, and functional options — not a Speakeasy/OpenAPI auto-gen dump with `any` everywhere.
|
||||
|
||||
**Test-driven.** 126 tests with race detection clean. Every endpoint tested against mock servers; integration tests against the real API.
|
||||
**Test-driven.** 284 tests with race detection clean. Every endpoint tested against mock servers; integration tests against the real API.
|
||||
|
||||
## Install
|
||||
|
||||
```sh
|
||||
go get somegit.dev/vikingowl/mistral-go-sdk
|
||||
go get github.com/VikingOwl91/mistral-go-sdk
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
@@ -38,8 +39,8 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
mistral "somegit.dev/vikingowl/mistral-go-sdk"
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/chat"
|
||||
mistral "github.com/VikingOwl91/mistral-go-sdk"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/chat"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -111,7 +112,7 @@ resp, err := client.ChatComplete(ctx, &chat.CompletionRequest{
|
||||
### Conversations
|
||||
|
||||
```go
|
||||
import "somegit.dev/vikingowl/mistral-go-sdk/conversation"
|
||||
import "github.com/VikingOwl91/mistral-go-sdk/conversation"
|
||||
|
||||
resp, err := client.StartConversation(ctx, &conversation.StartRequest{
|
||||
AgentID: "ag-your-agent-id",
|
||||
@@ -131,7 +132,7 @@ for stream.Next() {
|
||||
|
||||
## API Coverage
|
||||
|
||||
75 public methods on `Client`, grouped by domain:
|
||||
166 public methods on `Client`, grouped by domain:
|
||||
|
||||
| Domain | Methods |
|
||||
|--------|---------|
|
||||
@@ -139,17 +140,34 @@ for stream.Next() {
|
||||
| **FIM** | `FIMComplete`, `FIMCompleteStream` |
|
||||
| **Agents (completions)** | `AgentsComplete`, `AgentsCompleteStream` |
|
||||
| **Agents (CRUD)** | `CreateAgent`, `ListAgents`, `GetAgent`, `UpdateAgent`, `DeleteAgent`, `UpdateAgentVersion`, `ListAgentVersions`, `GetAgentVersion`, `SetAgentAlias`, `ListAgentAliases`, `DeleteAgentAlias` |
|
||||
| **Connectors** | `CreateConnector`, `ListConnectors`, `GetConnector`, `UpdateConnector`, `DeleteConnector`, `GetConnectorAuthURL`, `ListConnectorTools`, `CallConnectorTool` |
|
||||
| **Conversations** | `StartConversation`, `StartConversationStream`, `AppendConversation`, `AppendConversationStream`, `RestartConversation`, `RestartConversationStream`, `GetConversation`, `ListConversations`, `DeleteConversation`, `GetConversationHistory`, `GetConversationMessages` |
|
||||
| **Models** | `ListModels`, `GetModel`, `DeleteModel` |
|
||||
| **Files** | `UploadFile`, `ListFiles`, `GetFile`, `DeleteFile`, `GetFileContent`, `GetFileURL` |
|
||||
| **Embeddings** | `CreateEmbeddings` |
|
||||
| **Fine-tuning** | `CreateFineTuningJob`, `ListFineTuningJobs`, `GetFineTuningJob`, `CancelFineTuningJob`, `StartFineTuningJob`, `UpdateFineTunedModel`, `ArchiveFineTunedModel`, `UnarchiveFineTunedModel` |
|
||||
| **Batch** | `CreateBatchJob`, `ListBatchJobs`, `GetBatchJob`, `CancelBatchJob` |
|
||||
| **Batch** | `CreateBatchJob`, `ListBatchJobs`, `GetBatchJob`, `CancelBatchJob`, `DeleteBatchJob` |
|
||||
| **OCR** | `OCR` |
|
||||
| **Audio** | `Transcribe`, `TranscribeStream` |
|
||||
| **Audio (transcription)** | `Transcribe`, `TranscribeStream` |
|
||||
| **Audio (speech)** | `Speech`, `SpeechStream` |
|
||||
| **Audio (voices)** | `ListVoices`, `CreateVoice`, `GetVoice`, `UpdateVoice`, `DeleteVoice`, `GetVoiceSampleAudio` |
|
||||
| **Libraries** | `CreateLibrary`, `ListLibraries`, `GetLibrary`, `UpdateLibrary`, `DeleteLibrary`, `UploadDocument`, `ListDocuments`, `GetDocument`, `UpdateDocument`, `DeleteDocument`, `GetDocumentTextContent`, `GetDocumentStatus`, `GetDocumentSignedURL`, `GetDocumentExtractedTextSignedURL`, `ReprocessDocument`, `ListLibrarySharing`, `ShareLibrary`, `UnshareLibrary` |
|
||||
| **Moderation** | `Moderate`, `ModerateChat` |
|
||||
| **Classification** | `Classify`, `ClassifyChat` |
|
||||
| **Observability (campaigns)** | `CreateCampaign`, `ListCampaigns`, `GetCampaign`, `DeleteCampaign`, `GetCampaignStatus`, `ListCampaignEvents` |
|
||||
| **Observability (events)** | `SearchChatCompletionEvents`, `SearchChatCompletionEventIDs`, `GetChatCompletionEvent`, `GetSimilarChatCompletionEvents`, `JudgeChatCompletionEvent` |
|
||||
| **Observability (judges)** | `CreateJudge`, `ListJudges`, `GetJudge`, `UpdateJudge`, `DeleteJudge`, `JudgeConversation` |
|
||||
| **Observability (datasets)** | `CreateDataset`, `ListDatasets`, `GetDataset`, `UpdateDataset`, `DeleteDataset`, `ExportDatasetToJSONL`, `ListDatasetRecords`, `CreateDatasetRecord`, `GetDatasetRecord`, `UpdateDatasetRecordPayload`, `UpdateDatasetRecordProperties`, `DeleteDatasetRecord`, `BulkDeleteDatasetRecords`, `JudgeDatasetRecord`, `ImportDatasetFromCampaign`, `ImportDatasetFromExplorer`, `ImportDatasetFromFile`, `ImportDatasetFromPlayground`, `ImportDatasetFromDataset`, `ListDatasetTasks`, `GetDatasetTask` |
|
||||
| **Workflows (CRUD)** | `ListWorkflows`, `GetWorkflow`, `UpdateWorkflow`, `ArchiveWorkflow`, `UnarchiveWorkflow`, `ExecuteWorkflow`, `ExecuteWorkflowAndWait` |
|
||||
| **Workflows (registrations)** | `ListWorkflowRegistrations`, `GetWorkflowRegistration`, `ExecuteWorkflowRegistration` |
|
||||
| **Workflows (executions)** | `GetWorkflowExecution`, `GetWorkflowExecutionHistory`, `StreamWorkflowExecution`, `SignalWorkflowExecution`, `QueryWorkflowExecution`, `UpdateWorkflowExecution`, `TerminateWorkflowExecution`, `CancelWorkflowExecution`, `ResetWorkflowExecution`, `BatchCancelWorkflowExecutions`, `BatchTerminateWorkflowExecutions` |
|
||||
| **Workflows (trace)** | `GetWorkflowExecutionTraceOTel`, `GetWorkflowExecutionTraceSummary`, `GetWorkflowExecutionTraceEvents` |
|
||||
| **Workflows (events)** | `StreamWorkflowEvents`, `ListWorkflowEvents` |
|
||||
| **Workflows (deployments)** | `ListWorkflowDeployments`, `GetWorkflowDeployment` |
|
||||
| **Workflows (metrics)** | `GetWorkflowMetrics` |
|
||||
| **Workflows (runs)** | `ListWorkflowRuns`, `GetWorkflowRun`, `GetWorkflowRunHistory` |
|
||||
| **Workflows (schedules)** | `ListWorkflowSchedules`, `ScheduleWorkflow`, `UnscheduleWorkflow` |
|
||||
| **Workflows (workers)** | `GetWorkflowWorkerInfo` |
|
||||
|
||||
## Comparison
|
||||
|
||||
@@ -162,11 +180,14 @@ There is no official Go SDK from Mistral AI (only Python and TypeScript). The ma
|
||||
| Embeddings | Yes | Yes | Yes | Yes |
|
||||
| Tool calling | Yes | No | No | No |
|
||||
| Agents (completions + CRUD) | Yes | No | No | No |
|
||||
| Connectors (MCP) | Yes | No | No | No |
|
||||
| Conversations API | Yes | No | No | No |
|
||||
| Libraries / Documents | Yes | No | No | No |
|
||||
| Fine-tuning / Batch | Yes | No | No | No |
|
||||
| OCR | Yes | No | No | Yes |
|
||||
| Audio transcription | Yes | No | No | No |
|
||||
| Audio (transcription + TTS + voices) | Yes | No | No | No |
|
||||
| Workflows API | Yes | No | No | No |
|
||||
| Observability (beta) | Yes | No | No | No |
|
||||
| Moderation / Classification | Yes | No | No | No |
|
||||
| Vision (multimodal) | Yes | No | No | Yes |
|
||||
| Zero dependencies | Yes | test-only (testify) | test-only (testify) | test-only (testify) |
|
||||
@@ -213,6 +234,17 @@ if err != nil {
|
||||
}
|
||||
```
|
||||
|
||||
## Upstream Reference
|
||||
|
||||
This SDK tracks the [official Mistral Python SDK](https://github.com/mistralai/client-python)
|
||||
as its upstream reference for API surface and type definitions.
|
||||
|
||||
| SDK Version | Upstream Python SDK |
|
||||
|-------------|---------------------|
|
||||
| v1.2.0 | v2.2.0 |
|
||||
| v1.1.0 | v2.1.3 |
|
||||
| v1.0.0 | v2.0.4 |
|
||||
|
||||
## License
|
||||
|
||||
[MIT](LICENSE)
|
||||
|
||||
@@ -3,6 +3,8 @@ package agents
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/VikingOwl91/mistral-go-sdk/chat"
|
||||
)
|
||||
|
||||
// AgentTool is a sealed interface for agent tool types.
|
||||
@@ -59,6 +61,22 @@ type DocumentLibraryTool struct {
|
||||
|
||||
func (*DocumentLibraryTool) agentToolType() string { return "document_library" }
|
||||
|
||||
// ConnectorAuth holds authorization for a custom connector.
|
||||
type ConnectorAuth struct {
|
||||
Type string `json:"type"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
// ConnectorTool represents a custom connector tool.
|
||||
type ConnectorTool struct {
|
||||
Type string `json:"type"`
|
||||
ConnectorID string `json:"connector_id"`
|
||||
Authorization *ConnectorAuth `json:"authorization,omitempty"`
|
||||
ToolConfiguration *ToolConfiguration `json:"tool_configuration,omitempty"`
|
||||
}
|
||||
|
||||
func (*ConnectorTool) agentToolType() string { return "connector" }
|
||||
|
||||
// UnknownAgentTool holds an unrecognized tool type.
|
||||
type UnknownAgentTool struct {
|
||||
Type string
|
||||
@@ -99,6 +117,9 @@ func UnmarshalAgentTool(data []byte) (AgentTool, error) {
|
||||
case "document_library":
|
||||
var t DocumentLibraryTool
|
||||
return &t, json.Unmarshal(data, &t)
|
||||
case "connector":
|
||||
var t ConnectorTool
|
||||
return &t, json.Unmarshal(data, &t)
|
||||
default:
|
||||
return &UnknownAgentTool{Type: probe.Type, Raw: json.RawMessage(data)}, nil
|
||||
}
|
||||
@@ -151,7 +172,7 @@ type Agent struct {
|
||||
Description *string `json:"description,omitempty"`
|
||||
Tools AgentTools `json:"tools,omitempty"`
|
||||
CompletionArgs *CompletionArgs `json:"completion_args,omitempty"`
|
||||
Guardrails []GuardrailConfig `json:"guardrails,omitempty"`
|
||||
Guardrails []chat.GuardrailConfig `json:"guardrails,omitempty"`
|
||||
Handoffs []string `json:"handoffs,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
VersionMessage *string `json:"version_message,omitempty"`
|
||||
@@ -165,7 +186,7 @@ type CreateRequest struct {
|
||||
Description *string `json:"description,omitempty"`
|
||||
Tools AgentTools `json:"tools,omitempty"`
|
||||
CompletionArgs *CompletionArgs `json:"completion_args,omitempty"`
|
||||
Guardrails []GuardrailConfig `json:"guardrails,omitempty"`
|
||||
Guardrails []chat.GuardrailConfig `json:"guardrails,omitempty"`
|
||||
Handoffs []string `json:"handoffs,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
VersionMessage *string `json:"version_message,omitempty"`
|
||||
@@ -179,7 +200,7 @@ type UpdateRequest struct {
|
||||
Description *string `json:"description,omitempty"`
|
||||
Tools AgentTools `json:"tools,omitempty"`
|
||||
CompletionArgs *CompletionArgs `json:"completion_args,omitempty"`
|
||||
Guardrails []GuardrailConfig `json:"guardrails,omitempty"`
|
||||
Guardrails []chat.GuardrailConfig `json:"guardrails,omitempty"`
|
||||
Handoffs []string `json:"handoffs,omitempty"`
|
||||
DeploymentChat *bool `json:"deployment_chat,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
@@ -223,20 +244,6 @@ type CompletionArgs struct {
|
||||
ToolChoice *string `json:"tool_choice,omitempty"`
|
||||
}
|
||||
|
||||
// GuardrailConfig configures moderation guardrails for an agent.
|
||||
type GuardrailConfig struct {
|
||||
BlockOnError bool `json:"block_on_error"`
|
||||
ModerationLLMV1 *ModerationLLMV1Config `json:"moderation_llm_v1"`
|
||||
}
|
||||
|
||||
// ModerationLLMV1Config configures the moderation LLM guardrail.
|
||||
type ModerationLLMV1Config struct {
|
||||
ModelName string `json:"model_name,omitempty"`
|
||||
CustomCategoryThresholds json.RawMessage `json:"custom_category_thresholds,omitempty"`
|
||||
IgnoreOtherCategories bool `json:"ignore_other_categories,omitempty"`
|
||||
Action string `json:"action,omitempty"`
|
||||
}
|
||||
|
||||
// ToolConfiguration holds include/exclude/confirmation lists for tools.
|
||||
type ToolConfiguration struct {
|
||||
Exclude []string `json:"exclude,omitempty"`
|
||||
|
||||
@@ -66,6 +66,30 @@ func TestAgentTools_RoundTrip(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalAgentTool_Connector(t *testing.T) {
|
||||
data := []byte(`{"type":"connector","connector_id":"my-connector","authorization":{"type":"api-key","value":"sk-test"}}`)
|
||||
tool, err := UnmarshalAgentTool(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ct, ok := tool.(*ConnectorTool)
|
||||
if !ok {
|
||||
t.Fatalf("expected *ConnectorTool, got %T", tool)
|
||||
}
|
||||
if ct.ConnectorID != "my-connector" {
|
||||
t.Errorf("got connector_id %q", ct.ConnectorID)
|
||||
}
|
||||
if ct.Authorization == nil {
|
||||
t.Fatal("expected authorization")
|
||||
}
|
||||
if ct.Authorization.Type != "api-key" {
|
||||
t.Errorf("got auth type %q", ct.Authorization.Type)
|
||||
}
|
||||
if ct.Authorization.Value != "sk-test" {
|
||||
t.Errorf("got auth value %q", ct.Authorization.Value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_UnmarshalWithTools(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"id":"ag-1","object":"agent","name":"A","model":"m",
|
||||
|
||||
@@ -1,3 +1,10 @@
|
||||
// Package agents provides types for the Mistral agents API,
|
||||
// including agent CRUD operations and agent chat completions.
|
||||
//
|
||||
// # Tool Types
|
||||
//
|
||||
// Agents support multiple tool types via the [AgentTool] sealed interface:
|
||||
// [FunctionTool], [WebSearchTool], [WebSearchPremiumTool],
|
||||
// [CodeInterpreterTool], [ImageGenerationTool], [DocumentLibraryTool],
|
||||
// and [ConnectorTool] for custom connectors.
|
||||
package agents
|
||||
|
||||
@@ -3,7 +3,7 @@ package agents
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/chat"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/chat"
|
||||
)
|
||||
|
||||
// CompletionRequest represents an agents completion request.
|
||||
@@ -23,8 +23,10 @@ type CompletionRequest struct {
|
||||
N *int `json:"n,omitempty"`
|
||||
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
Prediction *chat.Prediction `json:"prediction,omitempty"`
|
||||
PromptMode *chat.PromptMode `json:"prompt_mode,omitempty"`
|
||||
Prediction *chat.Prediction `json:"prediction,omitempty"`
|
||||
PromptMode *chat.PromptMode `json:"prompt_mode,omitempty"`
|
||||
Guardrails []chat.GuardrailConfig `json:"guardrails,omitempty"`
|
||||
ReasoningEffort *chat.ReasoningEffort `json:"reasoning_effort,omitempty"`
|
||||
stream bool
|
||||
}
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ package mistral
|
||||
import (
|
||||
"context"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/agents"
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/chat"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/agents"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/chat"
|
||||
)
|
||||
|
||||
// AgentsComplete sends an agents completion request.
|
||||
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/agents"
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/chat"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/agents"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/chat"
|
||||
)
|
||||
|
||||
func TestAgentsComplete_Success(t *testing.T) {
|
||||
@@ -114,6 +114,39 @@ func TestAgentsComplete_WithTools(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentsComplete_ReasoningEffort(t *testing.T) {
|
||||
effort := chat.ReasoningEffortHigh
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var body map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
if body["reasoning_effort"] != "high" {
|
||||
t.Errorf("expected reasoning_effort=high, got %v", body["reasoning_effort"])
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "a-re", "object": "chat.completion",
|
||||
"model": "m", "created": 0,
|
||||
"choices": []map[string]any{{
|
||||
"index": 0, "message": map[string]any{"role": "assistant", "content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
}},
|
||||
"usage": map[string]any{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
_, err := client.AgentsComplete(context.Background(), &agents.CompletionRequest{
|
||||
AgentID: "agent-1",
|
||||
Messages: []chat.Message{&chat.UserMessage{Content: chat.TextContent("Hi")}},
|
||||
ReasoningEffort: &effort,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentsCompleteStream_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var body map[string]any
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/agents"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/agents"
|
||||
)
|
||||
|
||||
// CreateAgent creates a new agent.
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/agents"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/agents"
|
||||
)
|
||||
|
||||
func TestCreateAgent_Success(t *testing.T) {
|
||||
|
||||
22
audio/doc.go
22
audio/doc.go
@@ -1,5 +1,25 @@
|
||||
// Package audio provides types for the Mistral audio transcription API.
|
||||
// Package audio provides types for the Mistral audio APIs.
|
||||
//
|
||||
// # Transcription
|
||||
//
|
||||
// [TranscriptionRequest] and [TranscriptionResponse] handle speech-to-text.
|
||||
// Streaming transcription returns typed [StreamEvent] values via a sealed
|
||||
// interface dispatched by the "type" field.
|
||||
//
|
||||
// # Speech (TTS)
|
||||
//
|
||||
// [SpeechRequest] and [SpeechResponse] handle text-to-speech.
|
||||
// Streaming speech returns typed [SpeechStreamEvent] values
|
||||
// ([SpeechAudioDelta] and [SpeechDone]).
|
||||
//
|
||||
// # Voices
|
||||
//
|
||||
// [VoiceResponse], [VoiceCreateRequest], and [VoiceUpdateRequest] manage
|
||||
// custom voices for speech synthesis.
|
||||
//
|
||||
// # Realtime
|
||||
//
|
||||
// Realtime transcription types ([AudioEncoding], [AudioFormat],
|
||||
// [RealtimeSession], and WebSocket message types) are defined here.
|
||||
// The WebSocket client is not yet implemented.
|
||||
package audio
|
||||
|
||||
75
audio/realtime.go
Normal file
75
audio/realtime.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package audio
|
||||
|
||||
// AudioEncoding is the encoding format for realtime audio streams.
|
||||
type AudioEncoding string
|
||||
|
||||
const (
|
||||
EncodingPCMS16LE AudioEncoding = "pcm_s16le"
|
||||
EncodingPCMS32LE AudioEncoding = "pcm_s32le"
|
||||
EncodingPCMF16LE AudioEncoding = "pcm_f16le"
|
||||
EncodingPCMF32LE AudioEncoding = "pcm_f32le"
|
||||
EncodingPCMMulaw AudioEncoding = "pcm_mulaw"
|
||||
EncodingPCMAlaw AudioEncoding = "pcm_alaw"
|
||||
)
|
||||
|
||||
// AudioFormat describes the encoding and sample rate for realtime audio.
|
||||
type AudioFormat struct {
|
||||
Encoding AudioEncoding `json:"encoding"`
|
||||
SampleRate int `json:"sample_rate"`
|
||||
}
|
||||
|
||||
// RealtimeSession describes a realtime transcription session.
|
||||
type RealtimeSession struct {
|
||||
RequestID string `json:"request_id"`
|
||||
Model string `json:"model"`
|
||||
AudioFormat AudioFormat `json:"audio_format"`
|
||||
TargetStreamingDelayMs *int `json:"target_streaming_delay_ms,omitempty"`
|
||||
}
|
||||
|
||||
// RealtimeSessionUpdate is sent to update session parameters.
|
||||
// Parameters can only be changed before audio transmission starts.
|
||||
type RealtimeSessionUpdate struct {
|
||||
AudioFormat *AudioFormat `json:"audio_format,omitempty"`
|
||||
TargetStreamingDelayMs *int `json:"target_streaming_delay_ms,omitempty"`
|
||||
}
|
||||
|
||||
// InputAudioAppend sends a chunk of audio data.
|
||||
// Audio is base64-encoded (max 262144 bytes decoded).
|
||||
type InputAudioAppend struct {
|
||||
Type string `json:"type"` // "input_audio.append"
|
||||
Audio string `json:"audio"`
|
||||
}
|
||||
|
||||
// InputAudioFlush flushes the audio buffer.
|
||||
type InputAudioFlush struct {
|
||||
Type string `json:"type"` // "input_audio.flush"
|
||||
}
|
||||
|
||||
// InputAudioEnd signals the end of audio input.
|
||||
type InputAudioEnd struct {
|
||||
Type string `json:"type"` // "input_audio.end"
|
||||
}
|
||||
|
||||
// RealtimeSessionCreated is received when a session is created.
|
||||
type RealtimeSessionCreated struct {
|
||||
Type string `json:"type"` // "session.created"
|
||||
Session RealtimeSession `json:"session"`
|
||||
}
|
||||
|
||||
// RealtimeSessionUpdated is received when a session is updated.
|
||||
type RealtimeSessionUpdated struct {
|
||||
Type string `json:"type"` // "session.updated"
|
||||
Session RealtimeSession `json:"session"`
|
||||
}
|
||||
|
||||
// RealtimeErrorDetail describes a realtime error.
|
||||
type RealtimeErrorDetail struct {
|
||||
Message string `json:"message"`
|
||||
Code int `json:"code"`
|
||||
}
|
||||
|
||||
// RealtimeError is received on error.
|
||||
type RealtimeError struct {
|
||||
Type string `json:"type"` // "error"
|
||||
Error RealtimeErrorDetail `json:"error"`
|
||||
}
|
||||
88
audio/speech.go
Normal file
88
audio/speech.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package audio
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// SpeechOutputFormat is the output audio format for speech synthesis.
|
||||
type SpeechOutputFormat string
|
||||
|
||||
const (
|
||||
SpeechFormatPCM SpeechOutputFormat = "pcm"
|
||||
SpeechFormatWAV SpeechOutputFormat = "wav"
|
||||
SpeechFormatMP3 SpeechOutputFormat = "mp3"
|
||||
SpeechFormatFLAC SpeechOutputFormat = "flac"
|
||||
SpeechFormatOpus SpeechOutputFormat = "opus"
|
||||
)
|
||||
|
||||
// SpeechRequest represents a text-to-speech request.
|
||||
type SpeechRequest struct {
|
||||
Input string `json:"input"`
|
||||
Model string `json:"model"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
VoiceID *string `json:"voice_id,omitempty"`
|
||||
RefAudio *string `json:"ref_audio,omitempty"`
|
||||
ResponseFormat *SpeechOutputFormat `json:"response_format,omitempty"`
|
||||
stream bool
|
||||
}
|
||||
|
||||
// EnableStream is used internally to enable streaming.
|
||||
func (r *SpeechRequest) EnableStream() { r.stream = true }
|
||||
|
||||
func (r *SpeechRequest) MarshalJSON() ([]byte, error) {
|
||||
type Alias SpeechRequest
|
||||
return json.Marshal(&struct {
|
||||
Stream bool `json:"stream"`
|
||||
*Alias
|
||||
}{
|
||||
Stream: r.stream,
|
||||
Alias: (*Alias)(r),
|
||||
})
|
||||
}
|
||||
|
||||
// SpeechResponse is the response from a non-streaming speech request.
|
||||
type SpeechResponse struct {
|
||||
AudioData string `json:"audio_data"`
|
||||
}
|
||||
|
||||
// SpeechStreamEvent is a sealed interface for speech streaming events.
|
||||
type SpeechStreamEvent interface {
|
||||
speechStreamEvent()
|
||||
}
|
||||
|
||||
// SpeechAudioDelta contains a chunk of audio data during streaming.
|
||||
type SpeechAudioDelta struct {
|
||||
Type string `json:"type"`
|
||||
AudioData string `json:"audio_data"`
|
||||
}
|
||||
|
||||
func (*SpeechAudioDelta) speechStreamEvent() {}
|
||||
|
||||
// SpeechDone is emitted when speech synthesis is complete.
|
||||
type SpeechDone struct {
|
||||
Type string `json:"type"`
|
||||
Usage UsageInfo `json:"usage"`
|
||||
}
|
||||
|
||||
func (*SpeechDone) speechStreamEvent() {}
|
||||
|
||||
// UnmarshalSpeechStreamEvent dispatches a raw JSON event to the correct type.
|
||||
func UnmarshalSpeechStreamEvent(data []byte) (SpeechStreamEvent, error) {
|
||||
var probe struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &probe); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch probe.Type {
|
||||
case "speech.audio.delta":
|
||||
var e SpeechAudioDelta
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
case "speech.audio.done":
|
||||
var e SpeechDone
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown speech stream event type: %q", probe.Type)
|
||||
}
|
||||
}
|
||||
48
audio/voice.go
Normal file
48
audio/voice.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package audio
|
||||
|
||||
// VoiceResponse represents a voice entity.
|
||||
type VoiceResponse struct {
|
||||
Name string `json:"name"`
|
||||
ID string `json:"id"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UserID *string `json:"user_id,omitempty"`
|
||||
Slug *string `json:"slug,omitempty"`
|
||||
Languages []string `json:"languages,omitempty"`
|
||||
Gender *string `json:"gender,omitempty"`
|
||||
Age *int `json:"age,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
Color *string `json:"color,omitempty"`
|
||||
RetentionNotice *int `json:"retention_notice,omitempty"`
|
||||
}
|
||||
|
||||
// VoiceCreateRequest creates a custom voice.
|
||||
type VoiceCreateRequest struct {
|
||||
Name string `json:"name"`
|
||||
SampleAudio string `json:"sample_audio"`
|
||||
Slug *string `json:"slug,omitempty"`
|
||||
Languages []string `json:"languages,omitempty"`
|
||||
Gender *string `json:"gender,omitempty"`
|
||||
Age *int `json:"age,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
Color *string `json:"color,omitempty"`
|
||||
RetentionNotice *int `json:"retention_notice,omitempty"`
|
||||
SampleFilename *string `json:"sample_filename,omitempty"`
|
||||
}
|
||||
|
||||
// VoiceUpdateRequest updates a voice.
|
||||
type VoiceUpdateRequest struct {
|
||||
Name *string `json:"name,omitempty"`
|
||||
Languages []string `json:"languages,omitempty"`
|
||||
Gender *string `json:"gender,omitempty"`
|
||||
Age *int `json:"age,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
}
|
||||
|
||||
// VoiceListResponse is the response from listing voices.
|
||||
type VoiceListResponse struct {
|
||||
Items []VoiceResponse `json:"items"`
|
||||
Total int `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
}
|
||||
126
audio_api.go
126
audio_api.go
@@ -3,9 +3,11 @@ package mistral
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/audio"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/audio"
|
||||
)
|
||||
|
||||
// Transcribe sends an audio file for transcription.
|
||||
@@ -95,3 +97,125 @@ func (s *AudioStream) Err() error { return s.err }
|
||||
|
||||
// Close releases the underlying connection.
|
||||
func (s *AudioStream) Close() error { return s.stream.Close() }
|
||||
|
||||
// Speech sends a text-to-speech request and returns the full response.
|
||||
func (c *Client) Speech(ctx context.Context, req *audio.SpeechRequest) (*audio.SpeechResponse, error) {
|
||||
var resp audio.SpeechResponse
|
||||
if err := c.doJSON(ctx, "POST", "/v1/audio/speech", req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// SpeechStream sends a text-to-speech request and returns a stream of audio events.
|
||||
func (c *Client) SpeechStream(ctx context.Context, req *audio.SpeechRequest) (*SpeechStream, error) {
|
||||
req.EnableStream()
|
||||
resp, err := c.doStream(ctx, "POST", "/v1/audio/speech", req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newSpeechStream(resp.Body), nil
|
||||
}
|
||||
|
||||
// SpeechStream wraps the generic Stream for speech streaming events.
|
||||
type SpeechStream struct {
|
||||
stream *Stream[json.RawMessage]
|
||||
event audio.SpeechStreamEvent
|
||||
err error
|
||||
}
|
||||
|
||||
func newSpeechStream(body readCloser) *SpeechStream {
|
||||
return &SpeechStream{
|
||||
stream: newStream[json.RawMessage](body),
|
||||
}
|
||||
}
|
||||
|
||||
// Next advances to the next event. Returns false when done or on error.
|
||||
func (s *SpeechStream) Next() bool {
|
||||
if s.err != nil {
|
||||
return false
|
||||
}
|
||||
if !s.stream.Next() {
|
||||
s.err = s.stream.Err()
|
||||
return false
|
||||
}
|
||||
event, err := audio.UnmarshalSpeechStreamEvent(s.stream.Current())
|
||||
if err != nil {
|
||||
s.err = err
|
||||
return false
|
||||
}
|
||||
s.event = event
|
||||
return true
|
||||
}
|
||||
|
||||
// Current returns the most recently read event.
|
||||
func (s *SpeechStream) Current() audio.SpeechStreamEvent { return s.event }
|
||||
|
||||
// Err returns any error encountered during streaming.
|
||||
func (s *SpeechStream) Err() error { return s.err }
|
||||
|
||||
// Close releases the underlying connection.
|
||||
func (s *SpeechStream) Close() error { return s.stream.Close() }
|
||||
|
||||
// ListVoices returns available voices.
|
||||
func (c *Client) ListVoices(ctx context.Context) (*audio.VoiceListResponse, error) {
|
||||
var resp audio.VoiceListResponse
|
||||
if err := c.doJSON(ctx, "GET", "/v1/audio/voices", nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// CreateVoice creates a custom voice.
|
||||
func (c *Client) CreateVoice(ctx context.Context, req *audio.VoiceCreateRequest) (*audio.VoiceResponse, error) {
|
||||
var resp audio.VoiceResponse
|
||||
if err := c.doJSON(ctx, "POST", "/v1/audio/voices", req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// GetVoice retrieves a voice by ID.
|
||||
func (c *Client) GetVoice(ctx context.Context, voiceID string) (*audio.VoiceResponse, error) {
|
||||
var resp audio.VoiceResponse
|
||||
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/audio/voices/%s", voiceID), nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// UpdateVoice updates a voice.
|
||||
func (c *Client) UpdateVoice(ctx context.Context, voiceID string, req *audio.VoiceUpdateRequest) (*audio.VoiceResponse, error) {
|
||||
var resp audio.VoiceResponse
|
||||
if err := c.doJSON(ctx, "PATCH", fmt.Sprintf("/v1/audio/voices/%s", voiceID), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// DeleteVoice deletes a voice.
|
||||
func (c *Client) DeleteVoice(ctx context.Context, voiceID string) error {
|
||||
resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/audio/voices/%s", voiceID), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 400 {
|
||||
return parseAPIError(resp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetVoiceSampleAudio retrieves the sample audio for a voice.
|
||||
// Returns the raw HTTP response; the caller must close the body.
|
||||
func (c *Client) GetVoiceSampleAudio(ctx context.Context, voiceID string) (*http.Response, error) {
|
||||
resp, err := c.do(ctx, "GET", fmt.Sprintf("/v1/audio/voices/%s/sample", voiceID), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
defer resp.Body.Close()
|
||||
return nil, parseAPIError(resp)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
217
audio_speech_test.go
Normal file
217
audio_speech_test.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/VikingOwl91/mistral-go-sdk/audio"
|
||||
)
|
||||
|
||||
func TestSpeech_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
t.Errorf("expected POST, got %s", r.Method)
|
||||
}
|
||||
if r.URL.Path != "/v1/audio/speech" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
var body map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
if body["input"] != "Hello world" {
|
||||
t.Errorf("got input %v", body["input"])
|
||||
}
|
||||
if body["stream"] != false {
|
||||
t.Errorf("expected stream=false")
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"audio_data": "base64audiodata==",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.Speech(context.Background(), &audio.SpeechRequest{
|
||||
Input: "Hello world",
|
||||
Model: "mistral-speech",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.AudioData != "base64audiodata==" {
|
||||
t.Errorf("got audio_data %q", resp.AudioData)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpeechStream_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var body map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
if body["stream"] != true {
|
||||
t.Errorf("expected stream=true")
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
flusher, _ := w.(http.Flusher)
|
||||
|
||||
delta, _ := json.Marshal(map[string]any{
|
||||
"type": "speech.audio.delta", "audio_data": "chunk1==",
|
||||
})
|
||||
fmt.Fprintf(w, "data: %s\n\n", delta)
|
||||
flusher.Flush()
|
||||
|
||||
done, _ := json.Marshal(map[string]any{
|
||||
"type": "speech.audio.done",
|
||||
"usage": map[string]any{
|
||||
"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15,
|
||||
},
|
||||
})
|
||||
fmt.Fprintf(w, "data: %s\n\n", done)
|
||||
flusher.Flush()
|
||||
|
||||
fmt.Fprint(w, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
stream, err := client.SpeechStream(context.Background(), &audio.SpeechRequest{
|
||||
Input: "Hi",
|
||||
Model: "mistral-speech",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
var events int
|
||||
for stream.Next() {
|
||||
events++
|
||||
}
|
||||
if err := stream.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if events != 2 {
|
||||
t.Errorf("got %d events, want 2", events)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListVoices_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/audio/voices" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"items": []map[string]any{
|
||||
{"id": "v1", "name": "Default", "created_at": "2025-01-01"},
|
||||
},
|
||||
"total": 1, "page": 1, "page_size": 10, "total_pages": 1,
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.ListVoices(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(resp.Items) != 1 {
|
||||
t.Fatalf("got %d voices", len(resp.Items))
|
||||
}
|
||||
if resp.Items[0].ID != "v1" {
|
||||
t.Errorf("got id %q", resp.Items[0].ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateVoice_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
t.Errorf("expected POST, got %s", r.Method)
|
||||
}
|
||||
var body map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
if body["name"] != "MyVoice" {
|
||||
t.Errorf("got name %v", body["name"])
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "v2", "name": "MyVoice", "created_at": "2025-01-01",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.CreateVoice(context.Background(), &audio.VoiceCreateRequest{
|
||||
Name: "MyVoice",
|
||||
SampleAudio: "base64audio==",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.ID != "v2" {
|
||||
t.Errorf("got id %q", resp.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetVoice_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/audio/voices/v1" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "v1", "name": "Default", "created_at": "2025-01-01",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.GetVoice(context.Background(), "v1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Name != "Default" {
|
||||
t.Errorf("got name %q", resp.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateVoice_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "PATCH" {
|
||||
t.Errorf("expected PATCH, got %s", r.Method)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "v1", "name": "Renamed", "created_at": "2025-01-01",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
name := "Renamed"
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.UpdateVoice(context.Background(), "v1", &audio.VoiceUpdateRequest{
|
||||
Name: &name,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Name != "Renamed" {
|
||||
t.Errorf("got name %q", resp.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteVoice_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "DELETE" {
|
||||
t.Errorf("expected DELETE, got %s", r.Method)
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
err := client.DeleteVoice(context.Background(), "v1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/audio"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/audio"
|
||||
)
|
||||
|
||||
func TestTranscribe_WithFileURL(t *testing.T) {
|
||||
|
||||
@@ -60,3 +60,10 @@ type ListParams struct {
|
||||
Status []string
|
||||
OrderBy *string
|
||||
}
|
||||
|
||||
// DeleteResponse is the response from deleting a batch job.
|
||||
type DeleteResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Deleted bool `json:"deleted"`
|
||||
}
|
||||
|
||||
11
batch_api.go
11
batch_api.go
@@ -7,7 +7,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/batch"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/batch"
|
||||
)
|
||||
|
||||
// CreateBatchJob creates a new batch inference job.
|
||||
@@ -76,3 +76,12 @@ func (c *Client) CancelBatchJob(ctx context.Context, jobID string) (*batch.JobOu
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// DeleteBatchJob deletes a batch job.
|
||||
func (c *Client) DeleteBatchJob(ctx context.Context, jobID string) (*batch.DeleteResponse, error) {
|
||||
var resp batch.DeleteResponse
|
||||
if err := c.doJSON(ctx, "DELETE", fmt.Sprintf("/v1/batch/jobs/%s", jobID), nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/batch"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/batch"
|
||||
)
|
||||
|
||||
func TestCreateBatchJob_Success(t *testing.T) {
|
||||
@@ -121,3 +121,30 @@ func TestCancelBatchJob_Success(t *testing.T) {
|
||||
t.Errorf("got status %q", job.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteBatchJob_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "DELETE" {
|
||||
t.Errorf("expected DELETE, got %s", r.Method)
|
||||
}
|
||||
if r.URL.Path != "/v1/batch/jobs/batch-123" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "batch-123", "object": "batch", "deleted": true,
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.DeleteBatchJob(context.Background(), "batch-123")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.ID != "batch-123" {
|
||||
t.Errorf("got id %q", resp.ID)
|
||||
}
|
||||
if !resp.Deleted {
|
||||
t.Error("expected deleted=true")
|
||||
}
|
||||
}
|
||||
|
||||
119
chat/content.go
119
chat/content.go
@@ -3,6 +3,18 @@ package chat
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// BuiltInConnector identifies a built-in connector type.
|
||||
type BuiltInConnector string
|
||||
|
||||
const (
|
||||
ConnectorWebSearch BuiltInConnector = "web_search"
|
||||
ConnectorWebSearchPremium BuiltInConnector = "web_search_premium"
|
||||
ConnectorCodeInterpreter BuiltInConnector = "code_interpreter"
|
||||
ConnectorImageGeneration BuiltInConnector = "image_generation"
|
||||
ConnectorDocumentLibrary BuiltInConnector = "document_library"
|
||||
)
|
||||
|
||||
// ContentChunk is a sealed interface for message content parts.
|
||||
@@ -112,9 +124,65 @@ func (c *FileChunk) MarshalJSON() ([]byte, error) {
|
||||
})
|
||||
}
|
||||
|
||||
// ReferenceID is a reference identifier that can be an integer or string.
|
||||
// Use [IntRef] or [StringRef] constructors.
|
||||
type ReferenceID struct {
|
||||
raw string
|
||||
isString bool
|
||||
}
|
||||
|
||||
// IntRef creates an integer reference ID.
|
||||
func IntRef(n int) ReferenceID {
|
||||
return ReferenceID{raw: strconv.Itoa(n)}
|
||||
}
|
||||
|
||||
// StringRef creates a string reference ID.
|
||||
func StringRef(s string) ReferenceID {
|
||||
return ReferenceID{raw: s, isString: true}
|
||||
}
|
||||
|
||||
// String returns the string representation.
|
||||
func (id ReferenceID) String() string { return id.raw }
|
||||
|
||||
// Int returns the integer value and true if this is a numeric reference.
|
||||
func (id ReferenceID) Int() (int, bool) {
|
||||
if id.isString {
|
||||
return 0, false
|
||||
}
|
||||
n, err := strconv.Atoi(id.raw)
|
||||
return n, err == nil
|
||||
}
|
||||
|
||||
// IsString reports whether this is a string reference.
|
||||
func (id ReferenceID) IsString() bool { return id.isString }
|
||||
|
||||
func (id ReferenceID) MarshalJSON() ([]byte, error) {
|
||||
if id.isString {
|
||||
return json.Marshal(id.raw)
|
||||
}
|
||||
return []byte(id.raw), nil
|
||||
}
|
||||
|
||||
func (id *ReferenceID) UnmarshalJSON(data []byte) error {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
if data[0] == '"' {
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err != nil {
|
||||
return err
|
||||
}
|
||||
id.raw = s
|
||||
id.isString = true
|
||||
return nil
|
||||
}
|
||||
id.raw = string(data)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReferenceChunk represents a reference content part.
|
||||
type ReferenceChunk struct {
|
||||
ReferenceIDs []int `json:"reference_ids"`
|
||||
ReferenceIDs []ReferenceID `json:"reference_ids"`
|
||||
}
|
||||
|
||||
func (*ReferenceChunk) contentChunk() {}
|
||||
@@ -192,6 +260,49 @@ func (c *AudioChunk) MarshalJSON() ([]byte, error) {
|
||||
})
|
||||
}
|
||||
|
||||
// ToolReferenceChunk represents a tool reference content part.
|
||||
type ToolReferenceChunk struct {
|
||||
Tool string `json:"tool"`
|
||||
Title string `json:"title"`
|
||||
URL *string `json:"url,omitempty"`
|
||||
Favicon *string `json:"favicon,omitempty"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
}
|
||||
|
||||
func (*ToolReferenceChunk) contentChunk() {}
|
||||
|
||||
func (c *ToolReferenceChunk) MarshalJSON() ([]byte, error) {
|
||||
type alias ToolReferenceChunk
|
||||
return json.Marshal(&struct {
|
||||
Type string `json:"type"`
|
||||
*alias
|
||||
}{
|
||||
Type: "tool_reference",
|
||||
alias: (*alias)(c),
|
||||
})
|
||||
}
|
||||
|
||||
// ToolFileChunk represents a tool-generated file content part.
|
||||
type ToolFileChunk struct {
|
||||
Tool string `json:"tool"`
|
||||
FileID string `json:"file_id"`
|
||||
FileName *string `json:"file_name,omitempty"`
|
||||
FileType *string `json:"file_type,omitempty"`
|
||||
}
|
||||
|
||||
func (*ToolFileChunk) contentChunk() {}
|
||||
|
||||
func (c *ToolFileChunk) MarshalJSON() ([]byte, error) {
|
||||
type alias ToolFileChunk
|
||||
return json.Marshal(&struct {
|
||||
Type string `json:"type"`
|
||||
*alias
|
||||
}{
|
||||
Type: "tool_file",
|
||||
alias: (*alias)(c),
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalContentChunk dispatches to the concrete ContentChunk type
|
||||
// based on the "type" discriminator field.
|
||||
func UnmarshalContentChunk(data []byte) (ContentChunk, error) {
|
||||
@@ -223,6 +334,12 @@ func UnmarshalContentChunk(data []byte) (ContentChunk, error) {
|
||||
case "input_audio":
|
||||
var c AudioChunk
|
||||
return &c, json.Unmarshal(data, &c)
|
||||
case "tool_reference":
|
||||
var c ToolReferenceChunk
|
||||
return &c, json.Unmarshal(data, &c)
|
||||
case "tool_file":
|
||||
var c ToolFileChunk
|
||||
return &c, json.Unmarshal(data, &c)
|
||||
default:
|
||||
return &UnknownChunk{Type: probe.Type, Raw: json.RawMessage(data)}, nil
|
||||
}
|
||||
|
||||
@@ -150,12 +150,16 @@ func TestFileChunk_RoundTrip(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestReferenceChunk_RoundTrip(t *testing.T) {
|
||||
original := &ReferenceChunk{ReferenceIDs: []int{1, 2, 3}}
|
||||
func TestReferenceChunk_RoundTrip_IntIDs(t *testing.T) {
|
||||
original := &ReferenceChunk{ReferenceIDs: []ReferenceID{IntRef(1), IntRef(2), IntRef(3)}}
|
||||
data, err := json.Marshal(original)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
want := `{"type":"reference","reference_ids":[1,2,3]}`
|
||||
if string(data) != want {
|
||||
t.Errorf("marshal: got %s, want %s", data, want)
|
||||
}
|
||||
chunk, err := UnmarshalContentChunk(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -164,8 +168,64 @@ func TestReferenceChunk_RoundTrip(t *testing.T) {
|
||||
if !ok {
|
||||
t.Fatalf("expected *ReferenceChunk, got %T", chunk)
|
||||
}
|
||||
if len(rc.ReferenceIDs) != 3 || rc.ReferenceIDs[0] != 1 {
|
||||
t.Errorf("got %v, want [1 2 3]", rc.ReferenceIDs)
|
||||
if len(rc.ReferenceIDs) != 3 {
|
||||
t.Fatalf("got %d IDs, want 3", len(rc.ReferenceIDs))
|
||||
}
|
||||
n, ok := rc.ReferenceIDs[0].Int()
|
||||
if !ok || n != 1 {
|
||||
t.Errorf("got %v (ok=%v), want 1", n, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReferenceChunk_RoundTrip_MixedIDs(t *testing.T) {
|
||||
data := []byte(`{"type":"reference","reference_ids":[1,"abc",42]}`)
|
||||
chunk, err := UnmarshalContentChunk(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rc, ok := chunk.(*ReferenceChunk)
|
||||
if !ok {
|
||||
t.Fatalf("expected *ReferenceChunk, got %T", chunk)
|
||||
}
|
||||
if len(rc.ReferenceIDs) != 3 {
|
||||
t.Fatalf("got %d IDs, want 3", len(rc.ReferenceIDs))
|
||||
}
|
||||
// First: int 1
|
||||
if n, ok := rc.ReferenceIDs[0].Int(); !ok || n != 1 {
|
||||
t.Errorf("IDs[0]: got %v (ok=%v), want int 1", n, ok)
|
||||
}
|
||||
// Second: string "abc"
|
||||
if !rc.ReferenceIDs[1].IsString() || rc.ReferenceIDs[1].String() != "abc" {
|
||||
t.Errorf("IDs[1]: got %q (isString=%v), want string abc", rc.ReferenceIDs[1].String(), rc.ReferenceIDs[1].IsString())
|
||||
}
|
||||
// Third: int 42
|
||||
if n, ok := rc.ReferenceIDs[2].Int(); !ok || n != 42 {
|
||||
t.Errorf("IDs[2]: got %v (ok=%v), want int 42", n, ok)
|
||||
}
|
||||
// Round-trip preserves types
|
||||
out, err := json.Marshal(rc)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(out) != `{"type":"reference","reference_ids":[1,"abc",42]}` {
|
||||
t.Errorf("round-trip: got %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReferenceID_StringRef(t *testing.T) {
|
||||
id := StringRef("doc-123")
|
||||
if !id.IsString() {
|
||||
t.Error("expected IsString=true")
|
||||
}
|
||||
if id.String() != "doc-123" {
|
||||
t.Errorf("got %q", id.String())
|
||||
}
|
||||
if _, ok := id.Int(); ok {
|
||||
t.Error("Int() should return false for string ref")
|
||||
}
|
||||
data, _ := json.Marshal(id)
|
||||
if string(data) != `"doc-123"` {
|
||||
t.Errorf("marshal: got %s", data)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -174,7 +234,7 @@ func TestThinkChunk_RoundTrip(t *testing.T) {
|
||||
original := &ThinkChunk{
|
||||
Thinking: []ContentChunk{
|
||||
&TextChunk{Text: "reasoning step"},
|
||||
&ReferenceChunk{ReferenceIDs: []int{42}},
|
||||
&ReferenceChunk{ReferenceIDs: []ReferenceID{IntRef(42)}},
|
||||
},
|
||||
Closed: &closed,
|
||||
}
|
||||
@@ -365,3 +425,76 @@ func TestContent_IsNull(t *testing.T) {
|
||||
t.Error("chunks content should not be null")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolReferenceChunk_RoundTrip(t *testing.T) {
|
||||
url := "https://example.com/result"
|
||||
desc := "A search result"
|
||||
original := &ToolReferenceChunk{
|
||||
Tool: string(ConnectorWebSearch),
|
||||
Title: "Example Result",
|
||||
URL: &url,
|
||||
Description: &desc,
|
||||
}
|
||||
data, err := json.Marshal(original)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
chunk, err := UnmarshalContentChunk(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tr, ok := chunk.(*ToolReferenceChunk)
|
||||
if !ok {
|
||||
t.Fatalf("expected *ToolReferenceChunk, got %T", chunk)
|
||||
}
|
||||
if tr.Tool != "web_search" {
|
||||
t.Errorf("got tool %q, want web_search", tr.Tool)
|
||||
}
|
||||
if tr.Title != "Example Result" {
|
||||
t.Errorf("got title %q", tr.Title)
|
||||
}
|
||||
if tr.URL == nil || *tr.URL != url {
|
||||
t.Errorf("got url %v, want %q", tr.URL, url)
|
||||
}
|
||||
if tr.Description == nil || *tr.Description != desc {
|
||||
t.Errorf("got description %v, want %q", tr.Description, desc)
|
||||
}
|
||||
if tr.Favicon != nil {
|
||||
t.Errorf("expected nil favicon, got %v", tr.Favicon)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolFileChunk_RoundTrip(t *testing.T) {
|
||||
fname := "output.csv"
|
||||
ftype := "text/csv"
|
||||
original := &ToolFileChunk{
|
||||
Tool: string(ConnectorCodeInterpreter),
|
||||
FileID: "file-abc123",
|
||||
FileName: &fname,
|
||||
FileType: &ftype,
|
||||
}
|
||||
data, err := json.Marshal(original)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
chunk, err := UnmarshalContentChunk(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tf, ok := chunk.(*ToolFileChunk)
|
||||
if !ok {
|
||||
t.Fatalf("expected *ToolFileChunk, got %T", chunk)
|
||||
}
|
||||
if tf.Tool != "code_interpreter" {
|
||||
t.Errorf("got tool %q", tf.Tool)
|
||||
}
|
||||
if tf.FileID != "file-abc123" {
|
||||
t.Errorf("got file_id %q", tf.FileID)
|
||||
}
|
||||
if tf.FileName == nil || *tf.FileName != fname {
|
||||
t.Errorf("got file_name %v", tf.FileName)
|
||||
}
|
||||
if tf.FileType == nil || *tf.FileType != ftype {
|
||||
t.Errorf("got file_type %v", tf.FileType)
|
||||
}
|
||||
}
|
||||
|
||||
19
chat/doc.go
19
chat/doc.go
@@ -5,5 +5,22 @@
|
||||
// struct literals.
|
||||
//
|
||||
// Content is polymorphic: it can be a plain string (via [TextContent]),
|
||||
// nil, or a slice of [ContentChunk] values (text, image URL, document URL, audio).
|
||||
// nil, or a slice of [ContentChunk] values (text, image URL, document URL,
|
||||
// audio, tool reference, tool file).
|
||||
//
|
||||
// # Guardrails
|
||||
//
|
||||
// [GuardrailConfig] configures moderation guardrails on completion requests.
|
||||
// Both v1 ([ModerationLLMV1Config]) and v2 ([ModerationLLMV2Config]) moderation
|
||||
// configs are supported.
|
||||
//
|
||||
// # Content Chunks
|
||||
//
|
||||
// The following chunk types are supported: [TextChunk], [ImageURLChunk],
|
||||
// [DocumentURLChunk], [FileChunk], [ReferenceChunk], [ThinkChunk],
|
||||
// [AudioChunk], [ToolReferenceChunk], [ToolFileChunk].
|
||||
// Unrecognized types are preserved as [UnknownChunk].
|
||||
//
|
||||
// [ReferenceChunk] uses [ReferenceID] values that can hold either integer
|
||||
// or string identifiers. Use [IntRef] and [StringRef] constructors.
|
||||
package chat
|
||||
|
||||
62
chat/guardrail.go
Normal file
62
chat/guardrail.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package chat
|
||||
|
||||
// ModerationLLMAction specifies the action to take when content exceeds thresholds.
|
||||
type ModerationLLMAction string
|
||||
|
||||
const (
|
||||
ModerationActionNone ModerationLLMAction = "none"
|
||||
ModerationActionBlock ModerationLLMAction = "block"
|
||||
)
|
||||
|
||||
// ModerationLLMV1CategoryThresholds defines per-category score thresholds for v1 moderation.
|
||||
type ModerationLLMV1CategoryThresholds struct {
|
||||
Sexual *float64 `json:"sexual,omitempty"`
|
||||
HateAndDiscrimination *float64 `json:"hate_and_discrimination,omitempty"`
|
||||
ViolenceAndThreats *float64 `json:"violence_and_threats,omitempty"`
|
||||
DangerousAndCriminalContent *float64 `json:"dangerous_and_criminal_content,omitempty"`
|
||||
Selfharm *float64 `json:"selfharm,omitempty"`
|
||||
Health *float64 `json:"health,omitempty"`
|
||||
Financial *float64 `json:"financial,omitempty"`
|
||||
Law *float64 `json:"law,omitempty"`
|
||||
PII *float64 `json:"pii,omitempty"`
|
||||
}
|
||||
|
||||
// ModerationLLMV1Config configures the v1 moderation LLM guardrail.
|
||||
type ModerationLLMV1Config struct {
|
||||
ModelName string `json:"model_name,omitempty"`
|
||||
CustomCategoryThresholds *ModerationLLMV1CategoryThresholds `json:"custom_category_thresholds,omitempty"`
|
||||
IgnoreOtherCategories bool `json:"ignore_other_categories,omitempty"`
|
||||
Action ModerationLLMAction `json:"action,omitempty"`
|
||||
}
|
||||
|
||||
// ModerationLLMV2CategoryThresholds defines per-category score thresholds for v2 moderation.
|
||||
// V2 splits "dangerous_and_criminal_content" into separate "dangerous" and "criminal"
|
||||
// categories and adds "jailbreaking".
|
||||
type ModerationLLMV2CategoryThresholds struct {
|
||||
Sexual *float64 `json:"sexual,omitempty"`
|
||||
HateAndDiscrimination *float64 `json:"hate_and_discrimination,omitempty"`
|
||||
ViolenceAndThreats *float64 `json:"violence_and_threats,omitempty"`
|
||||
Dangerous *float64 `json:"dangerous,omitempty"`
|
||||
Criminal *float64 `json:"criminal,omitempty"`
|
||||
Selfharm *float64 `json:"selfharm,omitempty"`
|
||||
Health *float64 `json:"health,omitempty"`
|
||||
Financial *float64 `json:"financial,omitempty"`
|
||||
Law *float64 `json:"law,omitempty"`
|
||||
PII *float64 `json:"pii,omitempty"`
|
||||
Jailbreaking *float64 `json:"jailbreaking,omitempty"`
|
||||
}
|
||||
|
||||
// ModerationLLMV2Config configures the v2 moderation LLM guardrail.
|
||||
type ModerationLLMV2Config struct {
|
||||
ModelName string `json:"model_name,omitempty"`
|
||||
CustomCategoryThresholds *ModerationLLMV2CategoryThresholds `json:"custom_category_thresholds,omitempty"`
|
||||
IgnoreOtherCategories bool `json:"ignore_other_categories,omitempty"`
|
||||
Action ModerationLLMAction `json:"action,omitempty"`
|
||||
}
|
||||
|
||||
// GuardrailConfig configures moderation guardrails for requests.
|
||||
type GuardrailConfig struct {
|
||||
BlockOnError bool `json:"block_on_error"`
|
||||
ModerationLLMV1 *ModerationLLMV1Config `json:"moderation_llm_v1,omitempty"`
|
||||
ModerationLLMV2 *ModerationLLMV2Config `json:"moderation_llm_v2,omitempty"`
|
||||
}
|
||||
@@ -9,6 +9,14 @@ const (
|
||||
PromptModeReasoning PromptMode = "reasoning"
|
||||
)
|
||||
|
||||
// ReasoningEffort controls the amount of reasoning effort the model uses.
|
||||
type ReasoningEffort string
|
||||
|
||||
const (
|
||||
ReasoningEffortNone ReasoningEffort = "none"
|
||||
ReasoningEffortHigh ReasoningEffort = "high"
|
||||
)
|
||||
|
||||
// Prediction provides expected completion content for optimization.
|
||||
type Prediction struct {
|
||||
Type string `json:"type"`
|
||||
@@ -30,11 +38,13 @@ type CompletionRequest struct {
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
N *int `json:"n,omitempty"`
|
||||
SafePrompt bool `json:"safe_prompt,omitempty"`
|
||||
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
Prediction *Prediction `json:"prediction,omitempty"`
|
||||
PromptMode *PromptMode `json:"prompt_mode,omitempty"`
|
||||
SafePrompt bool `json:"safe_prompt,omitempty"`
|
||||
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
Prediction *Prediction `json:"prediction,omitempty"`
|
||||
PromptMode *PromptMode `json:"prompt_mode,omitempty"`
|
||||
Guardrails []GuardrailConfig `json:"guardrails,omitempty"`
|
||||
ReasoningEffort *ReasoningEffort `json:"reasoning_effort,omitempty"`
|
||||
stream bool
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ package mistral
|
||||
import (
|
||||
"context"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/chat"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/chat"
|
||||
)
|
||||
|
||||
// ChatComplete sends a chat completion request and returns the full response.
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/chat"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/chat"
|
||||
)
|
||||
|
||||
func TestChatComplete_Success(t *testing.T) {
|
||||
@@ -350,6 +350,41 @@ func TestChatComplete_RequestBody(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatComplete_ReasoningEffort(t *testing.T) {
|
||||
effort := chat.ReasoningEffortHigh
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
bodyBytes, _ := io.ReadAll(r.Body)
|
||||
var body map[string]any
|
||||
json.Unmarshal(bodyBytes, &body)
|
||||
|
||||
if body["reasoning_effort"] != "high" {
|
||||
t.Errorf("expected reasoning_effort=high, got %v", body["reasoning_effort"])
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "chat-re", "object": "chat.completion",
|
||||
"model": "m", "created": 0,
|
||||
"choices": []map[string]any{{
|
||||
"index": 0, "message": map[string]any{"role": "assistant", "content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
}},
|
||||
"usage": map[string]any{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
_, err := client.ChatComplete(context.Background(), &chat.CompletionRequest{
|
||||
Model: "m",
|
||||
Messages: []chat.Message{&chat.UserMessage{Content: chat.TextContent("Hi")}},
|
||||
ReasoningEffort: &effort,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatComplete_ContextCanceled(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Never responds — context should cancel first
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/chat"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/chat"
|
||||
)
|
||||
|
||||
func TestChatCompleteStream_Success(t *testing.T) {
|
||||
@@ -165,6 +165,7 @@ func TestChatCompleteStream_WithToolCalls(t *testing.T) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
flusher, _ := w.(http.Flusher)
|
||||
|
||||
toolCalls := chat.FinishReasonToolCalls
|
||||
chunk := chat.CompletionChunk{
|
||||
ID: "c",
|
||||
Model: "m",
|
||||
@@ -177,7 +178,9 @@ func TestChatCompleteStream_WithToolCalls(t *testing.T) {
|
||||
Function: chat.FunctionCall{Name: "get_weather", Arguments: `{"city":"Paris"}`},
|
||||
}},
|
||||
},
|
||||
FinishReason: &toolCalls,
|
||||
}},
|
||||
Usage: &chat.UsageInfo{PromptTokens: 10, CompletionTokens: 5, TotalTokens: 15},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
@@ -199,10 +202,17 @@ func TestChatCompleteStream_WithToolCalls(t *testing.T) {
|
||||
if !stream.Next() {
|
||||
t.Fatalf("expected chunk, err: %v", stream.Err())
|
||||
}
|
||||
tc := stream.Current().Choices[0].Delta.ToolCalls
|
||||
cur := stream.Current()
|
||||
tc := cur.Choices[0].Delta.ToolCalls
|
||||
if len(tc) != 1 || tc[0].Function.Name != "get_weather" {
|
||||
t.Errorf("got tool calls %+v", tc)
|
||||
}
|
||||
if cur.Choices[0].FinishReason == nil || *cur.Choices[0].FinishReason != chat.FinishReasonToolCalls {
|
||||
t.Errorf("expected finish_reason tool_calls, got %v", cur.Choices[0].FinishReason)
|
||||
}
|
||||
if cur.Usage == nil || cur.Usage.TotalTokens != 15 {
|
||||
t.Errorf("expected usage with total_tokens=15, got %+v", cur.Usage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatCompleteStream_APIError(t *testing.T) {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package classification
|
||||
|
||||
import "somegit.dev/vikingowl/mistral-go-sdk/chat"
|
||||
import "github.com/VikingOwl91/mistral-go-sdk/chat"
|
||||
|
||||
// Request represents a text classification request (/v1/classifications).
|
||||
type Request struct {
|
||||
|
||||
100
connector/connector.go
Normal file
100
connector/connector.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package connector
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// Visibility controls who can see a connector or tool.
|
||||
type Visibility string
|
||||
|
||||
const (
|
||||
VisibilitySharedGlobal Visibility = "shared_global"
|
||||
VisibilitySharedOrg Visibility = "shared_org"
|
||||
VisibilitySharedWorkspace Visibility = "shared_workspace"
|
||||
VisibilityPrivate Visibility = "private"
|
||||
)
|
||||
|
||||
// AuthData holds OAuth2 client credentials for a connector.
|
||||
type AuthData struct {
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
}
|
||||
|
||||
// Connector represents a registered MCP connector.
|
||||
type Connector struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
ModifiedAt string `json:"modified_at"`
|
||||
Server *string `json:"server,omitempty"`
|
||||
AuthType *string `json:"auth_type,omitempty"`
|
||||
}
|
||||
|
||||
// CreateRequest creates a new connector.
|
||||
type CreateRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Server string `json:"server"`
|
||||
IconURL *string `json:"icon_url,omitempty"`
|
||||
Visibility *Visibility `json:"visibility,omitempty"`
|
||||
Headers map[string]string `json:"headers,omitempty"`
|
||||
AuthData *AuthData `json:"auth_data,omitempty"`
|
||||
SystemPrompt *string `json:"system_prompt,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateRequest updates an existing connector.
|
||||
type UpdateRequest struct {
|
||||
Name *string `json:"name,omitempty"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
IconURL *string `json:"icon_url,omitempty"`
|
||||
SystemPrompt *string `json:"system_prompt,omitempty"`
|
||||
Server *string `json:"server,omitempty"`
|
||||
Headers map[string]string `json:"headers,omitempty"`
|
||||
AuthData *AuthData `json:"auth_data,omitempty"`
|
||||
ConnectionConfig map[string]any `json:"connection_config,omitempty"`
|
||||
ConnectionSecrets map[string]any `json:"connection_secrets,omitempty"`
|
||||
}
|
||||
|
||||
// AuthURLResponse is the response from getting a connector's OAuth URL.
|
||||
type AuthURLResponse struct {
|
||||
AuthURL string `json:"auth_url"`
|
||||
TTL int `json:"ttl"`
|
||||
}
|
||||
|
||||
// CallToolRequest is the request body for calling a connector tool.
|
||||
type CallToolRequest struct {
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
// CallToolResponse is the response from calling a connector tool.
|
||||
// Content is left as raw JSON because the upstream API returns a union
|
||||
// of 5 content types (text, image, audio, resource link, embedded resource).
|
||||
type CallToolResponse struct {
|
||||
Content json.RawMessage `json:"content"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// Tool represents a tool exposed by a connector.
|
||||
type Tool struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
Visibility Visibility `json:"visibility,omitempty"`
|
||||
CreatedAt string `json:"created_at,omitempty"`
|
||||
ModifiedAt string `json:"modified_at,omitempty"`
|
||||
SystemPrompt *string `json:"system_prompt,omitempty"`
|
||||
JsonSchema map[string]any `json:"jsonschema,omitempty"`
|
||||
Active *bool `json:"active,omitempty"`
|
||||
}
|
||||
|
||||
// ListParams holds query parameters for listing connectors.
|
||||
type ListParams struct {
|
||||
Page *int
|
||||
PageSize *int
|
||||
}
|
||||
|
||||
// ListToolsParams holds query parameters for listing connector tools.
|
||||
type ListToolsParams struct {
|
||||
Page *int
|
||||
PageSize *int
|
||||
Refresh *bool
|
||||
}
|
||||
6
connector/doc.go
Normal file
6
connector/doc.go
Normal file
@@ -0,0 +1,6 @@
|
||||
// Package connector provides types for the Mistral connectors API.
|
||||
//
|
||||
// Connectors represent MCP (Model Context Protocol) server integrations.
|
||||
// Use [CreateRequest] to register a new connector, then use tools
|
||||
// discovered via the list-tools endpoint in chat or agent completions.
|
||||
package connector
|
||||
119
connectors.go
Normal file
119
connectors.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"github.com/VikingOwl91/mistral-go-sdk/connector"
|
||||
)
|
||||
|
||||
// CreateConnector registers a new MCP connector.
|
||||
func (c *Client) CreateConnector(ctx context.Context, req *connector.CreateRequest) (*connector.Connector, error) {
|
||||
var resp connector.Connector
|
||||
if err := c.doJSON(ctx, "POST", "/v1/connectors", req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ListConnectors returns all connectors.
|
||||
func (c *Client) ListConnectors(ctx context.Context, params *connector.ListParams) ([]connector.Connector, error) {
|
||||
path := "/v1/connectors"
|
||||
if params != nil {
|
||||
q := url.Values{}
|
||||
if params.Page != nil {
|
||||
q.Set("page", strconv.Itoa(*params.Page))
|
||||
}
|
||||
if params.PageSize != nil {
|
||||
q.Set("page_size", strconv.Itoa(*params.PageSize))
|
||||
}
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
path += "?" + encoded
|
||||
}
|
||||
}
|
||||
var resp []connector.Connector
|
||||
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// GetConnector retrieves a connector by ID or name.
|
||||
func (c *Client) GetConnector(ctx context.Context, idOrName string) (*connector.Connector, error) {
|
||||
var resp connector.Connector
|
||||
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/connectors/%s", idOrName), nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// UpdateConnector updates an existing connector.
|
||||
func (c *Client) UpdateConnector(ctx context.Context, idOrName string, req *connector.UpdateRequest) (*connector.Connector, error) {
|
||||
var resp connector.Connector
|
||||
if err := c.doJSON(ctx, "PATCH", fmt.Sprintf("/v1/connectors/%s", idOrName), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// DeleteConnector deletes a connector.
|
||||
func (c *Client) DeleteConnector(ctx context.Context, idOrName string) error {
|
||||
resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/connectors/%s", idOrName), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 400 {
|
||||
return parseAPIError(resp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetConnectorAuthURL returns the OAuth2 authorization URL for a connector.
|
||||
func (c *Client) GetConnectorAuthURL(ctx context.Context, idOrName string, appReturnURL *string) (*connector.AuthURLResponse, error) {
|
||||
path := fmt.Sprintf("/v1/connectors/%s/auth_url", idOrName)
|
||||
if appReturnURL != nil {
|
||||
path += "?app_return_url=" + url.QueryEscape(*appReturnURL)
|
||||
}
|
||||
var resp connector.AuthURLResponse
|
||||
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ListConnectorTools lists tools exposed by a connector.
|
||||
func (c *Client) ListConnectorTools(ctx context.Context, idOrName string, params *connector.ListToolsParams) ([]connector.Tool, error) {
|
||||
path := fmt.Sprintf("/v1/connectors/%s/tools", idOrName)
|
||||
if params != nil {
|
||||
q := url.Values{}
|
||||
if params.Page != nil {
|
||||
q.Set("page", strconv.Itoa(*params.Page))
|
||||
}
|
||||
if params.PageSize != nil {
|
||||
q.Set("page_size", strconv.Itoa(*params.PageSize))
|
||||
}
|
||||
if params.Refresh != nil {
|
||||
q.Set("refresh", strconv.FormatBool(*params.Refresh))
|
||||
}
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
path += "?" + encoded
|
||||
}
|
||||
}
|
||||
var resp []connector.Tool
|
||||
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// CallConnectorTool invokes a tool on a connector.
|
||||
func (c *Client) CallConnectorTool(ctx context.Context, idOrName, toolName string, req *connector.CallToolRequest) (*connector.CallToolResponse, error) {
|
||||
var resp connector.CallToolResponse
|
||||
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/connectors/%s/tools/%s/call", idOrName, toolName), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
217
connectors_test.go
Normal file
217
connectors_test.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/VikingOwl91/mistral-go-sdk/connector"
|
||||
)
|
||||
|
||||
func TestCreateConnector_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
t.Errorf("expected POST, got %s", r.Method)
|
||||
}
|
||||
if r.URL.Path != "/v1/connectors" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
var body map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
if body["name"] != "my_connector" {
|
||||
t.Errorf("got name %v", body["name"])
|
||||
}
|
||||
if body["server"] != "https://mcp.example.com" {
|
||||
t.Errorf("got server %v", body["server"])
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "conn-1", "name": "my_connector",
|
||||
"description": "test", "created_at": "2025-01-01",
|
||||
"modified_at": "2025-01-01", "server": "https://mcp.example.com",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.CreateConnector(context.Background(), &connector.CreateRequest{
|
||||
Name: "my_connector",
|
||||
Description: "test",
|
||||
Server: "https://mcp.example.com",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.ID != "conn-1" {
|
||||
t.Errorf("got id %q", resp.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListConnectors_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "GET" {
|
||||
t.Errorf("expected GET, got %s", r.Method)
|
||||
}
|
||||
json.NewEncoder(w).Encode([]map[string]any{
|
||||
{"id": "c1", "name": "conn1", "description": "d1", "created_at": "t", "modified_at": "t"},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
list, err := client.ListConnectors(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(list) != 1 {
|
||||
t.Fatalf("got %d connectors", len(list))
|
||||
}
|
||||
if list[0].ID != "c1" {
|
||||
t.Errorf("got id %q", list[0].ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetConnector_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/connectors/my_conn" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "c1", "name": "my_conn", "description": "d",
|
||||
"created_at": "t", "modified_at": "t",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
c, err := client.GetConnector(context.Background(), "my_conn")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if c.Name != "my_conn" {
|
||||
t.Errorf("got name %q", c.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateConnector_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "PATCH" {
|
||||
t.Errorf("expected PATCH, got %s", r.Method)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "c1", "name": "updated", "description": "new desc",
|
||||
"created_at": "t", "modified_at": "t",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
name := "updated"
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.UpdateConnector(context.Background(), "c1", &connector.UpdateRequest{
|
||||
Name: &name,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Name != "updated" {
|
||||
t.Errorf("got name %q", resp.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteConnector_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "DELETE" {
|
||||
t.Errorf("expected DELETE, got %s", r.Method)
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
err := client.DeleteConnector(context.Background(), "c1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetConnectorAuthURL_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/connectors/c1/auth_url" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"auth_url": "https://oauth.example.com/authorize",
|
||||
"ttl": 3600,
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.GetConnectorAuthURL(context.Background(), "c1", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.AuthURL != "https://oauth.example.com/authorize" {
|
||||
t.Errorf("got auth_url %q", resp.AuthURL)
|
||||
}
|
||||
if resp.TTL != 3600 {
|
||||
t.Errorf("got ttl %d", resp.TTL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListConnectorTools_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/connectors/c1/tools" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode([]map[string]any{
|
||||
{"id": "t1", "name": "search", "description": "search the web"},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
tools, err := client.ListConnectorTools(context.Background(), "c1", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("got %d tools", len(tools))
|
||||
}
|
||||
if tools[0].Name != "search" {
|
||||
t.Errorf("got name %q", tools[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallConnectorTool_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
t.Errorf("expected POST, got %s", r.Method)
|
||||
}
|
||||
if r.URL.Path != "/v1/connectors/c1/tools/search/call" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
var body map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
args := body["arguments"].(map[string]any)
|
||||
if args["query"] != "hello" {
|
||||
t.Errorf("got query %v", args["query"])
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"content": []map[string]any{{"type": "text", "text": "result"}},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.CallConnectorTool(context.Background(), "c1", "search", &connector.CallToolRequest{
|
||||
Arguments: map[string]any{"query": "hello"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Content == nil {
|
||||
t.Error("expected non-nil content")
|
||||
}
|
||||
}
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/chat"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/chat"
|
||||
)
|
||||
|
||||
// HandoffExecution controls tool call execution.
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/chat"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/chat"
|
||||
)
|
||||
|
||||
// Entry is a sealed interface for conversation history entries.
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
package conversation
|
||||
|
||||
import "encoding/json"
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/VikingOwl91/mistral-go-sdk/chat"
|
||||
)
|
||||
|
||||
// StartRequest starts a new conversation.
|
||||
type StartRequest struct {
|
||||
@@ -11,9 +15,10 @@ type StartRequest struct {
|
||||
Instructions *string `json:"instructions,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
CompletionArgs *CompletionArgs `json:"completion_args,omitempty"`
|
||||
Store *bool `json:"store,omitempty"`
|
||||
HandoffExecution *HandoffExecution `json:"handoff_execution,omitempty"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
Store *bool `json:"store,omitempty"`
|
||||
HandoffExecution *HandoffExecution `json:"handoff_execution,omitempty"`
|
||||
Guardrails []chat.GuardrailConfig `json:"guardrails,omitempty"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
stream bool
|
||||
@@ -61,13 +66,14 @@ func (r *AppendRequest) MarshalJSON() ([]byte, error) {
|
||||
|
||||
// RestartRequest restarts a conversation from a specific entry.
|
||||
type RestartRequest struct {
|
||||
Inputs Inputs `json:"inputs"`
|
||||
FromEntryID string `json:"from_entry_id"`
|
||||
CompletionArgs *CompletionArgs `json:"completion_args,omitempty"`
|
||||
Store *bool `json:"store,omitempty"`
|
||||
HandoffExecution *HandoffExecution `json:"handoff_execution,omitempty"`
|
||||
AgentVersion json.RawMessage `json:"agent_version,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
Inputs Inputs `json:"inputs"`
|
||||
FromEntryID string `json:"from_entry_id"`
|
||||
CompletionArgs *CompletionArgs `json:"completion_args,omitempty"`
|
||||
Store *bool `json:"store,omitempty"`
|
||||
HandoffExecution *HandoffExecution `json:"handoff_execution,omitempty"`
|
||||
Guardrails []chat.GuardrailConfig `json:"guardrails,omitempty"`
|
||||
AgentVersion json.RawMessage `json:"agent_version,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
stream bool
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/conversation"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/conversation"
|
||||
)
|
||||
|
||||
// StartConversation creates and starts a new conversation.
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/conversation"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/conversation"
|
||||
)
|
||||
|
||||
func TestStartConversation_Success(t *testing.T) {
|
||||
|
||||
12
doc.go
12
doc.go
@@ -39,7 +39,13 @@
|
||||
// # Sub-packages
|
||||
//
|
||||
// Types are organized into sub-packages by domain: [chat], [agents],
|
||||
// [conversation], [embedding], [model], [file], [finetune], [batch],
|
||||
// [ocr], [audio], [library], [moderation], [classification], and [fim].
|
||||
// All service methods live directly on [Client].
|
||||
// [connector], [conversation], [embedding], [model], [file], [finetune],
|
||||
// [batch], [ocr], [audio], [library], [moderation], [classification],
|
||||
// [fim], and [observability]. All service methods live directly on [Client].
|
||||
//
|
||||
// # Reference
|
||||
//
|
||||
// This SDK tracks the official Mistral Python SDK
|
||||
// (https://github.com/mistralai/client-python) as its upstream reference
|
||||
// for API surface and type definitions.
|
||||
package mistral
|
||||
|
||||
9507
docs/openapi.yaml
9507
docs/openapi.yaml
File diff suppressed because it is too large
Load Diff
2966
docs/superpowers/plans/2026-04-02-workflows-api.md
Normal file
2966
docs/superpowers/plans/2026-04-02-workflows-api.md
Normal file
File diff suppressed because it is too large
Load Diff
650
docs/superpowers/specs/2026-04-02-workflows-api-design.md
Normal file
650
docs/superpowers/specs/2026-04-02-workflows-api-design.md
Normal file
@@ -0,0 +1,650 @@
|
||||
# Workflows API Integration — Design Spec
|
||||
|
||||
**Date:** 2026-04-02
|
||||
**Upstream:** Mistral Python SDK v2.2.0 (released 2026-03-31)
|
||||
**SDK version:** v1.2.0
|
||||
**Scope:** Full parity with Python SDK v2.2.0 changes since v2.1.3
|
||||
|
||||
## Summary
|
||||
|
||||
Add the Workflows API (37 new methods) and `DeleteBatchJob` (1 method) to the Go SDK.
|
||||
This is purely additive — no breaking changes to existing API surface.
|
||||
|
||||
## New Package: `workflow/`
|
||||
|
||||
Types-only package following the two-layer architecture. 8 type files + `doc.go`.
|
||||
|
||||
### `workflow/doc.go`
|
||||
|
||||
Package documentation.
|
||||
|
||||
### `workflow/workflow.go` — Core CRUD types
|
||||
|
||||
```go
|
||||
type Workflow struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
DisplayName *string `json:"display_name,omitempty"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
OwnerID string `json:"owner_id"`
|
||||
WorkspaceID string `json:"workspace_id"`
|
||||
AvailableInChatAssistant bool `json:"available_in_chat_assistant"`
|
||||
Archived bool `json:"archived"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
type WorkflowUpdateRequest struct {
|
||||
DisplayName *string `json:"display_name,omitempty"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
AvailableInChatAssistant *bool `json:"available_in_chat_assistant,omitempty"`
|
||||
}
|
||||
|
||||
type WorkflowListResponse struct {
|
||||
Workflows []Workflow `json:"workflows"`
|
||||
NextCursor *string `json:"next_cursor,omitempty"`
|
||||
}
|
||||
|
||||
type WorkflowListParams struct {
|
||||
ActiveOnly *bool
|
||||
IncludeShared *bool
|
||||
AvailableInChatAssistant *bool
|
||||
Archived *bool
|
||||
Cursor *string
|
||||
Limit *int
|
||||
}
|
||||
|
||||
type WorkflowArchiveResponse struct {
|
||||
ID string `json:"id"`
|
||||
Archived bool `json:"archived"`
|
||||
}
|
||||
```
|
||||
|
||||
### `workflow/execution.go` — Execution types
|
||||
|
||||
```go
|
||||
type ExecutionStatus string
|
||||
|
||||
const (
|
||||
ExecutionRunning ExecutionStatus = "RUNNING"
|
||||
ExecutionCompleted ExecutionStatus = "COMPLETED"
|
||||
ExecutionFailed ExecutionStatus = "FAILED"
|
||||
ExecutionCanceled ExecutionStatus = "CANCELED"
|
||||
ExecutionTerminated ExecutionStatus = "TERMINATED"
|
||||
ExecutionContinuedAsNew ExecutionStatus = "CONTINUED_AS_NEW"
|
||||
ExecutionTimedOut ExecutionStatus = "TIMED_OUT"
|
||||
ExecutionRetryingAfterErr ExecutionStatus = "RETRYING_AFTER_ERROR"
|
||||
)
|
||||
|
||||
type ExecutionRequest struct {
|
||||
ExecutionID *string `json:"execution_id,omitempty"`
|
||||
Input map[string]any `json:"input,omitempty"`
|
||||
EncodedInput *NetworkEncodedInput `json:"encoded_input,omitempty"`
|
||||
WaitForResult bool `json:"wait_for_result,omitempty"`
|
||||
TimeoutSeconds *float64 `json:"timeout_seconds,omitempty"`
|
||||
CustomTracingAttributes map[string]string `json:"custom_tracing_attributes,omitempty"`
|
||||
DeploymentName *string `json:"deployment_name,omitempty"`
|
||||
}
|
||||
|
||||
type ExecutionResponse struct {
|
||||
WorkflowName string `json:"workflow_name"`
|
||||
ExecutionID string `json:"execution_id"`
|
||||
RootExecutionID string `json:"root_execution_id"`
|
||||
Status ExecutionStatus `json:"status"`
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime *string `json:"end_time,omitempty"`
|
||||
Result any `json:"result,omitempty"`
|
||||
ParentExecutionID *string `json:"parent_execution_id,omitempty"`
|
||||
TotalDurationMs *int `json:"total_duration_ms,omitempty"`
|
||||
}
|
||||
|
||||
type NetworkEncodedInput struct {
|
||||
B64Payload string `json:"b64payload"`
|
||||
EncodingOptions []string `json:"encoding_options,omitempty"`
|
||||
Empty bool `json:"empty,omitempty"`
|
||||
}
|
||||
|
||||
type SignalInvocationBody struct {
|
||||
Name string `json:"name"`
|
||||
Input any `json:"input"`
|
||||
}
|
||||
|
||||
type SignalResponse struct {
|
||||
Message string `json:"message"` // default: "Signal accepted"
|
||||
}
|
||||
|
||||
type QueryInvocationBody struct {
|
||||
Name string `json:"name"`
|
||||
Input any `json:"input,omitempty"`
|
||||
}
|
||||
|
||||
type QueryResponse struct {
|
||||
QueryName string `json:"query_name"`
|
||||
Result any `json:"result"`
|
||||
}
|
||||
|
||||
type UpdateInvocationBody struct {
|
||||
Name string `json:"name"`
|
||||
Input any `json:"input,omitempty"`
|
||||
}
|
||||
|
||||
type UpdateResponse struct {
|
||||
UpdateName string `json:"update_name"`
|
||||
Result any `json:"result"`
|
||||
}
|
||||
|
||||
// Trace response types
|
||||
|
||||
type TraceOTelResponse struct {
|
||||
WorkflowName string `json:"workflow_name"`
|
||||
ExecutionID string `json:"execution_id"`
|
||||
RootExecutionID string `json:"root_execution_id"`
|
||||
Status *ExecutionStatus `json:"status"`
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime *string `json:"end_time,omitempty"`
|
||||
Result any `json:"result"`
|
||||
DataSource string `json:"data_source"`
|
||||
ParentExecutionID *string `json:"parent_execution_id,omitempty"`
|
||||
TotalDurationMs *int `json:"total_duration_ms,omitempty"`
|
||||
OTelTraceID *string `json:"otel_trace_id,omitempty"`
|
||||
OTelTraceData any `json:"otel_trace_data,omitempty"`
|
||||
}
|
||||
|
||||
type TraceSummaryResponse struct {
|
||||
WorkflowName string `json:"workflow_name"`
|
||||
ExecutionID string `json:"execution_id"`
|
||||
RootExecutionID string `json:"root_execution_id"`
|
||||
Status *ExecutionStatus `json:"status"`
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime *string `json:"end_time,omitempty"`
|
||||
Result any `json:"result"`
|
||||
ParentExecutionID *string `json:"parent_execution_id,omitempty"`
|
||||
TotalDurationMs *int `json:"total_duration_ms,omitempty"`
|
||||
SpanTree any `json:"span_tree,omitempty"`
|
||||
}
|
||||
|
||||
type TraceEventsResponse struct {
|
||||
WorkflowName string `json:"workflow_name"`
|
||||
ExecutionID string `json:"execution_id"`
|
||||
RootExecutionID string `json:"root_execution_id"`
|
||||
Status *ExecutionStatus `json:"status"`
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime *string `json:"end_time,omitempty"`
|
||||
Result any `json:"result"`
|
||||
ParentExecutionID *string `json:"parent_execution_id,omitempty"`
|
||||
TotalDurationMs *int `json:"total_duration_ms,omitempty"`
|
||||
Events []json.RawMessage `json:"events,omitempty"`
|
||||
}
|
||||
|
||||
type TraceEventsParams struct {
|
||||
MergeSameIDEvents *bool
|
||||
IncludeInternalEvents *bool
|
||||
}
|
||||
|
||||
type ResetInvocationBody struct {
|
||||
EventID int `json:"event_id"`
|
||||
Reason *string `json:"reason,omitempty"`
|
||||
ExcludeSignals bool `json:"exclude_signals,omitempty"`
|
||||
ExcludeUpdates bool `json:"exclude_updates,omitempty"`
|
||||
}
|
||||
|
||||
type BatchExecutionBody struct {
|
||||
ExecutionIDs []string `json:"execution_ids"`
|
||||
}
|
||||
|
||||
type BatchExecutionResponse struct {
|
||||
Results map[string]BatchExecutionResult `json:"results,omitempty"`
|
||||
}
|
||||
|
||||
type BatchExecutionResult struct {
|
||||
Status string `json:"status"`
|
||||
Error *string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type StreamParams struct {
|
||||
EventSource *EventSource
|
||||
LastEventID *string
|
||||
}
|
||||
```
|
||||
|
||||
### `workflow/event.go` — Sealed event interface + 17 variants
|
||||
|
||||
Discriminator field: `event_type`
|
||||
|
||||
```go
|
||||
type Event interface {
|
||||
workflowEvent()
|
||||
EventType() EventType
|
||||
}
|
||||
|
||||
type EventType string
|
||||
|
||||
const (
|
||||
EventWorkflowStarted EventType = "WORKFLOW_EXECUTION_STARTED"
|
||||
EventWorkflowCompleted EventType = "WORKFLOW_EXECUTION_COMPLETED"
|
||||
EventWorkflowFailed EventType = "WORKFLOW_EXECUTION_FAILED"
|
||||
EventWorkflowCanceled EventType = "WORKFLOW_EXECUTION_CANCELED"
|
||||
EventWorkflowContinuedAsNew EventType = "WORKFLOW_EXECUTION_CONTINUED_AS_NEW"
|
||||
EventWorkflowTaskTimedOut EventType = "WORKFLOW_TASK_TIMED_OUT"
|
||||
EventWorkflowTaskFailed EventType = "WORKFLOW_TASK_FAILED"
|
||||
EventCustomTaskStarted EventType = "CUSTOM_TASK_STARTED"
|
||||
EventCustomTaskInProgress EventType = "CUSTOM_TASK_IN_PROGRESS"
|
||||
EventCustomTaskCompleted EventType = "CUSTOM_TASK_COMPLETED"
|
||||
EventCustomTaskFailed EventType = "CUSTOM_TASK_FAILED"
|
||||
EventCustomTaskTimedOut EventType = "CUSTOM_TASK_TIMED_OUT"
|
||||
EventCustomTaskCanceled EventType = "CUSTOM_TASK_CANCELED"
|
||||
EventActivityTaskStarted EventType = "ACTIVITY_TASK_STARTED"
|
||||
EventActivityTaskCompleted EventType = "ACTIVITY_TASK_COMPLETED"
|
||||
EventActivityTaskRetrying EventType = "ACTIVITY_TASK_RETRYING"
|
||||
EventActivityTaskFailed EventType = "ACTIVITY_TASK_FAILED"
|
||||
)
|
||||
|
||||
type EventSource string
|
||||
|
||||
const (
|
||||
EventSourceDatabase EventSource = "DATABASE"
|
||||
EventSourceLive EventSource = "LIVE"
|
||||
)
|
||||
|
||||
type Scope string
|
||||
|
||||
const (
|
||||
ScopeActivity Scope = "activity"
|
||||
ScopeWorkflow Scope = "workflow"
|
||||
ScopeAll Scope = "*"
|
||||
)
|
||||
```
|
||||
|
||||
Each concrete event type has common fields + type-specific attributes:
|
||||
|
||||
```go
|
||||
// Common fields embedded in all event types
|
||||
type eventBase struct {
|
||||
ID string `json:"event_id"`
|
||||
Timestamp int64 `json:"event_timestamp"`
|
||||
RootWorkflowExecID string `json:"root_workflow_exec_id"`
|
||||
ParentWorkflowExecID *string `json:"parent_workflow_exec_id"`
|
||||
WorkflowExecID string `json:"workflow_exec_id"`
|
||||
WorkflowRunID string `json:"workflow_run_id"`
|
||||
WorkflowName string `json:"workflow_name"`
|
||||
}
|
||||
|
||||
// Example concrete types:
|
||||
|
||||
type WorkflowExecutionStartedEvent struct {
|
||||
eventBase
|
||||
Attributes WorkflowStartedAttributes `json:"attributes"`
|
||||
}
|
||||
func (WorkflowExecutionStartedEvent) workflowEvent() {}
|
||||
func (WorkflowExecutionStartedEvent) EventType() EventType { return EventWorkflowStarted }
|
||||
|
||||
type WorkflowExecutionCompletedEvent struct {
|
||||
eventBase
|
||||
Attributes WorkflowCompletedAttributes `json:"attributes"`
|
||||
}
|
||||
// ... pattern repeats for all 17 types
|
||||
|
||||
type UnknownEvent struct {
|
||||
eventBase
|
||||
RawType string
|
||||
Raw json.RawMessage
|
||||
}
|
||||
```
|
||||
|
||||
SSE envelope types:
|
||||
|
||||
```go
|
||||
type StreamPayload struct {
|
||||
Stream string `json:"stream"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
WorkflowContext StreamWorkflowContext `json:"workflow_context"`
|
||||
BrokerSequence int `json:"broker_sequence"`
|
||||
Timestamp *string `json:"timestamp,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
type StreamWorkflowContext struct {
|
||||
Namespace string `json:"namespace"`
|
||||
WorkflowName string `json:"workflow_name"`
|
||||
WorkflowExecID string `json:"workflow_exec_id"`
|
||||
ParentWorkflowExecID *string `json:"parent_workflow_exec_id,omitempty"`
|
||||
RootWorkflowExecID *string `json:"root_workflow_exec_id,omitempty"`
|
||||
}
|
||||
|
||||
func UnmarshalEvent(data json.RawMessage) (Event, error)
|
||||
// Probes event_type discriminator, dispatches to concrete type.
|
||||
// Unknown event_type returns UnknownEvent.
|
||||
```
|
||||
|
||||
### `workflow/deployment.go`
|
||||
|
||||
```go
|
||||
type Deployment struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
IsActive bool `json:"is_active"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
type DeploymentListResponse struct {
|
||||
Deployments []Deployment `json:"deployments"`
|
||||
}
|
||||
|
||||
type DeploymentListParams struct {
|
||||
ActiveOnly *bool
|
||||
WorkflowName *string
|
||||
}
|
||||
```
|
||||
|
||||
### `workflow/metrics.go`
|
||||
|
||||
```go
|
||||
type Metrics struct {
|
||||
ExecutionCount ScalarMetric `json:"execution_count"`
|
||||
SuccessCount ScalarMetric `json:"success_count"`
|
||||
ErrorCount ScalarMetric `json:"error_count"`
|
||||
AverageLatencyMs ScalarMetric `json:"average_latency_ms"`
|
||||
LatencyOverTime TimeSeriesMetric `json:"latency_over_time"`
|
||||
RetryRate ScalarMetric `json:"retry_rate"`
|
||||
}
|
||||
|
||||
type ScalarMetric struct {
|
||||
Value float64 `json:"value"`
|
||||
}
|
||||
|
||||
type TimeSeriesMetric struct {
|
||||
Value [][]float64 `json:"value"`
|
||||
}
|
||||
|
||||
type MetricsParams struct {
|
||||
StartTime *string
|
||||
EndTime *string
|
||||
}
|
||||
```
|
||||
|
||||
### `workflow/run.go`
|
||||
|
||||
```go
|
||||
type Run struct {
|
||||
ID string `json:"id"`
|
||||
WorkflowName string `json:"workflow_name"`
|
||||
ExecutionID string `json:"execution_id"`
|
||||
Status ExecutionStatus `json:"status"`
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime *string `json:"end_time,omitempty"`
|
||||
}
|
||||
|
||||
type ListRunsResponse struct {
|
||||
Runs []Run `json:"runs"`
|
||||
NextPageToken *string `json:"next_page_token,omitempty"`
|
||||
}
|
||||
|
||||
type RunListParams struct {
|
||||
WorkflowIdentifier *string
|
||||
Search *string
|
||||
Status *string
|
||||
PageSize *int
|
||||
NextPageToken *string
|
||||
}
|
||||
```
|
||||
|
||||
### `workflow/schedule.go`
|
||||
|
||||
```go
|
||||
type ScheduleRequest struct {
|
||||
Schedule ScheduleDefinition `json:"schedule"`
|
||||
WorkflowRegistrationID *string `json:"workflow_registration_id,omitempty"`
|
||||
WorkflowIdentifier *string `json:"workflow_identifier,omitempty"`
|
||||
ScheduleID *string `json:"schedule_id,omitempty"`
|
||||
DeploymentName *string `json:"deployment_name,omitempty"`
|
||||
}
|
||||
|
||||
type ScheduleDefinition struct {
|
||||
Input any `json:"input"`
|
||||
Calendars []ScheduleCalendar `json:"calendars,omitempty"`
|
||||
Intervals []ScheduleInterval `json:"intervals,omitempty"`
|
||||
CronExpressions []string `json:"cron_expressions,omitempty"`
|
||||
Skip []ScheduleCalendar `json:"skip,omitempty"`
|
||||
StartAt *string `json:"start_at,omitempty"`
|
||||
EndAt *string `json:"end_at,omitempty"`
|
||||
Jitter *string `json:"jitter,omitempty"`
|
||||
TimeZoneName *string `json:"time_zone_name,omitempty"`
|
||||
Policy *SchedulePolicy `json:"policy,omitempty"`
|
||||
}
|
||||
|
||||
type ScheduleCalendar struct {
|
||||
Second []ScheduleRange `json:"second,omitempty"`
|
||||
Minute []ScheduleRange `json:"minute,omitempty"`
|
||||
Hour []ScheduleRange `json:"hour,omitempty"`
|
||||
DayOfMonth []ScheduleRange `json:"day_of_month,omitempty"`
|
||||
Month []ScheduleRange `json:"month,omitempty"`
|
||||
Year []ScheduleRange `json:"year,omitempty"`
|
||||
DayOfWeek []ScheduleRange `json:"day_of_week,omitempty"`
|
||||
Comment *string `json:"comment,omitempty"`
|
||||
}
|
||||
|
||||
type ScheduleRange struct {
|
||||
Start int `json:"start"`
|
||||
End int `json:"end,omitempty"`
|
||||
Step int `json:"step,omitempty"`
|
||||
}
|
||||
|
||||
type ScheduleInterval struct {
|
||||
Every string `json:"every"`
|
||||
Offset *string `json:"offset,omitempty"`
|
||||
}
|
||||
|
||||
type SchedulePolicy struct {
|
||||
CatchupWindowSeconds int `json:"catchup_window_seconds,omitempty"`
|
||||
Overlap *int `json:"overlap,omitempty"`
|
||||
PauseOnFailure bool `json:"pause_on_failure,omitempty"`
|
||||
}
|
||||
|
||||
type ScheduleResponse struct {
|
||||
ScheduleID string `json:"schedule_id"`
|
||||
}
|
||||
|
||||
type ScheduleListResponse struct {
|
||||
Schedules []Schedule `json:"schedules"`
|
||||
}
|
||||
|
||||
type Schedule struct {
|
||||
ScheduleID string `json:"schedule_id"`
|
||||
Definition ScheduleDefinition `json:"definition"`
|
||||
WorkflowName string `json:"workflow_name"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
```
|
||||
|
||||
### `workflow/registration.go`
|
||||
|
||||
```go
|
||||
type Registration struct {
|
||||
ID string `json:"id"`
|
||||
WorkflowID string `json:"workflow_id"`
|
||||
TaskQueue string `json:"task_queue"`
|
||||
Workflow *Workflow `json:"workflow,omitempty"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
type RegistrationListResponse struct {
|
||||
Registrations []Registration `json:"registrations"`
|
||||
NextCursor *string `json:"next_cursor,omitempty"`
|
||||
}
|
||||
|
||||
type RegistrationListParams struct {
|
||||
WorkflowID *string
|
||||
TaskQueue *string
|
||||
ActiveOnly *bool
|
||||
IncludeShared *bool
|
||||
WorkflowSearch *string
|
||||
Archived *bool
|
||||
WithWorkflow *bool
|
||||
AvailableInChatAssistant *bool
|
||||
Limit *int
|
||||
Cursor *string
|
||||
}
|
||||
|
||||
type RegistrationGetParams struct {
|
||||
WithWorkflow *bool
|
||||
IncludeShared *bool
|
||||
}
|
||||
|
||||
type WorkerInfo struct {
|
||||
SchedulerURL string `json:"scheduler_url"`
|
||||
Namespace string `json:"namespace"`
|
||||
TLS bool `json:"tls"`
|
||||
}
|
||||
```
|
||||
|
||||
## Service Methods (root package)
|
||||
|
||||
### `workflows.go` — 10 methods
|
||||
|
||||
| Method | HTTP | Path |
|
||||
|--------|------|------|
|
||||
| `ListWorkflows` | GET | `/v1/workflows` |
|
||||
| `GetWorkflow` | GET | `/v1/workflows/{id}` |
|
||||
| `UpdateWorkflow` | PUT | `/v1/workflows/{id}` |
|
||||
| `ArchiveWorkflow` | PUT | `/v1/workflows/{id}/archive` |
|
||||
| `UnarchiveWorkflow` | PUT | `/v1/workflows/{id}/unarchive` |
|
||||
| `ExecuteWorkflow` | POST | `/v1/workflows/{id}/execute` |
|
||||
| `ListWorkflowRegistrations` | GET | `/v1/workflows/registrations` |
|
||||
| `GetWorkflowRegistration` | GET | `/v1/workflows/registrations/{id}` |
|
||||
| `ExecuteWorkflowRegistration` | POST | `/v1/workflows/registrations/{id}/execute` |
|
||||
| `ExecuteWorkflowAndWait` | (composite) | execute + poll |
|
||||
|
||||
`ExecuteWorkflowRegistration` is deprecated (doc comment only).
|
||||
|
||||
`ExecuteWorkflowAndWait` calls `ExecuteWorkflow`, then polls `GetWorkflowExecution`
|
||||
in a loop until status is terminal or context is canceled.
|
||||
|
||||
### `workflows_executions.go` — 14 methods
|
||||
|
||||
| Method | HTTP | Path |
|
||||
|--------|------|------|
|
||||
| `GetWorkflowExecution` | GET | `/v1/workflows/executions/{id}` |
|
||||
| `GetWorkflowExecutionHistory` | GET | `/v1/workflows/executions/{id}/history` |
|
||||
| `StreamWorkflowExecution` | GET (SSE) | `/v1/workflows/executions/{id}/stream` |
|
||||
| `SignalWorkflowExecution` | POST | `/v1/workflows/executions/{id}/signals` |
|
||||
| `QueryWorkflowExecution` | POST | `/v1/workflows/executions/{id}/queries` |
|
||||
| `UpdateWorkflowExecution` | POST | `/v1/workflows/executions/{id}/updates` |
|
||||
| `TerminateWorkflowExecution` | POST (204) | `/v1/workflows/executions/{id}/terminate` |
|
||||
| `CancelWorkflowExecution` | POST (204) | `/v1/workflows/executions/{id}/cancel` |
|
||||
| `ResetWorkflowExecution` | POST (204) | `/v1/workflows/executions/{id}/reset` |
|
||||
| `BatchCancelWorkflowExecutions` | POST | `/v1/workflows/executions/cancel` |
|
||||
| `BatchTerminateWorkflowExecutions` | POST | `/v1/workflows/executions/terminate` |
|
||||
| `GetWorkflowExecutionTraceOTel` | GET | `/v1/workflows/executions/{id}/trace/otel` |
|
||||
| `GetWorkflowExecutionTraceSummary` | GET | `/v1/workflows/executions/{id}/trace/summary` |
|
||||
| `GetWorkflowExecutionTraceEvents` | GET | `/v1/workflows/executions/{id}/trace/events` |
|
||||
|
||||
Also contains `WorkflowEventStream` type (wraps `Stream[json.RawMessage]`,
|
||||
dispatches via `workflow.UnmarshalEvent`).
|
||||
|
||||
### `workflows_events.go` — 2 methods
|
||||
|
||||
| Method | HTTP | Path |
|
||||
|--------|------|------|
|
||||
| `StreamWorkflowEvents` | GET (SSE) | `/v1/workflows/events/stream` |
|
||||
| `ListWorkflowEvents` | GET | `/v1/workflows/events/list` |
|
||||
|
||||
### `workflows_deployments.go` — 2 methods
|
||||
|
||||
| Method | HTTP | Path |
|
||||
|--------|------|------|
|
||||
| `ListWorkflowDeployments` | GET | `/v1/workflows/deployments` |
|
||||
| `GetWorkflowDeployment` | GET | `/v1/workflows/deployments/{id}` |
|
||||
|
||||
### `workflows_metrics.go` — 1 method
|
||||
|
||||
| Method | HTTP | Path |
|
||||
|--------|------|------|
|
||||
| `GetWorkflowMetrics` | GET | `/v1/workflows/{name}/metrics` |
|
||||
|
||||
### `workflows_runs.go` — 3 methods
|
||||
|
||||
| Method | HTTP | Path |
|
||||
|--------|------|------|
|
||||
| `ListWorkflowRuns` | GET | `/v1/workflows/runs` |
|
||||
| `GetWorkflowRun` | GET | `/v1/workflows/runs/{id}` |
|
||||
| `GetWorkflowRunHistory` | GET | `/v1/workflows/runs/{id}/history` |
|
||||
|
||||
### `workflows_schedules.go` — 3 methods
|
||||
|
||||
| Method | HTTP | Path |
|
||||
|--------|------|------|
|
||||
| `ListWorkflowSchedules` | GET | `/v1/workflows/schedules` |
|
||||
| `ScheduleWorkflow` | POST | `/v1/workflows/schedules` |
|
||||
| `UnscheduleWorkflow` | DELETE | `/v1/workflows/schedules/{id}` |
|
||||
|
||||
### `workflows_workers.go` — 1 method
|
||||
|
||||
| Method | HTTP | Path |
|
||||
|--------|------|------|
|
||||
| `GetWorkflowWorkerInfo` | GET | `/v1/workflows/workers/whoami` |
|
||||
|
||||
### `batch_api.go` — 1 new method
|
||||
|
||||
| Method | HTTP | Path |
|
||||
|--------|------|------|
|
||||
| `DeleteBatchJob` | DELETE | `/v1/batch/jobs/{id}` |
|
||||
|
||||
New type in `batch/`: `DeleteResponse { ID, Object, Deleted }`.
|
||||
|
||||
## Streaming Design
|
||||
|
||||
`WorkflowEventStream` wraps `Stream[json.RawMessage]` like `EventStream` does for conversations.
|
||||
|
||||
SSE data arrives as `StreamPayload` envelope:
|
||||
```json
|
||||
{
|
||||
"stream": "...",
|
||||
"data": { "event_type": "WORKFLOW_EXECUTION_COMPLETED", ... },
|
||||
"workflow_context": { ... },
|
||||
"broker_sequence": 42,
|
||||
"timestamp": "...",
|
||||
"metadata": {}
|
||||
}
|
||||
```
|
||||
|
||||
`WorkflowEventStream.Next()`:
|
||||
1. Read next SSE `data:` line via inner `Stream[json.RawMessage]`
|
||||
2. Unmarshal into `workflow.StreamPayload`
|
||||
3. Dispatch `payload.Data` via `workflow.UnmarshalEvent` (probes `event_type`)
|
||||
4. Expose both `Current() workflow.Event` and `CurrentPayload() *workflow.StreamPayload`
|
||||
|
||||
Both `StreamWorkflowExecution` and `StreamWorkflowEvents` use GET (not POST)
|
||||
with SSE response. They use `doStream` without a request body — the stream method
|
||||
needs to support GET + query params (verify `doStream` handles nil body for GET).
|
||||
|
||||
## Testing
|
||||
|
||||
One test file per service file. `httptest.NewServer` with inline handlers. Stdlib only.
|
||||
|
||||
Key scenarios:
|
||||
- Query param encoding for list/filter endpoints
|
||||
- PUT body marshaling for update/archive
|
||||
- 204 no-body responses for terminate/cancel/reset
|
||||
- 202 response for signal
|
||||
- SSE streaming with StreamPayload envelope + event type dispatch
|
||||
- UnknownEvent forward compatibility
|
||||
- ExecuteWorkflowAndWait polling loop (mock multiple get-execution responses)
|
||||
- Sealed interface UnmarshalEvent for all 17 event types + unknown
|
||||
- Batch operations with map response
|
||||
|
||||
## Version & Docs
|
||||
|
||||
- Bump version constant in `mistral.go` to `1.2.0`
|
||||
- Update `CLAUDE.md` sub-packages list to include `workflow/`
|
||||
- Update `CHANGELOG.md` with v1.2.0 entry
|
||||
- Upstream sync reference: Python SDK v2.2.0
|
||||
|
||||
## Non-Goals
|
||||
|
||||
- No pagination helpers (cursor chaining) — callers manage pagination manually, same as existing endpoints
|
||||
- No traceparent injection hook — Go callers manage their own tracing headers
|
||||
- No `execute_workflow_and_wait_async` — Go has context cancellation instead
|
||||
- No WebSocket/realtime workflow support (not in Python SDK either)
|
||||
@@ -1,6 +1,6 @@
|
||||
package embedding
|
||||
|
||||
import "somegit.dev/vikingowl/mistral-go-sdk/chat"
|
||||
import "github.com/VikingOwl91/mistral-go-sdk/chat"
|
||||
|
||||
// Dtype specifies the data type of output embeddings.
|
||||
type Dtype string
|
||||
|
||||
@@ -3,7 +3,7 @@ package mistral
|
||||
import (
|
||||
"context"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/embedding"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/embedding"
|
||||
)
|
||||
|
||||
// CreateEmbeddings sends an embedding request and returns the response.
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/embedding"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/embedding"
|
||||
)
|
||||
|
||||
func TestCreateEmbeddings_Success(t *testing.T) {
|
||||
|
||||
@@ -5,9 +5,9 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
mistral "somegit.dev/vikingowl/mistral-go-sdk"
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/chat"
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/embedding"
|
||||
mistral "github.com/VikingOwl91/mistral-go-sdk"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/chat"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/embedding"
|
||||
)
|
||||
|
||||
func ExampleNewClient() {
|
||||
|
||||
17
file/file.go
17
file/file.go
@@ -64,6 +64,23 @@ type SignedURL struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
// Visibility controls who can see a file.
|
||||
type Visibility string
|
||||
|
||||
const (
|
||||
VisibilitySharedGlobal Visibility = "shared_global"
|
||||
VisibilitySharedOrg Visibility = "shared_org"
|
||||
VisibilitySharedWorkspace Visibility = "shared_workspace"
|
||||
VisibilityPrivate Visibility = "private"
|
||||
)
|
||||
|
||||
// UploadParams holds parameters for uploading a file.
|
||||
type UploadParams struct {
|
||||
Purpose Purpose
|
||||
Expiry *int
|
||||
Visibility *Visibility
|
||||
}
|
||||
|
||||
// ListParams holds optional parameters for listing files.
|
||||
type ListParams struct {
|
||||
Page *int
|
||||
|
||||
16
files.go
16
files.go
@@ -8,14 +8,22 @@ import (
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/file"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/file"
|
||||
)
|
||||
|
||||
// UploadFile uploads a file for use with fine-tuning, batch, or OCR.
|
||||
func (c *Client) UploadFile(ctx context.Context, filename string, r io.Reader, purpose file.Purpose) (*file.File, error) {
|
||||
func (c *Client) UploadFile(ctx context.Context, filename string, r io.Reader, params *file.UploadParams) (*file.File, error) {
|
||||
fields := map[string]string{}
|
||||
if purpose != "" {
|
||||
fields["purpose"] = string(purpose)
|
||||
if params != nil {
|
||||
if params.Purpose != "" {
|
||||
fields["purpose"] = string(params.Purpose)
|
||||
}
|
||||
if params.Expiry != nil {
|
||||
fields["expiry"] = strconv.Itoa(*params.Expiry)
|
||||
}
|
||||
if params.Visibility != nil {
|
||||
fields["visibility"] = string(*params.Visibility)
|
||||
}
|
||||
}
|
||||
var resp file.File
|
||||
if err := c.doMultipart(ctx, "/v1/files", filename, r, fields, &resp); err != nil {
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/file"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/file"
|
||||
)
|
||||
|
||||
func TestUploadFile_Success(t *testing.T) {
|
||||
@@ -58,7 +58,7 @@ func TestUploadFile_Success(t *testing.T) {
|
||||
context.Background(),
|
||||
"train.jsonl",
|
||||
strings.NewReader(`{"text":"hello"}`),
|
||||
file.PurposeFineTune,
|
||||
&file.UploadParams{Purpose: file.PurposeFineTune},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -74,6 +74,43 @@ func TestUploadFile_Success(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUploadFile_WithExpiryAndVisibility(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseMultipartForm(10 << 20); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if r.FormValue("purpose") != "fine-tune" {
|
||||
t.Errorf("got purpose %q", r.FormValue("purpose"))
|
||||
}
|
||||
if r.FormValue("expiry") != "48" {
|
||||
t.Errorf("expected expiry=48, got %q", r.FormValue("expiry"))
|
||||
}
|
||||
if r.FormValue("visibility") != "private" {
|
||||
t.Errorf("expected visibility=private, got %q", r.FormValue("visibility"))
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "file-ev", "object": "file", "bytes": 10,
|
||||
"created_at": 1, "filename": "data.jsonl",
|
||||
"purpose": "fine-tune", "sample_type": "instruct",
|
||||
"source": "upload",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
expiry := 48
|
||||
vis := file.VisibilityPrivate
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
_, err := client.UploadFile(context.Background(), "data.jsonl", strings.NewReader("{}"), &file.UploadParams{
|
||||
Purpose: file.PurposeFineTune,
|
||||
Expiry: &expiry,
|
||||
Visibility: &vis,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListFiles_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "GET" {
|
||||
@@ -260,7 +297,7 @@ func TestUploadFile_Error(t *testing.T) {
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
_, err := client.UploadFile(context.Background(), "bad.txt", strings.NewReader(""), file.PurposeFineTune)
|
||||
_, err := client.UploadFile(context.Background(), "bad.txt", strings.NewReader(""), &file.UploadParams{Purpose: file.PurposeFineTune})
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
|
||||
@@ -3,8 +3,8 @@ package mistral
|
||||
import (
|
||||
"context"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/chat"
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/fim"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/chat"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/fim"
|
||||
)
|
||||
|
||||
// FIMComplete sends a Fill-In-the-Middle completion request.
|
||||
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/chat"
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/fim"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/chat"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/fim"
|
||||
)
|
||||
|
||||
func TestFIMComplete_Success(t *testing.T) {
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/finetune"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/finetune"
|
||||
)
|
||||
|
||||
// CreateFineTuningJob creates a new fine-tuning job.
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/finetune"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/finetune"
|
||||
)
|
||||
|
||||
func TestCreateFineTuningJob_Success(t *testing.T) {
|
||||
|
||||
2
go.mod
2
go.mod
@@ -1,3 +1,3 @@
|
||||
module somegit.dev/vikingowl/mistral-go-sdk
|
||||
module github.com/VikingOwl91/mistral-go-sdk
|
||||
|
||||
go 1.26
|
||||
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/chat"
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/embedding"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/chat"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/embedding"
|
||||
)
|
||||
|
||||
func integrationClient(t *testing.T) *Client {
|
||||
@@ -24,7 +24,7 @@ func integrationClient(t *testing.T) *Client {
|
||||
func TestIntegration_ListModels(t *testing.T) {
|
||||
client := integrationClient(t)
|
||||
|
||||
resp, err := client.ListModels(context.Background())
|
||||
resp, err := client.ListModels(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/library"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/library"
|
||||
)
|
||||
|
||||
// CreateLibrary creates a new document library.
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/library"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/library"
|
||||
)
|
||||
|
||||
func newLibraryJSON() map[string]any {
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
)
|
||||
|
||||
// Version is the SDK version string.
|
||||
const Version = "0.1.0"
|
||||
const Version = "1.2.1"
|
||||
|
||||
const (
|
||||
defaultBaseURL = "https://api.mistral.ai"
|
||||
|
||||
@@ -34,8 +34,10 @@ type ModelCapabilities struct {
|
||||
OCR bool `json:"ocr"`
|
||||
Classification bool `json:"classification"`
|
||||
Moderation bool `json:"moderation"`
|
||||
Audio bool `json:"audio"`
|
||||
AudioTranscription bool `json:"audio_transcription"`
|
||||
Audio bool `json:"audio"`
|
||||
AudioTranscription bool `json:"audio_transcription"`
|
||||
AudioTranscriptionRealtime bool `json:"audio_transcription_realtime"`
|
||||
AudioSpeech bool `json:"audio_speech"`
|
||||
}
|
||||
|
||||
// ModelList is the response from listing models.
|
||||
@@ -50,3 +52,9 @@ type DeleteModelOut struct {
|
||||
Object string `json:"object"`
|
||||
Deleted bool `json:"deleted"`
|
||||
}
|
||||
|
||||
// ListParams holds optional parameters for listing models.
|
||||
type ListParams struct {
|
||||
Provider *string
|
||||
Model *string
|
||||
}
|
||||
|
||||
20
models.go
20
models.go
@@ -2,14 +2,28 @@ package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/model"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/model"
|
||||
)
|
||||
|
||||
// ListModels returns a list of available models.
|
||||
func (c *Client) ListModels(ctx context.Context) (*model.ModelList, error) {
|
||||
func (c *Client) ListModels(ctx context.Context, params *model.ListParams) (*model.ModelList, error) {
|
||||
path := "/v1/models"
|
||||
if params != nil {
|
||||
q := url.Values{}
|
||||
if params.Provider != nil {
|
||||
q.Set("provider", *params.Provider)
|
||||
}
|
||||
if params.Model != nil {
|
||||
q.Set("model", *params.Model)
|
||||
}
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
path += "?" + encoded
|
||||
}
|
||||
}
|
||||
var resp model.ModelList
|
||||
if err := c.doJSON(ctx, "GET", "/v1/models", nil, &resp); err != nil {
|
||||
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/VikingOwl91/mistral-go-sdk/model"
|
||||
)
|
||||
|
||||
func TestListModels_Success(t *testing.T) {
|
||||
@@ -45,7 +47,7 @@ func TestListModels_Success(t *testing.T) {
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
list, err := client.ListModels(context.Background())
|
||||
list, err := client.ListModels(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -82,6 +84,33 @@ func TestListModels_Success(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestListModels_WithParams(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Query().Get("provider") != "mistralai" {
|
||||
t.Errorf("expected provider=mistralai, got %q", r.URL.Query().Get("provider"))
|
||||
}
|
||||
if r.URL.Query().Get("model") != "mistral-small" {
|
||||
t.Errorf("expected model=mistral-small, got %q", r.URL.Query().Get("model"))
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"object": "list",
|
||||
"data": []map[string]any{},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := "mistralai"
|
||||
modelName := "mistral-small"
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
_, err := client.ListModels(context.Background(), &model.ListParams{
|
||||
Provider: &provider,
|
||||
Model: &modelName,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModel_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/models/mistral-small-latest" {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package moderation
|
||||
|
||||
import "somegit.dev/vikingowl/mistral-go-sdk/chat"
|
||||
import "github.com/VikingOwl91/mistral-go-sdk/chat"
|
||||
|
||||
// Request represents a text moderation request (/v1/moderations).
|
||||
type Request struct {
|
||||
|
||||
@@ -3,8 +3,8 @@ package mistral
|
||||
import (
|
||||
"context"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/classification"
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/moderation"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/classification"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/moderation"
|
||||
)
|
||||
|
||||
// Moderate sends a text moderation request.
|
||||
|
||||
@@ -7,8 +7,8 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/classification"
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/moderation"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/classification"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/moderation"
|
||||
)
|
||||
|
||||
func TestModerate_Success(t *testing.T) {
|
||||
|
||||
46
observability/campaign.go
Normal file
46
observability/campaign.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package observability
|
||||
|
||||
// Campaign represents an observability campaign.
|
||||
type Campaign struct {
|
||||
ID string `json:"id"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
DeletedAt *string `json:"deleted_at,omitempty"`
|
||||
Name string `json:"name"`
|
||||
OwnerID string `json:"owner_id"`
|
||||
WorkspaceID string `json:"workspace_id"`
|
||||
Description string `json:"description"`
|
||||
MaxNbEvents int `json:"max_nb_events"`
|
||||
SearchParams FilterPayload `json:"search_params"`
|
||||
Judge Judge `json:"judge"`
|
||||
}
|
||||
|
||||
// CreateCampaignRequest creates a new campaign.
|
||||
type CreateCampaignRequest struct {
|
||||
SearchParams FilterPayload `json:"search_params"`
|
||||
JudgeID string `json:"judge_id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
MaxNbEvents int `json:"max_nb_events"`
|
||||
}
|
||||
|
||||
// CampaignStatusResponse is the response for campaign status.
|
||||
type CampaignStatusResponse struct {
|
||||
Status TaskStatus `json:"status"`
|
||||
}
|
||||
|
||||
// ListCampaignsResponse is the response from listing campaigns.
|
||||
type ListCampaignsResponse struct {
|
||||
Count int `json:"count"`
|
||||
Results []Campaign `json:"results,omitempty"`
|
||||
Next *string `json:"next,omitempty"`
|
||||
Previous *string `json:"previous,omitempty"`
|
||||
}
|
||||
|
||||
// ListCampaignEventsResponse is the response from listing campaign events.
|
||||
type ListCampaignEventsResponse struct {
|
||||
Count int `json:"count"`
|
||||
Results []ChatCompletionEventPreview `json:"results,omitempty"`
|
||||
Next *string `json:"next,omitempty"`
|
||||
Previous *string `json:"previous,omitempty"`
|
||||
}
|
||||
156
observability/dataset.go
Normal file
156
observability/dataset.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package observability
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// ConversationSource indicates how a dataset record was created.
|
||||
type ConversationSource string
|
||||
|
||||
const (
|
||||
SourceExplorer ConversationSource = "EXPLORER"
|
||||
SourceUploadedFile ConversationSource = "UPLOADED_FILE"
|
||||
SourceDirectInput ConversationSource = "DIRECT_INPUT"
|
||||
SourcePlayground ConversationSource = "PLAYGROUND"
|
||||
)
|
||||
|
||||
// Dataset represents a dataset entity.
|
||||
type Dataset struct {
|
||||
ID string `json:"id"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
DeletedAt *string `json:"deleted_at,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
OwnerID string `json:"owner_id"`
|
||||
WorkspaceID string `json:"workspace_id"`
|
||||
}
|
||||
|
||||
// CreateDatasetRequest creates a new dataset.
|
||||
type CreateDatasetRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
// UpdateDatasetRequest updates a dataset.
|
||||
type UpdateDatasetRequest struct {
|
||||
Name *string `json:"name,omitempty"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
}
|
||||
|
||||
// DatasetRecord is a single record in a dataset.
|
||||
type DatasetRecord struct {
|
||||
ID string `json:"id"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
DeletedAt *string `json:"deleted_at,omitempty"`
|
||||
DatasetID string `json:"dataset_id"`
|
||||
Payload ConversationPayload `json:"payload"`
|
||||
Properties map[string]any `json:"properties,omitempty"`
|
||||
Source ConversationSource `json:"source"`
|
||||
}
|
||||
|
||||
// ConversationPayload holds the messages for a dataset record.
|
||||
type ConversationPayload struct {
|
||||
Messages []map[string]any `json:"messages"`
|
||||
}
|
||||
|
||||
// CreateRecordRequest creates a new dataset record.
|
||||
type CreateRecordRequest struct {
|
||||
Payload ConversationPayload `json:"payload"`
|
||||
Properties map[string]any `json:"properties"`
|
||||
}
|
||||
|
||||
// UpdateRecordPayloadRequest updates a record's payload.
|
||||
type UpdateRecordPayloadRequest struct {
|
||||
Payload ConversationPayload `json:"payload"`
|
||||
}
|
||||
|
||||
// UpdateRecordPropertiesRequest updates a record's properties.
|
||||
type UpdateRecordPropertiesRequest struct {
|
||||
Properties map[string]any `json:"properties"`
|
||||
}
|
||||
|
||||
// BulkDeleteRecordsRequest deletes multiple records.
|
||||
type BulkDeleteRecordsRequest struct {
|
||||
DatasetRecordIDs []string `json:"dataset_record_ids"`
|
||||
}
|
||||
|
||||
// JudgeRecordRequest judges a dataset record.
|
||||
type JudgeRecordRequest struct {
|
||||
JudgeDefinition CreateJudgeRequest `json:"judge_definition"`
|
||||
}
|
||||
|
||||
// DatasetImportTask tracks an async import operation.
|
||||
type DatasetImportTask struct {
|
||||
ID string `json:"id"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
DeletedAt *string `json:"deleted_at,omitempty"`
|
||||
CreatorID string `json:"creator_id"`
|
||||
DatasetID string `json:"dataset_id"`
|
||||
WorkspaceID string `json:"workspace_id"`
|
||||
Status TaskStatus `json:"status"`
|
||||
Progress *int `json:"progress,omitempty"`
|
||||
Message *string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// ExportDatasetResponse is the response from exporting a dataset.
|
||||
type ExportDatasetResponse struct {
|
||||
FileURL string `json:"file_url"`
|
||||
}
|
||||
|
||||
// Import request types.
|
||||
|
||||
// ImportFromCampaignRequest imports records from a campaign.
|
||||
type ImportFromCampaignRequest struct {
|
||||
CampaignID string `json:"campaign_id"`
|
||||
}
|
||||
|
||||
// ImportFromExplorerRequest imports records from explorer events.
|
||||
type ImportFromExplorerRequest struct {
|
||||
CompletionEventIDs []string `json:"completion_event_ids"`
|
||||
}
|
||||
|
||||
// ImportFromFileRequest imports records from a file.
|
||||
type ImportFromFileRequest struct {
|
||||
FileID string `json:"file_id"`
|
||||
}
|
||||
|
||||
// ImportFromPlaygroundRequest imports records from playground conversations.
|
||||
type ImportFromPlaygroundRequest struct {
|
||||
ConversationIDs []string `json:"conversation_ids"`
|
||||
}
|
||||
|
||||
// ImportFromDatasetRequest imports records from another dataset.
|
||||
type ImportFromDatasetRequest struct {
|
||||
DatasetRecordIDs []string `json:"dataset_record_ids"`
|
||||
}
|
||||
|
||||
// List response types.
|
||||
|
||||
// ListDatasetsResponse is the response from listing datasets.
|
||||
type ListDatasetsResponse struct {
|
||||
Count int `json:"count"`
|
||||
Results []Dataset `json:"results,omitempty"`
|
||||
Next *string `json:"next,omitempty"`
|
||||
Previous *string `json:"previous,omitempty"`
|
||||
}
|
||||
|
||||
// ListRecordsResponse is the response from listing dataset records.
|
||||
type ListRecordsResponse struct {
|
||||
Count int `json:"count"`
|
||||
Results []DatasetRecord `json:"results,omitempty"`
|
||||
Next *string `json:"next,omitempty"`
|
||||
Previous *string `json:"previous,omitempty"`
|
||||
}
|
||||
|
||||
// ListTasksResponse is the response from listing import tasks.
|
||||
type ListTasksResponse struct {
|
||||
Count int `json:"count"`
|
||||
Results []DatasetImportTask `json:"results,omitempty"`
|
||||
Next *string `json:"next,omitempty"`
|
||||
Previous *string `json:"previous,omitempty"`
|
||||
}
|
||||
|
||||
// JudgeResultResponse is the raw response from judging operations.
|
||||
// The shape depends on the judge type (classification or regression).
|
||||
type JudgeResultResponse json.RawMessage
|
||||
5
observability/doc.go
Normal file
5
observability/doc.go
Normal file
@@ -0,0 +1,5 @@
|
||||
// Package observability provides types for the Mistral observability API (beta).
|
||||
//
|
||||
// This includes campaigns, chat completion events, judges, and datasets
|
||||
// for monitoring and evaluating model behavior.
|
||||
package observability
|
||||
70
observability/event.go
Normal file
70
observability/event.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package observability
|
||||
|
||||
// ChatCompletionEvent is a full chat completion event.
|
||||
type ChatCompletionEvent struct {
|
||||
EventID string `json:"event_id"`
|
||||
CorrelationID string `json:"correlation_id"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
ExtraFields map[string]any `json:"extra_fields,omitempty"`
|
||||
NbInputTokens int `json:"nb_input_tokens"`
|
||||
NbOutputTokens int `json:"nb_output_tokens"`
|
||||
EnabledTools []map[string]any `json:"enabled_tools,omitempty"`
|
||||
RequestMessages []map[string]any `json:"request_messages,omitempty"`
|
||||
ResponseMessages []map[string]any `json:"response_messages,omitempty"`
|
||||
NbMessages int `json:"nb_messages"`
|
||||
ChatTranscriptionEvents []ChatTranscriptionEvent `json:"chat_transcription_events,omitempty"`
|
||||
}
|
||||
|
||||
// ChatCompletionEventPreview is a summary of a chat completion event.
|
||||
type ChatCompletionEventPreview struct {
|
||||
EventID string `json:"event_id"`
|
||||
CorrelationID string `json:"correlation_id"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
ExtraFields map[string]any `json:"extra_fields,omitempty"`
|
||||
NbInputTokens int `json:"nb_input_tokens"`
|
||||
NbOutputTokens int `json:"nb_output_tokens"`
|
||||
}
|
||||
|
||||
// ChatTranscriptionEvent is an audio transcription within a chat event.
|
||||
type ChatTranscriptionEvent struct {
|
||||
AudioURL string `json:"audio_url"`
|
||||
Model string `json:"model"`
|
||||
ResponseMessage map[string]any `json:"response_message"`
|
||||
}
|
||||
|
||||
// SearchEventsRequest is the request body for searching chat completion events.
|
||||
type SearchEventsRequest struct {
|
||||
SearchParams FilterPayload `json:"search_params"`
|
||||
ExtraFields []string `json:"extra_fields,omitempty"`
|
||||
}
|
||||
|
||||
// SearchEventsResponse is the response from searching events.
|
||||
type SearchEventsResponse struct {
|
||||
Results []ChatCompletionEventPreview `json:"results,omitempty"`
|
||||
Next *string `json:"next,omitempty"`
|
||||
Cursor *string `json:"cursor,omitempty"`
|
||||
}
|
||||
|
||||
// SearchEventIDsRequest is the request body for searching event IDs.
|
||||
type SearchEventIDsRequest struct {
|
||||
SearchParams FilterPayload `json:"search_params"`
|
||||
ExtraFields []string `json:"extra_fields,omitempty"`
|
||||
}
|
||||
|
||||
// SearchEventIDsResponse is the response from searching event IDs.
|
||||
type SearchEventIDsResponse struct {
|
||||
CompletionEventIDs []string `json:"completion_event_ids"`
|
||||
}
|
||||
|
||||
// JudgeEventRequest is the request body for judging a chat completion event.
|
||||
type JudgeEventRequest struct {
|
||||
JudgeDefinition CreateJudgeRequest `json:"judge_definition"`
|
||||
}
|
||||
|
||||
// SimilarEventsResponse is the response from fetching similar events.
|
||||
type SimilarEventsResponse struct {
|
||||
Count int `json:"count"`
|
||||
Results []ChatCompletionEventPreview `json:"results,omitempty"`
|
||||
Next *string `json:"next,omitempty"`
|
||||
Previous *string `json:"previous,omitempty"`
|
||||
}
|
||||
75
observability/filter.go
Normal file
75
observability/filter.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package observability
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// Op is a filter comparison operator.
|
||||
type Op string
|
||||
|
||||
const (
|
||||
OpLt Op = "lt"
|
||||
OpLte Op = "lte"
|
||||
OpGt Op = "gt"
|
||||
OpGte Op = "gte"
|
||||
OpEq Op = "eq"
|
||||
OpNeq Op = "neq"
|
||||
OpIsNull Op = "isnull"
|
||||
OpStartsWith Op = "startswith"
|
||||
OpIStartsWith Op = "istartswith"
|
||||
OpEndsWith Op = "endswith"
|
||||
OpIEndsWith Op = "iendswith"
|
||||
OpContains Op = "contains"
|
||||
OpIContains Op = "icontains"
|
||||
OpMatches Op = "matches"
|
||||
OpNotContains Op = "notcontains"
|
||||
OpINotContains Op = "inotcontains"
|
||||
OpIncludes Op = "includes"
|
||||
OpExcludes Op = "excludes"
|
||||
OpLenEq Op = "len_eq"
|
||||
)
|
||||
|
||||
// FilterCondition is a single filter comparison.
|
||||
type FilterCondition struct {
|
||||
Field string `json:"field"`
|
||||
Op Op `json:"op"`
|
||||
Value any `json:"value"`
|
||||
}
|
||||
|
||||
// FilterGroup combines filters with AND/OR logic.
|
||||
// The JSON keys are uppercase "AND" / "OR".
|
||||
type FilterGroup struct {
|
||||
AND []json.RawMessage `json:"AND,omitempty"`
|
||||
OR []json.RawMessage `json:"OR,omitempty"`
|
||||
}
|
||||
|
||||
// FilterPayload wraps the top-level filter for search operations.
|
||||
// Filters can be a FilterGroup or a FilterCondition.
|
||||
type FilterPayload struct {
|
||||
Filters json.RawMessage `json:"filters,omitempty"`
|
||||
}
|
||||
|
||||
// TaskStatus is the status of an async task.
|
||||
type TaskStatus string
|
||||
|
||||
const (
|
||||
TaskStatusRunning TaskStatus = "RUNNING"
|
||||
TaskStatusCompleted TaskStatus = "COMPLETED"
|
||||
TaskStatusFailed TaskStatus = "FAILED"
|
||||
TaskStatusCanceled TaskStatus = "CANCELED"
|
||||
TaskStatusTerminated TaskStatus = "TERMINATED"
|
||||
TaskStatusContinuedAsNew TaskStatus = "CONTINUED_AS_NEW"
|
||||
TaskStatusTimedOut TaskStatus = "TIMED_OUT"
|
||||
TaskStatusUnknown TaskStatus = "UNKNOWN"
|
||||
)
|
||||
|
||||
// PaginationParams holds common pagination query parameters.
|
||||
type PaginationParams struct {
|
||||
Page *int
|
||||
PageSize *int
|
||||
}
|
||||
|
||||
// SearchParams holds common search query parameters.
|
||||
type SearchParams struct {
|
||||
Page *int
|
||||
PageSize *int
|
||||
Q *string
|
||||
}
|
||||
114
observability/judge.go
Normal file
114
observability/judge.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package observability
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// JudgeOutputType identifies the kind of judge output.
|
||||
type JudgeOutputType string
|
||||
|
||||
const (
|
||||
JudgeOutputClassification JudgeOutputType = "CLASSIFICATION"
|
||||
JudgeOutputRegression JudgeOutputType = "REGRESSION"
|
||||
)
|
||||
|
||||
// JudgeOutput is a sealed interface for judge output configurations.
|
||||
type JudgeOutput interface {
|
||||
judgeOutputType() JudgeOutputType
|
||||
}
|
||||
|
||||
// ClassificationOutput configures a classification judge.
|
||||
type ClassificationOutput struct {
|
||||
Type JudgeOutputType `json:"type"`
|
||||
Options []ClassificationOption `json:"options"`
|
||||
}
|
||||
|
||||
func (*ClassificationOutput) judgeOutputType() JudgeOutputType { return JudgeOutputClassification }
|
||||
|
||||
// ClassificationOption is a single option for a classification judge.
|
||||
type ClassificationOption struct {
|
||||
Value string `json:"value"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
// RegressionOutput configures a regression judge.
|
||||
type RegressionOutput struct {
|
||||
Type JudgeOutputType `json:"type"`
|
||||
MinDescription string `json:"min_description"`
|
||||
MaxDescription string `json:"max_description"`
|
||||
Min *float64 `json:"min,omitempty"`
|
||||
Max *float64 `json:"max,omitempty"`
|
||||
}
|
||||
|
||||
func (*RegressionOutput) judgeOutputType() JudgeOutputType { return JudgeOutputRegression }
|
||||
|
||||
// UnmarshalJudgeOutput dispatches to the concrete JudgeOutput type.
|
||||
func UnmarshalJudgeOutput(data []byte) (JudgeOutput, error) {
|
||||
var probe struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &probe); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal judge output: %w", err)
|
||||
}
|
||||
switch JudgeOutputType(probe.Type) {
|
||||
case JudgeOutputClassification:
|
||||
var o ClassificationOutput
|
||||
return &o, json.Unmarshal(data, &o)
|
||||
case JudgeOutputRegression:
|
||||
var o RegressionOutput
|
||||
return &o, json.Unmarshal(data, &o)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown judge output type: %q", probe.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// Judge represents a judge entity.
|
||||
type Judge struct {
|
||||
ID string `json:"id"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
DeletedAt *string `json:"deleted_at,omitempty"`
|
||||
OwnerID string `json:"owner_id"`
|
||||
WorkspaceID string `json:"workspace_id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
ModelName string `json:"model_name"`
|
||||
Output json.RawMessage `json:"output"`
|
||||
Instructions string `json:"instructions"`
|
||||
Tools []string `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
// CreateJudgeRequest creates a new judge.
|
||||
type CreateJudgeRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
ModelName string `json:"model_name"`
|
||||
Output json.RawMessage `json:"output"`
|
||||
Instructions string `json:"instructions"`
|
||||
Tools []string `json:"tools"`
|
||||
}
|
||||
|
||||
// UpdateJudgeRequest updates a judge.
|
||||
type UpdateJudgeRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
ModelName string `json:"model_name"`
|
||||
Output json.RawMessage `json:"output"`
|
||||
Instructions string `json:"instructions"`
|
||||
Tools []string `json:"tools"`
|
||||
}
|
||||
|
||||
// JudgeConversationRequest is the request for live-judging a conversation.
|
||||
type JudgeConversationRequest struct {
|
||||
Messages []map[string]any `json:"messages"`
|
||||
Properties map[string]any `json:"properties,omitempty"`
|
||||
}
|
||||
|
||||
// ListJudgesResponse is the response from listing judges.
|
||||
type ListJudgesResponse struct {
|
||||
Count int `json:"count"`
|
||||
Results []Judge `json:"results,omitempty"`
|
||||
Next *string `json:"next,omitempty"`
|
||||
Previous *string `json:"previous,omitempty"`
|
||||
}
|
||||
97
observability_campaigns.go
Normal file
97
observability_campaigns.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"github.com/VikingOwl91/mistral-go-sdk/observability"
|
||||
)
|
||||
|
||||
// CreateCampaign creates a new observability campaign.
|
||||
func (c *Client) CreateCampaign(ctx context.Context, req *observability.CreateCampaignRequest) (*observability.Campaign, error) {
|
||||
var resp observability.Campaign
|
||||
if err := c.doJSON(ctx, "POST", "/v1/observability/campaigns", req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ListCampaigns lists observability campaigns.
|
||||
func (c *Client) ListCampaigns(ctx context.Context, params *observability.SearchParams) (*observability.ListCampaignsResponse, error) {
|
||||
path := "/v1/observability/campaigns"
|
||||
if params != nil {
|
||||
q := url.Values{}
|
||||
if params.PageSize != nil {
|
||||
q.Set("page_size", strconv.Itoa(*params.PageSize))
|
||||
}
|
||||
if params.Page != nil {
|
||||
q.Set("page", strconv.Itoa(*params.Page))
|
||||
}
|
||||
if params.Q != nil {
|
||||
q.Set("q", *params.Q)
|
||||
}
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
path += "?" + encoded
|
||||
}
|
||||
}
|
||||
var resp observability.ListCampaignsResponse
|
||||
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// GetCampaign retrieves a campaign by ID.
|
||||
func (c *Client) GetCampaign(ctx context.Context, campaignID string) (*observability.Campaign, error) {
|
||||
var resp observability.Campaign
|
||||
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/campaigns/%s", campaignID), nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// DeleteCampaign deletes a campaign.
|
||||
func (c *Client) DeleteCampaign(ctx context.Context, campaignID string) error {
|
||||
resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/observability/campaigns/%s", campaignID), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 400 {
|
||||
return parseAPIError(resp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCampaignStatus retrieves the status of a campaign.
|
||||
func (c *Client) GetCampaignStatus(ctx context.Context, campaignID string) (*observability.CampaignStatusResponse, error) {
|
||||
var resp observability.CampaignStatusResponse
|
||||
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/campaigns/%s/status", campaignID), nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ListCampaignEvents lists events selected by a campaign.
|
||||
func (c *Client) ListCampaignEvents(ctx context.Context, campaignID string, params *observability.PaginationParams) (*observability.ListCampaignEventsResponse, error) {
|
||||
path := fmt.Sprintf("/v1/observability/campaigns/%s/selected-events", campaignID)
|
||||
if params != nil {
|
||||
q := url.Values{}
|
||||
if params.PageSize != nil {
|
||||
q.Set("page_size", strconv.Itoa(*params.PageSize))
|
||||
}
|
||||
if params.Page != nil {
|
||||
q.Set("page", strconv.Itoa(*params.Page))
|
||||
}
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
path += "?" + encoded
|
||||
}
|
||||
}
|
||||
var resp observability.ListCampaignEventsResponse
|
||||
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
148
observability_campaigns_test.go
Normal file
148
observability_campaigns_test.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/VikingOwl91/mistral-go-sdk/observability"
|
||||
)
|
||||
|
||||
func TestCreateCampaign_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" || r.URL.Path != "/v1/observability/campaigns" {
|
||||
t.Errorf("unexpected %s %s", r.Method, r.URL.Path)
|
||||
}
|
||||
var body map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
if body["name"] != "test-campaign" {
|
||||
t.Errorf("got name %v", body["name"])
|
||||
}
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "camp-1", "name": "test-campaign", "description": "d",
|
||||
"created_at": "t", "updated_at": "t", "owner_id": "o",
|
||||
"workspace_id": "w", "max_nb_events": 100,
|
||||
"search_params": map[string]any{}, "judge": map[string]any{"id": "j1"},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.CreateCampaign(context.Background(), &observability.CreateCampaignRequest{
|
||||
Name: "test-campaign",
|
||||
Description: "d",
|
||||
JudgeID: "j1",
|
||||
MaxNbEvents: 100,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.ID != "camp-1" {
|
||||
t.Errorf("got id %q", resp.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListCampaigns_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/observability/campaigns" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"count": 1,
|
||||
"results": []map[string]any{{"id": "c1", "name": "c", "description": "d", "created_at": "t", "updated_at": "t", "owner_id": "o", "workspace_id": "w", "max_nb_events": 10, "search_params": map[string]any{}, "judge": map[string]any{"id": "j"}}},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.ListCampaigns(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Count != 1 {
|
||||
t.Errorf("got count %d", resp.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCampaign_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/observability/campaigns/camp-1" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "camp-1", "name": "c", "description": "d",
|
||||
"created_at": "t", "updated_at": "t", "owner_id": "o",
|
||||
"workspace_id": "w", "max_nb_events": 10,
|
||||
"search_params": map[string]any{}, "judge": map[string]any{"id": "j"},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.GetCampaign(context.Background(), "camp-1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.ID != "camp-1" {
|
||||
t.Errorf("got id %q", resp.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteCampaign_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "DELETE" {
|
||||
t.Errorf("expected DELETE")
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
if err := client.DeleteCampaign(context.Background(), "camp-1"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCampaignStatus_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/observability/campaigns/camp-1/status" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{"status": "COMPLETED"})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.GetCampaignStatus(context.Background(), "camp-1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Status != observability.TaskStatusCompleted {
|
||||
t.Errorf("got status %q", resp.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListCampaignEvents_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/observability/campaigns/camp-1/selected-events" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"count": 0,
|
||||
"results": []any{},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.ListCampaignEvents(context.Background(), "camp-1", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Count != 0 {
|
||||
t.Errorf("got count %d", resp.Count)
|
||||
}
|
||||
}
|
||||
252
observability_datasets.go
Normal file
252
observability_datasets.go
Normal file
@@ -0,0 +1,252 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"github.com/VikingOwl91/mistral-go-sdk/observability"
|
||||
)
|
||||
|
||||
// CreateDataset creates a new observability dataset.
|
||||
func (c *Client) CreateDataset(ctx context.Context, req *observability.CreateDatasetRequest) (*observability.Dataset, error) {
|
||||
var resp observability.Dataset
|
||||
if err := c.doJSON(ctx, "POST", "/v1/observability/datasets", req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ListDatasets lists observability datasets.
|
||||
func (c *Client) ListDatasets(ctx context.Context, params *observability.SearchParams) (*observability.ListDatasetsResponse, error) {
|
||||
path := "/v1/observability/datasets"
|
||||
if params != nil {
|
||||
q := url.Values{}
|
||||
if params.PageSize != nil {
|
||||
q.Set("page_size", strconv.Itoa(*params.PageSize))
|
||||
}
|
||||
if params.Page != nil {
|
||||
q.Set("page", strconv.Itoa(*params.Page))
|
||||
}
|
||||
if params.Q != nil {
|
||||
q.Set("q", *params.Q)
|
||||
}
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
path += "?" + encoded
|
||||
}
|
||||
}
|
||||
var resp observability.ListDatasetsResponse
|
||||
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// GetDataset retrieves a dataset by ID.
|
||||
func (c *Client) GetDataset(ctx context.Context, datasetID string) (*observability.Dataset, error) {
|
||||
var resp observability.Dataset
|
||||
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/datasets/%s", datasetID), nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// UpdateDataset updates a dataset.
|
||||
func (c *Client) UpdateDataset(ctx context.Context, datasetID string, req *observability.UpdateDatasetRequest) (*observability.Dataset, error) {
|
||||
var resp observability.Dataset
|
||||
if err := c.doJSON(ctx, "PATCH", fmt.Sprintf("/v1/observability/datasets/%s", datasetID), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// DeleteDataset deletes a dataset.
|
||||
func (c *Client) DeleteDataset(ctx context.Context, datasetID string) error {
|
||||
resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/observability/datasets/%s", datasetID), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 400 {
|
||||
return parseAPIError(resp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExportDatasetToJSONL exports a dataset to JSONL format.
|
||||
func (c *Client) ExportDatasetToJSONL(ctx context.Context, datasetID string) (*observability.ExportDatasetResponse, error) {
|
||||
var resp observability.ExportDatasetResponse
|
||||
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/datasets/%s/exports/to-jsonl", datasetID), nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// Dataset records
|
||||
|
||||
// ListDatasetRecords lists records in a dataset.
|
||||
func (c *Client) ListDatasetRecords(ctx context.Context, datasetID string, params *observability.PaginationParams) (*observability.ListRecordsResponse, error) {
|
||||
path := fmt.Sprintf("/v1/observability/datasets/%s/records", datasetID)
|
||||
if params != nil {
|
||||
q := url.Values{}
|
||||
if params.PageSize != nil {
|
||||
q.Set("page_size", strconv.Itoa(*params.PageSize))
|
||||
}
|
||||
if params.Page != nil {
|
||||
q.Set("page", strconv.Itoa(*params.Page))
|
||||
}
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
path += "?" + encoded
|
||||
}
|
||||
}
|
||||
var resp observability.ListRecordsResponse
|
||||
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// CreateDatasetRecord creates a record in a dataset.
|
||||
func (c *Client) CreateDatasetRecord(ctx context.Context, datasetID string, req *observability.CreateRecordRequest) (*observability.DatasetRecord, error) {
|
||||
var resp observability.DatasetRecord
|
||||
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/records", datasetID), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// GetDatasetRecord retrieves a dataset record by ID.
|
||||
func (c *Client) GetDatasetRecord(ctx context.Context, recordID string) (*observability.DatasetRecord, error) {
|
||||
var resp observability.DatasetRecord
|
||||
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/dataset-records/%s", recordID), nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// UpdateDatasetRecordPayload updates a record's payload.
|
||||
func (c *Client) UpdateDatasetRecordPayload(ctx context.Context, recordID string, req *observability.UpdateRecordPayloadRequest) (*observability.DatasetRecord, error) {
|
||||
var resp observability.DatasetRecord
|
||||
if err := c.doJSON(ctx, "PUT", fmt.Sprintf("/v1/observability/dataset-records/%s/payload", recordID), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// UpdateDatasetRecordProperties updates a record's properties.
|
||||
func (c *Client) UpdateDatasetRecordProperties(ctx context.Context, recordID string, req *observability.UpdateRecordPropertiesRequest) (*observability.DatasetRecord, error) {
|
||||
var resp observability.DatasetRecord
|
||||
if err := c.doJSON(ctx, "PUT", fmt.Sprintf("/v1/observability/dataset-records/%s/properties", recordID), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// DeleteDatasetRecord deletes a dataset record.
|
||||
func (c *Client) DeleteDatasetRecord(ctx context.Context, recordID string) error {
|
||||
resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/observability/dataset-records/%s", recordID), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 400 {
|
||||
return parseAPIError(resp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BulkDeleteDatasetRecords deletes multiple dataset records.
|
||||
func (c *Client) BulkDeleteDatasetRecords(ctx context.Context, req *observability.BulkDeleteRecordsRequest) error {
|
||||
return c.doJSON(ctx, "POST", "/v1/observability/dataset-records/bulk-delete", req, nil)
|
||||
}
|
||||
|
||||
// JudgeDatasetRecord judges a dataset record.
|
||||
func (c *Client) JudgeDatasetRecord(ctx context.Context, recordID string, req *observability.JudgeRecordRequest) (json.RawMessage, error) {
|
||||
var resp json.RawMessage
|
||||
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/dataset-records/%s/live-judging", recordID), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Import operations
|
||||
|
||||
// ImportDatasetFromCampaign imports records from a campaign.
|
||||
func (c *Client) ImportDatasetFromCampaign(ctx context.Context, datasetID string, req *observability.ImportFromCampaignRequest) (*observability.DatasetImportTask, error) {
|
||||
var resp observability.DatasetImportTask
|
||||
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/imports/from-campaign", datasetID), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ImportDatasetFromExplorer imports records from explorer events.
|
||||
func (c *Client) ImportDatasetFromExplorer(ctx context.Context, datasetID string, req *observability.ImportFromExplorerRequest) (*observability.DatasetImportTask, error) {
|
||||
var resp observability.DatasetImportTask
|
||||
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/imports/from-explorer", datasetID), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ImportDatasetFromFile imports records from a file.
|
||||
func (c *Client) ImportDatasetFromFile(ctx context.Context, datasetID string, req *observability.ImportFromFileRequest) (*observability.DatasetImportTask, error) {
|
||||
var resp observability.DatasetImportTask
|
||||
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/imports/from-file", datasetID), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ImportDatasetFromPlayground imports records from playground conversations.
|
||||
func (c *Client) ImportDatasetFromPlayground(ctx context.Context, datasetID string, req *observability.ImportFromPlaygroundRequest) (*observability.DatasetImportTask, error) {
|
||||
var resp observability.DatasetImportTask
|
||||
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/imports/from-playground", datasetID), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ImportDatasetFromDataset imports records from another dataset.
|
||||
func (c *Client) ImportDatasetFromDataset(ctx context.Context, datasetID string, req *observability.ImportFromDatasetRequest) (*observability.DatasetImportTask, error) {
|
||||
var resp observability.DatasetImportTask
|
||||
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/imports/from-dataset", datasetID), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// Tasks
|
||||
|
||||
// ListDatasetTasks lists import tasks for a dataset.
|
||||
func (c *Client) ListDatasetTasks(ctx context.Context, datasetID string, params *observability.PaginationParams) (*observability.ListTasksResponse, error) {
|
||||
path := fmt.Sprintf("/v1/observability/datasets/%s/tasks", datasetID)
|
||||
if params != nil {
|
||||
q := url.Values{}
|
||||
if params.PageSize != nil {
|
||||
q.Set("page_size", strconv.Itoa(*params.PageSize))
|
||||
}
|
||||
if params.Page != nil {
|
||||
q.Set("page", strconv.Itoa(*params.Page))
|
||||
}
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
path += "?" + encoded
|
||||
}
|
||||
}
|
||||
var resp observability.ListTasksResponse
|
||||
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// GetDatasetTask retrieves an import task by ID.
|
||||
func (c *Client) GetDatasetTask(ctx context.Context, datasetID, taskID string) (*observability.DatasetImportTask, error) {
|
||||
var resp observability.DatasetImportTask
|
||||
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/datasets/%s/tasks/%s", datasetID, taskID), nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
211
observability_datasets_test.go
Normal file
211
observability_datasets_test.go
Normal file
@@ -0,0 +1,211 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/VikingOwl91/mistral-go-sdk/observability"
|
||||
)
|
||||
|
||||
func datasetJSON() map[string]any {
|
||||
return map[string]any{
|
||||
"id": "ds-1", "created_at": "t", "updated_at": "t",
|
||||
"name": "test-ds", "description": "d",
|
||||
"owner_id": "o", "workspace_id": "w",
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateDataset_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" || r.URL.Path != "/v1/observability/datasets" {
|
||||
t.Errorf("unexpected %s %s", r.Method, r.URL.Path)
|
||||
}
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(datasetJSON())
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.CreateDataset(context.Background(), &observability.CreateDatasetRequest{
|
||||
Name: "test-ds",
|
||||
Description: "d",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.ID != "ds-1" {
|
||||
t.Errorf("got id %q", resp.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListDatasets_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"count": 1,
|
||||
"results": []any{datasetJSON()},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.ListDatasets(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Count != 1 {
|
||||
t.Errorf("got count %d", resp.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDataset_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/observability/datasets/ds-1" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(datasetJSON())
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.GetDataset(context.Background(), "ds-1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Name != "test-ds" {
|
||||
t.Errorf("got name %q", resp.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteDataset_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
if err := client.DeleteDataset(context.Background(), "ds-1"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateDatasetRecord_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" || r.URL.Path != "/v1/observability/datasets/ds-1/records" {
|
||||
t.Errorf("unexpected %s %s", r.Method, r.URL.Path)
|
||||
}
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "rec-1", "created_at": "t", "updated_at": "t",
|
||||
"dataset_id": "ds-1", "source": "DIRECT_INPUT",
|
||||
"payload": map[string]any{"messages": []any{}},
|
||||
"properties": map[string]any{},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.CreateDatasetRecord(context.Background(), "ds-1", &observability.CreateRecordRequest{
|
||||
Payload: observability.ConversationPayload{Messages: []map[string]any{}},
|
||||
Properties: map[string]any{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.ID != "rec-1" {
|
||||
t.Errorf("got id %q", resp.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListDatasetRecords_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/observability/datasets/ds-1/records" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"count": 0, "results": []any{},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.ListDatasetRecords(context.Background(), "ds-1", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Count != 0 {
|
||||
t.Errorf("got count %d", resp.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportDatasetFromCampaign_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/observability/datasets/ds-1/imports/from-campaign" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "task-1", "created_at": "t", "updated_at": "t",
|
||||
"creator_id": "u", "dataset_id": "ds-1", "workspace_id": "w",
|
||||
"status": "RUNNING",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.ImportDatasetFromCampaign(context.Background(), "ds-1", &observability.ImportFromCampaignRequest{
|
||||
CampaignID: "camp-1",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Status != observability.TaskStatusRunning {
|
||||
t.Errorf("got status %q", resp.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExportDatasetToJSONL_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/observability/datasets/ds-1/exports/to-jsonl" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"file_url": "https://storage.example.com/export.jsonl",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.ExportDatasetToJSONL(context.Background(), "ds-1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.FileURL != "https://storage.example.com/export.jsonl" {
|
||||
t.Errorf("got file_url %q", resp.FileURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDatasetTask_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/observability/datasets/ds-1/tasks/task-1" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "task-1", "created_at": "t", "updated_at": "t",
|
||||
"creator_id": "u", "dataset_id": "ds-1", "workspace_id": "w",
|
||||
"status": "COMPLETED",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.GetDatasetTask(context.Background(), "ds-1", "task-1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Status != observability.TaskStatusCompleted {
|
||||
t.Errorf("got status %q", resp.Status)
|
||||
}
|
||||
}
|
||||
69
observability_events.go
Normal file
69
observability_events.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"github.com/VikingOwl91/mistral-go-sdk/observability"
|
||||
)
|
||||
|
||||
// SearchChatCompletionEvents searches for chat completion events.
|
||||
func (c *Client) SearchChatCompletionEvents(ctx context.Context, req *observability.SearchEventsRequest) (*observability.SearchEventsResponse, error) {
|
||||
var resp observability.SearchEventsResponse
|
||||
if err := c.doJSON(ctx, "POST", "/v1/observability/chat-completion-events/search", req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// SearchChatCompletionEventIDs searches for chat completion event IDs.
|
||||
func (c *Client) SearchChatCompletionEventIDs(ctx context.Context, req *observability.SearchEventIDsRequest) (*observability.SearchEventIDsResponse, error) {
|
||||
var resp observability.SearchEventIDsResponse
|
||||
if err := c.doJSON(ctx, "POST", "/v1/observability/chat-completion-events/search-ids", req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// GetChatCompletionEvent retrieves a chat completion event by ID.
|
||||
func (c *Client) GetChatCompletionEvent(ctx context.Context, eventID string) (*observability.ChatCompletionEvent, error) {
|
||||
var resp observability.ChatCompletionEvent
|
||||
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/chat-completion-events/%s", eventID), nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// GetSimilarChatCompletionEvents retrieves events similar to a given event.
|
||||
func (c *Client) GetSimilarChatCompletionEvents(ctx context.Context, eventID string, params *observability.PaginationParams) (*observability.SimilarEventsResponse, error) {
|
||||
path := fmt.Sprintf("/v1/observability/chat-completion-events/%s/similar-events", eventID)
|
||||
if params != nil {
|
||||
q := url.Values{}
|
||||
if params.PageSize != nil {
|
||||
q.Set("page_size", strconv.Itoa(*params.PageSize))
|
||||
}
|
||||
if params.Page != nil {
|
||||
q.Set("page", strconv.Itoa(*params.Page))
|
||||
}
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
path += "?" + encoded
|
||||
}
|
||||
}
|
||||
var resp observability.SimilarEventsResponse
|
||||
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// JudgeChatCompletionEvent judges a chat completion event.
|
||||
func (c *Client) JudgeChatCompletionEvent(ctx context.Context, eventID string, req *observability.JudgeEventRequest) (json.RawMessage, error) {
|
||||
var resp json.RawMessage
|
||||
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/chat-completion-events/%s/live-judging", eventID), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
101
observability_events_test.go
Normal file
101
observability_events_test.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/VikingOwl91/mistral-go-sdk/observability"
|
||||
)
|
||||
|
||||
func TestSearchChatCompletionEvents_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" || r.URL.Path != "/v1/observability/chat-completion-events/search" {
|
||||
t.Errorf("unexpected %s %s", r.Method, r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"results": []map[string]any{
|
||||
{"event_id": "ev-1", "correlation_id": "c1", "created_at": "t", "nb_input_tokens": 10, "nb_output_tokens": 5},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.SearchChatCompletionEvents(context.Background(), &observability.SearchEventsRequest{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(resp.Results) != 1 {
|
||||
t.Fatalf("got %d results", len(resp.Results))
|
||||
}
|
||||
if resp.Results[0].EventID != "ev-1" {
|
||||
t.Errorf("got event_id %q", resp.Results[0].EventID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchChatCompletionEventIDs_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/observability/chat-completion-events/search-ids" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"completion_event_ids": []string{"ev-1", "ev-2"},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.SearchChatCompletionEventIDs(context.Background(), &observability.SearchEventIDsRequest{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(resp.CompletionEventIDs) != 2 {
|
||||
t.Errorf("got %d ids", len(resp.CompletionEventIDs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetChatCompletionEvent_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/observability/chat-completion-events/ev-1" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"event_id": "ev-1", "correlation_id": "c1", "created_at": "t",
|
||||
"nb_input_tokens": 10, "nb_output_tokens": 5, "nb_messages": 2,
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.GetChatCompletionEvent(context.Background(), "ev-1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.EventID != "ev-1" {
|
||||
t.Errorf("got event_id %q", resp.EventID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSimilarChatCompletionEvents_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/observability/chat-completion-events/ev-1/similar-events" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"count": 0, "results": []any{},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.GetSimilarChatCompletionEvents(context.Background(), "ev-1", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Count != 0 {
|
||||
t.Errorf("got count %d", resp.Count)
|
||||
}
|
||||
}
|
||||
85
observability_judges.go
Normal file
85
observability_judges.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"github.com/VikingOwl91/mistral-go-sdk/observability"
|
||||
)
|
||||
|
||||
// CreateJudge creates a new observability judge.
|
||||
func (c *Client) CreateJudge(ctx context.Context, req *observability.CreateJudgeRequest) (*observability.Judge, error) {
|
||||
var resp observability.Judge
|
||||
if err := c.doJSON(ctx, "POST", "/v1/observability/judges", req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ListJudges lists observability judges.
|
||||
func (c *Client) ListJudges(ctx context.Context, params *observability.SearchParams) (*observability.ListJudgesResponse, error) {
|
||||
path := "/v1/observability/judges"
|
||||
if params != nil {
|
||||
q := url.Values{}
|
||||
if params.PageSize != nil {
|
||||
q.Set("page_size", strconv.Itoa(*params.PageSize))
|
||||
}
|
||||
if params.Page != nil {
|
||||
q.Set("page", strconv.Itoa(*params.Page))
|
||||
}
|
||||
if params.Q != nil {
|
||||
q.Set("q", *params.Q)
|
||||
}
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
path += "?" + encoded
|
||||
}
|
||||
}
|
||||
var resp observability.ListJudgesResponse
|
||||
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// GetJudge retrieves a judge by ID.
|
||||
func (c *Client) GetJudge(ctx context.Context, judgeID string) (*observability.Judge, error) {
|
||||
var resp observability.Judge
|
||||
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/judges/%s", judgeID), nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// UpdateJudge updates a judge.
|
||||
func (c *Client) UpdateJudge(ctx context.Context, judgeID string, req *observability.UpdateJudgeRequest) (*observability.Judge, error) {
|
||||
var resp observability.Judge
|
||||
if err := c.doJSON(ctx, "PUT", fmt.Sprintf("/v1/observability/judges/%s", judgeID), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// DeleteJudge deletes a judge.
|
||||
func (c *Client) DeleteJudge(ctx context.Context, judgeID string) error {
|
||||
resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/observability/judges/%s", judgeID), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 400 {
|
||||
return parseAPIError(resp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// JudgeConversation performs live judging on a conversation.
|
||||
func (c *Client) JudgeConversation(ctx context.Context, judgeID string, req *observability.JudgeConversationRequest) (json.RawMessage, error) {
|
||||
var resp json.RawMessage
|
||||
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/judges/%s/live-judging", judgeID), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
123
observability_judges_test.go
Normal file
123
observability_judges_test.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/VikingOwl91/mistral-go-sdk/observability"
|
||||
)
|
||||
|
||||
func judgeJSON() map[string]any {
|
||||
return map[string]any{
|
||||
"id": "j1", "created_at": "t", "updated_at": "t",
|
||||
"owner_id": "o", "workspace_id": "w", "name": "quality",
|
||||
"description": "d", "model_name": "m", "instructions": "i",
|
||||
"output": map[string]any{"type": "CLASSIFICATION", "options": []any{}},
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateJudge_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" || r.URL.Path != "/v1/observability/judges" {
|
||||
t.Errorf("unexpected %s %s", r.Method, r.URL.Path)
|
||||
}
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(judgeJSON())
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.CreateJudge(context.Background(), &observability.CreateJudgeRequest{
|
||||
Name: "quality",
|
||||
Description: "d",
|
||||
ModelName: "m",
|
||||
Instructions: "i",
|
||||
Tools: []string{},
|
||||
Output: json.RawMessage(`{"type":"CLASSIFICATION","options":[]}`),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.ID != "j1" {
|
||||
t.Errorf("got id %q", resp.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListJudges_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"count": 1,
|
||||
"results": []any{judgeJSON()},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.ListJudges(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Count != 1 {
|
||||
t.Errorf("got count %d", resp.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetJudge_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/observability/judges/j1" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(judgeJSON())
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.GetJudge(context.Background(), "j1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Name != "quality" {
|
||||
t.Errorf("got name %q", resp.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateJudge_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "PUT" {
|
||||
t.Errorf("expected PUT, got %s", r.Method)
|
||||
}
|
||||
json.NewEncoder(w).Encode(judgeJSON())
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
_, err := client.UpdateJudge(context.Background(), "j1", &observability.UpdateJudgeRequest{
|
||||
Name: "quality",
|
||||
Description: "d",
|
||||
ModelName: "m",
|
||||
Instructions: "i",
|
||||
Tools: []string{},
|
||||
Output: json.RawMessage(`{"type":"CLASSIFICATION","options":[]}`),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteJudge_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "DELETE" {
|
||||
t.Errorf("expected DELETE")
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
if err := client.DeleteJudge(context.Background(), "j1"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,7 @@ package mistral
|
||||
import (
|
||||
"context"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/ocr"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/ocr"
|
||||
)
|
||||
|
||||
// OCR performs optical character recognition on a document.
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/ocr"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/ocr"
|
||||
)
|
||||
|
||||
func TestOCR_Success(t *testing.T) {
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/chat"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/chat"
|
||||
)
|
||||
|
||||
func TestRetry_429ThenSuccess(t *testing.T) {
|
||||
|
||||
21
workflow/deployment.go
Normal file
21
workflow/deployment.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package workflow
|
||||
|
||||
// Deployment represents a workflow deployment.
|
||||
type Deployment struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
IsActive bool `json:"is_active"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
// DeploymentListResponse is the response from listing deployments.
|
||||
type DeploymentListResponse struct {
|
||||
Deployments []Deployment `json:"deployments"`
|
||||
}
|
||||
|
||||
// DeploymentListParams holds query parameters for listing deployments.
|
||||
type DeploymentListParams struct {
|
||||
ActiveOnly *bool
|
||||
WorkflowName *string
|
||||
}
|
||||
5
workflow/doc.go
Normal file
5
workflow/doc.go
Normal file
@@ -0,0 +1,5 @@
|
||||
// Package workflow provides types for the Mistral workflows API.
|
||||
//
|
||||
// Workflows support orchestrating multi-step processes with execution
|
||||
// management, scheduling, event streaming, and observability.
|
||||
package workflow
|
||||
369
workflow/event.go
Normal file
369
workflow/event.go
Normal file
@@ -0,0 +1,369 @@
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// EventType identifies the kind of workflow event.
|
||||
type EventType string
|
||||
|
||||
const (
|
||||
EventWorkflowStarted EventType = "WORKFLOW_EXECUTION_STARTED"
|
||||
EventWorkflowCompleted EventType = "WORKFLOW_EXECUTION_COMPLETED"
|
||||
EventWorkflowFailed EventType = "WORKFLOW_EXECUTION_FAILED"
|
||||
EventWorkflowCanceled EventType = "WORKFLOW_EXECUTION_CANCELED"
|
||||
EventWorkflowContinuedAsNew EventType = "WORKFLOW_EXECUTION_CONTINUED_AS_NEW"
|
||||
EventWorkflowTaskTimedOut EventType = "WORKFLOW_TASK_TIMED_OUT"
|
||||
EventWorkflowTaskFailed EventType = "WORKFLOW_TASK_FAILED"
|
||||
EventCustomTaskStarted EventType = "CUSTOM_TASK_STARTED"
|
||||
EventCustomTaskInProgress EventType = "CUSTOM_TASK_IN_PROGRESS"
|
||||
EventCustomTaskCompleted EventType = "CUSTOM_TASK_COMPLETED"
|
||||
EventCustomTaskFailed EventType = "CUSTOM_TASK_FAILED"
|
||||
EventCustomTaskTimedOut EventType = "CUSTOM_TASK_TIMED_OUT"
|
||||
EventCustomTaskCanceled EventType = "CUSTOM_TASK_CANCELED"
|
||||
EventActivityTaskStarted EventType = "ACTIVITY_TASK_STARTED"
|
||||
EventActivityTaskCompleted EventType = "ACTIVITY_TASK_COMPLETED"
|
||||
EventActivityTaskRetrying EventType = "ACTIVITY_TASK_RETRYING"
|
||||
EventActivityTaskFailed EventType = "ACTIVITY_TASK_FAILED"
|
||||
)
|
||||
|
||||
// EventSource identifies where an event originated.
|
||||
type EventSource string
|
||||
|
||||
const (
|
||||
EventSourceDatabase EventSource = "DATABASE"
|
||||
EventSourceLive EventSource = "LIVE"
|
||||
)
|
||||
|
||||
// Scope identifies the scope of an event subscription.
|
||||
type Scope string
|
||||
|
||||
const (
|
||||
ScopeActivity Scope = "activity"
|
||||
ScopeWorkflow Scope = "workflow"
|
||||
ScopeAll Scope = "*"
|
||||
)
|
||||
|
||||
// Event is a sealed interface for workflow execution events.
|
||||
type Event interface {
|
||||
workflowEvent()
|
||||
EventType() EventType
|
||||
}
|
||||
|
||||
// eventBase holds fields common to all workflow events.
|
||||
type eventBase struct {
|
||||
ID string `json:"event_id"`
|
||||
Timestamp int64 `json:"event_timestamp"`
|
||||
RootWorkflowExecID string `json:"root_workflow_exec_id"`
|
||||
ParentWorkflowExecID *string `json:"parent_workflow_exec_id"`
|
||||
WorkflowExecID string `json:"workflow_exec_id"`
|
||||
WorkflowRunID string `json:"workflow_run_id"`
|
||||
WorkflowName string `json:"workflow_name"`
|
||||
}
|
||||
|
||||
// WorkflowStartedAttributes holds typed attributes for workflow started events.
|
||||
type WorkflowStartedAttributes struct {
|
||||
TaskID string `json:"task_id"`
|
||||
}
|
||||
|
||||
// JSONPayload holds a typed JSON value.
|
||||
type JSONPayload struct {
|
||||
Value any `json:"value"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// WorkflowCompletedAttributes holds typed attributes for workflow completed events.
|
||||
type WorkflowCompletedAttributes struct {
|
||||
TaskID string `json:"task_id"`
|
||||
Result JSONPayload `json:"result"`
|
||||
}
|
||||
|
||||
// WorkflowFailedAttributes holds typed attributes for workflow failed events.
|
||||
type WorkflowFailedAttributes struct {
|
||||
TaskID string `json:"task_id"`
|
||||
Failure any `json:"failure"`
|
||||
}
|
||||
|
||||
// WorkflowExecutionStartedEvent signals that a workflow execution has started.
|
||||
type WorkflowExecutionStartedEvent struct {
|
||||
eventBase
|
||||
Attributes WorkflowStartedAttributes `json:"attributes"`
|
||||
}
|
||||
|
||||
func (*WorkflowExecutionStartedEvent) workflowEvent() {}
|
||||
func (*WorkflowExecutionStartedEvent) EventType() EventType { return EventWorkflowStarted }
|
||||
|
||||
// WorkflowExecutionCompletedEvent signals that a workflow execution has completed.
|
||||
type WorkflowExecutionCompletedEvent struct {
|
||||
eventBase
|
||||
Attributes WorkflowCompletedAttributes `json:"attributes"`
|
||||
}
|
||||
|
||||
func (*WorkflowExecutionCompletedEvent) workflowEvent() {}
|
||||
func (*WorkflowExecutionCompletedEvent) EventType() EventType { return EventWorkflowCompleted }
|
||||
|
||||
// WorkflowExecutionFailedEvent signals that a workflow execution has failed.
|
||||
type WorkflowExecutionFailedEvent struct {
|
||||
eventBase
|
||||
Attributes WorkflowFailedAttributes `json:"attributes"`
|
||||
}
|
||||
|
||||
func (*WorkflowExecutionFailedEvent) workflowEvent() {}
|
||||
func (*WorkflowExecutionFailedEvent) EventType() EventType { return EventWorkflowFailed }
|
||||
|
||||
// WorkflowExecutionCanceledEvent signals that a workflow execution was canceled.
|
||||
type WorkflowExecutionCanceledEvent struct {
|
||||
eventBase
|
||||
Attributes json.RawMessage `json:"attributes"`
|
||||
}
|
||||
|
||||
func (*WorkflowExecutionCanceledEvent) workflowEvent() {}
|
||||
func (*WorkflowExecutionCanceledEvent) EventType() EventType { return EventWorkflowCanceled }
|
||||
|
||||
// WorkflowExecutionContinuedAsNewEvent signals that a workflow continued as a new execution.
|
||||
type WorkflowExecutionContinuedAsNewEvent struct {
|
||||
eventBase
|
||||
Attributes json.RawMessage `json:"attributes"`
|
||||
}
|
||||
|
||||
func (*WorkflowExecutionContinuedAsNewEvent) workflowEvent() {}
|
||||
func (*WorkflowExecutionContinuedAsNewEvent) EventType() EventType { return EventWorkflowContinuedAsNew }
|
||||
|
||||
// WorkflowTaskTimedOutEvent signals that a workflow task timed out.
|
||||
type WorkflowTaskTimedOutEvent struct {
|
||||
eventBase
|
||||
Attributes json.RawMessage `json:"attributes"`
|
||||
}
|
||||
|
||||
func (*WorkflowTaskTimedOutEvent) workflowEvent() {}
|
||||
func (*WorkflowTaskTimedOutEvent) EventType() EventType { return EventWorkflowTaskTimedOut }
|
||||
|
||||
// WorkflowTaskFailedEvent signals that a workflow task failed.
|
||||
type WorkflowTaskFailedEvent struct {
|
||||
eventBase
|
||||
Attributes json.RawMessage `json:"attributes"`
|
||||
}
|
||||
|
||||
func (*WorkflowTaskFailedEvent) workflowEvent() {}
|
||||
func (*WorkflowTaskFailedEvent) EventType() EventType { return EventWorkflowTaskFailed }
|
||||
|
||||
// CustomTaskStartedEvent signals that a custom task has started.
|
||||
type CustomTaskStartedEvent struct {
|
||||
eventBase
|
||||
Attributes json.RawMessage `json:"attributes"`
|
||||
}
|
||||
|
||||
func (*CustomTaskStartedEvent) workflowEvent() {}
|
||||
func (*CustomTaskStartedEvent) EventType() EventType { return EventCustomTaskStarted }
|
||||
|
||||
// CustomTaskInProgressEvent signals that a custom task is in progress.
|
||||
type CustomTaskInProgressEvent struct {
|
||||
eventBase
|
||||
Attributes json.RawMessage `json:"attributes"`
|
||||
}
|
||||
|
||||
func (*CustomTaskInProgressEvent) workflowEvent() {}
|
||||
func (*CustomTaskInProgressEvent) EventType() EventType { return EventCustomTaskInProgress }
|
||||
|
||||
// CustomTaskCompletedEvent signals that a custom task has completed.
|
||||
type CustomTaskCompletedEvent struct {
|
||||
eventBase
|
||||
Attributes json.RawMessage `json:"attributes"`
|
||||
}
|
||||
|
||||
func (*CustomTaskCompletedEvent) workflowEvent() {}
|
||||
func (*CustomTaskCompletedEvent) EventType() EventType { return EventCustomTaskCompleted }
|
||||
|
||||
// CustomTaskFailedEvent signals that a custom task has failed.
|
||||
type CustomTaskFailedEvent struct {
|
||||
eventBase
|
||||
Attributes json.RawMessage `json:"attributes"`
|
||||
}
|
||||
|
||||
func (*CustomTaskFailedEvent) workflowEvent() {}
|
||||
func (*CustomTaskFailedEvent) EventType() EventType { return EventCustomTaskFailed }
|
||||
|
||||
// CustomTaskTimedOutEvent signals that a custom task timed out.
|
||||
type CustomTaskTimedOutEvent struct {
|
||||
eventBase
|
||||
Attributes json.RawMessage `json:"attributes"`
|
||||
}
|
||||
|
||||
func (*CustomTaskTimedOutEvent) workflowEvent() {}
|
||||
func (*CustomTaskTimedOutEvent) EventType() EventType { return EventCustomTaskTimedOut }
|
||||
|
||||
// CustomTaskCanceledEvent signals that a custom task was canceled.
|
||||
type CustomTaskCanceledEvent struct {
|
||||
eventBase
|
||||
Attributes json.RawMessage `json:"attributes"`
|
||||
}
|
||||
|
||||
func (*CustomTaskCanceledEvent) workflowEvent() {}
|
||||
func (*CustomTaskCanceledEvent) EventType() EventType { return EventCustomTaskCanceled }
|
||||
|
||||
// ActivityTaskStartedEvent signals that an activity task has started.
|
||||
type ActivityTaskStartedEvent struct {
|
||||
eventBase
|
||||
Attributes json.RawMessage `json:"attributes"`
|
||||
}
|
||||
|
||||
func (*ActivityTaskStartedEvent) workflowEvent() {}
|
||||
func (*ActivityTaskStartedEvent) EventType() EventType { return EventActivityTaskStarted }
|
||||
|
||||
// ActivityTaskCompletedEvent signals that an activity task has completed.
|
||||
type ActivityTaskCompletedEvent struct {
|
||||
eventBase
|
||||
Attributes json.RawMessage `json:"attributes"`
|
||||
}
|
||||
|
||||
func (*ActivityTaskCompletedEvent) workflowEvent() {}
|
||||
func (*ActivityTaskCompletedEvent) EventType() EventType { return EventActivityTaskCompleted }
|
||||
|
||||
// ActivityTaskRetryingEvent signals that an activity task is being retried.
|
||||
type ActivityTaskRetryingEvent struct {
|
||||
eventBase
|
||||
Attributes json.RawMessage `json:"attributes"`
|
||||
}
|
||||
|
||||
func (*ActivityTaskRetryingEvent) workflowEvent() {}
|
||||
func (*ActivityTaskRetryingEvent) EventType() EventType { return EventActivityTaskRetrying }
|
||||
|
||||
// ActivityTaskFailedEvent signals that an activity task has failed.
|
||||
type ActivityTaskFailedEvent struct {
|
||||
eventBase
|
||||
Attributes json.RawMessage `json:"attributes"`
|
||||
}
|
||||
|
||||
func (*ActivityTaskFailedEvent) workflowEvent() {}
|
||||
func (*ActivityTaskFailedEvent) EventType() EventType { return EventActivityTaskFailed }
|
||||
|
||||
// UnknownEvent holds an event with an unrecognized event_type.
|
||||
// This prevents the SDK from breaking when new event types are added.
|
||||
type UnknownEvent struct {
|
||||
eventBase
|
||||
RawType string
|
||||
Raw json.RawMessage
|
||||
}
|
||||
|
||||
func (*UnknownEvent) workflowEvent() {}
|
||||
func (e *UnknownEvent) EventType() EventType { return EventType(e.RawType) }
|
||||
|
||||
// UnmarshalEvent dispatches JSON to the concrete Event type
|
||||
// based on the "event_type" discriminator field.
|
||||
func UnmarshalEvent(data []byte) (Event, error) {
|
||||
var probe struct {
|
||||
Type string `json:"event_type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &probe); err != nil {
|
||||
return nil, fmt.Errorf("mistral: unmarshal workflow event: %w", err)
|
||||
}
|
||||
switch probe.Type {
|
||||
case string(EventWorkflowStarted):
|
||||
var e WorkflowExecutionStartedEvent
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
case string(EventWorkflowCompleted):
|
||||
var e WorkflowExecutionCompletedEvent
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
case string(EventWorkflowFailed):
|
||||
var e WorkflowExecutionFailedEvent
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
case string(EventWorkflowCanceled):
|
||||
var e WorkflowExecutionCanceledEvent
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
case string(EventWorkflowContinuedAsNew):
|
||||
var e WorkflowExecutionContinuedAsNewEvent
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
case string(EventWorkflowTaskTimedOut):
|
||||
var e WorkflowTaskTimedOutEvent
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
case string(EventWorkflowTaskFailed):
|
||||
var e WorkflowTaskFailedEvent
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
case string(EventCustomTaskStarted):
|
||||
var e CustomTaskStartedEvent
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
case string(EventCustomTaskInProgress):
|
||||
var e CustomTaskInProgressEvent
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
case string(EventCustomTaskCompleted):
|
||||
var e CustomTaskCompletedEvent
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
case string(EventCustomTaskFailed):
|
||||
var e CustomTaskFailedEvent
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
case string(EventCustomTaskTimedOut):
|
||||
var e CustomTaskTimedOutEvent
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
case string(EventCustomTaskCanceled):
|
||||
var e CustomTaskCanceledEvent
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
case string(EventActivityTaskStarted):
|
||||
var e ActivityTaskStartedEvent
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
case string(EventActivityTaskCompleted):
|
||||
var e ActivityTaskCompletedEvent
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
case string(EventActivityTaskRetrying):
|
||||
var e ActivityTaskRetryingEvent
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
case string(EventActivityTaskFailed):
|
||||
var e ActivityTaskFailedEvent
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
default:
|
||||
var base eventBase
|
||||
if err := json.Unmarshal(data, &base); err != nil {
|
||||
return nil, fmt.Errorf("mistral: unmarshal workflow event base: %w", err)
|
||||
}
|
||||
return &UnknownEvent{
|
||||
eventBase: base,
|
||||
RawType: probe.Type,
|
||||
Raw: json.RawMessage(data),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// StreamPayload is a single SSE payload from the workflow event stream.
|
||||
type StreamPayload struct {
|
||||
Stream string `json:"stream"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
WorkflowContext StreamWorkflowContext `json:"workflow_context"`
|
||||
BrokerSequence int64 `json:"broker_sequence"`
|
||||
}
|
||||
|
||||
// StreamWorkflowContext holds context for a workflow event stream.
|
||||
type StreamWorkflowContext struct {
|
||||
WorkflowName string `json:"workflow_name"`
|
||||
ExecutionID string `json:"execution_id"`
|
||||
}
|
||||
|
||||
// EventStreamParams holds query parameters for streaming workflow events.
|
||||
type EventStreamParams struct {
|
||||
Scope *Scope
|
||||
ActivityName *string
|
||||
ActivityID *string
|
||||
WorkflowName *string
|
||||
WorkflowExecID *string
|
||||
RootWorkflowExecID *string
|
||||
ParentWorkflowExecID *string
|
||||
Stream *string
|
||||
StartSeq *int
|
||||
MetadataFilters map[string]any
|
||||
WorkflowEventTypes []EventType
|
||||
LastEventID *string
|
||||
}
|
||||
|
||||
// EventListParams holds query parameters for listing workflow events.
|
||||
type EventListParams struct {
|
||||
RootWorkflowExecID *string
|
||||
WorkflowExecID *string
|
||||
WorkflowRunID *string
|
||||
Limit *int
|
||||
Cursor *string
|
||||
}
|
||||
|
||||
// EventListResponse is the response from listing workflow events.
|
||||
type EventListResponse struct {
|
||||
Events []json.RawMessage `json:"events"`
|
||||
NextCursor *string `json:"next_cursor,omitempty"`
|
||||
}
|
||||
158
workflow/event_test.go
Normal file
158
workflow/event_test.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUnmarshalEvent_WorkflowExecutionCompleted(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"event_id": "evt-1",
|
||||
"event_timestamp": 1711929600000000000,
|
||||
"root_workflow_exec_id": "root-1",
|
||||
"parent_workflow_exec_id": null,
|
||||
"workflow_exec_id": "exec-1",
|
||||
"workflow_run_id": "run-1",
|
||||
"workflow_name": "my-workflow",
|
||||
"event_type": "WORKFLOW_EXECUTION_COMPLETED",
|
||||
"attributes": {"task_id": "t1", "result": {"value": {"answer": 42}, "type": "json"}}
|
||||
}`)
|
||||
event, err := UnmarshalEvent(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
e, ok := event.(*WorkflowExecutionCompletedEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected *WorkflowExecutionCompletedEvent, got %T", event)
|
||||
}
|
||||
if e.ID != "evt-1" {
|
||||
t.Errorf("got ID %q", e.ID)
|
||||
}
|
||||
if e.WorkflowName != "my-workflow" {
|
||||
t.Errorf("got WorkflowName %q", e.WorkflowName)
|
||||
}
|
||||
if e.EventType() != EventWorkflowCompleted {
|
||||
t.Errorf("got EventType %q", e.EventType())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalEvent_CustomTaskStarted(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"event_id": "evt-2",
|
||||
"event_timestamp": 1711929600000000000,
|
||||
"root_workflow_exec_id": "root-1",
|
||||
"parent_workflow_exec_id": "parent-1",
|
||||
"workflow_exec_id": "exec-1",
|
||||
"workflow_run_id": "run-1",
|
||||
"workflow_name": "my-workflow",
|
||||
"event_type": "CUSTOM_TASK_STARTED",
|
||||
"attributes": {}
|
||||
}`)
|
||||
event, err := UnmarshalEvent(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
e, ok := event.(*CustomTaskStartedEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected *CustomTaskStartedEvent, got %T", event)
|
||||
}
|
||||
parent := "parent-1"
|
||||
if e.ParentWorkflowExecID == nil || *e.ParentWorkflowExecID != parent {
|
||||
t.Errorf("expected parent %q, got %v", parent, e.ParentWorkflowExecID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalEvent_ActivityTaskRetrying(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"event_id": "evt-3",
|
||||
"event_timestamp": 1711929600000000000,
|
||||
"root_workflow_exec_id": "root-1",
|
||||
"parent_workflow_exec_id": null,
|
||||
"workflow_exec_id": "exec-1",
|
||||
"workflow_run_id": "run-1",
|
||||
"workflow_name": "my-workflow",
|
||||
"event_type": "ACTIVITY_TASK_RETRYING",
|
||||
"attributes": {}
|
||||
}`)
|
||||
event, err := UnmarshalEvent(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, ok := event.(*ActivityTaskRetryingEvent); !ok {
|
||||
t.Fatalf("expected *ActivityTaskRetryingEvent, got %T", event)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalEvent_UnknownType(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"event_id": "evt-4",
|
||||
"event_timestamp": 1711929600000000000,
|
||||
"root_workflow_exec_id": "root-1",
|
||||
"parent_workflow_exec_id": null,
|
||||
"workflow_exec_id": "exec-1",
|
||||
"workflow_run_id": "run-1",
|
||||
"workflow_name": "my-workflow",
|
||||
"event_type": "FUTURE_EVENT_TYPE",
|
||||
"attributes": {}
|
||||
}`)
|
||||
event, err := UnmarshalEvent(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
unk, ok := event.(*UnknownEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected *UnknownEvent, got %T", event)
|
||||
}
|
||||
if unk.RawType != "FUTURE_EVENT_TYPE" {
|
||||
t.Errorf("got RawType %q", unk.RawType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalEvent_AllTypes(t *testing.T) {
|
||||
types := []struct {
|
||||
eventType string
|
||||
wantType string
|
||||
}{
|
||||
{"WORKFLOW_EXECUTION_STARTED", "*workflow.WorkflowExecutionStartedEvent"},
|
||||
{"WORKFLOW_EXECUTION_COMPLETED", "*workflow.WorkflowExecutionCompletedEvent"},
|
||||
{"WORKFLOW_EXECUTION_FAILED", "*workflow.WorkflowExecutionFailedEvent"},
|
||||
{"WORKFLOW_EXECUTION_CANCELED", "*workflow.WorkflowExecutionCanceledEvent"},
|
||||
{"WORKFLOW_EXECUTION_CONTINUED_AS_NEW", "*workflow.WorkflowExecutionContinuedAsNewEvent"},
|
||||
{"WORKFLOW_TASK_TIMED_OUT", "*workflow.WorkflowTaskTimedOutEvent"},
|
||||
{"WORKFLOW_TASK_FAILED", "*workflow.WorkflowTaskFailedEvent"},
|
||||
{"CUSTOM_TASK_STARTED", "*workflow.CustomTaskStartedEvent"},
|
||||
{"CUSTOM_TASK_IN_PROGRESS", "*workflow.CustomTaskInProgressEvent"},
|
||||
{"CUSTOM_TASK_COMPLETED", "*workflow.CustomTaskCompletedEvent"},
|
||||
{"CUSTOM_TASK_FAILED", "*workflow.CustomTaskFailedEvent"},
|
||||
{"CUSTOM_TASK_TIMED_OUT", "*workflow.CustomTaskTimedOutEvent"},
|
||||
{"CUSTOM_TASK_CANCELED", "*workflow.CustomTaskCanceledEvent"},
|
||||
{"ACTIVITY_TASK_STARTED", "*workflow.ActivityTaskStartedEvent"},
|
||||
{"ACTIVITY_TASK_COMPLETED", "*workflow.ActivityTaskCompletedEvent"},
|
||||
{"ACTIVITY_TASK_RETRYING", "*workflow.ActivityTaskRetryingEvent"},
|
||||
{"ACTIVITY_TASK_FAILED", "*workflow.ActivityTaskFailedEvent"},
|
||||
}
|
||||
for _, tt := range types {
|
||||
t.Run(tt.eventType, func(t *testing.T) {
|
||||
data, _ := json.Marshal(map[string]any{
|
||||
"event_id": "evt",
|
||||
"event_timestamp": 1711929600000000000,
|
||||
"root_workflow_exec_id": "root",
|
||||
"parent_workflow_exec_id": nil,
|
||||
"workflow_exec_id": "exec",
|
||||
"workflow_run_id": "run",
|
||||
"workflow_name": "wf",
|
||||
"event_type": tt.eventType,
|
||||
"attributes": map[string]any{},
|
||||
})
|
||||
event, err := UnmarshalEvent(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got := fmt.Sprintf("%T", event)
|
||||
if got != tt.wantType {
|
||||
t.Errorf("event_type %q: got %s, want %s", tt.eventType, got, tt.wantType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
163
workflow/execution.go
Normal file
163
workflow/execution.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package workflow
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// ExecutionStatus is the status of a workflow execution.
|
||||
type ExecutionStatus string
|
||||
|
||||
const (
|
||||
ExecutionRunning ExecutionStatus = "RUNNING"
|
||||
ExecutionCompleted ExecutionStatus = "COMPLETED"
|
||||
ExecutionFailed ExecutionStatus = "FAILED"
|
||||
ExecutionCanceled ExecutionStatus = "CANCELED"
|
||||
ExecutionTerminated ExecutionStatus = "TERMINATED"
|
||||
ExecutionContinuedAsNew ExecutionStatus = "CONTINUED_AS_NEW"
|
||||
ExecutionTimedOut ExecutionStatus = "TIMED_OUT"
|
||||
ExecutionRetryingAfterErr ExecutionStatus = "RETRYING_AFTER_ERROR"
|
||||
)
|
||||
|
||||
// ExecutionRequest is the request body for executing a workflow.
|
||||
type ExecutionRequest struct {
|
||||
ExecutionID *string `json:"execution_id,omitempty"`
|
||||
Input map[string]any `json:"input,omitempty"`
|
||||
EncodedInput *NetworkEncodedInput `json:"encoded_input,omitempty"`
|
||||
WaitForResult bool `json:"wait_for_result,omitempty"`
|
||||
TimeoutSeconds *float64 `json:"timeout_seconds,omitempty"`
|
||||
CustomTracingAttributes map[string]string `json:"custom_tracing_attributes,omitempty"`
|
||||
DeploymentName *string `json:"deployment_name,omitempty"`
|
||||
}
|
||||
|
||||
// ExecutionResponse is the response from a workflow execution.
|
||||
type ExecutionResponse struct {
|
||||
WorkflowName string `json:"workflow_name"`
|
||||
ExecutionID string `json:"execution_id"`
|
||||
RootExecutionID string `json:"root_execution_id"`
|
||||
Status ExecutionStatus `json:"status"`
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime *string `json:"end_time,omitempty"`
|
||||
Result any `json:"result,omitempty"`
|
||||
ParentExecutionID *string `json:"parent_execution_id,omitempty"`
|
||||
TotalDurationMs *int `json:"total_duration_ms,omitempty"`
|
||||
}
|
||||
|
||||
// NetworkEncodedInput holds a base64-encoded payload for workflow input.
|
||||
type NetworkEncodedInput struct {
|
||||
B64Payload string `json:"b64payload"`
|
||||
EncodingOptions []string `json:"encoding_options,omitempty"`
|
||||
Empty bool `json:"empty,omitempty"`
|
||||
}
|
||||
|
||||
// SignalInvocationBody is the request body for signaling a workflow execution.
|
||||
type SignalInvocationBody struct {
|
||||
Name string `json:"name"`
|
||||
Input any `json:"input"`
|
||||
}
|
||||
|
||||
// SignalResponse is the response from signaling a workflow execution.
|
||||
type SignalResponse struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// QueryInvocationBody is the request body for querying a workflow execution.
|
||||
type QueryInvocationBody struct {
|
||||
Name string `json:"name"`
|
||||
Input any `json:"input,omitempty"`
|
||||
}
|
||||
|
||||
// QueryResponse is the response from querying a workflow execution.
|
||||
type QueryResponse struct {
|
||||
QueryName string `json:"query_name"`
|
||||
Result any `json:"result"`
|
||||
}
|
||||
|
||||
// UpdateInvocationBody is the request body for updating a workflow execution.
|
||||
type UpdateInvocationBody struct {
|
||||
Name string `json:"name"`
|
||||
Input any `json:"input,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateResponse is the response from updating a workflow execution.
|
||||
type UpdateResponse struct {
|
||||
UpdateName string `json:"update_name"`
|
||||
Result any `json:"result"`
|
||||
}
|
||||
|
||||
// ResetInvocationBody is the request body for resetting a workflow execution.
|
||||
type ResetInvocationBody struct {
|
||||
EventID int `json:"event_id"`
|
||||
Reason *string `json:"reason,omitempty"`
|
||||
ExcludeSignals bool `json:"exclude_signals,omitempty"`
|
||||
ExcludeUpdates bool `json:"exclude_updates,omitempty"`
|
||||
}
|
||||
|
||||
// BatchExecutionBody is the request body for batch execution operations.
|
||||
type BatchExecutionBody struct {
|
||||
ExecutionIDs []string `json:"execution_ids"`
|
||||
}
|
||||
|
||||
// BatchExecutionResponse is the response from batch execution operations.
|
||||
type BatchExecutionResponse struct {
|
||||
Results map[string]BatchExecutionResult `json:"results,omitempty"`
|
||||
}
|
||||
|
||||
// BatchExecutionResult is the result of a single batch operation.
|
||||
type BatchExecutionResult struct {
|
||||
Status string `json:"status"`
|
||||
Error *string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// StreamParams holds query parameters for streaming workflow executions.
|
||||
type StreamParams struct {
|
||||
EventSource *EventSource
|
||||
LastEventID *string
|
||||
}
|
||||
|
||||
// TraceOTelResponse is the response from the OTel trace endpoint.
|
||||
type TraceOTelResponse struct {
|
||||
WorkflowName string `json:"workflow_name"`
|
||||
ExecutionID string `json:"execution_id"`
|
||||
RootExecutionID string `json:"root_execution_id"`
|
||||
Status *ExecutionStatus `json:"status"`
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime *string `json:"end_time,omitempty"`
|
||||
Result any `json:"result"`
|
||||
DataSource string `json:"data_source"`
|
||||
ParentExecutionID *string `json:"parent_execution_id,omitempty"`
|
||||
TotalDurationMs *int `json:"total_duration_ms,omitempty"`
|
||||
OTelTraceID *string `json:"otel_trace_id,omitempty"`
|
||||
OTelTraceData any `json:"otel_trace_data,omitempty"`
|
||||
}
|
||||
|
||||
// TraceSummaryResponse is the response from the trace summary endpoint.
|
||||
type TraceSummaryResponse struct {
|
||||
WorkflowName string `json:"workflow_name"`
|
||||
ExecutionID string `json:"execution_id"`
|
||||
RootExecutionID string `json:"root_execution_id"`
|
||||
Status *ExecutionStatus `json:"status"`
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime *string `json:"end_time,omitempty"`
|
||||
Result any `json:"result"`
|
||||
ParentExecutionID *string `json:"parent_execution_id,omitempty"`
|
||||
TotalDurationMs *int `json:"total_duration_ms,omitempty"`
|
||||
SpanTree any `json:"span_tree,omitempty"`
|
||||
}
|
||||
|
||||
// TraceEventsResponse is the response from the trace events endpoint.
|
||||
type TraceEventsResponse struct {
|
||||
WorkflowName string `json:"workflow_name"`
|
||||
ExecutionID string `json:"execution_id"`
|
||||
RootExecutionID string `json:"root_execution_id"`
|
||||
Status *ExecutionStatus `json:"status"`
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime *string `json:"end_time,omitempty"`
|
||||
Result any `json:"result"`
|
||||
ParentExecutionID *string `json:"parent_execution_id,omitempty"`
|
||||
TotalDurationMs *int `json:"total_duration_ms,omitempty"`
|
||||
Events []json.RawMessage `json:"events,omitempty"`
|
||||
}
|
||||
|
||||
// TraceEventsParams holds query parameters for the trace events endpoint.
|
||||
type TraceEventsParams struct {
|
||||
MergeSameIDEvents *bool
|
||||
IncludeInternalEvents *bool
|
||||
}
|
||||
27
workflow/metrics.go
Normal file
27
workflow/metrics.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package workflow
|
||||
|
||||
// Metrics holds workflow performance metrics.
|
||||
type Metrics struct {
|
||||
ExecutionCount ScalarMetric `json:"execution_count"`
|
||||
SuccessCount ScalarMetric `json:"success_count"`
|
||||
ErrorCount ScalarMetric `json:"error_count"`
|
||||
AverageLatencyMs ScalarMetric `json:"average_latency_ms"`
|
||||
LatencyOverTime TimeSeriesMetric `json:"latency_over_time"`
|
||||
RetryRate ScalarMetric `json:"retry_rate"`
|
||||
}
|
||||
|
||||
// ScalarMetric holds a single numeric metric value.
|
||||
type ScalarMetric struct {
|
||||
Value float64 `json:"value"`
|
||||
}
|
||||
|
||||
// TimeSeriesMetric holds a time series of [timestamp, value] pairs.
|
||||
type TimeSeriesMetric struct {
|
||||
Value [][]float64 `json:"value"`
|
||||
}
|
||||
|
||||
// MetricsParams holds query parameters for workflow metrics.
|
||||
type MetricsParams struct {
|
||||
StartTime *string
|
||||
EndTime *string
|
||||
}
|
||||
44
workflow/registration.go
Normal file
44
workflow/registration.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package workflow
|
||||
|
||||
// Registration represents a workflow registration.
|
||||
type Registration struct {
|
||||
ID string `json:"id"`
|
||||
WorkflowID string `json:"workflow_id"`
|
||||
TaskQueue string `json:"task_queue"`
|
||||
Workflow *Workflow `json:"workflow,omitempty"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
// RegistrationListResponse is the response from listing workflow registrations.
|
||||
type RegistrationListResponse struct {
|
||||
Registrations []Registration `json:"registrations"`
|
||||
NextCursor *string `json:"next_cursor,omitempty"`
|
||||
}
|
||||
|
||||
// RegistrationListParams holds query parameters for listing registrations.
|
||||
type RegistrationListParams struct {
|
||||
WorkflowID *string
|
||||
TaskQueue *string
|
||||
ActiveOnly *bool
|
||||
IncludeShared *bool
|
||||
WorkflowSearch *string
|
||||
Archived *bool
|
||||
WithWorkflow *bool
|
||||
AvailableInChatAssistant *bool
|
||||
Limit *int
|
||||
Cursor *string
|
||||
}
|
||||
|
||||
// RegistrationGetParams holds query parameters for getting a registration.
|
||||
type RegistrationGetParams struct {
|
||||
WithWorkflow *bool
|
||||
IncludeShared *bool
|
||||
}
|
||||
|
||||
// WorkerInfo holds information about the current worker.
|
||||
type WorkerInfo struct {
|
||||
SchedulerURL string `json:"scheduler_url"`
|
||||
Namespace string `json:"namespace"`
|
||||
TLS bool `json:"tls"`
|
||||
}
|
||||
26
workflow/run.go
Normal file
26
workflow/run.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package workflow
|
||||
|
||||
// Run represents a workflow run.
|
||||
type Run struct {
|
||||
ID string `json:"id"`
|
||||
WorkflowName string `json:"workflow_name"`
|
||||
ExecutionID string `json:"execution_id"`
|
||||
Status ExecutionStatus `json:"status"`
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime *string `json:"end_time,omitempty"`
|
||||
}
|
||||
|
||||
// ListRunsResponse is the response from listing workflow runs.
|
||||
type ListRunsResponse struct {
|
||||
Runs []Run `json:"runs"`
|
||||
NextPageToken *string `json:"next_page_token,omitempty"`
|
||||
}
|
||||
|
||||
// RunListParams holds query parameters for listing workflow runs.
|
||||
type RunListParams struct {
|
||||
WorkflowIdentifier *string
|
||||
Search *string
|
||||
Status *string
|
||||
PageSize *int
|
||||
NextPageToken *string
|
||||
}
|
||||
75
workflow/schedule.go
Normal file
75
workflow/schedule.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package workflow
|
||||
|
||||
// ScheduleRequest is the request body for scheduling a workflow.
|
||||
type ScheduleRequest struct {
|
||||
Schedule ScheduleDefinition `json:"schedule"`
|
||||
WorkflowRegistrationID *string `json:"workflow_registration_id,omitempty"`
|
||||
WorkflowIdentifier *string `json:"workflow_identifier,omitempty"`
|
||||
ScheduleID *string `json:"schedule_id,omitempty"`
|
||||
DeploymentName *string `json:"deployment_name,omitempty"`
|
||||
}
|
||||
|
||||
// ScheduleDefinition describes when and how a workflow should be scheduled.
|
||||
type ScheduleDefinition struct {
|
||||
Input any `json:"input"`
|
||||
Calendars []ScheduleCalendar `json:"calendars,omitempty"`
|
||||
Intervals []ScheduleInterval `json:"intervals,omitempty"`
|
||||
CronExpressions []string `json:"cron_expressions,omitempty"`
|
||||
Skip []ScheduleCalendar `json:"skip,omitempty"`
|
||||
StartAt *string `json:"start_at,omitempty"`
|
||||
EndAt *string `json:"end_at,omitempty"`
|
||||
Jitter *string `json:"jitter,omitempty"`
|
||||
TimeZoneName *string `json:"time_zone_name,omitempty"`
|
||||
Policy *SchedulePolicy `json:"policy,omitempty"`
|
||||
}
|
||||
|
||||
// ScheduleCalendar defines calendar-based schedule entries.
|
||||
type ScheduleCalendar struct {
|
||||
Second []ScheduleRange `json:"second,omitempty"`
|
||||
Minute []ScheduleRange `json:"minute,omitempty"`
|
||||
Hour []ScheduleRange `json:"hour,omitempty"`
|
||||
DayOfMonth []ScheduleRange `json:"day_of_month,omitempty"`
|
||||
Month []ScheduleRange `json:"month,omitempty"`
|
||||
Year []ScheduleRange `json:"year,omitempty"`
|
||||
DayOfWeek []ScheduleRange `json:"day_of_week,omitempty"`
|
||||
Comment *string `json:"comment,omitempty"`
|
||||
}
|
||||
|
||||
// ScheduleRange defines a numeric range for calendar schedules.
|
||||
type ScheduleRange struct {
|
||||
Start int `json:"start"`
|
||||
End int `json:"end,omitempty"`
|
||||
Step int `json:"step,omitempty"`
|
||||
}
|
||||
|
||||
// ScheduleInterval defines an interval-based schedule.
|
||||
type ScheduleInterval struct {
|
||||
Every string `json:"every"`
|
||||
Offset *string `json:"offset,omitempty"`
|
||||
}
|
||||
|
||||
// SchedulePolicy controls schedule overlap and failure behavior.
|
||||
type SchedulePolicy struct {
|
||||
CatchupWindowSeconds int `json:"catchup_window_seconds,omitempty"`
|
||||
Overlap *int `json:"overlap,omitempty"`
|
||||
PauseOnFailure bool `json:"pause_on_failure,omitempty"`
|
||||
}
|
||||
|
||||
// ScheduleResponse is the response from creating a workflow schedule.
|
||||
type ScheduleResponse struct {
|
||||
ScheduleID string `json:"schedule_id"`
|
||||
}
|
||||
|
||||
// ScheduleListResponse is the response from listing workflow schedules.
|
||||
type ScheduleListResponse struct {
|
||||
Schedules []Schedule `json:"schedules"`
|
||||
}
|
||||
|
||||
// Schedule represents a workflow schedule.
|
||||
type Schedule struct {
|
||||
ScheduleID string `json:"schedule_id"`
|
||||
Definition ScheduleDefinition `json:"definition"`
|
||||
WorkflowName string `json:"workflow_name"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
44
workflow/workflow.go
Normal file
44
workflow/workflow.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package workflow
|
||||
|
||||
// Workflow represents a workflow definition.
|
||||
type Workflow struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
DisplayName *string `json:"display_name,omitempty"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
OwnerID string `json:"owner_id"`
|
||||
WorkspaceID string `json:"workspace_id"`
|
||||
AvailableInChatAssistant bool `json:"available_in_chat_assistant"`
|
||||
Archived bool `json:"archived"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
// WorkflowUpdateRequest is the request body for updating a workflow.
|
||||
type WorkflowUpdateRequest struct {
|
||||
DisplayName *string `json:"display_name,omitempty"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
AvailableInChatAssistant *bool `json:"available_in_chat_assistant,omitempty"`
|
||||
}
|
||||
|
||||
// WorkflowListResponse is the response from listing workflows.
|
||||
type WorkflowListResponse struct {
|
||||
Workflows []Workflow `json:"workflows"`
|
||||
NextCursor *string `json:"next_cursor,omitempty"`
|
||||
}
|
||||
|
||||
// WorkflowListParams holds query parameters for listing workflows.
|
||||
type WorkflowListParams struct {
|
||||
ActiveOnly *bool
|
||||
IncludeShared *bool
|
||||
AvailableInChatAssistant *bool
|
||||
Archived *bool
|
||||
Cursor *string
|
||||
Limit *int
|
||||
}
|
||||
|
||||
// WorkflowArchiveResponse is the response from archiving/unarchiving a workflow.
|
||||
type WorkflowArchiveResponse struct {
|
||||
ID string `json:"id"`
|
||||
Archived bool `json:"archived"`
|
||||
}
|
||||
168
workflows.go
Normal file
168
workflows.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"github.com/VikingOwl91/mistral-go-sdk/workflow"
|
||||
)
|
||||
|
||||
// ListWorkflows lists workflows.
|
||||
func (c *Client) ListWorkflows(ctx context.Context, params *workflow.WorkflowListParams) (*workflow.WorkflowListResponse, error) {
|
||||
path := "/v1/workflows"
|
||||
if params != nil {
|
||||
q := url.Values{}
|
||||
if params.ActiveOnly != nil {
|
||||
q.Set("active_only", strconv.FormatBool(*params.ActiveOnly))
|
||||
}
|
||||
if params.IncludeShared != nil {
|
||||
q.Set("include_shared", strconv.FormatBool(*params.IncludeShared))
|
||||
}
|
||||
if params.AvailableInChatAssistant != nil {
|
||||
q.Set("available_in_chat_assistant", strconv.FormatBool(*params.AvailableInChatAssistant))
|
||||
}
|
||||
if params.Archived != nil {
|
||||
q.Set("archived", strconv.FormatBool(*params.Archived))
|
||||
}
|
||||
if params.Cursor != nil {
|
||||
q.Set("cursor", *params.Cursor)
|
||||
}
|
||||
if params.Limit != nil {
|
||||
q.Set("limit", strconv.Itoa(*params.Limit))
|
||||
}
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
path += "?" + encoded
|
||||
}
|
||||
}
|
||||
var resp workflow.WorkflowListResponse
|
||||
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// GetWorkflow retrieves a workflow by identifier.
|
||||
func (c *Client) GetWorkflow(ctx context.Context, workflowIdentifier string) (*workflow.Workflow, error) {
|
||||
var resp workflow.Workflow
|
||||
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/workflows/%s", workflowIdentifier), nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// UpdateWorkflow updates a workflow.
|
||||
func (c *Client) UpdateWorkflow(ctx context.Context, workflowIdentifier string, req *workflow.WorkflowUpdateRequest) (*workflow.Workflow, error) {
|
||||
var resp workflow.Workflow
|
||||
if err := c.doJSON(ctx, "PUT", fmt.Sprintf("/v1/workflows/%s", workflowIdentifier), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ArchiveWorkflow archives a workflow.
|
||||
func (c *Client) ArchiveWorkflow(ctx context.Context, workflowIdentifier string) (*workflow.WorkflowArchiveResponse, error) {
|
||||
var resp workflow.WorkflowArchiveResponse
|
||||
if err := c.doJSON(ctx, "PUT", fmt.Sprintf("/v1/workflows/%s/archive", workflowIdentifier), nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// UnarchiveWorkflow unarchives a workflow.
|
||||
func (c *Client) UnarchiveWorkflow(ctx context.Context, workflowIdentifier string) (*workflow.WorkflowArchiveResponse, error) {
|
||||
var resp workflow.WorkflowArchiveResponse
|
||||
if err := c.doJSON(ctx, "PUT", fmt.Sprintf("/v1/workflows/%s/unarchive", workflowIdentifier), nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ExecuteWorkflow executes a workflow.
|
||||
func (c *Client) ExecuteWorkflow(ctx context.Context, workflowIdentifier string, req *workflow.ExecutionRequest) (*workflow.ExecutionResponse, error) {
|
||||
var resp workflow.ExecutionResponse
|
||||
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/workflows/%s/execute", workflowIdentifier), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ListWorkflowRegistrations lists workflow registrations.
|
||||
func (c *Client) ListWorkflowRegistrations(ctx context.Context, params *workflow.RegistrationListParams) (*workflow.RegistrationListResponse, error) {
|
||||
path := "/v1/workflows/registrations"
|
||||
if params != nil {
|
||||
q := url.Values{}
|
||||
if params.WorkflowID != nil {
|
||||
q.Set("workflow_id", *params.WorkflowID)
|
||||
}
|
||||
if params.TaskQueue != nil {
|
||||
q.Set("task_queue", *params.TaskQueue)
|
||||
}
|
||||
if params.ActiveOnly != nil {
|
||||
q.Set("active_only", strconv.FormatBool(*params.ActiveOnly))
|
||||
}
|
||||
if params.IncludeShared != nil {
|
||||
q.Set("include_shared", strconv.FormatBool(*params.IncludeShared))
|
||||
}
|
||||
if params.WorkflowSearch != nil {
|
||||
q.Set("workflow_search", *params.WorkflowSearch)
|
||||
}
|
||||
if params.Archived != nil {
|
||||
q.Set("archived", strconv.FormatBool(*params.Archived))
|
||||
}
|
||||
if params.WithWorkflow != nil {
|
||||
q.Set("with_workflow", strconv.FormatBool(*params.WithWorkflow))
|
||||
}
|
||||
if params.AvailableInChatAssistant != nil {
|
||||
q.Set("available_in_chat_assistant", strconv.FormatBool(*params.AvailableInChatAssistant))
|
||||
}
|
||||
if params.Limit != nil {
|
||||
q.Set("limit", strconv.Itoa(*params.Limit))
|
||||
}
|
||||
if params.Cursor != nil {
|
||||
q.Set("cursor", *params.Cursor)
|
||||
}
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
path += "?" + encoded
|
||||
}
|
||||
}
|
||||
var resp workflow.RegistrationListResponse
|
||||
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// GetWorkflowRegistration retrieves a workflow registration by ID.
|
||||
func (c *Client) GetWorkflowRegistration(ctx context.Context, registrationID string, params *workflow.RegistrationGetParams) (*workflow.Registration, error) {
|
||||
path := fmt.Sprintf("/v1/workflows/registrations/%s", registrationID)
|
||||
if params != nil {
|
||||
q := url.Values{}
|
||||
if params.WithWorkflow != nil {
|
||||
q.Set("with_workflow", strconv.FormatBool(*params.WithWorkflow))
|
||||
}
|
||||
if params.IncludeShared != nil {
|
||||
q.Set("include_shared", strconv.FormatBool(*params.IncludeShared))
|
||||
}
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
path += "?" + encoded
|
||||
}
|
||||
}
|
||||
var resp workflow.Registration
|
||||
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ExecuteWorkflowRegistration executes a workflow via its registration.
|
||||
//
|
||||
// Deprecated: Use ExecuteWorkflow instead. This method will be removed in a future release.
|
||||
func (c *Client) ExecuteWorkflowRegistration(ctx context.Context, registrationID string, req *workflow.ExecutionRequest) (*workflow.ExecutionResponse, error) {
|
||||
var resp workflow.ExecutionResponse
|
||||
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/workflows/registrations/%s/execute", registrationID), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
41
workflows_deployments.go
Normal file
41
workflows_deployments.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"github.com/VikingOwl91/mistral-go-sdk/workflow"
|
||||
)
|
||||
|
||||
// ListWorkflowDeployments lists workflow deployments.
|
||||
func (c *Client) ListWorkflowDeployments(ctx context.Context, params *workflow.DeploymentListParams) (*workflow.DeploymentListResponse, error) {
|
||||
path := "/v1/workflows/deployments"
|
||||
if params != nil {
|
||||
q := url.Values{}
|
||||
if params.ActiveOnly != nil {
|
||||
q.Set("active_only", strconv.FormatBool(*params.ActiveOnly))
|
||||
}
|
||||
if params.WorkflowName != nil {
|
||||
q.Set("workflow_name", *params.WorkflowName)
|
||||
}
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
path += "?" + encoded
|
||||
}
|
||||
}
|
||||
var resp workflow.DeploymentListResponse
|
||||
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// GetWorkflowDeployment retrieves a workflow deployment by ID.
|
||||
func (c *Client) GetWorkflowDeployment(ctx context.Context, deploymentID string) (*workflow.Deployment, error) {
|
||||
var resp workflow.Deployment
|
||||
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/workflows/deployments/%s", deploymentID), nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
57
workflows_deployments_test.go
Normal file
57
workflows_deployments_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestListWorkflowDeployments_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/workflows/deployments" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"deployments": []map[string]any{
|
||||
{"id": "dep-1", "name": "prod", "is_active": true, "created_at": "2026-01-01", "updated_at": "2026-01-01"},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.ListWorkflowDeployments(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(resp.Deployments) != 1 {
|
||||
t.Fatalf("got %d deployments", len(resp.Deployments))
|
||||
}
|
||||
if resp.Deployments[0].Name != "prod" {
|
||||
t.Errorf("got name %q", resp.Deployments[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetWorkflowDeployment_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/workflows/deployments/dep-1" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "dep-1", "name": "prod", "is_active": true,
|
||||
"created_at": "2026-01-01", "updated_at": "2026-01-01",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
dep, err := client.GetWorkflowDeployment(context.Background(), "dep-1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if dep.ID != "dep-1" {
|
||||
t.Errorf("got id %q", dep.ID)
|
||||
}
|
||||
}
|
||||
99
workflows_events.go
Normal file
99
workflows_events.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/VikingOwl91/mistral-go-sdk/workflow"
|
||||
)
|
||||
|
||||
// StreamWorkflowEvents streams workflow events via SSE.
|
||||
func (c *Client) StreamWorkflowEvents(ctx context.Context, params *workflow.EventStreamParams) (*WorkflowEventStream, error) {
|
||||
path := "/v1/workflows/events/stream"
|
||||
if params != nil {
|
||||
q := url.Values{}
|
||||
if params.Scope != nil {
|
||||
q.Set("scope", string(*params.Scope))
|
||||
}
|
||||
if params.ActivityName != nil {
|
||||
q.Set("activity_name", *params.ActivityName)
|
||||
}
|
||||
if params.ActivityID != nil {
|
||||
q.Set("activity_id", *params.ActivityID)
|
||||
}
|
||||
if params.WorkflowName != nil {
|
||||
q.Set("workflow_name", *params.WorkflowName)
|
||||
}
|
||||
if params.WorkflowExecID != nil {
|
||||
q.Set("workflow_exec_id", *params.WorkflowExecID)
|
||||
}
|
||||
if params.RootWorkflowExecID != nil {
|
||||
q.Set("root_workflow_exec_id", *params.RootWorkflowExecID)
|
||||
}
|
||||
if params.ParentWorkflowExecID != nil {
|
||||
q.Set("parent_workflow_exec_id", *params.ParentWorkflowExecID)
|
||||
}
|
||||
if params.Stream != nil {
|
||||
q.Set("stream", *params.Stream)
|
||||
}
|
||||
if params.StartSeq != nil {
|
||||
q.Set("start_seq", strconv.Itoa(*params.StartSeq))
|
||||
}
|
||||
if params.MetadataFilters != nil {
|
||||
data, _ := json.Marshal(params.MetadataFilters)
|
||||
q.Set("metadata_filters", string(data))
|
||||
}
|
||||
if len(params.WorkflowEventTypes) > 0 {
|
||||
types := make([]string, len(params.WorkflowEventTypes))
|
||||
for i, et := range params.WorkflowEventTypes {
|
||||
types[i] = string(et)
|
||||
}
|
||||
q.Set("workflow_event_types", strings.Join(types, ","))
|
||||
}
|
||||
if params.LastEventID != nil {
|
||||
q.Set("last_event_id", *params.LastEventID)
|
||||
}
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
path += "?" + encoded
|
||||
}
|
||||
}
|
||||
resp, err := c.doStream(ctx, "GET", path, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newWorkflowEventStream(resp.Body), nil
|
||||
}
|
||||
|
||||
// ListWorkflowEvents lists workflow events.
|
||||
func (c *Client) ListWorkflowEvents(ctx context.Context, params *workflow.EventListParams) (*workflow.EventListResponse, error) {
|
||||
path := "/v1/workflows/events/list"
|
||||
if params != nil {
|
||||
q := url.Values{}
|
||||
if params.RootWorkflowExecID != nil {
|
||||
q.Set("root_workflow_exec_id", *params.RootWorkflowExecID)
|
||||
}
|
||||
if params.WorkflowExecID != nil {
|
||||
q.Set("workflow_exec_id", *params.WorkflowExecID)
|
||||
}
|
||||
if params.WorkflowRunID != nil {
|
||||
q.Set("workflow_run_id", *params.WorkflowRunID)
|
||||
}
|
||||
if params.Limit != nil {
|
||||
q.Set("limit", strconv.Itoa(*params.Limit))
|
||||
}
|
||||
if params.Cursor != nil {
|
||||
q.Set("cursor", *params.Cursor)
|
||||
}
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
path += "?" + encoded
|
||||
}
|
||||
}
|
||||
var resp workflow.EventListResponse
|
||||
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
40
workflows_events_test.go
Normal file
40
workflows_events_test.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/VikingOwl91/mistral-go-sdk/workflow"
|
||||
)
|
||||
|
||||
func TestListWorkflowEvents_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/workflows/events/list" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
if r.URL.Query().Get("limit") != "50" {
|
||||
t.Errorf("got limit %q", r.URL.Query().Get("limit"))
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"events": []map[string]any{{"event_type": "WORKFLOW_EXECUTION_STARTED"}},
|
||||
"next_cursor": "cur-1",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
limit := 50
|
||||
resp, err := client.ListWorkflowEvents(context.Background(), &workflow.EventListParams{Limit: &limit})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(resp.Events) != 1 {
|
||||
t.Fatalf("got %d events", len(resp.Events))
|
||||
}
|
||||
if resp.NextCursor == nil || *resp.NextCursor != "cur-1" {
|
||||
t.Errorf("got cursor %v", resp.NextCursor)
|
||||
}
|
||||
}
|
||||
255
workflows_executions.go
Normal file
255
workflows_executions.go
Normal file
@@ -0,0 +1,255 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/VikingOwl91/mistral-go-sdk/workflow"
|
||||
)
|
||||
|
||||
// GetWorkflowExecution retrieves a workflow execution by ID.
|
||||
func (c *Client) GetWorkflowExecution(ctx context.Context, executionID string) (*workflow.ExecutionResponse, error) {
|
||||
var resp workflow.ExecutionResponse
|
||||
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/workflows/executions/%s", executionID), nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// GetWorkflowExecutionHistory retrieves the history of a workflow execution.
|
||||
func (c *Client) GetWorkflowExecutionHistory(ctx context.Context, executionID string, decodePayloads *bool) (json.RawMessage, error) {
|
||||
path := fmt.Sprintf("/v1/workflows/executions/%s/history", executionID)
|
||||
if decodePayloads != nil {
|
||||
path += "?decode_payloads=" + strconv.FormatBool(*decodePayloads)
|
||||
}
|
||||
var resp json.RawMessage
|
||||
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// StreamWorkflowExecution streams events for a workflow execution via SSE.
|
||||
func (c *Client) StreamWorkflowExecution(ctx context.Context, executionID string, params *workflow.StreamParams) (*WorkflowEventStream, error) {
|
||||
path := fmt.Sprintf("/v1/workflows/executions/%s/stream", executionID)
|
||||
if params != nil {
|
||||
q := url.Values{}
|
||||
if params.EventSource != nil {
|
||||
q.Set("event_source", string(*params.EventSource))
|
||||
}
|
||||
if params.LastEventID != nil {
|
||||
q.Set("last_event_id", *params.LastEventID)
|
||||
}
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
path += "?" + encoded
|
||||
}
|
||||
}
|
||||
resp, err := c.doStream(ctx, "GET", path, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newWorkflowEventStream(resp.Body), nil
|
||||
}
|
||||
|
||||
// SignalWorkflowExecution sends a signal to a workflow execution.
|
||||
func (c *Client) SignalWorkflowExecution(ctx context.Context, executionID string, req *workflow.SignalInvocationBody) (*workflow.SignalResponse, error) {
|
||||
var resp workflow.SignalResponse
|
||||
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/workflows/executions/%s/signals", executionID), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// QueryWorkflowExecution queries a workflow execution.
|
||||
func (c *Client) QueryWorkflowExecution(ctx context.Context, executionID string, req *workflow.QueryInvocationBody) (*workflow.QueryResponse, error) {
|
||||
var resp workflow.QueryResponse
|
||||
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/workflows/executions/%s/queries", executionID), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// UpdateWorkflowExecution sends an update to a workflow execution.
|
||||
func (c *Client) UpdateWorkflowExecution(ctx context.Context, executionID string, req *workflow.UpdateInvocationBody) (*workflow.UpdateResponse, error) {
|
||||
var resp workflow.UpdateResponse
|
||||
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/workflows/executions/%s/updates", executionID), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// TerminateWorkflowExecution terminates a workflow execution.
|
||||
func (c *Client) TerminateWorkflowExecution(ctx context.Context, executionID string) error {
|
||||
resp, err := c.do(ctx, "POST", fmt.Sprintf("/v1/workflows/executions/%s/terminate", executionID), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 400 {
|
||||
return parseAPIError(resp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CancelWorkflowExecution cancels a workflow execution.
|
||||
func (c *Client) CancelWorkflowExecution(ctx context.Context, executionID string) error {
|
||||
resp, err := c.do(ctx, "POST", fmt.Sprintf("/v1/workflows/executions/%s/cancel", executionID), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 400 {
|
||||
return parseAPIError(resp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetWorkflowExecution resets a workflow execution to a specific event.
|
||||
func (c *Client) ResetWorkflowExecution(ctx context.Context, executionID string, req *workflow.ResetInvocationBody) error {
|
||||
return c.doJSON(ctx, "POST", fmt.Sprintf("/v1/workflows/executions/%s/reset", executionID), req, nil)
|
||||
}
|
||||
|
||||
// BatchCancelWorkflowExecutions cancels multiple workflow executions.
|
||||
func (c *Client) BatchCancelWorkflowExecutions(ctx context.Context, req *workflow.BatchExecutionBody) (*workflow.BatchExecutionResponse, error) {
|
||||
var resp workflow.BatchExecutionResponse
|
||||
if err := c.doJSON(ctx, "POST", "/v1/workflows/executions/cancel", req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// BatchTerminateWorkflowExecutions terminates multiple workflow executions.
|
||||
func (c *Client) BatchTerminateWorkflowExecutions(ctx context.Context, req *workflow.BatchExecutionBody) (*workflow.BatchExecutionResponse, error) {
|
||||
var resp workflow.BatchExecutionResponse
|
||||
if err := c.doJSON(ctx, "POST", "/v1/workflows/executions/terminate", req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// GetWorkflowExecutionTraceOTel retrieves the OpenTelemetry trace for a workflow execution.
|
||||
func (c *Client) GetWorkflowExecutionTraceOTel(ctx context.Context, executionID string) (*workflow.TraceOTelResponse, error) {
|
||||
var resp workflow.TraceOTelResponse
|
||||
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/workflows/executions/%s/trace/otel", executionID), nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// GetWorkflowExecutionTraceSummary retrieves the trace summary for a workflow execution.
|
||||
func (c *Client) GetWorkflowExecutionTraceSummary(ctx context.Context, executionID string) (*workflow.TraceSummaryResponse, error) {
|
||||
var resp workflow.TraceSummaryResponse
|
||||
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/workflows/executions/%s/trace/summary", executionID), nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// GetWorkflowExecutionTraceEvents retrieves the trace events for a workflow execution.
|
||||
func (c *Client) GetWorkflowExecutionTraceEvents(ctx context.Context, executionID string, params *workflow.TraceEventsParams) (*workflow.TraceEventsResponse, error) {
|
||||
path := fmt.Sprintf("/v1/workflows/executions/%s/trace/events", executionID)
|
||||
if params != nil {
|
||||
q := url.Values{}
|
||||
if params.MergeSameIDEvents != nil {
|
||||
q.Set("merge_same_id_events", strconv.FormatBool(*params.MergeSameIDEvents))
|
||||
}
|
||||
if params.IncludeInternalEvents != nil {
|
||||
q.Set("include_internal_events", strconv.FormatBool(*params.IncludeInternalEvents))
|
||||
}
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
path += "?" + encoded
|
||||
}
|
||||
}
|
||||
var resp workflow.TraceEventsResponse
|
||||
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// WorkflowEventStream wraps the generic Stream to provide typed workflow events
|
||||
// with StreamPayload envelope metadata.
|
||||
type WorkflowEventStream struct {
|
||||
stream *Stream[json.RawMessage]
|
||||
event workflow.Event
|
||||
payload *workflow.StreamPayload
|
||||
err error
|
||||
}
|
||||
|
||||
func newWorkflowEventStream(body readCloser) *WorkflowEventStream {
|
||||
return &WorkflowEventStream{
|
||||
stream: newStream[json.RawMessage](body),
|
||||
}
|
||||
}
|
||||
|
||||
// Next advances to the next event. Returns false when done or on error.
|
||||
func (s *WorkflowEventStream) Next() bool {
|
||||
if s.err != nil {
|
||||
return false
|
||||
}
|
||||
if !s.stream.Next() {
|
||||
s.err = s.stream.Err()
|
||||
return false
|
||||
}
|
||||
var payload workflow.StreamPayload
|
||||
if err := json.Unmarshal(s.stream.Current(), &payload); err != nil {
|
||||
s.err = fmt.Errorf("mistral: decode workflow stream payload: %w", err)
|
||||
return false
|
||||
}
|
||||
event, err := workflow.UnmarshalEvent(payload.Data)
|
||||
if err != nil {
|
||||
s.err = err
|
||||
return false
|
||||
}
|
||||
s.event = event
|
||||
s.payload = &payload
|
||||
return true
|
||||
}
|
||||
|
||||
// Current returns the most recently read workflow event.
|
||||
func (s *WorkflowEventStream) Current() workflow.Event { return s.event }
|
||||
|
||||
// CurrentPayload returns the full StreamPayload envelope of the current event.
|
||||
func (s *WorkflowEventStream) CurrentPayload() *workflow.StreamPayload { return s.payload }
|
||||
|
||||
// Err returns any error encountered during streaming.
|
||||
func (s *WorkflowEventStream) Err() error { return s.err }
|
||||
|
||||
// Close releases the underlying connection.
|
||||
func (s *WorkflowEventStream) Close() error { return s.stream.Close() }
|
||||
|
||||
// ExecuteWorkflowAndWait executes a workflow and polls until completion.
|
||||
func (c *Client) ExecuteWorkflowAndWait(ctx context.Context, workflowIdentifier string, req *workflow.ExecutionRequest) (*workflow.ExecutionResponse, error) {
|
||||
execResp, err := c.ExecuteWorkflow(ctx, workflowIdentifier, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for {
|
||||
if isTerminal(execResp.Status) {
|
||||
return execResp, nil
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
execResp, err = c.GetWorkflowExecution(ctx, execResp.ExecutionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isTerminal(s workflow.ExecutionStatus) bool {
|
||||
switch s {
|
||||
case workflow.ExecutionCompleted, workflow.ExecutionFailed,
|
||||
workflow.ExecutionCanceled, workflow.ExecutionTerminated,
|
||||
workflow.ExecutionTimedOut:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
266
workflows_executions_test.go
Normal file
266
workflows_executions_test.go
Normal file
@@ -0,0 +1,266 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/VikingOwl91/mistral-go-sdk/workflow"
|
||||
)
|
||||
|
||||
func TestGetWorkflowExecution_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/workflows/executions/exec-1" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"workflow_name": "my-flow", "execution_id": "exec-1",
|
||||
"root_execution_id": "exec-1", "status": "COMPLETED",
|
||||
"start_time": "2026-01-01T00:00:00Z",
|
||||
"end_time": "2026-01-01T00:01:00Z",
|
||||
"result": map[string]any{"answer": 42},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.GetWorkflowExecution(context.Background(), "exec-1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Status != workflow.ExecutionCompleted {
|
||||
t.Errorf("got status %q", resp.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignalWorkflowExecution_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
t.Errorf("got method %s", r.Method)
|
||||
}
|
||||
if r.URL.Path != "/v1/workflows/executions/exec-1/signals" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
var body map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
if body["name"] != "approval" {
|
||||
t.Errorf("got name %v", body["name"])
|
||||
}
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
json.NewEncoder(w).Encode(map[string]any{"message": "Signal accepted"})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.SignalWorkflowExecution(context.Background(), "exec-1", &workflow.SignalInvocationBody{
|
||||
Name: "approval",
|
||||
Input: map[string]any{"approved": true},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Message != "Signal accepted" {
|
||||
t.Errorf("got message %q", resp.Message)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTerminateWorkflowExecution_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
t.Errorf("got method %s", r.Method)
|
||||
}
|
||||
if r.URL.Path != "/v1/workflows/executions/exec-1/terminate" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
err := client.TerminateWorkflowExecution(context.Background(), "exec-1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchCancelWorkflowExecutions_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
t.Errorf("got method %s", r.Method)
|
||||
}
|
||||
if r.URL.Path != "/v1/workflows/executions/cancel" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"results": map[string]any{
|
||||
"exec-1": map[string]any{"status": "success"},
|
||||
"exec-2": map[string]any{"status": "failure", "error": "not found"},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.BatchCancelWorkflowExecutions(context.Background(), &workflow.BatchExecutionBody{
|
||||
ExecutionIDs: []string{"exec-1", "exec-2"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Results["exec-1"].Status != "success" {
|
||||
t.Errorf("got exec-1 status %q", resp.Results["exec-1"].Status)
|
||||
}
|
||||
if resp.Results["exec-2"].Error == nil || *resp.Results["exec-2"].Error != "not found" {
|
||||
t.Errorf("expected exec-2 error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamWorkflowExecution_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "GET" {
|
||||
t.Errorf("got method %s", r.Method)
|
||||
}
|
||||
if r.URL.Path != "/v1/workflows/executions/exec-1/stream" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
flusher, _ := w.(http.Flusher)
|
||||
|
||||
payloads := []map[string]any{
|
||||
{
|
||||
"stream": "events",
|
||||
"data": map[string]any{
|
||||
"event_id": "evt-1", "event_timestamp": 1711929600000000000,
|
||||
"root_workflow_exec_id": "exec-1", "parent_workflow_exec_id": nil,
|
||||
"workflow_exec_id": "exec-1", "workflow_run_id": "run-1",
|
||||
"workflow_name": "my-flow", "event_type": "WORKFLOW_EXECUTION_STARTED",
|
||||
"attributes": map[string]any{},
|
||||
},
|
||||
"workflow_context": map[string]any{
|
||||
"namespace": "default", "workflow_name": "my-flow", "workflow_exec_id": "exec-1",
|
||||
},
|
||||
"broker_sequence": 1,
|
||||
},
|
||||
{
|
||||
"stream": "events",
|
||||
"data": map[string]any{
|
||||
"event_id": "evt-2", "event_timestamp": 1711929601000000000,
|
||||
"root_workflow_exec_id": "exec-1", "parent_workflow_exec_id": nil,
|
||||
"workflow_exec_id": "exec-1", "workflow_run_id": "run-1",
|
||||
"workflow_name": "my-flow", "event_type": "WORKFLOW_EXECUTION_COMPLETED",
|
||||
"attributes": map[string]any{"result": map[string]any{"value": 42, "type": "json"}},
|
||||
},
|
||||
"workflow_context": map[string]any{
|
||||
"namespace": "default", "workflow_name": "my-flow", "workflow_exec_id": "exec-1",
|
||||
},
|
||||
"broker_sequence": 2,
|
||||
},
|
||||
}
|
||||
for _, p := range payloads {
|
||||
data, _ := json.Marshal(p)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
}
|
||||
fmt.Fprint(w, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
stream, err := client.StreamWorkflowExecution(context.Background(), "exec-1", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
var events []workflow.Event
|
||||
var lastPayload *workflow.StreamPayload
|
||||
for stream.Next() {
|
||||
events = append(events, stream.Current())
|
||||
lastPayload = stream.CurrentPayload()
|
||||
}
|
||||
if stream.Err() != nil {
|
||||
t.Fatal(stream.Err())
|
||||
}
|
||||
if len(events) != 2 {
|
||||
t.Fatalf("got %d events, want 2", len(events))
|
||||
}
|
||||
if _, ok := events[0].(*workflow.WorkflowExecutionStartedEvent); !ok {
|
||||
t.Errorf("expected *WorkflowExecutionStartedEvent, got %T", events[0])
|
||||
}
|
||||
if _, ok := events[1].(*workflow.WorkflowExecutionCompletedEvent); !ok {
|
||||
t.Errorf("expected *WorkflowExecutionCompletedEvent, got %T", events[1])
|
||||
}
|
||||
if lastPayload.WorkflowContext.WorkflowName != "my-flow" {
|
||||
t.Errorf("got workflow context name %q", lastPayload.WorkflowContext.WorkflowName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetWorkflowExecutionTraceOTel_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/workflows/executions/exec-1/trace/otel" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"workflow_name": "my-flow", "execution_id": "exec-1",
|
||||
"root_execution_id": "exec-1", "status": "COMPLETED",
|
||||
"start_time": "2026-01-01T00:00:00Z", "data_source": "temporal",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.GetWorkflowExecutionTraceOTel(context.Background(), "exec-1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.DataSource != "temporal" {
|
||||
t.Errorf("got data_source %q", resp.DataSource)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteWorkflowAndWait_Success(t *testing.T) {
|
||||
calls := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == "POST" && r.URL.Path == "/v1/workflows/wf-1/execute":
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"workflow_name": "my-flow", "execution_id": "exec-1",
|
||||
"root_execution_id": "exec-1", "status": "RUNNING",
|
||||
"start_time": "2026-01-01T00:00:00Z",
|
||||
})
|
||||
case r.Method == "GET" && r.URL.Path == "/v1/workflows/executions/exec-1":
|
||||
calls++
|
||||
status := "RUNNING"
|
||||
if calls >= 2 {
|
||||
status = "COMPLETED"
|
||||
}
|
||||
resp := map[string]any{
|
||||
"workflow_name": "my-flow", "execution_id": "exec-1",
|
||||
"root_execution_id": "exec-1", "status": status,
|
||||
"start_time": "2026-01-01T00:00:00Z",
|
||||
}
|
||||
if status == "COMPLETED" {
|
||||
resp["result"] = map[string]any{"answer": 42}
|
||||
}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
default:
|
||||
t.Errorf("unexpected %s %s", r.Method, r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.ExecuteWorkflowAndWait(context.Background(), "wf-1", &workflow.ExecutionRequest{
|
||||
Input: map[string]any{"prompt": "hello"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Status != workflow.ExecutionCompleted {
|
||||
t.Errorf("got status %q", resp.Status)
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user