session id check by refresh token

This commit is contained in:
nquidox 2025-09-10 20:19:51 +03:00
parent a8c974994b
commit 6d20fc3ed6
11 changed files with 74 additions and 103 deletions

View file

@ -2,7 +2,6 @@ package main
import ( import (
"context" "context"
"fmt"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"merch-parser-api/config" "merch-parser-api/config"
_ "merch-parser-api/docs" //for swagger _ "merch-parser-api/docs" //for swagger
@ -50,15 +49,11 @@ func main() {
utilsProvider := utils.NewUtils() utilsProvider := utils.NewUtils()
log.Debug("Utils provider initialized") log.Debug("Utils provider initialized")
//for users package anf router
usersRefreshRoute := fmt.Sprintf("%s/auth/refresh", c.AppConf.ApiPrefix)
//deps providers //deps providers
routerHandler := router.NewRouter(router.Deps{ routerHandler := router.NewRouter(router.Deps{
ApiPrefix: c.AppConf.ApiPrefix, ApiPrefix: c.AppConf.ApiPrefix,
GinMode: c.AppConf.GinMode, GinMode: c.AppConf.GinMode,
TokenProv: jwtProvider, TokenProv: jwtProvider,
UsersRefreshRoute: usersRefreshRoute,
}) })
log.Debug("Router handler initialized") log.Debug("Router handler initialized")
@ -74,7 +69,6 @@ func main() {
Auth: authProvider, Auth: authProvider,
DB: database, DB: database,
Utils: utilsProvider, Utils: utilsProvider,
RefreshRoute: usersRefreshRoute,
}) })
//collect modules //collect modules

View file

