Browse Source

feat: support aws bedrock api-keys-use

CaIon 2 tháng trước cách đây
mục cha
commit
f023efdbfc

+ 8 - 0
dto/channel_settings.go

@@ -16,6 +16,13 @@ const (
 	VertexKeyTypeAPIKey VertexKeyType = "api_key"
 )
 
+type AwsKeyType string
+
+const (
+	AwsKeyTypeAKSK   AwsKeyType = "ak_sk" // 默认
+	AwsKeyTypeApiKey AwsKeyType = "api_key"
+)
+
 type ChannelOtherSettings struct {
 	AzureResponsesVersion string        `json:"azure_responses_version,omitempty"`
 	VertexKeyType         VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key"
@@ -23,6 +30,7 @@ type ChannelOtherSettings struct {
 	AllowServiceTier      bool          `json:"allow_service_tier,omitempty"`      // 是否允许 service_tier 透传(默认过滤以避免额外计费)
 	DisableStore          bool          `json:"disable_store,omitempty"`           // 是否禁用 store 透传(默认允许透传,禁用后可能导致 Codex 无法使用)
 	AllowSafetyIdentifier bool          `json:"allow_safety_identifier,omitempty"` // 是否允许 safety_identifier 透传(默认过滤以保护用户隐私)
+	AwsKeyType            AwsKeyType    `json:"aws_key_type,omitempty"`
 }
 
 func (s *ChannelOtherSettings) IsOpenRouterEnterprise() bool {

+ 36 - 74
relay/channel/aws/adaptor.go

@@ -1,27 +1,31 @@
 package aws
 
 import (
+	"fmt"
 	"io"
 	"net/http"
+	"strings"
 
-	"github.com/QuantumNous/new-api/common"
 	"github.com/QuantumNous/new-api/dto"
+	"github.com/QuantumNous/new-api/relay/channel"
 	"github.com/QuantumNous/new-api/relay/channel/claude"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/types"
-	"github.com/aws/aws-sdk-go-v2/aws"
 	"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
 	"github.com/pkg/errors"
 
 	"github.com/gin-gonic/gin"
 )
 
+type ClientMode int
+
 const (
-	RequestModeCompletion = 1
-	RequestModeMessage    = 2
+	ClientModeApiKey ClientMode = iota + 1
+	ClientModeAKSK
 )
 
 type Adaptor struct {
+	ClientMode ClientMode
 	AwsClient  *bedrockruntime.Client
 	AwsModelId string
 	AwsReq     any
@@ -51,11 +55,25 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 }
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
-	return "", nil
+	if info.ChannelOtherSettings.AwsKeyType == dto.AwsKeyTypeApiKey {
+		awsModelId := awsModelID(info.UpstreamModelName)
+		a.ClientMode = ClientModeApiKey
+		awsSecret := strings.Split(info.ApiKey, "|")
+		if len(awsSecret) != 2 {
+			return "", errors.New("invalid aws api key, should be in format of <api-key>|<region>")
+		}
+		return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/converse", awsModelId, awsSecret[1]), nil
+	} else {
+		a.ClientMode = ClientModeAKSK
+		return "", nil
+	}
 }
 
 func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 	claude.CommonClaudeHeadersOperation(c, req, info)
+	if a.ClientMode == ClientModeApiKey {
+		req.Set("Authorization", "Bearer "+info.ApiKey)
+	}
 	return nil
 }
 
