client.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. package ws_helper
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "net/http"
  6. "sync"
  7. "time"
  8. ws2 "github.com/allanpk716/ChineseSubFinder/pkg/types/backend/ws"
  9. "github.com/allanpk716/ChineseSubFinder/pkg/common"
  10. "github.com/gorilla/websocket"
  11. "github.com/sirupsen/logrus"
  12. )
  13. const (
  14. // Time allowed to write a message to the peer.
  15. writeWait = 10 * time.Second
  16. // Time allowed to read the next pong message from the peer.
  17. pongWait = 60 * time.Second
  18. // Send pings to peer with this period. Must be less than pongWait.
  19. pingPeriod = (pongWait * 9) / 10
  20. // Maximum message size allowed from peer.
  21. maxMessageSize = 5 * 1024
  22. // 发送 chan 的队列长度
  23. bufSize = 5 * 1024
  24. upGraderReadBufferSize = 5 * 1024
  25. upGraderWriteBufferSize = 5 * 1024
  26. )
  27. var upGrader = websocket.Upgrader{
  28. ReadBufferSize: upGraderReadBufferSize,
  29. WriteBufferSize: upGraderWriteBufferSize,
  30. CheckOrigin: func(r *http.Request) bool {
  31. return true
  32. },
  33. }
  34. type Client struct {
  35. log *logrus.Logger
  36. hub *Hub
  37. conn *websocket.Conn // 与服务器连接实例
  38. sendLogLineIndex int // 日志发送到那个位置了
  39. authed bool // 是否已经通过认证
  40. send chan []byte // 发送给 client 的内容 bytes
  41. closeOnce sync.Once
  42. }
  43. func NewClient(log *logrus.Logger, hub *Hub, conn *websocket.Conn, sendLogLineIndex int, authed bool, send chan []byte) *Client {
  44. return &Client{log: log, hub: hub, conn: conn, sendLogLineIndex: sendLogLineIndex, authed: authed, send: send}
  45. }
  46. func (c *Client) close() {
  47. c.closeOnce.Do(func() {
  48. c.hub.unregister <- c
  49. _ = c.conn.Close()
  50. })
  51. }
  52. // 接收 Client 发送来的消息
  53. func (c *Client) readPump() {
  54. defer func() {
  55. if err := recover(); err != nil {
  56. c.log.Debugln("readPump.recover", err)
  57. }
  58. }()
  59. defer func() {
  60. // 触发移除 client 的逻辑
  61. //c.hub.unregister <- c
  62. c.close()
  63. }()
  64. var err error
  65. var message []byte
  66. c.conn.SetReadLimit(maxMessageSize)
  67. err = c.conn.SetReadDeadline(time.Now().Add(pongWait))
  68. if err != nil {
  69. c.log.Debugln("readPump.SetReadDeadline", err)
  70. return
  71. }
  72. c.conn.SetPongHandler(func(string) error {
  73. return c.conn.SetReadDeadline(time.Now().Add(pongWait))
  74. })
  75. // 收取 client 发送过来的消息
  76. for {
  77. _, message, err = c.conn.ReadMessage()
  78. if err != nil {
  79. if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
  80. c.log.Debugln("readPump.IsUnexpectedCloseError", err)
  81. }
  82. return
  83. }
  84. revMessage := ws2.BaseMessage{}
  85. err = json.Unmarshal(message, &revMessage)
  86. if err != nil {
  87. c.log.Debugln("readPump.BaseMessage.parse", err)
  88. return
  89. }
  90. if c.authed == false {
  91. // 如果没有经过认证,那么第一条一定需要判断是认证的消息
  92. if revMessage.Type != ws2.Auth.String() {
  93. // 提掉线
  94. return
  95. }
  96. // 认证
  97. login := ws2.Login{}
  98. err = json.Unmarshal([]byte(revMessage.Data), &login)
  99. if err != nil {
  100. c.log.Debugln("readPump.Login.parse", err)
  101. return
  102. }
  103. if login.Token != common.GetAccessToken() {
  104. // 登录 Token 不对
  105. // 发送 token 失败的消息
  106. outBytes, err := AuthReply(ws2.AuthError)
  107. if err != nil {
  108. c.log.Debugln("readPump.AuthReply", err)
  109. return
  110. }
  111. c.send <- outBytes
  112. // 直接退出可能会导致发送的队列没有清空,这里单独判断一条特殊的命令,收到 Write 线程就退出
  113. c.send <- ws2.CloseThisConnect
  114. } else {
  115. // Token 通过
  116. outBytes, err := AuthReply(ws2.AuthOk)
  117. if err != nil {
  118. c.log.Debugln("readPump.AuthReply", err)
  119. return
  120. }
  121. c.send <- outBytes
  122. c.authed = true
  123. }
  124. } else {
  125. // 进过认证后的消息,无需再次带有 token 信息
  126. }
  127. }
  128. }
  129. // 向 Client 发送消息的队列
  130. func (c *Client) writePump() {
  131. defer func() {
  132. if err := recover(); err != nil {
  133. c.log.Debugln("writePump.recover", err)
  134. }
  135. }()
  136. // 心跳计时器
  137. pingTicker := time.NewTicker(pingPeriod)
  138. defer func() {
  139. pingTicker.Stop()
  140. c.close()
  141. }()
  142. for {
  143. select {
  144. case message, ok := <-c.send:
  145. if bytes.Equal(message, ws2.CloseThisConnect) == true {
  146. return
  147. }
  148. // 这里是需要发送给 client 的消息
  149. // 当然首先还是得先把当前消息的发送超时,给确定下来
  150. err := c.conn.SetWriteDeadline(time.Now().Add(writeWait))
  151. if err != nil {
  152. c.log.Debugln("writePump.SetWriteDeadline", err)
  153. return
  154. }
  155. if ok == false {
  156. // The hub closed the channel.
  157. err = c.conn.WriteMessage(websocket.CloseMessage, []byte{})
  158. if err != nil {
  159. c.log.Debugln("writePump close hub WriteMessage", err)
  160. }
  161. return
  162. }
  163. w, err := c.conn.NextWriter(websocket.TextMessage)
  164. if err != nil {
  165. c.log.Debugln("writePump.NextWriter", err)
  166. return
  167. }
  168. _, err = w.Write(message)
  169. if err != nil {
  170. c.log.Debugln("writePump.Write", err)
  171. return
  172. }
  173. if err := w.Close(); err != nil {
  174. c.log.Debugln("writePump.Close", err)
  175. return
  176. }
  177. case <-pingTicker.C:
  178. // 心跳相关,这里是定时器到了触发的间隔,设置发送下一条心跳的超时时间
  179. if err := c.conn.SetWriteDeadline(time.Now().Add(writeWait)); err != nil {
  180. c.log.Debugln("writePump.pingTicker.C.SetWriteDeadline", err)
  181. return
  182. }
  183. // 然后发送心跳
  184. if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
  185. c.log.Debugln("writePump.pingTicker.C.WriteMessage", err)
  186. return
  187. }
  188. }
  189. }
  190. }
  191. // AuthReply 生成认证通过的回复数据
  192. func AuthReply(inType ws2.AuthMessage) ([]byte, error) {
  193. var err error
  194. var outData, outBytes []byte
  195. outData, err = json.Marshal(&ws2.Reply{
  196. Message: inType.String(),
  197. })
  198. if err != nil {
  199. return nil, err
  200. }
  201. outBytes, err = ws2.NewBaseMessage(ws2.CommonReply.String(), string(outData)).Bytes()
  202. if err != nil {
  203. return nil, err
  204. }
  205. return outBytes, nil
  206. }
  207. // ServeWs 每个 Client 连接 ws 上线时触发
  208. func ServeWs(log *logrus.Logger, hub *Hub, w http.ResponseWriter, r *http.Request) {
  209. conn, err := upGrader.Upgrade(w, r, nil)
  210. if err != nil {
  211. log.Errorln("ServeWs.Upgrade", err)
  212. return
  213. }
  214. client := NewClient(
  215. log,
  216. hub,
  217. conn,
  218. 0,
  219. false,
  220. make(chan []byte, bufSize),
  221. )
  222. client.hub.register <- client
  223. go client.writePump()
  224. go client.readPump()
  225. }