Use verbose Whisper transcription output
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
@@ -103,6 +104,15 @@ type audioTranscriptionSegment struct {
|
|||||||
Text string `json:"text"`
|
Text string `json:"text"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type audioTranscriptionStatusError struct {
|
||||||
|
status int
|
||||||
|
body string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e audioTranscriptionStatusError) Error() string {
|
||||||
|
return fmt.Sprintf("audio transcription HTTP %d: %s", e.status, e.body)
|
||||||
|
}
|
||||||
|
|
||||||
func New(baseURL string, timeout time.Duration, ffmpegPath string, leadSilence time.Duration) *Client {
|
func New(baseURL string, timeout time.Duration, ffmpegPath string, leadSilence time.Duration) *Client {
|
||||||
return NewWithOptions(Options{
|
return NewWithOptions(Options{
|
||||||
AudioBaseURL: baseURL,
|
AudioBaseURL: baseURL,
|
||||||
@@ -233,7 +243,15 @@ func clampTime(v float64) float64 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) transcribeOpenAIAudio(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*audioLLMResponse, time.Duration, error) {
|
func (c *Client) transcribeOpenAIAudio(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*audioLLMResponse, time.Duration, error) {
|
||||||
return c.doOpenAIAudioTranscription(ctx, provider, audio, filename, in, "json")
|
resp, duration, err := c.doOpenAIAudioTranscription(ctx, provider, audio, filename, in, "verbose_json")
|
||||||
|
if err == nil {
|
||||||
|
return resp, duration, nil
|
||||||
|
}
|
||||||
|
if !isVerboseJSONUnsupported(err) {
|
||||||
|
return nil, duration, err
|
||||||
|
}
|
||||||
|
fallbackResp, fallbackDuration, fallbackErr := c.doOpenAIAudioTranscription(ctx, provider, audio, filename, in, "json")
|
||||||
|
return fallbackResp, duration + fallbackDuration, fallbackErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) doOpenAIAudioTranscription(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input, responseFormat string) (*audioLLMResponse, time.Duration, error) {
|
func (c *Client) doOpenAIAudioTranscription(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input, responseFormat string) (*audioLLMResponse, time.Duration, error) {
|
||||||
@@ -295,7 +313,7 @@ func (c *Client) doOpenAIAudioTranscription(ctx context.Context, provider Provid
|
|||||||
return nil, duration, fmt.Errorf("audio transcription read: %w", err)
|
return nil, duration, fmt.Errorf("audio transcription read: %w", err)
|
||||||
}
|
}
|
||||||
if resp.StatusCode >= 300 {
|
if resp.StatusCode >= 300 {
|
||||||
return nil, duration, fmt.Errorf("audio transcription HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(raw)))
|
return nil, duration, audioTranscriptionStatusError{status: resp.StatusCode, body: strings.TrimSpace(string(raw))}
|
||||||
}
|
}
|
||||||
var out audioTranscriptionResponse
|
var out audioTranscriptionResponse
|
||||||
if err := json.Unmarshal(raw, &out); err != nil {
|
if err := json.Unmarshal(raw, &out); err != nil {
|
||||||
@@ -313,6 +331,20 @@ func (c *Client) doOpenAIAudioTranscription(ctx context.Context, provider Provid
|
|||||||
}, duration, nil
|
}, duration, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isVerboseJSONUnsupported(err error) bool {
|
||||||
|
var statusErr audioTranscriptionStatusError
|
||||||
|
if !errors.As(err, &statusErr) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if statusErr.status != http.StatusBadRequest && statusErr.status != http.StatusUnprocessableEntity {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
body := strings.ToLower(statusErr.body)
|
||||||
|
return strings.Contains(body, "verbose_json") ||
|
||||||
|
strings.Contains(body, "response_format") ||
|
||||||
|
strings.Contains(body, "timestamp_granularities")
|
||||||
|
}
|
||||||
|
|
||||||
func convertAudioSegments(in []audioTranscriptionSegment) []Segment {
|
func convertAudioSegments(in []audioTranscriptionSegment) []Segment {
|
||||||
out := make([]Segment, 0, len(in))
|
out := make([]Segment, 0, len(in))
|
||||||
for _, s := range in {
|
for _, s := range in {
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ func TestWhisperUsesAudioTranscriptionsEndpoint(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer audioSrv.Close()
|
defer audioSrv.Close()
|
||||||
|
|
||||||
var gotPath, gotModel, gotResponseFormat, gotPrompt, gotTemperature string
|
var gotPath, gotModel, gotResponseFormat, gotPrompt, gotTemperature, gotTimestampGranularity string
|
||||||
providerSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
providerSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
gotPath = r.URL.Path
|
gotPath = r.URL.Path
|
||||||
if err := r.ParseMultipartForm(16 << 20); err != nil {
|
if err := r.ParseMultipartForm(16 << 20); err != nil {
|
||||||
@@ -38,6 +38,7 @@ func TestWhisperUsesAudioTranscriptionsEndpoint(t *testing.T) {
|
|||||||
gotResponseFormat = r.FormValue("response_format")
|
gotResponseFormat = r.FormValue("response_format")
|
||||||
gotPrompt = r.FormValue("prompt")
|
gotPrompt = r.FormValue("prompt")
|
||||||
gotTemperature = r.FormValue("temperature")
|
gotTemperature = r.FormValue("temperature")
|
||||||
|
gotTimestampGranularity = r.FormValue("timestamp_granularities[]")
|
||||||
if _, _, err := r.FormFile("file"); err != nil {
|
if _, _, err := r.FormFile("file"); err != nil {
|
||||||
t.Fatalf("FormFile: %v", err)
|
t.Fatalf("FormFile: %v", err)
|
||||||
}
|
}
|
||||||
@@ -68,12 +69,15 @@ func TestWhisperUsesAudioTranscriptionsEndpoint(t *testing.T) {
|
|||||||
if gotModel != "openai/whisper-large-v3" {
|
if gotModel != "openai/whisper-large-v3" {
|
||||||
t.Fatalf("model = %q", gotModel)
|
t.Fatalf("model = %q", gotModel)
|
||||||
}
|
}
|
||||||
if gotResponseFormat != "json" {
|
if gotResponseFormat != "verbose_json" {
|
||||||
t.Fatalf("response_format = %q, want json", gotResponseFormat)
|
t.Fatalf("response_format = %q, want verbose_json", gotResponseFormat)
|
||||||
}
|
}
|
||||||
if gotTemperature != "0" {
|
if gotTemperature != "0" {
|
||||||
t.Fatalf("temperature = %q, want 0", gotTemperature)
|
t.Fatalf("temperature = %q, want 0", gotTemperature)
|
||||||
}
|
}
|
||||||
|
if gotTimestampGranularity != "segment" {
|
||||||
|
t.Fatalf("timestamp_granularities[] = %q, want segment", gotTimestampGranularity)
|
||||||
|
}
|
||||||
if gotPrompt != "" {
|
if gotPrompt != "" {
|
||||||
t.Fatalf("prompt = %q, want empty", gotPrompt)
|
t.Fatalf("prompt = %q, want empty", gotPrompt)
|
||||||
}
|
}
|
||||||
@@ -82,6 +86,48 @@ func TestWhisperUsesAudioTranscriptionsEndpoint(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWhisperFallsBackToJSONWhenVerboseJSONUnsupported(t *testing.T) {
|
||||||
|
audioSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, _ = w.Write([]byte("fake audio"))
|
||||||
|
}))
|
||||||
|
defer audioSrv.Close()
|
||||||
|
|
||||||
|
var formats []string
|
||||||
|
providerSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if err := r.ParseMultipartForm(16 << 20); err != nil {
|
||||||
|
t.Fatalf("ParseMultipartForm: %v", err)
|
||||||
|
}
|
||||||
|
format := r.FormValue("response_format")
|
||||||
|
formats = append(formats, format)
|
||||||
|
if format == "verbose_json" {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"error": map[string]any{"message": "unsupported response_format verbose_json"},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"text": "Алло, fallback работает.",
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer providerSrv.Close()
|
||||||
|
|
||||||
|
client := NewWithOptions(Options{
|
||||||
|
AudioBaseURL: providerSrv.URL,
|
||||||
|
AudioModel: "openai/whisper-large-v3",
|
||||||
|
})
|
||||||
|
got, err := client.Transcribe(t.Context(), Input{AudioURL: audioSrv.URL, Filename: "call.mp3"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Transcribe: %v", err)
|
||||||
|
}
|
||||||
|
if len(formats) != 2 || formats[0] != "verbose_json" || formats[1] != "json" {
|
||||||
|
t.Fatalf("formats = %#v, want verbose_json then json", formats)
|
||||||
|
}
|
||||||
|
if got.Segments[0].Text != "Алло, fallback работает." {
|
||||||
|
t.Fatalf("segments = %#v", got.Segments)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestSegmentTranscriptTextDoesNotInventSpeakers(t *testing.T) {
|
func TestSegmentTranscriptTextDoesNotInventSpeakers(t *testing.T) {
|
||||||
got := segmentTranscriptText("Алло, добрый день. Да, слушаю. Скажите, квартира продается? Да, продается.")
|
got := segmentTranscriptText("Алло, добрый день. Да, слушаю. Скажите, квартира продается? Да, продается.")
|
||||||
if len(got) < 2 {
|
if len(got) < 2 {
|
||||||
|
|||||||
Reference in New Issue
Block a user