Add transcription jobs to AI service
This commit is contained in:
@@ -10,17 +10,22 @@ import (
|
||||
"ai-service/internal/llm"
|
||||
"ai-service/internal/model"
|
||||
"ai-service/internal/store"
|
||||
"ai-service/internal/transcription"
|
||||
)
|
||||
|
||||
const (
|
||||
TaskLLMChat = "llm_chat"
|
||||
TaskChatCompletion = "chat_completion"
|
||||
TaskCallAnalysis = "call_analysis"
|
||||
TaskTranscription = "transcription"
|
||||
|
||||
TranscriptionProfile = "whisperx"
|
||||
)
|
||||
|
||||
type Worker struct {
|
||||
store *store.Store
|
||||
llm *llm.Client
|
||||
transcriber *transcription.Client
|
||||
workerID string
|
||||
modelProfile string
|
||||
pollInterval time.Duration
|
||||
@@ -28,7 +33,7 @@ type Worker struct {
|
||||
leaseTimeout time.Duration
|
||||
}
|
||||
|
||||
func New(store *store.Store, llmClient *llm.Client, workerID, modelProfile string, pollInterval, leaseTimeout time.Duration, claimLimit int) *Worker {
|
||||
func New(store *store.Store, llmClient *llm.Client, transcriber *transcription.Client, workerID, modelProfile string, pollInterval, leaseTimeout time.Duration, claimLimit int) *Worker {
|
||||
if pollInterval <= 0 {
|
||||
pollInterval = 2 * time.Second
|
||||
}
|
||||
@@ -44,6 +49,7 @@ func New(store *store.Store, llmClient *llm.Client, workerID, modelProfile strin
|
||||
return &Worker{
|
||||
store: store,
|
||||
llm: llmClient,
|
||||
transcriber: transcriber,
|
||||
workerID: workerID,
|
||||
modelProfile: modelProfile,
|
||||
pollInterval: pollInterval,
|
||||
@@ -73,8 +79,8 @@ func (w *Worker) tick(ctx context.Context) {
|
||||
}
|
||||
jobs, err := w.store.ClaimJobs(ctx, model.ClaimJobs{
|
||||
WorkerID: w.workerID,
|
||||
TaskTypes: []string{TaskLLMChat, TaskChatCompletion, TaskCallAnalysis},
|
||||
ModelProfiles: []string{w.modelProfile},
|
||||
TaskTypes: []string{TaskLLMChat, TaskChatCompletion, TaskCallAnalysis, TaskTranscription},
|
||||
ModelProfiles: []string{w.modelProfile, TranscriptionProfile},
|
||||
Limit: w.claimLimit,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -87,6 +93,10 @@ func (w *Worker) tick(ctx context.Context) {
|
||||
}
|
||||
|
||||
func (w *Worker) process(ctx context.Context, job *model.Job) {
|
||||
if job.TaskType == TaskTranscription {
|
||||
w.processTranscription(ctx, job)
|
||||
return
|
||||
}
|
||||
var input llm.ChatInput
|
||||
if err := json.Unmarshal(job.Input, &input); err != nil {
|
||||
w.fail(ctx, job, "bad_input", err.Error())
|
||||
@@ -107,12 +117,62 @@ func (w *Worker) process(ctx context.Context, job *model.Job) {
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) processTranscription(ctx context.Context, job *model.Job) {
|
||||
if w.transcriber == nil {
|
||||
w.fail(ctx, job, "provider_unavailable", "whisperx not configured")
|
||||
return
|
||||
}
|
||||
var input transcription.Input
|
||||
if err := json.Unmarshal(job.Input, &input); err != nil {
|
||||
w.fail(ctx, job, "bad_input", err.Error())
|
||||
return
|
||||
}
|
||||
result, err := w.transcriber.Transcribe(ctx, input)
|
||||
if err != nil {
|
||||
w.fail(ctx, job, classifyTranscriptionError(err), err.Error())
|
||||
return
|
||||
}
|
||||
body, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
w.fail(ctx, job, "bad_response", err.Error())
|
||||
return
|
||||
}
|
||||
if _, err := w.store.CompleteJob(ctx, job.ID, model.CompleteJob{Result: body}); err != nil {
|
||||
slog.Error("complete transcription job failed", "job_id", job.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) fail(ctx context.Context, job *model.Job, code, message string) {
|
||||
if _, err := w.store.FailJob(ctx, job.ID, model.FailJob{ErrorCode: code, ErrorMessage: message}); err != nil {
|
||||
slog.Error("fail job failed", "job_id", job.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func classifyTranscriptionError(err error) string {
|
||||
if err == nil {
|
||||
return "unknown"
|
||||
}
|
||||
s := strings.ToLower(err.Error())
|
||||
switch {
|
||||
case strings.Contains(s, "context deadline exceeded") || strings.Contains(s, "timeout"):
|
||||
return "timeout"
|
||||
case strings.Contains(s, "audio_url is required"):
|
||||
return "bad_input"
|
||||
case strings.Contains(s, "audio http 4") || strings.Contains(s, "audio is empty"):
|
||||
return "bad_audio"
|
||||
case strings.Contains(s, "audio download") || strings.Contains(s, "audio http 5"):
|
||||
return "storage_error"
|
||||
case strings.Contains(s, "whisperx http 4") || strings.Contains(s, "ffmpeg") || strings.Contains(s, "invalid data") || strings.Contains(s, "could not decode"):
|
||||
return "bad_audio"
|
||||
case strings.Contains(s, "whisperx do") || strings.Contains(s, "connection refused") || strings.Contains(s, "connection reset") || strings.Contains(s, "closed network connection"):
|
||||
return "provider_unavailable"
|
||||
case strings.Contains(s, "decode"):
|
||||
return "bad_response"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func classifyLLMError(err error) string {
|
||||
if err == nil {
|
||||
return "unknown"
|
||||
|
||||
Reference in New Issue
Block a user