Matt Rubens 1 year ago
parent
commit
6290f90fa5

+ 7 - 1
src/api/providers/__tests__/deepseek.test.ts

@@ -137,7 +137,13 @@ describe('DeepSeekHandler', () => {
         expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
             messages: [
                 { role: 'system', content: systemPrompt },
-                { role: 'user', content: 'part 1part 2' }
+                {
+                    role: 'user',
+                    content: [
+                        { type: 'text', text: 'part 1' },
+                        { type: 'text', text: 'part 2' }
+                    ]
+                }
             ]
         }))
     })

+ 25 - 95
src/api/providers/deepseek.ts

@@ -1,96 +1,26 @@
-import { Anthropic } from "@anthropic-ai/sdk"
-import OpenAI from "openai"
-import { ApiHandlerOptions, ModelInfo, deepSeekModels, deepSeekDefaultModelId } from "../../shared/api"
-import { ApiHandler } from "../index"
-import { ApiStream } from "../transform/stream"
-
-export class DeepSeekHandler implements ApiHandler {
-	private options: ApiHandlerOptions
-	private client: OpenAI
-
-	constructor(options: ApiHandlerOptions) {
-		this.options = options
-		if (!options.deepSeekApiKey) {
-			throw new Error("DeepSeek API key is required. Please provide it in the settings.")
-		}
-		this.client = new OpenAI({
-			baseURL: this.options.deepSeekBaseUrl ?? "https://api.deepseek.com/v1",
-			apiKey: this.options.deepSeekApiKey,
-		})
-	}
-
-	async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
-		const modelInfo = deepSeekModels[this.options.deepSeekModelId as keyof typeof deepSeekModels] || deepSeekModels[deepSeekDefaultModelId]
-
-		// Format all messages
-		const messagesToInclude: OpenAI.Chat.ChatCompletionMessageParam[] = [
-			{ role: 'system' as const, content: systemPrompt }
-		]
-
-		// Add the rest of the messages
-		for (const msg of messages) {
-			let messageContent = ""
-			if (typeof msg.content === "string") {
-				messageContent = msg.content
-			} else if (Array.isArray(msg.content)) {
-				messageContent = msg.content.reduce((acc, part) => {
-					if (part.type === "text") {
-						return acc + part.text
-					}
-					return acc
-				}, "")
-			}
-			
-			messagesToInclude.push({
-				role: msg.role === 'user' ? 'user' as const : 'assistant' as const,
-				content: messageContent
-			})
-		}
-
-		const requestOptions: OpenAI.Chat.ChatCompletionCreateParamsStreaming = {
-			model: this.options.deepSeekModelId ?? "deepseek-chat",
-			messages: messagesToInclude,
-			temperature: 0,
-			stream: true,
-			max_tokens: modelInfo.maxTokens,
-		}
-
-		if (this.options.includeStreamOptions ?? true) {
-			requestOptions.stream_options = { include_usage: true }
-		}
-
-		let totalInputTokens = 0;
-		let totalOutputTokens = 0;
-
-		try {
-			const stream = await this.client.chat.completions.create(requestOptions)
-			for await (const chunk of stream) {
-				const delta = chunk.choices[0]?.delta
-				if (delta?.content) {
-					yield {
-						type: "text",
-						text: delta.content,
-					}
-				}
-				if (chunk.usage) {
-					yield {
-						type: "usage",
-						inputTokens: chunk.usage.prompt_tokens || 0,
-						outputTokens: chunk.usage.completion_tokens || 0,
-					}
-				}
-			}
-		} catch (error) {
-			console.error("DeepSeek API Error:", error)
-			throw error
-		}
-	}
-
-	getModel(): { id: string; info: ModelInfo } {
-		const modelId = this.options.deepSeekModelId ?? deepSeekDefaultModelId
-		return {
-			id: modelId,
-			info: deepSeekModels[modelId as keyof typeof deepSeekModels] || deepSeekModels[deepSeekDefaultModelId],
-		}
-	}
+import { OpenAiHandler } from "./openai"
+import { ApiHandlerOptions, ModelInfo } from "../../shared/api"
+import { deepSeekModels, deepSeekDefaultModelId } from "../../shared/api"
+
+export class DeepSeekHandler extends OpenAiHandler {
+    constructor(options: ApiHandlerOptions) {
+        if (!options.deepSeekApiKey) {
+            throw new Error("DeepSeek API key is required. Please provide it in the settings.")
+        }
+        super({
+            ...options,
+            openAiApiKey: options.deepSeekApiKey,
+            openAiModelId: options.deepSeekModelId ?? deepSeekDefaultModelId,
+            openAiBaseUrl: options.deepSeekBaseUrl ?? "https://api.deepseek.com/v1",
+            includeMaxTokens: true
+        })
+    }
+
+    override getModel(): { id: string; info: ModelInfo } {
+        const modelId = this.options.deepSeekModelId ?? deepSeekDefaultModelId
+        return {
+            id: modelId,
+            info: deepSeekModels[modelId as keyof typeof deepSeekModels] || deepSeekModels[deepSeekDefaultModelId]
+        }
+    }
 }

+ 5 - 1
src/api/providers/openai.ts

@@ -11,7 +11,7 @@ import { convertToOpenAiMessages } from "../transform/openai-format"
 import { ApiStream } from "../transform/stream"
 
 export class OpenAiHandler implements ApiHandler {
-	private options: ApiHandlerOptions
+	protected options: ApiHandlerOptions
 	private client: OpenAI
 
 	constructor(options: ApiHandlerOptions) {
@@ -38,12 +38,16 @@ export class OpenAiHandler implements ApiHandler {
 			{ role: "system", content: systemPrompt },
 			...convertToOpenAiMessages(messages),
 		]
+		const modelInfo = this.getModel().info
 		const requestOptions: OpenAI.Chat.ChatCompletionCreateParams = {
 			model: this.options.openAiModelId ?? "",
 			messages: openAiMessages,
 			temperature: 0,
 			stream: true,
 		}
+		if (this.options.includeMaxTokens) {
+			requestOptions.max_tokens = modelInfo.maxTokens
+		}
 
 		if (this.options.includeStreamOptions ?? true) {
 			requestOptions.stream_options = { include_usage: true }

+ 1 - 0
src/shared/api.ts

@@ -42,6 +42,7 @@ export interface ApiHandlerOptions {
 	deepSeekBaseUrl?: string
 	deepSeekApiKey?: string
 	deepSeekModelId?: string
+	includeMaxTokens?: boolean
 }
 
 export type ApiConfiguration = ApiHandlerOptions & {