diff --git a/internal/store/retry_policy.go b/internal/store/retry_policy.go new file mode 100644 index 0000000..58a6a70 --- /dev/null +++ b/internal/store/retry_policy.go @@ -0,0 +1,22 @@ +package store + +import ( + "strings" + "time" +) + +type failRetryPolicy struct { + Retryable bool + Delay time.Duration +} + +func retryPolicyForError(errorCode string) failRetryPolicy { + switch strings.TrimSpace(errorCode) { + case "provider_unavailable", "model_unavailable", "timeout", "storage_error", "stale_worker": + return failRetryPolicy{Retryable: true, Delay: 30 * time.Second} + case "bad_response", "unknown": + return failRetryPolicy{Retryable: true, Delay: 2 * time.Minute} + default: + return failRetryPolicy{} + } +} diff --git a/internal/store/retry_policy_test.go b/internal/store/retry_policy_test.go new file mode 100644 index 0000000..429ba04 --- /dev/null +++ b/internal/store/retry_policy_test.go @@ -0,0 +1,37 @@ +package store + +import ( + "testing" + "time" +) + +func TestRetryPolicyForError(t *testing.T) { + tests := []struct { + name string + code string + retryable bool + delay time.Duration + }{ + {name: "provider unavailable", code: "provider_unavailable", retryable: true, delay: 30 * time.Second}, + {name: "model unavailable", code: "model_unavailable", retryable: true, delay: 30 * time.Second}, + {name: "timeout", code: "timeout", retryable: true, delay: 30 * time.Second}, + {name: "storage", code: "storage_error", retryable: true, delay: 30 * time.Second}, + {name: "bad response", code: "bad_response", retryable: true, delay: 2 * time.Minute}, + {name: "unknown", code: "unknown", retryable: true, delay: 2 * time.Minute}, + {name: "bad audio", code: "bad_audio"}, + {name: "bad input", code: "bad_input"}, + {name: "context length", code: "context_length"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := retryPolicyForError(tt.code) + if got.Retryable != tt.retryable { + t.Fatalf("Retryable = %v, want %v", got.Retryable, tt.retryable) + } + if got.Delay != tt.delay { + t.Fatalf("Delay = %s, want %s", got.Delay, tt.delay) + } + }) + } +} diff --git a/internal/store/store.go b/internal/store/store.go index 9899e6a..d5ff862 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -548,19 +548,23 @@ func (s *Store) FailJob(ctx context.Context, id uuid.UUID, in model.FailJob) (*m errorCode = "unknown" } errorMessage := strings.TrimSpace(in.ErrorMessage) + policy := retryPolicyForError(errorCode) const q = ` UPDATE ai_jobs -SET status = 'failed', +SET status = CASE WHEN $4 AND attempts < max_attempts THEN 'pending' ELSE 'failed' END, error_code = $2, error_message = $3, - completed_at = NOW(), - heartbeat_at = NOW(), + scheduled_at = CASE WHEN $4 AND attempts < max_attempts THEN NOW() + make_interval(secs => $5) ELSE scheduled_at END, + started_at = CASE WHEN $4 AND attempts < max_attempts THEN NULL ELSE started_at END, + completed_at = CASE WHEN $4 AND attempts < max_attempts THEN NULL ELSE NOW() END, + worker_id = NULL, + heartbeat_at = CASE WHEN $4 AND attempts < max_attempts THEN NULL ELSE NOW() END, updated_at = NOW() WHERE id = $1 AND status = 'running' RETURNING ` + jobSelectColumns + ` ` - job, err := scanJob(s.pool.QueryRow(ctx, q, id, errorCode, errorMessage)) + job, err := scanJob(s.pool.QueryRow(ctx, q, id, errorCode, errorMessage, policy.Retryable, int(policy.Delay.Seconds()))) if errors.Is(err, pgx.ErrNoRows) { return nil, nil }