getpostlistlogic.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. package logic
  2. import (
  3. "context"
  4. "slowwild/internal/errorx"
  5. "slowwild/internal/model"
  6. "slowwild/internal/svc"
  7. "git.banshen.xyz/huangguangrong/slow_wild_protobuff/slowwild/slowwildserver"
  8. "strconv"
  9. "strings"
  10. "github.com/zeromicro/go-zero/core/logx"
  11. )
  12. type GetPostListLogic struct {
  13. ctx context.Context
  14. svcCtx *svc.ServiceContext
  15. logx.Logger
  16. }
  17. func NewGetPostListLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetPostListLogic {
  18. return &GetPostListLogic{
  19. ctx: ctx,
  20. svcCtx: svcCtx,
  21. Logger: logx.WithContext(ctx),
  22. }
  23. }
  24. // 获取帖子列表
  25. func (l *GetPostListLogic) GetPostList(in *slowwildserver.GetPostListReq) (*slowwildserver.GetPostListRsp, error) {
  26. if in.Page <= 0 || in.PageSize <= 0 {
  27. return nil, errorx.ErrInvalidParam
  28. }
  29. posts, err := l.svcCtx.PostModel.GetPostList(l.ctx, in)
  30. if err != nil {
  31. l.Logger.Errorf("获取帖子列表失败: %v", err)
  32. return nil, errorx.ErrPostQueryFailed
  33. }
  34. // 构建响应
  35. resp := &slowwildserver.GetPostListRsp{
  36. List: make([]*slowwildserver.GetPostListItem, 0, len(posts)),
  37. }
  38. // 收集所有用户ID
  39. userIds := make([]int64, 0, len(posts))
  40. for _, post := range posts {
  41. userIds = append(userIds, post.UserId)
  42. }
  43. // 批量获取用户信息
  44. users, err := l.svcCtx.UserModel.FindByIds(l.ctx, userIds)
  45. if err != nil {
  46. l.Logger.Errorf("批量获取用户信息失败: %v", err)
  47. return nil, errorx.ErrUserQueryFailed
  48. }
  49. // 构建用户信息map
  50. userMap := make(map[int64]*model.User)
  51. for _, user := range users {
  52. userMap[user.ID] = user
  53. }
  54. // 获取所有话题ID
  55. var allTagIds []int64
  56. for _, post := range posts {
  57. if post.Tags != "" {
  58. tagStrs := strings.Split(post.Tags, ",")
  59. for _, tagStr := range tagStrs {
  60. tagId, err := strconv.ParseInt(tagStr, 10, 64)
  61. if err != nil {
  62. continue
  63. }
  64. allTagIds = append(allTagIds, tagId)
  65. }
  66. }
  67. }
  68. // 批量获取话题信息
  69. tags, err := l.svcCtx.TagModel.FindByIds(l.ctx, allTagIds)
  70. if err != nil {
  71. l.Logger.Errorf("批量获取话题信息失败: %v", err)
  72. return nil, errorx.ErrPostQueryFailed
  73. }
  74. // 构建话题信息map
  75. tagMap := make(map[int64]*model.Tag)
  76. for _, tag := range tags {
  77. tagMap[tag.ID] = tag
  78. }
  79. // 收集所有帖子ID
  80. postIds := make([]int64, 0, len(posts))
  81. for _, post := range posts {
  82. postIds = append(postIds, post.ID)
  83. }
  84. // 批量获取帖子内容
  85. postContents, err := l.svcCtx.PostModel.GetPostContents(l.ctx, postIds)
  86. if err != nil {
  87. l.Logger.Errorf("批量获取帖子内容失败: %v", err)
  88. return nil, errorx.ErrPostQueryFailed
  89. }
  90. for _, post := range posts {
  91. // 查询用户状态
  92. isLiked, err := l.svcCtx.UserModel.IsPostLikedByUser(l.ctx, in.UserId, post.ID)
  93. if err != nil {
  94. l.Logger.Errorf("查询用户点赞状态失败: %v", err)
  95. return nil, errorx.ErrPostQueryFailed
  96. }
  97. isCollected, err := l.svcCtx.UserModel.IsPostCollectedByUser(l.ctx, in.UserId, post.ID)
  98. if err != nil {
  99. l.Logger.Errorf("查询用户收藏状态失败: %v", err)
  100. return nil, errorx.ErrPostQueryFailed
  101. }
  102. // 获取话题信息
  103. var tagItems []*slowwildserver.TagItem
  104. if post.Tags != "" {
  105. tagStrs := strings.Split(post.Tags, ",")
  106. for _, tagStr := range tagStrs {
  107. tagId, err := strconv.ParseInt(tagStr, 10, 64)
  108. if err != nil {
  109. continue
  110. }
  111. if tag, ok := tagMap[tagId]; ok {
  112. tagItems = append(tagItems, &slowwildserver.TagItem{
  113. Id: tag.ID,
  114. Name: tag.Name,
  115. })
  116. }
  117. }
  118. }
  119. // 获取用户信息
  120. user := userMap[post.UserId]
  121. if user == nil {
  122. continue
  123. }
  124. // 获取帖子内容
  125. content := postContents[post.ID]
  126. if content == nil {
  127. continue
  128. }
  129. resp.List = append(resp.List, &slowwildserver.GetPostListItem{
  130. Id: post.ID,
  131. User: &slowwildserver.UserInfo{
  132. Id: user.ID,
  133. Nickname: user.Nickname,
  134. Avatar: user.Avatar,
  135. Sex: int32(user.Sex),
  136. },
  137. PostType: post.PostType,
  138. Title: content.Title,
  139. ContentSummary: content.ContentSummary,
  140. Images: strings.Split(content.PostCovers, ","),
  141. VideoUrl: content.PostVideoUrl,
  142. VideoCover: content.PostVideoCover,
  143. CommentCount: post.CommentCount,
  144. CollectionCount: post.CollectionCount,
  145. UpvoteCount: post.UpvoteCount,
  146. ShareCount: post.ShareCount,
  147. Tags: tagItems,
  148. IpLoc: post.IpLoc,
  149. HotNum: post.HotNum,
  150. CreatedOn: post.CreatedOn,
  151. IsLiked: isLiked,
  152. IsCollected: isCollected,
  153. IsMine: in.UserId == post.UserId,
  154. })
  155. }
  156. return resp, nil
  157. }