4 Commits

Author SHA1 Message Date
aa5c53c407 feat: v1.1.0 — sync with upstream Python SDK v2.1.3
Add Connectors, Audio Speech/Voices, Audio Realtime types,
and Observability (beta). 41 new service methods, 116 total.

Breaking: ListModels and UploadFile signatures changed
(pass nil for previous behavior).
2026-03-24 09:07:03 +01:00
b1f0fc4907 feat: v1.0.0 stable release — tracks upstream Python SDK v2.0.4 2026-03-17 11:34:09 +01:00
94a938b733 docs: add pkg.go.dev badge to README 2026-03-17 11:31:19 +01:00
52231df17b feat: v0.2.0 — sync with upstream Python SDK v2.0.4
Add ToolReferenceChunk, ToolFileChunk, BuiltInConnector enum,
ReferenceID union type (int|string), GuardrailConfig with v1/v2
moderation, ConnectorTool for custom connectors, and guardrails
field on chat/agents/conversation requests.

Add AudioTranscriptionRealtime and AudioSpeech to ModelCapabilities.

Move GuardrailConfig from agents/ to chat/ as shared base type.
Remove bundled OpenAPI spec; SDK now tracks upstream Python SDK.

BREAKING: ReferenceChunk.ReferenceIDs changed from []int to
[]ReferenceID. Use IntRef(n) / StringRef(s) constructors.
2026-03-17 11:23:58 +01:00
50 changed files with 3390 additions and 9570 deletions

1
.gitignore vendored
View File

@@ -29,3 +29,4 @@ vendor/
# OS
.DS_Store
Thumbs.db
client-python/

View File

