package worker import ( "context" "encoding/json" "log/slog" "strings" "sync" "time" "ai-service/internal/llm" "ai-service/internal/model" "ai-service/internal/store" "ai-service/internal/transcription" ) const ( TaskLLMChat = "llm_chat" TaskChatCompletion = "chat_completion" TaskCallAnalysis = "call_analysis" TaskTranscription = "transcription" TranscriptionProfile = "whisperx" ) type Worker struct { store *store.Store llm *llm.Client transcriber *transcription.Client workerID string modelProfile string taskTypes []string modelProfiles []string pollInterval time.Duration claimLimit int leaseTimeout time.Duration } func New(store *store.Store, llmClient *llm.Client, transcriber *transcription.Client, workerID, modelProfile string, taskTypes, modelProfiles []string, pollInterval, leaseTimeout time.Duration, claimLimit int) *Worker { if pollInterval <= 0 { pollInterval = 2 * time.Second } if leaseTimeout <= 0 { leaseTimeout = 15 * time.Minute } if claimLimit <= 0 { claimLimit = 4 } if strings.TrimSpace(workerID) == "" { workerID = "ai-service-worker" } if len(modelProfiles) == 0 { modelProfiles = []string{modelProfile, TranscriptionProfile} } return &Worker{ store: store, llm: llmClient, transcriber: transcriber, workerID: workerID, modelProfile: modelProfile, taskTypes: taskTypes, modelProfiles: modelProfiles, pollInterval: pollInterval, claimLimit: claimLimit, leaseTimeout: leaseTimeout, } } func (w *Worker) Run(ctx context.Context) { ticker := time.NewTicker(w.pollInterval) defer ticker.Stop() for { w.tick(ctx) select { case <-ctx.Done(): return case <-ticker.C: } } } func (w *Worker) tick(ctx context.Context) { if reset, err := w.store.RequeueStaleRunning(ctx, w.leaseTimeout, 100); err != nil { slog.Error("requeue stale jobs failed", "error", err) } else if reset > 0 { slog.Warn("requeued stale jobs", "count", reset) } jobs, err := w.store.ClaimJobs(ctx, model.ClaimJobs{ WorkerID: w.workerID, TaskTypes: w.taskTypes, ModelProfiles: w.modelProfiles, Limit: w.claimLimit, }) if err != nil { slog.Error("claim jobs failed", "error", err) return } if len(jobs) > 1 { var wg sync.WaitGroup wg.Add(len(jobs)) for _, job := range jobs { go func(job *model.Job) { defer wg.Done() w.process(ctx, job) }(job) } wg.Wait() return } for _, job := range jobs { w.process(ctx, job) } } func (w *Worker) process(ctx context.Context, job *model.Job) { if job.TaskType == TaskTranscription { w.processTranscription(ctx, job) return } var input llm.ChatInput if err := json.Unmarshal(job.Input, &input); err != nil { w.fail(ctx, job, "bad_input", err.Error()) return } result, err := w.llm.Chat(ctx, input) if err != nil { w.fail(ctx, job, classifyLLMError(err), err.Error()) return } body, err := json.Marshal(result) if err != nil { w.fail(ctx, job, "bad_response", err.Error()) return } if _, err := w.store.CompleteJob(ctx, job.ID, model.CompleteJob{Result: body}); err != nil { slog.Error("complete job failed", "job_id", job.ID, "error", err) } } func (w *Worker) processTranscription(ctx context.Context, job *model.Job) { if w.transcriber == nil { w.fail(ctx, job, "provider_unavailable", "whisperx not configured") return } var input transcription.Input if err := json.Unmarshal(job.Input, &input); err != nil { w.fail(ctx, job, "bad_input", err.Error()) return } result, err := w.transcriber.Transcribe(ctx, input) if err != nil { w.fail(ctx, job, classifyTranscriptionError(err), err.Error()) return } body, err := json.Marshal(result) if err != nil { w.fail(ctx, job, "bad_response", err.Error()) return } if _, err := w.store.CompleteJob(ctx, job.ID, model.CompleteJob{Result: body}); err != nil { slog.Error("complete transcription job failed", "job_id", job.ID, "error", err) } } func (w *Worker) fail(ctx context.Context, job *model.Job, code, message string) { if _, err := w.store.FailJob(ctx, job.ID, model.FailJob{ErrorCode: code, ErrorMessage: message}); err != nil { slog.Error("fail job failed", "job_id", job.ID, "error", err) } } func classifyTranscriptionError(err error) string { if err == nil { return "unknown" } s := strings.ToLower(err.Error()) switch { case strings.Contains(s, "context deadline exceeded") || strings.Contains(s, "timeout"): return "timeout" case strings.Contains(s, "audio_url is required"): return "bad_input" case strings.Contains(s, "request has expired") || strings.Contains(s, "accessdenied"): return "storage_error" case strings.Contains(s, "audio http 4") || strings.Contains(s, "audio is empty"): return "bad_audio" case strings.Contains(s, "audio download") || strings.Contains(s, "audio http 5"): return "storage_error" case strings.Contains(s, "whisperx http 4") || strings.Contains(s, "ffmpeg") || strings.Contains(s, "invalid data") || strings.Contains(s, "could not decode"): return "bad_audio" case strings.Contains(s, "whisperx http 5") || strings.Contains(s, "whisperx do") || strings.Contains(s, "connection refused") || strings.Contains(s, "connection reset") || strings.Contains(s, "closed network connection"): return "provider_unavailable" case strings.Contains(s, "decode"): return "bad_response" default: return "unknown" } } func classifyLLMError(err error) string { if err == nil { return "unknown" } s := strings.ToLower(err.Error()) switch { case strings.Contains(s, "context deadline exceeded") || strings.Contains(s, "timeout"): return "timeout" case strings.Contains(s, "connection refused") || strings.Contains(s, "connection reset") || strings.Contains(s, "no route to host") || strings.Contains(s, "llm http 5"): return "model_unavailable" case strings.Contains(s, "llm http 4") || strings.Contains(s, "messages are required"): return "bad_input" case strings.Contains(s, "llm decode") || strings.Contains(s, "empty choices"): return "bad_response" default: return "unknown" } }