Add generic LLM worker
All checks were successful
CI / test (push) Successful in 13s
Build and Deploy / build-and-deploy (push) Successful in 20s

This commit is contained in:
Grendgi
2026-06-08 13:52:29 +03:00
parent e0f74c62b0
commit 24c5d89c7b
10 changed files with 420 additions and 1 deletions

View File

@@ -17,6 +17,10 @@ type Config struct {
LLMModel string
LLMTimeout time.Duration
WhisperXURL string
WorkerID string
WorkerPollInterval time.Duration
WorkerClaimLimit int
}
func Load() Config {
@@ -31,6 +35,10 @@ func Load() Config {
LLMModel: envString("LLM_MODEL", "qwen2.5-14b"),
LLMTimeout: envDuration("LLM_TIMEOUT", 5*time.Minute),
WhisperXURL: envString("WHISPERX_URL", ""),
WorkerID: envString("WORKER_ID", hostname()),
WorkerPollInterval: envDuration("WORKER_POLL_INTERVAL", 2*time.Second),
WorkerClaimLimit: envInt("WORKER_CLAIM_LIMIT", 4),
}
}
@@ -76,3 +84,11 @@ func envDuration(key string, fallback time.Duration) time.Duration {
}
return v
}
func hostname() string {
h, err := os.Hostname()
if err != nil || h == "" {
return "ai-service-worker"
}
return h
}

159
internal/llm/client.go Normal file
View File

