Files
monitoring-tg/cmd/classifier/main.go
2026-06-04 15:48:25 +03:00

519 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"os"
"os/signal"
"strconv"
"strings"
"syscall"
"time"
"github.com/jackc/pgx/v5/pgxpool"
)
const (
verticalHR = "hr"
defaultREPromptKey = "real_estate"
)
type config struct {
PostgresUser string
PostgresPassword string
PostgresDB string
PostgresHost string
PostgresPort int
LLMEnabled bool
LLMBaseURL string
LLMAPIKey string
LLMModel string
LLMTimeout time.Duration
LLMMaxTokens int
LLMMinTextLength int
ClassifyInterval time.Duration
ClassifyBatchSize int
}
type pendingMessage struct {
ID int64
Text string
Vertical string
SectionSlug string
DepartmentID string
Extracted map[string]any
}
type chatRequest struct {
Model string `json:"model"`
Messages []chatMessage `json:"messages"`
Temperature float64 `json:"temperature"`
MaxTokens int `json:"max_tokens"`
ResponseFormat responseFmt `json:"response_format"`
}
type chatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type responseFmt struct {
Type string `json:"type"`
}
type chatResponse struct {
Choices []struct {
Message chatMessage `json:"message"`
} `json:"choices"`
}
func main() {
cfg := loadConfig()
logger := slog.New(slog.NewJSONHandler(os.Stdout, nil))
slog.SetDefault(logger)
if !cfg.LLMEnabled {
slog.Info("classifier_disabled")
return
}
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer stop()
pool, err := pgxpool.New(ctx, cfg.databaseURL())
if err != nil {
slog.Error("db_connect_failed", "error", err)
os.Exit(1)
}
defer pool.Close()
worker := &classifier{cfg: cfg, db: pool, http: &http.Client{Timeout: cfg.LLMTimeout}}
slog.Info(
"classifier_started",
"interval", cfg.ClassifyInterval.String(),
"batch", cfg.ClassifyBatchSize,
"model", cfg.LLMModel,
)
ticker := time.NewTicker(cfg.ClassifyInterval)
defer ticker.Stop()
for {
updated, err := worker.runOnce(ctx)
if err != nil {
slog.Error("classify_batch_failed", "error", err)
} else if updated > 0 {
slog.Info("classify_batch_done", "updated", updated)
}
select {
case <-ctx.Done():
slog.Info("classifier_stopped")
return
case <-ticker.C:
}
}
}
type classifier struct {
cfg config
db *pgxpool.Pool
http *http.Client
}
func (c *classifier) runOnce(ctx context.Context) (int, error) {
rows, err := c.loadPending(ctx)
if err != nil {
return 0, err
}
if len(rows) == 0 {
return 0, nil
}
updated := 0
for _, msg := range rows {
key := verdictKey(msg.Vertical)
if _, ok := msg.Extracted[key]; ok {
continue
}
verdict, err := c.classify(ctx, msg)
if err != nil {
slog.Warn("llm_classify_failed", "message_id", msg.ID, "vertical", msg.Vertical, "error", err)
continue
}
if len(verdict) == 0 {
verdict, err = marshalRaw(negativeVerdict(msg.Vertical))
if err != nil {
slog.Warn("negative_verdict_failed", "message_id", msg.ID, "error", err)
continue
}
}
if err := c.saveVerdict(ctx, msg.ID, key, verdict); err != nil {
slog.Warn("save_verdict_failed", "message_id", msg.ID, "error", err)
continue
}
updated++
}
return updated, nil
}
func (c *classifier) loadPending(ctx context.Context) ([]pendingMessage, error) {
rows, err := c.db.Query(ctx, `
SELECT
m.id,
m.text,
c.vertical,
s.slug,
COALESCE(s.department_id, ''),
COALESCE(m.extracted, '{}'::jsonb)::text
FROM messages m
JOIN channels c ON c.id = m.channel_id
JOIN sections s ON s.id = c.section_id
WHERE m.text IS NOT NULL
AND (
(c.vertical = 'hr' AND (m.extracted IS NULL OR m.extracted->'hr_lead' IS NULL))
OR
(c.vertical <> 'hr' AND (m.extracted IS NULL OR m.extracted->'lead' IS NULL))
)
ORDER BY m.id DESC
LIMIT $1
`, c.cfg.ClassifyBatchSize)
if err != nil {
return nil, err
}
defer rows.Close()
out := make([]pendingMessage, 0, c.cfg.ClassifyBatchSize)
for rows.Next() {
var msg pendingMessage
var extractedText string
if err := rows.Scan(&msg.ID, &msg.Text, &msg.Vertical, &msg.SectionSlug, &msg.DepartmentID, &extractedText); err != nil {
return nil, err
}
if err := json.Unmarshal([]byte(extractedText), &msg.Extracted); err != nil {
msg.Extracted = map[string]any{}
}
out = append(out, msg)
}
return out, rows.Err()
}
func (c *classifier) classify(ctx context.Context, msg pendingMessage) (json.RawMessage, error) {
if len(strings.TrimSpace(msg.Text)) < c.cfg.LLMMinTextLength {
return marshalRaw(negativeVerdict(msg.Vertical))
}
systemPrompt, err := c.resolvePrompt(ctx, msg.Vertical, msg.DepartmentID, msg.SectionSlug)
if err != nil {
return nil, err
}
payload := chatRequest{
Model: c.cfg.LLMModel,
Messages: []chatMessage{
{Role: "system", Content: systemPrompt},
{Role: "user", Content: buildUserPrompt(msg.Text)},
},
Temperature: 0.1,
MaxTokens: c.cfg.LLMMaxTokens,
ResponseFormat: responseFmt{Type: "json_object"},
}
body, err := json.Marshal(payload)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, strings.TrimRight(c.cfg.LLMBaseURL, "/")+"/v1/chat/completions", bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
if c.cfg.LLMAPIKey != "" {
req.Header.Set("Authorization", "Bearer "+c.cfg.LLMAPIKey)
}
resp, err := c.http.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
b, _ := io.ReadAll(io.LimitReader(resp.Body, 512))
return nil, fmt.Errorf("llm http %d: %s", resp.StatusCode, strings.TrimSpace(string(b)))
}
var parsed chatResponse
if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil {
return nil, err
}
if len(parsed.Choices) == 0 {
return nil, errors.New("llm returned no choices")
}
raw := strings.TrimSpace(parsed.Choices[0].Message.Content)
if raw == "" {
return nil, errors.New("llm returned empty content")
}
block, err := extractJSONBlock(raw)
if err != nil {
return nil, err
}
normalized, err := normalizeVerdict(msg.Vertical, block)
if err != nil {
return nil, err
}
return normalized, nil
}
func (c *classifier) resolvePrompt(ctx context.Context, vertical, departmentID, sectionSlug string) (string, error) {
dept := departmentID
if dept == "" {
dept = "global"
}
keys := []string{}
if sectionSlug != "" {
keys = append(keys, promptKey(dept, vertical, sectionSlug))
}
keys = append(keys, promptKey(dept, vertical, ""))
for _, key := range keys {
var text string
err := c.db.QueryRow(ctx, `SELECT value #>> '{}' FROM app_settings WHERE key = $1`, key).Scan(&text)
if err == nil && strings.TrimSpace(text) != "" {
return text, nil
}
}
return defaultPrompt(vertical), nil
}
func (c *classifier) saveVerdict(ctx context.Context, id int64, key string, verdict json.RawMessage) error {
_, err := c.db.Exec(ctx, `
UPDATE messages
SET extracted = jsonb_set(COALESCE(extracted, '{}'::jsonb), ARRAY[$2], $3::jsonb, true)
WHERE id = $1
`, id, key, string(verdict))
return err
}
func promptKey(departmentID, vertical, sectionSlug string) string {
if sectionSlug != "" {
return fmt.Sprintf("llm_system_prompt:%s:%s:%s", departmentID, vertical, sectionSlug)
}
return fmt.Sprintf("llm_system_prompt:%s:%s", departmentID, vertical)
}
func verdictKey(vertical string) string {
if vertical == verticalHR {
return "hr_lead"
}
return "lead"
}
func buildUserPrompt(text string) string {
return "Текст сообщения:\n```\n" + text + "\n```\nВерни JSON."
}
func extractJSONBlock(raw string) (json.RawMessage, error) {
var payload json.RawMessage
if err := json.Unmarshal([]byte(raw), &payload); err == nil {
return payload, nil
}
start := strings.Index(raw, "{")
end := strings.LastIndex(raw, "}")
if start < 0 || end < start {
return nil, errors.New("no json object in llm content")
}
block := raw[start : end+1]
if err := json.Unmarshal([]byte(block), &payload); err != nil {
return nil, err
}
return payload, nil
}
func normalizeVerdict(vertical string, raw json.RawMessage) (json.RawMessage, error) {
var m map[string]any
if err := json.Unmarshal(raw, &m); err != nil {
return nil, err
}
if vertical == verticalHR {
if _, ok := m["is_lead"]; !ok {
m["is_lead"] = false
}
} else if _, ok := m["is_listing"]; !ok {
m["is_listing"] = false
}
if confidence, ok := asFloat(m["confidence"]); ok {
if confidence < 0 {
confidence = 0
}
if confidence > 1 {
confidence = 1
}
m["confidence"] = confidence
}
return marshalRaw(m)
}
func negativeVerdict(vertical string) map[string]any {
if vertical == verticalHR {
return map[string]any{
"is_lead": false,
"kind": nil,
"summary": "",
"confidence": 0,
}
}
return map[string]any{
"is_listing": false,
"kind": nil,
"summary": "",
"confidence": 0,
}
}
func marshalRaw(v any) (json.RawMessage, error) {
b, err := json.Marshal(v)
return json.RawMessage(b), err
}
func asFloat(v any) (float64, bool) {
switch x := v.(type) {
case float64:
return x, true
case float32:
return float64(x), true
case int:
return float64(x), true
case int64:
return float64(x), true
case json.Number:
f, err := x.Float64()
return f, err == nil
default:
return 0, false
}
}
func defaultPrompt(vertical string) string {
if vertical == verticalHR {
return defaultHRPrompt
}
return defaultREPrompt
}
func loadConfig() config {
return config{
PostgresUser: env("POSTGRES_USER", "parser"),
PostgresPassword: env("POSTGRES_PASSWORD", "parser"),
PostgresDB: env("POSTGRES_DB", "parser"),
PostgresHost: env("POSTGRES_HOST", "db"),
PostgresPort: envInt("POSTGRES_PORT", 5432),
LLMEnabled: envBool("LLM_ENABLED", true),
LLMBaseURL: env("LLM_BASE_URL", "http://10.2.3.5:8002"),
LLMAPIKey: env("LLM_API_KEY", ""),
LLMModel: env("LLM_MODEL", "qwen2.5-14b"),
LLMTimeout: time.Duration(envInt("LLM_TIMEOUT_SECONDS", 120)) * time.Second,
LLMMaxTokens: envInt("LLM_MAX_TOKENS", 600),
LLMMinTextLength: envInt("LLM_MIN_TEXT_LENGTH", 20),
ClassifyInterval: time.Duration(envInt("LLM_CLASSIFY_INTERVAL_SECONDS", 20)) * time.Second,
ClassifyBatchSize: envInt("LLM_CLASSIFY_BATCH_SIZE", 5),
}
}
func (c config) databaseURL() string {
return fmt.Sprintf(
"postgres://%s:%s@%s:%d/%s",
url.QueryEscape(c.PostgresUser),
url.QueryEscape(c.PostgresPassword),
c.PostgresHost,
c.PostgresPort,
url.QueryEscape(c.PostgresDB),
)
}
func env(key, fallback string) string {
if v := strings.TrimSpace(os.Getenv(key)); v != "" {
return v
}
return fallback
}
func envInt(key string, fallback int) int {
if raw := strings.TrimSpace(os.Getenv(key)); raw != "" {
if n, err := strconv.Atoi(raw); err == nil {
return n
}
}
return fallback
}
func envBool(key string, fallback bool) bool {
if raw := strings.TrimSpace(os.Getenv(key)); raw != "" {
if b, err := strconv.ParseBool(raw); err == nil {
return b
}
if raw == "1" {
return true
}
if raw == "0" {
return false
}
}
return fallback
}
const defaultREPrompt = `Ты — аналитик объявлений о недвижимости. Тебе дают текст из Telegram-канала.
Определи, является ли сообщение реальным объявлением о покупке, продаже или аренде недвижимости.
Отвечай строго валидным JSON без markdown:
{
"is_listing": boolean,
"kind": "sale" | "rent" | "purchase" | null,
"property_type": string | null,
"rooms": string | null,
"area_m2": number | null,
"price_text": string | null,
"price_value": number | null,
"currency": "RUB" | "USD" | "EUR" | "AED" | "GBP" | "CNY" | "TRY" | "KZT" | "BYN" | "UAH" | null,
"location": string | null,
"contact_phone": string | null,
"contact_name": string | null,
"summary": string,
"confidence": number
}
summary всегда по-русски, confidence в диапазоне 0..1.`
const defaultHRPrompt = `Ты — аналитик HR-объявлений. Тебе дают текст из Telegram-канала.
Определи, относится ли сообщение к рынку труда: вакансия, резюме или короткий HR-контакт.
Отвечай строго валидным JSON без markdown:
{
"is_lead": boolean,
"kind": "vacancy" | "resume" | "contact" | null,
"title": string | null,
"company": string | null,
"candidate_name": string | null,
"experience_years": number | null,
"skills": string[],
"location": string | null,
"remote": boolean | null,
"employment_type": "full-time" | "part-time" | "contract" | "internship" | null,
"salary_text": string | null,
"salary_value": number | null,
"currency": "RUB" | "USD" | "EUR" | "AED" | "GBP" | "CNY" | "TRY" | "KZT" | "BYN" | "UAH" | null,
"contact_phone": string | null,
"contact_name": string | null,
"summary": string,
"confidence": number
}
summary всегда по-русски, confidence в диапазоне 0..1.`