jwtmiddleware.go 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. package middleware
  2. import (
  3. "context"
  4. "errors"
  5. "net/http"
  6. "slowwildws/internal/constants"
  7. "github.com/golang-jwt/jwt"
  8. xhttp "github.com/zeromicro/x/http"
  9. )
  10. // 自定义 Claims 类型
  11. type MyCustomClaims struct {
  12. UserAuthID int64 `json:"user_id"`
  13. jwt.StandardClaims
  14. }
  15. type JwtMiddleware struct {
  16. Secret string
  17. }
  18. func NewJwtMiddleware(secret string) *JwtMiddleware {
  19. return &JwtMiddleware{
  20. Secret: secret,
  21. }
  22. }
  23. func (j *JwtMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
  24. return func(w http.ResponseWriter, r *http.Request) {
  25. var authToken string
  26. if authToken = r.Header.Get("Sec-Websocket-Protocol"); authToken != "" {
  27. r.Header.Set("Authorization", authToken)
  28. }
  29. parseToken, err := jwt.ParseWithClaims(authToken, &MyCustomClaims{}, func(t *jwt.Token) (interface{}, error) {
  30. return []byte(j.Secret), nil
  31. })
  32. if err != nil {
  33. w.WriteHeader(http.StatusUnauthorized)
  34. xhttp.JsonBaseResponseCtx(r.Context(), w, errors.New("token invalid"))
  35. return
  36. }
  37. if !parseToken.Valid {
  38. w.WriteHeader(http.StatusUnauthorized)
  39. xhttp.JsonBaseResponseCtx(r.Context(), w, errors.New("token invalid"))
  40. return
  41. }
  42. claims, ok := parseToken.Claims.(*MyCustomClaims)
  43. if !ok {
  44. w.WriteHeader(http.StatusUnauthorized)
  45. xhttp.JsonBaseResponseCtx(r.Context(), w, errors.New("parse token invalid"))
  46. return
  47. }
  48. *r = *r.WithContext(context.WithValue(r.Context(), constants.UserID, claims.UserAuthID))
  49. next(w, r)
  50. }
  51. }