Browse Source

feat: use sse to fetch new messages (close #70)

JustSong 2 years ago
parent
commit
8f7e5456e5
6 changed files with 97 additions and 11 deletions
  1. 56 0
      controller/message-sse.go
  2. 2 0
      controller/message.go
  3. 1 2
      main.go
  4. 13 0
      middleware/sse.go
  5. 1 0
      router/api-router.go
  6. 24 9
      web/src/components/MessagesTable.js

+ 56 - 0
controller/message-sse.go

@@ -0,0 +1,56 @@
+package controller
+
+import (
+	"github.com/gin-gonic/gin"
+	"io"
+	"message-pusher/model"
+	"sync"
+)
+
+var messageChanBufferSize = 10
+
+var messageChanStore struct {
+	Map   map[int]*chan *model.Message
+	Mutex sync.RWMutex
+}
+
+func messageChanStoreAdd(messageChan *chan *model.Message, userId int) {
+	messageChanStore.Mutex.Lock()
+	defer messageChanStore.Mutex.Unlock()
+	messageChanStore.Map[userId] = messageChan
+}
+
+func messageChanStoreRemove(userId int) {
+	messageChanStore.Mutex.Lock()
+	defer messageChanStore.Mutex.Unlock()
+	delete(messageChanStore.Map, userId)
+}
+
+func init() {
+	messageChanStore.Map = make(map[int]*chan *model.Message)
+}
+
+func syncMessageToUser(message *model.Message, userId int) {
+	messageChanStore.Mutex.RLock()
+	defer messageChanStore.Mutex.RUnlock()
+	messageChan, ok := messageChanStore.Map[userId]
+	if !ok {
+		return
+	}
+	*messageChan <- message
+}
+
+func GetNewMessages(c *gin.Context) {
+	userId := c.GetInt("id")
+	messageChan := make(chan *model.Message, messageChanBufferSize)
+	messageChanStoreAdd(&messageChan, userId)
+	c.Stream(func(w io.Writer) bool {
+		if msg, ok := <-messageChan; ok {
+			c.SSEvent("message", *msg)
+			return true
+		}
+		return false
+	})
+	messageChanStoreRemove(userId)
+	close(messageChan)
+}

+ 2 - 0
controller/message.go

@@ -185,11 +185,13 @@ func saveAndSendMessage(user *model.User, message *model.Message, channel_ *mode
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
+		go syncMessageToUser(message, user.Id)
 	} else {
 	} else {
 		if message.Async {
 		if message.Async {
 			return errors.New("异步发送消息需要用户具备消息持久化的权限")
 			return errors.New("异步发送消息需要用户具备消息持久化的权限")
 		}
 		}
 		message.Link = "unsaved" // This is for user to identify whether the message is saved
 		message.Link = "unsaved" // This is for user to identify whether the message is saved
+		go syncMessageToUser(message, user.Id)
 	}
 	}
 	if !message.Async {
 	if !message.Async {
 		err := channel.SendMessage(message, user, channel_)
 		err := channel.SendMessage(message, user, channel_)

+ 1 - 2
main.go

@@ -2,7 +2,6 @@ package main
 
 
 import (
 import (
 	"embed"
 	"embed"
-	"github.com/gin-contrib/gzip"
 	"github.com/gin-contrib/sessions"
 	"github.com/gin-contrib/sessions"
 	"github.com/gin-contrib/sessions/cookie"
 	"github.com/gin-contrib/sessions/cookie"
 	"github.com/gin-contrib/sessions/redis"
 	"github.com/gin-contrib/sessions/redis"
@@ -56,7 +55,7 @@ func main() {
 	// Initialize HTTP server
 	// Initialize HTTP server
 	server := gin.Default()
 	server := gin.Default()
 	server.SetHTMLTemplate(common.LoadTemplate())
 	server.SetHTMLTemplate(common.LoadTemplate())
-	server.Use(gzip.Gzip(gzip.DefaultCompression))
+	//server.Use(gzip.Gzip(gzip.DefaultCompression))  // conflict with sse
 
 
 	// Initialize session store
 	// Initialize session store
 	if common.RedisEnabled {
 	if common.RedisEnabled {

+ 13 - 0
middleware/sse.go

@@ -0,0 +1,13 @@
+package middleware
+
+import "github.com/gin-gonic/gin"
+
+func SetSSEHeaders() func(c *gin.Context) {
+	return func(c *gin.Context) {
+		c.Writer.Header().Set("Content-Type", "text/event-stream")
+		c.Writer.Header().Set("Cache-Control", "no-cache")
+		c.Writer.Header().Set("Connection", "keep-alive")
+		c.Writer.Header().Set("Transfer-Encoding", "chunked")
+		c.Next()
+	}
+}

+ 1 - 0
router/api-router.go

@@ -58,6 +58,7 @@ func SetApiRouter(router *gin.Engine) {
 		messageRoute := apiRouter.Group("/message")
 		messageRoute := apiRouter.Group("/message")
 		{
 		{
 			messageRoute.GET("/", middleware.UserAuth(), controller.GetUserMessages)
 			messageRoute.GET("/", middleware.UserAuth(), controller.GetUserMessages)
+			messageRoute.GET("/stream", middleware.UserAuth(), middleware.SetSSEHeaders(), controller.GetNewMessages)
 			messageRoute.GET("/search", middleware.UserAuth(), controller.SearchMessages)
 			messageRoute.GET("/search", middleware.UserAuth(), controller.SearchMessages)
 			messageRoute.GET("/status/:link", controller.GetMessageStatus)
 			messageRoute.GET("/status/:link", controller.GetMessageStatus)
 			messageRoute.POST("/resend/:id", middleware.UserAuth(), controller.ResendMessage)
 			messageRoute.POST("/resend/:id", middleware.UserAuth(), controller.ResendMessage)

+ 24 - 9
web/src/components/MessagesTable.js

@@ -1,12 +1,5 @@
 import React, { useEffect, useRef, useState } from 'react';
 import React, { useEffect, useRef, useState } from 'react';
-import {
-  Button,
-  Form,
-  Label,
-  Modal,
-  Pagination,
-  Table,
-} from 'semantic-ui-react';
+import { Button, Form, Label, Modal, Pagination, Table } from 'semantic-ui-react';
 import { API, openPage, showError, showSuccess, showWarning } from '../helpers';
 import { API, openPage, showError, showSuccess, showWarning } from '../helpers';
 
 
 import { ITEMS_PER_PAGE } from '../constants';
 import { ITEMS_PER_PAGE } from '../constants';
@@ -61,7 +54,7 @@ const MessagesTable = () => {
     title: '消息标题',
     title: '消息标题',
     description: '消息描述',
     description: '消息描述',
     content: '消息内容',
     content: '消息内容',
-    link: '',
+    link: ''
   }); // Message to be viewed
   }); // Message to be viewed
   const [viewModalOpen, setViewModalOpen] = useState(false);
   const [viewModalOpen, setViewModalOpen] = useState(false);
 
 
@@ -123,6 +116,17 @@ const MessagesTable = () => {
         showError(reason);
         showError(reason);
       });
       });
     checkPermission().then();
     checkPermission().then();
+    const eventSource = new EventSource('/api/message/stream');
+    eventSource.onerror = (e) => {
+      showError('服务端消息推送流连接出错!');
+    };
+    eventSource.onmessage = (e) => {
+      let newMessage = JSON.parse(e.data);
+      insertNewMessage(newMessage);
+    };
+    return () => {
+      eventSource.close();
+    };
   }, []);
   }, []);
 
 
   const viewMessage = async (id) => {
   const viewMessage = async (id) => {
@@ -203,6 +207,17 @@ const MessagesTable = () => {
     setLoading(false);
     setLoading(false);
   };
   };
 
 
+  const insertNewMessage = (message) => {
+    console.log(messages);
+    setMessages(messages => {
+        let newMessages = [message];
+        newMessages.push(...messages);
+        return newMessages;
+      }
+    );
+    setActivePage(1);
+  };
+
   const refresh = async () => {
   const refresh = async () => {
     await loadMessages(0);
     await loadMessages(0);
     setActivePage(1);
     setActivePage(1);