diff --git a/internal/transcription/client.go b/internal/transcription/client.go index ebc95b9..d62420c 100644 --- a/internal/transcription/client.go +++ b/internal/transcription/client.go @@ -141,6 +141,15 @@ type audioLLMChatResponse struct { } `json:"error,omitempty"` } +type audioTranscriptionResponse struct { + Text string `json:"text"` + Language string `json:"language,omitempty"` + Model string `json:"model,omitempty"` + Error *struct { + Message string `json:"message"` + } `json:"error,omitempty"` +} + func New(baseURL string, timeout time.Duration, ffmpegPath string, leadSilence time.Duration) *Client { return NewWithOptions(Options{ Providers: []string{ProviderWhisperX}, @@ -224,7 +233,7 @@ func buildProviders(opts Options, prompt string, maxTokens int) []ProviderConfig model := firstNonEmpty(opts.VoxtralModel, "mistralai/Voxtral-Small-24B-2507") out = append(out, ProviderConfig{ Name: ProviderVoxtral, - Kind: "audio_llm", + Kind: "audio_transcription", BaseURL: baseURL, APIKey: strings.TrimSpace(opts.VoxtralAPIKey), Model: model, @@ -549,6 +558,9 @@ func (c *Client) transcribeAudio(ctx context.Context, provider ProviderConfig, a } func (c *Client) transcribeAudioLLM(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*audioLLMResponse, time.Duration, error) { + if provider.Kind == "audio_transcription" { + return c.transcribeOpenAIAudio(ctx, provider, audio, filename, in) + } prompt := provider.Prompt if in.Language != "" { prompt += "\nЯзык аудио: " + in.Language + "." @@ -620,6 +632,70 @@ func (c *Client) transcribeAudioLLM(ctx context.Context, provider ProviderConfig }, duration, nil } +func (c *Client) transcribeOpenAIAudio(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*audioLLMResponse, time.Duration, error) { + body := &bytes.Buffer{} + mw := multipart.NewWriter(body) + if err := mw.WriteField("model", provider.Model); err != nil { + return nil, 0, fmt.Errorf("audio transcription model field: %w", err) + } + if err := mw.WriteField("response_format", "json"); err != nil { + return nil, 0, fmt.Errorf("audio transcription response_format field: %w", err) + } + if prompt := strings.TrimSpace(provider.Prompt); prompt != "" { + if err := mw.WriteField("prompt", prompt); err != nil { + return nil, 0, fmt.Errorf("audio transcription prompt field: %w", err) + } + } + if lang := strings.TrimSpace(in.Language); lang != "" { + if err := mw.WriteField("language", lang); err != nil { + return nil, 0, fmt.Errorf("audio transcription language field: %w", err) + } + } + fw, err := mw.CreateFormFile("file", filename) + if err != nil { + return nil, 0, fmt.Errorf("audio transcription create file: %w", err) + } + if _, err := fw.Write(audio); err != nil { + return nil, 0, fmt.Errorf("audio transcription copy audio: %w", err) + } + if err := mw.Close(); err != nil { + return nil, 0, fmt.Errorf("audio transcription close form: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, provider.BaseURL+"/v1/audio/transcriptions", body) + if err != nil { + return nil, 0, fmt.Errorf("audio transcription request: %w", err) + } + req.Header.Set("Content-Type", mw.FormDataContentType()) + if provider.APIKey != "" { + req.Header.Set("Authorization", "Bearer "+provider.APIKey) + } + + start := time.Now() + resp, err := c.http.Do(req) + duration := time.Since(start) + if err != nil { + return nil, duration, fmt.Errorf("audio transcription do: %w", err) + } + defer resp.Body.Close() + raw, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20)) + if err != nil { + return nil, duration, fmt.Errorf("audio transcription read: %w", err) + } + if resp.StatusCode >= 300 { + return nil, duration, fmt.Errorf("audio transcription HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(raw))) + } + var out audioTranscriptionResponse + if err := json.Unmarshal(raw, &out); err != nil { + return nil, duration, fmt.Errorf("audio transcription decode: %w", err) + } + if out.Error != nil { + return nil, duration, fmt.Errorf("audio transcription error: %s", out.Error.Message) + } + modelName := firstNonEmpty(out.Model, provider.Model) + return &audioLLMResponse{Text: strings.TrimSpace(out.Text), Model: modelName}, duration, nil +} + func audioFormat(filename string) string { ext := strings.TrimPrefix(strings.ToLower(filepath.Ext(filename)), ".") switch ext { diff --git a/internal/transcription/client_test.go b/internal/transcription/client_test.go index 8f9fb61..9bb76c2 100644 --- a/internal/transcription/client_test.go +++ b/internal/transcription/client_test.go @@ -1,7 +1,10 @@ package transcription import ( + "encoding/json" "math" + "net/http" + "net/http/httptest" "testing" "time" ) @@ -66,6 +69,49 @@ func TestAudioDataURLUsesVLLMAudioURLFormat(t *testing.T) { } } +func TestVoxtralUsesAudioTranscriptionsEndpoint(t *testing.T) { + audioSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("fake audio")) + })) + defer audioSrv.Close() + + var gotPath, gotModel string + providerSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + if err := r.ParseMultipartForm(16 << 20); err != nil { + t.Fatalf("ParseMultipartForm: %v", err) + } + gotModel = r.FormValue("model") + if _, _, err := r.FormFile("file"); err != nil { + t.Fatalf("FormFile: %v", err) + } + _ = json.NewEncoder(w).Encode(map[string]string{"text": "Алло, тест."}) + })) + defer providerSrv.Close() + + client := NewWithOptions(Options{ + Providers: []string{"voxtral-small"}, + VoxtralBaseURL: providerSrv.URL, + VoxtralModel: "mistralai/Voxtral-Small-24B-2507", + }) + if client == nil { + t.Fatal("client is nil") + } + got, err := client.Transcribe(t.Context(), Input{AudioURL: audioSrv.URL, Filename: "call.mp3"}) + if err != nil { + t.Fatalf("Transcribe: %v", err) + } + if gotPath != "/v1/audio/transcriptions" { + t.Fatalf("path = %q, want /v1/audio/transcriptions", gotPath) + } + if gotModel != "mistralai/Voxtral-Small-24B-2507" { + t.Fatalf("model = %q", gotModel) + } + if len(got.Segments) != 1 || got.Segments[0].Text != "Алло, тест." { + t.Fatalf("segments = %#v", got.Segments) + } +} + func near(got, want float64) bool { return math.Abs(got-want) < 0.000001 }