client.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. package channel
  2. import (
  3. "errors"
  4. "github.com/gorilla/websocket"
  5. "message-pusher/common"
  6. "message-pusher/model"
  7. "sync"
  8. "time"
  9. )
  10. const (
  11. writeWait = 10 * time.Second
  12. pongWait = 60 * time.Second
  13. pingPeriod = (pongWait * 9) / 10
  14. maxMessageSize = 512
  15. )
  16. type webSocketClient struct {
  17. userId int
  18. conn *websocket.Conn
  19. message chan *model.Message
  20. pong chan bool
  21. stop chan bool
  22. timestamp int64
  23. }
  24. func (c *webSocketClient) handleDataReading() {
  25. c.conn.SetReadLimit(maxMessageSize)
  26. _ = c.conn.SetReadDeadline(time.Now().Add(pongWait))
  27. c.conn.SetPongHandler(func(string) error {
  28. return c.conn.SetReadDeadline(time.Now().Add(pongWait))
  29. })
  30. for {
  31. messageType, _, err := c.conn.ReadMessage()
  32. if err != nil {
  33. if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNoStatusReceived, websocket.CloseAbnormalClosure) {
  34. common.SysError("error read WebSocket client: " + err.Error())
  35. }
  36. c.close()
  37. break
  38. }
  39. switch messageType {
  40. case websocket.PingMessage:
  41. c.pong <- true
  42. case websocket.CloseMessage:
  43. c.close()
  44. break
  45. }
  46. }
  47. }
  48. func (c *webSocketClient) handleDataWriting() {
  49. pingTicker := time.NewTicker(pingPeriod)
  50. defer func() {
  51. pingTicker.Stop()
  52. clientConnMapMutex.Lock()
  53. client, ok := clientMap[c.userId]
  54. // otherwise we may delete the new added client!
  55. if ok && client.timestamp == c.timestamp {
  56. delete(clientMap, c.userId)
  57. }
  58. clientConnMapMutex.Unlock()
  59. err := c.conn.Close()
  60. if err != nil {
  61. common.SysError("error close WebSocket client: " + err.Error())
  62. }
  63. }()
  64. for {
  65. select {
  66. case message := <-c.message:
  67. _ = c.conn.SetWriteDeadline(time.Now().Add(writeWait))
  68. err := c.conn.WriteJSON(message)
  69. if err != nil {
  70. common.SysError("error write data to WebSocket client: " + err.Error())
  71. return
  72. }
  73. case <-c.pong:
  74. err := c.conn.WriteMessage(websocket.PongMessage, nil)
  75. if err != nil {
  76. common.SysError("error send pong to WebSocket client: " + err.Error())
  77. return
  78. }
  79. case <-pingTicker.C:
  80. _ = c.conn.SetWriteDeadline(time.Now().Add(writeWait))
  81. err := c.conn.WriteMessage(websocket.PingMessage, nil)
  82. if err != nil {
  83. common.SysError("error write data to WebSocket client: " + err.Error())
  84. return
  85. }
  86. case <-c.stop:
  87. err := c.conn.WriteMessage(websocket.CloseMessage, nil)
  88. if err != nil {
  89. common.SysError("error write data to WebSocket client: " + err.Error())
  90. }
  91. return
  92. }
  93. }
  94. }
  95. func (c *webSocketClient) sendMessage(message *model.Message) {
  96. c.message <- message
  97. }
  98. func (c *webSocketClient) close() {
  99. // should only be called once
  100. c.stop <- true
  101. // the defer function in handleDataWriting will do the cleanup
  102. }
  103. var clientMap map[int]*webSocketClient
  104. var clientConnMapMutex sync.Mutex
  105. func init() {
  106. clientConnMapMutex.Lock()
  107. clientMap = make(map[int]*webSocketClient)
  108. clientConnMapMutex.Unlock()
  109. }
  110. func RegisterClient(userId int, conn *websocket.Conn) {
  111. clientConnMapMutex.Lock()
  112. oldClient, existed := clientMap[userId]
  113. clientConnMapMutex.Unlock()
  114. if existed {
  115. byeMessage := &model.Message{
  116. Title: common.SystemName,
  117. Description: "其他客户端已连接服务器,本客户端已被挤下线!",
  118. }
  119. oldClient.sendMessage(byeMessage)
  120. oldClient.close()
  121. }
  122. helloMessage := &model.Message{
  123. Title: common.SystemName,
  124. Description: "客户端连接成功!",
  125. }
  126. newClient := &webSocketClient{
  127. userId: userId,
  128. conn: conn,
  129. message: make(chan *model.Message),
  130. pong: make(chan bool),
  131. stop: make(chan bool),
  132. timestamp: time.Now().UnixMilli(),
  133. }
  134. go newClient.handleDataWriting()
  135. go newClient.handleDataReading()
  136. defer newClient.sendMessage(helloMessage)
  137. clientConnMapMutex.Lock()
  138. clientMap[userId] = newClient
  139. clientConnMapMutex.Unlock()
  140. }
  141. func SendClientMessage(message *model.Message, user *model.User) error {
  142. if user.ClientSecret == "" {
  143. return errors.New("未配置 WebSocket 客户端消息推送方式")
  144. }
  145. clientConnMapMutex.Lock()
  146. client, existed := clientMap[user.Id]
  147. clientConnMapMutex.Unlock()
  148. if !existed {
  149. return errors.New("客户端未连接")
  150. }
  151. client.sendMessage(message)
  152. return nil
  153. }