473 lines
13 KiB
Go
473 lines
13 KiB
Go
package transcription
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type Client struct {
|
|
provider ProviderConfig
|
|
http *http.Client
|
|
}
|
|
|
|
const ProviderVoxtral = "voxtral-small"
|
|
|
|
type Options struct {
|
|
VoxtralBaseURL string
|
|
VoxtralAPIKey string
|
|
VoxtralModel string
|
|
VoxtralTimeout time.Duration
|
|
AudioLLMPrompt string
|
|
}
|
|
|
|
type ProviderConfig struct {
|
|
Name string
|
|
BaseURL string
|
|
APIKey string
|
|
Model string
|
|
Timeout time.Duration
|
|
Prompt string
|
|
}
|
|
|
|
type Input struct {
|
|
AudioURL string `json:"audio_url"`
|
|
Filename string `json:"filename,omitempty"`
|
|
Language string `json:"language,omitempty"`
|
|
Diarize bool `json:"diarize"`
|
|
MinSpeakers int `json:"min_speakers,omitempty"`
|
|
MaxSpeakers int `json:"max_speakers,omitempty"`
|
|
}
|
|
|
|
type Segment struct {
|
|
Start float64 `json:"start"`
|
|
End float64 `json:"end"`
|
|
Text string `json:"text"`
|
|
Speaker string `json:"speaker,omitempty"`
|
|
}
|
|
|
|
type Result struct {
|
|
Provider string `json:"provider,omitempty"`
|
|
Model string `json:"model,omitempty"`
|
|
Attempts []Attempt `json:"attempts,omitempty"`
|
|
Language string `json:"language"`
|
|
Segments []Segment `json:"segments"`
|
|
DiarizeError *string `json:"diarize_error,omitempty"`
|
|
AlignError *string `json:"align_error,omitempty"`
|
|
DurationMS int64 `json:"duration_ms"`
|
|
}
|
|
|
|
type Attempt struct {
|
|
Provider string `json:"provider"`
|
|
Model string `json:"model,omitempty"`
|
|
Status string `json:"status"`
|
|
Error string `json:"error,omitempty"`
|
|
Text string `json:"text,omitempty"`
|
|
Segments []Segment `json:"segments,omitempty"`
|
|
DurationMS int64 `json:"duration_ms,omitempty"`
|
|
}
|
|
|
|
type audioLLMResponse struct {
|
|
Text string
|
|
Model string
|
|
Language string
|
|
Segments []Segment
|
|
}
|
|
|
|
type audioTranscriptionResponse struct {
|
|
Text string `json:"text"`
|
|
Language string `json:"language,omitempty"`
|
|
Model string `json:"model,omitempty"`
|
|
Segments []audioTranscriptionSegment `json:"segments,omitempty"`
|
|
Error *struct {
|
|
Message string `json:"message"`
|
|
} `json:"error,omitempty"`
|
|
}
|
|
|
|
type audioTranscriptionSegment struct {
|
|
Start float64 `json:"start"`
|
|
End float64 `json:"end"`
|
|
Text string `json:"text"`
|
|
}
|
|
|
|
func New(baseURL string, timeout time.Duration, ffmpegPath string, leadSilence time.Duration) *Client {
|
|
return NewWithOptions(Options{
|
|
VoxtralBaseURL: baseURL,
|
|
VoxtralTimeout: timeout,
|
|
})
|
|
}
|
|
|
|
func NewWithOptions(opts Options) *Client {
|
|
audioLLMPrompt := strings.TrimSpace(opts.AudioLLMPrompt)
|
|
if audioLLMPrompt == "" {
|
|
audioLLMPrompt = "Transcribe the audio exactly. Return only the transcript text."
|
|
}
|
|
provider := buildVoxtralProvider(opts, audioLLMPrompt)
|
|
if provider.BaseURL == "" {
|
|
return nil
|
|
}
|
|
return &Client{
|
|
provider: provider,
|
|
http: &http.Client{Timeout: provider.Timeout},
|
|
}
|
|
}
|
|
|
|
func buildVoxtralProvider(opts Options, prompt string) ProviderConfig {
|
|
baseURL := strings.TrimRight(strings.TrimSpace(opts.VoxtralBaseURL), "/")
|
|
if baseURL == "" {
|
|
return ProviderConfig{}
|
|
}
|
|
model := firstNonEmpty(opts.VoxtralModel, "mistralai/Voxtral-Small-24B-2507")
|
|
return ProviderConfig{
|
|
Name: ProviderVoxtral,
|
|
BaseURL: baseURL,
|
|
APIKey: strings.TrimSpace(opts.VoxtralAPIKey),
|
|
Model: model,
|
|
Timeout: defaultDuration(opts.VoxtralTimeout, 10*time.Minute),
|
|
Prompt: prompt,
|
|
}
|
|
}
|
|
|
|
func defaultDuration(v, fallback time.Duration) time.Duration {
|
|
if v <= 0 {
|
|
return fallback
|
|
}
|
|
return v
|
|
}
|
|
|
|
func (c *Client) Transcribe(ctx context.Context, in Input) (*Result, error) {
|
|
if c == nil || c.provider.BaseURL == "" {
|
|
return nil, fmt.Errorf("voxtral transcription provider not configured")
|
|
}
|
|
if strings.TrimSpace(in.AudioURL) == "" {
|
|
return nil, fmt.Errorf("audio_url is required")
|
|
}
|
|
audio, filename, err := c.downloadAudio(ctx, in)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
result, attempt, err := c.transcribeWithProvider(ctx, c.provider, audio, filename, in)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
result.Attempts = []Attempt{attempt}
|
|
return result, nil
|
|
}
|
|
|
|
func (c *Client) transcribeWithProvider(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*Result, Attempt, error) {
|
|
providerCtx := ctx
|
|
cancel := func() {}
|
|
if provider.Timeout > 0 {
|
|
providerCtx, cancel = context.WithTimeout(ctx, provider.Timeout)
|
|
}
|
|
defer cancel()
|
|
attempt := Attempt{
|
|
Provider: provider.Name,
|
|
Model: provider.Model,
|
|
Status: "failed",
|
|
}
|
|
resp, duration, err := c.transcribeOpenAIAudio(providerCtx, provider, audio, filename, in)
|
|
attempt.DurationMS = duration.Milliseconds()
|
|
if err != nil {
|
|
attempt.Error = err.Error()
|
|
return nil, attempt, err
|
|
}
|
|
text := strings.TrimSpace(resp.Text)
|
|
segments := normalizeAudioLLMSegments(resp.Segments, text, in.Diarize)
|
|
attempt.Status = "ok"
|
|
attempt.Model = resp.Model
|
|
attempt.Text = text
|
|
attempt.Segments = segments
|
|
return &Result{
|
|
Provider: provider.Name,
|
|
Model: resp.Model,
|
|
Language: firstNonEmpty(resp.Language, in.Language, "unknown"),
|
|
Segments: segments,
|
|
DurationMS: duration.Milliseconds(),
|
|
}, attempt, nil
|
|
}
|
|
|
|
func (c *Client) downloadAudio(ctx context.Context, in Input) ([]byte, string, error) {
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, in.AudioURL, nil)
|
|
if err != nil {
|
|
return nil, "", fmt.Errorf("audio request: %w", err)
|
|
}
|
|
resp, err := c.http.Do(req)
|
|
if err != nil {
|
|
return nil, "", fmt.Errorf("audio download: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode >= 300 {
|
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
|
return nil, "", fmt.Errorf("audio HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
|
}
|
|
audio, err := io.ReadAll(io.LimitReader(resp.Body, 512<<20))
|
|
if err != nil {
|
|
return nil, "", fmt.Errorf("audio read: %w", err)
|
|
}
|
|
if len(audio) == 0 {
|
|
return nil, "", fmt.Errorf("audio is empty")
|
|
}
|
|
filename := filepath.Base(strings.TrimSpace(in.Filename))
|
|
if filename == "." || filename == "/" || filename == "" {
|
|
filename = "audio.mp3"
|
|
}
|
|
return audio, filename, nil
|
|
}
|
|
|
|
func clampTime(v float64) float64 {
|
|
if v < 0 {
|
|
return 0
|
|
}
|
|
return v
|
|
}
|
|
|
|
func (c *Client) transcribeOpenAIAudio(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*audioLLMResponse, time.Duration, error) {
|
|
resp, duration, err := c.doOpenAIAudioTranscription(ctx, provider, audio, filename, in, "verbose_json")
|
|
if err == nil {
|
|
return resp, duration, nil
|
|
}
|
|
if !strings.Contains(strings.ToLower(err.Error()), "http 4") {
|
|
return nil, duration, err
|
|
}
|
|
fallback, fallbackDuration, fallbackErr := c.doOpenAIAudioTranscription(ctx, provider, audio, filename, in, "json")
|
|
if fallbackErr != nil {
|
|
return nil, duration + fallbackDuration, err
|
|
}
|
|
return fallback, duration + fallbackDuration, nil
|
|
}
|
|
|
|
func (c *Client) doOpenAIAudioTranscription(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input, responseFormat string) (*audioLLMResponse, time.Duration, error) {
|
|
body := &bytes.Buffer{}
|
|
mw := multipart.NewWriter(body)
|
|
if err := mw.WriteField("model", provider.Model); err != nil {
|
|
return nil, 0, fmt.Errorf("audio transcription model field: %w", err)
|
|
}
|
|
if err := mw.WriteField("response_format", responseFormat); err != nil {
|
|
return nil, 0, fmt.Errorf("audio transcription response_format field: %w", err)
|
|
}
|
|
if responseFormat == "verbose_json" {
|
|
if err := mw.WriteField("timestamp_granularities[]", "segment"); err != nil {
|
|
return nil, 0, fmt.Errorf("audio transcription timestamp field: %w", err)
|
|
}
|
|
}
|
|
if prompt := strings.TrimSpace(provider.Prompt); prompt != "" {
|
|
if err := mw.WriteField("prompt", prompt); err != nil {
|
|
return nil, 0, fmt.Errorf("audio transcription prompt field: %w", err)
|
|
}
|
|
}
|
|
if lang := strings.TrimSpace(in.Language); lang != "" {
|
|
if err := mw.WriteField("language", lang); err != nil {
|
|
return nil, 0, fmt.Errorf("audio transcription language field: %w", err)
|
|
}
|
|
}
|
|
fw, err := mw.CreateFormFile("file", filename)
|
|
if err != nil {
|
|
return nil, 0, fmt.Errorf("audio transcription create file: %w", err)
|
|
}
|
|
if _, err := fw.Write(audio); err != nil {
|
|
return nil, 0, fmt.Errorf("audio transcription copy audio: %w", err)
|
|
}
|
|
if err := mw.Close(); err != nil {
|
|
return nil, 0, fmt.Errorf("audio transcription close form: %w", err)
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, provider.BaseURL+"/v1/audio/transcriptions", body)
|
|
if err != nil {
|
|
return nil, 0, fmt.Errorf("audio transcription request: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", mw.FormDataContentType())
|
|
if provider.APIKey != "" {
|
|
req.Header.Set("Authorization", "Bearer "+provider.APIKey)
|
|
}
|
|
|
|
start := time.Now()
|
|
resp, err := c.http.Do(req)
|
|
duration := time.Since(start)
|
|
if err != nil {
|
|
return nil, duration, fmt.Errorf("audio transcription do: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
raw, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
|
|
if err != nil {
|
|
return nil, duration, fmt.Errorf("audio transcription read: %w", err)
|
|
}
|
|
if resp.StatusCode >= 300 {
|
|
return nil, duration, fmt.Errorf("audio transcription HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(raw)))
|
|
}
|
|
var out audioTranscriptionResponse
|
|
if err := json.Unmarshal(raw, &out); err != nil {
|
|
return nil, duration, fmt.Errorf("audio transcription decode: %w", err)
|
|
}
|
|
if out.Error != nil {
|
|
return nil, duration, fmt.Errorf("audio transcription error: %s", out.Error.Message)
|
|
}
|
|
modelName := firstNonEmpty(out.Model, provider.Model)
|
|
return &audioLLMResponse{
|
|
Text: strings.TrimSpace(out.Text),
|
|
Model: modelName,
|
|
Language: out.Language,
|
|
Segments: convertAudioSegments(out.Segments),
|
|
}, duration, nil
|
|
}
|
|
|
|
func convertAudioSegments(in []audioTranscriptionSegment) []Segment {
|
|
out := make([]Segment, 0, len(in))
|
|
for _, s := range in {
|
|
text := strings.TrimSpace(s.Text)
|
|
if text == "" {
|
|
continue
|
|
}
|
|
end := s.End
|
|
if end < s.Start {
|
|
end = s.Start
|
|
}
|
|
out = append(out, Segment{Start: clampTime(s.Start), End: clampTime(end), Text: text})
|
|
}
|
|
return out
|
|
}
|
|
|
|
func normalizeAudioLLMSegments(segments []Segment, text string, diarize bool) []Segment {
|
|
text = strings.TrimSpace(text)
|
|
if len(segments) <= 1 && text != "" {
|
|
heuristic := segmentTranscriptText(text, diarize)
|
|
if len(heuristic) > len(segments) {
|
|
segments = heuristic
|
|
}
|
|
}
|
|
return ensureHeuristicSpeakers(segments, diarize)
|
|
}
|
|
|
|
func ensureHeuristicSpeakers(segments []Segment, diarize bool) []Segment {
|
|
if !diarize || len(segments) < 2 || segmentsHaveSpeakers(segments) {
|
|
return segments
|
|
}
|
|
out := make([]Segment, len(segments))
|
|
copy(out, segments)
|
|
for i := range out {
|
|
if i%2 == 0 {
|
|
out[i].Speaker = "SPEAKER_00"
|
|
} else {
|
|
out[i].Speaker = "SPEAKER_01"
|
|
}
|
|
}
|
|
return out
|
|
}
|
|
|
|
func segmentsHaveSpeakers(segments []Segment) bool {
|
|
for _, segment := range segments {
|
|
if strings.TrimSpace(segment.Speaker) != "" {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func segmentTranscriptText(text string, diarize bool) []Segment {
|
|
parts := splitTranscriptSentences(text)
|
|
out := make([]Segment, 0, len(parts))
|
|
var t float64
|
|
for i, part := range parts {
|
|
words := len(strings.Fields(part))
|
|
if words == 0 {
|
|
continue
|
|
}
|
|
duration := float64(words) * 0.42
|
|
if duration < 1.2 {
|
|
duration = 1.2
|
|
}
|
|
segment := Segment{Start: t, End: t + duration, Text: part}
|
|
if diarize && len(parts) > 1 {
|
|
if i%2 == 0 {
|
|
segment.Speaker = "SPEAKER_00"
|
|
} else {
|
|
segment.Speaker = "SPEAKER_01"
|
|
}
|
|
}
|
|
out = append(out, segment)
|
|
t = segment.End
|
|
}
|
|
if len(out) == 0 && strings.TrimSpace(text) != "" {
|
|
out = append(out, Segment{Start: 0, End: 0, Text: strings.TrimSpace(text)})
|
|
}
|
|
return out
|
|
}
|
|
|
|
func splitTranscriptSentences(text string) []string {
|
|
text = strings.Join(strings.Fields(text), " ")
|
|
if text == "" {
|
|
return nil
|
|
}
|
|
var out []string
|
|
start := 0
|
|
runes := []rune(text)
|
|
for i, r := range runes {
|
|
if r != '.' && r != '!' && r != '?' && r != '…' {
|
|
continue
|
|
}
|
|
next := i + 1
|
|
if next < len(runes) && runes[next] != ' ' {
|
|
continue
|
|
}
|
|
part := strings.TrimSpace(string(runes[start : i+1]))
|
|
if part != "" {
|
|
out = append(out, part)
|
|
}
|
|
start = i + 1
|
|
for start < len(runes) && runes[start] == ' ' {
|
|
start++
|
|
}
|
|
}
|
|
tail := strings.TrimSpace(string(runes[start:]))
|
|
if tail != "" {
|
|
out = append(out, tail)
|
|
}
|
|
return mergeShortSegments(out, 8, 34)
|
|
}
|
|
|
|
func mergeShortSegments(parts []string, minWords, maxWords int) []string {
|
|
if len(parts) <= 1 {
|
|
return parts
|
|
}
|
|
out := make([]string, 0, len(parts))
|
|
var current []string
|
|
currentWords := 0
|
|
flush := func() {
|
|
if len(current) == 0 {
|
|
return
|
|
}
|
|
out = append(out, strings.Join(current, " "))
|
|
current = nil
|
|
currentWords = 0
|
|
}
|
|
for _, part := range parts {
|
|
words := len(strings.Fields(part))
|
|
if currentWords > 0 && currentWords+words > maxWords {
|
|
flush()
|
|
}
|
|
current = append(current, part)
|
|
currentWords += words
|
|
if currentWords >= minWords {
|
|
flush()
|
|
}
|
|
}
|
|
flush()
|
|
return out
|
|
}
|
|
|
|
func firstNonEmpty(values ...string) string {
|
|
for _, value := range values {
|
|
if strings.TrimSpace(value) != "" {
|
|
return value
|
|
}
|
|
}
|
|
return ""
|
|
}
|