Pad audio before WhisperX transcription
This commit is contained in:
@@ -13,7 +13,7 @@ RUN CGO_ENABLED=0 GOOS=linux go build -o /out/ai-service ./cmd/server \
|
|||||||
|
|
||||||
FROM alpine:3.22
|
FROM alpine:3.22
|
||||||
|
|
||||||
RUN apk add --no-cache ca-certificates tini
|
RUN apk add --no-cache ca-certificates ffmpeg tini
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
COPY --from=builder /out/ai-service /usr/local/bin/ai-service
|
COPY --from=builder /out/ai-service /usr/local/bin/ai-service
|
||||||
|
|||||||
@@ -42,13 +42,14 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
llmClient := llm.New(cfg.LLMBaseURL, cfg.LLMAPIKey, cfg.LLMModel, cfg.LLMTimeout)
|
llmClient := llm.New(cfg.LLMBaseURL, cfg.LLMAPIKey, cfg.LLMModel, cfg.LLMTimeout)
|
||||||
transcriber := transcription.New(cfg.WhisperXURL, cfg.WhisperXTimeout)
|
transcriber := transcription.New(cfg.WhisperXURL, cfg.WhisperXTimeout, cfg.FfmpegPath, cfg.WhisperXLeadSilence)
|
||||||
w := worker.New(db, llmClient, transcriber, cfg.WorkerID, cfg.LLMModel, cfg.WorkerTaskTypes, cfg.WorkerModelProfiles, cfg.WorkerPollInterval, cfg.WorkerLeaseTimeout, cfg.WorkerClaimLimit)
|
w := worker.New(db, llmClient, transcriber, cfg.WorkerID, cfg.LLMModel, cfg.WorkerTaskTypes, cfg.WorkerModelProfiles, cfg.WorkerPollInterval, cfg.WorkerLeaseTimeout, cfg.WorkerClaimLimit)
|
||||||
|
|
||||||
slog.Info("ai_worker_started",
|
slog.Info("ai_worker_started",
|
||||||
"worker_id", cfg.WorkerID,
|
"worker_id", cfg.WorkerID,
|
||||||
"model", cfg.LLMModel,
|
"model", cfg.LLMModel,
|
||||||
"whisperx_enabled", transcriber != nil,
|
"whisperx_enabled", transcriber != nil,
|
||||||
|
"whisperx_lead_silence", cfg.WhisperXLeadSilence.String(),
|
||||||
"task_types", cfg.WorkerTaskTypes,
|
"task_types", cfg.WorkerTaskTypes,
|
||||||
"model_profiles", cfg.WorkerModelProfiles,
|
"model_profiles", cfg.WorkerModelProfiles,
|
||||||
"poll_interval", cfg.WorkerPollInterval.String(),
|
"poll_interval", cfg.WorkerPollInterval.String(),
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ type Config struct {
|
|||||||
LLMTimeout time.Duration
|
LLMTimeout time.Duration
|
||||||
WhisperXURL string
|
WhisperXURL string
|
||||||
WhisperXTimeout time.Duration
|
WhisperXTimeout time.Duration
|
||||||
|
WhisperXLeadSilence time.Duration
|
||||||
|
FfmpegPath string
|
||||||
AIStatsSidecarURL string
|
AIStatsSidecarURL string
|
||||||
AIStatsTimeout time.Duration
|
AIStatsTimeout time.Duration
|
||||||
|
|
||||||
@@ -45,6 +47,8 @@ func Load() Config {
|
|||||||
LLMTimeout: envDuration("LLM_TIMEOUT", 5*time.Minute),
|
LLMTimeout: envDuration("LLM_TIMEOUT", 5*time.Minute),
|
||||||
WhisperXURL: envString("WHISPERX_URL", ""),
|
WhisperXURL: envString("WHISPERX_URL", ""),
|
||||||
WhisperXTimeout: envDuration("WHISPERX_TIMEOUT", 10*time.Minute),
|
WhisperXTimeout: envDuration("WHISPERX_TIMEOUT", 10*time.Minute),
|
||||||
|
WhisperXLeadSilence: envDuration("WHISPERX_LEAD_SILENCE", 800*time.Millisecond),
|
||||||
|
FfmpegPath: envString("FFMPEG_PATH", "/usr/bin/ffmpeg"),
|
||||||
AIStatsSidecarURL: envString("AI_STATS_SIDECAR_URL", ""),
|
AIStatsSidecarURL: envString("AI_STATS_SIDECAR_URL", ""),
|
||||||
AIStatsTimeout: envDuration("AI_STATS_TIMEOUT", 8*time.Second),
|
AIStatsTimeout: envDuration("AI_STATS_TIMEOUT", 8*time.Second),
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -16,6 +18,8 @@ import (
|
|||||||
type Client struct {
|
type Client struct {
|
||||||
baseURL string
|
baseURL string
|
||||||
http *http.Client
|
http *http.Client
|
||||||
|
ffmpegPath string
|
||||||
|
leadSilence time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
type Input struct {
|
type Input struct {
|
||||||
@@ -49,7 +53,7 @@ type whisperResponse struct {
|
|||||||
AlignError *string `json:"align_error,omitempty"`
|
AlignError *string `json:"align_error,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(baseURL string, timeout time.Duration) *Client {
|
func New(baseURL string, timeout time.Duration, ffmpegPath string, leadSilence time.Duration) *Client {
|
||||||
baseURL = strings.TrimRight(strings.TrimSpace(baseURL), "/")
|
baseURL = strings.TrimRight(strings.TrimSpace(baseURL), "/")
|
||||||
if baseURL == "" {
|
if baseURL == "" {
|
||||||
return nil
|
return nil
|
||||||
@@ -57,9 +61,21 @@ func New(baseURL string, timeout time.Duration) *Client {
|
|||||||
if timeout <= 0 {
|
if timeout <= 0 {
|
||||||
timeout = 10 * time.Minute
|
timeout = 10 * time.Minute
|
||||||
}
|
}
|
||||||
|
if leadSilence < 0 {
|
||||||
|
leadSilence = 0
|
||||||
|
}
|
||||||
|
if leadSilence > 5*time.Second {
|
||||||
|
leadSilence = 5 * time.Second
|
||||||
|
}
|
||||||
|
ffmpegPath = strings.TrimSpace(ffmpegPath)
|
||||||
|
if ffmpegPath == "" {
|
||||||
|
ffmpegPath = "ffmpeg"
|
||||||
|
}
|
||||||
return &Client{
|
return &Client{
|
||||||
baseURL: baseURL,
|
baseURL: baseURL,
|
||||||
http: &http.Client{Timeout: timeout},
|
http: &http.Client{Timeout: timeout},
|
||||||
|
ffmpegPath: ffmpegPath,
|
||||||
|
leadSilence: leadSilence,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -74,13 +90,20 @@ func (c *Client) Transcribe(ctx context.Context, in Input) (*Result, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if c.leadSilence > 0 {
|
||||||
|
audio, filename, err = c.addLeadSilence(ctx, audio, filename)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
resp, duration, err := c.transcribeAudio(ctx, audio, filename, in)
|
resp, duration, err := c.transcribeAudio(ctx, audio, filename, in)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
segments := adjustLeadSilence(resp.Segments, c.leadSilence)
|
||||||
return &Result{
|
return &Result{
|
||||||
Language: resp.Language,
|
Language: resp.Language,
|
||||||
Segments: resp.Segments,
|
Segments: segments,
|
||||||
DiarizeError: resp.DiarizeError,
|
DiarizeError: resp.DiarizeError,
|
||||||
AlignError: resp.AlignError,
|
AlignError: resp.AlignError,
|
||||||
DurationMS: duration.Milliseconds(),
|
DurationMS: duration.Milliseconds(),
|
||||||
@@ -115,6 +138,90 @@ func (c *Client) downloadAudio(ctx context.Context, in Input) ([]byte, string, e
|
|||||||
return audio, filename, nil
|
return audio, filename, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) addLeadSilence(ctx context.Context, audio []byte, filename string) ([]byte, string, error) {
|
||||||
|
tmpDir, err := os.MkdirTemp("", "ai-transcribe-*")
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", fmt.Errorf("prepare audio temp dir: %w", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tmpDir)
|
||||||
|
|
||||||
|
inputPath := filepath.Join(tmpDir, "input"+safeExt(filename))
|
||||||
|
outputPath := filepath.Join(tmpDir, "padded.mp3")
|
||||||
|
if err := os.WriteFile(inputPath, audio, 0o600); err != nil {
|
||||||
|
return nil, "", fmt.Errorf("write audio temp file: %w", err)
|
||||||
|
}
|
||||||
|
delayMS := int(c.leadSilence.Milliseconds())
|
||||||
|
if delayMS <= 0 {
|
||||||
|
return audio, filename, nil
|
||||||
|
}
|
||||||
|
cmd := exec.CommandContext(ctx, c.ffmpegPath,
|
||||||
|
"-nostdin", "-y",
|
||||||
|
"-i", inputPath,
|
||||||
|
"-af", fmt.Sprintf("adelay=%d:all=1", delayMS),
|
||||||
|
"-codec:a", "libmp3lame",
|
||||||
|
"-qscale:a", "5",
|
||||||
|
outputPath,
|
||||||
|
)
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", fmt.Errorf("ffmpeg lead silence: %w (%s)", err, trimOutput(out))
|
||||||
|
}
|
||||||
|
padded, err := os.ReadFile(outputPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", fmt.Errorf("read padded audio: %w", err)
|
||||||
|
}
|
||||||
|
if len(padded) == 0 {
|
||||||
|
return nil, "", fmt.Errorf("padded audio is empty")
|
||||||
|
}
|
||||||
|
base := strings.TrimSuffix(filepath.Base(filename), filepath.Ext(filename))
|
||||||
|
if base == "" || base == "." || base == "/" {
|
||||||
|
base = "audio"
|
||||||
|
}
|
||||||
|
return padded, base + "-padded.mp3", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func safeExt(filename string) string {
|
||||||
|
ext := strings.ToLower(filepath.Ext(filename))
|
||||||
|
switch ext {
|
||||||
|
case ".mp3", ".wav", ".m4a", ".ogg", ".opus", ".webm":
|
||||||
|
return ext
|
||||||
|
default:
|
||||||
|
return ".audio"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func trimOutput(out []byte) string {
|
||||||
|
s := strings.TrimSpace(string(out))
|
||||||
|
if len(s) > 600 {
|
||||||
|
return s[:600]
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func adjustLeadSilence(segments []Segment, silence time.Duration) []Segment {
|
||||||
|
if len(segments) == 0 || silence <= 0 {
|
||||||
|
return segments
|
||||||
|
}
|
||||||
|
shift := silence.Seconds()
|
||||||
|
out := make([]Segment, 0, len(segments))
|
||||||
|
for _, segment := range segments {
|
||||||
|
segment.Start = clampTime(segment.Start - shift)
|
||||||
|
segment.End = clampTime(segment.End - shift)
|
||||||
|
if segment.End < segment.Start {
|
||||||
|
segment.End = segment.Start
|
||||||
|
}
|
||||||
|
out = append(out, segment)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func clampTime(v float64) float64 {
|
||||||
|
if v < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Client) transcribeAudio(ctx context.Context, audio []byte, filename string, in Input) (*whisperResponse, time.Duration, error) {
|
func (c *Client) transcribeAudio(ctx context.Context, audio []byte, filename string, in Input) (*whisperResponse, time.Duration, error) {
|
||||||
body := &bytes.Buffer{}
|
body := &bytes.Buffer{}
|
||||||
mw := multipart.NewWriter(body)
|
mw := multipart.NewWriter(body)
|
||||||
|
|||||||
28
internal/transcription/client_test.go
Normal file
28
internal/transcription/client_test.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package transcription
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAdjustLeadSilence(t *testing.T) {
|
||||||
|
got := adjustLeadSilence([]Segment{
|
||||||
|
{Start: 0.2, End: 1.1, Text: "first"},
|
||||||
|
{Start: 1.4, End: 2.0, Text: "second"},
|
||||||
|
}, 800*time.Millisecond)
|
||||||
|
|
||||||
|
if got[0].Start != 0 {
|
||||||
|
t.Fatalf("first start = %v, want 0", got[0].Start)
|
||||||
|
}
|
||||||
|
if !near(got[0].End, 0.3) {
|
||||||
|
t.Fatalf("first end = %v, want 0.3", got[0].End)
|
||||||
|
}
|
||||||
|
if !near(got[1].Start, 0.6) {
|
||||||
|
t.Fatalf("second start = %v, want 0.6", got[1].Start)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func near(got, want float64) bool {
|
||||||
|
return math.Abs(got-want) < 0.000001
|
||||||
|
}
|
||||||
@@ -13,6 +13,8 @@ data:
|
|||||||
LLM_TIMEOUT: "5m"
|
LLM_TIMEOUT: "5m"
|
||||||
WHISPERX_URL: "http://10.2.3.5:8001"
|
WHISPERX_URL: "http://10.2.3.5:8001"
|
||||||
WHISPERX_TIMEOUT: "10m"
|
WHISPERX_TIMEOUT: "10m"
|
||||||
|
WHISPERX_LEAD_SILENCE: "800ms"
|
||||||
|
FFMPEG_PATH: "/usr/bin/ffmpeg"
|
||||||
AI_STATS_SIDECAR_URL: "http://10.2.3.5:9090"
|
AI_STATS_SIDECAR_URL: "http://10.2.3.5:9090"
|
||||||
AI_STATS_TIMEOUT: "8s"
|
AI_STATS_TIMEOUT: "8s"
|
||||||
WORKER_POLL_INTERVAL: "2s"
|
WORKER_POLL_INTERVAL: "2s"
|
||||||
|
|||||||
Reference in New Issue
Block a user