@@ -95,82 +113,26 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo
 }
 
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
-	awsCli, err := newAwsClient(c, info)
-	if err != nil {
-		return nil, types.NewError(err, types.ErrorCodeChannelAwsClientError)
-	}
-	a.AwsClient = awsCli
-
-	awsModelId := awsModelID(info.UpstreamModelName)
-
-	awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
-	canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
-	if canCrossRegion {
-		awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
-	}
-
-	if isNovaModel(awsModelId) {
-		var novaReq *NovaRequest
-		err = common.DecodeJson(requestBody, &novaReq)
-		if err != nil {
-			return nil, types.NewError(errors.Wrap(err, "decode nova request fail"), types.ErrorCodeBadRequestBody)
-		}
-
-		// 使用InvokeModel API,但使用Nova格式的请求体
-		awsReq := &bedrockruntime.InvokeModelInput{
-			ModelId:     aws.String(awsModelId),
-			Accept:      aws.String("application/json"),
-			ContentType: aws.String("application/json"),
-		}
-
-		reqBody, err := common.Marshal(novaReq)
-		if err != nil {
-			return nil, types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody)
-		}
-		awsReq.Body = reqBody
-		return nil, nil
+	if a.ClientMode == ClientModeApiKey {
+		return channel.DoApiRequest(a, c, info, requestBody)
 	} else {
-		awsClaudeReq, err := formatRequest(requestBody)
-		if err != nil {
-			return nil, types.NewError(errors.Wrap(err, "format aws request fail"), types.ErrorCodeBadRequestBody)
-		}
-
-		if info.IsStream {
-			awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
-				ModelId:     aws.String(awsModelId),
-				Accept:      aws.String("application/json"),
-				ContentType: aws.String("application/json"),
-			}
-			awsReq.Body, err = common.Marshal(awsClaudeReq)
-			if err != nil {
-				return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
-			}
-			a.AwsReq = awsReq
-			return nil, nil
-		} else {
-			awsReq := &bedrockruntime.InvokeModelInput{
-				ModelId:     aws.String(awsModelId),
-				Accept:      aws.String("application/json"),
-				ContentType: aws.String("application/json"),
-			}
-			awsReq.Body, err = common.Marshal(awsClaudeReq)
-			if err != nil {
-				return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
-			}
-			a.AwsReq = awsReq
-			return nil, nil
-		}
+		return doAwsClientRequest(c, info, a, requestBody)
 	}
 }
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
-	if a.IsNova {
-		err, usage = handleNovaRequest(c, info, a)
+	if a.ClientMode == ClientModeApiKey {
+		claudeAdaptor := claude.Adaptor{}
+		usage, err = claudeAdaptor.DoResponse(c, resp, info)
 	} else {
-		if info.IsStream {
-			err, usage = awsStreamHandler(c, info, a)
+		if a.IsNova {
+			err, usage = handleNovaRequest(c, info, a)
 		} else {
-			err, usage = awsHandler(c, info, a)
+			if info.IsStream {
+				err, usage = awsStreamHandler(c, info, a)
+			} else {
+				err, usage = awsHandler(c, info, a)
+			}
 		}
 	}
 	return

+ 70 - 9
relay/channel/aws/relay-aws.go

@@ -3,6 +3,7 @@ package aws
 import (
 	"encoding/json"
 	"fmt"
+	"io"
 	"net/http"
 	"strings"
 
@@ -49,12 +50,72 @@ func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.
 	return client, nil
 }
 
-func wrapErr(err error) *dto.OpenAIErrorWithStatusCode {
-	return &dto.OpenAIErrorWithStatusCode{
-		StatusCode: http.StatusInternalServerError,
-		Error: dto.OpenAIError{
-			Message: fmt.Sprintf("%s", err.Error()),
-		},
+func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor, requestBody io.Reader) (any, error) {
+	awsCli, err := newAwsClient(c, info)
+	if err != nil {
+		return nil, types.NewError(err, types.ErrorCodeChannelAwsClientError)
+	}
+	a.AwsClient = awsCli
+
+	awsModelId := awsModelID(info.UpstreamModelName)
+
+	awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
+	canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
+	if canCrossRegion {
+		awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
+	}
+
+	if isNovaModel(awsModelId) {
+		var novaReq *NovaRequest
+		err = common.DecodeJson(requestBody, &novaReq)
+		if err != nil {
+			return nil, types.NewError(errors.Wrap(err, "decode nova request fail"), types.ErrorCodeBadRequestBody)
+		}
+
+		// 使用InvokeModel API,但使用Nova格式的请求体
+		awsReq := &bedrockruntime.InvokeModelInput{
+			ModelId:     aws.String(awsModelId),
+			Accept:      aws.String("application/json"),
+			ContentType: aws.String("application/json"),
+		}
+
+		reqBody, err := common.Marshal(novaReq)
+		if err != nil {
+			return nil, types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody)
+		}
+		awsReq.Body = reqBody
+		return nil, nil
+	} else {
+		awsClaudeReq, err := formatRequest(requestBody)
+		if err != nil {
+			return nil, types.NewError(errors.Wrap(err, "format aws request fail"), types.ErrorCodeBadRequestBody)
+		}
+
+		if info.IsStream {
+			awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
+				ModelId:     aws.String(awsModelId),
+				Accept:      aws.String("application/json"),
+				ContentType: aws.String("application/json"),
+			}
+			awsReq.Body, err = common.Marshal(awsClaudeReq)
+			if err != nil {
+				return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
+			}
+			a.AwsReq = awsReq
+			return nil, nil
+		} else {
+			awsReq := &bedrockruntime.InvokeModelInput{
+				ModelId:     aws.String(awsModelId),
+				Accept:      aws.String("application/json"),
+				ContentType: aws.String("application/json"),
+			}
+			awsReq.Body, err = common.Marshal(awsClaudeReq)
+			if err != nil {
+				return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
+			}
+			a.AwsReq = awsReq
+			return nil, nil
+		}
 	}
 }
 
@@ -108,7 +169,7 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types
 		c.Writer.Header().Set("Content-Type", *awsResp.ContentType)
 	}
 
