Parcourir la source

refactor: migrate zai provider to AI SDK (#11263)

* refactor: migrate zai provider to AI SDK using zhipu-ai-provider

* Update src/api/providers/zai.ts

Co-authored-by: roomote[bot] <219738659+roomote[bot]@users.noreply.github.com>

* fix: remove unused zai-format.ts (knip)

---------

Co-authored-by: roomote[bot] <219738659+roomote[bot]@users.noreply.github.com>
Daniel il y a 5 jours
Parent
commit
f179ba1b9e
5 fichiers modifiés avec 387 ajouts et 545 suppressions
  1. 15 1
      pnpm-lock.yaml
  2. 240 233
      src/api/providers/__tests__/zai.spec.ts
  3. 131 69
      src/api/providers/zai.ts
  4. 0 242
      src/api/transform/zai-format.ts
  5. 1 0
      src/package.json

+ 15 - 1
pnpm-lock.yaml

@@ -1019,6 +1019,9 @@ importers:
       yaml:
         specifier: ^2.8.0
         version: 2.8.0
+      zhipu-ai-provider:
+        specifier: ^0.2.2
+        version: 0.2.2([email protected])
       zod:
         specifier: 3.25.76
         version: 3.25.76
@@ -10983,6 +10986,10 @@ packages:
   [email protected]:
     resolution: {integrity: sha512-0LPOt3AxKqMdFBZA3HBAt/t/8vIKq7VaQYbuA8WxCgung+p9TVyKRYdpvCb80HcdTN2NkbIKbhNwKUfm3tQywQ==}
 
+  [email protected]:
+    resolution: {integrity: sha512-UjX1ho4DI9ICUv/mrpAnzmrRe5/LXrGkS5hF6h4WDY2aup5GketWWopFzWYCqsbArXAM5wbzzdH9QzZusgGiBg==}
+    engines: {node: '>=18'}
+
   [email protected]:
     resolution: {integrity: sha512-9qv4rlDiopXg4E69k+vMHjNN63YFMe9sZMrdlvKnCjlCRWeCBswPPMPUfx+ipsAWq1LXHe70RcbaHdJJpS6hyQ==}
     engines: {node: '>= 10'}
@@ -14951,7 +14958,7 @@ snapshots:
       sirv: 3.0.1
       tinyglobby: 0.2.14
       tinyrainbow: 2.0.0
-      vitest: 3.2.4(@types/[email protected])(@types/node@24.2.1)(@vitest/[email protected])([email protected])([email protected])([email protected])([email protected])([email protected])
+      vitest: 3.2.4(@types/[email protected])(@types/node@20.17.50)(@vitest/[email protected])([email protected])([email protected])([email protected])([email protected])([email protected])
 
   '@vitest/[email protected]':
     dependencies:
@@ -22242,6 +22249,13 @@ snapshots:
 
   [email protected]: {}
 
+  [email protected]([email protected]):
+    dependencies:
+      '@ai-sdk/provider': 2.0.1
+      '@ai-sdk/provider-utils': 3.0.20([email protected])
+    transitivePeerDependencies:
+      - zod
+
   [email protected]:
     dependencies:
       archiver-utils: 3.0.4

+ 240 - 233
src/api/providers/__tests__/zai.spec.ts

@@ -1,7 +1,30 @@
 // npx vitest run src/api/providers/__tests__/zai.spec.ts
 
-import OpenAI from "openai"
-import { Anthropic } from "@anthropic-ai/sdk"
+// Use vi.hoisted to define mock functions that can be referenced in hoisted vi.mock() calls
+const { mockStreamText, mockGenerateText } = vi.hoisted(() => ({
+	mockStreamText: vi.fn(),
+	mockGenerateText: vi.fn(),
+}))
+
+vi.mock("ai", async (importOriginal) => {
+	const actual = await importOriginal<typeof import("ai")>()
+	return {
+		...actual,
+		streamText: mockStreamText,
+		generateText: mockGenerateText,
+	}
+})
+
+vi.mock("zhipu-ai-provider", () => ({
+	createZhipu: vi.fn(() => {
+		return vi.fn(() => ({
+			modelId: "glm-4.6",
+			provider: "zhipu",
+		}))
+	}),
+}))
+
+import type { Anthropic } from "@anthropic-ai/sdk"
 
 import {
 	type InternationalZAiModelId,
@@ -13,42 +36,41 @@ import {
 	ZAI_DEFAULT_TEMPERATURE,
 } from "@roo-code/types"
 
-import { ZAiHandler } from "../zai"
+import type { ApiHandlerOptions } from "../../../shared/api"
 
-vitest.mock("openai", () => {
-	const createMock = vitest.fn()
-	return {
-		default: vitest.fn(() => ({ chat: { completions: { create: createMock } } })),
-	}
-})
+import { ZAiHandler } from "../zai"
 
 describe("ZAiHandler", () => {
 	let handler: ZAiHandler
-	let mockCreate: any
+	let mockOptions: ApiHandlerOptions
 
 	beforeEach(() => {
-		vitest.clearAllMocks()
-		mockCreate = (OpenAI as unknown as any)().chat.completions.create
+		mockOptions = {
+			zaiApiKey: "test-zai-api-key",
+			zaiApiLine: "international_coding",
+			apiModelId: "glm-4.6",
+		}
+		handler = new ZAiHandler(mockOptions)
+		vi.clearAllMocks()
 	})
 
-	describe("International Z AI", () => {
-		beforeEach(() => {
-			handler = new ZAiHandler({ zaiApiKey: "test-zai-api-key", zaiApiLine: "international_coding" })
+	describe("constructor", () => {
+		it("should initialize with provided options", () => {
+			expect(handler).toBeInstanceOf(ZAiHandler)
+			expect(handler.getModel().id).toBe(mockOptions.apiModelId)
 		})
 
-		it("should use the correct international Z AI base URL", () => {
-			new ZAiHandler({ zaiApiKey: "test-zai-api-key", zaiApiLine: "international_coding" })
-			expect(OpenAI).toHaveBeenCalledWith(
-				expect.objectContaining({
-					baseURL: "https://api.z.ai/api/coding/paas/v4",
-				}),
-			)
+		it("should default to international when no zaiApiLine is specified", () => {
+			const handlerDefault = new ZAiHandler({ zaiApiKey: "test-zai-api-key" })
+			const model = handlerDefault.getModel()
+			expect(model.id).toBe(internationalZAiDefaultModelId)
+			expect(model.info).toEqual(internationalZAiModels[internationalZAiDefaultModelId])
 		})
+	})
 
-		it("should use the provided API key for international", () => {
-			const zaiApiKey = "test-zai-api-key"
-			new ZAiHandler({ zaiApiKey, zaiApiLine: "international_coding" })
-			expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: zaiApiKey }))
+	describe("International Z AI", () => {
+		beforeEach(() => {
+			handler = new ZAiHandler({ zaiApiKey: "test-zai-api-key", zaiApiLine: "international_coding" })
 		})
 
 		it("should return international default model when no model is specified", () => {
@@ -119,19 +141,6 @@ describe("ZAiHandler", () => {
 			handler = new ZAiHandler({ zaiApiKey: "test-zai-api-key", zaiApiLine: "china_coding" })
 		})
 
-		it("should use the correct China Z AI base URL", () => {
-			new ZAiHandler({ zaiApiKey: "test-zai-api-key", zaiApiLine: "china_coding" })
-			expect(OpenAI).toHaveBeenCalledWith(
-				expect.objectContaining({ baseURL: "https://open.bigmodel.cn/api/coding/paas/v4" }),
-			)
-		})
-
-		it("should use the provided API key for China", () => {
-			const zaiApiKey = "test-zai-api-key"
-			new ZAiHandler({ zaiApiKey, zaiApiLine: "china_coding" })
-			expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: zaiApiKey }))
-		})
-
 		it("should return China default model when no model is specified", () => {
 			const model = handler.getModel()
 			expect(model.id).toBe(mainlandZAiDefaultModelId)
@@ -200,21 +209,6 @@ describe("ZAiHandler", () => {
 			handler = new ZAiHandler({ zaiApiKey: "test-zai-api-key", zaiApiLine: "international_api" })
 		})
 
-		it("should use the correct international API base URL", () => {
-			new ZAiHandler({ zaiApiKey: "test-zai-api-key", zaiApiLine: "international_api" })
-			expect(OpenAI).toHaveBeenCalledWith(
-				expect.objectContaining({
-					baseURL: "https://api.z.ai/api/paas/v4",
-				}),
-			)
-		})
-
-		it("should use the provided API key for international API", () => {
-			const zaiApiKey = "test-zai-api-key"
-			new ZAiHandler({ zaiApiKey, zaiApiLine: "international_api" })
-			expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: zaiApiKey }))
-		})
-
 		it("should return international default model when no model is specified", () => {
 			const model = handler.getModel()
 			expect(model.id).toBe(internationalZAiDefaultModelId)
@@ -239,21 +233,6 @@ describe("ZAiHandler", () => {
 			handler = new ZAiHandler({ zaiApiKey: "test-zai-api-key", zaiApiLine: "china_api" })
 		})
 
-		it("should use the correct China API base URL", () => {
-			new ZAiHandler({ zaiApiKey: "test-zai-api-key", zaiApiLine: "china_api" })
-			expect(OpenAI).toHaveBeenCalledWith(
-				expect.objectContaining({
-					baseURL: "https://open.bigmodel.cn/api/paas/v4",
-				}),
-			)
-		})
-
-		it("should use the provided API key for China API", () => {
-			const zaiApiKey = "test-zai-api-key"
-			new ZAiHandler({ zaiApiKey, zaiApiLine: "china_api" })
-			expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: zaiApiKey }))
-		})
-
 		it("should return China default model when no model is specified", () => {
 			const model = handler.getModel()
 			expect(model.id).toBe(mainlandZAiDefaultModelId)
@@ -273,133 +252,98 @@ describe("ZAiHandler", () => {
 		})
 	})
 
-	describe("Default behavior", () => {
-		it("should default to international when no zaiApiLine is specified", () => {
-			const handlerDefault = new ZAiHandler({ zaiApiKey: "test-zai-api-key" })
-			expect(OpenAI).toHaveBeenCalledWith(
-				expect.objectContaining({
-					baseURL: "https://api.z.ai/api/coding/paas/v4",
-				}),
-			)
-
-			const model = handlerDefault.getModel()
-			expect(model.id).toBe(internationalZAiDefaultModelId)
-			expect(model.info).toEqual(internationalZAiModels[internationalZAiDefaultModelId])
-		})
-
-		it("should use 'not-provided' as default API key when none is specified", () => {
-			new ZAiHandler({ zaiApiLine: "international_coding" })
-			expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: "not-provided" }))
+	describe("getModel", () => {
+		it("should include model parameters from getModelParams", () => {
+			const model = handler.getModel()
+			expect(model).toHaveProperty("temperature")
+			expect(model).toHaveProperty("maxTokens")
 		})
 	})
 
-	describe("API Methods", () => {
-		beforeEach(() => {
-			handler = new ZAiHandler({ zaiApiKey: "test-zai-api-key", zaiApiLine: "international_coding" })
-		})
-
-		it("completePrompt method should return text from Z AI API", async () => {
-			const expectedResponse = "This is a test response from Z AI"
-			mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: expectedResponse } }] })
-			const result = await handler.completePrompt("test prompt")
-			expect(result).toBe(expectedResponse)
-		})
-
-		it("should handle errors in completePrompt", async () => {
-			const errorMessage = "Z AI API error"
-			mockCreate.mockRejectedValueOnce(new Error(errorMessage))
-			await expect(handler.completePrompt("test prompt")).rejects.toThrow(
-				`Z.ai completion error: ${errorMessage}`,
-			)
-		})
+	describe("createMessage", () => {
+		const systemPrompt = "You are a helpful assistant."
+		const messages: Anthropic.Messages.MessageParam[] = [
+			{
+				role: "user",
+				content: [{ type: "text" as const, text: "Hello!" }],
+			},
+		]
+
+		it("should handle streaming responses", async () => {
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "Test response from Z.ai" }
+			}
+
+			const mockUsage = Promise.resolve({
+				inputTokens: 10,
+				outputTokens: 5,
+			})
 
-		it("createMessage should yield text content from stream", async () => {
-			const testContent = "This is test content from Z AI stream"
-
-			mockCreate.mockImplementationOnce(() => {
-				return {
-					[Symbol.asyncIterator]: () => ({
-						next: vitest
-							.fn()
-							.mockResolvedValueOnce({
-								done: false,
-								value: { choices: [{ delta: { content: testContent } }] },
-							})
-							.mockResolvedValueOnce({ done: true }),
-					}),
-				}
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: mockUsage,
 			})
 
-			const stream = handler.createMessage("system prompt", [])
-			const firstChunk = await stream.next()
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
 
-			expect(firstChunk.done).toBe(false)
-			expect(firstChunk.value).toEqual({ type: "text", text: testContent })
+			expect(chunks.length).toBeGreaterThan(0)
+			const textChunks = chunks.filter((chunk) => chunk.type === "text")
+			expect(textChunks).toHaveLength(1)
+			expect(textChunks[0].text).toBe("Test response from Z.ai")
 		})
 
-		it("createMessage should yield usage data from stream", async () => {
-			mockCreate.mockImplementationOnce(() => {
-				return {
-					[Symbol.asyncIterator]: () => ({
-						next: vitest
-							.fn()
-							.mockResolvedValueOnce({
-								done: false,
-								value: {
-									choices: [{ delta: {} }],
-									usage: { prompt_tokens: 10, completion_tokens: 20 },
-								},
-							})
-							.mockResolvedValueOnce({ done: true }),
-					}),
-				}
+		it("should include usage information", async () => {
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "Test response" }
+			}
+
+			const mockUsage = Promise.resolve({
+				inputTokens: 10,
+				outputTokens: 20,
 			})
 
-			const stream = handler.createMessage("system prompt", [])
-			const firstChunk = await stream.next()
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: mockUsage,
+			})
 
-			expect(firstChunk.done).toBe(false)
-			expect(firstChunk.value).toMatchObject({ type: "usage", inputTokens: 10, outputTokens: 20 })
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			const usageChunks = chunks.filter((chunk) => chunk.type === "usage")
+			expect(usageChunks.length).toBeGreaterThan(0)
+			expect(usageChunks[0].inputTokens).toBe(10)
+			expect(usageChunks[0].outputTokens).toBe(20)
 		})
 
-		it("createMessage should pass correct parameters to Z AI client", async () => {
-			const modelId: InternationalZAiModelId = "glm-4.5"
-			const modelInfo = internationalZAiModels[modelId]
-			const handlerWithModel = new ZAiHandler({
-				apiModelId: modelId,
-				zaiApiKey: "test-zai-api-key",
-				zaiApiLine: "international_coding",
-			})
+		it("should pass correct parameters to streamText", async () => {
+			async function* mockFullStream() {
+				// empty stream
+			}
 
-			mockCreate.mockImplementationOnce(() => {
-				return {
-					[Symbol.asyncIterator]: () => ({
-						async next() {
-							return { done: true }
-						},
-					}),
-				}
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }),
 			})
 
-			const systemPrompt = "Test system prompt for Z AI"
-			const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message for Z AI" }]
-
-			const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages)
-			await messageGenerator.next()
+			const stream = handler.createMessage(systemPrompt, messages)
+			// Consume the stream
+			for await (const _chunk of stream) {
+				// drain
+			}
 
-			// Centralized 20% cap should apply to OpenAI-compatible providers like Z AI
-			const expectedMaxTokens = Math.min(modelInfo.maxTokens, Math.ceil(modelInfo.contextWindow * 0.2))
-
-			expect(mockCreate).toHaveBeenCalledWith(
+			expect(mockStreamText).toHaveBeenCalledWith(
 				expect.objectContaining({
-					model: modelId,
-					max_tokens: expectedMaxTokens,
-					temperature: ZAI_DEFAULT_TEMPERATURE,
-					messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]),
-					stream: true,
-					stream_options: { include_usage: true },
+					system: systemPrompt,
+					temperature: expect.any(Number),
 				}),
-				undefined,
 			)
 		})
 	})
@@ -410,27 +354,29 @@ describe("ZAiHandler", () => {
 				apiModelId: "glm-4.7",
 				zaiApiKey: "test-zai-api-key",
 				zaiApiLine: "international_coding",
-				// No reasoningEffort setting - should use model default (medium)
 			})
 
-			mockCreate.mockImplementationOnce(() => {
-				return {
-					[Symbol.asyncIterator]: () => ({
-						async next() {
-							return { done: true }
-						},
-					}),
-				}
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "response" }
+			}
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }),
 			})
 
