Browse Source

Add DeepSeek to the list of providers

Matt Rubens 1 year ago
parent
commit
eb8c4cc50f

+ 5 - 0
.changeset/modern-carrots-applaud.md

@@ -0,0 +1,5 @@
+---
+"roo-cline": patch
+---
+
+Add the DeepSeek provider along with logic to trim messages when it hits the context window

+ 1 - 0
README.md

@@ -13,6 +13,7 @@ A fork of Cline, an autonomous coding agent, with some additional experimental f
 - Includes current time in the system prompt
 - Uses a file system watcher to more reliably watch for file system changes
 - Language selection for Cline's communication (English, Japanese, Spanish, French, German, and more)
+- Support for DeepSeek V3 with logic to trim messages when it hits the context window
 - Support for Meta 3, 3.1, and 3.2 models via AWS Bedrock
 - Support for listing models from OpenAI-compatible providers
 - Per-tool MCP auto-approval

+ 3 - 0
src/api/index.ts

@@ -9,6 +9,7 @@ import { OllamaHandler } from "./providers/ollama"
 import { LmStudioHandler } from "./providers/lmstudio"
 import { GeminiHandler } from "./providers/gemini"
 import { OpenAiNativeHandler } from "./providers/openai-native"
+import { DeepSeekHandler } from "./providers/deepseek"
 import { ApiStream } from "./transform/stream"
 
 export interface SingleCompletionHandler {
@@ -41,6 +42,8 @@ export function buildApiHandler(configuration: ApiConfiguration): ApiHandler {
 			return new GeminiHandler(options)
 		case "openai-native":
 			return new OpenAiNativeHandler(options)
+		case "deepseek":
+			return new DeepSeekHandler(options)
 		default:
 			return new AnthropicHandler(options)
 	}

+ 251 - 0
src/api/providers/__tests__/deepseek.test.ts

@@ -0,0 +1,251 @@
+import { DeepSeekHandler } from '../deepseek'
+import { ApiHandlerOptions } from '../../../shared/api'
+import OpenAI from 'openai'
+import { Anthropic } from '@anthropic-ai/sdk'
+
+// Mock dependencies
+jest.mock('openai')
+jest.mock('../../../shared/api', () => ({
+    ...jest.requireActual('../../../shared/api'),
+    deepSeekModels: {
+        'deepseek-chat': {
+            maxTokens: 1000,
+            contextWindow: 2000,
+            supportsImages: false,
+            supportsPromptCache: false,
+            inputPrice: 0.014,
+            outputPrice: 0.28,
+        }
+    }
+}))
+
+describe('DeepSeekHandler', () => {
+
+    const mockOptions: ApiHandlerOptions = {
+        deepSeekApiKey: 'test-key',
+        deepSeekModelId: 'deepseek-chat',
+    }
+
+    beforeEach(() => {
+        jest.clearAllMocks()
+    })
+
+    test('constructor initializes with correct options', () => {
+        const handler = new DeepSeekHandler(mockOptions)
+        expect(handler).toBeInstanceOf(DeepSeekHandler)
+        expect(OpenAI).toHaveBeenCalledWith({
+            baseURL: 'https://api.deepseek.com/v1',
+            apiKey: mockOptions.deepSeekApiKey,
+        })
+    })
+
+    test('getModel returns correct model info', () => {
+        const handler = new DeepSeekHandler(mockOptions)
+        const result = handler.getModel()
+        
+        expect(result).toEqual({
+            id: mockOptions.deepSeekModelId,
+            info: expect.objectContaining({
+                maxTokens: 1000,
+                contextWindow: 2000,
+                supportsPromptCache: false,
+                supportsImages: false,
+                inputPrice: 0.014,
+                outputPrice: 0.28,
+            })
+        })
+    })
+
+    test('getModel returns default model info when no model specified', () => {
+        const handler = new DeepSeekHandler({ deepSeekApiKey: 'test-key' })
+        const result = handler.getModel()
+        
+        expect(result.id).toBe('deepseek-chat')
+        expect(result.info.maxTokens).toBe(1000)
+    })
+
+    test('createMessage handles string content correctly', async () => {
+        const handler = new DeepSeekHandler(mockOptions)
+        const mockStream = {
+            async *[Symbol.asyncIterator]() {
+                yield {
+                    choices: [{
+                        delta: {
+                            content: 'test response'
+                        }
+                    }]
+                }
+            }
+        }
+
+        const mockCreate = jest.fn().mockResolvedValue(mockStream)
+        ;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
+            completions: { create: mockCreate }
+        } as any
+
+        const systemPrompt = 'test system prompt'
+        const messages: Anthropic.Messages.MessageParam[] = [
+            { role: 'user', content: 'test message' }
+        ]
+
+        const generator = handler.createMessage(systemPrompt, messages)
+        const chunks = []
+        
+        for await (const chunk of generator) {
+            chunks.push(chunk)
+        }
+
+        expect(chunks).toHaveLength(1)
+        expect(chunks[0]).toEqual({
+            type: 'text',
+            text: 'test response'
+        })
+
+        expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
+            model: mockOptions.deepSeekModelId,
+            messages: [
+                { role: 'system', content: systemPrompt },
+                { role: 'user', content: 'test message' }
+            ],
+            temperature: 0,
+            stream: true,
+            max_tokens: 1000,
+            stream_options: { include_usage: true }
+        }))
+    })
+
+    test('createMessage handles complex content correctly', async () => {
+        const handler = new DeepSeekHandler(mockOptions)
+        const mockStream = {
+            async *[Symbol.asyncIterator]() {
+                yield {
+                    choices: [{
+                        delta: {
+                            content: 'test response'
+                        }
+                    }]
+                }
+            }
+        }
+
+        const mockCreate = jest.fn().mockResolvedValue(mockStream)
+        ;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
+            completions: { create: mockCreate }
+        } as any
+
+        const systemPrompt = 'test system prompt'
+        const messages: Anthropic.Messages.MessageParam[] = [
+            {
+                role: 'user',
+                content: [
+                    { type: 'text', text: 'part 1' },
+                    { type: 'text', text: 'part 2' }
+                ]
+            }
+        ]
+
+        const generator = handler.createMessage(systemPrompt, messages)
+        await generator.next()
+
+        expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
+            messages: [
+                { role: 'system', content: systemPrompt },
+                { role: 'user', content: 'part 1part 2' }
+            ]
+        }))
+    })
+
+    test('createMessage truncates messages when exceeding context window', async () => {
+        const handler = new DeepSeekHandler(mockOptions)
+        const longString = 'a'.repeat(1000) // ~300 tokens
+        const shortString = 'b'.repeat(100) // ~30 tokens
+        
+        const systemPrompt = 'test system prompt'
+        const messages: Anthropic.Messages.MessageParam[] = [
+            { role: 'user', content: longString }, // Old message
+            { role: 'assistant', content: 'short response' },
+            { role: 'user', content: shortString } // Recent message
+        ]
+
+        const mockStream = {
+            async *[Symbol.asyncIterator]() {
+                yield {
+                    choices: [{
+                        delta: {
+                            content: '(Note: Some earlier messages were truncated to fit within the model\'s context window)\n\n'
+                        }
+                    }]
+                }
+                yield {
+                    choices: [{
+                        delta: {
+                            content: 'test response'
+                        }
+                    }]
+                }
+            }
+        }
+
+        const mockCreate = jest.fn().mockResolvedValue(mockStream)
+        ;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
+            completions: { create: mockCreate }
+        } as any
+
+        const generator = handler.createMessage(systemPrompt, messages)
+        const chunks = []
+        for await (const chunk of generator) {
+            chunks.push(chunk)
+        }
+
+        // Should get two chunks: truncation notice and response
+        expect(chunks).toHaveLength(2)
+        expect(chunks[0]).toEqual({
+            type: 'text',
+            text: expect.stringContaining('truncated')
+        })
+        expect(chunks[1]).toEqual({
+            type: 'text',
+            text: 'test response'
+        })
+
+        // Verify API call includes system prompt and recent messages, but not old message
+        expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
+            messages: expect.arrayContaining([
+                { role: 'system', content: systemPrompt },
+                { role: 'assistant', content: 'short response' },
+                { role: 'user', content: shortString }
+            ])
+        }))
+        
+        // Verify truncation notice was included
+        expect(chunks[0]).toEqual({
+            type: 'text',
+            text: expect.stringContaining('truncated')
+        })
+
+        // Verify the messages array contains the expected messages
+        const calledMessages = mockCreate.mock.calls[0][0].messages
+        expect(calledMessages).toHaveLength(4)
+        expect(calledMessages[0]).toEqual({ role: 'system', content: systemPrompt })
+        expect(calledMessages[1]).toEqual({ role: 'user', content: longString })
+        expect(calledMessages[2]).toEqual({ role: 'assistant', content: 'short response' })
+        expect(calledMessages[3]).toEqual({ role: 'user', content: shortString })
+    })
+
+    test('createMessage handles API errors', async () => {
+        const handler = new DeepSeekHandler(mockOptions)
+        const mockStream = {
+            async *[Symbol.asyncIterator]() {
+                throw new Error('API Error')
+            }
+        }
+
+        const mockCreate = jest.fn().mockResolvedValue(mockStream)
+        ;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
+            completions: { create: mockCreate }
+        } as any
+
+        const generator = handler.createMessage('test', [])
+        await expect(generator.next()).rejects.toThrow('API Error')
+    })
+})

