Parcourir la source

feat: able to refresh baidu access token automatically (#400, close #401)

* feat:baidu channel support apiKey and secretKey

添加百度文心渠道时支持填写secretKey|apiKey或者accessToken,支持自动刷新accessToken

* fix

* fix

* fix

* chore: update implementation

---------

Co-authored-by: JustSong <[email protected]>
Co-authored-by: JustSong <[email protected]>
igophper il y a 2 ans
Parent
commit
af20063a8d
4 fichiers modifiés avec 86 ajouts et 9 suppressions
  1. 73 6
      controller/relay-baidu.go
  2. 10 1
      controller/relay-text.go
  3. 2 1
      i18n/en.json
  4. 1 1
      web/src/pages/Channel/EditChannel.js

+ 73 - 6
controller/relay-baidu.go

@@ -3,22 +3,22 @@ package controller
 import (
 	"bufio"
 	"encoding/json"
+	"errors"
+	"fmt"
 	"github.com/gin-gonic/gin"
 	"io"
 	"net/http"
 	"one-api/common"
 	"strings"
+	"sync"
+	"time"
 )
 
 // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
 
 type BaiduTokenResponse struct {
-	RefreshToken  string `json:"refresh_token"`
-	ExpiresIn     int    `json:"expires_in"`
-	SessionKey    string `json:"session_key"`
-	AccessToken   string `json:"access_token"`
-	Scope         string `json:"scope"`
-	SessionSecret string `json:"session_secret"`
+	ExpiresIn   int    `json:"expires_in"`
+	AccessToken string `json:"access_token"`
 }
 
 type BaiduMessage struct {
@@ -73,6 +73,16 @@ type BaiduEmbeddingResponse struct {
 	BaiduError
 }
 
+type BaiduAccessToken struct {
+	AccessToken      string    `json:"access_token"`
+	Error            string    `json:"error,omitempty"`
+	ErrorDescription string    `json:"error_description,omitempty"`
+	ExpiresIn        int64     `json:"expires_in,omitempty"`
+	ExpiresAt        time.Time `json:"-"`
+}
+
+var baiduTokenStore sync.Map
+
 func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
 	messages := make([]BaiduMessage, 0, len(request.Messages))
 	for _, message := range request.Messages {
@@ -295,3 +305,60 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit
 	_, err = c.Writer.Write(jsonResponse)
 	return nil, &fullTextResponse.Usage
 }
+
+func getBaiduAccessToken(apiKey string) (string, error) {
+	if val, ok := baiduTokenStore.Load(apiKey); ok {
+		var accessToken BaiduAccessToken
+		if accessToken, ok = val.(BaiduAccessToken); ok {
+			// soon this will expire
+			if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
+				go func() {
+					_, _ = getBaiduAccessTokenHelper(apiKey)
+				}()
+			}
+			return accessToken.AccessToken, nil
+		}
+	}
+	accessToken, err := getBaiduAccessTokenHelper(apiKey)
+	if err != nil {
+		return "", err
+	}
+	if accessToken == nil {
+		return "", errors.New("getBaiduAccessToken return a nil token")
+	}
+	return (*accessToken).AccessToken, nil
+}
+
+func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
+	parts := strings.Split(apiKey, "|")
+	if len(parts) != 2 {
+		return nil, errors.New("invalid baidu apikey")
+	}
+	req, err := http.NewRequest("POST", fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s",
+		parts[0], parts[1]), nil)
+	if err != nil {
+		return nil, err
+	}
+	req.Header.Add("Content-Type", "application/json")
+	req.Header.Add("Accept", "application/json")
+	res, err := impatientHTTPClient.Do(req)
+	if err != nil {
+		return nil, err
+	}
+	defer res.Body.Close()
+
+	var accessToken BaiduAccessToken
+	err = json.NewDecoder(res.Body).Decode(&accessToken)
+	if err != nil {
+		return nil, err
+	}
+	if accessToken.Error != "" {
+		return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription)
+	}
+	if accessToken.AccessToken == "" {
+		return nil, errors.New("getBaiduAccessTokenHelper get empty access token")
+	}
+	accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second)
+	baiduTokenStore.Store(apiKey, accessToken)
+	return &accessToken, nil
+}

+ 10 - 1
controller/relay-text.go

@@ -11,6 +11,7 @@ import (
 	"one-api/common"
 	"one-api/model"
 	"strings"
+	"time"
 )
 
 const (
@@ -24,9 +25,13 @@ const (
 )
 
 var httpClient *http.Client
+var impatientHTTPClient *http.Client
 
 func init() {
 	httpClient = &http.Client{}
+	impatientHTTPClient = &http.Client{
+		Timeout: 5 * time.Second,
+	}
 }
 
 func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
@@ -145,7 +150,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 		}
 		apiKey := c.Request.Header.Get("Authorization")
 		apiKey = strings.TrimPrefix(apiKey, "Bearer ")
-		fullRequestURL += "?access_token=" + apiKey // TODO: access token expire in 30 days
+		var err error
+		if apiKey, err = getBaiduAccessToken(apiKey); err != nil {
+			return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError)
+		}
+		fullRequestURL += "?access_token=" + apiKey
 	case APITypePaLM:
 		fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
 		if baseURL != "" {

+ 2 - 1
i18n/en.json

@@ -519,5 +519,6 @@
   "令牌创建成功,请在列表页面点击复制获取令牌!": "Token created successfully, please click copy on the list page to get the token!",
   "代理": "Proxy",
   "此项可选,用于通过代理站来进行 API 调用,请输入代理站地址,格式为:https://domain.com": "This is optional, used to make API calls through the proxy site, please enter the proxy site address, the format is: https://domain.com",
-  "取消密码登录将导致所有未绑定其他登录方式的用户(包括管理员)无法通过密码登录,确认取消?": "Canceling password login will cause all users (including administrators) who have not bound other login methods to be unable to log in via password, confirm cancel?"
+  "取消密码登录将导致所有未绑定其他登录方式的用户(包括管理员)无法通过密码登录,确认取消?": "Canceling password login will cause all users (including administrators) who have not bound other login methods to be unable to log in via password, confirm cancel?",
+  "按照如下格式输入:": "Enter in the following format:"
 }

+ 1 - 1
web/src/pages/Channel/EditChannel.js

@@ -355,7 +355,7 @@ const EditChannel = () => {
                 label='密钥'
                 name='key'
                 required
-                placeholder={inputs.type === 15 ? '请输入 access token,当前版本暂不支持自动刷新,请每 30 天更新一次' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')}
+                placeholder={inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')}
                 onChange={handleInputChange}
                 value={inputs.key}
                 autoComplete='new-password'