Segment Voxtral transcripts for telephony
This commit is contained in:
@@ -105,6 +105,8 @@ type whisperResponse struct {
|
|||||||
type audioLLMResponse struct {
|
type audioLLMResponse struct {
|
||||||
Text string
|
Text string
|
||||||
Model string
|
Model string
|
||||||
|
Language string
|
||||||
|
Segments []Segment
|
||||||
}
|
}
|
||||||
|
|
||||||
type audioLLMChatRequest struct {
|
type audioLLMChatRequest struct {
|
||||||
@@ -145,11 +147,18 @@ type audioTranscriptionResponse struct {
|
|||||||
Text string `json:"text"`
|
Text string `json:"text"`
|
||||||
Language string `json:"language,omitempty"`
|
Language string `json:"language,omitempty"`
|
||||||
Model string `json:"model,omitempty"`
|
Model string `json:"model,omitempty"`
|
||||||
|
Segments []audioTranscriptionSegment `json:"segments,omitempty"`
|
||||||
Error *struct {
|
Error *struct {
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
} `json:"error,omitempty"`
|
} `json:"error,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type audioTranscriptionSegment struct {
|
||||||
|
Start float64 `json:"start"`
|
||||||
|
End float64 `json:"end"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
}
|
||||||
|
|
||||||
func New(baseURL string, timeout time.Duration, ffmpegPath string, leadSilence time.Duration) *Client {
|
func New(baseURL string, timeout time.Duration, ffmpegPath string, leadSilence time.Duration) *Client {
|
||||||
return NewWithOptions(Options{
|
return NewWithOptions(Options{
|
||||||
Providers: []string{ProviderWhisperX},
|
Providers: []string{ProviderWhisperX},
|
||||||
@@ -368,7 +377,10 @@ func (c *Client) transcribeWithProvider(ctx context.Context, provider ProviderCo
|
|||||||
return nil, attempt, err
|
return nil, attempt, err
|
||||||
}
|
}
|
||||||
text := strings.TrimSpace(resp.Text)
|
text := strings.TrimSpace(resp.Text)
|
||||||
segments := []Segment{{Start: 0, End: 0, Text: text}}
|
segments := resp.Segments
|
||||||
|
if len(segments) == 0 {
|
||||||
|
segments = segmentTranscriptText(text, in.Diarize)
|
||||||
|
}
|
||||||
attempt.Status = "ok"
|
attempt.Status = "ok"
|
||||||
attempt.Model = resp.Model
|
attempt.Model = resp.Model
|
||||||
attempt.Text = text
|
attempt.Text = text
|
||||||
@@ -376,7 +388,7 @@ func (c *Client) transcribeWithProvider(ctx context.Context, provider ProviderCo
|
|||||||
return &Result{
|
return &Result{
|
||||||
Provider: provider.Name,
|
Provider: provider.Name,
|
||||||
Model: resp.Model,
|
Model: resp.Model,
|
||||||
Language: firstNonEmpty(in.Language, "unknown"),
|
Language: firstNonEmpty(resp.Language, in.Language, "unknown"),
|
||||||
Segments: segments,
|
Segments: segments,
|
||||||
DurationMS: duration.Milliseconds(),
|
DurationMS: duration.Milliseconds(),
|
||||||
}, attempt, nil
|
}, attempt, nil
|
||||||
@@ -633,14 +645,34 @@ func (c *Client) transcribeAudioLLM(ctx context.Context, provider ProviderConfig
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) transcribeOpenAIAudio(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*audioLLMResponse, time.Duration, error) {
|
func (c *Client) transcribeOpenAIAudio(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*audioLLMResponse, time.Duration, error) {
|
||||||
|
resp, duration, err := c.doOpenAIAudioTranscription(ctx, provider, audio, filename, in, "verbose_json")
|
||||||
|
if err == nil {
|
||||||
|
return resp, duration, nil
|
||||||
|
}
|
||||||
|
if !strings.Contains(strings.ToLower(err.Error()), "http 4") {
|
||||||
|
return nil, duration, err
|
||||||
|
}
|
||||||
|
fallback, fallbackDuration, fallbackErr := c.doOpenAIAudioTranscription(ctx, provider, audio, filename, in, "json")
|
||||||
|
if fallbackErr != nil {
|
||||||
|
return nil, duration + fallbackDuration, err
|
||||||
|
}
|
||||||
|
return fallback, duration + fallbackDuration, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) doOpenAIAudioTranscription(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input, responseFormat string) (*audioLLMResponse, time.Duration, error) {
|
||||||
body := &bytes.Buffer{}
|
body := &bytes.Buffer{}
|
||||||
mw := multipart.NewWriter(body)
|
mw := multipart.NewWriter(body)
|
||||||
if err := mw.WriteField("model", provider.Model); err != nil {
|
if err := mw.WriteField("model", provider.Model); err != nil {
|
||||||
return nil, 0, fmt.Errorf("audio transcription model field: %w", err)
|
return nil, 0, fmt.Errorf("audio transcription model field: %w", err)
|
||||||
}
|
}
|
||||||
if err := mw.WriteField("response_format", "json"); err != nil {
|
if err := mw.WriteField("response_format", responseFormat); err != nil {
|
||||||
return nil, 0, fmt.Errorf("audio transcription response_format field: %w", err)
|
return nil, 0, fmt.Errorf("audio transcription response_format field: %w", err)
|
||||||
}
|
}
|
||||||
|
if responseFormat == "verbose_json" {
|
||||||
|
if err := mw.WriteField("timestamp_granularities[]", "segment"); err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("audio transcription timestamp field: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
if prompt := strings.TrimSpace(provider.Prompt); prompt != "" {
|
if prompt := strings.TrimSpace(provider.Prompt); prompt != "" {
|
||||||
if err := mw.WriteField("prompt", prompt); err != nil {
|
if err := mw.WriteField("prompt", prompt); err != nil {
|
||||||
return nil, 0, fmt.Errorf("audio transcription prompt field: %w", err)
|
return nil, 0, fmt.Errorf("audio transcription prompt field: %w", err)
|
||||||
@@ -693,7 +725,120 @@ func (c *Client) transcribeOpenAIAudio(ctx context.Context, provider ProviderCon
|
|||||||
return nil, duration, fmt.Errorf("audio transcription error: %s", out.Error.Message)
|
return nil, duration, fmt.Errorf("audio transcription error: %s", out.Error.Message)
|
||||||
}
|
}
|
||||||
modelName := firstNonEmpty(out.Model, provider.Model)
|
modelName := firstNonEmpty(out.Model, provider.Model)
|
||||||
return &audioLLMResponse{Text: strings.TrimSpace(out.Text), Model: modelName}, duration, nil
|
return &audioLLMResponse{
|
||||||
|
Text: strings.TrimSpace(out.Text),
|
||||||
|
Model: modelName,
|
||||||
|
Language: out.Language,
|
||||||
|
Segments: convertAudioSegments(out.Segments),
|
||||||
|
}, duration, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertAudioSegments(in []audioTranscriptionSegment) []Segment {
|
||||||
|
out := make([]Segment, 0, len(in))
|
||||||
|
for _, s := range in {
|
||||||
|
text := strings.TrimSpace(s.Text)
|
||||||
|
if text == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
end := s.End
|
||||||
|
if end < s.Start {
|
||||||
|
end = s.Start
|
||||||
|
}
|
||||||
|
out = append(out, Segment{Start: clampTime(s.Start), End: clampTime(end), Text: text})
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func segmentTranscriptText(text string, diarize bool) []Segment {
|
||||||
|
parts := splitTranscriptSentences(text)
|
||||||
|
out := make([]Segment, 0, len(parts))
|
||||||
|
var t float64
|
||||||
|
for i, part := range parts {
|
||||||
|
words := len(strings.Fields(part))
|
||||||
|
if words == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
duration := float64(words) * 0.42
|
||||||
|
if duration < 1.2 {
|
||||||
|
duration = 1.2
|
||||||
|
}
|
||||||
|
segment := Segment{Start: t, End: t + duration, Text: part}
|
||||||
|
if diarize && len(parts) > 1 {
|
||||||
|
if i%2 == 0 {
|
||||||
|
segment.Speaker = "SPEAKER_00"
|
||||||
|
} else {
|
||||||
|
segment.Speaker = "SPEAKER_01"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out = append(out, segment)
|
||||||
|
t = segment.End
|
||||||
|
}
|
||||||
|
if len(out) == 0 && strings.TrimSpace(text) != "" {
|
||||||
|
out = append(out, Segment{Start: 0, End: 0, Text: strings.TrimSpace(text)})
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func splitTranscriptSentences(text string) []string {
|
||||||
|
text = strings.Join(strings.Fields(text), " ")
|
||||||
|
if text == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var out []string
|
||||||
|
start := 0
|
||||||
|
runes := []rune(text)
|
||||||
|
for i, r := range runes {
|
||||||
|
if r != '.' && r != '!' && r != '?' && r != '…' {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
next := i + 1
|
||||||
|
if next < len(runes) && runes[next] != ' ' {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
part := strings.TrimSpace(string(runes[start : i+1]))
|
||||||
|
if part != "" {
|
||||||
|
out = append(out, part)
|
||||||
|
}
|
||||||
|
start = i + 1
|
||||||
|
for start < len(runes) && runes[start] == ' ' {
|
||||||
|
start++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tail := strings.TrimSpace(string(runes[start:]))
|
||||||
|
if tail != "" {
|
||||||
|
out = append(out, tail)
|
||||||
|
}
|
||||||
|
return mergeShortSegments(out, 8, 34)
|
||||||
|
}
|
||||||
|
|
||||||
|
func mergeShortSegments(parts []string, minWords, maxWords int) []string {
|
||||||
|
if len(parts) <= 1 {
|
||||||
|
return parts
|
||||||
|
}
|
||||||
|
out := make([]string, 0, len(parts))
|
||||||
|
var current []string
|
||||||
|
currentWords := 0
|
||||||
|
flush := func() {
|
||||||
|
if len(current) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
out = append(out, strings.Join(current, " "))
|
||||||
|
current = nil
|
||||||
|
currentWords = 0
|
||||||
|
}
|
||||||
|
for _, part := range parts {
|
||||||
|
words := len(strings.Fields(part))
|
||||||
|
if currentWords > 0 && currentWords+words > maxWords {
|
||||||
|
flush()
|
||||||
|
}
|
||||||
|
current = append(current, part)
|
||||||
|
currentWords += words
|
||||||
|
if currentWords >= minWords {
|
||||||
|
flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
flush()
|
||||||
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
func audioFormat(filename string) string {
|
func audioFormat(filename string) string {
|
||||||
|
|||||||
@@ -75,17 +75,24 @@ func TestVoxtralUsesAudioTranscriptionsEndpoint(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer audioSrv.Close()
|
defer audioSrv.Close()
|
||||||
|
|
||||||
var gotPath, gotModel string
|
var gotPath, gotModel, gotResponseFormat string
|
||||||
providerSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
providerSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
gotPath = r.URL.Path
|
gotPath = r.URL.Path
|
||||||
if err := r.ParseMultipartForm(16 << 20); err != nil {
|
if err := r.ParseMultipartForm(16 << 20); err != nil {
|
||||||
t.Fatalf("ParseMultipartForm: %v", err)
|
t.Fatalf("ParseMultipartForm: %v", err)
|
||||||
}
|
}
|
||||||
gotModel = r.FormValue("model")
|
gotModel = r.FormValue("model")
|
||||||
|
gotResponseFormat = r.FormValue("response_format")
|
||||||
if _, _, err := r.FormFile("file"); err != nil {
|
if _, _, err := r.FormFile("file"); err != nil {
|
||||||
t.Fatalf("FormFile: %v", err)
|
t.Fatalf("FormFile: %v", err)
|
||||||
}
|
}
|
||||||
_ = json.NewEncoder(w).Encode(map[string]string{"text": "Алло, тест."})
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"text": "Алло, тест. Да, слышно.",
|
||||||
|
"segments": []map[string]any{
|
||||||
|
{"start": 0, "end": 1.2, "text": "Алло, тест."},
|
||||||
|
{"start": 1.2, "end": 2.4, "text": "Да, слышно."},
|
||||||
|
},
|
||||||
|
})
|
||||||
}))
|
}))
|
||||||
defer providerSrv.Close()
|
defer providerSrv.Close()
|
||||||
|
|
||||||
@@ -107,11 +114,27 @@ func TestVoxtralUsesAudioTranscriptionsEndpoint(t *testing.T) {
|
|||||||
if gotModel != "mistralai/Voxtral-Small-24B-2507" {
|
if gotModel != "mistralai/Voxtral-Small-24B-2507" {
|
||||||
t.Fatalf("model = %q", gotModel)
|
t.Fatalf("model = %q", gotModel)
|
||||||
}
|
}
|
||||||
if len(got.Segments) != 1 || got.Segments[0].Text != "Алло, тест." {
|
if gotResponseFormat != "verbose_json" {
|
||||||
|
t.Fatalf("response_format = %q, want verbose_json", gotResponseFormat)
|
||||||
|
}
|
||||||
|
if len(got.Segments) != 2 || got.Segments[0].Text != "Алло, тест." || got.Segments[1].Start != 1.2 {
|
||||||
t.Fatalf("segments = %#v", got.Segments)
|
t.Fatalf("segments = %#v", got.Segments)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSegmentTranscriptTextAddsHeuristicSpeakers(t *testing.T) {
|
||||||
|
got := segmentTranscriptText("Алло, добрый день. Да, слушаю. Скажите, квартира продается? Да, продается.", true)
|
||||||
|
if len(got) < 2 {
|
||||||
|
t.Fatalf("segments = %#v, want multiple", got)
|
||||||
|
}
|
||||||
|
if got[0].Speaker != "SPEAKER_00" || got[1].Speaker != "SPEAKER_01" {
|
||||||
|
t.Fatalf("speakers = %q/%q", got[0].Speaker, got[1].Speaker)
|
||||||
|
}
|
||||||
|
if got[1].Start <= got[0].Start {
|
||||||
|
t.Fatalf("segment times did not advance: %#v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func near(got, want float64) bool {
|
func near(got, want float64) bool {
|
||||||
return math.Abs(got-want) < 0.000001
|
return math.Abs(got-want) < 0.000001
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user