|
@@ -43,6 +43,28 @@ func wrapErr(err error) *dto.OpenAIErrorWithStatusCode {
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+func awsRegionPrefix(awsRegionId string) string {
|
|
|
|
|
+ parts := strings.Split(awsRegionId, "-")
|
|
|
|
|
+ regionPrefix := ""
|
|
|
|
|
+ if len(parts) > 0 {
|
|
|
|
|
+ regionPrefix = parts[0]
|
|
|
|
|
+ }
|
|
|
|
|
+ return regionPrefix
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func awsModelCanCrossRegion(awsModelId, awsRegionPrefix string) bool {
|
|
|
|
|
+ regionSet, exists := awsModelCanCrossRegionMap[awsModelId]
|
|
|
|
|
+ return exists && regionSet[awsRegionPrefix]
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func awsModelCrossRegion(awsModelId, awsRegionPrefix string) string {
|
|
|
|
|
+ modelPrefix, find := awsRegionCrossModelPrefixMap[awsRegionPrefix]
|
|
|
|
|
+ if !find {
|
|
|
|
|
+ return awsModelId
|
|
|
|
|
+ }
|
|
|
|
|
+ return modelPrefix + "." + awsModelId
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
func awsModelID(requestModel string) (string, error) {
|
|
func awsModelID(requestModel string) (string, error) {
|
|
|
if awsModelID, ok := awsModelIDMap[requestModel]; ok {
|
|
if awsModelID, ok := awsModelIDMap[requestModel]; ok {
|
|
|
return awsModelID, nil
|
|
return awsModelID, nil
|
|
@@ -62,6 +84,12 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
|
|
|
return wrapErr(errors.Wrap(err, "awsModelID")), nil
|
|
return wrapErr(errors.Wrap(err, "awsModelID")), nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
|
|
|
|
+ canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
|
|
|
|
+ if canCrossRegion {
|
|
|
|
|
+ awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
awsReq := &bedrockruntime.InvokeModelInput{
|
|
awsReq := &bedrockruntime.InvokeModelInput{
|
|
|
ModelId: aws.String(awsModelId),
|
|
ModelId: aws.String(awsModelId),
|
|
|
Accept: aws.String("application/json"),
|
|
Accept: aws.String("application/json"),
|