309 lines
8.8 KiB
Go
309 lines
8.8 KiB
Go
package aiservice
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type Client struct {
|
|
baseURL string
|
|
token string
|
|
http *http.Client
|
|
}
|
|
|
|
type Message struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
type ChatInput struct {
|
|
Messages []Message `json:"messages"`
|
|
Temperature float64 `json:"temperature"`
|
|
MaxTokens int `json:"max_tokens,omitempty"`
|
|
ResponseFormat json.RawMessage `json:"response_format,omitempty"`
|
|
}
|
|
|
|
type CreateJobRequest struct {
|
|
OwnerService string `json:"owner_service"`
|
|
OwnerRef string `json:"owner_ref"`
|
|
TaskType string `json:"task_type"`
|
|
ModelProfile string `json:"model_profile"`
|
|
Priority int `json:"priority"`
|
|
MaxAttempts int `json:"max_attempts"`
|
|
Input json.RawMessage `json:"input"`
|
|
IdempotencyKey string `json:"idempotency_key,omitempty"`
|
|
}
|
|
|
|
type CreateJobsRequest struct {
|
|
OwnerService string `json:"owner_service,omitempty"`
|
|
TaskType string `json:"task_type,omitempty"`
|
|
ModelProfile string `json:"model_profile,omitempty"`
|
|
Priority int `json:"priority,omitempty"`
|
|
MaxAttempts int `json:"max_attempts,omitempty"`
|
|
Jobs []CreateJobRequest `json:"jobs"`
|
|
}
|
|
|
|
type Job struct {
|
|
ID string `json:"id"`
|
|
OwnerService string `json:"owner_service,omitempty"`
|
|
OwnerRef string `json:"owner_ref,omitempty"`
|
|
TaskType string `json:"task_type,omitempty"`
|
|
ModelProfile string `json:"model_profile,omitempty"`
|
|
Status string `json:"status"`
|
|
Result json.RawMessage `json:"result,omitempty"`
|
|
ErrorCode *string `json:"error_code,omitempty"`
|
|
ErrorMessage *string `json:"error_message,omitempty"`
|
|
IdempotencyKey *string `json:"idempotency_key,omitempty"`
|
|
}
|
|
|
|
type ChatResult struct {
|
|
Content string `json:"content"`
|
|
Model string `json:"model"`
|
|
DurationMS int64 `json:"duration_ms"`
|
|
}
|
|
|
|
type ProvidersStatus struct {
|
|
At time.Time `json:"at"`
|
|
Providers []ProviderStatus `json:"providers"`
|
|
}
|
|
|
|
type ProviderStatus struct {
|
|
Name string `json:"name"`
|
|
Configured bool `json:"configured"`
|
|
OK bool `json:"ok"`
|
|
URL string `json:"url,omitempty"`
|
|
Model string `json:"model,omitempty"`
|
|
LatencyMS int64 `json:"latency_ms,omitempty"`
|
|
Error string `json:"error,omitempty"`
|
|
}
|
|
|
|
type Stats struct {
|
|
At time.Time `json:"at"`
|
|
Owners []OwnerStat `json:"owners,omitempty"`
|
|
Errors []ErrorStat `json:"errors,omitempty"`
|
|
Backlog []BacklogStat `json:"backlog,omitempty"`
|
|
}
|
|
|
|
type OwnerStat struct {
|
|
OwnerService string `json:"owner_service"`
|
|
TaskType string `json:"task_type"`
|
|
ModelProfile string `json:"model_profile"`
|
|
Status string `json:"status"`
|
|
Total int64 `json:"total"`
|
|
}
|
|
|
|
type ErrorStat struct {
|
|
OwnerService string `json:"owner_service,omitempty"`
|
|
TaskType string `json:"task_type"`
|
|
ModelProfile string `json:"model_profile"`
|
|
ErrorCode string `json:"error_code"`
|
|
Total int64 `json:"total"`
|
|
Last24h int64 `json:"last_24h"`
|
|
}
|
|
|
|
type BacklogStat struct {
|
|
OwnerService string `json:"owner_service"`
|
|
TaskType string `json:"task_type"`
|
|
ModelProfile string `json:"model_profile"`
|
|
Pending int64 `json:"pending"`
|
|
Running int64 `json:"running"`
|
|
StaleRunning int64 `json:"stale_running"`
|
|
OldestPendingAgeSeconds int64 `json:"oldest_pending_age_seconds"`
|
|
OldestRunningAgeSeconds int64 `json:"oldest_running_age_seconds"`
|
|
}
|
|
|
|
func New(baseURL, token string, timeout time.Duration) *Client {
|
|
baseURL = strings.TrimRight(strings.TrimSpace(baseURL), "/")
|
|
if baseURL == "" {
|
|
return nil
|
|
}
|
|
if timeout <= 0 {
|
|
timeout = 2 * time.Minute
|
|
}
|
|
return &Client{
|
|
baseURL: baseURL,
|
|
token: strings.TrimSpace(token),
|
|
http: &http.Client{Timeout: timeout},
|
|
}
|
|
}
|
|
|
|
func (c *Client) CreateJob(ctx context.Context, req CreateJobRequest) (*Job, error) {
|
|
if c == nil {
|
|
return nil, fmt.Errorf("ai-service not configured")
|
|
}
|
|
body, err := json.Marshal(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("marshal ai job: %w", err)
|
|
}
|
|
httpReq, err := c.request(ctx, http.MethodPost, "/api/v1/jobs", body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
resp, err := c.http.Do(httpReq)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create ai job: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
return nil, fmt.Errorf("create ai job: http %d: %s", resp.StatusCode, readSmall(resp.Body))
|
|
}
|
|
var out Job
|
|
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
|
|
return nil, fmt.Errorf("decode ai job: %w", err)
|
|
}
|
|
return &out, nil
|
|
}
|
|
|
|
func (c *Client) CreateJobs(ctx context.Context, req CreateJobsRequest) ([]*Job, error) {
|
|
if c == nil {
|
|
return nil, fmt.Errorf("ai-service not configured")
|
|
}
|
|
body, err := json.Marshal(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("marshal ai jobs: %w", err)
|
|
}
|
|
httpReq, err := c.request(ctx, http.MethodPost, "/api/v1/jobs/batch", body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
resp, err := c.http.Do(httpReq)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create ai jobs: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
return nil, fmt.Errorf("create ai jobs: http %d: %s", resp.StatusCode, readSmall(resp.Body))
|
|
}
|
|
var out struct {
|
|
Jobs []*Job `json:"jobs"`
|
|
}
|
|
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
|
|
return nil, fmt.Errorf("decode ai jobs: %w", err)
|
|
}
|
|
return out.Jobs, nil
|
|
}
|
|
|
|
func (c *Client) GetJob(ctx context.Context, id string) (*Job, error) {
|
|
if c == nil || strings.TrimSpace(id) == "" {
|
|
return nil, fmt.Errorf("ai job id is required")
|
|
}
|
|
req, err := c.request(ctx, http.MethodGet, "/api/v1/jobs/"+id, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
resp, err := c.http.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get ai job: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
return nil, fmt.Errorf("get ai job: http %d: %s", resp.StatusCode, readSmall(resp.Body))
|
|
}
|
|
var out Job
|
|
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
|
|
return nil, fmt.Errorf("decode ai job: %w", err)
|
|
}
|
|
return &out, nil
|
|
}
|
|
|
|
func (c *Client) WaitJob(ctx context.Context, id string, pollInterval time.Duration) (*Job, error) {
|
|
if pollInterval <= 0 {
|
|
pollInterval = 2 * time.Second
|
|
}
|
|
ticker := time.NewTicker(pollInterval)
|
|
defer ticker.Stop()
|
|
for {
|
|
job, err := c.GetJob(ctx, id)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
switch job.Status {
|
|
case "done", "failed", "cancelled":
|
|
return job, nil
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case <-ticker.C:
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Client) ProvidersStatus(ctx context.Context) (*ProvidersStatus, error) {
|
|
if c == nil {
|
|
return nil, fmt.Errorf("ai-service not configured")
|
|
}
|
|
req, err := c.request(ctx, http.MethodGet, "/api/v1/providers/status", nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
resp, err := c.http.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("ai providers status: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
return nil, fmt.Errorf("ai providers status: http %d: %s", resp.StatusCode, readSmall(resp.Body))
|
|
}
|
|
var out ProvidersStatus
|
|
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
|
|
return nil, fmt.Errorf("decode ai providers status: %w", err)
|
|
}
|
|
return &out, nil
|
|
}
|
|
|
|
func (c *Client) Stats(ctx context.Context) (*Stats, error) {
|
|
if c == nil {
|
|
return nil, fmt.Errorf("ai-service not configured")
|
|
}
|
|
req, err := c.request(ctx, http.MethodGet, "/api/v1/stats", nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
resp, err := c.http.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("ai stats: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
return nil, fmt.Errorf("ai stats: http %d: %s", resp.StatusCode, readSmall(resp.Body))
|
|
}
|
|
var out Stats
|
|
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
|
|
return nil, fmt.Errorf("decode ai stats: %w", err)
|
|
}
|
|
return &out, nil
|
|
}
|
|
|
|
func (c *Client) request(ctx context.Context, method, path string, body []byte) (*http.Request, error) {
|
|
var r io.Reader
|
|
if body != nil {
|
|
r = bytes.NewReader(body)
|
|
}
|
|
req, err := http.NewRequestWithContext(ctx, method, c.baseURL+path, r)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if body != nil {
|
|
req.Header.Set("Content-Type", "application/json")
|
|
}
|
|
if c.token != "" {
|
|
req.Header.Set("Authorization", "Bearer "+c.token)
|
|
}
|
|
return req, nil
|
|
}
|
|
|
|
func readSmall(r io.Reader) string {
|
|
body, err := io.ReadAll(io.LimitReader(r, 1024))
|
|
if err != nil {
|
|
return err.Error()
|
|
}
|
|
return strings.TrimSpace(string(body))
|
|
}
|