+ 116 - 0
src/api/providers/deepseek.ts

@@ -0,0 +1,116 @@
+import { Anthropic } from "@anthropic-ai/sdk"
+import OpenAI from "openai"
+import { ApiHandlerOptions, ModelInfo, deepSeekModels, deepSeekDefaultModelId } from "../../shared/api"
+import { ApiHandler } from "../index"
+import { ApiStream } from "../transform/stream"
+
+export class DeepSeekHandler implements ApiHandler {
+	private options: ApiHandlerOptions
+	private client: OpenAI
+
+	constructor(options: ApiHandlerOptions) {
+		this.options = options
+		if (!options.deepSeekApiKey) {
+			throw new Error("DeepSeek API key is required. Please provide it in the settings.")
+		}
+		this.client = new OpenAI({
+			baseURL: this.options.deepSeekBaseUrl ?? "https://api.deepseek.com/v1",
+			apiKey: this.options.deepSeekApiKey,
+		})
+	}
+
+	async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
+		// Convert messages to simple format that DeepSeek expects
+		const formattedMessages = messages.map(msg => {
+			if (typeof msg.content === "string") {
+				return { role: msg.role, content: msg.content }
+			}
+			// For array content, concatenate text parts
+			return {
+				role: msg.role,
+				content: msg.content.reduce((acc, part) => {
+					if (part.type === "text") {
+						return acc + part.text
+					}
+					return acc
+				}, "")
+			}
+		})
+
+		const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
+			{ role: "system", content: systemPrompt },
+			...formattedMessages,
+		]
+		const modelInfo = deepSeekModels[this.options.deepSeekModelId as keyof typeof deepSeekModels] || deepSeekModels[deepSeekDefaultModelId]
+		
+		const contextWindow = modelInfo.contextWindow || 64_000
+		const getTokenCount = (content: string) => Math.ceil(content.length * 0.3)
+
+		// Always keep system prompt
+		const systemMsg = openAiMessages[0]
+		let availableTokens = contextWindow - getTokenCount(typeof systemMsg.content === 'string' ? systemMsg.content : '')
+		
+		// Start with most recent messages and work backwards
+		const userMessages = openAiMessages.slice(1).reverse()
+		const includedMessages = []
+		let truncated = false
+
+		for (const msg of userMessages) {
+			const content = typeof msg.content === 'string' ? msg.content : ''
+			const tokens = getTokenCount(content)
+			
+			if (tokens <= availableTokens) {
+				includedMessages.unshift(msg)
+				availableTokens -= tokens
+			} else {
+				truncated = true
+				break
+			}
+		}
+
+		if (truncated) {
+			yield {
+				type: 'text',
+				text: '(Note: Some earlier messages were truncated to fit within the model\'s context window)\n\n'
+			}
+		}
+
+		const requestOptions: OpenAI.Chat.ChatCompletionCreateParamsStreaming = {
+			model: this.options.deepSeekModelId ?? "deepseek-chat",
+			messages: [systemMsg, ...includedMessages],
+			temperature: 0,
+			stream: true,
+			max_tokens: modelInfo.maxTokens,
+		}
+
+		if (this.options.includeStreamOptions ?? true) {
+			requestOptions.stream_options = { include_usage: true }
+		}
+
+		const stream = await this.client.chat.completions.create(requestOptions)
+		for await (const chunk of stream) {
+			const delta = chunk.choices[0]?.delta
+			if (delta?.content) {
+				yield {
+					type: "text",
+					text: delta.content,
+				}
+			}
+			if (chunk.usage) {
+				yield {
+					type: "usage",
+					inputTokens: chunk.usage.prompt_tokens || 0,
+					outputTokens: chunk.usage.completion_tokens || 0,
+				}
+			}
+		}
+	}
+
+	getModel(): { id: string; info: ModelInfo } {
+		const modelId = this.options.deepSeekModelId ?? deepSeekDefaultModelId
+		return {
+			id: modelId,
+			info: deepSeekModels[modelId as keyof typeof deepSeekModels] || deepSeekModels[deepSeekDefaultModelId],
+		}
+	}
+}

