From 64bf40b3ba104eaa7c66c5a4f947faa2dba5bfb6 Mon Sep 17 00:00:00 2001 From: Grendgi Date: Tue, 9 Jun 2026 16:12:57 +0300 Subject: [PATCH] Segment Voxtral transcripts for telephony --- internal/transcription/client.go | 163 ++++++++++++++++++++++++-- internal/transcription/client_test.go | 29 ++++- 2 files changed, 180 insertions(+), 12 deletions(-) diff --git a/internal/transcription/client.go b/internal/transcription/client.go index d62420c..44917be 100644 --- a/internal/transcription/client.go +++ b/internal/transcription/client.go @@ -103,8 +103,10 @@ type whisperResponse struct { } type audioLLMResponse struct { - Text string - Model string + Text string + Model string + Language string + Segments []Segment } type audioLLMChatRequest struct { @@ -142,14 +144,21 @@ type audioLLMChatResponse struct { } type audioTranscriptionResponse struct { - Text string `json:"text"` - Language string `json:"language,omitempty"` - Model string `json:"model,omitempty"` + Text string `json:"text"` + Language string `json:"language,omitempty"` + Model string `json:"model,omitempty"` + Segments []audioTranscriptionSegment `json:"segments,omitempty"` Error *struct { Message string `json:"message"` } `json:"error,omitempty"` } +type audioTranscriptionSegment struct { + Start float64 `json:"start"` + End float64 `json:"end"` + Text string `json:"text"` +} + func New(baseURL string, timeout time.Duration, ffmpegPath string, leadSilence time.Duration) *Client { return NewWithOptions(Options{ Providers: []string{ProviderWhisperX}, @@ -368,7 +377,10 @@ func (c *Client) transcribeWithProvider(ctx context.Context, provider ProviderCo return nil, attempt, err } text := strings.TrimSpace(resp.Text) - segments := []Segment{{Start: 0, End: 0, Text: text}} + segments := resp.Segments + if len(segments) == 0 { + segments = segmentTranscriptText(text, in.Diarize) + } attempt.Status = "ok" attempt.Model = resp.Model attempt.Text = text @@ -376,7 +388,7 @@ func (c *Client) transcribeWithProvider(ctx context.Context, provider ProviderCo return &Result{ Provider: provider.Name, Model: resp.Model, - Language: firstNonEmpty(in.Language, "unknown"), + Language: firstNonEmpty(resp.Language, in.Language, "unknown"), Segments: segments, DurationMS: duration.Milliseconds(), }, attempt, nil @@ -633,14 +645,34 @@ func (c *Client) transcribeAudioLLM(ctx context.Context, provider ProviderConfig } func (c *Client) transcribeOpenAIAudio(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*audioLLMResponse, time.Duration, error) { + resp, duration, err := c.doOpenAIAudioTranscription(ctx, provider, audio, filename, in, "verbose_json") + if err == nil { + return resp, duration, nil + } + if !strings.Contains(strings.ToLower(err.Error()), "http 4") { + return nil, duration, err + } + fallback, fallbackDuration, fallbackErr := c.doOpenAIAudioTranscription(ctx, provider, audio, filename, in, "json") + if fallbackErr != nil { + return nil, duration + fallbackDuration, err + } + return fallback, duration + fallbackDuration, nil +} + +func (c *Client) doOpenAIAudioTranscription(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input, responseFormat string) (*audioLLMResponse, time.Duration, error) { body := &bytes.Buffer{} mw := multipart.NewWriter(body) if err := mw.WriteField("model", provider.Model); err != nil { return nil, 0, fmt.Errorf("audio transcription model field: %w", err) } - if err := mw.WriteField("response_format", "json"); err != nil { + if err := mw.WriteField("response_format", responseFormat); err != nil { return nil, 0, fmt.Errorf("audio transcription response_format field: %w", err) } + if responseFormat == "verbose_json" { + if err := mw.WriteField("timestamp_granularities[]", "segment"); err != nil { + return nil, 0, fmt.Errorf("audio transcription timestamp field: %w", err) + } + } if prompt := strings.TrimSpace(provider.Prompt); prompt != "" { if err := mw.WriteField("prompt", prompt); err != nil { return nil, 0, fmt.Errorf("audio transcription prompt field: %w", err) @@ -693,7 +725,120 @@ func (c *Client) transcribeOpenAIAudio(ctx context.Context, provider ProviderCon return nil, duration, fmt.Errorf("audio transcription error: %s", out.Error.Message) } modelName := firstNonEmpty(out.Model, provider.Model) - return &audioLLMResponse{Text: strings.TrimSpace(out.Text), Model: modelName}, duration, nil + return &audioLLMResponse{ + Text: strings.TrimSpace(out.Text), + Model: modelName, + Language: out.Language, + Segments: convertAudioSegments(out.Segments), + }, duration, nil +} + +func convertAudioSegments(in []audioTranscriptionSegment) []Segment { + out := make([]Segment, 0, len(in)) + for _, s := range in { + text := strings.TrimSpace(s.Text) + if text == "" { + continue + } + end := s.End + if end < s.Start { + end = s.Start + } + out = append(out, Segment{Start: clampTime(s.Start), End: clampTime(end), Text: text}) + } + return out +} + +func segmentTranscriptText(text string, diarize bool) []Segment { + parts := splitTranscriptSentences(text) + out := make([]Segment, 0, len(parts)) + var t float64 + for i, part := range parts { + words := len(strings.Fields(part)) + if words == 0 { + continue + } + duration := float64(words) * 0.42 + if duration < 1.2 { + duration = 1.2 + } + segment := Segment{Start: t, End: t + duration, Text: part} + if diarize && len(parts) > 1 { + if i%2 == 0 { + segment.Speaker = "SPEAKER_00" + } else { + segment.Speaker = "SPEAKER_01" + } + } + out = append(out, segment) + t = segment.End + } + if len(out) == 0 && strings.TrimSpace(text) != "" { + out = append(out, Segment{Start: 0, End: 0, Text: strings.TrimSpace(text)}) + } + return out +} + +func splitTranscriptSentences(text string) []string { + text = strings.Join(strings.Fields(text), " ") + if text == "" { + return nil + } + var out []string + start := 0 + runes := []rune(text) + for i, r := range runes { + if r != '.' && r != '!' && r != '?' && r != '…' { + continue + } + next := i + 1 + if next < len(runes) && runes[next] != ' ' { + continue + } + part := strings.TrimSpace(string(runes[start : i+1])) + if part != "" { + out = append(out, part) + } + start = i + 1 + for start < len(runes) && runes[start] == ' ' { + start++ + } + } + tail := strings.TrimSpace(string(runes[start:])) + if tail != "" { + out = append(out, tail) + } + return mergeShortSegments(out, 8, 34) +} + +func mergeShortSegments(parts []string, minWords, maxWords int) []string { + if len(parts) <= 1 { + return parts + } + out := make([]string, 0, len(parts)) + var current []string + currentWords := 0 + flush := func() { + if len(current) == 0 { + return + } + out = append(out, strings.Join(current, " ")) + current = nil + currentWords = 0 + } + for _, part := range parts { + words := len(strings.Fields(part)) + if currentWords > 0 && currentWords+words > maxWords { + flush() + } + current = append(current, part) + currentWords += words + if currentWords >= minWords { + flush() + } + } + flush() + return out } func audioFormat(filename string) string { diff --git a/internal/transcription/client_test.go b/internal/transcription/client_test.go index 9bb76c2..a769bc3 100644 --- a/internal/transcription/client_test.go +++ b/internal/transcription/client_test.go @@ -75,17 +75,24 @@ func TestVoxtralUsesAudioTranscriptionsEndpoint(t *testing.T) { })) defer audioSrv.Close() - var gotPath, gotModel string + var gotPath, gotModel, gotResponseFormat string providerSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotPath = r.URL.Path if err := r.ParseMultipartForm(16 << 20); err != nil { t.Fatalf("ParseMultipartForm: %v", err) } gotModel = r.FormValue("model") + gotResponseFormat = r.FormValue("response_format") if _, _, err := r.FormFile("file"); err != nil { t.Fatalf("FormFile: %v", err) } - _ = json.NewEncoder(w).Encode(map[string]string{"text": "Алло, тест."}) + _ = json.NewEncoder(w).Encode(map[string]any{ + "text": "Алло, тест. Да, слышно.", + "segments": []map[string]any{ + {"start": 0, "end": 1.2, "text": "Алло, тест."}, + {"start": 1.2, "end": 2.4, "text": "Да, слышно."}, + }, + }) })) defer providerSrv.Close() @@ -107,11 +114,27 @@ func TestVoxtralUsesAudioTranscriptionsEndpoint(t *testing.T) { if gotModel != "mistralai/Voxtral-Small-24B-2507" { t.Fatalf("model = %q", gotModel) } - if len(got.Segments) != 1 || got.Segments[0].Text != "Алло, тест." { + if gotResponseFormat != "verbose_json" { + t.Fatalf("response_format = %q, want verbose_json", gotResponseFormat) + } + if len(got.Segments) != 2 || got.Segments[0].Text != "Алло, тест." || got.Segments[1].Start != 1.2 { t.Fatalf("segments = %#v", got.Segments) } } +func TestSegmentTranscriptTextAddsHeuristicSpeakers(t *testing.T) { + got := segmentTranscriptText("Алло, добрый день. Да, слушаю. Скажите, квартира продается? Да, продается.", true) + if len(got) < 2 { + t.Fatalf("segments = %#v, want multiple", got) + } + if got[0].Speaker != "SPEAKER_00" || got[1].Speaker != "SPEAKER_01" { + t.Fatalf("speakers = %q/%q", got[0].Speaker, got[1].Speaker) + } + if got[1].Start <= got[0].Start { + t.Fatalf("segment times did not advance: %#v", got) + } +} + func near(got, want float64) bool { return math.Abs(got-want) < 0.000001 }