Files
ai-service/internal/transcription/client.go
2026-06-17 16:46:03 +03:00

513 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package transcription
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"path/filepath"
"regexp"
"strings"
"time"
)
type Client struct {
provider ProviderConfig
http *http.Client
}
const (
ProviderWhisperLargeV3 = "whisper-large-v3"
defaultWhisperModel = "openai/whisper-large-v3"
)
var speakerLabelPattern = regexp.MustCompile(`(?i)(?:^|[\n\r ]+)((?:speaker|спикер|говорящий)\s*\d+)\s*[:-]`)
type Options struct {
AudioBaseURL string
AudioAPIKey string
AudioModel string
AudioTimeout time.Duration
AudioPrompt string
}
type ProviderConfig struct {
Name string
BaseURL string
APIKey string
Model string
Timeout time.Duration
Prompt string
}
type Input struct {
AudioURL string `json:"audio_url"`
Filename string `json:"filename,omitempty"`
Language string `json:"language,omitempty"`
Diarize bool `json:"diarize"`
MinSpeakers int `json:"min_speakers,omitempty"`
MaxSpeakers int `json:"max_speakers,omitempty"`
}
type Segment struct {
Start float64 `json:"start"`
End float64 `json:"end"`
Text string `json:"text"`
Speaker string `json:"speaker,omitempty"`
}
const ResultSchemaVersion = "ai.transcription_result.v1"
type Result struct {
SchemaVersion string `json:"schema_version"`
Provider string `json:"provider,omitempty"`
Model string `json:"model,omitempty"`
Attempts []Attempt `json:"attempts,omitempty"`
Language string `json:"language"`
Segments []Segment `json:"segments"`
DiarizeError *string `json:"diarize_error,omitempty"`
AlignError *string `json:"align_error,omitempty"`
DurationMS int64 `json:"duration_ms"`
}
type Attempt struct {
Provider string `json:"provider"`
Model string `json:"model,omitempty"`
Status string `json:"status"`
Error string `json:"error,omitempty"`
Text string `json:"text,omitempty"`
Segments []Segment `json:"segments,omitempty"`
DurationMS int64 `json:"duration_ms,omitempty"`
}
type audioLLMResponse struct {
Text string
Model string
Language string
Segments []Segment
}
type audioTranscriptionResponse struct {
Text string `json:"text"`
Language string `json:"language,omitempty"`
Model string `json:"model,omitempty"`
Segments []audioTranscriptionSegment `json:"segments,omitempty"`
Error *struct {
Message string `json:"message"`
} `json:"error,omitempty"`
}
type audioTranscriptionSegment struct {
Start float64 `json:"start"`
End float64 `json:"end"`
Text string `json:"text"`
}
type audioTranscriptionStatusError struct {
status int
body string
}
func (e audioTranscriptionStatusError) Error() string {
return fmt.Sprintf("audio transcription HTTP %d: %s", e.status, e.body)
}
func New(baseURL string, timeout time.Duration, ffmpegPath string, leadSilence time.Duration) *Client {
return NewWithOptions(Options{
AudioBaseURL: baseURL,
AudioTimeout: timeout,
})
}
func NewWithOptions(opts Options) *Client {
audioPrompt := strings.TrimSpace(opts.AudioPrompt)
provider := buildAudioProvider(opts, audioPrompt)
if provider.BaseURL == "" {
return nil
}
return &Client{
provider: provider,
http: &http.Client{Timeout: provider.Timeout},
}
}
func buildAudioProvider(opts Options, prompt string) ProviderConfig {
baseURL := strings.TrimRight(strings.TrimSpace(opts.AudioBaseURL), "/")
if baseURL == "" {
return ProviderConfig{}
}
model := firstNonEmpty(opts.AudioModel, defaultWhisperModel)
return ProviderConfig{
Name: ProviderWhisperLargeV3,
BaseURL: baseURL,
APIKey: strings.TrimSpace(opts.AudioAPIKey),
Model: model,
Timeout: defaultDuration(opts.AudioTimeout, 10*time.Minute),
Prompt: prompt,
}
}
func defaultDuration(v, fallback time.Duration) time.Duration {
if v <= 0 {
return fallback
}
return v
}
func (c *Client) Transcribe(ctx context.Context, in Input) (*Result, error) {
if c == nil || c.provider.BaseURL == "" {
return nil, fmt.Errorf("audio transcription provider not configured")
}
if strings.TrimSpace(in.AudioURL) == "" {
return nil, fmt.Errorf("audio_url is required")
}
audio, filename, err := c.downloadAudio(ctx, in)
if err != nil {
return nil, err
}
result, attempt, err := c.transcribeWithProvider(ctx, c.provider, audio, filename, in)
if err != nil {
return nil, err
}
result.Attempts = []Attempt{attempt}
return result, nil
}
func (c *Client) transcribeWithProvider(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*Result, Attempt, error) {
providerCtx := ctx
cancel := func() {}
if provider.Timeout > 0 {
providerCtx, cancel = context.WithTimeout(ctx, provider.Timeout)
}
defer cancel()
attempt := Attempt{
Provider: provider.Name,
Model: provider.Model,
Status: "failed",
}
resp, duration, err := c.transcribeOpenAIAudio(providerCtx, provider, audio, filename, in)
attempt.DurationMS = duration.Milliseconds()
if err != nil {
attempt.Error = err.Error()
return nil, attempt, err
}
text := strings.TrimSpace(resp.Text)
segments := normalizeAudioLLMSegments(resp.Segments, text, in.Diarize)
attempt.Status = "ok"
attempt.Model = resp.Model
attempt.Text = text
attempt.Segments = segments
return &Result{
SchemaVersion: ResultSchemaVersion,
Provider: provider.Name,
Model: resp.Model,
Language: firstNonEmpty(resp.Language, in.Language, "unknown"),
Segments: segments,
DurationMS: duration.Milliseconds(),
}, attempt, nil
}
func (c *Client) downloadAudio(ctx context.Context, in Input) ([]byte, string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, in.AudioURL, nil)
if err != nil {
return nil, "", fmt.Errorf("audio request: %w", err)
}
resp, err := c.http.Do(req)
if err != nil {
return nil, "", fmt.Errorf("audio download: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return nil, "", fmt.Errorf("audio HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
audio, err := io.ReadAll(io.LimitReader(resp.Body, 512<<20))
if err != nil {
return nil, "", fmt.Errorf("audio read: %w", err)
}
if len(audio) == 0 {
return nil, "", fmt.Errorf("audio is empty")
}
filename := filepath.Base(strings.TrimSpace(in.Filename))
if filename == "." || filename == "/" || filename == "" {
filename = "audio.mp3"
}
return audio, filename, nil
}
func clampTime(v float64) float64 {
if v < 0 {
return 0
}
return v
}
func (c *Client) transcribeOpenAIAudio(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*audioLLMResponse, time.Duration, error) {
resp, duration, err := c.doOpenAIAudioTranscription(ctx, provider, audio, filename, in, "verbose_json")
if err == nil {
return resp, duration, nil
}
if !isVerboseJSONUnsupported(err) {
return nil, duration, err
}
fallbackResp, fallbackDuration, fallbackErr := c.doOpenAIAudioTranscription(ctx, provider, audio, filename, in, "json")
return fallbackResp, duration + fallbackDuration, fallbackErr
}
func (c *Client) doOpenAIAudioTranscription(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input, responseFormat string) (*audioLLMResponse, time.Duration, error) {
body := &bytes.Buffer{}
mw := multipart.NewWriter(body)
if err := mw.WriteField("model", provider.Model); err != nil {
return nil, 0, fmt.Errorf("audio transcription model field: %w", err)
}
if err := mw.WriteField("response_format", responseFormat); err != nil {
return nil, 0, fmt.Errorf("audio transcription response_format field: %w", err)
}
if err := mw.WriteField("temperature", "0"); err != nil {
return nil, 0, fmt.Errorf("audio transcription temperature field: %w", err)
}
if responseFormat == "verbose_json" {
if err := mw.WriteField("timestamp_granularities[]", "segment"); err != nil {
return nil, 0, fmt.Errorf("audio transcription timestamp field: %w", err)
}
}
if prompt := strings.TrimSpace(provider.Prompt); prompt != "" {
if err := mw.WriteField("prompt", prompt); err != nil {
return nil, 0, fmt.Errorf("audio transcription prompt field: %w", err)
}
}
if lang := strings.TrimSpace(in.Language); lang != "" {
if err := mw.WriteField("language", lang); err != nil {
return nil, 0, fmt.Errorf("audio transcription language field: %w", err)
}
}
fw, err := mw.CreateFormFile("file", filename)
if err != nil {
return nil, 0, fmt.Errorf("audio transcription create file: %w", err)
}
if _, err := fw.Write(audio); err != nil {
return nil, 0, fmt.Errorf("audio transcription copy audio: %w", err)
}
if err := mw.Close(); err != nil {
return nil, 0, fmt.Errorf("audio transcription close form: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, provider.BaseURL+"/v1/audio/transcriptions", body)
if err != nil {
return nil, 0, fmt.Errorf("audio transcription request: %w", err)
}
req.Header.Set("Content-Type", mw.FormDataContentType())
if provider.APIKey != "" {
req.Header.Set("Authorization", "Bearer "+provider.APIKey)
}
start := time.Now()
resp, err := c.http.Do(req)
duration := time.Since(start)
if err != nil {
return nil, duration, fmt.Errorf("audio transcription do: %w", err)
}
defer resp.Body.Close()
raw, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
if err != nil {
return nil, duration, fmt.Errorf("audio transcription read: %w", err)
}
if resp.StatusCode >= 300 {
return nil, duration, audioTranscriptionStatusError{status: resp.StatusCode, body: strings.TrimSpace(string(raw))}
}
var out audioTranscriptionResponse
if err := json.Unmarshal(raw, &out); err != nil {
return nil, duration, fmt.Errorf("audio transcription decode: %w", err)
}
if out.Error != nil {
return nil, duration, fmt.Errorf("audio transcription error: %s", out.Error.Message)
}
modelName := firstNonEmpty(out.Model, provider.Model)
return &audioLLMResponse{
Text: strings.TrimSpace(out.Text),
Model: modelName,
Language: out.Language,
Segments: convertAudioSegments(out.Segments),
}, duration, nil
}
func isVerboseJSONUnsupported(err error) bool {
var statusErr audioTranscriptionStatusError
if !errors.As(err, &statusErr) {
return false
}
if statusErr.status != http.StatusBadRequest && statusErr.status != http.StatusUnprocessableEntity {
return false
}
body := strings.ToLower(statusErr.body)
return strings.Contains(body, "verbose_json") ||
strings.Contains(body, "response_format") ||
strings.Contains(body, "timestamp_granularities")
}
func convertAudioSegments(in []audioTranscriptionSegment) []Segment {
out := make([]Segment, 0, len(in))
for _, s := range in {
text := strings.TrimSpace(s.Text)
if text == "" {
continue
}
end := s.End
if end < s.Start {
end = s.Start
}
out = append(out, Segment{Start: clampTime(s.Start), End: clampTime(end), Text: text})
}
return out
}
func normalizeAudioLLMSegments(segments []Segment, text string, diarize bool) []Segment {
text = strings.TrimSpace(text)
if text != "" {
if labeled := segmentSpeakerLabeledText(text); len(labeled) > 0 {
return labeled
}
}
if len(segments) <= 1 && text != "" {
heuristic := segmentTranscriptText(text)
if len(heuristic) > len(segments) {
segments = heuristic
}
}
return segments
}
func segmentSpeakerLabeledText(text string) []Segment {
matches := speakerLabelPattern.FindAllStringSubmatchIndex(text, -1)
if len(matches) == 0 {
return nil
}
speakerIDs := map[string]string{}
var out []Segment
var t float64
for i, match := range matches {
label := strings.ToLower(strings.TrimSpace(text[match[2]:match[3]]))
speaker, ok := speakerIDs[label]
if !ok {
speaker = fmt.Sprintf("SPEAKER_%02d", len(speakerIDs))
speakerIDs[label] = speaker
}
start := match[1]
end := len(text)
if i+1 < len(matches) {
end = matches[i+1][0]
}
part := strings.TrimSpace(text[start:end])
part = strings.Trim(part, ":-— ")
if part == "" {
continue
}
words := len(strings.Fields(part))
duration := float64(words) * 0.42
if duration < 1.2 {
duration = 1.2
}
out = append(out, Segment{Start: t, End: t + duration, Text: part, Speaker: speaker})
t += duration
}
return out
}
func segmentTranscriptText(text string) []Segment {
parts := splitTranscriptSentences(text)
out := make([]Segment, 0, len(parts))
var t float64
for _, part := range parts {
words := len(strings.Fields(part))
if words == 0 {
continue
}
duration := float64(words) * 0.42
if duration < 1.2 {
duration = 1.2
}
segment := Segment{Start: t, End: t + duration, Text: part}
out = append(out, segment)
t = segment.End
}
if len(out) == 0 && strings.TrimSpace(text) != "" {
out = append(out, Segment{Start: 0, End: 0, Text: strings.TrimSpace(text)})
}
return out
}
func splitTranscriptSentences(text string) []string {
text = strings.Join(strings.Fields(text), " ")
if text == "" {
return nil
}
var out []string
start := 0
runes := []rune(text)
for i, r := range runes {
if r != '.' && r != '!' && r != '?' && r != '…' {
continue
}
next := i + 1
if next < len(runes) && runes[next] != ' ' {
continue
}
part := strings.TrimSpace(string(runes[start : i+1]))
if part != "" {
out = append(out, part)
}
start = i + 1
for start < len(runes) && runes[start] == ' ' {
start++
}
}
tail := strings.TrimSpace(string(runes[start:]))
if tail != "" {
out = append(out, tail)
}
return mergeShortSegments(out, 8, 34)
}
func mergeShortSegments(parts []string, minWords, maxWords int) []string {
if len(parts) <= 1 {
return parts
}
out := make([]string, 0, len(parts))
var current []string
currentWords := 0
flush := func() {
if len(current) == 0 {
return
}
out = append(out, strings.Join(current, " "))
current = nil
currentWords = 0
}
for _, part := range parts {
words := len(strings.Fields(part))
if currentWords > 0 && currentWords+words > maxWords {
flush()
}
current = append(current, part)
currentWords += words
if currentWords >= minWords {
flush()
}
}
flush()
return out
}
func firstNonEmpty(values ...string) string {
for _, value := range values {
if strings.TrimSpace(value) != "" {
return value
}
}
return ""
}