218 lines
6.1 KiB
Go
218 lines
6.1 KiB
Go
package worker
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"log/slog"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"ai-service/internal/llm"
|
|
"ai-service/internal/model"
|
|
"ai-service/internal/store"
|
|
"ai-service/internal/transcription"
|
|
)
|
|
|
|
const (
|
|
TaskLLMChat = "llm_chat"
|
|
TaskChatCompletion = "chat_completion"
|
|
TaskCallAnalysis = "call_analysis"
|
|
TaskTranscription = "transcription"
|
|
|
|
TranscriptionProfile = "whisper-large-v3"
|
|
)
|
|
|
|
type Worker struct {
|
|
store *store.Store
|
|
llm *llm.Client
|
|
transcriber *transcription.Client
|
|
workerID string
|
|
modelProfile string
|
|
taskTypes []string
|
|
modelProfiles []string
|
|
pollInterval time.Duration
|
|
claimLimit int
|
|
leaseTimeout time.Duration
|
|
}
|
|
|
|
func New(store *store.Store, llmClient *llm.Client, transcriber *transcription.Client, workerID, modelProfile string, taskTypes, modelProfiles []string, pollInterval, leaseTimeout time.Duration, claimLimit int) *Worker {
|
|
if pollInterval <= 0 {
|
|
pollInterval = 2 * time.Second
|
|
}
|
|
if leaseTimeout <= 0 {
|
|
leaseTimeout = 15 * time.Minute
|
|
}
|
|
if claimLimit <= 0 {
|
|
claimLimit = 4
|
|
}
|
|
if strings.TrimSpace(workerID) == "" {
|
|
workerID = "ai-service-worker"
|
|
}
|
|
if len(modelProfiles) == 0 {
|
|
modelProfiles = []string{modelProfile, TranscriptionProfile}
|
|
}
|
|
return &Worker{
|
|
store: store,
|
|
llm: llmClient,
|
|
transcriber: transcriber,
|
|
workerID: workerID,
|
|
modelProfile: modelProfile,
|
|
taskTypes: taskTypes,
|
|
modelProfiles: modelProfiles,
|
|
pollInterval: pollInterval,
|
|
claimLimit: claimLimit,
|
|
leaseTimeout: leaseTimeout,
|
|
}
|
|
}
|
|
|
|
func (w *Worker) Run(ctx context.Context) {
|
|
ticker := time.NewTicker(w.pollInterval)
|
|
defer ticker.Stop()
|
|
for {
|
|
w.tick(ctx)
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
}
|
|
}
|
|
}
|
|
|
|
func (w *Worker) tick(ctx context.Context) {
|
|
if reset, err := w.store.RequeueStaleRunning(ctx, w.leaseTimeout, 100); err != nil {
|
|
slog.Error("requeue stale jobs failed", "error", err)
|
|
} else if reset > 0 {
|
|
slog.Warn("requeued stale jobs", "count", reset)
|
|
}
|
|
jobs, err := w.store.ClaimJobs(ctx, model.ClaimJobs{
|
|
WorkerID: w.workerID,
|
|
TaskTypes: w.taskTypes,
|
|
ModelProfiles: w.modelProfiles,
|
|
Limit: w.claimLimit,
|
|
})
|
|
if err != nil {
|
|
slog.Error("claim jobs failed", "error", err)
|
|
return
|
|
}
|
|
if len(jobs) > 1 {
|
|
var wg sync.WaitGroup
|
|
wg.Add(len(jobs))
|
|
for _, job := range jobs {
|
|
go func(job *model.Job) {
|
|
defer wg.Done()
|
|
w.process(ctx, job)
|
|
}(job)
|
|
}
|
|
wg.Wait()
|
|
return
|
|
}
|
|
for _, job := range jobs {
|
|
w.process(ctx, job)
|
|
}
|
|
}
|
|
|
|
func (w *Worker) process(ctx context.Context, job *model.Job) {
|
|
if job.TaskType == TaskTranscription {
|
|
w.processTranscription(ctx, job)
|
|
return
|
|
}
|
|
var input llm.ChatInput
|
|
if err := json.Unmarshal(job.Input, &input); err != nil {
|
|
w.fail(ctx, job, "bad_input", err.Error())
|
|
return
|
|
}
|
|
result, err := w.llm.Chat(ctx, input)
|
|
if err != nil {
|
|
w.fail(ctx, job, classifyLLMError(err), err.Error())
|
|
return
|
|
}
|
|
body, err := json.Marshal(result)
|
|
if err != nil {
|
|
w.fail(ctx, job, "bad_response", err.Error())
|
|
return
|
|
}
|
|
if _, err := w.store.CompleteJob(ctx, job.ID, model.CompleteJob{Result: body}); err != nil {
|
|
slog.Error("complete job failed", "job_id", job.ID, "error", err)
|
|
}
|
|
}
|
|
|
|
func (w *Worker) processTranscription(ctx context.Context, job *model.Job) {
|
|
if w.transcriber == nil {
|
|
w.fail(ctx, job, "provider_unavailable", "transcription providers not configured")
|
|
return
|
|
}
|
|
var input transcription.Input
|
|
if err := json.Unmarshal(job.Input, &input); err != nil {
|
|
w.fail(ctx, job, "bad_input", err.Error())
|
|
return
|
|
}
|
|
result, err := w.transcriber.Transcribe(ctx, input)
|
|
if err != nil {
|
|
w.fail(ctx, job, classifyTranscriptionError(err), err.Error())
|
|
return
|
|
}
|
|
body, err := json.Marshal(result)
|
|
if err != nil {
|
|
w.fail(ctx, job, "bad_response", err.Error())
|
|
return
|
|
}
|
|
if _, err := w.store.CompleteJob(ctx, job.ID, model.CompleteJob{Result: body}); err != nil {
|
|
slog.Error("complete transcription job failed", "job_id", job.ID, "error", err)
|
|
}
|
|
}
|
|
|
|
func (w *Worker) fail(ctx context.Context, job *model.Job, code, message string) {
|
|
if _, err := w.store.FailJob(ctx, job.ID, model.FailJob{ErrorCode: code, ErrorMessage: message}); err != nil {
|
|
slog.Error("fail job failed", "job_id", job.ID, "error", err)
|
|
}
|
|
}
|
|
|
|
func classifyTranscriptionError(err error) string {
|
|
if err == nil {
|
|
return "unknown"
|
|
}
|
|
s := strings.ToLower(err.Error())
|
|
switch {
|
|
case strings.Contains(s, "context deadline exceeded") || strings.Contains(s, "timeout"):
|
|
return "timeout"
|
|
case strings.Contains(s, "audio_url is required"):
|
|
return "bad_input"
|
|
case strings.Contains(s, "request has expired") || strings.Contains(s, "accessdenied"):
|
|
return "storage_error"
|
|
case strings.Contains(s, "audio http 4") || strings.Contains(s, "audio is empty"):
|
|
return "bad_audio"
|
|
case strings.Contains(s, "audio download") || strings.Contains(s, "audio http 5"):
|
|
return "storage_error"
|
|
case strings.Contains(s, "audio transcription http 4") || strings.Contains(s, "invalid data") || strings.Contains(s, "could not decode"):
|
|
return "bad_audio"
|
|
case strings.Contains(s, "audio transcription http 5") || strings.Contains(s, "audio transcription do") || strings.Contains(s, "connection refused") || strings.Contains(s, "connection reset") || strings.Contains(s, "closed network connection"):
|
|
return "provider_unavailable"
|
|
case strings.Contains(s, "audio transcription http 4"):
|
|
return "bad_input"
|
|
case strings.Contains(s, "decode"):
|
|
return "bad_response"
|
|
default:
|
|
return "unknown"
|
|
}
|
|
}
|
|
|
|
func classifyLLMError(err error) string {
|
|
if err == nil {
|
|
return "unknown"
|
|
}
|
|
s := strings.ToLower(err.Error())
|
|
switch {
|
|
case strings.Contains(s, "context deadline exceeded") || strings.Contains(s, "timeout"):
|
|
return "timeout"
|
|
case strings.Contains(s, "connection refused") || strings.Contains(s, "connection reset") || strings.Contains(s, "no route to host") || strings.Contains(s, "llm http 5"):
|
|
return "model_unavailable"
|
|
case strings.Contains(s, "llm http 4") || strings.Contains(s, "messages are required"):
|
|
return "bad_input"
|
|
case strings.Contains(s, "llm decode") || strings.Contains(s, "empty choices"):
|
|
return "bad_response"
|
|
default:
|
|
return "unknown"
|
|
}
|
|
}
|