Compare commits
44 Commits
76ac9b8896
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8d06cfabb1 | ||
|
|
b81a8ee6be | ||
|
|
63553fba33 | ||
|
|
f32265400b | ||
|
|
aad905c2c8 | ||
|
|
22d85ce646 | ||
|
|
3c124c5f5a | ||
|
|
773f53f790 | ||
|
|
abc64214d2 | ||
|
|
e45884c5e5 | ||
|
|
c618ffaff9 | ||
|
|
92ac01d8b5 | ||
|
|
b536877181 | ||
|
|
bc71caa762 | ||
|
|
d0980007d7 | ||
|
|
ea632902bb | ||
|
|
800d1d7cdd | ||
|
|
837acf2f00 | ||
|
|
631a45aff3 | ||
|
|
7a1965e402 | ||
|
|
3c2f13b967 | ||
|
|
2a481fdc54 | ||
|
|
f54400e8e2 | ||
|
|
e6ae792325 | ||
|
|
80fa21ff80 | ||
|
|
7d0e27f681 | ||
|
|
11247f17de | ||
|
|
ae1802dab9 | ||
|
|
bde56978d6 | ||
|
|
8d6cd84403 | ||
|
|
1b63dcdbf5 | ||
|
|
bf945e05e3 | ||
|
|
e074f6b226 | ||
|
|
9bd6d726f0 | ||
|
|
5c965be8c9 | ||
|
|
64bf40b3ba | ||
|
|
e6c2b46cf6 | ||
|
|
817eb8ff71 | ||
|
|
94e0d03580 | ||
|
|
add15f1385 | ||
|
|
35c60f0e0e | ||
|
|
88e7c86836 | ||
|
|
1202ebcb7f | ||
|
|
2ef71a822b |
35
.gitea/scripts/hygiene-check.sh
Normal file
35
.gitea/scripts/hygiene-check.sh
Normal file
@@ -0,0 +1,35 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
fail=0
|
||||
|
||||
while IFS= read -r -d '' path; do
|
||||
base="$(basename "$path")"
|
||||
case "$base" in
|
||||
.DS_Store|.env)
|
||||
echo "::error file=$path::tracked local-only file is forbidden"
|
||||
fail=1
|
||||
;;
|
||||
esac
|
||||
|
||||
case "$path" in
|
||||
*node_modules/*|node_modules/*)
|
||||
echo "::error file=$path::tracked node_modules content is forbidden"
|
||||
fail=1
|
||||
;;
|
||||
*.tmp|*.temp|*.bak|*.orig|*.rej|*.zip|*.tar|*.tar.gz|*.tgz|*.rar|*.7z)
|
||||
echo "::error file=$path::tracked temporary/archive artifact is forbidden"
|
||||
fail=1
|
||||
;;
|
||||
esac
|
||||
|
||||
if [ -f "$path" ]; then
|
||||
size="$(wc -c < "$path" | tr -d ' ')"
|
||||
if [ "${size:-0}" -gt 52428800 ]; then
|
||||
echo "::error file=$path::tracked file is larger than 50 MiB"
|
||||
fail=1
|
||||
fi
|
||||
fi
|
||||
done < <(git ls-files -z)
|
||||
|
||||
exit "$fail"
|
||||
@@ -5,8 +5,15 @@ on:
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
hygiene:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- run: bash .gitea/scripts/hygiene-check.sh
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
needs: hygiene
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
|
||||
@@ -58,8 +58,11 @@ jobs:
|
||||
server=${{ env.NODE_REGISTRY }}/admin/ai-service:${{ github.sha }}
|
||||
kubectl -n ai-service set image deployment/ai-service-worker \
|
||||
worker=${{ env.NODE_REGISTRY }}/admin/ai-service:${{ github.sha }}
|
||||
kubectl -n ai-service set image deployment/ai-service-analysis-worker \
|
||||
worker=${{ env.NODE_REGISTRY }}/admin/ai-service:${{ github.sha }}
|
||||
kubectl -n ai-service set image deployment/ai-service-transcription-worker \
|
||||
worker=${{ env.NODE_REGISTRY }}/admin/ai-service:${{ github.sha }}
|
||||
kubectl -n ai-service rollout status deployment/ai-service --timeout=180s
|
||||
kubectl -n ai-service rollout status deployment/ai-service-worker --timeout=180s
|
||||
kubectl -n ai-service rollout status deployment/ai-service-analysis-worker --timeout=180s
|
||||
kubectl -n ai-service rollout status deployment/ai-service-transcription-worker --timeout=180s
|
||||
|
||||
124
README.md
124
README.md
@@ -2,8 +2,9 @@
|
||||
|
||||
Technical AI job service for Portal workloads.
|
||||
|
||||
The first version owns only AI job lifecycle and metrics. Business data stays in
|
||||
domain services such as `telephony`, `monitoring-tg` and `monitoring-pf`.
|
||||
AI Service owns technical AI job lifecycle, provider execution and metrics.
|
||||
Business data stays in domain services such as `telephony`, `monitoring-tg` and
|
||||
`monitoring-pf`.
|
||||
|
||||
## Generic job contract
|
||||
|
||||
@@ -14,8 +15,9 @@ The service is intentionally domain-agnostic:
|
||||
- `owner_ref` is the caller's stable object reference, for example
|
||||
`beeline/{call_id}` or `channel/{message_id}`.
|
||||
- `task_type` describes the technical task class, for example
|
||||
`transcribe`, `call_analysis`, `tg_analysis`, `pf_competitor_analysis`.
|
||||
- `model_profile` selects a runtime profile, for example `whisperx`,
|
||||
`transcription`, `transcript_summary`, `call_analysis`,
|
||||
`telegram_classification`, `tg_analysis`, `pf_competitor_analysis`.
|
||||
- `model_profile` selects a runtime profile, for example `whisper-large-v3`,
|
||||
`qwen2.5-14b`, `vision`, or a future provider profile.
|
||||
- `input` and `result` are JSON payloads owned by the caller and worker.
|
||||
|
||||
@@ -24,8 +26,9 @@ service.
|
||||
|
||||
## Built-in workers
|
||||
|
||||
The first built-in worker processes `llm_chat`, `chat_completion` and
|
||||
`call_analysis` jobs whose `model_profile` equals `LLM_MODEL`.
|
||||
The LLM worker processes `llm_chat`, `chat_completion`, `call_analysis`,
|
||||
`transcript_summary` and `telegram_classification` jobs whose `model_profile`
|
||||
equals `LLM_MODEL`.
|
||||
|
||||
Input can be either explicit messages:
|
||||
|
||||
@@ -40,39 +43,34 @@ Input can be either explicit messages:
|
||||
```
|
||||
|
||||
or compact `system` / `user` fields. The completed job result contains
|
||||
`content`, `model`, `usage` and `duration_ms`.
|
||||
`schema_version=ai.chat_result.v1`, `content`, `model`, `usage` and
|
||||
`duration_ms`.
|
||||
|
||||
`call_analysis` uses the same input contract as `llm_chat`; callers may include
|
||||
domain metadata fields in `input`, but the worker only reads chat fields such as
|
||||
`system`, `user`, `messages`, `max_tokens` and `response_format`.
|
||||
`call_analysis` and `transcript_summary` use the same input contract as
|
||||
`llm_chat`; callers may include domain metadata fields in `input`, but the
|
||||
worker only reads chat fields such as `system`, `user`, `messages`,
|
||||
`max_tokens` and `response_format`.
|
||||
|
||||
`transcription` jobs can run several transcription providers in order for
|
||||
temporary A/B comparison. The main `segments` field remains compatible with
|
||||
telephony and contains the first successful provider result. The full comparison
|
||||
is stored in `attempts` with `provider`, `model`, `status`, `text`, `segments`,
|
||||
`duration_ms` and `error`.
|
||||
`transcription` jobs are processed only by Whisper Large v3
|
||||
(`openai/whisper-large-v3`) through an OpenAI-compatible
|
||||
`/v1/audio/transcriptions` endpoint. The returned `segments` field stays
|
||||
compatible with telephony. If the provider returns one long segment, AI Service
|
||||
splits it into smaller transcript segments without inventing speaker labels.
|
||||
The completed job result contains
|
||||
`schema_version=ai.transcription_result.v1`, `provider`, `model`, `language`,
|
||||
`segments`, optional provider `attempts` and `duration_ms`.
|
||||
|
||||
Recommended comparison order:
|
||||
AI-server compose snippet for Whisper Large v3 lives in
|
||||
`deploy/ai-server/docker-compose.audio.yml`:
|
||||
|
||||
1. `whisperx`
|
||||
2. `qwen2-audio` (`Qwen/Qwen2-Audio-7B-Instruct`)
|
||||
3. `voxtral-small` (`mistralai/Voxtral-Small-24B-2507`)
|
||||
- Whisper endpoint: `http://10.2.3.5:8004`
|
||||
- Start Whisper:
|
||||
`docker compose -f docker-compose.yml -f docker-compose.audio.yml --profile whisper-large-v3 up -d whisper-large-v3`
|
||||
|
||||
Qwen2-Audio and Voxtral are called through an OpenAI-compatible
|
||||
`/v1/chat/completions` endpoint with `input_audio`; set their endpoint URLs only
|
||||
after the models are actually exposed on the AI server.
|
||||
|
||||
AI-server compose snippets for these temporary comparison endpoints live in
|
||||
`deploy/ai-server/docker-compose.audio.yml`. They are profile-gated because the
|
||||
single GPU cannot keep the production text vLLM, two WhisperX instances, Qwen2
|
||||
Audio and Voxtral loaded at the same time:
|
||||
|
||||
- Qwen2-Audio endpoint: `http://10.2.3.5:8003`
|
||||
- Voxtral endpoint: `http://10.2.3.5:8004`
|
||||
- Start Qwen only:
|
||||
`docker compose -f docker-compose.yml -f docker-compose.audio.yml --profile qwen-audio up -d qwen-audio`
|
||||
- Start Voxtral only:
|
||||
`docker compose -f docker-compose.yml -f docker-compose.audio.yml --profile voxtral-small up -d voxtral-small`
|
||||
In Kubernetes the dedicated transcription worker may claim more than one
|
||||
`whisper-large-v3` job at a time. This keeps download/upload/wait overhead from
|
||||
serializing the queue while the Whisper provider still controls the actual GPU
|
||||
scheduling.
|
||||
|
||||
## API
|
||||
|
||||
@@ -90,7 +88,9 @@ Audio and Voxtral loaded at the same time:
|
||||
- `GET /api/v1/providers/status` checks configured AI providers without
|
||||
returning secrets.
|
||||
- `GET /api/v1/infra/status` returns AI-server sidecar telemetry
|
||||
(GPU, containers, vLLM and WhisperX live metrics) when configured.
|
||||
(GPU, containers and vLLM live metrics) when configured.
|
||||
- `GET /health/detail` returns PostgreSQL, provider, queue, error, throughput
|
||||
and infra components for Portal `admin/health`.
|
||||
- `GET /healthz` returns process health.
|
||||
- `GET /readyz` checks PostgreSQL readiness.
|
||||
- Built-in workers expose open Kubernetes endpoints on `WORKER_HTTP_PORT`:
|
||||
@@ -100,6 +100,34 @@ All `/api/v1/*` endpoints require `Authorization: Bearer <AI_SERVICE_TOKEN>`
|
||||
when `AI_SERVICE_TOKEN` is configured. Health and readiness endpoints stay open
|
||||
for Kubernetes probes.
|
||||
|
||||
## Retry policy
|
||||
|
||||
Workers store a normalized `error_code` on failed jobs. AI Service requeues only
|
||||
explicitly retryable categories while attempts remain.
|
||||
|
||||
| Category | Retry | Delay |
|
||||
| --- | --- | --- |
|
||||
| `provider_unavailable`, `model_unavailable`, `provider_error`, `dependency_error`, `timeout`, `storage_error`, `stale_worker` | yes | 30s |
|
||||
| `bad_response`, `transcript_hallucination`, `transcript_incomplete`, `internal_error`, `unknown` | yes | 2m |
|
||||
| `bad_audio`, `bad_input`, `context_length`, `unsupported_task`, `cancelled` | no | - |
|
||||
|
||||
Domain services may still expose manual retry for terminal errors after the
|
||||
underlying data or prompt is corrected.
|
||||
|
||||
## Result schemas
|
||||
|
||||
AI Service result payloads are versioned with `schema_version`. Consumers should
|
||||
ignore unknown fields and reject only unsupported major schema names.
|
||||
|
||||
Current schemas:
|
||||
|
||||
- `ai.chat_result.v1`: `{schema_version, content, model, usage?, duration_ms}`.
|
||||
- `ai.transcription_result.v1`:
|
||||
`{schema_version, provider?, model?, attempts?, language, segments, duration_ms}`.
|
||||
|
||||
New optional fields may be added to a `v1` schema without a breaking change.
|
||||
Breaking shape changes require a new schema name.
|
||||
|
||||
## Configuration
|
||||
|
||||
- `HTTP_HOST`, default `0.0.0.0`
|
||||
@@ -111,19 +139,11 @@ for Kubernetes probes.
|
||||
- `LLM_API_KEY`, primary LLM API key
|
||||
- `LLM_MODEL`, default `qwen2.5-14b`
|
||||
- `LLM_TIMEOUT`, default `5m`
|
||||
- `TRANSCRIPTION_PROVIDERS`, default `whisperx`, comma-separated ordered list:
|
||||
`whisperx,qwen2-audio,voxtral-small`
|
||||
- `WHISPERX_URL`, WhisperX endpoint for transcription jobs
|
||||
- `QWEN_AUDIO_BASE_URL`, OpenAI-compatible endpoint for Qwen2-Audio
|
||||
- `QWEN_AUDIO_MODEL`, default `Qwen/Qwen2-Audio-7B-Instruct`
|
||||
- `QWEN_AUDIO_API_KEY`, optional bearer token for Qwen2-Audio; falls back to
|
||||
- `AUDIO_TRANSCRIPTION_BASE_URL`, OpenAI-compatible transcription endpoint
|
||||
- `AUDIO_TRANSCRIPTION_MODEL`, default `openai/whisper-large-v3`
|
||||
- `AUDIO_TRANSCRIPTION_API_KEY`, optional bearer token; falls back to
|
||||
`AUDIO_LLM_API_KEY`, then `LLM_API_KEY`
|
||||
- `VOXTRAL_BASE_URL`, OpenAI-compatible endpoint for Voxtral
|
||||
- `VOXTRAL_MODEL`, default `mistralai/Voxtral-Small-24B-2507`
|
||||
- `VOXTRAL_API_KEY`, optional bearer token for Voxtral; falls back to
|
||||
`AUDIO_LLM_API_KEY`, then `LLM_API_KEY`
|
||||
- `AUDIO_LLM_PROMPT`, transcription instruction for audio LLM providers
|
||||
- `AUDIO_LLM_MAX_TOKENS`, default `4096`
|
||||
- `AUDIO_TRANSCRIPTION_PROMPT`, transcription instruction
|
||||
- `WORKER_ID`, default hostname
|
||||
- `WORKER_HTTP_HOST`, default `0.0.0.0`
|
||||
- `WORKER_HTTP_PORT`, default `8081`
|
||||
@@ -131,8 +151,10 @@ for Kubernetes probes.
|
||||
- `WORKER_CLAIM_LIMIT`, default `4`
|
||||
- `WORKER_LEASE_TIMEOUT`, default `15m`
|
||||
|
||||
## Next integration step
|
||||
## Current telephony pipeline
|
||||
|
||||
`telephony` should first mirror low-risk analysis jobs into this service while
|
||||
continuing local processing. Remote execution can then be enabled by feature
|
||||
flag per task type.
|
||||
`telephony` now uses AI Service as the only AI execution path:
|
||||
|
||||
1. `transcription` turns call audio into segments.
|
||||
2. `transcript_summary` creates a detailed Russian call summary.
|
||||
3. `call_analysis` runs tags and negotiation rules against the summary.
|
||||
|
||||
@@ -49,21 +49,11 @@ func main() {
|
||||
|
||||
llmClient := llm.New(cfg.LLMBaseURL, cfg.LLMAPIKey, cfg.LLMModel, cfg.LLMTimeout)
|
||||
transcriber := transcription.NewWithOptions(transcription.Options{
|
||||
Providers: cfg.TranscriptionProviders,
|
||||
WhisperXURL: cfg.WhisperXURL,
|
||||
WhisperXTimeout: cfg.WhisperXTimeout,
|
||||
FfmpegPath: cfg.FfmpegPath,
|
||||
LeadSilence: cfg.WhisperXLeadSilence,
|
||||
QwenAudioBaseURL: cfg.QwenAudioBaseURL,
|
||||
QwenAudioAPIKey: cfg.QwenAudioAPIKey,
|
||||
QwenAudioModel: cfg.QwenAudioModel,
|
||||
QwenAudioTimeout: cfg.QwenAudioTimeout,
|
||||
VoxtralBaseURL: cfg.VoxtralBaseURL,
|
||||
VoxtralAPIKey: cfg.VoxtralAPIKey,
|
||||
VoxtralModel: cfg.VoxtralModel,
|
||||
VoxtralTimeout: cfg.VoxtralTimeout,
|
||||
AudioLLMPrompt: cfg.AudioLLMPrompt,
|
||||
AudioLLMMaxTokens: cfg.AudioLLMMaxTokens,
|
||||
AudioBaseURL: cfg.AudioBaseURL,
|
||||
AudioAPIKey: cfg.AudioAPIKey,
|
||||
AudioModel: cfg.AudioModel,
|
||||
AudioTimeout: cfg.AudioTimeout,
|
||||
AudioPrompt: cfg.AudioPrompt,
|
||||
})
|
||||
w := worker.New(db, llmClient, transcriber, cfg.WorkerID, cfg.LLMModel, cfg.WorkerTaskTypes, cfg.WorkerModelProfiles, cfg.WorkerPollInterval, cfg.WorkerLeaseTimeout, cfg.WorkerClaimLimit)
|
||||
healthSrv := startHealthServer(ctx, db, cfg)
|
||||
@@ -72,8 +62,8 @@ func main() {
|
||||
"worker_id", cfg.WorkerID,
|
||||
"model", cfg.LLMModel,
|
||||
"transcription_enabled", transcriber != nil,
|
||||
"transcription_providers", cfg.TranscriptionProviders,
|
||||
"whisperx_lead_silence", cfg.WhisperXLeadSilence.String(),
|
||||
"transcription_provider", transcription.ProviderWhisperLargeV3,
|
||||
"transcription_model", cfg.AudioModel,
|
||||
"task_types", cfg.WorkerTaskTypes,
|
||||
"model_profiles", cfg.WorkerModelProfiles,
|
||||
"poll_interval", cfg.WorkerPollInterval.String(),
|
||||
@@ -144,7 +134,8 @@ func (h workerHealth) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
"worker_id": h.cfg.WorkerID,
|
||||
"task_types": h.cfg.WorkerTaskTypes,
|
||||
"model_profiles": h.cfg.WorkerModelProfiles,
|
||||
"transcription_providers": h.cfg.TranscriptionProviders,
|
||||
"transcription_provider": transcription.ProviderWhisperLargeV3,
|
||||
"transcription_model": h.cfg.AudioModel,
|
||||
"claim_limit": h.cfg.WorkerClaimLimit,
|
||||
"poll_interval": h.cfg.WorkerPollInterval.String(),
|
||||
"lease_timeout": h.cfg.WorkerLeaseTimeout.String(),
|
||||
|
||||
@@ -1,63 +1,12 @@
|
||||
services:
|
||||
qwen-audio:
|
||||
image: vllm/vllm-openai:latest
|
||||
container_name: qwen-audio
|
||||
whisper-large-v3:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: vllm-audio.Dockerfile
|
||||
image: vllm-audio:local
|
||||
container_name: whisper-large-v3
|
||||
profiles:
|
||||
- qwen-audio
|
||||
- audio-compare
|
||||
restart: unless-stopped
|
||||
ipc: host
|
||||
runtime: nvidia
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: all
|
||||
capabilities: [gpu]
|
||||
environment:
|
||||
HUGGING_FACE_HUB_TOKEN: ${HF_TOKEN}
|
||||
VLLM_API_KEY: ${VLLM_API_KEY}
|
||||
HF_HOME: /cache
|
||||
volumes:
|
||||
- ./data/vllm-cache:/cache
|
||||
networks:
|
||||
- audio-models
|
||||
ports:
|
||||
- "10.2.3.5:8003:8000"
|
||||
command:
|
||||
- "--model"
|
||||
- "Qwen/Qwen2-Audio-7B-Instruct"
|
||||
- "--served-model-name"
|
||||
- "Qwen/Qwen2-Audio-7B-Instruct"
|
||||
- "--trust-remote-code"
|
||||
- "--host"
|
||||
- "0.0.0.0"
|
||||
- "--port"
|
||||
- "8000"
|
||||
- "--max-model-len"
|
||||
- "8192"
|
||||
- "--gpu-memory-utilization"
|
||||
- "0.25"
|
||||
- "--api-key"
|
||||
- "${VLLM_API_KEY}"
|
||||
- "--max-num-seqs"
|
||||
- "4"
|
||||
- "--max-num-batched-tokens"
|
||||
- "4096"
|
||||
healthcheck:
|
||||
test: ["CMD", "python3", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"]
|
||||
interval: 30s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
start_period: 900s
|
||||
|
||||
voxtral-small:
|
||||
image: vllm/vllm-openai:latest
|
||||
container_name: voxtral-small
|
||||
profiles:
|
||||
- voxtral-small
|
||||
- audio-compare
|
||||
- whisper-large-v3
|
||||
restart: unless-stopped
|
||||
ipc: host
|
||||
runtime: nvidia
|
||||
@@ -80,32 +29,17 @@ services:
|
||||
- "10.2.3.5:8004:8000"
|
||||
command:
|
||||
- "--model"
|
||||
- "mistralai/Voxtral-Small-24B-2507"
|
||||
- "openai/whisper-large-v3"
|
||||
- "--served-model-name"
|
||||
- "mistralai/Voxtral-Small-24B-2507"
|
||||
- "--tokenizer-mode"
|
||||
- "mistral"
|
||||
- "--config-format"
|
||||
- "mistral"
|
||||
- "--load-format"
|
||||
- "mistral"
|
||||
- "--tool-call-parser"
|
||||
- "mistral"
|
||||
- "--enable-auto-tool-choice"
|
||||
- "openai/whisper-large-v3"
|
||||
- "--host"
|
||||
- "0.0.0.0"
|
||||
- "--port"
|
||||
- "8000"
|
||||
- "--max-model-len"
|
||||
- "32768"
|
||||
- "--gpu-memory-utilization"
|
||||
- "0.62"
|
||||
- "0.55"
|
||||
- "--api-key"
|
||||
- "${VLLM_API_KEY}"
|
||||
- "--max-num-seqs"
|
||||
- "2"
|
||||
- "--max-num-batched-tokens"
|
||||
- "8192"
|
||||
healthcheck:
|
||||
test: ["CMD", "python3", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"]
|
||||
interval: 30s
|
||||
|
||||
3
deploy/ai-server/vllm-audio.Dockerfile
Normal file
3
deploy/ai-server/vllm-audio.Dockerfile
Normal file
@@ -0,0 +1,3 @@
|
||||
FROM vllm/vllm-openai:latest
|
||||
|
||||
RUN python3 -m pip install --no-cache-dir av soundfile librosa
|
||||
@@ -1,20 +0,0 @@
|
||||
upstream whisperx_upstream {
|
||||
server whisperx-1:8000 max_fails=3 fail_timeout=30s;
|
||||
}
|
||||
|
||||
server {
|
||||
listen 80 default_server;
|
||||
client_max_body_size 200m;
|
||||
|
||||
location / {
|
||||
proxy_pass http://whisperx_upstream;
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_connect_timeout 30s;
|
||||
proxy_send_timeout 10m;
|
||||
proxy_read_timeout 10m;
|
||||
proxy_request_buffering off;
|
||||
proxy_buffering off;
|
||||
}
|
||||
}
|
||||
@@ -18,21 +18,11 @@ type Config struct {
|
||||
LLMAPIKey string
|
||||
LLMModel string
|
||||
LLMTimeout time.Duration
|
||||
TranscriptionProviders []string
|
||||
WhisperXURL string
|
||||
WhisperXTimeout time.Duration
|
||||
WhisperXLeadSilence time.Duration
|
||||
QwenAudioBaseURL string
|
||||
QwenAudioAPIKey string
|
||||
QwenAudioModel string
|
||||
QwenAudioTimeout time.Duration
|
||||
VoxtralBaseURL string
|
||||
VoxtralAPIKey string
|
||||
VoxtralModel string
|
||||
VoxtralTimeout time.Duration
|
||||
AudioLLMMaxTokens int
|
||||
AudioLLMPrompt string
|
||||
FfmpegPath string
|
||||
AudioBaseURL string
|
||||
AudioAPIKey string
|
||||
AudioModel string
|
||||
AudioTimeout time.Duration
|
||||
AudioPrompt string
|
||||
AIStatsSidecarURL string
|
||||
AIStatsTimeout time.Duration
|
||||
|
||||
@@ -58,21 +48,11 @@ func Load() Config {
|
||||
LLMAPIKey: envString("LLM_API_KEY", ""),
|
||||
LLMModel: envString("LLM_MODEL", "qwen2.5-14b"),
|
||||
LLMTimeout: envDuration("LLM_TIMEOUT", 5*time.Minute),
|
||||
TranscriptionProviders: envCSVDefault("TRANSCRIPTION_PROVIDERS", []string{"whisperx"}),
|
||||
WhisperXURL: envString("WHISPERX_URL", ""),
|
||||
WhisperXTimeout: envDuration("WHISPERX_TIMEOUT", 10*time.Minute),
|
||||
WhisperXLeadSilence: envDuration("WHISPERX_LEAD_SILENCE", 800*time.Millisecond),
|
||||
QwenAudioBaseURL: envString("QWEN_AUDIO_BASE_URL", envString("AUDIO_LLM_BASE_URL", "")),
|
||||
QwenAudioAPIKey: envString("QWEN_AUDIO_API_KEY", envString("AUDIO_LLM_API_KEY", envString("LLM_API_KEY", ""))),
|
||||
QwenAudioModel: envString("QWEN_AUDIO_MODEL", "Qwen/Qwen2-Audio-7B-Instruct"),
|
||||
QwenAudioTimeout: envDuration("QWEN_AUDIO_TIMEOUT", envDuration("AUDIO_LLM_TIMEOUT", 10*time.Minute)),
|
||||
VoxtralBaseURL: envString("VOXTRAL_BASE_URL", envString("AUDIO_LLM_BASE_URL", "")),
|
||||
VoxtralAPIKey: envString("VOXTRAL_API_KEY", envString("AUDIO_LLM_API_KEY", envString("LLM_API_KEY", ""))),
|
||||
VoxtralModel: envString("VOXTRAL_MODEL", "mistralai/Voxtral-Small-24B-2507"),
|
||||
VoxtralTimeout: envDuration("VOXTRAL_TIMEOUT", envDuration("AUDIO_LLM_TIMEOUT", 10*time.Minute)),
|
||||
AudioLLMMaxTokens: envInt("AUDIO_LLM_MAX_TOKENS", 4096),
|
||||
AudioLLMPrompt: envString("AUDIO_LLM_PROMPT", defaultAudioLLMPrompt()),
|
||||
FfmpegPath: envString("FFMPEG_PATH", "/usr/bin/ffmpeg"),
|
||||
AudioBaseURL: envString("AUDIO_TRANSCRIPTION_BASE_URL", ""),
|
||||
AudioAPIKey: envString("AUDIO_TRANSCRIPTION_API_KEY", ""),
|
||||
AudioModel: envString("AUDIO_TRANSCRIPTION_MODEL", "openai/whisper-large-v3"),
|
||||
AudioTimeout: envDuration("AUDIO_TRANSCRIPTION_TIMEOUT", 10*time.Minute),
|
||||
AudioPrompt: envString("AUDIO_TRANSCRIPTION_PROMPT", defaultAudioPrompt()),
|
||||
AIStatsSidecarURL: envString("AI_STATS_SIDECAR_URL", ""),
|
||||
AIStatsTimeout: envDuration("AI_STATS_TIMEOUT", 8*time.Second),
|
||||
|
||||
@@ -145,15 +125,8 @@ func envCSV(key string) []string {
|
||||
return out
|
||||
}
|
||||
|
||||
func envCSVDefault(key string, fallback []string) []string {
|
||||
if values := envCSV(key); len(values) > 0 {
|
||||
return values
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func defaultAudioLLMPrompt() string {
|
||||
return "Расшифруй речь из аудио максимально точно. Сохрани русский язык, имена, телефоны, суммы и смысловые паузы. Не добавляй комментарии, анализ, Markdown или JSON. Верни только чистый текст расшифровки."
|
||||
func defaultAudioPrompt() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func hostname() string {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"time"
|
||||
|
||||
"ai-service/internal/model"
|
||||
"ai-service/internal/transcription"
|
||||
)
|
||||
|
||||
type dashboardResponse struct {
|
||||
@@ -30,7 +31,7 @@ func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := contextWithTimeout(r, 12*time.Second)
|
||||
defer cancel()
|
||||
|
||||
stats, err := s.store.Stats(ctx)
|
||||
stats, err := s.store.Stats(ctx, s.cfg.WorkerLeaseTimeout)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
@@ -43,7 +44,6 @@ func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp := dashboardResponse{
|
||||
At: now,
|
||||
Summary: summarizeQueues(stats),
|
||||
@@ -52,9 +52,7 @@ func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
|
||||
At: now,
|
||||
Providers: []providerStatus{
|
||||
s.checkLLM(ctx),
|
||||
s.checkWhisperX(ctx),
|
||||
s.checkAudioLLM(ctx, "qwen2-audio", s.cfg.QwenAudioBaseURL, s.cfg.QwenAudioAPIKey, s.cfg.QwenAudioModel, s.cfg.QwenAudioTimeout),
|
||||
s.checkAudioLLM(ctx, "voxtral-small", s.cfg.VoxtralBaseURL, s.cfg.VoxtralAPIKey, s.cfg.VoxtralModel, s.cfg.VoxtralTimeout),
|
||||
s.checkAudioLLM(ctx, transcription.ProviderWhisperLargeV3, s.cfg.AudioBaseURL, s.cfg.AudioAPIKey, s.cfg.AudioModel, s.cfg.AudioTimeout),
|
||||
},
|
||||
},
|
||||
Infra: loadInfraSnapshot(r, s.cfg),
|
||||
|
||||
241
internal/httpapi/health.go
Normal file
241
internal/httpapi/health.go
Normal file
@@ -0,0 +1,241 @@
|
||||
package httpapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ai-service/internal/model"
|
||||
"ai-service/internal/transcription"
|
||||
)
|
||||
|
||||
type healthDetailResponse struct {
|
||||
Status string `json:"status"`
|
||||
Generated time.Time `json:"generated_at"`
|
||||
Components []healthComponent `json:"components"`
|
||||
}
|
||||
|
||||
type healthComponent struct {
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Data map[string]any `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
func (s *Server) handleHealthDetail(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := contextWithTimeout(r, 12*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp := healthDetailResponse{
|
||||
Status: "ok",
|
||||
Generated: time.Now().UTC(),
|
||||
}
|
||||
|
||||
if err := s.store.Ping(ctx); err != nil {
|
||||
resp.Components = append(resp.Components, healthComponent{
|
||||
Name: "postgres",
|
||||
Status: "down",
|
||||
Error: err.Error(),
|
||||
})
|
||||
resp.Status = worseHealthStatus(resp.Status, "down")
|
||||
writeJSON(w, http.StatusServiceUnavailable, resp)
|
||||
return
|
||||
}
|
||||
resp.Components = append(resp.Components, healthComponent{Name: "postgres", Status: "ok"})
|
||||
|
||||
stats, err := s.store.Stats(ctx, s.cfg.WorkerLeaseTimeout)
|
||||
if err != nil {
|
||||
resp.Components = append(resp.Components, healthComponent{
|
||||
Name: "queue",
|
||||
Status: "down",
|
||||
Error: err.Error(),
|
||||
})
|
||||
resp.Status = worseHealthStatus(resp.Status, "down")
|
||||
writeJSON(w, http.StatusServiceUnavailable, resp)
|
||||
return
|
||||
}
|
||||
|
||||
for _, component := range []healthComponent{
|
||||
s.healthProviders(ctx),
|
||||
healthQueue(stats),
|
||||
healthErrors(stats),
|
||||
healthThroughput(stats),
|
||||
healthInfra(loadInfraSnapshot(r, s.cfg)),
|
||||
} {
|
||||
resp.Components = append(resp.Components, component)
|
||||
resp.Status = worseHealthStatus(resp.Status, component.Status)
|
||||
}
|
||||
|
||||
statusCode := http.StatusOK
|
||||
if resp.Status == "down" {
|
||||
statusCode = http.StatusServiceUnavailable
|
||||
}
|
||||
writeJSON(w, statusCode, resp)
|
||||
}
|
||||
|
||||
func (s *Server) healthProviders(ctx context.Context) healthComponent {
|
||||
providers := []providerStatus{
|
||||
s.checkLLM(ctx),
|
||||
s.checkAudioLLM(ctx, transcription.ProviderWhisperLargeV3, s.cfg.AudioBaseURL, s.cfg.AudioAPIKey, s.cfg.AudioModel, s.cfg.AudioTimeout),
|
||||
}
|
||||
status := "ok"
|
||||
messages := make([]string, 0)
|
||||
for _, provider := range providers {
|
||||
switch {
|
||||
case !provider.Configured:
|
||||
status = worseHealthStatus(status, "degraded")
|
||||
messages = append(messages, provider.Name+" not configured")
|
||||
case !provider.OK:
|
||||
status = worseHealthStatus(status, "down")
|
||||
if provider.Error != "" {
|
||||
messages = append(messages, provider.Name+": "+provider.Error)
|
||||
} else {
|
||||
messages = append(messages, provider.Name+" unavailable")
|
||||
}
|
||||
case provider.Stale:
|
||||
status = worseHealthStatus(status, "degraded")
|
||||
if provider.Error != "" {
|
||||
messages = append(messages, provider.Name+": "+provider.Error)
|
||||
}
|
||||
}
|
||||
}
|
||||
return healthComponent{
|
||||
Name: "providers",
|
||||
Status: status,
|
||||
Error: strings.Join(messages, "; "),
|
||||
Data: map[string]any{
|
||||
"providers": providers,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func healthQueue(stats *model.Stats) healthComponent {
|
||||
var pending, running, staleRunning int64
|
||||
var oldestPendingAgeSeconds, oldestRunningAgeSeconds int64
|
||||
for _, row := range stats.Backlog {
|
||||
pending += row.Pending
|
||||
running += row.Running
|
||||
staleRunning += row.StaleRunning
|
||||
if row.OldestPendingAgeSeconds > oldestPendingAgeSeconds {
|
||||
oldestPendingAgeSeconds = row.OldestPendingAgeSeconds
|
||||
}
|
||||
if row.OldestRunningAgeSeconds > oldestRunningAgeSeconds {
|
||||
oldestRunningAgeSeconds = row.OldestRunningAgeSeconds
|
||||
}
|
||||
}
|
||||
status := "ok"
|
||||
message := ""
|
||||
if staleRunning > 0 {
|
||||
status = "degraded"
|
||||
message = "there are stale running jobs"
|
||||
}
|
||||
return healthComponent{
|
||||
Name: "queue",
|
||||
Status: status,
|
||||
Error: message,
|
||||
Data: map[string]any{
|
||||
"pending": pending,
|
||||
"running": running,
|
||||
"stale_running": staleRunning,
|
||||
"oldest_pending_age_seconds": oldestPendingAgeSeconds,
|
||||
"oldest_running_age_seconds": oldestRunningAgeSeconds,
|
||||
"backlog": stats.Backlog,
|
||||
"queue_status_totals": stats.Queues,
|
||||
"owner_status_totals": stats.Owners,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func healthErrors(stats *model.Stats) healthComponent {
|
||||
var failedTotal, failed24h int64
|
||||
for _, row := range stats.Errors {
|
||||
failedTotal += row.Total
|
||||
failed24h += row.Last24h
|
||||
}
|
||||
status := "ok"
|
||||
message := ""
|
||||
if failed24h > 0 {
|
||||
status = "degraded"
|
||||
message = "there are failed jobs in the last 24 hours"
|
||||
}
|
||||
return healthComponent{
|
||||
Name: "errors",
|
||||
Status: status,
|
||||
Error: message,
|
||||
Data: map[string]any{
|
||||
"failed_total": failedTotal,
|
||||
"failed_24h": failed24h,
|
||||
"by_code": stats.Errors,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func healthThroughput(stats *model.Stats) healthComponent {
|
||||
var done24h, retried24h int64
|
||||
for _, row := range stats.Stages {
|
||||
done24h += row.Done24h
|
||||
retried24h += row.Retried24h
|
||||
}
|
||||
|
||||
pendingByStage := make(map[string]int64)
|
||||
for _, row := range stats.Backlog {
|
||||
pendingByStage[row.TaskType+"|"+row.ModelProfile] += row.Pending + row.Running
|
||||
}
|
||||
doneByStage := make(map[string]int64)
|
||||
for _, row := range stats.Stages {
|
||||
doneByStage[row.TaskType+"|"+row.ModelProfile] += row.Done24h
|
||||
}
|
||||
|
||||
stuckStages := make([]string, 0)
|
||||
for key, total := range pendingByStage {
|
||||
if total > 0 && doneByStage[key] == 0 {
|
||||
stuckStages = append(stuckStages, key)
|
||||
}
|
||||
}
|
||||
|
||||
status := "ok"
|
||||
message := ""
|
||||
if len(stuckStages) > 0 {
|
||||
status = "degraded"
|
||||
message = "some active queues have no completed jobs in the last 24 hours"
|
||||
}
|
||||
return healthComponent{
|
||||
Name: "throughput",
|
||||
Status: status,
|
||||
Error: message,
|
||||
Data: map[string]any{
|
||||
"done_24h": done24h,
|
||||
"retried_24h": retried24h,
|
||||
"stuck_stages": stuckStages,
|
||||
"stages": stats.Stages,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func healthInfra(infra infraStatusResponse) healthComponent {
|
||||
status := "ok"
|
||||
message := ""
|
||||
if infra.SidecarError != "" {
|
||||
status = "degraded"
|
||||
message = infra.SidecarError
|
||||
}
|
||||
return healthComponent{
|
||||
Name: "infra",
|
||||
Status: status,
|
||||
Error: message,
|
||||
Data: map[string]any{
|
||||
"sidecar": infra.Sidecar,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func worseHealthStatus(current, next string) string {
|
||||
if current == "down" || next == "down" {
|
||||
return "down"
|
||||
}
|
||||
if current == "degraded" || next == "degraded" {
|
||||
return "degraded"
|
||||
}
|
||||
return "ok"
|
||||
}
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ai-service/internal/transcription"
|
||||
)
|
||||
|
||||
type providerStatus struct {
|
||||
@@ -42,9 +44,7 @@ func (s *Server) handleProviderStatus(w http.ResponseWriter, r *http.Request) {
|
||||
At: time.Now().UTC(),
|
||||
Providers: []providerStatus{
|
||||
s.checkLLM(ctx),
|
||||
s.checkWhisperX(ctx),
|
||||
s.checkAudioLLM(ctx, "qwen2-audio", s.cfg.QwenAudioBaseURL, s.cfg.QwenAudioAPIKey, s.cfg.QwenAudioModel, s.cfg.QwenAudioTimeout),
|
||||
s.checkAudioLLM(ctx, "voxtral-small", s.cfg.VoxtralBaseURL, s.cfg.VoxtralAPIKey, s.cfg.VoxtralModel, s.cfg.VoxtralTimeout),
|
||||
s.checkAudioLLM(ctx, transcription.ProviderWhisperLargeV3, s.cfg.AudioBaseURL, s.cfg.AudioAPIKey, s.cfg.AudioModel, s.cfg.AudioTimeout),
|
||||
},
|
||||
}
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
@@ -132,47 +132,6 @@ func (s *Server) checkLLM(ctx context.Context) providerStatus {
|
||||
return st
|
||||
}
|
||||
|
||||
func (s *Server) checkWhisperX(ctx context.Context) providerStatus {
|
||||
baseURL := strings.TrimRight(strings.TrimSpace(s.cfg.WhisperXURL), "/")
|
||||
st := providerStatus{Name: "whisperx", Configured: baseURL != "", URL: baseURL}
|
||||
if !st.Configured {
|
||||
return st
|
||||
}
|
||||
paths := []string{"/health", "/healthz", "/readyz", "/"}
|
||||
var lastErr string
|
||||
for _, path := range paths {
|
||||
cctx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||
start := time.Now()
|
||||
req, err := http.NewRequestWithContext(cctx, http.MethodGet, baseURL+path, nil)
|
||||
if err != nil {
|
||||
cancel()
|
||||
lastErr = err.Error()
|
||||
continue
|
||||
}
|
||||
res, err := (&http.Client{Timeout: 2 * time.Second}).Do(req)
|
||||
st.LatencyMS = time.Since(start).Milliseconds()
|
||||
cancel()
|
||||
if err != nil {
|
||||
lastErr = err.Error()
|
||||
continue
|
||||
}
|
||||
body := ""
|
||||
if res.StatusCode >= 300 {
|
||||
body = readSmallBody(res.Body)
|
||||
}
|
||||
_ = res.Body.Close()
|
||||
if res.StatusCode >= 300 {
|
||||
lastErr = fmt.Sprintf("%s http %d: %s", path, res.StatusCode, body)
|
||||
continue
|
||||
}
|
||||
st.OK = true
|
||||
s.rememberProviderOK("whisperx", st.LatencyMS)
|
||||
return st
|
||||
}
|
||||
st.Error = lastErr
|
||||
return s.withStaleProviderOK("whisperx", st)
|
||||
}
|
||||
|
||||
func (s *Server) rememberProviderOK(name string, latencyMS int64) {
|
||||
s.providerMu.Lock()
|
||||
defer s.providerMu.Unlock()
|
||||
|
||||
@@ -41,6 +41,8 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
|
||||
case r.Method == http.MethodGet && path == "/readyz":
|
||||
s.handleReady(w, r)
|
||||
case r.Method == http.MethodGet && path == "/health/detail":
|
||||
s.handleHealthDetail(w, r)
|
||||
case r.Method == http.MethodGet && path == "/":
|
||||
writeJSON(w, http.StatusOK, map[string]string{"service": "ai-service"})
|
||||
case r.Method == http.MethodPost && path == "/api/v1/jobs":
|
||||
@@ -59,6 +61,8 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
s.handleGetJob(w, r, path)
|
||||
case r.Method == http.MethodPost && strings.HasPrefix(path, "/api/v1/jobs/") && strings.HasSuffix(path, "/retry"):
|
||||
s.handleRetryJob(w, r, path)
|
||||
case r.Method == http.MethodPost && strings.HasPrefix(path, "/api/v1/jobs/") && strings.HasSuffix(path, "/cancel"):
|
||||
s.handleCancelJob(w, r, path)
|
||||
case r.Method == http.MethodPost && strings.HasPrefix(path, "/api/v1/jobs/") && strings.HasSuffix(path, "/complete"):
|
||||
s.handleCompleteJob(w, r, path)
|
||||
case r.Method == http.MethodPost && strings.HasPrefix(path, "/api/v1/jobs/") && strings.HasSuffix(path, "/fail"):
|
||||
@@ -265,7 +269,7 @@ func (s *Server) handleGetJob(w http.ResponseWriter, r *http.Request, path strin
|
||||
}
|
||||
|
||||
func (s *Server) handleRetryJob(w http.ResponseWriter, r *http.Request, path string) {
|
||||
id, err := jobIDFromPath(path, true)
|
||||
id, err := jobIDFromActionPath(path, "retry")
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
@@ -284,6 +288,26 @@ func (s *Server) handleRetryJob(w http.ResponseWriter, r *http.Request, path str
|
||||
writeJSON(w, http.StatusOK, job)
|
||||
}
|
||||
|
||||
func (s *Server) handleCancelJob(w http.ResponseWriter, r *http.Request, path string) {
|
||||
id, err := jobIDFromActionPath(path, "cancel")
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
ctx, cancel := contextWithTimeout(r, 8*time.Second)
|
||||
defer cancel()
|
||||
job, err := s.store.CancelJob(ctx, id)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if job == nil {
|
||||
writeError(w, http.StatusNotFound, "cancellable job not found")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, job)
|
||||
}
|
||||
|
||||
func (s *Server) handleCompleteJob(w http.ResponseWriter, r *http.Request, path string) {
|
||||
id, err := jobIDFromActionPath(path, "complete")
|
||||
if err != nil {
|
||||
@@ -337,7 +361,7 @@ func (s *Server) handleFailJob(w http.ResponseWriter, r *http.Request, path stri
|
||||
func (s *Server) handleStats(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := contextWithTimeout(r, 8*time.Second)
|
||||
defer cancel()
|
||||
stats, err := s.store.Stats(ctx)
|
||||
stats, err := s.store.Stats(ctx, s.cfg.WorkerLeaseTimeout)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
|
||||
@@ -38,7 +38,10 @@ type Usage struct {
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
const ChatResultSchemaVersion = "ai.chat_result.v1"
|
||||
|
||||
type ChatResult struct {
|
||||
SchemaVersion string `json:"schema_version"`
|
||||
Content string `json:"content"`
|
||||
Model string `json:"model"`
|
||||
Usage *Usage `json:"usage,omitempty"`
|
||||
@@ -137,6 +140,7 @@ func (c *Client) Chat(ctx context.Context, in ChatInput) (*ChatResult, error) {
|
||||
modelName = c.model
|
||||
}
|
||||
return &ChatResult{
|
||||
SchemaVersion: ChatResultSchemaVersion,
|
||||
Content: out.Choices[0].Message.Content,
|
||||
Model: modelName,
|
||||
Usage: out.Usage,
|
||||
|
||||
43
internal/llm/client_test.go
Normal file
43
internal/llm/client_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestChatResultIncludesSchemaVersion(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/chat/completions" {
|
||||
t.Fatalf("path = %q, want /v1/chat/completions", r.URL.Path)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"model": "qwen2.5-14b",
|
||||
"choices": []map[string]any{
|
||||
{"message": map[string]string{"role": "assistant", "content": `{"ok":true}`}},
|
||||
},
|
||||
"usage": map[string]int{
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 2,
|
||||
"total_tokens": 12,
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := New(server.URL, "", "fallback-model", 0)
|
||||
got, err := client.Chat(t.Context(), ChatInput{User: "test", MaxTokens: 32})
|
||||
if err != nil {
|
||||
t.Fatalf("Chat: %v", err)
|
||||
}
|
||||
if got.SchemaVersion != ChatResultSchemaVersion {
|
||||
t.Fatalf("schema_version = %q, want %q", got.SchemaVersion, ChatResultSchemaVersion)
|
||||
}
|
||||
if got.Content != `{"ok":true}` {
|
||||
t.Fatalf("content = %q", got.Content)
|
||||
}
|
||||
if got.Usage == nil || got.Usage.TotalTokens != 12 {
|
||||
t.Fatalf("usage = %#v", got.Usage)
|
||||
}
|
||||
}
|
||||
@@ -119,6 +119,30 @@ type ErrorStat struct {
|
||||
Last24h int64 `json:"last_24h"`
|
||||
}
|
||||
|
||||
type StageStat struct {
|
||||
OwnerService string `json:"owner_service"`
|
||||
TaskType string `json:"task_type"`
|
||||
ModelProfile string `json:"model_profile"`
|
||||
Done24h int64 `json:"done_24h"`
|
||||
AvgDurationSeconds int64 `json:"avg_duration_seconds"`
|
||||
AvgAttempts int64 `json:"avg_attempts"`
|
||||
Retried24h int64 `json:"retried_24h"`
|
||||
}
|
||||
|
||||
type BacklogStat struct {
|
||||
OwnerService string `json:"owner_service"`
|
||||
TaskType string `json:"task_type"`
|
||||
ModelProfile string `json:"model_profile"`
|
||||
Pending int64 `json:"pending"`
|
||||
Running int64 `json:"running"`
|
||||
StaleRunning int64 `json:"stale_running"`
|
||||
OldestPendingAgeSeconds int64 `json:"oldest_pending_age_seconds"`
|
||||
OldestPendingScheduledAt string `json:"oldest_pending_scheduled_at,omitempty"`
|
||||
OldestRunningAgeSeconds int64 `json:"oldest_running_age_seconds"`
|
||||
OldestRunningStartedAt string `json:"oldest_running_started_at,omitempty"`
|
||||
LastHeartbeatAt string `json:"last_heartbeat_at,omitempty"`
|
||||
}
|
||||
|
||||
type OwnerStat struct {
|
||||
OwnerService string `json:"owner_service"`
|
||||
TaskType string `json:"task_type"`
|
||||
@@ -132,4 +156,6 @@ type Stats struct {
|
||||
Queues []QueueStat `json:"queues"`
|
||||
Owners []OwnerStat `json:"owners,omitempty"`
|
||||
Errors []ErrorStat `json:"errors,omitempty"`
|
||||
Stages []StageStat `json:"stages,omitempty"`
|
||||
Backlog []BacklogStat `json:"backlog,omitempty"`
|
||||
}
|
||||
|
||||
22
internal/store/retry_policy.go
Normal file
22
internal/store/retry_policy.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type failRetryPolicy struct {
|
||||
Retryable bool
|
||||
Delay time.Duration
|
||||
}
|
||||
|
||||
func retryPolicyForError(errorCode string) failRetryPolicy {
|
||||
switch strings.TrimSpace(errorCode) {
|
||||
case "provider_unavailable", "model_unavailable", "provider_error", "dependency_error", "timeout", "storage_error", "stale_worker":
|
||||
return failRetryPolicy{Retryable: true, Delay: 30 * time.Second}
|
||||
case "bad_response", "transcript_hallucination", "transcript_incomplete", "internal_error", "unknown":
|
||||
return failRetryPolicy{Retryable: true, Delay: 2 * time.Minute}
|
||||
default:
|
||||
return failRetryPolicy{}
|
||||
}
|
||||
}
|
||||
45
internal/store/retry_policy_test.go
Normal file
45
internal/store/retry_policy_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRetryPolicyForError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
code string
|
||||
retryable bool
|
||||
delay time.Duration
|
||||
}{
|
||||
{name: "provider unavailable", code: "provider_unavailable", retryable: true, delay: 30 * time.Second},
|
||||
{name: "model unavailable", code: "model_unavailable", retryable: true, delay: 30 * time.Second},
|
||||
{name: "provider error", code: "provider_error", retryable: true, delay: 30 * time.Second},
|
||||
{name: "dependency error", code: "dependency_error", retryable: true, delay: 30 * time.Second},
|
||||
{name: "timeout", code: "timeout", retryable: true, delay: 30 * time.Second},
|
||||
{name: "storage", code: "storage_error", retryable: true, delay: 30 * time.Second},
|
||||
{name: "stale worker", code: "stale_worker", retryable: true, delay: 30 * time.Second},
|
||||
{name: "bad response", code: "bad_response", retryable: true, delay: 2 * time.Minute},
|
||||
{name: "transcript hallucination", code: "transcript_hallucination", retryable: true, delay: 2 * time.Minute},
|
||||
{name: "transcript incomplete", code: "transcript_incomplete", retryable: true, delay: 2 * time.Minute},
|
||||
{name: "internal error", code: "internal_error", retryable: true, delay: 2 * time.Minute},
|
||||
{name: "unknown", code: "unknown", retryable: true, delay: 2 * time.Minute},
|
||||
{name: "bad audio", code: "bad_audio"},
|
||||
{name: "bad input", code: "bad_input"},
|
||||
{name: "context length", code: "context_length"},
|
||||
{name: "unsupported task", code: "unsupported_task"},
|
||||
{name: "cancelled", code: "cancelled"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := retryPolicyForError(tt.code)
|
||||
if got.Retryable != tt.retryable {
|
||||
t.Fatalf("Retryable = %v, want %v", got.Retryable, tt.retryable)
|
||||
}
|
||||
if got.Delay != tt.delay {
|
||||
t.Fatalf("Delay = %s, want %s", got.Delay, tt.delay)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -41,17 +41,48 @@ func Open(ctx context.Context, databaseURL string) (*Store, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse database url: %w", err)
|
||||
}
|
||||
pool, err := pgxpool.NewWithConfig(ctx, cfg)
|
||||
pool, err := connectWithRetry(ctx, cfg, 2*time.Minute)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect postgres: %w", err)
|
||||
}
|
||||
if err := pool.Ping(ctx); err != nil {
|
||||
pool.Close()
|
||||
return nil, fmt.Errorf("ping postgres: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
return &Store{pool: pool}, nil
|
||||
}
|
||||
|
||||
func connectWithRetry(ctx context.Context, cfg *pgxpool.Config, maxWait time.Duration) (*pgxpool.Pool, error) {
|
||||
deadline := time.Now().Add(maxWait)
|
||||
var lastErr error
|
||||
|
||||
for attempt := 1; ; attempt++ {
|
||||
pool, err := pgxpool.NewWithConfig(ctx, cfg)
|
||||
if err == nil {
|
||||
if pingErr := pool.Ping(ctx); pingErr == nil {
|
||||
return pool, nil
|
||||
} else {
|
||||
err = fmt.Errorf("ping postgres: %w", pingErr)
|
||||
pool.Close()
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("connect postgres: %w", err)
|
||||
}
|
||||
lastErr = err
|
||||
|
||||
if time.Now().After(deadline) {
|
||||
return nil, fmt.Errorf("connect postgres after retry: %w", lastErr)
|
||||
}
|
||||
sleep := time.Duration(attempt) * time.Second
|
||||
if sleep > 5*time.Second {
|
||||
sleep = 5 * time.Second
|
||||
}
|
||||
timer := time.NewTimer(sleep)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return nil, fmt.Errorf("connect postgres cancelled: %w", ctx.Err())
|
||||
case <-timer.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Store) Close() {
|
||||
s.pool.Close()
|
||||
}
|
||||
@@ -78,7 +109,33 @@ INSERT INTO ai_jobs (
|
||||
)
|
||||
VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9)
|
||||
ON CONFLICT (idempotency_key) WHERE idempotency_key IS NOT NULL
|
||||
DO UPDATE SET updated_at = ai_jobs.updated_at
|
||||
DO UPDATE SET
|
||||
status = CASE WHEN ai_jobs.status = 'failed' AND ai_jobs.error_code = 'stale_worker'
|
||||
THEN 'pending' ELSE ai_jobs.status END,
|
||||
priority = CASE WHEN ai_jobs.status = 'failed' AND ai_jobs.error_code = 'stale_worker'
|
||||
THEN EXCLUDED.priority ELSE ai_jobs.priority END,
|
||||
max_attempts = CASE WHEN ai_jobs.status = 'failed' AND ai_jobs.error_code = 'stale_worker'
|
||||
THEN EXCLUDED.max_attempts ELSE ai_jobs.max_attempts END,
|
||||
attempts = CASE WHEN ai_jobs.status = 'failed' AND ai_jobs.error_code = 'stale_worker'
|
||||
THEN 0 ELSE ai_jobs.attempts END,
|
||||
input = CASE WHEN ai_jobs.status = 'failed' AND ai_jobs.error_code = 'stale_worker'
|
||||
THEN EXCLUDED.input ELSE ai_jobs.input END,
|
||||
scheduled_at = CASE WHEN ai_jobs.status = 'failed' AND ai_jobs.error_code = 'stale_worker'
|
||||
THEN EXCLUDED.scheduled_at ELSE ai_jobs.scheduled_at END,
|
||||
started_at = CASE WHEN ai_jobs.status = 'failed' AND ai_jobs.error_code = 'stale_worker'
|
||||
THEN NULL ELSE ai_jobs.started_at END,
|
||||
completed_at = CASE WHEN ai_jobs.status = 'failed' AND ai_jobs.error_code = 'stale_worker'
|
||||
THEN NULL ELSE ai_jobs.completed_at END,
|
||||
worker_id = CASE WHEN ai_jobs.status = 'failed' AND ai_jobs.error_code = 'stale_worker'
|
||||
THEN NULL ELSE ai_jobs.worker_id END,
|
||||
heartbeat_at = CASE WHEN ai_jobs.status = 'failed' AND ai_jobs.error_code = 'stale_worker'
|
||||
THEN NULL ELSE ai_jobs.heartbeat_at END,
|
||||
error_code = CASE WHEN ai_jobs.status = 'failed' AND ai_jobs.error_code = 'stale_worker'
|
||||
THEN NULL ELSE ai_jobs.error_code END,
|
||||
error_message = CASE WHEN ai_jobs.status = 'failed' AND ai_jobs.error_code = 'stale_worker'
|
||||
THEN NULL ELSE ai_jobs.error_message END,
|
||||
updated_at = CASE WHEN ai_jobs.status = 'failed' AND ai_jobs.error_code = 'stale_worker'
|
||||
THEN NOW() ELSE ai_jobs.updated_at END
|
||||
RETURNING ` + jobSelectColumns + `
|
||||
`
|
||||
row := s.pool.QueryRow(ctx, q,
|
||||
@@ -106,7 +163,33 @@ INSERT INTO ai_jobs (
|
||||
)
|
||||
VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9)
|
||||
ON CONFLICT (idempotency_key) WHERE idempotency_key IS NOT NULL
|
||||
DO UPDATE SET updated_at = ai_jobs.updated_at
|
||||
DO UPDATE SET
|
||||
status = CASE WHEN ai_jobs.status = 'failed' AND ai_jobs.error_code = 'stale_worker'
|
||||
THEN 'pending' ELSE ai_jobs.status END,
|
||||
priority = CASE WHEN ai_jobs.status = 'failed' AND ai_jobs.error_code = 'stale_worker'
|
||||
THEN EXCLUDED.priority ELSE ai_jobs.priority END,
|
||||
max_attempts = CASE WHEN ai_jobs.status = 'failed' AND ai_jobs.error_code = 'stale_worker'
|
||||
THEN EXCLUDED.max_attempts ELSE ai_jobs.max_attempts END,
|
||||
attempts = CASE WHEN ai_jobs.status = 'failed' AND ai_jobs.error_code = 'stale_worker'
|
||||
THEN 0 ELSE ai_jobs.attempts END,
|
||||
input = CASE WHEN ai_jobs.status = 'failed' AND ai_jobs.error_code = 'stale_worker'
|
||||
THEN EXCLUDED.input ELSE ai_jobs.input END,
|
||||
scheduled_at = CASE WHEN ai_jobs.status = 'failed' AND ai_jobs.error_code = 'stale_worker'
|
||||
THEN EXCLUDED.scheduled_at ELSE ai_jobs.scheduled_at END,
|
||||
started_at = CASE WHEN ai_jobs.status = 'failed' AND ai_jobs.error_code = 'stale_worker'
|
||||
THEN NULL ELSE ai_jobs.started_at END,
|
||||
completed_at = CASE WHEN ai_jobs.status = 'failed' AND ai_jobs.error_code = 'stale_worker'
|
||||
THEN NULL ELSE ai_jobs.completed_at END,
|
||||
worker_id = CASE WHEN ai_jobs.status = 'failed' AND ai_jobs.error_code = 'stale_worker'
|
||||
THEN NULL ELSE ai_jobs.worker_id END,
|
||||
heartbeat_at = CASE WHEN ai_jobs.status = 'failed' AND ai_jobs.error_code = 'stale_worker'
|
||||
THEN NULL ELSE ai_jobs.heartbeat_at END,
|
||||
error_code = CASE WHEN ai_jobs.status = 'failed' AND ai_jobs.error_code = 'stale_worker'
|
||||
THEN NULL ELSE ai_jobs.error_code END,
|
||||
error_message = CASE WHEN ai_jobs.status = 'failed' AND ai_jobs.error_code = 'stale_worker'
|
||||
THEN NULL ELSE ai_jobs.error_message END,
|
||||
updated_at = CASE WHEN ai_jobs.status = 'failed' AND ai_jobs.error_code = 'stale_worker'
|
||||
THEN NOW() ELSE ai_jobs.updated_at END
|
||||
RETURNING ` + jobSelectColumns + `
|
||||
`
|
||||
var batch pgx.Batch
|
||||
@@ -368,6 +451,28 @@ WHERE j.id = picked.id
|
||||
return int(tag.RowsAffected()), nil
|
||||
}
|
||||
|
||||
func (s *Store) CancelJob(ctx context.Context, id uuid.UUID) (*model.Job, error) {
|
||||
const q = `
|
||||
UPDATE ai_jobs
|
||||
SET status = 'cancelled',
|
||||
completed_at = NOW(),
|
||||
worker_id = NULL,
|
||||
heartbeat_at = NULL,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
AND status IN ('pending', 'running')
|
||||
RETURNING id, owner_service, owner_ref, task_type, model_profile, priority, status,
|
||||
attempts, max_attempts, input, result, error_code, error_message,
|
||||
scheduled_at, started_at, completed_at, worker_id, heartbeat_at,
|
||||
created_at, updated_at, idempotency_key
|
||||
`
|
||||
job, err := scanJob(s.pool.QueryRow(ctx, q, id))
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
return job, err
|
||||
}
|
||||
|
||||
func (s *Store) CancelJobs(ctx context.Context, filter model.JobFilter) (int, error) {
|
||||
normalizeFilter(&filter)
|
||||
const q = `
|
||||
@@ -490,25 +595,47 @@ RETURNING ` + jobSelectColumns + `
|
||||
return job, err
|
||||
}
|
||||
|
||||
func (s *Store) HeartbeatJob(ctx context.Context, id uuid.UUID) error {
|
||||
const q = `
|
||||
UPDATE ai_jobs
|
||||
SET heartbeat_at = NOW(),
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
AND status = 'running'
|
||||
`
|
||||
tag, err := s.pool.Exec(ctx, q, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) FailJob(ctx context.Context, id uuid.UUID, in model.FailJob) (*model.Job, error) {
|
||||
errorCode := strings.TrimSpace(in.ErrorCode)
|
||||
if errorCode == "" {
|
||||
errorCode = "unknown"
|
||||
}
|
||||
errorMessage := strings.TrimSpace(in.ErrorMessage)
|
||||
policy := retryPolicyForError(errorCode)
|
||||
const q = `
|
||||
UPDATE ai_jobs
|
||||
SET status = 'failed',
|
||||
SET status = CASE WHEN $4 AND attempts < max_attempts THEN 'pending' ELSE 'failed' END,
|
||||
error_code = $2,
|
||||
error_message = $3,
|
||||
completed_at = NOW(),
|
||||
heartbeat_at = NOW(),
|
||||
scheduled_at = CASE WHEN $4 AND attempts < max_attempts THEN NOW() + make_interval(secs => $5) ELSE scheduled_at END,
|
||||
started_at = CASE WHEN $4 AND attempts < max_attempts THEN NULL ELSE started_at END,
|
||||
completed_at = CASE WHEN $4 AND attempts < max_attempts THEN NULL ELSE NOW() END,
|
||||
worker_id = NULL,
|
||||
heartbeat_at = CASE WHEN $4 AND attempts < max_attempts THEN NULL ELSE NOW() END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
AND status = 'running'
|
||||
RETURNING ` + jobSelectColumns + `
|
||||
`
|
||||
job, err := scanJob(s.pool.QueryRow(ctx, q, id, errorCode, errorMessage))
|
||||
job, err := scanJob(s.pool.QueryRow(ctx, q, id, errorCode, errorMessage, policy.Retryable, int(policy.Delay.Seconds())))
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -553,7 +680,10 @@ WHERE j.id = picked.id
|
||||
return int(tag.RowsAffected()), nil
|
||||
}
|
||||
|
||||
func (s *Store) Stats(ctx context.Context) (*model.Stats, error) {
|
||||
func (s *Store) Stats(ctx context.Context, staleAfter time.Duration) (*model.Stats, error) {
|
||||
if staleAfter <= 0 {
|
||||
staleAfter = 15 * time.Minute
|
||||
}
|
||||
out := &model.Stats{At: time.Now().UTC()}
|
||||
|
||||
queueRows, err := s.pool.Query(ctx, `
|
||||
@@ -618,7 +748,105 @@ ORDER BY owner_service, last_24h DESC, total DESC
|
||||
}
|
||||
out.Errors = append(out.Errors, stat)
|
||||
}
|
||||
return out, errorRows.Err()
|
||||
if err := errorRows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stageRows, err := s.pool.Query(ctx, `
|
||||
SELECT owner_service,
|
||||
task_type,
|
||||
model_profile,
|
||||
count(*) AS done_24h,
|
||||
COALESCE(ROUND(AVG(EXTRACT(EPOCH FROM (completed_at - started_at))))::bigint, 0) AS avg_duration_seconds,
|
||||
COALESCE(ROUND(AVG(attempts))::bigint, 0) AS avg_attempts,
|
||||
count(*) FILTER (WHERE attempts > 1) AS retried_24h
|
||||
FROM ai_jobs
|
||||
WHERE status = 'done'
|
||||
AND started_at IS NOT NULL
|
||||
AND completed_at IS NOT NULL
|
||||
AND completed_at > NOW() - INTERVAL '24 hours'
|
||||
GROUP BY owner_service, task_type, model_profile
|
||||
ORDER BY owner_service, task_type, model_profile
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer stageRows.Close()
|
||||
for stageRows.Next() {
|
||||
var stat model.StageStat
|
||||
if err := stageRows.Scan(
|
||||
&stat.OwnerService,
|
||||
&stat.TaskType,
|
||||
&stat.ModelProfile,
|
||||
&stat.Done24h,
|
||||
&stat.AvgDurationSeconds,
|
||||
&stat.AvgAttempts,
|
||||
&stat.Retried24h,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out.Stages = append(out.Stages, stat)
|
||||
}
|
||||
if err := stageRows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
backlogRows, err := s.pool.Query(ctx, `
|
||||
SELECT owner_service,
|
||||
task_type,
|
||||
model_profile,
|
||||
count(*) FILTER (WHERE status = 'pending') AS pending,
|
||||
count(*) FILTER (WHERE status = 'running') AS running,
|
||||
count(*) FILTER (
|
||||
WHERE status = 'running'
|
||||
AND COALESCE(heartbeat_at, started_at, updated_at) < NOW() - make_interval(secs => $1)
|
||||
) AS stale_running,
|
||||
COALESCE(EXTRACT(EPOCH FROM (NOW() - MIN(scheduled_at) FILTER (WHERE status = 'pending')))::bigint, 0) AS oldest_pending_age_seconds,
|
||||
MIN(scheduled_at) FILTER (WHERE status = 'pending') AS oldest_pending_scheduled_at,
|
||||
COALESCE(EXTRACT(EPOCH FROM (NOW() - MIN(started_at) FILTER (WHERE status = 'running')))::bigint, 0) AS oldest_running_age_seconds,
|
||||
MIN(started_at) FILTER (WHERE status = 'running') AS oldest_running_started_at,
|
||||
MAX(heartbeat_at) FILTER (WHERE status = 'running') AS last_heartbeat_at
|
||||
FROM ai_jobs
|
||||
WHERE status IN ('pending', 'running')
|
||||
GROUP BY owner_service, task_type, model_profile
|
||||
ORDER BY stale_running DESC, pending DESC, running DESC, owner_service, task_type, model_profile
|
||||
`, int(staleAfter.Seconds()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer backlogRows.Close()
|
||||
for backlogRows.Next() {
|
||||
var stat model.BacklogStat
|
||||
var oldestPendingScheduledAt *time.Time
|
||||
var oldestRunningStartedAt *time.Time
|
||||
var lastHeartbeatAt *time.Time
|
||||
if err := backlogRows.Scan(
|
||||
&stat.OwnerService,
|
||||
&stat.TaskType,
|
||||
&stat.ModelProfile,
|
||||
&stat.Pending,
|
||||
&stat.Running,
|
||||
&stat.StaleRunning,
|
||||
&stat.OldestPendingAgeSeconds,
|
||||
&oldestPendingScheduledAt,
|
||||
&stat.OldestRunningAgeSeconds,
|
||||
&oldestRunningStartedAt,
|
||||
&lastHeartbeatAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if oldestPendingScheduledAt != nil {
|
||||
stat.OldestPendingScheduledAt = oldestPendingScheduledAt.UTC().Format(time.RFC3339)
|
||||
}
|
||||
if oldestRunningStartedAt != nil {
|
||||
stat.OldestRunningStartedAt = oldestRunningStartedAt.UTC().Format(time.RFC3339)
|
||||
}
|
||||
if lastHeartbeatAt != nil {
|
||||
stat.LastHeartbeatAt = lastHeartbeatAt.UTC().Format(time.RFC3339)
|
||||
}
|
||||
out.Backlog = append(out.Backlog, stat)
|
||||
}
|
||||
return out, backlogRows.Err()
|
||||
}
|
||||
|
||||
func scanJobSummary(row pgx.Row) (*model.JobSummary, error) {
|
||||
|
||||
@@ -3,58 +3,44 @@ package transcription
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
providers []ProviderConfig
|
||||
provider ProviderConfig
|
||||
http *http.Client
|
||||
ffmpegPath string
|
||||
leadSilence time.Duration
|
||||
}
|
||||
|
||||
const (
|
||||
ProviderWhisperX = "whisperx"
|
||||
ProviderQwenAudio = "qwen2-audio"
|
||||
ProviderVoxtral = "voxtral-small"
|
||||
ProviderWhisperLargeV3 = "whisper-large-v3"
|
||||
defaultWhisperModel = "openai/whisper-large-v3"
|
||||
)
|
||||
|
||||
var speakerLabelPattern = regexp.MustCompile(`(?i)(?:^|[\n\r ]+)((?:speaker|спикер|говорящий)\s*\d+)\s*[::-]`)
|
||||
|
||||
type Options struct {
|
||||
Providers []string
|
||||
WhisperXURL string
|
||||
WhisperXTimeout time.Duration
|
||||
FfmpegPath string
|
||||
LeadSilence time.Duration
|
||||
QwenAudioBaseURL string
|
||||
QwenAudioAPIKey string
|
||||
QwenAudioModel string
|
||||
QwenAudioTimeout time.Duration
|
||||
VoxtralBaseURL string
|
||||
VoxtralAPIKey string
|
||||
VoxtralModel string
|
||||
VoxtralTimeout time.Duration
|
||||
AudioLLMPrompt string
|
||||
AudioLLMMaxTokens int
|
||||
AudioBaseURL string
|
||||
AudioAPIKey string
|
||||
AudioModel string
|
||||
AudioTimeout time.Duration
|
||||
AudioPrompt string
|
||||
}
|
||||
|
||||
type ProviderConfig struct {
|
||||
Name string
|
||||
Kind string
|
||||
BaseURL string
|
||||
APIKey string
|
||||
Model string
|
||||
Timeout time.Duration
|
||||
MaxTokens int
|
||||
Prompt string
|
||||
}
|
||||
|
||||
@@ -74,7 +60,10 @@ type Segment struct {
|
||||
Speaker string `json:"speaker,omitempty"`
|
||||
}
|
||||
|
||||
const ResultSchemaVersion = "ai.transcription_result.v1"
|
||||
|
||||
type Result struct {
|
||||
SchemaVersion string `json:"schema_version"`
|
||||
Provider string `json:"provider,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Attempts []Attempt `json:"attempts,omitempty"`
|
||||
@@ -95,183 +84,71 @@ type Attempt struct {
|
||||
DurationMS int64 `json:"duration_ms,omitempty"`
|
||||
}
|
||||
|
||||
type whisperResponse struct {
|
||||
Language string `json:"language"`
|
||||
Segments []Segment `json:"segments"`
|
||||
DiarizeError *string `json:"diarize_error,omitempty"`
|
||||
AlignError *string `json:"align_error,omitempty"`
|
||||
}
|
||||
|
||||
type audioLLMResponse struct {
|
||||
Text string
|
||||
Model string
|
||||
Language string
|
||||
Segments []Segment
|
||||
}
|
||||
|
||||
type audioLLMChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []audioLLMChatMessage `json:"messages"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
}
|
||||
|
||||
type audioLLMChatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content []audioLLMContentPart `json:"content"`
|
||||
}
|
||||
|
||||
type audioLLMContentPart struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
InputAudio *audioLLMAudio `json:"input_audio,omitempty"`
|
||||
}
|
||||
|
||||
type audioLLMAudio struct {
|
||||
Data string `json:"data"`
|
||||
Format string `json:"format,omitempty"`
|
||||
}
|
||||
|
||||
type audioLLMChatResponse struct {
|
||||
type audioTranscriptionResponse struct {
|
||||
Text string `json:"text"`
|
||||
Language string `json:"language,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
Segments []audioTranscriptionSegment `json:"segments,omitempty"`
|
||||
Error *struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type audioTranscriptionSegment struct {
|
||||
Start float64 `json:"start"`
|
||||
End float64 `json:"end"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type audioTranscriptionStatusError struct {
|
||||
status int
|
||||
body string
|
||||
}
|
||||
|
||||
func (e audioTranscriptionStatusError) Error() string {
|
||||
return fmt.Sprintf("audio transcription HTTP %d: %s", e.status, e.body)
|
||||
}
|
||||
|
||||
func New(baseURL string, timeout time.Duration, ffmpegPath string, leadSilence time.Duration) *Client {
|
||||
return NewWithOptions(Options{
|
||||
Providers: []string{ProviderWhisperX},
|
||||
WhisperXURL: baseURL,
|
||||
WhisperXTimeout: timeout,
|
||||
FfmpegPath: ffmpegPath,
|
||||
LeadSilence: leadSilence,
|
||||
AudioBaseURL: baseURL,
|
||||
AudioTimeout: timeout,
|
||||
})
|
||||
}
|
||||
|
||||
func NewWithOptions(opts Options) *Client {
|
||||
leadSilence := opts.LeadSilence
|
||||
if leadSilence < 0 {
|
||||
leadSilence = 0
|
||||
}
|
||||
if leadSilence > 5*time.Second {
|
||||
leadSilence = 5 * time.Second
|
||||
}
|
||||
ffmpegPath := strings.TrimSpace(opts.FfmpegPath)
|
||||
if ffmpegPath == "" {
|
||||
ffmpegPath = "ffmpeg"
|
||||
}
|
||||
maxTokens := opts.AudioLLMMaxTokens
|
||||
if maxTokens <= 0 {
|
||||
maxTokens = 4096
|
||||
}
|
||||
audioLLMPrompt := strings.TrimSpace(opts.AudioLLMPrompt)
|
||||
if audioLLMPrompt == "" {
|
||||
audioLLMPrompt = "Transcribe the audio exactly. Return only the transcript text."
|
||||
}
|
||||
providers := buildProviders(opts, audioLLMPrompt, maxTokens)
|
||||
if len(providers) == 0 {
|
||||
audioPrompt := strings.TrimSpace(opts.AudioPrompt)
|
||||
provider := buildAudioProvider(opts, audioPrompt)
|
||||
if provider.BaseURL == "" {
|
||||
return nil
|
||||
}
|
||||
return &Client{
|
||||
providers: providers,
|
||||
http: &http.Client{Timeout: maxProviderTimeout(providers)},
|
||||
ffmpegPath: ffmpegPath,
|
||||
leadSilence: leadSilence,
|
||||
provider: provider,
|
||||
http: &http.Client{Timeout: provider.Timeout},
|
||||
}
|
||||
}
|
||||
|
||||
func buildProviders(opts Options, prompt string, maxTokens int) []ProviderConfig {
|
||||
order := normalizeProviderOrder(opts.Providers)
|
||||
out := make([]ProviderConfig, 0, len(order))
|
||||
for _, name := range order {
|
||||
switch name {
|
||||
case ProviderWhisperX:
|
||||
baseURL := strings.TrimRight(strings.TrimSpace(opts.WhisperXURL), "/")
|
||||
func buildAudioProvider(opts Options, prompt string) ProviderConfig {
|
||||
baseURL := strings.TrimRight(strings.TrimSpace(opts.AudioBaseURL), "/")
|
||||
if baseURL == "" {
|
||||
continue
|
||||
return ProviderConfig{}
|
||||
}
|
||||
out = append(out, ProviderConfig{
|
||||
Name: ProviderWhisperX,
|
||||
Kind: ProviderWhisperX,
|
||||
model := firstNonEmpty(opts.AudioModel, defaultWhisperModel)
|
||||
return ProviderConfig{
|
||||
Name: ProviderWhisperLargeV3,
|
||||
BaseURL: baseURL,
|
||||
Model: ProviderWhisperX,
|
||||
Timeout: defaultDuration(opts.WhisperXTimeout, 10*time.Minute),
|
||||
})
|
||||
case ProviderQwenAudio:
|
||||
baseURL := strings.TrimRight(strings.TrimSpace(opts.QwenAudioBaseURL), "/")
|
||||
if baseURL == "" {
|
||||
continue
|
||||
}
|
||||
model := firstNonEmpty(opts.QwenAudioModel, "Qwen/Qwen2-Audio-7B-Instruct")
|
||||
out = append(out, ProviderConfig{
|
||||
Name: ProviderQwenAudio,
|
||||
Kind: "audio_llm",
|
||||
BaseURL: baseURL,
|
||||
APIKey: strings.TrimSpace(opts.QwenAudioAPIKey),
|
||||
APIKey: strings.TrimSpace(opts.AudioAPIKey),
|
||||
Model: model,
|
||||
Timeout: defaultDuration(opts.QwenAudioTimeout, 10*time.Minute),
|
||||
MaxTokens: maxTokens,
|
||||
Timeout: defaultDuration(opts.AudioTimeout, 10*time.Minute),
|
||||
Prompt: prompt,
|
||||
})
|
||||
case ProviderVoxtral:
|
||||
baseURL := strings.TrimRight(strings.TrimSpace(opts.VoxtralBaseURL), "/")
|
||||
if baseURL == "" {
|
||||
continue
|
||||
}
|
||||
model := firstNonEmpty(opts.VoxtralModel, "mistralai/Voxtral-Small-24B-2507")
|
||||
out = append(out, ProviderConfig{
|
||||
Name: ProviderVoxtral,
|
||||
Kind: "audio_llm",
|
||||
BaseURL: baseURL,
|
||||
APIKey: strings.TrimSpace(opts.VoxtralAPIKey),
|
||||
Model: model,
|
||||
Timeout: defaultDuration(opts.VoxtralTimeout, 10*time.Minute),
|
||||
MaxTokens: maxTokens,
|
||||
Prompt: prompt,
|
||||
})
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeProviderOrder(in []string) []string {
|
||||
if len(in) == 0 {
|
||||
return []string{ProviderWhisperX}
|
||||
}
|
||||
out := make([]string, 0, len(in))
|
||||
seen := map[string]bool{}
|
||||
for _, raw := range in {
|
||||
name := strings.ToLower(strings.TrimSpace(raw))
|
||||
switch name {
|
||||
case "whisper", "whisperx":
|
||||
name = ProviderWhisperX
|
||||
case "qwen", "qwen-audio", "qwen2-audio", "qwen2-audio-7b-instruct":
|
||||
name = ProviderQwenAudio
|
||||
case "voxtral", "voxtral-small", "voxtral-small-24b-2507":
|
||||
name = ProviderVoxtral
|
||||
default:
|
||||
continue
|
||||
}
|
||||
if !seen[name] {
|
||||
out = append(out, name)
|
||||
seen[name] = true
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func maxProviderTimeout(providers []ProviderConfig) time.Duration {
|
||||
maxTimeout := 10 * time.Minute
|
||||
for _, provider := range providers {
|
||||
if provider.Timeout > maxTimeout {
|
||||
maxTimeout = provider.Timeout
|
||||
}
|
||||
}
|
||||
return maxTimeout
|
||||
}
|
||||
|
||||
func defaultDuration(v, fallback time.Duration) time.Duration {
|
||||
@@ -282,8 +159,8 @@ func defaultDuration(v, fallback time.Duration) time.Duration {
|
||||
}
|
||||
|
||||
func (c *Client) Transcribe(ctx context.Context, in Input) (*Result, error) {
|
||||
if c == nil || len(c.providers) == 0 {
|
||||
return nil, fmt.Errorf("transcription providers not configured")
|
||||
if c == nil || c.provider.BaseURL == "" {
|
||||
return nil, fmt.Errorf("audio transcription provider not configured")
|
||||
}
|
||||
if strings.TrimSpace(in.AudioURL) == "" {
|
||||
return nil, fmt.Errorf("audio_url is required")
|
||||
@@ -292,31 +169,12 @@ func (c *Client) Transcribe(ctx context.Context, in Input) (*Result, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if c.leadSilence > 0 {
|
||||
audio, filename, err = c.addLeadSilence(ctx, audio, filename)
|
||||
result, attempt, err := c.transcribeWithProvider(ctx, c.provider, audio, filename, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
var attempts []Attempt
|
||||
var winner *Result
|
||||
var errors []string
|
||||
for _, provider := range c.providers {
|
||||
result, attempt, err := c.transcribeWithProvider(ctx, provider, audio, filename, in)
|
||||
attempts = append(attempts, attempt)
|
||||
if err != nil {
|
||||
errors = append(errors, provider.Name+": "+err.Error())
|
||||
continue
|
||||
}
|
||||
if winner == nil {
|
||||
winner = result
|
||||
}
|
||||
}
|
||||
if winner == nil {
|
||||
return nil, fmt.Errorf("all transcription providers failed: %s", strings.Join(errors, "; "))
|
||||
}
|
||||
winner.Attempts = attempts
|
||||
return winner, nil
|
||||
result.Attempts = []Attempt{attempt}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *Client) transcribeWithProvider(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*Result, Attempt, error) {
|
||||
@@ -331,58 +189,26 @@ func (c *Client) transcribeWithProvider(ctx context.Context, provider ProviderCo
|
||||
Model: provider.Model,
|
||||
Status: "failed",
|
||||
}
|
||||
switch provider.Kind {
|
||||
case ProviderWhisperX:
|
||||
resp, duration, err := c.transcribeAudio(providerCtx, provider, audio, filename, in)
|
||||
attempt.DurationMS = duration.Milliseconds()
|
||||
if err != nil {
|
||||
attempt.Error = err.Error()
|
||||
return nil, attempt, err
|
||||
}
|
||||
segments := adjustLeadSilence(resp.Segments, c.leadSilence)
|
||||
attempt.Status = "ok"
|
||||
attempt.Segments = segments
|
||||
attempt.Text = segmentsText(segments)
|
||||
return &Result{
|
||||
Provider: provider.Name,
|
||||
Model: provider.Model,
|
||||
Language: resp.Language,
|
||||
Segments: segments,
|
||||
DiarizeError: resp.DiarizeError,
|
||||
AlignError: resp.AlignError,
|
||||
DurationMS: duration.Milliseconds(),
|
||||
}, attempt, nil
|
||||
default:
|
||||
resp, duration, err := c.transcribeAudioLLM(providerCtx, provider, audio, filename, in)
|
||||
resp, duration, err := c.transcribeOpenAIAudio(providerCtx, provider, audio, filename, in)
|
||||
attempt.DurationMS = duration.Milliseconds()
|
||||
if err != nil {
|
||||
attempt.Error = err.Error()
|
||||
return nil, attempt, err
|
||||
}
|
||||
text := strings.TrimSpace(resp.Text)
|
||||
segments := []Segment{{Start: 0, End: 0, Text: text}}
|
||||
segments := normalizeAudioLLMSegments(resp.Segments, text, in.Diarize)
|
||||
attempt.Status = "ok"
|
||||
attempt.Model = resp.Model
|
||||
attempt.Text = text
|
||||
attempt.Segments = segments
|
||||
return &Result{
|
||||
SchemaVersion: ResultSchemaVersion,
|
||||
Provider: provider.Name,
|
||||
Model: resp.Model,
|
||||
Language: firstNonEmpty(in.Language, "unknown"),
|
||||
Language: firstNonEmpty(resp.Language, in.Language, "unknown"),
|
||||
Segments: segments,
|
||||
DurationMS: duration.Milliseconds(),
|
||||
}, attempt, nil
|
||||
}
|
||||
}
|
||||
|
||||
func segmentsText(segments []Segment) string {
|
||||
parts := make([]string, 0, len(segments))
|
||||
for _, segment := range segments {
|
||||
if text := strings.TrimSpace(segment.Text); text != "" {
|
||||
parts = append(parts, text)
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
|
||||
func (c *Client) downloadAudio(ctx context.Context, in Input) ([]byte, string, error) {
|
||||
@@ -413,83 +239,6 @@ func (c *Client) downloadAudio(ctx context.Context, in Input) ([]byte, string, e
|
||||
return audio, filename, nil
|
||||
}
|
||||
|
||||
func (c *Client) addLeadSilence(ctx context.Context, audio []byte, filename string) ([]byte, string, error) {
|
||||
tmpDir, err := os.MkdirTemp("", "ai-transcribe-*")
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("prepare audio temp dir: %w", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
inputPath := filepath.Join(tmpDir, "input"+safeExt(filename))
|
||||
outputPath := filepath.Join(tmpDir, "padded.mp3")
|
||||
if err := os.WriteFile(inputPath, audio, 0o600); err != nil {
|
||||
return nil, "", fmt.Errorf("write audio temp file: %w", err)
|
||||
}
|
||||
delayMS := int(c.leadSilence.Milliseconds())
|
||||
if delayMS <= 0 {
|
||||
return audio, filename, nil
|
||||
}
|
||||
cmd := exec.CommandContext(ctx, c.ffmpegPath,
|
||||
"-nostdin", "-y",
|
||||
"-i", inputPath,
|
||||
"-af", fmt.Sprintf("adelay=%d:all=1", delayMS),
|
||||
"-codec:a", "libmp3lame",
|
||||
"-qscale:a", "5",
|
||||
outputPath,
|
||||
)
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("ffmpeg lead silence: %w (%s)", err, trimOutput(out))
|
||||
}
|
||||
padded, err := os.ReadFile(outputPath)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("read padded audio: %w", err)
|
||||
}
|
||||
if len(padded) == 0 {
|
||||
return nil, "", fmt.Errorf("padded audio is empty")
|
||||
}
|
||||
base := strings.TrimSuffix(filepath.Base(filename), filepath.Ext(filename))
|
||||
if base == "" || base == "." || base == "/" {
|
||||
base = "audio"
|
||||
}
|
||||
return padded, base + "-padded.mp3", nil
|
||||
}
|
||||
|
||||
func safeExt(filename string) string {
|
||||
ext := strings.ToLower(filepath.Ext(filename))
|
||||
switch ext {
|
||||
case ".mp3", ".wav", ".m4a", ".ogg", ".opus", ".webm":
|
||||
return ext
|
||||
default:
|
||||
return ".audio"
|
||||
}
|
||||
}
|
||||
|
||||
func trimOutput(out []byte) string {
|
||||
s := strings.TrimSpace(string(out))
|
||||
if len(s) > 600 {
|
||||
return s[:600]
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func adjustLeadSilence(segments []Segment, silence time.Duration) []Segment {
|
||||
if len(segments) == 0 || silence <= 0 {
|
||||
return segments
|
||||
}
|
||||
shift := silence.Seconds()
|
||||
out := make([]Segment, 0, len(segments))
|
||||
for _, segment := range segments {
|
||||
segment.Start = clampTime(segment.Start - shift)
|
||||
segment.End = clampTime(segment.End - shift)
|
||||
if segment.End < segment.Start {
|
||||
segment.End = segment.Start
|
||||
}
|
||||
out = append(out, segment)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func clampTime(v float64) float64 {
|
||||
if v < 0 {
|
||||
return 0
|
||||
@@ -497,95 +246,61 @@ func clampTime(v float64) float64 {
|
||||
return v
|
||||
}
|
||||
|
||||
func (c *Client) transcribeAudio(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*whisperResponse, time.Duration, error) {
|
||||
body := &bytes.Buffer{}
|
||||
mw := multipart.NewWriter(body)
|
||||
fw, err := mw.CreateFormFile("file", filename)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("create form file: %w", err)
|
||||
func (c *Client) transcribeOpenAIAudio(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*audioLLMResponse, time.Duration, error) {
|
||||
resp, duration, err := c.doOpenAIAudioTranscription(ctx, provider, audio, filename, in, "verbose_json")
|
||||
if err == nil {
|
||||
return resp, duration, nil
|
||||
}
|
||||
if _, err := fw.Write(audio); err != nil {
|
||||
return nil, 0, fmt.Errorf("copy audio: %w", err)
|
||||
if !isVerboseJSONUnsupported(err) {
|
||||
return nil, duration, err
|
||||
}
|
||||
if in.Language != "" {
|
||||
_ = mw.WriteField("language", in.Language)
|
||||
}
|
||||
if in.Diarize {
|
||||
_ = mw.WriteField("diarize", "true")
|
||||
if in.MinSpeakers > 0 {
|
||||
_ = mw.WriteField("min_speakers", fmt.Sprintf("%d", in.MinSpeakers))
|
||||
}
|
||||
if in.MaxSpeakers > 0 {
|
||||
_ = mw.WriteField("max_speakers", fmt.Sprintf("%d", in.MaxSpeakers))
|
||||
}
|
||||
} else {
|
||||
_ = mw.WriteField("diarize", "false")
|
||||
}
|
||||
if err := mw.Close(); err != nil {
|
||||
return nil, 0, fmt.Errorf("close form: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, provider.BaseURL+"/transcribe", body)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("whisperx request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", mw.FormDataContentType())
|
||||
|
||||
start := time.Now()
|
||||
resp, err := c.http.Do(req)
|
||||
duration := time.Since(start)
|
||||
if err != nil {
|
||||
return nil, duration, fmt.Errorf("whisperx do: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||||
return nil, duration, fmt.Errorf("whisperx HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
var out whisperResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
|
||||
return nil, duration, fmt.Errorf("whisperx decode: %w", err)
|
||||
}
|
||||
return &out, duration, nil
|
||||
fallbackResp, fallbackDuration, fallbackErr := c.doOpenAIAudioTranscription(ctx, provider, audio, filename, in, "json")
|
||||
return fallbackResp, duration + fallbackDuration, fallbackErr
|
||||
}
|
||||
|
||||
func (c *Client) transcribeAudioLLM(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input) (*audioLLMResponse, time.Duration, error) {
|
||||
prompt := provider.Prompt
|
||||
if in.Language != "" {
|
||||
prompt += "\nЯзык аудио: " + in.Language + "."
|
||||
func (c *Client) doOpenAIAudioTranscription(ctx context.Context, provider ProviderConfig, audio []byte, filename string, in Input, responseFormat string) (*audioLLMResponse, time.Duration, error) {
|
||||
body := &bytes.Buffer{}
|
||||
mw := multipart.NewWriter(body)
|
||||
if err := mw.WriteField("model", provider.Model); err != nil {
|
||||
return nil, 0, fmt.Errorf("audio transcription model field: %w", err)
|
||||
}
|
||||
if in.Diarize {
|
||||
prompt += "\nЕсли слышны разные говорящие, разделяй реплики с короткими пометками Спикер 1/Спикер 2."
|
||||
if err := mw.WriteField("response_format", responseFormat); err != nil {
|
||||
return nil, 0, fmt.Errorf("audio transcription response_format field: %w", err)
|
||||
}
|
||||
reqBody := audioLLMChatRequest{
|
||||
Model: provider.Model,
|
||||
MaxTokens: provider.MaxTokens,
|
||||
Temperature: 0,
|
||||
Messages: []audioLLMChatMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []audioLLMContentPart{
|
||||
{Type: "text", Text: prompt},
|
||||
{
|
||||
Type: "input_audio",
|
||||
InputAudio: &audioLLMAudio{
|
||||
Data: base64.StdEncoding.EncodeToString(audio),
|
||||
Format: audioFormat(filename),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
if err := mw.WriteField("temperature", "0"); err != nil {
|
||||
return nil, 0, fmt.Errorf("audio transcription temperature field: %w", err)
|
||||
}
|
||||
body, err := json.Marshal(reqBody)
|
||||
if responseFormat == "verbose_json" {
|
||||
if err := mw.WriteField("timestamp_granularities[]", "segment"); err != nil {
|
||||
return nil, 0, fmt.Errorf("audio transcription timestamp field: %w", err)
|
||||
}
|
||||
}
|
||||
if prompt := strings.TrimSpace(provider.Prompt); prompt != "" {
|
||||
if err := mw.WriteField("prompt", prompt); err != nil {
|
||||
return nil, 0, fmt.Errorf("audio transcription prompt field: %w", err)
|
||||
}
|
||||
}
|
||||
if lang := strings.TrimSpace(in.Language); lang != "" {
|
||||
if err := mw.WriteField("language", lang); err != nil {
|
||||
return nil, 0, fmt.Errorf("audio transcription language field: %w", err)
|
||||
}
|
||||
}
|
||||
fw, err := mw.CreateFormFile("file", filename)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("audio llm marshal: %w", err)
|
||||
return nil, 0, fmt.Errorf("audio transcription create file: %w", err)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, provider.BaseURL+"/v1/chat/completions", bytes.NewReader(body))
|
||||
if _, err := fw.Write(audio); err != nil {
|
||||
return nil, 0, fmt.Errorf("audio transcription copy audio: %w", err)
|
||||
}
|
||||
if err := mw.Close(); err != nil {
|
||||
return nil, 0, fmt.Errorf("audio transcription close form: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, provider.BaseURL+"/v1/audio/transcriptions", body)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("audio llm request: %w", err)
|
||||
return nil, 0, fmt.Errorf("audio transcription request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Content-Type", mw.FormDataContentType())
|
||||
if provider.APIKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+provider.APIKey)
|
||||
}
|
||||
@@ -594,44 +309,197 @@ func (c *Client) transcribeAudioLLM(ctx context.Context, provider ProviderConfig
|
||||
resp, err := c.http.Do(req)
|
||||
duration := time.Since(start)
|
||||
if err != nil {
|
||||
return nil, duration, fmt.Errorf("audio llm do: %w", err)
|
||||
return nil, duration, fmt.Errorf("audio transcription do: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
raw, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
|
||||
if err != nil {
|
||||
return nil, duration, fmt.Errorf("audio llm read: %w", err)
|
||||
return nil, duration, fmt.Errorf("audio transcription read: %w", err)
|
||||
}
|
||||
if resp.StatusCode >= 300 {
|
||||
return nil, duration, fmt.Errorf("audio llm HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(raw)))
|
||||
return nil, duration, audioTranscriptionStatusError{status: resp.StatusCode, body: strings.TrimSpace(string(raw))}
|
||||
}
|
||||
var out audioLLMChatResponse
|
||||
var out audioTranscriptionResponse
|
||||
if err := json.Unmarshal(raw, &out); err != nil {
|
||||
return nil, duration, fmt.Errorf("audio llm decode: %w", err)
|
||||
return nil, duration, fmt.Errorf("audio transcription decode: %w", err)
|
||||
}
|
||||
if out.Error != nil {
|
||||
return nil, duration, fmt.Errorf("audio llm error: %s", out.Error.Message)
|
||||
}
|
||||
if len(out.Choices) == 0 {
|
||||
return nil, duration, fmt.Errorf("audio llm: empty choices")
|
||||
}
|
||||
modelName := out.Model
|
||||
if modelName == "" {
|
||||
modelName = provider.Model
|
||||
return nil, duration, fmt.Errorf("audio transcription error: %s", out.Error.Message)
|
||||
}
|
||||
modelName := firstNonEmpty(out.Model, provider.Model)
|
||||
return &audioLLMResponse{
|
||||
Text: strings.TrimSpace(out.Choices[0].Message.Content),
|
||||
Text: strings.TrimSpace(out.Text),
|
||||
Model: modelName,
|
||||
Language: out.Language,
|
||||
Segments: convertAudioSegments(out.Segments),
|
||||
}, duration, nil
|
||||
}
|
||||
|
||||
func audioFormat(filename string) string {
|
||||
ext := strings.TrimPrefix(strings.ToLower(filepath.Ext(filename)), ".")
|
||||
switch ext {
|
||||
case "wav", "mp3", "flac", "m4a", "ogg", "opus", "webm":
|
||||
return ext
|
||||
default:
|
||||
return "mp3"
|
||||
func isVerboseJSONUnsupported(err error) bool {
|
||||
var statusErr audioTranscriptionStatusError
|
||||
if !errors.As(err, &statusErr) {
|
||||
return false
|
||||
}
|
||||
if statusErr.status != http.StatusBadRequest && statusErr.status != http.StatusUnprocessableEntity {
|
||||
return false
|
||||
}
|
||||
body := strings.ToLower(statusErr.body)
|
||||
return strings.Contains(body, "verbose_json") ||
|
||||
strings.Contains(body, "response_format") ||
|
||||
strings.Contains(body, "timestamp_granularities")
|
||||
}
|
||||
|
||||
func convertAudioSegments(in []audioTranscriptionSegment) []Segment {
|
||||
out := make([]Segment, 0, len(in))
|
||||
for _, s := range in {
|
||||
text := strings.TrimSpace(s.Text)
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
end := s.End
|
||||
if end < s.Start {
|
||||
end = s.Start
|
||||
}
|
||||
out = append(out, Segment{Start: clampTime(s.Start), End: clampTime(end), Text: text})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeAudioLLMSegments(segments []Segment, text string, diarize bool) []Segment {
|
||||
text = strings.TrimSpace(text)
|
||||
if text != "" {
|
||||
if labeled := segmentSpeakerLabeledText(text); len(labeled) > 0 {
|
||||
return labeled
|
||||
}
|
||||
}
|
||||
if len(segments) <= 1 && text != "" {
|
||||
heuristic := segmentTranscriptText(text)
|
||||
if len(heuristic) > len(segments) {
|
||||
segments = heuristic
|
||||
}
|
||||
}
|
||||
return segments
|
||||
}
|
||||
|
||||
func segmentSpeakerLabeledText(text string) []Segment {
|
||||
matches := speakerLabelPattern.FindAllStringSubmatchIndex(text, -1)
|
||||
if len(matches) == 0 {
|
||||
return nil
|
||||
}
|
||||
speakerIDs := map[string]string{}
|
||||
var out []Segment
|
||||
var t float64
|
||||
for i, match := range matches {
|
||||
label := strings.ToLower(strings.TrimSpace(text[match[2]:match[3]]))
|
||||
speaker, ok := speakerIDs[label]
|
||||
if !ok {
|
||||
speaker = fmt.Sprintf("SPEAKER_%02d", len(speakerIDs))
|
||||
speakerIDs[label] = speaker
|
||||
}
|
||||
start := match[1]
|
||||
end := len(text)
|
||||
if i+1 < len(matches) {
|
||||
end = matches[i+1][0]
|
||||
}
|
||||
part := strings.TrimSpace(text[start:end])
|
||||
part = strings.Trim(part, ":-— ")
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
words := len(strings.Fields(part))
|
||||
duration := float64(words) * 0.42
|
||||
if duration < 1.2 {
|
||||
duration = 1.2
|
||||
}
|
||||
out = append(out, Segment{Start: t, End: t + duration, Text: part, Speaker: speaker})
|
||||
t += duration
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func segmentTranscriptText(text string) []Segment {
|
||||
parts := splitTranscriptSentences(text)
|
||||
out := make([]Segment, 0, len(parts))
|
||||
var t float64
|
||||
for _, part := range parts {
|
||||
words := len(strings.Fields(part))
|
||||
if words == 0 {
|
||||
continue
|
||||
}
|
||||
duration := float64(words) * 0.42
|
||||
if duration < 1.2 {
|
||||
duration = 1.2
|
||||
}
|
||||
segment := Segment{Start: t, End: t + duration, Text: part}
|
||||
out = append(out, segment)
|
||||
t = segment.End
|
||||
}
|
||||
if len(out) == 0 && strings.TrimSpace(text) != "" {
|
||||
out = append(out, Segment{Start: 0, End: 0, Text: strings.TrimSpace(text)})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func splitTranscriptSentences(text string) []string {
|
||||
text = strings.Join(strings.Fields(text), " ")
|
||||
if text == "" {
|
||||
return nil
|
||||
}
|
||||
var out []string
|
||||
start := 0
|
||||
runes := []rune(text)
|
||||
for i, r := range runes {
|
||||
if r != '.' && r != '!' && r != '?' && r != '…' {
|
||||
continue
|
||||
}
|
||||
next := i + 1
|
||||
if next < len(runes) && runes[next] != ' ' {
|
||||
continue
|
||||
}
|
||||
part := strings.TrimSpace(string(runes[start : i+1]))
|
||||
if part != "" {
|
||||
out = append(out, part)
|
||||
}
|
||||
start = i + 1
|
||||
for start < len(runes) && runes[start] == ' ' {
|
||||
start++
|
||||
}
|
||||
}
|
||||
tail := strings.TrimSpace(string(runes[start:]))
|
||||
if tail != "" {
|
||||
out = append(out, tail)
|
||||
}
|
||||
return mergeShortSegments(out, 8, 34)
|
||||
}
|
||||
|
||||
func mergeShortSegments(parts []string, minWords, maxWords int) []string {
|
||||
if len(parts) <= 1 {
|
||||
return parts
|
||||
}
|
||||
out := make([]string, 0, len(parts))
|
||||
var current []string
|
||||
currentWords := 0
|
||||
flush := func() {
|
||||
if len(current) == 0 {
|
||||
return
|
||||
}
|
||||
out = append(out, strings.Join(current, " "))
|
||||
current = nil
|
||||
currentWords = 0
|
||||
}
|
||||
for _, part := range parts {
|
||||
words := len(strings.Fields(part))
|
||||
if currentWords > 0 && currentWords+words > maxWords {
|
||||
flush()
|
||||
}
|
||||
current = append(current, part)
|
||||
currentWords += words
|
||||
if currentWords >= minWords {
|
||||
flush()
|
||||
}
|
||||
}
|
||||
flush()
|
||||
return out
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
|
||||
@@ -1,63 +1,183 @@
|
||||
package transcription
|
||||
|
||||
import (
|
||||
"math"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestAdjustLeadSilence(t *testing.T) {
|
||||
got := adjustLeadSilence([]Segment{
|
||||
{Start: 0.2, End: 1.1, Text: "first"},
|
||||
{Start: 1.4, End: 2.0, Text: "second"},
|
||||
}, 800*time.Millisecond)
|
||||
|
||||
if got[0].Start != 0 {
|
||||
t.Fatalf("first start = %v, want 0", got[0].Start)
|
||||
}
|
||||
if !near(got[0].End, 0.3) {
|
||||
t.Fatalf("first end = %v, want 0.3", got[0].End)
|
||||
}
|
||||
if !near(got[1].Start, 0.6) {
|
||||
t.Fatalf("second start = %v, want 0.6", got[1].Start)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeProviderOrder(t *testing.T) {
|
||||
got := normalizeProviderOrder([]string{"whisperx", "qwen", "voxtral", "qwen2-audio"})
|
||||
want := []string{ProviderWhisperX, ProviderQwenAudio, ProviderVoxtral}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("providers = %#v, want %#v", got, want)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("providers = %#v, want %#v", got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWithOptionsBuildsComparisonProviders(t *testing.T) {
|
||||
func TestNewWithOptionsBuildsWhisperProvider(t *testing.T) {
|
||||
client := NewWithOptions(Options{
|
||||
Providers: []string{"whisperx", "qwen2-audio", "voxtral-small"},
|
||||
WhisperXURL: "http://whisperx",
|
||||
QwenAudioBaseURL: "http://qwen",
|
||||
VoxtralBaseURL: "http://voxtral",
|
||||
AudioBaseURL: "http://whisper",
|
||||
})
|
||||
if client == nil {
|
||||
t.Fatal("client is nil")
|
||||
}
|
||||
got := make([]string, 0, len(client.providers))
|
||||
for _, provider := range client.providers {
|
||||
got = append(got, provider.Name)
|
||||
}
|
||||
want := []string{ProviderWhisperX, ProviderQwenAudio, ProviderVoxtral}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("providers = %#v, want %#v", got, want)
|
||||
if client.provider.Name != ProviderWhisperLargeV3 {
|
||||
t.Fatalf("provider = %q, want %q", client.provider.Name, ProviderWhisperLargeV3)
|
||||
}
|
||||
if client.provider.Model != "openai/whisper-large-v3" {
|
||||
t.Fatalf("model = %q", client.provider.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func near(got, want float64) bool {
|
||||
return math.Abs(got-want) < 0.000001
|
||||
func TestWhisperUsesAudioTranscriptionsEndpoint(t *testing.T) {
|
||||
audioSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte("fake audio"))
|
||||
}))
|
||||
defer audioSrv.Close()
|
||||
|
||||
var gotPath, gotModel, gotResponseFormat, gotPrompt, gotTemperature, gotTimestampGranularity string
|
||||
providerSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
if err := r.ParseMultipartForm(16 << 20); err != nil {
|
||||
t.Fatalf("ParseMultipartForm: %v", err)
|
||||
}
|
||||
gotModel = r.FormValue("model")
|
||||
gotResponseFormat = r.FormValue("response_format")
|
||||
gotPrompt = r.FormValue("prompt")
|
||||
gotTemperature = r.FormValue("temperature")
|
||||
gotTimestampGranularity = r.FormValue("timestamp_granularities[]")
|
||||
if _, _, err := r.FormFile("file"); err != nil {
|
||||
t.Fatalf("FormFile: %v", err)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"text": "Алло, тест. Да, слышно.",
|
||||
"segments": []map[string]any{
|
||||
{"start": 0, "end": 1.2, "text": "Алло, тест."},
|
||||
{"start": 1.2, "end": 2.4, "text": "Да, слышно."},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer providerSrv.Close()
|
||||
|
||||
client := NewWithOptions(Options{
|
||||
AudioBaseURL: providerSrv.URL,
|
||||
AudioModel: "openai/whisper-large-v3",
|
||||
})
|
||||
if client == nil {
|
||||
t.Fatal("client is nil")
|
||||
}
|
||||
got, err := client.Transcribe(t.Context(), Input{AudioURL: audioSrv.URL, Filename: "call.mp3"})
|
||||
if err != nil {
|
||||
t.Fatalf("Transcribe: %v", err)
|
||||
}
|
||||
if gotPath != "/v1/audio/transcriptions" {
|
||||
t.Fatalf("path = %q, want /v1/audio/transcriptions", gotPath)
|
||||
}
|
||||
if gotModel != "openai/whisper-large-v3" {
|
||||
t.Fatalf("model = %q", gotModel)
|
||||
}
|
||||
if gotResponseFormat != "verbose_json" {
|
||||
t.Fatalf("response_format = %q, want verbose_json", gotResponseFormat)
|
||||
}
|
||||
if gotTemperature != "0" {
|
||||
t.Fatalf("temperature = %q, want 0", gotTemperature)
|
||||
}
|
||||
if gotTimestampGranularity != "segment" {
|
||||
t.Fatalf("timestamp_granularities[] = %q, want segment", gotTimestampGranularity)
|
||||
}
|
||||
if gotPrompt != "" {
|
||||
t.Fatalf("prompt = %q, want empty", gotPrompt)
|
||||
}
|
||||
if len(got.Segments) != 2 || got.Segments[0].Text != "Алло, тест." || got.Segments[1].Start != 1.2 {
|
||||
t.Fatalf("segments = %#v", got.Segments)
|
||||
}
|
||||
if got.SchemaVersion != ResultSchemaVersion {
|
||||
t.Fatalf("schema_version = %q, want %q", got.SchemaVersion, ResultSchemaVersion)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhisperFallsBackToJSONWhenVerboseJSONUnsupported(t *testing.T) {
|
||||
audioSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte("fake audio"))
|
||||
}))
|
||||
defer audioSrv.Close()
|
||||
|
||||
var formats []string
|
||||
providerSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseMultipartForm(16 << 20); err != nil {
|
||||
t.Fatalf("ParseMultipartForm: %v", err)
|
||||
}
|
||||
format := r.FormValue("response_format")
|
||||
formats = append(formats, format)
|
||||
if format == "verbose_json" {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"error": map[string]any{"message": "unsupported response_format verbose_json"},
|
||||
})
|
||||
return
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"text": "Алло, fallback работает.",
|
||||
})
|
||||
}))
|
||||
defer providerSrv.Close()
|
||||
|
||||
client := NewWithOptions(Options{
|
||||
AudioBaseURL: providerSrv.URL,
|
||||
AudioModel: "openai/whisper-large-v3",
|
||||
})
|
||||
got, err := client.Transcribe(t.Context(), Input{AudioURL: audioSrv.URL, Filename: "call.mp3"})
|
||||
if err != nil {
|
||||
t.Fatalf("Transcribe: %v", err)
|
||||
}
|
||||
if len(formats) != 2 || formats[0] != "verbose_json" || formats[1] != "json" {
|
||||
t.Fatalf("formats = %#v, want verbose_json then json", formats)
|
||||
}
|
||||
if got.Segments[0].Text != "Алло, fallback работает." {
|
||||
t.Fatalf("segments = %#v", got.Segments)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSegmentTranscriptTextDoesNotInventSpeakers(t *testing.T) {
|
||||
got := segmentTranscriptText("Алло, добрый день. Да, слушаю. Скажите, квартира продается? Да, продается.")
|
||||
if len(got) < 2 {
|
||||
t.Fatalf("segments = %#v, want multiple", got)
|
||||
}
|
||||
if got[0].Speaker != "" || got[1].Speaker != "" {
|
||||
t.Fatalf("speakers = %q/%q, want empty", got[0].Speaker, got[1].Speaker)
|
||||
}
|
||||
if got[1].Start <= got[0].Start {
|
||||
t.Fatalf("segment times did not advance: %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAudioLLMSegmentsSplitsSingleLongSegment(t *testing.T) {
|
||||
text := "Алло, добрый день. Да, слушаю. Скажите, квартира продается? Да, продается."
|
||||
got := normalizeAudioLLMSegments([]Segment{{Start: 0, End: 12, Text: text}}, text, true)
|
||||
if len(got) < 2 {
|
||||
t.Fatalf("segments = %#v, want heuristic split", got)
|
||||
}
|
||||
if got[0].Speaker != "" || got[1].Speaker != "" {
|
||||
t.Fatalf("speakers = %q/%q, want empty", got[0].Speaker, got[1].Speaker)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAudioLLMSegmentsKeepsSegmentsWithoutFakeSpeakers(t *testing.T) {
|
||||
got := normalizeAudioLLMSegments([]Segment{
|
||||
{Start: 0, End: 1, Text: "Алло."},
|
||||
{Start: 1, End: 2, Text: "Да, слушаю."},
|
||||
}, "Алло. Да, слушаю.", true)
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("segments = %#v", got)
|
||||
}
|
||||
if got[0].Speaker != "" || got[1].Speaker != "" {
|
||||
t.Fatalf("speakers = %q/%q, want empty", got[0].Speaker, got[1].Speaker)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAudioLLMSegmentsUsesExplicitSpeakerLabels(t *testing.T) {
|
||||
text := "Спикер 1: Алло, добрый день. Спикер 2: Да, слушаю. Спикер 1: Скажите, квартира продается?"
|
||||
got := normalizeAudioLLMSegments(nil, text, true)
|
||||
if len(got) != 3 {
|
||||
t.Fatalf("segments = %#v, want 3", got)
|
||||
}
|
||||
if got[0].Speaker != "SPEAKER_00" || got[1].Speaker != "SPEAKER_01" || got[2].Speaker != "SPEAKER_00" {
|
||||
t.Fatalf("speakers = %q/%q/%q", got[0].Speaker, got[1].Speaker, got[2].Speaker)
|
||||
}
|
||||
if got[0].Text != "Алло, добрый день." || got[1].Text != "Да, слушаю." {
|
||||
t.Fatalf("texts = %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ const (
|
||||
TaskCallAnalysis = "call_analysis"
|
||||
TaskTranscription = "transcription"
|
||||
|
||||
TranscriptionProfile = "whisperx"
|
||||
TranscriptionProfile = "whisper-large-v3"
|
||||
)
|
||||
|
||||
type Worker struct {
|
||||
@@ -113,6 +113,9 @@ func (w *Worker) tick(ctx context.Context) {
|
||||
}
|
||||
|
||||
func (w *Worker) process(ctx context.Context, job *model.Job) {
|
||||
stopHeartbeat := w.startHeartbeat(ctx, job)
|
||||
defer stopHeartbeat()
|
||||
|
||||
if job.TaskType == TaskTranscription {
|
||||
w.processTranscription(ctx, job)
|
||||
return
|
||||
@@ -168,6 +171,41 @@ func (w *Worker) fail(ctx context.Context, job *model.Job, code, message string)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) startHeartbeat(ctx context.Context, job *model.Job) func() {
|
||||
heartbeatCtx, cancel := context.WithCancel(ctx)
|
||||
done := make(chan struct{})
|
||||
ticker := time.NewTicker(w.heartbeatInterval())
|
||||
go func() {
|
||||
defer close(done)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-heartbeatCtx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := w.store.HeartbeatJob(heartbeatCtx, job.ID); err != nil {
|
||||
slog.Warn("heartbeat job failed", "job_id", job.ID, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
return func() {
|
||||
cancel()
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) heartbeatInterval() time.Duration {
|
||||
interval := w.leaseTimeout / 3
|
||||
if interval < 10*time.Second {
|
||||
return 10 * time.Second
|
||||
}
|
||||
if interval > time.Minute {
|
||||
return time.Minute
|
||||
}
|
||||
return interval
|
||||
}
|
||||
|
||||
func classifyTranscriptionError(err error) string {
|
||||
if err == nil {
|
||||
return "unknown"
|
||||
@@ -184,11 +222,11 @@ func classifyTranscriptionError(err error) string {
|
||||
return "bad_audio"
|
||||
case strings.Contains(s, "audio download") || strings.Contains(s, "audio http 5"):
|
||||
return "storage_error"
|
||||
case strings.Contains(s, "whisperx http 4") || strings.Contains(s, "ffmpeg") || strings.Contains(s, "invalid data") || strings.Contains(s, "could not decode"):
|
||||
case strings.Contains(s, "audio transcription http 4") || strings.Contains(s, "invalid data") || strings.Contains(s, "could not decode"):
|
||||
return "bad_audio"
|
||||
case strings.Contains(s, "whisperx http 5") || strings.Contains(s, "whisperx do") || strings.Contains(s, "audio llm http 5") || strings.Contains(s, "audio llm do") || strings.Contains(s, "connection refused") || strings.Contains(s, "connection reset") || strings.Contains(s, "closed network connection"):
|
||||
case strings.Contains(s, "audio transcription http 5") || strings.Contains(s, "audio transcription do") || strings.Contains(s, "connection refused") || strings.Contains(s, "connection reset") || strings.Contains(s, "closed network connection"):
|
||||
return "provider_unavailable"
|
||||
case strings.Contains(s, "audio llm http 4"):
|
||||
case strings.Contains(s, "audio transcription http 4"):
|
||||
return "bad_input"
|
||||
case strings.Contains(s, "decode"):
|
||||
return "bad_response"
|
||||
@@ -207,6 +245,8 @@ func classifyLLMError(err error) string {
|
||||
return "timeout"
|
||||
case strings.Contains(s, "connection refused") || strings.Contains(s, "connection reset") || strings.Contains(s, "no route to host") || strings.Contains(s, "llm http 5"):
|
||||
return "model_unavailable"
|
||||
case strings.Contains(s, "maximum context length") || strings.Contains(s, "context length") || strings.Contains(s, "input_tokens"):
|
||||
return "context_length"
|
||||
case strings.Contains(s, "llm http 4") || strings.Contains(s, "messages are required"):
|
||||
return "bad_input"
|
||||
case strings.Contains(s, "llm decode") || strings.Contains(s, "empty choices"):
|
||||
|
||||
29
internal/worker/worker_test.go
Normal file
29
internal/worker/worker_test.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClassifyLLMError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
want string
|
||||
}{
|
||||
{name: "timeout", err: errors.New("context deadline exceeded"), want: "timeout"},
|
||||
{name: "unavailable", err: errors.New("llm HTTP 500: internal server error"), want: "model_unavailable"},
|
||||
{name: "context length", err: errors.New("This model's maximum context length is 16384 tokens. input_tokens=16001"), want: "context_length"},
|
||||
{name: "bad input", err: errors.New("llm HTTP 400: messages are required"), want: "bad_input"},
|
||||
{name: "bad response", err: errors.New("llm decode: invalid character '<'"), want: "bad_response"},
|
||||
{name: "unknown", err: errors.New("strange failure"), want: "unknown"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := classifyLLMError(tt.err); got != tt.want {
|
||||
t.Fatalf("classifyLLMError() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -11,20 +11,11 @@ data:
|
||||
LLM_BASE_URL: "http://10.2.3.5:8002"
|
||||
LLM_MODEL: "qwen2.5-14b"
|
||||
LLM_TIMEOUT: "5m"
|
||||
TRANSCRIPTION_PROVIDERS: "whisperx,qwen2-audio,voxtral-small"
|
||||
WHISPERX_URL: "http://10.2.3.5:8001"
|
||||
WHISPERX_TIMEOUT: "10m"
|
||||
WHISPERX_LEAD_SILENCE: "800ms"
|
||||
# Fill these after Qwen2-Audio and Voxtral are exposed as OpenAI-compatible
|
||||
# chat-completions endpoints on the AI server.
|
||||
QWEN_AUDIO_BASE_URL: "http://10.2.3.5:8003"
|
||||
QWEN_AUDIO_MODEL: "Qwen/Qwen2-Audio-7B-Instruct"
|
||||
QWEN_AUDIO_TIMEOUT: "10m"
|
||||
VOXTRAL_BASE_URL: ""
|
||||
VOXTRAL_MODEL: "mistralai/Voxtral-Small-24B-2507"
|
||||
VOXTRAL_TIMEOUT: "10m"
|
||||
AUDIO_LLM_MAX_TOKENS: "4096"
|
||||
FFMPEG_PATH: "/usr/bin/ffmpeg"
|
||||
# Whisper Large v3 is exposed on the AI server through an OpenAI-compatible
|
||||
# /v1/audio/transcriptions endpoint.
|
||||
AUDIO_TRANSCRIPTION_BASE_URL: "http://10.2.3.5:8004"
|
||||
AUDIO_TRANSCRIPTION_MODEL: "openai/whisper-large-v3"
|
||||
AUDIO_TRANSCRIPTION_TIMEOUT: "30m"
|
||||
AI_STATS_SIDECAR_URL: "http://10.2.3.5:9090"
|
||||
AI_STATS_TIMEOUT: "8s"
|
||||
WORKER_POLL_INTERVAL: "2s"
|
||||
|
||||
@@ -18,6 +18,5 @@ type: Opaque
|
||||
stringData:
|
||||
DATABASE_URL: "postgres://ai_service:ai_service@postgres:5432/ai_service?sslmode=disable"
|
||||
LLM_API_KEY: "sk-111f838ccec43406e078cd9094b6797307cb895236179f32"
|
||||
QWEN_AUDIO_API_KEY: "sk-111f838ccec43406e078cd9094b6797307cb895236179f32"
|
||||
VOXTRAL_API_KEY: "sk-111f838ccec43406e078cd9094b6797307cb895236179f32"
|
||||
AUDIO_TRANSCRIPTION_API_KEY: "sk-111f838ccec43406e078cd9094b6797307cb895236179f32"
|
||||
AI_SERVICE_TOKEN: "d18bcacf9e02bae1806ee6b6eeda62b95be6a915c0a22936d9a700128b275442"
|
||||
|
||||
@@ -30,11 +30,81 @@ spec:
|
||||
fieldRef:
|
||||
fieldPath: metadata.name
|
||||
- name: WORKER_TASK_TYPES
|
||||
value: "llm_chat,chat_completion,call_analysis,telegram_classification"
|
||||
value: "llm_chat,chat_completion,call_analysis,telegram_classification,transcript_summary"
|
||||
- name: WORKER_MODEL_PROFILES
|
||||
value: "qwen2.5-14b"
|
||||
- name: WORKER_CLAIM_LIMIT
|
||||
value: "8"
|
||||
- name: WORKER_LEASE_TIMEOUT
|
||||
value: "15m"
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: ai-service-config
|
||||
- secretRef:
|
||||
name: ai-service-secrets
|
||||
startupProbe:
|
||||
httpGet:
|
||||
path: /readyz
|
||||
port: 8081
|
||||
periodSeconds: 5
|
||||
failureThreshold: 30
|
||||
readinessProbe:
|
||||
httpGet:
|
||||
path: /readyz
|
||||
port: 8081
|
||||
periodSeconds: 10
|
||||
livenessProbe:
|
||||
httpGet:
|
||||
path: /healthz
|
||||
port: 8081
|
||||
periodSeconds: 10
|
||||
resources:
|
||||
requests:
|
||||
cpu: 50m
|
||||
memory: 96Mi
|
||||
limits:
|
||||
cpu: 500m
|
||||
memory: 384Mi
|
||||
---
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: ai-service-analysis-worker
|
||||
namespace: ai-service
|
||||
spec:
|
||||
replicas: 1
|
||||
selector:
|
||||
matchLabels:
|
||||
app: ai-service-analysis-worker
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: ai-service-analysis-worker
|
||||
spec:
|
||||
terminationGracePeriodSeconds: 20
|
||||
hostAliases:
|
||||
- ip: "77.105.173.42"
|
||||
hostnames:
|
||||
- "s3-minio.estateliga.work"
|
||||
containers:
|
||||
- name: worker
|
||||
image: localhost:30300/admin/ai-service:latest
|
||||
command: ["/usr/local/bin/ai-service-worker"]
|
||||
ports:
|
||||
- containerPort: 8081
|
||||
env:
|
||||
- name: WORKER_ID
|
||||
valueFrom:
|
||||
fieldRef:
|
||||
fieldPath: metadata.name
|
||||
- name: WORKER_TASK_TYPES
|
||||
value: "call_analysis"
|
||||
- name: WORKER_MODEL_PROFILES
|
||||
value: "qwen2.5-14b"
|
||||
- name: WORKER_CLAIM_LIMIT
|
||||
value: "8"
|
||||
- name: WORKER_LEASE_TIMEOUT
|
||||
value: "15m"
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: ai-service-config
|
||||
@@ -98,9 +168,11 @@ spec:
|
||||
- name: WORKER_TASK_TYPES
|
||||
value: "transcription"
|
||||
- name: WORKER_MODEL_PROFILES
|
||||
value: "whisperx"
|
||||
value: "whisper-large-v3"
|
||||
- name: WORKER_CLAIM_LIMIT
|
||||
value: "1"
|
||||
value: "4"
|
||||
- name: WORKER_LEASE_TIMEOUT
|
||||
value: "15m"
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: ai-service-config
|
||||
|
||||
Reference in New Issue
Block a user