فهرست منبع

Implement streaming for all providers

Saoud Rizwan 1 سال پیش
والد
کامیت
06ccaf6f67

+ 0 - 1
src/api/index.ts

@@ -12,7 +12,6 @@ import { ApiStream } from "./transform/stream"
 
 export interface ApiHandler {
 	createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream
-
 	getModel(): { id: string; info: ModelInfo }
 }
 

+ 2 - 4
src/api/providers/anthropic.ts

@@ -117,8 +117,8 @@ export class AnthropicHandler implements ApiHandler {
 						type: "usage",
 						inputTokens: usage.input_tokens || 0,
 						outputTokens: usage.output_tokens || 0,
-						cacheWriteTokens: usage.cache_creation_input_tokens || 0,
-						cacheReadTokens: usage.cache_read_input_tokens || 0,
+						cacheWriteTokens: usage.cache_creation_input_tokens || undefined,
+						cacheReadTokens: usage.cache_read_input_tokens || undefined,
 					}
 					break
 				case "message_delta":
@@ -128,8 +128,6 @@ export class AnthropicHandler implements ApiHandler {
 						type: "usage",
 						inputTokens: 0,
 						outputTokens: chunk.usage.output_tokens || 0,
-						cacheWriteTokens: 0,
-						cacheReadTokens: 0,
 					}
 					break
 				case "message_stop":

+ 52 - 11
src/api/providers/bedrock.ts

@@ -1,7 +1,8 @@
 import AnthropicBedrock from "@anthropic-ai/bedrock-sdk"
 import { Anthropic } from "@anthropic-ai/sdk"
-import { ApiHandler, ApiHandlerMessageResponse } from "../"
+import { ApiHandler } from "../"
 import { ApiHandlerOptions, bedrockDefaultModelId, BedrockModelId, bedrockModels, ModelInfo } from "../../shared/api"
+import { ApiStream } from "../transform/stream"
 
 // https://docs.anthropic.com/en/api/claude-on-amazon-bedrock
 export class AwsBedrockHandler implements ApiHandler {
@@ -23,21 +24,61 @@ export class AwsBedrockHandler implements ApiHandler {
 		})
 	}
 
