浏览代码

feat: token store is done but not tested

JustSong 2 年之前
父节点
当前提交
4810db17d4
共有 9 个文件被更改,包括 270 次插入16 次删除
  1. 11 1
      README.md
  2. 97 0
      channel/token-store.go
  3. 61 0
      channel/wechat-corp-account.go
  4. 60 0
      channel/wechat-test-account.go
  5. 3 3
      common/init.go
  6. 5 0
      common/logger.go
  7. 8 0
      common/utils.go
  8. 4 0
      main.go
  9. 21 12
      model/user.go

+ 11 - 1
README.md

@@ -1,3 +1,13 @@
 # 消息推送服务
 # 消息推送服务
+> 正在用 Go 重写 Message Pusher,敬请期待!
 
 
-README 待重写。
+## TODOs
++ [ ] Token Store 测试(内存泄漏检查,不必要的拷贝的检查)
++ [ ] 添加 & 更新推送配置信息时对 Token Store 进行妥善更新
++ [ ] 微信消息推送 API
++ [ ] 支持飞书
++ [ ] 支持钉钉
++ [ ] 支持 Telegram
++ [ ] 重新编写 README
++ [ ] 推广
++ [ ] 支持从外部系统获取 Token

+ 97 - 0
channel/token-store.go

@@ -0,0 +1,97 @@
+package channel
+
+import (
+	"message-pusher/common"
+	"message-pusher/model"
+	"sync"
+	"time"
+)
+
+type TokenStoreItem interface {
+	Key() string
+	Token() string
+	Refresh()
+}
+
+type tokenStore struct {
+	Map               map[string]*TokenStoreItem
+	Mutex             sync.RWMutex
+	ExpirationSeconds int
+}
+
+var s tokenStore
+
+func TokenStoreInit() {
+	s.Map = make(map[string]*TokenStoreItem)
+	s.ExpirationSeconds = 2 * 60 * 60
+	go func() {
+		users, err := model.GetAllUsers()
+		if err != nil {
+			common.FatalLog(err.Error())
+		}
+		var items []TokenStoreItem
+		for _, user := range users {
+			if user.WeChatTestAccountId != "" {
+				item := &WeChatTestAccountTokenStoreItem{
+					AppID:     user.WeChatTestAccountId,
+					AppSecret: user.WeChatTestAccountSecret,
+				}
+				items = append(items, item)
+			}
+			if user.WeChatCorpAccountId != "" {
+				item := &WeChatCorpAccountTokenStoreItem{
+					CorpId:     user.WeChatCorpAccountId,
+					CorpSecret: user.WeChatCorpAccountSecret,
+					AgentId:    user.WeChatCorpAccountAgentId,
+				}
+				items = append(items, item)
+			}
+		}
+		s.Mutex.RLock()
+		for _, item := range items {
+			s.Map[item.Key()] = &item
+		}
+		s.Mutex.RUnlock()
+		for {
+			s.Mutex.RLock()
+			var tmpMap = make(map[string]*TokenStoreItem)
+			for k, v := range s.Map {
+				tmpMap[k] = v
+			}
+			s.Mutex.RUnlock()
+			for k := range tmpMap {
+				(*tmpMap[k]).Refresh()
+			}
+			s.Mutex.RLock()
+			// we shouldn't directly replace the old map with the new map, cause the old map's keys may already change
+			for k := range s.Map {
+				v, okay := tmpMap[k]
+				if okay {
+					s.Map[k] = v
+				}
+			}
+			sleepDuration := common.Max(s.ExpirationSeconds, 60)
+			s.Mutex.RUnlock()
+			time.Sleep(time.Duration(sleepDuration) * time.Second)
+		}
+	}()
+}
+
+func TokenStoreAddItem(item *TokenStoreItem) {
+	(*item).Refresh()
+	s.Mutex.RLock()
+	s.Map[(*item).Key()] = item
+	s.Mutex.RUnlock()
+}
+
+func TokenStoreRemoveItem(item *TokenStoreItem) {
+	s.Mutex.RLock()
+	delete(s.Map, (*item).Key())
+	s.Mutex.RUnlock()
+}
+
+func TokenStoreGetToken(key string) string {
+	s.Mutex.RLock()
+	defer s.Mutex.RUnlock()
+	return (*s.Map[key]).Token()
+}

+ 61 - 0
channel/wechat-corp-account.go

@@ -1 +1,62 @@
 package channel
 package channel
