181 lines
6.3 KiB
Go
181 lines
6.3 KiB
Go
package transcription
|
|
|
|
import (
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
)
|
|
|
|
func TestNewWithOptionsBuildsWhisperProvider(t *testing.T) {
|
|
client := NewWithOptions(Options{
|
|
AudioBaseURL: "http://whisper",
|
|
})
|
|
if client == nil {
|
|
t.Fatal("client is nil")
|
|
}
|
|
if client.provider.Name != ProviderWhisperLargeV3 {
|
|
t.Fatalf("provider = %q, want %q", client.provider.Name, ProviderWhisperLargeV3)
|
|
}
|
|
if client.provider.Model != "openai/whisper-large-v3" {
|
|
t.Fatalf("model = %q", client.provider.Model)
|
|
}
|
|
}
|
|
|
|
func TestWhisperUsesAudioTranscriptionsEndpoint(t *testing.T) {
|
|
audioSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = w.Write([]byte("fake audio"))
|
|
}))
|
|
defer audioSrv.Close()
|
|
|
|
var gotPath, gotModel, gotResponseFormat, gotPrompt, gotTemperature, gotTimestampGranularity string
|
|
providerSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
gotPath = r.URL.Path
|
|
if err := r.ParseMultipartForm(16 << 20); err != nil {
|
|
t.Fatalf("ParseMultipartForm: %v", err)
|
|
}
|
|
gotModel = r.FormValue("model")
|
|
gotResponseFormat = r.FormValue("response_format")
|
|
gotPrompt = r.FormValue("prompt")
|
|
gotTemperature = r.FormValue("temperature")
|
|
gotTimestampGranularity = r.FormValue("timestamp_granularities[]")
|
|
if _, _, err := r.FormFile("file"); err != nil {
|
|
t.Fatalf("FormFile: %v", err)
|
|
}
|
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
|
"text": "Алло, тест. Да, слышно.",
|
|
"segments": []map[string]any{
|
|
{"start": 0, "end": 1.2, "text": "Алло, тест."},
|
|
{"start": 1.2, "end": 2.4, "text": "Да, слышно."},
|
|
},
|
|
})
|
|
}))
|
|
defer providerSrv.Close()
|
|
|
|
client := NewWithOptions(Options{
|
|
AudioBaseURL: providerSrv.URL,
|
|
AudioModel: "openai/whisper-large-v3",
|
|
})
|
|
if client == nil {
|
|
t.Fatal("client is nil")
|
|
}
|
|
got, err := client.Transcribe(t.Context(), Input{AudioURL: audioSrv.URL, Filename: "call.mp3"})
|
|
if err != nil {
|
|
t.Fatalf("Transcribe: %v", err)
|
|
}
|
|
if gotPath != "/v1/audio/transcriptions" {
|
|
t.Fatalf("path = %q, want /v1/audio/transcriptions", gotPath)
|
|
}
|
|
if gotModel != "openai/whisper-large-v3" {
|
|
t.Fatalf("model = %q", gotModel)
|
|
}
|
|
if gotResponseFormat != "verbose_json" {
|
|
t.Fatalf("response_format = %q, want verbose_json", gotResponseFormat)
|
|
}
|
|
if gotTemperature != "0" {
|
|
t.Fatalf("temperature = %q, want 0", gotTemperature)
|
|
}
|
|
if gotTimestampGranularity != "segment" {
|
|
t.Fatalf("timestamp_granularities[] = %q, want segment", gotTimestampGranularity)
|
|
}
|
|
if gotPrompt != "" {
|
|
t.Fatalf("prompt = %q, want empty", gotPrompt)
|
|
}
|
|
if len(got.Segments) != 2 || got.Segments[0].Text != "Алло, тест." || got.Segments[1].Start != 1.2 {
|
|
t.Fatalf("segments = %#v", got.Segments)
|
|
}
|
|
}
|
|
|
|
func TestWhisperFallsBackToJSONWhenVerboseJSONUnsupported(t *testing.T) {
|
|
audioSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = w.Write([]byte("fake audio"))
|
|
}))
|
|
defer audioSrv.Close()
|
|
|
|
var formats []string
|
|
providerSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if err := r.ParseMultipartForm(16 << 20); err != nil {
|
|
t.Fatalf("ParseMultipartForm: %v", err)
|
|
}
|
|
format := r.FormValue("response_format")
|
|
formats = append(formats, format)
|
|
if format == "verbose_json" {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
|
"error": map[string]any{"message": "unsupported response_format verbose_json"},
|
|
})
|
|
return
|
|
}
|
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
|
"text": "Алло, fallback работает.",
|
|
})
|
|
}))
|
|
defer providerSrv.Close()
|
|
|
|
client := NewWithOptions(Options{
|
|
AudioBaseURL: providerSrv.URL,
|
|
AudioModel: "openai/whisper-large-v3",
|
|
})
|
|
got, err := client.Transcribe(t.Context(), Input{AudioURL: audioSrv.URL, Filename: "call.mp3"})
|
|
if err != nil {
|
|
t.Fatalf("Transcribe: %v", err)
|
|
}
|
|
if len(formats) != 2 || formats[0] != "verbose_json" || formats[1] != "json" {
|
|
t.Fatalf("formats = %#v, want verbose_json then json", formats)
|
|
}
|
|
if got.Segments[0].Text != "Алло, fallback работает." {
|
|
t.Fatalf("segments = %#v", got.Segments)
|
|
}
|
|
}
|
|
|
|
func TestSegmentTranscriptTextDoesNotInventSpeakers(t *testing.T) {
|
|
got := segmentTranscriptText("Алло, добрый день. Да, слушаю. Скажите, квартира продается? Да, продается.")
|
|
if len(got) < 2 {
|
|
t.Fatalf("segments = %#v, want multiple", got)
|
|
}
|
|
if got[0].Speaker != "" || got[1].Speaker != "" {
|
|
t.Fatalf("speakers = %q/%q, want empty", got[0].Speaker, got[1].Speaker)
|
|
}
|
|
if got[1].Start <= got[0].Start {
|
|
t.Fatalf("segment times did not advance: %#v", got)
|
|
}
|
|
}
|
|
|
|
func TestNormalizeAudioLLMSegmentsSplitsSingleLongSegment(t *testing.T) {
|
|
text := "Алло, добрый день. Да, слушаю. Скажите, квартира продается? Да, продается."
|
|
got := normalizeAudioLLMSegments([]Segment{{Start: 0, End: 12, Text: text}}, text, true)
|
|
if len(got) < 2 {
|
|
t.Fatalf("segments = %#v, want heuristic split", got)
|
|
}
|
|
if got[0].Speaker != "" || got[1].Speaker != "" {
|
|
t.Fatalf("speakers = %q/%q, want empty", got[0].Speaker, got[1].Speaker)
|
|
}
|
|
}
|
|
|
|
func TestNormalizeAudioLLMSegmentsKeepsSegmentsWithoutFakeSpeakers(t *testing.T) {
|
|
got := normalizeAudioLLMSegments([]Segment{
|
|
{Start: 0, End: 1, Text: "Алло."},
|
|
{Start: 1, End: 2, Text: "Да, слушаю."},
|
|
}, "Алло. Да, слушаю.", true)
|
|
if len(got) != 2 {
|
|
t.Fatalf("segments = %#v", got)
|
|
}
|
|
if got[0].Speaker != "" || got[1].Speaker != "" {
|
|
t.Fatalf("speakers = %q/%q, want empty", got[0].Speaker, got[1].Speaker)
|
|
}
|
|
}
|
|
|
|
func TestNormalizeAudioLLMSegmentsUsesExplicitSpeakerLabels(t *testing.T) {
|
|
text := "Спикер 1: Алло, добрый день. Спикер 2: Да, слушаю. Спикер 1: Скажите, квартира продается?"
|
|
got := normalizeAudioLLMSegments(nil, text, true)
|
|
if len(got) != 3 {
|
|
t.Fatalf("segments = %#v, want 3", got)
|
|
}
|
|
if got[0].Speaker != "SPEAKER_00" || got[1].Speaker != "SPEAKER_01" || got[2].Speaker != "SPEAKER_00" {
|
|
t.Fatalf("speakers = %q/%q/%q", got[0].Speaker, got[1].Speaker, got[2].Speaker)
|
|
}
|
|
if got[0].Text != "Алло, добрый день." || got[1].Text != "Да, слушаю." {
|
|
t.Fatalf("texts = %#v", got)
|
|
}
|
|
}
|