| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657 |
- package middleware
- import (
- "context"
- "errors"
- "net/http"
- "slowwildws/internal/constants"
- "github.com/golang-jwt/jwt"
- xhttp "github.com/zeromicro/x/http"
- )
- // 自定义 Claims 类型
- type MyCustomClaims struct {
- UserAuthID int64 `json:"user_id"`
- jwt.StandardClaims
- }
- type JwtMiddleware struct {
- Secret string
- }
- func NewJwtMiddleware(secret string) *JwtMiddleware {
- return &JwtMiddleware{
- Secret: secret,
- }
- }
- func (j *JwtMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
- return func(w http.ResponseWriter, r *http.Request) {
- var authToken string
- if authToken = r.Header.Get("Sec-Websocket-Protocol"); authToken != "" {
- r.Header.Set("Authorization", authToken)
- }
- parseToken, err := jwt.ParseWithClaims(authToken, &MyCustomClaims{}, func(t *jwt.Token) (interface{}, error) {
- return []byte(j.Secret), nil
- })
- if err != nil {
- w.WriteHeader(http.StatusUnauthorized)
- xhttp.JsonBaseResponseCtx(r.Context(), w, errors.New("token invalid"))
- return
- }
- if !parseToken.Valid {
- w.WriteHeader(http.StatusUnauthorized)
- xhttp.JsonBaseResponseCtx(r.Context(), w, errors.New("token invalid"))
- return
- }
- claims, ok := parseToken.Claims.(*MyCustomClaims)
- if !ok {
- w.WriteHeader(http.StatusUnauthorized)
- xhttp.JsonBaseResponseCtx(r.Context(), w, errors.New("parse token invalid"))
- return
- }
- *r = *r.WithContext(context.WithValue(r.Context(), constants.UserID, claims.UserAuthID))
- next(w, r)
- }
- }
|