Segment Voxtral transcripts for telephony
This commit is contained in:
@@ -103,8 +103,10 @@ type whisperResponse struct {
|
||||
}
|
||||
|
||||
type audioLLMResponse struct {
|
||||
Text string
|
||||
Model string
|
||||
Text string
|
||||
Model string
|
||||
Language string
|
||||
Segments []Segment
|
||||
}
|
||||
|
||||
type audioLLMChatRequest struct {
|
||||
@@ -142,14 +144,21 @@ type audioLLMChatResponse struct {
|
||||
}
|
||||
|
||||
type audioTranscriptionResponse struct {
|
||||
Text string `json:"text"`
|
||||
Language string `json:"language,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Text string `json:"text"`
|
||||
Language string `json:"language,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Segments []audioTranscriptionSegment `json:"segments,omitempty"`
|
||||
Error *struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type audioTranscriptionSegment struct {
|
||||
Start float64 `json:"start"`
|
||||
End float64 `json:"end"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
func New(baseURL string, timeout time.Duration, ffmpegPath string, leadSilence time.Duration) *Client {
|
||||
return NewWithOptions(Options{
|
||||
Providers: []string{ProviderWhisperX},
|
||||
@@ -368,7 +377,10 @@ func (c *Client) transcribeWithProvider(ctx context.Context, provider ProviderCo
|
||||
return nil, attempt, err
|
||||
}
|
||||
text := strings.TrimSpace(resp.Text)
|
||||
segments := []Segment{{Start: 0, End: 0, Text: text}}
|
||||
segments := resp.Segments
|
||||
if len(segments) == 0 {
|
||||
segments = segmentTranscriptText(text, in.Diarize)
|
||||
}
|
||||
attempt.Status = "ok"
|
||||
attempt.Model = resp.Model
|
||||
attempt.Text = text
|
||||
@@ -376,7 +388,7 @@ func (c *Client) transcribeWithProvider(ctx context.Context, provider ProviderCo
|
||||
return &Result{
|
||||
Provider: provider.Name,
|
||||
Model: resp.Model,
|
||||
Language: firstNonEmpty(in.Language, "unknown"),
|
||||
Language: firstNonEmpty(resp.Language, in.Language, "unknown"),
|
||||
Segments: segments,
|
||||
DurationMS: duration.Milliseconds(),
|
||||
}, attempt, nil
|
||||
@@ -633,14 +645,34 @@ func (c *Client) transcribeAudioLLM(ctx context.Context, provider ProviderConfig
|
||||
}
|
||||
|
||||
func (c *Client) transcribeOpenAIAudio(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*audioLLMResponse, time.Duration, error) {
|
||||
resp, duration, err := c.doOpenAIAudioTranscription(ctx, provider, audio, filename, in, "verbose_json")
|
||||
if err == nil {
|
||||
return resp, duration, nil
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(err.Error()), "http 4") {
|
||||
return nil, duration, err
|
||||
}
|
||||
fallback, fallbackDuration, fallbackErr := c.doOpenAIAudioTranscription(ctx, provider, audio, filename, in, "json")
|
||||
if fallbackErr != nil {
|
||||
return nil, duration + fallbackDuration, err
|
||||
}
|
||||
return fallback, duration + fallbackDuration, nil
|
||||
}
|
||||
|
||||
func (c *Client) doOpenAIAudioTranscription(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input, responseFormat string) (*audioLLMResponse, time.Duration, error) {
|
||||
body := &bytes.Buffer{}
|
||||
mw := multipart.NewWriter(body)
|
||||
if err := mw.WriteField("model", provider.Model); err != nil {
|
||||
return nil, 0, fmt.Errorf("audio transcription model field: %w", err)
|
||||
}
|
||||
if err := mw.WriteField("response_format", "json"); err != nil {
|
||||
if err := mw.WriteField("response_format", responseFormat); err != nil {
|
||||
return nil, 0, fmt.Errorf("audio transcription response_format field: %w", err)
|
||||
}
|
||||
if responseFormat == "verbose_json" {
|
||||
if err := mw.WriteField("timestamp_granularities[]", "segment"); err != nil {
|
||||
return nil, 0, fmt.Errorf("audio transcription timestamp field: %w", err)
|
||||
}
|
||||
}
|
||||
if prompt := strings.TrimSpace(provider.Prompt); prompt != "" {
|
||||
if err := mw.WriteField("prompt", prompt); err != nil {
|
||||
return nil, 0, fmt.Errorf("audio transcription prompt field: %w", err)
|
||||
@@ -693,7 +725,120 @@ func (c *Client) transcribeOpenAIAudio(ctx context.Context, provider ProviderCon
|
||||
return nil, duration, fmt.Errorf("audio transcription error: %s", out.Error.Message)
|
||||
}
|
||||
modelName := firstNonEmpty(out.Model, provider.Model)
|
||||
return &audioLLMResponse{Text: strings.TrimSpace(out.Text), Model: modelName}, duration, nil
|
||||
return &audioLLMResponse{
|
||||
Text: strings.TrimSpace(out.Text),
|
||||
Model: modelName,
|
||||
Language: out.Language,
|
||||
Segments: convertAudioSegments(out.Segments),
|
||||
}, duration, nil
|
||||
}
|
||||
|
||||
func convertAudioSegments(in []audioTranscriptionSegment) []Segment {
|
||||
out := make([]Segment, 0, len(in))
|
||||
for _, s := range in {
|
||||
text := strings.TrimSpace(s.Text)
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
end := s.End
|
||||
if end < s.Start {
|
||||
end = s.Start
|
||||
}
|
||||
out = append(out, Segment{Start: clampTime(s.Start), End: clampTime(end), Text: text})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func segmentTranscriptText(text string, diarize bool) []Segment {
|
||||
parts := splitTranscriptSentences(text)
|
||||
out := make([]Segment, 0, len(parts))
|
||||
var t float64
|
||||
for i, part := range parts {
|
||||
words := len(strings.Fields(part))
|
||||
if words == 0 {
|
||||
continue
|
||||
}
|
||||
duration := float64(words) * 0.42
|
||||
if duration < 1.2 {
|
||||
duration = 1.2
|
||||
}
|
||||
segment := Segment{Start: t, End: t + duration, Text: part}
|
||||
if diarize && len(parts) > 1 {
|
||||
if i%2 == 0 {
|
||||
segment.Speaker = "SPEAKER_00"
|
||||
} else {
|
||||
segment.Speaker = "SPEAKER_01"
|
||||
}
|
||||
}
|
||||
out = append(out, segment)
|
||||
t = segment.End
|
||||
}
|
||||
if len(out) == 0 && strings.TrimSpace(text) != "" {
|
||||
out = append(out, Segment{Start: 0, End: 0, Text: strings.TrimSpace(text)})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func splitTranscriptSentences(text string) []string {
|
||||
text = strings.Join(strings.Fields(text), " ")
|
||||
if text == "" {
|
||||
return nil
|
||||
}
|
||||
var out []string
|
||||
start := 0
|
||||
runes := []rune(text)
|
||||
for i, r := range runes {
|
||||
if r != '.' && r != '!' && r != '?' && r != '…' {
|
||||
continue
|
||||
}
|
||||
next := i + 1
|
||||
if next < len(runes) && runes[next] != ' ' {
|
||||
continue
|
||||
}
|
||||
part := strings.TrimSpace(string(runes[start : i+1]))
|
||||
if part != "" {
|
||||
out = append(out, part)
|
||||
}
|
||||
start = i + 1
|
||||
for start < len(runes) && runes[start] == ' ' {
|
||||
start++
|
||||
}
|
||||
}
|
||||
tail := strings.TrimSpace(string(runes[start:]))
|
||||
if tail != "" {
|
||||
out = append(out, tail)
|
||||
}
|
||||
return mergeShortSegments(out, 8, 34)
|
||||
}
|
||||
|
||||
func mergeShortSegments(parts []string, minWords, maxWords int) []string {
|
||||
if len(parts) <= 1 {
|
||||
return parts
|
||||
}
|
||||
out := make([]string, 0, len(parts))
|
||||
var current []string
|
||||
currentWords := 0
|
||||
flush := func() {
|
||||
if len(current) == 0 {
|
||||
return
|
||||
}
|
||||
out = append(out, strings.Join(current, " "))
|
||||
current = nil
|
||||
currentWords = 0
|
||||
}
|
||||
for _, part := range parts {
|
||||
words := len(strings.Fields(part))
|
||||
if currentWords > 0 && currentWords+words > maxWords {
|
||||
flush()
|
||||
}
|
||||
current = append(current, part)
|
||||
currentWords += words
|
||||
if currentWords >= minWords {
|
||||
flush()
|
||||
}
|
||||
}
|
||||
flush()
|
||||
return out
|
||||
}
|
||||
|
||||
func audioFormat(filename string) string {
|
||||
|
||||
@@ -75,17 +75,24 @@ func TestVoxtralUsesAudioTranscriptionsEndpoint(t *testing.T) {
|
||||
}))
|
||||
defer audioSrv.Close()
|
||||
|
||||
var gotPath, gotModel string
|
||||
var gotPath, gotModel, gotResponseFormat string
|
||||
providerSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
if err := r.ParseMultipartForm(16 << 20); err != nil {
|
||||
t.Fatalf("ParseMultipartForm: %v", err)
|
||||
}
|
||||
gotModel = r.FormValue("model")
|
||||
gotResponseFormat = r.FormValue("response_format")
|
||||
if _, _, err := r.FormFile("file"); err != nil {
|
||||
t.Fatalf("FormFile: %v", err)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"text": "Алло, тест."})
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"text": "Алло, тест. Да, слышно.",
|
||||
"segments": []map[string]any{
|
||||
{"start": 0, "end": 1.2, "text": "Алло, тест."},
|
||||
{"start": 1.2, "end": 2.4, "text": "Да, слышно."},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer providerSrv.Close()
|
||||
|
||||
@@ -107,11 +114,27 @@ func TestVoxtralUsesAudioTranscriptionsEndpoint(t *testing.T) {
|
||||
if gotModel != "mistralai/Voxtral-Small-24B-2507" {
|
||||
t.Fatalf("model = %q", gotModel)
|
||||
}
|
||||
if len(got.Segments) != 1 || got.Segments[0].Text != "Алло, тест." {
|
||||
if gotResponseFormat != "verbose_json" {
|
||||
t.Fatalf("response_format = %q, want verbose_json", gotResponseFormat)
|
||||
}
|
||||
if len(got.Segments) != 2 || got.Segments[0].Text != "Алло, тест." || got.Segments[1].Start != 1.2 {
|
||||
t.Fatalf("segments = %#v", got.Segments)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSegmentTranscriptTextAddsHeuristicSpeakers(t *testing.T) {
|
||||
got := segmentTranscriptText("Алло, добрый день. Да, слушаю. Скажите, квартира продается? Да, продается.", true)
|
||||
if len(got) < 2 {
|
||||
t.Fatalf("segments = %#v, want multiple", got)
|
||||
}
|
||||
if got[0].Speaker != "SPEAKER_00" || got[1].Speaker != "SPEAKER_01" {
|
||||
t.Fatalf("speakers = %q/%q", got[0].Speaker, got[1].Speaker)
|
||||
}
|
||||
if got[1].Start <= got[0].Start {
|
||||
t.Fatalf("segment times did not advance: %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func near(got, want float64) bool {
|
||||
return math.Abs(got-want) < 0.000001
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user