+ 6 - 0
src/core/webview/ClineProvider.ts

@@ -40,6 +40,7 @@ type SecretKey =
 	| "openAiApiKey"
 	| "geminiApiKey"
 	| "openAiNativeApiKey"
+	| "deepSeekApiKey"
 type GlobalStateKey =
 	| "apiProvider"
 	| "apiModelId"
@@ -443,6 +444,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 							await this.updateGlobalState("anthropicBaseUrl", anthropicBaseUrl)
 							await this.storeSecret("geminiApiKey", geminiApiKey)
 							await this.storeSecret("openAiNativeApiKey", openAiNativeApiKey)
+							await this.storeSecret("deepSeekApiKey", message.apiConfiguration.deepSeekApiKey)
 							await this.updateGlobalState("azureApiVersion", azureApiVersion)
 							await this.updateGlobalState("openRouterModelId", openRouterModelId)
 							await this.updateGlobalState("openRouterModelInfo", openRouterModelInfo)
@@ -1121,6 +1123,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 			anthropicBaseUrl,
 			geminiApiKey,
 			openAiNativeApiKey,
+			deepSeekApiKey,
 			azureApiVersion,
 			openRouterModelId,
 			openRouterModelInfo,
@@ -1163,6 +1166,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 			this.getGlobalState("anthropicBaseUrl") as Promise<string | undefined>,
 			this.getSecret("geminiApiKey") as Promise<string | undefined>,
 			this.getSecret("openAiNativeApiKey") as Promise<string | undefined>,