+
+import (
+	"encoding/json"
+	"fmt"
+	"message-pusher/common"
+	"net/http"
+	"time"
+)
+
+type wechatCorpAccountResponse struct {
+	ErrorCode    int    `json:"errcode"`
+	ErrorMessage string `json:"errmsg"`
+	AccessToken  string `json:"access_token"`
+	ExpiresIn    int    `json:"expires_in"`
+}
+
+type WeChatCorpAccountTokenStoreItem struct {
+	CorpId      string
+	CorpSecret  string
+	AgentId     string
+	AccessToken string
+}
+
+func (i *WeChatCorpAccountTokenStoreItem) Key() string {
+	return i.CorpId + i.AgentId + i.CorpSecret
+}
+
+func (i *WeChatCorpAccountTokenStoreItem) Token() string {
+	return i.AccessToken
+}
+
+func (i *WeChatCorpAccountTokenStoreItem) Refresh() {
+	// https://work.weixin.qq.com/api/doc/90000/90135/91039
+	client := http.Client{
+		Timeout: 5 * time.Second,
+	}
+	req, err := http.NewRequest("GET", fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid=%s&corpsecret=%s",
+		i.CorpId, i.CorpSecret), nil)
+	if err != nil {
+		common.SysError(err.Error())
+		return
+	}
+	responseData, err := client.Do(req)
+	if err != nil {
+		common.SysError("failed to refresh access token: " + err.Error())
+		return
+	}
+	defer responseData.Body.Close()
+	var res wechatCorpAccountResponse
+	err = json.NewDecoder(responseData.Body).Decode(&res)
+	if err != nil {
+		common.SysError("failed to decode wechatCorpAccountResponse: " + err.Error())
+		return
+	}
+	if res.ErrorCode != 0 {
+		common.SysError(res.ErrorMessage)
+		return
+	}
+	i.AccessToken = res.AccessToken
+	common.SysLog("access token refreshed")
+}

+ 60 - 0
channel/wechat-test-account.go

@@ -1 +1,61 @@
 package channel
 package channel
+
+import (
+	"encoding/json"
+	"fmt"
+	"message-pusher/common"
+	"net/http"
+	"time"
+)
+
+type wechatTestAccountResponse struct {
+	ErrorCode    int    `json:"errcode"`
+	ErrorMessage string `json:"errmsg"`
+	AccessToken  string `json:"access_token"`
+	ExpiresIn    int    `json:"expires_in"`
+}
+
+type WeChatTestAccountTokenStoreItem struct {
+	AppID       string
+	AppSecret   string
+	AccessToken string
+}
+
+func (i *WeChatTestAccountTokenStoreItem) Key() string {
+	return i.AppID + i.AppSecret
+}
+
+func (i *WeChatTestAccountTokenStoreItem) Token() string {
+	return i.AccessToken
+}
+
+func (i *WeChatTestAccountTokenStoreItem) Refresh() {
+	// https://developers.weixin.qq.com/doc/offiaccount/Basic_Information/Get_access_token.html
+	client := http.Client{
+		Timeout: 5 * time.Second,
+	}
+	req, err := http.NewRequest("GET", fmt.Sprintf("https://api.weixin.qq.com/cgi-bin/token?grant_type=client_credential&appid=%s&secret=%s",
+		i.AppID, i.AppSecret), nil)
+	if err != nil {
+		common.SysError(err.Error())
+		return
+	}
+	responseData, err := client.Do(req)
+	if err != nil {
+		common.SysError("failed to refresh access token: " + err.Error())
+		return
+	}
+	defer responseData.Body.Close()
+	var res wechatTestAccountResponse
+	err = json.NewDecoder(responseData.Body).Decode(&res)
+	if err != nil {
+		common.SysError("failed to decode wechatTestAccountResponse: " + err.Error())
+		return
+	}
+	if res.ErrorCode != 0 {
+		common.SysError(res.ErrorMessage)
+		return
+	}
+	i.AccessToken = res.AccessToken
+	common.SysLog("access token refreshed")
+}

+ 3 - 3
common/init.go

@@ -12,9 +12,9 @@ var (
 	Port         = flag.Int("port", 3000, "the listening port")
 	Port         = flag.Int("port", 3000, "the listening port")
 	PrintVersion = flag.Bool("version", false, "print version and exit")
 	PrintVersion = flag.Bool("version", false, "print version and exit")
 	LogDir       = flag.String("log-dir", "", "specify the log directory")
 	LogDir       = flag.String("log-dir", "", "specify the log directory")
-	//Host         = flag.String("host", "localhost", "the server's ip address or domain")
-	//Path         = flag.String("path", "", "specify a local path to public")
-	//VideoPath    = flag.String("video", "", "specify a video folder to public")
+	//Host         = flag.Key("host", "localhost", "the server's ip address or domain")
+	//Path         = flag.Key("path", "", "specify a local path to public")
+	//VideoPath    = flag.Key("video", "", "specify a video folder to public")
 	//NoBrowser    = flag.Bool("no-browser", false, "open browser or not")
 	//NoBrowser    = flag.Bool("no-browser", false, "open browser or not")
 )
 )
 
 

+ 5 - 0
common/logger.go

