Add generic LLM worker
This commit is contained in:
@@ -53,6 +53,10 @@ jobs:
|
||||
kubectl apply -f k8s/postgres.yaml
|
||||
kubectl apply -f k8s/server-deployment.yaml
|
||||
kubectl apply -f k8s/server-service.yaml
|
||||
kubectl apply -f k8s/worker-deployment.yaml
|
||||
kubectl -n ai-service set image deployment/ai-service \
|
||||
server=${{ env.NODE_REGISTRY }}/admin/ai-service:${{ github.sha }}
|
||||
kubectl -n ai-service set image deployment/ai-service-worker \
|
||||
worker=${{ env.NODE_REGISTRY }}/admin/ai-service:${{ github.sha }}
|
||||
kubectl -n ai-service rollout status deployment/ai-service --timeout=180s
|
||||
kubectl -n ai-service rollout status deployment/ai-service-worker --timeout=180s
|
||||
|
||||
@@ -8,7 +8,8 @@ RUN go mod download
|
||||
COPY cmd ./cmd
|
||||
COPY internal ./internal
|
||||
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -o /out/ai-service ./cmd/server
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -o /out/ai-service ./cmd/server \
|
||||
&& CGO_ENABLED=0 GOOS=linux go build -o /out/ai-service-worker ./cmd/worker
|
||||
|
||||
FROM alpine:3.22
|
||||
|
||||
@@ -16,6 +17,7 @@ RUN apk add --no-cache ca-certificates tini
|
||||
|
||||
WORKDIR /app
|
||||
COPY --from=builder /out/ai-service /usr/local/bin/ai-service
|
||||
COPY --from=builder /out/ai-service-worker /usr/local/bin/ai-service-worker
|
||||
|
||||
EXPOSE 8080
|
||||
|
||||
|
||||
23
README.md
23
README.md
@@ -22,6 +22,26 @@ The service is intentionally domain-agnostic:
|
||||
This keeps AI service as shared infrastructure rather than a telephony-specific
|
||||
service.
|
||||
|
||||
## Built-in workers
|
||||
|
||||
The first built-in worker processes `llm_chat` and `chat_completion` jobs whose
|
||||
`model_profile` equals `LLM_MODEL`.
|
||||
|
||||
Input can be either explicit messages:
|
||||
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{"role": "system", "content": "Answer as JSON."},
|
||||
{"role": "user", "content": "Classify this text"}
|
||||
],
|
||||
"max_tokens": 256
|
||||
}
|
||||
```
|
||||
|
||||
or compact `system` / `user` fields. The completed job result contains
|
||||
`content`, `model`, `usage` and `duration_ms`.
|
||||
|
||||
## API
|
||||
|
||||
- `POST /api/v1/jobs` creates one job.
|
||||
@@ -48,6 +68,9 @@ service.
|
||||
- `LLM_MODEL`, default `qwen2.5-14b`
|
||||
- `LLM_TIMEOUT`, default `5m`
|
||||
- `WHISPERX_URL`, WhisperX endpoint for transcription jobs
|
||||
- `WORKER_ID`, default hostname
|
||||
- `WORKER_POLL_INTERVAL`, default `2s`
|
||||
- `WORKER_CLAIM_LIMIT`, default `4`
|
||||
|
||||
## Next integration step
|
||||
|
||||
|
||||
53
cmd/worker/main.go
Normal file
53
cmd/worker/main.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"ai-service/internal/config"
|
||||
"ai-service/internal/llm"
|
||||
"ai-service/internal/migrate"
|
||||
"ai-service/internal/store"
|
||||
"ai-service/internal/worker"
|
||||
)
|
||||
|
||||
func main() {
|
||||
cfg := config.Load()
|
||||
slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, nil)))
|
||||
|
||||
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
db, err := store.Open(ctx, cfg.DatabaseURL)
|
||||
if err != nil {
|
||||
slog.Error("db_open_failed", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
if cfg.MigrateOnStart {
|
||||
if err := migrate.Up(ctx, db); err != nil {
|
||||
slog.Error("migrate_failed", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.LLMBaseURL == "" {
|
||||
slog.Error("llm_not_configured")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
llmClient := llm.New(cfg.LLMBaseURL, cfg.LLMAPIKey, cfg.LLMModel, cfg.LLMTimeout)
|
||||
w := worker.New(db, llmClient, cfg.WorkerID, cfg.LLMModel, cfg.WorkerPollInterval, cfg.WorkerClaimLimit)
|
||||
|
||||
slog.Info("ai_worker_started",
|
||||
"worker_id", cfg.WorkerID,
|
||||
"model", cfg.LLMModel,
|
||||
"poll_interval", cfg.WorkerPollInterval.String(),
|
||||
"claim_limit", cfg.WorkerClaimLimit,
|
||||
)
|
||||
w.Run(ctx)
|
||||
}
|
||||
@@ -17,6 +17,10 @@ type Config struct {
|
||||
LLMModel string
|
||||
LLMTimeout time.Duration
|
||||
WhisperXURL string
|
||||
|
||||
WorkerID string
|
||||
WorkerPollInterval time.Duration
|
||||
WorkerClaimLimit int
|
||||
}
|
||||
|
||||
func Load() Config {
|
||||
@@ -31,6 +35,10 @@ func Load() Config {
|
||||
LLMModel: envString("LLM_MODEL", "qwen2.5-14b"),
|
||||
LLMTimeout: envDuration("LLM_TIMEOUT", 5*time.Minute),
|
||||
WhisperXURL: envString("WHISPERX_URL", ""),
|
||||
|
||||
WorkerID: envString("WORKER_ID", hostname()),
|
||||
WorkerPollInterval: envDuration("WORKER_POLL_INTERVAL", 2*time.Second),
|
||||
WorkerClaimLimit: envInt("WORKER_CLAIM_LIMIT", 4),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,3 +84,11 @@ func envDuration(key string, fallback time.Duration) time.Duration {
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func hostname() string {
|
||||
h, err := os.Hostname()
|
||||
if err != nil || h == "" {
|
||||
return "ai-service-worker"
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
159
internal/llm/client.go
Normal file
159
internal/llm/client.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
baseURL string
|
||||
apiKey string
|
||||
model string
|
||||
http *http.Client
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ChatInput struct {
|
||||
System string `json:"system,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
ResponseFormat json.RawMessage `json:"response_format,omitempty"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type ChatResult struct {
|
||||
Content string `json:"content"`
|
||||
Model string `json:"model"`
|
||||
Usage *Usage `json:"usage,omitempty"`
|
||||
DurationMS int64 `json:"duration_ms"`
|
||||
}
|
||||
|
||||
type chatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
ResponseFormat *json.RawMessage `json:"response_format,omitempty"`
|
||||
}
|
||||
|
||||
type chatResponse struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Choices []struct {
|
||||
Message Message `json:"message"`
|
||||
} `json:"choices"`
|
||||
Usage *Usage `json:"usage,omitempty"`
|
||||
Error *struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func New(baseURL, apiKey, model string, timeout time.Duration) *Client {
|
||||
return &Client{
|
||||
baseURL: strings.TrimRight(strings.TrimSpace(baseURL), "/"),
|
||||
apiKey: apiKey,
|
||||
model: model,
|
||||
http: &http.Client{Timeout: timeout},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) Chat(ctx context.Context, in ChatInput) (*ChatResult, error) {
|
||||
if c == nil || c.baseURL == "" {
|
||||
return nil, fmt.Errorf("llm not configured")
|
||||
}
|
||||
messages := normalizeMessages(in)
|
||||
if len(messages) == 0 {
|
||||
return nil, fmt.Errorf("messages are required")
|
||||
}
|
||||
temp := 0.1
|
||||
if in.Temperature != nil {
|
||||
temp = *in.Temperature
|
||||
}
|
||||
reqBody := chatRequest{
|
||||
Model: c.model,
|
||||
Messages: messages,
|
||||
Temperature: temp,
|
||||
MaxTokens: in.MaxTokens,
|
||||
}
|
||||
if len(in.ResponseFormat) > 0 {
|
||||
reqBody.ResponseFormat = &in.ResponseFormat
|
||||
}
|
||||
body, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/chat/completions", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if c.apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+c.apiKey)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
resp, err := c.http.Do(req)
|
||||
duration := time.Since(start)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("llm do: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
raw, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("llm read: %w", err)
|
||||
}
|
||||
if resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("llm HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(raw)))
|
||||
}
|
||||
var out chatResponse
|
||||
if err := json.Unmarshal(raw, &out); err != nil {
|
||||
return nil, fmt.Errorf("llm decode: %w", err)
|
||||
}
|
||||
if out.Error != nil {
|
||||
return nil, fmt.Errorf("llm error: %s", out.Error.Message)
|
||||
}
|
||||
if len(out.Choices) == 0 {
|
||||
return nil, fmt.Errorf("llm: empty choices")
|
||||
}
|
||||
modelName := out.Model
|
||||
if modelName == "" {
|
||||
modelName = c.model
|
||||
}
|
||||
return &ChatResult{
|
||||
Content: out.Choices[0].Message.Content,
|
||||
Model: modelName,
|
||||
Usage: out.Usage,
|
||||
DurationMS: duration.Milliseconds(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func normalizeMessages(in ChatInput) []Message {
|
||||
if len(in.Messages) > 0 {
|
||||
return in.Messages
|
||||
}
|
||||
var out []Message
|
||||
if strings.TrimSpace(in.System) != "" {
|
||||
out = append(out, Message{Role: "system", Content: in.System})
|
||||
}
|
||||
if strings.TrimSpace(in.User) != "" {
|
||||
out = append(out, Message{Role: "user", Content: in.User})
|
||||
}
|
||||
return out
|
||||
}
|
||||
122
internal/worker/worker.go
Normal file
122
internal/worker/worker.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ai-service/internal/llm"
|
||||
"ai-service/internal/model"
|
||||
"ai-service/internal/store"
|
||||
)
|
||||
|
||||
const (
|
||||
TaskLLMChat = "llm_chat"
|
||||
TaskChatCompletion = "chat_completion"
|
||||
)
|
||||
|
||||
type Worker struct {
|
||||
store *store.Store
|
||||
llm *llm.Client
|
||||
workerID string
|
||||
modelProfile string
|
||||
pollInterval time.Duration
|
||||
claimLimit int
|
||||
}
|
||||
|
||||
func New(store *store.Store, llmClient *llm.Client, workerID, modelProfile string, pollInterval time.Duration, claimLimit int) *Worker {
|
||||
if pollInterval <= 0 {
|
||||
pollInterval = 2 * time.Second
|
||||
}
|
||||
if claimLimit <= 0 {
|
||||
claimLimit = 4
|
||||
}
|
||||
if strings.TrimSpace(workerID) == "" {
|
||||
workerID = "ai-service-worker"
|
||||
}
|
||||
return &Worker{
|
||||
store: store,
|
||||
llm: llmClient,
|
||||
workerID: workerID,
|
||||
modelProfile: modelProfile,
|
||||
pollInterval: pollInterval,
|
||||
claimLimit: claimLimit,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) Run(ctx context.Context) {
|
||||
ticker := time.NewTicker(w.pollInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
w.tick(ctx)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) tick(ctx context.Context) {
|
||||
jobs, err := w.store.ClaimJobs(ctx, model.ClaimJobs{
|
||||
WorkerID: w.workerID,
|
||||
TaskTypes: []string{TaskLLMChat, TaskChatCompletion},
|
||||
ModelProfiles: []string{w.modelProfile},
|
||||
Limit: w.claimLimit,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("claim jobs failed", "error", err)
|
||||
return
|
||||
}
|
||||
for _, job := range jobs {
|
||||
w.process(ctx, job)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) process(ctx context.Context, job *model.Job) {
|
||||
var input llm.ChatInput
|
||||
if err := json.Unmarshal(job.Input, &input); err != nil {
|
||||
w.fail(ctx, job, "bad_input", err.Error())
|
||||
return
|
||||
}
|
||||
result, err := w.llm.Chat(ctx, input)
|
||||
if err != nil {
|
||||
w.fail(ctx, job, classifyLLMError(err), err.Error())
|
||||
return
|
||||
}
|
||||
body, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
w.fail(ctx, job, "bad_response", err.Error())
|
||||
return
|
||||
}
|
||||
if _, err := w.store.CompleteJob(ctx, job.ID, model.CompleteJob{Result: body}); err != nil {
|
||||
slog.Error("complete job failed", "job_id", job.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) fail(ctx context.Context, job *model.Job, code, message string) {
|
||||
if _, err := w.store.FailJob(ctx, job.ID, model.FailJob{ErrorCode: code, ErrorMessage: message}); err != nil {
|
||||
slog.Error("fail job failed", "job_id", job.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func classifyLLMError(err error) string {
|
||||
if err == nil {
|
||||
return "unknown"
|
||||
}
|
||||
s := strings.ToLower(err.Error())
|
||||
switch {
|
||||
case strings.Contains(s, "context deadline exceeded") || strings.Contains(s, "timeout"):
|
||||
return "timeout"
|
||||
case strings.Contains(s, "connection refused") || strings.Contains(s, "connection reset") || strings.Contains(s, "no route to host") || strings.Contains(s, "llm http 5"):
|
||||
return "model_unavailable"
|
||||
case strings.Contains(s, "llm http 4") || strings.Contains(s, "messages are required"):
|
||||
return "bad_input"
|
||||
case strings.Contains(s, "llm decode") || strings.Contains(s, "empty choices"):
|
||||
return "bad_response"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
@@ -12,3 +12,5 @@ data:
|
||||
LLM_MODEL: "qwen2.5-14b"
|
||||
LLM_TIMEOUT: "5m"
|
||||
WHISPERX_URL: "http://10.2.3.5:8001"
|
||||
WORKER_POLL_INTERVAL: "2s"
|
||||
WORKER_CLAIM_LIMIT: "4"
|
||||
|
||||
@@ -10,3 +10,4 @@ resources:
|
||||
- postgres.yaml
|
||||
- server-deployment.yaml
|
||||
- server-service.yaml
|
||||
- worker-deployment.yaml
|
||||
|
||||
37
k8s/worker-deployment.yaml
Normal file
37
k8s/worker-deployment.yaml
Normal file
@@ -0,0 +1,37 @@
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: ai-service-worker
|
||||
namespace: ai-service
|
||||
spec:
|
||||
replicas: 1
|
||||
selector:
|
||||
matchLabels:
|
||||
app: ai-service-worker
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: ai-service-worker
|
||||
spec:
|
||||
terminationGracePeriodSeconds: 20
|
||||
containers:
|
||||
- name: worker
|
||||
image: localhost:30300/admin/ai-service:latest
|
||||
command: ["/usr/local/bin/ai-service-worker"]
|
||||
env:
|
||||
- name: WORKER_ID
|
||||
valueFrom:
|
||||
fieldRef:
|
||||
fieldPath: metadata.name
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: ai-service-config
|
||||
- secretRef:
|
||||
name: ai-service-secrets
|
||||
resources:
|
||||
requests:
|
||||
cpu: 50m
|
||||
memory: 96Mi
|
||||
limits:
|
||||
cpu: 500m
|
||||
memory: 384Mi
|
||||
Reference in New Issue
Block a user