diff --git a/internal/httpapi/server.go b/internal/httpapi/server.go index f363667..254c3a7 100644 --- a/internal/httpapi/server.go +++ b/internal/httpapi/server.go @@ -135,6 +135,8 @@ type createBatchResponse struct { Jobs []*model.Job `json:"jobs"` } +const maxCreateBatchJobs = 1000 + func (s *Server) handleCreateBatch(w http.ResponseWriter, r *http.Request) { var req createBatchRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { @@ -145,17 +147,21 @@ func (s *Server) handleCreateBatch(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusBadRequest, "jobs is required") return } + if len(req.Jobs) > maxCreateBatchJobs { + writeError(w, http.StatusBadRequest, fmt.Sprintf("jobs limit is %d", maxCreateBatchJobs)) + return + } ctx, cancel := contextWithTimeout(r, 20*time.Second) defer cancel() - out := createBatchResponse{Jobs: make([]*model.Job, 0, len(req.Jobs))} + items := make([]model.CreateJob, 0, len(req.Jobs)) for _, item := range req.Jobs { - if item.OwnerService == "" { + if strings.TrimSpace(item.OwnerService) == "" { item.OwnerService = req.OwnerService } - if item.TaskType == "" { + if strings.TrimSpace(item.TaskType) == "" { item.TaskType = req.TaskType } - if item.ModelProfile == "" { + if strings.TrimSpace(item.ModelProfile) == "" { item.ModelProfile = req.ModelProfile } if item.Priority == 0 { @@ -164,18 +170,18 @@ func (s *Server) handleCreateBatch(w http.ResponseWriter, r *http.Request) { if item.MaxAttempts == 0 { item.MaxAttempts = req.MaxAttempts } - job, err := s.store.CreateJob(ctx, item) - if err != nil { - status := http.StatusInternalServerError - if isValidationError(err) { - status = http.StatusBadRequest - } - writeError(w, status, err.Error()) - return - } - out.Jobs = append(out.Jobs, job) + items = append(items, item) } - writeJSON(w, http.StatusCreated, out) + jobs, err := s.store.CreateJobs(ctx, items) + if err != nil { + status := http.StatusInternalServerError + if isValidationError(err) { + status = http.StatusBadRequest + } + writeError(w, status, err.Error()) + return + } + writeJSON(w, http.StatusCreated, createBatchResponse{Jobs: jobs}) } func (s *Server) handleRetryJobs(w http.ResponseWriter, r *http.Request) { diff --git a/internal/store/store.go b/internal/store/store.go index 879d538..9d18855 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -69,16 +69,7 @@ func (s *Store) CreateJob(ctx context.Context, in model.CreateJob) (*model.Job, if err := validateCreateJob(in); err != nil { return nil, err } - if in.MaxAttempts <= 0 { - in.MaxAttempts = 3 - } - if len(in.Input) == 0 { - in.Input = json.RawMessage(`{}`) - } - scheduledAt := time.Now().UTC() - if in.ScheduledAt != nil { - scheduledAt = in.ScheduledAt.UTC() - } + normalizeCreateJob(&in) const q = ` INSERT INTO ai_jobs ( @@ -98,12 +89,81 @@ RETURNING ` + jobSelectColumns + ` in.Priority, in.MaxAttempts, in.Input, - scheduledAt, + *in.ScheduledAt, in.IdempotencyKey, ) return scanJob(row) } +func (s *Store) CreateJobs(ctx context.Context, items []model.CreateJob) ([]*model.Job, error) { + if len(items) == 0 { + return []*model.Job{}, nil + } + const q = ` +INSERT INTO ai_jobs ( + owner_service, owner_ref, task_type, model_profile, priority, max_attempts, + input, scheduled_at, idempotency_key +) +VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9) +ON CONFLICT (idempotency_key) WHERE idempotency_key IS NOT NULL +DO UPDATE SET updated_at = ai_jobs.updated_at +RETURNING ` + jobSelectColumns + ` +` + var batch pgx.Batch + for i := range items { + if err := validateCreateJob(items[i]); err != nil { + return nil, err + } + normalizeCreateJob(&items[i]) + batch.Queue(q, + items[i].OwnerService, + items[i].OwnerRef, + items[i].TaskType, + items[i].ModelProfile, + items[i].Priority, + items[i].MaxAttempts, + items[i].Input, + *items[i].ScheduledAt, + items[i].IdempotencyKey, + ) + } + br := s.pool.SendBatch(ctx, &batch) + batchClosed := false + defer func() { + if !batchClosed { + _ = br.Close() + } + }() + out := make([]*model.Job, 0, len(items)) + for range items { + job, err := scanJob(br.QueryRow()) + if err != nil { + return nil, err + } + out = append(out, job) + } + err := br.Close() + batchClosed = true + if err != nil { + return nil, err + } + return out, nil +} + +func normalizeCreateJob(in *model.CreateJob) { + if in.MaxAttempts <= 0 { + in.MaxAttempts = 3 + } + if len(in.Input) == 0 { + in.Input = json.RawMessage(`{}`) + } + scheduledAt := time.Now().UTC() + if in.ScheduledAt != nil { + scheduledAt = in.ScheduledAt.UTC() + } + in.ScheduledAt = &scheduledAt +} + func validateCreateJob(in model.CreateJob) error { switch { case strings.TrimSpace(in.OwnerService) == "":