Explorar el Código

refactor: migrate baseten provider to AI SDK (#11261)

* refactor: migrate baseten provider to AI SDK

* refactor(baseten): migrate to native @ai-sdk/baseten package

Replace OpenAICompatibleHandler with dedicated @ai-sdk/baseten package,
following the same pattern used by other native AI SDK providers (groq,
deepseek, etc.). This uses createBaseten() for provider initialization
and extends BaseProvider directly instead of the generic OpenAI-compatible
handler.
Daniel hace 6 días
padre
commit
43a3073545
Se han modificado 5 ficheros con 807 adiciones y 10 borrados
  1. 196 0
      pnpm-lock.yaml
  2. 446 0
      src/api/providers/__tests__/baseten.spec.ts
  3. 148 10
      src/api/providers/baseten.ts
  4. 16 0
      src/esbuild.mjs
  5. 1 0
      src/package.json

+ 196 - 0
pnpm-lock.yaml

@@ -749,6 +749,9 @@ importers:
       '@ai-sdk/amazon-bedrock':
         specifier: ^4.0.50
         version: 4.0.50([email protected])
+      '@ai-sdk/baseten':
+        specifier: ^1.0.31
+        version: 1.0.31([email protected])
       '@ai-sdk/cerebras':
         specifier: ^1.0.0
         version: 1.0.35([email protected])
@@ -1435,6 +1438,12 @@ packages:
     peerDependencies:
       zod: 3.25.76
 
+  '@ai-sdk/[email protected]':
+    resolution: {integrity: sha512-tGbV96WBb5nnfyUYFrPyBxrhw53YlKSJbMC+rH3HhQlUaIs8+m/Bm4M0isrek9owIIf4MmmSDZ5VZL08zz7eFQ==}
+    engines: {node: '>=18'}
+    peerDependencies:
+      zod: 3.25.76
+
   '@ai-sdk/[email protected]':
     resolution: {integrity: sha512-JrNdMYptrOUjNthibgBeAcBjZ/H+fXb49sSrWhOx5Aq8eUcrYvwQ2DtSAi8VraHssZu78NAnBMrgFWSUOTXFxw==}
     engines: {node: '>=18'}
@@ -1513,6 +1522,12 @@ packages:
     peerDependencies:
       zod: 3.25.76
 
+  '@ai-sdk/[email protected]':
+    resolution: {integrity: sha512-WzDnU0B13FMSSupDtm2lksFZvWGXnOfhG5S0HoPI0pkX5uVkr6N1UTATMyVaxLCG0MRkMhXCjkg4NXgEbb330Q==}
+    engines: {node: '>=18'}
+    peerDependencies:
+      zod: 3.25.76
+
   '@ai-sdk/[email protected]':
     resolution: {integrity: sha512-iXHVe0apM2zUEzauqJwqmpC37A5rihrStAih5Ks+JE32iTe4LZ58y17UGBjpQQTCRw9YxMeo2UFLxLpBluyvLQ==}
     engines: {node: '>=18'}
@@ -1543,6 +1558,12 @@ packages:
     peerDependencies:
       zod: 3.25.76
 
+  '@ai-sdk/[email protected]':
+    resolution: {integrity: sha512-7bzKd9lgiDeXM7O4U4nQ8iTxguAOkg8LZGD9AfDVZYjO5cKYRwBPwVjboFcVrxncRHu0tYxZtXZtiLKpG4pEng==}
+    engines: {node: '>=18'}
+    peerDependencies:
+      zod: 3.25.76
+
   '@ai-sdk/[email protected]':
     resolution: {integrity: sha512-6o7Y2SeO9vFKB8lArHXehNuusnpddKPk7xqL7T2/b+OvXMRIXUO1rR4wcv1hAFUAT9avGZshty3Wlua/XA7TvA==}
     engines: {node: '>=18'}
@@ -1563,6 +1584,10 @@ packages:
     resolution: {integrity: sha512-VkPLrutM6VdA924/mG8OS+5frbVTcu6e046D2bgDo00tehBANR1QBJ/mPcZ9tXMFOsVcm6SQArOregxePzTFPw==}
     engines: {node: '>=18'}
 
+  '@ai-sdk/[email protected]':
+    resolution: {integrity: sha512-oGMAgGoQdBXbZqNG0Ze56CHjDZ1IDYOwGYxYjO5KLSlz5HiNQ9udIXsPZ61VWaHGZ5XW/jyjmr6t2xz2jGVwbQ==}
+    engines: {node: '>=18'}
+
   '@ai-sdk/[email protected]':
     resolution: {integrity: sha512-26qM/jYcFhF5krTM7bQT1CiZcdz22EQmA+r5me1hKYFM/yM20sSUMHnAcUzvzuuG9oQVKF0tziU2IcC0HX5huQ==}
     engines: {node: '>=18'}
@@ -1880,6 +1905,93 @@ packages:
     resolution: {integrity: sha512-+EzkxvLNfiUeKMgy/3luqfsCWFRXLb7U6wNQTk60tovuckwB15B191tJWvpp4HjiQWdJkCxO3Wbvc6jlk3Xb2Q==}
     engines: {node: '>=6.9.0'}
 
+  '@basetenlabs/[email protected]':
+    resolution: {integrity: sha512-gwDZ6GDJA0AAmQAHxt2vaCz0tYTaLjxJKZnoYt+0Eji4gy231JZZFAwvbAqNdQCrGEQ9lXnk7SNM1Apet4NlYg==}
+    engines: {node: '>= 10'}
+    cpu: [arm]
+    os: [android]
+
+  '@basetenlabs/[email protected]':
+    resolution: {integrity: sha512-oGRB/6hH89majhsmoVmj1IAZv4C7F2aLeTSebevBelmdYO4CFkn5qewxLzU1pDkkmxVVk2k+TRpYa1Dt4B96qQ==}
+    engines: {node: '>= 10'}
+    cpu: [arm64]
+    os: [android]
+
+  '@basetenlabs/[email protected]':
+    resolution: {integrity: sha512-QpBOUjeO05tWgFWkDw2RUQZa3BMplX5jNiBBTi5mH1lIL/m1sm2vkxoc0iorEESp1mMPstYFS/fr4ssBuO7wyA==}
+    engines: {node: '>= 10'}
+    cpu: [arm64]
+    os: [darwin]
+
+  '@basetenlabs/[email protected]':
+    resolution: {integrity: sha512-CBM38GAhekjylrlf7jW/0WNyFAGnAMBCNHZxaPnAjjhDNzJh1tcrwhvtOs66XbAqCOjO/tkt5Pdu6mg2Ui2Pjw==}
+    engines: {node: '>= 10'}
+    os: [darwin]
+
+  '@basetenlabs/[email protected]':
+    resolution: {integrity: sha512-R+NsA72Axclh1CUpmaWOCLTWCqXn5/tFMj2z9BnHVSRTelx/pYFlx6ZngVTB1HYp1n21m3upPXGo8CHF8R7Itw==}
+    engines: {node: '>= 10'}
+    cpu: [x64]
+    os: [darwin]
+
+  '@basetenlabs/[email protected]':
+    resolution: {integrity: sha512-96kEo0Eas4GVQdFkxIB1aAv6dy5Ga57j+RIg5l0Yiawv+AYIEmgk9BsGkqcwayp8Iiu6LN22Z+AUsGY2gstNrg==}
+    engines: {node: '>= 10'}
+    cpu: [arm]
+    os: [linux]
+
+  '@basetenlabs/[email protected]':
+    resolution: {integrity: sha512-lzEHeu+/BWDl2q+QZcqCkg1rDGF4MeyM3HgYwX+07t+vGZoqtM2we9vEV68wXMpl6ToEHQr7ML2KHA1Gb6ogxg==}
+    engines: {node: '>= 10'}
+    cpu: [arm]
+    os: [linux]
+
+  '@basetenlabs/[email protected]':
+    resolution: {integrity: sha512-MnY2cIRY/cQOYERWIHhh5CoaS2wgmmXtGDVGSLYyZvjwizrXZvjkEz7Whv2jaQ21T5S56VER67RABjz2TItrHQ==}
+    engines: {node: '>= 10'}
+    cpu: [arm64]
+    os: [linux]
+
+  '@basetenlabs/[email protected]':
+    resolution: {integrity: sha512-2KUvdK4wuoZdIqNnJhx7cu6ybXCwtiwGAtlrEvhai3FOkUQ3wE2Xa+TQ33mNGSyFbw6wAvLawYtKVFmmw27gJw==}
+    engines: {node: '>= 10'}
+    cpu: [riscv64]
+    os: [linux]
+
+  '@basetenlabs/[email protected]':
+    resolution: {integrity: sha512-9jjQPjHLiVOGwUPlmhnBl7OmmO7hQ8WMt+v3mJuxkS5JTNDmVOngfmgGlbN9NjBhQMENjdcMUVOquVo7HeybGQ==}
+    engines: {node: '>= 10'}
+    cpu: [x64]
+    os: [linux]
+
+  '@basetenlabs/[email protected]':
+    resolution: {integrity: sha512-bjYB8FKcPvEa251Ep2Gm3tvywADL9eavVjZsikdf0AvJ1K5pT+vLLvJBU9ihBsTPWnbF4pJgxVjwS6UjVObsQA==}
+    engines: {node: '>= 10'}
+    cpu: [x64]
+    os: [linux]
+
+  '@basetenlabs/[email protected]':
+    resolution: {integrity: sha512-Vxq5UXEmfh3C3hpwXdp3Daaf0dnLR9zFH2x8MJ1Hf/TcilmOP1clneewNpIv0e7MrnT56Z4pM6P3d8VFMZqBKg==}
+    engines: {node: '>= 10'}
+    cpu: [arm64]
+    os: [win32]
+
+  '@basetenlabs/[email protected]':
+    resolution: {integrity: sha512-KJrm7CgZdP/UDC5+tHtqE6w9XMfY5YUfMOxJfBZGSsLMqS2OGsakQsaF0a55k+58l29X5w/nAkjHrI1BcQO03w==}
+    engines: {node: '>= 10'}
+    cpu: [ia32]
+    os: [win32]
+
+  '@basetenlabs/[email protected]':
+    resolution: {integrity: sha512-M/mhvfTItUcUX+aeXRb5g5MbRlndfg6yelV7tSYfLU4YixMIe5yoGaAP3iDilpFJjcC99f+EU4l4+yLbPtpXig==}
+    engines: {node: '>= 10'}
+    cpu: [x64]
+    os: [win32]
+
+  '@basetenlabs/[email protected]':
+    resolution: {integrity: sha512-H6bpd1JcDbuJsOS2dNft+CCGLzBqHJO/ST/4mMKhLAW641J6PpVJUw1szYsk/dTetdedbWxHpMkvFObOKeP8nw==}
+    engines: {node: '>= 10'}
+
   '@bcoe/[email protected]':
     resolution: {integrity: sha512-0hYQ8SB4Db5zvZB4axdMHGwEaQjkZzFjQiN9LVYvIFB2nSUHW9tYpxWriPrWDASIxiaXax83REcLxuSdnGPZtw==}
 
@@ -11016,6 +11128,14 @@ snapshots:
       '@ai-sdk/provider-utils': 4.0.13([email protected])
       zod: 3.25.76
 
+  '@ai-sdk/[email protected]([email protected])':
+    dependencies:
+      '@ai-sdk/openai-compatible': 2.0.28([email protected])
+      '@ai-sdk/provider': 3.0.8
+      '@ai-sdk/provider-utils': 4.0.14([email protected])
+      '@basetenlabs/performance-client': 0.0.10
+      zod: 3.25.76
+
   '@ai-sdk/[email protected]([email protected])':
     dependencies:
       '@ai-sdk/openai-compatible': 1.0.31([email protected])
@@ -11102,6 +11222,12 @@ snapshots:
       '@ai-sdk/provider-utils': 4.0.13([email protected])
       zod: 3.25.76
 
+  '@ai-sdk/[email protected]([email protected])':
+    dependencies:
+      '@ai-sdk/provider': 3.0.8
+      '@ai-sdk/provider-utils': 4.0.14([email protected])
+      zod: 3.25.76
+
   '@ai-sdk/[email protected]([email protected])':
     dependencies:
       '@ai-sdk/provider': 2.0.1
@@ -11138,6 +11264,13 @@ snapshots:
       eventsource-parser: 3.0.6
       zod: 3.25.76
 
+  '@ai-sdk/[email protected]([email protected])':
+    dependencies:
+      '@ai-sdk/provider': 3.0.8
+      '@standard-schema/spec': 1.1.0
+      eventsource-parser: 3.0.6
+      zod: 3.25.76
+
   '@ai-sdk/[email protected]':
     dependencies:
       json-schema: 0.4.0
@@ -11158,6 +11291,10 @@ snapshots:
     dependencies:
       json-schema: 0.4.0
 
+  '@ai-sdk/[email protected]':
+    dependencies:
+      json-schema: 0.4.0
+
   '@ai-sdk/[email protected]([email protected])':
     dependencies:
       '@ai-sdk/openai-compatible': 2.0.26([email protected])
@@ -11893,6 +12030,65 @@ snapshots:
       '@babel/helper-string-parser': 7.27.1
       '@babel/helper-validator-identifier': 7.27.1
 
+  '@basetenlabs/[email protected]':
+    optional: true
+
+  '@basetenlabs/[email protected]':
+    optional: true
+
+  '@basetenlabs/[email protected]':
+    optional: true
+
+  '@basetenlabs/[email protected]':
+    optional: true
+
+  '@basetenlabs/[email protected]':
+    optional: true
+
+  '@basetenlabs/[email protected]':
+    optional: true
+
+  '@basetenlabs/[email protected]':
+    optional: true
+
+  '@basetenlabs/[email protected]':
+    optional: true
+
+  '@basetenlabs/[email protected]':
+    optional: true
+
+  '@basetenlabs/[email protected]':
+    optional: true
+
+  '@basetenlabs/[email protected]':
+    optional: true
+
+  '@basetenlabs/[email protected]':
+    optional: true
+
+  '@basetenlabs/[email protected]':
+    optional: true
+
+  '@basetenlabs/[email protected]':
+    optional: true
+
+  '@basetenlabs/[email protected]':
+    optionalDependencies:
+      '@basetenlabs/performance-client-android-arm-eabi': 0.0.10
+      '@basetenlabs/performance-client-android-arm64': 0.0.10
+      '@basetenlabs/performance-client-darwin-arm64': 0.0.10
+      '@basetenlabs/performance-client-darwin-universal': 0.0.10
+      '@basetenlabs/performance-client-darwin-x64': 0.0.10
+      '@basetenlabs/performance-client-linux-arm-gnueabihf': 0.0.10
+      '@basetenlabs/performance-client-linux-arm-musleabihf': 0.0.10
+      '@basetenlabs/performance-client-linux-arm64-gnu': 0.0.10
+      '@basetenlabs/performance-client-linux-riscv64-gnu': 0.0.10
+      '@basetenlabs/performance-client-linux-x64-gnu': 0.0.10
+      '@basetenlabs/performance-client-linux-x64-musl': 0.0.10
+      '@basetenlabs/performance-client-win32-arm64-msvc': 0.0.10
+      '@basetenlabs/performance-client-win32-ia32-msvc': 0.0.10
+      '@basetenlabs/performance-client-win32-x64-msvc': 0.0.10
+
   '@bcoe/[email protected]': {}
 
   '@braintree/[email protected]': {}

+ 446 - 0
src/api/providers/__tests__/baseten.spec.ts

@@ -0,0 +1,446 @@
+// npx vitest run src/api/providers/__tests__/baseten.spec.ts
+
+// Use vi.hoisted to define mock functions that can be referenced in hoisted vi.mock() calls
+const { mockStreamText, mockGenerateText } = vi.hoisted(() => ({
+	mockStreamText: vi.fn(),
+	mockGenerateText: vi.fn(),
+}))
+
+vi.mock("ai", async (importOriginal) => {
+	const actual = await importOriginal<typeof import("ai")>()
+	return {
+		...actual,
+		streamText: mockStreamText,
+		generateText: mockGenerateText,
+	}
+})
+
+vi.mock("@ai-sdk/baseten", () => ({
+	createBaseten: vi.fn(() => {
+		return vi.fn(() => ({
+			modelId: "zai-org/GLM-4.6",
+			provider: "baseten",
+		}))
+	}),
+}))
+
+import type { Anthropic } from "@anthropic-ai/sdk"
+
+import { basetenDefaultModelId, basetenModels, type BasetenModelId } from "@roo-code/types"
+
+import type { ApiHandlerOptions } from "../../../shared/api"
+
+import { BasetenHandler } from "../baseten"
+
+describe("BasetenHandler", () => {
+	let handler: BasetenHandler
+	let mockOptions: ApiHandlerOptions
+
+	beforeEach(() => {
+		mockOptions = {
+			basetenApiKey: "test-baseten-api-key",
+			apiModelId: "zai-org/GLM-4.6",
+		}
+		handler = new BasetenHandler(mockOptions)
+		vi.clearAllMocks()
+	})
+
+	describe("constructor", () => {
+		it("should initialize with provided options", () => {
+			expect(handler).toBeInstanceOf(BasetenHandler)
+			expect(handler.getModel().id).toBe(mockOptions.apiModelId)
+		})
+
+		it("should use default model ID if not provided", () => {
+			const handlerWithoutModel = new BasetenHandler({
+				...mockOptions,
+				apiModelId: undefined,
+			})
+			expect(handlerWithoutModel.getModel().id).toBe(basetenDefaultModelId)
+		})
+	})
+
+	describe("getModel", () => {
+		it("should return default model when no model is specified", () => {
+			const handlerWithoutModel = new BasetenHandler({
+				basetenApiKey: "test-baseten-api-key",
+			})
+			const model = handlerWithoutModel.getModel()
+			expect(model.id).toBe(basetenDefaultModelId)
+			expect(model.info).toEqual(basetenModels[basetenDefaultModelId])
+		})
+
+		it("should return specified model when valid model is provided", () => {
+			const testModelId: BasetenModelId = "deepseek-ai/DeepSeek-R1"
+			const handlerWithModel = new BasetenHandler({
+				apiModelId: testModelId,
+				basetenApiKey: "test-baseten-api-key",
+			})
+			const model = handlerWithModel.getModel()
+			expect(model.id).toBe(testModelId)
+			expect(model.info).toEqual(basetenModels[testModelId])
+		})
+
+		it("should return provided model ID with default model info if model does not exist", () => {
+			const handlerWithInvalidModel = new BasetenHandler({
+				...mockOptions,
+				apiModelId: "invalid-model",
+			})
+			const model = handlerWithInvalidModel.getModel()
+			expect(model.id).toBe("invalid-model")
+			expect(model.info).toBeDefined()
+			expect(model.info).toBe(basetenModels[basetenDefaultModelId])
+		})
+
+		it("should include model parameters from getModelParams", () => {
+			const model = handler.getModel()
+			expect(model).toHaveProperty("temperature")
+			expect(model).toHaveProperty("maxTokens")
+		})
+	})
+
+	describe("createMessage", () => {
+		const systemPrompt = "You are a helpful assistant."
+		const messages: Anthropic.Messages.MessageParam[] = [
+			{
+				role: "user",
+				content: [
+					{
+						type: "text" as const,
+						text: "Hello!",
+					},
+				],
+			},
+		]
+
+		it("should handle streaming responses", async () => {
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "Test response from Baseten" }
+			}
+
+			const mockUsage = Promise.resolve({
+				inputTokens: 10,
+				outputTokens: 5,
+			})
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: mockUsage,
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			expect(chunks.length).toBeGreaterThan(0)
+			const textChunks = chunks.filter((chunk) => chunk.type === "text")
+			expect(textChunks).toHaveLength(1)
+			expect(textChunks[0].text).toBe("Test response from Baseten")
+		})
+
+		it("should include usage information", async () => {
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "Test response" }
+			}
+
+			const mockUsage = Promise.resolve({
+				inputTokens: 10,
+				outputTokens: 20,
+			})
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: mockUsage,
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			const usageChunks = chunks.filter((chunk) => chunk.type === "usage")
+			expect(usageChunks.length).toBeGreaterThan(0)
+			expect(usageChunks[0].inputTokens).toBe(10)
+			expect(usageChunks[0].outputTokens).toBe(20)
+		})
+
+		it("should pass correct temperature (0.5 default) to streamText", async () => {
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "Test" }
+			}
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }),
+			})
+
+			const handlerWithDefaultTemp = new BasetenHandler({
+				basetenApiKey: "test-key",
+				apiModelId: "zai-org/GLM-4.6",
+			})
+
+			const stream = handlerWithDefaultTemp.createMessage(systemPrompt, messages)
+			for await (const _ of stream) {
+				// consume stream
+			}
+
+			expect(mockStreamText).toHaveBeenCalledWith(
+				expect.objectContaining({
+					temperature: 0.5,
+				}),
+			)
+		})
+
+		it("should use user-specified temperature over default", async () => {
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "Test" }
+			}
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }),
+			})
+
+			const handlerWithCustomTemp = new BasetenHandler({
+				basetenApiKey: "test-key",
+				apiModelId: "zai-org/GLM-4.6",
+				modelTemperature: 0.9,
+			})
+
+			const stream = handlerWithCustomTemp.createMessage(systemPrompt, messages)
+			for await (const _ of stream) {
+				// consume stream
+			}
+
+			expect(mockStreamText).toHaveBeenCalledWith(
+				expect.objectContaining({
+					temperature: 0.9,
+				}),
+			)
+		})
+
+		it("should handle stream with multiple chunks", async () => {
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "Hello" }
+				yield { type: "text-delta", text: " world" }
+			}
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: Promise.resolve({ inputTokens: 5, outputTokens: 10 }),
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			const textChunks = chunks.filter((c) => c.type === "text")
+			expect(textChunks[0]).toEqual({ type: "text", text: "Hello" })
+			expect(textChunks[1]).toEqual({ type: "text", text: " world" })
+
+			const usageChunks = chunks.filter((c) => c.type === "usage")
+			expect(usageChunks[0]).toMatchObject({ type: "usage", inputTokens: 5, outputTokens: 10 })
+		})
+	})
+
+	describe("completePrompt", () => {
+		it("should complete a prompt using generateText", async () => {
+			mockGenerateText.mockResolvedValue({
+				text: "Test completion from Baseten",
+			})
+
+			const result = await handler.completePrompt("Test prompt")
+
+			expect(result).toBe("Test completion from Baseten")
+			expect(mockGenerateText).toHaveBeenCalledWith(
+				expect.objectContaining({
+					prompt: "Test prompt",
+				}),
+			)
+		})
+
+		it("should use default temperature in completePrompt", async () => {
+			mockGenerateText.mockResolvedValue({
+				text: "Test completion",
+			})
+
+			await handler.completePrompt("Test prompt")
+
+			expect(mockGenerateText).toHaveBeenCalledWith(
+				expect.objectContaining({
+					temperature: 0.5,
+				}),
+			)
+		})
+	})
+
+	describe("tool handling", () => {
+		const systemPrompt = "You are a helpful assistant."
+		const messages: Anthropic.Messages.MessageParam[] = [
+			{
+				role: "user",
+				content: [{ type: "text" as const, text: "Hello!" }],
+			},
+		]
+
+		it("should handle tool calls in streaming", async () => {
+			async function* mockFullStream() {
+				yield {
+					type: "tool-input-start",
+					id: "tool-call-1",
+					toolName: "read_file",
+				}
+				yield {
+					type: "tool-input-delta",
+					id: "tool-call-1",
+					delta: '{"path":"test.ts"}',
+				}
+				yield {
+					type: "tool-input-end",
+					id: "tool-call-1",
+				}
+			}
+
+			const mockUsage = Promise.resolve({
+				inputTokens: 10,
+				outputTokens: 5,
+			})
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: mockUsage,
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages, {
+				taskId: "test-task",
+				tools: [
+					{
+						type: "function",
+						function: {
+							name: "read_file",
+							description: "Read a file",
+							parameters: {
+								type: "object",
+								properties: { path: { type: "string" } },
+								required: ["path"],
+							},
+						},
+					},
+				],
+			})
+
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			const toolCallStartChunks = chunks.filter((c) => c.type === "tool_call_start")
+			const toolCallDeltaChunks = chunks.filter((c) => c.type === "tool_call_delta")
+			const toolCallEndChunks = chunks.filter((c) => c.type === "tool_call_end")
+
+			expect(toolCallStartChunks.length).toBe(1)
+			expect(toolCallStartChunks[0].id).toBe("tool-call-1")
+			expect(toolCallStartChunks[0].name).toBe("read_file")
+
+			expect(toolCallDeltaChunks.length).toBe(1)
+			expect(toolCallDeltaChunks[0].delta).toBe('{"path":"test.ts"}')
+
+			expect(toolCallEndChunks.length).toBe(1)
+			expect(toolCallEndChunks[0].id).toBe("tool-call-1")
+		})
+
+		it("should ignore tool-call events to prevent duplicate tools in UI", async () => {
+			async function* mockFullStream() {
+				yield {
+					type: "tool-call",
+					toolCallId: "tool-call-1",
+					toolName: "read_file",
+					input: { path: "test.ts" },
+				}
+			}
+
+			const mockUsage = Promise.resolve({
+				inputTokens: 10,
+				outputTokens: 5,
+			})
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: mockUsage,
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			const toolCallChunks = chunks.filter(
+				(c) => c.type === "tool_call_start" || c.type === "tool_call_delta" || c.type === "tool_call_end",
+			)
+			expect(toolCallChunks.length).toBe(0)
+		})
+	})
+
+	describe("error handling", () => {
+		const systemPrompt = "You are a helpful assistant."
+		const messages: Anthropic.Messages.MessageParam[] = [
+			{
+				role: "user",
+				content: [{ type: "text" as const, text: "Hello!" }],
+			},
+		]
+
+		it("should handle AI SDK errors with handleAiSdkError", async () => {
+			// eslint-disable-next-line require-yield
+			async function* mockFullStream(): AsyncGenerator<any> {
+				throw new Error("API Error")
+			}
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }),
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+
+			await expect(async () => {
+				for await (const _ of stream) {
+					// consume stream
+				}
+			}).rejects.toThrow("Baseten: API Error")
+		})
+
+		it("should preserve status codes in error handling", async () => {
+			const apiError = new Error("Rate limit exceeded")
+			;(apiError as any).status = 429
+
+			// eslint-disable-next-line require-yield
+			async function* mockFullStream(): AsyncGenerator<any> {
+				throw apiError
+			}
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }),
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+
+			try {
+				for await (const _ of stream) {
+					// consume stream
+				}
+				expect.fail("Should have thrown an error")
+			} catch (error: any) {
+				expect(error.message).toContain("Baseten")
+				expect(error.status).toBe(429)
+			}
+		})
+	})
+})

