465 lines
12 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package middlewares
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"log"
"net"
"net/http"
"os"
"time"
"git.insit.tech/psa/rtsp_reader-writer/reader/internal/config"
jwtware "github.com/gofiber/contrib/jwt"
"github.com/gofiber/fiber/v2"
"github.com/golang-jwt/jwt/v5"
"github.com/sirupsen/logrus"
)
// CheckJWT is custom middleware that checks if token is valid.
func CheckJWT(c *fiber.Ctx) error {
// Пропускаем JWT проверку, если есть session_id
if c.Query("session_id") != "" {
return c.Next()
}
// Выполняем JWT проверку только при наличии токена
return jwtware.New(jwtware.Config{
SigningKey: jwtware.SigningKey{
Key: config.JwtSecretKey,
},
ContextKey: config.ContextKeyUser,
TokenLookup: "query:token",
})(c)
}
// CheckJWT2 is custom middleware that checks if token is valid.
func CheckJWT2(next http.Handler) http.Handler {
log.Println("before return CheckJWT2")
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Println("hello from CheckJWT2")
tokenStr := r.URL.Query().Get("token")
if tokenStr == "" {
next.ServeHTTP(w, r)
return
}
token, err := jwt.Parse(tokenStr, func(t *jwt.Token) (interface{}, error) {
return config.JwtSecretKey, nil
})
if err != nil || !token.Valid {
w.WriteHeader(http.StatusUnauthorized)
return
}
ctx := context.WithValue(r.Context(), config.ContextKeyUser, token)
next.ServeHTTP(w, r.WithContext(ctx))
log.Println("CheckJWT2 finished")
})
}
// CheckSessionID middleware.
func CheckSessionID(c *fiber.Ctx) error {
// Если есть валидный JWT в контексте
if c.Locals(config.ContextKeyUser) != nil {
return c.Next()
}
// Проверка session_id
sessionID := c.Query("session_id")
if sessionID == "" {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Auth required"})
}
// Получаем данные из памяти
data, err := getSessionData(sessionID)
if err != nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Invalid session"})
}
if !validateSessionParams(c, data) {
deleteSession(sessionID)
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Session parameters mismatch"})
}
if time.Since(data.LastCheck) > 3*time.Minute {
if !checkTokenInDB(data.Token) {
deleteSession(sessionID)
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Token expired"})
}
data.LastCheck = time.Now()
saveSession(sessionID, data)
}
c.Locals(config.ContextKeySession, data)
return c.Next()
}
// CheckSessionID2 middleware.
func CheckSessionID2(next http.Handler) http.Handler {
log.Println("before return CheckSessionID2")
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Println("hello from CheckSessionID2")
if r.Context().Value(config.ContextKeyUser) != nil {
next.ServeHTTP(w, r)
return
}
sessionID := r.URL.Query().Get("session_id")
if sessionID == "" {
w.WriteHeader(http.StatusUnauthorized)
return
}
data, err := getSessionData(sessionID)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
if !validateSessionParam2(r, data) {
deleteSession(sessionID)
w.WriteHeader(http.StatusUnauthorized)
return
}
if time.Since(data.LastCheck) > 3*time.Minute {
if !checkTokenInDB(data.Token) {
deleteSession(sessionID)
w.WriteHeader(http.StatusUnauthorized)
return
}
data.LastCheck = time.Now()
saveSession(sessionID, data)
}
ctx := context.WithValue(r.Context(), config.ContextKeySession, data)
next.ServeHTTP(w, r.WithContext(ctx))
log.Println("CheckSessionID2 finished")
})
}
// Работа с памятью
func getSessionData(sessionID string) (*config.SessionData, error) {
// Для работы с файлами (раскомментировать при необходимости):
// return loadFromFile(sessionID)
val, ok := config.SessionStore.Load(sessionID)
if !ok {
return nil, fmt.Errorf("session not found")
}
data, ok := val.(*config.SessionData)
if !ok {
return nil, fmt.Errorf("invalid session data format")
}
return data, nil
}
// Остальные функции без изменений
func validateSessionParams(c *fiber.Ctx, data *config.SessionData) bool {
log.Printf("HTTP/1.1 && HTTP/2 Protocol: %s", c.Protocol())
log.Printf("HTTP/1.1 && HTTP/2 File: %s", c.Params("file"))
log.Printf("HTTP/1.1 && HTTP/2 IP: %s", c.IP())
log.Printf("DB Protocol: %s", data.Proto)
log.Printf("DB File: %s", data.FileName)
log.Printf("DB IP: %s", data.IP)
return c.Protocol() == data.Proto &&
c.Params("file") == data.FileName &&
c.IP() == data.IP
}
// Остальные функции без изменений
func validateSessionParam2(r *http.Request, data *config.SessionData) bool {
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
file := r.PathValue("file")
clientIP, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
// Если SplitHostPort не сработал, оставляем как есть
clientIP = r.RemoteAddr
}
log.Printf("HTTP/3 Protocol: %s", scheme)
log.Printf("HTTP/3 File: %s", file)
log.Printf("HTTP/3 IP: %s", clientIP)
log.Printf("DB Protocol: %s", data.Proto)
log.Printf("DB File: %s", data.FileName)
log.Printf("DB IP: %s", data.IP)
return scheme == data.Proto &&
file == data.FileName &&
clientIP == data.IP
}
func checkTokenInDB(token string) bool {
// Ваша реализация проверки токена
return true // временная заглушка
}
// Генерация session_id
func generateSessionID(c *fiber.Ctx, token string) string {
data := fmt.Sprintf("%s|%s|%s|%s",
c.Protocol(),
c.Params("file"),
c.IP(),
token,
)
hash := sha256.Sum256([]byte(data))
return hex.EncodeToString(hash[:])
}
// Генерация session_id
func generateSessionID2(r *http.Request, token string) string {
// Determine protocol.
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
// Determine filename.
file := r.PathValue("file")
// Determine IP.
clientIP, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
// Если SplitHostPort не сработал, оставляем как есть
clientIP = r.RemoteAddr
}
data := fmt.Sprintf("%s|%s|%s|%s",
scheme,
file,
clientIP,
token,
)
hash := sha256.Sum256([]byte(data))
return hex.EncodeToString(hash[:])
}
func saveSession(sessionID string, data *config.SessionData) {
// Для работы с файлами (раскомментировать при необходимости):
// saveToFile(sessionID, data)
config.SessionStore.Store(sessionID, data)
}
func deleteSession(sessionID string) {
// Для работы с файлами (раскомментировать при необходимости):
// deleteFromFile(sessionID)
config.SessionStore.Delete(sessionID)
}
func saveToFile(sessionID string, data *config.SessionData) error {
file, _ := os.ReadFile(config.SessionFile)
var sessions []config.FileSession
json.Unmarshal(file, &sessions)
// Удаляем старую сессию если существует
for i, s := range sessions {
if s.ID == sessionID {
sessions = append(sessions[:i], sessions[i+1:]...)
break
}
}
sessions = append(sessions, config.FileSession{ID: sessionID, Data: data})
bytes, _ := json.MarshalIndent(sessions, "", " ")
return os.WriteFile(config.SessionFile, bytes, 0644)
}
func loadFromFile(sessionID string) (*config.SessionData, error) {
file, err := os.ReadFile(config.SessionFile)
if err != nil {
return nil, err
}
var sessions []config.FileSession
json.Unmarshal(file, &sessions)
for _, s := range sessions {
if s.ID == sessionID {
return s.Data, nil
}
}
return nil, fmt.Errorf("session not found")
}
func deleteFromFile(sessionID string) error {
file, _ := os.ReadFile(config.SessionFile)
var sessions []config.FileSession
json.Unmarshal(file, &sessions)
for i, s := range sessions {
if s.ID == sessionID {
sessions = append(sessions[:i], sessions[i+1:]...)
bytes, _ := json.MarshalIndent(sessions, "", " ")
return os.WriteFile(config.SessionFile, bytes, 0644)
}
}
return nil
}
// CreateSessionID is custom middleware that creates session ID by using token.
func CreateSessionID(c *fiber.Ctx) error {
if token := c.Query("token"); token != "" {
// Get payload from token.
jwtPayload, ok := jwtPayloadFromRequest(c)
if !ok {
return c.SendStatus(fiber.StatusUnauthorized)
}
// Check if the owner of the token is registered.
_, ok = config.Storage.Users[jwtPayload["sub"].(string)]
if !ok {
return errors.New("user not found")
}
// Create session data.
sessionData := &config.SessionData{
Proto: c.Protocol(),
FileName: c.Params("file"),
IP: c.IP(),
Token: token,
LastCheck: time.Now(),
}
// Генерируем session_id
sessionID := generateSessionID(c, token)
// Сохраняем сессию
saveSession(sessionID, sessionData)
return c.JSON(fiber.Map{
"session_id": sessionID,
})
}
return c.Next()
}
// CreateSessionID2 is custom middleware that creates session ID by using token.
func CreateSessionID2(next http.Handler) http.Handler {
log.Println("before return CreateSessionID2")
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Println("hello from CreateSessionID2")
if token := r.URL.Query().Get("token"); token != "" {
// Get payload from token.
jwtPayload, ok := jwtPayloadFromRequest2(r)
if !ok {
w.WriteHeader(http.StatusUnauthorized)
return
}
// Check if the owner of the token is registered.
_, ok = config.Storage.Users[jwtPayload["sub"].(string)]
if !ok {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte("user not found"))
return
}
// Determine protocol.
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
// Determine filename.
file := r.PathValue("file")
// Determine IP.
clientIP, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
// Если SplitHostPort не сработал, оставляем как есть
clientIP = r.RemoteAddr
}
// Create session data.
sessionData := &config.SessionData{
Proto: scheme,
FileName: file,
IP: clientIP,
Token: token,
LastCheck: time.Now(),
}
// Генерируем session_id
sessionID := generateSessionID2(r, token)
// Сохраняем сессию
saveSession(sessionID, sessionData)
w.Header().Add("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(sessionID); err != nil {
w.WriteHeader(http.StatusInternalServerError)
}
return
}
next.ServeHTTP(w, r)
log.Println("hello from CreateSessionID2: StatusOK")
})
}
func jwtPayloadFromRequest(c *fiber.Ctx) (jwt.MapClaims, bool) {
jwtToken, ok := c.Context().Value(config.ContextKeyUser).(*jwt.Token)
if !ok {
logrus.WithFields(logrus.Fields{
"jwt_token_context_value": c.Context().Value(config.ContextKeyUser),
}).Error("wrong type of JWT token in context")
return nil, false
}
payload, ok := jwtToken.Claims.(jwt.MapClaims)
if !ok {
logrus.WithFields(logrus.Fields{
"jwt_token_claims": jwtToken.Claims,
}).Error("wrong type of JWT token claims")
return nil, false
}
return payload, true
}
func jwtPayloadFromRequest2(r *http.Request) (jwt.MapClaims, bool) {
jwtToken, ok := r.Context().Value(config.ContextKeyUser).(*jwt.Token)
if !ok {
logrus.WithFields(logrus.Fields{
"jwt_token_context_value": r.Context().Value(config.ContextKeyUser),
}).Error("wrong type of JWT token in context")
return nil, false
}
payload, ok := jwtToken.Claims.(jwt.MapClaims)
if !ok {
logrus.WithFields(logrus.Fields{
"jwt_token_claims": jwtToken.Claims,
}).Error("wrong type of JWT token claims")
return nil, false
}
return payload, true
}
// AuthMiddleware collects auth middlewares.
func AuthMiddleware(h http.Handler) http.Handler {
return CheckJWT2(CheckSessionID2(CreateSessionID2(h)))
}