-			const messageGenerator = handlerWithModel.createMessage("system prompt", [])
-			await messageGenerator.next()
+			const stream = handlerWithModel.createMessage("system prompt", [])
+			for await (const _chunk of stream) {
+				// drain
+			}
 
-			// For GLM-4.7 with default reasoning (medium), thinking should be enabled
-			expect(mockCreate).toHaveBeenCalledWith(
+			expect(mockStreamText).toHaveBeenCalledWith(
 				expect.objectContaining({
-					model: "glm-4.7",
-					thinking: { type: "enabled" },
+					providerOptions: {
+						zhipu: {
+							thinking: { type: "enabled" },
+						},
+					},
 				}),
 			)
 		})
@@ -444,24 +390,27 @@ describe("ZAiHandler", () => {
 				reasoningEffort: "disable",
 			})
 
-			mockCreate.mockImplementationOnce(() => {
-				return {
-					[Symbol.asyncIterator]: () => ({
-						async next() {
-							return { done: true }
-						},
-					}),
-				}
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "response" }
+			}
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }),
 			})
 
-			const messageGenerator = handlerWithModel.createMessage("system prompt", [])
-			await messageGenerator.next()
+			const stream = handlerWithModel.createMessage("system prompt", [])
+			for await (const _chunk of stream) {
+				// drain
+			}
 
