|
|
@@ -1,27 +1,31 @@
|
|
|
package aws
|
|
|
|
|
|
import (
|
|
|
+ "fmt"
|
|
|
"io"
|
|
|
"net/http"
|
|
|
+ "strings"
|
|
|
|
|
|
- "github.com/QuantumNous/new-api/common"
|
|
|
"github.com/QuantumNous/new-api/dto"
|
|
|
+ "github.com/QuantumNous/new-api/relay/channel"
|
|
|
"github.com/QuantumNous/new-api/relay/channel/claude"
|
|
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
|
|
"github.com/QuantumNous/new-api/types"
|
|
|
- "github.com/aws/aws-sdk-go-v2/aws"
|
|
|
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
|
|
"github.com/pkg/errors"
|
|
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
)
|
|
|
|
|
|
+type ClientMode int
|
|
|
+
|
|
|
const (
|
|
|
- RequestModeCompletion = 1
|
|
|
- RequestModeMessage = 2
|
|
|
+ ClientModeApiKey ClientMode = iota + 1
|
|
|
+ ClientModeAKSK
|
|
|
)
|
|
|
|
|
|
type Adaptor struct {
|
|
|
+ ClientMode ClientMode
|
|
|
AwsClient *bedrockruntime.Client
|
|
|
AwsModelId string
|
|
|
AwsReq any
|
|
|
@@ -51,11 +55,25 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|
|
}
|
|
|
|
|
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|
|
- return "", nil
|
|
|
+ if info.ChannelOtherSettings.AwsKeyType == dto.AwsKeyTypeApiKey {
|
|
|
+ awsModelId := awsModelID(info.UpstreamModelName)
|
|
|
+ a.ClientMode = ClientModeApiKey
|
|
|
+ awsSecret := strings.Split(info.ApiKey, "|")
|
|
|
+ if len(awsSecret) != 2 {
|
|
|
+ return "", errors.New("invalid aws api key, should be in format of <api-key>|<region>")
|
|
|
+ }
|
|
|
+ return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/converse", awsModelId, awsSecret[1]), nil
|
|
|
+ } else {
|
|
|
+ a.ClientMode = ClientModeAKSK
|
|
|
+ return "", nil
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
|
|
claude.CommonClaudeHeadersOperation(c, req, info)
|
|
|
+ if a.ClientMode == ClientModeApiKey {
|
|
|
+ req.Set("Authorization", "Bearer "+info.ApiKey)
|
|
|
+ }
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
@@ -95,82 +113,26 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo
|
|
|
}
|
|
|
|
|
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
|
|
- awsCli, err := newAwsClient(c, info)
|
|
|
- if err != nil {
|
|
|
- return nil, types.NewError(err, types.ErrorCodeChannelAwsClientError)
|
|
|
- }
|
|
|
- a.AwsClient = awsCli
|
|
|
-
|
|
|
- awsModelId := awsModelID(info.UpstreamModelName)
|
|
|
-
|
|
|
- awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
|
|
- canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
|
|
- if canCrossRegion {
|
|
|
- awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
|
|
|
- }
|
|
|
-
|
|
|
- if isNovaModel(awsModelId) {
|
|
|
- var novaReq *NovaRequest
|
|
|
- err = common.DecodeJson(requestBody, &novaReq)
|
|
|
- if err != nil {
|
|
|
- return nil, types.NewError(errors.Wrap(err, "decode nova request fail"), types.ErrorCodeBadRequestBody)
|
|
|
- }
|
|
|
-
|
|
|
- // 使用InvokeModel API,但使用Nova格式的请求体
|
|
|
- awsReq := &bedrockruntime.InvokeModelInput{
|
|
|
- ModelId: aws.String(awsModelId),
|
|
|
- Accept: aws.String("application/json"),
|
|
|
- ContentType: aws.String("application/json"),
|
|
|
- }
|
|
|
-
|
|
|
- reqBody, err := common.Marshal(novaReq)
|
|
|
- if err != nil {
|
|
|
- return nil, types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody)
|
|
|
- }
|
|
|
- awsReq.Body = reqBody
|
|
|
- return nil, nil
|
|
|
+ if a.ClientMode == ClientModeApiKey {
|
|
|
+ return channel.DoApiRequest(a, c, info, requestBody)
|
|
|
} else {
|
|
|
- awsClaudeReq, err := formatRequest(requestBody)
|
|
|
- if err != nil {
|
|
|
- return nil, types.NewError(errors.Wrap(err, "format aws request fail"), types.ErrorCodeBadRequestBody)
|
|
|
- }
|
|
|
-
|
|
|
- if info.IsStream {
|
|
|
- awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
|
|
|
- ModelId: aws.String(awsModelId),
|
|
|
- Accept: aws.String("application/json"),
|
|
|
- ContentType: aws.String("application/json"),
|
|
|
- }
|
|
|
- awsReq.Body, err = common.Marshal(awsClaudeReq)
|
|
|
- if err != nil {
|
|
|
- return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
|
|
|
- }
|
|
|
- a.AwsReq = awsReq
|
|
|
- return nil, nil
|
|
|
- } else {
|
|
|
- awsReq := &bedrockruntime.InvokeModelInput{
|
|
|
- ModelId: aws.String(awsModelId),
|
|
|
- Accept: aws.String("application/json"),
|
|
|
- ContentType: aws.String("application/json"),
|
|
|
- }
|
|
|
- awsReq.Body, err = common.Marshal(awsClaudeReq)
|
|
|
- if err != nil {
|
|
|
- return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
|
|
|
- }
|
|
|
- a.AwsReq = awsReq
|
|
|
- return nil, nil
|
|
|
- }
|
|
|
+ return doAwsClientRequest(c, info, a, requestBody)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
|
|
- if a.IsNova {
|
|
|
- err, usage = handleNovaRequest(c, info, a)
|
|
|
+ if a.ClientMode == ClientModeApiKey {
|
|
|
+ claudeAdaptor := claude.Adaptor{}
|
|
|
+ usage, err = claudeAdaptor.DoResponse(c, resp, info)
|
|
|
} else {
|
|
|
- if info.IsStream {
|
|
|
- err, usage = awsStreamHandler(c, info, a)
|
|
|
+ if a.IsNova {
|
|
|
+ err, usage = handleNovaRequest(c, info, a)
|
|
|
} else {
|
|
|
- err, usage = awsHandler(c, info, a)
|
|
|
+ if info.IsStream {
|
|
|
+ err, usage = awsStreamHandler(c, info, a)
|
|
|
+ } else {
|
|
|
+ err, usage = awsHandler(c, info, a)
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
return
|