123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263 |
- package ws_helper
- import (
- "bytes"
- "encoding/json"
- "net/http"
- "sync"
- "time"
- ws2 "github.com/allanpk716/ChineseSubFinder/pkg/types/backend/ws"
- "github.com/allanpk716/ChineseSubFinder/pkg/common"
- "github.com/gorilla/websocket"
- "github.com/sirupsen/logrus"
- )
- const (
- // Time allowed to write a message to the peer.
- writeWait = 10 * time.Second
- // Time allowed to read the next pong message from the peer.
- pongWait = 60 * time.Second
- // Send pings to peer with this period. Must be less than pongWait.
- pingPeriod = (pongWait * 9) / 10
- // Maximum message size allowed from peer.
- maxMessageSize = 5 * 1024
- // 发送 chan 的队列长度
- bufSize = 5 * 1024
- upGraderReadBufferSize = 5 * 1024
- upGraderWriteBufferSize = 5 * 1024
- )
- var upGrader = websocket.Upgrader{
- ReadBufferSize: upGraderReadBufferSize,
- WriteBufferSize: upGraderWriteBufferSize,
- CheckOrigin: func(r *http.Request) bool {
- return true
- },
- }
- type Client struct {
- log *logrus.Logger
- hub *Hub
- conn *websocket.Conn // 与服务器连接实例
- sendLogLineIndex int // 日志发送到那个位置了
- authed bool // 是否已经通过认证
- send chan []byte // 发送给 client 的内容 bytes
- closeOnce sync.Once
- }
- func NewClient(log *logrus.Logger, hub *Hub, conn *websocket.Conn, sendLogLineIndex int, authed bool, send chan []byte) *Client {
- return &Client{log: log, hub: hub, conn: conn, sendLogLineIndex: sendLogLineIndex, authed: authed, send: send}
- }
- func (c *Client) close() {
- c.closeOnce.Do(func() {
- c.hub.unregister <- c
- _ = c.conn.Close()
- })
- }
- // 接收 Client 发送来的消息
- func (c *Client) readPump() {
- defer func() {
- if err := recover(); err != nil {
- c.log.Debugln("readPump.recover", err)
- }
- }()
- defer func() {
- // 触发移除 client 的逻辑
- //c.hub.unregister <- c
- c.close()
- }()
- var err error
- var message []byte
- c.conn.SetReadLimit(maxMessageSize)
- err = c.conn.SetReadDeadline(time.Now().Add(pongWait))
- if err != nil {
- c.log.Debugln("readPump.SetReadDeadline", err)
- return
- }
- c.conn.SetPongHandler(func(string) error {
- return c.conn.SetReadDeadline(time.Now().Add(pongWait))
- })
- // 收取 client 发送过来的消息
- for {
- _, message, err = c.conn.ReadMessage()
- if err != nil {
- if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
- c.log.Debugln("readPump.IsUnexpectedCloseError", err)
- }
- return
- }
- revMessage := ws2.BaseMessage{}
- err = json.Unmarshal(message, &revMessage)
- if err != nil {
- c.log.Debugln("readPump.BaseMessage.parse", err)
- return
- }
- if c.authed == false {
- // 如果没有经过认证,那么第一条一定需要判断是认证的消息
- if revMessage.Type != ws2.Auth.String() {
- // 提掉线
- return
- }
- // 认证
- login := ws2.Login{}
- err = json.Unmarshal([]byte(revMessage.Data), &login)
- if err != nil {
- c.log.Debugln("readPump.Login.parse", err)
- return
- }
- if login.Token != common.GetAccessToken() {
- // 登录 Token 不对
- // 发送 token 失败的消息
- outBytes, err := AuthReply(ws2.AuthError)
- if err != nil {
- c.log.Debugln("readPump.AuthReply", err)
- return
- }
- c.send <- outBytes
- // 直接退出可能会导致发送的队列没有清空,这里单独判断一条特殊的命令,收到 Write 线程就退出
- c.send <- ws2.CloseThisConnect
- } else {
- // Token 通过
- outBytes, err := AuthReply(ws2.AuthOk)
- if err != nil {
- c.log.Debugln("readPump.AuthReply", err)
- return
- }
- c.send <- outBytes
- c.authed = true
- }
- } else {
- // 进过认证后的消息,无需再次带有 token 信息
- }
- }
- }
- // 向 Client 发送消息的队列
- func (c *Client) writePump() {
- defer func() {
- if err := recover(); err != nil {
- c.log.Debugln("writePump.recover", err)
- }
- }()
- // 心跳计时器
- pingTicker := time.NewTicker(pingPeriod)
- defer func() {
- pingTicker.Stop()
- c.close()
- }()
- for {
- select {
- case message, ok := <-c.send:
- if bytes.Equal(message, ws2.CloseThisConnect) == true {
- return
- }
- // 这里是需要发送给 client 的消息
- // 当然首先还是得先把当前消息的发送超时,给确定下来
- err := c.conn.SetWriteDeadline(time.Now().Add(writeWait))
- if err != nil {
- c.log.Debugln("writePump.SetWriteDeadline", err)
- return
- }
- if ok == false {
- // The hub closed the channel.
- err = c.conn.WriteMessage(websocket.CloseMessage, []byte{})
- if err != nil {
- c.log.Debugln("writePump close hub WriteMessage", err)
- }
- return
- }
- w, err := c.conn.NextWriter(websocket.TextMessage)
- if err != nil {
- c.log.Debugln("writePump.NextWriter", err)
- return
- }
- _, err = w.Write(message)
- if err != nil {
- c.log.Debugln("writePump.Write", err)
- return
- }
- if err := w.Close(); err != nil {
- c.log.Debugln("writePump.Close", err)
- return
- }
- case <-pingTicker.C:
- // 心跳相关,这里是定时器到了触发的间隔,设置发送下一条心跳的超时时间
- if err := c.conn.SetWriteDeadline(time.Now().Add(writeWait)); err != nil {
- c.log.Debugln("writePump.pingTicker.C.SetWriteDeadline", err)
- return
- }
- // 然后发送心跳
- if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
- c.log.Debugln("writePump.pingTicker.C.WriteMessage", err)
- return
- }
- }
- }
- }
- // AuthReply 生成认证通过的回复数据
- func AuthReply(inType ws2.AuthMessage) ([]byte, error) {
- var err error
- var outData, outBytes []byte
- outData, err = json.Marshal(&ws2.Reply{
- Message: inType.String(),
- })
- if err != nil {
- return nil, err
- }
- outBytes, err = ws2.NewBaseMessage(ws2.CommonReply.String(), string(outData)).Bytes()
- if err != nil {
- return nil, err
- }
- return outBytes, nil
- }
- // ServeWs 每个 Client 连接 ws 上线时触发
- func ServeWs(log *logrus.Logger, hub *Hub, w http.ResponseWriter, r *http.Request) {
- conn, err := upGrader.Upgrade(w, r, nil)
- if err != nil {
- log.Errorln("ServeWs.Upgrade", err)
- return
- }
- client := NewClient(
- log,
- hub,
- conn,
- 0,
- false,
- make(chan []byte, bufSize),
- )
- client.hub.register <- client
- go client.writePump()
- go client.readPump()
- }
|