|
|
@@ -25,7 +25,7 @@ vi.mock("@aws-sdk/client-bedrock-runtime", () => {
|
|
|
|
|
|
import { AwsBedrockHandler } from "../bedrock"
|
|
|
import { ConverseStreamCommand, BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime"
|
|
|
-import { BEDROCK_1M_CONTEXT_MODEL_IDS } from "@roo-code/types"
|
|
|
+import { BEDROCK_1M_CONTEXT_MODEL_IDS, BEDROCK_SERVICE_TIER_MODEL_IDS, bedrockModels } from "@roo-code/types"
|
|
|
|
|
|
import type { Anthropic } from "@anthropic-ai/sdk"
|
|
|
|
|
|
@@ -755,4 +755,245 @@ describe("AwsBedrockHandler", () => {
|
|
|
expect(commandArg.modelId).toBe(`us.${BEDROCK_1M_CONTEXT_MODEL_IDS[0]}`)
|
|
|
})
|
|
|
})
|
|
|
+
|
|
|
+ describe("service tier feature", () => {
|
|
|
+ const supportedModelId = BEDROCK_SERVICE_TIER_MODEL_IDS[0] // amazon.nova-lite-v1:0
|
|
|
+
|
|
|
+ beforeEach(() => {
|
|
|
+ mockConverseStreamCommand.mockReset()
|
|
|
+ })
|
|
|
+
|
|
|
+ describe("pricing multipliers in getModel()", () => {
|
|
|
+ it("should apply FLEX tier pricing with 50% discount", () => {
|
|
|
+ const handler = new AwsBedrockHandler({
|
|
|
+ apiModelId: supportedModelId,
|
|
|
+ awsAccessKey: "test",
|
|
|
+ awsSecretKey: "test",
|
|
|
+ awsRegion: "us-east-1",
|
|
|
+ awsBedrockServiceTier: "FLEX",
|
|
|
+ })
|
|
|
+
|
|
|
+ const model = handler.getModel()
|
|
|
+ const baseModel = bedrockModels[supportedModelId as keyof typeof bedrockModels] as {
|
|
|
+ inputPrice: number
|
|
|
+ outputPrice: number
|
|
|
+ }
|
|
|
+
|
|
|
+ // FLEX tier should apply 0.5 multiplier (50% discount)
|
|
|
+ expect(model.info.inputPrice).toBe(baseModel.inputPrice * 0.5)
|
|
|
+ expect(model.info.outputPrice).toBe(baseModel.outputPrice * 0.5)
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should apply PRIORITY tier pricing with 75% premium", () => {
|
|
|
+ const handler = new AwsBedrockHandler({
|
|
|
+ apiModelId: supportedModelId,
|
|
|
+ awsAccessKey: "test",
|
|
|
+ awsSecretKey: "test",
|
|
|
+ awsRegion: "us-east-1",
|
|
|
+ awsBedrockServiceTier: "PRIORITY",
|
|
|
+ })
|
|
|
+
|
|
|
+ const model = handler.getModel()
|
|
|
+ const baseModel = bedrockModels[supportedModelId as keyof typeof bedrockModels] as {
|
|
|
+ inputPrice: number
|
|
|
+ outputPrice: number
|
|
|
+ }
|
|
|
+
|
|
|
+ // PRIORITY tier should apply 1.75 multiplier (75% premium)
|
|
|
+ expect(model.info.inputPrice).toBe(baseModel.inputPrice * 1.75)
|
|
|
+ expect(model.info.outputPrice).toBe(baseModel.outputPrice * 1.75)
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should not modify pricing for STANDARD tier", () => {
|
|
|
+ const handler = new AwsBedrockHandler({
|
|
|
+ apiModelId: supportedModelId,
|
|
|
+ awsAccessKey: "test",
|
|
|
+ awsSecretKey: "test",
|
|
|
+ awsRegion: "us-east-1",
|
|
|
+ awsBedrockServiceTier: "STANDARD",
|
|
|
+ })
|
|
|
+
|
|
|
+ const model = handler.getModel()
|
|
|
+ const baseModel = bedrockModels[supportedModelId as keyof typeof bedrockModels] as {
|
|
|
+ inputPrice: number
|
|
|
+ outputPrice: number
|
|
|
+ }
|
|
|
+
|
|
|
+ // STANDARD tier should not modify pricing (1.0 multiplier)
|
|
|
+ expect(model.info.inputPrice).toBe(baseModel.inputPrice)
|
|
|
+ expect(model.info.outputPrice).toBe(baseModel.outputPrice)
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should not apply service tier pricing for unsupported models", () => {
|
|
|
+ const unsupportedModelId = "anthropic.claude-3-5-sonnet-20241022-v2:0"
|
|
|
+ const handler = new AwsBedrockHandler({
|
|
|
+ apiModelId: unsupportedModelId,
|
|
|
+ awsAccessKey: "test",
|
|
|
+ awsSecretKey: "test",
|
|
|
+ awsRegion: "us-east-1",
|
|
|
+ awsBedrockServiceTier: "FLEX", // Try to apply FLEX tier
|
|
|
+ })
|
|
|
+
|
|
|
+ const model = handler.getModel()
|
|
|
+ const baseModel = bedrockModels[unsupportedModelId as keyof typeof bedrockModels] as {
|
|
|
+ inputPrice: number
|
|
|
+ outputPrice: number
|
|
|
+ }
|
|
|
+
|
|
|
+ // Pricing should remain unchanged for unsupported models
|
|
|
+ expect(model.info.inputPrice).toBe(baseModel.inputPrice)
|
|
|
+ expect(model.info.outputPrice).toBe(baseModel.outputPrice)
|
|
|
+ })
|
|
|
+ })
|
|
|
+
|
|
|
+ describe("service_tier parameter in API requests", () => {
|
|
|
+ it("should include service_tier as top-level parameter for supported models", async () => {
|
|
|
+ const handler = new AwsBedrockHandler({
|
|
|
+ apiModelId: supportedModelId,
|
|
|
+ awsAccessKey: "test",
|
|
|
+ awsSecretKey: "test",
|
|
|
+ awsRegion: "us-east-1",
|
|
|
+ awsBedrockServiceTier: "PRIORITY",
|
|
|
+ })
|
|
|
+
|
|
|
+ const messages: Anthropic.Messages.MessageParam[] = [
|
|
|
+ {
|
|
|
+ role: "user",
|
|
|
+ content: "Test message",
|
|
|
+ },
|
|
|
+ ]
|
|
|
+
|
|
|
+ const generator = handler.createMessage("", messages)
|
|
|
+ await generator.next() // Start the generator
|
|
|
+
|
|
|
+ // Verify the command was created with service_tier at top level
|
|
|
+ // Per AWS documentation, service_tier must be a top-level parameter, not inside additionalModelRequestFields
|
|
|
+ // https://docs.aws.amazon.com/bedrock/latest/userguide/service-tiers-inference.html
|
|
|
+ expect(mockConverseStreamCommand).toHaveBeenCalled()
|
|
|
+ const commandArg = mockConverseStreamCommand.mock.calls[0][0] as any
|
|
|
+
|
|
|
+ // service_tier should be at the top level of the payload
|
|
|
+ expect(commandArg.service_tier).toBe("PRIORITY")
|
|
|
+ // service_tier should NOT be in additionalModelRequestFields
|
|
|
+ if (commandArg.additionalModelRequestFields) {
|
|
|
+ expect(commandArg.additionalModelRequestFields.service_tier).toBeUndefined()
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should include service_tier FLEX as top-level parameter", async () => {
|
|
|
+ const handler = new AwsBedrockHandler({
|
|
|
+ apiModelId: supportedModelId,
|
|
|
+ awsAccessKey: "test",
|
|
|
+ awsSecretKey: "test",
|
|
|
+ awsRegion: "us-east-1",
|
|
|
+ awsBedrockServiceTier: "FLEX",
|
|
|
+ })
|
|
|
+
|
|
|
+ const messages: Anthropic.Messages.MessageParam[] = [
|
|
|
+ {
|
|
|
+ role: "user",
|
|
|
+ content: "Test message",
|
|
|
+ },
|
|
|
+ ]
|
|
|
+
|
|
|
+ const generator = handler.createMessage("", messages)
|
|
|
+ await generator.next() // Start the generator
|
|
|
+
|
|
|
+ expect(mockConverseStreamCommand).toHaveBeenCalled()
|
|
|
+ const commandArg = mockConverseStreamCommand.mock.calls[0][0] as any
|
|
|
+
|
|
|
+ // service_tier should be at the top level of the payload
|
|
|
+ expect(commandArg.service_tier).toBe("FLEX")
|
|
|
+ // service_tier should NOT be in additionalModelRequestFields
|
|
|
+ if (commandArg.additionalModelRequestFields) {
|
|
|
+ expect(commandArg.additionalModelRequestFields.service_tier).toBeUndefined()
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should NOT include service_tier for unsupported models", async () => {
|
|
|
+ const unsupportedModelId = "anthropic.claude-3-5-sonnet-20241022-v2:0"
|
|
|
+ const handler = new AwsBedrockHandler({
|
|
|
+ apiModelId: unsupportedModelId,
|
|
|
+ awsAccessKey: "test",
|
|
|
+ awsSecretKey: "test",
|
|
|
+ awsRegion: "us-east-1",
|
|
|
+ awsBedrockServiceTier: "PRIORITY", // Try to apply PRIORITY tier
|
|
|
+ })
|
|
|
+
|
|
|
+ const messages: Anthropic.Messages.MessageParam[] = [
|
|
|
+ {
|
|
|
+ role: "user",
|
|
|
+ content: "Test message",
|
|
|
+ },
|
|
|
+ ]
|
|
|
+
|
|
|
+ const generator = handler.createMessage("", messages)
|
|
|
+ await generator.next() // Start the generator
|
|
|
+
|
|
|
+ expect(mockConverseStreamCommand).toHaveBeenCalled()
|
|
|
+ const commandArg = mockConverseStreamCommand.mock.calls[0][0] as any
|
|
|
+
|
|
|
+ // Service tier should NOT be included for unsupported models (at top level or in additionalModelRequestFields)
|
|
|
+ expect(commandArg.service_tier).toBeUndefined()
|
|
|
+ if (commandArg.additionalModelRequestFields) {
|
|
|
+ expect(commandArg.additionalModelRequestFields.service_tier).toBeUndefined()
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should NOT include service_tier when not specified", async () => {
|
|
|
+ const handler = new AwsBedrockHandler({
|
|
|
+ apiModelId: supportedModelId,
|
|
|
+ awsAccessKey: "test",
|
|
|
+ awsSecretKey: "test",
|
|
|
+ awsRegion: "us-east-1",
|
|
|
+ // No awsBedrockServiceTier specified
|
|
|
+ })
|
|
|
+
|
|
|
+ const messages: Anthropic.Messages.MessageParam[] = [
|
|
|
+ {
|
|
|
+ role: "user",
|
|
|
+ content: "Test message",
|
|
|
+ },
|
|
|
+ ]
|
|
|
+
|
|
|
+ const generator = handler.createMessage("", messages)
|
|
|
+ await generator.next() // Start the generator
|
|
|
+
|
|
|
+ expect(mockConverseStreamCommand).toHaveBeenCalled()
|
|
|
+ const commandArg = mockConverseStreamCommand.mock.calls[0][0] as any
|
|
|
+
|
|
|
+ // Service tier should NOT be included when not specified (at top level or in additionalModelRequestFields)
|
|
|
+ expect(commandArg.service_tier).toBeUndefined()
|
|
|
+ if (commandArg.additionalModelRequestFields) {
|
|
|
+ expect(commandArg.additionalModelRequestFields.service_tier).toBeUndefined()
|
|
|
+ }
|
|
|
+ })
|
|
|
+ })
|
|
|
+
|
|
|
+ describe("service tier with cross-region inference", () => {
|
|
|
+ it("should apply service tier pricing with cross-region inference prefix", () => {
|
|
|
+ const handler = new AwsBedrockHandler({
|
|
|
+ apiModelId: supportedModelId,
|
|
|
+ awsAccessKey: "test",
|
|
|
+ awsSecretKey: "test",
|
|
|
+ awsRegion: "us-east-1",
|
|
|
+ awsUseCrossRegionInference: true,
|
|
|
+ awsBedrockServiceTier: "FLEX",
|
|
|
+ })
|
|
|
+
|
|
|
+ const model = handler.getModel()
|
|
|
+ const baseModel = bedrockModels[supportedModelId as keyof typeof bedrockModels] as {
|
|
|
+ inputPrice: number
|
|
|
+ outputPrice: number
|
|
|
+ }
|
|
|
+
|
|
|
+ // Model ID should have cross-region prefix
|
|
|
+ expect(model.id).toBe(`us.${supportedModelId}`)
|
|
|
+
|
|
|
+ // FLEX tier pricing should still be applied
|
|
|
+ expect(model.info.inputPrice).toBe(baseModel.inputPrice * 0.5)
|
|
|
+ expect(model.info.outputPrice).toBe(baseModel.outputPrice * 0.5)
|
|
|
+ })
|
|
|
+ })
|
|
|
+ })
|
|
|
})
|