diff --git a/cmd/worker/main.go b/cmd/worker/main.go index 2aa8dd1..a8751ab 100644 --- a/cmd/worker/main.go +++ b/cmd/worker/main.go @@ -11,6 +11,7 @@ import ( "ai-service/internal/llm" "ai-service/internal/migrate" "ai-service/internal/store" + "ai-service/internal/transcription" "ai-service/internal/worker" ) @@ -41,11 +42,13 @@ func main() { } llmClient := llm.New(cfg.LLMBaseURL, cfg.LLMAPIKey, cfg.LLMModel, cfg.LLMTimeout) - w := worker.New(db, llmClient, cfg.WorkerID, cfg.LLMModel, cfg.WorkerPollInterval, cfg.WorkerLeaseTimeout, cfg.WorkerClaimLimit) + transcriber := transcription.New(cfg.WhisperXURL, cfg.WhisperXTimeout) + w := worker.New(db, llmClient, transcriber, cfg.WorkerID, cfg.LLMModel, cfg.WorkerPollInterval, cfg.WorkerLeaseTimeout, cfg.WorkerClaimLimit) slog.Info("ai_worker_started", "worker_id", cfg.WorkerID, "model", cfg.LLMModel, + "whisperx_enabled", transcriber != nil, "poll_interval", cfg.WorkerPollInterval.String(), "lease_timeout", cfg.WorkerLeaseTimeout.String(), "claim_limit", cfg.WorkerClaimLimit, diff --git a/internal/config/config.go b/internal/config/config.go index fadad08..5e34d30 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -13,11 +13,12 @@ type Config struct { MigrateOnStart bool APIAuthToken string - LLMBaseURL string - LLMAPIKey string - LLMModel string - LLMTimeout time.Duration - WhisperXURL string + LLMBaseURL string + LLMAPIKey string + LLMModel string + LLMTimeout time.Duration + WhisperXURL string + WhisperXTimeout time.Duration WorkerID string WorkerPollInterval time.Duration @@ -33,11 +34,12 @@ func Load() Config { MigrateOnStart: envBool("MIGRATE_ON_START", true), APIAuthToken: envString("AI_SERVICE_TOKEN", ""), - LLMBaseURL: envString("LLM_BASE_URL", ""), - LLMAPIKey: envString("LLM_API_KEY", ""), - LLMModel: envString("LLM_MODEL", "qwen2.5-14b"), - LLMTimeout: envDuration("LLM_TIMEOUT", 5*time.Minute), - WhisperXURL: envString("WHISPERX_URL", ""), + LLMBaseURL: envString("LLM_BASE_URL", ""), + LLMAPIKey: envString("LLM_API_KEY", ""), + LLMModel: envString("LLM_MODEL", "qwen2.5-14b"), + LLMTimeout: envDuration("LLM_TIMEOUT", 5*time.Minute), + WhisperXURL: envString("WHISPERX_URL", ""), + WhisperXTimeout: envDuration("WHISPERX_TIMEOUT", 10*time.Minute), WorkerID: envString("WORKER_ID", hostname()), WorkerPollInterval: envDuration("WORKER_POLL_INTERVAL", 2*time.Second), diff --git a/internal/transcription/client.go b/internal/transcription/client.go new file mode 100644 index 0000000..5770093 --- /dev/null +++ b/internal/transcription/client.go @@ -0,0 +1,168 @@ +package transcription + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "mime/multipart" + "net/http" + "path/filepath" + "strings" + "time" +) + +type Client struct { + baseURL string + http *http.Client +} + +type Input struct { + AudioURL string `json:"audio_url"` + Filename string `json:"filename,omitempty"` + Language string `json:"language,omitempty"` + Diarize bool `json:"diarize"` + MinSpeakers int `json:"min_speakers,omitempty"` + MaxSpeakers int `json:"max_speakers,omitempty"` +} + +type Segment struct { + Start float64 `json:"start"` + End float64 `json:"end"` + Text string `json:"text"` + Speaker string `json:"speaker,omitempty"` +} + +type Result struct { + Language string `json:"language"` + Segments []Segment `json:"segments"` + DiarizeError *string `json:"diarize_error,omitempty"` + AlignError *string `json:"align_error,omitempty"` + DurationMS int64 `json:"duration_ms"` +} + +type whisperResponse struct { + Language string `json:"language"` + Segments []Segment `json:"segments"` + DiarizeError *string `json:"diarize_error,omitempty"` + AlignError *string `json:"align_error,omitempty"` +} + +func New(baseURL string, timeout time.Duration) *Client { + baseURL = strings.TrimRight(strings.TrimSpace(baseURL), "/") + if baseURL == "" { + return nil + } + if timeout <= 0 { + timeout = 10 * time.Minute + } + return &Client{ + baseURL: baseURL, + http: &http.Client{Timeout: timeout}, + } +} + +func (c *Client) Transcribe(ctx context.Context, in Input) (*Result, error) { + if c == nil || c.baseURL == "" { + return nil, fmt.Errorf("whisperx not configured") + } + if strings.TrimSpace(in.AudioURL) == "" { + return nil, fmt.Errorf("audio_url is required") + } + audio, filename, err := c.downloadAudio(ctx, in) + if err != nil { + return nil, err + } + resp, duration, err := c.transcribeAudio(ctx, audio, filename, in) + if err != nil { + return nil, err + } + return &Result{ + Language: resp.Language, + Segments: resp.Segments, + DiarizeError: resp.DiarizeError, + AlignError: resp.AlignError, + DurationMS: duration.Milliseconds(), + }, nil +} + +func (c *Client) downloadAudio(ctx context.Context, in Input) ([]byte, string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, in.AudioURL, nil) + if err != nil { + return nil, "", fmt.Errorf("audio request: %w", err) + } + resp, err := c.http.Do(req) + if err != nil { + return nil, "", fmt.Errorf("audio download: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode >= 300 { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + return nil, "", fmt.Errorf("audio HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + audio, err := io.ReadAll(io.LimitReader(resp.Body, 512<<20)) + if err != nil { + return nil, "", fmt.Errorf("audio read: %w", err) + } + if len(audio) == 0 { + return nil, "", fmt.Errorf("audio is empty") + } + filename := filepath.Base(strings.TrimSpace(in.Filename)) + if filename == "." || filename == "/" || filename == "" { + filename = "audio.mp3" + } + return audio, filename, nil +} + +func (c *Client) transcribeAudio(ctx context.Context, audio []byte, filename string, in Input) (*whisperResponse, time.Duration, error) { + body := &bytes.Buffer{} + mw := multipart.NewWriter(body) + fw, err := mw.CreateFormFile("file", filename) + if err != nil { + return nil, 0, fmt.Errorf("create form file: %w", err) + } + if _, err := fw.Write(audio); err != nil { + return nil, 0, fmt.Errorf("copy audio: %w", err) + } + if in.Language != "" { + _ = mw.WriteField("language", in.Language) + } + if in.Diarize { + _ = mw.WriteField("diarize", "true") + if in.MinSpeakers > 0 { + _ = mw.WriteField("min_speakers", fmt.Sprintf("%d", in.MinSpeakers)) + } + if in.MaxSpeakers > 0 { + _ = mw.WriteField("max_speakers", fmt.Sprintf("%d", in.MaxSpeakers)) + } + } else { + _ = mw.WriteField("diarize", "false") + } + if err := mw.Close(); err != nil { + return nil, 0, fmt.Errorf("close form: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/transcribe", body) + if err != nil { + return nil, 0, fmt.Errorf("whisperx request: %w", err) + } + req.Header.Set("Content-Type", mw.FormDataContentType()) + + start := time.Now() + resp, err := c.http.Do(req) + duration := time.Since(start) + if err != nil { + return nil, duration, fmt.Errorf("whisperx do: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode >= 300 { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + return nil, duration, fmt.Errorf("whisperx HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + var out whisperResponse + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + return nil, duration, fmt.Errorf("whisperx decode: %w", err) + } + return &out, duration, nil +} diff --git a/internal/worker/worker.go b/internal/worker/worker.go index 037899d..c21ebd3 100644 --- a/internal/worker/worker.go +++ b/internal/worker/worker.go @@ -10,17 +10,22 @@ import ( "ai-service/internal/llm" "ai-service/internal/model" "ai-service/internal/store" + "ai-service/internal/transcription" ) const ( TaskLLMChat = "llm_chat" TaskChatCompletion = "chat_completion" TaskCallAnalysis = "call_analysis" + TaskTranscription = "transcription" + + TranscriptionProfile = "whisperx" ) type Worker struct { store *store.Store llm *llm.Client + transcriber *transcription.Client workerID string modelProfile string pollInterval time.Duration @@ -28,7 +33,7 @@ type Worker struct { leaseTimeout time.Duration } -func New(store *store.Store, llmClient *llm.Client, workerID, modelProfile string, pollInterval, leaseTimeout time.Duration, claimLimit int) *Worker { +func New(store *store.Store, llmClient *llm.Client, transcriber *transcription.Client, workerID, modelProfile string, pollInterval, leaseTimeout time.Duration, claimLimit int) *Worker { if pollInterval <= 0 { pollInterval = 2 * time.Second } @@ -44,6 +49,7 @@ func New(store *store.Store, llmClient *llm.Client, workerID, modelProfile strin return &Worker{ store: store, llm: llmClient, + transcriber: transcriber, workerID: workerID, modelProfile: modelProfile, pollInterval: pollInterval, @@ -73,8 +79,8 @@ func (w *Worker) tick(ctx context.Context) { } jobs, err := w.store.ClaimJobs(ctx, model.ClaimJobs{ WorkerID: w.workerID, - TaskTypes: []string{TaskLLMChat, TaskChatCompletion, TaskCallAnalysis}, - ModelProfiles: []string{w.modelProfile}, + TaskTypes: []string{TaskLLMChat, TaskChatCompletion, TaskCallAnalysis, TaskTranscription}, + ModelProfiles: []string{w.modelProfile, TranscriptionProfile}, Limit: w.claimLimit, }) if err != nil { @@ -87,6 +93,10 @@ func (w *Worker) tick(ctx context.Context) { } func (w *Worker) process(ctx context.Context, job *model.Job) { + if job.TaskType == TaskTranscription { + w.processTranscription(ctx, job) + return + } var input llm.ChatInput if err := json.Unmarshal(job.Input, &input); err != nil { w.fail(ctx, job, "bad_input", err.Error()) @@ -107,12 +117,62 @@ func (w *Worker) process(ctx context.Context, job *model.Job) { } } +func (w *Worker) processTranscription(ctx context.Context, job *model.Job) { + if w.transcriber == nil { + w.fail(ctx, job, "provider_unavailable", "whisperx not configured") + return + } + var input transcription.Input + if err := json.Unmarshal(job.Input, &input); err != nil { + w.fail(ctx, job, "bad_input", err.Error()) + return + } + result, err := w.transcriber.Transcribe(ctx, input) + if err != nil { + w.fail(ctx, job, classifyTranscriptionError(err), err.Error()) + return + } + body, err := json.Marshal(result) + if err != nil { + w.fail(ctx, job, "bad_response", err.Error()) + return + } + if _, err := w.store.CompleteJob(ctx, job.ID, model.CompleteJob{Result: body}); err != nil { + slog.Error("complete transcription job failed", "job_id", job.ID, "error", err) + } +} + func (w *Worker) fail(ctx context.Context, job *model.Job, code, message string) { if _, err := w.store.FailJob(ctx, job.ID, model.FailJob{ErrorCode: code, ErrorMessage: message}); err != nil { slog.Error("fail job failed", "job_id", job.ID, "error", err) } } +func classifyTranscriptionError(err error) string { + if err == nil { + return "unknown" + } + s := strings.ToLower(err.Error()) + switch { + case strings.Contains(s, "context deadline exceeded") || strings.Contains(s, "timeout"): + return "timeout" + case strings.Contains(s, "audio_url is required"): + return "bad_input" + case strings.Contains(s, "audio http 4") || strings.Contains(s, "audio is empty"): + return "bad_audio" + case strings.Contains(s, "audio download") || strings.Contains(s, "audio http 5"): + return "storage_error" + case strings.Contains(s, "whisperx http 4") || strings.Contains(s, "ffmpeg") || strings.Contains(s, "invalid data") || strings.Contains(s, "could not decode"): + return "bad_audio" + case strings.Contains(s, "whisperx do") || strings.Contains(s, "connection refused") || strings.Contains(s, "connection reset") || strings.Contains(s, "closed network connection"): + return "provider_unavailable" + case strings.Contains(s, "decode"): + return "bad_response" + default: + return "unknown" + } +} + func classifyLLMError(err error) string { if err == nil { return "unknown" diff --git a/k8s/configmap.yaml b/k8s/configmap.yaml index bcada4c..6ed7baa 100644 --- a/k8s/configmap.yaml +++ b/k8s/configmap.yaml @@ -12,6 +12,7 @@ data: LLM_MODEL: "qwen2.5-14b" LLM_TIMEOUT: "5m" WHISPERX_URL: "http://10.2.3.5:8001" + WHISPERX_TIMEOUT: "10m" WORKER_POLL_INTERVAL: "2s" WORKER_CLAIM_LIMIT: "4" WORKER_LEASE_TIMEOUT: "15m" diff --git a/k8s/worker-deployment.yaml b/k8s/worker-deployment.yaml index ff2461c..3bdad01 100644 --- a/k8s/worker-deployment.yaml +++ b/k8s/worker-deployment.yaml @@ -14,6 +14,10 @@ spec: app: ai-service-worker spec: terminationGracePeriodSeconds: 20 + hostAliases: + - ip: "77.105.173.42" + hostnames: + - "s3-minio.estateliga.work" containers: - name: worker image: localhost:30300/admin/ai-service:latest