Files
ai-service/internal/worker/worker.go
Grendgi 631a45aff3
Some checks failed
CI / test (push) Failing after 8s
Build and Deploy / build-and-deploy (push) Successful in 31s
Classify LLM context length errors
2026-06-10 16:16:57 +03:00

220 lines
6.3 KiB
Go

package worker
import (
"context"
"encoding/json"
"log/slog"
"strings"
"sync"
"time"
"ai-service/internal/llm"
"ai-service/internal/model"
"ai-service/internal/store"
"ai-service/internal/transcription"
)
const (
TaskLLMChat = "llm_chat"
TaskChatCompletion = "chat_completion"
TaskCallAnalysis = "call_analysis"
TaskTranscription = "transcription"
TranscriptionProfile = "whisper-large-v3"
)
type Worker struct {
store *store.Store
llm *llm.Client
transcriber *transcription.Client
workerID string
modelProfile string
taskTypes []string
modelProfiles []string
pollInterval time.Duration
claimLimit int
leaseTimeout time.Duration
}
func New(store *store.Store, llmClient *llm.Client, transcriber *transcription.Client, workerID, modelProfile string, taskTypes, modelProfiles []string, pollInterval, leaseTimeout time.Duration, claimLimit int) *Worker {
if pollInterval <= 0 {
pollInterval = 2 * time.Second
}
if leaseTimeout <= 0 {
leaseTimeout = 15 * time.Minute
}
if claimLimit <= 0 {
claimLimit = 4
}
if strings.TrimSpace(workerID) == "" {
workerID = "ai-service-worker"
}
if len(modelProfiles) == 0 {
modelProfiles = []string{modelProfile, TranscriptionProfile}
}
return &Worker{
store: store,
llm: llmClient,
transcriber: transcriber,
workerID: workerID,
modelProfile: modelProfile,
taskTypes: taskTypes,
modelProfiles: modelProfiles,
pollInterval: pollInterval,
claimLimit: claimLimit,
leaseTimeout: leaseTimeout,
}
}
func (w *Worker) Run(ctx context.Context) {
ticker := time.NewTicker(w.pollInterval)
defer ticker.Stop()
for {
w.tick(ctx)
select {
case <-ctx.Done():
return
case <-ticker.C:
}
}
}
func (w *Worker) tick(ctx context.Context) {
if reset, err := w.store.RequeueStaleRunning(ctx, w.leaseTimeout, 100); err != nil {
slog.Error("requeue stale jobs failed", "error", err)
} else if reset > 0 {
slog.Warn("requeued stale jobs", "count", reset)
}
jobs, err := w.store.ClaimJobs(ctx, model.ClaimJobs{
WorkerID: w.workerID,
TaskTypes: w.taskTypes,
ModelProfiles: w.modelProfiles,
Limit: w.claimLimit,
})
if err != nil {
slog.Error("claim jobs failed", "error", err)
return
}
if len(jobs) > 1 {
var wg sync.WaitGroup
wg.Add(len(jobs))
for _, job := range jobs {
go func(job *model.Job) {
defer wg.Done()
w.process(ctx, job)
}(job)
}
wg.Wait()
return
}
for _, job := range jobs {
w.process(ctx, job)
}
}
func (w *Worker) process(ctx context.Context, job *model.Job) {
if job.TaskType == TaskTranscription {
w.processTranscription(ctx, job)
return
}
var input llm.ChatInput
if err := json.Unmarshal(job.Input, &input); err != nil {
w.fail(ctx, job, "bad_input", err.Error())
return
}
result, err := w.llm.Chat(ctx, input)
if err != nil {
w.fail(ctx, job, classifyLLMError(err), err.Error())
return
}
body, err := json.Marshal(result)
if err != nil {
w.fail(ctx, job, "bad_response", err.Error())
return
}
if _, err := w.store.CompleteJob(ctx, job.ID, model.CompleteJob{Result: body}); err != nil {
slog.Error("complete job failed", "job_id", job.ID, "error", err)
}
}
func (w *Worker) processTranscription(ctx context.Context, job *model.Job) {
if w.transcriber == nil {
w.fail(ctx, job, "provider_unavailable", "transcription providers not configured")
return
}
var input transcription.Input
if err := json.Unmarshal(job.Input, &input); err != nil {
w.fail(ctx, job, "bad_input", err.Error())
return
}
result, err := w.transcriber.Transcribe(ctx, input)
if err != nil {
w.fail(ctx, job, classifyTranscriptionError(err), err.Error())
return
}
body, err := json.Marshal(result)
if err != nil {
w.fail(ctx, job, "bad_response", err.Error())
return
}
if _, err := w.store.CompleteJob(ctx, job.ID, model.CompleteJob{Result: body}); err != nil {
slog.Error("complete transcription job failed", "job_id", job.ID, "error", err)
}
}
func (w *Worker) fail(ctx context.Context, job *model.Job, code, message string) {
if _, err := w.store.FailJob(ctx, job.ID, model.FailJob{ErrorCode: code, ErrorMessage: message}); err != nil {
slog.Error("fail job failed", "job_id", job.ID, "error", err)
}
}
func classifyTranscriptionError(err error) string {
if err == nil {
return "unknown"
}
s := strings.ToLower(err.Error())
switch {
case strings.Contains(s, "context deadline exceeded") || strings.Contains(s, "timeout"):
return "timeout"
case strings.Contains(s, "audio_url is required"):
return "bad_input"
case strings.Contains(s, "request has expired") || strings.Contains(s, "accessdenied"):
return "storage_error"
case strings.Contains(s, "audio http 4") || strings.Contains(s, "audio is empty"):
return "bad_audio"
case strings.Contains(s, "audio download") || strings.Contains(s, "audio http 5"):
return "storage_error"
case strings.Contains(s, "audio transcription http 4") || strings.Contains(s, "invalid data") || strings.Contains(s, "could not decode"):
return "bad_audio"
case strings.Contains(s, "audio transcription http 5") || strings.Contains(s, "audio transcription do") || strings.Contains(s, "connection refused") || strings.Contains(s, "connection reset") || strings.Contains(s, "closed network connection"):
return "provider_unavailable"
case strings.Contains(s, "audio transcription http 4"):
return "bad_input"
case strings.Contains(s, "decode"):
return "bad_response"
default:
return "unknown"
}
}
func classifyLLMError(err error) string {
if err == nil {
return "unknown"
}
s := strings.ToLower(err.Error())
switch {
case strings.Contains(s, "context deadline exceeded") || strings.Contains(s, "timeout"):
return "timeout"
case strings.Contains(s, "connection refused") || strings.Contains(s, "connection reset") || strings.Contains(s, "no route to host") || strings.Contains(s, "llm http 5"):
return "model_unavailable"
case strings.Contains(s, "maximum context length") || strings.Contains(s, "context length") || strings.Contains(s, "input_tokens"):
return "context_length"
case strings.Contains(s, "llm http 4") || strings.Contains(s, "messages are required"):
return "bad_input"
case strings.Contains(s, "llm decode") || strings.Contains(s, "empty choices"):
return "bad_response"
default:
return "unknown"
}
}