huangguangrong 9 сар өмнө
parent
commit
7276083fe2

+ 1 - 1
go.mod

@@ -3,7 +3,7 @@ module slowwild
 go 1.22.0
 
 require (
-	git.banshen.xyz/huangguangrong/slow_wild_protobuff v0.1.6
+	git.banshen.xyz/huangguangrong/slow_wild_protobuff v0.1.8
 	github.com/zeromicro/go-zero v1.8.0
 	google.golang.org/grpc v1.70.0
 )

+ 2 - 0
go.sum

@@ -28,6 +28,8 @@ git.banshen.xyz/huangguangrong/slow_wild_protobuff v0.1.5 h1:sk2kkLJQoO8UUKP0qQ3
 git.banshen.xyz/huangguangrong/slow_wild_protobuff v0.1.5/go.mod h1:v8AzHCelFBbIkoY+gR4WIEuc6mG5okJ12IXbjvGJWHk=
 git.banshen.xyz/huangguangrong/slow_wild_protobuff v0.1.6 h1:AZtJwYwmLnq/pDYtPd4vm2TU1UGFZHPJevhQEvjcnbM=
 git.banshen.xyz/huangguangrong/slow_wild_protobuff v0.1.6/go.mod h1:v8AzHCelFBbIkoY+gR4WIEuc6mG5okJ12IXbjvGJWHk=
+git.banshen.xyz/huangguangrong/slow_wild_protobuff v0.1.8 h1:BW3DLwhDNATa45Jyzp18LCbXPNGEcz9k3v9mct81ayE=
+git.banshen.xyz/huangguangrong/slow_wild_protobuff v0.1.8/go.mod h1:RvtFTWaCnJcB8iy/clOqBoimd7UxSx5+KB96G+bdq5c=
 github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 h1:uvdUDbHQHO85qeSydJtItA4T55Pw6BtAejd0APRJOCE=
 github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc=
 github.com/alicebob/miniredis/v2 v2.34.0 h1:mBFWMaJSNL9RwdGRyEDoAAv8OQc5UlEhLDQggTglU/0=

+ 3 - 0
internal/constants/common.go

@@ -0,0 +1,3 @@
+package constants
+
+const UserIDKey = "user_id"

+ 10 - 4
internal/logic/commentdeletelogic.go

@@ -2,12 +2,14 @@ package logic
 
 import (
 	"context"
+	"slowwild/internal/constants"
 	"slowwild/internal/errorx"
 	"slowwild/internal/model"
 	"slowwild/internal/svc"
 
 	"git.banshen.xyz/huangguangrong/slow_wild_protobuff/slowwild/slowwildserver"
 
+	"github.com/spf13/cast"
 	"github.com/zeromicro/go-zero/core/logx"
 	"gorm.io/gorm"
 )
