Use Voxtral audio transcription endpoint
This commit is contained in:
@@ -141,6 +141,15 @@ type audioLLMChatResponse struct {
|
||||
} `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type audioTranscriptionResponse struct {
|
||||
Text string `json:"text"`
|
||||
Language string `json:"language,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Error *struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func New(baseURL string, timeout time.Duration, ffmpegPath string, leadSilence time.Duration) *Client {
|
||||
return NewWithOptions(Options{
|
||||
Providers: []string{ProviderWhisperX},
|
||||
@@ -224,7 +233,7 @@ func buildProviders(opts Options, prompt string, maxTokens int) []ProviderConfig
|
||||
model := firstNonEmpty(opts.VoxtralModel, "mistralai/Voxtral-Small-24B-2507")
|
||||
out = append(out, ProviderConfig{
|
||||
Name: ProviderVoxtral,
|
||||
Kind: "audio_llm",
|
||||
Kind: "audio_transcription",
|
||||
BaseURL: baseURL,
|
||||
APIKey: strings.TrimSpace(opts.VoxtralAPIKey),
|
||||
Model: model,
|
||||
@@ -549,6 +558,9 @@ func (c *Client) transcribeAudio(ctx context.Context, provider ProviderConfig, a
|
||||
}
|
||||
|
||||
func (c *Client) transcribeAudioLLM(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*audioLLMResponse, time.Duration, error) {
|
||||
if provider.Kind == "audio_transcription" {
|
||||
return c.transcribeOpenAIAudio(ctx, provider, audio, filename, in)
|
||||
}
|
||||
prompt := provider.Prompt
|
||||
if in.Language != "" {
|
||||
prompt += "\nЯзык аудио: " + in.Language + "."
|
||||
@@ -620,6 +632,70 @@ func (c *Client) transcribeAudioLLM(ctx context.Context, provider ProviderConfig
|
||||
}, duration, nil
|
||||
}
|
||||
|
||||
func (c *Client) transcribeOpenAIAudio(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*audioLLMResponse, time.Duration, error) {
|
||||
body := &bytes.Buffer{}
|
||||
mw := multipart.NewWriter(body)
|
||||
if err := mw.WriteField("model", provider.Model); err != nil {
|
||||
return nil, 0, fmt.Errorf("audio transcription model field: %w", err)
|
||||
}
|
||||
if err := mw.WriteField("response_format", "json"); err != nil {
|
||||
return nil, 0, fmt.Errorf("audio transcription response_format field: %w", err)
|
||||
}
|
||||
if prompt := strings.TrimSpace(provider.Prompt); prompt != "" {
|
||||
if err := mw.WriteField("prompt", prompt); err != nil {
|
||||
return nil, 0, fmt.Errorf("audio transcription prompt field: %w", err)
|
||||
}
|
||||
}
|
||||
if lang := strings.TrimSpace(in.Language); lang != "" {
|
||||
if err := mw.WriteField("language", lang); err != nil {
|
||||
return nil, 0, fmt.Errorf("audio transcription language field: %w", err)
|
||||
}
|
||||
}
|
||||
fw, err := mw.CreateFormFile("file", filename)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("audio transcription create file: %w", err)
|
||||
}
|
||||
if _, err := fw.Write(audio); err != nil {
|
||||
return nil, 0, fmt.Errorf("audio transcription copy audio: %w", err)
|
||||
}
|
||||
if err := mw.Close(); err != nil {
|
||||
return nil, 0, fmt.Errorf("audio transcription close form: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, provider.BaseURL+"/v1/audio/transcriptions", body)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("audio transcription request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", mw.FormDataContentType())
|
||||
if provider.APIKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+provider.APIKey)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
resp, err := c.http.Do(req)
|
||||
duration := time.Since(start)
|
||||
if err != nil {
|
||||
return nil, duration, fmt.Errorf("audio transcription do: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
raw, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
|
||||
if err != nil {
|
||||
return nil, duration, fmt.Errorf("audio transcription read: %w", err)
|
||||
}
|
||||
if resp.StatusCode >= 300 {
|
||||
return nil, duration, fmt.Errorf("audio transcription HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(raw)))
|
||||
}
|
||||
var out audioTranscriptionResponse
|
||||
if err := json.Unmarshal(raw, &out); err != nil {
|
||||
return nil, duration, fmt.Errorf("audio transcription decode: %w", err)
|
||||
}
|
||||
if out.Error != nil {
|
||||
return nil, duration, fmt.Errorf("audio transcription error: %s", out.Error.Message)
|
||||
}
|
||||
modelName := firstNonEmpty(out.Model, provider.Model)
|
||||
return &audioLLMResponse{Text: strings.TrimSpace(out.Text), Model: modelName}, duration, nil
|
||||
}
|
||||
|
||||
func audioFormat(filename string) string {
|
||||
ext := strings.TrimPrefix(strings.ToLower(filepath.Ext(filename)), ".")
|
||||
switch ext {
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package transcription
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@@ -66,6 +69,49 @@ func TestAudioDataURLUsesVLLMAudioURLFormat(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestVoxtralUsesAudioTranscriptionsEndpoint(t *testing.T) {
|
||||
audioSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte("fake audio"))
|
||||
}))
|
||||
defer audioSrv.Close()
|
||||
|
||||
var gotPath, gotModel string
|
||||
providerSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
if err := r.ParseMultipartForm(16 << 20); err != nil {
|
||||
t.Fatalf("ParseMultipartForm: %v", err)
|
||||
}
|
||||
gotModel = r.FormValue("model")
|
||||
if _, _, err := r.FormFile("file"); err != nil {
|
||||
t.Fatalf("FormFile: %v", err)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"text": "Алло, тест."})
|
||||
}))
|
||||
defer providerSrv.Close()
|
||||
|
||||
client := NewWithOptions(Options{
|
||||
Providers: []string{"voxtral-small"},
|
||||
VoxtralBaseURL: providerSrv.URL,
|
||||
VoxtralModel: "mistralai/Voxtral-Small-24B-2507",
|
||||
})
|
||||
if client == nil {
|
||||
t.Fatal("client is nil")
|
||||
}
|
||||
got, err := client.Transcribe(t.Context(), Input{AudioURL: audioSrv.URL, Filename: "call.mp3"})
|
||||
if err != nil {
|
||||
t.Fatalf("Transcribe: %v", err)
|
||||
}
|
||||
if gotPath != "/v1/audio/transcriptions" {
|
||||
t.Fatalf("path = %q, want /v1/audio/transcriptions", gotPath)
|
||||
}
|
||||
if gotModel != "mistralai/Voxtral-Small-24B-2507" {
|
||||
t.Fatalf("model = %q", gotModel)
|
||||
}
|
||||
if len(got.Segments) != 1 || got.Segments[0].Text != "Алло, тест." {
|
||||
t.Fatalf("segments = %#v", got.Segments)
|
||||
}
|
||||
}
|
||||
|
||||
func near(got, want float64) bool {
|
||||
return math.Abs(got-want) < 0.000001
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user