optionaljwtmiddleware.go 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. package middleware
  2. import (
  3. "context"
  4. "errors"
  5. "net/http"
  6. "strings"
  7. "github.com/golang-jwt/jwt/v4"
  8. )
  9. const (
  10. jwtAudience = "aud"
  11. jwtExpire = "exp"
  12. jwtId = "jti"
  13. jwtIssueAt = "iat"
  14. jwtIssuer = "iss"
  15. jwtNotBefore = "nbf"
  16. jwtSubject = "sub"
  17. noDetailReason = "no detail reason"
  18. )
  19. type OptionalJwtMiddleware struct {
  20. Secret string
  21. }
  22. func NewOptionalJwtMiddleware(secret string) *OptionalJwtMiddleware {
  23. return &OptionalJwtMiddleware{
  24. Secret: secret,
  25. }
  26. }
  27. func (m *OptionalJwtMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
  28. return func(w http.ResponseWriter, r *http.Request) {
  29. // 尝试从请求头中获取 JWT 令牌
  30. tokenString := r.Header.Get("Authorization")
  31. ctx := r.Context()
  32. if len(tokenString) > 0 {
  33. // 如果提供了 JWT 令牌,则验证它
  34. tokenValue := strings.Split(tokenString, " ")
  35. if len(tokenValue) == 2 {
  36. parsedToken, err := jwt.Parse(tokenValue[1], func(token *jwt.Token) (interface{}, error) {
  37. if _, ok := token.Method.(*jwt.SigningMethodHMAC); ok {
  38. return []byte(m.Secret), nil
  39. }
  40. return nil, errors.New("failed to parse token")
  41. })
  42. if err == nil {
  43. if claims, ok := parsedToken.Claims.(jwt.MapClaims); ok && parsedToken.Valid {
  44. // 将解析到的 claims 添加到请求上下文中
  45. for k, v := range claims {
  46. switch k {
  47. case jwtAudience, jwtExpire, jwtId, jwtIssueAt, jwtIssuer, jwtNotBefore, jwtSubject:
  48. // ignore the standard claims
  49. default:
  50. ctx = context.WithValue(ctx, k, v)
  51. }
  52. }
  53. }
  54. }
  55. }
  56. }
  57. newReq := r.WithContext(ctx)
  58. next(w, newReq)
  59. }
  60. }