-	handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, nil, awsResp.Body, RequestModeMessage)
+	handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, nil, awsResp.Body, claude.RequestModeMessage)
 	if handlerErr != nil {
 		return handlerErr, nil
 	}
@@ -135,7 +196,7 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (
 		switch v := event.(type) {
 		case *bedrockruntimeTypes.ResponseStreamMemberChunk:
 			info.SetFirstResponseTime()
-			respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), RequestModeMessage)
+			respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), claude.RequestModeMessage)
 			if respErr != nil {
 				return respErr, nil
 			}
@@ -148,7 +209,7 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (
 		}
 	}
 
-	claude.HandleStreamFinalResponse(c, info, claudeInfo, RequestModeMessage)
+	claude.HandleStreamFinalResponse(c, info, claudeInfo, claude.RequestModeMessage)
 	return nil, claudeInfo.Usage
 }
 

+ 54 - 2
web/src/components/table/channels/modals/EditChannelModal.jsx

@@ -153,6 +153,8 @@ const EditChannelModal = (props) => {
     settings: '',
     // 仅 Vertex: 密钥格式(存入 settings.vertex_key_type)
     vertex_key_type: 'json',
+    // 仅 AWS: 密钥格式和区域(存入 settings.aws_key_type 和 settings.aws_region)
+    aws_key_type: 'ak_sk',
     // 企业账户设置
     is_enterprise_account: false,
     // 字段透传控制默认值
@@ -515,6 +517,8 @@ const EditChannelModal = (props) => {
             parsedSettings.azure_responses_version || '';
           // 读取 Vertex 密钥格式
           data.vertex_key_type = parsedSettings.vertex_key_type || 'json';
+          // 读取 AWS 密钥格式和区域
+          data.aws_key_type = parsedSettings.aws_key_type || 'ak_sk';
           // 读取企业账户设置
           data.is_enterprise_account =
             parsedSettings.openrouter_enterprise === true;
@@ -528,6 +532,7 @@ const EditChannelModal = (props) => {
           data.azure_responses_version = '';
           data.region = '';
           data.vertex_key_type = 'json';
+          data.aws_key_type = 'ak_sk';
           data.is_enterprise_account = false;
           data.allow_service_tier = false;
           data.disable_store = false;
@@ -536,6 +541,7 @@ const EditChannelModal = (props) => {
       } else {
         // 兼容历史数据:老渠道没有 settings 时,默认按 json 展示
         data.vertex_key_type = 'json';
+        data.aws_key_type = 'ak_sk';
         data.is_enterprise_account = false;
         data.allow_service_tier = false;
         data.disable_store = false;
@@ -997,6 +1003,11 @@ const EditChannelModal = (props) => {
         localInputs.is_enterprise_account === true;
     }
 
+    // type === 33 (AWS): 保存 aws_key_type 到 settings
+    if (localInputs.type === 33) {
+      settings.aws_key_type = localInputs.aws_key_type || 'ak_sk';
+    }
+
     // type === 1 (OpenAI) 或 type === 14 (Claude): 设置字段透传控制(显式保存布尔值)
     if (localInputs.type === 1 || localInputs.type === 14) {
       settings.allow_service_tier = localInputs.allow_service_tier === true;
@@ -1020,6 +1031,8 @@ const EditChannelModal = (props) => {
     delete localInputs.is_enterprise_account;
     // 顶层的 vertex_key_type 不应发送给后端
     delete localInputs.vertex_key_type;
+    // 顶层的 aws_key_type 不应发送给后端
+    delete localInputs.aws_key_type;
     // 清理字段透传控制的临时字段
     delete localInputs.allow_service_tier;
     delete localInputs.disable_store;
@@ -1468,6 +1481,31 @@ const EditChannelModal = (props) => {
                       autoComplete='new-password'
                     />
 
+                    {inputs.type === 33 && (
+                      <>
+                        <Form.Select
+                          field='aws_key_type'
+                          label={t('密钥格式')}
+                          placeholder={t('请选择密钥格式')}
+                          optionList={[
+                            {
+                              label: 'Access Key ID / Secret Access Key',
+                              value: 'ak_sk',
+                            },
+                            { label: 'API Key', value: 'api_key' },
+                          ]}
+                          style={{ width: '100%' }}
+                          value={inputs.aws_key_type || 'ak_sk'}
+                          onChange={(value) => {
+                            handleChannelOtherSettingsChange('aws_key_type', value);
+                          }}
+                          extraText={t(
+                            'AK/SK 模式:使用 Access Key ID 和 Secret Access Key;API Key 模式:使用 API Key',
+                          )}
+                        />
+                      </>
+                    )}
+
                     {inputs.type === 41 && (
                       <Form.Select
                         field='vertex_key_type'
@@ -1536,7 +1574,15 @@ const EditChannelModal = (props) => {
                         <Form.TextArea
                           field='key'
                           label={t('密钥')}
-                          placeholder={t('请输入密钥,一行一个')}
+                          placeholder={
+                            inputs.type === 33
+                              ? inputs.aws_key_type === 'api_key'
+                                ? t('请输入 API Key,一行一个,格式:API Key|Region')
+                                : t(
+                                    '请输入密钥,一行一个,格式:Access Key ID|Secret Access Key|Region',
+                                  )
+                              : t('请输入密钥,一行一个')
+                          }
                           rules={
                             isEdit
                               ? []
@@ -1730,7 +1776,13 @@ const EditChannelModal = (props) => {
                                 ? t('密钥(编辑模式下,保存的密钥不会显示)')
                                 : t('密钥')
                             }
-                            placeholder={t(type2secretPrompt(inputs.type))}
+                            placeholder={
+                              inputs.type === 33
+                                ? inputs.aws_key_type === 'api_key'
+                                  ? t('请输入 API Key,格式:API Key|Region')
+                                  : t('按照如下格式输入:Access Key ID|Secret Access Key|Region')
+                                : t(type2secretPrompt(inputs.type))
+                            }
                             rules={
                               isEdit
                                 ? []