zai.ts 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import { Anthropic } from "@anthropic-ai/sdk"
  2. import { createZhipu } from "zhipu-ai-provider"
  3. import { streamText, generateText, ToolSet } from "ai"
  4. import {
  5. internationalZAiModels,
  6. mainlandZAiModels,
  7. internationalZAiDefaultModelId,
  8. mainlandZAiDefaultModelId,
  9. type ModelInfo,
  10. ZAI_DEFAULT_TEMPERATURE,
  11. zaiApiLineConfigs,
  12. } from "@roo-code/types"
  13. import { type ApiHandlerOptions, shouldUseReasoningEffort } from "../../shared/api"
  14. import {
  15. convertToAiSdkMessages,
  16. convertToolsForAiSdk,
  17. consumeAiSdkStream,
  18. mapToolChoice,
  19. handleAiSdkError,
  20. } from "../transform/ai-sdk"
  21. import { applyToolCacheOptions } from "../transform/cache-breakpoints"
  22. import { ApiStream } from "../transform/stream"
  23. import { getModelParams } from "../transform/model-params"
  24. import { DEFAULT_HEADERS } from "./constants"
  25. import { BaseProvider } from "./base-provider"
  26. import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
  27. import type { RooMessage } from "../../core/task-persistence/rooMessage"
  28. import { sanitizeMessagesForProvider } from "../transform/sanitize-messages"
  29. /**
  30. * Z.ai provider using the dedicated zhipu-ai-provider package.
  31. * Provides native support for GLM-4.7 thinking mode and region-based model selection.
  32. */
  33. export class ZAiHandler extends BaseProvider implements SingleCompletionHandler {
  34. protected options: ApiHandlerOptions
  35. protected provider: ReturnType<typeof createZhipu>
  36. private isChina: boolean
  37. constructor(options: ApiHandlerOptions) {
  38. super()
  39. this.options = options
  40. this.isChina = zaiApiLineConfigs[options.zaiApiLine ?? "international_coding"].isChina
  41. this.provider = createZhipu({
  42. baseURL: zaiApiLineConfigs[options.zaiApiLine ?? "international_coding"].baseUrl,
  43. apiKey: options.zaiApiKey ?? "not-provided",
  44. headers: DEFAULT_HEADERS,
  45. })
  46. }
  47. override getModel(): { id: string; info: ModelInfo; maxTokens?: number; temperature?: number } {
  48. const models = (this.isChina ? mainlandZAiModels : internationalZAiModels) as unknown as Record<
  49. string,
  50. ModelInfo
  51. >
  52. const defaultModelId = (this.isChina ? mainlandZAiDefaultModelId : internationalZAiDefaultModelId) as string
  53. const id = this.options.apiModelId ?? defaultModelId
  54. const info = models[id] || models[defaultModelId]
  55. const params = getModelParams({
  56. format: "openai",
  57. modelId: id,
  58. model: info,
  59. settings: this.options,
  60. defaultTemperature: ZAI_DEFAULT_TEMPERATURE,
  61. })
  62. return { id, info, ...params }
  63. }
  64. /**
  65. * Get the language model for the configured model ID.
  66. */
  67. protected getLanguageModel() {
  68. const { id } = this.getModel()
  69. return this.provider(id)
  70. }
  71. /**
  72. * Get the max tokens parameter to include in the request.
  73. */
  74. protected getMaxOutputTokens(): number | undefined {
  75. const { info } = this.getModel()
  76. return this.options.modelMaxTokens || info.maxTokens || undefined
  77. }
  78. /**
  79. * Create a message stream using the AI SDK.
  80. * For GLM-4.7, passes the thinking parameter via providerOptions.
  81. */
  82. override async *createMessage(
  83. systemPrompt: string,
  84. messages: RooMessage[],
  85. metadata?: ApiHandlerCreateMessageMetadata,
  86. ): ApiStream {
  87. const { id: modelId, info, temperature } = this.getModel()
  88. const languageModel = this.getLanguageModel()
  89. const aiSdkMessages = sanitizeMessagesForProvider(messages)
  90. const openAiTools = this.convertToolsForOpenAI(metadata?.tools)
  91. const aiSdkTools = convertToolsForAiSdk(openAiTools) as ToolSet | undefined
  92. applyToolCacheOptions(aiSdkTools as Parameters<typeof applyToolCacheOptions>[0], metadata?.toolProviderOptions)
  93. const requestOptions: Parameters<typeof streamText>[0] = {
  94. model: languageModel,
  95. system: systemPrompt || undefined,
  96. messages: aiSdkMessages,
  97. temperature: this.options.modelTemperature ?? temperature ?? ZAI_DEFAULT_TEMPERATURE,
  98. maxOutputTokens: this.getMaxOutputTokens(),
  99. tools: aiSdkTools,
  100. toolChoice: mapToolChoice(metadata?.tool_choice),
  101. }
  102. // Thinking mode: pass thinking parameter via providerOptions for models that support it (e.g. GLM-4.7, GLM-5)
  103. const isThinkingModel = Array.isArray(info.supportsReasoningEffort)
  104. if (isThinkingModel) {
  105. const useReasoning = shouldUseReasoningEffort({ model: info, settings: this.options })
  106. requestOptions.providerOptions = {
  107. zhipu: {
  108. thinking: useReasoning ? { type: "enabled" } : { type: "disabled" },
  109. },
  110. }
  111. }
  112. const result = streamText(requestOptions)
  113. try {
  114. yield* consumeAiSdkStream(result)
  115. } catch (error) {
  116. throw handleAiSdkError(error, "Z.ai")
  117. }
  118. }
  119. /**
  120. * Complete a prompt using the AI SDK generateText.
  121. */
  122. async completePrompt(prompt: string): Promise<string> {
  123. const { temperature } = this.getModel()
  124. const languageModel = this.getLanguageModel()
  125. try {
  126. const { text } = await generateText({
  127. model: languageModel,
  128. prompt,
  129. maxOutputTokens: this.getMaxOutputTokens(),
  130. temperature: this.options.modelTemperature ?? temperature ?? ZAI_DEFAULT_TEMPERATURE,
  131. })
  132. return text
  133. } catch (error) {
  134. throw handleAiSdkError(error, "Z.ai")
  135. }
  136. }
  137. override isAiSdkProvider(): boolean {
  138. return true
  139. }
  140. }