浏览代码

feat: add support for client channel

JustSong 2 年之前
父节点
当前提交
1160489c7b
共有 9 个文件被更改,包括 167 次插入0 次删除
  1. 82 0
      channel/client.go
  2. 3 0
      channel/main.go
  3. 1 0
      controller/user.go
  4. 41 0
      controller/websocket.go
  5. 1 0
      go.mod
  6. 2 0
      go.sum
  7. 1 0
      model/user.go
  8. 1 0
      router/api-router.go
  9. 35 0
      web/src/components/PushSetting.js

+ 82 - 0
channel/client.go

@@ -0,0 +1,82 @@
+package channel
+
+import (
+	"errors"
+	"github.com/gorilla/websocket"
+	"message-pusher/common"
+	"message-pusher/model"
+	"sync"
+)
+
+var clientConnMap map[int]*websocket.Conn
+var clientConnMapMutex sync.Mutex
+
+func init() {
+	clientConnMapMutex.Lock()
+	clientConnMap = make(map[int]*websocket.Conn)
+	clientConnMapMutex.Unlock()
+}
+
+func SendMessageWithConn(message *Message, conn *websocket.Conn) error {
+	return conn.WriteJSON(message)
+}
+
+func LogoutClient(userId int) {
+	clientConnMapMutex.Lock()
+	delete(clientConnMap, userId)
+	clientConnMapMutex.Unlock()
+}
+
+func RegisterClient(userId int, conn *websocket.Conn) {
+	clientConnMapMutex.Lock()
+	oldConn, existed := clientConnMap[userId]
+	clientConnMapMutex.Unlock()
+	if existed {
+		byeMessage := &Message{
+			Title:       common.SystemName,
+			Description: "其他客户端已连接服务器,本客户端已被挤下线!",
+		}
+		err := SendMessageWithConn(byeMessage, oldConn)
+		if err != nil {
+			common.SysError("error send message to client: " + err.Error())
+		}
+		err = oldConn.Close()
+		if err != nil {
+			common.SysError("error close WebSocket connection: " + err.Error())
+		}
+	}
+	helloMessage := &Message{
+		Title:       common.SystemName,
+		Description: "客户端连接成功!",
+	}
+	err := SendMessageWithConn(helloMessage, conn)
+	if err != nil {
+		common.SysError("error send message to client: " + err.Error())
+		return
+	} else {
+		clientConnMapMutex.Lock()
+		clientConnMap[userId] = conn
+		clientConnMapMutex.Unlock()
+		conn.SetCloseHandler(func(code int, text string) error {
+			LogoutClient(userId)
+			return nil
+		})
+	}
+}
+
+func SendClientMessage(message *Message, user *model.User) error {
+	if user.ClientSecret == "" {
+		return errors.New("未配置 WebSocket 客户端消息推送方式")
+	}
+	clientConnMapMutex.Lock()
+	conn, existed := clientConnMap[user.Id]
+	clientConnMapMutex.Unlock()
+	if !existed {
+		return errors.New("客户端未连接")
+	}
+	err := SendMessageWithConn(message, conn)
+	if err != nil {
+		LogoutClient(user.Id)
+	}
+	return err
+}

+ 3 - 0
channel/main.go

@@ -14,6 +14,7 @@ const (
 	TypeDing              = "ding"
 	TypeTelegram          = "telegram"
 	TypeBark              = "bark"
+	TypeClient            = "client"
 )
 
 type Message struct {
@@ -42,6 +43,8 @@ func (message *Message) Send(user *model.User) error {
 		return SendDingMessage(message, user)
 	case TypeBark:
 		return SendBarkMessage(message, user)
+	case TypeClient:
+		return SendClientMessage(message, user)
 	default:
 		return errors.New("不支持的消息通道:" + message.Channel)
 	}

+ 1 - 0
controller/user.go

@@ -416,6 +416,7 @@ func UpdateSelf(c *gin.Context) {
 		DingWebhookSecret:                  user.DingWebhookSecret,
 		BarkServer:                         user.BarkServer,
 		BarkSecret:                         user.BarkSecret,
+		ClientSecret:                       user.ClientSecret,
 	}
 	channel.TokenStoreUpdateUser(&cleanUser, originUser)
 

+ 41 - 0
controller/websocket.go

@@ -0,0 +1,41 @@
+package controller
+
+import (
+	"github.com/gin-gonic/gin"
+	"github.com/gorilla/websocket"
+	"message-pusher/channel"
+	"message-pusher/model"
+	"net/http"
+)
+
+var upgrader = websocket.Upgrader{} // use default options
+
+func RegisterClient(c *gin.Context) {
+	secret := c.Query("secret")
+	if secret == "" {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "secret 为空",
+		})
+		return
+	}
+	user := model.User{Username: c.Param("username")}
+	err := user.FillUserByUsername()
+	if secret != user.ClientSecret {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "用户名与密钥不匹配",
+		})
+		return
+	}
+	conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+		return
+	}
+	channel.RegisterClient(user.Id, conn)
+	return
+}

