|
@@ -1,14 +1,17 @@
|
|
|
package aws
|
|
package aws
|
|
|
|
|
|
|
|
import (
|
|
import (
|
|
|
- "errors"
|
|
|
|
|
"io"
|
|
"io"
|
|
|
"net/http"
|
|
"net/http"
|
|
|
|
|
|
|
|
|
|
+ "github.com/QuantumNous/new-api/common"
|
|
|
"github.com/QuantumNous/new-api/dto"
|
|
"github.com/QuantumNous/new-api/dto"
|
|
|
"github.com/QuantumNous/new-api/relay/channel/claude"
|
|
"github.com/QuantumNous/new-api/relay/channel/claude"
|
|
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
|
|
"github.com/QuantumNous/new-api/types"
|
|
"github.com/QuantumNous/new-api/types"
|
|
|
|
|
+ "github.com/aws/aws-sdk-go-v2/aws"
|
|
|
|
|
+ "github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
|
|
|
|
+ "github.com/pkg/errors"
|
|
|
|
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/gin-gonic/gin"
|
|
|
)
|
|
)
|
|
@@ -19,7 +22,10 @@ const (
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
type Adaptor struct {
|
|
type Adaptor struct {
|
|
|
- RequestMode int
|
|
|
|
|
|
|
+ AwsClient *bedrockruntime.Client
|
|
|
|
|
+ AwsModelId string
|
|
|
|
|
+ AwsReq any
|
|
|
|
|
+ IsNova bool
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
|
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
|
@@ -28,8 +34,6 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
|
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
|
|
- c.Set("request_model", request.Model)
|
|
|
|
|
- c.Set("converted_request", request)
|
|
|
|
|
return request, nil
|
|
return request, nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -44,7 +48,6 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|
|
- a.RequestMode = RequestModeMessage
|
|
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|
@@ -63,9 +66,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
|
|
// 检查是否为Nova模型
|
|
// 检查是否为Nova模型
|
|
|
if isNovaModel(request.Model) {
|
|
if isNovaModel(request.Model) {
|
|
|
novaReq := convertToNovaRequest(request)
|
|
novaReq := convertToNovaRequest(request)
|
|
|
- c.Set("request_model", request.Model)
|
|
|
|
|
- c.Set("converted_request", novaReq)
|
|
|
|
|
- c.Set("is_nova_model", true)
|
|
|
|
|
|
|
+ a.IsNova = true
|
|
|
return novaReq, nil
|
|
return novaReq, nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -76,9 +77,6 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return nil, err
|
|
return nil, err
|
|
|
}
|
|
}
|
|
|
- c.Set("request_model", claudeReq.Model)
|
|
|
|
|
- c.Set("converted_request", claudeReq)
|
|
|
|
|
- c.Set("is_nova_model", false)
|
|
|
|
|
return claudeReq, err
|
|
return claudeReq, err
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -97,14 +95,83 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
|
|
- return nil, nil
|
|
|
|
|
|
|
+ awsCli, err := newAwsClient(c, info)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, types.NewError(err, types.ErrorCodeChannelAwsClientError)
|
|
|
|
|
+ }
|
|
|
|
|
+ a.AwsClient = awsCli
|
|
|
|
|
+
|
|
|
|
|
+ awsModelId := awsModelID(info.UpstreamModelName)
|
|
|
|
|
+
|
|
|
|
|
+ awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
|
|
|
|
+ canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
|
|
|
|
+ if canCrossRegion {
|
|
|
|
|
+ awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if isNovaModel(awsModelId) {
|
|
|
|
|
+ var novaReq *NovaRequest
|
|
|
|
|
+ err = common.DecodeJson(requestBody, &novaReq)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, types.NewError(errors.Wrap(err, "decode nova request fail"), types.ErrorCodeBadRequestBody)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // 使用InvokeModel API,但使用Nova格式的请求体
|
|
|
|
|
+ awsReq := &bedrockruntime.InvokeModelInput{
|
|
|
|
|
+ ModelId: aws.String(awsModelId),
|
|
|
|
|
+ Accept: aws.String("application/json"),
|
|
|
|
|
+ ContentType: aws.String("application/json"),
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ reqBody, err := common.Marshal(novaReq)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody)
|
|
|
|
|
+ }
|
|
|
|
|
+ awsReq.Body = reqBody
|
|
|
|
|
+ return nil, nil
|
|
|
|
|
+ } else {
|
|
|
|
|
+ awsClaudeReq, err := formatRequest(requestBody)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, types.NewError(errors.Wrap(err, "format aws request fail"), types.ErrorCodeBadRequestBody)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if info.IsStream {
|
|
|
|
|
+ awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
|
|
|
|
|
+ ModelId: aws.String(awsModelId),
|
|
|
|
|
+ Accept: aws.String("application/json"),
|
|
|
|
|
+ ContentType: aws.String("application/json"),
|
|
|
|
|
+ }
|
|
|
|
|
+ awsReq.Body, err = common.Marshal(awsClaudeReq)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
|
|
|
|
|
+ }
|
|
|
|
|
+ a.AwsReq = awsReq
|
|
|
|
|
+ return nil, nil
|
|
|
|
|
+ } else {
|
|
|
|
|
+ awsReq := &bedrockruntime.InvokeModelInput{
|
|
|
|
|
+ ModelId: aws.String(awsModelId),
|
|
|
|
|
+ Accept: aws.String("application/json"),
|
|
|
|
|
+ ContentType: aws.String("application/json"),
|
|
|
|
|
+ }
|
|
|
|
|
+ awsReq.Body, err = common.Marshal(awsClaudeReq)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
|
|
|
|
|
+ }
|
|
|
|
|
+ a.AwsReq = awsReq
|
|
|
|
|
+ return nil, nil
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
|
|
- if info.IsStream {
|
|
|
|
|
- err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
|
|
|
|
|
|
|
+ if a.IsNova {
|
|
|
|
|
+ err, usage = handleNovaRequest(c, info, a)
|
|
|
} else {
|
|
} else {
|
|
|
- err, usage = awsHandler(c, info, a.RequestMode)
|
|
|
|
|
|
|
+ if info.IsStream {
|
|
|
|
|
+ err, usage = awsStreamHandler(c, info, a)
|
|
|
|
|
+ } else {
|
|
|
|
|
+ err, usage = awsHandler(c, info, a)
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
return
|
|
return
|
|
|
}
|
|
}
|