|
|
@@ -3,22 +3,22 @@ package controller
|
|
|
import (
|
|
|
"bufio"
|
|
|
"encoding/json"
|
|
|
+ "errors"
|
|
|
+ "fmt"
|
|
|
"github.com/gin-gonic/gin"
|
|
|
"io"
|
|
|
"net/http"
|
|
|
"one-api/common"
|
|
|
"strings"
|
|
|
+ "sync"
|
|
|
+ "time"
|
|
|
)
|
|
|
|
|
|
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
|
|
|
|
|
|
type BaiduTokenResponse struct {
|
|
|
- RefreshToken string `json:"refresh_token"`
|
|
|
- ExpiresIn int `json:"expires_in"`
|
|
|
- SessionKey string `json:"session_key"`
|
|
|
- AccessToken string `json:"access_token"`
|
|
|
- Scope string `json:"scope"`
|
|
|
- SessionSecret string `json:"session_secret"`
|
|
|
+ ExpiresIn int `json:"expires_in"`
|
|
|
+ AccessToken string `json:"access_token"`
|
|
|
}
|
|
|
|
|
|
type BaiduMessage struct {
|
|
|
@@ -73,6 +73,16 @@ type BaiduEmbeddingResponse struct {
|
|
|
BaiduError
|
|
|
}
|
|
|
|
|
|
+type BaiduAccessToken struct {
|
|
|
+ AccessToken string `json:"access_token"`
|
|
|
+ Error string `json:"error,omitempty"`
|
|
|
+ ErrorDescription string `json:"error_description,omitempty"`
|
|
|
+ ExpiresIn int64 `json:"expires_in,omitempty"`
|
|
|
+ ExpiresAt time.Time `json:"-"`
|
|
|
+}
|
|
|
+
|
|
|
+var baiduTokenStore sync.Map
|
|
|
+
|
|
|
func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
|
|
|
messages := make([]BaiduMessage, 0, len(request.Messages))
|
|
|
for _, message := range request.Messages {
|
|
|
@@ -295,3 +305,60 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit
|
|
|
_, err = c.Writer.Write(jsonResponse)
|
|
|
return nil, &fullTextResponse.Usage
|
|
|
}
|
|
|
+
|
|
|
+func getBaiduAccessToken(apiKey string) (string, error) {
|
|
|
+ if val, ok := baiduTokenStore.Load(apiKey); ok {
|
|
|
+ var accessToken BaiduAccessToken
|
|
|
+ if accessToken, ok = val.(BaiduAccessToken); ok {
|
|
|
+ // soon this will expire
|
|
|
+ if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
|
|
|
+ go func() {
|
|
|
+ _, _ = getBaiduAccessTokenHelper(apiKey)
|
|
|
+ }()
|
|
|
+ }
|
|
|
+ return accessToken.AccessToken, nil
|
|
|
+ }
|
|
|
+ }
|
|
|
+ accessToken, err := getBaiduAccessTokenHelper(apiKey)
|
|
|
+ if err != nil {
|
|
|
+ return "", err
|
|
|
+ }
|
|
|
+ if accessToken == nil {
|
|
|
+ return "", errors.New("getBaiduAccessToken return a nil token")
|
|
|
+ }
|
|
|
+ return (*accessToken).AccessToken, nil
|
|
|
+}
|
|
|
+
|
|
|
+func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
|
|
|
+ parts := strings.Split(apiKey, "|")
|
|
|
+ if len(parts) != 2 {
|
|
|
+ return nil, errors.New("invalid baidu apikey")
|
|
|
+ }
|
|
|
+ req, err := http.NewRequest("POST", fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s",
|
|
|
+ parts[0], parts[1]), nil)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ req.Header.Add("Content-Type", "application/json")
|
|
|
+ req.Header.Add("Accept", "application/json")
|
|
|
+ res, err := impatientHTTPClient.Do(req)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ defer res.Body.Close()
|
|
|
+
|
|
|
+ var accessToken BaiduAccessToken
|
|
|
+ err = json.NewDecoder(res.Body).Decode(&accessToken)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ if accessToken.Error != "" {
|
|
|
+ return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription)
|
|
|
+ }
|
|
|
+ if accessToken.AccessToken == "" {
|
|
|
+ return nil, errors.New("getBaiduAccessTokenHelper get empty access token")
|
|
|
+ }
|
|
|
+ accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second)
|
|
|
+ baiduTokenStore.Store(apiKey, accessToken)
|
|
|
+ return &accessToken, nil
|
|
|
+}
|