-			// For GLM-4.7 with reasoning disabled, thinking should be disabled
-			expect(mockCreate).toHaveBeenCalledWith(
+			expect(mockStreamText).toHaveBeenCalledWith(
 				expect.objectContaining({
-					model: "glm-4.7",
-					thinking: { type: "disabled" },
+					providerOptions: {
+						zhipu: {
+							thinking: { type: "disabled" },
+						},
+					},
 				}),
 			)
 		})
@@ -475,51 +424,109 @@ describe("ZAiHandler", () => {
 				reasoningEffort: "medium",
 			})
 
-			mockCreate.mockImplementationOnce(() => {
-				return {
-					[Symbol.asyncIterator]: () => ({
-						async next() {
-							return { done: true }
-						},
-					}),
-				}
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "response" }
+			}
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }),
 			})
 
-			const messageGenerator = handlerWithModel.createMessage("system prompt", [])
-			await messageGenerator.next()
+			const stream = handlerWithModel.createMessage("system prompt", [])
+			for await (const _chunk of stream) {
+				// drain
+			}
 
-			// For GLM-4.7 with reasoning set to medium, thinking should be enabled
-			expect(mockCreate).toHaveBeenCalledWith(
+			expect(mockStreamText).toHaveBeenCalledWith(
 				expect.objectContaining({
-					model: "glm-4.7",
-					thinking: { type: "enabled" },
+					providerOptions: {
+						zhipu: {
+							thinking: { type: "enabled" },
+						},
+					},
 				}),
 			)
 		})
 