-	async createMessage(
-		systemPrompt: string,
-		messages: Anthropic.Messages.MessageParam[],
-		tools: Anthropic.Messages.Tool[]
-	): Promise<ApiHandlerMessageResponse> {
-		const message = await this.client.messages.create({
+	async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
+		const stream = await this.client.messages.create({
 			model: this.getModel().id,
 			max_tokens: this.getModel().info.maxTokens,
-			temperature: 0.2,
+			temperature: 0,
 			system: systemPrompt,
 			messages,
-			tools,
-			tool_choice: { type: "auto" },
+			stream: true,
 		})
-		return { message }
+		for await (const chunk of stream) {
+			switch (chunk.type) {
+				case "message_start":
+					const usage = chunk.message.usage
+					yield {
+						type: "usage",
+						inputTokens: usage.input_tokens || 0,
+						outputTokens: usage.output_tokens || 0,
+					}
+					break
+				case "message_delta":
+					yield {
+						type: "usage",
+						inputTokens: 0,
+						outputTokens: chunk.usage.output_tokens || 0,
+					}
+					break
+
+				case "content_block_start":
+					switch (chunk.content_block.type) {
+						case "text":
+							if (chunk.index > 0) {
+								yield {
+									type: "text",
+									text: "\n",
+								}
+							}
+							yield {
+								type: "text",
+								text: chunk.content_block.text,
+							}
+							break
+					}
+					break
+				case "content_block_delta":
+					switch (chunk.delta.type) {
+						case "text_delta":
+							yield {
+								type: "text",
+								text: chunk.delta.text,
+							}
+							break
+					}
+					break
+			}
+		}
 	}
 
 	getModel(): { id: BedrockModelId; info: ModelInfo } {

+ 20 - 22
src/api/providers/gemini.ts

@@ -1,12 +1,9 @@
 import { Anthropic } from "@anthropic-ai/sdk"
-import { FunctionCallingMode, GoogleGenerativeAI } from "@google/generative-ai"
-import { ApiHandler, ApiHandlerMessageResponse } from "../"
+import { GoogleGenerativeAI } from "@google/generative-ai"
+import { ApiHandler } from "../"
 import { ApiHandlerOptions, geminiDefaultModelId, GeminiModelId, geminiModels, ModelInfo } from "../../shared/api"
-import {
-	convertAnthropicMessageToGemini,
-	convertAnthropicToolToGemini,
-	convertGeminiResponseToAnthropic,
-} from "../transform/gemini-format"
+import { convertAnthropicMessageToGemini } from "../transform/gemini-format"
+import { ApiStream } from "../transform/stream"
 
 export class GeminiHandler implements ApiHandler {
 	private options: ApiHandlerOptions
@@ -20,31 +17,32 @@ export class GeminiHandler implements ApiHandler {
 		this.client = new GoogleGenerativeAI(options.geminiApiKey)
 	}
 
-	async createMessage(
-		systemPrompt: string,
-		messages: Anthropic.Messages.MessageParam[],
-		tools: Anthropic.Messages.Tool[]
-	): Promise<ApiHandlerMessageResponse> {
+	async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
 		const model = this.client.getGenerativeModel({
 			model: this.getModel().id,
 			systemInstruction: systemPrompt,
-			tools: [{ functionDeclarations: tools.map(convertAnthropicToolToGemini) }],
-			toolConfig: {
-				functionCallingConfig: {
-					mode: FunctionCallingMode.AUTO,
-				},
-			},
 		})
-		const result = await model.generateContent({
+		const result = await model.generateContentStream({
 			contents: messages.map(convertAnthropicMessageToGemini),
 			generationConfig: {
 				maxOutputTokens: this.getModel().info.maxTokens,
-				temperature: 0.2,
+				temperature: 0,
 			},
 		})
-		const message = convertGeminiResponseToAnthropic(result.response)
 
-		return { message }
+		for await (const chunk of result.stream) {
+			yield {
+				type: "text",
+				text: chunk.text(),
+			}
+		}
+
+		const response = await result.response
+		yield {
+			type: "usage",
+			inputTokens: response.usageMetadata?.promptTokenCount ?? 0,
+			outputTokens: response.usageMetadata?.candidatesTokenCount ?? 0,
+		}
 	}
 
 	getModel(): { id: GeminiModelId; info: ModelInfo } {

+ 17 - 26
src/api/providers/ollama.ts

@@ -1,8 +1,9 @@
 import { Anthropic } from "@anthropic-ai/sdk"
 import OpenAI from "openai"
-import { ApiHandler, ApiHandlerMessageResponse } from "../"
+import { ApiHandler } from "../"
 import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api"
-import { convertToAnthropicMessage, convertToOpenAiMessages } from "../transform/openai-format"
+import { convertToOpenAiMessages } from "../transform/openai-format"
+import { ApiStream } from "../transform/stream"
 
 export class OllamaHandler implements ApiHandler {
 	private options: ApiHandlerOptions
@@ -16,37 +17,27 @@ export class OllamaHandler implements ApiHandler {
 		})
 	}
 
-	async createMessage(
-		systemPrompt: string,
-		messages: Anthropic.Messages.MessageParam[],
-		tools: Anthropic.Messages.Tool[]
-	): Promise<ApiHandlerMessageResponse> {
+	async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
 		const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
 			{ role: "system", content: systemPrompt },
 			...convertToOpenAiMessages(messages),
 		]
-		const openAiTools: OpenAI.Chat.ChatCompletionTool[] = tools.map((tool) => ({
-			type: "function",
-			function: {
-				name: tool.name,
-				description: tool.description,
-				parameters: tool.input_schema,
-			},
-		}))
-		const createParams: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
+
+		const stream = await this.client.chat.completions.create({
 			model: this.options.ollamaModelId ?? "",
 			messages: openAiMessages,
-			temperature: 0.2,
-			tools: openAiTools,
-			tool_choice: "auto",
-		}
-		const completion = await this.client.chat.completions.create(createParams)
-		const errorMessage = (completion as any).error?.message
-		if (errorMessage) {
-			throw new Error(errorMessage)
+			temperature: 0,
+			stream: true,
+		})
+		for await (const chunk of stream) {
+			const delta = chunk.choices[0]?.delta
+			if (delta?.content) {
+				yield {
+					type: "text",
+					text: delta.content,
+				}
+			}
 		}
-		const anthropicMessage = convertToAnthropicMessage(completion)
-		return { message: anthropicMessage }
 	}
 
 	getModel(): { id: string; info: ModelInfo } {

+ 27 - 53
src/api/providers/openai-native.ts

@@ -1,6 +1,6 @@
 import { Anthropic } from "@anthropic-ai/sdk"
 import OpenAI from "openai"
-import { ApiHandler, ApiHandlerMessageResponse } from "../"
+import { ApiHandler } from "../"
 import {
 	ApiHandlerOptions,
 	ModelInfo,
@@ -8,8 +8,8 @@ import {
 	OpenAiNativeModelId,
 	openAiNativeModels,
 } from "../../shared/api"
-import { convertToAnthropicMessage, convertToOpenAiMessages } from "../transform/openai-format"
-import { convertO1ResponseToAnthropicMessage, convertToO1Messages } from "../transform/o1-format"
+import { convertToOpenAiMessages } from "../transform/openai-format"
+import { ApiStream } from "../transform/stream"
 
 export class OpenAiNativeHandler implements ApiHandler {
 	private options: ApiHandlerOptions
@@ -22,65 +22,39 @@ export class OpenAiNativeHandler implements ApiHandler {
 		})
 	}
 
-	async createMessage(
-		systemPrompt: string,
-		messages: Anthropic.Messages.MessageParam[],
-		tools: Anthropic.Messages.Tool[]
-	): Promise<ApiHandlerMessageResponse> {
+	async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
 		const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
 			{ role: "system", content: systemPrompt },
 			...convertToOpenAiMessages(messages),
 		]
-		const openAiTools: OpenAI.Chat.ChatCompletionTool[] = tools.map((tool) => ({
-			type: "function",
-			function: {
-				name: tool.name,
-				description: tool.description,
-				parameters: tool.input_schema,
-			},
-		}))
 
-		let createParams: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming
+		const stream = await this.client.chat.completions.create({
+			model: this.getModel().id,
+			max_completion_tokens: this.getModel().info.maxTokens,
+			temperature: 0,
+			messages: openAiMessages,
+			stream: true,
+			stream_options: { include_usage: true },
+		})
 
-		switch (this.getModel().id) {
-			case "o1-preview":
-			case "o1-mini":
-				createParams = {
-					model: this.getModel().id,
-					max_completion_tokens: this.getModel().info.maxTokens,
-					messages: convertToO1Messages(convertToOpenAiMessages(messages), systemPrompt),
-				}
-				break
-			default:
-				createParams = {
-					model: this.getModel().id,
-					max_completion_tokens: this.getModel().info.maxTokens,
-					temperature: 0.2,
-					messages: openAiMessages,
-					tools: openAiTools,
-					tool_choice: "auto",
+		for await (const chunk of stream) {
+			const delta = chunk.choices[0]?.delta
+			if (delta?.content) {
+				yield {
+					type: "text",
+					text: delta.content,
 				}
-				break
-		}
+			}
 
-		const completion = await this.client.chat.completions.create(createParams)
-		const errorMessage = (completion as any).error?.message
-		if (errorMessage) {
-			throw new Error(errorMessage)
-		}
-
-		let anthropicMessage: Anthropic.Messages.Message
-		switch (this.getModel().id) {
-			case "o1-preview":
-			case "o1-mini":
-				anthropicMessage = convertO1ResponseToAnthropicMessage(completion)
-				break
-			default:
-				anthropicMessage = convertToAnthropicMessage(completion)
-				break
+			// contains a null value except for the last chunk which contains the token usage statistics for the entire request
+			if (chunk.usage) {
+				yield {
+					type: "usage",
+					inputTokens: chunk.usage.prompt_tokens || 0,
+					outputTokens: chunk.usage.completion_tokens || 0,
+				}
+			}
 		}
-
-		return { message: anthropicMessage }
 	}
 
 	getModel(): { id: OpenAiNativeModelId; info: ModelInfo } {

+ 24 - 26
src/api/providers/openai.ts

@@ -1,13 +1,14 @@
 import { Anthropic } from "@anthropic-ai/sdk"
 import OpenAI, { AzureOpenAI } from "openai"
-import { ApiHandler, ApiHandlerMessageResponse } from "../index"
 import {
 	ApiHandlerOptions,
 	azureOpenAiDefaultApiVersion,
 	ModelInfo,
 	openAiModelInfoSaneDefaults,
 } from "../../shared/api"
-import { convertToAnthropicMessage, convertToOpenAiMessages } from "../transform/openai-format"
+import { ApiHandler } from "../index"
+import { convertToOpenAiMessages } from "../transform/openai-format"
+import { ApiStream } from "../transform/stream"
 
 export class OpenAiHandler implements ApiHandler {
 	private options: ApiHandlerOptions
@@ -30,37 +31,34 @@ export class OpenAiHandler implements ApiHandler {
 		}
 	}
 
-	async createMessage(
-		systemPrompt: string,
-		messages: Anthropic.Messages.MessageParam[],
-		tools: Anthropic.Messages.Tool[]
-	): Promise<ApiHandlerMessageResponse> {
+	async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
 		const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
 			{ role: "system", content: systemPrompt },
 			...convertToOpenAiMessages(messages),
 		]
-		const openAiTools: OpenAI.Chat.ChatCompletionTool[] = tools.map((tool) => ({
-			type: "function",
-			function: {
-				name: tool.name,
-				description: tool.description,
-				parameters: tool.input_schema,
-			},
-		}))
-		const createParams: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
+		const stream = await this.client.chat.completions.create({
 			model: this.options.openAiModelId ?? "",
 			messages: openAiMessages,
-			temperature: 0.2,
-			tools: openAiTools,
-			tool_choice: "auto",
+			temperature: 0,
+			stream: true,
+			stream_options: { include_usage: true },
+		})
+		for await (const chunk of stream) {
+			const delta = chunk.choices[0]?.delta
+			if (delta?.content) {
+				yield {
+					type: "text",
+					text: delta.content,
+				}
+			}
+			if (chunk.usage) {
+				yield {
+					type: "usage",
+					inputTokens: chunk.usage.prompt_tokens || 0,
+					outputTokens: chunk.usage.completion_tokens || 0,
+				}
+			}
 		}
-		const completion = await this.client.chat.completions.create(createParams)
-		const errorMessage = (completion as any).error?.message
-		if (errorMessage) {
-			throw new Error(errorMessage)
-		}
-		const anthropicMessage = convertToAnthropicMessage(completion)
-		return { message: anthropicMessage }
 	}
 
 	getModel(): { id: string; info: ModelInfo } {

+ 2 - 2
src/api/providers/openrouter.ts

@@ -124,8 +124,8 @@ export class OpenRouterHandler implements ApiHandler {
 				type: "usage",
 				inputTokens: generation?.native_tokens_prompt || 0,
 				outputTokens: generation?.native_tokens_completion || 0,
-				cacheWriteTokens: 0,
-				cacheReadTokens: 0,
+				// cacheWriteTokens: 0,
+				// cacheReadTokens: 0,
 				totalCost: generation?.total_cost || 0,
 			}
 		} catch (error) {

+ 53 - 12
src/api/providers/vertex.ts

@@ -1,7 +1,8 @@
-import { AnthropicVertex } from "@anthropic-ai/vertex-sdk"
 import { Anthropic } from "@anthropic-ai/sdk"
-import { ApiHandler, ApiHandlerMessageResponse } from "../"
+import { AnthropicVertex } from "@anthropic-ai/vertex-sdk"
+import { ApiHandler } from "../"
 import { ApiHandlerOptions, ModelInfo, vertexDefaultModelId, VertexModelId, vertexModels } from "../../shared/api"
+import { ApiStream } from "../transform/stream"
 
 // https://docs.anthropic.com/en/api/claude-on-vertex-ai
 export class VertexHandler implements ApiHandler {
@@ -17,21 +18,61 @@ export class VertexHandler implements ApiHandler {
 		})
 	}
 
-	async createMessage(
-		systemPrompt: string,
-		messages: Anthropic.Messages.MessageParam[],
-		tools: Anthropic.Messages.Tool[]
-	): Promise<ApiHandlerMessageResponse> {
-		const message = await this.client.messages.create({
+	async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
+		const stream = await this.client.messages.create({
 			model: this.getModel().id,
 			max_tokens: this.getModel().info.maxTokens,
-			temperature: 0.2,
+			temperature: 0,
 			system: systemPrompt,
 			messages,
-			tools,
-			tool_choice: { type: "auto" },
+			stream: true,
 		})
-		return { message }
+		for await (const chunk of stream) {
+			switch (chunk.type) {
+				case "message_start":
+					const usage = chunk.message.usage
+					yield {
+						type: "usage",
+						inputTokens: usage.input_tokens || 0,
+						outputTokens: usage.output_tokens || 0,
+					}
+					break
+				case "message_delta":
+					yield {
+						type: "usage",
+						inputTokens: 0,
+						outputTokens: chunk.usage.output_tokens || 0,
+					}
+					break
+
+				case "content_block_start":
+					switch (chunk.content_block.type) {
+						case "text":
+							if (chunk.index > 0) {
+								yield {
+									type: "text",
+									text: "\n",
+								}
+							}
+							yield {
+								type: "text",
+								text: chunk.content_block.text,
+							}
+							break
+					}
+					break
+				case "content_block_delta":
+					switch (chunk.delta.type) {
+						case "text_delta":
+							yield {
+								type: "text",
+								text: chunk.delta.text,
+							}
+							break
+					}
+					break
+			}
+		}
 	}
 
 	getModel(): { id: VertexModelId; info: ModelInfo } {

+ 2 - 2
src/api/transform/stream.ts

@@ -10,7 +10,7 @@ export interface ApiStreamUsageChunk {
 	type: "usage"
 	inputTokens: number
 	outputTokens: number
-	cacheWriteTokens: number
-	cacheReadTokens: number
+	cacheWriteTokens?: number
+	cacheReadTokens?: number
 	totalCost?: number // openrouter
 }

+ 2 - 2
src/core/ClaudeDev.ts

@@ -2397,8 +2397,8 @@ ${this.customInstructions.trim()}
 					case "usage":
 						inputTokens += chunk.inputTokens
 						outputTokens += chunk.outputTokens
-						cacheWriteTokens += chunk.cacheWriteTokens
-						cacheReadTokens += chunk.cacheReadTokens
+						cacheWriteTokens += chunk.cacheWriteTokens ?? 0
+						cacheReadTokens += chunk.cacheReadTokens ?? 0
 						totalCost = chunk.totalCost
 						break
 					case "text":