@@ -0,0 +1,159 @@
package llm
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)
type Client struct {
baseURL string
apiKey string
model string
http *http.Client
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ChatInput struct {
System string `json:"system,omitempty"`
User string `json:"user,omitempty"`
Messages []Message `json:"messages,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
ResponseFormat json.RawMessage `json:"response_format,omitempty"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type ChatResult struct {
Content string `json:"content"`
Model string `json:"model"`
Usage *Usage `json:"usage,omitempty"`
DurationMS int64 `json:"duration_ms"`
}
type chatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Temperature float64 `json:"temperature"`
MaxTokens int `json:"max_tokens,omitempty"`
ResponseFormat *json.RawMessage `json:"response_format,omitempty"`
}
type chatResponse struct {
Model string `json:"model,omitempty"`
Choices []struct {
Message Message `json:"message"`
} `json:"choices"`
Usage *Usage `json:"usage,omitempty"`
Error *struct {
Message string `json:"message"`
} `json:"error,omitempty"`
}
func New(baseURL, apiKey, model string, timeout time.Duration) *Client {
return &Client{
baseURL: strings.TrimRight(strings.TrimSpace(baseURL), "/"),
apiKey: apiKey,
model: model,
http: &http.Client{Timeout: timeout},
}
}
func (c *Client) Chat(ctx context.Context, in ChatInput) (*ChatResult, error) {
if c == nil || c.baseURL == "" {
return nil, fmt.Errorf("llm not configured")
}
messages := normalizeMessages(in)
if len(messages) == 0 {
return nil, fmt.Errorf("messages are required")
}
temp := 0.1
if in.Temperature != nil {
temp = *in.Temperature
}
reqBody := chatRequest{
Model: c.model,
Messages: messages,
Temperature: temp,
MaxTokens: in.MaxTokens,
}
if len(in.ResponseFormat) > 0 {
reqBody.ResponseFormat = &in.ResponseFormat
}
body, err := json.Marshal(reqBody)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/chat/completions", bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
if c.apiKey != "" {
req.Header.Set("Authorization", "Bearer "+c.apiKey)
}
start := time.Now()
resp, err := c.http.Do(req)
duration := time.Since(start)
if err != nil {
return nil, fmt.Errorf("llm do: %w", err)
}
defer resp.Body.Close()
raw, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if err != nil {
return nil, fmt.Errorf("llm read: %w", err)
}
if resp.StatusCode >= 300 {
return nil, fmt.Errorf("llm HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(raw)))
}
var out chatResponse
if err := json.Unmarshal(raw, &out); err != nil {
return nil, fmt.Errorf("llm decode: %w", err)
}
if out.Error != nil {
return nil, fmt.Errorf("llm error: %s", out.Error.Message)
}
if len(out.Choices) == 0 {
return nil, fmt.Errorf("llm: empty choices")
}
modelName := out.Model
if modelName == "" {
modelName = c.model
}
return &ChatResult{
Content: out.Choices[0].Message.Content,
Model: modelName,
Usage: out.Usage,
DurationMS: duration.Milliseconds(),
}, nil
}
func normalizeMessages(in ChatInput) []Message {
if len(in.Messages) > 0 {
return in.Messages
}
var out []Message
if strings.TrimSpace(in.System) != "" {
out = append(out, Message{Role: "system", Content: in.System})
}
if strings.TrimSpace(in.User) != "" {
out = append(out, Message{Role: "user", Content: in.User})
}
return out
}

122
internal/worker/worker.go Normal file
View File

@@ -0,0 +1,122 @@
package worker
import (
"context"
"encoding/json"
"log/slog"
"strings"
"time"
"ai-service/internal/llm"
"ai-service/internal/model"
"ai-service/internal/store"
)
const (
TaskLLMChat = "llm_chat"
TaskChatCompletion = "chat_completion"
)
type Worker struct {
store *store.Store
llm *llm.Client
workerID string
modelProfile string
pollInterval time.Duration
claimLimit int
}
func New(store *store.Store, llmClient *llm.Client, workerID, modelProfile string, pollInterval time.Duration, claimLimit int) *Worker {
if pollInterval <= 0 {
pollInterval = 2 * time.Second
}
if claimLimit <= 0 {
claimLimit = 4
}
if strings.TrimSpace(workerID) == "" {
workerID = "ai-service-worker"
}
return &Worker{
store: store,
llm: llmClient,
workerID: workerID,
modelProfile: modelProfile,
pollInterval: pollInterval,
claimLimit: claimLimit,
}
}
func (w *Worker) Run(ctx context.Context) {
ticker := time.NewTicker(w.pollInterval)
defer ticker.Stop()
for {
w.tick(ctx)
select {
case <-ctx.Done():
return
case <-ticker.C:
}
}
}
func (w *Worker) tick(ctx context.Context) {
jobs, err := w.store.ClaimJobs(ctx, model.ClaimJobs{
WorkerID: w.workerID,
TaskTypes: []string{TaskLLMChat, TaskChatCompletion},
ModelProfiles: []string{w.modelProfile},
Limit: w.claimLimit,
})
if err != nil {
slog.Error("claim jobs failed", "error", err)
return
}
for _, job := range jobs {
w.process(ctx, job)
}
}
func (w *Worker) process(ctx context.Context, job *model.Job) {
var input llm.ChatInput
if err := json.Unmarshal(job.Input, &input); err != nil {
w.fail(ctx, job, "bad_input", err.Error())
return
}
result, err := w.llm.Chat(ctx, input)
if err != nil {
w.fail(ctx, job, classifyLLMError(err), err.Error())
return
}
body, err := json.Marshal(result)
if err != nil {
w.fail(ctx, job, "bad_response", err.Error())
return
}
if _, err := w.store.CompleteJob(ctx, job.ID, model.CompleteJob{Result: body}); err != nil {
slog.Error("complete job failed", "job_id", job.ID, "error", err)
}
}
func (w *Worker) fail(ctx context.Context, job *model.Job, code, message string) {
if _, err := w.store.FailJob(ctx, job.ID, model.FailJob{ErrorCode: code, ErrorMessage: message}); err != nil {
slog.Error("fail job failed", "job_id", job.ID, "error", err)
}
}
func classifyLLMError(err error) string {
if err == nil {
return "unknown"
}
s := strings.ToLower(err.Error())
switch {
case strings.Contains(s, "context deadline exceeded") || strings.Contains(s, "timeout"):
return "timeout"
case strings.Contains(s, "connection refused") || strings.Contains(s, "connection reset") || strings.Contains(s, "no route to host") || strings.Contains(s, "llm http 5"):
return "model_unavailable"
case strings.Contains(s, "llm http 4") || strings.Contains(s, "messages are required"):
return "bad_input"
case strings.Contains(s, "llm decode") || strings.Contains(s, "empty choices"):
return "bad_response"
default:
return "unknown"
}
}