Add transcription provider comparison chain
This commit is contained in:
@@ -3,6 +3,7 @@ package transcription
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -16,12 +17,47 @@ import (
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
baseURL string
|
||||
providers []ProviderConfig
|
||||
http *http.Client
|
||||
ffmpegPath string
|
||||
leadSilence time.Duration
|
||||
}
|
||||
|
||||
const (
|
||||
ProviderWhisperX = "whisperx"
|
||||
ProviderQwenAudio = "qwen2-audio"
|
||||
ProviderVoxtral = "voxtral-small"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
Providers []string
|
||||
WhisperXURL string
|
||||
WhisperXTimeout time.Duration
|
||||
FfmpegPath string
|
||||
LeadSilence time.Duration
|
||||
QwenAudioBaseURL string
|
||||
QwenAudioAPIKey string
|
||||
QwenAudioModel string
|
||||
QwenAudioTimeout time.Duration
|
||||
VoxtralBaseURL string
|
||||
VoxtralAPIKey string
|
||||
VoxtralModel string
|
||||
VoxtralTimeout time.Duration
|
||||
AudioLLMPrompt string
|
||||
AudioLLMMaxTokens int
|
||||
}
|
||||
|
||||
type ProviderConfig struct {
|
||||
Name string
|
||||
Kind string
|
||||
BaseURL string
|
||||
APIKey string
|
||||
Model string
|
||||
Timeout time.Duration
|
||||
MaxTokens int
|
||||
Prompt string
|
||||
}
|
||||
|
||||
type Input struct {
|
||||
AudioURL string `json:"audio_url"`
|
||||
Filename string `json:"filename,omitempty"`
|
||||
@@ -39,6 +75,9 @@ type Segment struct {
|
||||
}
|
||||
|
||||
type Result struct {
|
||||
Provider string `json:"provider,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Attempts []Attempt `json:"attempts,omitempty"`
|
||||
Language string `json:"language"`
|
||||
Segments []Segment `json:"segments"`
|
||||
DiarizeError *string `json:"diarize_error,omitempty"`
|
||||
@@ -46,6 +85,16 @@ type Result struct {
|
||||
DurationMS int64 `json:"duration_ms"`
|
||||
}
|
||||
|
||||
type Attempt struct {
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Status string `json:"status"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Segments []Segment `json:"segments,omitempty"`
|
||||
DurationMS int64 `json:"duration_ms,omitempty"`
|
||||
}
|
||||
|
||||
type whisperResponse struct {
|
||||
Language string `json:"language"`
|
||||
Segments []Segment `json:"segments"`
|
||||
@@ -53,35 +102,188 @@ type whisperResponse struct {
|
||||
AlignError *string `json:"align_error,omitempty"`
|
||||
}
|
||||
|
||||
type audioLLMResponse struct {
|
||||
Text string
|
||||
Model string
|
||||
}
|
||||
|
||||
type audioLLMChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []audioLLMChatMessage `json:"messages"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
}
|
||||
|
||||
type audioLLMChatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content []audioLLMContentPart `json:"content"`
|
||||
}
|
||||
|
||||
type audioLLMContentPart struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
InputAudio *audioLLMAudio `json:"input_audio,omitempty"`
|
||||
}
|
||||
|
||||
type audioLLMAudio struct {
|
||||
Data string `json:"data"`
|
||||
Format string `json:"format,omitempty"`
|
||||
}
|
||||
|
||||
type audioLLMChatResponse struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
Error *struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func New(baseURL string, timeout time.Duration, ffmpegPath string, leadSilence time.Duration) *Client {
|
||||
baseURL = strings.TrimRight(strings.TrimSpace(baseURL), "/")
|
||||
if baseURL == "" {
|
||||
return nil
|
||||
}
|
||||
if timeout <= 0 {
|
||||
timeout = 10 * time.Minute
|
||||
}
|
||||
return NewWithOptions(Options{
|
||||
Providers: []string{ProviderWhisperX},
|
||||
WhisperXURL: baseURL,
|
||||
WhisperXTimeout: timeout,
|
||||
FfmpegPath: ffmpegPath,
|
||||
LeadSilence: leadSilence,
|
||||
})
|
||||
}
|
||||
|
||||
func NewWithOptions(opts Options) *Client {
|
||||
leadSilence := opts.LeadSilence
|
||||
if leadSilence < 0 {
|
||||
leadSilence = 0
|
||||
}
|
||||
if leadSilence > 5*time.Second {
|
||||
leadSilence = 5 * time.Second
|
||||
}
|
||||
ffmpegPath = strings.TrimSpace(ffmpegPath)
|
||||
ffmpegPath := strings.TrimSpace(opts.FfmpegPath)
|
||||
if ffmpegPath == "" {
|
||||
ffmpegPath = "ffmpeg"
|
||||
}
|
||||
maxTokens := opts.AudioLLMMaxTokens
|
||||
if maxTokens <= 0 {
|
||||
maxTokens = 4096
|
||||
}
|
||||
audioLLMPrompt := strings.TrimSpace(opts.AudioLLMPrompt)
|
||||
if audioLLMPrompt == "" {
|
||||
audioLLMPrompt = "Transcribe the audio exactly. Return only the transcript text."
|
||||
}
|
||||
providers := buildProviders(opts, audioLLMPrompt, maxTokens)
|
||||
if len(providers) == 0 {
|
||||
return nil
|
||||
}
|
||||
return &Client{
|
||||
baseURL: baseURL,
|
||||
http: &http.Client{Timeout: timeout},
|
||||
providers: providers,
|
||||
http: &http.Client{Timeout: maxProviderTimeout(providers)},
|
||||
ffmpegPath: ffmpegPath,
|
||||
leadSilence: leadSilence,
|
||||
}
|
||||
}
|
||||
|
||||
func buildProviders(opts Options, prompt string, maxTokens int) []ProviderConfig {
|
||||
order := normalizeProviderOrder(opts.Providers)
|
||||
out := make([]ProviderConfig, 0, len(order))
|
||||
for _, name := range order {
|
||||
switch name {
|
||||
case ProviderWhisperX:
|
||||
baseURL := strings.TrimRight(strings.TrimSpace(opts.WhisperXURL), "/")
|
||||
if baseURL == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, ProviderConfig{
|
||||
Name: ProviderWhisperX,
|
||||
Kind: ProviderWhisperX,
|
||||
BaseURL: baseURL,
|
||||
Model: ProviderWhisperX,
|
||||
Timeout: defaultDuration(opts.WhisperXTimeout, 10*time.Minute),
|
||||
})
|
||||
case ProviderQwenAudio:
|
||||
baseURL := strings.TrimRight(strings.TrimSpace(opts.QwenAudioBaseURL), "/")
|
||||
if baseURL == "" {
|
||||
continue
|
||||
}
|
||||
model := firstNonEmpty(opts.QwenAudioModel, "Qwen/Qwen2-Audio-7B-Instruct")
|
||||
out = append(out, ProviderConfig{
|
||||
Name: ProviderQwenAudio,
|
||||
Kind: "audio_llm",
|
||||
BaseURL: baseURL,
|
||||
APIKey: strings.TrimSpace(opts.QwenAudioAPIKey),
|
||||
Model: model,
|
||||
Timeout: defaultDuration(opts.QwenAudioTimeout, 10*time.Minute),
|
||||
MaxTokens: maxTokens,
|
||||
Prompt: prompt,
|
||||
})
|
||||
case ProviderVoxtral:
|
||||
baseURL := strings.TrimRight(strings.TrimSpace(opts.VoxtralBaseURL), "/")
|
||||
if baseURL == "" {
|
||||
continue
|
||||
}
|
||||
model := firstNonEmpty(opts.VoxtralModel, "mistralai/Voxtral-Small-24B-2507")
|
||||
out = append(out, ProviderConfig{
|
||||
Name: ProviderVoxtral,
|
||||
Kind: "audio_llm",
|
||||
BaseURL: baseURL,
|
||||
APIKey: strings.TrimSpace(opts.VoxtralAPIKey),
|
||||
Model: model,
|
||||
Timeout: defaultDuration(opts.VoxtralTimeout, 10*time.Minute),
|
||||
MaxTokens: maxTokens,
|
||||
Prompt: prompt,
|
||||
})
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeProviderOrder(in []string) []string {
|
||||
if len(in) == 0 {
|
||||
return []string{ProviderWhisperX}
|
||||
}
|
||||
out := make([]string, 0, len(in))
|
||||
seen := map[string]bool{}
|
||||
for _, raw := range in {
|
||||
name := strings.ToLower(strings.TrimSpace(raw))
|
||||
switch name {
|
||||
case "whisper", "whisperx":
|
||||
name = ProviderWhisperX
|
||||
case "qwen", "qwen-audio", "qwen2-audio", "qwen2-audio-7b-instruct":
|
||||
name = ProviderQwenAudio
|
||||
case "voxtral", "voxtral-small", "voxtral-small-24b-2507":
|
||||
name = ProviderVoxtral
|
||||
default:
|
||||
continue
|
||||
}
|
||||
if !seen[name] {
|
||||
out = append(out, name)
|
||||
seen[name] = true
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func maxProviderTimeout(providers []ProviderConfig) time.Duration {
|
||||
maxTimeout := 10 * time.Minute
|
||||
for _, provider := range providers {
|
||||
if provider.Timeout > maxTimeout {
|
||||
maxTimeout = provider.Timeout
|
||||
}
|
||||
}
|
||||
return maxTimeout
|
||||
}
|
||||
|
||||
func defaultDuration(v, fallback time.Duration) time.Duration {
|
||||
if v <= 0 {
|
||||
return fallback
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func (c *Client) Transcribe(ctx context.Context, in Input) (*Result, error) {
|
||||
if c == nil || c.baseURL == "" {
|
||||
return nil, fmt.Errorf("whisperx not configured")
|
||||
if c == nil || len(c.providers) == 0 {
|
||||
return nil, fmt.Errorf("transcription providers not configured")
|
||||
}
|
||||
if strings.TrimSpace(in.AudioURL) == "" {
|
||||
return nil, fmt.Errorf("audio_url is required")
|
||||
@@ -96,18 +298,91 @@ func (c *Client) Transcribe(ctx context.Context, in Input) (*Result, error) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
resp, duration, err := c.transcribeAudio(ctx, audio, filename, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
var attempts []Attempt
|
||||
var winner *Result
|
||||
var errors []string
|
||||
for _, provider := range c.providers {
|
||||
result, attempt, err := c.transcribeWithProvider(ctx, provider, audio, filename, in)
|
||||
attempts = append(attempts, attempt)
|
||||
if err != nil {
|
||||
errors = append(errors, provider.Name+": "+err.Error())
|
||||
continue
|
||||
}
|
||||
if winner == nil {
|
||||
winner = result
|
||||
}
|
||||
}
|
||||
segments := adjustLeadSilence(resp.Segments, c.leadSilence)
|
||||
return &Result{
|
||||
Language: resp.Language,
|
||||
Segments: segments,
|
||||
DiarizeError: resp.DiarizeError,
|
||||
AlignError: resp.AlignError,
|
||||
DurationMS: duration.Milliseconds(),
|
||||
}, nil
|
||||
if winner == nil {
|
||||
return nil, fmt.Errorf("all transcription providers failed: %s", strings.Join(errors, "; "))
|
||||
}
|
||||
winner.Attempts = attempts
|
||||
return winner, nil
|
||||
}
|
||||
|
||||
func (c *Client) transcribeWithProvider(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*Result, Attempt, error) {
|
||||
providerCtx := ctx
|
||||
cancel := func() {}
|
||||
if provider.Timeout > 0 {
|
||||
providerCtx, cancel = context.WithTimeout(ctx, provider.Timeout)
|
||||
}
|
||||
defer cancel()
|
||||
attempt := Attempt{
|
||||
Provider: provider.Name,
|
||||
Model: provider.Model,
|
||||
Status: "failed",
|
||||
}
|
||||
switch provider.Kind {
|
||||
case ProviderWhisperX:
|
||||
resp, duration, err := c.transcribeAudio(providerCtx, provider, audio, filename, in)
|
||||
attempt.DurationMS = duration.Milliseconds()
|
||||
if err != nil {
|
||||
attempt.Error = err.Error()
|
||||
return nil, attempt, err
|
||||
}
|
||||
segments := adjustLeadSilence(resp.Segments, c.leadSilence)
|
||||
attempt.Status = "ok"
|
||||
attempt.Segments = segments
|
||||
attempt.Text = segmentsText(segments)
|
||||
return &Result{
|
||||
Provider: provider.Name,
|
||||
Model: provider.Model,
|
||||
Language: resp.Language,
|
||||
Segments: segments,
|
||||
DiarizeError: resp.DiarizeError,
|
||||
AlignError: resp.AlignError,
|
||||
DurationMS: duration.Milliseconds(),
|
||||
}, attempt, nil
|
||||
default:
|
||||
resp, duration, err := c.transcribeAudioLLM(providerCtx, provider, audio, filename, in)
|
||||
attempt.DurationMS = duration.Milliseconds()
|
||||
if err != nil {
|
||||
attempt.Error = err.Error()
|
||||
return nil, attempt, err
|
||||
}
|
||||
text := strings.TrimSpace(resp.Text)
|
||||
segments := []Segment{{Start: 0, End: 0, Text: text}}
|
||||
attempt.Status = "ok"
|
||||
attempt.Model = resp.Model
|
||||
attempt.Text = text
|
||||
attempt.Segments = segments
|
||||
return &Result{
|
||||
Provider: provider.Name,
|
||||
Model: resp.Model,
|
||||
Language: firstNonEmpty(in.Language, "unknown"),
|
||||
Segments: segments,
|
||||
DurationMS: duration.Milliseconds(),
|
||||
}, attempt, nil
|
||||
}
|
||||
}
|
||||
|
||||
func segmentsText(segments []Segment) string {
|
||||
parts := make([]string, 0, len(segments))
|
||||
for _, segment := range segments {
|
||||
if text := strings.TrimSpace(segment.Text); text != "" {
|
||||
parts = append(parts, text)
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
|
||||
func (c *Client) downloadAudio(ctx context.Context, in Input) ([]byte, string, error) {
|
||||
@@ -222,7 +497,7 @@ func clampTime(v float64) float64 {
|
||||
return v
|
||||
}
|
||||
|
||||
func (c *Client) transcribeAudio(ctx context.Context, audio []byte, filename string, in Input) (*whisperResponse, time.Duration, error) {
|
||||
func (c *Client) transcribeAudio(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*whisperResponse, time.Duration, error) {
|
||||
body := &bytes.Buffer{}
|
||||
mw := multipart.NewWriter(body)
|
||||
fw, err := mw.CreateFormFile("file", filename)
|
||||
@@ -250,7 +525,7 @@ func (c *Client) transcribeAudio(ctx context.Context, audio []byte, filename str
|
||||
return nil, 0, fmt.Errorf("close form: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/transcribe", body)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, provider.BaseURL+"/transcribe", body)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("whisperx request: %w", err)
|
||||
}
|
||||
@@ -273,3 +548,97 @@ func (c *Client) transcribeAudio(ctx context.Context, audio []byte, filename str
|
||||
}
|
||||
return &out, duration, nil
|
||||
}
|
||||
|
||||
func (c *Client) transcribeAudioLLM(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*audioLLMResponse, time.Duration, error) {
|
||||
prompt := provider.Prompt
|
||||
if in.Language != "" {
|
||||
prompt += "\nЯзык аудио: " + in.Language + "."
|
||||
}
|
||||
if in.Diarize {
|
||||
prompt += "\nЕсли слышны разные говорящие, разделяй реплики с короткими пометками Спикер 1/Спикер 2."
|
||||
}
|
||||
reqBody := audioLLMChatRequest{
|
||||
Model: provider.Model,
|
||||
MaxTokens: provider.MaxTokens,
|
||||
Temperature: 0,
|
||||
Messages: []audioLLMChatMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []audioLLMContentPart{
|
||||
{Type: "text", Text: prompt},
|
||||
{
|
||||
Type: "input_audio",
|
||||
InputAudio: &audioLLMAudio{
|
||||
Data: base64.StdEncoding.EncodeToString(audio),
|
||||
Format: audioFormat(filename),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
body, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("audio llm marshal: %w", err)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, provider.BaseURL+"/v1/chat/completions", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("audio llm request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if provider.APIKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+provider.APIKey)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
resp, err := c.http.Do(req)
|
||||
duration := time.Since(start)
|
||||
if err != nil {
|
||||
return nil, duration, fmt.Errorf("audio llm do: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
raw, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
|
||||
if err != nil {
|
||||
return nil, duration, fmt.Errorf("audio llm read: %w", err)
|
||||
}
|
||||
if resp.StatusCode >= 300 {
|
||||
return nil, duration, fmt.Errorf("audio llm HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(raw)))
|
||||
}
|
||||
var out audioLLMChatResponse
|
||||
if err := json.Unmarshal(raw, &out); err != nil {
|
||||
return nil, duration, fmt.Errorf("audio llm decode: %w", err)
|
||||
}
|
||||
if out.Error != nil {
|
||||
return nil, duration, fmt.Errorf("audio llm error: %s", out.Error.Message)
|
||||
}
|
||||
if len(out.Choices) == 0 {
|
||||
return nil, duration, fmt.Errorf("audio llm: empty choices")
|
||||
}
|
||||
modelName := out.Model
|
||||
if modelName == "" {
|
||||
modelName = provider.Model
|
||||
}
|
||||
return &audioLLMResponse{
|
||||
Text: strings.TrimSpace(out.Choices[0].Message.Content),
|
||||
Model: modelName,
|
||||
}, duration, nil
|
||||
}
|
||||
|
||||
func audioFormat(filename string) string {
|
||||
ext := strings.TrimPrefix(strings.ToLower(filepath.Ext(filename)), ".")
|
||||
switch ext {
|
||||
case "wav", "mp3", "flac", "m4a", "ogg", "opus", "webm":
|
||||
return ext
|
||||
default:
|
||||
return "mp3"
|
||||
}
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, value := range values {
|
||||
if strings.TrimSpace(value) != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user