Add generic AI job queue lifecycle
This commit is contained in:
@@ -39,10 +39,16 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
s.handleCreateJob(w, r)
|
||||
case r.Method == http.MethodPost && path == "/api/v1/jobs/batch":
|
||||
s.handleCreateBatch(w, r)
|
||||
case r.Method == http.MethodPost && path == "/api/v1/jobs/claim":
|
||||
s.handleClaimJobs(w, r)
|
||||
case r.Method == http.MethodGet && strings.HasPrefix(path, "/api/v1/jobs/"):
|
||||
s.handleGetJob(w, r, path)
|
||||
case r.Method == http.MethodPost && strings.HasPrefix(path, "/api/v1/jobs/") && strings.HasSuffix(path, "/retry"):
|
||||
s.handleRetryJob(w, r, path)
|
||||
case r.Method == http.MethodPost && strings.HasPrefix(path, "/api/v1/jobs/") && strings.HasSuffix(path, "/complete"):
|
||||
s.handleCompleteJob(w, r, path)
|
||||
case r.Method == http.MethodPost && strings.HasPrefix(path, "/api/v1/jobs/") && strings.HasSuffix(path, "/fail"):
|
||||
s.handleFailJob(w, r, path)
|
||||
case r.Method == http.MethodGet && path == "/api/v1/stats":
|
||||
s.handleStats(w, r)
|
||||
default:
|
||||
@@ -136,6 +142,26 @@ func (s *Server) handleCreateBatch(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusCreated, out)
|
||||
}
|
||||
|
||||
type claimJobsResponse struct {
|
||||
Jobs []*model.Job `json:"jobs"`
|
||||
}
|
||||
|
||||
func (s *Server) handleClaimJobs(w http.ResponseWriter, r *http.Request) {
|
||||
var req model.ClaimJobs
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "bad json")
|
||||
return
|
||||
}
|
||||
ctx, cancel := contextWithTimeout(r, 8*time.Second)
|
||||
defer cancel()
|
||||
jobs, err := s.store.ClaimJobs(ctx, req)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, claimJobsResponse{Jobs: jobs})
|
||||
}
|
||||
|
||||
func (s *Server) handleGetJob(w http.ResponseWriter, r *http.Request, path string) {
|
||||
id, err := jobIDFromPath(path, false)
|
||||
if err != nil {
|
||||
@@ -176,6 +202,56 @@ func (s *Server) handleRetryJob(w http.ResponseWriter, r *http.Request, path str
|
||||
writeJSON(w, http.StatusOK, job)
|
||||
}
|
||||
|
||||
func (s *Server) handleCompleteJob(w http.ResponseWriter, r *http.Request, path string) {
|
||||
id, err := jobIDFromActionPath(path, "complete")
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
var req model.CompleteJob
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "bad json")
|
||||
return
|
||||
}
|
||||
ctx, cancel := contextWithTimeout(r, 8*time.Second)
|
||||
defer cancel()
|
||||
job, err := s.store.CompleteJob(ctx, id, req)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if job == nil {
|
||||
writeError(w, http.StatusNotFound, "running job not found")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, job)
|
||||
}
|
||||
|
||||
func (s *Server) handleFailJob(w http.ResponseWriter, r *http.Request, path string) {
|
||||
id, err := jobIDFromActionPath(path, "fail")
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
var req model.FailJob
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "bad json")
|
||||
return
|
||||
}
|
||||
ctx, cancel := contextWithTimeout(r, 8*time.Second)
|
||||
defer cancel()
|
||||
job, err := s.store.FailJob(ctx, id, req)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if job == nil {
|
||||
writeError(w, http.StatusNotFound, "running job not found")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, job)
|
||||
}
|
||||
|
||||
func (s *Server) handleStats(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := contextWithTimeout(r, 8*time.Second)
|
||||
defer cancel()
|
||||
@@ -199,6 +275,16 @@ func jobIDFromPath(path string, retry bool) (uuid.UUID, error) {
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func jobIDFromActionPath(path string, action string) (uuid.UUID, error) {
|
||||
raw := strings.TrimPrefix(path, "/api/v1/jobs/")
|
||||
raw = strings.TrimSuffix(raw, "/"+action)
|
||||
id, err := uuid.Parse(strings.Trim(raw, "/"))
|
||||
if err != nil {
|
||||
return uuid.Nil, errors.New("bad job id")
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func isValidationError(err error) bool {
|
||||
msg := err.Error()
|
||||
return strings.Contains(msg, " is required")
|
||||
|
||||
@@ -18,6 +18,8 @@ CREATE TABLE IF NOT EXISTS ai_jobs (
|
||||
scheduled_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
started_at TIMESTAMPTZ,
|
||||
completed_at TIMESTAMPTZ,
|
||||
worker_id TEXT,
|
||||
heartbeat_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
idempotency_key TEXT
|
||||
|
||||
3
internal/migrate/sql/002_ai_jobs_worker_lease.up.sql
Normal file
3
internal/migrate/sql/002_ai_jobs_worker_lease.up.sql
Normal file
@@ -0,0 +1,3 @@
|
||||
ALTER TABLE ai_jobs
|
||||
ADD COLUMN IF NOT EXISTS worker_id TEXT,
|
||||
ADD COLUMN IF NOT EXISTS heartbeat_at TIMESTAMPTZ;
|
||||
@@ -32,6 +32,8 @@ type Job struct {
|
||||
ScheduledAt time.Time `json:"scheduled_at"`
|
||||
StartedAt *time.Time `json:"started_at,omitempty"`
|
||||
CompletedAt *time.Time `json:"completed_at,omitempty"`
|
||||
WorkerID *string `json:"worker_id,omitempty"`
|
||||
HeartbeatAt *time.Time `json:"heartbeat_at,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
IdempotencyKey *string `json:"idempotency_key,omitempty"`
|
||||
@@ -49,6 +51,22 @@ type CreateJob struct {
|
||||
IdempotencyKey *string `json:"idempotency_key,omitempty"`
|
||||
}
|
||||
|
||||
type ClaimJobs struct {
|
||||
WorkerID string `json:"worker_id"`
|
||||
TaskTypes []string `json:"task_types"`
|
||||
ModelProfiles []string `json:"model_profiles"`
|
||||
Limit int `json:"limit"`
|
||||
}
|
||||
|
||||
type CompleteJob struct {
|
||||
Result json.RawMessage `json:"result"`
|
||||
}
|
||||
|
||||
type FailJob struct {
|
||||
ErrorCode string `json:"error_code"`
|
||||
ErrorMessage string `json:"error_message"`
|
||||
}
|
||||
|
||||
type QueueStat struct {
|
||||
TaskType string `json:"task_type"`
|
||||
ModelProfile string `json:"model_profile"`
|
||||
|
||||
@@ -19,6 +19,13 @@ type Store struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
const jobSelectColumns = `
|
||||
id, owner_service, owner_ref, task_type, model_profile, priority, status,
|
||||
attempts, max_attempts, input, result, error_code, error_message,
|
||||
scheduled_at, started_at, completed_at, worker_id, heartbeat_at,
|
||||
created_at, updated_at, idempotency_key
|
||||
`
|
||||
|
||||
func Open(ctx context.Context, databaseURL string) (*Store, error) {
|
||||
if strings.TrimSpace(databaseURL) == "" {
|
||||
return nil, errors.New("DATABASE_URL is required")
|
||||
@@ -74,9 +81,7 @@ INSERT INTO ai_jobs (
|
||||
VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9)
|
||||
ON CONFLICT (idempotency_key) WHERE idempotency_key IS NOT NULL
|
||||
DO UPDATE SET updated_at = ai_jobs.updated_at
|
||||
RETURNING id, owner_service, owner_ref, task_type, model_profile, priority, status,
|
||||
attempts, max_attempts, input, result, error_code, error_message,
|
||||
scheduled_at, started_at, completed_at, created_at, updated_at, idempotency_key
|
||||
RETURNING ` + jobSelectColumns + `
|
||||
`
|
||||
row := s.pool.QueryRow(ctx, q,
|
||||
in.OwnerService,
|
||||
@@ -111,7 +116,8 @@ func (s *Store) GetJob(ctx context.Context, id uuid.UUID) (*model.Job, error) {
|
||||
const q = `
|
||||
SELECT id, owner_service, owner_ref, task_type, model_profile, priority, status,
|
||||
attempts, max_attempts, input, result, error_code, error_message,
|
||||
scheduled_at, started_at, completed_at, created_at, updated_at, idempotency_key
|
||||
scheduled_at, started_at, completed_at, worker_id, heartbeat_at,
|
||||
created_at, updated_at, idempotency_key
|
||||
FROM ai_jobs
|
||||
WHERE id = $1
|
||||
`
|
||||
@@ -131,13 +137,16 @@ SET status = 'pending',
|
||||
completed_at = NULL,
|
||||
error_code = NULL,
|
||||
error_message = NULL,
|
||||
worker_id = NULL,
|
||||
heartbeat_at = NULL,
|
||||
scheduled_at = NOW(),
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
AND status IN ('failed', 'running')
|
||||
RETURNING id, owner_service, owner_ref, task_type, model_profile, priority, status,
|
||||
attempts, max_attempts, input, result, error_code, error_message,
|
||||
scheduled_at, started_at, completed_at, created_at, updated_at, idempotency_key
|
||||
scheduled_at, started_at, completed_at, worker_id, heartbeat_at,
|
||||
created_at, updated_at, idempotency_key
|
||||
`
|
||||
job, err := scanJob(s.pool.QueryRow(ctx, q, id))
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
@@ -146,6 +155,115 @@ RETURNING id, owner_service, owner_ref, task_type, model_profile, priority, stat
|
||||
return job, err
|
||||
}
|
||||
|
||||
func (s *Store) ClaimJobs(ctx context.Context, in model.ClaimJobs) ([]*model.Job, error) {
|
||||
if in.Limit <= 0 {
|
||||
in.Limit = 1
|
||||
}
|
||||
if in.Limit > 100 {
|
||||
in.Limit = 100
|
||||
}
|
||||
workerID := strings.TrimSpace(in.WorkerID)
|
||||
if workerID == "" {
|
||||
workerID = "unknown"
|
||||
}
|
||||
if in.TaskTypes == nil {
|
||||
in.TaskTypes = []string{}
|
||||
}
|
||||
if in.ModelProfiles == nil {
|
||||
in.ModelProfiles = []string{}
|
||||
}
|
||||
const q = `
|
||||
WITH picked AS (
|
||||
SELECT id
|
||||
FROM ai_jobs
|
||||
WHERE status = 'pending'
|
||||
AND attempts < max_attempts
|
||||
AND scheduled_at <= NOW()
|
||||
AND (cardinality($1::text[]) = 0 OR task_type = ANY($1::text[]))
|
||||
AND (cardinality($2::text[]) = 0 OR model_profile = ANY($2::text[]))
|
||||
ORDER BY priority DESC, scheduled_at ASC, created_at ASC
|
||||
LIMIT $3
|
||||
FOR UPDATE SKIP LOCKED
|
||||
)
|
||||
UPDATE ai_jobs j
|
||||
SET status = 'running',
|
||||
attempts = j.attempts + 1,
|
||||
started_at = NOW(),
|
||||
completed_at = NULL,
|
||||
error_code = NULL,
|
||||
error_message = NULL,
|
||||
worker_id = $4,
|
||||
heartbeat_at = NOW(),
|
||||
updated_at = NOW()
|
||||
FROM picked
|
||||
WHERE j.id = picked.id
|
||||
RETURNING ` + jobSelectColumns + `
|
||||
`
|
||||
rows, err := s.pool.Query(ctx, q, in.TaskTypes, in.ModelProfiles, in.Limit, workerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var out []*model.Job
|
||||
for rows.Next() {
|
||||
job, err := scanJob(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, job)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
func (s *Store) CompleteJob(ctx context.Context, id uuid.UUID, in model.CompleteJob) (*model.Job, error) {
|
||||
if len(in.Result) == 0 {
|
||||
in.Result = json.RawMessage(`{}`)
|
||||
}
|
||||
const q = `
|
||||
UPDATE ai_jobs
|
||||
SET status = 'done',
|
||||
result = $2,
|
||||
error_code = NULL,
|
||||
error_message = NULL,
|
||||
completed_at = NOW(),
|
||||
heartbeat_at = NOW(),
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
AND status = 'running'
|
||||
RETURNING ` + jobSelectColumns + `
|
||||
`
|
||||
job, err := scanJob(s.pool.QueryRow(ctx, q, id, in.Result))
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
return job, err
|
||||
}
|
||||
|
||||
func (s *Store) FailJob(ctx context.Context, id uuid.UUID, in model.FailJob) (*model.Job, error) {
|
||||
errorCode := strings.TrimSpace(in.ErrorCode)
|
||||
if errorCode == "" {
|
||||
errorCode = "unknown"
|
||||
}
|
||||
errorMessage := strings.TrimSpace(in.ErrorMessage)
|
||||
const q = `
|
||||
UPDATE ai_jobs
|
||||
SET status = 'failed',
|
||||
error_code = $2,
|
||||
error_message = $3,
|
||||
completed_at = NOW(),
|
||||
heartbeat_at = NOW(),
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
AND status = 'running'
|
||||
RETURNING ` + jobSelectColumns + `
|
||||
`
|
||||
job, err := scanJob(s.pool.QueryRow(ctx, q, id, errorCode, errorMessage))
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
return job, err
|
||||
}
|
||||
|
||||
func (s *Store) Stats(ctx context.Context) (*model.Stats, error) {
|
||||
out := &model.Stats{At: time.Now().UTC()}
|
||||
|
||||
@@ -214,6 +332,8 @@ func scanJob(row pgx.Row) (*model.Job, error) {
|
||||
&job.ScheduledAt,
|
||||
&job.StartedAt,
|
||||
&job.CompletedAt,
|
||||
&job.WorkerID,
|
||||
&job.HeartbeatAt,
|
||||
&job.CreatedAt,
|
||||
&job.UpdatedAt,
|
||||
&job.IdempotencyKey,
|
||||
|
||||
Reference in New Issue
Block a user