CaIon пре 2 месеци
родитељ
комит
8b65623726
7 измењених фајлова са 108 додато и 113 уклоњено
  1. 2 1
      common/json.go
  2. 1 0
      model/log.go
  3. 82 15
      relay/channel/aws/adaptor.go
  4. 1 1
      relay/channel/aws/constants.go
  5. 13 0
      relay/channel/aws/dto.go
  6. 6 96
      relay/channel/aws/relay-aws.go
  7. 3 0
      types/error.go

+ 2 - 1
common/json.go

@@ -3,6 +3,7 @@ package common
 import (
 	"bytes"
 	"encoding/json"
+	"io"
 )
 
 func Unmarshal(data []byte, v any) error {
@@ -13,7 +14,7 @@ func UnmarshalJsonStr(data string, v any) error {
 	return json.Unmarshal(StringToByteSlice(data), v)
 }
 
-func DecodeJson(reader *bytes.Reader, v any) error {
+func DecodeJson(reader io.Reader, v any) error {
 	return json.NewDecoder(reader).Decode(v)
 }
 

+ 1 - 0
model/log.go

@@ -45,6 +45,7 @@ const (
 	LogTypeConsume
 	LogTypeManage
 	LogTypeSystem
+	LogTypeRefund
 	LogTypeError
 )
 

+ 82 - 15
relay/channel/aws/adaptor.go

@@ -1,14 +1,17 @@
 package aws
 
 import (
-	"errors"
 	"io"
 	"net/http"
 
+	"github.com/QuantumNous/new-api/common"
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/relay/channel/claude"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/types"
+	"github.com/aws/aws-sdk-go-v2/aws"
+	"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
+	"github.com/pkg/errors"
 
 	"github.com/gin-gonic/gin"
 )
@@ -19,7 +22,10 @@ const (
 )
 
 type Adaptor struct {
-	RequestMode int
+	AwsClient  *bedrockruntime.Client
+	AwsModelId string
+	AwsReq     any
+	IsNova     bool
 }
 
 func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
@@ -28,8 +34,6 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
 }
 
 func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
-	c.Set("request_model", request.Model)
-	c.Set("converted_request", request)
 	return request, nil
 }
 
@@ -44,7 +48,6 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
 }
 
 func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
-	a.RequestMode = RequestModeMessage
 }
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@@ -63,9 +66,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
 	// 检查是否为Nova模型
 	if isNovaModel(request.Model) {
 		novaReq := convertToNovaRequest(request)
-		c.Set("request_model", request.Model)
-		c.Set("converted_request", novaReq)
-		c.Set("is_nova_model", true)
+		a.IsNova = true
 		return novaReq, nil
 	}
 
@@ -76,9 +77,6 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
 	if err != nil {
 		return nil, err
 	}
-	c.Set("request_model", claudeReq.Model)
-	c.Set("converted_request", claudeReq)
-	c.Set("is_nova_model", false)
 	return claudeReq, err
 }
 
@@ -97,14 +95,83 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo
 }
 
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
-	return nil, nil
+	awsCli, err := newAwsClient(c, info)
+	if err != nil {
+		return nil, types.NewError(err, types.ErrorCodeChannelAwsClientError)
+	}
+	a.AwsClient = awsCli
+
+	awsModelId := awsModelID(info.UpstreamModelName)
+
+	awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
+	canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
+	if canCrossRegion {
+		awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
+	}
+
+	if isNovaModel(awsModelId) {
+		var novaReq *NovaRequest
+		err = common.DecodeJson(requestBody, &novaReq)
+		if err != nil {
+			return nil, types.NewError(errors.Wrap(err, "decode nova request fail"), types.ErrorCodeBadRequestBody)
+		}
+
+		// 使用InvokeModel API,但使用Nova格式的请求体
+		awsReq := &bedrockruntime.InvokeModelInput{
+			ModelId:     aws.String(awsModelId),
+			Accept:      aws.String("application/json"),
+			ContentType: aws.String("application/json"),
+		}
+
+		reqBody, err := common.Marshal(novaReq)
+		if err != nil {
+			return nil, types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody)
+		}
+		awsReq.Body = reqBody
+		return nil, nil
+	} else {
+		awsClaudeReq, err := formatRequest(requestBody)
+		if err != nil {
+			return nil, types.NewError(errors.Wrap(err, "format aws request fail"), types.ErrorCodeBadRequestBody)
+		}
+
+		if info.IsStream {
+			awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
+				ModelId:     aws.String(awsModelId),
+				Accept:      aws.String("application/json"),
+				ContentType: aws.String("application/json"),
+			}
+			awsReq.Body, err = common.Marshal(awsClaudeReq)
+			if err != nil {
+				return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
+			}
+			a.AwsReq = awsReq
+			return nil, nil
+		} else {
+			awsReq := &bedrockruntime.InvokeModelInput{
+				ModelId:     aws.String(awsModelId),
+				Accept:      aws.String("application/json"),
+				ContentType: aws.String("application/json"),
+			}
+			awsReq.Body, err = common.Marshal(awsClaudeReq)
+			if err != nil {
+				return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
+			}
+			a.AwsReq = awsReq
+			return nil, nil
+		}
+	}
 }
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
-	if info.IsStream {
-		err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
+	if a.IsNova {
+		err, usage = handleNovaRequest(c, info, a)
 	} else {
-		err, usage = awsHandler(c, info, a.RequestMode)
+		if info.IsStream {
+			err, usage = awsStreamHandler(c, info, a)
+		} else {
+			err, usage = awsHandler(c, info, a)
+		}
 	}
 	return
 }

