package store import ( "context" "encoding/json" "errors" "fmt" "strings" "time" "ai-service/internal/model" "github.com/google/uuid" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" ) type Store struct { pool *pgxpool.Pool } const jobSelectColumns = ` id, owner_service, owner_ref, task_type, model_profile, priority, status, attempts, max_attempts, input, result, error_code, error_message, scheduled_at, started_at, completed_at, worker_id, heartbeat_at, created_at, updated_at, idempotency_key ` const jobSelectColumnsFromJobAlias = ` j.id, j.owner_service, j.owner_ref, j.task_type, j.model_profile, j.priority, j.status, j.attempts, j.max_attempts, j.input, j.result, j.error_code, j.error_message, j.scheduled_at, j.started_at, j.completed_at, j.worker_id, j.heartbeat_at, j.created_at, j.updated_at, j.idempotency_key ` func Open(ctx context.Context, databaseURL string) (*Store, error) { if strings.TrimSpace(databaseURL) == "" { return nil, errors.New("DATABASE_URL is required") } cfg, err := pgxpool.ParseConfig(databaseURL) if err != nil { return nil, fmt.Errorf("parse database url: %w", err) } pool, err := pgxpool.NewWithConfig(ctx, cfg) if err != nil { return nil, fmt.Errorf("connect postgres: %w", err) } if err := pool.Ping(ctx); err != nil { pool.Close() return nil, fmt.Errorf("ping postgres: %w", err) } return &Store{pool: pool}, nil } func (s *Store) Close() { s.pool.Close() } func (s *Store) Ping(ctx context.Context) error { return s.pool.Ping(ctx) } func (s *Store) Exec(ctx context.Context, sql string, args ...any) error { _, err := s.pool.Exec(ctx, sql, args...) return err } func (s *Store) CreateJob(ctx context.Context, in model.CreateJob) (*model.Job, error) { if err := validateCreateJob(in); err != nil { return nil, err } normalizeCreateJob(&in) const q = ` INSERT INTO ai_jobs ( owner_service, owner_ref, task_type, model_profile, priority, max_attempts, input, scheduled_at, idempotency_key ) VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9) ON CONFLICT (idempotency_key) WHERE idempotency_key IS NOT NULL DO UPDATE SET updated_at = ai_jobs.updated_at RETURNING ` + jobSelectColumns + ` ` row := s.pool.QueryRow(ctx, q, in.OwnerService, in.OwnerRef, in.TaskType, in.ModelProfile, in.Priority, in.MaxAttempts, in.Input, *in.ScheduledAt, in.IdempotencyKey, ) return scanJob(row) } func (s *Store) CreateJobs(ctx context.Context, items []model.CreateJob) ([]*model.Job, error) { if len(items) == 0 { return []*model.Job{}, nil } const q = ` INSERT INTO ai_jobs ( owner_service, owner_ref, task_type, model_profile, priority, max_attempts, input, scheduled_at, idempotency_key ) VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9) ON CONFLICT (idempotency_key) WHERE idempotency_key IS NOT NULL DO UPDATE SET updated_at = ai_jobs.updated_at RETURNING ` + jobSelectColumns + ` ` var batch pgx.Batch for i := range items { if err := validateCreateJob(items[i]); err != nil { return nil, err } normalizeCreateJob(&items[i]) batch.Queue(q, items[i].OwnerService, items[i].OwnerRef, items[i].TaskType, items[i].ModelProfile, items[i].Priority, items[i].MaxAttempts, items[i].Input, *items[i].ScheduledAt, items[i].IdempotencyKey, ) } br := s.pool.SendBatch(ctx, &batch) batchClosed := false defer func() { if !batchClosed { _ = br.Close() } }() out := make([]*model.Job, 0, len(items)) for range items { job, err := scanJob(br.QueryRow()) if err != nil { return nil, err } out = append(out, job) } err := br.Close() batchClosed = true if err != nil { return nil, err } return out, nil } func normalizeCreateJob(in *model.CreateJob) { if in.MaxAttempts <= 0 { in.MaxAttempts = 3 } if len(in.Input) == 0 { in.Input = json.RawMessage(`{}`) } scheduledAt := time.Now().UTC() if in.ScheduledAt != nil { scheduledAt = in.ScheduledAt.UTC() } in.ScheduledAt = &scheduledAt } func validateCreateJob(in model.CreateJob) error { switch { case strings.TrimSpace(in.OwnerService) == "": return errors.New("owner_service is required") case strings.TrimSpace(in.OwnerRef) == "": return errors.New("owner_ref is required") case strings.TrimSpace(in.TaskType) == "": return errors.New("task_type is required") case strings.TrimSpace(in.ModelProfile) == "": return errors.New("model_profile is required") default: return nil } } func (s *Store) GetJob(ctx context.Context, id uuid.UUID) (*model.Job, error) { const q = ` SELECT id, owner_service, owner_ref, task_type, model_profile, priority, status, attempts, max_attempts, input, result, error_code, error_message, scheduled_at, started_at, completed_at, worker_id, heartbeat_at, created_at, updated_at, idempotency_key FROM ai_jobs WHERE id = $1 ` job, err := scanJob(s.pool.QueryRow(ctx, q, id)) if errors.Is(err, pgx.ErrNoRows) { return nil, nil } return job, err } func (s *Store) ListJobs(ctx context.Context, filter model.JobFilter) ([]*model.Job, error) { normalizeFilter(&filter) const q = ` SELECT ` + jobSelectColumns + ` FROM ai_jobs WHERE ($1 = '' OR owner_service = $1) AND ($2 = '' OR owner_ref = $2) AND ($3 = '' OR task_type = $3) AND ($4 = '' OR model_profile = $4) AND (cardinality($5::text[]) = 0 OR status = ANY($5::text[])) AND (cardinality($6::text[]) = 0 OR COALESCE(NULLIF(error_code, ''), 'unknown') = ANY($6::text[])) ORDER BY created_at DESC LIMIT $7 OFFSET $8 ` rows, err := s.pool.Query(ctx, q, filter.OwnerService, filter.OwnerRef, filter.TaskType, filter.ModelProfile, filter.Statuses, filter.ErrorCodes, filter.Limit, filter.Offset, ) if err != nil { return nil, err } defer rows.Close() var out []*model.Job for rows.Next() { job, err := scanJob(rows) if err != nil { return nil, err } out = append(out, job) } return out, rows.Err() } func (s *Store) ListJobSummaries(ctx context.Context, filter model.JobFilter) ([]*model.JobSummary, error) { normalizeFilter(&filter) const q = ` SELECT id, owner_service, owner_ref, task_type, model_profile, priority, status, attempts, max_attempts, error_code, error_message, scheduled_at, started_at, completed_at, worker_id, heartbeat_at, created_at, updated_at FROM ai_jobs WHERE ($1 = '' OR owner_service = $1) AND ($2 = '' OR owner_ref = $2) AND ($3 = '' OR task_type = $3) AND ($4 = '' OR model_profile = $4) AND (cardinality($5::text[]) = 0 OR status = ANY($5::text[])) AND (cardinality($6::text[]) = 0 OR COALESCE(NULLIF(error_code, ''), 'unknown') = ANY($6::text[])) ORDER BY updated_at DESC, created_at DESC LIMIT $7 OFFSET $8 ` rows, err := s.pool.Query(ctx, q, filter.OwnerService, filter.OwnerRef, filter.TaskType, filter.ModelProfile, filter.Statuses, filter.ErrorCodes, filter.Limit, filter.Offset, ) if err != nil { return nil, err } defer rows.Close() var out []*model.JobSummary for rows.Next() { job, err := scanJobSummary(rows) if err != nil { return nil, err } out = append(out, job) } return out, rows.Err() } func normalizeFilter(filter *model.JobFilter) { filter.OwnerService = strings.TrimSpace(filter.OwnerService) filter.OwnerRef = strings.TrimSpace(filter.OwnerRef) filter.TaskType = strings.TrimSpace(filter.TaskType) filter.ModelProfile = strings.TrimSpace(filter.ModelProfile) if filter.Statuses == nil { filter.Statuses = []string{} } if filter.ErrorCodes == nil { filter.ErrorCodes = []string{} } if filter.Limit <= 0 { filter.Limit = 100 } if filter.Limit > 500 { filter.Limit = 500 } if filter.Offset < 0 { filter.Offset = 0 } } func (s *Store) RetryJob(ctx context.Context, id uuid.UUID) (*model.Job, error) { const q = ` UPDATE ai_jobs SET status = 'pending', attempts = 0, started_at = NULL, completed_at = NULL, error_code = NULL, error_message = NULL, worker_id = NULL, heartbeat_at = NULL, scheduled_at = NOW(), updated_at = NOW() WHERE id = $1 AND status IN ('failed', 'running') RETURNING id, owner_service, owner_ref, task_type, model_profile, priority, status, attempts, max_attempts, input, result, error_code, error_message, scheduled_at, started_at, completed_at, worker_id, heartbeat_at, created_at, updated_at, idempotency_key ` job, err := scanJob(s.pool.QueryRow(ctx, q, id)) if errors.Is(err, pgx.ErrNoRows) { return nil, nil } return job, err } func (s *Store) RetryJobs(ctx context.Context, filter model.JobFilter) (int, error) { normalizeFilter(&filter) const q = ` WITH picked AS ( SELECT id FROM ai_jobs WHERE status IN ('failed', 'running') AND ($1 = '' OR owner_service = $1) AND ($2 = '' OR owner_ref = $2) AND ($3 = '' OR task_type = $3) AND ($4 = '' OR model_profile = $4) AND (cardinality($5::text[]) = 0 OR COALESCE(NULLIF(error_code, ''), 'unknown') = ANY($5::text[])) ORDER BY updated_at ASC LIMIT $6 ) UPDATE ai_jobs j SET status = 'pending', attempts = 0, started_at = NULL, completed_at = NULL, error_code = NULL, error_message = NULL, worker_id = NULL, heartbeat_at = NULL, scheduled_at = NOW(), updated_at = NOW() FROM picked WHERE j.id = picked.id ` tag, err := s.pool.Exec(ctx, q, filter.OwnerService, filter.OwnerRef, filter.TaskType, filter.ModelProfile, filter.ErrorCodes, filter.Limit, ) if err != nil { return 0, err } return int(tag.RowsAffected()), nil } func (s *Store) CancelJobs(ctx context.Context, filter model.JobFilter) (int, error) { normalizeFilter(&filter) const q = ` WITH picked AS ( SELECT id FROM ai_jobs WHERE status IN ('pending', 'running') AND ($1 = '' OR owner_service = $1) AND ($2 = '' OR owner_ref = $2) AND ($3 = '' OR task_type = $3) AND ($4 = '' OR model_profile = $4) AND (cardinality($5::text[]) = 0 OR status = ANY($5::text[])) ORDER BY updated_at ASC LIMIT $6 ) UPDATE ai_jobs j SET status = 'cancelled', completed_at = NOW(), worker_id = NULL, heartbeat_at = NULL, updated_at = NOW() FROM picked WHERE j.id = picked.id ` tag, err := s.pool.Exec(ctx, q, filter.OwnerService, filter.OwnerRef, filter.TaskType, filter.ModelProfile, filter.Statuses, filter.Limit, ) if err != nil { return 0, err } return int(tag.RowsAffected()), nil } func (s *Store) ClaimJobs(ctx context.Context, in model.ClaimJobs) ([]*model.Job, error) { if in.Limit <= 0 { in.Limit = 1 } if in.Limit > 100 { in.Limit = 100 } workerID := strings.TrimSpace(in.WorkerID) if workerID == "" { workerID = "unknown" } if in.TaskTypes == nil { in.TaskTypes = []string{} } if in.ModelProfiles == nil { in.ModelProfiles = []string{} } const q = ` WITH picked AS ( SELECT id FROM ai_jobs WHERE status = 'pending' AND attempts < max_attempts AND scheduled_at <= NOW() AND (cardinality($1::text[]) = 0 OR task_type = ANY($1::text[])) AND (cardinality($2::text[]) = 0 OR model_profile = ANY($2::text[])) ORDER BY priority DESC, scheduled_at DESC, created_at DESC LIMIT $3 FOR UPDATE SKIP LOCKED ) UPDATE ai_jobs j SET status = 'running', attempts = j.attempts + 1, started_at = NOW(), completed_at = NULL, error_code = NULL, error_message = NULL, worker_id = $4, heartbeat_at = NOW(), updated_at = NOW() FROM picked WHERE j.id = picked.id RETURNING ` + jobSelectColumnsFromJobAlias + ` ` rows, err := s.pool.Query(ctx, q, in.TaskTypes, in.ModelProfiles, in.Limit, workerID) if err != nil { return nil, err } defer rows.Close() var out []*model.Job for rows.Next() { job, err := scanJob(rows) if err != nil { return nil, err } out = append(out, job) } return out, rows.Err() } func (s *Store) CompleteJob(ctx context.Context, id uuid.UUID, in model.CompleteJob) (*model.Job, error) { if len(in.Result) == 0 { in.Result = json.RawMessage(`{}`) } const q = ` UPDATE ai_jobs SET status = 'done', result = $2, error_code = NULL, error_message = NULL, completed_at = NOW(), heartbeat_at = NOW(), updated_at = NOW() WHERE id = $1 AND status = 'running' RETURNING ` + jobSelectColumns + ` ` job, err := scanJob(s.pool.QueryRow(ctx, q, id, in.Result)) if errors.Is(err, pgx.ErrNoRows) { return nil, nil } return job, err } func (s *Store) FailJob(ctx context.Context, id uuid.UUID, in model.FailJob) (*model.Job, error) { errorCode := strings.TrimSpace(in.ErrorCode) if errorCode == "" { errorCode = "unknown" } errorMessage := strings.TrimSpace(in.ErrorMessage) const q = ` UPDATE ai_jobs SET status = 'failed', error_code = $2, error_message = $3, completed_at = NOW(), heartbeat_at = NOW(), updated_at = NOW() WHERE id = $1 AND status = 'running' RETURNING ` + jobSelectColumns + ` ` job, err := scanJob(s.pool.QueryRow(ctx, q, id, errorCode, errorMessage)) if errors.Is(err, pgx.ErrNoRows) { return nil, nil } return job, err } func (s *Store) RequeueStaleRunning(ctx context.Context, olderThan time.Duration, limit int) (int, error) { if olderThan <= 0 { olderThan = 15 * time.Minute } if limit <= 0 { limit = 100 } if limit > 1000 { limit = 1000 } const q = ` WITH picked AS ( SELECT id FROM ai_jobs WHERE status = 'running' AND COALESCE(heartbeat_at, started_at, updated_at) < NOW() - make_interval(secs => $1) ORDER BY COALESCE(heartbeat_at, started_at, updated_at) ASC LIMIT $2 ) UPDATE ai_jobs j SET status = CASE WHEN j.attempts < j.max_attempts THEN 'pending' ELSE 'failed' END, error_code = CASE WHEN j.attempts < j.max_attempts THEN NULL ELSE 'stale_worker' END, error_message = CASE WHEN j.attempts < j.max_attempts THEN NULL ELSE 'worker lease expired' END, worker_id = NULL, heartbeat_at = NULL, completed_at = CASE WHEN j.attempts < j.max_attempts THEN NULL ELSE NOW() END, scheduled_at = CASE WHEN j.attempts < j.max_attempts THEN NOW() ELSE j.scheduled_at END, updated_at = NOW() FROM picked WHERE j.id = picked.id ` tag, err := s.pool.Exec(ctx, q, int(olderThan.Seconds()), limit) if err != nil { return 0, err } return int(tag.RowsAffected()), nil } func (s *Store) Stats(ctx context.Context) (*model.Stats, error) { out := &model.Stats{At: time.Now().UTC()} queueRows, err := s.pool.Query(ctx, ` SELECT task_type, model_profile, status, count(*) FROM ai_jobs GROUP BY task_type, model_profile, status ORDER BY task_type, model_profile, status `) if err != nil { return nil, err } defer queueRows.Close() for queueRows.Next() { var stat model.QueueStat if err := queueRows.Scan(&stat.TaskType, &stat.ModelProfile, &stat.Status, &stat.Total); err != nil { return nil, err } out.Queues = append(out.Queues, stat) } if err := queueRows.Err(); err != nil { return nil, err } ownerRows, err := s.pool.Query(ctx, ` SELECT owner_service, task_type, model_profile, status, count(*) FROM ai_jobs GROUP BY owner_service, task_type, model_profile, status ORDER BY owner_service, task_type, model_profile, status `) if err != nil { return nil, err } defer ownerRows.Close() for ownerRows.Next() { var stat model.OwnerStat if err := ownerRows.Scan(&stat.OwnerService, &stat.TaskType, &stat.ModelProfile, &stat.Status, &stat.Total); err != nil { return nil, err } out.Owners = append(out.Owners, stat) } if err := ownerRows.Err(); err != nil { return nil, err } errorRows, err := s.pool.Query(ctx, ` SELECT owner_service, task_type, model_profile, COALESCE(NULLIF(error_code, ''), 'unknown') AS error_code, count(*) AS total, count(*) FILTER (WHERE updated_at > NOW() - INTERVAL '24 hours') AS last_24h FROM ai_jobs WHERE status = 'failed' GROUP BY owner_service, task_type, model_profile, COALESCE(NULLIF(error_code, ''), 'unknown') ORDER BY owner_service, last_24h DESC, total DESC `) if err != nil { return nil, err } defer errorRows.Close() for errorRows.Next() { var stat model.ErrorStat if err := errorRows.Scan(&stat.OwnerService, &stat.TaskType, &stat.ModelProfile, &stat.ErrorCode, &stat.Total, &stat.Last24h); err != nil { return nil, err } out.Errors = append(out.Errors, stat) } return out, errorRows.Err() } func scanJobSummary(row pgx.Row) (*model.JobSummary, error) { var job model.JobSummary err := row.Scan( &job.ID, &job.OwnerService, &job.OwnerRef, &job.TaskType, &job.ModelProfile, &job.Priority, &job.Status, &job.Attempts, &job.MaxAttempts, &job.ErrorCode, &job.ErrorMessage, &job.ScheduledAt, &job.StartedAt, &job.CompletedAt, &job.WorkerID, &job.HeartbeatAt, &job.CreatedAt, &job.UpdatedAt, ) if err != nil { return nil, err } return &job, nil } func scanJob(row pgx.Row) (*model.Job, error) { var job model.Job var input []byte var result []byte err := row.Scan( &job.ID, &job.OwnerService, &job.OwnerRef, &job.TaskType, &job.ModelProfile, &job.Priority, &job.Status, &job.Attempts, &job.MaxAttempts, &input, &result, &job.ErrorCode, &job.ErrorMessage, &job.ScheduledAt, &job.StartedAt, &job.CompletedAt, &job.WorkerID, &job.HeartbeatAt, &job.CreatedAt, &job.UpdatedAt, &job.IdempotencyKey, ) if err != nil { return nil, err } job.Input = json.RawMessage(input) if len(result) > 0 { job.Result = json.RawMessage(result) } return &job, nil }