Răsfoiți Sursa

Streaming checkbox for OpenAI-compatible providers

Matt Rubens 1 an în urmă
părinte
comite
2cdfff02c0

+ 52 - 30
src/api/providers/openai.ts

@@ -32,43 +32,65 @@ export class OpenAiHandler implements ApiHandler {
 		}
 	}
 
-	// Include stream_options for OpenAI Compatible providers if the checkbox is checked
 	async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
-		const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
-			{ role: "system", content: systemPrompt },
-			...convertToOpenAiMessages(messages),
-		]
 		const modelInfo = this.getModel().info
-		const requestOptions: OpenAI.Chat.ChatCompletionCreateParams = {
-			model: this.options.openAiModelId ?? "",
-			messages: openAiMessages,
-			temperature: 0,
-			stream: true,
-		}
-		if (this.options.includeMaxTokens) {
-			requestOptions.max_tokens = modelInfo.maxTokens
-		}
+		const modelId = this.options.openAiModelId ?? ""
 
-		if (this.options.includeStreamOptions ?? true) {
-			requestOptions.stream_options = { include_usage: true }
-		}
+		if (this.options.openAiStreamingEnabled ?? true) {
+			const systemMessage: OpenAI.Chat.ChatCompletionSystemMessageParam = {
+				role: "system",
+				content: systemPrompt
+			}
+			const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
+				model: modelId,
+				temperature: 0,
+				messages: [systemMessage, ...convertToOpenAiMessages(messages)],
+				stream: true as const,
+				stream_options: { include_usage: true },
+			}
+			if (this.options.includeMaxTokens) {
+				requestOptions.max_tokens = modelInfo.maxTokens
+			}
+
+			const stream = await this.client.chat.completions.create(requestOptions)
 
