diff --git a/internal/transcription/client.go b/internal/transcription/client.go index 39a713c..1ce6ca4 100644 --- a/internal/transcription/client.go +++ b/internal/transcription/client.go @@ -9,6 +9,7 @@ import ( "mime/multipart" "net/http" "path/filepath" + "regexp" "strings" "time" ) @@ -20,6 +21,8 @@ type Client struct { const ProviderVoxtral = "voxtral-small" +var speakerLabelPattern = regexp.MustCompile(`(?i)(?:^|[\n\r ]+)((?:speaker|спикер|говорящий)\s*\d+)\s*[::-]`) + type Options struct { VoxtralBaseURL string VoxtralAPIKey string @@ -230,18 +233,7 @@ func clampTime(v float64) float64 { } func (c *Client) transcribeOpenAIAudio(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*audioLLMResponse, time.Duration, error) { - resp, duration, err := c.doOpenAIAudioTranscription(ctx, provider, audio, filename, in, "verbose_json") - if err == nil { - return resp, duration, nil - } - if !strings.Contains(strings.ToLower(err.Error()), "http 4") { - return nil, duration, err - } - fallback, fallbackDuration, fallbackErr := c.doOpenAIAudioTranscription(ctx, provider, audio, filename, in, "json") - if fallbackErr != nil { - return nil, duration + fallbackDuration, err - } - return fallback, duration + fallbackDuration, nil + return c.doOpenAIAudioTranscription(ctx, provider, audio, filename, in, "json") } func (c *Client) doOpenAIAudioTranscription(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input, responseFormat string) (*audioLLMResponse, time.Duration, error) { @@ -336,45 +328,61 @@ func convertAudioSegments(in []audioTranscriptionSegment) []Segment { func normalizeAudioLLMSegments(segments []Segment, text string, diarize bool) []Segment { text = strings.TrimSpace(text) + if text != "" { + if labeled := segmentSpeakerLabeledText(text); len(labeled) > 0 { + return labeled + } + } if len(segments) <= 1 && text != "" { - heuristic := segmentTranscriptText(text, diarize) + heuristic := segmentTranscriptText(text) if len(heuristic) > len(segments) { segments = heuristic } } - return ensureHeuristicSpeakers(segments, diarize) + return segments } -func ensureHeuristicSpeakers(segments []Segment, diarize bool) []Segment { - if !diarize || len(segments) < 2 || segmentsHaveSpeakers(segments) { - return segments +func segmentSpeakerLabeledText(text string) []Segment { + matches := speakerLabelPattern.FindAllStringSubmatchIndex(text, -1) + if len(matches) == 0 { + return nil } - out := make([]Segment, len(segments)) - copy(out, segments) - for i := range out { - if i%2 == 0 { - out[i].Speaker = "SPEAKER_00" - } else { - out[i].Speaker = "SPEAKER_01" + speakerIDs := map[string]string{} + var out []Segment + var t float64 + for i, match := range matches { + label := strings.ToLower(strings.TrimSpace(text[match[2]:match[3]])) + speaker, ok := speakerIDs[label] + if !ok { + speaker = fmt.Sprintf("SPEAKER_%02d", len(speakerIDs)) + speakerIDs[label] = speaker } + start := match[1] + end := len(text) + if i+1 < len(matches) { + end = matches[i+1][0] + } + part := strings.TrimSpace(text[start:end]) + part = strings.Trim(part, ":-— ") + if part == "" { + continue + } + words := len(strings.Fields(part)) + duration := float64(words) * 0.42 + if duration < 1.2 { + duration = 1.2 + } + out = append(out, Segment{Start: t, End: t + duration, Text: part, Speaker: speaker}) + t += duration } return out } -func segmentsHaveSpeakers(segments []Segment) bool { - for _, segment := range segments { - if strings.TrimSpace(segment.Speaker) != "" { - return true - } - } - return false -} - -func segmentTranscriptText(text string, diarize bool) []Segment { +func segmentTranscriptText(text string) []Segment { parts := splitTranscriptSentences(text) out := make([]Segment, 0, len(parts)) var t float64 - for i, part := range parts { + for _, part := range parts { words := len(strings.Fields(part)) if words == 0 { continue @@ -384,13 +392,6 @@ func segmentTranscriptText(text string, diarize bool) []Segment { duration = 1.2 } segment := Segment{Start: t, End: t + duration, Text: part} - if diarize && len(parts) > 1 { - if i%2 == 0 { - segment.Speaker = "SPEAKER_00" - } else { - segment.Speaker = "SPEAKER_01" - } - } out = append(out, segment) t = segment.End } diff --git a/internal/transcription/client_test.go b/internal/transcription/client_test.go index 670ded1..fcc6049 100644 --- a/internal/transcription/client_test.go +++ b/internal/transcription/client_test.go @@ -66,21 +66,21 @@ func TestVoxtralUsesAudioTranscriptionsEndpoint(t *testing.T) { if gotModel != "mistralai/Voxtral-Small-24B-2507" { t.Fatalf("model = %q", gotModel) } - if gotResponseFormat != "verbose_json" { - t.Fatalf("response_format = %q, want verbose_json", gotResponseFormat) + if gotResponseFormat != "json" { + t.Fatalf("response_format = %q, want json", gotResponseFormat) } if len(got.Segments) != 2 || got.Segments[0].Text != "Алло, тест." || got.Segments[1].Start != 1.2 { t.Fatalf("segments = %#v", got.Segments) } } -func TestSegmentTranscriptTextAddsHeuristicSpeakers(t *testing.T) { - got := segmentTranscriptText("Алло, добрый день. Да, слушаю. Скажите, квартира продается? Да, продается.", true) +func TestSegmentTranscriptTextDoesNotInventSpeakers(t *testing.T) { + got := segmentTranscriptText("Алло, добрый день. Да, слушаю. Скажите, квартира продается? Да, продается.") if len(got) < 2 { t.Fatalf("segments = %#v, want multiple", got) } - if got[0].Speaker != "SPEAKER_00" || got[1].Speaker != "SPEAKER_01" { - t.Fatalf("speakers = %q/%q", got[0].Speaker, got[1].Speaker) + if got[0].Speaker != "" || got[1].Speaker != "" { + t.Fatalf("speakers = %q/%q, want empty", got[0].Speaker, got[1].Speaker) } if got[1].Start <= got[0].Start { t.Fatalf("segment times did not advance: %#v", got) @@ -93,12 +93,12 @@ func TestNormalizeAudioLLMSegmentsSplitsSingleLongSegment(t *testing.T) { if len(got) < 2 { t.Fatalf("segments = %#v, want heuristic split", got) } - if got[0].Speaker != "SPEAKER_00" || got[1].Speaker != "SPEAKER_01" { - t.Fatalf("speakers = %q/%q", got[0].Speaker, got[1].Speaker) + if got[0].Speaker != "" || got[1].Speaker != "" { + t.Fatalf("speakers = %q/%q, want empty", got[0].Speaker, got[1].Speaker) } } -func TestNormalizeAudioLLMSegmentsKeepsSegmentsAndAddsSpeakers(t *testing.T) { +func TestNormalizeAudioLLMSegmentsKeepsSegmentsWithoutFakeSpeakers(t *testing.T) { got := normalizeAudioLLMSegments([]Segment{ {Start: 0, End: 1, Text: "Алло."}, {Start: 1, End: 2, Text: "Да, слушаю."}, @@ -106,7 +106,21 @@ func TestNormalizeAudioLLMSegmentsKeepsSegmentsAndAddsSpeakers(t *testing.T) { if len(got) != 2 { t.Fatalf("segments = %#v", got) } - if got[0].Speaker != "SPEAKER_00" || got[1].Speaker != "SPEAKER_01" { - t.Fatalf("speakers = %q/%q", got[0].Speaker, got[1].Speaker) + if got[0].Speaker != "" || got[1].Speaker != "" { + t.Fatalf("speakers = %q/%q, want empty", got[0].Speaker, got[1].Speaker) + } +} + +func TestNormalizeAudioLLMSegmentsUsesExplicitSpeakerLabels(t *testing.T) { + text := "Спикер 1: Алло, добрый день. Спикер 2: Да, слушаю. Спикер 1: Скажите, квартира продается?" + got := normalizeAudioLLMSegments(nil, text, true) + if len(got) != 3 { + t.Fatalf("segments = %#v, want 3", got) + } + if got[0].Speaker != "SPEAKER_00" || got[1].Speaker != "SPEAKER_01" || got[2].Speaker != "SPEAKER_00" { + t.Fatalf("speakers = %q/%q/%q", got[0].Speaker, got[1].Speaker, got[2].Speaker) + } + if got[0].Text != "Алло, добрый день." || got[1].Text != "Да, слушаю." { + t.Fatalf("texts = %#v", got) } }