jwtmiddleware.go 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. package middleware
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "net/http"
  7. "slowwildws/internal/constants"
  8. "github.com/golang-jwt/jwt"
  9. xhttp "github.com/zeromicro/x/http"
  10. )
  11. // 自定义 Claims 类型
  12. type MyCustomClaims struct {
  13. UserAuthID int64 `json:"user_id"`
  14. jwt.StandardClaims
  15. }
  16. type JwtMiddleware struct {
  17. Secret string
  18. }
  19. func NewJwtMiddleware(secret string) *JwtMiddleware {
  20. return &JwtMiddleware{
  21. Secret: secret,
  22. }
  23. }
  24. func (j *JwtMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
  25. return func(w http.ResponseWriter, r *http.Request) {
  26. var authToken string
  27. if authToken = r.Header.Get("Sec-Websocket-Protocol"); authToken != "" {
  28. r.Header.Set("Authorization", authToken)
  29. }
  30. fmt.Println("获取到的token参数: ", authToken)
  31. parseToken, err := jwt.ParseWithClaims(authToken, &MyCustomClaims{}, func(t *jwt.Token) (interface{}, error) {
  32. return []byte(j.Secret), nil
  33. })
  34. if err != nil {
  35. w.WriteHeader(http.StatusUnauthorized)
  36. xhttp.JsonBaseResponseCtx(r.Context(), w, errors.New("token invalid2"))
  37. return
  38. }
  39. if !parseToken.Valid {
  40. w.WriteHeader(http.StatusUnauthorized)
  41. xhttp.JsonBaseResponseCtx(r.Context(), w, errors.New("token invalid"))
  42. return
  43. }
  44. fmt.Println("解析token成功 ")
  45. claims, ok := parseToken.Claims.(*MyCustomClaims)
  46. if !ok {
  47. w.WriteHeader(http.StatusUnauthorized)
  48. xhttp.JsonBaseResponseCtx(r.Context(), w, errors.New("parse token invalid"))
  49. return
  50. }
  51. *r = *r.WithContext(context.WithValue(r.Context(), constants.UserID, claims.UserAuthID))
  52. fmt.Println("获取到的用户id: ", claims.UserAuthID)
  53. next(w, r)
  54. }
  55. }