@@ -32,6 +32,11 @@ func SysLog(s string) {
 	_, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
 	_, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
 }
 }
 
 
+func SysError(s string) {
+	t := time.Now()
+	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
+}
+
 func FatalLog(v ...any) {
 func FatalLog(v ...any) {
 	t := time.Now()
 	t := time.Now()
 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)

+ 8 - 0
common/utils.go

@@ -131,3 +131,11 @@ func GetUUID() string {
 	code = strings.Replace(code, "-", "", -1)
 	code = strings.Replace(code, "-", "", -1)
 	return code
 	return code
 }
 }
+
+func Max(a int, b int) int {
+	if a >= b {
+		return a
+	} else {
+		return b
+	}
+}

+ 4 - 0
main.go

@@ -7,6 +7,7 @@ import (
 	"github.com/gin-contrib/sessions/redis"
 	"github.com/gin-contrib/sessions/redis"
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
 	"log"
 	"log"
+	"message-pusher/channel"
 	"message-pusher/common"
 	"message-pusher/common"
 	"message-pusher/middleware"
 	"message-pusher/middleware"
 	"message-pusher/model"
 	"message-pusher/model"
@@ -48,6 +49,9 @@ func main() {
 	// Initialize options
 	// Initialize options
 	model.InitOptionMap()
 	model.InitOptionMap()
 
 
+	// Initialize token store
+	channel.TokenStoreInit()
+
 	// Initialize HTTP server
 	// Initialize HTTP server
 	server := gin.Default()
 	server := gin.Default()
 	server.Use(middleware.CORS())
 	server.Use(middleware.CORS())

+ 21 - 12
model/user.go

@@ -6,18 +6,27 @@ import (
 )
 )
 
 
 type User struct {
 type User struct {
-	Id               int    `json:"id"`
-	Username         string `json:"username" gorm:"unique;index" validate:"max=12"`
-	Password         string `json:"password" gorm:"not null;" validate:"min=8,max=20"`
-	DisplayName      string `json:"display_name" gorm:"index" validate:"max=20"`
-	Role             int    `json:"role" gorm:"type:int;default:1"`   // admin, common
-	Status           int    `json:"status" gorm:"type:int;default:1"` // enabled, disabled
-	Token            string `json:"token"`
-	Email            string `json:"email" gorm:"index" validate:"max=50"`
-	GitHubId         string `json:"github_id" gorm:"column:github_id;index"`
-	WeChatId         string `json:"wechat_id" gorm:"column:wechat_id;index"`
-	Channel          string `json:"channel"`
-	VerificationCode string `json:"verification_code" gorm:"-:all"`
+	Id                                 int    `json:"id"`
+	Username                           string `json:"username" gorm:"unique;index" validate:"max=12"`
+	Password                           string `json:"password" gorm:"not null;" validate:"min=8,max=20"`
+	DisplayName                        string `json:"display_name" gorm:"index" validate:"max=20"`
+	Role                               int    `json:"role" gorm:"type:int;default:1"`   // admin, common
+	Status                             int    `json:"status" gorm:"type:int;default:1"` // enabled, disabled
+	Token                              string `json:"token"`
+	Email                              string `json:"email" gorm:"index" validate:"max=50"`
+	GitHubId                           string `json:"github_id" gorm:"column:github_id;index"`
+	WeChatId                           string `json:"wechat_id" gorm:"column:wechat_id;index"`
+	Channel                            string `json:"channel"`
+	VerificationCode                   string `json:"verification_code" gorm:"-:all"`
+	WeChatTestAccountId                string `json:"wechat_test_account_id" gorm:"column:wechat_test_account_id"`
+	WeChatTestAccountSecret            string `json:"wechat_test_account_secret" gorm:"column:wechat_test_account_secret"`
+	WeChatTestAccountTemplateId        string `json:"wechat_test_account_template_id" gorm:"column:wechat_test_account_template_id"`
+	WeChatTestAccountOpenId            string `json:"wechat_test_account_open_id" gorm:"column:wechat_test_account_open_id"`
+	WeChatTestAccountVerificationToken string `json:"wechat_test_account_verification_token" gorm:"column:wechat_test_account_verification_token"`
+	WeChatCorpAccountId                string `json:"wechat_corp_account_id" gorm:"column:wechat_corp_account_id"`
+	WeChatCorpAccountSecret            string `json:"wechat_corp_account_secret" gorm:"column:wechat_corp_account_secret"`
+	WeChatCorpAccountAgentId           string `json:"wechat_corp_account_agent_id" gorm:"column:wechat_corp_account_agent_id"`
+	WeChatCorpAccountUserId            string `json:"wechat_corp_account_user_id" gorm:"column:wechat_corp_account_user_id"`
 }
 }
 
 
 func GetMaxUserId() int {
 func GetMaxUserId() int {