anthropic-vertex.ts 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. import type { Anthropic } from "@anthropic-ai/sdk"
  2. import { createVertexAnthropic } from "@ai-sdk/google-vertex/anthropic"
  3. import { streamText, generateText, ToolSet } from "ai"
  4. import {
  5. type ModelInfo,
  6. type VertexModelId,
  7. vertexDefaultModelId,
  8. vertexModels,
  9. ANTHROPIC_DEFAULT_MAX_TOKENS,
  10. VERTEX_1M_CONTEXT_MODEL_IDS,
  11. ApiProviderError,
  12. } from "@roo-code/types"
  13. import { TelemetryService } from "@roo-code/telemetry"
  14. import type { ApiHandlerOptions } from "../../shared/api"
  15. import { shouldUseReasoningBudget } from "../../shared/api"
  16. import type { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
  17. import { getModelParams } from "../transform/model-params"
  18. import {
  19. convertToAiSdkMessages,
  20. convertToolsForAiSdk,
  21. processAiSdkStreamPart,
  22. mapToolChoice,
  23. handleAiSdkError,
  24. } from "../transform/ai-sdk"
  25. import { calculateApiCostAnthropic } from "../../shared/cost"
  26. import { DEFAULT_HEADERS } from "./constants"
  27. import { BaseProvider } from "./base-provider"
  28. import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
  29. // https://docs.anthropic.com/en/api/claude-on-vertex-ai
  30. export class AnthropicVertexHandler extends BaseProvider implements SingleCompletionHandler {
  31. protected options: ApiHandlerOptions
  32. private provider: ReturnType<typeof createVertexAnthropic>
  33. private readonly providerName = "Vertex (Anthropic)"
  34. private lastThoughtSignature: string | undefined
  35. private lastRedactedThinkingBlocks: Array<{ type: "redacted_thinking"; data: string }> = []
  36. constructor(options: ApiHandlerOptions) {
  37. super()
  38. this.options = options
  39. // https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude#regions
  40. const projectId = this.options.vertexProjectId ?? "not-provided"
  41. const region = this.options.vertexRegion ?? "us-east5"
  42. // Build googleAuthOptions based on provided credentials
  43. let googleAuthOptions: { credentials?: object; keyFile?: string } | undefined
  44. if (options.vertexJsonCredentials) {
  45. try {
  46. googleAuthOptions = { credentials: JSON.parse(options.vertexJsonCredentials) }
  47. } catch {
  48. // If JSON parsing fails, ignore and try other auth methods
  49. }
  50. } else if (options.vertexKeyFile) {
  51. googleAuthOptions = { keyFile: options.vertexKeyFile }
  52. }
  53. // Build beta headers for 1M context support
  54. const modelId = options.apiModelId
  55. const betas: string[] = []
  56. if (modelId) {
  57. const supports1MContext = VERTEX_1M_CONTEXT_MODEL_IDS.includes(
  58. modelId as (typeof VERTEX_1M_CONTEXT_MODEL_IDS)[number],
  59. )
  60. if (supports1MContext && options.vertex1MContext) {
  61. betas.push("context-1m-2025-08-07")
  62. }
  63. }
  64. this.provider = createVertexAnthropic({
  65. project: projectId,
  66. location: region,
  67. googleAuthOptions,
  68. headers: {
  69. ...DEFAULT_HEADERS,
  70. ...(betas.length > 0 ? { "anthropic-beta": betas.join(",") } : {}),
  71. },
  72. })
  73. }
  74. override async *createMessage(
  75. systemPrompt: string,
  76. messages: Anthropic.Messages.MessageParam[],
  77. metadata?: ApiHandlerCreateMessageMetadata,
  78. ): ApiStream {
  79. const modelConfig = this.getModel()
  80. // Reset thinking state for this request
  81. this.lastThoughtSignature = undefined
  82. this.lastRedactedThinkingBlocks = []
  83. // Convert messages to AI SDK format
  84. const aiSdkMessages = convertToAiSdkMessages(messages)
  85. // Convert tools to AI SDK format
  86. const openAiTools = this.convertToolsForOpenAI(metadata?.tools)
  87. const aiSdkTools = convertToolsForAiSdk(openAiTools) as ToolSet | undefined
  88. // Build Anthropic provider options
  89. const anthropicProviderOptions: Record<string, unknown> = {}
  90. // Configure thinking/reasoning if the model supports it
  91. const isThinkingEnabled =
  92. shouldUseReasoningBudget({ model: modelConfig.info, settings: this.options }) &&
  93. modelConfig.reasoning &&
  94. modelConfig.reasoningBudget
  95. if (isThinkingEnabled) {
  96. anthropicProviderOptions.thinking = {
  97. type: "enabled",
  98. budgetTokens: modelConfig.reasoningBudget,
  99. }
  100. }
  101. // Forward parallelToolCalls setting
  102. // When parallelToolCalls is explicitly false, disable parallel tool use
  103. if (metadata?.parallelToolCalls === false) {
  104. anthropicProviderOptions.disableParallelToolUse = true
  105. }
  106. /**
  107. * Vertex API has specific limitations for prompt caching:
  108. * 1. Maximum of 4 blocks can have cache_control
  109. * 2. Only text blocks can be cached (images and other content types cannot)
  110. * 3. Cache control can only be applied to user messages, not assistant messages
  111. *
  112. * Our caching strategy:
  113. * - Cache the system prompt (1 block)
  114. * - Cache the last text block of the second-to-last user message (1 block)
  115. * - Cache the last text block of the last user message (1 block)
  116. * This ensures we stay under the 4-block limit while maintaining effective caching
  117. * for the most relevant context.
  118. */
  119. const cacheProviderOption = { anthropic: { cacheControl: { type: "ephemeral" as const } } }
  120. const userMsgIndices = messages.reduce(
  121. (acc, msg, index) => (msg.role === "user" ? [...acc, index] : acc),
  122. [] as number[],
  123. )
  124. const targetIndices = new Set<number>()
  125. const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1
  126. const secondLastUserMsgIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1
  127. if (lastUserMsgIndex >= 0) targetIndices.add(lastUserMsgIndex)
  128. if (secondLastUserMsgIndex >= 0) targetIndices.add(secondLastUserMsgIndex)
  129. if (targetIndices.size > 0) {
  130. this.applyCacheControlToAiSdkMessages(messages, aiSdkMessages, targetIndices, cacheProviderOption)
  131. }
  132. // Build streamText request
  133. // Cast providerOptions to any to bypass strict JSONObject typing — the AI SDK accepts the correct runtime values
  134. const requestOptions: Parameters<typeof streamText>[0] = {
  135. model: this.provider(modelConfig.id),
  136. system: systemPrompt,
  137. ...({
  138. systemProviderOptions: { anthropic: { cacheControl: { type: "ephemeral" } } },
  139. } as Record<string, unknown>),
  140. messages: aiSdkMessages,
  141. temperature: modelConfig.temperature,
  142. maxOutputTokens: modelConfig.maxTokens ?? ANTHROPIC_DEFAULT_MAX_TOKENS,
  143. tools: aiSdkTools,
  144. toolChoice: mapToolChoice(metadata?.tool_choice),
  145. ...(Object.keys(anthropicProviderOptions).length > 0 && {
  146. providerOptions: { anthropic: anthropicProviderOptions } as any,
  147. }),
  148. }
  149. try {
  150. const result = streamText(requestOptions)
  151. let lastStreamError: string | undefined
  152. for await (const part of result.fullStream) {
  153. // Capture thinking signature from stream events
  154. // The AI SDK's @ai-sdk/anthropic emits the signature as a reasoning-delta
  155. // event with providerMetadata.anthropic.signature
  156. const partAny = part as any
  157. if (partAny.providerMetadata?.anthropic?.signature) {
  158. this.lastThoughtSignature = partAny.providerMetadata.anthropic.signature
  159. }
  160. // Capture redacted thinking blocks from stream events
  161. if (partAny.providerMetadata?.anthropic?.redactedData) {
  162. this.lastRedactedThinkingBlocks.push({
  163. type: "redacted_thinking",
  164. data: partAny.providerMetadata.anthropic.redactedData,
  165. })
  166. }
  167. for (const chunk of processAiSdkStreamPart(part)) {
  168. if (chunk.type === "error") {
  169. lastStreamError = chunk.message
  170. }
  171. yield chunk
  172. }
  173. }
  174. // Yield usage metrics at the end, including cache metrics from providerMetadata
  175. try {
  176. const usage = await result.usage
  177. const providerMetadata = await result.providerMetadata
  178. if (usage) {
  179. yield this.processUsageMetrics(usage, modelConfig.info, providerMetadata)
  180. }
  181. } catch (usageError) {
  182. if (lastStreamError) {
  183. throw new Error(lastStreamError)
  184. }
  185. throw usageError
  186. }
  187. } catch (error) {
  188. const errorMessage = error instanceof Error ? error.message : String(error)
  189. TelemetryService.instance.captureException(
  190. new ApiProviderError(errorMessage, this.providerName, modelConfig.id, "createMessage"),
  191. )
  192. throw handleAiSdkError(error, this.providerName)
  193. }
  194. }
  195. /**
  196. * Process usage metrics from the AI SDK response, including Anthropic's cache metrics.
  197. */
  198. private processUsageMetrics(
  199. usage: { inputTokens?: number; outputTokens?: number },
  200. info: ModelInfo,
  201. providerMetadata?: Record<string, Record<string, unknown>>,
  202. ): ApiStreamUsageChunk {
  203. const inputTokens = usage.inputTokens ?? 0
  204. const outputTokens = usage.outputTokens ?? 0
  205. // Extract cache metrics from Anthropic's providerMetadata
  206. const anthropicMeta = providerMetadata?.anthropic as
  207. | { cacheCreationInputTokens?: number; cacheReadInputTokens?: number }
  208. | undefined
  209. const cacheWriteTokens = anthropicMeta?.cacheCreationInputTokens ?? 0
  210. const cacheReadTokens = anthropicMeta?.cacheReadInputTokens ?? 0
  211. const { totalCost } = calculateApiCostAnthropic(
  212. info,
  213. inputTokens,
  214. outputTokens,
  215. cacheWriteTokens,
  216. cacheReadTokens,
  217. )
  218. return {
  219. type: "usage",
  220. inputTokens,
  221. outputTokens,
  222. cacheWriteTokens: cacheWriteTokens > 0 ? cacheWriteTokens : undefined,
  223. cacheReadTokens: cacheReadTokens > 0 ? cacheReadTokens : undefined,
  224. totalCost,
  225. }
  226. }
  227. /**
  228. * Apply cacheControl providerOptions to the correct AI SDK messages by walking
  229. * the original Anthropic messages and converted AI SDK messages in parallel.
  230. *
  231. * convertToAiSdkMessages() can split a single Anthropic user message (containing
  232. * tool_results + text) into 2 AI SDK messages (tool role + user role). This method
  233. * accounts for that split so cache control lands on the right message.
  234. */
  235. private applyCacheControlToAiSdkMessages(
  236. originalMessages: Anthropic.Messages.MessageParam[],
  237. aiSdkMessages: { role: string; providerOptions?: Record<string, Record<string, unknown>> }[],
  238. targetOriginalIndices: Set<number>,
  239. cacheProviderOption: Record<string, Record<string, unknown>>,
  240. ): void {
  241. let aiSdkIdx = 0
  242. for (let origIdx = 0; origIdx < originalMessages.length; origIdx++) {
  243. const origMsg = originalMessages[origIdx]
  244. if (typeof origMsg.content === "string") {
  245. if (targetOriginalIndices.has(origIdx) && aiSdkIdx < aiSdkMessages.length) {
  246. aiSdkMessages[aiSdkIdx].providerOptions = {
  247. ...aiSdkMessages[aiSdkIdx].providerOptions,
  248. ...cacheProviderOption,
  249. }
  250. }
  251. aiSdkIdx++
  252. } else if (origMsg.role === "user") {
  253. const hasToolResults = origMsg.content.some((part) => (part as { type: string }).type === "tool_result")
  254. const hasNonToolContent = origMsg.content.some(
  255. (part) => (part as { type: string }).type === "text" || (part as { type: string }).type === "image",
  256. )
  257. if (hasToolResults && hasNonToolContent) {
  258. const userMsgIdx = aiSdkIdx + 1
  259. if (targetOriginalIndices.has(origIdx) && userMsgIdx < aiSdkMessages.length) {
  260. aiSdkMessages[userMsgIdx].providerOptions = {
  261. ...aiSdkMessages[userMsgIdx].providerOptions,
  262. ...cacheProviderOption,
  263. }
  264. }
  265. aiSdkIdx += 2
  266. } else if (hasToolResults) {
  267. if (targetOriginalIndices.has(origIdx) && aiSdkIdx < aiSdkMessages.length) {
  268. aiSdkMessages[aiSdkIdx].providerOptions = {
  269. ...aiSdkMessages[aiSdkIdx].providerOptions,
  270. ...cacheProviderOption,
  271. }
  272. }
  273. aiSdkIdx++
  274. } else {
  275. if (targetOriginalIndices.has(origIdx) && aiSdkIdx < aiSdkMessages.length) {
  276. aiSdkMessages[aiSdkIdx].providerOptions = {
  277. ...aiSdkMessages[aiSdkIdx].providerOptions,
  278. ...cacheProviderOption,
  279. }
  280. }
  281. aiSdkIdx++
  282. }
  283. } else {
  284. aiSdkIdx++
  285. }
  286. }
  287. }
  288. getModel() {
  289. const modelId = this.options.apiModelId
  290. let id = modelId && modelId in vertexModels ? (modelId as VertexModelId) : vertexDefaultModelId
  291. let info: ModelInfo = vertexModels[id]
  292. // Check if 1M context beta should be enabled for supported models
  293. const supports1MContext = VERTEX_1M_CONTEXT_MODEL_IDS.includes(
  294. id as (typeof VERTEX_1M_CONTEXT_MODEL_IDS)[number],
  295. )
  296. const enable1MContext = supports1MContext && this.options.vertex1MContext
  297. // If 1M context beta is enabled, update the model info with tier pricing
  298. if (enable1MContext) {
  299. const tier = info.tiers?.[0]
  300. if (tier) {
  301. info = {
  302. ...info,
  303. contextWindow: tier.contextWindow,
  304. inputPrice: tier.inputPrice,
  305. outputPrice: tier.outputPrice,
  306. cacheWritesPrice: tier.cacheWritesPrice,
  307. cacheReadsPrice: tier.cacheReadsPrice,
  308. }
  309. }
  310. }
  311. const params = getModelParams({
  312. format: "anthropic",
  313. modelId: id,
  314. model: info,
  315. settings: this.options,
  316. defaultTemperature: 0,
  317. })
  318. // Build betas array for request headers (kept for backward compatibility / testing)
  319. const betas: string[] = []
  320. if (enable1MContext) {
  321. betas.push("context-1m-2025-08-07")
  322. }
  323. // The `:thinking` suffix indicates that the model is a "Hybrid"
  324. // reasoning model and that reasoning is required to be enabled.
  325. // The actual model ID honored by Anthropic's API does not have this
  326. // suffix.
  327. return {
  328. id: id.endsWith(":thinking") ? id.replace(":thinking", "") : id,
  329. info,
  330. betas: betas.length > 0 ? betas : undefined,
  331. ...params,
  332. }
  333. }
  334. async completePrompt(prompt: string): Promise<string> {
  335. const { id, temperature } = this.getModel()
  336. try {
  337. const { text } = await generateText({
  338. model: this.provider(id),
  339. prompt,
  340. maxOutputTokens: ANTHROPIC_DEFAULT_MAX_TOKENS,
  341. temperature,
  342. })
  343. return text
  344. } catch (error) {
  345. TelemetryService.instance.captureException(
  346. new ApiProviderError(
  347. error instanceof Error ? error.message : String(error),
  348. this.providerName,
  349. id,
  350. "completePrompt",
  351. ),
  352. )
  353. throw handleAiSdkError(error, this.providerName)
  354. }
  355. }
  356. /**
  357. * Returns the thinking signature captured from the last Anthropic response.
  358. * Claude models with extended thinking return a cryptographic signature
  359. * which must be round-tripped back for multi-turn conversations with tool use.
  360. */
  361. getThoughtSignature(): string | undefined {
  362. return this.lastThoughtSignature
  363. }
  364. /**
  365. * Returns any redacted thinking blocks captured from the last Anthropic response.
  366. * Anthropic returns these when safety filters trigger on reasoning content.
  367. */
  368. getRedactedThinkingBlocks(): Array<{ type: "redacted_thinking"; data: string }> | undefined {
  369. return this.lastRedactedThinkingBlocks.length > 0 ? this.lastRedactedThinkingBlocks : undefined
  370. }
  371. override isAiSdkProvider(): boolean {
  372. return true
  373. }
  374. }