Move TG LLM classifier to Go
This commit is contained in:
518
cmd/classifier/main.go
Normal file
518
cmd/classifier/main.go
Normal file
@@ -0,0 +1,518 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
const (
|
||||
verticalHR = "hr"
|
||||
defaultREPromptKey = "real_estate"
|
||||
)
|
||||
|
||||
type config struct {
|
||||
PostgresUser string
|
||||
PostgresPassword string
|
||||
PostgresDB string
|
||||
PostgresHost string
|
||||
PostgresPort int
|
||||
|
||||
LLMEnabled bool
|
||||
LLMBaseURL string
|
||||
LLMAPIKey string
|
||||
LLMModel string
|
||||
LLMTimeout time.Duration
|
||||
LLMMaxTokens int
|
||||
LLMMinTextLength int
|
||||
ClassifyInterval time.Duration
|
||||
ClassifyBatchSize int
|
||||
}
|
||||
|
||||
type pendingMessage struct {
|
||||
ID int64
|
||||
Text string
|
||||
Vertical string
|
||||
SectionSlug string
|
||||
DepartmentID string
|
||||
Extracted map[string]any
|
||||
}
|
||||
|
||||
type chatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []chatMessage `json:"messages"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
ResponseFormat responseFmt `json:"response_format"`
|
||||
}
|
||||
|
||||
type chatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type responseFmt struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
type chatResponse struct {
|
||||
Choices []struct {
|
||||
Message chatMessage `json:"message"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
cfg := loadConfig()
|
||||
logger := slog.New(slog.NewJSONHandler(os.Stdout, nil))
|
||||
slog.SetDefault(logger)
|
||||
|
||||
if !cfg.LLMEnabled {
|
||||
slog.Info("classifier_disabled")
|
||||
return
|
||||
}
|
||||
|
||||
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
pool, err := pgxpool.New(ctx, cfg.databaseURL())
|
||||
if err != nil {
|
||||
slog.Error("db_connect_failed", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer pool.Close()
|
||||
|
||||
worker := &classifier{cfg: cfg, db: pool, http: &http.Client{Timeout: cfg.LLMTimeout}}
|
||||
slog.Info(
|
||||
"classifier_started",
|
||||
"interval", cfg.ClassifyInterval.String(),
|
||||
"batch", cfg.ClassifyBatchSize,
|
||||
"model", cfg.LLMModel,
|
||||
)
|
||||
|
||||
ticker := time.NewTicker(cfg.ClassifyInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
updated, err := worker.runOnce(ctx)
|
||||
if err != nil {
|
||||
slog.Error("classify_batch_failed", "error", err)
|
||||
} else if updated > 0 {
|
||||
slog.Info("classify_batch_done", "updated", updated)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
slog.Info("classifier_stopped")
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type classifier struct {
|
||||
cfg config
|
||||
db *pgxpool.Pool
|
||||
http *http.Client
|
||||
}
|
||||
|
||||
func (c *classifier) runOnce(ctx context.Context) (int, error) {
|
||||
rows, err := c.loadPending(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(rows) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
updated := 0
|
||||
for _, msg := range rows {
|
||||
key := verdictKey(msg.Vertical)
|
||||
if _, ok := msg.Extracted[key]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
verdict, err := c.classify(ctx, msg)
|
||||
if err != nil {
|
||||
slog.Warn("llm_classify_failed", "message_id", msg.ID, "vertical", msg.Vertical, "error", err)
|
||||
continue
|
||||
}
|
||||
if len(verdict) == 0 {
|
||||
verdict, err = marshalRaw(negativeVerdict(msg.Vertical))
|
||||
if err != nil {
|
||||
slog.Warn("negative_verdict_failed", "message_id", msg.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.saveVerdict(ctx, msg.ID, key, verdict); err != nil {
|
||||
slog.Warn("save_verdict_failed", "message_id", msg.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
updated++
|
||||
}
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
func (c *classifier) loadPending(ctx context.Context) ([]pendingMessage, error) {
|
||||
rows, err := c.db.Query(ctx, `
|
||||
SELECT
|
||||
m.id,
|
||||
m.text,
|
||||
c.vertical,
|
||||
s.slug,
|
||||
COALESCE(s.department_id, ''),
|
||||
COALESCE(m.extracted, '{}'::jsonb)::text
|
||||
FROM messages m
|
||||
JOIN channels c ON c.id = m.channel_id
|
||||
JOIN sections s ON s.id = c.section_id
|
||||
WHERE m.text IS NOT NULL
|
||||
AND (
|
||||
(c.vertical = 'hr' AND (m.extracted IS NULL OR m.extracted->'hr_lead' IS NULL))
|
||||
OR
|
||||
(c.vertical <> 'hr' AND (m.extracted IS NULL OR m.extracted->'lead' IS NULL))
|
||||
)
|
||||
ORDER BY m.id DESC
|
||||
LIMIT $1
|
||||
`, c.cfg.ClassifyBatchSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
out := make([]pendingMessage, 0, c.cfg.ClassifyBatchSize)
|
||||
for rows.Next() {
|
||||
var msg pendingMessage
|
||||
var extractedText string
|
||||
if err := rows.Scan(&msg.ID, &msg.Text, &msg.Vertical, &msg.SectionSlug, &msg.DepartmentID, &extractedText); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := json.Unmarshal([]byte(extractedText), &msg.Extracted); err != nil {
|
||||
msg.Extracted = map[string]any{}
|
||||
}
|
||||
out = append(out, msg)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
func (c *classifier) classify(ctx context.Context, msg pendingMessage) (json.RawMessage, error) {
|
||||
if len(strings.TrimSpace(msg.Text)) < c.cfg.LLMMinTextLength {
|
||||
return marshalRaw(negativeVerdict(msg.Vertical))
|
||||
}
|
||||
|
||||
systemPrompt, err := c.resolvePrompt(ctx, msg.Vertical, msg.DepartmentID, msg.SectionSlug)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
payload := chatRequest{
|
||||
Model: c.cfg.LLMModel,
|
||||
Messages: []chatMessage{
|
||||
{Role: "system", Content: systemPrompt},
|
||||
{Role: "user", Content: buildUserPrompt(msg.Text)},
|
||||
},
|
||||
Temperature: 0.1,
|
||||
MaxTokens: c.cfg.LLMMaxTokens,
|
||||
ResponseFormat: responseFmt{Type: "json_object"},
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, strings.TrimRight(c.cfg.LLMBaseURL, "/")+"/v1/chat/completions", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if c.cfg.LLMAPIKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+c.cfg.LLMAPIKey)
|
||||
}
|
||||
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(io.LimitReader(resp.Body, 512))
|
||||
return nil, fmt.Errorf("llm http %d: %s", resp.StatusCode, strings.TrimSpace(string(b)))
|
||||
}
|
||||
|
||||
var parsed chatResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(parsed.Choices) == 0 {
|
||||
return nil, errors.New("llm returned no choices")
|
||||
}
|
||||
|
||||
raw := strings.TrimSpace(parsed.Choices[0].Message.Content)
|
||||
if raw == "" {
|
||||
return nil, errors.New("llm returned empty content")
|
||||
}
|
||||
block, err := extractJSONBlock(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
normalized, err := normalizeVerdict(msg.Vertical, block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
func (c *classifier) resolvePrompt(ctx context.Context, vertical, departmentID, sectionSlug string) (string, error) {
|
||||
dept := departmentID
|
||||
if dept == "" {
|
||||
dept = "global"
|
||||
}
|
||||
keys := []string{}
|
||||
if sectionSlug != "" {
|
||||
keys = append(keys, promptKey(dept, vertical, sectionSlug))
|
||||
}
|
||||
keys = append(keys, promptKey(dept, vertical, ""))
|
||||
|
||||
for _, key := range keys {
|
||||
var text string
|
||||
err := c.db.QueryRow(ctx, `SELECT value #>> '{}' FROM app_settings WHERE key = $1`, key).Scan(&text)
|
||||
if err == nil && strings.TrimSpace(text) != "" {
|
||||
return text, nil
|
||||
}
|
||||
}
|
||||
return defaultPrompt(vertical), nil
|
||||
}
|
||||
|
||||
func (c *classifier) saveVerdict(ctx context.Context, id int64, key string, verdict json.RawMessage) error {
|
||||
_, err := c.db.Exec(ctx, `
|
||||
UPDATE messages
|
||||
SET extracted = jsonb_set(COALESCE(extracted, '{}'::jsonb), ARRAY[$2], $3::jsonb, true)
|
||||
WHERE id = $1
|
||||
`, id, key, string(verdict))
|
||||
return err
|
||||
}
|
||||
|
||||
func promptKey(departmentID, vertical, sectionSlug string) string {
|
||||
if sectionSlug != "" {
|
||||
return fmt.Sprintf("llm_system_prompt:%s:%s:%s", departmentID, vertical, sectionSlug)
|
||||
}
|
||||
return fmt.Sprintf("llm_system_prompt:%s:%s", departmentID, vertical)
|
||||
}
|
||||
|
||||
func verdictKey(vertical string) string {
|
||||
if vertical == verticalHR {
|
||||
return "hr_lead"
|
||||
}
|
||||
return "lead"
|
||||
}
|
||||
|
||||
func buildUserPrompt(text string) string {
|
||||
return "Текст сообщения:\n```\n" + text + "\n```\nВерни JSON."
|
||||
}
|
||||
|
||||
func extractJSONBlock(raw string) (json.RawMessage, error) {
|
||||
var payload json.RawMessage
|
||||
if err := json.Unmarshal([]byte(raw), &payload); err == nil {
|
||||
return payload, nil
|
||||
}
|
||||
start := strings.Index(raw, "{")
|
||||
end := strings.LastIndex(raw, "}")
|
||||
if start < 0 || end < start {
|
||||
return nil, errors.New("no json object in llm content")
|
||||
}
|
||||
block := raw[start : end+1]
|
||||
if err := json.Unmarshal([]byte(block), &payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func normalizeVerdict(vertical string, raw json.RawMessage) (json.RawMessage, error) {
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(raw, &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if vertical == verticalHR {
|
||||
if _, ok := m["is_lead"]; !ok {
|
||||
m["is_lead"] = false
|
||||
}
|
||||
} else if _, ok := m["is_listing"]; !ok {
|
||||
m["is_listing"] = false
|
||||
}
|
||||
if confidence, ok := asFloat(m["confidence"]); ok {
|
||||
if confidence < 0 {
|
||||
confidence = 0
|
||||
}
|
||||
if confidence > 1 {
|
||||
confidence = 1
|
||||
}
|
||||
m["confidence"] = confidence
|
||||
}
|
||||
return marshalRaw(m)
|
||||
}
|
||||
|
||||
func negativeVerdict(vertical string) map[string]any {
|
||||
if vertical == verticalHR {
|
||||
return map[string]any{
|
||||
"is_lead": false,
|
||||
"kind": nil,
|
||||
"summary": "",
|
||||
"confidence": 0,
|
||||
}
|
||||
}
|
||||
return map[string]any{
|
||||
"is_listing": false,
|
||||
"kind": nil,
|
||||
"summary": "",
|
||||
"confidence": 0,
|
||||
}
|
||||
}
|
||||
|
||||
func marshalRaw(v any) (json.RawMessage, error) {
|
||||
b, err := json.Marshal(v)
|
||||
return json.RawMessage(b), err
|
||||
}
|
||||
|
||||
func asFloat(v any) (float64, bool) {
|
||||
switch x := v.(type) {
|
||||
case float64:
|
||||
return x, true
|
||||
case float32:
|
||||
return float64(x), true
|
||||
case int:
|
||||
return float64(x), true
|
||||
case int64:
|
||||
return float64(x), true
|
||||
case json.Number:
|
||||
f, err := x.Float64()
|
||||
return f, err == nil
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func defaultPrompt(vertical string) string {
|
||||
if vertical == verticalHR {
|
||||
return defaultHRPrompt
|
||||
}
|
||||
return defaultREPrompt
|
||||
}
|
||||
|
||||
func loadConfig() config {
|
||||
return config{
|
||||
PostgresUser: env("POSTGRES_USER", "parser"),
|
||||
PostgresPassword: env("POSTGRES_PASSWORD", "parser"),
|
||||
PostgresDB: env("POSTGRES_DB", "parser"),
|
||||
PostgresHost: env("POSTGRES_HOST", "db"),
|
||||
PostgresPort: envInt("POSTGRES_PORT", 5432),
|
||||
LLMEnabled: envBool("LLM_ENABLED", true),
|
||||
LLMBaseURL: env("LLM_BASE_URL", "http://10.2.3.5:8002"),
|
||||
LLMAPIKey: env("LLM_API_KEY", ""),
|
||||
LLMModel: env("LLM_MODEL", "qwen2.5-14b"),
|
||||
LLMTimeout: time.Duration(envInt("LLM_TIMEOUT_SECONDS", 120)) * time.Second,
|
||||
LLMMaxTokens: envInt("LLM_MAX_TOKENS", 600),
|
||||
LLMMinTextLength: envInt("LLM_MIN_TEXT_LENGTH", 20),
|
||||
ClassifyInterval: time.Duration(envInt("LLM_CLASSIFY_INTERVAL_SECONDS", 20)) * time.Second,
|
||||
ClassifyBatchSize: envInt("LLM_CLASSIFY_BATCH_SIZE", 5),
|
||||
}
|
||||
}
|
||||
|
||||
func (c config) databaseURL() string {
|
||||
return fmt.Sprintf(
|
||||
"postgres://%s:%s@%s:%d/%s",
|
||||
url.QueryEscape(c.PostgresUser),
|
||||
url.QueryEscape(c.PostgresPassword),
|
||||
c.PostgresHost,
|
||||
c.PostgresPort,
|
||||
url.QueryEscape(c.PostgresDB),
|
||||
)
|
||||
}
|
||||
|
||||
func env(key, fallback string) string {
|
||||
if v := strings.TrimSpace(os.Getenv(key)); v != "" {
|
||||
return v
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func envInt(key string, fallback int) int {
|
||||
if raw := strings.TrimSpace(os.Getenv(key)); raw != "" {
|
||||
if n, err := strconv.Atoi(raw); err == nil {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func envBool(key string, fallback bool) bool {
|
||||
if raw := strings.TrimSpace(os.Getenv(key)); raw != "" {
|
||||
if b, err := strconv.ParseBool(raw); err == nil {
|
||||
return b
|
||||
}
|
||||
if raw == "1" {
|
||||
return true
|
||||
}
|
||||
if raw == "0" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
const defaultREPrompt = `Ты — аналитик объявлений о недвижимости. Тебе дают текст из Telegram-канала.
|
||||
Определи, является ли сообщение реальным объявлением о покупке, продаже или аренде недвижимости.
|
||||
Отвечай строго валидным JSON без markdown:
|
||||
{
|
||||
"is_listing": boolean,
|
||||
"kind": "sale" | "rent" | "purchase" | null,
|
||||
"property_type": string | null,
|
||||
"rooms": string | null,
|
||||
"area_m2": number | null,
|
||||
"price_text": string | null,
|
||||
"price_value": number | null,
|
||||
"currency": "RUB" | "USD" | "EUR" | "AED" | "GBP" | "CNY" | "TRY" | "KZT" | "BYN" | "UAH" | null,
|
||||
"location": string | null,
|
||||
"contact_phone": string | null,
|
||||
"contact_name": string | null,
|
||||
"summary": string,
|
||||
"confidence": number
|
||||
}
|
||||
summary всегда по-русски, confidence в диапазоне 0..1.`
|
||||
|
||||
const defaultHRPrompt = `Ты — аналитик HR-объявлений. Тебе дают текст из Telegram-канала.
|
||||
Определи, относится ли сообщение к рынку труда: вакансия, резюме или короткий HR-контакт.
|
||||
Отвечай строго валидным JSON без markdown:
|
||||
{
|
||||
"is_lead": boolean,
|
||||
"kind": "vacancy" | "resume" | "contact" | null,
|
||||
"title": string | null,
|
||||
"company": string | null,
|
||||
"candidate_name": string | null,
|
||||
"experience_years": number | null,
|
||||
"skills": string[],
|
||||
"location": string | null,
|
||||
"remote": boolean | null,
|
||||
"employment_type": "full-time" | "part-time" | "contract" | "internship" | null,
|
||||
"salary_text": string | null,
|
||||
"salary_value": number | null,
|
||||
"currency": "RUB" | "USD" | "EUR" | "AED" | "GBP" | "CNY" | "TRY" | "KZT" | "BYN" | "UAH" | null,
|
||||
"contact_phone": string | null,
|
||||
"contact_name": string | null,
|
||||
"summary": string,
|
||||
"confidence": number
|
||||
}
|
||||
summary всегда по-русски, confidence в диапазоне 0..1.`
|
||||
Reference in New Issue
Block a user