+ 1 - 1
relay/channel/aws/constants.go

@@ -124,5 +124,5 @@ var ChannelName = "aws"
 
 // 判断是否为Nova模型
 func isNovaModel(modelId string) bool {
-	return strings.HasPrefix(modelId, "nova-")
+	return strings.Contains(modelId, "nova-")
 }

+ 13 - 0
relay/channel/aws/dto.go

@@ -1,6 +1,9 @@
 package aws
 
 import (
+	"io"
+
+	"github.com/QuantumNous/new-api/common"
 	"github.com/QuantumNous/new-api/dto"
 )
 
@@ -35,6 +38,16 @@ func copyRequest(req *dto.ClaudeRequest) *AwsClaudeRequest {
 	}
 }
 
+func formatRequest(requestBody io.Reader) (*AwsClaudeRequest, error) {
+	var awsClaudeRequest AwsClaudeRequest
+	err := common.DecodeJson(requestBody, &awsClaudeRequest)
+	if err != nil {
+		return nil, err
+	}
+	awsClaudeRequest.AnthropicVersion = "bedrock-2023-05-31"
+	return &awsClaudeRequest, nil
+}
+
 // NovaMessage Nova模型使用messages-v1格式
 type NovaMessage struct {
 	Role    string        `json:"role"`

+ 6 - 96
relay/channel/aws/relay-aws.go

@@ -88,50 +88,9 @@ func awsModelID(requestModel string) string {
 	return requestModel
 }
 
-func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
-	awsCli, err := newAwsClient(c, info)
-	if err != nil {
-		return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
-	}
-
-	awsModelId := awsModelID(c.GetString("request_model"))
-	// 检查是否为Nova模型
-	isNova, _ := c.Get("is_nova_model")
-	if isNova == true {
-		// Nova模型也支持跨区域
-		awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
-		canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
-		if canCrossRegion {
-			awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
-		}
-		return handleNovaRequest(c, awsCli, info, awsModelId)
-	}
-
-	// 原有的Claude处理逻辑
-	awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
-	canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
-	if canCrossRegion {
-		awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
-	}
-
-	awsReq := &bedrockruntime.InvokeModelInput{
-		ModelId:     aws.String(awsModelId),
-		Accept:      aws.String("application/json"),
-		ContentType: aws.String("application/json"),
-	}
-
-	claudeReq_, ok := c.Get("converted_request")
-	if !ok {
-		return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
-	}
-	claudeReq := claudeReq_.(*dto.ClaudeRequest)
-	awsClaudeReq := copyRequest(claudeReq)
-	awsReq.Body, err = common.Marshal(awsClaudeReq)
-	if err != nil {
-		return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
-	}
+func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
 
-	awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
+	awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput))
 	if err != nil {
 		return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, http.StatusInternalServerError), nil
 	}
@@ -156,39 +115,8 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
 	return nil, claudeInfo.Usage
 }
 
-func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
-	awsCli, err := newAwsClient(c, info)
-	if err != nil {
-		return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
-	}
-
-	awsModelId := awsModelID(c.GetString("request_model"))
-
-	awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
-	canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
-	if canCrossRegion {
-		awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
-	}
-
-	awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
-		ModelId:     aws.String(awsModelId),
-		Accept:      aws.String("application/json"),
-		ContentType: aws.String("application/json"),
-	}
-
-	claudeReq_, ok := c.Get("converted_request")
-	if !ok {
-		return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
-	}
-	claudeReq := claudeReq_.(*dto.ClaudeRequest)
-
-	awsClaudeReq := copyRequest(claudeReq)
-	awsReq.Body, err = common.Marshal(awsClaudeReq)
-	if err != nil {
-		return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
-	}
-
-	awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
+func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
+	awsResp, err := a.AwsClient.InvokeModelWithResponseStream(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelWithResponseStreamInput))
 	if err != nil {
 		return types.NewOpenAIError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeAwsInvokeError, http.StatusInternalServerError), nil
 	}
@@ -225,27 +153,9 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 }
 
 // Nova模型处理函数
-func handleNovaRequest(c *gin.Context, awsCli *bedrockruntime.Client, info *relaycommon.RelayInfo, awsModelId string) (*types.NewAPIError, *dto.Usage) {
-	novaReq_, ok := c.Get("converted_request")
-	if !ok {
-		return types.NewError(errors.New("nova request not found"), types.ErrorCodeInvalidRequest), nil
-	}
-	novaReq := novaReq_.(*NovaRequest)
-
-	// 使用InvokeModel API,但使用Nova格式的请求体
-	awsReq := &bedrockruntime.InvokeModelInput{
-		ModelId:     aws.String(awsModelId),
-		Accept:      aws.String("application/json"),
-		ContentType: aws.String("application/json"),
-	}
-
-	reqBody, err := json.Marshal(novaReq)
-	if err != nil {
-		return types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody), nil
-	}
-	awsReq.Body = reqBody
+func handleNovaRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
 
-	awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
+	awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput))
 	if err != nil {
 		return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil
 	}

+ 3 - 0
types/error.go

@@ -62,6 +62,9 @@ const (
 	ErrorCodeConvertRequestFailed  ErrorCode = "convert_request_failed"
 	ErrorCodeAccessDenied          ErrorCode = "access_denied"
 
+	// request error
+	ErrorCodeBadRequestBody ErrorCode = "bad_request_body"
+
 	// response error
 	ErrorCodeReadResponseBodyFailed ErrorCode = "read_response_body_failed"
 	ErrorCodeBadResponseStatusCode  ErrorCode = "bad_response_status_code"