|
|
@@ -14,6 +14,7 @@ import (
|
|
|
"one-api/service"
|
|
|
"one-api/setting/operation_setting"
|
|
|
"one-api/types"
|
|
|
+ "strings"
|
|
|
"sync"
|
|
|
"time"
|
|
|
|
|
|
@@ -36,6 +37,26 @@ func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Hea
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+// processHeaderOverride 处理请求头覆盖,支持变量替换
|
|
|
+// 支持的变量:{api_key}
|
|
|
+func processHeaderOverride(info *common.RelayInfo) (map[string]string, error) {
|
|
|
+ headerOverride := make(map[string]string)
|
|
|
+ for k, v := range info.HeadersOverride {
|
|
|
+ str, ok := v.(string)
|
|
|
+ if !ok {
|
|
|
+ return nil, types.NewError(nil, types.ErrorCodeChannelHeaderOverrideInvalid)
|
|
|
+ }
|
|
|
+
|
|
|
+ // 替换支持的变量
|
|
|
+ if strings.Contains(str, "{api_key}") {
|
|
|
+ str = strings.ReplaceAll(str, "{api_key}", info.ApiKey)
|
|
|
+ }
|
|
|
+
|
|
|
+ headerOverride[k] = str
|
|
|
+ }
|
|
|
+ return headerOverride, nil
|
|
|
+}
|
|
|
+
|
|
|
func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
|
|
fullRequestURL, err := a.GetRequestURL(info)
|
|
|
if err != nil {
|
|
|
@@ -49,13 +70,9 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
|
|
|
return nil, fmt.Errorf("new request failed: %w", err)
|
|
|
}
|
|
|
headers := req.Header
|
|
|
- headerOverride := make(map[string]string)
|
|
|
- for k, v := range info.HeadersOverride {
|
|
|
- if str, ok := v.(string); ok {
|
|
|
- headerOverride[k] = str
|
|
|
- } else {
|
|
|
- return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid)
|
|
|
- }
|
|
|
+ headerOverride, err := processHeaderOverride(info)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
}
|
|
|
for key, value := range headerOverride {
|
|
|
headers.Set(key, value)
|
|
|
@@ -86,13 +103,9 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
|
|
|
// set form data
|
|
|
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
|
|
headers := req.Header
|
|
|
- headerOverride := make(map[string]string)
|
|
|
- for k, v := range info.HeadersOverride {
|
|
|
- if str, ok := v.(string); ok {
|
|
|
- headerOverride[k] = str
|
|
|
- } else {
|
|
|
- return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid)
|
|
|
- }
|
|
|
+ headerOverride, err := processHeaderOverride(info)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
}
|
|
|
for key, value := range headerOverride {
|
|
|
headers.Set(key, value)
|
|
|
@@ -114,6 +127,13 @@ func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
|
|
|
return nil, fmt.Errorf("get request url failed: %w", err)
|
|
|
}
|
|
|
targetHeader := http.Header{}
|
|
|
+ headerOverride, err := processHeaderOverride(info)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ for key, value := range headerOverride {
|
|
|
+ targetHeader.Set(key, value)
|
|
|
+ }
|
|
|
err = a.SetupRequestHeader(c, &targetHeader, info)
|
|
|
if err != nil {
|
|
|
return nil, fmt.Errorf("setup request header failed: %w", err)
|