-		it("should NOT add thinking parameter for non-thinking models like GLM-4.6", async () => {
+		it("should NOT add providerOptions for non-thinking models like GLM-4.6", async () => {
 			const handlerWithModel = new ZAiHandler({
 				apiModelId: "glm-4.6",
 				zaiApiKey: "test-zai-api-key",
 				zaiApiLine: "international_coding",
 			})
 
-			mockCreate.mockImplementationOnce(() => {
-				return {
-					[Symbol.asyncIterator]: () => ({
-						async next() {
-							return { done: true }
-						},
-					}),
-				}
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "response" }
+			}
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }),
+			})
+
+			const stream = handlerWithModel.createMessage("system prompt", [])
+			for await (const _chunk of stream) {
+				// drain
+			}
+
+			const callArgs = mockStreamText.mock.calls[0][0]
+			expect(callArgs.providerOptions).toBeUndefined()
+		})
+
+		it("should handle reasoning content in streaming responses", async () => {
+			const handlerWithModel = new ZAiHandler({
+				apiModelId: "glm-4.7",
+				zaiApiKey: "test-zai-api-key",
+				zaiApiLine: "international_coding",
+			})
+
+			async function* mockFullStream() {
+				yield { type: "reasoning", text: "Let me think about this..." }
+				yield { type: "text-delta", text: "Here is my answer" }
+			}
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: Promise.resolve({ inputTokens: 10, outputTokens: 20 }),
 			})
 
