package handler import ( "bytes" "context" "encoding/json" "log/slog" "net/http" "strconv" "strings" "time" chimw "github.com/go-chi/chi/v5/middleware" commonaudit "gitea.estateliga.work/admin/portal-common/audit" commonmw "gitea.estateliga.work/admin/portal-common/middleware" ) type AuditMiddleware struct { client *commonaudit.Client } func NewAuditMiddleware(client *commonaudit.Client) *AuditMiddleware { return &AuditMiddleware{client: client} } func (m *AuditMiddleware) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if m == nil || m.client == nil || !m.client.Enabled() || !isAuditWriteMethod(r.Method) { next.ServeHTTP(w, r) return } rec := &auditResponseRecorder{ResponseWriter: w, status: http.StatusOK} start := time.Now() next.ServeHTTP(rec, r) if rec.status < 200 || rec.status >= 300 { return } action := matchLearningAuditAction(r.Method, r.URL.Path) if action.Action == "" { return } if action.EntityID == "" { action.EntityID = extractAuditEntityID(rec.body.Bytes()) } details := map[string]any{ "method": r.Method, "path": r.URL.Path, "status": rec.status, "duration_ms": time.Since(start).Milliseconds(), } if rid := chimw.GetReqID(r.Context()); rid != "" { details["request_id"] = rid } if q := r.URL.RawQuery; q != "" { details["query"] = q } event := commonaudit.Event{ Action: action.Action, EntityType: action.EntityType, EntityID: action.EntityID, UserID: commonmw.GetUserID(r.Context()), UserName: commonmw.GetUserName(r.Context()), IPAddress: commonmw.GetClientIP(r.Context()), Details: details, } go func() { if err := m.client.Send(context.Background(), event); err != nil { slog.Warn("learning audit send failed", "error", err, "action", event.Action, "entity_id", event.EntityID) } }() }) } type learningAuditAction struct { Action string EntityType string EntityID string } type auditResponseRecorder struct { http.ResponseWriter status int body bytes.Buffer } func (r *auditResponseRecorder) WriteHeader(status int) { r.status = status r.ResponseWriter.WriteHeader(status) } func (r *auditResponseRecorder) Write(body []byte) (int, error) { if r.body.Len() < 64*1024 { _, _ = r.body.Write(body) } return r.ResponseWriter.Write(body) } func isAuditWriteMethod(method string) bool { switch method { case http.MethodPost, http.MethodPatch, http.MethodPut, http.MethodDelete: return true default: return false } } func matchLearningAuditAction(method, path string) learningAuditAction { parts := strings.Split(strings.Trim(path, "/"), "/") if len(parts) > 0 && parts[0] == "api" { parts = parts[1:] } if len(parts) == 0 { return learningAuditAction{} } switch parts[0] { case "tests": return matchTestAuditAction(method, parts) case "attempts": return matchAttemptAuditAction(method, parts) case "courses": return matchCourseAuditAction(method, parts) case "lessons": return matchLessonAuditAction(method, parts) case "access": return matchAccessAuditAction(method, parts) case "public-tokens": return matchPublicTokenAuditAction(method, parts) default: return learningAuditAction{} } } func matchTestAuditAction(method string, parts []string) learningAuditAction { if method == http.MethodPost && len(parts) == 1 { return learningAuditAction{"learning.test_create", "learning_test", ""} } if len(parts) < 2 { return learningAuditAction{} } testID := parts[1] if len(parts) == 2 { switch method { case http.MethodPatch: return learningAuditAction{"learning.test_update", "learning_test", testID} case http.MethodDelete: return learningAuditAction{"learning.test_delete", "learning_test", testID} } } if len(parts) >= 3 { switch parts[2] { case "questions": return matchQuestionAuditAction(method, parts, testID) case "attempts": if method == http.MethodPost && len(parts) == 3 { return learningAuditAction{"learning.attempt_start", "learning_test", testID} } } } return learningAuditAction{} } func matchQuestionAuditAction(method string, parts []string, testID string) learningAuditAction { if method == http.MethodPost && len(parts) == 3 { return learningAuditAction{"learning.question_create", "learning_test", testID} } if method == http.MethodPost && len(parts) == 4 && parts[3] == "reorder" { return learningAuditAction{"learning.question_reorder", "learning_test", testID} } if len(parts) == 4 { switch method { case http.MethodPut: return learningAuditAction{"learning.question_update", "learning_question", parts[3]} case http.MethodDelete: return learningAuditAction{"learning.question_delete", "learning_question", parts[3]} } } return learningAuditAction{} } func matchAttemptAuditAction(method string, parts []string) learningAuditAction { if method == http.MethodPost && len(parts) == 3 && parts[2] == "submit" { return learningAuditAction{"learning.attempt_submit", "learning_attempt", parts[1]} } return learningAuditAction{} } func matchCourseAuditAction(method string, parts []string) learningAuditAction { if method == http.MethodPost && len(parts) == 1 { return learningAuditAction{"learning.course_create", "learning_course", ""} } if len(parts) < 2 { return learningAuditAction{} } courseID := parts[1] if len(parts) == 2 { switch method { case http.MethodPatch: return learningAuditAction{"learning.course_update", "learning_course", courseID} case http.MethodDelete: return learningAuditAction{"learning.course_delete", "learning_course", courseID} } } if method == http.MethodPost && len(parts) == 3 && parts[2] == "lessons" { return learningAuditAction{"learning.lesson_create", "learning_course", courseID} } if method == http.MethodPost && len(parts) == 4 && parts[2] == "lessons" && parts[3] == "reorder" { return learningAuditAction{"learning.lesson_reorder", "learning_course", courseID} } return learningAuditAction{} } func matchLessonAuditAction(method string, parts []string) learningAuditAction { if len(parts) < 2 { return learningAuditAction{} } lessonID := parts[1] if len(parts) == 2 { switch method { case http.MethodPatch: return learningAuditAction{"learning.lesson_update", "learning_lesson", lessonID} case http.MethodDelete: return learningAuditAction{"learning.lesson_delete", "learning_lesson", lessonID} } } if len(parts) == 3 && parts[2] == "video" { switch method { case http.MethodPost: return learningAuditAction{"learning.lesson_video_upload", "learning_lesson", lessonID} case http.MethodDelete: return learningAuditAction{"learning.lesson_video_delete", "learning_lesson", lessonID} } } return learningAuditAction{} } func matchAccessAuditAction(method string, parts []string) learningAuditAction { entityID := strings.Join(parts, "/") if method == http.MethodPost && len(parts) == 3 { return learningAuditAction{"learning.access_grant_create", "learning_access", entityID} } if method == http.MethodDelete && len(parts) == 5 && parts[3] == "grants" { return learningAuditAction{"learning.access_grant_delete", "learning_access_grant", parts[4]} } return learningAuditAction{} } func matchPublicTokenAuditAction(method string, parts []string) learningAuditAction { if method == http.MethodPost && len(parts) == 1 { return learningAuditAction{"learning.public_token_create", "learning_public_token", ""} } if method == http.MethodDelete && len(parts) == 2 { return learningAuditAction{"learning.public_token_revoke", "learning_public_token", parts[1]} } return learningAuditAction{} } func extractAuditEntityID(body []byte) string { if len(body) == 0 { return "" } var payload struct { ID any `json:"id"` } if err := json.Unmarshal(body, &payload); err != nil { return "" } switch id := payload.ID.(type) { case string: return id case float64: return strconv.FormatInt(int64(id), 10) default: return "" } }