465 lines
12 KiB
Go
465 lines
12 KiB
Go
package middlewares
|
||
|
||
import (
|
||
"context"
|
||
"crypto/sha256"
|
||
"encoding/hex"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"log"
|
||
"net"
|
||
"net/http"
|
||
"os"
|
||
"time"
|
||
|
||
"git.insit.tech/psa/rtsp_reader-writer/reader/internal/config"
|
||
jwtware "github.com/gofiber/contrib/jwt"
|
||
"github.com/gofiber/fiber/v2"
|
||
"github.com/golang-jwt/jwt/v5"
|
||
"github.com/sirupsen/logrus"
|
||
)
|
||
|
||
// CheckJWT is custom middleware that checks if token is valid.
|
||
func CheckJWT(c *fiber.Ctx) error {
|
||
// Пропускаем JWT проверку, если есть session_id
|
||
if c.Query("session_id") != "" {
|
||
return c.Next()
|
||
}
|
||
|
||
// Выполняем JWT проверку только при наличии токена
|
||
return jwtware.New(jwtware.Config{
|
||
SigningKey: jwtware.SigningKey{
|
||
Key: config.JwtSecretKey,
|
||
},
|
||
ContextKey: config.ContextKeyUser,
|
||
TokenLookup: "query:token",
|
||
})(c)
|
||
}
|
||
|
||
// CheckJWT2 is custom middleware that checks if token is valid.
|
||
func CheckJWT2(next http.Handler) http.Handler {
|
||
log.Println("before return CheckJWT2")
|
||
|
||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
log.Println("hello from CheckJWT2")
|
||
tokenStr := r.URL.Query().Get("token")
|
||
if tokenStr == "" {
|
||
next.ServeHTTP(w, r)
|
||
return
|
||
}
|
||
|
||
token, err := jwt.Parse(tokenStr, func(t *jwt.Token) (interface{}, error) {
|
||
return config.JwtSecretKey, nil
|
||
})
|
||
|
||
if err != nil || !token.Valid {
|
||
w.WriteHeader(http.StatusUnauthorized)
|
||
return
|
||
}
|
||
|
||
ctx := context.WithValue(r.Context(), config.ContextKeyUser, token)
|
||
next.ServeHTTP(w, r.WithContext(ctx))
|
||
log.Println("CheckJWT2 finished")
|
||
})
|
||
}
|
||
|
||
// CheckSessionID middleware.
|
||
func CheckSessionID(c *fiber.Ctx) error {
|
||
// Если есть валидный JWT в контексте
|
||
if c.Locals(config.ContextKeyUser) != nil {
|
||
return c.Next()
|
||
}
|
||
|
||
// Проверка session_id
|
||
sessionID := c.Query("session_id")
|
||
if sessionID == "" {
|
||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Auth required"})
|
||
}
|
||
|
||
// Получаем данные из памяти
|
||
data, err := getSessionData(sessionID)
|
||
if err != nil {
|
||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Invalid session"})
|
||
}
|
||
|
||
if !validateSessionParams(c, data) {
|
||
deleteSession(sessionID)
|
||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Session parameters mismatch"})
|
||
}
|
||
|
||
if time.Since(data.LastCheck) > 3*time.Minute {
|
||
if !checkTokenInDB(data.Token) {
|
||
deleteSession(sessionID)
|
||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Token expired"})
|
||
}
|
||
data.LastCheck = time.Now()
|
||
saveSession(sessionID, data)
|
||
}
|
||
|
||
c.Locals(config.ContextKeySession, data)
|
||
return c.Next()
|
||
}
|
||
|
||
// CheckSessionID2 middleware.
|
||
func CheckSessionID2(next http.Handler) http.Handler {
|
||
log.Println("before return CheckSessionID2")
|
||
|
||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
log.Println("hello from CheckSessionID2")
|
||
if r.Context().Value(config.ContextKeyUser) != nil {
|
||
next.ServeHTTP(w, r)
|
||
return
|
||
}
|
||
|
||
sessionID := r.URL.Query().Get("session_id")
|
||
if sessionID == "" {
|
||
w.WriteHeader(http.StatusUnauthorized)
|
||
return
|
||
}
|
||
|
||
data, err := getSessionData(sessionID)
|
||
if err != nil {
|
||
w.WriteHeader(http.StatusUnauthorized)
|
||
return
|
||
}
|
||
|
||
if !validateSessionParam2(r, data) {
|
||
deleteSession(sessionID)
|
||
w.WriteHeader(http.StatusUnauthorized)
|
||
return
|
||
}
|
||
|
||
if time.Since(data.LastCheck) > 3*time.Minute {
|
||
if !checkTokenInDB(data.Token) {
|
||
deleteSession(sessionID)
|
||
w.WriteHeader(http.StatusUnauthorized)
|
||
return
|
||
}
|
||
data.LastCheck = time.Now()
|
||
saveSession(sessionID, data)
|
||
}
|
||
|
||
ctx := context.WithValue(r.Context(), config.ContextKeySession, data)
|
||
next.ServeHTTP(w, r.WithContext(ctx))
|
||
log.Println("CheckSessionID2 finished")
|
||
})
|
||
}
|
||
|
||
// Работа с памятью
|
||
func getSessionData(sessionID string) (*config.SessionData, error) {
|
||
// Для работы с файлами (раскомментировать при необходимости):
|
||
// return loadFromFile(sessionID)
|
||
|
||
val, ok := config.SessionStore.Load(sessionID)
|
||
if !ok {
|
||
return nil, fmt.Errorf("session not found")
|
||
}
|
||
|
||
data, ok := val.(*config.SessionData)
|
||
if !ok {
|
||
return nil, fmt.Errorf("invalid session data format")
|
||
}
|
||
return data, nil
|
||
}
|
||
|
||
// Остальные функции без изменений
|
||
func validateSessionParams(c *fiber.Ctx, data *config.SessionData) bool {
|
||
log.Printf("HTTP/1.1 && HTTP/2 Protocol: %s", c.Protocol())
|
||
log.Printf("HTTP/1.1 && HTTP/2 File: %s", c.Params("file"))
|
||
log.Printf("HTTP/1.1 && HTTP/2 IP: %s", c.IP())
|
||
|
||
log.Printf("DB Protocol: %s", data.Proto)
|
||
log.Printf("DB File: %s", data.FileName)
|
||
log.Printf("DB IP: %s", data.IP)
|
||
|
||
return c.Protocol() == data.Proto &&
|
||
c.Params("file") == data.FileName &&
|
||
c.IP() == data.IP
|
||
}
|
||
|
||
// Остальные функции без изменений
|
||
func validateSessionParam2(r *http.Request, data *config.SessionData) bool {
|
||
scheme := "http"
|
||
if r.TLS != nil {
|
||
scheme = "https"
|
||
}
|
||
|
||
file := r.PathValue("file")
|
||
|
||
clientIP, _, err := net.SplitHostPort(r.RemoteAddr)
|
||
if err != nil {
|
||
// Если SplitHostPort не сработал, оставляем как есть
|
||
clientIP = r.RemoteAddr
|
||
}
|
||
|
||
log.Printf("HTTP/3 Protocol: %s", scheme)
|
||
log.Printf("HTTP/3 File: %s", file)
|
||
log.Printf("HTTP/3 IP: %s", clientIP)
|
||
|
||
log.Printf("DB Protocol: %s", data.Proto)
|
||
log.Printf("DB File: %s", data.FileName)
|
||
log.Printf("DB IP: %s", data.IP)
|
||
|
||
return scheme == data.Proto &&
|
||
file == data.FileName &&
|
||
clientIP == data.IP
|
||
}
|
||
|
||
func checkTokenInDB(token string) bool {
|
||
// Ваша реализация проверки токена
|
||
return true // временная заглушка
|
||
}
|
||
|
||
// Генерация session_id
|
||
func generateSessionID(c *fiber.Ctx, token string) string {
|
||
data := fmt.Sprintf("%s|%s|%s|%s",
|
||
c.Protocol(),
|
||
c.Params("file"),
|
||
c.IP(),
|
||
token,
|
||
)
|
||
hash := sha256.Sum256([]byte(data))
|
||
return hex.EncodeToString(hash[:])
|
||
}
|
||
|
||
// Генерация session_id
|
||
func generateSessionID2(r *http.Request, token string) string {
|
||
// Determine protocol.
|
||
scheme := "http"
|
||
if r.TLS != nil {
|
||
scheme = "https"
|
||
}
|
||
|
||
// Determine filename.
|
||
file := r.PathValue("file")
|
||
|
||
// Determine IP.
|
||
clientIP, _, err := net.SplitHostPort(r.RemoteAddr)
|
||
if err != nil {
|
||
// Если SplitHostPort не сработал, оставляем как есть
|
||
clientIP = r.RemoteAddr
|
||
}
|
||
|
||
data := fmt.Sprintf("%s|%s|%s|%s",
|
||
scheme,
|
||
file,
|
||
clientIP,
|
||
token,
|
||
)
|
||
hash := sha256.Sum256([]byte(data))
|
||
return hex.EncodeToString(hash[:])
|
||
}
|
||
|
||
func saveSession(sessionID string, data *config.SessionData) {
|
||
// Для работы с файлами (раскомментировать при необходимости):
|
||
// saveToFile(sessionID, data)
|
||
|
||
config.SessionStore.Store(sessionID, data)
|
||
}
|
||
|
||
func deleteSession(sessionID string) {
|
||
// Для работы с файлами (раскомментировать при необходимости):
|
||
// deleteFromFile(sessionID)
|
||
|
||
config.SessionStore.Delete(sessionID)
|
||
}
|
||
|
||
func saveToFile(sessionID string, data *config.SessionData) error {
|
||
file, _ := os.ReadFile(config.SessionFile)
|
||
var sessions []config.FileSession
|
||
json.Unmarshal(file, &sessions)
|
||
|
||
// Удаляем старую сессию если существует
|
||
for i, s := range sessions {
|
||
if s.ID == sessionID {
|
||
sessions = append(sessions[:i], sessions[i+1:]...)
|
||
break
|
||
}
|
||
}
|
||
|
||
sessions = append(sessions, config.FileSession{ID: sessionID, Data: data})
|
||
bytes, _ := json.MarshalIndent(sessions, "", " ")
|
||
return os.WriteFile(config.SessionFile, bytes, 0644)
|
||
}
|
||
|
||
func loadFromFile(sessionID string) (*config.SessionData, error) {
|
||
file, err := os.ReadFile(config.SessionFile)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var sessions []config.FileSession
|
||
json.Unmarshal(file, &sessions)
|
||
|
||
for _, s := range sessions {
|
||
if s.ID == sessionID {
|
||
return s.Data, nil
|
||
}
|
||
}
|
||
return nil, fmt.Errorf("session not found")
|
||
}
|
||
|
||
func deleteFromFile(sessionID string) error {
|
||
file, _ := os.ReadFile(config.SessionFile)
|
||
var sessions []config.FileSession
|
||
json.Unmarshal(file, &sessions)
|
||
|
||
for i, s := range sessions {
|
||
if s.ID == sessionID {
|
||
sessions = append(sessions[:i], sessions[i+1:]...)
|
||
bytes, _ := json.MarshalIndent(sessions, "", " ")
|
||
return os.WriteFile(config.SessionFile, bytes, 0644)
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// CreateSessionID is custom middleware that creates session ID by using token.
|
||
func CreateSessionID(c *fiber.Ctx) error {
|
||
if token := c.Query("token"); token != "" {
|
||
// Get payload from token.
|
||
jwtPayload, ok := jwtPayloadFromRequest(c)
|
||
if !ok {
|
||
return c.SendStatus(fiber.StatusUnauthorized)
|
||
}
|
||
|
||
// Check if the owner of the token is registered.
|
||
_, ok = config.Storage.Users[jwtPayload["sub"].(string)]
|
||
if !ok {
|
||
return errors.New("user not found")
|
||
}
|
||
|
||
// Create session data.
|
||
sessionData := &config.SessionData{
|
||
Proto: c.Protocol(),
|
||
FileName: c.Params("file"),
|
||
IP: c.IP(),
|
||
Token: token,
|
||
LastCheck: time.Now(),
|
||
}
|
||
|
||
// Генерируем session_id
|
||
sessionID := generateSessionID(c, token)
|
||
|
||
// Сохраняем сессию
|
||
saveSession(sessionID, sessionData)
|
||
|
||
return c.JSON(fiber.Map{
|
||
"session_id": sessionID,
|
||
})
|
||
}
|
||
return c.Next()
|
||
}
|
||
|
||
// CreateSessionID2 is custom middleware that creates session ID by using token.
|
||
func CreateSessionID2(next http.Handler) http.Handler {
|
||
log.Println("before return CreateSessionID2")
|
||
|
||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
log.Println("hello from CreateSessionID2")
|
||
|
||
if token := r.URL.Query().Get("token"); token != "" {
|
||
// Get payload from token.
|
||
jwtPayload, ok := jwtPayloadFromRequest2(r)
|
||
if !ok {
|
||
w.WriteHeader(http.StatusUnauthorized)
|
||
return
|
||
}
|
||
|
||
// Check if the owner of the token is registered.
|
||
_, ok = config.Storage.Users[jwtPayload["sub"].(string)]
|
||
if !ok {
|
||
w.WriteHeader(http.StatusUnauthorized)
|
||
w.Write([]byte("user not found"))
|
||
return
|
||
}
|
||
|
||
// Determine protocol.
|
||
scheme := "http"
|
||
if r.TLS != nil {
|
||
scheme = "https"
|
||
}
|
||
|
||
// Determine filename.
|
||
file := r.PathValue("file")
|
||
|
||
// Determine IP.
|
||
clientIP, _, err := net.SplitHostPort(r.RemoteAddr)
|
||
if err != nil {
|
||
// Если SplitHostPort не сработал, оставляем как есть
|
||
clientIP = r.RemoteAddr
|
||
}
|
||
|
||
// Create session data.
|
||
sessionData := &config.SessionData{
|
||
Proto: scheme,
|
||
FileName: file,
|
||
IP: clientIP,
|
||
Token: token,
|
||
LastCheck: time.Now(),
|
||
}
|
||
|
||
// Генерируем session_id
|
||
sessionID := generateSessionID2(r, token)
|
||
|
||
// Сохраняем сессию
|
||
saveSession(sessionID, sessionData)
|
||
|
||
w.Header().Add("Content-Type", "application/json")
|
||
if err := json.NewEncoder(w).Encode(sessionID); err != nil {
|
||
w.WriteHeader(http.StatusInternalServerError)
|
||
}
|
||
return
|
||
}
|
||
next.ServeHTTP(w, r)
|
||
log.Println("hello from CreateSessionID2: StatusOK")
|
||
})
|
||
}
|
||
|
||
func jwtPayloadFromRequest(c *fiber.Ctx) (jwt.MapClaims, bool) {
|
||
jwtToken, ok := c.Context().Value(config.ContextKeyUser).(*jwt.Token)
|
||
if !ok {
|
||
logrus.WithFields(logrus.Fields{
|
||
"jwt_token_context_value": c.Context().Value(config.ContextKeyUser),
|
||
}).Error("wrong type of JWT token in context")
|
||
return nil, false
|
||
}
|
||
|
||
payload, ok := jwtToken.Claims.(jwt.MapClaims)
|
||
if !ok {
|
||
logrus.WithFields(logrus.Fields{
|
||
"jwt_token_claims": jwtToken.Claims,
|
||
}).Error("wrong type of JWT token claims")
|
||
return nil, false
|
||
}
|
||
|
||
return payload, true
|
||
}
|
||
|
||
func jwtPayloadFromRequest2(r *http.Request) (jwt.MapClaims, bool) {
|
||
jwtToken, ok := r.Context().Value(config.ContextKeyUser).(*jwt.Token)
|
||
|
||
if !ok {
|
||
logrus.WithFields(logrus.Fields{
|
||
"jwt_token_context_value": r.Context().Value(config.ContextKeyUser),
|
||
}).Error("wrong type of JWT token in context")
|
||
return nil, false
|
||
}
|
||
|
||
payload, ok := jwtToken.Claims.(jwt.MapClaims)
|
||
if !ok {
|
||
logrus.WithFields(logrus.Fields{
|
||
"jwt_token_claims": jwtToken.Claims,
|
||
}).Error("wrong type of JWT token claims")
|
||
return nil, false
|
||
}
|
||
|
||
return payload, true
|
||
}
|
||
|
||
// AuthMiddleware collects auth middlewares.
|
||
func AuthMiddleware(h http.Handler) http.Handler {
|
||
return CheckJWT2(CheckSessionID2(CreateSessionID2(h)))
|
||
}
|