Files
ai-service/internal/transcription/client.go
2026-06-09 12:34:08 +03:00

645 lines
18 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package transcription
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
)
type Client struct {
providers []ProviderConfig
http *http.Client
ffmpegPath string
leadSilence time.Duration
}
const (
ProviderWhisperX = "whisperx"
ProviderQwenAudio = "qwen2-audio"
ProviderVoxtral = "voxtral-small"
)
type Options struct {
Providers []string
WhisperXURL string
WhisperXTimeout time.Duration
FfmpegPath string
LeadSilence time.Duration
QwenAudioBaseURL string
QwenAudioAPIKey string
QwenAudioModel string
QwenAudioTimeout time.Duration
VoxtralBaseURL string
VoxtralAPIKey string
VoxtralModel string
VoxtralTimeout time.Duration
AudioLLMPrompt string
AudioLLMMaxTokens int
}
type ProviderConfig struct {
Name string
Kind string
BaseURL string
APIKey string
Model string
Timeout time.Duration
MaxTokens int
Prompt string
}
type Input struct {
AudioURL string `json:"audio_url"`
Filename string `json:"filename,omitempty"`
Language string `json:"language,omitempty"`
Diarize bool `json:"diarize"`
MinSpeakers int `json:"min_speakers,omitempty"`
MaxSpeakers int `json:"max_speakers,omitempty"`
}
type Segment struct {
Start float64 `json:"start"`
End float64 `json:"end"`
Text string `json:"text"`
Speaker string `json:"speaker,omitempty"`
}
type Result struct {
Provider string `json:"provider,omitempty"`
Model string `json:"model,omitempty"`
Attempts []Attempt `json:"attempts,omitempty"`
Language string `json:"language"`
Segments []Segment `json:"segments"`
DiarizeError *string `json:"diarize_error,omitempty"`
AlignError *string `json:"align_error,omitempty"`
DurationMS int64 `json:"duration_ms"`
}
type Attempt struct {
Provider string `json:"provider"`
Model string `json:"model,omitempty"`
Status string `json:"status"`
Error string `json:"error,omitempty"`
Text string `json:"text,omitempty"`
Segments []Segment `json:"segments,omitempty"`
DurationMS int64 `json:"duration_ms,omitempty"`
}
type whisperResponse struct {
Language string `json:"language"`
Segments []Segment `json:"segments"`
DiarizeError *string `json:"diarize_error,omitempty"`
AlignError *string `json:"align_error,omitempty"`
}
type audioLLMResponse struct {
Text string
Model string
}
type audioLLMChatRequest struct {
Model string `json:"model"`
Messages []audioLLMChatMessage `json:"messages"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature"`
}
type audioLLMChatMessage struct {
Role string `json:"role"`
Content []audioLLMContentPart `json:"content"`
}
type audioLLMContentPart struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
InputAudio *audioLLMAudio `json:"input_audio,omitempty"`
}
type audioLLMAudio struct {
Data string `json:"data"`
Format string `json:"format,omitempty"`
}
type audioLLMChatResponse struct {
Model string `json:"model,omitempty"`
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
Error *struct {
Message string `json:"message"`
} `json:"error,omitempty"`
}
func New(baseURL string, timeout time.Duration, ffmpegPath string, leadSilence time.Duration) *Client {
return NewWithOptions(Options{
Providers: []string{ProviderWhisperX},
WhisperXURL: baseURL,
WhisperXTimeout: timeout,
FfmpegPath: ffmpegPath,
LeadSilence: leadSilence,
})
}
func NewWithOptions(opts Options) *Client {
leadSilence := opts.LeadSilence
if leadSilence < 0 {
leadSilence = 0
}
if leadSilence > 5*time.Second {
leadSilence = 5 * time.Second
}
ffmpegPath := strings.TrimSpace(opts.FfmpegPath)
if ffmpegPath == "" {
ffmpegPath = "ffmpeg"
}
maxTokens := opts.AudioLLMMaxTokens
if maxTokens <= 0 {
maxTokens = 4096
}
audioLLMPrompt := strings.TrimSpace(opts.AudioLLMPrompt)
if audioLLMPrompt == "" {
audioLLMPrompt = "Transcribe the audio exactly. Return only the transcript text."
}
providers := buildProviders(opts, audioLLMPrompt, maxTokens)
if len(providers) == 0 {
return nil
}
return &Client{
providers: providers,
http: &http.Client{Timeout: maxProviderTimeout(providers)},
ffmpegPath: ffmpegPath,
leadSilence: leadSilence,
}
}
func buildProviders(opts Options, prompt string, maxTokens int) []ProviderConfig {
order := normalizeProviderOrder(opts.Providers)
out := make([]ProviderConfig, 0, len(order))
for _, name := range order {
switch name {
case ProviderWhisperX:
baseURL := strings.TrimRight(strings.TrimSpace(opts.WhisperXURL), "/")
if baseURL == "" {
continue
}
out = append(out, ProviderConfig{
Name: ProviderWhisperX,
Kind: ProviderWhisperX,
BaseURL: baseURL,
Model: ProviderWhisperX,
Timeout: defaultDuration(opts.WhisperXTimeout, 10*time.Minute),
})
case ProviderQwenAudio:
baseURL := strings.TrimRight(strings.TrimSpace(opts.QwenAudioBaseURL), "/")
if baseURL == "" {
continue
}
model := firstNonEmpty(opts.QwenAudioModel, "Qwen/Qwen2-Audio-7B-Instruct")
out = append(out, ProviderConfig{
Name: ProviderQwenAudio,
Kind: "audio_llm",
BaseURL: baseURL,
APIKey: strings.TrimSpace(opts.QwenAudioAPIKey),
Model: model,
Timeout: defaultDuration(opts.QwenAudioTimeout, 10*time.Minute),
MaxTokens: maxTokens,
Prompt: prompt,
})
case ProviderVoxtral:
baseURL := strings.TrimRight(strings.TrimSpace(opts.VoxtralBaseURL), "/")
if baseURL == "" {
continue
}
model := firstNonEmpty(opts.VoxtralModel, "mistralai/Voxtral-Small-24B-2507")
out = append(out, ProviderConfig{
Name: ProviderVoxtral,
Kind: "audio_llm",
BaseURL: baseURL,
APIKey: strings.TrimSpace(opts.VoxtralAPIKey),
Model: model,
Timeout: defaultDuration(opts.VoxtralTimeout, 10*time.Minute),
MaxTokens: maxTokens,
Prompt: prompt,
})
}
}
return out
}
func normalizeProviderOrder(in []string) []string {
if len(in) == 0 {
return []string{ProviderWhisperX}
}
out := make([]string, 0, len(in))
seen := map[string]bool{}
for _, raw := range in {
name := strings.ToLower(strings.TrimSpace(raw))
switch name {
case "whisper", "whisperx":
name = ProviderWhisperX
case "qwen", "qwen-audio", "qwen2-audio", "qwen2-audio-7b-instruct":
name = ProviderQwenAudio
case "voxtral", "voxtral-small", "voxtral-small-24b-2507":
name = ProviderVoxtral
default:
continue
}
if !seen[name] {
out = append(out, name)
seen[name] = true
}
}
return out
}
func maxProviderTimeout(providers []ProviderConfig) time.Duration {
maxTimeout := 10 * time.Minute
for _, provider := range providers {
if provider.Timeout > maxTimeout {
maxTimeout = provider.Timeout
}
}
return maxTimeout
}
func defaultDuration(v, fallback time.Duration) time.Duration {
if v <= 0 {
return fallback
}
return v
}
func (c *Client) Transcribe(ctx context.Context, in Input) (*Result, error) {
if c == nil || len(c.providers) == 0 {
return nil, fmt.Errorf("transcription providers not configured")
}
if strings.TrimSpace(in.AudioURL) == "" {
return nil, fmt.Errorf("audio_url is required")
}
audio, filename, err := c.downloadAudio(ctx, in)
if err != nil {
return nil, err
}
if c.leadSilence > 0 {
audio, filename, err = c.addLeadSilence(ctx, audio, filename)
if err != nil {
return nil, err
}
}
var attempts []Attempt
var winner *Result
var errors []string
for _, provider := range c.providers {
result, attempt, err := c.transcribeWithProvider(ctx, provider, audio, filename, in)
attempts = append(attempts, attempt)
if err != nil {
errors = append(errors, provider.Name+": "+err.Error())
continue
}
if winner == nil {
winner = result
}
}
if winner == nil {
return nil, fmt.Errorf("all transcription providers failed: %s", strings.Join(errors, "; "))
}
winner.Attempts = attempts
return winner, nil
}
func (c *Client) transcribeWithProvider(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*Result, Attempt, error) {
providerCtx := ctx
cancel := func() {}
if provider.Timeout > 0 {
providerCtx, cancel = context.WithTimeout(ctx, provider.Timeout)
}
defer cancel()
attempt := Attempt{
Provider: provider.Name,
Model: provider.Model,
Status: "failed",
}
switch provider.Kind {
case ProviderWhisperX:
resp, duration, err := c.transcribeAudio(providerCtx, provider, audio, filename, in)
attempt.DurationMS = duration.Milliseconds()
if err != nil {
attempt.Error = err.Error()
return nil, attempt, err
}
segments := adjustLeadSilence(resp.Segments, c.leadSilence)
attempt.Status = "ok"
attempt.Segments = segments
attempt.Text = segmentsText(segments)
return &Result{
Provider: provider.Name,
Model: provider.Model,
Language: resp.Language,
Segments: segments,
DiarizeError: resp.DiarizeError,
AlignError: resp.AlignError,
DurationMS: duration.Milliseconds(),
}, attempt, nil
default:
resp, duration, err := c.transcribeAudioLLM(providerCtx, provider, audio, filename, in)
attempt.DurationMS = duration.Milliseconds()
if err != nil {
attempt.Error = err.Error()
return nil, attempt, err
}
text := strings.TrimSpace(resp.Text)
segments := []Segment{{Start: 0, End: 0, Text: text}}
attempt.Status = "ok"
attempt.Model = resp.Model
attempt.Text = text
attempt.Segments = segments
return &Result{
Provider: provider.Name,
Model: resp.Model,
Language: firstNonEmpty(in.Language, "unknown"),
Segments: segments,
DurationMS: duration.Milliseconds(),
}, attempt, nil
}
}
func segmentsText(segments []Segment) string {
parts := make([]string, 0, len(segments))
for _, segment := range segments {
if text := strings.TrimSpace(segment.Text); text != "" {
parts = append(parts, text)
}
}
return strings.Join(parts, "\n")
}
func (c *Client) downloadAudio(ctx context.Context, in Input) ([]byte, string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, in.AudioURL, nil)
if err != nil {
return nil, "", fmt.Errorf("audio request: %w", err)
}
resp, err := c.http.Do(req)
if err != nil {
return nil, "", fmt.Errorf("audio download: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return nil, "", fmt.Errorf("audio HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
audio, err := io.ReadAll(io.LimitReader(resp.Body, 512<<20))
if err != nil {
return nil, "", fmt.Errorf("audio read: %w", err)
}
if len(audio) == 0 {
return nil, "", fmt.Errorf("audio is empty")
}
filename := filepath.Base(strings.TrimSpace(in.Filename))
if filename == "." || filename == "/" || filename == "" {
filename = "audio.mp3"
}
return audio, filename, nil
}
func (c *Client) addLeadSilence(ctx context.Context, audio []byte, filename string) ([]byte, string, error) {
tmpDir, err := os.MkdirTemp("", "ai-transcribe-*")
if err != nil {
return nil, "", fmt.Errorf("prepare audio temp dir: %w", err)
}
defer os.RemoveAll(tmpDir)
inputPath := filepath.Join(tmpDir, "input"+safeExt(filename))
outputPath := filepath.Join(tmpDir, "padded.mp3")
if err := os.WriteFile(inputPath, audio, 0o600); err != nil {
return nil, "", fmt.Errorf("write audio temp file: %w", err)
}
delayMS := int(c.leadSilence.Milliseconds())
if delayMS <= 0 {
return audio, filename, nil
}
cmd := exec.CommandContext(ctx, c.ffmpegPath,
"-nostdin", "-y",
"-i", inputPath,
"-af", fmt.Sprintf("adelay=%d:all=1", delayMS),
"-codec:a", "libmp3lame",
"-qscale:a", "5",
outputPath,
)
out, err := cmd.CombinedOutput()
if err != nil {
return nil, "", fmt.Errorf("ffmpeg lead silence: %w (%s)", err, trimOutput(out))
}
padded, err := os.ReadFile(outputPath)
if err != nil {
return nil, "", fmt.Errorf("read padded audio: %w", err)
}
if len(padded) == 0 {
return nil, "", fmt.Errorf("padded audio is empty")
}
base := strings.TrimSuffix(filepath.Base(filename), filepath.Ext(filename))
if base == "" || base == "." || base == "/" {
base = "audio"
}
return padded, base + "-padded.mp3", nil
}
func safeExt(filename string) string {
ext := strings.ToLower(filepath.Ext(filename))
switch ext {
case ".mp3", ".wav", ".m4a", ".ogg", ".opus", ".webm":
return ext
default:
return ".audio"
}
}
func trimOutput(out []byte) string {
s := strings.TrimSpace(string(out))
if len(s) > 600 {
return s[:600]
}
return s
}
func adjustLeadSilence(segments []Segment, silence time.Duration) []Segment {
if len(segments) == 0 || silence <= 0 {
return segments
}
shift := silence.Seconds()
out := make([]Segment, 0, len(segments))
for _, segment := range segments {
segment.Start = clampTime(segment.Start - shift)
segment.End = clampTime(segment.End - shift)
if segment.End < segment.Start {
segment.End = segment.Start
}
out = append(out, segment)
}
return out
}
func clampTime(v float64) float64 {
if v < 0 {
return 0
}
return v
}
func (c *Client) transcribeAudio(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*whisperResponse, time.Duration, error) {
body := &bytes.Buffer{}
mw := multipart.NewWriter(body)
fw, err := mw.CreateFormFile("file", filename)
if err != nil {
return nil, 0, fmt.Errorf("create form file: %w", err)
}
if _, err := fw.Write(audio); err != nil {
return nil, 0, fmt.Errorf("copy audio: %w", err)
}
if in.Language != "" {
_ = mw.WriteField("language", in.Language)
}
if in.Diarize {
_ = mw.WriteField("diarize", "true")
if in.MinSpeakers > 0 {
_ = mw.WriteField("min_speakers", fmt.Sprintf("%d", in.MinSpeakers))
}
if in.MaxSpeakers > 0 {
_ = mw.WriteField("max_speakers", fmt.Sprintf("%d", in.MaxSpeakers))
}
} else {
_ = mw.WriteField("diarize", "false")
}
if err := mw.Close(); err != nil {
return nil, 0, fmt.Errorf("close form: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, provider.BaseURL+"/transcribe", body)
if err != nil {
return nil, 0, fmt.Errorf("whisperx request: %w", err)
}
req.Header.Set("Content-Type", mw.FormDataContentType())
start := time.Now()
resp, err := c.http.Do(req)
duration := time.Since(start)
if err != nil {
return nil, duration, fmt.Errorf("whisperx do: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return nil, duration, fmt.Errorf("whisperx HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var out whisperResponse
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
return nil, duration, fmt.Errorf("whisperx decode: %w", err)
}
return &out, duration, nil
}
func (c *Client) transcribeAudioLLM(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*audioLLMResponse, time.Duration, error) {
prompt := provider.Prompt
if in.Language != "" {
prompt += "\nЯзык аудио: " + in.Language + "."
}
if in.Diarize {
prompt += "\nЕсли слышны разные говорящие, разделяй реплики с короткими пометками Спикер 1/Спикер 2."
}
reqBody := audioLLMChatRequest{
Model: provider.Model,
MaxTokens: provider.MaxTokens,
Temperature: 0,
Messages: []audioLLMChatMessage{
{
Role: "user",
Content: []audioLLMContentPart{
{Type: "text", Text: prompt},
{
Type: "input_audio",
InputAudio: &audioLLMAudio{
Data: base64.StdEncoding.EncodeToString(audio),
Format: audioFormat(filename),
},
},
},
},
},
}
body, err := json.Marshal(reqBody)
if err != nil {
return nil, 0, fmt.Errorf("audio llm marshal: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, provider.BaseURL+"/v1/chat/completions", bytes.NewReader(body))
if err != nil {
return nil, 0, fmt.Errorf("audio llm request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
if provider.APIKey != "" {
req.Header.Set("Authorization", "Bearer "+provider.APIKey)
}
start := time.Now()
resp, err := c.http.Do(req)
duration := time.Since(start)
if err != nil {
return nil, duration, fmt.Errorf("audio llm do: %w", err)
}
defer resp.Body.Close()
raw, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
if err != nil {
return nil, duration, fmt.Errorf("audio llm read: %w", err)
}
if resp.StatusCode >= 300 {
return nil, duration, fmt.Errorf("audio llm HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(raw)))
}
var out audioLLMChatResponse
if err := json.Unmarshal(raw, &out); err != nil {
return nil, duration, fmt.Errorf("audio llm decode: %w", err)
}
if out.Error != nil {
return nil, duration, fmt.Errorf("audio llm error: %s", out.Error.Message)
}
if len(out.Choices) == 0 {
return nil, duration, fmt.Errorf("audio llm: empty choices")
}
modelName := out.Model
if modelName == "" {
modelName = provider.Model
}
return &audioLLMResponse{
Text: strings.TrimSpace(out.Choices[0].Message.Content),
Model: modelName,
}, duration, nil
}
func audioFormat(filename string) string {
ext := strings.TrimPrefix(strings.ToLower(filepath.Ext(filename)), ".")
switch ext {
case "wav", "mp3", "flac", "m4a", "ogg", "opus", "webm":
return ext
default:
return "mp3"
}
}
func firstNonEmpty(values ...string) string {
for _, value := range values {
if strings.TrimSpace(value) != "" {
return value
}
}
return ""
}