Move TG LLM classifier to Go

This commit is contained in:
Grendgi
2026-06-04 15:48:25 +03:00
parent 5cda374fb1
commit 37d27308c2
10 changed files with 594 additions and 4 deletions

518
cmd/classifier/main.go Normal file
View File

@@ -0,0 +1,518 @@
package main
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"os"
"os/signal"
"strconv"
"strings"
"syscall"
"time"
"github.com/jackc/pgx/v5/pgxpool"
)
const (
verticalHR = "hr"
defaultREPromptKey = "real_estate"
)
type config struct {
PostgresUser string
PostgresPassword string
PostgresDB string
PostgresHost string
PostgresPort int
LLMEnabled bool
LLMBaseURL string
LLMAPIKey string
LLMModel string
LLMTimeout time.Duration
LLMMaxTokens int
LLMMinTextLength int
ClassifyInterval time.Duration
ClassifyBatchSize int
}
type pendingMessage struct {
ID int64
Text string
Vertical string
SectionSlug string
DepartmentID string
Extracted map[string]any
}
type chatRequest struct {
Model string `json:"model"`
Messages []chatMessage `json:"messages"`
Temperature float64 `json:"temperature"`
MaxTokens int `json:"max_tokens"`
ResponseFormat responseFmt `json:"response_format"`
}
type chatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type responseFmt struct {
Type string `json:"type"`
}
type chatResponse struct {
Choices []struct {
Message chatMessage `json:"message"`
} `json:"choices"`
}
func main() {
cfg := loadConfig()
logger := slog.New(slog.NewJSONHandler(os.Stdout, nil))
slog.SetDefault(logger)
if !cfg.LLMEnabled {
slog.Info("classifier_disabled")
return
}
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer stop()
pool, err := pgxpool.New(ctx, cfg.databaseURL())
if err != nil {
slog.Error("db_connect_failed", "error", err)
os.Exit(1)
}
defer pool.Close()
worker := &classifier{cfg: cfg, db: pool, http: &http.Client{Timeout: cfg.LLMTimeout}}
slog.Info(
"classifier_started",
"interval", cfg.ClassifyInterval.String(),
"batch", cfg.ClassifyBatchSize,
"model", cfg.LLMModel,
)
ticker := time.NewTicker(cfg.ClassifyInterval)
defer ticker.Stop()
for {
updated, err := worker.runOnce(ctx)
if err != nil {
slog.Error("classify_batch_failed", "error", err)
} else if updated > 0 {
slog.Info("classify_batch_done", "updated", updated)
}
select {
case <-ctx.Done():
slog.Info("classifier_stopped")
return
case <-ticker.C:
}
}
}
type classifier struct {
cfg config
db *pgxpool.Pool
http *http.Client
}
func (c *classifier) runOnce(ctx context.Context) (int, error) {
rows, err := c.loadPending(ctx)
if err != nil {
return 0, err
}
if len(rows) == 0 {
return 0, nil
}
updated := 0
for _, msg := range rows {
key := verdictKey(msg.Vertical)
if _, ok := msg.Extracted[key]; ok {
continue
}
verdict, err := c.classify(ctx, msg)
if err != nil {
slog.Warn("llm_classify_failed", "message_id", msg.ID, "vertical", msg.Vertical, "error", err)
continue
}
if len(verdict) == 0 {
verdict, err = marshalRaw(negativeVerdict(msg.Vertical))
if err != nil {
slog.Warn("negative_verdict_failed", "message_id", msg.ID, "error", err)
continue
}
}
if err := c.saveVerdict(ctx, msg.ID, key, verdict); err != nil {
slog.Warn("save_verdict_failed", "message_id", msg.ID, "error", err)
continue
}
updated++
}
return updated, nil
}
func (c *classifier) loadPending(ctx context.Context) ([]pendingMessage, error) {
rows, err := c.db.Query(ctx, `
SELECT
m.id,
m.text,
c.vertical,
s.slug,
COALESCE(s.department_id, ''),
COALESCE(m.extracted, '{}'::jsonb)::text
FROM messages m
JOIN channels c ON c.id = m.channel_id
JOIN sections s ON s.id = c.section_id
WHERE m.text IS NOT NULL
AND (
(c.vertical = 'hr' AND (m.extracted IS NULL OR m.extracted->'hr_lead' IS NULL))
OR
(c.vertical <> 'hr' AND (m.extracted IS NULL OR m.extracted->'lead' IS NULL))
)
ORDER BY m.id DESC
LIMIT $1
`, c.cfg.ClassifyBatchSize)
if err != nil {
return nil, err
}
defer rows.Close()
out := make([]pendingMessage, 0, c.cfg.ClassifyBatchSize)
for rows.Next() {
var msg pendingMessage
var extractedText string
if err := rows.Scan(&msg.ID, &msg.Text, &msg.Vertical, &msg.SectionSlug, &msg.DepartmentID, &extractedText); err != nil {
return nil, err
}
if err := json.Unmarshal([]byte(extractedText), &msg.Extracted); err != nil {
msg.Extracted = map[string]any{}
}
out = append(out, msg)
}
return out, rows.Err()
}
func (c *classifier) classify(ctx context.Context, msg pendingMessage) (json.RawMessage, error) {
if len(strings.TrimSpace(msg.Text)) < c.cfg.LLMMinTextLength {
return marshalRaw(negativeVerdict(msg.Vertical))
}
systemPrompt, err := c.resolvePrompt(ctx, msg.Vertical, msg.DepartmentID, msg.SectionSlug)
if err != nil {
return nil, err
}
payload := chatRequest{
Model: c.cfg.LLMModel,
Messages: []chatMessage{
{Role: "system", Content: systemPrompt},
{Role: "user", Content: buildUserPrompt(msg.Text)},
},
Temperature: 0.1,
MaxTokens: c.cfg.LLMMaxTokens,
ResponseFormat: responseFmt{Type: "json_object"},
}
body, err := json.Marshal(payload)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, strings.TrimRight(c.cfg.LLMBaseURL, "/")+"/v1/chat/completions", bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
if c.cfg.LLMAPIKey != "" {
req.Header.Set("Authorization", "Bearer "+c.cfg.LLMAPIKey)
}
resp, err := c.http.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
b, _ := io.ReadAll(io.LimitReader(resp.Body, 512))
return nil, fmt.Errorf("llm http %d: %s", resp.StatusCode, strings.TrimSpace(string(b)))
}
var parsed chatResponse
if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil {
return nil, err
}
if len(parsed.Choices) == 0 {
return nil, errors.New("llm returned no choices")
}
raw := strings.TrimSpace(parsed.Choices[0].Message.Content)
if raw == "" {
return nil, errors.New("llm returned empty content")
}
block, err := extractJSONBlock(raw)
if err != nil {
return nil, err
}
normalized, err := normalizeVerdict(msg.Vertical, block)
if err != nil {
return nil, err
}
return normalized, nil
}
func (c *classifier) resolvePrompt(ctx context.Context, vertical, departmentID, sectionSlug string) (string, error) {
dept := departmentID
if dept == "" {
dept = "global"
}
keys := []string{}
if sectionSlug != "" {
keys = append(keys, promptKey(dept, vertical, sectionSlug))
}
keys = append(keys, promptKey(dept, vertical, ""))
for _, key := range keys {
var text string
err := c.db.QueryRow(ctx, `SELECT value #>> '{}' FROM app_settings WHERE key = $1`, key).Scan(&text)
if err == nil && strings.TrimSpace(text) != "" {
return text, nil
}
}
return defaultPrompt(vertical), nil
}
func (c *classifier) saveVerdict(ctx context.Context, id int64, key string, verdict json.RawMessage) error {
_, err := c.db.Exec(ctx, `
UPDATE messages
SET extracted = jsonb_set(COALESCE(extracted, '{}'::jsonb), ARRAY[$2], $3::jsonb, true)
WHERE id = $1
`, id, key, string(verdict))
return err
}
func promptKey(departmentID, vertical, sectionSlug string) string {
if sectionSlug != "" {
return fmt.Sprintf("llm_system_prompt:%s:%s:%s", departmentID, vertical, sectionSlug)
}
return fmt.Sprintf("llm_system_prompt:%s:%s", departmentID, vertical)
}
func verdictKey(vertical string) string {
if vertical == verticalHR {
return "hr_lead"
}
return "lead"
}
func buildUserPrompt(text string) string {
return "Текст сообщения:\n```\n" + text + "\n```\nВерни JSON."
}
func extractJSONBlock(raw string) (json.RawMessage, error) {
var payload json.RawMessage
if err := json.Unmarshal([]byte(raw), &payload); err == nil {
return payload, nil
}
start := strings.Index(raw, "{")
end := strings.LastIndex(raw, "}")
if start < 0 || end < start {
return nil, errors.New("no json object in llm content")
}
block := raw[start : end+1]
if err := json.Unmarshal([]byte(block), &payload); err != nil {
return nil, err
}
return payload, nil
}
func normalizeVerdict(vertical string, raw json.RawMessage) (json.RawMessage, error) {
var m map[string]any
if err := json.Unmarshal(raw, &m); err != nil {
return nil, err
}
if vertical == verticalHR {
if _, ok := m["is_lead"]; !ok {
m["is_lead"] = false
}
} else if _, ok := m["is_listing"]; !ok {
m["is_listing"] = false
}
if confidence, ok := asFloat(m["confidence"]); ok {
if confidence < 0 {
confidence = 0
}
if confidence > 1 {
confidence = 1
}
m["confidence"] = confidence
}
return marshalRaw(m)
}
func negativeVerdict(vertical string) map[string]any {
if vertical == verticalHR {
return map[string]any{
"is_lead": false,
"kind": nil,
"summary": "",
"confidence": 0,
}
}
return map[string]any{
"is_listing": false,
"kind": nil,
"summary": "",
"confidence": 0,
}
}
func marshalRaw(v any) (json.RawMessage, error) {
b, err := json.Marshal(v)
return json.RawMessage(b), err
}
func asFloat(v any) (float64, bool) {
switch x := v.(type) {
case float64:
return x, true
case float32:
return float64(x), true
case int:
return float64(x), true
case int64:
return float64(x), true
case json.Number:
f, err := x.Float64()
return f, err == nil
default:
return 0, false
}
}
func defaultPrompt(vertical string) string {
if vertical == verticalHR {
return defaultHRPrompt
}
return defaultREPrompt
}
func loadConfig() config {
return config{
PostgresUser: env("POSTGRES_USER", "parser"),
PostgresPassword: env("POSTGRES_PASSWORD", "parser"),
PostgresDB: env("POSTGRES_DB", "parser"),
PostgresHost: env("POSTGRES_HOST", "db"),
PostgresPort: envInt("POSTGRES_PORT", 5432),
LLMEnabled: envBool("LLM_ENABLED", true),
LLMBaseURL: env("LLM_BASE_URL", "http://10.2.3.5:8002"),
LLMAPIKey: env("LLM_API_KEY", ""),
LLMModel: env("LLM_MODEL", "qwen2.5-14b"),
LLMTimeout: time.Duration(envInt("LLM_TIMEOUT_SECONDS", 120)) * time.Second,
LLMMaxTokens: envInt("LLM_MAX_TOKENS", 600),
LLMMinTextLength: envInt("LLM_MIN_TEXT_LENGTH", 20),
ClassifyInterval: time.Duration(envInt("LLM_CLASSIFY_INTERVAL_SECONDS", 20)) * time.Second,
ClassifyBatchSize: envInt("LLM_CLASSIFY_BATCH_SIZE", 5),
}
}
func (c config) databaseURL() string {
return fmt.Sprintf(
"postgres://%s:%s@%s:%d/%s",
url.QueryEscape(c.PostgresUser),
url.QueryEscape(c.PostgresPassword),
c.PostgresHost,
c.PostgresPort,
url.QueryEscape(c.PostgresDB),
)
}
func env(key, fallback string) string {
if v := strings.TrimSpace(os.Getenv(key)); v != "" {
return v
}
return fallback
}
func envInt(key string, fallback int) int {
if raw := strings.TrimSpace(os.Getenv(key)); raw != "" {
if n, err := strconv.Atoi(raw); err == nil {
return n
}
}
return fallback
}
func envBool(key string, fallback bool) bool {
if raw := strings.TrimSpace(os.Getenv(key)); raw != "" {
if b, err := strconv.ParseBool(raw); err == nil {
return b
}
if raw == "1" {
return true
}
if raw == "0" {
return false
}
}
return fallback
}
const defaultREPrompt = `Ты — аналитик объявлений о недвижимости. Тебе дают текст из Telegram-канала.
Определи, является ли сообщение реальным объявлением о покупке, продаже или аренде недвижимости.
Отвечай строго валидным JSON без markdown:
{
"is_listing": boolean,
"kind": "sale" | "rent" | "purchase" | null,
"property_type": string | null,
"rooms": string | null,
"area_m2": number | null,
"price_text": string | null,
"price_value": number | null,
"currency": "RUB" | "USD" | "EUR" | "AED" | "GBP" | "CNY" | "TRY" | "KZT" | "BYN" | "UAH" | null,
"location": string | null,
"contact_phone": string | null,
"contact_name": string | null,
"summary": string,
"confidence": number
}
summary всегда по-русски, confidence в диапазоне 0..1.`
const defaultHRPrompt = `Ты — аналитик HR-объявлений. Тебе дают текст из Telegram-канала.
Определи, относится ли сообщение к рынку труда: вакансия, резюме или короткий HR-контакт.
Отвечай строго валидным JSON без markdown:
{
"is_lead": boolean,
"kind": "vacancy" | "resume" | "contact" | null,
"title": string | null,
"company": string | null,
"candidate_name": string | null,
"experience_years": number | null,
"skills": string[],
"location": string | null,
"remote": boolean | null,
"employment_type": "full-time" | "part-time" | "contract" | "internship" | null,
"salary_text": string | null,
"salary_value": number | null,
"currency": "RUB" | "USD" | "EUR" | "AED" | "GBP" | "CNY" | "TRY" | "KZT" | "BYN" | "UAH" | null,
"contact_phone": string | null,
"contact_name": string | null,
"summary": string,
"confidence": number
}
summary всегда по-русски, confidence в диапазоне 0..1.`