connertionlogic.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. package logic
  2. import (
  3. "context"
  4. "slowwildws/internal/constants"
  5. "slowwildws/internal/svc"
  6. "slowwildws/internal/types"
  7. "sync"
  8. "time"
  9. "github.com/gorilla/websocket"
  10. "github.com/zeromicro/go-zero/core/logx"
  11. )
  12. type ConnectionLogic struct {
  13. logx.Logger
  14. ctx context.Context
  15. svcCtx *svc.ServiceContext
  16. conn *websocket.Conn
  17. idleMu sync.Mutex // 超时锁
  18. idle time.Time // 闲置时间
  19. maxConnectionIdle time.Duration // 最大闲置时间,也就是空闲时间
  20. messageMu sync.Mutex // 消息锁
  21. readMessageList []*types.Message // 读取消息队列
  22. readMessageSeq map[string]*types.Message // 读取消息序列化
  23. done chan struct{} // 结束方法
  24. message chan *types.Message //消息通道
  25. }
  26. func NewConnectionLogic(ctx context.Context, svcCtx *svc.ServiceContext, conn *websocket.Conn) *ConnectionLogic {
  27. return &ConnectionLogic{
  28. Logger: logx.WithContext(ctx),
  29. ctx: ctx,
  30. svcCtx: svcCtx,
  31. conn: conn,
  32. idle: time.Now(),
  33. maxConnectionIdle: time.Duration(svcCtx.Config.MaxConnectionIdle) * time.Second,
  34. done: make(chan struct{}),
  35. readMessageList: make([]*types.Message, 0, 2),
  36. readMessageSeq: make(map[string]*types.Message, 2),
  37. message: make(chan *types.Message, 1), // 给容量为1的话可以确保收发顺序
  38. }
  39. }
  40. // 关闭连接
  41. func (c *ConnectionLogic) Close() error {
  42. select {
  43. case <-c.done:
  44. default:
  45. close(c.done)
  46. }
  47. return c.conn.Close()
  48. }
  49. // 读取消息
  50. func (c *ConnectionLogic) ReadMessage() (messageType int, p []byte, err error) {
  51. // 这里不能先获取锁,因为会被阻塞住,下面这个获取消息的方法是阻塞的,会导致锁一直得不到释放
  52. messageType, p, err = c.conn.ReadMessage()
  53. c.idleMu.Lock()
  54. defer func() {
  55. c.idleMu.Unlock()
  56. }()
  57. c.idle = time.Now()
  58. return
  59. }
  60. // 写消息
  61. func (c *ConnectionLogic) WriteMessage(messageType int, data []byte) error {
  62. c.idleMu.Lock()
  63. defer func() {
  64. c.idleMu.Unlock()
  65. }()
  66. err := c.conn.WriteMessage(messageType, data)
  67. c.idle = time.Now()
  68. return err
  69. }
  70. // 心跳检测
  71. func (c *ConnectionLogic) Keepalive() {
  72. idlerTimer := time.NewTimer(c.maxConnectionIdle)
  73. defer idlerTimer.Stop()
  74. for {
  75. select {
  76. case <-idlerTimer.C:
  77. c.idleMu.Lock()
  78. idle := c.idle
  79. if idle.IsZero() {
  80. idlerTimer.Reset(c.maxConnectionIdle)
  81. c.idleMu.Unlock()
  82. continue
  83. }
  84. val := c.maxConnectionIdle - time.Since(idle)
  85. if val <= 0 {
  86. c.idleMu.Unlock()
  87. c.Close()
  88. return
  89. }
  90. idlerTimer.Reset(val)
  91. c.idleMu.Unlock()
  92. case <-c.done:
  93. return
  94. }
  95. }
  96. }
  97. // 添加消息到队列中
  98. func (c *ConnectionLogic) appendMsgMq(msg *types.Message) {
  99. c.messageMu.Lock()
  100. defer c.messageMu.Unlock()
  101. if m, ok := c.readMessageSeq[msg.Id]; ok {
  102. // 已经有消息记录,该消息已经有ack确认
  103. if len(c.readMessageList) == 0 {
  104. // 队列中没有该消息
  105. return
  106. }
  107. if m.AckSeq >= msg.AckSeq {
  108. // 没有进行ack确认,或者重复
  109. return
  110. }
  111. c.readMessageSeq[msg.Id] = msg
  112. return
  113. }
  114. // 还没有进行ack确认,避免客户端重复发送多余的ack消息
  115. if msg.FrameType == constants.FrameAck {
  116. return
  117. }
  118. c.readMessageList = append(c.readMessageList, msg)
  119. c.readMessageSeq[msg.Id] = msg
  120. }