+			this.getSecret("deepSeekApiKey") as Promise<string | undefined>,
 			this.getGlobalState("azureApiVersion") as Promise<string | undefined>,
 			this.getGlobalState("openRouterModelId") as Promise<string | undefined>,
 			this.getGlobalState("openRouterModelInfo") as Promise<ModelInfo | undefined>,
@@ -1222,6 +1226,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 				anthropicBaseUrl,
 				geminiApiKey,
 				openAiNativeApiKey,
+				deepSeekApiKey,
 				azureApiVersion,
 				openRouterModelId,
 				openRouterModelInfo,
@@ -1344,6 +1349,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 			"openAiApiKey",
 			"geminiApiKey",
 			"openAiNativeApiKey",
+			"deepSeekApiKey",
 		]
 		for (const key of secretKeys) {
 			await this.storeSecret(key, undefined)

+ 20 - 0
src/shared/api.ts

@@ -8,6 +8,7 @@ export type ApiProvider =
 	| "lmstudio"
 	| "gemini"
 	| "openai-native"
+	| "deepseek"
 
 export interface ApiHandlerOptions {
 	apiModelId?: string
@@ -38,6 +39,9 @@ export interface ApiHandlerOptions {
 	openRouterUseMiddleOutTransform?: boolean
 	includeStreamOptions?: boolean
 	setAzureApiVersion?: boolean
+	deepSeekBaseUrl?: string
+	deepSeekApiKey?: string
+	deepSeekModelId?: string
 }
 
 export type ApiConfiguration = ApiHandlerOptions & {
@@ -489,6 +493,22 @@ export const openAiNativeModels = {
 	},
 } as const satisfies Record<string, ModelInfo>
 
+// DeepSeek
+// https://platform.deepseek.com/docs/api
+export type DeepSeekModelId = keyof typeof deepSeekModels
+export const deepSeekDefaultModelId: DeepSeekModelId = "deepseek-chat"
+export const deepSeekModels = {
+	"deepseek-chat": {
+		maxTokens: 8192,
+		contextWindow: 64_000,
+		supportsImages: false,
+		supportsPromptCache: false,
+		inputPrice: 0.014,  // $0.014 per million tokens
+		outputPrice: 0.28,  // $0.28 per million tokens
+		description: `DeepSeek-V3 achieves a significant breakthrough in inference speed over previous models. It tops the leaderboard among open-source models and rivals the most advanced closed-source models globally.`,
+	},
+} as const satisfies Record<string, ModelInfo>
+
 // Azure OpenAI
 // https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation
 // https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#api-specs

+ 36 - 2
webview-ui/src/components/settings/ApiOptions.tsx

@@ -17,6 +17,8 @@ import {
 	azureOpenAiDefaultApiVersion,
 	bedrockDefaultModelId,
 	bedrockModels,
+	deepSeekDefaultModelId,
+	deepSeekModels,
 	geminiDefaultModelId,
 	geminiModels,
 	openAiModelInfoSaneDefaults,
@@ -130,10 +132,11 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage, modelIdErrorMessage }:
 					<VSCodeOption value="openrouter">OpenRouter</VSCodeOption>
 					<VSCodeOption value="anthropic">Anthropic</VSCodeOption>
 					<VSCodeOption value="gemini">Google Gemini</VSCodeOption>
-					<VSCodeOption value="vertex">GCP Vertex AI</VSCodeOption>
-					<VSCodeOption value="bedrock">AWS Bedrock</VSCodeOption>
+					<VSCodeOption value="deepseek">DeepSeek</VSCodeOption>
 					<VSCodeOption value="openai-native">OpenAI</VSCodeOption>
 					<VSCodeOption value="openai">OpenAI Compatible</VSCodeOption>
+					<VSCodeOption value="vertex">GCP Vertex AI</VSCodeOption>
+					<VSCodeOption value="bedrock">AWS Bedrock</VSCodeOption>
 					<VSCodeOption value="lmstudio">LM Studio</VSCodeOption>
 					<VSCodeOption value="ollama">Ollama</VSCodeOption>
 				</VSCodeDropdown>
@@ -560,6 +563,34 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage, modelIdErrorMessage }:
 				</div>
 			)}
 