@ -13,14 +13,13 @@ import (
type controller struct { type controller struct {
service *service service *service
utils interfaces.Utils utils interfaces.Utils
refreshRoute string authPath string
} }
func newController(service *service, utils interfaces.Utils, refreshRoute string) *controller { func newController(service *service, utils interfaces.Utils) *controller {
return &controller{ return &controller{
service: service, service: service,
utils: utils, utils: utils,
refreshRoute: refreshRoute,
} }
} }
@ -33,9 +32,12 @@ func (h *Handler) RegisterRoutes(r *gin.RouterGroup) {
userGroup.DELETE("/", h.controller.delete) userGroup.DELETE("/", h.controller.delete)
//auth //auth
userGroup.POST("/login", h.controller.login) h.controller.authPath = "/user/auth"
userGroup.POST("/logout", h.controller.logout)
userGroup.POST("/refresh", h.controller.refresh) authGroup := userGroup.Group("/auth")
authGroup.POST("/login", h.controller.login)
authGroup.POST("/logout", h.controller.logout)
authGroup.POST("/refresh", h.controller.refresh)
} }
func (h *Handler) ExcludeRoutes() []shared.ExcludeRoute { func (h *Handler) ExcludeRoutes() []shared.ExcludeRoute {
@ -187,7 +189,7 @@ func (co *controller) login(c *gin.Context) {
response.RefreshCookie.Name, response.RefreshCookie.Name,
response.RefreshCookie.Value, response.RefreshCookie.Value,
int(time.Until(response.RefreshCookie.Expires).Seconds()), int(time.Until(response.RefreshCookie.Expires).Seconds()),
co.refreshRoute, co.authPath,
"", "",
response.RefreshCookie.Secure, response.RefreshCookie.Secure,
response.RefreshCookie.HttpOnly, response.RefreshCookie.HttpOnly,
@ -206,14 +208,14 @@ func (co *controller) login(c *gin.Context) {
// @Failure 500 {object} responses.ErrorResponse500 // @Failure 500 {object} responses.ErrorResponse500
// @Router /user/logout [post] // @Router /user/logout [post]
func (co *controller) logout(c *gin.Context) { func (co *controller) logout(c *gin.Context) {
userUuid, refreshUuid, sessionUuid, err := co.utils.GetAllTokensFromContext(c) userUuid, refreshUuid, err := co.utils.GetAllTokensFromContext(c)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, responses.ErrorResponse400{Error: err.Error()}) c.JSON(http.StatusBadRequest, responses.ErrorResponse400{Error: err.Error()})
log.WithError(err).Error("User | Failed to get uuids from context on refresh") log.WithError(err).Error("User | Failed to get uuids from context on refresh")
return return
} }
if err = co.service.logout(userUuid, refreshUuid, sessionUuid); err != nil { if err = co.service.logout(userUuid, refreshUuid); err != nil {
c.JSON(http.StatusInternalServerError, responses.ErrorResponse500{Error: err.Error()}) c.JSON(http.StatusInternalServerError, responses.ErrorResponse500{Error: err.Error()})
log.WithError(err).Error("User | Failed to logout") log.WithError(err).Error("User | Failed to logout")
return return
@ -232,14 +234,14 @@ func (co *controller) logout(c *gin.Context) {
// @Router /user/refresh [post] // @Router /user/refresh [post]
func (co *controller) refresh(c *gin.Context) { func (co *controller) refresh(c *gin.Context) {
//токены будут помещены в контекст при срабатывании мидлвари авторизации //токены будут помещены в контекст при срабатывании мидлвари авторизации
userUuid, refreshUuid, sessionUuid, err := co.utils.GetAllTokensFromContext(c) userUuid, refreshUuid, err := co.utils.GetAllTokensFromContext(c)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, responses.ErrorResponse400{Error: err.Error()}) c.JSON(http.StatusBadRequest, responses.ErrorResponse400{Error: err.Error()})
log.WithError(err).Error("User | Failed to get uuids from context on refresh") log.WithError(err).Error("User | Failed to get uuids from context on refresh")
return return
} }
response, err := co.service.refresh(userUuid, refreshUuid, sessionUuid) response, err := co.service.refresh(userUuid, refreshUuid)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, responses.ErrorResponse500{Error: err.Error()}) c.JSON(http.StatusInternalServerError, responses.ErrorResponse500{Error: err.Error()})
log.WithError(err).Error("User | Failed to refresh user info") log.WithError(err).Error("User | Failed to refresh user info")

View file

@ -15,13 +15,12 @@ type Deps struct {
Auth interfaces.Auth Auth interfaces.Auth
DB *gorm.DB DB *gorm.DB
Utils interfaces.Utils Utils interfaces.Utils
RefreshRoute string
} }
func NewHandler(deps Deps) *Handler { func NewHandler(deps Deps) *Handler {
r := newRepo(deps.DB) r := newRepo(deps.DB)
s := newService(deps.Auth, r, deps.Utils) s := newService(deps.Auth, r, deps.Utils)
c := newController(s, deps.Utils, deps.RefreshRoute) c := newController(s, deps.Utils)
return &Handler{ return &Handler{
controller: c, controller: c,

View file

@ -120,10 +120,10 @@ func (s *service) login(login Login) (shared.AuthData, error) {
return authData, nil return authData, nil
} }
func (s *service) logout(userUuid, refreshUuid, sessionUuid string) error { func (s *service) logout(userUuid, refreshUuid string) error {
return s.auth.Logout(userUuid, refreshUuid, sessionUuid) return s.auth.Logout(userUuid, refreshUuid)
} }
func (s *service) refresh(userUuid, refreshUuid, sessionUuid string) (shared.AuthData, error) { func (s *service) refresh(userUuid, refreshUuid string) (shared.AuthData, error) {
return s.auth.Refresh(userUuid, refreshUuid, sessionUuid) return s.auth.Refresh(userUuid, refreshUuid)
} }

View file

@ -4,6 +4,6 @@ import "merch-parser-api/internal/shared"
type Auth interface { type Auth interface {
Login(userUuid string) (shared.AuthData, error) Login(userUuid string) (shared.AuthData, error)
Logout(userUuid, refreshUuid, sessionUuid string) error Logout(userUuid, refreshUuid string) error
Refresh(userUuid, refreshUuid, sessionUuid string) (shared.AuthData, error) Refresh(userUuid, refreshUuid string) (shared.AuthData, error)
} }

View file

@ -5,7 +5,7 @@ import "github.com/gin-gonic/gin"
type Utils interface { type Utils interface {
IsEmail(email string) bool IsEmail(email string) bool
GetUserUuidFromContext(c *gin.Context) (string, error) GetUserUuidFromContext(c *gin.Context) (string, error)
GetAllTokensFromContext(c *gin.Context) (string, string, string, error) GetAllTokensFromContext(c *gin.Context) (string, string, error)
HashPassword(password string) (string, error) HashPassword(password string) (string, error)
ComparePasswords(hashedPassword string, plainPassword string) error ComparePasswords(hashedPassword string, plainPassword string) error
} }

View file

@ -7,8 +7,8 @@ import (
type Repository interface { type Repository interface {
CreateRefreshToken(token *Session) error CreateRefreshToken(token *Session) error
ReadRefreshToken(userUuid, tokenUuid, sessionUuid string) (Session, error) ReadRefreshToken(userUuid, tokenUuid string) (Session, error)
InvalidateRefreshToken(userUuid, refreshUuid, sessionUuid string) error InvalidateRefreshToken(userUuid, refreshUuid string) error
} }
type repo struct { type repo struct {
@ -23,13 +23,12 @@ func (r *repo) CreateRefreshToken(token *Session) error {
return r.db.Create(token).Error return r.db.Create(token).Error
} }
func (r *repo) ReadRefreshToken(userUuid, tokenUuid, sessionUuid string) (Session, error) { func (r *repo) ReadRefreshToken(userUuid, tokenUuid string) (Session, error) {
var tokenData Session var tokenData Session
if err := r.db. if err := r.db.
Where("user_uuid = ?", userUuid). Where("user_uuid = ?", userUuid).
Where("refresh_uuid = ?", tokenUuid). Where("refresh_uuid = ?", tokenUuid).
Where("session_uuid = ?", sessionUuid).
Where("deleted_at IS NULL"). Where("deleted_at IS NULL").
First(&tokenData).Error; err != nil { First(&tokenData).Error; err != nil {
return Session{}, err return Session{}, err
@ -38,11 +37,10 @@ func (r *repo) ReadRefreshToken(userUuid, tokenUuid, sessionUuid string) (Sessio
return tokenData, nil return tokenData, nil
} }
func (r *repo) InvalidateRefreshToken(userUuid, refreshUuid, sessionUuid string) error { func (r *repo) InvalidateRefreshToken(userUuid, refreshUuid string) error {
return r.db. return r.db.
Model(&Session{}). Model(&Session{}).
Where("user_uuid = ?", userUuid). Where("user_uuid = ?", userUuid).
Where("refresh_uuid = ?", refreshUuid). Where("refresh_uuid = ?", refreshUuid).
Where("session_uuid = ?", sessionUuid).
Update("deleted_at", time.Now().UTC()).Error Update("deleted_at", time.Now().UTC()).Error
} }

View file

@ -26,10 +26,10 @@ func (s *Service) Login(userUuid string) (shared.AuthData, error) {
return s.newSession(userUuid) return s.newSession(userUuid)
} }
func (s *Service) Refresh(userUuid, refreshUuid, sessionUuid string) (shared.AuthData, error) { func (s *Service) Refresh(userUuid, refreshUuid string) (shared.AuthData, error) {
var err error var err error
tokenData, err := s.repo.ReadRefreshToken(userUuid, refreshUuid, sessionUuid) tokenData, err := s.repo.ReadRefreshToken(userUuid, refreshUuid)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return shared.AuthData{}, errors.New("refresh token is not valid or doesn't exist") return shared.AuthData{}, errors.New("refresh token is not valid or doesn't exist")
@ -38,19 +38,20 @@ func (s *Service) Refresh(userUuid, refreshUuid, sessionUuid string) (shared.Aut
} }
if time.Now().After(tokenData.Expires) { if time.Now().After(tokenData.Expires) {
_ = s.repo.InvalidateRefreshToken(userUuid, refreshUuid)
return shared.AuthData{}, errors.New("token expired") return shared.AuthData{}, errors.New("token expired")
} }
err = s.repo.InvalidateRefreshToken(userUuid, refreshUuid, sessionUuid) err = s.repo.InvalidateRefreshToken(userUuid, refreshUuid)
if err != nil { if err != nil {
return shared.AuthData{}, err return shared.AuthData{}, err
} }
return s.updateSession(userUuid, sessionUuid) return s.updateSession(userUuid, tokenData.SessionUuid)
} }
func (s *Service) Logout(userUuid, refreshUuid, sessionUuid string) error { func (s *Service) Logout(userUuid, refreshUuid string) error {
return s.repo.InvalidateRefreshToken(userUuid, refreshUuid, sessionUuid) return s.repo.InvalidateRefreshToken(userUuid, refreshUuid)
} }
func (s *Service) newSession(userUuid string) (shared.AuthData, error) { func (s *Service) newSession(userUuid string) (shared.AuthData, error) {

View file

@ -18,14 +18,12 @@ type router struct {
ginMode string ginMode string
excludeRoutes map[string]shared.ExcludeRoute excludeRoutes map[string]shared.ExcludeRoute
tokenProv interfaces.JWTProvider tokenProv interfaces.JWTProvider
usersRefreshRoute string
} }
type Deps struct { type Deps struct {
ApiPrefix string ApiPrefix string
GinMode string GinMode string
TokenProv interfaces.JWTProvider TokenProv interfaces.JWTProvider
UsersRefreshRoute string
} }
func NewRouter(deps Deps) interfaces.Router { func NewRouter(deps Deps) interfaces.Router {
@ -53,7 +51,6 @@ func NewRouter(deps Deps) interfaces.Router {
apiPrefix: deps.ApiPrefix, apiPrefix: deps.ApiPrefix,
engine: engine, engine: engine,
tokenProv: deps.TokenProv, tokenProv: deps.TokenProv,
usersRefreshRoute: deps.UsersRefreshRoute,
} }
} }
@ -72,7 +69,6 @@ func (r *router) Set() *gin.Engine {
prefix: r.apiPrefix, prefix: r.apiPrefix,
excludeRoutes: &r.excludeRoutes, excludeRoutes: &r.excludeRoutes,
tokenProv: r.tokenProv, tokenProv: r.tokenProv,
usersRefreshRoute: r.usersRefreshRoute,
})) }))
return r.engine return r.engine

View file

@ -13,7 +13,6 @@ type mwDeps struct {
prefix string prefix string
excludeRoutes *map[string]shared.ExcludeRoute excludeRoutes *map[string]shared.ExcludeRoute
tokenProv interfaces.JWTProvider tokenProv interfaces.JWTProvider
usersRefreshRoute string
} }
func authMiddleware(deps mwDeps) gin.HandlerFunc { func authMiddleware(deps mwDeps) gin.HandlerFunc {
@ -24,7 +23,7 @@ func authMiddleware(deps mwDeps) gin.HandlerFunc {
return return
} }
if c.FullPath() == deps.usersRefreshRoute && c.Request.Method == "POST" { if (c.FullPath() == "/user/auth/refresh" || c.FullPath() == "/user/auth/logout") && c.Request.Method == "POST" {
refreshUuid, err := c.Cookie("refresh_uuid") refreshUuid, err := c.Cookie("refresh_uuid")
if err != nil { if err != nil {
c.JSON(http.StatusUnauthorized, responses.ErrorResponse401{Error: "Refresh token is required"}) c.JSON(http.StatusUnauthorized, responses.ErrorResponse401{Error: "Refresh token is required"})
@ -48,7 +47,7 @@ func authMiddleware(deps mwDeps) gin.HandlerFunc {
return return
} }
userUuid, sessionUuid, err := deps.tokenProv.Parse(token) userUuid, err := deps.tokenProv.Parse(token)
if err != nil { if err != nil {
c.JSON(http.StatusUnauthorized, responses.ErrorResponse401{Error: err.Error()}) c.JSON(http.StatusUnauthorized, responses.ErrorResponse401{Error: err.Error()})
log.WithField("msg", "error parsing jwt").Error("MW | Authorization") log.WithField("msg", "error parsing jwt").Error("MW | Authorization")
@ -57,11 +56,9 @@ func authMiddleware(deps mwDeps) gin.HandlerFunc {
} }
c.Set("userUuid", userUuid) c.Set("userUuid", userUuid)
c.Set("sessionUuid", sessionUuid)
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"userUuid": userUuid, "userUuid": userUuid,
"sessionUuid": sessionUuid,
}).Debug("MW | Parsed uuids") }).Debug("MW | Parsed uuids")
if !c.IsAborted() { if !c.IsAborted() {

View file

@ -28,58 +28,42 @@ func (u *Utils) GetUserUuidFromContext(c *gin.Context) (string, error) {
return userUuid.String(), nil return userUuid.String(), nil
} }
func (u *Utils) GetAllTokensFromContext(c *gin.Context) (string, string, string, error) { func (u *Utils) GetAllTokensFromContext(c *gin.Context) (string, string, error) {
if c == nil { if c == nil {
return "", "", "", errors.New("context is nil") return "", "", errors.New("context is nil")
} }
//get user uuid //get user uuid
userRaw, exists := c.Get("userUuid") userRaw, exists := c.Get("userUuid")
if !exists { if !exists {
return "", "", "", errors.New("user uuid not found in context") return "", "", errors.New("user uuid not found in context")
} }
userUuidStr, ok := userRaw.(string) userUuidStr, ok := userRaw.(string)
if !ok { if !ok {
return "", "", "", errors.New("user uuid is not a string") return "", "", errors.New("user uuid is not a string")
} }
userUuid, err := uuid.Parse(userUuidStr) userUuid, err := uuid.Parse(userUuidStr)
if err != nil { if err != nil {
return "", "", "", errors.New("error parsing user uuid") return "", "", errors.New("error parsing user uuid")
} }
//get refresh token uuid //get refresh token uuid
refreshRaw, exists := c.Get("refreshUuid") refreshRaw, exists := c.Get("refreshUuid")
if !exists { if !exists {
return "", "", "", errors.New("refresh uuid not found in context") return "", "", errors.New("refresh uuid not found in context")
} }
refreshUuidStr, ok := refreshRaw.(string) refreshUuidStr, ok := refreshRaw.(string)
if !ok { if !ok {
return "", "", "", errors.New("refresh uuid is not a string") return "", "", errors.New("refresh uuid is not a string")
} }
refreshUuid, err := uuid.Parse(refreshUuidStr) refreshUuid, err := uuid.Parse(refreshUuidStr)
if err != nil { if err != nil {
return "", "", "", errors.New("error parsing refresh uuid") return "", "", errors.New("error parsing refresh uuid")
} }
//get session token uuid return userUuid.String(), refreshUuid.String(), nil
sessionRaw, exists := c.Get("sessionUuid")
if !exists {
return "", "", "", errors.New("session uuid not found in context")
}
sessionUuidStr, ok := sessionRaw.(string)
if !ok {
return "", "", "", errors.New("session uuid is not a string")
}
sessionUuid, err := uuid.Parse(sessionUuidStr)
if err != nil {
return "", "", "", errors.New("error parsing session uuid")
}
return userUuid.String(), refreshUuid.String(), sessionUuid.String(), nil
} }