Explorar o código

Merge pull request #872 from neotf/main

feat: support AWS Model CrossRegion
Calcium-Ion hai 9 meses
pai
achega
b5aa3c129b
Modificáronse 2 ficheiros con 65 adicións e 0 borrados
  1. 37 0
      relay/channel/aws/constants.go
  2. 28 0
      relay/channel/aws/relay-aws.go

+ 37 - 0
relay/channel/aws/constants.go

@@ -13,4 +13,41 @@ var awsModelIDMap = map[string]string{
 	"claude-3-7-sonnet-20250219": "anthropic.claude-3-7-sonnet-20250219-v1:0",
 }
 
+var awsModelCanCrossRegionMap = map[string]map[string]bool{
+	"anthropic.claude-3-sonnet-20240229-v1:0": {
+		"us": true,
+		"eu": true,
+		"ap": true,
+	},
+	"anthropic.claude-3-opus-20240229-v1:0": {
+		"us": true,
+	},
+	"anthropic.claude-3-haiku-20240307-v1:0": {
+		"us": true,
+		"eu": true,
+		"ap": true,
+	},
+	"anthropic.claude-3-5-sonnet-20240620-v1:0": {
+		"us": true,
+		"eu": true,
+		"ap": true,
+	},
+	"anthropic.claude-3-5-sonnet-20241022-v2:0": {
+		"us": true,
+		"ap": true,
+	},
+	"anthropic.claude-3-5-haiku-20241022-v1:0": {
+		"us": true,
+	},
+	"anthropic.claude-3-7-sonnet-20250219-v1:0": {
+		"us": true,
+	},
+}
+
+var awsRegionCrossModelPrefixMap = map[string]string{
+	"us": "us",
+	"eu": "eu",
+	"ap": "apac",
+}
+
 var ChannelName = "aws"

+ 28 - 0
relay/channel/aws/relay-aws.go

@@ -43,6 +43,28 @@ func wrapErr(err error) *dto.OpenAIErrorWithStatusCode {
 	}
 }
 
+func awsRegionPrefix(awsRegionId string) string {
+	parts := strings.Split(awsRegionId, "-")
+	regionPrefix := ""
+	if len(parts) > 0 {
+		regionPrefix = parts[0]
+	}
+	return regionPrefix
+}
+
+func awsModelCanCrossRegion(awsModelId, awsRegionPrefix string) bool {
+	regionSet, exists := awsModelCanCrossRegionMap[awsModelId]
+	return exists && regionSet[awsRegionPrefix]
+}
+
+func awsModelCrossRegion(awsModelId, awsRegionPrefix string) string {
+	modelPrefix, find := awsRegionCrossModelPrefixMap[awsRegionPrefix]
+	if !find {
+		return awsModelId
+	}
+	return modelPrefix + "." + awsModelId
+}
+
 func awsModelID(requestModel string) (string, error) {
 	if awsModelID, ok := awsModelIDMap[requestModel]; ok {
 		return awsModelID, nil
@@ -62,6 +84,12 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
 		return wrapErr(errors.Wrap(err, "awsModelID")), nil
 	}
 
+	awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
+	canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
+	if canCrossRegion {
+		awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
+	}
+
 	awsReq := &bedrockruntime.InvokeModelInput{
 		ModelId:     aws.String(awsModelId),
 		Accept:      aws.String("application/json"),