| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061 |
- package middleware
- import (
- "context"
- "errors"
- "fmt"
- "net/http"
- "slowwildws/internal/constants"
- "github.com/golang-jwt/jwt"
- xhttp "github.com/zeromicro/x/http"
- )
- // 自定义 Claims 类型
- type MyCustomClaims struct {
- UserAuthID int64 `json:"user_id"`
- jwt.StandardClaims
- }
- type JwtMiddleware struct {
- Secret string
- }
- func NewJwtMiddleware(secret string) *JwtMiddleware {
- return &JwtMiddleware{
- Secret: secret,
- }
- }
- func (j *JwtMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
- return func(w http.ResponseWriter, r *http.Request) {
- var authToken string
- if authToken = r.Header.Get("Sec-Websocket-Protocol"); authToken != "" {
- r.Header.Set("Authorization", authToken)
- }
- fmt.Println("获取到的token参数: ", authToken)
- parseToken, err := jwt.ParseWithClaims(authToken, &MyCustomClaims{}, func(t *jwt.Token) (interface{}, error) {
- return []byte(j.Secret), nil
- })
- if err != nil {
- w.WriteHeader(http.StatusUnauthorized)
- xhttp.JsonBaseResponseCtx(r.Context(), w, errors.New("token invalid2"))
- return
- }
- if !parseToken.Valid {
- w.WriteHeader(http.StatusUnauthorized)
- xhttp.JsonBaseResponseCtx(r.Context(), w, errors.New("token invalid"))
- return
- }
- fmt.Println("解析token成功 ")
- claims, ok := parseToken.Claims.(*MyCustomClaims)
- if !ok {
- w.WriteHeader(http.StatusUnauthorized)
- xhttp.JsonBaseResponseCtx(r.Context(), w, errors.New("parse token invalid"))
- return
- }
- *r = *r.WithContext(context.WithValue(r.Context(), constants.UserID, claims.UserAuthID))
- fmt.Println("获取到的用户id: ", claims.UserAuthID)
- next(w, r)
- }
- }
|