|
|
@@ -88,50 +88,9 @@ func awsModelID(requestModel string) string {
|
|
|
return requestModel
|
|
|
}
|
|
|
|
|
|
-func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
|
|
|
- awsCli, err := newAwsClient(c, info)
|
|
|
- if err != nil {
|
|
|
- return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
|
|
|
- }
|
|
|
-
|
|
|
- awsModelId := awsModelID(c.GetString("request_model"))
|
|
|
- // 检查是否为Nova模型
|
|
|
- isNova, _ := c.Get("is_nova_model")
|
|
|
- if isNova == true {
|
|
|
- // Nova模型也支持跨区域
|
|
|
- awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
|
|
- canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
|
|
- if canCrossRegion {
|
|
|
- awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
|
|
|
- }
|
|
|
- return handleNovaRequest(c, awsCli, info, awsModelId)
|
|
|
- }
|
|
|
-
|
|
|
- // 原有的Claude处理逻辑
|
|
|
- awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
|
|
- canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
|
|
- if canCrossRegion {
|
|
|
- awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
|
|
|
- }
|
|
|
-
|
|
|
- awsReq := &bedrockruntime.InvokeModelInput{
|
|
|
- ModelId: aws.String(awsModelId),
|
|
|
- Accept: aws.String("application/json"),
|
|
|
- ContentType: aws.String("application/json"),
|
|
|
- }
|
|
|
-
|
|
|
- claudeReq_, ok := c.Get("converted_request")
|
|
|
- if !ok {
|
|
|
- return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
|
|
|
- }
|
|
|
- claudeReq := claudeReq_.(*dto.ClaudeRequest)
|
|
|
- awsClaudeReq := copyRequest(claudeReq)
|
|
|
- awsReq.Body, err = common.Marshal(awsClaudeReq)
|
|
|
- if err != nil {
|
|
|
- return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
|
|
|
- }
|
|
|
+func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
|
|
|
|
|
|
- awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
|
|
|
+ awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput))
|
|
|
if err != nil {
|
|
|
return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, http.StatusInternalServerError), nil
|
|
|
}
|
|
|
@@ -156,39 +115,8 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
|
|
|
return nil, claudeInfo.Usage
|
|
|
}
|
|
|
|
|
|
-func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
|
|
|
- awsCli, err := newAwsClient(c, info)
|
|
|
- if err != nil {
|
|
|
- return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
|
|
|
- }
|
|
|
-
|
|
|
- awsModelId := awsModelID(c.GetString("request_model"))
|
|
|
-
|
|
|
- awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
|
|
- canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
|
|
- if canCrossRegion {
|
|
|
- awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
|
|
|
- }
|
|
|
-
|
|
|
- awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
|
|
|
- ModelId: aws.String(awsModelId),
|
|
|
- Accept: aws.String("application/json"),
|
|
|
- ContentType: aws.String("application/json"),
|
|
|
- }
|
|
|
-
|
|
|
- claudeReq_, ok := c.Get("converted_request")
|
|
|
- if !ok {
|
|
|
- return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
|
|
|
- }
|
|
|
- claudeReq := claudeReq_.(*dto.ClaudeRequest)
|
|
|
-
|
|
|
- awsClaudeReq := copyRequest(claudeReq)
|
|
|
- awsReq.Body, err = common.Marshal(awsClaudeReq)
|
|
|
- if err != nil {
|
|
|
- return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
|
|
|
- }
|
|
|
-
|
|
|
- awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
|
|
|
+func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
|
|
|
+ awsResp, err := a.AwsClient.InvokeModelWithResponseStream(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelWithResponseStreamInput))
|
|
|
if err != nil {
|
|
|
return types.NewOpenAIError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeAwsInvokeError, http.StatusInternalServerError), nil
|
|
|
}
|
|
|
@@ -225,27 +153,9 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|
|
}
|
|
|
|
|
|
// Nova模型处理函数
|
|
|
-func handleNovaRequest(c *gin.Context, awsCli *bedrockruntime.Client, info *relaycommon.RelayInfo, awsModelId string) (*types.NewAPIError, *dto.Usage) {
|
|
|
- novaReq_, ok := c.Get("converted_request")
|
|
|
- if !ok {
|
|
|
- return types.NewError(errors.New("nova request not found"), types.ErrorCodeInvalidRequest), nil
|
|
|
- }
|
|
|
- novaReq := novaReq_.(*NovaRequest)
|
|
|
-
|
|
|
- // 使用InvokeModel API,但使用Nova格式的请求体
|
|
|
- awsReq := &bedrockruntime.InvokeModelInput{
|
|
|
- ModelId: aws.String(awsModelId),
|
|
|
- Accept: aws.String("application/json"),
|
|
|
- ContentType: aws.String("application/json"),
|
|
|
- }
|
|
|
-
|
|
|
- reqBody, err := json.Marshal(novaReq)
|
|
|
- if err != nil {
|
|
|
- return types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody), nil
|
|
|
- }
|
|
|
- awsReq.Body = reqBody
|
|
|
+func handleNovaRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
|
|
|
|
|
|
- awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
|
|
|
+ awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput))
|
|
|
if err != nil {
|
|
|
return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil
|
|
|
}
|