276 lines
7.4 KiB
Go
276 lines
7.4 KiB
Go
package transcription
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"os"
|
|
"os/exec"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type Client struct {
|
|
baseURL string
|
|
http *http.Client
|
|
ffmpegPath string
|
|
leadSilence time.Duration
|
|
}
|
|
|
|
type Input struct {
|
|
AudioURL string `json:"audio_url"`
|
|
Filename string `json:"filename,omitempty"`
|
|
Language string `json:"language,omitempty"`
|
|
Diarize bool `json:"diarize"`
|
|
MinSpeakers int `json:"min_speakers,omitempty"`
|
|
MaxSpeakers int `json:"max_speakers,omitempty"`
|
|
}
|
|
|
|
type Segment struct {
|
|
Start float64 `json:"start"`
|
|
End float64 `json:"end"`
|
|
Text string `json:"text"`
|
|
Speaker string `json:"speaker,omitempty"`
|
|
}
|
|
|
|
type Result struct {
|
|
Language string `json:"language"`
|
|
Segments []Segment `json:"segments"`
|
|
DiarizeError *string `json:"diarize_error,omitempty"`
|
|
AlignError *string `json:"align_error,omitempty"`
|
|
DurationMS int64 `json:"duration_ms"`
|
|
}
|
|
|
|
type whisperResponse struct {
|
|
Language string `json:"language"`
|
|
Segments []Segment `json:"segments"`
|
|
DiarizeError *string `json:"diarize_error,omitempty"`
|
|
AlignError *string `json:"align_error,omitempty"`
|
|
}
|
|
|
|
func New(baseURL string, timeout time.Duration, ffmpegPath string, leadSilence time.Duration) *Client {
|
|
baseURL = strings.TrimRight(strings.TrimSpace(baseURL), "/")
|
|
if baseURL == "" {
|
|
return nil
|
|
}
|
|
if timeout <= 0 {
|
|
timeout = 10 * time.Minute
|
|
}
|
|
if leadSilence < 0 {
|
|
leadSilence = 0
|
|
}
|
|
if leadSilence > 5*time.Second {
|
|
leadSilence = 5 * time.Second
|
|
}
|
|
ffmpegPath = strings.TrimSpace(ffmpegPath)
|
|
if ffmpegPath == "" {
|
|
ffmpegPath = "ffmpeg"
|
|
}
|
|
return &Client{
|
|
baseURL: baseURL,
|
|
http: &http.Client{Timeout: timeout},
|
|
ffmpegPath: ffmpegPath,
|
|
leadSilence: leadSilence,
|
|
}
|
|
}
|
|
|
|
func (c *Client) Transcribe(ctx context.Context, in Input) (*Result, error) {
|
|
if c == nil || c.baseURL == "" {
|
|
return nil, fmt.Errorf("whisperx not configured")
|
|
}
|
|
if strings.TrimSpace(in.AudioURL) == "" {
|
|
return nil, fmt.Errorf("audio_url is required")
|
|
}
|
|
audio, filename, err := c.downloadAudio(ctx, in)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if c.leadSilence > 0 {
|
|
audio, filename, err = c.addLeadSilence(ctx, audio, filename)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
resp, duration, err := c.transcribeAudio(ctx, audio, filename, in)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
segments := adjustLeadSilence(resp.Segments, c.leadSilence)
|
|
return &Result{
|
|
Language: resp.Language,
|
|
Segments: segments,
|
|
DiarizeError: resp.DiarizeError,
|
|
AlignError: resp.AlignError,
|
|
DurationMS: duration.Milliseconds(),
|
|
}, nil
|
|
}
|
|
|
|
func (c *Client) downloadAudio(ctx context.Context, in Input) ([]byte, string, error) {
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, in.AudioURL, nil)
|
|
if err != nil {
|
|
return nil, "", fmt.Errorf("audio request: %w", err)
|
|
}
|
|
resp, err := c.http.Do(req)
|
|
if err != nil {
|
|
return nil, "", fmt.Errorf("audio download: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode >= 300 {
|
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
|
return nil, "", fmt.Errorf("audio HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
|
}
|
|
audio, err := io.ReadAll(io.LimitReader(resp.Body, 512<<20))
|
|
if err != nil {
|
|
return nil, "", fmt.Errorf("audio read: %w", err)
|
|
}
|
|
if len(audio) == 0 {
|
|
return nil, "", fmt.Errorf("audio is empty")
|
|
}
|
|
filename := filepath.Base(strings.TrimSpace(in.Filename))
|
|
if filename == "." || filename == "/" || filename == "" {
|
|
filename = "audio.mp3"
|
|
}
|
|
return audio, filename, nil
|
|
}
|
|
|
|
func (c *Client) addLeadSilence(ctx context.Context, audio []byte, filename string) ([]byte, string, error) {
|
|
tmpDir, err := os.MkdirTemp("", "ai-transcribe-*")
|
|
if err != nil {
|
|
return nil, "", fmt.Errorf("prepare audio temp dir: %w", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
inputPath := filepath.Join(tmpDir, "input"+safeExt(filename))
|
|
outputPath := filepath.Join(tmpDir, "padded.mp3")
|
|
if err := os.WriteFile(inputPath, audio, 0o600); err != nil {
|
|
return nil, "", fmt.Errorf("write audio temp file: %w", err)
|
|
}
|
|
delayMS := int(c.leadSilence.Milliseconds())
|
|
if delayMS <= 0 {
|
|
return audio, filename, nil
|
|
}
|
|
cmd := exec.CommandContext(ctx, c.ffmpegPath,
|
|
"-nostdin", "-y",
|
|
"-i", inputPath,
|
|
"-af", fmt.Sprintf("adelay=%d:all=1", delayMS),
|
|
"-codec:a", "libmp3lame",
|
|
"-qscale:a", "5",
|
|
outputPath,
|
|
)
|
|
out, err := cmd.CombinedOutput()
|
|
if err != nil {
|
|
return nil, "", fmt.Errorf("ffmpeg lead silence: %w (%s)", err, trimOutput(out))
|
|
}
|
|
padded, err := os.ReadFile(outputPath)
|
|
if err != nil {
|
|
return nil, "", fmt.Errorf("read padded audio: %w", err)
|
|
}
|
|
if len(padded) == 0 {
|
|
return nil, "", fmt.Errorf("padded audio is empty")
|
|
}
|
|
base := strings.TrimSuffix(filepath.Base(filename), filepath.Ext(filename))
|
|
if base == "" || base == "." || base == "/" {
|
|
base = "audio"
|
|
}
|
|
return padded, base + "-padded.mp3", nil
|
|
}
|
|
|
|
func safeExt(filename string) string {
|
|
ext := strings.ToLower(filepath.Ext(filename))
|
|
switch ext {
|
|
case ".mp3", ".wav", ".m4a", ".ogg", ".opus", ".webm":
|
|
return ext
|
|
default:
|
|
return ".audio"
|
|
}
|
|
}
|
|
|
|
func trimOutput(out []byte) string {
|
|
s := strings.TrimSpace(string(out))
|
|
if len(s) > 600 {
|
|
return s[:600]
|
|
}
|
|
return s
|
|
}
|
|
|
|
func adjustLeadSilence(segments []Segment, silence time.Duration) []Segment {
|
|
if len(segments) == 0 || silence <= 0 {
|
|
return segments
|
|
}
|
|
shift := silence.Seconds()
|
|
out := make([]Segment, 0, len(segments))
|
|
for _, segment := range segments {
|
|
segment.Start = clampTime(segment.Start - shift)
|
|
segment.End = clampTime(segment.End - shift)
|
|
if segment.End < segment.Start {
|
|
segment.End = segment.Start
|
|
}
|
|
out = append(out, segment)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func clampTime(v float64) float64 {
|
|
if v < 0 {
|
|
return 0
|
|
}
|
|
return v
|
|
}
|
|
|
|
func (c *Client) transcribeAudio(ctx context.Context, audio []byte, filename string, in Input) (*whisperResponse, time.Duration, error) {
|
|
body := &bytes.Buffer{}
|
|
mw := multipart.NewWriter(body)
|
|
fw, err := mw.CreateFormFile("file", filename)
|
|
if err != nil {
|
|
return nil, 0, fmt.Errorf("create form file: %w", err)
|
|
}
|
|
if _, err := fw.Write(audio); err != nil {
|
|
return nil, 0, fmt.Errorf("copy audio: %w", err)
|
|
}
|
|
if in.Language != "" {
|
|
_ = mw.WriteField("language", in.Language)
|
|
}
|
|
if in.Diarize {
|
|
_ = mw.WriteField("diarize", "true")
|
|
if in.MinSpeakers > 0 {
|
|
_ = mw.WriteField("min_speakers", fmt.Sprintf("%d", in.MinSpeakers))
|
|
}
|
|
if in.MaxSpeakers > 0 {
|
|
_ = mw.WriteField("max_speakers", fmt.Sprintf("%d", in.MaxSpeakers))
|
|
}
|
|
} else {
|
|
_ = mw.WriteField("diarize", "false")
|
|
}
|
|
if err := mw.Close(); err != nil {
|
|
return nil, 0, fmt.Errorf("close form: %w", err)
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/transcribe", body)
|
|
if err != nil {
|
|
return nil, 0, fmt.Errorf("whisperx request: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", mw.FormDataContentType())
|
|
|
|
start := time.Now()
|
|
resp, err := c.http.Do(req)
|
|
duration := time.Since(start)
|
|
if err != nil {
|
|
return nil, duration, fmt.Errorf("whisperx do: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode >= 300 {
|
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
|
return nil, duration, fmt.Errorf("whisperx HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
|
}
|
|
var out whisperResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
|
|
return nil, duration, fmt.Errorf("whisperx decode: %w", err)
|
|
}
|
|
return &out, duration, nil
|
|
}
|