From 6d20fc3ed6e3979b7be0136c854b7b1ca666ba41 Mon Sep 17 00:00:00 2001 From: nquidox Date: Wed, 10 Sep 2025 20:19:51 +0300 Subject: [PATCH] session id check by refresh token --- cmd/main.go | 18 +++++---------- internal/api/user/controller.go | 32 ++++++++++++++------------ internal/api/user/handler.go | 9 ++++---- internal/api/user/service.go | 8 +++---- internal/interfaces/auth.go | 4 ++-- internal/interfaces/utils.go | 2 +- internal/provider/auth/repository.go | 10 ++++---- internal/provider/auth/service.go | 13 ++++++----- internal/router/handler.go | 32 ++++++++++++-------------- internal/router/middleware.go | 15 +++++------- pkg/utils/userUuid.go | 34 ++++++++-------------------- 11 files changed, 74 insertions(+), 103 deletions(-) diff --git a/cmd/main.go b/cmd/main.go index c630c2c..14cd0bd 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -2,7 +2,6 @@ package main import ( "context" - "fmt" log "github.com/sirupsen/logrus" "merch-parser-api/config" _ "merch-parser-api/docs" //for swagger @@ -50,15 +49,11 @@ func main() { utilsProvider := utils.NewUtils() log.Debug("Utils provider initialized") - //for users package anf router - usersRefreshRoute := fmt.Sprintf("%s/auth/refresh", c.AppConf.ApiPrefix) - //deps providers routerHandler := router.NewRouter(router.Deps{ - ApiPrefix: c.AppConf.ApiPrefix, - GinMode: c.AppConf.GinMode, - TokenProv: jwtProvider, - UsersRefreshRoute: usersRefreshRoute, + ApiPrefix: c.AppConf.ApiPrefix, + GinMode: c.AppConf.GinMode, + TokenProv: jwtProvider, }) log.Debug("Router handler initialized") @@ -71,10 +66,9 @@ func main() { //register app modules users := user.NewHandler(user.Deps{ - Auth: authProvider, - DB: database, - Utils: utilsProvider, - RefreshRoute: usersRefreshRoute, + Auth: authProvider, + DB: database, + Utils: utilsProvider, }) //collect modules diff --git a/internal/api/user/controller.go b/internal/api/user/controller.go index 145e0a0..d4ad1e9 100644 --- a/internal/api/user/controller.go +++ b/internal/api/user/controller.go @@ -11,16 +11,15 @@ import ( ) type controller struct { - service *service - utils interfaces.Utils - refreshRoute string + service *service + utils interfaces.Utils + authPath string } -func newController(service *service, utils interfaces.Utils, refreshRoute string) *controller { +func newController(service *service, utils interfaces.Utils) *controller { return &controller{ - service: service, - utils: utils, - refreshRoute: refreshRoute, + service: service, + utils: utils, } } @@ -33,9 +32,12 @@ func (h *Handler) RegisterRoutes(r *gin.RouterGroup) { userGroup.DELETE("/", h.controller.delete) //auth - userGroup.POST("/login", h.controller.login) - userGroup.POST("/logout", h.controller.logout) - userGroup.POST("/refresh", h.controller.refresh) + h.controller.authPath = "/user/auth" + + authGroup := userGroup.Group("/auth") + authGroup.POST("/login", h.controller.login) + authGroup.POST("/logout", h.controller.logout) + authGroup.POST("/refresh", h.controller.refresh) } func (h *Handler) ExcludeRoutes() []shared.ExcludeRoute { @@ -187,7 +189,7 @@ func (co *controller) login(c *gin.Context) { response.RefreshCookie.Name, response.RefreshCookie.Value, int(time.Until(response.RefreshCookie.Expires).Seconds()), - co.refreshRoute, + co.authPath, "", response.RefreshCookie.Secure, response.RefreshCookie.HttpOnly, @@ -206,14 +208,14 @@ func (co *controller) login(c *gin.Context) { // @Failure 500 {object} responses.ErrorResponse500 // @Router /user/logout [post] func (co *controller) logout(c *gin.Context) { - userUuid, refreshUuid, sessionUuid, err := co.utils.GetAllTokensFromContext(c) + userUuid, refreshUuid, err := co.utils.GetAllTokensFromContext(c) if err != nil { c.JSON(http.StatusBadRequest, responses.ErrorResponse400{Error: err.Error()}) log.WithError(err).Error("User | Failed to get uuids from context on refresh") return } - if err = co.service.logout(userUuid, refreshUuid, sessionUuid); err != nil { + if err = co.service.logout(userUuid, refreshUuid); err != nil { c.JSON(http.StatusInternalServerError, responses.ErrorResponse500{Error: err.Error()}) log.WithError(err).Error("User | Failed to logout") return @@ -232,14 +234,14 @@ func (co *controller) logout(c *gin.Context) { // @Router /user/refresh [post] func (co *controller) refresh(c *gin.Context) { //токены будут помещены в контекст при срабатывании мидлвари авторизации - userUuid, refreshUuid, sessionUuid, err := co.utils.GetAllTokensFromContext(c) + userUuid, refreshUuid, err := co.utils.GetAllTokensFromContext(c) if err != nil { c.JSON(http.StatusBadRequest, responses.ErrorResponse400{Error: err.Error()}) log.WithError(err).Error("User | Failed to get uuids from context on refresh") return } - response, err := co.service.refresh(userUuid, refreshUuid, sessionUuid) + response, err := co.service.refresh(userUuid, refreshUuid) if err != nil { c.JSON(http.StatusInternalServerError, responses.ErrorResponse500{Error: err.Error()}) log.WithError(err).Error("User | Failed to refresh user info") diff --git a/internal/api/user/handler.go b/internal/api/user/handler.go index 9747d1c..2b0c143 100644 --- a/internal/api/user/handler.go +++ b/internal/api/user/handler.go @@ -12,16 +12,15 @@ type Handler struct { } type Deps struct { - Auth interfaces.Auth - DB *gorm.DB - Utils interfaces.Utils - RefreshRoute string + Auth interfaces.Auth + DB *gorm.DB + Utils interfaces.Utils } func NewHandler(deps Deps) *Handler { r := newRepo(deps.DB) s := newService(deps.Auth, r, deps.Utils) - c := newController(s, deps.Utils, deps.RefreshRoute) + c := newController(s, deps.Utils) return &Handler{ controller: c, diff --git a/internal/api/user/service.go b/internal/api/user/service.go index df03998..1bf2240 100644 --- a/internal/api/user/service.go +++ b/internal/api/user/service.go @@ -120,10 +120,10 @@ func (s *service) login(login Login) (shared.AuthData, error) { return authData, nil } -func (s *service) logout(userUuid, refreshUuid, sessionUuid string) error { - return s.auth.Logout(userUuid, refreshUuid, sessionUuid) +func (s *service) logout(userUuid, refreshUuid string) error { + return s.auth.Logout(userUuid, refreshUuid) } -func (s *service) refresh(userUuid, refreshUuid, sessionUuid string) (shared.AuthData, error) { - return s.auth.Refresh(userUuid, refreshUuid, sessionUuid) +func (s *service) refresh(userUuid, refreshUuid string) (shared.AuthData, error) { + return s.auth.Refresh(userUuid, refreshUuid) } diff --git a/internal/interfaces/auth.go b/internal/interfaces/auth.go index 60012cb..dedc506 100644 --- a/internal/interfaces/auth.go +++ b/internal/interfaces/auth.go @@ -4,6 +4,6 @@ import "merch-parser-api/internal/shared" type Auth interface { Login(userUuid string) (shared.AuthData, error) - Logout(userUuid, refreshUuid, sessionUuid string) error - Refresh(userUuid, refreshUuid, sessionUuid string) (shared.AuthData, error) + Logout(userUuid, refreshUuid string) error + Refresh(userUuid, refreshUuid string) (shared.AuthData, error) } diff --git a/internal/interfaces/utils.go b/internal/interfaces/utils.go index d8ccf07..76de065 100644 --- a/internal/interfaces/utils.go +++ b/internal/interfaces/utils.go @@ -5,7 +5,7 @@ import "github.com/gin-gonic/gin" type Utils interface { IsEmail(email string) bool GetUserUuidFromContext(c *gin.Context) (string, error) - GetAllTokensFromContext(c *gin.Context) (string, string, string, error) + GetAllTokensFromContext(c *gin.Context) (string, string, error) HashPassword(password string) (string, error) ComparePasswords(hashedPassword string, plainPassword string) error } diff --git a/internal/provider/auth/repository.go b/internal/provider/auth/repository.go index 2643ed7..af21d03 100644 --- a/internal/provider/auth/repository.go +++ b/internal/provider/auth/repository.go @@ -7,8 +7,8 @@ import ( type Repository interface { CreateRefreshToken(token *Session) error - ReadRefreshToken(userUuid, tokenUuid, sessionUuid string) (Session, error) - InvalidateRefreshToken(userUuid, refreshUuid, sessionUuid string) error + ReadRefreshToken(userUuid, tokenUuid string) (Session, error) + InvalidateRefreshToken(userUuid, refreshUuid string) error } type repo struct { @@ -23,13 +23,12 @@ func (r *repo) CreateRefreshToken(token *Session) error { return r.db.Create(token).Error } -func (r *repo) ReadRefreshToken(userUuid, tokenUuid, sessionUuid string) (Session, error) { +func (r *repo) ReadRefreshToken(userUuid, tokenUuid string) (Session, error) { var tokenData Session if err := r.db. Where("user_uuid = ?", userUuid). Where("refresh_uuid = ?", tokenUuid). - Where("session_uuid = ?", sessionUuid). Where("deleted_at IS NULL"). First(&tokenData).Error; err != nil { return Session{}, err @@ -38,11 +37,10 @@ func (r *repo) ReadRefreshToken(userUuid, tokenUuid, sessionUuid string) (Sessio return tokenData, nil } -func (r *repo) InvalidateRefreshToken(userUuid, refreshUuid, sessionUuid string) error { +func (r *repo) InvalidateRefreshToken(userUuid, refreshUuid string) error { return r.db. Model(&Session{}). Where("user_uuid = ?", userUuid). Where("refresh_uuid = ?", refreshUuid). - Where("session_uuid = ?", sessionUuid). Update("deleted_at", time.Now().UTC()).Error } diff --git a/internal/provider/auth/service.go b/internal/provider/auth/service.go index 09ceb19..27c1538 100644 --- a/internal/provider/auth/service.go +++ b/internal/provider/auth/service.go @@ -26,10 +26,10 @@ func (s *Service) Login(userUuid string) (shared.AuthData, error) { return s.newSession(userUuid) } -func (s *Service) Refresh(userUuid, refreshUuid, sessionUuid string) (shared.AuthData, error) { +func (s *Service) Refresh(userUuid, refreshUuid string) (shared.AuthData, error) { var err error - tokenData, err := s.repo.ReadRefreshToken(userUuid, refreshUuid, sessionUuid) + tokenData, err := s.repo.ReadRefreshToken(userUuid, refreshUuid) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return shared.AuthData{}, errors.New("refresh token is not valid or doesn't exist") @@ -38,19 +38,20 @@ func (s *Service) Refresh(userUuid, refreshUuid, sessionUuid string) (shared.Aut } if time.Now().After(tokenData.Expires) { + _ = s.repo.InvalidateRefreshToken(userUuid, refreshUuid) return shared.AuthData{}, errors.New("token expired") } - err = s.repo.InvalidateRefreshToken(userUuid, refreshUuid, sessionUuid) + err = s.repo.InvalidateRefreshToken(userUuid, refreshUuid) if err != nil { return shared.AuthData{}, err } - return s.updateSession(userUuid, sessionUuid) + return s.updateSession(userUuid, tokenData.SessionUuid) } -func (s *Service) Logout(userUuid, refreshUuid, sessionUuid string) error { - return s.repo.InvalidateRefreshToken(userUuid, refreshUuid, sessionUuid) +func (s *Service) Logout(userUuid, refreshUuid string) error { + return s.repo.InvalidateRefreshToken(userUuid, refreshUuid) } func (s *Service) newSession(userUuid string) (shared.AuthData, error) { diff --git a/internal/router/handler.go b/internal/router/handler.go index 875ed3f..91fdd22 100644 --- a/internal/router/handler.go +++ b/internal/router/handler.go @@ -13,19 +13,17 @@ import ( ) type router struct { - apiPrefix string - engine *gin.Engine - ginMode string - excludeRoutes map[string]shared.ExcludeRoute - tokenProv interfaces.JWTProvider - usersRefreshRoute string + apiPrefix string + engine *gin.Engine + ginMode string + excludeRoutes map[string]shared.ExcludeRoute + tokenProv interfaces.JWTProvider } type Deps struct { - ApiPrefix string - GinMode string - TokenProv interfaces.JWTProvider - UsersRefreshRoute string + ApiPrefix string + GinMode string + TokenProv interfaces.JWTProvider } func NewRouter(deps Deps) interfaces.Router { @@ -50,10 +48,9 @@ func NewRouter(deps Deps) interfaces.Router { })) return &router{ - apiPrefix: deps.ApiPrefix, - engine: engine, - tokenProv: deps.TokenProv, - usersRefreshRoute: deps.UsersRefreshRoute, + apiPrefix: deps.ApiPrefix, + engine: engine, + tokenProv: deps.TokenProv, } } @@ -69,10 +66,9 @@ func (r *router) Set() *gin.Engine { r.engine.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler)) r.engine.Use(authMiddleware(mwDeps{ - prefix: r.apiPrefix, - excludeRoutes: &r.excludeRoutes, - tokenProv: r.tokenProv, - usersRefreshRoute: r.usersRefreshRoute, + prefix: r.apiPrefix, + excludeRoutes: &r.excludeRoutes, + tokenProv: r.tokenProv, })) return r.engine diff --git a/internal/router/middleware.go b/internal/router/middleware.go index 5d3b758..64864bc 100644 --- a/internal/router/middleware.go +++ b/internal/router/middleware.go @@ -10,10 +10,9 @@ import ( ) type mwDeps struct { - prefix string - excludeRoutes *map[string]shared.ExcludeRoute - tokenProv interfaces.JWTProvider - usersRefreshRoute string + prefix string + excludeRoutes *map[string]shared.ExcludeRoute + tokenProv interfaces.JWTProvider } func authMiddleware(deps mwDeps) gin.HandlerFunc { @@ -24,7 +23,7 @@ func authMiddleware(deps mwDeps) gin.HandlerFunc { return } - if c.FullPath() == deps.usersRefreshRoute && c.Request.Method == "POST" { + if (c.FullPath() == "/user/auth/refresh" || c.FullPath() == "/user/auth/logout") && c.Request.Method == "POST" { refreshUuid, err := c.Cookie("refresh_uuid") if err != nil { c.JSON(http.StatusUnauthorized, responses.ErrorResponse401{Error: "Refresh token is required"}) @@ -48,7 +47,7 @@ func authMiddleware(deps mwDeps) gin.HandlerFunc { return } - userUuid, sessionUuid, err := deps.tokenProv.Parse(token) + userUuid, err := deps.tokenProv.Parse(token) if err != nil { c.JSON(http.StatusUnauthorized, responses.ErrorResponse401{Error: err.Error()}) log.WithField("msg", "error parsing jwt").Error("MW | Authorization") @@ -57,11 +56,9 @@ func authMiddleware(deps mwDeps) gin.HandlerFunc { } c.Set("userUuid", userUuid) - c.Set("sessionUuid", sessionUuid) log.WithFields(log.Fields{ - "userUuid": userUuid, - "sessionUuid": sessionUuid, + "userUuid": userUuid, }).Debug("MW | Parsed uuids") if !c.IsAborted() { diff --git a/pkg/utils/userUuid.go b/pkg/utils/userUuid.go index 16b7ea3..847f9f7 100644 --- a/pkg/utils/userUuid.go +++ b/pkg/utils/userUuid.go @@ -28,58 +28,42 @@ func (u *Utils) GetUserUuidFromContext(c *gin.Context) (string, error) { return userUuid.String(), nil } -func (u *Utils) GetAllTokensFromContext(c *gin.Context) (string, string, string, error) { +func (u *Utils) GetAllTokensFromContext(c *gin.Context) (string, string, error) { if c == nil { - return "", "", "", errors.New("context is nil") + return "", "", errors.New("context is nil") } //get user uuid userRaw, exists := c.Get("userUuid") if !exists { - return "", "", "", errors.New("user uuid not found in context") + return "", "", errors.New("user uuid not found in context") } userUuidStr, ok := userRaw.(string) if !ok { - return "", "", "", errors.New("user uuid is not a string") + return "", "", errors.New("user uuid is not a string") } userUuid, err := uuid.Parse(userUuidStr) if err != nil { - return "", "", "", errors.New("error parsing user uuid") + return "", "", errors.New("error parsing user uuid") } //get refresh token uuid refreshRaw, exists := c.Get("refreshUuid") if !exists { - return "", "", "", errors.New("refresh uuid not found in context") + return "", "", errors.New("refresh uuid not found in context") } refreshUuidStr, ok := refreshRaw.(string) if !ok { - return "", "", "", errors.New("refresh uuid is not a string") + return "", "", errors.New("refresh uuid is not a string") } refreshUuid, err := uuid.Parse(refreshUuidStr) if err != nil { - return "", "", "", errors.New("error parsing refresh uuid") + return "", "", errors.New("error parsing refresh uuid") } - //get session token uuid - sessionRaw, exists := c.Get("sessionUuid") - if !exists { - return "", "", "", errors.New("session uuid not found in context") - } - - sessionUuidStr, ok := sessionRaw.(string) - if !ok { - return "", "", "", errors.New("session uuid is not a string") - } - - sessionUuid, err := uuid.Parse(sessionUuidStr) - if err != nil { - return "", "", "", errors.New("error parsing session uuid") - } - - return userUuid.String(), refreshUuid.String(), sessionUuid.String(), nil + return userUuid.String(), refreshUuid.String(), nil }