package middlewares import ( "context" "crypto/sha256" "encoding/hex" "encoding/json" "errors" "fmt" "log" "net" "net/http" "os" "time" "git.insit.tech/psa/rtsp_reader-writer/reader/internal/config" jwtware "github.com/gofiber/contrib/jwt" "github.com/gofiber/fiber/v2" "github.com/golang-jwt/jwt/v5" "github.com/sirupsen/logrus" ) // CheckJWT is custom middleware that checks if token is valid. func CheckJWT(c *fiber.Ctx) error { // Пропускаем JWT проверку, если есть session_id if c.Query("session_id") != "" { return c.Next() } // Выполняем JWT проверку только при наличии токена return jwtware.New(jwtware.Config{ SigningKey: jwtware.SigningKey{ Key: config.JwtSecretKey, }, ContextKey: config.ContextKeyUser, TokenLookup: "query:token", })(c) } // CheckJWT2 is custom middleware that checks if token is valid. func CheckJWT2(next http.Handler) http.Handler { log.Println("before return CheckJWT2") return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log.Println("hello from CheckJWT2") tokenStr := r.URL.Query().Get("token") if tokenStr == "" { next.ServeHTTP(w, r) return } token, err := jwt.Parse(tokenStr, func(t *jwt.Token) (interface{}, error) { return config.JwtSecretKey, nil }) if err != nil || !token.Valid { w.WriteHeader(http.StatusUnauthorized) return } ctx := context.WithValue(r.Context(), config.ContextKeyUser, token) next.ServeHTTP(w, r.WithContext(ctx)) log.Println("CheckJWT2 finished") }) } // CheckSessionID middleware. func CheckSessionID(c *fiber.Ctx) error { // Если есть валидный JWT в контексте if c.Locals(config.ContextKeyUser) != nil { return c.Next() } // Проверка session_id sessionID := c.Query("session_id") if sessionID == "" { return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Auth required"}) } // Получаем данные из памяти data, err := getSessionData(sessionID) if err != nil { return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Invalid session"}) } if !validateSessionParams(c, data) { deleteSession(sessionID) return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Session parameters mismatch"}) } if time.Since(data.LastCheck) > 3*time.Minute { if !checkTokenInDB(data.Token) { deleteSession(sessionID) return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Token expired"}) } data.LastCheck = time.Now() saveSession(sessionID, data) } c.Locals(config.ContextKeySession, data) return c.Next() } // CheckSessionID2 middleware. func CheckSessionID2(next http.Handler) http.Handler { log.Println("before return CheckSessionID2") return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log.Println("hello from CheckSessionID2") if r.Context().Value(config.ContextKeyUser) != nil { next.ServeHTTP(w, r) return } sessionID := r.URL.Query().Get("session_id") if sessionID == "" { w.WriteHeader(http.StatusUnauthorized) return } data, err := getSessionData(sessionID) if err != nil { w.WriteHeader(http.StatusUnauthorized) return } if !validateSessionParam2(r, data) { deleteSession(sessionID) w.WriteHeader(http.StatusUnauthorized) return } if time.Since(data.LastCheck) > 3*time.Minute { if !checkTokenInDB(data.Token) { deleteSession(sessionID) w.WriteHeader(http.StatusUnauthorized) return } data.LastCheck = time.Now() saveSession(sessionID, data) } ctx := context.WithValue(r.Context(), config.ContextKeySession, data) next.ServeHTTP(w, r.WithContext(ctx)) log.Println("CheckSessionID2 finished") }) } // Работа с памятью func getSessionData(sessionID string) (*config.SessionData, error) { // Для работы с файлами (раскомментировать при необходимости): // return loadFromFile(sessionID) val, ok := config.SessionStore.Load(sessionID) if !ok { return nil, fmt.Errorf("session not found") } data, ok := val.(*config.SessionData) if !ok { return nil, fmt.Errorf("invalid session data format") } return data, nil } // Остальные функции без изменений func validateSessionParams(c *fiber.Ctx, data *config.SessionData) bool { log.Printf("HTTP/1.1 && HTTP/2 Protocol: %s", c.Protocol()) log.Printf("HTTP/1.1 && HTTP/2 File: %s", c.Params("file")) log.Printf("HTTP/1.1 && HTTP/2 IP: %s", c.IP()) log.Printf("DB Protocol: %s", data.Proto) log.Printf("DB File: %s", data.FileName) log.Printf("DB IP: %s", data.IP) return c.Protocol() == data.Proto && c.Params("file") == data.FileName && c.IP() == data.IP } // Остальные функции без изменений func validateSessionParam2(r *http.Request, data *config.SessionData) bool { scheme := "http" if r.TLS != nil { scheme = "https" } file := r.PathValue("file") clientIP, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { // Если SplitHostPort не сработал, оставляем как есть clientIP = r.RemoteAddr } log.Printf("HTTP/3 Protocol: %s", scheme) log.Printf("HTTP/3 File: %s", file) log.Printf("HTTP/3 IP: %s", clientIP) log.Printf("DB Protocol: %s", data.Proto) log.Printf("DB File: %s", data.FileName) log.Printf("DB IP: %s", data.IP) return scheme == data.Proto && file == data.FileName && clientIP == data.IP } func checkTokenInDB(token string) bool { // Ваша реализация проверки токена return true // временная заглушка } // Генерация session_id func generateSessionID(c *fiber.Ctx, token string) string { data := fmt.Sprintf("%s|%s|%s|%s", c.Protocol(), c.Params("file"), c.IP(), token, ) hash := sha256.Sum256([]byte(data)) return hex.EncodeToString(hash[:]) } // Генерация session_id func generateSessionID2(r *http.Request, token string) string { // Determine protocol. scheme := "http" if r.TLS != nil { scheme = "https" } // Determine filename. file := r.PathValue("file") // Determine IP. clientIP, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { // Если SplitHostPort не сработал, оставляем как есть clientIP = r.RemoteAddr } data := fmt.Sprintf("%s|%s|%s|%s", scheme, file, clientIP, token, ) hash := sha256.Sum256([]byte(data)) return hex.EncodeToString(hash[:]) } func saveSession(sessionID string, data *config.SessionData) { // Для работы с файлами (раскомментировать при необходимости): // saveToFile(sessionID, data) config.SessionStore.Store(sessionID, data) } func deleteSession(sessionID string) { // Для работы с файлами (раскомментировать при необходимости): // deleteFromFile(sessionID) config.SessionStore.Delete(sessionID) } func saveToFile(sessionID string, data *config.SessionData) error { file, _ := os.ReadFile(config.SessionFile) var sessions []config.FileSession json.Unmarshal(file, &sessions) // Удаляем старую сессию если существует for i, s := range sessions { if s.ID == sessionID { sessions = append(sessions[:i], sessions[i+1:]...) break } } sessions = append(sessions, config.FileSession{ID: sessionID, Data: data}) bytes, _ := json.MarshalIndent(sessions, "", " ") return os.WriteFile(config.SessionFile, bytes, 0644) } func loadFromFile(sessionID string) (*config.SessionData, error) { file, err := os.ReadFile(config.SessionFile) if err != nil { return nil, err } var sessions []config.FileSession json.Unmarshal(file, &sessions) for _, s := range sessions { if s.ID == sessionID { return s.Data, nil } } return nil, fmt.Errorf("session not found") } func deleteFromFile(sessionID string) error { file, _ := os.ReadFile(config.SessionFile) var sessions []config.FileSession json.Unmarshal(file, &sessions) for i, s := range sessions { if s.ID == sessionID { sessions = append(sessions[:i], sessions[i+1:]...) bytes, _ := json.MarshalIndent(sessions, "", " ") return os.WriteFile(config.SessionFile, bytes, 0644) } } return nil } // CreateSessionID is custom middleware that creates session ID by using token. func CreateSessionID(c *fiber.Ctx) error { if token := c.Query("token"); token != "" { // Get payload from token. jwtPayload, ok := jwtPayloadFromRequest(c) if !ok { return c.SendStatus(fiber.StatusUnauthorized) } // Check if the owner of the token is registered. _, ok = config.Storage.Users[jwtPayload["sub"].(string)] if !ok { return errors.New("user not found") } // Create session data. sessionData := &config.SessionData{ Proto: c.Protocol(), FileName: c.Params("file"), IP: c.IP(), Token: token, LastCheck: time.Now(), } // Генерируем session_id sessionID := generateSessionID(c, token) // Сохраняем сессию saveSession(sessionID, sessionData) return c.JSON(fiber.Map{ "session_id": sessionID, }) } return c.Next() } // CreateSessionID2 is custom middleware that creates session ID by using token. func CreateSessionID2(next http.Handler) http.Handler { log.Println("before return CreateSessionID2") return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log.Println("hello from CreateSessionID2") if token := r.URL.Query().Get("token"); token != "" { // Get payload from token. jwtPayload, ok := jwtPayloadFromRequest2(r) if !ok { w.WriteHeader(http.StatusUnauthorized) return } // Check if the owner of the token is registered. _, ok = config.Storage.Users[jwtPayload["sub"].(string)] if !ok { w.WriteHeader(http.StatusUnauthorized) w.Write([]byte("user not found")) return } // Determine protocol. scheme := "http" if r.TLS != nil { scheme = "https" } // Determine filename. file := r.PathValue("file") // Determine IP. clientIP, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { // Если SplitHostPort не сработал, оставляем как есть clientIP = r.RemoteAddr } // Create session data. sessionData := &config.SessionData{ Proto: scheme, FileName: file, IP: clientIP, Token: token, LastCheck: time.Now(), } // Генерируем session_id sessionID := generateSessionID2(r, token) // Сохраняем сессию saveSession(sessionID, sessionData) w.Header().Add("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(sessionID); err != nil { w.WriteHeader(http.StatusInternalServerError) } return } next.ServeHTTP(w, r) log.Println("hello from CreateSessionID2: StatusOK") }) } func jwtPayloadFromRequest(c *fiber.Ctx) (jwt.MapClaims, bool) { jwtToken, ok := c.Context().Value(config.ContextKeyUser).(*jwt.Token) if !ok { logrus.WithFields(logrus.Fields{ "jwt_token_context_value": c.Context().Value(config.ContextKeyUser), }).Error("wrong type of JWT token in context") return nil, false } payload, ok := jwtToken.Claims.(jwt.MapClaims) if !ok { logrus.WithFields(logrus.Fields{ "jwt_token_claims": jwtToken.Claims, }).Error("wrong type of JWT token claims") return nil, false } return payload, true } func jwtPayloadFromRequest2(r *http.Request) (jwt.MapClaims, bool) { jwtToken, ok := r.Context().Value(config.ContextKeyUser).(*jwt.Token) if !ok { logrus.WithFields(logrus.Fields{ "jwt_token_context_value": r.Context().Value(config.ContextKeyUser), }).Error("wrong type of JWT token in context") return nil, false } payload, ok := jwtToken.Claims.(jwt.MapClaims) if !ok { logrus.WithFields(logrus.Fields{ "jwt_token_claims": jwtToken.Claims, }).Error("wrong type of JWT token claims") return nil, false } return payload, true } // AuthMiddleware collects auth middlewares. func AuthMiddleware(h http.Handler) http.Handler { return CheckJWT2(CheckSessionID2(CreateSessionID2(h))) }