Make Voxtral the only transcription provider
Some checks failed
CI / test (push) Failing after 8s
Build and Deploy / build-and-deploy (push) Successful in 27s

This commit is contained in:
Grendgi
2026-06-09 16:54:54 +03:00
parent 5c965be8c9
commit 9bd6d726f0
15 changed files with 128 additions and 900 deletions

View File

@@ -3,59 +3,38 @@ package transcription
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
)
type Client struct {
providers []ProviderConfig
http *http.Client
ffmpegPath string
leadSilence time.Duration
provider ProviderConfig
http *http.Client
}
const (
ProviderWhisperX = "whisperx"
ProviderQwenAudio = "qwen2-audio"
ProviderVoxtral = "voxtral-small"
)
const ProviderVoxtral = "voxtral-small"
type Options struct {
Providers []string
WhisperXURL string
WhisperXTimeout time.Duration
FfmpegPath string
LeadSilence time.Duration
QwenAudioBaseURL string
QwenAudioAPIKey string
QwenAudioModel string
QwenAudioTimeout time.Duration
VoxtralBaseURL string
VoxtralAPIKey string
VoxtralModel string
VoxtralTimeout time.Duration
AudioLLMPrompt string
AudioLLMMaxTokens int
VoxtralBaseURL string
VoxtralAPIKey string
VoxtralModel string
VoxtralTimeout time.Duration
AudioLLMPrompt string
}
type ProviderConfig struct {
Name string
Kind string
BaseURL string
APIKey string
Model string
Timeout time.Duration
MaxTokens int
Prompt string
Name string
BaseURL string
APIKey string
Model string
Timeout time.Duration
Prompt string
}
type Input struct {
@@ -95,13 +74,6 @@ type Attempt struct {
DurationMS int64 `json:"duration_ms,omitempty"`
}
type whisperResponse struct {
Language string `json:"language"`
Segments []Segment `json:"segments"`
DiarizeError *string `json:"diarize_error,omitempty"`
AlignError *string `json:"align_error,omitempty"`
}
type audioLLMResponse struct {
Text string
Model string
@@ -109,40 +81,6 @@ type audioLLMResponse struct {
Segments []Segment
}
type audioLLMChatRequest struct {
Model string `json:"model"`
Messages []audioLLMChatMessage `json:"messages"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature"`
}
type audioLLMChatMessage struct {
Role string `json:"role"`
Content []audioLLMContentPart `json:"content"`
}
type audioLLMContentPart struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
AudioURL *audioLLMURLRef `json:"audio_url,omitempty"`
}
type audioLLMURLRef struct {
URL string `json:"url"`
}
type audioLLMChatResponse struct {
Model string `json:"model,omitempty"`
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
Error *struct {
Message string `json:"message"`
} `json:"error,omitempty"`
}
type audioTranscriptionResponse struct {
Text string `json:"text"`
Language string `json:"language,omitempty"`
@@ -161,134 +99,40 @@ type audioTranscriptionSegment struct {
func New(baseURL string, timeout time.Duration, ffmpegPath string, leadSilence time.Duration) *Client {
return NewWithOptions(Options{
Providers: []string{ProviderWhisperX},
WhisperXURL: baseURL,
WhisperXTimeout: timeout,
FfmpegPath: ffmpegPath,
LeadSilence: leadSilence,
VoxtralBaseURL: baseURL,
VoxtralTimeout: timeout,
})
}
func NewWithOptions(opts Options) *Client {
leadSilence := opts.LeadSilence
if leadSilence < 0 {
leadSilence = 0
}
if leadSilence > 5*time.Second {
leadSilence = 5 * time.Second
}
ffmpegPath := strings.TrimSpace(opts.FfmpegPath)
if ffmpegPath == "" {
ffmpegPath = "ffmpeg"
}
maxTokens := opts.AudioLLMMaxTokens
if maxTokens <= 0 {
maxTokens = 4096
}
audioLLMPrompt := strings.TrimSpace(opts.AudioLLMPrompt)
if audioLLMPrompt == "" {
audioLLMPrompt = "Transcribe the audio exactly. Return only the transcript text."
}
providers := buildProviders(opts, audioLLMPrompt, maxTokens)
if len(providers) == 0 {
provider := buildVoxtralProvider(opts, audioLLMPrompt)
if provider.BaseURL == "" {
return nil
}
return &Client{
providers: providers,
http: &http.Client{Timeout: maxProviderTimeout(providers)},
ffmpegPath: ffmpegPath,
leadSilence: leadSilence,
provider: provider,
http: &http.Client{Timeout: provider.Timeout},
}
}
func buildProviders(opts Options, prompt string, maxTokens int) []ProviderConfig {
order := normalizeProviderOrder(opts.Providers)
out := make([]ProviderConfig, 0, len(order))
for _, name := range order {
switch name {
case ProviderWhisperX:
baseURL := strings.TrimRight(strings.TrimSpace(opts.WhisperXURL), "/")
if baseURL == "" {
continue
}
out = append(out, ProviderConfig{
Name: ProviderWhisperX,
Kind: ProviderWhisperX,
BaseURL: baseURL,
Model: ProviderWhisperX,
Timeout: defaultDuration(opts.WhisperXTimeout, 10*time.Minute),
})
case ProviderQwenAudio:
baseURL := strings.TrimRight(strings.TrimSpace(opts.QwenAudioBaseURL), "/")
if baseURL == "" {
continue
}
model := firstNonEmpty(opts.QwenAudioModel, "Qwen/Qwen2-Audio-7B-Instruct")
out = append(out, ProviderConfig{
Name: ProviderQwenAudio,
Kind: "audio_llm",
BaseURL: baseURL,
APIKey: strings.TrimSpace(opts.QwenAudioAPIKey),
Model: model,
Timeout: defaultDuration(opts.QwenAudioTimeout, 10*time.Minute),
MaxTokens: maxTokens,
Prompt: prompt,
})
case ProviderVoxtral:
baseURL := strings.TrimRight(strings.TrimSpace(opts.VoxtralBaseURL), "/")
if baseURL == "" {
continue
}
model := firstNonEmpty(opts.VoxtralModel, "mistralai/Voxtral-Small-24B-2507")
out = append(out, ProviderConfig{
Name: ProviderVoxtral,
Kind: "audio_transcription",
BaseURL: baseURL,
APIKey: strings.TrimSpace(opts.VoxtralAPIKey),
Model: model,
Timeout: defaultDuration(opts.VoxtralTimeout, 10*time.Minute),
MaxTokens: maxTokens,
Prompt: prompt,
})
}
func buildVoxtralProvider(opts Options, prompt string) ProviderConfig {
baseURL := strings.TrimRight(strings.TrimSpace(opts.VoxtralBaseURL), "/")
if baseURL == "" {
return ProviderConfig{}
}
return out
}
func normalizeProviderOrder(in []string) []string {
if len(in) == 0 {
return []string{ProviderWhisperX}
model := firstNonEmpty(opts.VoxtralModel, "mistralai/Voxtral-Small-24B-2507")
return ProviderConfig{
Name: ProviderVoxtral,
BaseURL: baseURL,
APIKey: strings.TrimSpace(opts.VoxtralAPIKey),
Model: model,
Timeout: defaultDuration(opts.VoxtralTimeout, 10*time.Minute),
Prompt: prompt,
}
out := make([]string, 0, len(in))
seen := map[string]bool{}
for _, raw := range in {
name := strings.ToLower(strings.TrimSpace(raw))
switch name {
case "whisper", "whisperx":
name = ProviderWhisperX
case "qwen", "qwen-audio", "qwen2-audio", "qwen2-audio-7b-instruct":
name = ProviderQwenAudio
case "voxtral", "voxtral-small", "voxtral-small-24b-2507":
name = ProviderVoxtral
default:
continue
}
if !seen[name] {
out = append(out, name)
seen[name] = true
}
}
return out
}
func maxProviderTimeout(providers []ProviderConfig) time.Duration {
maxTimeout := 10 * time.Minute
for _, provider := range providers {
if provider.Timeout > maxTimeout {
maxTimeout = provider.Timeout
}
}
return maxTimeout
}
func defaultDuration(v, fallback time.Duration) time.Duration {
@@ -299,8 +143,8 @@ func defaultDuration(v, fallback time.Duration) time.Duration {
}
func (c *Client) Transcribe(ctx context.Context, in Input) (*Result, error) {
if c == nil || len(c.providers) == 0 {
return nil, fmt.Errorf("transcription providers not configured")
if c == nil || c.provider.BaseURL == "" {
return nil, fmt.Errorf("voxtral transcription provider not configured")
}
if strings.TrimSpace(in.AudioURL) == "" {
return nil, fmt.Errorf("audio_url is required")
@@ -309,31 +153,12 @@ func (c *Client) Transcribe(ctx context.Context, in Input) (*Result, error) {
if err != nil {
return nil, err
}
if c.leadSilence > 0 {
audio, filename, err = c.addLeadSilence(ctx, audio, filename)
if err != nil {
return nil, err
}
result, attempt, err := c.transcribeWithProvider(ctx, c.provider, audio, filename, in)
if err != nil {
return nil, err
}
var attempts []Attempt
var winner *Result
var errors []string
for _, provider := range c.providers {
result, attempt, err := c.transcribeWithProvider(ctx, provider, audio, filename, in)
attempts = append(attempts, attempt)
if err != nil {
errors = append(errors, provider.Name+": "+err.Error())
continue
}
if winner == nil {
winner = result
}
}
if winner == nil {
return nil, fmt.Errorf("all transcription providers failed: %s", strings.Join(errors, "; "))
}
winner.Attempts = attempts
return winner, nil
result.Attempts = []Attempt{attempt}
return result, nil
}
func (c *Client) transcribeWithProvider(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*Result, Attempt, error) {
@@ -348,58 +173,25 @@ func (c *Client) transcribeWithProvider(ctx context.Context, provider ProviderCo
Model: provider.Model,
Status: "failed",
}
switch provider.Kind {
case ProviderWhisperX:
resp, duration, err := c.transcribeAudio(providerCtx, provider, audio, filename, in)
attempt.DurationMS = duration.Milliseconds()
if err != nil {
attempt.Error = err.Error()
return nil, attempt, err
}
segments := adjustLeadSilence(resp.Segments, c.leadSilence)
attempt.Status = "ok"
attempt.Segments = segments
attempt.Text = segmentsText(segments)
return &Result{
Provider: provider.Name,
Model: provider.Model,
Language: resp.Language,
Segments: segments,
DiarizeError: resp.DiarizeError,
AlignError: resp.AlignError,
DurationMS: duration.Milliseconds(),
}, attempt, nil
default:
resp, duration, err := c.transcribeAudioLLM(providerCtx, provider, audio, filename, in)
attempt.DurationMS = duration.Milliseconds()
if err != nil {
attempt.Error = err.Error()
return nil, attempt, err
}
text := strings.TrimSpace(resp.Text)
segments := normalizeAudioLLMSegments(resp.Segments, text, in.Diarize)
attempt.Status = "ok"
attempt.Model = resp.Model
attempt.Text = text
attempt.Segments = segments
return &Result{
Provider: provider.Name,
Model: resp.Model,
Language: firstNonEmpty(resp.Language, in.Language, "unknown"),
Segments: segments,
DurationMS: duration.Milliseconds(),
}, attempt, nil
resp, duration, err := c.transcribeOpenAIAudio(providerCtx, provider, audio, filename, in)
attempt.DurationMS = duration.Milliseconds()
if err != nil {
attempt.Error = err.Error()
return nil, attempt, err
}
}
func segmentsText(segments []Segment) string {
parts := make([]string, 0, len(segments))
for _, segment := range segments {
if text := strings.TrimSpace(segment.Text); text != "" {
parts = append(parts, text)
}
}
return strings.Join(parts, "\n")
text := strings.TrimSpace(resp.Text)
segments := normalizeAudioLLMSegments(resp.Segments, text, in.Diarize)
attempt.Status = "ok"
attempt.Model = resp.Model
attempt.Text = text
attempt.Segments = segments
return &Result{
Provider: provider.Name,
Model: resp.Model,
Language: firstNonEmpty(resp.Language, in.Language, "unknown"),
Segments: segments,
DurationMS: duration.Milliseconds(),
}, attempt, nil
}
func (c *Client) downloadAudio(ctx context.Context, in Input) ([]byte, string, error) {
@@ -430,83 +222,6 @@ func (c *Client) downloadAudio(ctx context.Context, in Input) ([]byte, string, e
return audio, filename, nil
}
func (c *Client) addLeadSilence(ctx context.Context, audio []byte, filename string) ([]byte, string, error) {
tmpDir, err := os.MkdirTemp("", "ai-transcribe-*")
if err != nil {
return nil, "", fmt.Errorf("prepare audio temp dir: %w", err)
}
defer os.RemoveAll(tmpDir)
inputPath := filepath.Join(tmpDir, "input"+safeExt(filename))
outputPath := filepath.Join(tmpDir, "padded.mp3")
if err := os.WriteFile(inputPath, audio, 0o600); err != nil {
return nil, "", fmt.Errorf("write audio temp file: %w", err)
}
delayMS := int(c.leadSilence.Milliseconds())
if delayMS <= 0 {
return audio, filename, nil
}
cmd := exec.CommandContext(ctx, c.ffmpegPath,
"-nostdin", "-y",
"-i", inputPath,
"-af", fmt.Sprintf("adelay=%d:all=1", delayMS),
"-codec:a", "libmp3lame",
"-qscale:a", "5",
outputPath,
)
out, err := cmd.CombinedOutput()
if err != nil {
return nil, "", fmt.Errorf("ffmpeg lead silence: %w (%s)", err, trimOutput(out))
}
padded, err := os.ReadFile(outputPath)
if err != nil {
return nil, "", fmt.Errorf("read padded audio: %w", err)
}
if len(padded) == 0 {
return nil, "", fmt.Errorf("padded audio is empty")
}
base := strings.TrimSuffix(filepath.Base(filename), filepath.Ext(filename))
if base == "" || base == "." || base == "/" {
base = "audio"
}
return padded, base + "-padded.mp3", nil
}
func safeExt(filename string) string {
ext := strings.ToLower(filepath.Ext(filename))
switch ext {
case ".mp3", ".wav", ".m4a", ".ogg", ".opus", ".webm":
return ext
default:
return ".audio"
}
}
func trimOutput(out []byte) string {
s := strings.TrimSpace(string(out))
if len(s) > 600 {
return s[:600]
}
return s
}
func adjustLeadSilence(segments []Segment, silence time.Duration) []Segment {
if len(segments) == 0 || silence <= 0 {
return segments
}
shift := silence.Seconds()
out := make([]Segment, 0, len(segments))
for _, segment := range segments {
segment.Start = clampTime(segment.Start - shift)
segment.End = clampTime(segment.End - shift)
if segment.End < segment.Start {
segment.End = segment.Start
}
out = append(out, segment)
}
return out
}
func clampTime(v float64) float64 {
if v < 0 {
return 0
@@ -514,133 +229,6 @@ func clampTime(v float64) float64 {
return v
}
func (c *Client) transcribeAudio(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*whisperResponse, time.Duration, error) {
body := &bytes.Buffer{}
mw := multipart.NewWriter(body)
fw, err := mw.CreateFormFile("file", filename)
if err != nil {
return nil, 0, fmt.Errorf("create form file: %w", err)
}
if _, err := fw.Write(audio); err != nil {
return nil, 0, fmt.Errorf("copy audio: %w", err)
}
if in.Language != "" {
_ = mw.WriteField("language", in.Language)
}
if in.Diarize {
_ = mw.WriteField("diarize", "true")
if in.MinSpeakers > 0 {
_ = mw.WriteField("min_speakers", fmt.Sprintf("%d", in.MinSpeakers))
}
if in.MaxSpeakers > 0 {
_ = mw.WriteField("max_speakers", fmt.Sprintf("%d", in.MaxSpeakers))
}
} else {
_ = mw.WriteField("diarize", "false")
}
if err := mw.Close(); err != nil {
return nil, 0, fmt.Errorf("close form: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, provider.BaseURL+"/transcribe", body)
if err != nil {
return nil, 0, fmt.Errorf("whisperx request: %w", err)
}
req.Header.Set("Content-Type", mw.FormDataContentType())
start := time.Now()
resp, err := c.http.Do(req)
duration := time.Since(start)
if err != nil {
return nil, duration, fmt.Errorf("whisperx do: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return nil, duration, fmt.Errorf("whisperx HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var out whisperResponse
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
return nil, duration, fmt.Errorf("whisperx decode: %w", err)
}
return &out, duration, nil
}
func (c *Client) transcribeAudioLLM(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*audioLLMResponse, time.Duration, error) {
if provider.Kind == "audio_transcription" {
return c.transcribeOpenAIAudio(ctx, provider, audio, filename, in)
}
prompt := provider.Prompt
if in.Language != "" {
prompt += "\nЯзык аудио: " + in.Language + "."
}
if in.Diarize {
prompt += "\nЕсли слышны разные говорящие, разделяй реплики с короткими пометками Спикер 1/Спикер 2."
}
reqBody := audioLLMChatRequest{
Model: provider.Model,
MaxTokens: provider.MaxTokens,
Temperature: 0,
Messages: []audioLLMChatMessage{
{
Role: "user",
Content: []audioLLMContentPart{
{Type: "text", Text: prompt},
{
Type: "audio_url",
AudioURL: &audioLLMURLRef{URL: audioDataURL(audio, filename)},
},
},
},
},
}
body, err := json.Marshal(reqBody)
if err != nil {
return nil, 0, fmt.Errorf("audio llm marshal: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, provider.BaseURL+"/v1/chat/completions", bytes.NewReader(body))
if err != nil {
return nil, 0, fmt.Errorf("audio llm request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
if provider.APIKey != "" {
req.Header.Set("Authorization", "Bearer "+provider.APIKey)
}
start := time.Now()
resp, err := c.http.Do(req)
duration := time.Since(start)
if err != nil {
return nil, duration, fmt.Errorf("audio llm do: %w", err)
}
defer resp.Body.Close()
raw, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
if err != nil {
return nil, duration, fmt.Errorf("audio llm read: %w", err)
}
if resp.StatusCode >= 300 {
return nil, duration, fmt.Errorf("audio llm HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(raw)))
}
var out audioLLMChatResponse
if err := json.Unmarshal(raw, &out); err != nil {
return nil, duration, fmt.Errorf("audio llm decode: %w", err)
}
if out.Error != nil {
return nil, duration, fmt.Errorf("audio llm error: %s", out.Error.Message)
}
if len(out.Choices) == 0 {
return nil, duration, fmt.Errorf("audio llm: empty choices")
}
modelName := out.Model
if modelName == "" {
modelName = provider.Model
}
return &audioLLMResponse{
Text: strings.TrimSpace(out.Choices[0].Message.Content),
Model: modelName,
}, duration, nil
}
func (c *Client) transcribeOpenAIAudio(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*audioLLMResponse, time.Duration, error) {
resp, duration, err := c.doOpenAIAudioTranscription(ctx, provider, audio, filename, in, "verbose_json")
if err == nil {
@@ -874,20 +462,6 @@ func mergeShortSegments(parts []string, minWords, maxWords int) []string {
return out
}
func audioFormat(filename string) string {
ext := strings.TrimPrefix(strings.ToLower(filepath.Ext(filename)), ".")
switch ext {
case "wav", "mp3", "flac", "m4a", "ogg", "opus", "webm":
return ext
default:
return "mp3"
}
}
func audioDataURL(audio []byte, filename string) string {
return "data:audio/" + audioFormat(filename) + ";base64," + base64.StdEncoding.EncodeToString(audio)
}
func firstNonEmpty(values ...string) string {
for _, value := range values {
if strings.TrimSpace(value) != "" {