getpostlistlogic.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. package logic
  2. import (
  3. "context"
  4. "slowwild/internal/constants"
  5. "slowwild/internal/errorx"
  6. "slowwild/internal/model"
  7. "slowwild/internal/svc"
  8. "git.banshen.xyz/huangguangrong/slow_wild_protobuff/slowwild/slowwildserver"
  9. "strconv"
  10. "strings"
  11. "github.com/spf13/cast"
  12. "github.com/zeromicro/go-zero/core/logx"
  13. )
  14. type GetPostListLogic struct {
  15. ctx context.Context
  16. svcCtx *svc.ServiceContext
  17. logx.Logger
  18. }
  19. func NewGetPostListLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetPostListLogic {
  20. return &GetPostListLogic{
  21. ctx: ctx,
  22. svcCtx: svcCtx,
  23. Logger: logx.WithContext(ctx),
  24. }
  25. }
  26. // 获取帖子列表
  27. func (l *GetPostListLogic) GetPostList(in *slowwildserver.GetPostListReq) (*slowwildserver.GetPostListRsp, error) {
  28. userId, _ := cast.ToInt64E(l.ctx.Value(constants.UserIDKey))
  29. if in.Page <= 0 || in.PageSize <= 0 {
  30. return nil, errorx.ErrInvalidParam
  31. }
  32. posts, err := l.svcCtx.PostModel.GetPostList(l.ctx, in)
  33. if err != nil {
  34. l.Logger.Errorf("获取帖子列表失败: %v", err)
  35. return nil, errorx.ErrPostQueryFailed
  36. }
  37. // 构建响应
  38. resp := &slowwildserver.GetPostListRsp{
  39. List: make([]*slowwildserver.GetPostListItem, 0, len(posts)),
  40. }
  41. // 收集所有用户ID
  42. userIds := make([]int64, 0, len(posts))
  43. for _, post := range posts {
  44. userIds = append(userIds, post.UserId)
  45. }
  46. // 批量获取用户信息
  47. users, err := l.svcCtx.UserModel.FindByIds(l.ctx, userIds)
  48. if err != nil {
  49. l.Logger.Errorf("批量获取用户信息失败: %v", err)
  50. return nil, errorx.ErrUserQueryFailed
  51. }
  52. // 构建用户信息map
  53. userMap := make(map[int64]*model.User)
  54. for _, user := range users {
  55. userMap[user.ID] = user
  56. }
  57. // 获取所有话题ID
  58. var allTagIds []int64
  59. for _, post := range posts {
  60. if post.Tags != "" {
  61. tagStrs := strings.Split(post.Tags, ",")
  62. for _, tagStr := range tagStrs {
  63. tagId, err := strconv.ParseInt(tagStr, 10, 64)
  64. if err != nil {
  65. continue
  66. }
  67. allTagIds = append(allTagIds, tagId)
  68. }
  69. }
  70. }
  71. // 批量获取话题信息
  72. tags, err := l.svcCtx.TagModel.FindByIds(l.ctx, allTagIds)
  73. if err != nil {
  74. l.Logger.Errorf("批量获取话题信息失败: %v", err)
  75. return nil, errorx.ErrPostQueryFailed
  76. }
  77. // 构建话题信息map
  78. tagMap := make(map[int64]*model.Tag)
  79. for _, tag := range tags {
  80. tagMap[tag.ID] = tag
  81. }
  82. // 收集所有帖子ID
  83. postIds := make([]int64, 0, len(posts))
  84. for _, post := range posts {
  85. postIds = append(postIds, post.ID)
  86. }
  87. // 批量获取帖子内容
  88. postContents, err := l.svcCtx.PostModel.GetPostContents(l.ctx, postIds)
  89. if err != nil {
  90. l.Logger.Errorf("批量获取帖子内容失败: %v", err)
  91. return nil, errorx.ErrPostQueryFailed
  92. }
  93. for _, post := range posts {
  94. var isLiked, isCollected bool
  95. // 查询用户状态
  96. if userId > 0 {
  97. isLiked, err = l.svcCtx.UserModel.IsPostLikedByUser(l.ctx, userId, post.ID)
  98. if err != nil {
  99. l.Logger.Errorf("查询用户点赞状态失败: %v", err)
  100. return nil, errorx.ErrPostQueryFailed
  101. }
  102. isCollected, err = l.svcCtx.UserModel.IsPostCollectedByUser(l.ctx, userId, post.ID)
  103. if err != nil {
  104. l.Logger.Errorf("查询用户收藏状态失败: %v", err)
  105. return nil, errorx.ErrPostQueryFailed
  106. }
  107. }
  108. // 获取话题信息
  109. var tagItems []*slowwildserver.TagItem
  110. if post.Tags != "" {
  111. tagStrs := strings.Split(post.Tags, ",")
  112. for _, tagStr := range tagStrs {
  113. tagId, err := strconv.ParseInt(tagStr, 10, 64)
  114. if err != nil {
  115. continue
  116. }
  117. if tag, ok := tagMap[tagId]; ok {
  118. tagItems = append(tagItems, &slowwildserver.TagItem{
  119. Id: tag.ID,
  120. Name: tag.Name,
  121. })
  122. }
  123. }
  124. }
  125. // 获取用户信息
  126. user := userMap[post.UserId]
  127. if user == nil {
  128. continue
  129. }
  130. // 获取帖子内容
  131. content := postContents[post.ID]
  132. if content == nil {
  133. continue
  134. }
  135. resp.List = append(resp.List, &slowwildserver.GetPostListItem{
  136. Id: post.ID,
  137. User: &slowwildserver.UserInfo{
  138. Id: user.ID,
  139. Nickname: user.Nickname,
  140. Avatar: user.Avatar,
  141. Sex: int32(user.Sex),
  142. },
  143. PostType: post.PostType,
  144. Title: content.Title,
  145. ContentSummary: content.ContentSummary,
  146. Images: strings.Split(content.PostCovers, ","),
  147. VideoUrl: content.PostVideoUrl,
  148. VideoCover: content.PostVideoCover,
  149. CommentCount: post.CommentCount,
  150. CollectionCount: post.CollectionCount,
  151. UpvoteCount: post.UpvoteCount,
  152. ShareCount: post.ShareCount,
  153. Tags: tagItems,
  154. IpLoc: post.IpLoc,
  155. HotNum: post.HotNum,
  156. CreatedOn: post.CreatedOn,
  157. IsLiked: isLiked,
  158. IsCollected: isCollected,
  159. IsMine: userId == post.UserId,
  160. })
  161. }
  162. return resp, nil
  163. }