474 lines
13 KiB
Go
474 lines
13 KiB
Go
package transcription
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"mime/multipart"
|
||
"net/http"
|
||
"path/filepath"
|
||
"regexp"
|
||
"strings"
|
||
"time"
|
||
)
|
||
|
||
type Client struct {
|
||
provider ProviderConfig
|
||
http *http.Client
|
||
}
|
||
|
||
const ProviderVoxtral = "voxtral-small"
|
||
|
||
var speakerLabelPattern = regexp.MustCompile(`(?i)(?:^|[\n\r ]+)((?:speaker|спикер|говорящий)\s*\d+)\s*[::-]`)
|
||
|
||
type Options struct {
|
||
VoxtralBaseURL string
|
||
VoxtralAPIKey string
|
||
VoxtralModel string
|
||
VoxtralTimeout time.Duration
|
||
AudioLLMPrompt string
|
||
}
|
||
|
||
type ProviderConfig struct {
|
||
Name string
|
||
BaseURL string
|
||
APIKey string
|
||
Model string
|
||
Timeout time.Duration
|
||
Prompt string
|
||
}
|
||
|
||
type Input struct {
|
||
AudioURL string `json:"audio_url"`
|
||
Filename string `json:"filename,omitempty"`
|
||
Language string `json:"language,omitempty"`
|
||
Diarize bool `json:"diarize"`
|
||
MinSpeakers int `json:"min_speakers,omitempty"`
|
||
MaxSpeakers int `json:"max_speakers,omitempty"`
|
||
}
|
||
|
||
type Segment struct {
|
||
Start float64 `json:"start"`
|
||
End float64 `json:"end"`
|
||
Text string `json:"text"`
|
||
Speaker string `json:"speaker,omitempty"`
|
||
}
|
||
|
||
type Result struct {
|
||
Provider string `json:"provider,omitempty"`
|
||
Model string `json:"model,omitempty"`
|
||
Attempts []Attempt `json:"attempts,omitempty"`
|
||
Language string `json:"language"`
|
||
Segments []Segment `json:"segments"`
|
||
DiarizeError *string `json:"diarize_error,omitempty"`
|
||
AlignError *string `json:"align_error,omitempty"`
|
||
DurationMS int64 `json:"duration_ms"`
|
||
}
|
||
|
||
type Attempt struct {
|
||
Provider string `json:"provider"`
|
||
Model string `json:"model,omitempty"`
|
||
Status string `json:"status"`
|
||
Error string `json:"error,omitempty"`
|
||
Text string `json:"text,omitempty"`
|
||
Segments []Segment `json:"segments,omitempty"`
|
||
DurationMS int64 `json:"duration_ms,omitempty"`
|
||
}
|
||
|
||
type audioLLMResponse struct {
|
||
Text string
|
||
Model string
|
||
Language string
|
||
Segments []Segment
|
||
}
|
||
|
||
type audioTranscriptionResponse struct {
|
||
Text string `json:"text"`
|
||
Language string `json:"language,omitempty"`
|
||
Model string `json:"model,omitempty"`
|
||
Segments []audioTranscriptionSegment `json:"segments,omitempty"`
|
||
Error *struct {
|
||
Message string `json:"message"`
|
||
} `json:"error,omitempty"`
|
||
}
|
||
|
||
type audioTranscriptionSegment struct {
|
||
Start float64 `json:"start"`
|
||
End float64 `json:"end"`
|
||
Text string `json:"text"`
|
||
}
|
||
|
||
func New(baseURL string, timeout time.Duration, ffmpegPath string, leadSilence time.Duration) *Client {
|
||
return NewWithOptions(Options{
|
||
VoxtralBaseURL: baseURL,
|
||
VoxtralTimeout: timeout,
|
||
})
|
||
}
|
||
|
||
func NewWithOptions(opts Options) *Client {
|
||
audioLLMPrompt := strings.TrimSpace(opts.AudioLLMPrompt)
|
||
if audioLLMPrompt == "" {
|
||
audioLLMPrompt = "Transcribe the audio exactly. Return only the transcript text."
|
||
}
|
||
provider := buildVoxtralProvider(opts, audioLLMPrompt)
|
||
if provider.BaseURL == "" {
|
||
return nil
|
||
}
|
||
return &Client{
|
||
provider: provider,
|
||
http: &http.Client{Timeout: provider.Timeout},
|
||
}
|
||
}
|
||
|
||
func buildVoxtralProvider(opts Options, prompt string) ProviderConfig {
|
||
baseURL := strings.TrimRight(strings.TrimSpace(opts.VoxtralBaseURL), "/")
|
||
if baseURL == "" {
|
||
return ProviderConfig{}
|
||
}
|
||
model := firstNonEmpty(opts.VoxtralModel, "mistralai/Voxtral-Small-24B-2507")
|
||
return ProviderConfig{
|
||
Name: ProviderVoxtral,
|
||
BaseURL: baseURL,
|
||
APIKey: strings.TrimSpace(opts.VoxtralAPIKey),
|
||
Model: model,
|
||
Timeout: defaultDuration(opts.VoxtralTimeout, 10*time.Minute),
|
||
Prompt: prompt,
|
||
}
|
||
}
|
||
|
||
func defaultDuration(v, fallback time.Duration) time.Duration {
|
||
if v <= 0 {
|
||
return fallback
|
||
}
|
||
return v
|
||
}
|
||
|
||
func (c *Client) Transcribe(ctx context.Context, in Input) (*Result, error) {
|
||
if c == nil || c.provider.BaseURL == "" {
|
||
return nil, fmt.Errorf("voxtral transcription provider not configured")
|
||
}
|
||
if strings.TrimSpace(in.AudioURL) == "" {
|
||
return nil, fmt.Errorf("audio_url is required")
|
||
}
|
||
audio, filename, err := c.downloadAudio(ctx, in)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
result, attempt, err := c.transcribeWithProvider(ctx, c.provider, audio, filename, in)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
result.Attempts = []Attempt{attempt}
|
||
return result, nil
|
||
}
|
||
|
||
func (c *Client) transcribeWithProvider(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*Result, Attempt, error) {
|
||
providerCtx := ctx
|
||
cancel := func() {}
|
||
if provider.Timeout > 0 {
|
||
providerCtx, cancel = context.WithTimeout(ctx, provider.Timeout)
|
||
}
|
||
defer cancel()
|
||
attempt := Attempt{
|
||
Provider: provider.Name,
|
||
Model: provider.Model,
|
||
Status: "failed",
|
||
}
|
||
resp, duration, err := c.transcribeOpenAIAudio(providerCtx, provider, audio, filename, in)
|
||
attempt.DurationMS = duration.Milliseconds()
|
||
if err != nil {
|
||
attempt.Error = err.Error()
|
||
return nil, attempt, err
|
||
}
|
||
text := strings.TrimSpace(resp.Text)
|
||
segments := normalizeAudioLLMSegments(resp.Segments, text, in.Diarize)
|
||
attempt.Status = "ok"
|
||
attempt.Model = resp.Model
|
||
attempt.Text = text
|
||
attempt.Segments = segments
|
||
return &Result{
|
||
Provider: provider.Name,
|
||
Model: resp.Model,
|
||
Language: firstNonEmpty(resp.Language, in.Language, "unknown"),
|
||
Segments: segments,
|
||
DurationMS: duration.Milliseconds(),
|
||
}, attempt, nil
|
||
}
|
||
|
||
func (c *Client) downloadAudio(ctx context.Context, in Input) ([]byte, string, error) {
|
||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, in.AudioURL, nil)
|
||
if err != nil {
|
||
return nil, "", fmt.Errorf("audio request: %w", err)
|
||
}
|
||
resp, err := c.http.Do(req)
|
||
if err != nil {
|
||
return nil, "", fmt.Errorf("audio download: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
if resp.StatusCode >= 300 {
|
||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||
return nil, "", fmt.Errorf("audio HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||
}
|
||
audio, err := io.ReadAll(io.LimitReader(resp.Body, 512<<20))
|
||
if err != nil {
|
||
return nil, "", fmt.Errorf("audio read: %w", err)
|
||
}
|
||
if len(audio) == 0 {
|
||
return nil, "", fmt.Errorf("audio is empty")
|
||
}
|
||
filename := filepath.Base(strings.TrimSpace(in.Filename))
|
||
if filename == "." || filename == "/" || filename == "" {
|
||
filename = "audio.mp3"
|
||
}
|
||
return audio, filename, nil
|
||
}
|
||
|
||
func clampTime(v float64) float64 {
|
||
if v < 0 {
|
||
return 0
|
||
}
|
||
return v
|
||
}
|
||
|
||
func (c *Client) transcribeOpenAIAudio(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*audioLLMResponse, time.Duration, error) {
|
||
return c.doOpenAIAudioTranscription(ctx, provider, audio, filename, in, "json")
|
||
}
|
||
|
||
func (c *Client) doOpenAIAudioTranscription(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input, responseFormat string) (*audioLLMResponse, time.Duration, error) {
|
||
body := &bytes.Buffer{}
|
||
mw := multipart.NewWriter(body)
|
||
if err := mw.WriteField("model", provider.Model); err != nil {
|
||
return nil, 0, fmt.Errorf("audio transcription model field: %w", err)
|
||
}
|
||
if err := mw.WriteField("response_format", responseFormat); err != nil {
|
||
return nil, 0, fmt.Errorf("audio transcription response_format field: %w", err)
|
||
}
|
||
if responseFormat == "verbose_json" {
|
||
if err := mw.WriteField("timestamp_granularities[]", "segment"); err != nil {
|
||
return nil, 0, fmt.Errorf("audio transcription timestamp field: %w", err)
|
||
}
|
||
}
|
||
if prompt := strings.TrimSpace(provider.Prompt); prompt != "" {
|
||
if err := mw.WriteField("prompt", prompt); err != nil {
|
||
return nil, 0, fmt.Errorf("audio transcription prompt field: %w", err)
|
||
}
|
||
}
|
||
if lang := strings.TrimSpace(in.Language); lang != "" {
|
||
if err := mw.WriteField("language", lang); err != nil {
|
||
return nil, 0, fmt.Errorf("audio transcription language field: %w", err)
|
||
}
|
||
}
|
||
fw, err := mw.CreateFormFile("file", filename)
|
||
if err != nil {
|
||
return nil, 0, fmt.Errorf("audio transcription create file: %w", err)
|
||
}
|
||
if _, err := fw.Write(audio); err != nil {
|
||
return nil, 0, fmt.Errorf("audio transcription copy audio: %w", err)
|
||
}
|
||
if err := mw.Close(); err != nil {
|
||
return nil, 0, fmt.Errorf("audio transcription close form: %w", err)
|
||
}
|
||
|
||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, provider.BaseURL+"/v1/audio/transcriptions", body)
|
||
if err != nil {
|
||
return nil, 0, fmt.Errorf("audio transcription request: %w", err)
|
||
}
|
||
req.Header.Set("Content-Type", mw.FormDataContentType())
|
||
if provider.APIKey != "" {
|
||
req.Header.Set("Authorization", "Bearer "+provider.APIKey)
|
||
}
|
||
|
||
start := time.Now()
|
||
resp, err := c.http.Do(req)
|
||
duration := time.Since(start)
|
||
if err != nil {
|
||
return nil, duration, fmt.Errorf("audio transcription do: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
raw, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
|
||
if err != nil {
|
||
return nil, duration, fmt.Errorf("audio transcription read: %w", err)
|
||
}
|
||
if resp.StatusCode >= 300 {
|
||
return nil, duration, fmt.Errorf("audio transcription HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(raw)))
|
||
}
|
||
var out audioTranscriptionResponse
|
||
if err := json.Unmarshal(raw, &out); err != nil {
|
||
return nil, duration, fmt.Errorf("audio transcription decode: %w", err)
|
||
}
|
||
if out.Error != nil {
|
||
return nil, duration, fmt.Errorf("audio transcription error: %s", out.Error.Message)
|
||
}
|
||
modelName := firstNonEmpty(out.Model, provider.Model)
|
||
return &audioLLMResponse{
|
||
Text: strings.TrimSpace(out.Text),
|
||
Model: modelName,
|
||
Language: out.Language,
|
||
Segments: convertAudioSegments(out.Segments),
|
||
}, duration, nil
|
||
}
|
||
|
||
func convertAudioSegments(in []audioTranscriptionSegment) []Segment {
|
||
out := make([]Segment, 0, len(in))
|
||
for _, s := range in {
|
||
text := strings.TrimSpace(s.Text)
|
||
if text == "" {
|
||
continue
|
||
}
|
||
end := s.End
|
||
if end < s.Start {
|
||
end = s.Start
|
||
}
|
||
out = append(out, Segment{Start: clampTime(s.Start), End: clampTime(end), Text: text})
|
||
}
|
||
return out
|
||
}
|
||
|
||
func normalizeAudioLLMSegments(segments []Segment, text string, diarize bool) []Segment {
|
||
text = strings.TrimSpace(text)
|
||
if text != "" {
|
||
if labeled := segmentSpeakerLabeledText(text); len(labeled) > 0 {
|
||
return labeled
|
||
}
|
||
}
|
||
if len(segments) <= 1 && text != "" {
|
||
heuristic := segmentTranscriptText(text)
|
||
if len(heuristic) > len(segments) {
|
||
segments = heuristic
|
||
}
|
||
}
|
||
return segments
|
||
}
|
||
|
||
func segmentSpeakerLabeledText(text string) []Segment {
|
||
matches := speakerLabelPattern.FindAllStringSubmatchIndex(text, -1)
|
||
if len(matches) == 0 {
|
||
return nil
|
||
}
|
||
speakerIDs := map[string]string{}
|
||
var out []Segment
|
||
var t float64
|
||
for i, match := range matches {
|
||
label := strings.ToLower(strings.TrimSpace(text[match[2]:match[3]]))
|
||
speaker, ok := speakerIDs[label]
|
||
if !ok {
|
||
speaker = fmt.Sprintf("SPEAKER_%02d", len(speakerIDs))
|
||
speakerIDs[label] = speaker
|
||
}
|
||
start := match[1]
|
||
end := len(text)
|
||
if i+1 < len(matches) {
|
||
end = matches[i+1][0]
|
||
}
|
||
part := strings.TrimSpace(text[start:end])
|
||
part = strings.Trim(part, ":-— ")
|
||
if part == "" {
|
||
continue
|
||
}
|
||
words := len(strings.Fields(part))
|
||
duration := float64(words) * 0.42
|
||
if duration < 1.2 {
|
||
duration = 1.2
|
||
}
|
||
out = append(out, Segment{Start: t, End: t + duration, Text: part, Speaker: speaker})
|
||
t += duration
|
||
}
|
||
return out
|
||
}
|
||
|
||
func segmentTranscriptText(text string) []Segment {
|
||
parts := splitTranscriptSentences(text)
|
||
out := make([]Segment, 0, len(parts))
|
||
var t float64
|
||
for _, part := range parts {
|
||
words := len(strings.Fields(part))
|
||
if words == 0 {
|
||
continue
|
||
}
|
||
duration := float64(words) * 0.42
|
||
if duration < 1.2 {
|
||
duration = 1.2
|
||
}
|
||
segment := Segment{Start: t, End: t + duration, Text: part}
|
||
out = append(out, segment)
|
||
t = segment.End
|
||
}
|
||
if len(out) == 0 && strings.TrimSpace(text) != "" {
|
||
out = append(out, Segment{Start: 0, End: 0, Text: strings.TrimSpace(text)})
|
||
}
|
||
return out
|
||
}
|
||
|
||
func splitTranscriptSentences(text string) []string {
|
||
text = strings.Join(strings.Fields(text), " ")
|
||
if text == "" {
|
||
return nil
|
||
}
|
||
var out []string
|
||
start := 0
|
||
runes := []rune(text)
|
||
for i, r := range runes {
|
||
if r != '.' && r != '!' && r != '?' && r != '…' {
|
||
continue
|
||
}
|
||
next := i + 1
|
||
if next < len(runes) && runes[next] != ' ' {
|
||
continue
|
||
}
|
||
part := strings.TrimSpace(string(runes[start : i+1]))
|
||
if part != "" {
|
||
out = append(out, part)
|
||
}
|
||
start = i + 1
|
||
for start < len(runes) && runes[start] == ' ' {
|
||
start++
|
||
}
|
||
}
|
||
tail := strings.TrimSpace(string(runes[start:]))
|
||
if tail != "" {
|
||
out = append(out, tail)
|
||
}
|
||
return mergeShortSegments(out, 8, 34)
|
||
}
|
||
|
||
func mergeShortSegments(parts []string, minWords, maxWords int) []string {
|
||
if len(parts) <= 1 {
|
||
return parts
|
||
}
|
||
out := make([]string, 0, len(parts))
|
||
var current []string
|
||
currentWords := 0
|
||
flush := func() {
|
||
if len(current) == 0 {
|
||
return
|
||
}
|
||
out = append(out, strings.Join(current, " "))
|
||
current = nil
|
||
currentWords = 0
|
||
}
|
||
for _, part := range parts {
|
||
words := len(strings.Fields(part))
|
||
if currentWords > 0 && currentWords+words > maxWords {
|
||
flush()
|
||
}
|
||
current = append(current, part)
|
||
currentWords += words
|
||
if currentWords >= minWords {
|
||
flush()
|
||
}
|
||
}
|
||
flush()
|
||
return out
|
||
}
|
||
|
||
func firstNonEmpty(values ...string) string {
|
||
for _, value := range values {
|
||
if strings.TrimSpace(value) != "" {
|
||
return value
|
||
}
|
||
}
|
||
return ""
|
||
}
|