Files
ai-service/internal/transcription/client.go
Grendgi 9bd6d726f0
Some checks failed
CI / test (push) Failing after 8s
Build and Deploy / build-and-deploy (push) Successful in 27s
Make Voxtral the only transcription provider
2026-06-09 16:54:54 +03:00

473 lines
13 KiB
Go

package transcription
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"path/filepath"
"strings"
"time"
)
type Client struct {
provider ProviderConfig
http *http.Client
}
const ProviderVoxtral = "voxtral-small"
type Options struct {
VoxtralBaseURL string
VoxtralAPIKey string
VoxtralModel string
VoxtralTimeout time.Duration
AudioLLMPrompt string
}
type ProviderConfig struct {
Name string
BaseURL string
APIKey string
Model string
Timeout time.Duration
Prompt string
}
type Input struct {
AudioURL string `json:"audio_url"`
Filename string `json:"filename,omitempty"`
Language string `json:"language,omitempty"`
Diarize bool `json:"diarize"`
MinSpeakers int `json:"min_speakers,omitempty"`
MaxSpeakers int `json:"max_speakers,omitempty"`
}
type Segment struct {
Start float64 `json:"start"`
End float64 `json:"end"`
Text string `json:"text"`
Speaker string `json:"speaker,omitempty"`
}
type Result struct {
Provider string `json:"provider,omitempty"`
Model string `json:"model,omitempty"`
Attempts []Attempt `json:"attempts,omitempty"`
Language string `json:"language"`
Segments []Segment `json:"segments"`
DiarizeError *string `json:"diarize_error,omitempty"`
AlignError *string `json:"align_error,omitempty"`
DurationMS int64 `json:"duration_ms"`
}
type Attempt struct {
Provider string `json:"provider"`
Model string `json:"model,omitempty"`
Status string `json:"status"`
Error string `json:"error,omitempty"`
Text string `json:"text,omitempty"`
Segments []Segment `json:"segments,omitempty"`
DurationMS int64 `json:"duration_ms,omitempty"`
}
type audioLLMResponse struct {
Text string
Model string
Language string
Segments []Segment
}
type audioTranscriptionResponse struct {
Text string `json:"text"`
Language string `json:"language,omitempty"`
Model string `json:"model,omitempty"`
Segments []audioTranscriptionSegment `json:"segments,omitempty"`
Error *struct {
Message string `json:"message"`
} `json:"error,omitempty"`
}
type audioTranscriptionSegment struct {
Start float64 `json:"start"`
End float64 `json:"end"`
Text string `json:"text"`
}
func New(baseURL string, timeout time.Duration, ffmpegPath string, leadSilence time.Duration) *Client {
return NewWithOptions(Options{
VoxtralBaseURL: baseURL,
VoxtralTimeout: timeout,
})
}
func NewWithOptions(opts Options) *Client {
audioLLMPrompt := strings.TrimSpace(opts.AudioLLMPrompt)
if audioLLMPrompt == "" {
audioLLMPrompt = "Transcribe the audio exactly. Return only the transcript text."
}
provider := buildVoxtralProvider(opts, audioLLMPrompt)
if provider.BaseURL == "" {
return nil
}
return &Client{
provider: provider,
http: &http.Client{Timeout: provider.Timeout},
}
}
func buildVoxtralProvider(opts Options, prompt string) ProviderConfig {
baseURL := strings.TrimRight(strings.TrimSpace(opts.VoxtralBaseURL), "/")
if baseURL == "" {
return ProviderConfig{}
}
model := firstNonEmpty(opts.VoxtralModel, "mistralai/Voxtral-Small-24B-2507")
return ProviderConfig{
Name: ProviderVoxtral,
BaseURL: baseURL,
APIKey: strings.TrimSpace(opts.VoxtralAPIKey),
Model: model,
Timeout: defaultDuration(opts.VoxtralTimeout, 10*time.Minute),
Prompt: prompt,
}
}
func defaultDuration(v, fallback time.Duration) time.Duration {
if v <= 0 {
return fallback
}
return v
}
func (c *Client) Transcribe(ctx context.Context, in Input) (*Result, error) {
if c == nil || c.provider.BaseURL == "" {
return nil, fmt.Errorf("voxtral transcription provider not configured")
}
if strings.TrimSpace(in.AudioURL) == "" {
return nil, fmt.Errorf("audio_url is required")
}
audio, filename, err := c.downloadAudio(ctx, in)
if err != nil {
return nil, err
}
result, attempt, err := c.transcribeWithProvider(ctx, c.provider, audio, filename, in)
if err != nil {
return nil, err
}
result.Attempts = []Attempt{attempt}
return result, nil
}
func (c *Client) transcribeWithProvider(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*Result, Attempt, error) {
providerCtx := ctx
cancel := func() {}
if provider.Timeout > 0 {
providerCtx, cancel = context.WithTimeout(ctx, provider.Timeout)
}
defer cancel()
attempt := Attempt{
Provider: provider.Name,
Model: provider.Model,
Status: "failed",
}
resp, duration, err := c.transcribeOpenAIAudio(providerCtx, provider, audio, filename, in)
attempt.DurationMS = duration.Milliseconds()
if err != nil {
attempt.Error = err.Error()
return nil, attempt, err
}
text := strings.TrimSpace(resp.Text)
segments := normalizeAudioLLMSegments(resp.Segments, text, in.Diarize)
attempt.Status = "ok"
attempt.Model = resp.Model
attempt.Text = text
attempt.Segments = segments
return &Result{
Provider: provider.Name,
Model: resp.Model,
Language: firstNonEmpty(resp.Language, in.Language, "unknown"),
Segments: segments,
DurationMS: duration.Milliseconds(),
}, attempt, nil
}
func (c *Client) downloadAudio(ctx context.Context, in Input) ([]byte, string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, in.AudioURL, nil)
if err != nil {
return nil, "", fmt.Errorf("audio request: %w", err)
}
resp, err := c.http.Do(req)
if err != nil {
return nil, "", fmt.Errorf("audio download: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return nil, "", fmt.Errorf("audio HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
audio, err := io.ReadAll(io.LimitReader(resp.Body, 512<<20))
if err != nil {
return nil, "", fmt.Errorf("audio read: %w", err)
}
if len(audio) == 0 {
return nil, "", fmt.Errorf("audio is empty")
}
filename := filepath.Base(strings.TrimSpace(in.Filename))
if filename == "." || filename == "/" || filename == "" {
filename = "audio.mp3"
}
return audio, filename, nil
}
func clampTime(v float64) float64 {
if v < 0 {
return 0
}
return v
}
func (c *Client) transcribeOpenAIAudio(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*audioLLMResponse, time.Duration, error) {
resp, duration, err := c.doOpenAIAudioTranscription(ctx, provider, audio, filename, in, "verbose_json")
if err == nil {
return resp, duration, nil
}
if !strings.Contains(strings.ToLower(err.Error()), "http 4") {
return nil, duration, err
}
fallback, fallbackDuration, fallbackErr := c.doOpenAIAudioTranscription(ctx, provider, audio, filename, in, "json")
if fallbackErr != nil {
return nil, duration + fallbackDuration, err
}
return fallback, duration + fallbackDuration, nil
}
func (c *Client) doOpenAIAudioTranscription(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input, responseFormat string) (*audioLLMResponse, time.Duration, error) {
body := &bytes.Buffer{}
mw := multipart.NewWriter(body)
if err := mw.WriteField("model", provider.Model); err != nil {
return nil, 0, fmt.Errorf("audio transcription model field: %w", err)
}
if err := mw.WriteField("response_format", responseFormat); err != nil {
return nil, 0, fmt.Errorf("audio transcription response_format field: %w", err)
}
if responseFormat == "verbose_json" {
if err := mw.WriteField("timestamp_granularities[]", "segment"); err != nil {
return nil, 0, fmt.Errorf("audio transcription timestamp field: %w", err)
}
}
if prompt := strings.TrimSpace(provider.Prompt); prompt != "" {
if err := mw.WriteField("prompt", prompt); err != nil {
return nil, 0, fmt.Errorf("audio transcription prompt field: %w", err)
}
}
if lang := strings.TrimSpace(in.Language); lang != "" {
if err := mw.WriteField("language", lang); err != nil {
return nil, 0, fmt.Errorf("audio transcription language field: %w", err)
}
}
fw, err := mw.CreateFormFile("file", filename)
if err != nil {
return nil, 0, fmt.Errorf("audio transcription create file: %w", err)
}
if _, err := fw.Write(audio); err != nil {
return nil, 0, fmt.Errorf("audio transcription copy audio: %w", err)
}
if err := mw.Close(); err != nil {
return nil, 0, fmt.Errorf("audio transcription close form: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, provider.BaseURL+"/v1/audio/transcriptions", body)
if err != nil {
return nil, 0, fmt.Errorf("audio transcription request: %w", err)
}
req.Header.Set("Content-Type", mw.FormDataContentType())
if provider.APIKey != "" {
req.Header.Set("Authorization", "Bearer "+provider.APIKey)
}
start := time.Now()
resp, err := c.http.Do(req)
duration := time.Since(start)
if err != nil {
return nil, duration, fmt.Errorf("audio transcription do: %w", err)
}
defer resp.Body.Close()
raw, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
if err != nil {
return nil, duration, fmt.Errorf("audio transcription read: %w", err)
}
if resp.StatusCode >= 300 {
return nil, duration, fmt.Errorf("audio transcription HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(raw)))
}
var out audioTranscriptionResponse
if err := json.Unmarshal(raw, &out); err != nil {
return nil, duration, fmt.Errorf("audio transcription decode: %w", err)
}
if out.Error != nil {
return nil, duration, fmt.Errorf("audio transcription error: %s", out.Error.Message)
}
modelName := firstNonEmpty(out.Model, provider.Model)
return &audioLLMResponse{
Text: strings.TrimSpace(out.Text),
Model: modelName,
Language: out.Language,
Segments: convertAudioSegments(out.Segments),
}, duration, nil
}
func convertAudioSegments(in []audioTranscriptionSegment) []Segment {
out := make([]Segment, 0, len(in))
for _, s := range in {
text := strings.TrimSpace(s.Text)
if text == "" {
continue
}
end := s.End
if end < s.Start {
end = s.Start
}
out = append(out, Segment{Start: clampTime(s.Start), End: clampTime(end), Text: text})
}
return out
}
func normalizeAudioLLMSegments(segments []Segment, text string, diarize bool) []Segment {
text = strings.TrimSpace(text)
if len(segments) <= 1 && text != "" {
heuristic := segmentTranscriptText(text, diarize)
if len(heuristic) > len(segments) {
segments = heuristic
}
}
return ensureHeuristicSpeakers(segments, diarize)
}
func ensureHeuristicSpeakers(segments []Segment, diarize bool) []Segment {
if !diarize || len(segments) < 2 || segmentsHaveSpeakers(segments) {
return segments
}
out := make([]Segment, len(segments))
copy(out, segments)
for i := range out {
if i%2 == 0 {
out[i].Speaker = "SPEAKER_00"
} else {
out[i].Speaker = "SPEAKER_01"
}
}
return out
}
func segmentsHaveSpeakers(segments []Segment) bool {
for _, segment := range segments {
if strings.TrimSpace(segment.Speaker) != "" {
return true
}
}
return false
}
func segmentTranscriptText(text string, diarize bool) []Segment {
parts := splitTranscriptSentences(text)
out := make([]Segment, 0, len(parts))
var t float64
for i, part := range parts {
words := len(strings.Fields(part))
if words == 0 {
continue
}
duration := float64(words) * 0.42
if duration < 1.2 {
duration = 1.2
}
segment := Segment{Start: t, End: t + duration, Text: part}
if diarize && len(parts) > 1 {
if i%2 == 0 {
segment.Speaker = "SPEAKER_00"
} else {
segment.Speaker = "SPEAKER_01"
}
}
out = append(out, segment)
t = segment.End
}
if len(out) == 0 && strings.TrimSpace(text) != "" {
out = append(out, Segment{Start: 0, End: 0, Text: strings.TrimSpace(text)})
}
return out
}
func splitTranscriptSentences(text string) []string {
text = strings.Join(strings.Fields(text), " ")
if text == "" {
return nil
}
var out []string
start := 0
runes := []rune(text)
for i, r := range runes {
if r != '.' && r != '!' && r != '?' && r != '…' {
continue
}
next := i + 1
if next < len(runes) && runes[next] != ' ' {
continue
}
part := strings.TrimSpace(string(runes[start : i+1]))
if part != "" {
out = append(out, part)
}
start = i + 1
for start < len(runes) && runes[start] == ' ' {
start++
}
}
tail := strings.TrimSpace(string(runes[start:]))
if tail != "" {
out = append(out, tail)
}
return mergeShortSegments(out, 8, 34)
}
func mergeShortSegments(parts []string, minWords, maxWords int) []string {
if len(parts) <= 1 {
return parts
}
out := make([]string, 0, len(parts))
var current []string
currentWords := 0
flush := func() {
if len(current) == 0 {
return
}
out = append(out, strings.Join(current, " "))
current = nil
currentWords = 0
}
for _, part := range parts {
words := len(strings.Fields(part))
if currentWords > 0 && currentWords+words > maxWords {
flush()
}
current = append(current, part)
currentWords += words
if currentWords >= minWords {
flush()
}
}
flush()
return out
}
func firstNonEmpty(values ...string) string {
for _, value := range values {
if strings.TrimSpace(value) != "" {
return value
}
}
return ""
}