@@ -28,12 +30,16 @@ func NewCommentDeleteLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Com
 
 // 删除回复/评论
 func (l *CommentDeleteLogic) CommentDelete(in *slowwildserver.CommentDeleteReq) (*slowwildserver.CommentDeleteRsp, error) {
-	if in.UserId <= 0 || in.CommentId <= 0 {
+	userId, err := cast.ToInt64E(l.ctx.Value(constants.UserIDKey))
+	if err != nil {
+		return nil, errorx.NewCodeError(50001, "用户未登录")
+	}
+	if userId <= 0 || in.CommentId <= 0 {
 		return nil, errorx.ErrInvalidParam
 	}
 
 	// 开启事务
-	err := l.svcCtx.UserModel.Transaction(l.ctx, func(tx *gorm.DB) error {
+	err = l.svcCtx.UserModel.Transaction(l.ctx, func(tx *gorm.DB) error {
 		postModel := model.NewPostModel(tx, l.svcCtx.Redis)
 		commentModel := model.NewCommentModel(tx, l.svcCtx.Redis)
 		replyModel := model.NewReplyModel(tx, l.svcCtx.Redis)
@@ -46,7 +52,7 @@ func (l *CommentDeleteLogic) CommentDelete(in *slowwildserver.CommentDeleteReq)
 			}
 
 			// 检查是否是自己的评论
-			if comment.UserId != in.UserId {
+			if comment.UserId != userId {
 				return errorx.NewCodeError(20012, "无权删除该评论")
 			}
 
@@ -72,7 +78,7 @@ func (l *CommentDeleteLogic) CommentDelete(in *slowwildserver.CommentDeleteReq)
 			}
 
 			// 检查是否是自己的回复
-			if reply.UserId != in.UserId {
+			if reply.UserId != userId {
 				return errorx.NewCodeError(20012, "无权删除该回复")
 			}
 

+ 12 - 7
internal/logic/createpostlogic.go

@@ -2,6 +2,7 @@ package logic
 
 import (
 	"context"
+	"slowwild/internal/constants"
 	"slowwild/internal/svc"
 	"strings"
 
@@ -34,7 +35,11 @@ func NewCreatePostLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Create
 
 // 发布帖子
 func (l *CreatePostLogic) CreatePost(in *slowwildserver.CreatePostReq) (*slowwildserver.CreatePostRsp, error) {
-	if in.UserId <= 0 || in.Content == "" {
+	userId, err := cast.ToInt64E(l.ctx.Value(constants.UserIDKey))
+	if err != nil {
+		return nil, errorx.NewCodeError(50001, "用户未登录")
+	}
+	if userId <= 0 || in.Content == "" {
 		return nil, errorx.ErrInvalidParam
 	}
 
@@ -44,7 +49,7 @@ func (l *CreatePostLogic) CreatePost(in *slowwildserver.CreatePostReq) (*slowwil
 	}
 
 	// 开启事务
-	err := l.svcCtx.UserModel.Transaction(l.ctx, func(tx *gorm.DB) error {
+	err = l.svcCtx.UserModel.Transaction(l.ctx, func(tx *gorm.DB) error {
 		postModel := model.NewPostModel(tx, l.svcCtx.Redis)
 		postContentModel := model.NewPostContentModel(tx)
 		userModel := model.NewUserModel(tx)
@@ -78,7 +83,7 @@ func (l *CreatePostLogic) CreatePost(in *slowwildserver.CreatePostReq) (*slowwil
 						CreatedOn:  time.Now().Unix(),
 						ModifiedOn: time.Now().Unix(),
 					},
-					UserId: in.UserId,
+					UserId: userId,
 					Name:   tag.Name,
 					HotNum: 0,
 				})
@@ -121,7 +126,7 @@ func (l *CreatePostLogic) CreatePost(in *slowwildserver.CreatePostReq) (*slowwil
 				CreatedOn:  time.Now().Unix(),
 				ModifiedOn: time.Now().Unix(),
 			},
-			UserId:          in.UserId,
+			UserId:          userId,
 			PostType:        in.Type,
 			Visibility:      int8(in.Visibility),
 			Tags:            strings.Join(tagNames, ","),
@@ -153,7 +158,7 @@ func (l *CreatePostLogic) CreatePost(in *slowwildserver.CreatePostReq) (*slowwil
 				ModifiedOn: time.Now().Unix(),
 			},
 			PostId:         post.ID,
-			UserId:         in.UserId,
+			UserId:         userId,
 			Title:          in.Title,
 			Content:        in.Content,     // 保存原始内容
 			ContentSummary: contentSummary, // 保存处理后的摘要
@@ -181,7 +186,7 @@ func (l *CreatePostLogic) CreatePost(in *slowwildserver.CreatePostReq) (*slowwil
 
 			// 发送站内信给被艾特的用户
 			for _, user := range atUsers {
-				err = messageModel.SendNotification(l.ctx, in.UserId, user.ID, 1, "你被艾特了", "你被艾特在一条动态中", post.ID, 0, 0)
+				err = messageModel.SendNotification(l.ctx, userId, user.ID, 1, "你被艾特了", "你被艾特在一条动态中", post.ID, 0, 0)
 				if err != nil {
 					return err
 				}
@@ -189,7 +194,7 @@ func (l *CreatePostLogic) CreatePost(in *slowwildserver.CreatePostReq) (*slowwil
 		}
 
 		// 增加用户的帖子数
-		if err := userModel.IncrementTweetCount(l.ctx, in.UserId); err != nil {
+		if err := userModel.IncrementTweetCount(l.ctx, userId); err != nil {
 			return err
 		}
 

+ 22 - 14
internal/logic/getpostcommentlistlogic.go

@@ -2,6 +2,7 @@ package logic
 
 import (
 	"context"
+	"slowwild/internal/constants"
 	"slowwild/internal/errorx"
 	"slowwild/internal/model"
 	"slowwild/internal/svc"
@@ -10,6 +11,7 @@ import (
 
 	"git.banshen.xyz/huangguangrong/slow_wild_protobuff/slowwild/slowwildserver"
 
+	"github.com/spf13/cast"
 	"github.com/zeromicro/go-zero/core/logx"
 )
 
@@ -29,6 +31,8 @@ func NewGetPostCommentListLogic(ctx context.Context, svcCtx *svc.ServiceContext)
 
 // 获取评论列表
 func (l *GetPostCommentListLogic) GetPostCommentList(in *slowwildserver.GetPostCommentListReq) (*slowwildserver.GetPostCommentListRsp, error) {
+	userId, _ := cast.ToInt64E(l.ctx.Value(constants.UserIDKey))
+
 	if in.PostId <= 0 || in.Page <= 0 || in.PageSize <= 0 {
 		return nil, errorx.ErrInvalidParam
 	}
@@ -105,12 +109,14 @@ func (l *GetPostCommentListLogic) GetPostCommentList(in *slowwildserver.GetPostC
 			l.Logger.Errorf("获取评论回复失败: %v", err)
 			continue
 		}
-
-		// 获取评论点赞状态
-		isUpvoted, err := l.svcCtx.CommentModel.IsCommentUpvotedByUser(l.ctx, in.UserId, comment.ID)
-		if err != nil {
-			l.Logger.Errorf("获取评论点赞状态失败: %v", err)
-			continue
+		var isUpvoted bool
+		if userId > 0 {
+			// 获取评论点赞状态
+			isUpvoted, err = l.svcCtx.CommentModel.IsCommentUpvotedByUser(l.ctx, userId, comment.ID)
+			if err != nil {
+				l.Logger.Errorf("获取评论点赞状态失败: %v", err)
+				continue
+			}
 		}
 
 		// 构建回复列表
@@ -137,12 +143,14 @@ func (l *GetPostCommentListLogic) GetPostCommentList(in *slowwildserver.GetPostC
 					}
 				}
 			}
-
-			// 获取回复点赞状态
-			isReplyUpvoted, err := l.svcCtx.CommentModel.IsCommentUpvotedByUser(l.ctx, in.UserId, reply.ID)
-			if err != nil {
-				l.Logger.Errorf("获取回复点赞状态失败: %v", err)
-				continue
+			var isReplyUpvoted bool
+			if userId > 0 {
+				// 获取回复点赞状态
+				isReplyUpvoted, err = l.svcCtx.CommentModel.IsCommentUpvotedByUser(l.ctx, userId, reply.ID)
+				if err != nil {
+					l.Logger.Errorf("获取回复点赞状态失败: %v", err)
+					continue
+				}
 			}
 
 			replyList = append(replyList, &slowwildserver.RepliesItem{
@@ -158,7 +166,7 @@ func (l *GetPostCommentListLogic) GetPostCommentList(in *slowwildserver.GetPostC
 				UpvoteCount: reply.ThumbsUpCount,
 				IpLoc:       reply.IpLoc,
 				CreatedOn:   reply.CreatedOn,
-				IsMine:      in.UserId == reply.UserId,
+				IsMine:      userId == reply.UserId,
 				IsLiked:     isReplyUpvoted,
 			})
 		}
@@ -177,7 +185,7 @@ func (l *GetPostCommentListLogic) GetPostCommentList(in *slowwildserver.GetPostC
 			UpvoteCount: comment.ThumbsUpCount,
 			IpLoc:       comment.IpLoc,
 			CreatedOn:   comment.CreatedOn,
-			IsMine:      in.UserId == comment.UserId,
+			IsMine:      userId == comment.UserId,
 			ReplyItem:   replyList,
 			IsLiked:     isUpvoted,
 		})

+ 17 - 10
internal/logic/getpostlistlogic.go

@@ -2,6 +2,7 @@ package logic
 
 import (
 	"context"
+	"slowwild/internal/constants"
 	"slowwild/internal/errorx"
 	"slowwild/internal/model"
 	"slowwild/internal/svc"
@@ -11,6 +12,7 @@ import (
 	"strconv"
 	"strings"
 
+	"github.com/spf13/cast"
 	"github.com/zeromicro/go-zero/core/logx"
 )
 
@@ -30,6 +32,8 @@ func NewGetPostListLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetPo
 
 // 获取帖子列表
 func (l *GetPostListLogic) GetPostList(in *slowwildserver.GetPostListReq) (*slowwildserver.GetPostListRsp, error) {
+	userId, _ := cast.ToInt64E(l.ctx.Value(constants.UserIDKey))
+
 	if in.Page <= 0 || in.PageSize <= 0 {
 		return nil, errorx.ErrInvalidParam
 	}
@@ -106,17 +110,20 @@ func (l *GetPostListLogic) GetPostList(in *slowwildserver.GetPostListReq) (*slow
 	}
 
 	for _, post := range posts {
+		var isLiked, isCollected bool
 		// 查询用户状态
-		isLiked, err := l.svcCtx.UserModel.IsPostLikedByUser(l.ctx, in.UserId, post.ID)
-		if err != nil {
-			l.Logger.Errorf("查询用户点赞状态失败: %v", err)
-			return nil, errorx.ErrPostQueryFailed
-		}
+		if userId > 0 {
+			isLiked, err = l.svcCtx.UserModel.IsPostLikedByUser(l.ctx, userId, post.ID)
+			if err != nil {
+				l.Logger.Errorf("查询用户点赞状态失败: %v", err)
+				return nil, errorx.ErrPostQueryFailed
+			}
 
-		isCollected, err := l.svcCtx.UserModel.IsPostCollectedByUser(l.ctx, in.UserId, post.ID)
-		if err != nil {
-			l.Logger.Errorf("查询用户收藏状态失败: %v", err)
-			return nil, errorx.ErrPostQueryFailed
+			isCollected, err = l.svcCtx.UserModel.IsPostCollectedByUser(l.ctx, userId, post.ID)
+			if err != nil {
+				l.Logger.Errorf("查询用户收藏状态失败: %v", err)
+				return nil, errorx.ErrPostQueryFailed
+			}
 		}
 
 		// 获取话题信息
@@ -173,7 +180,7 @@ func (l *GetPostListLogic) GetPostList(in *slowwildserver.GetPostListReq) (*slow
 			CreatedOn:       post.CreatedOn,
 			IsLiked:         isLiked,
 			IsCollected:     isCollected,
-			IsMine:          in.UserId == post.UserId,
+			IsMine:          userId == post.UserId,
 		})
 	}
 

+ 16 - 11
internal/logic/getpostlogic.go

@@ -2,6 +2,7 @@ package logic
 
 import (
 	"context"
+	"slowwild/internal/constants"
 	"slowwild/internal/errorx"
 	"slowwild/internal/svc"
 	"strconv"
@@ -9,6 +10,7 @@ import (
 
 	"git.banshen.xyz/huangguangrong/slow_wild_protobuff/slowwild/slowwildserver"
 
+	"github.com/spf13/cast"
 	"github.com/zeromicro/go-zero/core/logx"
 )
 
@@ -28,6 +30,7 @@ func NewGetPostLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetPostLo
 
 // 获取动态详情
 func (l *GetPostLogic) GetPost(in *slowwildserver.GetPostReq) (*slowwildserver.GetPostRsp, error) {
+	userId, _ := cast.ToInt64E(l.ctx.Value(constants.UserIDKey))
 	if in.PostId <= 0 {
 		return nil, errorx.ErrInvalidParam
 	}
@@ -79,18 +82,20 @@ func (l *GetPostLogic) GetPost(in *slowwildserver.GetPostReq) (*slowwildserver.G
 			})
 		}
 	}
-
+	var isLiked, isCollected bool
 	// 获取用户交互状态
-	isLiked, err := l.svcCtx.UserModel.IsPostLikedByUser(l.ctx, in.UserId, post.ID)
-	if err != nil {
-		l.Logger.Errorf("查询用户点赞状态失败: %v", err)
-		return nil, errorx.ErrPostQueryFailed
-	}
+	if userId > 0 {
+		isLiked, err = l.svcCtx.UserModel.IsPostLikedByUser(l.ctx, userId, post.ID)
+		if err != nil {
+			l.Logger.Errorf("查询用户点赞状态失败: %v", err)
+			return nil, errorx.ErrPostQueryFailed
+		}
 
-	isCollected, err := l.svcCtx.UserModel.IsPostCollectedByUser(l.ctx, in.UserId, post.ID)
-	if err != nil {
-		l.Logger.Errorf("查询用户收藏状态失败: %v", err)
-		return nil, errorx.ErrPostQueryFailed
+		isCollected, err = l.svcCtx.UserModel.IsPostCollectedByUser(l.ctx, userId, post.ID)
+		if err != nil {
+			l.Logger.Errorf("查询用户收藏状态失败: %v", err)
+			return nil, errorx.ErrPostQueryFailed
+		}
 	}
 
 	// 获取@用户信息
@@ -153,7 +158,7 @@ func (l *GetPostLogic) GetPost(in *slowwildserver.GetPostReq) (*slowwildserver.G
 		CreatedOn:       post.CreatedOn,
 		IsLiked:         isLiked,
 		IsCollected:     isCollected,
-		IsMine:          in.UserId == post.UserId,
+		IsMine:          userId == post.UserId,
 		LatestRepliedOn: post.LatestRepliedOn,
 		WithUser:        atUserInfos,
 	}

+ 13 - 6
internal/logic/getreplylistlogic.go

@@ -2,6 +2,7 @@ package logic
 
 import (
 	"context"
+	"slowwild/internal/constants"
 	"slowwild/internal/errorx"
 	"slowwild/internal/model"
 	"slowwild/internal/svc"
@@ -10,6 +11,7 @@ import (
 
 	"git.banshen.xyz/huangguangrong/slow_wild_protobuff/slowwild/slowwildserver"
 
+	"github.com/spf13/cast"
 	"github.com/zeromicro/go-zero/core/logx"
 )
 
@@ -29,6 +31,9 @@ func NewGetReplyListLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetR
 
 // 获取回复列表
 func (l *GetReplyListLogic) GetReplyList(in *slowwildserver.GetReplyListReq) (*slowwildserver.GetReplyListRsp, error) {
+	// 从context中获取userId
+	userId, _ := cast.ToInt64E(l.ctx.Value(constants.UserIDKey))
+
 	if in.CommentId <= 0 || in.Page <= 0 || in.PageSize <= 0 {
 		return nil, errorx.ErrInvalidParam
 	}
@@ -98,12 +103,14 @@ func (l *GetReplyListLogic) GetReplyList(in *slowwildserver.GetReplyListReq) (*s
 				}
 			}
 		}
-
+		var isUpvoted bool
 		// 获取回复点赞状态
-		isUpvoted, err := l.svcCtx.CommentModel.IsCommentUpvotedByUser(l.ctx, in.UserId, reply.ID)
-		if err != nil {
-			l.Logger.Errorf("获取回复点赞状态失败: %v", err)
-			continue
+		if userId > 0 {
+			isUpvoted, err = l.svcCtx.CommentModel.IsCommentUpvotedByUser(l.ctx, userId, reply.ID)
+			if err != nil {
+				l.Logger.Errorf("获取回复点赞状态失败: %v", err)
+				continue
+			}
 		}
 
 		resp.List = append(resp.List, &slowwildserver.RepliesItem{
@@ -119,7 +126,7 @@ func (l *GetReplyListLogic) GetReplyList(in *slowwildserver.GetReplyListReq) (*s
 			UpvoteCount: reply.ThumbsUpCount,
 			IpLoc:       reply.IpLoc,
 			CreatedOn:   reply.CreatedOn,
-			IsMine:      in.UserId == reply.UserId,
+			IsMine:      userId == reply.UserId,
 			IsLiked:     isUpvoted,
 		})
 	}

+ 7 - 2
internal/logic/gettaglogic.go

@@ -2,11 +2,13 @@ package logic
 
 import (
 	"context"
+	"slowwild/internal/constants"
 	"slowwild/internal/errorx"
 	"slowwild/internal/svc"
 
 	"git.banshen.xyz/huangguangrong/slow_wild_protobuff/slowwild/slowwildserver"
 
+	"github.com/spf13/cast"
 	"github.com/zeromicro/go-zero/core/logx"
 )
 
@@ -26,6 +28,9 @@ func NewGetTagLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetTagLogi
 
 // 获取话题详情
 func (l *GetTagLogic) GetTag(in *slowwildserver.GetTagReq) (*slowwildserver.GetTagRsp, error) {
+	// 从context中获取userId
+	userId, _ := cast.ToInt64E(l.ctx.Value(constants.UserIDKey))
+
 	if in.TagId <= 0 {
 		return nil, errorx.ErrInvalidParam
 	}
@@ -53,8 +58,8 @@ func (l *GetTagLogic) GetTag(in *slowwildserver.GetTagReq) (*slowwildserver.GetT
 
 	// 获取用户关注状态
 	isFollowed := false
-	if in.UserId > 0 {
-		isFollowed, err = l.svcCtx.TagModel.IsTagFollowedByUser(l.ctx, in.UserId, int64(in.TagId))
+	if userId > 0 {
+		isFollowed, err = l.svcCtx.TagModel.IsTagFollowedByUser(l.ctx, userId, int64(in.TagId))
 		if err != nil {
 			l.Logger.Errorf("获取话题关注状态失败: %v", err)
 			return nil, errorx.ErrTagQueryFailed

+ 8 - 3
internal/logic/getuserpostcollectionlistlogic.go

@@ -2,6 +2,7 @@ package logic
 
 import (
 	"context"
+	"slowwild/internal/constants"
 	"slowwild/internal/errorx"
 	"slowwild/internal/model"
 	"slowwild/internal/svc"
@@ -11,6 +12,7 @@ import (
 
 	"strings"
 
+	"github.com/spf13/cast"
 	"github.com/zeromicro/go-zero/core/logx"
 )
 
@@ -30,7 +32,10 @@ func NewGetUserPostCollectionListLogic(ctx context.Context, svcCtx *svc.ServiceC
 
 // 获取用户收藏的帖子
 func (l *GetUserPostCollectionListLogic) GetUserPostCollectionList(in *slowwildserver.GetUserPostCollectionListReq) (*slowwildserver.GetUserPostCollectionListRsp, error) {
-	if in.UserId <= 0 || in.QueryUserId <= 0 || in.Page <= 0 || in.PageSize <= 0 {
+	// 从context中获取userId
+	userId, _ := cast.ToInt64E(l.ctx.Value(constants.UserIDKey))
+
+	if in.QueryUserId <= 0 || in.Page <= 0 || in.PageSize <= 0 {
 		return nil, errorx.ErrInvalidParam
 	}
 
@@ -80,13 +85,13 @@ func (l *GetUserPostCollectionListLogic) GetUserPostCollectionList(in *slowwilds
 		}
 
 		// 检查当前用户是否点赞过该帖子
-		isLiked, err := l.svcCtx.PostActionModel.CheckPostAction(l.ctx, int64(in.UserId), post.ID, 0)
+		isLiked, err := l.svcCtx.PostActionModel.CheckPostAction(l.ctx, userId, post.ID, 0)
 		if err != nil {
 			l.Logger.Errorf("检查点赞状态失败: %v", err)
 		}
 
 		// 检查当前用户是否收藏过该帖子
-		isCollected, err := l.svcCtx.PostActionModel.CheckPostAction(l.ctx, int64(in.UserId), post.ID, 1)
+		isCollected, err := l.svcCtx.PostActionModel.CheckPostAction(l.ctx, userId, post.ID, 1)
 		if err != nil {
 			l.Logger.Errorf("检查收藏状态失败: %v", err)
 		}

+ 19 - 7
internal/logic/getuserpostlikelistlogic.go

@@ -2,6 +2,7 @@ package logic
 
 import (
 	"context"
+	"slowwild/internal/constants"
 	"slowwild/internal/errorx"
 	"slowwild/internal/model"
 	"slowwild/internal/svc"
@@ -10,6 +11,7 @@ import (
 
 	"git.banshen.xyz/huangguangrong/slow_wild_protobuff/slowwild/slowwildserver"
 
+	"github.com/spf13/cast"
 	"github.com/zeromicro/go-zero/core/logx"
 )
 
@@ -29,7 +31,10 @@ func NewGetUserPostLikeListLogic(ctx context.Context, svcCtx *svc.ServiceContext
 
 // 获取用户点赞过的帖子
 func (l *GetUserPostLikeListLogic) GetUserPostLikeList(in *slowwildserver.GetUserPostLikeListReq) (*slowwildserver.GetUserPostLikeListRsp, error) {
-	if in.UserId <= 0 || in.QueryUserId <= 0 || in.Page <= 0 || in.PageSize <= 0 {
+	// 从context中获取userId
+	userId, _ := cast.ToInt64E(l.ctx.Value(constants.UserIDKey))
+
+	if in.QueryUserId <= 0 || in.Page <= 0 || in.PageSize <= 0 {
 		return nil, errorx.ErrInvalidParam
 	}
 
@@ -78,16 +83,23 @@ func (l *GetUserPostLikeListLogic) GetUserPostLikeList(in *slowwildserver.GetUse
 			continue
 		}
 
+		var isLiked bool
 		// 检查当前用户是否点赞过该帖子
-		isLiked, err := l.svcCtx.PostActionModel.CheckPostAction(l.ctx, int64(in.UserId), post.ID, 0)
-		if err != nil {
-			l.Logger.Errorf("检查点赞状态失败: %v", err)
+		if userId > 0 {
+			isLiked, err = l.svcCtx.PostActionModel.CheckPostAction(l.ctx, userId, post.ID, 0)
+			if err != nil {
+				l.Logger.Errorf("检查点赞状态失败: %v", err)
+			}
+
 		}
 
+		var isCollected bool
 		// 检查当前用户是否收藏过该帖子
-		isCollected, err := l.svcCtx.PostActionModel.CheckPostAction(l.ctx, int64(in.UserId), post.ID, 1)
-		if err != nil {
-			l.Logger.Errorf("检查收藏状态失败: %v", err)
+		if userId > 0 {
+			isCollected, err = l.svcCtx.PostActionModel.CheckPostAction(l.ctx, userId, post.ID, 1)
+			if err != nil {
+				l.Logger.Errorf("检查收藏状态失败: %v", err)
+			}
 		}
 
 		// 获取话题信息

+ 19 - 7
internal/logic/getuserpostlistlogic.go

@@ -10,6 +10,9 @@ import (
 
 	"git.banshen.xyz/huangguangrong/slow_wild_protobuff/slowwild/slowwildserver"
 
+	"slowwild/internal/constants"
+
+	"github.com/spf13/cast"
 	"github.com/zeromicro/go-zero/core/logx"
 )
 
@@ -29,7 +32,10 @@ func NewGetUserPostListLogic(ctx context.Context, svcCtx *svc.ServiceContext) *G
 
 // 获取用户发布的帖子
 func (l *GetUserPostListLogic) GetUserPostList(in *slowwildserver.GetUserPostListReq) (*slowwildserver.GetUserPostListRsp, error) {
-	if in.UserId <= 0 || in.QueryUserId <= 0 || in.Page <= 0 || in.PageSize <= 0 {
+	// 从context中获取userId
+	userId, _ := cast.ToInt64E(l.ctx.Value(constants.UserIDKey))
+
+	if in.QueryUserId <= 0 || in.Page <= 0 || in.PageSize <= 0 {
 		return nil, errorx.ErrInvalidParam
 	}
 
@@ -78,16 +84,22 @@ func (l *GetUserPostListLogic) GetUserPostList(in *slowwildserver.GetUserPostLis
 			continue
 		}
 
+		var isLiked bool
 		// 检查当前用户是否点赞过该帖子
-		isLiked, err := l.svcCtx.PostActionModel.CheckPostAction(l.ctx, int64(in.UserId), post.ID, 0)
-		if err != nil {
-			l.Logger.Errorf("检查点赞状态失败: %v", err)
+		if userId > 0 {
+			isLiked, err = l.svcCtx.PostActionModel.CheckPostAction(l.ctx, userId, post.ID, 0)
+			if err != nil {
+				l.Logger.Errorf("检查点赞状态失败: %v", err)
+			}
 		}
 
+		var isCollected bool
 		// 检查当前用户是否收藏过该帖子
-		isCollected, err := l.svcCtx.PostActionModel.CheckPostAction(l.ctx, int64(in.UserId), post.ID, 1)
-		if err != nil {
-			l.Logger.Errorf("检查收藏状态失败: %v", err)
+		if userId > 0 {
+			isCollected, err = l.svcCtx.PostActionModel.CheckPostAction(l.ctx, userId, post.ID, 1)
+			if err != nil {
+				l.Logger.Errorf("检查收藏状态失败: %v", err)
+			}
 		}
 
 		// 获取话题信息

+ 14 - 6
internal/logic/postcollectionlogic.go

@@ -2,6 +2,7 @@ package logic
 
 import (
 	"context"
+	"slowwild/internal/constants"
 	"slowwild/internal/errorx"
 	"slowwild/internal/model"
 	"slowwild/internal/svc"
@@ -9,6 +10,7 @@ import (
 
 	"git.banshen.xyz/huangguangrong/slow_wild_protobuff/slowwild/slowwildserver"
 
+	"github.com/spf13/cast"
 	"github.com/zeromicro/go-zero/core/logx"
 	"gorm.io/gorm"
 )
@@ -29,12 +31,18 @@ func NewPostCollectionLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Po
 
 // 收藏帖子
 func (l *PostCollectionLogic) PostCollection(in *slowwildserver.PostCollectionReq) (*slowwildserver.PostCollectionRsp, error) {
-	if in.UserId <= 0 || in.PostId <= 0 {
+	// 从context中获取userId
+	userId, err := cast.ToInt64E(l.ctx.Value(constants.UserIDKey))
+	if err != nil {
+		return nil, errorx.ErrInvalidParam
+	}
+
+	if in.PostId <= 0 {
 		return nil, errorx.ErrInvalidParam
 	}
 
 	// 开启事务
-	err := l.svcCtx.UserModel.Transaction(l.ctx, func(tx *gorm.DB) error {
+	err = l.svcCtx.UserModel.Transaction(l.ctx, func(tx *gorm.DB) error {
 		postModel := model.NewPostModel(tx, l.svcCtx.Redis)
 		userModel := model.NewUserModel(tx)
 		postActionModel := model.NewPostActionModel(tx, l.svcCtx.Redis)
@@ -46,7 +54,7 @@ func (l *PostCollectionLogic) PostCollection(in *slowwildserver.PostCollectionRe
 		}
 
 		// 检查是否已经收藏
-		isCollected, err := userModel.IsPostCollectedByUser(l.ctx, in.UserId, in.PostId)
+		isCollected, err := userModel.IsPostCollectedByUser(l.ctx, userId, in.PostId)
 		if err != nil {
 			return err
 		}
@@ -58,7 +66,7 @@ func (l *PostCollectionLogic) PostCollection(in *slowwildserver.PostCollectionRe
 				ModifiedOn: time.Now().Unix(),
 			},
 			PostId:     in.PostId,
-			UserId:     in.UserId,
+			UserId:     userId,
 			ActionType: 1, // 1表示收藏
 		}
 
@@ -73,9 +81,9 @@ func (l *PostCollectionLogic) PostCollection(in *slowwildserver.PostCollectionRe
 			}
 
 			// 如果不是自己的帖子,发送消息通知
-			if post.UserId != in.UserId {
+			if post.UserId != userId {
 				messageModel := model.NewMessageModel(tx)
-				err = messageModel.SendNotification(l.ctx, in.UserId, post.UserId, 2,
+				err = messageModel.SendNotification(l.ctx, userId, post.UserId, 2,
 					"收藏通知",
 					"有人收藏了你的动态",
 					in.PostId, 0, 0)

+ 13 - 7
internal/logic/postcommentlogic.go

@@ -2,6 +2,7 @@ package logic
 
 import (
 	"context"
+	"slowwild/internal/constants"
 	"slowwild/internal/errorx"
 	"slowwild/internal/model"
 	"slowwild/internal/svc"
@@ -29,12 +30,18 @@ func NewPostCommentLogic(ctx context.Context, svcCtx *svc.ServiceContext) *PostC
 
 // 发布评论
 func (l *PostCommentLogic) PostComment(in *slowwildserver.PostCommentReq) (*slowwildserver.PostCommentRsp, error) {
-	if in.UserId <= 0 || in.PostId <= 0 || in.Content == "" {
+	// 从context中获取userId
+	userId, err := cast.ToInt64E(l.ctx.Value(constants.UserIDKey))
+	if err != nil {
+		return nil, errorx.ErrInvalidParam
+	}
+
+	if in.PostId <= 0 || in.Content == "" {
 		return nil, errorx.ErrInvalidParam
 	}
 
 	// 开启事务
-	err := l.svcCtx.UserModel.Transaction(l.ctx, func(tx *gorm.DB) error {
+	err = l.svcCtx.UserModel.Transaction(l.ctx, func(tx *gorm.DB) error {
 		postModel := model.NewPostModel(tx, l.svcCtx.Redis)
 		commentModel := model.NewCommentModel(tx, l.svcCtx.Redis)
 		messageModel := model.NewMessageModel(tx)
@@ -52,7 +59,7 @@ func (l *PostCommentLogic) PostComment(in *slowwildserver.PostCommentReq) (*slow
 				ModifiedOn: time.Now().Unix(),
 			},
 			PostId:      in.PostId,
-			UserId:      in.UserId,
+			UserId:      userId,
 			Content:     in.Content,
 			Image:       in.Image,
 			WithUserIds: cast.ToString(in.AtUserId),
@@ -71,8 +78,8 @@ func (l *PostCommentLogic) PostComment(in *slowwildserver.PostCommentReq) (*slow
 		}
 
 		// 如果不是评论自己的帖子,发送消息通知
-		if post.UserId != in.UserId {
-			err = messageModel.SendNotification(l.ctx, in.UserId, post.UserId, 3,
+		if post.UserId != userId {
+			err = messageModel.SendNotification(l.ctx, userId, post.UserId, 3,
 				"评论通知",
 				"有人评论了你的动态",
 				in.PostId, comment.ID, 0)
@@ -83,14 +90,13 @@ func (l *PostCommentLogic) PostComment(in *slowwildserver.PostCommentReq) (*slow
 
 		// 如果有@用户,给被@的用户发送消息通知
 		if in.AtUserId > 0 {
-			err = messageModel.SendNotification(l.ctx, in.UserId, in.AtUserId, 4,
+			err = messageModel.SendNotification(l.ctx, userId, in.AtUserId, 4,
 				"@通知",
 				"有人在评论中@了你",
 				in.PostId, comment.ID, 0)
 			if err != nil {
 				return err
 			}
-
 		}
 
 		return nil

+ 16 - 7
internal/logic/postcommentupvotelogic.go

@@ -10,6 +10,9 @@ import (
 
 	"time"
 
+	"slowwild/internal/constants"
+
+	"github.com/spf13/cast"
 	"github.com/zeromicro/go-zero/core/logx"
 	"gorm.io/gorm"
 )
@@ -30,12 +33,18 @@ func NewPostCommentUpvoteLogic(ctx context.Context, svcCtx *svc.ServiceContext)
 
 // 评论点赞
 func (l *PostCommentUpvoteLogic) PostCommentUpvote(in *slowwildserver.PostCommentUpvoteReq) (*slowwildserver.PostCommentUpvoteRsp, error) {
-	if in.UserId <= 0 || in.CommentId <= 0 {
+	// 从context中获取userId
+	userId, err := cast.ToInt64E(l.ctx.Value(constants.UserIDKey))
+	if err != nil {
+		return nil, errorx.ErrInvalidParam
+	}
+
+	if in.CommentId <= 0 {
 		return nil, errorx.ErrInvalidParam
 	}
 
 	// 开启事务
-	err := l.svcCtx.UserModel.Transaction(l.ctx, func(tx *gorm.DB) error {
+	err = l.svcCtx.UserModel.Transaction(l.ctx, func(tx *gorm.DB) error {
 		commentModel := model.NewCommentModel(tx, l.svcCtx.Redis)
 		messageModel := model.NewMessageModel(tx)
 
@@ -46,7 +55,7 @@ func (l *PostCommentUpvoteLogic) PostCommentUpvote(in *slowwildserver.PostCommen
 		}
 
 		// 检查是否已经点赞
-		isUpvoted, err := commentModel.IsCommentUpvotedByUser(l.ctx, in.UserId, in.CommentId)
+		isUpvoted, err := commentModel.IsCommentUpvotedByUser(l.ctx, userId, in.CommentId)
 		if err != nil {
 			return err
 		}
@@ -58,7 +67,7 @@ func (l *PostCommentUpvoteLogic) PostCommentUpvote(in *slowwildserver.PostCommen
 					CreatedOn:  time.Now().Unix(),
 					ModifiedOn: time.Now().Unix(),
 				},
-				UserId:     in.UserId,
+				UserId:     userId,
 				CommentId:  in.CommentId,
 				ActionType: in.CommentType, // 0-评论点赞 1-回复点赞
 			}
@@ -72,12 +81,12 @@ func (l *PostCommentUpvoteLogic) PostCommentUpvote(in *slowwildserver.PostCommen
 			}
 
 			// 如果不是给自己点赞,发送消息通知
-			if comment.UserId != in.UserId {
+			if comment.UserId != userId {
 				notificationType := 3 // 评论通知
 				if in.CommentType == 1 {
 					notificationType = 4 // 回复通知
 				}
-				err = messageModel.SendNotification(l.ctx, in.UserId, comment.UserId, int8(notificationType),
+				err = messageModel.SendNotification(l.ctx, userId, comment.UserId, int8(notificationType),
 					"点赞通知",
 					"有人点赞了你的评论",
 					comment.PostId, in.CommentId, 0)
@@ -87,7 +96,7 @@ func (l *PostCommentUpvoteLogic) PostCommentUpvote(in *slowwildserver.PostCommen
 			}
 		} else {
 			// 取消点赞
-			if err := commentModel.DeleteCommentUpvote(l.ctx, in.UserId, in.CommentId); err != nil {
+			if err := commentModel.DeleteCommentUpvote(l.ctx, userId, in.CommentId); err != nil {
 				return err
 			}
 

+ 13 - 4
internal/logic/postdeletelogic.go

@@ -8,6 +8,9 @@ import (
 
 	"git.banshen.xyz/huangguangrong/slow_wild_protobuff/slowwild/slowwildserver"
 
+	"slowwild/internal/constants"
+
+	"github.com/spf13/cast"
 	"github.com/zeromicro/go-zero/core/logx"
 	"gorm.io/gorm"
 )
@@ -28,12 +31,18 @@ func NewPostDeleteLogic(ctx context.Context, svcCtx *svc.ServiceContext) *PostDe
 
 // 删除帖子
 func (l *PostDeleteLogic) PostDelete(in *slowwildserver.PostDeleteReq) (*slowwildserver.PostDeleteRsp, error) {
-	if in.UserId <= 0 || in.PostId <= 0 {
+	// 从context中获取userId
+	userId, err := cast.ToInt64E(l.ctx.Value(constants.UserIDKey))
+	if err != nil {
+		return nil, errorx.ErrInvalidParam
+	}
+
+	if in.PostId <= 0 {
 		return nil, errorx.ErrInvalidParam
 	}
 
 	// 开启事务
-	err := l.svcCtx.UserModel.Transaction(l.ctx, func(tx *gorm.DB) error {
+	err = l.svcCtx.UserModel.Transaction(l.ctx, func(tx *gorm.DB) error {
 		postModel := model.NewPostModel(tx, l.svcCtx.Redis)
 		userModel := model.NewUserModel(tx)
 
@@ -44,7 +53,7 @@ func (l *PostDeleteLogic) PostDelete(in *slowwildserver.PostDeleteReq) (*slowwil
 		}
 
 		// 检查是否是自己的帖子
-		if post.UserId != in.UserId {
+		if post.UserId != userId {
 			return errorx.NewCodeError(20008, "无权删除该帖子")
 		}
 
@@ -54,7 +63,7 @@ func (l *PostDeleteLogic) PostDelete(in *slowwildserver.PostDeleteReq) (*slowwil
 		}
 
 		// 减少用户的帖子数
-		if err := userModel.DecrementTweetCount(l.ctx, in.UserId); err != nil {
+		if err := userModel.DecrementTweetCount(l.ctx, userId); err != nil {
 			return err
 		}
 

+ 16 - 9
internal/logic/postreplylogic.go

@@ -2,6 +2,7 @@ package logic
 
 import (
 	"context"
+	"slowwild/internal/constants"
 	"slowwild/internal/errorx"
 	"slowwild/internal/model"
 	"slowwild/internal/svc"
@@ -30,12 +31,18 @@ func NewPostReplyLogic(ctx context.Context, svcCtx *svc.ServiceContext) *PostRep
 
 // 发布回复
 func (l *PostReplyLogic) PostReply(in *slowwildserver.PostReplyReq) (*slowwildserver.PostReplyRsp, error) {
-	if in.UserId <= 0 || in.PostId <= 0 || in.CommentId <= 0 || in.Content == "" {
+	// 从context中获取userId
+	userId, err := cast.ToInt64E(l.ctx.Value(constants.UserIDKey))
+	if err != nil {
+		return nil, errorx.ErrInvalidParam
+	}
+
+	if in.PostId <= 0 || in.CommentId <= 0 || in.Content == "" {
 		return nil, errorx.ErrInvalidParam
 	}
 
 	// 开启事务
-	err := l.svcCtx.UserModel.Transaction(l.ctx, func(tx *gorm.DB) error {
+	err = l.svcCtx.UserModel.Transaction(l.ctx, func(tx *gorm.DB) error {
 		postModel := model.NewPostModel(tx, l.svcCtx.Redis)
 		commentModel := model.NewCommentModel(tx, l.svcCtx.Redis)
 		replyModel := model.NewReplyModel(tx, l.svcCtx.Redis)
@@ -54,7 +61,7 @@ func (l *PostReplyLogic) PostReply(in *slowwildserver.PostReplyReq) (*slowwildse
 				ModifiedOn: time.Now().Unix(),
 			},
 			PostId:      in.PostId,
-			UserId:      in.UserId,
+			UserId:      userId,
 			CommentId:   in.CommentId,
 			Content:     in.Content,
 			Image:       in.Image,
@@ -85,8 +92,8 @@ func (l *PostReplyLogic) PostReply(in *slowwildserver.PostReplyReq) (*slowwildse
 			if err != nil {
 				return err
 			}
-			if parentReply.UserId != in.UserId {
-				err = messageModel.SendNotification(l.ctx, in.UserId, parentReply.UserId, 3,
+			if parentReply.UserId != userId {
+				err = messageModel.SendNotification(l.ctx, userId, parentReply.UserId, 3,
 					"回复通知",
 					"有人回复了你的回复",
 					in.PostId, in.CommentId, reply.ID)
@@ -94,8 +101,8 @@ func (l *PostReplyLogic) PostReply(in *slowwildserver.PostReplyReq) (*slowwildse
 					return err
 				}
 			}
-		} else if parentComment.UserId != in.UserId { // 如果是回复评论,需要通知评论作者
-			err = messageModel.SendNotification(l.ctx, in.UserId, parentComment.UserId, 3,
+		} else if parentComment.UserId != userId { // 如果是回复评论,需要通知评论作者
+			err = messageModel.SendNotification(l.ctx, userId, parentComment.UserId, 3,
 				"回复通知",
 				"有人回复了你的评论",
 				in.PostId, in.CommentId, reply.ID)
@@ -105,8 +112,8 @@ func (l *PostReplyLogic) PostReply(in *slowwildserver.PostReplyReq) (*slowwildse
 		}
 
 		// 如果有@用户,给被@的用户发送消息通知
-		if in.AtUserId > 0 && in.AtUserId != in.UserId {
-			err = messageModel.SendNotification(l.ctx, in.UserId, in.AtUserId, 4,
+		if in.AtUserId > 0 && in.AtUserId != userId {
+			err = messageModel.SendNotification(l.ctx, userId, in.AtUserId, 4,
 				"@通知",
 				"有人在回复中@了你",
 				in.PostId, in.CommentId, reply.ID)

+ 15 - 6
internal/logic/postsharelogic.go

@@ -10,6 +10,9 @@ import (
 
 	"time"
 
+	"slowwild/internal/constants"
+
+	"github.com/spf13/cast"
 	"github.com/zeromicro/go-zero/core/logx"
 	"gorm.io/gorm"
 )
@@ -30,12 +33,18 @@ func NewPostShareLogic(ctx context.Context, svcCtx *svc.ServiceContext) *PostSha
 
 // 分享帖子
 func (l *PostShareLogic) PostShare(in *slowwildserver.PostShareReq) (*slowwildserver.PostShareRsp, error) {
-	if in.UserId <= 0 || in.PostId <= 0 {
+	// 从context中获取userId
+	userId, err := cast.ToInt64E(l.ctx.Value(constants.UserIDKey))
+	if err != nil {
+		return nil, errorx.ErrInvalidParam
+	}
+
+	if in.PostId <= 0 {
 		return nil, errorx.ErrInvalidParam
 	}
 
 	// 开启事务
-	err := l.svcCtx.UserModel.Transaction(l.ctx, func(tx *gorm.DB) error {
+	err = l.svcCtx.UserModel.Transaction(l.ctx, func(tx *gorm.DB) error {
 		postModel := model.NewPostModel(tx, l.svcCtx.Redis)
 		postActionModel := model.NewPostActionModel(tx, l.svcCtx.Redis)
 		messageModel := model.NewMessageModel(tx)
@@ -47,7 +56,7 @@ func (l *PostShareLogic) PostShare(in *slowwildserver.PostShareReq) (*slowwildse
 		}
 
 		// 检查是否已经分享过
-		isShared, err := postActionModel.CheckPostAction(l.ctx, in.UserId, in.PostId, 2) // 2表示分享动作
+		isShared, err := postActionModel.CheckPostAction(l.ctx, userId, in.PostId, 2) // 2表示分享动作
 		if err != nil {
 			return err
 		}
@@ -62,7 +71,7 @@ func (l *PostShareLogic) PostShare(in *slowwildserver.PostShareReq) (*slowwildse
 				ModifiedOn: time.Now().Unix(),
 			},
 			PostId:     in.PostId,
-			UserId:     in.UserId,
+			UserId:     userId,
 			ActionType: 2, // 2表示分享
 		}
 		if err := postActionModel.CreatePostAction(l.ctx, postAction); err != nil {
@@ -75,8 +84,8 @@ func (l *PostShareLogic) PostShare(in *slowwildserver.PostShareReq) (*slowwildse
 		}
 
 		// 如果不是分享自己的帖子,发送消息通知
-		if post.UserId != in.UserId {
-			err = messageModel.SendNotification(l.ctx, in.UserId, post.UserId, 5,
+		if post.UserId != userId {
+			err = messageModel.SendNotification(l.ctx, userId, post.UserId, 5,
 				"分享通知",
 				"有人分享了你的动态",
 				in.PostId, 0, 0)

+ 16 - 7
internal/logic/postupvotelogic.go

@@ -9,6 +9,9 @@ import (
 
 	"git.banshen.xyz/huangguangrong/slow_wild_protobuff/slowwild/slowwildserver"
 
+	"slowwild/internal/constants"
+
+	"github.com/spf13/cast"
 	"github.com/zeromicro/go-zero/core/logx"
 	"gorm.io/gorm"
 )
@@ -29,12 +32,18 @@ func NewPostUpvoteLogic(ctx context.Context, svcCtx *svc.ServiceContext) *PostUp
 
 // 点赞帖子
 func (l *PostUpvoteLogic) PostUpvote(in *slowwildserver.PostUpvoteReq) (*slowwildserver.PostUpvoteRsp, error) {
-	if in.UserId <= 0 || in.PostId <= 0 {
+	// 从context中获取userId
+	userId, err := cast.ToInt64E(l.ctx.Value(constants.UserIDKey))
+	if err != nil {
+		return nil, errorx.ErrInvalidParam
+	}
+
+	if in.PostId <= 0 {
 		return nil, errorx.ErrInvalidParam
 	}
 
 	// 开启事务
-	err := l.svcCtx.UserModel.Transaction(l.ctx, func(tx *gorm.DB) error {
+	err = l.svcCtx.UserModel.Transaction(l.ctx, func(tx *gorm.DB) error {
 		postModel := model.NewPostModel(tx, l.svcCtx.Redis)
 		postActionModel := model.NewPostActionModel(tx, l.svcCtx.Redis)
 		messageModel := model.NewMessageModel(tx)
@@ -46,7 +55,7 @@ func (l *PostUpvoteLogic) PostUpvote(in *slowwildserver.PostUpvoteReq) (*slowwil
 		}
 
 		// 检查是否已经点赞过
-		isLiked, err := postActionModel.CheckPostAction(l.ctx, in.UserId, in.PostId, 0) // 0表示点赞动作
+		isLiked, err := postActionModel.CheckPostAction(l.ctx, userId, in.PostId, 0) // 0表示点赞动作
 		if err != nil {
 			return err
 		}
@@ -59,7 +68,7 @@ func (l *PostUpvoteLogic) PostUpvote(in *slowwildserver.PostUpvoteReq) (*slowwil
 					ModifiedOn: time.Now().Unix(),
 				},
 				PostId:     in.PostId,
-				UserId:     in.UserId,
+				UserId:     userId,
 				ActionType: 0, // 0表示点赞
 			}
 			if err := postActionModel.CreatePostAction(l.ctx, postAction); err != nil {
@@ -72,8 +81,8 @@ func (l *PostUpvoteLogic) PostUpvote(in *slowwildserver.PostUpvoteReq) (*slowwil
 			}
 
 			// 如果不是点赞自己的帖子,发送消息通知
-			if post.UserId != in.UserId {
-				err = messageModel.SendNotification(l.ctx, in.UserId, post.UserId, 6,
+			if post.UserId != userId {
+				err = messageModel.SendNotification(l.ctx, userId, post.UserId, 6,
 					"点赞通知",
 					"有人点赞了你的动态",
 					in.PostId, 0, 0)
@@ -90,7 +99,7 @@ func (l *PostUpvoteLogic) PostUpvote(in *slowwildserver.PostUpvoteReq) (*slowwil
 					IsDel:      1,
 				},
 				PostId:     in.PostId,
-				UserId:     in.UserId,
+				UserId:     userId,
 				ActionType: 0,
 			}
 			if err := postActionModel.DeletePostAction(l.ctx, postAction); err != nil {

+ 6 - 9
internal/logic/registerlogic.go

@@ -2,7 +2,6 @@ package logic
 
 import (
 	"context"
-	"time"
 
 	"slowwild/internal/errorx"
 	"slowwild/internal/model"
@@ -55,14 +54,12 @@ func (l *RegisterLogic) Register(in *slowwildserver.RegisterReq) (*slowwildserve
 
 	// 创建用户
 	user := &model.User{
-		Phone:      in.Phone,
-		Password:   encryptedPassword,
-		Salt:       salt,
-		Nickname:   in.Nickname,
-		Sex:        int(in.Sex),
-		Status:     1, // 正常状态
-		CreateTime: time.Now(),
-		UpdateTime: time.Now(),
+		Phone:    in.Phone,
+		Password: encryptedPassword,
+		Salt:     salt,
+		Nickname: in.Nickname,
+		Sex:      int(in.Sex),
+		Status:   1, // 正常状态
 	}
 
 	err = l.svcCtx.UserModel.Create(l.ctx, user)

+ 34 - 0
internal/middleware/auth/userauth.go

@@ -0,0 +1,34 @@
+package auth
+
+import (
+	"context"
+	"slowwild/internal/constants"
+
+	"google.golang.org/grpc"
+	"google.golang.org/grpc/metadata"
+)
+
+// GetUserIDFromContext 从context中获取用户ID
+func GetUserIDFromContext(ctx context.Context) (string, bool) {
+	userID, ok := ctx.Value(constants.UserIDKey).(string)
+	return userID, ok
+}
+
+// NewUserAuthInterceptor 创建一个新的用户认证拦截器
+func NewUserAuthInterceptor() grpc.UnaryServerInterceptor {
+	return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
+		md, ok := metadata.FromIncomingContext(ctx)
+		if !ok {
+			return handler(ctx, req)
+		}
+
+		// 获取user_id
+		userIDs := md.Get("user_id")
+		if len(userIDs) > 0 {
+			// 将user_id写入context
+			ctx = context.WithValue(ctx, constants.UserIDKey, userIDs[0])
+		}
+
+		return handler(ctx, req)
+	}
+}

+ 15 - 19
internal/model/user_model.go

@@ -2,7 +2,6 @@ package model
 
 import (
 	"context"
-	"time"
 
 	"gorm.io/gorm"
 )
@@ -10,28 +9,25 @@ import (
 // User 用户模型
 type User struct {
 	*Model
-	Username        string    `gorm:"column:username;type:varchar(32);not null"`
-	Phone           string    `gorm:"column:phone;type:varchar(11);not null"`
-	Password        string    `gorm:"column:password;type:varchar(128);not null"`
-	Avatar          string    `gorm:"column:avatar;type:varchar(255)"`
-	Gender          int32     `gorm:"column:gender;type:tinyint(1);default:0"` // 0:未知 1:男 2:女
-	Status          int32     `gorm:"column:status;type:tinyint(1);default:1"` // 1:正常 2:禁用
-	LastLoginIp     string    `gorm:"column:last_login_ip;type:varchar(15)"`
-	CreateTime      time.Time `gorm:"column:create_time;autoCreateTime"`
-	UpdateTime      time.Time `gorm:"column:update_time;autoUpdateTime"`
-	Nickname        string    `gorm:"column:nickname;NOT NULL"`                   // 昵称
-	Salt            string    `gorm:"column:salt;NOT NULL"`                       // 盐值
-	Sex             int       `gorm:"column:sex;default:0;NOT NULL"`              // 性别 0 女、1 男
-	LikeCount       int       `gorm:"column:like_count;default:0;NOT NULL"`       // 收获的点赞数量
-	TweetCount      int       `gorm:"column:tweet_count;default:0;NOT NULL"`      // 帖子数量
-	CollectionCount int       `gorm:"column:collection_count;default:0;NOT NULL"` // 收获的收藏数量
-	FollowCount     int       `gorm:"column:follow_count;default:0;NOT NULL"`     // 关注数量
-	FansCount       int       `gorm:"column:fans_count;default:0;NOT NULL"`       // 粉丝数量
+	Username        string `gorm:"column:username;type:varchar(32);not null"`
+	Phone           string `gorm:"column:phone;type:varchar(11);not null"`
+	Password        string `gorm:"column:password;type:varchar(128);not null"`
+	Avatar          string `gorm:"column:avatar;type:varchar(255)"`
+	Status          int32  `gorm:"column:status;type:tinyint(1);default:1"` // 1:正常 2:禁用
+	LastLoginIp     string `gorm:"column:last_login_ip;type:varchar(15)"`
+	Nickname        string `gorm:"column:nickname;NOT NULL"`                   // 昵称
+	Salt            string `gorm:"column:salt;NOT NULL"`                       // 盐值
+	Sex             int    `gorm:"column:sex;default:0;NOT NULL"`              // 性别 0 女、1 男
+	LikeCount       int    `gorm:"column:like_count;default:0;NOT NULL"`       // 收获的点赞数量
+	TweetCount      int    `gorm:"column:tweet_count;default:0;NOT NULL"`      // 帖子数量
+	CollectionCount int    `gorm:"column:collection_count;default:0;NOT NULL"` // 收获的收藏数量
+	FollowCount     int    `gorm:"column:follow_count;default:0;NOT NULL"`     // 关注数量
+	FansCount       int    `gorm:"column:fans_count;default:0;NOT NULL"`       // 粉丝数量
 }
 
 // TableName 表名
 func (User) TableName() string {
-	return "user"
+	return "p_user"
 }
 
 // FindByIds 通过ID列表批量查询用户

+ 4 - 0
slowwild.go

@@ -5,6 +5,7 @@ import (
 	"fmt"
 
 	"slowwild/internal/config"
+	"slowwild/internal/middleware/auth"
 	"slowwild/internal/server"
 	"slowwild/internal/svc"
 
@@ -34,6 +35,9 @@ func main() {
 	})
 	defer s.Stop()
 
+	// 添加用户认证拦截器
+	s.AddUnaryInterceptors(auth.NewUserAuthInterceptor())
+
 	fmt.Printf("Starting rpc server at %s...\n", c.ListenOn)
 	s.Start()
 }