Browse Source

Add AWS cross-region inference toggle

Saoud Rizwan 1 year ago
parent
commit
ad29ff2a03

+ 21 - 1
src/api/providers/bedrock.ts

@@ -25,8 +25,28 @@ export class AwsBedrockHandler implements ApiHandler {
 	}
 	}
 
 
 	async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
 	async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
+		// cross region inference requires prefixing the model id with the region
+		let modelId: string
+		if (this.options.awsUseCrossRegionInference) {
+			let regionPrefix = (this.options.awsRegion || "").slice(0, 3)
+			switch (regionPrefix) {
+				case "us-":
+					modelId = `us.${this.getModel().id}`
+					break
+				case "eu-":
+					modelId = `eu.${this.getModel().id}`
+					break
+				default:
+					// cross region inference is not supported in this region, falling back to default model
+					modelId = this.getModel().id
+					break
+			}
+		} else {
+			modelId = this.getModel().id
+		}
+
 		const stream = await this.client.messages.create({
 		const stream = await this.client.messages.create({
-			model: this.getModel().id,
+			model: modelId,
 			max_tokens: this.getModel().info.maxTokens || 8192,
 			max_tokens: this.getModel().info.maxTokens || 8192,
 			temperature: 0,
 			temperature: 0,
 			system: systemPrompt,
 			system: systemPrompt,

+ 6 - 0
src/core/webview/ClineProvider.ts

@@ -40,6 +40,7 @@ type GlobalStateKey =
 	| "apiProvider"
 	| "apiProvider"
 	| "apiModelId"
 	| "apiModelId"
 	| "awsRegion"
 	| "awsRegion"
+	| "awsUseCrossRegionInference"
 	| "vertexProjectId"
 	| "vertexProjectId"
 	| "vertexRegion"
 	| "vertexRegion"
 	| "lastShownAnnouncementId"
 	| "lastShownAnnouncementId"
@@ -350,6 +351,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 								awsSecretKey,
 								awsSecretKey,
 								awsSessionToken,
 								awsSessionToken,
 								awsRegion,
 								awsRegion,
+								awsUseCrossRegionInference,
 								vertexProjectId,
 								vertexProjectId,
 								vertexRegion,
 								vertexRegion,
 								openAiBaseUrl,
 								openAiBaseUrl,
@@ -372,6 +374,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 							await this.storeSecret("awsSecretKey", awsSecretKey)
 							await this.storeSecret("awsSecretKey", awsSecretKey)
 							await this.storeSecret("awsSessionToken", awsSessionToken)
 							await this.storeSecret("awsSessionToken", awsSessionToken)
 							await this.updateGlobalState("awsRegion", awsRegion)
 							await this.updateGlobalState("awsRegion", awsRegion)
+							await this.updateGlobalState("awsUseCrossRegionInference", awsUseCrossRegionInference)
 							await this.updateGlobalState("vertexProjectId", vertexProjectId)
 							await this.updateGlobalState("vertexProjectId", vertexProjectId)
 							await this.updateGlobalState("vertexRegion", vertexRegion)
 							await this.updateGlobalState("vertexRegion", vertexRegion)
 							await this.updateGlobalState("openAiBaseUrl", openAiBaseUrl)
 							await this.updateGlobalState("openAiBaseUrl", openAiBaseUrl)
@@ -824,6 +827,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 			awsSecretKey,
 			awsSecretKey,
 			awsSessionToken,
 			awsSessionToken,
 			awsRegion,
 			awsRegion,
+			awsUseCrossRegionInference,
 			vertexProjectId,
 			vertexProjectId,
 			vertexRegion,
 			vertexRegion,
 			openAiBaseUrl,
 			openAiBaseUrl,
@@ -850,6 +854,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 			this.getSecret("awsSecretKey") as Promise<string | undefined>,
 			this.getSecret("awsSecretKey") as Promise<string | undefined>,
 			this.getSecret("awsSessionToken") as Promise<string | undefined>,
 			this.getSecret("awsSessionToken") as Promise<string | undefined>,
 			this.getGlobalState("awsRegion") as Promise<string | undefined>,
 			this.getGlobalState("awsRegion") as Promise<string | undefined>,
+			this.getGlobalState("awsUseCrossRegionInference") as Promise<boolean | undefined>,
 			this.getGlobalState("vertexProjectId") as Promise<string | undefined>,
 			this.getGlobalState("vertexProjectId") as Promise<string | undefined>,
 			this.getGlobalState("vertexRegion") as Promise<string | undefined>,
 			this.getGlobalState("vertexRegion") as Promise<string | undefined>,
 			this.getGlobalState("openAiBaseUrl") as Promise<string | undefined>,
 			this.getGlobalState("openAiBaseUrl") as Promise<string | undefined>,
@@ -893,6 +898,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 				awsSecretKey,
 				awsSecretKey,
 				awsSessionToken,
 				awsSessionToken,
 				awsRegion,
 				awsRegion,
+				awsUseCrossRegionInference,
 				vertexProjectId,
 				vertexProjectId,
 				vertexRegion,
 				vertexRegion,
 				openAiBaseUrl,
 				openAiBaseUrl,

+ 1 - 0
src/shared/api.ts

@@ -19,6 +19,7 @@ export interface ApiHandlerOptions {
 	awsSecretKey?: string
 	awsSecretKey?: string
 	awsSessionToken?: string
 	awsSessionToken?: string
 	awsRegion?: string
 	awsRegion?: string
+	awsUseCrossRegionInference?: boolean
 	vertexProjectId?: string
 	vertexProjectId?: string
 	vertexRegion?: string
 	vertexRegion?: string
 	openAiBaseUrl?: string
 	openAiBaseUrl?: string

+ 8 - 0
webview-ui/src/components/settings/ApiOptions.tsx

@@ -307,6 +307,14 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage, modelIdErrorMessage }:
 							{/* <VSCodeOption value="us-gov-east-1">us-gov-east-1</VSCodeOption> */}
 							{/* <VSCodeOption value="us-gov-east-1">us-gov-east-1</VSCodeOption> */}
 						</VSCodeDropdown>
 						</VSCodeDropdown>
 					</div>
 					</div>
+					<VSCodeCheckbox
+						checked={apiConfiguration?.awsUseCrossRegionInference || false}
+						onChange={(e: any) => {
+							const isChecked = e.target.checked === true
+							setApiConfiguration({ ...apiConfiguration, awsUseCrossRegionInference: isChecked })
+						}}>
+						Use cross-region inference
+					</VSCodeCheckbox>
 					<p
 					<p
 						style={{
 						style={{
 							fontSize: "12px",
 							fontSize: "12px",