+			{selectedProvider === "deepseek" && (
+				<div>
+					<VSCodeTextField
+						value={apiConfiguration?.deepSeekApiKey || ""}
+						style={{ width: "100%" }}
+						type="password"
+						onInput={handleInputChange("deepSeekApiKey")}
+						placeholder="Enter API Key...">
+						<span style={{ fontWeight: 500 }}>DeepSeek API Key</span>
+					</VSCodeTextField>
+					<p
+						style={{
+							fontSize: "12px",
+							marginTop: "5px",
+							color: "var(--vscode-descriptionForeground)",
+						}}>
+						This key is stored locally and only used to make API requests from this extension.
+						{!apiConfiguration?.deepSeekApiKey && (
+							<VSCodeLink
+								href="https://platform.deepseek.com/"
+								style={{ display: "inline", fontSize: "inherit" }}>
+								You can get a DeepSeek API key by signing up here.
+							</VSCodeLink>
+						)}
+					</p>
+				</div>
+			)}
+
 			{selectedProvider === "ollama" && (
 				<div>
 					<VSCodeTextField
@@ -652,6 +683,7 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage, modelIdErrorMessage }:
 							{selectedProvider === "vertex" && createDropdown(vertexModels)}
 							{selectedProvider === "gemini" && createDropdown(geminiModels)}
 							{selectedProvider === "openai-native" && createDropdown(openAiNativeModels)}
+							{selectedProvider === "deepseek" && createDropdown(deepSeekModels)}
 						</div>
 
 						<ModelInfoView
@@ -836,6 +868,8 @@ export function normalizeApiConfiguration(apiConfiguration?: ApiConfiguration) {
 			return getProviderData(vertexModels, vertexDefaultModelId)
 		case "gemini":
 			return getProviderData(geminiModels, geminiDefaultModelId)
+		case "deepseek":
+			return getProviderData(deepSeekModels, deepSeekDefaultModelId)
 		case "openai-native":
 			return getProviderData(openAiNativeModels, openAiNativeDefaultModelId)
 		case "openrouter":