Просмотр исходного кода

feat: Add DeepSeek R1 support to Chutes provider (#4523) (#4525)

* feat: Add DeepSeek R1 support to Chutes provider (#4523)

- Modified BaseOpenAiCompatibleProvider to expose client as protected
- Enhanced ChutesHandler to detect DeepSeek R1 models and parse reasoning chunks
- Applied R1 format conversion for message formatting
- Set appropriate temperature (0.6) for DeepSeek models
- Migrated tests from Jest to Vitest format
- Added comprehensive tests for DeepSeek R1 functionality

This ensures reasoning chunks are properly separated from regular content
when using DeepSeek R1 models via Chutes provider.

* feat: Enhance DeepSeek R1 support with <think> tag handling in Chutes provider

* fix: Correct temperature retrieval in ChutesHandler to use model's info

* fix: Update condition for DeepSeek-R1 model identification in createMessage method

---------

Co-authored-by: Daniel Riccio <[email protected]>
Hannes Rudolph 6 месяцев назад
Родитель
Сommit
a851ffb7cb

+ 186 - 22
src/api/providers/__tests__/chutes.spec.ts

@@ -1,33 +1,64 @@
 // npx vitest run api/providers/__tests__/chutes.spec.ts
 
-import { vitest, describe, it, expect, beforeEach } from "vitest"
-import OpenAI from "openai"
 import { Anthropic } from "@anthropic-ai/sdk"
+import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"
+import OpenAI from "openai"
 
-import { type ChutesModelId, chutesDefaultModelId, chutesModels } from "@roo-code/types"
+import { type ChutesModelId, chutesDefaultModelId, chutesModels, DEEP_SEEK_DEFAULT_TEMPERATURE } from "@roo-code/types"
 
 import { ChutesHandler } from "../chutes"
 
-const mockCreate = vitest.fn()
+// Create mock functions
+const mockCreate = vi.fn()
 
-vitest.mock("openai", () => {
-	return {
-		default: vitest.fn().mockImplementation(() => ({
-			chat: {
-				completions: {
-					create: mockCreate,
-				},
+// Mock OpenAI module
+vi.mock("openai", () => ({
+	default: vi.fn(() => ({
+		chat: {
+			completions: {
+				create: mockCreate,
 			},
-		})),
-	}
-})
+		},
+	})),
+}))
 
 describe("ChutesHandler", () => {
 	let handler: ChutesHandler
 
 	beforeEach(() => {
-		vitest.clearAllMocks()
-		handler = new ChutesHandler({ chutesApiKey: "test-chutes-api-key" })
+		vi.clearAllMocks()
+		// Set up default mock implementation
+		mockCreate.mockImplementation(async () => ({
+			[Symbol.asyncIterator]: async function* () {
+				yield {
+					choices: [
+						{
+							delta: { content: "Test response" },
+							index: 0,
+						},
+					],
+					usage: null,
+				}
+				yield {
+					choices: [
+						{
+							delta: {},
+							index: 0,
+						},
+					],
+					usage: {
+						prompt_tokens: 10,
+						completion_tokens: 5,
+						total_tokens: 15,
+					},
+				}
+			},
+		}))
+		handler = new ChutesHandler({ chutesApiKey: "test-key" })
+	})
+
+	afterEach(() => {
+		vi.restoreAllMocks()
 	})
 
 	it("should use the correct Chutes base URL", () => {
@@ -41,18 +72,96 @@ describe("ChutesHandler", () => {
 		expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: chutesApiKey }))
 	})
 
+	it("should handle DeepSeek R1 reasoning format", async () => {
+		// Override the mock for this specific test
+		mockCreate.mockImplementationOnce(async () => ({
+			[Symbol.asyncIterator]: async function* () {
+				yield {
+					choices: [
+						{
+							delta: { content: "<think>Thinking..." },
+							index: 0,
+						},
+					],
+					usage: null,
+				}
+				yield {
+					choices: [
+						{
+							delta: { content: "</think>Hello" },
+							index: 0,
+						},
+					],
+					usage: null,
+				}
+				yield {
+					choices: [
+						{
+							delta: {},
+							index: 0,
+						},
+					],
+					usage: { prompt_tokens: 10, completion_tokens: 5 },
+				}
+			},
+		}))
+
+		const systemPrompt = "You are a helpful assistant."
+		const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hi" }]
+		vi.spyOn(handler, "getModel").mockReturnValue({
+			id: "deepseek-ai/DeepSeek-R1-0528",
+			info: { maxTokens: 1024, temperature: 0.7 },
+		} as any)
+
+		const stream = handler.createMessage(systemPrompt, messages)
+		const chunks = []
+		for await (const chunk of stream) {
+			chunks.push(chunk)
+		}
+
+		expect(chunks).toEqual([
+			{ type: "reasoning", text: "Thinking..." },
+			{ type: "text", text: "Hello" },
+			{ type: "usage", inputTokens: 10, outputTokens: 5 },
+		])
+	})
+
+	it("should fall back to base provider for non-DeepSeek models", async () => {
+		// Use default mock implementation which returns text content
+		const systemPrompt = "You are a helpful assistant."
+		const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hi" }]
+		vi.spyOn(handler, "getModel").mockReturnValue({
+			id: "some-other-model",
+			info: { maxTokens: 1024, temperature: 0.7 },
+		} as any)
+
+		const stream = handler.createMessage(systemPrompt, messages)
+		const chunks = []
+		for await (const chunk of stream) {
+			chunks.push(chunk)
+		}
+
+		expect(chunks).toEqual([
+			{ type: "text", text: "Test response" },
+			{ type: "usage", inputTokens: 10, outputTokens: 5 },
+		])
+	})
+
 	it("should return default model when no model is specified", () => {
 		const model = handler.getModel()
 		expect(model.id).toBe(chutesDefaultModelId)
-		expect(model.info).toEqual(chutesModels[chutesDefaultModelId])
+		expect(model.info).toEqual(expect.objectContaining(chutesModels[chutesDefaultModelId]))
 	})
 
 	it("should return specified model when valid model is provided", () => {
 		const testModelId: ChutesModelId = "deepseek-ai/DeepSeek-R1"
-		const handlerWithModel = new ChutesHandler({ apiModelId: testModelId, chutesApiKey: "test-chutes-api-key" })
+		const handlerWithModel = new ChutesHandler({
+			apiModelId: testModelId,
+			chutesApiKey: "test-chutes-api-key",
+		})
 		const model = handlerWithModel.getModel()
 		expect(model.id).toBe(testModelId)
-		expect(model.info).toEqual(chutesModels[testModelId])
+		expect(model.info).toEqual(expect.objectContaining(chutesModels[testModelId]))
 	})
 
 	it("completePrompt method should return text from Chutes API", async () => {
@@ -74,7 +183,7 @@ describe("ChutesHandler", () => {
 		mockCreate.mockImplementationOnce(() => {
 			return {
 				[Symbol.asyncIterator]: () => ({
-					next: vitest
+					next: vi
 						.fn()
 						.mockResolvedValueOnce({
 							done: false,
@@ -96,7 +205,7 @@ describe("ChutesHandler", () => {
 		mockCreate.mockImplementationOnce(() => {
 			return {
 				[Symbol.asyncIterator]: () => ({
-					next: vitest
+					next: vi
 						.fn()
 						.mockResolvedValueOnce({
 							done: false,
@@ -114,8 +223,43 @@ describe("ChutesHandler", () => {
 		expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 })
 	})
 
-	it("createMessage should pass correct parameters to Chutes client", async () => {
+	it("createMessage should pass correct parameters to Chutes client for DeepSeek R1", async () => {
 		const modelId: ChutesModelId = "deepseek-ai/DeepSeek-R1"
+
+		// Clear previous mocks and set up new implementation
+		mockCreate.mockClear()
+		mockCreate.mockImplementationOnce(async () => ({
+			[Symbol.asyncIterator]: async function* () {
+				// Empty stream for this test
+			},
+		}))
+
+		const handlerWithModel = new ChutesHandler({
+			apiModelId: modelId,
+			chutesApiKey: "test-chutes-api-key",
+		})
+
+		const systemPrompt = "Test system prompt for Chutes"
+		const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message for Chutes" }]
+
+		const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages)
+		await messageGenerator.next()
+
+		expect(mockCreate).toHaveBeenCalledWith(
+			expect.objectContaining({
+				model: modelId,
+				messages: [
+					{
+						role: "user",
+						content: `${systemPrompt}\n${messages[0].content}`,
+					},
+				],
+			}),
+		)
+	})
+
+	it("createMessage should pass correct parameters to Chutes client for non-DeepSeek models", async () => {
+		const modelId: ChutesModelId = "unsloth/Llama-3.3-70B-Instruct"
 		const modelInfo = chutesModels[modelId]
 		const handlerWithModel = new ChutesHandler({ apiModelId: modelId, chutesApiKey: "test-chutes-api-key" })
 
@@ -146,4 +290,24 @@ describe("ChutesHandler", () => {
 			}),
 		)
 	})
+
+	it("should apply DeepSeek default temperature for R1 models", () => {
+		const testModelId: ChutesModelId = "deepseek-ai/DeepSeek-R1"
+		const handlerWithModel = new ChutesHandler({
+			apiModelId: testModelId,
+			chutesApiKey: "test-chutes-api-key",
+		})
+		const model = handlerWithModel.getModel()
+		expect(model.info.temperature).toBe(DEEP_SEEK_DEFAULT_TEMPERATURE)
+	})
+
+	it("should use default temperature for non-DeepSeek models", () => {
+		const testModelId: ChutesModelId = "unsloth/Llama-3.3-70B-Instruct"
+		const handlerWithModel = new ChutesHandler({
+			apiModelId: testModelId,
+			chutesApiKey: "test-chutes-api-key",
+		})
+		const model = handlerWithModel.getModel()
+		expect(model.info.temperature).toBe(0.5)
+	})
 })

+ 1 - 1
src/api/providers/base-openai-compatible-provider.ts

@@ -31,7 +31,7 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
 
 	protected readonly options: ApiHandlerOptions
 
-	private client: OpenAI
+	protected client: OpenAI
 
 	constructor({
 		providerName,

+ 85 - 1
src/api/providers/chutes.ts

@@ -1,6 +1,12 @@
-import { type ChutesModelId, chutesDefaultModelId, chutesModels } from "@roo-code/types"
+import { DEEP_SEEK_DEFAULT_TEMPERATURE, type ChutesModelId, chutesDefaultModelId, chutesModels } from "@roo-code/types"
+import { Anthropic } from "@anthropic-ai/sdk"
+import OpenAI from "openai"
 
 import type { ApiHandlerOptions } from "../../shared/api"
+import { XmlMatcher } from "../../utils/xml-matcher"
+import { convertToR1Format } from "../transform/r1-format"
+import { convertToOpenAiMessages } from "../transform/openai-format"
+import { ApiStream } from "../transform/stream"
 
 import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider"
 
@@ -16,4 +22,82 @@ export class ChutesHandler extends BaseOpenAiCompatibleProvider<ChutesModelId> {
 			defaultTemperature: 0.5,
 		})
 	}
+
+	private getCompletionParams(
+		systemPrompt: string,
+		messages: Anthropic.Messages.MessageParam[],
+	): OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming {
+		const {
+			id: model,
+			info: { maxTokens: max_tokens },
+		} = this.getModel()
+
+		const temperature = this.options.modelTemperature ?? this.getModel().info.temperature
+
+		return {
+			model,
+			max_tokens,
+			temperature,
+			messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
+			stream: true,
+			stream_options: { include_usage: true },
+		}
+	}
+
+	override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
+		const model = this.getModel()
+
+		if (model.id.includes("DeepSeek-R1")) {
+			const stream = await this.client.chat.completions.create({
+				...this.getCompletionParams(systemPrompt, messages),
+				messages: convertToR1Format([{ role: "user", content: systemPrompt }, ...messages]),
+			})
+
+			const matcher = new XmlMatcher(
+				"think",
+				(chunk) =>
+					({
+						type: chunk.matched ? "reasoning" : "text",
+						text: chunk.data,
+					}) as const,
+			)
+
+			for await (const chunk of stream) {
+				const delta = chunk.choices[0]?.delta
+
+				if (delta?.content) {
+					for (const processedChunk of matcher.update(delta.content)) {
+						yield processedChunk
+					}
+				}
+
+				if (chunk.usage) {
+					yield {
+						type: "usage",
+						inputTokens: chunk.usage.prompt_tokens || 0,
+						outputTokens: chunk.usage.completion_tokens || 0,
+					}
+				}
+			}
+
+			// Process any remaining content
+			for (const processedChunk of matcher.final()) {
+				yield processedChunk
+			}
+		} else {
+			yield* super.createMessage(systemPrompt, messages)
+		}
+	}
+
+	override getModel() {
+		const model = super.getModel()
+		const isDeepSeekR1 = model.id.includes("DeepSeek-R1")
+		return {
+			...model,
+			info: {
+				...model.info,
+				temperature: isDeepSeekR1 ? DEEP_SEEK_DEFAULT_TEMPERATURE : this.defaultTemperature,
+			},
+		}
+	}
 }