Browse Source

Feature/vertex ai model name conversion (#5728)

Co-authored-by: Claude <[email protected]>
Co-authored-by: Daniel Riccio <[email protected]>
flameboy 6 tháng trước cách đây
mục cha
commit
9db64de363

+ 41 - 0
packages/types/src/providers/__tests__/claude-code.spec.ts

@@ -0,0 +1,41 @@
+import { describe, test, expect } from "vitest"
+import { convertModelNameForVertex, getClaudeCodeModelId } from "../claude-code.js"
+
+describe("convertModelNameForVertex", () => {
+	test("should convert hyphen-date format to @date format", () => {
+		expect(convertModelNameForVertex("claude-sonnet-4-20250514")).toBe("claude-sonnet-4@20250514")
+		expect(convertModelNameForVertex("claude-opus-4-20250514")).toBe("claude-opus-4@20250514")
+		expect(convertModelNameForVertex("claude-3-7-sonnet-20250219")).toBe("claude-3-7-sonnet@20250219")
+		expect(convertModelNameForVertex("claude-3-5-sonnet-20241022")).toBe("claude-3-5-sonnet@20241022")
+		expect(convertModelNameForVertex("claude-3-5-haiku-20241022")).toBe("claude-3-5-haiku@20241022")
+	})
+
+	test("should not modify models without date pattern", () => {
+		expect(convertModelNameForVertex("some-other-model")).toBe("some-other-model")
+		expect(convertModelNameForVertex("claude-model")).toBe("claude-model")
+		expect(convertModelNameForVertex("model-with-short-date-123")).toBe("model-with-short-date-123")
+	})
+
+	test("should only convert 8-digit date patterns at the end", () => {
+		expect(convertModelNameForVertex("claude-20250514-sonnet")).toBe("claude-20250514-sonnet")
+		expect(convertModelNameForVertex("model-20250514-with-more")).toBe("model-20250514-with-more")
+	})
+})
+
+describe("getClaudeCodeModelId", () => {
+	test("should return original model when useVertex is false", () => {
+		expect(getClaudeCodeModelId("claude-sonnet-4-20250514", false)).toBe("claude-sonnet-4-20250514")
+		expect(getClaudeCodeModelId("claude-opus-4-20250514", false)).toBe("claude-opus-4-20250514")
+		expect(getClaudeCodeModelId("claude-3-7-sonnet-20250219", false)).toBe("claude-3-7-sonnet-20250219")
+	})
+
+	test("should return converted model when useVertex is true", () => {
+		expect(getClaudeCodeModelId("claude-sonnet-4-20250514", true)).toBe("claude-sonnet-4@20250514")
+		expect(getClaudeCodeModelId("claude-opus-4-20250514", true)).toBe("claude-opus-4@20250514")
+		expect(getClaudeCodeModelId("claude-3-7-sonnet-20250219", true)).toBe("claude-3-7-sonnet@20250219")
+	})
+
+	test("should default to useVertex false when parameter not provided", () => {
+		expect(getClaudeCodeModelId("claude-sonnet-4-20250514")).toBe("claude-sonnet-4-20250514")
+	})
+})

+ 34 - 0
packages/types/src/providers/claude-code.ts

@@ -1,10 +1,44 @@
 import type { ModelInfo } from "../model.js"
 import { anthropicModels } from "./anthropic.js"
 
+// Regex pattern to match 8-digit date at the end of model names
+const VERTEX_DATE_PATTERN = /-(\d{8})$/
+
+/**
+ * Converts Claude model names from hyphen-date format to Vertex AI's @-date format.
+ *
+ * @param modelName - The original model name (e.g., "claude-sonnet-4-20250514")
+ * @returns The converted model name for Vertex AI (e.g., "claude-sonnet-4@20250514")
+ *
+ * @example
+ * convertModelNameForVertex("claude-sonnet-4-20250514") // returns "claude-sonnet-4@20250514"
+ * convertModelNameForVertex("claude-model") // returns "claude-model" (no change)
+ */
+export function convertModelNameForVertex(modelName: string): string {
+	// Convert hyphen-date format to @date format for Vertex AI
+	return modelName.replace(VERTEX_DATE_PATTERN, "@$1")
+}
+
 // Claude Code
 export type ClaudeCodeModelId = keyof typeof claudeCodeModels
 export const claudeCodeDefaultModelId: ClaudeCodeModelId = "claude-sonnet-4-20250514"
 export const CLAUDE_CODE_DEFAULT_MAX_OUTPUT_TOKENS = 8000
+
+/**
+ * Gets the appropriate model ID based on whether Vertex AI is being used.
+ *
+ * @param baseModelId - The base Claude Code model ID
+ * @param useVertex - Whether to format the model ID for Vertex AI (default: false)
+ * @returns The model ID, potentially formatted for Vertex AI
+ *
+ * @example
+ * getClaudeCodeModelId("claude-sonnet-4-20250514", true) // returns "claude-sonnet-4@20250514"
+ * getClaudeCodeModelId("claude-sonnet-4-20250514", false) // returns "claude-sonnet-4-20250514"
+ */
+export function getClaudeCodeModelId(baseModelId: ClaudeCodeModelId, useVertex = false): string {
+	return useVertex ? convertModelNameForVertex(baseModelId) : baseModelId
+}
+
 export const claudeCodeModels = {
 	"claude-sonnet-4-20250514": {
 		...anthropicModels["claude-sonnet-4-20250514"],

+ 14 - 2
src/api/providers/claude-code.ts

@@ -1,5 +1,11 @@
 import type { Anthropic } from "@anthropic-ai/sdk"
-import { claudeCodeDefaultModelId, type ClaudeCodeModelId, claudeCodeModels, type ModelInfo } from "@roo-code/types"
+import {
+	claudeCodeDefaultModelId,
+	type ClaudeCodeModelId,
+	claudeCodeModels,
+	type ModelInfo,
+	getClaudeCodeModelId,
+} from "@roo-code/types"
 import { type ApiHandler } from ".."
 import { ApiStreamUsageChunk, type ApiStream } from "../transform/stream"
 import { runClaudeCode } from "../../integrations/claude-code/run"
@@ -20,11 +26,17 @@ export class ClaudeCodeHandler extends BaseProvider implements ApiHandler {
 		// Filter out image blocks since Claude Code doesn't support them
 		const filteredMessages = filterMessagesForClaudeCode(messages)
 
+		const useVertex = process.env.CLAUDE_CODE_USE_VERTEX === "1"
+		const model = this.getModel()
+
+		// Validate that the model ID is a valid ClaudeCodeModelId
+		const modelId = model.id in claudeCodeModels ? (model.id as ClaudeCodeModelId) : claudeCodeDefaultModelId
+
 		const claudeProcess = runClaudeCode({
 			systemPrompt,
 			messages: filteredMessages,
 			path: this.options.claudeCodePath,
-			modelId: this.getModel().id,
+			modelId: getClaudeCodeModelId(modelId, useVertex),
 			maxOutputTokens: this.options.claudeCodeMaxOutputTokens,
 		})