-		const stream = await this.client.chat.completions.create(requestOptions)
-		for await (const chunk of stream) {
-			const delta = chunk.choices[0]?.delta
-			if (delta?.content) {
-				yield {
-					type: "text",
-					text: delta.content,
+			for await (const chunk of stream) {
+				const delta = chunk.choices[0]?.delta
+				if (delta?.content) {
+					yield {
+						type: "text",
+						text: delta.content,
+					}
 				}
-			}
-			if (chunk.usage) {
-				yield {
-					type: "usage",
-					inputTokens: chunk.usage.prompt_tokens || 0,
-					outputTokens: chunk.usage.completion_tokens || 0,
+				if (chunk.usage) {
+					yield {
+						type: "usage",
+						inputTokens: chunk.usage.prompt_tokens || 0,
+						outputTokens: chunk.usage.completion_tokens || 0,
+					}
 				}
 			}
+		} else {
+			// o1 for instance doesnt support streaming, non-1 temp, or system prompt
+			const systemMessage: OpenAI.Chat.ChatCompletionUserMessageParam = {
+				role: "user",
+				content: systemPrompt
+			}
+			const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
+				model: modelId,
+				messages: [systemMessage, ...convertToOpenAiMessages(messages)],
+			}
+			const response = await this.client.chat.completions.create(requestOptions)
+			
+			yield {
+				type: "text",
+				text: response.choices[0]?.message.content || "",
+			}
+			yield {
+				type: "usage",
+				inputTokens: response.usage?.prompt_tokens || 0,
+				outputTokens: response.usage?.completion_tokens || 0,
+			}
 		}
 	}
 

+ 6 - 6
src/core/webview/ClineProvider.ts

@@ -66,7 +66,7 @@ type GlobalStateKey =
 	| "lmStudioBaseUrl"
 	| "anthropicBaseUrl"
 	| "azureApiVersion"
-	| "includeStreamOptions"
+	| "openAiStreamingEnabled"
 	| "openRouterModelId"
 	| "openRouterModelInfo"
 	| "openRouterUseMiddleOutTransform"
@@ -447,7 +447,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 								geminiApiKey,
 								openAiNativeApiKey,
 								azureApiVersion,
-								includeStreamOptions,
+								openAiStreamingEnabled,
 								openRouterModelId,
 								openRouterModelInfo,
 								openRouterUseMiddleOutTransform,
@@ -478,7 +478,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 							await this.storeSecret("openAiNativeApiKey", openAiNativeApiKey)
 							await this.storeSecret("deepSeekApiKey", message.apiConfiguration.deepSeekApiKey)
 							await this.updateGlobalState("azureApiVersion", azureApiVersion)
-							await this.updateGlobalState("includeStreamOptions", includeStreamOptions)
+							await this.updateGlobalState("openAiStreamingEnabled", openAiStreamingEnabled)
 							await this.updateGlobalState("openRouterModelId", openRouterModelId)
 							await this.updateGlobalState("openRouterModelInfo", openRouterModelInfo)
 							await this.updateGlobalState("openRouterUseMiddleOutTransform", openRouterUseMiddleOutTransform)
@@ -1295,7 +1295,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 			openAiNativeApiKey,
 			deepSeekApiKey,
 			azureApiVersion,
-			includeStreamOptions,
+			openAiStreamingEnabled,
 			openRouterModelId,
 			openRouterModelInfo,
 			openRouterUseMiddleOutTransform,
@@ -1345,7 +1345,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 			this.getSecret("openAiNativeApiKey") as Promise<string | undefined>,
 			this.getSecret("deepSeekApiKey") as Promise<string | undefined>,
 			this.getGlobalState("azureApiVersion") as Promise<string | undefined>,
-			this.getGlobalState("includeStreamOptions") as Promise<boolean | undefined>,
+			this.getGlobalState("openAiStreamingEnabled") as Promise<boolean | undefined>,
 			this.getGlobalState("openRouterModelId") as Promise<string | undefined>,
 			this.getGlobalState("openRouterModelInfo") as Promise<ModelInfo | undefined>,
 			this.getGlobalState("openRouterUseMiddleOutTransform") as Promise<boolean | undefined>,
@@ -1412,7 +1412,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 				openAiNativeApiKey,
 				deepSeekApiKey,
 				azureApiVersion,
-				includeStreamOptions,
+				openAiStreamingEnabled,
 				openRouterModelId,
 				openRouterModelInfo,
 				openRouterUseMiddleOutTransform,

+ 1 - 1
src/shared/api.ts

@@ -41,7 +41,7 @@ export interface ApiHandlerOptions {
 	openAiNativeApiKey?: string
 	azureApiVersion?: string
 	openRouterUseMiddleOutTransform?: boolean
-	includeStreamOptions?: boolean
+	openAiStreamingEnabled?: boolean
 	setAzureApiVersion?: boolean
 	deepSeekBaseUrl?: string
 	deepSeekApiKey?: string

+ 3 - 8
webview-ui/src/components/settings/ApiOptions.tsx

@@ -477,21 +477,16 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage, modelIdErrorMessage }:
 					<OpenAiModelPicker />
 					<div style={{ display: 'flex', alignItems: 'center' }}>
 						<VSCodeCheckbox
-							checked={apiConfiguration?.includeStreamOptions ?? true}
+							checked={apiConfiguration?.openAiStreamingEnabled ?? true}
 							onChange={(e: any) => {
 								const isChecked = e.target.checked
 								setApiConfiguration({
 									...apiConfiguration,
-									includeStreamOptions: isChecked
+									openAiStreamingEnabled: isChecked
 								})
 							}}>
-							Include stream options
+							Enable streaming
 						</VSCodeCheckbox>
-						<span
-							className="codicon codicon-info"
-							title="Stream options are for { include_usage: true }. Some providers may not support this option."
-							style={{ marginLeft: '5px', cursor: 'help' }}
-						></span>
 					</div>
 					<VSCodeCheckbox
 						checked={azureApiVersionSelected}