558 lines
14 KiB
Go
558 lines
14 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"ai-service/internal/model"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
)
|
|
|
|
type Store struct {
|
|
pool *pgxpool.Pool
|
|
}
|
|
|
|
const jobSelectColumns = `
|
|
id, owner_service, owner_ref, task_type, model_profile, priority, status,
|
|
attempts, max_attempts, input, result, error_code, error_message,
|
|
scheduled_at, started_at, completed_at, worker_id, heartbeat_at,
|
|
created_at, updated_at, idempotency_key
|
|
`
|
|
|
|
const jobSelectColumnsFromJobAlias = `
|
|
j.id, j.owner_service, j.owner_ref, j.task_type, j.model_profile, j.priority, j.status,
|
|
j.attempts, j.max_attempts, j.input, j.result, j.error_code, j.error_message,
|
|
j.scheduled_at, j.started_at, j.completed_at, j.worker_id, j.heartbeat_at,
|
|
j.created_at, j.updated_at, j.idempotency_key
|
|
`
|
|
|
|
func Open(ctx context.Context, databaseURL string) (*Store, error) {
|
|
if strings.TrimSpace(databaseURL) == "" {
|
|
return nil, errors.New("DATABASE_URL is required")
|
|
}
|
|
cfg, err := pgxpool.ParseConfig(databaseURL)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse database url: %w", err)
|
|
}
|
|
pool, err := pgxpool.NewWithConfig(ctx, cfg)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("connect postgres: %w", err)
|
|
}
|
|
if err := pool.Ping(ctx); err != nil {
|
|
pool.Close()
|
|
return nil, fmt.Errorf("ping postgres: %w", err)
|
|
}
|
|
return &Store{pool: pool}, nil
|
|
}
|
|
|
|
func (s *Store) Close() {
|
|
s.pool.Close()
|
|
}
|
|
|
|
func (s *Store) Ping(ctx context.Context) error {
|
|
return s.pool.Ping(ctx)
|
|
}
|
|
|
|
func (s *Store) Exec(ctx context.Context, sql string, args ...any) error {
|
|
_, err := s.pool.Exec(ctx, sql, args...)
|
|
return err
|
|
}
|
|
|
|
func (s *Store) CreateJob(ctx context.Context, in model.CreateJob) (*model.Job, error) {
|
|
if err := validateCreateJob(in); err != nil {
|
|
return nil, err
|
|
}
|
|
if in.MaxAttempts <= 0 {
|
|
in.MaxAttempts = 3
|
|
}
|
|
if len(in.Input) == 0 {
|
|
in.Input = json.RawMessage(`{}`)
|
|
}
|
|
scheduledAt := time.Now().UTC()
|
|
if in.ScheduledAt != nil {
|
|
scheduledAt = in.ScheduledAt.UTC()
|
|
}
|
|
|
|
const q = `
|
|
INSERT INTO ai_jobs (
|
|
owner_service, owner_ref, task_type, model_profile, priority, max_attempts,
|
|
input, scheduled_at, idempotency_key
|
|
)
|
|
VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9)
|
|
ON CONFLICT (idempotency_key) WHERE idempotency_key IS NOT NULL
|
|
DO UPDATE SET updated_at = ai_jobs.updated_at
|
|
RETURNING ` + jobSelectColumns + `
|
|
`
|
|
row := s.pool.QueryRow(ctx, q,
|
|
in.OwnerService,
|
|
in.OwnerRef,
|
|
in.TaskType,
|
|
in.ModelProfile,
|
|
in.Priority,
|
|
in.MaxAttempts,
|
|
in.Input,
|
|
scheduledAt,
|
|
in.IdempotencyKey,
|
|
)
|
|
return scanJob(row)
|
|
}
|
|
|
|
func validateCreateJob(in model.CreateJob) error {
|
|
switch {
|
|
case strings.TrimSpace(in.OwnerService) == "":
|
|
return errors.New("owner_service is required")
|
|
case strings.TrimSpace(in.OwnerRef) == "":
|
|
return errors.New("owner_ref is required")
|
|
case strings.TrimSpace(in.TaskType) == "":
|
|
return errors.New("task_type is required")
|
|
case strings.TrimSpace(in.ModelProfile) == "":
|
|
return errors.New("model_profile is required")
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (s *Store) GetJob(ctx context.Context, id uuid.UUID) (*model.Job, error) {
|
|
const q = `
|
|
SELECT id, owner_service, owner_ref, task_type, model_profile, priority, status,
|
|
attempts, max_attempts, input, result, error_code, error_message,
|
|
scheduled_at, started_at, completed_at, worker_id, heartbeat_at,
|
|
created_at, updated_at, idempotency_key
|
|
FROM ai_jobs
|
|
WHERE id = $1
|
|
`
|
|
job, err := scanJob(s.pool.QueryRow(ctx, q, id))
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return nil, nil
|
|
}
|
|
return job, err
|
|
}
|
|
|
|
func (s *Store) ListJobs(ctx context.Context, filter model.JobFilter) ([]*model.Job, error) {
|
|
normalizeFilter(&filter)
|
|
const q = `
|
|
SELECT ` + jobSelectColumns + `
|
|
FROM ai_jobs
|
|
WHERE ($1 = '' OR owner_service = $1)
|
|
AND ($2 = '' OR owner_ref = $2)
|
|
AND ($3 = '' OR task_type = $3)
|
|
AND ($4 = '' OR model_profile = $4)
|
|
AND (cardinality($5::text[]) = 0 OR status = ANY($5::text[]))
|
|
AND (cardinality($6::text[]) = 0 OR COALESCE(NULLIF(error_code, ''), 'unknown') = ANY($6::text[]))
|
|
ORDER BY created_at DESC
|
|
LIMIT $7 OFFSET $8
|
|
`
|
|
rows, err := s.pool.Query(ctx, q,
|
|
filter.OwnerService,
|
|
filter.OwnerRef,
|
|
filter.TaskType,
|
|
filter.ModelProfile,
|
|
filter.Statuses,
|
|
filter.ErrorCodes,
|
|
filter.Limit,
|
|
filter.Offset,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
var out []*model.Job
|
|
for rows.Next() {
|
|
job, err := scanJob(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
out = append(out, job)
|
|
}
|
|
return out, rows.Err()
|
|
}
|
|
|
|
func normalizeFilter(filter *model.JobFilter) {
|
|
filter.OwnerService = strings.TrimSpace(filter.OwnerService)
|
|
filter.OwnerRef = strings.TrimSpace(filter.OwnerRef)
|
|
filter.TaskType = strings.TrimSpace(filter.TaskType)
|
|
filter.ModelProfile = strings.TrimSpace(filter.ModelProfile)
|
|
if filter.Statuses == nil {
|
|
filter.Statuses = []string{}
|
|
}
|
|
if filter.ErrorCodes == nil {
|
|
filter.ErrorCodes = []string{}
|
|
}
|
|
if filter.Limit <= 0 {
|
|
filter.Limit = 100
|
|
}
|
|
if filter.Limit > 500 {
|
|
filter.Limit = 500
|
|
}
|
|
if filter.Offset < 0 {
|
|
filter.Offset = 0
|
|
}
|
|
}
|
|
|
|
func (s *Store) RetryJob(ctx context.Context, id uuid.UUID) (*model.Job, error) {
|
|
const q = `
|
|
UPDATE ai_jobs
|
|
SET status = 'pending',
|
|
attempts = 0,
|
|
started_at = NULL,
|
|
completed_at = NULL,
|
|
error_code = NULL,
|
|
error_message = NULL,
|
|
worker_id = NULL,
|
|
heartbeat_at = NULL,
|
|
scheduled_at = NOW(),
|
|
updated_at = NOW()
|
|
WHERE id = $1
|
|
AND status IN ('failed', 'running')
|
|
RETURNING id, owner_service, owner_ref, task_type, model_profile, priority, status,
|
|
attempts, max_attempts, input, result, error_code, error_message,
|
|
scheduled_at, started_at, completed_at, worker_id, heartbeat_at,
|
|
created_at, updated_at, idempotency_key
|
|
`
|
|
job, err := scanJob(s.pool.QueryRow(ctx, q, id))
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return nil, nil
|
|
}
|
|
return job, err
|
|
}
|
|
|
|
func (s *Store) RetryJobs(ctx context.Context, filter model.JobFilter) (int, error) {
|
|
normalizeFilter(&filter)
|
|
const q = `
|
|
WITH picked AS (
|
|
SELECT id
|
|
FROM ai_jobs
|
|
WHERE status IN ('failed', 'running')
|
|
AND ($1 = '' OR owner_service = $1)
|
|
AND ($2 = '' OR owner_ref = $2)
|
|
AND ($3 = '' OR task_type = $3)
|
|
AND ($4 = '' OR model_profile = $4)
|
|
AND (cardinality($5::text[]) = 0 OR COALESCE(NULLIF(error_code, ''), 'unknown') = ANY($5::text[]))
|
|
ORDER BY updated_at ASC
|
|
LIMIT $6
|
|
)
|
|
UPDATE ai_jobs j
|
|
SET status = 'pending',
|
|
attempts = 0,
|
|
started_at = NULL,
|
|
completed_at = NULL,
|
|
error_code = NULL,
|
|
error_message = NULL,
|
|
worker_id = NULL,
|
|
heartbeat_at = NULL,
|
|
scheduled_at = NOW(),
|
|
updated_at = NOW()
|
|
FROM picked
|
|
WHERE j.id = picked.id
|
|
`
|
|
tag, err := s.pool.Exec(ctx, q,
|
|
filter.OwnerService,
|
|
filter.OwnerRef,
|
|
filter.TaskType,
|
|
filter.ModelProfile,
|
|
filter.ErrorCodes,
|
|
filter.Limit,
|
|
)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return int(tag.RowsAffected()), nil
|
|
}
|
|
|
|
func (s *Store) CancelJobs(ctx context.Context, filter model.JobFilter) (int, error) {
|
|
normalizeFilter(&filter)
|
|
const q = `
|
|
WITH picked AS (
|
|
SELECT id
|
|
FROM ai_jobs
|
|
WHERE status IN ('pending', 'running')
|
|
AND ($1 = '' OR owner_service = $1)
|
|
AND ($2 = '' OR owner_ref = $2)
|
|
AND ($3 = '' OR task_type = $3)
|
|
AND ($4 = '' OR model_profile = $4)
|
|
AND (cardinality($5::text[]) = 0 OR status = ANY($5::text[]))
|
|
ORDER BY updated_at ASC
|
|
LIMIT $6
|
|
)
|
|
UPDATE ai_jobs j
|
|
SET status = 'cancelled',
|
|
completed_at = NOW(),
|
|
worker_id = NULL,
|
|
heartbeat_at = NULL,
|
|
updated_at = NOW()
|
|
FROM picked
|
|
WHERE j.id = picked.id
|
|
`
|
|
tag, err := s.pool.Exec(ctx, q,
|
|
filter.OwnerService,
|
|
filter.OwnerRef,
|
|
filter.TaskType,
|
|
filter.ModelProfile,
|
|
filter.Statuses,
|
|
filter.Limit,
|
|
)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return int(tag.RowsAffected()), nil
|
|
}
|
|
|
|
func (s *Store) ClaimJobs(ctx context.Context, in model.ClaimJobs) ([]*model.Job, error) {
|
|
if in.Limit <= 0 {
|
|
in.Limit = 1
|
|
}
|
|
if in.Limit > 100 {
|
|
in.Limit = 100
|
|
}
|
|
workerID := strings.TrimSpace(in.WorkerID)
|
|
if workerID == "" {
|
|
workerID = "unknown"
|
|
}
|
|
if in.TaskTypes == nil {
|
|
in.TaskTypes = []string{}
|
|
}
|
|
if in.ModelProfiles == nil {
|
|
in.ModelProfiles = []string{}
|
|
}
|
|
const q = `
|
|
WITH picked AS (
|
|
SELECT id
|
|
FROM ai_jobs
|
|
WHERE status = 'pending'
|
|
AND attempts < max_attempts
|
|
AND scheduled_at <= NOW()
|
|
AND (cardinality($1::text[]) = 0 OR task_type = ANY($1::text[]))
|
|
AND (cardinality($2::text[]) = 0 OR model_profile = ANY($2::text[]))
|
|
ORDER BY priority DESC, scheduled_at ASC, created_at ASC
|
|
LIMIT $3
|
|
FOR UPDATE SKIP LOCKED
|
|
)
|
|
UPDATE ai_jobs j
|
|
SET status = 'running',
|
|
attempts = j.attempts + 1,
|
|
started_at = NOW(),
|
|
completed_at = NULL,
|
|
error_code = NULL,
|
|
error_message = NULL,
|
|
worker_id = $4,
|
|
heartbeat_at = NOW(),
|
|
updated_at = NOW()
|
|
FROM picked
|
|
WHERE j.id = picked.id
|
|
RETURNING ` + jobSelectColumnsFromJobAlias + `
|
|
`
|
|
rows, err := s.pool.Query(ctx, q, in.TaskTypes, in.ModelProfiles, in.Limit, workerID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
var out []*model.Job
|
|
for rows.Next() {
|
|
job, err := scanJob(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
out = append(out, job)
|
|
}
|
|
return out, rows.Err()
|
|
}
|
|
|
|
func (s *Store) CompleteJob(ctx context.Context, id uuid.UUID, in model.CompleteJob) (*model.Job, error) {
|
|
if len(in.Result) == 0 {
|
|
in.Result = json.RawMessage(`{}`)
|
|
}
|
|
const q = `
|
|
UPDATE ai_jobs
|
|
SET status = 'done',
|
|
result = $2,
|
|
error_code = NULL,
|
|
error_message = NULL,
|
|
completed_at = NOW(),
|
|
heartbeat_at = NOW(),
|
|
updated_at = NOW()
|
|
WHERE id = $1
|
|
AND status = 'running'
|
|
RETURNING ` + jobSelectColumns + `
|
|
`
|
|
job, err := scanJob(s.pool.QueryRow(ctx, q, id, in.Result))
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return nil, nil
|
|
}
|
|
return job, err
|
|
}
|
|
|
|
func (s *Store) FailJob(ctx context.Context, id uuid.UUID, in model.FailJob) (*model.Job, error) {
|
|
errorCode := strings.TrimSpace(in.ErrorCode)
|
|
if errorCode == "" {
|
|
errorCode = "unknown"
|
|
}
|
|
errorMessage := strings.TrimSpace(in.ErrorMessage)
|
|
const q = `
|
|
UPDATE ai_jobs
|
|
SET status = 'failed',
|
|
error_code = $2,
|
|
error_message = $3,
|
|
completed_at = NOW(),
|
|
heartbeat_at = NOW(),
|
|
updated_at = NOW()
|
|
WHERE id = $1
|
|
AND status = 'running'
|
|
RETURNING ` + jobSelectColumns + `
|
|
`
|
|
job, err := scanJob(s.pool.QueryRow(ctx, q, id, errorCode, errorMessage))
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return nil, nil
|
|
}
|
|
return job, err
|
|
}
|
|
|
|
func (s *Store) RequeueStaleRunning(ctx context.Context, olderThan time.Duration, limit int) (int, error) {
|
|
if olderThan <= 0 {
|
|
olderThan = 15 * time.Minute
|
|
}
|
|
if limit <= 0 {
|
|
limit = 100
|
|
}
|
|
if limit > 1000 {
|
|
limit = 1000
|
|
}
|
|
const q = `
|
|
WITH picked AS (
|
|
SELECT id
|
|
FROM ai_jobs
|
|
WHERE status = 'running'
|
|
AND COALESCE(heartbeat_at, started_at, updated_at) < NOW() - make_interval(secs => $1)
|
|
ORDER BY COALESCE(heartbeat_at, started_at, updated_at) ASC
|
|
LIMIT $2
|
|
)
|
|
UPDATE ai_jobs j
|
|
SET status = CASE WHEN j.attempts < j.max_attempts THEN 'pending' ELSE 'failed' END,
|
|
error_code = CASE WHEN j.attempts < j.max_attempts THEN NULL ELSE 'stale_worker' END,
|
|
error_message = CASE WHEN j.attempts < j.max_attempts THEN NULL ELSE 'worker lease expired' END,
|
|
worker_id = NULL,
|
|
heartbeat_at = NULL,
|
|
completed_at = CASE WHEN j.attempts < j.max_attempts THEN NULL ELSE NOW() END,
|
|
scheduled_at = CASE WHEN j.attempts < j.max_attempts THEN NOW() ELSE j.scheduled_at END,
|
|
updated_at = NOW()
|
|
FROM picked
|
|
WHERE j.id = picked.id
|
|
`
|
|
tag, err := s.pool.Exec(ctx, q, int(olderThan.Seconds()), limit)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return int(tag.RowsAffected()), nil
|
|
}
|
|
|
|
func (s *Store) Stats(ctx context.Context) (*model.Stats, error) {
|
|
out := &model.Stats{At: time.Now().UTC()}
|
|
|
|
queueRows, err := s.pool.Query(ctx, `
|
|
SELECT task_type, model_profile, status, count(*)
|
|
FROM ai_jobs
|
|
GROUP BY task_type, model_profile, status
|
|
ORDER BY task_type, model_profile, status
|
|
`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer queueRows.Close()
|
|
for queueRows.Next() {
|
|
var stat model.QueueStat
|
|
if err := queueRows.Scan(&stat.TaskType, &stat.ModelProfile, &stat.Status, &stat.Total); err != nil {
|
|
return nil, err
|
|
}
|
|
out.Queues = append(out.Queues, stat)
|
|
}
|
|
if err := queueRows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ownerRows, err := s.pool.Query(ctx, `
|
|
SELECT owner_service, task_type, model_profile, status, count(*)
|
|
FROM ai_jobs
|
|
GROUP BY owner_service, task_type, model_profile, status
|
|
ORDER BY owner_service, task_type, model_profile, status
|
|
`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer ownerRows.Close()
|
|
for ownerRows.Next() {
|
|
var stat model.OwnerStat
|
|
if err := ownerRows.Scan(&stat.OwnerService, &stat.TaskType, &stat.ModelProfile, &stat.Status, &stat.Total); err != nil {
|
|
return nil, err
|
|
}
|
|
out.Owners = append(out.Owners, stat)
|
|
}
|
|
if err := ownerRows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
errorRows, err := s.pool.Query(ctx, `
|
|
SELECT task_type, model_profile, COALESCE(NULLIF(error_code, ''), 'unknown') AS error_code,
|
|
count(*) AS total,
|
|
count(*) FILTER (WHERE updated_at > NOW() - INTERVAL '24 hours') AS last_24h
|
|
FROM ai_jobs
|
|
WHERE status = 'failed'
|
|
GROUP BY task_type, model_profile, COALESCE(NULLIF(error_code, ''), 'unknown')
|
|
ORDER BY last_24h DESC, total DESC
|
|
`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer errorRows.Close()
|
|
for errorRows.Next() {
|
|
var stat model.ErrorStat
|
|
if err := errorRows.Scan(&stat.TaskType, &stat.ModelProfile, &stat.ErrorCode, &stat.Total, &stat.Last24h); err != nil {
|
|
return nil, err
|
|
}
|
|
out.Errors = append(out.Errors, stat)
|
|
}
|
|
return out, errorRows.Err()
|
|
}
|
|
|
|
func scanJob(row pgx.Row) (*model.Job, error) {
|
|
var job model.Job
|
|
var input []byte
|
|
var result []byte
|
|
err := row.Scan(
|
|
&job.ID,
|
|
&job.OwnerService,
|
|
&job.OwnerRef,
|
|
&job.TaskType,
|
|
&job.ModelProfile,
|
|
&job.Priority,
|
|
&job.Status,
|
|
&job.Attempts,
|
|
&job.MaxAttempts,
|
|
&input,
|
|
&result,
|
|
&job.ErrorCode,
|
|
&job.ErrorMessage,
|
|
&job.ScheduledAt,
|
|
&job.StartedAt,
|
|
&job.CompletedAt,
|
|
&job.WorkerID,
|
|
&job.HeartbeatAt,
|
|
&job.CreatedAt,
|
|
&job.UpdatedAt,
|
|
&job.IdempotencyKey,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
job.Input = json.RawMessage(input)
|
|
if len(result) > 0 {
|
|
job.Result = json.RawMessage(result)
|
|
}
|
|
return &job, nil
|
|
}
|