-			const messageGenerator = handlerWithModel.createMessage("system prompt", [])
-			await messageGenerator.next()
+			const stream = handlerWithModel.createMessage("system prompt", [])
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			const reasoningChunks = chunks.filter((chunk) => chunk.type === "reasoning")
+			expect(reasoningChunks).toHaveLength(1)
+			expect(reasoningChunks[0].text).toBe("Let me think about this...")
+
+			const textChunks = chunks.filter((chunk) => chunk.type === "text")
+			expect(textChunks).toHaveLength(1)
+			expect(textChunks[0].text).toBe("Here is my answer")
+		})
+	})
+
+	describe("completePrompt", () => {
+		it("should complete a prompt using generateText", async () => {
+			mockGenerateText.mockResolvedValue({
+				text: "Test completion from Z.ai",
+			})
+
+			const result = await handler.completePrompt("Test prompt")
+
+			expect(result).toBe("Test completion from Z.ai")
+			expect(mockGenerateText).toHaveBeenCalledWith(
+				expect.objectContaining({
+					prompt: "Test prompt",
+				}),
+			)
+		})
+	})
 
-			// For GLM-4.6 (no thinking support), thinking parameter should not be present
-			const callArgs = mockCreate.mock.calls[0][0]
-			expect(callArgs.thinking).toBeUndefined()
+	describe("isAiSdkProvider", () => {
+		it("should return true", () => {
+			expect(handler.isAiSdkProvider()).toBe(true)
 		})
 	})
 })

+ 131 - 69
src/api/providers/zai.ts

@@ -1,5 +1,6 @@
 import { Anthropic } from "@anthropic-ai/sdk"