@@ -1,3 +1,82 @@
## v1.1.0 — 2026-03-24
Upstream sync with Python SDK v2.1.3. Adds Connectors, Audio Speech/Voices, and Observability (beta).
### Breaking Changes
- **`ListModels`** signature changed from `(ctx)` to `(ctx, *model.ListParams)`.
Pass `nil` for previous behavior. The new `ListParams` supports `Provider` and
`Model` query filters.
- **`UploadFile`** signature changed from `(ctx, filename, reader, purpose)` to
`(ctx, filename, reader, *file.UploadParams)`. The new `UploadParams` struct
holds `Purpose`, `Expiry`, and `Visibility` fields.
### Added
- **`ReasoningEffort`** field on `chat.CompletionRequest` and
`agents.CompletionRequest` — controls reasoning effort (`"none"`, `"high"`).
- **Connectors API** (new `connector/` package) — `CreateConnector`,
`ListConnectors`, `GetConnector`, `UpdateConnector`, `DeleteConnector`,
`GetConnectorAuthURL`, `ListConnectorTools`, `CallConnectorTool`.
- **Audio Speech (TTS)** — `Speech`, `SpeechStream` with `SpeechStream` typed
wrapper, `SpeechOutputFormat` enum (pcm/wav/mp3/flac/opus).
- **Audio Voices** — `ListVoices`, `CreateVoice`, `GetVoice`, `UpdateVoice`,
`DeleteVoice`, `GetVoiceSampleAudio`.
- **Audio Realtime types** — `AudioEncoding`, `AudioFormat`, `RealtimeSession`,
and WebSocket message types in `audio/realtime.go`. No WebSocket client yet
(would require adding a dependency).
- **Observability API** (new `observability/` package, beta) — campaigns,
chat completion events, judges, datasets, records, and import tasks.
33 service methods total.
- **`file.Visibility`** enum — `shared_global`, `shared_org`,
`shared_workspace`, `private`.
- **`model.ListParams`** — filter models by `Provider` and `Model`.
## v1.0.0 — 2026-03-17
Stable release. Tracks upstream Python SDK v2.0.4.
No API changes from v0.2.0. This release signals that the SDK surface is
stable and follows Go module semver conventions — breaking changes will only
occur in future major versions.
## v0.2.0 — 2026-03-17
Sync with upstream Python SDK v2.0.4. Upstream reference changed from OpenAPI
spec to official Python SDK (https://github.com/mistralai/client-python).
### Breaking Changes
- **`ReferenceChunk.ReferenceIDs`** changed from `[]int` to `[]ReferenceID`.
The API now returns mixed integer and string identifiers. Use `IntRef(n)` and
`StringRef(s)` constructors; read back with `.Int()` and `.IsString()`.
- **`agents.GuardrailConfig` and `agents.ModerationLLMV1Config`** moved to
`chat.GuardrailConfig` and `chat.ModerationLLMV1Config`. The types are now
shared across chat, agents, and conversation packages.
### Added
- **`ToolReferenceChunk`** — new content chunk type for tool references
returned by built-in connectors (web search, code interpreter, etc.).
- **`ToolFileChunk`** — new content chunk type for tool-generated files.
- **`BuiltInConnector`** constants — `ConnectorWebSearch`,
`ConnectorWebSearchPremium`, `ConnectorCodeInterpreter`,
`ConnectorImageGeneration`, `ConnectorDocumentLibrary`.
- **`ModerationLLMV2Config`** — v2 moderation guardrail with split
`dangerous`/`criminal` categories and new `jailbreaking` category.
- **`GuardrailConfig`** on `chat.CompletionRequest`,
`agents.CompletionRequest`, `conversation.StartRequest`, and
`conversation.RestartRequest`.
- **`ConnectorTool`** — new agent tool type for custom connectors with
`ConnectorAuth` (api-key / oauth2-token authorization).
- **`ModelCapabilities`** — added `AudioTranscriptionRealtime` and
`AudioSpeech` fields.
### Removed
- Bundled `docs/openapi.yaml`. The SDK now tracks the upstream Python SDK
directly as its reference implementation.
## v0.1.0 — 2026-03-05
Initial release.

89
CLAUDE.md Normal file
View File

@@ -0,0 +1,89 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Project
Idiomatic Go SDK for the Mistral AI API. Module path: `somegit.dev/vikingowl/mistral-go-sdk`. Requires Go 1.26+. Zero external dependencies (stdlib only). Tracks the upstream [Mistral Python SDK](https://github.com/mistralai/client-python) as reference for API surface and type definitions.
## Repository layout
- **Working directory**: `mistral-go-sdk/` — the Go SDK source. All development happens here.
- **`../client-python/`**: Clone of the upstream Mistral Python SDK. Read-only reference — pull/update it when checking for upstream API changes, but never modify it.
## Commands
```bash
# Run all unit tests
go test ./...
# Run a single test
go test -run TestChatComplete_Success
# Run integration tests (requires MISTRAL_API_KEY env var)
go test -tags=integration ./...
# Vet and build
go vet ./...
go build ./...
```
No Makefile, linter config, or code generation tooling — standard `go test` / `go vet` / `go build`.
## Architecture
### Two-layer design: types in sub-packages, methods on `*Client`
Sub-packages (`chat/`, `agents/`, `conversation/`, `embedding/`, `model/`, `file/`, `finetune/`, `batch/`, `ocr/`, `audio/`, `library/`, `moderation/`, `classification/`, `fim/`) are **types-only** — they define request/response structs and enums but contain no HTTP logic. All service methods live on `*Client` in the root package, prefix-namespaced by domain (e.g. `ChatComplete`, `AgentsComplete`, `CreateFineTuningJob`, `UploadFile`).
### HTTP internals (request.go)
All HTTP flows route through a small set of unexported helpers on `*Client`:
- `do()` — raw HTTP with auth headers + retry
- `doJSON()` — JSON marshal request → `do()` → unmarshal response
- `doStream()` — JSON request → raw `*http.Response` for SSE
- `doMultipart()` / `doMultipartStream()` — multipart file upload variants
- `doRetry()` — retry loop with exponential backoff + jitter + `Retry-After` parsing
### Streaming
Generic `Stream[T]` type wraps SSE (`sseReader`) with `Next()`/`Current()`/`Err()`/`Close()` iterator pattern. Typed wrappers `EventStream` (conversations) and `AudioStream` (transcription) unmarshal `json.RawMessage` into domain-specific event types.
### Sealed interfaces for discriminated unions
Polymorphic API types use **sealed interfaces** with unexported marker methods:
- `chat.Message` (marker: `isMessage()`) — `SystemMessage`, `UserMessage`, `AssistantMessage`, `ToolMessage`
- `chat.ContentChunk` (marker: `contentChunk()`) — `TextChunk`, `ImageURLChunk`, `DocumentURLChunk`, `FileChunk`, `ReferenceChunk`, `ThinkChunk`, `AudioChunk`, `ToolReferenceChunk`, `ToolFileChunk`
- `agents.AgentTool` (marker: `agentToolType()`) — `FunctionTool`, `WebSearchTool`, `CodeInterpreterTool`, `ConnectorTool`, etc.
- `conversation.Event` — conversation streaming events
Each has an `Unknown*` variant so the SDK doesn't break on new API types. Each has a `Unmarshal*` dispatch function that probes a `type`/`role` discriminator field.
### Custom JSON patterns
Several types require non-trivial marshal/unmarshal:
- **Type alias trick** — `type alias T` inside `MarshalJSON` to avoid infinite recursion when injecting a `type`/`role` discriminator field.
- **`json:"-"` + custom MarshalJSON** — `CompletionRequest.Messages` (and `stream`) are excluded from default marshaling and injected via custom `MarshalJSON`.
- **Union types** — `Content` handles `string | null | []ContentChunk`; `ToolChoice` handles `string | object`; `ImageURL` handles `string | object`; `FunctionCall.Arguments` handles `string | object`; `ReferenceID` handles `int | string` with type preservation.
- **Probe struct pattern** — `Unmarshal*` functions decode only the discriminator field first, then dispatch to the concrete type.
### Shared types in `chat/`
`GuardrailConfig`, `ModerationLLMV1Config`, `ModerationLLMV2Config` live in `chat/` because it's the base types package imported by both `agents/` and `conversation/`. This avoids import cycles.
### Error handling
`APIError` in `error.go` with sentinel checkers: `IsNotFound()`, `IsRateLimit()`, `IsAuth()`. All use `errors.As` for unwrapping.
## Testing patterns
- Unit tests use `httptest.NewServer` with inline handlers to mock the Mistral API. Client is pointed at the test server via `WithBaseURL(server.URL)`.
- Integration tests are behind `//go:build integration` build tag and require `MISTRAL_API_KEY`.
- Tests use stdlib `testing` only — no third-party test frameworks.
## Adding a new API endpoint
1. Define request/response types in the appropriate sub-package (or create a new one with a `doc.go`).
2. Add a method on `*Client` in the root package. Use `doJSON` for standard request/response, `doStream` for SSE, `doMultipart` for file uploads.
3. Add unit tests with `httptest.NewServer`.
4. If the endpoint supports streaming, return `*Stream[T]` and call `EnableStream()` on the request before sending.

View File

@@ -3,6 +3,7 @@
The most complete Go client for the [Mistral AI API](https://docs.mistral.ai/).
<!-- Badges -->
[![Go Reference](https://pkg.go.dev/badge/somegit.dev/vikingowl/mistral-go-sdk.svg)](https://pkg.go.dev/somegit.dev/vikingowl/mistral-go-sdk)
![Go Version](https://img.shields.io/badge/go-1.26-blue)
![License](https://img.shields.io/badge/license-MIT-green)
@@ -10,7 +11,7 @@ The most complete Go client for the [Mistral AI API](https://docs.mistral.ai/).
**Zero dependencies.** The entire SDK — including tests — uses only the Go standard library. No `go.sum`, no transitive dependency tree to audit, no version conflicts, no supply chain risk.
**Full API coverage.** 75 methods across every Mistral endpoint — including Conversations, Agents CRUD, Libraries, OCR, Audio, Fine-tuning, and Batch Jobs. No other Go SDK covers Conversations or Agents.
**Full API coverage.** 116 methods across every Mistral endpoint — including Connectors, Audio Speech/Voices, Conversations, Agents CRUD, Libraries, OCR, Observability, Fine-tuning, and Batch Jobs. No other Go SDK covers Conversations, Connectors, or Observability.
**Typed streaming.** A generic pull-based `Stream[T]` iterator — no channels, no goroutines, no leaks. Just `Next()` / `Current()` / `Err()` / `Close()`.
@@ -18,7 +19,7 @@ The most complete Go client for the [Mistral AI API](https://docs.mistral.ai/).
**Hand-written, not generated.** Idiomatic Go with sealed interfaces, discriminated unions, and functional options — not a Speakeasy/OpenAPI auto-gen dump with `any` everywhere.
**Test-driven.** 126 tests with race detection clean. Every endpoint tested against mock servers; integration tests against the real API.
**Test-driven.** 193 tests with race detection clean. Every endpoint tested against mock servers; integration tests against the real API.
## Install
@@ -131,7 +132,7 @@ for stream.Next() {
## API Coverage
75 public methods on `Client`, grouped by domain:
116 public methods on `Client`, grouped by domain:
| Domain | Methods |
|--------|---------|
@@ -139,6 +140,7 @@ for stream.Next() {
| **FIM** | `FIMComplete`, `FIMCompleteStream` |
| **Agents (completions)** | `AgentsComplete`, `AgentsCompleteStream` |
| **Agents (CRUD)** | `CreateAgent`, `ListAgents`, `GetAgent`, `UpdateAgent`, `DeleteAgent`, `UpdateAgentVersion`, `ListAgentVersions`, `GetAgentVersion`, `SetAgentAlias`, `ListAgentAliases`, `DeleteAgentAlias` |
| **Connectors** | `CreateConnector`, `ListConnectors`, `GetConnector`, `UpdateConnector`, `DeleteConnector`, `GetConnectorAuthURL`, `ListConnectorTools`, `CallConnectorTool` |
| **Conversations** | `StartConversation`, `StartConversationStream`, `AppendConversation`, `AppendConversationStream`, `RestartConversation`, `RestartConversationStream`, `GetConversation`, `ListConversations`, `DeleteConversation`, `GetConversationHistory`, `GetConversationMessages` |
| **Models** | `ListModels`, `GetModel`, `DeleteModel` |
| **Files** | `UploadFile`, `ListFiles`, `GetFile`, `DeleteFile`, `GetFileContent`, `GetFileURL` |
@@ -146,10 +148,16 @@ for stream.Next() {
| **Fine-tuning** | `CreateFineTuningJob`, `ListFineTuningJobs`, `GetFineTuningJob`, `CancelFineTuningJob`, `StartFineTuningJob`, `UpdateFineTunedModel`, `ArchiveFineTunedModel`, `UnarchiveFineTunedModel` |
| **Batch** | `CreateBatchJob`, `ListBatchJobs`, `GetBatchJob`, `CancelBatchJob` |
| **OCR** | `OCR` |
| **Audio** | `Transcribe`, `TranscribeStream` |
| **Audio (transcription)** | `Transcribe`, `TranscribeStream` |
| **Audio (speech)** | `Speech`, `SpeechStream` |
| **Audio (voices)** | `ListVoices`, `CreateVoice`, `GetVoice`, `UpdateVoice`, `DeleteVoice`, `GetVoiceSampleAudio` |
| **Libraries** | `CreateLibrary`, `ListLibraries`, `GetLibrary`, `UpdateLibrary`, `DeleteLibrary`, `UploadDocument`, `ListDocuments`, `GetDocument`, `UpdateDocument`, `DeleteDocument`, `GetDocumentTextContent`, `GetDocumentStatus`, `GetDocumentSignedURL`, `GetDocumentExtractedTextSignedURL`, `ReprocessDocument`, `ListLibrarySharing`, `ShareLibrary`, `UnshareLibrary` |
| **Moderation** | `Moderate`, `ModerateChat` |
| **Classification** | `Classify`, `ClassifyChat` |
| **Observability (campaigns)** | `CreateCampaign`, `ListCampaigns`, `GetCampaign`, `DeleteCampaign`, `GetCampaignStatus`, `ListCampaignEvents` |
| **Observability (events)** | `SearchChatCompletionEvents`, `SearchChatCompletionEventIDs`, `GetChatCompletionEvent`, `GetSimilarChatCompletionEvents`, `JudgeChatCompletionEvent` |
| **Observability (judges)** | `CreateJudge`, `ListJudges`, `GetJudge`, `UpdateJudge`, `DeleteJudge`, `JudgeConversation` |
| **Observability (datasets)** | `CreateDataset`, `ListDatasets`, `GetDataset`, `UpdateDataset`, `DeleteDataset`, `ExportDatasetToJSONL`, `ListDatasetRecords`, `CreateDatasetRecord`, `GetDatasetRecord`, `UpdateDatasetRecordPayload`, `UpdateDatasetRecordProperties`, `DeleteDatasetRecord`, `BulkDeleteDatasetRecords`, `JudgeDatasetRecord`, `ImportDatasetFromCampaign`, `ImportDatasetFromExplorer`, `ImportDatasetFromFile`, `ImportDatasetFromPlayground`, `ImportDatasetFromDataset`, `ListDatasetTasks`, `GetDatasetTask` |
## Comparison
@@ -162,11 +170,13 @@ There is no official Go SDK from Mistral AI (only Python and TypeScript). The ma
| Embeddings | Yes | Yes | Yes | Yes |
| Tool calling | Yes | No | No | No |
| Agents (completions + CRUD) | Yes | No | No | No |
| Connectors (MCP) | Yes | No | No | No |
| Conversations API | Yes | No | No | No |
| Libraries / Documents | Yes | No | No | No |
| Fine-tuning / Batch | Yes | No | No | No |
| OCR | Yes | No | No | Yes |
| Audio transcription | Yes | No | No | No |
| Audio (transcription + TTS + voices) | Yes | No | No | No |
| Observability (beta) | Yes | No | No | No |
| Moderation / Classification | Yes | No | No | No |
| Vision (multimodal) | Yes | No | No | Yes |
| Zero dependencies | Yes | test-only (testify) | test-only (testify) | test-only (testify) |
@@ -213,6 +223,16 @@ if err != nil {
}
```
## Upstream Reference
This SDK tracks the [official Mistral Python SDK](https://github.com/mistralai/client-python)
as its upstream reference for API surface and type definitions.
| SDK Version | Upstream Python SDK |
|-------------|---------------------|
| v1.1.0 | v2.1.3 |
| v1.0.0 | v2.0.4 |
## License
[MIT](LICENSE)

View File

@@ -3,6 +3,8 @@ package agents
import (
"encoding/json"
"fmt"
"somegit.dev/vikingowl/mistral-go-sdk/chat"
)
// AgentTool is a sealed interface for agent tool types.
@@ -59,6 +61,22 @@ type DocumentLibraryTool struct {
func (*DocumentLibraryTool) agentToolType() string { return "document_library" }
// ConnectorAuth holds authorization for a custom connector.
type ConnectorAuth struct {
Type string `json:"type"`
Value string `json:"value"`
}
// ConnectorTool represents a custom connector tool.
type ConnectorTool struct {
Type string `json:"type"`
ConnectorID string `json:"connector_id"`
Authorization *ConnectorAuth `json:"authorization,omitempty"`
ToolConfiguration *ToolConfiguration `json:"tool_configuration,omitempty"`
}
func (*ConnectorTool) agentToolType() string { return "connector" }
// UnknownAgentTool holds an unrecognized tool type.
type UnknownAgentTool struct {
Type string
@@ -99,6 +117,9 @@ func UnmarshalAgentTool(data []byte) (AgentTool, error) {
case "document_library":
var t DocumentLibraryTool
return &t, json.Unmarshal(data, &t)
case "connector":
var t ConnectorTool
return &t, json.Unmarshal(data, &t)
default:
return &UnknownAgentTool{Type: probe.Type, Raw: json.RawMessage(data)}, nil
}
@@ -151,7 +172,7 @@ type Agent struct {
Description *string `json:"description,omitempty"`
Tools AgentTools `json:"tools,omitempty"`
CompletionArgs *CompletionArgs `json:"completion_args,omitempty"`
Guardrails []GuardrailConfig `json:"guardrails,omitempty"`
Guardrails []chat.GuardrailConfig `json:"guardrails,omitempty"`
Handoffs []string `json:"handoffs,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
VersionMessage *string `json:"version_message,omitempty"`
@@ -165,7 +186,7 @@ type CreateRequest struct {
Description *string `json:"description,omitempty"`
Tools AgentTools `json:"tools,omitempty"`
CompletionArgs *CompletionArgs `json:"completion_args,omitempty"`
Guardrails []GuardrailConfig `json:"guardrails,omitempty"`
Guardrails []chat.GuardrailConfig `json:"guardrails,omitempty"`
Handoffs []string `json:"handoffs,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
VersionMessage *string `json:"version_message,omitempty"`
@@ -179,7 +200,7 @@ type UpdateRequest struct {
Description *string `json:"description,omitempty"`
Tools AgentTools `json:"tools,omitempty"`
CompletionArgs *CompletionArgs `json:"completion_args,omitempty"`
Guardrails []GuardrailConfig `json:"guardrails,omitempty"`
Guardrails []chat.GuardrailConfig `json:"guardrails,omitempty"`
Handoffs []string `json:"handoffs,omitempty"`
DeploymentChat *bool `json:"deployment_chat,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
@@ -223,20 +244,6 @@ type CompletionArgs struct {
ToolChoice *string `json:"tool_choice,omitempty"`
}
// GuardrailConfig configures moderation guardrails for an agent.
type GuardrailConfig struct {
BlockOnError bool `json:"block_on_error"`
ModerationLLMV1 *ModerationLLMV1Config `json:"moderation_llm_v1"`
}
// ModerationLLMV1Config configures the moderation LLM guardrail.
type ModerationLLMV1Config struct {
ModelName string `json:"model_name,omitempty"`
CustomCategoryThresholds json.RawMessage `json:"custom_category_thresholds,omitempty"`
IgnoreOtherCategories bool `json:"ignore_other_categories,omitempty"`
Action string `json:"action,omitempty"`
}
// ToolConfiguration holds include/exclude/confirmation lists for tools.
type ToolConfiguration struct {
Exclude []string `json:"exclude,omitempty"`

View File

@@ -66,6 +66,30 @@ func TestAgentTools_RoundTrip(t *testing.T) {
}
}
func TestUnmarshalAgentTool_Connector(t *testing.T) {
data := []byte(`{"type":"connector","connector_id":"my-connector","authorization":{"type":"api-key","value":"sk-test"}}`)
tool, err := UnmarshalAgentTool(data)
if err != nil {
t.Fatal(err)
}
ct, ok := tool.(*ConnectorTool)
if !ok {
t.Fatalf("expected *ConnectorTool, got %T", tool)
}
if ct.ConnectorID != "my-connector" {
t.Errorf("got connector_id %q", ct.ConnectorID)
}
if ct.Authorization == nil {
t.Fatal("expected authorization")
}
if ct.Authorization.Type != "api-key" {
t.Errorf("got auth type %q", ct.Authorization.Type)
}
if ct.Authorization.Value != "sk-test" {
t.Errorf("got auth value %q", ct.Authorization.Value)
}
}
func TestAgent_UnmarshalWithTools(t *testing.T) {
data := []byte(`{
"id":"ag-1","object":"agent","name":"A","model":"m",

View File

@@ -1,3 +1,10 @@
// Package agents provides types for the Mistral agents API,
// including agent CRUD operations and agent chat completions.
//
// # Tool Types
//
// Agents support multiple tool types via the [AgentTool] sealed interface:
// [FunctionTool], [WebSearchTool], [WebSearchPremiumTool],
// [CodeInterpreterTool], [ImageGenerationTool], [DocumentLibraryTool],
// and [ConnectorTool] for custom connectors.
package agents

View File

@@ -23,8 +23,10 @@ type CompletionRequest struct {
N *int `json:"n,omitempty"`
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
Prediction *chat.Prediction `json:"prediction,omitempty"`
PromptMode *chat.PromptMode `json:"prompt_mode,omitempty"`
Prediction *chat.Prediction `json:"prediction,omitempty"`
PromptMode *chat.PromptMode `json:"prompt_mode,omitempty"`
Guardrails []chat.GuardrailConfig `json:"guardrails,omitempty"`
ReasoningEffort *chat.ReasoningEffort `json:"reasoning_effort,omitempty"`
stream bool
}

View File

@@ -114,6 +114,39 @@ func TestAgentsComplete_WithTools(t *testing.T) {
}
}
func TestAgentsComplete_ReasoningEffort(t *testing.T) {
effort := chat.ReasoningEffortHigh
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
if body["reasoning_effort"] != "high" {
t.Errorf("expected reasoning_effort=high, got %v", body["reasoning_effort"])
}
json.NewEncoder(w).Encode(map[string]any{
"id": "a-re", "object": "chat.completion",
"model": "m", "created": 0,
"choices": []map[string]any{{
"index": 0, "message": map[string]any{"role": "assistant", "content": "ok"},
"finish_reason": "stop",
}},
"usage": map[string]any{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
_, err := client.AgentsComplete(context.Background(), &agents.CompletionRequest{
AgentID: "agent-1",
Messages: []chat.Message{&chat.UserMessage{Content: chat.TextContent("Hi")}},
ReasoningEffort: &effort,
})
if err != nil {
t.Fatal(err)
}
}
func TestAgentsCompleteStream_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var body map[string]any

View File

@@ -1,5 +1,25 @@
// Package audio provides types for the Mistral audio transcription API.
// Package audio provides types for the Mistral audio APIs.
//
// # Transcription
//
// [TranscriptionRequest] and [TranscriptionResponse] handle speech-to-text.
// Streaming transcription returns typed [StreamEvent] values via a sealed
// interface dispatched by the "type" field.
//
// # Speech (TTS)
//
// [SpeechRequest] and [SpeechResponse] handle text-to-speech.
// Streaming speech returns typed [SpeechStreamEvent] values
// ([SpeechAudioDelta] and [SpeechDone]).
//
// # Voices
//
// [VoiceResponse], [VoiceCreateRequest], and [VoiceUpdateRequest] manage
// custom voices for speech synthesis.
//
// # Realtime
//
// Realtime transcription types ([AudioEncoding], [AudioFormat],
// [RealtimeSession], and WebSocket message types) are defined here.
// The WebSocket client is not yet implemented.
package audio

75
audio/realtime.go Normal file
View File

@@ -0,0 +1,75 @@
package audio
// AudioEncoding is the encoding format for realtime audio streams.
type AudioEncoding string
const (
EncodingPCMS16LE AudioEncoding = "pcm_s16le"
EncodingPCMS32LE AudioEncoding = "pcm_s32le"
EncodingPCMF16LE AudioEncoding = "pcm_f16le"
EncodingPCMF32LE AudioEncoding = "pcm_f32le"
EncodingPCMMulaw AudioEncoding = "pcm_mulaw"
EncodingPCMAlaw AudioEncoding = "pcm_alaw"
)
// AudioFormat describes the encoding and sample rate for realtime audio.
type AudioFormat struct {
Encoding AudioEncoding `json:"encoding"`
SampleRate int `json:"sample_rate"`
}
// RealtimeSession describes a realtime transcription session.
type RealtimeSession struct {
RequestID string `json:"request_id"`
Model string `json:"model"`
AudioFormat AudioFormat `json:"audio_format"`
TargetStreamingDelayMs *int `json:"target_streaming_delay_ms,omitempty"`
}
// RealtimeSessionUpdate is sent to update session parameters.
// Parameters can only be changed before audio transmission starts.
type RealtimeSessionUpdate struct {
AudioFormat *AudioFormat `json:"audio_format,omitempty"`
TargetStreamingDelayMs *int `json:"target_streaming_delay_ms,omitempty"`
}
// InputAudioAppend sends a chunk of audio data.
// Audio is base64-encoded (max 262144 bytes decoded).
type InputAudioAppend struct {
Type string `json:"type"` // "input_audio.append"
Audio string `json:"audio"`
}
// InputAudioFlush flushes the audio buffer.
type InputAudioFlush struct {
Type string `json:"type"` // "input_audio.flush"
}
// InputAudioEnd signals the end of audio input.
type InputAudioEnd struct {
Type string `json:"type"` // "input_audio.end"
}
// RealtimeSessionCreated is received when a session is created.
type RealtimeSessionCreated struct {
Type string `json:"type"` // "session.created"
Session RealtimeSession `json:"session"`
}
// RealtimeSessionUpdated is received when a session is updated.
type RealtimeSessionUpdated struct {
Type string `json:"type"` // "session.updated"
Session RealtimeSession `json:"session"`
}
// RealtimeErrorDetail describes a realtime error.
type RealtimeErrorDetail struct {
Message string `json:"message"`
Code int `json:"code"`
}
// RealtimeError is received on error.
type RealtimeError struct {
Type string `json:"type"` // "error"
Error RealtimeErrorDetail `json:"error"`
}

88
audio/speech.go Normal file
View File

@@ -0,0 +1,88 @@
package audio
import (
"encoding/json"
"fmt"
)
// SpeechOutputFormat is the output audio format for speech synthesis.
type SpeechOutputFormat string
const (
SpeechFormatPCM SpeechOutputFormat = "pcm"
SpeechFormatWAV SpeechOutputFormat = "wav"
SpeechFormatMP3 SpeechOutputFormat = "mp3"
SpeechFormatFLAC SpeechOutputFormat = "flac"
SpeechFormatOpus SpeechOutputFormat = "opus"
)
// SpeechRequest represents a text-to-speech request.
type SpeechRequest struct {
Input string `json:"input"`
Model string `json:"model"`
Metadata map[string]any `json:"metadata,omitempty"`
VoiceID *string `json:"voice_id,omitempty"`
RefAudio *string `json:"ref_audio,omitempty"`
ResponseFormat *SpeechOutputFormat `json:"response_format,omitempty"`
stream bool
}
// EnableStream is used internally to enable streaming.
func (r *SpeechRequest) EnableStream() { r.stream = true }
func (r *SpeechRequest) MarshalJSON() ([]byte, error) {
type Alias SpeechRequest
return json.Marshal(&struct {
Stream bool `json:"stream"`
*Alias
}{
Stream: r.stream,
Alias: (*Alias)(r),
})
}
// SpeechResponse is the response from a non-streaming speech request.
type SpeechResponse struct {
AudioData string `json:"audio_data"`
}
// SpeechStreamEvent is a sealed interface for speech streaming events.
type SpeechStreamEvent interface {
speechStreamEvent()
}
// SpeechAudioDelta contains a chunk of audio data during streaming.
type SpeechAudioDelta struct {
Type string `json:"type"`
AudioData string `json:"audio_data"`
}
func (*SpeechAudioDelta) speechStreamEvent() {}
// SpeechDone is emitted when speech synthesis is complete.
type SpeechDone struct {
Type string `json:"type"`
Usage UsageInfo `json:"usage"`
}
func (*SpeechDone) speechStreamEvent() {}
// UnmarshalSpeechStreamEvent dispatches a raw JSON event to the correct type.
func UnmarshalSpeechStreamEvent(data []byte) (SpeechStreamEvent, error) {
var probe struct {
Type string `json:"type"`
}
if err := json.Unmarshal(data, &probe); err != nil {
return nil, err
}
switch probe.Type {
case "speech.audio.delta":
var e SpeechAudioDelta
return &e, json.Unmarshal(data, &e)
case "speech.audio.done":
var e SpeechDone
return &e, json.Unmarshal(data, &e)
default:
return nil, fmt.Errorf("unknown speech stream event type: %q", probe.Type)
}
}

48
audio/voice.go Normal file
View File

@@ -0,0 +1,48 @@
package audio
// VoiceResponse represents a voice entity.
type VoiceResponse struct {
Name string `json:"name"`
ID string `json:"id"`
CreatedAt string `json:"created_at"`
UserID *string `json:"user_id,omitempty"`
Slug *string `json:"slug,omitempty"`
Languages []string `json:"languages,omitempty"`
Gender *string `json:"gender,omitempty"`
Age *int `json:"age,omitempty"`
Tags []string `json:"tags,omitempty"`
Color *string `json:"color,omitempty"`
RetentionNotice *int `json:"retention_notice,omitempty"`
}
// VoiceCreateRequest creates a custom voice.
type VoiceCreateRequest struct {
Name string `json:"name"`
SampleAudio string `json:"sample_audio"`
Slug *string `json:"slug,omitempty"`
Languages []string `json:"languages,omitempty"`
Gender *string `json:"gender,omitempty"`
Age *int `json:"age,omitempty"`
Tags []string `json:"tags,omitempty"`
Color *string `json:"color,omitempty"`
RetentionNotice *int `json:"retention_notice,omitempty"`
SampleFilename *string `json:"sample_filename,omitempty"`
}
// VoiceUpdateRequest updates a voice.
type VoiceUpdateRequest struct {
Name *string `json:"name,omitempty"`
Languages []string `json:"languages,omitempty"`
Gender *string `json:"gender,omitempty"`
Age *int `json:"age,omitempty"`
Tags []string `json:"tags,omitempty"`
}
// VoiceListResponse is the response from listing voices.
type VoiceListResponse struct {
Items []VoiceResponse `json:"items"`
Total int `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
TotalPages int `json:"total_pages"`
}

View File

@@ -3,7 +3,9 @@ package mistral
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"somegit.dev/vikingowl/mistral-go-sdk/audio"
)
@@ -95,3 +97,125 @@ func (s *AudioStream) Err() error { return s.err }
// Close releases the underlying connection.
func (s *AudioStream) Close() error { return s.stream.Close() }
// Speech sends a text-to-speech request and returns the full response.
func (c *Client) Speech(ctx context.Context, req *audio.SpeechRequest) (*audio.SpeechResponse, error) {
var resp audio.SpeechResponse
if err := c.doJSON(ctx, "POST", "/v1/audio/speech", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// SpeechStream sends a text-to-speech request and returns a stream of audio events.
func (c *Client) SpeechStream(ctx context.Context, req *audio.SpeechRequest) (*SpeechStream, error) {
req.EnableStream()
resp, err := c.doStream(ctx, "POST", "/v1/audio/speech", req)
if err != nil {
return nil, err
}
return newSpeechStream(resp.Body), nil
}
// SpeechStream wraps the generic Stream for speech streaming events.
type SpeechStream struct {
stream *Stream[json.RawMessage]
event audio.SpeechStreamEvent
err error
}
func newSpeechStream(body readCloser) *SpeechStream {
return &SpeechStream{
stream: newStream[json.RawMessage](body),
}
}
// Next advances to the next event. Returns false when done or on error.
func (s *SpeechStream) Next() bool {
if s.err != nil {
return false
}
if !s.stream.Next() {
s.err = s.stream.Err()
return false
}
event, err := audio.UnmarshalSpeechStreamEvent(s.stream.Current())
if err != nil {
s.err = err
return false
}
s.event = event
return true
}
// Current returns the most recently read event.
func (s *SpeechStream) Current() audio.SpeechStreamEvent { return s.event }
// Err returns any error encountered during streaming.
func (s *SpeechStream) Err() error { return s.err }
// Close releases the underlying connection.
func (s *SpeechStream) Close() error { return s.stream.Close() }
// ListVoices returns available voices.
func (c *Client) ListVoices(ctx context.Context) (*audio.VoiceListResponse, error) {
var resp audio.VoiceListResponse
if err := c.doJSON(ctx, "GET", "/v1/audio/voices", nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// CreateVoice creates a custom voice.
func (c *Client) CreateVoice(ctx context.Context, req *audio.VoiceCreateRequest) (*audio.VoiceResponse, error) {
var resp audio.VoiceResponse
if err := c.doJSON(ctx, "POST", "/v1/audio/voices", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// GetVoice retrieves a voice by ID.
func (c *Client) GetVoice(ctx context.Context, voiceID string) (*audio.VoiceResponse, error) {
var resp audio.VoiceResponse
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/audio/voices/%s", voiceID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// UpdateVoice updates a voice.
func (c *Client) UpdateVoice(ctx context.Context, voiceID string, req *audio.VoiceUpdateRequest) (*audio.VoiceResponse, error) {
var resp audio.VoiceResponse
if err := c.doJSON(ctx, "PATCH", fmt.Sprintf("/v1/audio/voices/%s", voiceID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// DeleteVoice deletes a voice.
func (c *Client) DeleteVoice(ctx context.Context, voiceID string) error {
resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/audio/voices/%s", voiceID), nil)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return parseAPIError(resp)
}
return nil
}
// GetVoiceSampleAudio retrieves the sample audio for a voice.
// Returns the raw HTTP response; the caller must close the body.
func (c *Client) GetVoiceSampleAudio(ctx context.Context, voiceID string) (*http.Response, error) {
resp, err := c.do(ctx, "GET", fmt.Sprintf("/v1/audio/voices/%s/sample", voiceID), nil)
if err != nil {
return nil, err
}
if resp.StatusCode >= 400 {
defer resp.Body.Close()
return nil, parseAPIError(resp)
}
return resp, nil
}

217
audio_speech_test.go Normal file
View File

@@ -0,0 +1,217 @@
package mistral
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"somegit.dev/vikingowl/mistral-go-sdk/audio"
)
func TestSpeech_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("expected POST, got %s", r.Method)
}
if r.URL.Path != "/v1/audio/speech" {
t.Errorf("got path %s", r.URL.Path)
}
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
if body["input"] != "Hello world" {
t.Errorf("got input %v", body["input"])
}
if body["stream"] != false {
t.Errorf("expected stream=false")
}
json.NewEncoder(w).Encode(map[string]any{
"audio_data": "base64audiodata==",
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.Speech(context.Background(), &audio.SpeechRequest{
Input: "Hello world",
Model: "mistral-speech",
})
if err != nil {
t.Fatal(err)
}
if resp.AudioData != "base64audiodata==" {
t.Errorf("got audio_data %q", resp.AudioData)
}
}
func TestSpeechStream_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
if body["stream"] != true {
t.Errorf("expected stream=true")
}
w.Header().Set("Content-Type", "text/event-stream")
flusher, _ := w.(http.Flusher)
delta, _ := json.Marshal(map[string]any{
"type": "speech.audio.delta", "audio_data": "chunk1==",
})
fmt.Fprintf(w, "data: %s\n\n", delta)
flusher.Flush()
done, _ := json.Marshal(map[string]any{
"type": "speech.audio.done",
"usage": map[string]any{
"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15,
},
})
fmt.Fprintf(w, "data: %s\n\n", done)
flusher.Flush()
fmt.Fprint(w, "data: [DONE]\n\n")
flusher.Flush()
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
stream, err := client.SpeechStream(context.Background(), &audio.SpeechRequest{
Input: "Hi",
Model: "mistral-speech",
})
if err != nil {
t.Fatal(err)
}
defer stream.Close()
var events int
for stream.Next() {
events++
}
if err := stream.Err(); err != nil {
t.Fatal(err)
}
if events != 2 {
t.Errorf("got %d events, want 2", events)
}
}
func TestListVoices_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/audio/voices" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"items": []map[string]any{
{"id": "v1", "name": "Default", "created_at": "2025-01-01"},
},
"total": 1, "page": 1, "page_size": 10, "total_pages": 1,
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.ListVoices(context.Background())
if err != nil {
t.Fatal(err)
}
if len(resp.Items) != 1 {
t.Fatalf("got %d voices", len(resp.Items))
}
if resp.Items[0].ID != "v1" {
t.Errorf("got id %q", resp.Items[0].ID)
}
}
func TestCreateVoice_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("expected POST, got %s", r.Method)
}
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
if body["name"] != "MyVoice" {
t.Errorf("got name %v", body["name"])
}
json.NewEncoder(w).Encode(map[string]any{
"id": "v2", "name": "MyVoice", "created_at": "2025-01-01",
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.CreateVoice(context.Background(), &audio.VoiceCreateRequest{
Name: "MyVoice",
SampleAudio: "base64audio==",
})
if err != nil {
t.Fatal(err)
}
if resp.ID != "v2" {
t.Errorf("got id %q", resp.ID)
}
}
func TestGetVoice_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/audio/voices/v1" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"id": "v1", "name": "Default", "created_at": "2025-01-01",
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.GetVoice(context.Background(), "v1")
if err != nil {
t.Fatal(err)
}
if resp.Name != "Default" {
t.Errorf("got name %q", resp.Name)
}
}
func TestUpdateVoice_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "PATCH" {
t.Errorf("expected PATCH, got %s", r.Method)
}
json.NewEncoder(w).Encode(map[string]any{
"id": "v1", "name": "Renamed", "created_at": "2025-01-01",
})
}))
defer server.Close()
name := "Renamed"
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.UpdateVoice(context.Background(), "v1", &audio.VoiceUpdateRequest{
Name: &name,
})
if err != nil {
t.Fatal(err)
}
if resp.Name != "Renamed" {
t.Errorf("got name %q", resp.Name)
}
}
func TestDeleteVoice_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "DELETE" {
t.Errorf("expected DELETE, got %s", r.Method)
}
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
err := client.DeleteVoice(context.Background(), "v1")
if err != nil {
t.Fatal(err)
}
}

View File

@@ -3,6 +3,18 @@ package chat
import (
"encoding/json"
"fmt"
"strconv"
)
// BuiltInConnector identifies a built-in connector type.
type BuiltInConnector string
const (
ConnectorWebSearch BuiltInConnector = "web_search"
ConnectorWebSearchPremium BuiltInConnector = "web_search_premium"
ConnectorCodeInterpreter BuiltInConnector = "code_interpreter"
ConnectorImageGeneration BuiltInConnector = "image_generation"
ConnectorDocumentLibrary BuiltInConnector = "document_library"
)
// ContentChunk is a sealed interface for message content parts.
@@ -112,9 +124,65 @@ func (c *FileChunk) MarshalJSON() ([]byte, error) {
})
}
// ReferenceID is a reference identifier that can be an integer or string.
// Use [IntRef] or [StringRef] constructors.
type ReferenceID struct {
raw string
isString bool
}
// IntRef creates an integer reference ID.
func IntRef(n int) ReferenceID {
return ReferenceID{raw: strconv.Itoa(n)}
}
// StringRef creates a string reference ID.
func StringRef(s string) ReferenceID {
return ReferenceID{raw: s, isString: true}
}
// String returns the string representation.
func (id ReferenceID) String() string { return id.raw }
// Int returns the integer value and true if this is a numeric reference.
func (id ReferenceID) Int() (int, bool) {
if id.isString {
return 0, false
}
n, err := strconv.Atoi(id.raw)
return n, err == nil
}
// IsString reports whether this is a string reference.
func (id ReferenceID) IsString() bool { return id.isString }
func (id ReferenceID) MarshalJSON() ([]byte, error) {
if id.isString {
return json.Marshal(id.raw)
}
return []byte(id.raw), nil
}
func (id *ReferenceID) UnmarshalJSON(data []byte) error {
if len(data) == 0 {
return nil
}
if data[0] == '"' {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return err
}
id.raw = s
id.isString = true
return nil
}
id.raw = string(data)
return nil
}
// ReferenceChunk represents a reference content part.
type ReferenceChunk struct {
ReferenceIDs []int `json:"reference_ids"`
ReferenceIDs []ReferenceID `json:"reference_ids"`
}
func (*ReferenceChunk) contentChunk() {}
@@ -192,6 +260,49 @@ func (c *AudioChunk) MarshalJSON() ([]byte, error) {
})
}
// ToolReferenceChunk represents a tool reference content part.
type ToolReferenceChunk struct {
Tool string `json:"tool"`
Title string `json:"title"`
URL *string `json:"url,omitempty"`
Favicon *string `json:"favicon,omitempty"`
Description *string `json:"description,omitempty"`
}
func (*ToolReferenceChunk) contentChunk() {}
func (c *ToolReferenceChunk) MarshalJSON() ([]byte, error) {
type alias ToolReferenceChunk
return json.Marshal(&struct {
Type string `json:"type"`
*alias
}{
Type: "tool_reference",
alias: (*alias)(c),
})
}
// ToolFileChunk represents a tool-generated file content part.
type ToolFileChunk struct {
Tool string `json:"tool"`
FileID string `json:"file_id"`
FileName *string `json:"file_name,omitempty"`
FileType *string `json:"file_type,omitempty"`
}
func (*ToolFileChunk) contentChunk() {}
func (c *ToolFileChunk) MarshalJSON() ([]byte, error) {
type alias ToolFileChunk
return json.Marshal(&struct {
Type string `json:"type"`
*alias
}{
Type: "tool_file",
alias: (*alias)(c),
})
}
// UnmarshalContentChunk dispatches to the concrete ContentChunk type
// based on the "type" discriminator field.
func UnmarshalContentChunk(data []byte) (ContentChunk, error) {
@@ -223,6 +334,12 @@ func UnmarshalContentChunk(data []byte) (ContentChunk, error) {
case "input_audio":
var c AudioChunk
return &c, json.Unmarshal(data, &c)
case "tool_reference":
var c ToolReferenceChunk
return &c, json.Unmarshal(data, &c)
case "tool_file":
var c ToolFileChunk
return &c, json.Unmarshal(data, &c)
default:
return &UnknownChunk{Type: probe.Type, Raw: json.RawMessage(data)}, nil
}

View File

@@ -150,12 +150,16 @@ func TestFileChunk_RoundTrip(t *testing.T) {
}
}
func TestReferenceChunk_RoundTrip(t *testing.T) {
original := &ReferenceChunk{ReferenceIDs: []int{1, 2, 3}}
func TestReferenceChunk_RoundTrip_IntIDs(t *testing.T) {
original := &ReferenceChunk{ReferenceIDs: []ReferenceID{IntRef(1), IntRef(2), IntRef(3)}}
data, err := json.Marshal(original)
if err != nil {
t.Fatal(err)
}
want := `{"type":"reference","reference_ids":[1,2,3]}`
if string(data) != want {
t.Errorf("marshal: got %s, want %s", data, want)
}
chunk, err := UnmarshalContentChunk(data)
if err != nil {
t.Fatal(err)
@@ -164,8 +168,64 @@ func TestReferenceChunk_RoundTrip(t *testing.T) {
if !ok {
t.Fatalf("expected *ReferenceChunk, got %T", chunk)
}
if len(rc.ReferenceIDs) != 3 || rc.ReferenceIDs[0] != 1 {
t.Errorf("got %v, want [1 2 3]", rc.ReferenceIDs)
if len(rc.ReferenceIDs) != 3 {
t.Fatalf("got %d IDs, want 3", len(rc.ReferenceIDs))
}
n, ok := rc.ReferenceIDs[0].Int()
if !ok || n != 1 {
t.Errorf("got %v (ok=%v), want 1", n, ok)
}
}
func TestReferenceChunk_RoundTrip_MixedIDs(t *testing.T) {
data := []byte(`{"type":"reference","reference_ids":[1,"abc",42]}`)
chunk, err := UnmarshalContentChunk(data)
if err != nil {
t.Fatal(err)
}
rc, ok := chunk.(*ReferenceChunk)
if !ok {
t.Fatalf("expected *ReferenceChunk, got %T", chunk)
}
if len(rc.ReferenceIDs) != 3 {
t.Fatalf("got %d IDs, want 3", len(rc.ReferenceIDs))
}
// First: int 1
if n, ok := rc.ReferenceIDs[0].Int(); !ok || n != 1 {
t.Errorf("IDs[0]: got %v (ok=%v), want int 1", n, ok)
}
// Second: string "abc"
if !rc.ReferenceIDs[1].IsString() || rc.ReferenceIDs[1].String() != "abc" {
t.Errorf("IDs[1]: got %q (isString=%v), want string abc", rc.ReferenceIDs[1].String(), rc.ReferenceIDs[1].IsString())
}
// Third: int 42
if n, ok := rc.ReferenceIDs[2].Int(); !ok || n != 42 {
t.Errorf("IDs[2]: got %v (ok=%v), want int 42", n, ok)
}
// Round-trip preserves types
out, err := json.Marshal(rc)
if err != nil {
t.Fatal(err)
}
if string(out) != `{"type":"reference","reference_ids":[1,"abc",42]}` {
t.Errorf("round-trip: got %s", out)
}
}
func TestReferenceID_StringRef(t *testing.T) {
id := StringRef("doc-123")
if !id.IsString() {
t.Error("expected IsString=true")
}
if id.String() != "doc-123" {
t.Errorf("got %q", id.String())
}
if _, ok := id.Int(); ok {
t.Error("Int() should return false for string ref")
}
data, _ := json.Marshal(id)
if string(data) != `"doc-123"` {
t.Errorf("marshal: got %s", data)
}
}
@@ -174,7 +234,7 @@ func TestThinkChunk_RoundTrip(t *testing.T) {
original := &ThinkChunk{
Thinking: []ContentChunk{
&TextChunk{Text: "reasoning step"},
&ReferenceChunk{ReferenceIDs: []int{42}},
&ReferenceChunk{ReferenceIDs: []ReferenceID{IntRef(42)}},
},
Closed: &closed,
}
@@ -365,3 +425,76 @@ func TestContent_IsNull(t *testing.T) {
t.Error("chunks content should not be null")
}
}
func TestToolReferenceChunk_RoundTrip(t *testing.T) {
url := "https://example.com/result"
desc := "A search result"
original := &ToolReferenceChunk{
Tool: string(ConnectorWebSearch),
Title: "Example Result",
URL: &url,
Description: &desc,
}
data, err := json.Marshal(original)
if err != nil {
t.Fatal(err)
}
chunk, err := UnmarshalContentChunk(data)
if err != nil {
t.Fatal(err)
}
tr, ok := chunk.(*ToolReferenceChunk)
if !ok {
t.Fatalf("expected *ToolReferenceChunk, got %T", chunk)
}
if tr.Tool != "web_search" {
t.Errorf("got tool %q, want web_search", tr.Tool)
}
if tr.Title != "Example Result" {
t.Errorf("got title %q", tr.Title)
}
if tr.URL == nil || *tr.URL != url {
t.Errorf("got url %v, want %q", tr.URL, url)
}
if tr.Description == nil || *tr.Description != desc {
t.Errorf("got description %v, want %q", tr.Description, desc)
}
if tr.Favicon != nil {
t.Errorf("expected nil favicon, got %v", tr.Favicon)
}
}
func TestToolFileChunk_RoundTrip(t *testing.T) {
fname := "output.csv"
ftype := "text/csv"
original := &ToolFileChunk{
Tool: string(ConnectorCodeInterpreter),
FileID: "file-abc123",
FileName: &fname,
FileType: &ftype,
}
data, err := json.Marshal(original)
if err != nil {
t.Fatal(err)
}
chunk, err := UnmarshalContentChunk(data)
if err != nil {
t.Fatal(err)
}
tf, ok := chunk.(*ToolFileChunk)
if !ok {
t.Fatalf("expected *ToolFileChunk, got %T", chunk)
}
if tf.Tool != "code_interpreter" {
t.Errorf("got tool %q", tf.Tool)
}
if tf.FileID != "file-abc123" {
t.Errorf("got file_id %q", tf.FileID)
}
if tf.FileName == nil || *tf.FileName != fname {
t.Errorf("got file_name %v", tf.FileName)
}
if tf.FileType == nil || *tf.FileType != ftype {
t.Errorf("got file_type %v", tf.FileType)
}
}

View File

@@ -5,5 +5,22 @@
// struct literals.
//
// Content is polymorphic: it can be a plain string (via [TextContent]),
// nil, or a slice of [ContentChunk] values (text, image URL, document URL, audio).
// nil, or a slice of [ContentChunk] values (text, image URL, document URL,
// audio, tool reference, tool file).
//
// # Guardrails
//
// [GuardrailConfig] configures moderation guardrails on completion requests.
// Both v1 ([ModerationLLMV1Config]) and v2 ([ModerationLLMV2Config]) moderation
// configs are supported.
//
// # Content Chunks
//
// The following chunk types are supported: [TextChunk], [ImageURLChunk],
// [DocumentURLChunk], [FileChunk], [ReferenceChunk], [ThinkChunk],
// [AudioChunk], [ToolReferenceChunk], [ToolFileChunk].
// Unrecognized types are preserved as [UnknownChunk].
//
// [ReferenceChunk] uses [ReferenceID] values that can hold either integer
// or string identifiers. Use [IntRef] and [StringRef] constructors.
package chat

62
chat/guardrail.go Normal file
View File

@@ -0,0 +1,62 @@
package chat
// ModerationLLMAction specifies the action to take when content exceeds thresholds.
type ModerationLLMAction string
const (
ModerationActionNone ModerationLLMAction = "none"
ModerationActionBlock ModerationLLMAction = "block"
)
// ModerationLLMV1CategoryThresholds defines per-category score thresholds for v1 moderation.
type ModerationLLMV1CategoryThresholds struct {
Sexual *float64 `json:"sexual,omitempty"`
HateAndDiscrimination *float64 `json:"hate_and_discrimination,omitempty"`
ViolenceAndThreats *float64 `json:"violence_and_threats,omitempty"`
DangerousAndCriminalContent *float64 `json:"dangerous_and_criminal_content,omitempty"`
Selfharm *float64 `json:"selfharm,omitempty"`
Health *float64 `json:"health,omitempty"`
Financial *float64 `json:"financial,omitempty"`
Law *float64 `json:"law,omitempty"`
PII *float64 `json:"pii,omitempty"`
}
// ModerationLLMV1Config configures the v1 moderation LLM guardrail.
type ModerationLLMV1Config struct {
ModelName string `json:"model_name,omitempty"`
CustomCategoryThresholds *ModerationLLMV1CategoryThresholds `json:"custom_category_thresholds,omitempty"`
IgnoreOtherCategories bool `json:"ignore_other_categories,omitempty"`
Action ModerationLLMAction `json:"action,omitempty"`
}
// ModerationLLMV2CategoryThresholds defines per-category score thresholds for v2 moderation.
// V2 splits "dangerous_and_criminal_content" into separate "dangerous" and "criminal"
// categories and adds "jailbreaking".
type ModerationLLMV2CategoryThresholds struct {
Sexual *float64 `json:"sexual,omitempty"`
HateAndDiscrimination *float64 `json:"hate_and_discrimination,omitempty"`
ViolenceAndThreats *float64 `json:"violence_and_threats,omitempty"`
Dangerous *float64 `json:"dangerous,omitempty"`
Criminal *float64 `json:"criminal,omitempty"`
Selfharm *float64 `json:"selfharm,omitempty"`
Health *float64 `json:"health,omitempty"`
Financial *float64 `json:"financial,omitempty"`
Law *float64 `json:"law,omitempty"`
PII *float64 `json:"pii,omitempty"`
Jailbreaking *float64 `json:"jailbreaking,omitempty"`
}
// ModerationLLMV2Config configures the v2 moderation LLM guardrail.
type ModerationLLMV2Config struct {
ModelName string `json:"model_name,omitempty"`
CustomCategoryThresholds *ModerationLLMV2CategoryThresholds `json:"custom_category_thresholds,omitempty"`
IgnoreOtherCategories bool `json:"ignore_other_categories,omitempty"`
Action ModerationLLMAction `json:"action,omitempty"`
}
// GuardrailConfig configures moderation guardrails for requests.
type GuardrailConfig struct {
BlockOnError bool `json:"block_on_error"`
ModerationLLMV1 *ModerationLLMV1Config `json:"moderation_llm_v1,omitempty"`
ModerationLLMV2 *ModerationLLMV2Config `json:"moderation_llm_v2,omitempty"`
}

View File

@@ -9,6 +9,14 @@ const (
PromptModeReasoning PromptMode = "reasoning"
)
// ReasoningEffort controls the amount of reasoning effort the model uses.
type ReasoningEffort string
const (
ReasoningEffortNone ReasoningEffort = "none"
ReasoningEffortHigh ReasoningEffort = "high"
)
// Prediction provides expected completion content for optimization.
type Prediction struct {
Type string `json:"type"`
@@ -30,11 +38,13 @@ type CompletionRequest struct {
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
N *int `json:"n,omitempty"`
SafePrompt bool `json:"safe_prompt,omitempty"`
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
Prediction *Prediction `json:"prediction,omitempty"`
PromptMode *PromptMode `json:"prompt_mode,omitempty"`
SafePrompt bool `json:"safe_prompt,omitempty"`
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
Prediction *Prediction `json:"prediction,omitempty"`
PromptMode *PromptMode `json:"prompt_mode,omitempty"`
Guardrails []GuardrailConfig `json:"guardrails,omitempty"`
ReasoningEffort *ReasoningEffort `json:"reasoning_effort,omitempty"`
stream bool
}

View File

@@ -350,6 +350,41 @@ func TestChatComplete_RequestBody(t *testing.T) {
}
}
func TestChatComplete_ReasoningEffort(t *testing.T) {
effort := chat.ReasoningEffortHigh
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
bodyBytes, _ := io.ReadAll(r.Body)
var body map[string]any
json.Unmarshal(bodyBytes, &body)
if body["reasoning_effort"] != "high" {
t.Errorf("expected reasoning_effort=high, got %v", body["reasoning_effort"])
}
json.NewEncoder(w).Encode(map[string]any{
"id": "chat-re", "object": "chat.completion",
"model": "m", "created": 0,
"choices": []map[string]any{{
"index": 0, "message": map[string]any{"role": "assistant", "content": "ok"},
"finish_reason": "stop",
}},
"usage": map[string]any{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
_, err := client.ChatComplete(context.Background(), &chat.CompletionRequest{
Model: "m",
Messages: []chat.Message{&chat.UserMessage{Content: chat.TextContent("Hi")}},
ReasoningEffort: &effort,
})
if err != nil {
t.Fatal(err)
}
}
func TestChatComplete_ContextCanceled(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Never responds — context should cancel first

100
connector/connector.go Normal file
View File

@@ -0,0 +1,100 @@
package connector
import "encoding/json"
// Visibility controls who can see a connector or tool.
type Visibility string
const (
VisibilitySharedGlobal Visibility = "shared_global"
VisibilitySharedOrg Visibility = "shared_org"
VisibilitySharedWorkspace Visibility = "shared_workspace"
VisibilityPrivate Visibility = "private"
)
// AuthData holds OAuth2 client credentials for a connector.
type AuthData struct {
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
}
// Connector represents a registered MCP connector.
type Connector struct {
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
CreatedAt string `json:"created_at"`
ModifiedAt string `json:"modified_at"`
Server *string `json:"server,omitempty"`
AuthType *string `json:"auth_type,omitempty"`
}
// CreateRequest creates a new connector.
type CreateRequest struct {
Name string `json:"name"`
Description string `json:"description"`
Server string `json:"server"`
IconURL *string `json:"icon_url,omitempty"`
Visibility *Visibility `json:"visibility,omitempty"`
Headers map[string]string `json:"headers,omitempty"`
AuthData *AuthData `json:"auth_data,omitempty"`
SystemPrompt *string `json:"system_prompt,omitempty"`
}
// UpdateRequest updates an existing connector.
type UpdateRequest struct {
Name *string `json:"name,omitempty"`
Description *string `json:"description,omitempty"`
IconURL *string `json:"icon_url,omitempty"`
SystemPrompt *string `json:"system_prompt,omitempty"`
Server *string `json:"server,omitempty"`
Headers map[string]string `json:"headers,omitempty"`
AuthData *AuthData `json:"auth_data,omitempty"`
ConnectionConfig map[string]any `json:"connection_config,omitempty"`
ConnectionSecrets map[string]any `json:"connection_secrets,omitempty"`
}
// AuthURLResponse is the response from getting a connector's OAuth URL.
type AuthURLResponse struct {
AuthURL string `json:"auth_url"`
TTL int `json:"ttl"`
}
// CallToolRequest is the request body for calling a connector tool.
type CallToolRequest struct {
Arguments map[string]any `json:"arguments,omitempty"`
}
// CallToolResponse is the response from calling a connector tool.
// Content is left as raw JSON because the upstream API returns a union
// of 5 content types (text, image, audio, resource link, embedded resource).
type CallToolResponse struct {
Content json.RawMessage `json:"content"`
Metadata map[string]any `json:"metadata,omitempty"`
}
// Tool represents a tool exposed by a connector.
type Tool struct {
ID string `json:"id"`
Name string `json:"name"`
Description *string `json:"description,omitempty"`
Visibility Visibility `json:"visibility,omitempty"`
CreatedAt string `json:"created_at,omitempty"`
ModifiedAt string `json:"modified_at,omitempty"`
SystemPrompt *string `json:"system_prompt,omitempty"`
JsonSchema map[string]any `json:"jsonschema,omitempty"`
Active *bool `json:"active,omitempty"`
}
// ListParams holds query parameters for listing connectors.
type ListParams struct {
Page *int
PageSize *int
}
// ListToolsParams holds query parameters for listing connector tools.
type ListToolsParams struct {
Page *int
PageSize *int
Refresh *bool
}

6
connector/doc.go Normal file
View File

@@ -0,0 +1,6 @@
// Package connector provides types for the Mistral connectors API.
//
// Connectors represent MCP (Model Context Protocol) server integrations.
// Use [CreateRequest] to register a new connector, then use tools
// discovered via the list-tools endpoint in chat or agent completions.
package connector

119
connectors.go Normal file
View File

@@ -0,0 +1,119 @@
package mistral
import (
"context"
"fmt"
"net/url"
"strconv"
"somegit.dev/vikingowl/mistral-go-sdk/connector"
)
// CreateConnector registers a new MCP connector.
func (c *Client) CreateConnector(ctx context.Context, req *connector.CreateRequest) (*connector.Connector, error) {
var resp connector.Connector
if err := c.doJSON(ctx, "POST", "/v1/connectors", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ListConnectors returns all connectors.
func (c *Client) ListConnectors(ctx context.Context, params *connector.ListParams) ([]connector.Connector, error) {
path := "/v1/connectors"
if params != nil {
q := url.Values{}
if params.Page != nil {
q.Set("page", strconv.Itoa(*params.Page))
}
if params.PageSize != nil {
q.Set("page_size", strconv.Itoa(*params.PageSize))
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp []connector.Connector
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return resp, nil
}
// GetConnector retrieves a connector by ID or name.
func (c *Client) GetConnector(ctx context.Context, idOrName string) (*connector.Connector, error) {
var resp connector.Connector
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/connectors/%s", idOrName), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// UpdateConnector updates an existing connector.
func (c *Client) UpdateConnector(ctx context.Context, idOrName string, req *connector.UpdateRequest) (*connector.Connector, error) {
var resp connector.Connector
if err := c.doJSON(ctx, "PATCH", fmt.Sprintf("/v1/connectors/%s", idOrName), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// DeleteConnector deletes a connector.
func (c *Client) DeleteConnector(ctx context.Context, idOrName string) error {
resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/connectors/%s", idOrName), nil)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return parseAPIError(resp)
}
return nil
}
// GetConnectorAuthURL returns the OAuth2 authorization URL for a connector.
func (c *Client) GetConnectorAuthURL(ctx context.Context, idOrName string, appReturnURL *string) (*connector.AuthURLResponse, error) {
path := fmt.Sprintf("/v1/connectors/%s/auth_url", idOrName)
if appReturnURL != nil {
path += "?app_return_url=" + url.QueryEscape(*appReturnURL)
}
var resp connector.AuthURLResponse
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ListConnectorTools lists tools exposed by a connector.
func (c *Client) ListConnectorTools(ctx context.Context, idOrName string, params *connector.ListToolsParams) ([]connector.Tool, error) {
path := fmt.Sprintf("/v1/connectors/%s/tools", idOrName)
if params != nil {
q := url.Values{}
if params.Page != nil {
q.Set("page", strconv.Itoa(*params.Page))
}
if params.PageSize != nil {
q.Set("page_size", strconv.Itoa(*params.PageSize))
}
if params.Refresh != nil {
q.Set("refresh", strconv.FormatBool(*params.Refresh))
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp []connector.Tool
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return resp, nil
}
// CallConnectorTool invokes a tool on a connector.
func (c *Client) CallConnectorTool(ctx context.Context, idOrName, toolName string, req *connector.CallToolRequest) (*connector.CallToolResponse, error) {
var resp connector.CallToolResponse
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/connectors/%s/tools/%s/call", idOrName, toolName), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}

217
connectors_test.go Normal file
View File

@@ -0,0 +1,217 @@
package mistral
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"somegit.dev/vikingowl/mistral-go-sdk/connector"
)
func TestCreateConnector_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("expected POST, got %s", r.Method)
}
if r.URL.Path != "/v1/connectors" {
t.Errorf("got path %s", r.URL.Path)
}
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
if body["name"] != "my_connector" {
t.Errorf("got name %v", body["name"])
}
if body["server"] != "https://mcp.example.com" {
t.Errorf("got server %v", body["server"])
}
json.NewEncoder(w).Encode(map[string]any{
"id": "conn-1", "name": "my_connector",
"description": "test", "created_at": "2025-01-01",
"modified_at": "2025-01-01", "server": "https://mcp.example.com",
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.CreateConnector(context.Background(), &connector.CreateRequest{
Name: "my_connector",
Description: "test",
Server: "https://mcp.example.com",
})
if err != nil {
t.Fatal(err)
}
if resp.ID != "conn-1" {
t.Errorf("got id %q", resp.ID)
}
}
func TestListConnectors_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
t.Errorf("expected GET, got %s", r.Method)
}
json.NewEncoder(w).Encode([]map[string]any{
{"id": "c1", "name": "conn1", "description": "d1", "created_at": "t", "modified_at": "t"},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
list, err := client.ListConnectors(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
if len(list) != 1 {
t.Fatalf("got %d connectors", len(list))
}
if list[0].ID != "c1" {
t.Errorf("got id %q", list[0].ID)
}
}
func TestGetConnector_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/connectors/my_conn" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"id": "c1", "name": "my_conn", "description": "d",
"created_at": "t", "modified_at": "t",
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
c, err := client.GetConnector(context.Background(), "my_conn")
if err != nil {
t.Fatal(err)
}
if c.Name != "my_conn" {
t.Errorf("got name %q", c.Name)
}
}
func TestUpdateConnector_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "PATCH" {
t.Errorf("expected PATCH, got %s", r.Method)
}
json.NewEncoder(w).Encode(map[string]any{
"id": "c1", "name": "updated", "description": "new desc",
"created_at": "t", "modified_at": "t",
})
}))
defer server.Close()
name := "updated"
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.UpdateConnector(context.Background(), "c1", &connector.UpdateRequest{
Name: &name,
})
if err != nil {
t.Fatal(err)
}
if resp.Name != "updated" {
t.Errorf("got name %q", resp.Name)
}
}
func TestDeleteConnector_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "DELETE" {
t.Errorf("expected DELETE, got %s", r.Method)
}
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
err := client.DeleteConnector(context.Background(), "c1")
if err != nil {
t.Fatal(err)
}
}
func TestGetConnectorAuthURL_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/connectors/c1/auth_url" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"auth_url": "https://oauth.example.com/authorize",
"ttl": 3600,
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.GetConnectorAuthURL(context.Background(), "c1", nil)
if err != nil {
t.Fatal(err)
}
if resp.AuthURL != "https://oauth.example.com/authorize" {
t.Errorf("got auth_url %q", resp.AuthURL)
}
if resp.TTL != 3600 {
t.Errorf("got ttl %d", resp.TTL)
}
}
func TestListConnectorTools_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/connectors/c1/tools" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode([]map[string]any{
{"id": "t1", "name": "search", "description": "search the web"},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
tools, err := client.ListConnectorTools(context.Background(), "c1", nil)
if err != nil {
t.Fatal(err)
}
if len(tools) != 1 {
t.Fatalf("got %d tools", len(tools))
}
if tools[0].Name != "search" {
t.Errorf("got name %q", tools[0].Name)
}
}
func TestCallConnectorTool_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("expected POST, got %s", r.Method)
}
if r.URL.Path != "/v1/connectors/c1/tools/search/call" {
t.Errorf("got path %s", r.URL.Path)
}
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
args := body["arguments"].(map[string]any)
if args["query"] != "hello" {
t.Errorf("got query %v", args["query"])
}
json.NewEncoder(w).Encode(map[string]any{
"content": []map[string]any{{"type": "text", "text": "result"}},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.CallConnectorTool(context.Background(), "c1", "search", &connector.CallToolRequest{
Arguments: map[string]any{"query": "hello"},
})
if err != nil {
t.Fatal(err)
}
if resp.Content == nil {
t.Error("expected non-nil content")
}
}

View File

@@ -1,6 +1,10 @@
package conversation
import "encoding/json"
import (
"encoding/json"
"somegit.dev/vikingowl/mistral-go-sdk/chat"
)
// StartRequest starts a new conversation.
type StartRequest struct {
@@ -11,9 +15,10 @@ type StartRequest struct {
Instructions *string `json:"instructions,omitempty"`
Tools []Tool `json:"tools,omitempty"`
CompletionArgs *CompletionArgs `json:"completion_args,omitempty"`
Store *bool `json:"store,omitempty"`
HandoffExecution *HandoffExecution `json:"handoff_execution,omitempty"`
Name *string `json:"name,omitempty"`
Store *bool `json:"store,omitempty"`
HandoffExecution *HandoffExecution `json:"handoff_execution,omitempty"`
Guardrails []chat.GuardrailConfig `json:"guardrails,omitempty"`
Name *string `json:"name,omitempty"`
Description *string `json:"description,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
stream bool
@@ -61,13 +66,14 @@ func (r *AppendRequest) MarshalJSON() ([]byte, error) {
// RestartRequest restarts a conversation from a specific entry.
type RestartRequest struct {
Inputs Inputs `json:"inputs"`
FromEntryID string `json:"from_entry_id"`
CompletionArgs *CompletionArgs `json:"completion_args,omitempty"`
Store *bool `json:"store,omitempty"`
HandoffExecution *HandoffExecution `json:"handoff_execution,omitempty"`
AgentVersion json.RawMessage `json:"agent_version,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
Inputs Inputs `json:"inputs"`
FromEntryID string `json:"from_entry_id"`
CompletionArgs *CompletionArgs `json:"completion_args,omitempty"`
Store *bool `json:"store,omitempty"`
HandoffExecution *HandoffExecution `json:"handoff_execution,omitempty"`
Guardrails []chat.GuardrailConfig `json:"guardrails,omitempty"`
AgentVersion json.RawMessage `json:"agent_version,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
stream bool
}

12
doc.go
View File

@@ -39,7 +39,13 @@
// # Sub-packages
//
// Types are organized into sub-packages by domain: [chat], [agents],
// [conversation], [embedding], [model], [file], [finetune], [batch],
// [ocr], [audio], [library], [moderation], [classification], and [fim].
// All service methods live directly on [Client].
// [connector], [conversation], [embedding], [model], [file], [finetune],
// [batch], [ocr], [audio], [library], [moderation], [classification],
// [fim], and [observability]. All service methods live directly on [Client].
//
// # Reference
//
// This SDK tracks the official Mistral Python SDK
// (https://github.com/mistralai/client-python) as its upstream reference
// for API surface and type definitions.
package mistral

File diff suppressed because it is too large Load Diff

View File

@@ -64,6 +64,23 @@ type SignedURL struct {
URL string `json:"url"`
}
// Visibility controls who can see a file.
type Visibility string
const (
VisibilitySharedGlobal Visibility = "shared_global"
VisibilitySharedOrg Visibility = "shared_org"
VisibilitySharedWorkspace Visibility = "shared_workspace"
VisibilityPrivate Visibility = "private"
)
// UploadParams holds parameters for uploading a file.
type UploadParams struct {
Purpose Purpose
Expiry *int
Visibility *Visibility
}
// ListParams holds optional parameters for listing files.
type ListParams struct {
Page *int

View File

@@ -12,10 +12,18 @@ import (
)
// UploadFile uploads a file for use with fine-tuning, batch, or OCR.
func (c *Client) UploadFile(ctx context.Context, filename string, r io.Reader, purpose file.Purpose) (*file.File, error) {
func (c *Client) UploadFile(ctx context.Context, filename string, r io.Reader, params *file.UploadParams) (*file.File, error) {
fields := map[string]string{}
if purpose != "" {
fields["purpose"] = string(purpose)
if params != nil {
if params.Purpose != "" {
fields["purpose"] = string(params.Purpose)
}
if params.Expiry != nil {
fields["expiry"] = strconv.Itoa(*params.Expiry)
}
if params.Visibility != nil {
fields["visibility"] = string(*params.Visibility)
}
}
var resp file.File
if err := c.doMultipart(ctx, "/v1/files", filename, r, fields, &resp); err != nil {

View File

@@ -58,7 +58,7 @@ func TestUploadFile_Success(t *testing.T) {
context.Background(),
"train.jsonl",
strings.NewReader(`{"text":"hello"}`),
file.PurposeFineTune,
&file.UploadParams{Purpose: file.PurposeFineTune},
)
if err != nil {
t.Fatal(err)
@@ -74,6 +74,43 @@ func TestUploadFile_Success(t *testing.T) {
}
}
func TestUploadFile_WithExpiryAndVisibility(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseMultipartForm(10 << 20); err != nil {
t.Fatal(err)
}
if r.FormValue("purpose") != "fine-tune" {
t.Errorf("got purpose %q", r.FormValue("purpose"))
}
if r.FormValue("expiry") != "48" {
t.Errorf("expected expiry=48, got %q", r.FormValue("expiry"))
}
if r.FormValue("visibility") != "private" {
t.Errorf("expected visibility=private, got %q", r.FormValue("visibility"))
}
json.NewEncoder(w).Encode(map[string]any{
"id": "file-ev", "object": "file", "bytes": 10,
"created_at": 1, "filename": "data.jsonl",
"purpose": "fine-tune", "sample_type": "instruct",
"source": "upload",
})
}))
defer server.Close()
expiry := 48
vis := file.VisibilityPrivate
client := NewClient("key", WithBaseURL(server.URL))
_, err := client.UploadFile(context.Background(), "data.jsonl", strings.NewReader("{}"), &file.UploadParams{
Purpose: file.PurposeFineTune,
Expiry: &expiry,
Visibility: &vis,
})
if err != nil {
t.Fatal(err)
}
}
func TestListFiles_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
@@ -260,7 +297,7 @@ func TestUploadFile_Error(t *testing.T) {
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
_, err := client.UploadFile(context.Background(), "bad.txt", strings.NewReader(""), file.PurposeFineTune)
_, err := client.UploadFile(context.Background(), "bad.txt", strings.NewReader(""), &file.UploadParams{Purpose: file.PurposeFineTune})
if err == nil {
t.Fatal("expected error")
}

View File

@@ -24,7 +24,7 @@ func integrationClient(t *testing.T) *Client {
func TestIntegration_ListModels(t *testing.T) {
client := integrationClient(t)
resp, err := client.ListModels(context.Background())
resp, err := client.ListModels(context.Background(), nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -6,7 +6,7 @@ import (
)
// Version is the SDK version string.
const Version = "0.1.0"
const Version = "1.1.0"
const (
defaultBaseURL = "https://api.mistral.ai"

View File

@@ -34,8 +34,10 @@ type ModelCapabilities struct {
OCR bool `json:"ocr"`
Classification bool `json:"classification"`
Moderation bool `json:"moderation"`
Audio bool `json:"audio"`
AudioTranscription bool `json:"audio_transcription"`
Audio bool `json:"audio"`
AudioTranscription bool `json:"audio_transcription"`
AudioTranscriptionRealtime bool `json:"audio_transcription_realtime"`
AudioSpeech bool `json:"audio_speech"`
}
// ModelList is the response from listing models.
@@ -50,3 +52,9 @@ type DeleteModelOut struct {
Object string `json:"object"`
Deleted bool `json:"deleted"`
}
// ListParams holds optional parameters for listing models.
type ListParams struct {
Provider *string
Model *string
}

View File

@@ -2,14 +2,28 @@ package mistral
import (
"context"
"net/url"
"somegit.dev/vikingowl/mistral-go-sdk/model"
)
// ListModels returns a list of available models.
func (c *Client) ListModels(ctx context.Context) (*model.ModelList, error) {
func (c *Client) ListModels(ctx context.Context, params *model.ListParams) (*model.ModelList, error) {
path := "/v1/models"
if params != nil {
q := url.Values{}
if params.Provider != nil {
q.Set("provider", *params.Provider)
}
if params.Model != nil {
q.Set("model", *params.Model)
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp model.ModelList
if err := c.doJSON(ctx, "GET", "/v1/models", nil, &resp); err != nil {
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil

View File

@@ -6,6 +6,8 @@ import (
"net/http"
"net/http/httptest"
"testing"
"somegit.dev/vikingowl/mistral-go-sdk/model"
)
func TestListModels_Success(t *testing.T) {
@@ -45,7 +47,7 @@ func TestListModels_Success(t *testing.T) {
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
list, err := client.ListModels(context.Background())
list, err := client.ListModels(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
@@ -82,6 +84,33 @@ func TestListModels_Success(t *testing.T) {
}
}
func TestListModels_WithParams(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Query().Get("provider") != "mistralai" {
t.Errorf("expected provider=mistralai, got %q", r.URL.Query().Get("provider"))
}
if r.URL.Query().Get("model") != "mistral-small" {
t.Errorf("expected model=mistral-small, got %q", r.URL.Query().Get("model"))
}
json.NewEncoder(w).Encode(map[string]any{
"object": "list",
"data": []map[string]any{},
})
}))
defer server.Close()
provider := "mistralai"
modelName := "mistral-small"
client := NewClient("key", WithBaseURL(server.URL))
_, err := client.ListModels(context.Background(), &model.ListParams{
Provider: &provider,
Model: &modelName,
})
if err != nil {
t.Fatal(err)
}
}
func TestGetModel_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/models/mistral-small-latest" {

46
observability/campaign.go Normal file
View File

@@ -0,0 +1,46 @@
package observability
// Campaign represents an observability campaign.
type Campaign struct {
ID string `json:"id"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
DeletedAt *string `json:"deleted_at,omitempty"`
Name string `json:"name"`
OwnerID string `json:"owner_id"`
WorkspaceID string `json:"workspace_id"`
Description string `json:"description"`
MaxNbEvents int `json:"max_nb_events"`
SearchParams FilterPayload `json:"search_params"`
Judge Judge `json:"judge"`
}
// CreateCampaignRequest creates a new campaign.
type CreateCampaignRequest struct {
SearchParams FilterPayload `json:"search_params"`
JudgeID string `json:"judge_id"`
Name string `json:"name"`
Description string `json:"description"`
MaxNbEvents int `json:"max_nb_events"`
}
// CampaignStatusResponse is the response for campaign status.
type CampaignStatusResponse struct {
Status TaskStatus `json:"status"`
}
// ListCampaignsResponse is the response from listing campaigns.
type ListCampaignsResponse struct {
Count int `json:"count"`
Results []Campaign `json:"results,omitempty"`
Next *string `json:"next,omitempty"`
Previous *string `json:"previous,omitempty"`
}
// ListCampaignEventsResponse is the response from listing campaign events.
type ListCampaignEventsResponse struct {
Count int `json:"count"`
Results []ChatCompletionEventPreview `json:"results,omitempty"`
Next *string `json:"next,omitempty"`
Previous *string `json:"previous,omitempty"`
}

156
observability/dataset.go Normal file
View File

@@ -0,0 +1,156 @@
package observability
import "encoding/json"
// ConversationSource indicates how a dataset record was created.
type ConversationSource string
const (
SourceExplorer ConversationSource = "EXPLORER"
SourceUploadedFile ConversationSource = "UPLOADED_FILE"
SourceDirectInput ConversationSource = "DIRECT_INPUT"
SourcePlayground ConversationSource = "PLAYGROUND"
)
// Dataset represents a dataset entity.
type Dataset struct {
ID string `json:"id"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
DeletedAt *string `json:"deleted_at,omitempty"`
Name string `json:"name"`
Description string `json:"description"`
OwnerID string `json:"owner_id"`
WorkspaceID string `json:"workspace_id"`
}
// CreateDatasetRequest creates a new dataset.
type CreateDatasetRequest struct {
Name string `json:"name"`
Description string `json:"description"`
}
// UpdateDatasetRequest updates a dataset.
type UpdateDatasetRequest struct {
Name *string `json:"name,omitempty"`
Description *string `json:"description,omitempty"`
}
// DatasetRecord is a single record in a dataset.
type DatasetRecord struct {
ID string `json:"id"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
DeletedAt *string `json:"deleted_at,omitempty"`
DatasetID string `json:"dataset_id"`
Payload ConversationPayload `json:"payload"`
Properties map[string]any `json:"properties,omitempty"`
Source ConversationSource `json:"source"`
}
// ConversationPayload holds the messages for a dataset record.
type ConversationPayload struct {
Messages []map[string]any `json:"messages"`
}
// CreateRecordRequest creates a new dataset record.
type CreateRecordRequest struct {
Payload ConversationPayload `json:"payload"`
Properties map[string]any `json:"properties"`
}
// UpdateRecordPayloadRequest updates a record's payload.
type UpdateRecordPayloadRequest struct {
Payload ConversationPayload `json:"payload"`
}
// UpdateRecordPropertiesRequest updates a record's properties.
type UpdateRecordPropertiesRequest struct {
Properties map[string]any `json:"properties"`
}
// BulkDeleteRecordsRequest deletes multiple records.
type BulkDeleteRecordsRequest struct {
DatasetRecordIDs []string `json:"dataset_record_ids"`
}
// JudgeRecordRequest judges a dataset record.
type JudgeRecordRequest struct {
JudgeDefinition CreateJudgeRequest `json:"judge_definition"`
}
// DatasetImportTask tracks an async import operation.
type DatasetImportTask struct {
ID string `json:"id"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
DeletedAt *string `json:"deleted_at,omitempty"`
CreatorID string `json:"creator_id"`
DatasetID string `json:"dataset_id"`
WorkspaceID string `json:"workspace_id"`
Status TaskStatus `json:"status"`
Progress *int `json:"progress,omitempty"`
Message *string `json:"message,omitempty"`
}
// ExportDatasetResponse is the response from exporting a dataset.
type ExportDatasetResponse struct {
FileURL string `json:"file_url"`
}
// Import request types.
// ImportFromCampaignRequest imports records from a campaign.
type ImportFromCampaignRequest struct {
CampaignID string `json:"campaign_id"`
}
// ImportFromExplorerRequest imports records from explorer events.
type ImportFromExplorerRequest struct {
CompletionEventIDs []string `json:"completion_event_ids"`
}
// ImportFromFileRequest imports records from a file.
type ImportFromFileRequest struct {
FileID string `json:"file_id"`
}
// ImportFromPlaygroundRequest imports records from playground conversations.
type ImportFromPlaygroundRequest struct {
ConversationIDs []string `json:"conversation_ids"`
}
// ImportFromDatasetRequest imports records from another dataset.
type ImportFromDatasetRequest struct {
DatasetRecordIDs []string `json:"dataset_record_ids"`
}
// List response types.
// ListDatasetsResponse is the response from listing datasets.
type ListDatasetsResponse struct {
Count int `json:"count"`
Results []Dataset `json:"results,omitempty"`
Next *string `json:"next,omitempty"`
Previous *string `json:"previous,omitempty"`
}
// ListRecordsResponse is the response from listing dataset records.
type ListRecordsResponse struct {
Count int `json:"count"`
Results []DatasetRecord `json:"results,omitempty"`
Next *string `json:"next,omitempty"`
Previous *string `json:"previous,omitempty"`
}
// ListTasksResponse is the response from listing import tasks.
type ListTasksResponse struct {
Count int `json:"count"`
Results []DatasetImportTask `json:"results,omitempty"`
Next *string `json:"next,omitempty"`
Previous *string `json:"previous,omitempty"`
}
// JudgeResultResponse is the raw response from judging operations.
// The shape depends on the judge type (classification or regression).
type JudgeResultResponse json.RawMessage

5
observability/doc.go Normal file
View File

@@ -0,0 +1,5 @@
// Package observability provides types for the Mistral observability API (beta).
//
// This includes campaigns, chat completion events, judges, and datasets
// for monitoring and evaluating model behavior.
package observability

70
observability/event.go Normal file
View File

@@ -0,0 +1,70 @@
package observability
// ChatCompletionEvent is a full chat completion event.
type ChatCompletionEvent struct {
EventID string `json:"event_id"`
CorrelationID string `json:"correlation_id"`
CreatedAt string `json:"created_at"`
ExtraFields map[string]any `json:"extra_fields,omitempty"`
NbInputTokens int `json:"nb_input_tokens"`
NbOutputTokens int `json:"nb_output_tokens"`
EnabledTools []map[string]any `json:"enabled_tools,omitempty"`
RequestMessages []map[string]any `json:"request_messages,omitempty"`
ResponseMessages []map[string]any `json:"response_messages,omitempty"`
NbMessages int `json:"nb_messages"`
ChatTranscriptionEvents []ChatTranscriptionEvent `json:"chat_transcription_events,omitempty"`
}
// ChatCompletionEventPreview is a summary of a chat completion event.
type ChatCompletionEventPreview struct {
EventID string `json:"event_id"`
CorrelationID string `json:"correlation_id"`
CreatedAt string `json:"created_at"`
ExtraFields map[string]any `json:"extra_fields,omitempty"`
NbInputTokens int `json:"nb_input_tokens"`
NbOutputTokens int `json:"nb_output_tokens"`
}
// ChatTranscriptionEvent is an audio transcription within a chat event.
type ChatTranscriptionEvent struct {
AudioURL string `json:"audio_url"`
Model string `json:"model"`
ResponseMessage map[string]any `json:"response_message"`
}
// SearchEventsRequest is the request body for searching chat completion events.
type SearchEventsRequest struct {
SearchParams FilterPayload `json:"search_params"`
ExtraFields []string `json:"extra_fields,omitempty"`
}
// SearchEventsResponse is the response from searching events.
type SearchEventsResponse struct {
Results []ChatCompletionEventPreview `json:"results,omitempty"`
Next *string `json:"next,omitempty"`
Cursor *string `json:"cursor,omitempty"`
}
// SearchEventIDsRequest is the request body for searching event IDs.
type SearchEventIDsRequest struct {
SearchParams FilterPayload `json:"search_params"`
ExtraFields []string `json:"extra_fields,omitempty"`
}
// SearchEventIDsResponse is the response from searching event IDs.
type SearchEventIDsResponse struct {
CompletionEventIDs []string `json:"completion_event_ids"`
}
// JudgeEventRequest is the request body for judging a chat completion event.
type JudgeEventRequest struct {
JudgeDefinition CreateJudgeRequest `json:"judge_definition"`
}
// SimilarEventsResponse is the response from fetching similar events.
type SimilarEventsResponse struct {
Count int `json:"count"`
Results []ChatCompletionEventPreview `json:"results,omitempty"`
Next *string `json:"next,omitempty"`
Previous *string `json:"previous,omitempty"`
}

75
observability/filter.go Normal file
View File

@@ -0,0 +1,75 @@
package observability
import "encoding/json"
// Op is a filter comparison operator.
type Op string
const (
OpLt Op = "lt"
OpLte Op = "lte"
OpGt Op = "gt"
OpGte Op = "gte"
OpEq Op = "eq"
OpNeq Op = "neq"
OpIsNull Op = "isnull"
OpStartsWith Op = "startswith"
OpIStartsWith Op = "istartswith"
OpEndsWith Op = "endswith"
OpIEndsWith Op = "iendswith"
OpContains Op = "contains"
OpIContains Op = "icontains"
OpMatches Op = "matches"
OpNotContains Op = "notcontains"
OpINotContains Op = "inotcontains"
OpIncludes Op = "includes"
OpExcludes Op = "excludes"
OpLenEq Op = "len_eq"
)
// FilterCondition is a single filter comparison.
type FilterCondition struct {
Field string `json:"field"`
Op Op `json:"op"`
Value any `json:"value"`
}
// FilterGroup combines filters with AND/OR logic.
// The JSON keys are uppercase "AND" / "OR".
type FilterGroup struct {
AND []json.RawMessage `json:"AND,omitempty"`
OR []json.RawMessage `json:"OR,omitempty"`
}
// FilterPayload wraps the top-level filter for search operations.
// Filters can be a FilterGroup or a FilterCondition.
type FilterPayload struct {
Filters json.RawMessage `json:"filters,omitempty"`
}
// TaskStatus is the status of an async task.
type TaskStatus string
const (
TaskStatusRunning TaskStatus = "RUNNING"
TaskStatusCompleted TaskStatus = "COMPLETED"
TaskStatusFailed TaskStatus = "FAILED"
TaskStatusCanceled TaskStatus = "CANCELED"
TaskStatusTerminated TaskStatus = "TERMINATED"
TaskStatusContinuedAsNew TaskStatus = "CONTINUED_AS_NEW"
TaskStatusTimedOut TaskStatus = "TIMED_OUT"
TaskStatusUnknown TaskStatus = "UNKNOWN"
)
// PaginationParams holds common pagination query parameters.
type PaginationParams struct {
Page *int
PageSize *int
}
// SearchParams holds common search query parameters.
type SearchParams struct {
Page *int
PageSize *int
Q *string
}

114
observability/judge.go Normal file
View File

@@ -0,0 +1,114 @@
package observability
import (
"encoding/json"
"fmt"
)
// JudgeOutputType identifies the kind of judge output.
type JudgeOutputType string
const (
JudgeOutputClassification JudgeOutputType = "CLASSIFICATION"
JudgeOutputRegression JudgeOutputType = "REGRESSION"
)
// JudgeOutput is a sealed interface for judge output configurations.
type JudgeOutput interface {
judgeOutputType() JudgeOutputType
}
// ClassificationOutput configures a classification judge.
type ClassificationOutput struct {
Type JudgeOutputType `json:"type"`
Options []ClassificationOption `json:"options"`
}
func (*ClassificationOutput) judgeOutputType() JudgeOutputType { return JudgeOutputClassification }
// ClassificationOption is a single option for a classification judge.
type ClassificationOption struct {
Value string `json:"value"`
Description string `json:"description"`
}
// RegressionOutput configures a regression judge.
type RegressionOutput struct {
Type JudgeOutputType `json:"type"`
MinDescription string `json:"min_description"`
MaxDescription string `json:"max_description"`
Min *float64 `json:"min,omitempty"`
Max *float64 `json:"max,omitempty"`
}
func (*RegressionOutput) judgeOutputType() JudgeOutputType { return JudgeOutputRegression }
// UnmarshalJudgeOutput dispatches to the concrete JudgeOutput type.
func UnmarshalJudgeOutput(data []byte) (JudgeOutput, error) {
var probe struct {
Type string `json:"type"`
}
if err := json.Unmarshal(data, &probe); err != nil {
return nil, fmt.Errorf("unmarshal judge output: %w", err)
}
switch JudgeOutputType(probe.Type) {
case JudgeOutputClassification:
var o ClassificationOutput
return &o, json.Unmarshal(data, &o)
case JudgeOutputRegression:
var o RegressionOutput
return &o, json.Unmarshal(data, &o)
default:
return nil, fmt.Errorf("unknown judge output type: %q", probe.Type)
}
}
// Judge represents a judge entity.
type Judge struct {
ID string `json:"id"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
DeletedAt *string `json:"deleted_at,omitempty"`
OwnerID string `json:"owner_id"`
WorkspaceID string `json:"workspace_id"`
Name string `json:"name"`
Description string `json:"description"`
ModelName string `json:"model_name"`
Output json.RawMessage `json:"output"`
Instructions string `json:"instructions"`
Tools []string `json:"tools,omitempty"`
}
// CreateJudgeRequest creates a new judge.
type CreateJudgeRequest struct {
Name string `json:"name"`
Description string `json:"description"`
ModelName string `json:"model_name"`
Output json.RawMessage `json:"output"`
Instructions string `json:"instructions"`
Tools []string `json:"tools"`
}
// UpdateJudgeRequest updates a judge.
type UpdateJudgeRequest struct {
Name string `json:"name"`
Description string `json:"description"`
ModelName string `json:"model_name"`
Output json.RawMessage `json:"output"`
Instructions string `json:"instructions"`
Tools []string `json:"tools"`
}
// JudgeConversationRequest is the request for live-judging a conversation.
type JudgeConversationRequest struct {
Messages []map[string]any `json:"messages"`
Properties map[string]any `json:"properties,omitempty"`
}
// ListJudgesResponse is the response from listing judges.
type ListJudgesResponse struct {
Count int `json:"count"`
Results []Judge `json:"results,omitempty"`
Next *string `json:"next,omitempty"`
Previous *string `json:"previous,omitempty"`
}

View File

@@ -0,0 +1,97 @@
package mistral
import (
"context"
"fmt"
"net/url"
"strconv"
"somegit.dev/vikingowl/mistral-go-sdk/observability"
)
// CreateCampaign creates a new observability campaign.
func (c *Client) CreateCampaign(ctx context.Context, req *observability.CreateCampaignRequest) (*observability.Campaign, error) {
var resp observability.Campaign
if err := c.doJSON(ctx, "POST", "/v1/observability/campaigns", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ListCampaigns lists observability campaigns.
func (c *Client) ListCampaigns(ctx context.Context, params *observability.SearchParams) (*observability.ListCampaignsResponse, error) {
path := "/v1/observability/campaigns"
if params != nil {
q := url.Values{}
if params.PageSize != nil {
q.Set("page_size", strconv.Itoa(*params.PageSize))
}
if params.Page != nil {
q.Set("page", strconv.Itoa(*params.Page))
}
if params.Q != nil {
q.Set("q", *params.Q)
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp observability.ListCampaignsResponse
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// GetCampaign retrieves a campaign by ID.
func (c *Client) GetCampaign(ctx context.Context, campaignID string) (*observability.Campaign, error) {
var resp observability.Campaign
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/campaigns/%s", campaignID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// DeleteCampaign deletes a campaign.
func (c *Client) DeleteCampaign(ctx context.Context, campaignID string) error {
resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/observability/campaigns/%s", campaignID), nil)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return parseAPIError(resp)
}
return nil
}
// GetCampaignStatus retrieves the status of a campaign.
func (c *Client) GetCampaignStatus(ctx context.Context, campaignID string) (*observability.CampaignStatusResponse, error) {
var resp observability.CampaignStatusResponse
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/campaigns/%s/status", campaignID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ListCampaignEvents lists events selected by a campaign.
func (c *Client) ListCampaignEvents(ctx context.Context, campaignID string, params *observability.PaginationParams) (*observability.ListCampaignEventsResponse, error) {
path := fmt.Sprintf("/v1/observability/campaigns/%s/selected-events", campaignID)
if params != nil {
q := url.Values{}
if params.PageSize != nil {
q.Set("page_size", strconv.Itoa(*params.PageSize))
}
if params.Page != nil {
q.Set("page", strconv.Itoa(*params.Page))
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp observability.ListCampaignEventsResponse
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}

View File

@@ -0,0 +1,148 @@
package mistral
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"somegit.dev/vikingowl/mistral-go-sdk/observability"
)
func TestCreateCampaign_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" || r.URL.Path != "/v1/observability/campaigns" {
t.Errorf("unexpected %s %s", r.Method, r.URL.Path)
}
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
if body["name"] != "test-campaign" {
t.Errorf("got name %v", body["name"])
}
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]any{
"id": "camp-1", "name": "test-campaign", "description": "d",
"created_at": "t", "updated_at": "t", "owner_id": "o",
"workspace_id": "w", "max_nb_events": 100,
"search_params": map[string]any{}, "judge": map[string]any{"id": "j1"},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.CreateCampaign(context.Background(), &observability.CreateCampaignRequest{
Name: "test-campaign",
Description: "d",
JudgeID: "j1",
MaxNbEvents: 100,
})
if err != nil {
t.Fatal(err)
}
if resp.ID != "camp-1" {
t.Errorf("got id %q", resp.ID)
}
}
func TestListCampaigns_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/campaigns" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"count": 1,
"results": []map[string]any{{"id": "c1", "name": "c", "description": "d", "created_at": "t", "updated_at": "t", "owner_id": "o", "workspace_id": "w", "max_nb_events": 10, "search_params": map[string]any{}, "judge": map[string]any{"id": "j"}}},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.ListCampaigns(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
if resp.Count != 1 {
t.Errorf("got count %d", resp.Count)
}
}
func TestGetCampaign_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/campaigns/camp-1" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"id": "camp-1", "name": "c", "description": "d",
"created_at": "t", "updated_at": "t", "owner_id": "o",
"workspace_id": "w", "max_nb_events": 10,
"search_params": map[string]any{}, "judge": map[string]any{"id": "j"},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.GetCampaign(context.Background(), "camp-1")
if err != nil {
t.Fatal(err)
}
if resp.ID != "camp-1" {
t.Errorf("got id %q", resp.ID)
}
}
func TestDeleteCampaign_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "DELETE" {
t.Errorf("expected DELETE")
}
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
if err := client.DeleteCampaign(context.Background(), "camp-1"); err != nil {
t.Fatal(err)
}
}
func TestGetCampaignStatus_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/campaigns/camp-1/status" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{"status": "COMPLETED"})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.GetCampaignStatus(context.Background(), "camp-1")
if err != nil {
t.Fatal(err)
}
if resp.Status != observability.TaskStatusCompleted {
t.Errorf("got status %q", resp.Status)
}
}
func TestListCampaignEvents_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/campaigns/camp-1/selected-events" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"count": 0,
"results": []any{},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.ListCampaignEvents(context.Background(), "camp-1", nil)
if err != nil {
t.Fatal(err)
}
if resp.Count != 0 {
t.Errorf("got count %d", resp.Count)
}
}

252
observability_datasets.go Normal file
View File

@@ -0,0 +1,252 @@
package mistral
import (
"context"
"encoding/json"
"fmt"
"net/url"
"strconv"
"somegit.dev/vikingowl/mistral-go-sdk/observability"
)
// CreateDataset creates a new observability dataset.
func (c *Client) CreateDataset(ctx context.Context, req *observability.CreateDatasetRequest) (*observability.Dataset, error) {
var resp observability.Dataset
if err := c.doJSON(ctx, "POST", "/v1/observability/datasets", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ListDatasets lists observability datasets.
func (c *Client) ListDatasets(ctx context.Context, params *observability.SearchParams) (*observability.ListDatasetsResponse, error) {
path := "/v1/observability/datasets"
if params != nil {
q := url.Values{}
if params.PageSize != nil {
q.Set("page_size", strconv.Itoa(*params.PageSize))
}
if params.Page != nil {
q.Set("page", strconv.Itoa(*params.Page))
}
if params.Q != nil {
q.Set("q", *params.Q)
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp observability.ListDatasetsResponse
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// GetDataset retrieves a dataset by ID.
func (c *Client) GetDataset(ctx context.Context, datasetID string) (*observability.Dataset, error) {
var resp observability.Dataset
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/datasets/%s", datasetID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// UpdateDataset updates a dataset.
func (c *Client) UpdateDataset(ctx context.Context, datasetID string, req *observability.UpdateDatasetRequest) (*observability.Dataset, error) {
var resp observability.Dataset
if err := c.doJSON(ctx, "PATCH", fmt.Sprintf("/v1/observability/datasets/%s", datasetID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// DeleteDataset deletes a dataset.
func (c *Client) DeleteDataset(ctx context.Context, datasetID string) error {
resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/observability/datasets/%s", datasetID), nil)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return parseAPIError(resp)
}
return nil
}
// ExportDatasetToJSONL exports a dataset to JSONL format.
func (c *Client) ExportDatasetToJSONL(ctx context.Context, datasetID string) (*observability.ExportDatasetResponse, error) {
var resp observability.ExportDatasetResponse
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/datasets/%s/exports/to-jsonl", datasetID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// Dataset records
// ListDatasetRecords lists records in a dataset.
func (c *Client) ListDatasetRecords(ctx context.Context, datasetID string, params *observability.PaginationParams) (*observability.ListRecordsResponse, error) {
path := fmt.Sprintf("/v1/observability/datasets/%s/records", datasetID)
if params != nil {
q := url.Values{}
if params.PageSize != nil {
q.Set("page_size", strconv.Itoa(*params.PageSize))
}
if params.Page != nil {
q.Set("page", strconv.Itoa(*params.Page))
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp observability.ListRecordsResponse
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// CreateDatasetRecord creates a record in a dataset.
func (c *Client) CreateDatasetRecord(ctx context.Context, datasetID string, req *observability.CreateRecordRequest) (*observability.DatasetRecord, error) {
var resp observability.DatasetRecord
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/records", datasetID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// GetDatasetRecord retrieves a dataset record by ID.
func (c *Client) GetDatasetRecord(ctx context.Context, recordID string) (*observability.DatasetRecord, error) {
var resp observability.DatasetRecord
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/dataset-records/%s", recordID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// UpdateDatasetRecordPayload updates a record's payload.
func (c *Client) UpdateDatasetRecordPayload(ctx context.Context, recordID string, req *observability.UpdateRecordPayloadRequest) (*observability.DatasetRecord, error) {
var resp observability.DatasetRecord
if err := c.doJSON(ctx, "PUT", fmt.Sprintf("/v1/observability/dataset-records/%s/payload", recordID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// UpdateDatasetRecordProperties updates a record's properties.
func (c *Client) UpdateDatasetRecordProperties(ctx context.Context, recordID string, req *observability.UpdateRecordPropertiesRequest) (*observability.DatasetRecord, error) {
var resp observability.DatasetRecord
if err := c.doJSON(ctx, "PUT", fmt.Sprintf("/v1/observability/dataset-records/%s/properties", recordID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// DeleteDatasetRecord deletes a dataset record.
func (c *Client) DeleteDatasetRecord(ctx context.Context, recordID string) error {
resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/observability/dataset-records/%s", recordID), nil)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return parseAPIError(resp)
}
return nil
}
// BulkDeleteDatasetRecords deletes multiple dataset records.
func (c *Client) BulkDeleteDatasetRecords(ctx context.Context, req *observability.BulkDeleteRecordsRequest) error {
return c.doJSON(ctx, "POST", "/v1/observability/dataset-records/bulk-delete", req, nil)
}
// JudgeDatasetRecord judges a dataset record.
func (c *Client) JudgeDatasetRecord(ctx context.Context, recordID string, req *observability.JudgeRecordRequest) (json.RawMessage, error) {
var resp json.RawMessage
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/dataset-records/%s/live-judging", recordID), req, &resp); err != nil {
return nil, err
}
return resp, nil
}
// Import operations
// ImportDatasetFromCampaign imports records from a campaign.
func (c *Client) ImportDatasetFromCampaign(ctx context.Context, datasetID string, req *observability.ImportFromCampaignRequest) (*observability.DatasetImportTask, error) {
var resp observability.DatasetImportTask
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/imports/from-campaign", datasetID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ImportDatasetFromExplorer imports records from explorer events.
func (c *Client) ImportDatasetFromExplorer(ctx context.Context, datasetID string, req *observability.ImportFromExplorerRequest) (*observability.DatasetImportTask, error) {
var resp observability.DatasetImportTask
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/imports/from-explorer", datasetID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ImportDatasetFromFile imports records from a file.
func (c *Client) ImportDatasetFromFile(ctx context.Context, datasetID string, req *observability.ImportFromFileRequest) (*observability.DatasetImportTask, error) {
var resp observability.DatasetImportTask
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/imports/from-file", datasetID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ImportDatasetFromPlayground imports records from playground conversations.
func (c *Client) ImportDatasetFromPlayground(ctx context.Context, datasetID string, req *observability.ImportFromPlaygroundRequest) (*observability.DatasetImportTask, error) {
var resp observability.DatasetImportTask
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/imports/from-playground", datasetID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ImportDatasetFromDataset imports records from another dataset.
func (c *Client) ImportDatasetFromDataset(ctx context.Context, datasetID string, req *observability.ImportFromDatasetRequest) (*observability.DatasetImportTask, error) {
var resp observability.DatasetImportTask
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/imports/from-dataset", datasetID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// Tasks
// ListDatasetTasks lists import tasks for a dataset.
func (c *Client) ListDatasetTasks(ctx context.Context, datasetID string, params *observability.PaginationParams) (*observability.ListTasksResponse, error) {
path := fmt.Sprintf("/v1/observability/datasets/%s/tasks", datasetID)
if params != nil {
q := url.Values{}
if params.PageSize != nil {
q.Set("page_size", strconv.Itoa(*params.PageSize))
}
if params.Page != nil {
q.Set("page", strconv.Itoa(*params.Page))
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp observability.ListTasksResponse
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// GetDatasetTask retrieves an import task by ID.
func (c *Client) GetDatasetTask(ctx context.Context, datasetID, taskID string) (*observability.DatasetImportTask, error) {
var resp observability.DatasetImportTask
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/datasets/%s/tasks/%s", datasetID, taskID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}

View File

@@ -0,0 +1,211 @@
package mistral
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"somegit.dev/vikingowl/mistral-go-sdk/observability"
)
func datasetJSON() map[string]any {
return map[string]any{
"id": "ds-1", "created_at": "t", "updated_at": "t",
"name": "test-ds", "description": "d",
"owner_id": "o", "workspace_id": "w",
}
}
func TestCreateDataset_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" || r.URL.Path != "/v1/observability/datasets" {
t.Errorf("unexpected %s %s", r.Method, r.URL.Path)
}
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(datasetJSON())
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.CreateDataset(context.Background(), &observability.CreateDatasetRequest{
Name: "test-ds",
Description: "d",
})
if err != nil {
t.Fatal(err)
}
if resp.ID != "ds-1" {
t.Errorf("got id %q", resp.ID)
}
}
func TestListDatasets_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]any{
"count": 1,
"results": []any{datasetJSON()},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.ListDatasets(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
if resp.Count != 1 {
t.Errorf("got count %d", resp.Count)
}
}
func TestGetDataset_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/datasets/ds-1" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(datasetJSON())
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.GetDataset(context.Background(), "ds-1")
if err != nil {
t.Fatal(err)
}
if resp.Name != "test-ds" {
t.Errorf("got name %q", resp.Name)
}
}
func TestDeleteDataset_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
if err := client.DeleteDataset(context.Background(), "ds-1"); err != nil {
t.Fatal(err)
}
}
func TestCreateDatasetRecord_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" || r.URL.Path != "/v1/observability/datasets/ds-1/records" {
t.Errorf("unexpected %s %s", r.Method, r.URL.Path)
}
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]any{
"id": "rec-1", "created_at": "t", "updated_at": "t",
"dataset_id": "ds-1", "source": "DIRECT_INPUT",
"payload": map[string]any{"messages": []any{}},
"properties": map[string]any{},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.CreateDatasetRecord(context.Background(), "ds-1", &observability.CreateRecordRequest{
Payload: observability.ConversationPayload{Messages: []map[string]any{}},
Properties: map[string]any{},
})
if err != nil {
t.Fatal(err)
}
if resp.ID != "rec-1" {
t.Errorf("got id %q", resp.ID)
}
}
func TestListDatasetRecords_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/datasets/ds-1/records" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"count": 0, "results": []any{},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.ListDatasetRecords(context.Background(), "ds-1", nil)
if err != nil {
t.Fatal(err)
}
if resp.Count != 0 {
t.Errorf("got count %d", resp.Count)
}
}
func TestImportDatasetFromCampaign_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/datasets/ds-1/imports/from-campaign" {
t.Errorf("got path %s", r.URL.Path)
}
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]any{
"id": "task-1", "created_at": "t", "updated_at": "t",
"creator_id": "u", "dataset_id": "ds-1", "workspace_id": "w",
"status": "RUNNING",
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.ImportDatasetFromCampaign(context.Background(), "ds-1", &observability.ImportFromCampaignRequest{
CampaignID: "camp-1",
})
if err != nil {
t.Fatal(err)
}
if resp.Status != observability.TaskStatusRunning {
t.Errorf("got status %q", resp.Status)
}
}
func TestExportDatasetToJSONL_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/datasets/ds-1/exports/to-jsonl" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"file_url": "https://storage.example.com/export.jsonl",
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.ExportDatasetToJSONL(context.Background(), "ds-1")
if err != nil {
t.Fatal(err)
}
if resp.FileURL != "https://storage.example.com/export.jsonl" {
t.Errorf("got file_url %q", resp.FileURL)
}
}
func TestGetDatasetTask_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/datasets/ds-1/tasks/task-1" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"id": "task-1", "created_at": "t", "updated_at": "t",
"creator_id": "u", "dataset_id": "ds-1", "workspace_id": "w",
"status": "COMPLETED",
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.GetDatasetTask(context.Background(), "ds-1", "task-1")
if err != nil {
t.Fatal(err)
}
if resp.Status != observability.TaskStatusCompleted {
t.Errorf("got status %q", resp.Status)
}
}

69
observability_events.go Normal file
View File

@@ -0,0 +1,69 @@
package mistral
import (
"context"
"encoding/json"
"fmt"
"net/url"
"strconv"
"somegit.dev/vikingowl/mistral-go-sdk/observability"
)
// SearchChatCompletionEvents searches for chat completion events.
func (c *Client) SearchChatCompletionEvents(ctx context.Context, req *observability.SearchEventsRequest) (*observability.SearchEventsResponse, error) {
var resp observability.SearchEventsResponse
if err := c.doJSON(ctx, "POST", "/v1/observability/chat-completion-events/search", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// SearchChatCompletionEventIDs searches for chat completion event IDs.
func (c *Client) SearchChatCompletionEventIDs(ctx context.Context, req *observability.SearchEventIDsRequest) (*observability.SearchEventIDsResponse, error) {
var resp observability.SearchEventIDsResponse
if err := c.doJSON(ctx, "POST", "/v1/observability/chat-completion-events/search-ids", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// GetChatCompletionEvent retrieves a chat completion event by ID.
func (c *Client) GetChatCompletionEvent(ctx context.Context, eventID string) (*observability.ChatCompletionEvent, error) {
var resp observability.ChatCompletionEvent
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/chat-completion-events/%s", eventID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// GetSimilarChatCompletionEvents retrieves events similar to a given event.
func (c *Client) GetSimilarChatCompletionEvents(ctx context.Context, eventID string, params *observability.PaginationParams) (*observability.SimilarEventsResponse, error) {
path := fmt.Sprintf("/v1/observability/chat-completion-events/%s/similar-events", eventID)
if params != nil {
q := url.Values{}
if params.PageSize != nil {
q.Set("page_size", strconv.Itoa(*params.PageSize))
}
if params.Page != nil {
q.Set("page", strconv.Itoa(*params.Page))
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp observability.SimilarEventsResponse
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// JudgeChatCompletionEvent judges a chat completion event.
func (c *Client) JudgeChatCompletionEvent(ctx context.Context, eventID string, req *observability.JudgeEventRequest) (json.RawMessage, error) {
var resp json.RawMessage
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/chat-completion-events/%s/live-judging", eventID), req, &resp); err != nil {
return nil, err
}
return resp, nil
}

View File

@@ -0,0 +1,101 @@
package mistral
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"somegit.dev/vikingowl/mistral-go-sdk/observability"
)
func TestSearchChatCompletionEvents_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" || r.URL.Path != "/v1/observability/chat-completion-events/search" {
t.Errorf("unexpected %s %s", r.Method, r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"results": []map[string]any{
{"event_id": "ev-1", "correlation_id": "c1", "created_at": "t", "nb_input_tokens": 10, "nb_output_tokens": 5},
},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.SearchChatCompletionEvents(context.Background(), &observability.SearchEventsRequest{})
if err != nil {
t.Fatal(err)
}
if len(resp.Results) != 1 {
t.Fatalf("got %d results", len(resp.Results))
}
if resp.Results[0].EventID != "ev-1" {
t.Errorf("got event_id %q", resp.Results[0].EventID)
}
}
func TestSearchChatCompletionEventIDs_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/chat-completion-events/search-ids" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"completion_event_ids": []string{"ev-1", "ev-2"},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.SearchChatCompletionEventIDs(context.Background(), &observability.SearchEventIDsRequest{})
if err != nil {
t.Fatal(err)
}
if len(resp.CompletionEventIDs) != 2 {
t.Errorf("got %d ids", len(resp.CompletionEventIDs))
}
}
func TestGetChatCompletionEvent_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/chat-completion-events/ev-1" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"event_id": "ev-1", "correlation_id": "c1", "created_at": "t",
"nb_input_tokens": 10, "nb_output_tokens": 5, "nb_messages": 2,
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.GetChatCompletionEvent(context.Background(), "ev-1")
if err != nil {
t.Fatal(err)
}
if resp.EventID != "ev-1" {
t.Errorf("got event_id %q", resp.EventID)
}
}
func TestGetSimilarChatCompletionEvents_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/chat-completion-events/ev-1/similar-events" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"count": 0, "results": []any{},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.GetSimilarChatCompletionEvents(context.Background(), "ev-1", nil)
if err != nil {
t.Fatal(err)
}
if resp.Count != 0 {
t.Errorf("got count %d", resp.Count)
}
}

85
observability_judges.go Normal file
View File

@@ -0,0 +1,85 @@
package mistral
import (
"context"
"encoding/json"
"fmt"
"net/url"
"strconv"
"somegit.dev/vikingowl/mistral-go-sdk/observability"
)
// CreateJudge creates a new observability judge.
func (c *Client) CreateJudge(ctx context.Context, req *observability.CreateJudgeRequest) (*observability.Judge, error) {
var resp observability.Judge
if err := c.doJSON(ctx, "POST", "/v1/observability/judges", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ListJudges lists observability judges.
func (c *Client) ListJudges(ctx context.Context, params *observability.SearchParams) (*observability.ListJudgesResponse, error) {
path := "/v1/observability/judges"
if params != nil {
q := url.Values{}
if params.PageSize != nil {
q.Set("page_size", strconv.Itoa(*params.PageSize))
}
if params.Page != nil {
q.Set("page", strconv.Itoa(*params.Page))
}
if params.Q != nil {
q.Set("q", *params.Q)
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp observability.ListJudgesResponse
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// GetJudge retrieves a judge by ID.
func (c *Client) GetJudge(ctx context.Context, judgeID string) (*observability.Judge, error) {
var resp observability.Judge
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/judges/%s", judgeID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// UpdateJudge updates a judge.
func (c *Client) UpdateJudge(ctx context.Context, judgeID string, req *observability.UpdateJudgeRequest) (*observability.Judge, error) {
var resp observability.Judge
if err := c.doJSON(ctx, "PUT", fmt.Sprintf("/v1/observability/judges/%s", judgeID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// DeleteJudge deletes a judge.
func (c *Client) DeleteJudge(ctx context.Context, judgeID string) error {
resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/observability/judges/%s", judgeID), nil)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return parseAPIError(resp)
}
return nil
}
// JudgeConversation performs live judging on a conversation.
func (c *Client) JudgeConversation(ctx context.Context, judgeID string, req *observability.JudgeConversationRequest) (json.RawMessage, error) {
var resp json.RawMessage
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/judges/%s/live-judging", judgeID), req, &resp); err != nil {
return nil, err
}
return resp, nil
}

View File

@@ -0,0 +1,123 @@
package mistral
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"somegit.dev/vikingowl/mistral-go-sdk/observability"
)
func judgeJSON() map[string]any {
return map[string]any{
"id": "j1", "created_at": "t", "updated_at": "t",
"owner_id": "o", "workspace_id": "w", "name": "quality",
"description": "d", "model_name": "m", "instructions": "i",
"output": map[string]any{"type": "CLASSIFICATION", "options": []any{}},
}
}
func TestCreateJudge_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" || r.URL.Path != "/v1/observability/judges" {
t.Errorf("unexpected %s %s", r.Method, r.URL.Path)
}
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(judgeJSON())
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.CreateJudge(context.Background(), &observability.CreateJudgeRequest{
Name: "quality",
Description: "d",
ModelName: "m",
Instructions: "i",
Tools: []string{},
Output: json.RawMessage(`{"type":"CLASSIFICATION","options":[]}`),
})
if err != nil {
t.Fatal(err)
}
if resp.ID != "j1" {
t.Errorf("got id %q", resp.ID)
}
}
func TestListJudges_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]any{
"count": 1,
"results": []any{judgeJSON()},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.ListJudges(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
if resp.Count != 1 {
t.Errorf("got count %d", resp.Count)
}
}
func TestGetJudge_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/judges/j1" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(judgeJSON())
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.GetJudge(context.Background(), "j1")
if err != nil {
t.Fatal(err)
}
if resp.Name != "quality" {
t.Errorf("got name %q", resp.Name)
}
}
func TestUpdateJudge_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "PUT" {
t.Errorf("expected PUT, got %s", r.Method)
}
json.NewEncoder(w).Encode(judgeJSON())
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
_, err := client.UpdateJudge(context.Background(), "j1", &observability.UpdateJudgeRequest{
Name: "quality",
Description: "d",
ModelName: "m",
Instructions: "i",
Tools: []string{},
Output: json.RawMessage(`{"type":"CLASSIFICATION","options":[]}`),
})
if err != nil {
t.Fatal(err)
}
}
func TestDeleteJudge_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "DELETE" {
t.Errorf("expected DELETE")
}
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
if err := client.DeleteJudge(context.Background(), "j1"); err != nil {
t.Fatal(err)
}
}