diff --git a/cmd/server/main.go b/cmd/server/main.go index ec438d8..39c9ddf 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -74,6 +74,7 @@ type accessScope struct { CanManage bool CanAuth bool DeptID string + DeptIDs []string } type sectionOut struct { @@ -261,6 +262,7 @@ func (a *app) handleAccessMe(w http.ResponseWriter, r *http.Request) { "can_manage_department": scope.CanManage, "can_auth_telegram": scope.CanAuth, "department_id": nullableString(scope.DeptID), + "department_ids": scope.departmentIDs(), }) } @@ -288,8 +290,7 @@ func (a *app) listSections(ctx context.Context, w http.ResponseWriter, r *http.R args := []any{vertical} deptFilter := "" if !scope.IsAdmin { - args = append(args, scope.DeptID) - deptFilter = fmt.Sprintf(" AND s.department_id = $%d", len(args)) + args, deptFilter = appendDepartmentFilter(args, scope, "s.department_id") } rows, err := a.db.Query(ctx, ` @@ -367,7 +368,7 @@ func (a *app) createSection(ctx context.Context, w http.ResponseWriter, r *http. writeError(w, http.StatusBadRequest, "vertical, slug and title are required") return } - dept := nullableString(scope.DeptID) + dept := nullableString(scope.primaryDepartmentID()) row := a.db.QueryRow(ctx, ` INSERT INTO sections (vertical, department_id, slug, title, emoji, description) VALUES ($1, $2, $3, $4, $5, $6) @@ -459,8 +460,9 @@ func (a *app) updateSection(ctx context.Context, w http.ResponseWriter, r *http. args = append(args, vertical, slug) where := fmt.Sprintf("vertical = $%d AND slug = $%d", len(args)-1, len(args)) if !scope.IsAdmin { - args = append(args, scope.DeptID) - where += fmt.Sprintf(" AND department_id = $%d", len(args)) + var deptFilter string + args, deptFilter = appendDepartmentFilter(args, scope, "department_id") + where += deptFilter } row := a.db.QueryRow(ctx, ` UPDATE sections @@ -510,8 +512,9 @@ func (a *app) findSection(ctx context.Context, vertical, slug string, scope acce args := []any{vertical, slug} where := "s.vertical = $1 AND s.slug = $2" if !scope.IsAdmin { - args = append(args, scope.DeptID) - where += fmt.Sprintf(" AND s.department_id = $%d", len(args)) + var deptFilter string + args, deptFilter = appendDepartmentFilter(args, scope, "s.department_id") + where += deptFilter } row := a.db.QueryRow(ctx, ` SELECT s.id, s.vertical, COALESCE(s.department_id, ''), s.slug, s.title, COALESCE(s.emoji, ''), COALESCE(s.description, ''), s.created_at, @@ -553,8 +556,9 @@ func (a *app) listChannels(ctx context.Context, w http.ResponseWriter, r *http.R where += fmt.Sprintf(" AND s.slug = $%d", len(args)) } if !scope.IsAdmin { - args = append(args, scope.DeptID) - where += fmt.Sprintf(" AND s.department_id = $%d", len(args)) + var deptFilter string + args, deptFilter = appendDepartmentFilter(args, scope, "s.department_id") + where += deptFilter } rows, err := a.db.Query(ctx, ` SELECT c.id, COALESCE(c.tg_id, src.tg_id), c.source_channel_id, @@ -785,8 +789,9 @@ func (a *app) findChannel(ctx context.Context, id int64, scope accessScope, vert where += fmt.Sprintf(" AND s.slug = $%d", len(args)) } if !scope.IsAdmin { - args = append(args, scope.DeptID) - where += fmt.Sprintf(" AND s.department_id = $%d", len(args)) + var deptFilter string + args, deptFilter = appendDepartmentFilter(args, scope, "s.department_id") + where += deptFilter } row := a.db.QueryRow(ctx, ` SELECT c.id, COALESCE(c.tg_id, src.tg_id), c.source_channel_id, @@ -934,8 +939,9 @@ func (a *app) handleMessages(ctx context.Context, w http.ResponseWriter, r *http where += fmt.Sprintf(" AND COALESCE(mc.verdict ->> $%d, m.extracted -> $%d ->> $%d) = 'true'", len(args), len(args)-1, len(args)) } if !scope.IsAdmin { - args = append(args, scope.DeptID) - where += fmt.Sprintf(" AND s.department_id = $%d", len(args)) + var deptFilter string + args, deptFilter = appendDepartmentFilter(args, scope, "s.department_id") + where += deptFilter } fetchLimit := clampInt(limit*5, limit, 1000) args = append(args, fetchLimit, offset) @@ -1002,8 +1008,9 @@ func (a *app) handleMessageItem(ctx context.Context, w http.ResponseWriter, r *h args := []any{id} where := "m.id = $1" if !scope.IsAdmin { - args = append(args, scope.DeptID) - where += fmt.Sprintf(" AND s.department_id = $%d", len(args)) + var deptFilter string + args, deptFilter = appendDepartmentFilter(args, scope, "s.department_id") + where += deptFilter } row := a.db.QueryRow(ctx, ` SELECT m.id, c.id, c.vertical, s.slug, m.tg_message_id, m.grouped_id, 1::int, @@ -1119,11 +1126,11 @@ func (a *app) serveMinioMedia(w http.ResponseWriter, r *http.Request, key string func (a *app) canReadChannelMedia(ctx context.Context, scope accessScope, channelID int64) (bool, error) { var allowed bool err := a.db.QueryRow(ctx, ` - SELECT COALESCE(bool_or(s.department_id = $2 OR $3::boolean), false) + SELECT COALESCE(bool_or($3::boolean OR s.department_id::text = ANY($2::text[])), false) FROM channels c JOIN sections s ON s.id = c.section_id WHERE c.id = $1 OR c.source_channel_id = $1 - `, channelID, scope.DeptID, scope.IsAdmin).Scan(&allowed) + `, channelID, scope.departmentIDs(), scope.IsAdmin).Scan(&allowed) if errors.Is(err, pgx.ErrNoRows) { return false, nil } @@ -1154,8 +1161,9 @@ func (a *app) handleStats(ctx context.Context, w http.ResponseWriter, r *http.Re where += fmt.Sprintf(" AND s.slug = $%d", len(args)) } if !scope.IsAdmin { - args = append(args, scope.DeptID) - where += fmt.Sprintf(" AND s.department_id = $%d", len(args)) + var deptFilter string + args, deptFilter = appendDepartmentFilter(args, scope, "s.department_id") + where += deptFilter } var channelsTotal, channelsActive, messagesTotal, messages24h, leadsTotal, leads24h int64 @@ -1272,8 +1280,9 @@ func (a *app) pendingLLM(ctx context.Context, scope accessScope, vertical, secti where += fmt.Sprintf(" AND s.slug = $%d", len(args)) } if !scope.IsAdmin { - args = append(args, scope.DeptID) - where += fmt.Sprintf(" AND s.department_id = $%d", len(args)) + var deptFilter string + args, deptFilter = appendDepartmentFilter(args, scope, "s.department_id") + where += deptFilter } var pending int64 err := a.db.QueryRow(ctx, ` @@ -1310,25 +1319,24 @@ func (a *app) getPrompt(ctx context.Context, w http.ResponseWriter, r *http.Requ return } section := strings.TrimSpace(r.URL.Query().Get("section")) - if section != "" { - if _, err := a.findSection(ctx, vertical, section, scope); err != nil { - writeDBError(w, err) - return - } - } - prompt, source, err := a.resolvePrompt(ctx, scope.DeptID, vertical, section) + deptID, err := a.promptDepartmentID(ctx, scope, vertical, section) if err != nil { writeDBError(w, err) return } - overridden, err := a.promptExists(ctx, scope.DeptID, vertical, section) + prompt, source, err := a.resolvePrompt(ctx, deptID, vertical, section) + if err != nil { + writeDBError(w, err) + return + } + overridden, err := a.promptExists(ctx, deptID, vertical, section) if err != nil { writeDBError(w, err) return } writeJSON(w, http.StatusOK, map[string]any{ "vertical": vertical, - "department_id": nullableString(scope.DeptID), + "department_id": nullableString(deptID), "section": nullableString(section), "prompt": prompt, "default": defaultPrompt(vertical), @@ -1362,7 +1370,12 @@ func (a *app) savePrompt(ctx context.Context, w http.ResponseWriter, r *http.Req writeError(w, http.StatusBadRequest, "prompt is too long (max 30000 chars)") return } - key := promptKey(scope.DeptID, vertical, section) + deptID, err := a.promptDepartmentID(ctx, scope, vertical, section) + if err != nil { + writeDBError(w, err) + return + } + key := promptKey(deptID, vertical, section) value, _ := json.Marshal(text) if _, err := a.db.Exec(ctx, ` INSERT INTO app_settings (key, value, updated_at) @@ -1372,7 +1385,7 @@ func (a *app) savePrompt(ctx context.Context, w http.ResponseWriter, r *http.Req writeDBError(w, err) return } - writeJSON(w, http.StatusOK, map[string]any{"saved": true, "vertical": vertical, "department_id": nullableString(scope.DeptID), "section": nullableString(section), "length": len(text)}) + writeJSON(w, http.StatusOK, map[string]any{"saved": true, "vertical": vertical, "department_id": nullableString(deptID), "section": nullableString(section), "length": len(text)}) } func (a *app) resetPrompt(ctx context.Context, w http.ResponseWriter, r *http.Request) { @@ -1385,13 +1398,29 @@ func (a *app) resetPrompt(ctx context.Context, w http.ResponseWriter, r *http.Re return } section := strings.TrimSpace(r.URL.Query().Get("section")) - if _, err := a.db.Exec(ctx, `DELETE FROM app_settings WHERE key = $1`, promptKey(scope.DeptID, vertical, section)); err != nil { + deptID, err := a.promptDepartmentID(ctx, scope, vertical, section) + if err != nil { + writeDBError(w, err) + return + } + if _, err := a.db.Exec(ctx, `DELETE FROM app_settings WHERE key = $1`, promptKey(deptID, vertical, section)); err != nil { writeDBError(w, err) return } w.WriteHeader(http.StatusNoContent) } +func (a *app) promptDepartmentID(ctx context.Context, scope accessScope, vertical, section string) (string, error) { + if strings.TrimSpace(section) == "" { + return scope.primaryDepartmentID(), nil + } + item, err := a.findSection(ctx, vertical, section, scope) + if err != nil { + return "", err + } + return valueOrEmpty(item.DepartmentID), nil +} + func (a *app) resolvePrompt(ctx context.Context, deptID, vertical, section string) (string, string, error) { keys := []struct { key string @@ -1474,11 +1503,11 @@ func (a *app) readScope(w http.ResponseWriter, r *http.Request, manage bool) (ac writeError(w, http.StatusNotFound, "not found") return scope, false } - } else if !scope.IsAdmin && scope.DeptID == "" { + } else if !scope.IsAdmin && len(scope.departmentIDs()) == 0 { writeError(w, http.StatusForbidden, "department is required") return scope, false } - if manage && !scope.IsAdmin && scope.DeptID == "" { + if manage && !scope.IsAdmin && len(scope.departmentIDs()) == 0 { writeError(w, http.StatusForbidden, "department is required") return scope, false } @@ -1490,14 +1519,75 @@ func readAccess(r *http.Request) accessScope { deptHead := r.Header.Get("X-User-Is-Department-Head") == "1" canManage := r.Header.Get("X-Monitoring-TG-Can-Manage") == "1" canAuth := r.Header.Get("X-Monitoring-TG-Can-Auth") == "1" + deptID := strings.TrimSpace(r.Header.Get("X-User-Department-Id")) + deptIDs := parseCSVHeader(r.Header.Get("X-User-Department-Ids")) + if deptID != "" { + deptIDs = appendUniqueString(deptIDs, deptID) + } return accessScope{ IsAdmin: admin, CanManage: admin || deptHead || canManage, CanAuth: admin || canAuth, - DeptID: strings.TrimSpace(r.Header.Get("X-User-Department-Id")), + DeptID: deptID, + DeptIDs: deptIDs, } } +func parseCSVHeader(raw string) []string { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil + } + parts := strings.Split(raw, ",") + out := make([]string, 0, len(parts)) + for _, part := range parts { + out = appendUniqueString(out, strings.TrimSpace(part)) + } + return out +} + +func appendUniqueString(items []string, value string) []string { + value = strings.TrimSpace(value) + if value == "" { + return items + } + for _, item := range items { + if item == value { + return items + } + } + return append(items, value) +} + +func (s accessScope) departmentIDs() []string { + out := make([]string, 0, len(s.DeptIDs)+1) + for _, id := range s.DeptIDs { + out = appendUniqueString(out, id) + } + out = appendUniqueString(out, s.DeptID) + return out +} + +func (s accessScope) primaryDepartmentID() string { + if strings.TrimSpace(s.DeptID) != "" { + return strings.TrimSpace(s.DeptID) + } + ids := s.departmentIDs() + if len(ids) == 0 { + return "" + } + return ids[0] +} + +func appendDepartmentFilter(args []any, scope accessScope, column string) ([]any, string) { + ids := scope.departmentIDs() + if len(ids) == 0 { + ids = []string{"__no_department_scope__"} + } + args = append(args, ids) + return args, fmt.Sprintf(" AND %s::text = ANY($%d::text[])", column, len(args)) +} + type rowScanner interface { Scan(dest ...any) error } diff --git a/src/parser_bot/access.py b/src/parser_bot/access.py index d2462e1..fdc6ced 100644 --- a/src/parser_bot/access.py +++ b/src/parser_bot/access.py @@ -13,6 +13,19 @@ def portal_department_id(request: Request) -> str | None: return value or None +def portal_department_ids(request: Request) -> list[str]: + raw = (request.headers.get("x-user-department-ids") or "").strip() + out: list[str] = [] + for part in raw.split(","): + value = part.strip() + if value and value not in out: + out.append(value) + current = portal_department_id(request) + if current and current not in out: + out.append(current) + return out + + def is_department_head_request(request: Request) -> bool: return request.headers.get("x-user-is-department-head") == "1" diff --git a/src/parser_bot/api/routes.py b/src/parser_bot/api/routes.py index d12a199..bb1463c 100644 --- a/src/parser_bot/api/routes.py +++ b/src/parser_bot/api/routes.py @@ -7,7 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from parser_bot.access import ( is_admin_request, - portal_department_id, + portal_department_ids, require_department_manager, require_telegram_auth_manager, ) @@ -39,13 +39,13 @@ class AuthCodeResult(BaseModel): needs_password: bool -def _department_scope(request: Request) -> str | None: +def _department_scopes(request: Request) -> list[str] | None: if is_admin_request(request): return None - dept_id = portal_department_id(request) - if not dept_id: + dept_ids = portal_department_ids(request) + if not dept_ids: raise HTTPException(status_code=403, detail="department is required") - return dept_id + return dept_ids async def _require_channel_scope( @@ -55,7 +55,7 @@ async def _require_channel_scope( vertical: str | None, section: str | None, ) -> None: - department_id = _department_scope(request) + department_ids = _department_scopes(request) stmt = ( select(Channel.id) .join(Section, Section.id == Channel.section_id) @@ -65,8 +65,8 @@ async def _require_channel_scope( stmt = stmt.where(Channel.vertical == vertical) if section: stmt = stmt.where(Section.slug == section) - if department_id is not None: - stmt = stmt.where(Section.department_id == department_id) + if department_ids is not None: + stmt = stmt.where(Section.department_id.in_(department_ids)) exists = (await session.execute(stmt)).scalar_one_or_none() if exists is None: raise HTTPException(status_code=404) @@ -182,7 +182,7 @@ async def trigger_poll_all( section: str | None = Query(None), session: AsyncSession = Depends(get_session), ) -> dict[str, Any]: - department_id = _department_scope(request) + department_ids = _department_scopes(request) stmt = ( select(Channel.id) .join(Section, Section.id == Channel.section_id) @@ -190,8 +190,8 @@ async def trigger_poll_all( ) if section: stmt = stmt.where(Section.slug == section) - if department_id is not None: - stmt = stmt.where(Section.department_id == department_id) + if department_ids is not None: + stmt = stmt.where(Section.department_id.in_(department_ids)) result = await session.execute(stmt) ids = [row[0] for row in result.all()] background.add_task(_poll_all_in_background, ids)