Use Voxtral audio transcription endpoint
All checks were successful
CI / test (push) Successful in 14s
Build and Deploy / build-and-deploy (push) Successful in 30s

This commit is contained in:
Grendgi
2026-06-09 15:51:50 +03:00
parent 817eb8ff71
commit e6c2b46cf6
2 changed files with 123 additions and 1 deletions

View File

@@ -141,6 +141,15 @@ type audioLLMChatResponse struct {
} `json:"error,omitempty"` } `json:"error,omitempty"`
} }
type audioTranscriptionResponse struct {
Text string `json:"text"`
Language string `json:"language,omitempty"`
Model string `json:"model,omitempty"`
Error *struct {
Message string `json:"message"`
} `json:"error,omitempty"`
}
func New(baseURL string, timeout time.Duration, ffmpegPath string, leadSilence time.Duration) *Client { func New(baseURL string, timeout time.Duration, ffmpegPath string, leadSilence time.Duration) *Client {
return NewWithOptions(Options{ return NewWithOptions(Options{
Providers: []string{ProviderWhisperX}, Providers: []string{ProviderWhisperX},
@@ -224,7 +233,7 @@ func buildProviders(opts Options, prompt string, maxTokens int) []ProviderConfig
model := firstNonEmpty(opts.VoxtralModel, "mistralai/Voxtral-Small-24B-2507") model := firstNonEmpty(opts.VoxtralModel, "mistralai/Voxtral-Small-24B-2507")
out = append(out, ProviderConfig{ out = append(out, ProviderConfig{
Name: ProviderVoxtral, Name: ProviderVoxtral,
Kind: "audio_llm", Kind: "audio_transcription",
BaseURL: baseURL, BaseURL: baseURL,
APIKey: strings.TrimSpace(opts.VoxtralAPIKey), APIKey: strings.TrimSpace(opts.VoxtralAPIKey),
Model: model, Model: model,
@@ -549,6 +558,9 @@ func (c *Client) transcribeAudio(ctx context.Context, provider ProviderConfig, a
} }
func (c *Client) transcribeAudioLLM(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*audioLLMResponse, time.Duration, error) { func (c *Client) transcribeAudioLLM(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*audioLLMResponse, time.Duration, error) {
if provider.Kind == "audio_transcription" {
return c.transcribeOpenAIAudio(ctx, provider, audio, filename, in)
}
prompt := provider.Prompt prompt := provider.Prompt
if in.Language != "" { if in.Language != "" {
prompt += "\nЯзык аудио: " + in.Language + "." prompt += "\nЯзык аудио: " + in.Language + "."
@@ -620,6 +632,70 @@ func (c *Client) transcribeAudioLLM(ctx context.Context, provider ProviderConfig
}, duration, nil }, duration, nil
} }
func (c *Client) transcribeOpenAIAudio(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*audioLLMResponse, time.Duration, error) {
body := &bytes.Buffer{}
mw := multipart.NewWriter(body)
if err := mw.WriteField("model", provider.Model); err != nil {
return nil, 0, fmt.Errorf("audio transcription model field: %w", err)
}
if err := mw.WriteField("response_format", "json"); err != nil {
return nil, 0, fmt.Errorf("audio transcription response_format field: %w", err)
}
if prompt := strings.TrimSpace(provider.Prompt); prompt != "" {
if err := mw.WriteField("prompt", prompt); err != nil {
return nil, 0, fmt.Errorf("audio transcription prompt field: %w", err)
}
}
if lang := strings.TrimSpace(in.Language); lang != "" {
if err := mw.WriteField("language", lang); err != nil {
return nil, 0, fmt.Errorf("audio transcription language field: %w", err)
}
}
fw, err := mw.CreateFormFile("file", filename)
if err != nil {
return nil, 0, fmt.Errorf("audio transcription create file: %w", err)
}
if _, err := fw.Write(audio); err != nil {
return nil, 0, fmt.Errorf("audio transcription copy audio: %w", err)
}
if err := mw.Close(); err != nil {
return nil, 0, fmt.Errorf("audio transcription close form: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, provider.BaseURL+"/v1/audio/transcriptions", body)
if err != nil {
return nil, 0, fmt.Errorf("audio transcription request: %w", err)
}
req.Header.Set("Content-Type", mw.FormDataContentType())
if provider.APIKey != "" {
req.Header.Set("Authorization", "Bearer "+provider.APIKey)
}
start := time.Now()
resp, err := c.http.Do(req)
duration := time.Since(start)
if err != nil {
return nil, duration, fmt.Errorf("audio transcription do: %w", err)
}
defer resp.Body.Close()
raw, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
if err != nil {
return nil, duration, fmt.Errorf("audio transcription read: %w", err)
}
if resp.StatusCode >= 300 {
return nil, duration, fmt.Errorf("audio transcription HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(raw)))
}
var out audioTranscriptionResponse
if err := json.Unmarshal(raw, &out); err != nil {
return nil, duration, fmt.Errorf("audio transcription decode: %w", err)
}
if out.Error != nil {
return nil, duration, fmt.Errorf("audio transcription error: %s", out.Error.Message)
}
modelName := firstNonEmpty(out.Model, provider.Model)
return &audioLLMResponse{Text: strings.TrimSpace(out.Text), Model: modelName}, duration, nil
}
func audioFormat(filename string) string { func audioFormat(filename string) string {
ext := strings.TrimPrefix(strings.ToLower(filepath.Ext(filename)), ".") ext := strings.TrimPrefix(strings.ToLower(filepath.Ext(filename)), ".")
switch ext { switch ext {

View File

@@ -1,7 +1,10 @@
package transcription package transcription
import ( import (
"encoding/json"
"math" "math"
"net/http"
"net/http/httptest"
"testing" "testing"
"time" "time"
) )
@@ -66,6 +69,49 @@ func TestAudioDataURLUsesVLLMAudioURLFormat(t *testing.T) {
} }
} }
func TestVoxtralUsesAudioTranscriptionsEndpoint(t *testing.T) {
audioSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("fake audio"))
}))
defer audioSrv.Close()
var gotPath, gotModel string
providerSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
if err := r.ParseMultipartForm(16 << 20); err != nil {
t.Fatalf("ParseMultipartForm: %v", err)
}
gotModel = r.FormValue("model")
if _, _, err := r.FormFile("file"); err != nil {
t.Fatalf("FormFile: %v", err)
}
_ = json.NewEncoder(w).Encode(map[string]string{"text": "Алло, тест."})
}))
defer providerSrv.Close()
client := NewWithOptions(Options{
Providers: []string{"voxtral-small"},
VoxtralBaseURL: providerSrv.URL,
VoxtralModel: "mistralai/Voxtral-Small-24B-2507",
})
if client == nil {
t.Fatal("client is nil")
}
got, err := client.Transcribe(t.Context(), Input{AudioURL: audioSrv.URL, Filename: "call.mp3"})
if err != nil {
t.Fatalf("Transcribe: %v", err)
}
if gotPath != "/v1/audio/transcriptions" {
t.Fatalf("path = %q, want /v1/audio/transcriptions", gotPath)
}
if gotModel != "mistralai/Voxtral-Small-24B-2507" {
t.Fatalf("model = %q", gotModel)
}
if len(got.Segments) != 1 || got.Segments[0].Text != "Алло, тест." {
t.Fatalf("segments = %#v", got.Segments)
}
}
func near(got, want float64) bool { func near(got, want float64) bool {
return math.Abs(got-want) < 0.000001 return math.Abs(got-want) < 0.000001
} }