package middleware import ( "context" "errors" "net/http" "slowwildws/internal/constants" "github.com/golang-jwt/jwt" xhttp "github.com/zeromicro/x/http" ) // 自定义 Claims 类型 type MyCustomClaims struct { UserAuthID int64 `json:"user_id"` jwt.StandardClaims } type JwtMiddleware struct { Secret string } func NewJwtMiddleware(secret string) *JwtMiddleware { return &JwtMiddleware{ Secret: secret, } } func (j *JwtMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var authToken string if authToken = r.Header.Get("Sec-Websocket-Protocol"); authToken != "" { r.Header.Set("Authorization", authToken) } parseToken, err := jwt.ParseWithClaims(authToken, &MyCustomClaims{}, func(t *jwt.Token) (interface{}, error) { return []byte(j.Secret), nil }) if err != nil { w.WriteHeader(http.StatusUnauthorized) xhttp.JsonBaseResponseCtx(r.Context(), w, errors.New("token invalid")) return } if !parseToken.Valid { w.WriteHeader(http.StatusUnauthorized) xhttp.JsonBaseResponseCtx(r.Context(), w, errors.New("token invalid")) return } claims, ok := parseToken.Claims.(*MyCustomClaims) if !ok { w.WriteHeader(http.StatusUnauthorized) xhttp.JsonBaseResponseCtx(r.Context(), w, errors.New("parse token invalid")) return } *r = *r.WithContext(context.WithValue(r.Context(), constants.UserID, claims.UserAuthID)) next(w, r) } }