-import OpenAI from "openai"
+import { createZhipu } from "zhipu-ai-provider"
+import { streamText, generateText, ToolSet } from "ai"
 
 import {
 	internationalZAiModels,
@@ -11,101 +12,162 @@ import {
 	zaiApiLineConfigs,
 } from "@roo-code/types"
 
-import { type ApiHandlerOptions, getModelMaxOutputTokens, shouldUseReasoningEffort } from "../../shared/api"
-import { convertToZAiFormat } from "../transform/zai-format"
+import { type ApiHandlerOptions, shouldUseReasoningEffort } from "../../shared/api"
 
-import type { ApiHandlerCreateMessageMetadata } from "../index"
-import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider"
-
-// Custom interface for Z.ai params to support thinking mode
-type ZAiChatCompletionParams = OpenAI.Chat.ChatCompletionCreateParamsStreaming & {
-	thinking?: { type: "enabled" | "disabled" }
-}
+import {
+	convertToAiSdkMessages,
+	convertToolsForAiSdk,
+	processAiSdkStreamPart,
+	mapToolChoice,
+	handleAiSdkError,
+} from "../transform/ai-sdk"
+import { ApiStream } from "../transform/stream"
+import { getModelParams } from "../transform/model-params"
+
+import { DEFAULT_HEADERS } from "./constants"
+import { BaseProvider } from "./base-provider"
+import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
+
+/**
+ * Z.ai provider using the dedicated zhipu-ai-provider package.
+ * Provides native support for GLM-4.7 thinking mode and region-based model selection.
+ */
+export class ZAiHandler extends BaseProvider implements SingleCompletionHandler {
+	protected options: ApiHandlerOptions
+	protected provider: ReturnType<typeof createZhipu>
+	private isChina: boolean
 
-export class ZAiHandler extends BaseOpenAiCompatibleProvider<string> {
 	constructor(options: ApiHandlerOptions) {
-		const isChina = zaiApiLineConfigs[options.zaiApiLine ?? "international_coding"].isChina
-		const models = (isChina ? mainlandZAiModels : internationalZAiModels) as unknown as Record<string, ModelInfo>
-		const defaultModelId = (isChina ? mainlandZAiDefaultModelId : internationalZAiDefaultModelId) as string
+		super()
+		this.options = options
+		this.isChina = zaiApiLineConfigs[options.zaiApiLine ?? "international_coding"].isChina
 
-		super({
-			...options,
-			providerName: "Z.ai",
+		this.provider = createZhipu({
 			baseURL: zaiApiLineConfigs[options.zaiApiLine ?? "international_coding"].baseUrl,
 			apiKey: options.zaiApiKey ?? "not-provided",
-			defaultProviderModelId: defaultModelId,
-			providerModels: models,
+			headers: DEFAULT_HEADERS,
+		})
+	}
+
+	override getModel(): { id: string; info: ModelInfo; maxTokens?: number; temperature?: number } {
+		const models = (this.isChina ? mainlandZAiModels : internationalZAiModels) as unknown as Record<
+			string,
+			ModelInfo
+		>
+		const defaultModelId = (this.isChina ? mainlandZAiDefaultModelId : internationalZAiDefaultModelId) as string
+
+		const id = this.options.apiModelId ?? defaultModelId
+		const info = models[id] || models[defaultModelId]
+		const params = getModelParams({
+			format: "openai",
+			modelId: id,
+			model: info,
+			settings: this.options,
 			defaultTemperature: ZAI_DEFAULT_TEMPERATURE,
 		})
+
+		return { id, info, ...params }
+	}
+
+	/**
+	 * Get the language model for the configured model ID.
+	 */
+	protected getLanguageModel() {
+		const { id } = this.getModel()
+		return this.provider(id)
 	}
 
 	/**
-	 * Override createStream to handle GLM-4.7's thinking mode.
-	 * GLM-4.7 has thinking enabled by default in the API, so we need to
-	 * explicitly send { type: "disabled" } when the user turns off reasoning.
+	 * Get the max tokens parameter to include in the request.
 	 */
-	protected override createStream(
+	protected getMaxOutputTokens(): number | undefined {
+		const { info } = this.getModel()
+		return this.options.modelMaxTokens || info.maxTokens || undefined
+	}
+
+	/**
+	 * Create a message stream using the AI SDK.
+	 * For GLM-4.7, passes the thinking parameter via providerOptions.
+	 */
+	override async *createMessage(
 		systemPrompt: string,
 		messages: Anthropic.Messages.MessageParam[],
 		metadata?: ApiHandlerCreateMessageMetadata,
-		requestOptions?: OpenAI.RequestOptions,
-	) {
-		const { id: modelId, info } = this.getModel()
+	): ApiStream {
+		const { id: modelId, info, temperature } = this.getModel()
+		const languageModel = this.getLanguageModel()
+
+		const aiSdkMessages = convertToAiSdkMessages(messages)
+
+		const openAiTools = this.convertToolsForOpenAI(metadata?.tools)
+		const aiSdkTools = convertToolsForAiSdk(openAiTools) as ToolSet | undefined
+
+		const requestOptions: Parameters<typeof streamText>[0] = {
+			model: languageModel,
+			system: systemPrompt,
+			messages: aiSdkMessages,
+			temperature: this.options.modelTemperature ?? temperature ?? ZAI_DEFAULT_TEMPERATURE,
+			maxOutputTokens: this.getMaxOutputTokens(),
+			tools: aiSdkTools,
+			toolChoice: mapToolChoice(metadata?.tool_choice),
+		}
 
-		// Check if this is a GLM-4.7 model with thinking support
+		// GLM-4.7 thinking mode: pass thinking parameter via providerOptions
 		const isThinkingModel = modelId === "glm-4.7" && Array.isArray(info.supportsReasoningEffort)
 
 		if (isThinkingModel) {
-			// For GLM-4.7, thinking is ON by default in the API.
-			// We need to explicitly disable it when reasoning is off.
 			const useReasoning = shouldUseReasoningEffort({ model: info, settings: this.options })
-
-			// Create the stream with our custom thinking parameter
-			return this.createStreamWithThinking(systemPrompt, messages, metadata, useReasoning)
+			requestOptions.providerOptions = {
+				zhipu: {
+					thinking: useReasoning ? { type: "enabled" } : { type: "disabled" },
+				},
+			}
 		}
 
-		// For non-thinking models, use the default behavior
-		return super.createStream(systemPrompt, messages, metadata, requestOptions)
+		const result = streamText(requestOptions)
+
+		try {
+			for await (const part of result.fullStream) {
+				for (const chunk of processAiSdkStreamPart(part)) {
+					yield chunk
+				}
+			}
+
+			const usage = await result.usage
+			if (usage) {
+				yield {
+					type: "usage" as const,
+					inputTokens: usage.inputTokens || 0,
+					outputTokens: usage.outputTokens || 0,
+				}
+			}
+		} catch (error) {
+			throw handleAiSdkError(error, "Z.ai")
+		}
 	}
 
 	/**
-	 * Creates a stream with explicit thinking control for GLM-4.7
+	 * Complete a prompt using the AI SDK generateText.
 	 */
-	private createStreamWithThinking(
-		systemPrompt: string,
-		messages: Anthropic.Messages.MessageParam[],
-		metadata?: ApiHandlerCreateMessageMetadata,
-		useReasoning?: boolean,
-	) {
-		const { id: model, info } = this.getModel()
-
-		const max_tokens =
-			getModelMaxOutputTokens({
-				modelId: model,
-				model: info,
-				settings: this.options,
-				format: "openai",
-			}) ?? undefined
-
-		const temperature = this.options.modelTemperature ?? this.defaultTemperature
-
-		// Use Z.ai format to preserve reasoning_content and merge post-tool text into tool messages
-		const convertedMessages = convertToZAiFormat(messages, { mergeToolResultText: true })
-
-		const params: ZAiChatCompletionParams = {
-			model,
-			max_tokens,
-			temperature,
-			messages: [{ role: "system", content: systemPrompt }, ...convertedMessages],
-			stream: true,
-			stream_options: { include_usage: true },
-			// For GLM-4.7: thinking is ON by default, so we explicitly disable when needed
-			thinking: useReasoning ? { type: "enabled" } : { type: "disabled" },
-			tools: this.convertToolsForOpenAI(metadata?.tools),
-			tool_choice: metadata?.tool_choice,
-			parallel_tool_calls: metadata?.parallelToolCalls ?? true,
+	async completePrompt(prompt: string): Promise<string> {
+		const { temperature } = this.getModel()
+		const languageModel = this.getLanguageModel()
+
+		try {
+			const { text } = await generateText({
+				model: languageModel,
+				prompt,
+				maxOutputTokens: this.getMaxOutputTokens(),
+				temperature: this.options.modelTemperature ?? temperature ?? ZAI_DEFAULT_TEMPERATURE,
+			})
+
+			return text
+		} catch (error) {
+			throw handleAiSdkError(error, "Z.ai")
 		}
+	}
 
-		return this.client.chat.completions.create(params)
+	override isAiSdkProvider(): boolean {
+		return true
 	}
 }

+ 0 - 242
src/api/transform/zai-format.ts

@@ -1,242 +0,0 @@
-import { Anthropic } from "@anthropic-ai/sdk"
-import OpenAI from "openai"
-
-type ContentPartText = OpenAI.Chat.ChatCompletionContentPartText
-type ContentPartImage = OpenAI.Chat.ChatCompletionContentPartImage
-type UserMessage = OpenAI.Chat.ChatCompletionUserMessageParam
-type AssistantMessage = OpenAI.Chat.ChatCompletionAssistantMessageParam
-type SystemMessage = OpenAI.Chat.ChatCompletionSystemMessageParam
-type ToolMessage = OpenAI.Chat.ChatCompletionToolMessageParam
-type Message = OpenAI.Chat.ChatCompletionMessageParam
-type AnthropicMessage = Anthropic.Messages.MessageParam
-
-/**
- * Extended assistant message type to support Z.ai's interleaved thinking.
- * Z.ai's API returns reasoning_content alongside content and tool_calls,
- * and requires it to be passed back in subsequent requests for preserved thinking.
- */
-export type ZAiAssistantMessage = AssistantMessage & {
-	reasoning_content?: string
-}
-
-/**
- * Converts Anthropic messages to OpenAI format optimized for Z.ai's GLM-4.7 thinking mode.
- *
- * Key differences from standard OpenAI format:
- * - Preserves reasoning_content on assistant messages for interleaved thinking
- * - Text content after tool_results (like environment_details) is merged into the last tool message
- *   to avoid creating user messages that would cause reasoning_content to be dropped
- *
- * @param messages Array of Anthropic messages
- * @param options Optional configuration for message conversion
- * @param options.mergeToolResultText If true, merge text content after tool_results into the last
- *                                     tool message instead of creating a separate user message.
- *                                     This is critical for Z.ai's interleaved thinking mode.
- * @returns Array of OpenAI messages optimized for Z.ai's thinking mode
- */
-export function convertToZAiFormat(
-	messages: AnthropicMessage[],
-	options?: { mergeToolResultText?: boolean },
-): Message[] {
-	const result: Message[] = []
-
-	for (const message of messages) {
-		// Check if the message has reasoning_content (for Z.ai interleaved thinking)
-		const messageWithReasoning = message as AnthropicMessage & { reasoning_content?: string }
-		const reasoningContent = messageWithReasoning.reasoning_content
-
-		if (message.role === "user") {
-			// Handle user messages - may contain tool_result blocks
-			if (Array.isArray(message.content)) {
-				const textParts: string[] = []
-				const imageParts: ContentPartImage[] = []
-				const toolResults: { tool_use_id: string; content: string }[] = []
-
-				for (const part of message.content) {
-					if (part.type === "text") {
-						textParts.push(part.text)
-					} else if (part.type === "image") {
-						imageParts.push({
-							type: "image_url",
-							image_url: { url: `data:${part.source.media_type};base64,${part.source.data}` },
-						})
-					} else if (part.type === "tool_result") {
-						// Convert tool_result to OpenAI tool message format
-						let content: string
-						if (typeof part.content === "string") {
-							content = part.content
-						} else if (Array.isArray(part.content)) {
-							content =
-								part.content
-									?.map((c) => {
-										if (c.type === "text") return c.text
-										if (c.type === "image") return "(image)"
-										return ""
-									})
-									.join("\n") ?? ""
-						} else {
-							content = ""
-						}
-						toolResults.push({
-							tool_use_id: part.tool_use_id,
-							content,
-						})
-					}
-				}
-
-				// Add tool messages first (they must follow assistant tool_use)
-				for (const toolResult of toolResults) {
-					const toolMessage: ToolMessage = {
-						role: "tool",
-						tool_call_id: toolResult.tool_use_id,
-						content: toolResult.content,
-					}
-					result.push(toolMessage)
-				}
-
-				// Handle text/image content after tool results
-				if (textParts.length > 0 || imageParts.length > 0) {
-					// For Z.ai interleaved thinking: when mergeToolResultText is enabled and we have
-					// tool results followed by text, merge the text into the last tool message to avoid
-					// creating a user message that would cause reasoning_content to be dropped.
-					// This is critical because Z.ai drops all reasoning_content when it sees a user message.
-					const shouldMergeIntoToolMessage =
-						options?.mergeToolResultText && toolResults.length > 0 && imageParts.length === 0
-
-					if (shouldMergeIntoToolMessage) {
-						// Merge text content into the last tool message
-						const lastToolMessage = result[result.length - 1] as ToolMessage
-						if (lastToolMessage?.role === "tool") {
-							const additionalText = textParts.join("\n")
-							lastToolMessage.content = `${lastToolMessage.content}\n\n${additionalText}`
-						}
-					} else {
-						// Standard behavior: add user message with text/image content
-						let content: UserMessage["content"]
-						if (imageParts.length > 0) {
-							const parts: (ContentPartText | ContentPartImage)[] = []
-							if (textParts.length > 0) {
-								parts.push({ type: "text", text: textParts.join("\n") })
-							}
-							parts.push(...imageParts)
-							content = parts
-						} else {
-							content = textParts.join("\n")
-						}
-
-						// Check if we can merge with the last message
-						const lastMessage = result[result.length - 1]
-						if (lastMessage?.role === "user") {
-							// Merge with existing user message
-							if (typeof lastMessage.content === "string" && typeof content === "string") {
-								lastMessage.content += `\n${content}`
-							} else {
-								const lastContent = Array.isArray(lastMessage.content)
-									? lastMessage.content
-									: [{ type: "text" as const, text: lastMessage.content || "" }]
-								const newContent = Array.isArray(content)
-									? content
-									: [{ type: "text" as const, text: content }]
-								lastMessage.content = [...lastContent, ...newContent] as UserMessage["content"]
-							}
-						} else {
-							result.push({ role: "user", content })
-						}
-					}
-				}
-			} else {
-				// Simple string content
-				const lastMessage = result[result.length - 1]
-				if (lastMessage?.role === "user") {
-					if (typeof lastMessage.content === "string") {
-						lastMessage.content += `\n${message.content}`
-					} else {
-						;(lastMessage.content as (ContentPartText | ContentPartImage)[]).push({
-							type: "text",
-							text: message.content,
-						})
-					}
-				} else {
-					result.push({ role: "user", content: message.content })
-				}
-			}
-		} else if (message.role === "assistant") {
-			// Handle assistant messages - may contain tool_use blocks and reasoning blocks
-			if (Array.isArray(message.content)) {
-				const textParts: string[] = []
-				const toolCalls: OpenAI.Chat.ChatCompletionMessageToolCall[] = []
-				let extractedReasoning: string | undefined
-
-				for (const part of message.content) {
-					if (part.type === "text") {
-						textParts.push(part.text)
-					} else if (part.type === "tool_use") {
-						toolCalls.push({
-							id: part.id,
-							type: "function",
-							function: {
-								name: part.name,
-								arguments: JSON.stringify(part.input),
-							},
-						})
-					} else if ((part as any).type === "reasoning" && (part as any).text) {
-						// Extract reasoning from content blocks (Task stores it this way)
-						extractedReasoning = (part as any).text
-					}
-				}
-
-				// Use reasoning from content blocks if not provided at top level
-				const finalReasoning = reasoningContent || extractedReasoning
-
-				const assistantMessage: ZAiAssistantMessage = {
-					role: "assistant",
-					content: textParts.length > 0 ? textParts.join("\n") : null,
-					...(toolCalls.length > 0 && { tool_calls: toolCalls }),
-					// Preserve reasoning_content for Z.ai interleaved thinking
-					...(finalReasoning && { reasoning_content: finalReasoning }),
-				}
-
-				// Check if we can merge with the last message (only if no tool calls)
-				const lastMessage = result[result.length - 1]
-				if (lastMessage?.role === "assistant" && !toolCalls.length && !(lastMessage as any).tool_calls) {
-					// Merge text content
-					if (typeof lastMessage.content === "string" && typeof assistantMessage.content === "string") {
-						lastMessage.content += `\n${assistantMessage.content}`
-					} else if (assistantMessage.content) {
-						const lastContent = lastMessage.content || ""
-						lastMessage.content = `${lastContent}\n${assistantMessage.content}`
-					}
-					// Preserve reasoning_content from the new message if present
-					if (finalReasoning) {
-						;(lastMessage as ZAiAssistantMessage).reasoning_content = finalReasoning
-					}
-				} else {
-					result.push(assistantMessage)
-				}
-			} else {
-				// Simple string content
-				const lastMessage = result[result.length - 1]
-				if (lastMessage?.role === "assistant" && !(lastMessage as any).tool_calls) {
-					if (typeof lastMessage.content === "string") {
-						lastMessage.content += `\n${message.content}`
-					} else {
-						lastMessage.content = message.content
-					}
-					// Preserve reasoning_content from the new message if present
-					if (reasoningContent) {
-						;(lastMessage as ZAiAssistantMessage).reasoning_content = reasoningContent
-					}
-				} else {
-					const assistantMessage: ZAiAssistantMessage = {
-						role: "assistant",
-						content: message.content,
-						...(reasoningContent && { reasoning_content: reasoningContent }),
-					}
-					result.push(assistantMessage)
-				}
-			}
-		}
-	}
-
-	return result
-}

+ 1 - 0
src/package.json

@@ -541,6 +541,7 @@
 		"web-tree-sitter": "^0.25.6",
 		"workerpool": "^9.2.0",
 		"yaml": "^2.8.0",
+		"zhipu-ai-provider": "^0.2.2",
 		"zod": "3.25.76"
 	},
 	"devDependencies": {