+ 1 - 0
go.mod

@@ -32,6 +32,7 @@ require (
 	github.com/gorilla/context v1.1.1 // indirect
 	github.com/gorilla/securecookie v1.1.1 // indirect
 	github.com/gorilla/sessions v1.2.1 // indirect
+	github.com/gorilla/websocket v1.5.0 // indirect
 	github.com/jinzhu/inflection v1.0.0 // indirect
 	github.com/jinzhu/now v1.1.5 // indirect
 	github.com/json-iterator/go v1.1.12 // indirect

+ 2 - 0
go.sum

@@ -54,6 +54,8 @@ github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+
 github.com/gorilla/sessions v1.1.1/go.mod h1:8KCfur6+4Mqcc6S0FEfKuN15Vl5MgXW92AE8ovaJD0w=
 github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI=
 github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
+github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
+github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
 github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
 github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
 github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=

+ 1 - 0
model/user.go

@@ -38,6 +38,7 @@ type User struct {
 	DingWebhookSecret                  string `json:"ding_webhook_secret"`
 	BarkServer                         string `json:"bark_server"`
 	BarkSecret                         string `json:"bark_secret"`
+	ClientSecret                       string `json:"client_secret"`
 }
 
 func GetMaxUserId() int {

+ 1 - 0
router/api-router.go

@@ -15,6 +15,7 @@ func SetApiRouter(router *gin.Engine) {
 		apiRouter.GET("/about", controller.GetAbout)
 		apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification)
 		apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
+		apiRouter.GET("/register_client/:username", middleware.CriticalRateLimit(), controller.RegisterClient)
 		apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
 		apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
 		apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)

+ 35 - 0
web/src/components/PushSetting.js

@@ -31,6 +31,7 @@ const PushSetting = () => {
     ding_webhook_secret: '',
     bark_server: '',
     bark_secret: '',
+    client_secret: '',
   });
   let [loading, setLoading] = useState(false);
 
@@ -106,6 +107,9 @@ const PushSetting = () => {
         data.bark_server = removeTrailingSlash(inputs.bark_server);
         data.bark_secret = inputs.bark_secret;
         break;
+      case 'client':
+        data.client_secret = inputs.client_secret;
+        break;
       default:
         showError(`无效的参数:${which}`);
         return;
@@ -151,6 +155,7 @@ const PushSetting = () => {
                 { key: 'lark', text: '飞书群机器人', value: 'lark' },
                 { key: 'ding', text: '钉钉群机器人', value: 'ding' },
                 { key: 'bark', text: 'Bark App', value: 'bark' },
+                { key: 'client', text: 'WebSocket 客户端', value: 'client' },
               ]}
               value={inputs.channel}
               onChange={handleInputChange}
@@ -443,6 +448,36 @@ const PushSetting = () => {
             保存
           </Button>
           <Button onClick={() => test('bark')}>测试</Button>
+          <Divider />
+          <Header as='h3'>
+            WebSocket 客户端设置(client)
+            <Header.Subheader>
+              通过 WebSocket
+              客户端进行推送,可以使用官方客户端实现,或者根据协议自行实现。官方客户端
+              <a
+                target='_blank'
+                href='https://github.com/songquanpeng/personal-assistant'
+              >
+                详见此处
+              </a>
+              。
+            </Header.Subheader>
+          </Header>
+          <Form.Group widths={2}>
+            <Form.Input
+              label='服务器连接密钥'
+              name='client_secret'
+              type='password'
+              onChange={handleInputChange}
+              autoComplete='off'
+              value={inputs.client_secret}
+              placeholder='在此设置服务器连接密钥'
+            />
+          </Form.Group>
+          <Button onClick={() => submit('client')} loading={loading}>
+            保存
+          </Button>
+          <Button onClick={() => test('client')}>测试</Button>
         </Form>
       </Grid.Column>
     </Grid>