899 lines
25 KiB
Go
899 lines
25 KiB
Go
package transcription
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"encoding/base64"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"mime/multipart"
|
||
"net/http"
|
||
"os"
|
||
"os/exec"
|
||
"path/filepath"
|
||
"strings"
|
||
"time"
|
||
)
|
||
|
||
type Client struct {
|
||
providers []ProviderConfig
|
||
http *http.Client
|
||
ffmpegPath string
|
||
leadSilence time.Duration
|
||
}
|
||
|
||
const (
|
||
ProviderWhisperX = "whisperx"
|
||
ProviderQwenAudio = "qwen2-audio"
|
||
ProviderVoxtral = "voxtral-small"
|
||
)
|
||
|
||
type Options struct {
|
||
Providers []string
|
||
WhisperXURL string
|
||
WhisperXTimeout time.Duration
|
||
FfmpegPath string
|
||
LeadSilence time.Duration
|
||
QwenAudioBaseURL string
|
||
QwenAudioAPIKey string
|
||
QwenAudioModel string
|
||
QwenAudioTimeout time.Duration
|
||
VoxtralBaseURL string
|
||
VoxtralAPIKey string
|
||
VoxtralModel string
|
||
VoxtralTimeout time.Duration
|
||
AudioLLMPrompt string
|
||
AudioLLMMaxTokens int
|
||
}
|
||
|
||
type ProviderConfig struct {
|
||
Name string
|
||
Kind string
|
||
BaseURL string
|
||
APIKey string
|
||
Model string
|
||
Timeout time.Duration
|
||
MaxTokens int
|
||
Prompt string
|
||
}
|
||
|
||
type Input struct {
|
||
AudioURL string `json:"audio_url"`
|
||
Filename string `json:"filename,omitempty"`
|
||
Language string `json:"language,omitempty"`
|
||
Diarize bool `json:"diarize"`
|
||
MinSpeakers int `json:"min_speakers,omitempty"`
|
||
MaxSpeakers int `json:"max_speakers,omitempty"`
|
||
}
|
||
|
||
type Segment struct {
|
||
Start float64 `json:"start"`
|
||
End float64 `json:"end"`
|
||
Text string `json:"text"`
|
||
Speaker string `json:"speaker,omitempty"`
|
||
}
|
||
|
||
type Result struct {
|
||
Provider string `json:"provider,omitempty"`
|
||
Model string `json:"model,omitempty"`
|
||
Attempts []Attempt `json:"attempts,omitempty"`
|
||
Language string `json:"language"`
|
||
Segments []Segment `json:"segments"`
|
||
DiarizeError *string `json:"diarize_error,omitempty"`
|
||
AlignError *string `json:"align_error,omitempty"`
|
||
DurationMS int64 `json:"duration_ms"`
|
||
}
|
||
|
||
type Attempt struct {
|
||
Provider string `json:"provider"`
|
||
Model string `json:"model,omitempty"`
|
||
Status string `json:"status"`
|
||
Error string `json:"error,omitempty"`
|
||
Text string `json:"text,omitempty"`
|
||
Segments []Segment `json:"segments,omitempty"`
|
||
DurationMS int64 `json:"duration_ms,omitempty"`
|
||
}
|
||
|
||
type whisperResponse struct {
|
||
Language string `json:"language"`
|
||
Segments []Segment `json:"segments"`
|
||
DiarizeError *string `json:"diarize_error,omitempty"`
|
||
AlignError *string `json:"align_error,omitempty"`
|
||
}
|
||
|
||
type audioLLMResponse struct {
|
||
Text string
|
||
Model string
|
||
Language string
|
||
Segments []Segment
|
||
}
|
||
|
||
type audioLLMChatRequest struct {
|
||
Model string `json:"model"`
|
||
Messages []audioLLMChatMessage `json:"messages"`
|
||
MaxTokens int `json:"max_tokens,omitempty"`
|
||
Temperature float64 `json:"temperature"`
|
||
}
|
||
|
||
type audioLLMChatMessage struct {
|
||
Role string `json:"role"`
|
||
Content []audioLLMContentPart `json:"content"`
|
||
}
|
||
|
||
type audioLLMContentPart struct {
|
||
Type string `json:"type"`
|
||
Text string `json:"text,omitempty"`
|
||
AudioURL *audioLLMURLRef `json:"audio_url,omitempty"`
|
||
}
|
||
|
||
type audioLLMURLRef struct {
|
||
URL string `json:"url"`
|
||
}
|
||
|
||
type audioLLMChatResponse struct {
|
||
Model string `json:"model,omitempty"`
|
||
Choices []struct {
|
||
Message struct {
|
||
Content string `json:"content"`
|
||
} `json:"message"`
|
||
} `json:"choices"`
|
||
Error *struct {
|
||
Message string `json:"message"`
|
||
} `json:"error,omitempty"`
|
||
}
|
||
|
||
type audioTranscriptionResponse struct {
|
||
Text string `json:"text"`
|
||
Language string `json:"language,omitempty"`
|
||
Model string `json:"model,omitempty"`
|
||
Segments []audioTranscriptionSegment `json:"segments,omitempty"`
|
||
Error *struct {
|
||
Message string `json:"message"`
|
||
} `json:"error,omitempty"`
|
||
}
|
||
|
||
type audioTranscriptionSegment struct {
|
||
Start float64 `json:"start"`
|
||
End float64 `json:"end"`
|
||
Text string `json:"text"`
|
||
}
|
||
|
||
func New(baseURL string, timeout time.Duration, ffmpegPath string, leadSilence time.Duration) *Client {
|
||
return NewWithOptions(Options{
|
||
Providers: []string{ProviderWhisperX},
|
||
WhisperXURL: baseURL,
|
||
WhisperXTimeout: timeout,
|
||
FfmpegPath: ffmpegPath,
|
||
LeadSilence: leadSilence,
|
||
})
|
||
}
|
||
|
||
func NewWithOptions(opts Options) *Client {
|
||
leadSilence := opts.LeadSilence
|
||
if leadSilence < 0 {
|
||
leadSilence = 0
|
||
}
|
||
if leadSilence > 5*time.Second {
|
||
leadSilence = 5 * time.Second
|
||
}
|
||
ffmpegPath := strings.TrimSpace(opts.FfmpegPath)
|
||
if ffmpegPath == "" {
|
||
ffmpegPath = "ffmpeg"
|
||
}
|
||
maxTokens := opts.AudioLLMMaxTokens
|
||
if maxTokens <= 0 {
|
||
maxTokens = 4096
|
||
}
|
||
audioLLMPrompt := strings.TrimSpace(opts.AudioLLMPrompt)
|
||
if audioLLMPrompt == "" {
|
||
audioLLMPrompt = "Transcribe the audio exactly. Return only the transcript text."
|
||
}
|
||
providers := buildProviders(opts, audioLLMPrompt, maxTokens)
|
||
if len(providers) == 0 {
|
||
return nil
|
||
}
|
||
return &Client{
|
||
providers: providers,
|
||
http: &http.Client{Timeout: maxProviderTimeout(providers)},
|
||
ffmpegPath: ffmpegPath,
|
||
leadSilence: leadSilence,
|
||
}
|
||
}
|
||
|
||
func buildProviders(opts Options, prompt string, maxTokens int) []ProviderConfig {
|
||
order := normalizeProviderOrder(opts.Providers)
|
||
out := make([]ProviderConfig, 0, len(order))
|
||
for _, name := range order {
|
||
switch name {
|
||
case ProviderWhisperX:
|
||
baseURL := strings.TrimRight(strings.TrimSpace(opts.WhisperXURL), "/")
|
||
if baseURL == "" {
|
||
continue
|
||
}
|
||
out = append(out, ProviderConfig{
|
||
Name: ProviderWhisperX,
|
||
Kind: ProviderWhisperX,
|
||
BaseURL: baseURL,
|
||
Model: ProviderWhisperX,
|
||
Timeout: defaultDuration(opts.WhisperXTimeout, 10*time.Minute),
|
||
})
|
||
case ProviderQwenAudio:
|
||
baseURL := strings.TrimRight(strings.TrimSpace(opts.QwenAudioBaseURL), "/")
|
||
if baseURL == "" {
|
||
continue
|
||
}
|
||
model := firstNonEmpty(opts.QwenAudioModel, "Qwen/Qwen2-Audio-7B-Instruct")
|
||
out = append(out, ProviderConfig{
|
||
Name: ProviderQwenAudio,
|
||
Kind: "audio_llm",
|
||
BaseURL: baseURL,
|
||
APIKey: strings.TrimSpace(opts.QwenAudioAPIKey),
|
||
Model: model,
|
||
Timeout: defaultDuration(opts.QwenAudioTimeout, 10*time.Minute),
|
||
MaxTokens: maxTokens,
|
||
Prompt: prompt,
|
||
})
|
||
case ProviderVoxtral:
|
||
baseURL := strings.TrimRight(strings.TrimSpace(opts.VoxtralBaseURL), "/")
|
||
if baseURL == "" {
|
||
continue
|
||
}
|
||
model := firstNonEmpty(opts.VoxtralModel, "mistralai/Voxtral-Small-24B-2507")
|
||
out = append(out, ProviderConfig{
|
||
Name: ProviderVoxtral,
|
||
Kind: "audio_transcription",
|
||
BaseURL: baseURL,
|
||
APIKey: strings.TrimSpace(opts.VoxtralAPIKey),
|
||
Model: model,
|
||
Timeout: defaultDuration(opts.VoxtralTimeout, 10*time.Minute),
|
||
MaxTokens: maxTokens,
|
||
Prompt: prompt,
|
||
})
|
||
}
|
||
}
|
||
return out
|
||
}
|
||
|
||
func normalizeProviderOrder(in []string) []string {
|
||
if len(in) == 0 {
|
||
return []string{ProviderWhisperX}
|
||
}
|
||
out := make([]string, 0, len(in))
|
||
seen := map[string]bool{}
|
||
for _, raw := range in {
|
||
name := strings.ToLower(strings.TrimSpace(raw))
|
||
switch name {
|
||
case "whisper", "whisperx":
|
||
name = ProviderWhisperX
|
||
case "qwen", "qwen-audio", "qwen2-audio", "qwen2-audio-7b-instruct":
|
||
name = ProviderQwenAudio
|
||
case "voxtral", "voxtral-small", "voxtral-small-24b-2507":
|
||
name = ProviderVoxtral
|
||
default:
|
||
continue
|
||
}
|
||
if !seen[name] {
|
||
out = append(out, name)
|
||
seen[name] = true
|
||
}
|
||
}
|
||
return out
|
||
}
|
||
|
||
func maxProviderTimeout(providers []ProviderConfig) time.Duration {
|
||
maxTimeout := 10 * time.Minute
|
||
for _, provider := range providers {
|
||
if provider.Timeout > maxTimeout {
|
||
maxTimeout = provider.Timeout
|
||
}
|
||
}
|
||
return maxTimeout
|
||
}
|
||
|
||
func defaultDuration(v, fallback time.Duration) time.Duration {
|
||
if v <= 0 {
|
||
return fallback
|
||
}
|
||
return v
|
||
}
|
||
|
||
func (c *Client) Transcribe(ctx context.Context, in Input) (*Result, error) {
|
||
if c == nil || len(c.providers) == 0 {
|
||
return nil, fmt.Errorf("transcription providers not configured")
|
||
}
|
||
if strings.TrimSpace(in.AudioURL) == "" {
|
||
return nil, fmt.Errorf("audio_url is required")
|
||
}
|
||
audio, filename, err := c.downloadAudio(ctx, in)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if c.leadSilence > 0 {
|
||
audio, filename, err = c.addLeadSilence(ctx, audio, filename)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
}
|
||
var attempts []Attempt
|
||
var winner *Result
|
||
var errors []string
|
||
for _, provider := range c.providers {
|
||
result, attempt, err := c.transcribeWithProvider(ctx, provider, audio, filename, in)
|
||
attempts = append(attempts, attempt)
|
||
if err != nil {
|
||
errors = append(errors, provider.Name+": "+err.Error())
|
||
continue
|
||
}
|
||
if winner == nil {
|
||
winner = result
|
||
}
|
||
}
|
||
if winner == nil {
|
||
return nil, fmt.Errorf("all transcription providers failed: %s", strings.Join(errors, "; "))
|
||
}
|
||
winner.Attempts = attempts
|
||
return winner, nil
|
||
}
|
||
|
||
func (c *Client) transcribeWithProvider(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*Result, Attempt, error) {
|
||
providerCtx := ctx
|
||
cancel := func() {}
|
||
if provider.Timeout > 0 {
|
||
providerCtx, cancel = context.WithTimeout(ctx, provider.Timeout)
|
||
}
|
||
defer cancel()
|
||
attempt := Attempt{
|
||
Provider: provider.Name,
|
||
Model: provider.Model,
|
||
Status: "failed",
|
||
}
|
||
switch provider.Kind {
|
||
case ProviderWhisperX:
|
||
resp, duration, err := c.transcribeAudio(providerCtx, provider, audio, filename, in)
|
||
attempt.DurationMS = duration.Milliseconds()
|
||
if err != nil {
|
||
attempt.Error = err.Error()
|
||
return nil, attempt, err
|
||
}
|
||
segments := adjustLeadSilence(resp.Segments, c.leadSilence)
|
||
attempt.Status = "ok"
|
||
attempt.Segments = segments
|
||
attempt.Text = segmentsText(segments)
|
||
return &Result{
|
||
Provider: provider.Name,
|
||
Model: provider.Model,
|
||
Language: resp.Language,
|
||
Segments: segments,
|
||
DiarizeError: resp.DiarizeError,
|
||
AlignError: resp.AlignError,
|
||
DurationMS: duration.Milliseconds(),
|
||
}, attempt, nil
|
||
default:
|
||
resp, duration, err := c.transcribeAudioLLM(providerCtx, provider, audio, filename, in)
|
||
attempt.DurationMS = duration.Milliseconds()
|
||
if err != nil {
|
||
attempt.Error = err.Error()
|
||
return nil, attempt, err
|
||
}
|
||
text := strings.TrimSpace(resp.Text)
|
||
segments := normalizeAudioLLMSegments(resp.Segments, text, in.Diarize)
|
||
attempt.Status = "ok"
|
||
attempt.Model = resp.Model
|
||
attempt.Text = text
|
||
attempt.Segments = segments
|
||
return &Result{
|
||
Provider: provider.Name,
|
||
Model: resp.Model,
|
||
Language: firstNonEmpty(resp.Language, in.Language, "unknown"),
|
||
Segments: segments,
|
||
DurationMS: duration.Milliseconds(),
|
||
}, attempt, nil
|
||
}
|
||
}
|
||
|
||
func segmentsText(segments []Segment) string {
|
||
parts := make([]string, 0, len(segments))
|
||
for _, segment := range segments {
|
||
if text := strings.TrimSpace(segment.Text); text != "" {
|
||
parts = append(parts, text)
|
||
}
|
||
}
|
||
return strings.Join(parts, "\n")
|
||
}
|
||
|
||
func (c *Client) downloadAudio(ctx context.Context, in Input) ([]byte, string, error) {
|
||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, in.AudioURL, nil)
|
||
if err != nil {
|
||
return nil, "", fmt.Errorf("audio request: %w", err)
|
||
}
|
||
resp, err := c.http.Do(req)
|
||
if err != nil {
|
||
return nil, "", fmt.Errorf("audio download: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
if resp.StatusCode >= 300 {
|
||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||
return nil, "", fmt.Errorf("audio HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||
}
|
||
audio, err := io.ReadAll(io.LimitReader(resp.Body, 512<<20))
|
||
if err != nil {
|
||
return nil, "", fmt.Errorf("audio read: %w", err)
|
||
}
|
||
if len(audio) == 0 {
|
||
return nil, "", fmt.Errorf("audio is empty")
|
||
}
|
||
filename := filepath.Base(strings.TrimSpace(in.Filename))
|
||
if filename == "." || filename == "/" || filename == "" {
|
||
filename = "audio.mp3"
|
||
}
|
||
return audio, filename, nil
|
||
}
|
||
|
||
func (c *Client) addLeadSilence(ctx context.Context, audio []byte, filename string) ([]byte, string, error) {
|
||
tmpDir, err := os.MkdirTemp("", "ai-transcribe-*")
|
||
if err != nil {
|
||
return nil, "", fmt.Errorf("prepare audio temp dir: %w", err)
|
||
}
|
||
defer os.RemoveAll(tmpDir)
|
||
|
||
inputPath := filepath.Join(tmpDir, "input"+safeExt(filename))
|
||
outputPath := filepath.Join(tmpDir, "padded.mp3")
|
||
if err := os.WriteFile(inputPath, audio, 0o600); err != nil {
|
||
return nil, "", fmt.Errorf("write audio temp file: %w", err)
|
||
}
|
||
delayMS := int(c.leadSilence.Milliseconds())
|
||
if delayMS <= 0 {
|
||
return audio, filename, nil
|
||
}
|
||
cmd := exec.CommandContext(ctx, c.ffmpegPath,
|
||
"-nostdin", "-y",
|
||
"-i", inputPath,
|
||
"-af", fmt.Sprintf("adelay=%d:all=1", delayMS),
|
||
"-codec:a", "libmp3lame",
|
||
"-qscale:a", "5",
|
||
outputPath,
|
||
)
|
||
out, err := cmd.CombinedOutput()
|
||
if err != nil {
|
||
return nil, "", fmt.Errorf("ffmpeg lead silence: %w (%s)", err, trimOutput(out))
|
||
}
|
||
padded, err := os.ReadFile(outputPath)
|
||
if err != nil {
|
||
return nil, "", fmt.Errorf("read padded audio: %w", err)
|
||
}
|
||
if len(padded) == 0 {
|
||
return nil, "", fmt.Errorf("padded audio is empty")
|
||
}
|
||
base := strings.TrimSuffix(filepath.Base(filename), filepath.Ext(filename))
|
||
if base == "" || base == "." || base == "/" {
|
||
base = "audio"
|
||
}
|
||
return padded, base + "-padded.mp3", nil
|
||
}
|
||
|
||
func safeExt(filename string) string {
|
||
ext := strings.ToLower(filepath.Ext(filename))
|
||
switch ext {
|
||
case ".mp3", ".wav", ".m4a", ".ogg", ".opus", ".webm":
|
||
return ext
|
||
default:
|
||
return ".audio"
|
||
}
|
||
}
|
||
|
||
func trimOutput(out []byte) string {
|
||
s := strings.TrimSpace(string(out))
|
||
if len(s) > 600 {
|
||
return s[:600]
|
||
}
|
||
return s
|
||
}
|
||
|
||
func adjustLeadSilence(segments []Segment, silence time.Duration) []Segment {
|
||
if len(segments) == 0 || silence <= 0 {
|
||
return segments
|
||
}
|
||
shift := silence.Seconds()
|
||
out := make([]Segment, 0, len(segments))
|
||
for _, segment := range segments {
|
||
segment.Start = clampTime(segment.Start - shift)
|
||
segment.End = clampTime(segment.End - shift)
|
||
if segment.End < segment.Start {
|
||
segment.End = segment.Start
|
||
}
|
||
out = append(out, segment)
|
||
}
|
||
return out
|
||
}
|
||
|
||
func clampTime(v float64) float64 {
|
||
if v < 0 {
|
||
return 0
|
||
}
|
||
return v
|
||
}
|
||
|
||
func (c *Client) transcribeAudio(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*whisperResponse, time.Duration, error) {
|
||
body := &bytes.Buffer{}
|
||
mw := multipart.NewWriter(body)
|
||
fw, err := mw.CreateFormFile("file", filename)
|
||
if err != nil {
|
||
return nil, 0, fmt.Errorf("create form file: %w", err)
|
||
}
|
||
if _, err := fw.Write(audio); err != nil {
|
||
return nil, 0, fmt.Errorf("copy audio: %w", err)
|
||
}
|
||
if in.Language != "" {
|
||
_ = mw.WriteField("language", in.Language)
|
||
}
|
||
if in.Diarize {
|
||
_ = mw.WriteField("diarize", "true")
|
||
if in.MinSpeakers > 0 {
|
||
_ = mw.WriteField("min_speakers", fmt.Sprintf("%d", in.MinSpeakers))
|
||
}
|
||
if in.MaxSpeakers > 0 {
|
||
_ = mw.WriteField("max_speakers", fmt.Sprintf("%d", in.MaxSpeakers))
|
||
}
|
||
} else {
|
||
_ = mw.WriteField("diarize", "false")
|
||
}
|
||
if err := mw.Close(); err != nil {
|
||
return nil, 0, fmt.Errorf("close form: %w", err)
|
||
}
|
||
|
||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, provider.BaseURL+"/transcribe", body)
|
||
if err != nil {
|
||
return nil, 0, fmt.Errorf("whisperx request: %w", err)
|
||
}
|
||
req.Header.Set("Content-Type", mw.FormDataContentType())
|
||
|
||
start := time.Now()
|
||
resp, err := c.http.Do(req)
|
||
duration := time.Since(start)
|
||
if err != nil {
|
||
return nil, duration, fmt.Errorf("whisperx do: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
if resp.StatusCode >= 300 {
|
||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||
return nil, duration, fmt.Errorf("whisperx HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||
}
|
||
var out whisperResponse
|
||
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
|
||
return nil, duration, fmt.Errorf("whisperx decode: %w", err)
|
||
}
|
||
return &out, duration, nil
|
||
}
|
||
|
||
func (c *Client) transcribeAudioLLM(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*audioLLMResponse, time.Duration, error) {
|
||
if provider.Kind == "audio_transcription" {
|
||
return c.transcribeOpenAIAudio(ctx, provider, audio, filename, in)
|
||
}
|
||
prompt := provider.Prompt
|
||
if in.Language != "" {
|
||
prompt += "\nЯзык аудио: " + in.Language + "."
|
||
}
|
||
if in.Diarize {
|
||
prompt += "\nЕсли слышны разные говорящие, разделяй реплики с короткими пометками Спикер 1/Спикер 2."
|
||
}
|
||
reqBody := audioLLMChatRequest{
|
||
Model: provider.Model,
|
||
MaxTokens: provider.MaxTokens,
|
||
Temperature: 0,
|
||
Messages: []audioLLMChatMessage{
|
||
{
|
||
Role: "user",
|
||
Content: []audioLLMContentPart{
|
||
{Type: "text", Text: prompt},
|
||
{
|
||
Type: "audio_url",
|
||
AudioURL: &audioLLMURLRef{URL: audioDataURL(audio, filename)},
|
||
},
|
||
},
|
||
},
|
||
},
|
||
}
|
||
body, err := json.Marshal(reqBody)
|
||
if err != nil {
|
||
return nil, 0, fmt.Errorf("audio llm marshal: %w", err)
|
||
}
|
||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, provider.BaseURL+"/v1/chat/completions", bytes.NewReader(body))
|
||
if err != nil {
|
||
return nil, 0, fmt.Errorf("audio llm request: %w", err)
|
||
}
|
||
req.Header.Set("Content-Type", "application/json")
|
||
if provider.APIKey != "" {
|
||
req.Header.Set("Authorization", "Bearer "+provider.APIKey)
|
||
}
|
||
|
||
start := time.Now()
|
||
resp, err := c.http.Do(req)
|
||
duration := time.Since(start)
|
||
if err != nil {
|
||
return nil, duration, fmt.Errorf("audio llm do: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
raw, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
|
||
if err != nil {
|
||
return nil, duration, fmt.Errorf("audio llm read: %w", err)
|
||
}
|
||
if resp.StatusCode >= 300 {
|
||
return nil, duration, fmt.Errorf("audio llm HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(raw)))
|
||
}
|
||
var out audioLLMChatResponse
|
||
if err := json.Unmarshal(raw, &out); err != nil {
|
||
return nil, duration, fmt.Errorf("audio llm decode: %w", err)
|
||
}
|
||
if out.Error != nil {
|
||
return nil, duration, fmt.Errorf("audio llm error: %s", out.Error.Message)
|
||
}
|
||
if len(out.Choices) == 0 {
|
||
return nil, duration, fmt.Errorf("audio llm: empty choices")
|
||
}
|
||
modelName := out.Model
|
||
if modelName == "" {
|
||
modelName = provider.Model
|
||
}
|
||
return &audioLLMResponse{
|
||
Text: strings.TrimSpace(out.Choices[0].Message.Content),
|
||
Model: modelName,
|
||
}, duration, nil
|
||
}
|
||
|
||
func (c *Client) transcribeOpenAIAudio(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*audioLLMResponse, time.Duration, error) {
|
||
resp, duration, err := c.doOpenAIAudioTranscription(ctx, provider, audio, filename, in, "verbose_json")
|
||
if err == nil {
|
||
return resp, duration, nil
|
||
}
|
||
if !strings.Contains(strings.ToLower(err.Error()), "http 4") {
|
||
return nil, duration, err
|
||
}
|
||
fallback, fallbackDuration, fallbackErr := c.doOpenAIAudioTranscription(ctx, provider, audio, filename, in, "json")
|
||
if fallbackErr != nil {
|
||
return nil, duration + fallbackDuration, err
|
||
}
|
||
return fallback, duration + fallbackDuration, nil
|
||
}
|
||
|
||
func (c *Client) doOpenAIAudioTranscription(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input, responseFormat string) (*audioLLMResponse, time.Duration, error) {
|
||
body := &bytes.Buffer{}
|
||
mw := multipart.NewWriter(body)
|
||
if err := mw.WriteField("model", provider.Model); err != nil {
|
||
return nil, 0, fmt.Errorf("audio transcription model field: %w", err)
|
||
}
|
||
if err := mw.WriteField("response_format", responseFormat); err != nil {
|
||
return nil, 0, fmt.Errorf("audio transcription response_format field: %w", err)
|
||
}
|
||
if responseFormat == "verbose_json" {
|
||
if err := mw.WriteField("timestamp_granularities[]", "segment"); err != nil {
|
||
return nil, 0, fmt.Errorf("audio transcription timestamp field: %w", err)
|
||
}
|
||
}
|
||
if prompt := strings.TrimSpace(provider.Prompt); prompt != "" {
|
||
if err := mw.WriteField("prompt", prompt); err != nil {
|
||
return nil, 0, fmt.Errorf("audio transcription prompt field: %w", err)
|
||
}
|
||
}
|
||
if lang := strings.TrimSpace(in.Language); lang != "" {
|
||
if err := mw.WriteField("language", lang); err != nil {
|
||
return nil, 0, fmt.Errorf("audio transcription language field: %w", err)
|
||
}
|
||
}
|
||
fw, err := mw.CreateFormFile("file", filename)
|
||
if err != nil {
|
||
return nil, 0, fmt.Errorf("audio transcription create file: %w", err)
|
||
}
|
||
if _, err := fw.Write(audio); err != nil {
|
||
return nil, 0, fmt.Errorf("audio transcription copy audio: %w", err)
|
||
}
|
||
if err := mw.Close(); err != nil {
|
||
return nil, 0, fmt.Errorf("audio transcription close form: %w", err)
|
||
}
|
||
|
||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, provider.BaseURL+"/v1/audio/transcriptions", body)
|
||
if err != nil {
|
||
return nil, 0, fmt.Errorf("audio transcription request: %w", err)
|
||
}
|
||
req.Header.Set("Content-Type", mw.FormDataContentType())
|
||
if provider.APIKey != "" {
|
||
req.Header.Set("Authorization", "Bearer "+provider.APIKey)
|
||
}
|
||
|
||
start := time.Now()
|
||
resp, err := c.http.Do(req)
|
||
duration := time.Since(start)
|
||
if err != nil {
|
||
return nil, duration, fmt.Errorf("audio transcription do: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
raw, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
|
||
if err != nil {
|
||
return nil, duration, fmt.Errorf("audio transcription read: %w", err)
|
||
}
|
||
if resp.StatusCode >= 300 {
|
||
return nil, duration, fmt.Errorf("audio transcription HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(raw)))
|
||
}
|
||
var out audioTranscriptionResponse
|
||
if err := json.Unmarshal(raw, &out); err != nil {
|
||
return nil, duration, fmt.Errorf("audio transcription decode: %w", err)
|
||
}
|
||
if out.Error != nil {
|
||
return nil, duration, fmt.Errorf("audio transcription error: %s", out.Error.Message)
|
||
}
|
||
modelName := firstNonEmpty(out.Model, provider.Model)
|
||
return &audioLLMResponse{
|
||
Text: strings.TrimSpace(out.Text),
|
||
Model: modelName,
|
||
Language: out.Language,
|
||
Segments: convertAudioSegments(out.Segments),
|
||
}, duration, nil
|
||
}
|
||
|
||
func convertAudioSegments(in []audioTranscriptionSegment) []Segment {
|
||
out := make([]Segment, 0, len(in))
|
||
for _, s := range in {
|
||
text := strings.TrimSpace(s.Text)
|
||
if text == "" {
|
||
continue
|
||
}
|
||
end := s.End
|
||
if end < s.Start {
|
||
end = s.Start
|
||
}
|
||
out = append(out, Segment{Start: clampTime(s.Start), End: clampTime(end), Text: text})
|
||
}
|
||
return out
|
||
}
|
||
|
||
func normalizeAudioLLMSegments(segments []Segment, text string, diarize bool) []Segment {
|
||
text = strings.TrimSpace(text)
|
||
if len(segments) <= 1 && text != "" {
|
||
heuristic := segmentTranscriptText(text, diarize)
|
||
if len(heuristic) > len(segments) {
|
||
segments = heuristic
|
||
}
|
||
}
|
||
return ensureHeuristicSpeakers(segments, diarize)
|
||
}
|
||
|
||
func ensureHeuristicSpeakers(segments []Segment, diarize bool) []Segment {
|
||
if !diarize || len(segments) < 2 || segmentsHaveSpeakers(segments) {
|
||
return segments
|
||
}
|
||
out := make([]Segment, len(segments))
|
||
copy(out, segments)
|
||
for i := range out {
|
||
if i%2 == 0 {
|
||
out[i].Speaker = "SPEAKER_00"
|
||
} else {
|
||
out[i].Speaker = "SPEAKER_01"
|
||
}
|
||
}
|
||
return out
|
||
}
|
||
|
||
func segmentsHaveSpeakers(segments []Segment) bool {
|
||
for _, segment := range segments {
|
||
if strings.TrimSpace(segment.Speaker) != "" {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
func segmentTranscriptText(text string, diarize bool) []Segment {
|
||
parts := splitTranscriptSentences(text)
|
||
out := make([]Segment, 0, len(parts))
|
||
var t float64
|
||
for i, part := range parts {
|
||
words := len(strings.Fields(part))
|
||
if words == 0 {
|
||
continue
|
||
}
|
||
duration := float64(words) * 0.42
|
||
if duration < 1.2 {
|
||
duration = 1.2
|
||
}
|
||
segment := Segment{Start: t, End: t + duration, Text: part}
|
||
if diarize && len(parts) > 1 {
|
||
if i%2 == 0 {
|
||
segment.Speaker = "SPEAKER_00"
|
||
} else {
|
||
segment.Speaker = "SPEAKER_01"
|
||
}
|
||
}
|
||
out = append(out, segment)
|
||
t = segment.End
|
||
}
|
||
if len(out) == 0 && strings.TrimSpace(text) != "" {
|
||
out = append(out, Segment{Start: 0, End: 0, Text: strings.TrimSpace(text)})
|
||
}
|
||
return out
|
||
}
|
||
|
||
func splitTranscriptSentences(text string) []string {
|
||
text = strings.Join(strings.Fields(text), " ")
|
||
if text == "" {
|
||
return nil
|
||
}
|
||
var out []string
|
||
start := 0
|
||
runes := []rune(text)
|
||
for i, r := range runes {
|
||
if r != '.' && r != '!' && r != '?' && r != '…' {
|
||
continue
|
||
}
|
||
next := i + 1
|
||
if next < len(runes) && runes[next] != ' ' {
|
||
continue
|
||
}
|
||
part := strings.TrimSpace(string(runes[start : i+1]))
|
||
if part != "" {
|
||
out = append(out, part)
|
||
}
|
||
start = i + 1
|
||
for start < len(runes) && runes[start] == ' ' {
|
||
start++
|
||
}
|
||
}
|
||
tail := strings.TrimSpace(string(runes[start:]))
|
||
if tail != "" {
|
||
out = append(out, tail)
|
||
}
|
||
return mergeShortSegments(out, 8, 34)
|
||
}
|
||
|
||
func mergeShortSegments(parts []string, minWords, maxWords int) []string {
|
||
if len(parts) <= 1 {
|
||
return parts
|
||
}
|
||
out := make([]string, 0, len(parts))
|
||
var current []string
|
||
currentWords := 0
|
||
flush := func() {
|
||
if len(current) == 0 {
|
||
return
|
||
}
|
||
out = append(out, strings.Join(current, " "))
|
||
current = nil
|
||
currentWords = 0
|
||
}
|
||
for _, part := range parts {
|
||
words := len(strings.Fields(part))
|
||
if currentWords > 0 && currentWords+words > maxWords {
|
||
flush()
|
||
}
|
||
current = append(current, part)
|
||
currentWords += words
|
||
if currentWords >= minWords {
|
||
flush()
|
||
}
|
||
}
|
||
flush()
|
||
return out
|
||
}
|
||
|
||
func audioFormat(filename string) string {
|
||
ext := strings.TrimPrefix(strings.ToLower(filepath.Ext(filename)), ".")
|
||
switch ext {
|
||
case "wav", "mp3", "flac", "m4a", "ogg", "opus", "webm":
|
||
return ext
|
||
default:
|
||
return "mp3"
|
||
}
|
||
}
|
||
|
||
func audioDataURL(audio []byte, filename string) string {
|
||
return "data:audio/" + audioFormat(filename) + ";base64," + base64.StdEncoding.EncodeToString(audio)
|
||
}
|
||
|
||
func firstNonEmpty(values ...string) string {
|
||
for _, value := range values {
|
||
if strings.TrimSpace(value) != "" {
|
||
return value
|
||
}
|
||
}
|
||
return ""
|
||
}
|