+ 148 - 10
src/api/providers/baseten.ts

@@ -1,18 +1,156 @@
-import { type BasetenModelId, basetenDefaultModelId, basetenModels } from "@roo-code/types"
+import { Anthropic } from "@anthropic-ai/sdk"
+import { createBaseten } from "@ai-sdk/baseten"
+import { streamText, generateText, ToolSet } from "ai"
+
+import { basetenModels, basetenDefaultModelId, type ModelInfo } from "@roo-code/types"
 
 import type { ApiHandlerOptions } from "../../shared/api"
-import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider"
 
-export class BasetenHandler extends BaseOpenAiCompatibleProvider<BasetenModelId> {
+import {
+	convertToAiSdkMessages,
+	convertToolsForAiSdk,
+	processAiSdkStreamPart,
+	mapToolChoice,
+	handleAiSdkError,
+} from "../transform/ai-sdk"
+import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
+import { getModelParams } from "../transform/model-params"
+
+import { DEFAULT_HEADERS } from "./constants"
+import { BaseProvider } from "./base-provider"
+import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
+
+const BASETEN_DEFAULT_TEMPERATURE = 0.5
+
+/**
+ * Baseten provider using the dedicated @ai-sdk/baseten package.
+ * Provides native support for Baseten's inference API.
+ */
+export class BasetenHandler extends BaseProvider implements SingleCompletionHandler {
+	protected options: ApiHandlerOptions
+	protected provider: ReturnType<typeof createBaseten>
+
 	constructor(options: ApiHandlerOptions) {
-		super({
-			...options,
-			providerName: "Baseten",
+		super()
+		this.options = options
+
+		this.provider = createBaseten({
 			baseURL: "https://inference.baseten.co/v1",
-			apiKey: options.basetenApiKey,
-			defaultProviderModelId: basetenDefaultModelId,
-			providerModels: basetenModels,
-			defaultTemperature: 0.5,
+			apiKey: options.basetenApiKey ?? "not-provided",
+			headers: DEFAULT_HEADERS,
+		})
+	}
+
+	override getModel(): { id: string; info: ModelInfo; maxTokens?: number; temperature?: number } {
+		const id = this.options.apiModelId ?? basetenDefaultModelId
+		const info = basetenModels[id as keyof typeof basetenModels] || basetenModels[basetenDefaultModelId]
+		const params = getModelParams({
+			format: "openai",
+			modelId: id,
+			model: info,
+			settings: this.options,
+			defaultTemperature: BASETEN_DEFAULT_TEMPERATURE,
+		})
+		return { id, info, ...params }
+	}
+
+	/**
+	 * Get the language model for the configured model ID.
+	 */
+	protected getLanguageModel() {
+		const { id } = this.getModel()
+		return this.provider(id)
+	}
+
+	/**
+	 * Process usage metrics from the AI SDK response.
+	 */
+	protected processUsageMetrics(usage: {
+		inputTokens?: number
+		outputTokens?: number
+		details?: {
+			cachedInputTokens?: number
+			reasoningTokens?: number
+		}
+	}): ApiStreamUsageChunk {
+		return {
+			type: "usage",
+			inputTokens: usage.inputTokens || 0,
+			outputTokens: usage.outputTokens || 0,
+			reasoningTokens: usage.details?.reasoningTokens,
+		}
+	}
+
+	/**
+	 * Get the max tokens parameter to include in the request.
+	 */
+	protected getMaxOutputTokens(): number | undefined {
+		const { info } = this.getModel()
+		return this.options.modelMaxTokens || info.maxTokens || undefined
+	}
+
+	/**
+	 * Create a message stream using the AI SDK.
+	 */
+	override async *createMessage(
+		systemPrompt: string,
+		messages: Anthropic.Messages.MessageParam[],
+		metadata?: ApiHandlerCreateMessageMetadata,
+	): ApiStream {
+		const { temperature } = this.getModel()
+		const languageModel = this.getLanguageModel()
+
+		const aiSdkMessages = convertToAiSdkMessages(messages)
+
+		const openAiTools = this.convertToolsForOpenAI(metadata?.tools)
+		const aiSdkTools = convertToolsForAiSdk(openAiTools) as ToolSet | undefined
+
+		const requestOptions: Parameters<typeof streamText>[0] = {
+			model: languageModel,
+			system: systemPrompt,
+			messages: aiSdkMessages,
+			temperature: this.options.modelTemperature ?? temperature ?? BASETEN_DEFAULT_TEMPERATURE,
+			maxOutputTokens: this.getMaxOutputTokens(),
+			tools: aiSdkTools,
+			toolChoice: mapToolChoice(metadata?.tool_choice),
+		}
+
+		const result = streamText(requestOptions)
+
+		try {
+			for await (const part of result.fullStream) {
+				for (const chunk of processAiSdkStreamPart(part)) {
+					yield chunk
+				}
+			}
+
+			const usage = await result.usage
+			if (usage) {
+				yield this.processUsageMetrics(usage)
+			}
+		} catch (error) {
+			throw handleAiSdkError(error, "Baseten")
+		}
+	}
+
+	/**
+	 * Complete a prompt using the AI SDK generateText.
+	 */
+	async completePrompt(prompt: string): Promise<string> {
+		const { temperature } = this.getModel()
+		const languageModel = this.getLanguageModel()
+
+		const { text } = await generateText({
+			model: languageModel,
+			prompt,
+			maxOutputTokens: this.getMaxOutputTokens(),
+			temperature: this.options.modelTemperature ?? temperature ?? BASETEN_DEFAULT_TEMPERATURE,
 		})
+
+		return text
+	}
+
+	override isAiSdkProvider(): boolean {
+		return true
 	}
 }

+ 16 - 0
src/esbuild.mjs

@@ -43,6 +43,22 @@ async function main() {
 	 * @type {import('esbuild').Plugin[]}
 	 */
 	const plugins = [
+		{
+			// Stub out @basetenlabs/performance-client which contains native .node
+			// binaries that esbuild cannot bundle. This module is only used by
+			// @ai-sdk/baseten for embedding models, not for chat completions.
+			name: "stub-baseten-native",
+			setup(build) {
+				build.onResolve({ filter: /^@basetenlabs\/performance-client/ }, (args) => ({
+					path: args.path,
+					namespace: "stub-baseten-native",
+				}))
+				build.onLoad({ filter: /.*/, namespace: "stub-baseten-native" }, () => ({
+					contents: "module.exports = { PerformanceClient: class PerformanceClient {} };",
+					loader: "js",
+				}))
+			},
+		},
 		{
 			name: "copyFiles",
 			setup(build) {

+ 1 - 0
src/package.json

@@ -451,6 +451,7 @@
 	},
 	"dependencies": {
 		"@ai-sdk/amazon-bedrock": "^4.0.50",
+		"@ai-sdk/baseten": "^1.0.31",
 		"@ai-sdk/cerebras": "^1.0.0",
 		"@ai-sdk/deepseek": "^2.0.14",
 		"@ai-sdk/fireworks": "^2.0.26",