client.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. package channel
  2. import (
  3. "errors"
  4. "fmt"
  5. "github.com/gorilla/websocket"
  6. "message-pusher/common"
  7. "message-pusher/model"
  8. "sync"
  9. "time"
  10. )
  11. const (
  12. writeWait = 10 * time.Second
  13. pongWait = 60 * time.Second
  14. pingPeriod = (pongWait * 9) / 10
  15. maxMessageSize = 512
  16. )
  17. type webSocketClient struct {
  18. key string
  19. conn *websocket.Conn
  20. message chan *model.Message
  21. pong chan bool
  22. stop chan bool
  23. timestamp int64
  24. }
  25. func (c *webSocketClient) handleDataReading() {
  26. c.conn.SetReadLimit(maxMessageSize)
  27. _ = c.conn.SetReadDeadline(time.Now().Add(pongWait))
  28. c.conn.SetPongHandler(func(string) error {
  29. return c.conn.SetReadDeadline(time.Now().Add(pongWait))
  30. })
  31. for {
  32. messageType, _, err := c.conn.ReadMessage()
  33. if err != nil {
  34. if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNoStatusReceived, websocket.CloseAbnormalClosure) {
  35. common.SysError("error read WebSocket client: " + err.Error())
  36. }
  37. c.close()
  38. break
  39. }
  40. switch messageType {
  41. case websocket.PingMessage:
  42. c.pong <- true
  43. case websocket.CloseMessage:
  44. c.close()
  45. break
  46. }
  47. }
  48. }
  49. func (c *webSocketClient) handleDataWriting() {
  50. pingTicker := time.NewTicker(pingPeriod)
  51. defer func() {
  52. pingTicker.Stop()
  53. clientConnMapMutex.Lock()
  54. client, ok := clientMap[c.key]
  55. // otherwise we may delete the new added client!
  56. if ok && client.timestamp == c.timestamp {
  57. delete(clientMap, c.key)
  58. }
  59. clientConnMapMutex.Unlock()
  60. err := c.conn.Close()
  61. if err != nil {
  62. common.SysError("error close WebSocket client: " + err.Error())
  63. }
  64. }()
  65. for {
  66. select {
  67. case message := <-c.message:
  68. _ = c.conn.SetWriteDeadline(time.Now().Add(writeWait))
  69. err := c.conn.WriteJSON(message)
  70. if err != nil {
  71. common.SysError("error write data to WebSocket client: " + err.Error())
  72. return
  73. }
  74. case <-c.pong:
  75. err := c.conn.WriteMessage(websocket.PongMessage, nil)
  76. if err != nil {
  77. common.SysError("error send pong to WebSocket client: " + err.Error())
  78. return
  79. }
  80. case <-pingTicker.C:
  81. _ = c.conn.SetWriteDeadline(time.Now().Add(writeWait))
  82. err := c.conn.WriteMessage(websocket.PingMessage, nil)
  83. if err != nil {
  84. common.SysError("error write data to WebSocket client: " + err.Error())
  85. return
  86. }
  87. case <-c.stop:
  88. err := c.conn.WriteMessage(websocket.CloseMessage, nil)
  89. if err != nil {
  90. common.SysError("error write data to WebSocket client: " + err.Error())
  91. }
  92. return
  93. }
  94. }
  95. }
  96. func (c *webSocketClient) sendMessage(message *model.Message) {
  97. c.message <- message
  98. }
  99. func (c *webSocketClient) close() {
  100. // should only be called once
  101. c.stop <- true
  102. // the defer function in handleDataWriting will do the cleanup
  103. }
  104. var clientMap map[string]*webSocketClient
  105. var clientConnMapMutex sync.Mutex
  106. func init() {
  107. clientConnMapMutex.Lock()
  108. clientMap = make(map[string]*webSocketClient)
  109. clientConnMapMutex.Unlock()
  110. }
  111. func RegisterClient(channelName string, userId int, conn *websocket.Conn) {
  112. key := fmt.Sprintf("%s:%d", channelName, userId)
  113. clientConnMapMutex.Lock()
  114. oldClient, existed := clientMap[key]
  115. clientConnMapMutex.Unlock()
  116. if existed {
  117. byeMessage := &model.Message{
  118. Title: common.SystemName,
  119. Description: "其他客户端已连接服务器,本客户端已被挤下线!",
  120. }
  121. oldClient.sendMessage(byeMessage)
  122. oldClient.close()
  123. }
  124. helloMessage := &model.Message{
  125. Title: common.SystemName,
  126. Description: "客户端连接成功!",
  127. }
  128. newClient := &webSocketClient{
  129. key: key,
  130. conn: conn,
  131. message: make(chan *model.Message),
  132. pong: make(chan bool),
  133. stop: make(chan bool),
  134. timestamp: time.Now().UnixMilli(),
  135. }
  136. go newClient.handleDataWriting()
  137. go newClient.handleDataReading()
  138. defer newClient.sendMessage(helloMessage)
  139. clientConnMapMutex.Lock()
  140. clientMap[key] = newClient
  141. clientConnMapMutex.Unlock()
  142. }
  143. func SendClientMessage(message *model.Message, user *model.User, channel_ *model.Channel) error {
  144. key := fmt.Sprintf("%s:%d", channel_.Name, user.Id)
  145. clientConnMapMutex.Lock()
  146. client, existed := clientMap[key]
  147. clientConnMapMutex.Unlock()
  148. if !existed {
  149. return errors.New("客户端未连接")
  150. }
  151. client.sendMessage(message)
  152. return nil
  153. }