Browse Source

Add openai compatible provider

Saoud Rizwan 1 year ago
parent
commit
c209198b23

+ 5 - 2
src/api/index.ts

@@ -1,9 +1,10 @@
 import { Anthropic } from "@anthropic-ai/sdk"
 import { Anthropic } from "@anthropic-ai/sdk"
-import { ApiConfiguration, ApiModelId, ModelInfo } from "../shared/api"
+import { ApiConfiguration, ModelInfo } from "../shared/api"
 import { AnthropicHandler } from "./anthropic"
 import { AnthropicHandler } from "./anthropic"
 import { AwsBedrockHandler } from "./bedrock"
 import { AwsBedrockHandler } from "./bedrock"
 import { OpenRouterHandler } from "./openrouter"
 import { OpenRouterHandler } from "./openrouter"
 import { VertexHandler } from "./vertex"
 import { VertexHandler } from "./vertex"
+import { OpenAiHandler } from "./openai"
 
 
 export interface ApiHandlerMessageResponse {
 export interface ApiHandlerMessageResponse {
 	message: Anthropic.Messages.Message
 	message: Anthropic.Messages.Message
@@ -26,7 +27,7 @@ export interface ApiHandler {
 		>
 		>
 	): any
 	): any
 
 
-	getModel(): { id: ApiModelId; info: ModelInfo }
+	getModel(): { id: string; info: ModelInfo }
 }
 }
 
 
 export function buildApiHandler(configuration: ApiConfiguration): ApiHandler {
 export function buildApiHandler(configuration: ApiConfiguration): ApiHandler {
@@ -40,6 +41,8 @@ export function buildApiHandler(configuration: ApiConfiguration): ApiHandler {
 			return new AwsBedrockHandler(options)
 			return new AwsBedrockHandler(options)
 		case "vertex":
 		case "vertex":
 			return new VertexHandler(options)
 			return new VertexHandler(options)
+		case "openai":
+			return new OpenAiHandler(options)
 		default:
 		default:
 			return new AnthropicHandler(options)
 			return new AnthropicHandler(options)
 	}
 	}

+ 74 - 0
src/api/openai.ts

@@ -0,0 +1,74 @@
+import { Anthropic } from "@anthropic-ai/sdk"
+import OpenAI from "openai"
+import { ApiHandler, ApiHandlerMessageResponse, withoutImageData } from "."
+import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../shared/api"
+import { convertToAnthropicMessage, convertToOpenAiMessages } from "../utils/openai-format"
+
+export class OpenAiHandler implements ApiHandler {
+	private options: ApiHandlerOptions
+	private client: OpenAI
+
+	constructor(options: ApiHandlerOptions) {
+		this.options = options
+		this.client = new OpenAI({
+			baseURL: this.options.openAiBaseUrl,
+			apiKey: this.options.openAiApiKey,
+		})
+	}
+
+	async createMessage(
+		systemPrompt: string,
+		messages: Anthropic.Messages.MessageParam[],
+		tools: Anthropic.Messages.Tool[]
+	): Promise<ApiHandlerMessageResponse> {
+		const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
+			{ role: "system", content: systemPrompt },
+			...convertToOpenAiMessages(messages),
+		]
+		const openAiTools: OpenAI.Chat.ChatCompletionTool[] = tools.map((tool) => ({
+			type: "function",
+			function: {
+				name: tool.name,
+				description: tool.description,
+				parameters: tool.input_schema,
+			},
+		}))
+		const createParams: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
+			model: this.options.openAiModelId ?? "",
+			messages: openAiMessages,
+			tools: openAiTools,
+			tool_choice: "auto",
+		}
+		const completion = await this.client.chat.completions.create(createParams)
+		const errorMessage = (completion as any).error?.message
+		if (errorMessage) {
+			throw new Error(errorMessage)
+		}
+		const anthropicMessage = convertToAnthropicMessage(completion)
+		return { message: anthropicMessage }
+	}
+
+	createUserReadableRequest(
+		userContent: Array<
+			| Anthropic.TextBlockParam
+			| Anthropic.ImageBlockParam
+			| Anthropic.ToolUseBlockParam
+			| Anthropic.ToolResultBlockParam
+		>
+	): any {
+		return {
+			model: this.options.openAiModelId ?? "",
+			system: "(see SYSTEM_PROMPT in src/ClaudeDev.ts)",
+			messages: [{ conversation_history: "..." }, { role: "user", content: withoutImageData(userContent) }],
+			tools: "(see tools in src/ClaudeDev.ts)",
+			tool_choice: "auto",
+		}
+	}
+
+	getModel(): { id: string; info: ModelInfo } {
+		return {
+			id: this.options.openAiModelId ?? "",
+			info: openAiModelInfoSaneDefaults,
+		}
+	}
+}

+ 2 - 52
src/api/openrouter.ts

@@ -8,7 +8,7 @@ import {
 	OpenRouterModelId,
 	OpenRouterModelId,
 	openRouterModels,
 	openRouterModels,
 } from "../shared/api"
 } from "../shared/api"
-import { convertToOpenAiMessages } from "../utils/openai-format"
+import { convertToAnthropicMessage, convertToOpenAiMessages } from "../utils/openai-format"
 
 
 export class OpenRouterHandler implements ApiHandler {
 export class OpenRouterHandler implements ApiHandler {
 	private options: ApiHandlerOptions
 	private options: ApiHandlerOptions
@@ -68,57 +68,7 @@ export class OpenRouterHandler implements ApiHandler {
 			throw new Error(errorMessage)
 			throw new Error(errorMessage)
 		}
 		}
 
 
-		// Convert OpenAI response to Anthropic format
-		const openAiMessage = completion.choices[0].message
-		const anthropicMessage: Anthropic.Messages.Message = {
-			id: completion.id,
-			type: "message",
-			role: openAiMessage.role, // always "assistant"
-			content: [
-				{
-					type: "text",
-					text: openAiMessage.content || "",
-				},
-			],
-			model: completion.model,
-			stop_reason: (() => {
-				switch (completion.choices[0].finish_reason) {
-					case "stop":
-						return "end_turn"
-					case "length":
-						return "max_tokens"
-					case "tool_calls":
-						return "tool_use"
-					case "content_filter": // Anthropic doesn't have an exact equivalent
-					default:
-						return null
-				}
-			})(),
-			stop_sequence: null, // which custom stop_sequence was generated, if any (not applicable if you don't use stop_sequence)
-			usage: {
-				input_tokens: completion.usage?.prompt_tokens || 0,
-				output_tokens: completion.usage?.completion_tokens || 0,
-			},
-		}
-
-		if (openAiMessage.tool_calls && openAiMessage.tool_calls.length > 0) {
-			anthropicMessage.content.push(
-				...openAiMessage.tool_calls.map((toolCall): Anthropic.ToolUseBlock => {
-					let parsedInput = {}
-					try {
-						parsedInput = JSON.parse(toolCall.function.arguments || "{}")
-					} catch (error) {
-						console.error("Failed to parse tool arguments:", error)
-					}
-					return {
-						type: "tool_use",
-						id: toolCall.id,
-						name: toolCall.function.name,
-						input: parsedInput,
-					}
-				})
-			)
-		}
+		const anthropicMessage = convertToAnthropicMessage(completion)
 
 
 		return { message: anthropicMessage }
 		return { message: anthropicMessage }
 	}
 	}

+ 21 - 3
src/providers/ClaudeDevProvider.ts

@@ -1,7 +1,7 @@
 import { Anthropic } from "@anthropic-ai/sdk"
 import { Anthropic } from "@anthropic-ai/sdk"
 import * as vscode from "vscode"
 import * as vscode from "vscode"
 import { ClaudeDev } from "../ClaudeDev"
 import { ClaudeDev } from "../ClaudeDev"
-import { ApiModelId, ApiProvider } from "../shared/api"
+import { ApiProvider } from "../shared/api"
 import { ExtensionMessage } from "../shared/ExtensionMessage"
 import { ExtensionMessage } from "../shared/ExtensionMessage"
 import { WebviewMessage } from "../shared/WebviewMessage"
 import { WebviewMessage } from "../shared/WebviewMessage"
 import { downloadTask, findLast, getNonce, getUri, selectImages } from "../utils"
 import { downloadTask, findLast, getNonce, getUri, selectImages } from "../utils"
@@ -16,7 +16,7 @@ https://github.com/microsoft/vscode-webview-ui-toolkit-samples/blob/main/default
 https://github.com/KumarVariable/vscode-extension-sidebar-html/blob/master/src/customSidebarViewProvider.ts
 https://github.com/KumarVariable/vscode-extension-sidebar-html/blob/master/src/customSidebarViewProvider.ts
 */
 */
 
 
-type SecretKey = "apiKey" | "openRouterApiKey" | "awsAccessKey" | "awsSecretKey" | "awsSessionToken"
+type SecretKey = "apiKey" | "openRouterApiKey" | "awsAccessKey" | "awsSecretKey" | "awsSessionToken" | "openAiApiKey"
 type GlobalStateKey =
 type GlobalStateKey =
 	| "apiProvider"
 	| "apiProvider"
 	| "apiModelId"
 	| "apiModelId"
@@ -27,6 +27,8 @@ type GlobalStateKey =
 	| "customInstructions"
 	| "customInstructions"
 	| "alwaysAllowReadOnly"
 	| "alwaysAllowReadOnly"
 	| "taskHistory"
 	| "taskHistory"
+	| "openAiBaseUrl"
+	| "openAiModelId"
 
 
 export class ClaudeDevProvider implements vscode.WebviewViewProvider {
 export class ClaudeDevProvider implements vscode.WebviewViewProvider {
 	public static readonly sideBarId = "claude-dev.SidebarProvider" // used in package.json as the view's id. This value cannot be changed due to how vscode caches views based on their id, and updating the id would break existing instances of the extension.
 	public static readonly sideBarId = "claude-dev.SidebarProvider" // used in package.json as the view's id. This value cannot be changed due to how vscode caches views based on their id, and updating the id would break existing instances of the extension.
@@ -314,6 +316,9 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
 								awsRegion,
 								awsRegion,
 								vertexProjectId,
 								vertexProjectId,
 								vertexRegion,
 								vertexRegion,
+								openAiBaseUrl,
+								openAiApiKey,
+								openAiModelId,
 							} = message.apiConfiguration
 							} = message.apiConfiguration
 							await this.updateGlobalState("apiProvider", apiProvider)
 							await this.updateGlobalState("apiProvider", apiProvider)
 							await this.updateGlobalState("apiModelId", apiModelId)
 							await this.updateGlobalState("apiModelId", apiModelId)
@@ -325,6 +330,9 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
 							await this.updateGlobalState("awsRegion", awsRegion)
 							await this.updateGlobalState("awsRegion", awsRegion)
 							await this.updateGlobalState("vertexProjectId", vertexProjectId)
 							await this.updateGlobalState("vertexProjectId", vertexProjectId)
 							await this.updateGlobalState("vertexRegion", vertexRegion)
 							await this.updateGlobalState("vertexRegion", vertexRegion)
+							await this.updateGlobalState("openAiBaseUrl", openAiBaseUrl)
+							await this.storeSecret("openAiApiKey", openAiApiKey)
+							await this.updateGlobalState("openAiModelId", openAiModelId)
 							this.claudeDev?.updateApi(message.apiConfiguration)
 							this.claudeDev?.updateApi(message.apiConfiguration)
 						}
 						}
 						await this.postStateToWebview()
 						await this.postStateToWebview()
@@ -615,13 +623,16 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
 			awsRegion,
 			awsRegion,
 			vertexProjectId,
 			vertexProjectId,
 			vertexRegion,
 			vertexRegion,
+			openAiBaseUrl,
+			openAiApiKey,
+			openAiModelId,
 			lastShownAnnouncementId,
 			lastShownAnnouncementId,
 			customInstructions,
 			customInstructions,
 			alwaysAllowReadOnly,
 			alwaysAllowReadOnly,
 			taskHistory,
 			taskHistory,
 		] = await Promise.all([
 		] = await Promise.all([
 			this.getGlobalState("apiProvider") as Promise<ApiProvider | undefined>,
 			this.getGlobalState("apiProvider") as Promise<ApiProvider | undefined>,
-			this.getGlobalState("apiModelId") as Promise<ApiModelId | undefined>,
+			this.getGlobalState("apiModelId") as Promise<string | undefined>,
 			this.getSecret("apiKey") as Promise<string | undefined>,
 			this.getSecret("apiKey") as Promise<string | undefined>,
 			this.getSecret("openRouterApiKey") as Promise<string | undefined>,
 			this.getSecret("openRouterApiKey") as Promise<string | undefined>,
 			this.getSecret("awsAccessKey") as Promise<string | undefined>,
 			this.getSecret("awsAccessKey") as Promise<string | undefined>,
@@ -630,6 +641,9 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
 			this.getGlobalState("awsRegion") as Promise<string | undefined>,
 			this.getGlobalState("awsRegion") as Promise<string | undefined>,
 			this.getGlobalState("vertexProjectId") as Promise<string | undefined>,
 			this.getGlobalState("vertexProjectId") as Promise<string | undefined>,
 			this.getGlobalState("vertexRegion") as Promise<string | undefined>,
 			this.getGlobalState("vertexRegion") as Promise<string | undefined>,
+			this.getGlobalState("openAiBaseUrl") as Promise<string | undefined>,
+			this.getSecret("openAiApiKey") as Promise<string | undefined>,
+			this.getGlobalState("openAiModelId") as Promise<string | undefined>,
 			this.getGlobalState("lastShownAnnouncementId") as Promise<string | undefined>,
 			this.getGlobalState("lastShownAnnouncementId") as Promise<string | undefined>,
 			this.getGlobalState("customInstructions") as Promise<string | undefined>,
 			this.getGlobalState("customInstructions") as Promise<string | undefined>,
 			this.getGlobalState("alwaysAllowReadOnly") as Promise<boolean | undefined>,
 			this.getGlobalState("alwaysAllowReadOnly") as Promise<boolean | undefined>,
@@ -662,6 +676,9 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
 				awsRegion,
 				awsRegion,
 				vertexProjectId,
 				vertexProjectId,
 				vertexRegion,
 				vertexRegion,
+				openAiBaseUrl,
+				openAiApiKey,
+				openAiModelId,
 			},
 			},
 			lastShownAnnouncementId,
 			lastShownAnnouncementId,
 			customInstructions,
 			customInstructions,
@@ -739,6 +756,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
 			"awsAccessKey",
 			"awsAccessKey",
 			"awsSecretKey",
 			"awsSecretKey",
 			"awsSessionToken",
 			"awsSessionToken",
+			"openAiApiKey",
 		]
 		]
 		for (const key of secretKeys) {
 		for (const key of secretKeys) {
 			await this.storeSecret(key, undefined)
 			await this.storeSecret(key, undefined)

+ 14 - 4
src/shared/api.ts

@@ -1,7 +1,7 @@
-export type ApiProvider = "anthropic" | "openrouter" | "bedrock" | "vertex"
+export type ApiProvider = "anthropic" | "openrouter" | "bedrock" | "vertex" | "openai"
 
 
 export interface ApiHandlerOptions {
 export interface ApiHandlerOptions {
-	apiModelId?: ApiModelId
+	apiModelId?: string
 	apiKey?: string // anthropic
 	apiKey?: string // anthropic
 	openRouterApiKey?: string
 	openRouterApiKey?: string
 	awsAccessKey?: string
 	awsAccessKey?: string
@@ -10,6 +10,9 @@ export interface ApiHandlerOptions {
 	awsRegion?: string
 	awsRegion?: string
 	vertexProjectId?: string
 	vertexProjectId?: string
 	vertexRegion?: string
 	vertexRegion?: string
+	openAiBaseUrl?: string
+	openAiApiKey?: string
+	openAiModelId?: string
 }
 }
 
 
 export type ApiConfiguration = ApiHandlerOptions & {
 export type ApiConfiguration = ApiHandlerOptions & {
@@ -29,8 +32,6 @@ export interface ModelInfo {
 	cacheReadsPrice?: number
 	cacheReadsPrice?: number
 }
 }
 
 
-export type ApiModelId = AnthropicModelId | OpenRouterModelId | BedrockModelId | VertexModelId
-
 // Anthropic
 // Anthropic
 // https://docs.anthropic.com/en/docs/about-claude/models
 // https://docs.anthropic.com/en/docs/about-claude/models
 export type AnthropicModelId = keyof typeof anthropicModels
 export type AnthropicModelId = keyof typeof anthropicModels
@@ -292,3 +293,12 @@ export const vertexModels = {
 		outputPrice: 1.25,
 		outputPrice: 1.25,
 	},
 	},
 } as const satisfies Record<string, ModelInfo>
 } as const satisfies Record<string, ModelInfo>
+
+export const openAiModelInfoSaneDefaults: ModelInfo = {
+	maxTokens: -1,
+	contextWindow: 128_000,
+	supportsImages: true,
+	supportsPromptCache: false,
+	inputPrice: 0,
+	outputPrice: 0,
+}

+ 57 - 0
src/utils/openai-format.ts

@@ -142,3 +142,60 @@ export function convertToOpenAiMessages(
 
 
 	return openAiMessages
 	return openAiMessages
 }
 }
+
+// Convert OpenAI response to Anthropic format
+export function convertToAnthropicMessage(
+	completion: OpenAI.Chat.Completions.ChatCompletion
+): Anthropic.Messages.Message {
+	const openAiMessage = completion.choices[0].message
+	const anthropicMessage: Anthropic.Messages.Message = {
+		id: completion.id,
+		type: "message",
+		role: openAiMessage.role, // always "assistant"
+		content: [
+			{
+				type: "text",
+				text: openAiMessage.content || "",
+			},
+		],
+		model: completion.model,
+		stop_reason: (() => {
+			switch (completion.choices[0].finish_reason) {
+				case "stop":
+					return "end_turn"
+				case "length":
+					return "max_tokens"
+				case "tool_calls":
+					return "tool_use"
+				case "content_filter": // Anthropic doesn't have an exact equivalent
+				default:
+					return null
+			}
+		})(),
+		stop_sequence: null, // which custom stop_sequence was generated, if any (not applicable if you don't use stop_sequence)
+		usage: {
+			input_tokens: completion.usage?.prompt_tokens || 0,
+			output_tokens: completion.usage?.completion_tokens || 0,
+		},
+	}
+
+	if (openAiMessage.tool_calls && openAiMessage.tool_calls.length > 0) {
+		anthropicMessage.content.push(
+			...openAiMessage.tool_calls.map((toolCall): Anthropic.ToolUseBlock => {
+				let parsedInput = {}
+				try {
+					parsedInput = JSON.parse(toolCall.function.arguments || "{}")
+				} catch (error) {
+					console.error("Failed to parse tool arguments:", error)
+				}
+				return {
+					type: "tool_use",
+					id: toolCall.id,
+					name: toolCall.function.name,
+					input: parsedInput,
+				}
+			})
+		)
+	}
+	return anthropicMessage
+}

+ 57 - 5
webview-ui/src/components/ApiOptions.tsx

@@ -2,12 +2,12 @@ import { VSCodeDropdown, VSCodeLink, VSCodeOption, VSCodeTextField } from "@vsco
 import React, { useMemo } from "react"
 import React, { useMemo } from "react"
 import {
 import {
 	ApiConfiguration,
 	ApiConfiguration,
-	ApiModelId,
 	ModelInfo,
 	ModelInfo,
 	anthropicDefaultModelId,
 	anthropicDefaultModelId,
 	anthropicModels,
 	anthropicModels,
 	bedrockDefaultModelId,
 	bedrockDefaultModelId,
 	bedrockModels,
 	bedrockModels,
+	openAiModelInfoSaneDefaults,
 	openRouterDefaultModelId,
 	openRouterDefaultModelId,
 	openRouterModels,
 	openRouterModels,
 	vertexDefaultModelId,
 	vertexDefaultModelId,
@@ -69,11 +69,16 @@ const ApiOptions: React.FC<ApiOptionsProps> = ({ showModelOptions, apiErrorMessa
 				<label htmlFor="api-provider">
 				<label htmlFor="api-provider">
 					<span style={{ fontWeight: 500 }}>API Provider</span>
 					<span style={{ fontWeight: 500 }}>API Provider</span>
 				</label>
 				</label>
-				<VSCodeDropdown id="api-provider" value={selectedProvider} onChange={handleInputChange("apiProvider")}>
+				<VSCodeDropdown
+					id="api-provider"
+					value={selectedProvider}
+					onChange={handleInputChange("apiProvider")}
+					style={{ minWidth: 125 }}>
 					<VSCodeOption value="anthropic">Anthropic</VSCodeOption>
 					<VSCodeOption value="anthropic">Anthropic</VSCodeOption>
 					<VSCodeOption value="openrouter">OpenRouter</VSCodeOption>
 					<VSCodeOption value="openrouter">OpenRouter</VSCodeOption>
 					<VSCodeOption value="bedrock">AWS Bedrock</VSCodeOption>
 					<VSCodeOption value="bedrock">AWS Bedrock</VSCodeOption>
 					<VSCodeOption value="vertex">GCP Vertex AI</VSCodeOption>
 					<VSCodeOption value="vertex">GCP Vertex AI</VSCodeOption>
+					<VSCodeOption value="openai">OpenAI Compatible</VSCodeOption>
 				</VSCodeDropdown>
 				</VSCodeDropdown>
 			</div>
 			</div>
 
 
@@ -256,6 +261,47 @@ const ApiOptions: React.FC<ApiOptionsProps> = ({ showModelOptions, apiErrorMessa
 				</div>
 				</div>
 			)}
 			)}
 
 
+			{selectedProvider === "openai" && (
+				<div>
+					<VSCodeTextField
+						value={apiConfiguration?.openAiBaseUrl || ""}
+						style={{ width: "100%" }}
+						type="url"
+						onInput={handleInputChange("openAiBaseUrl")}
+						placeholder={"e.g. http://localhost:11434"}>
+						<span style={{ fontWeight: 500 }}>Base URL</span>
+					</VSCodeTextField>
+					<VSCodeTextField
+						value={apiConfiguration?.openAiApiKey || ""}
+						style={{ width: "100%" }}
+						type="password"
+						onInput={handleInputChange("openAiApiKey")}
+						placeholder="e.g. ollama">
+						<span style={{ fontWeight: 500 }}>API Key</span>
+					</VSCodeTextField>
+					<VSCodeTextField
+						value={apiConfiguration?.openAiModelId || ""}
+						style={{ width: "100%" }}
+						onInput={handleInputChange("openAiModelId")}
+						placeholder={"e.g. llama3.1"}>
+						<span style={{ fontWeight: 500 }}>Model ID</span>
+					</VSCodeTextField>
+					<p
+						style={{
+							fontSize: "12px",
+							marginTop: "5px",
+							color: "var(--vscode-descriptionForeground)",
+						}}>
+						You can use any OpenAI compatible API with models that support tool use.{" "}
+						<span style={{ color: "var(--vscode-errorForeground)" }}>
+							(<span style={{ fontWeight: 500 }}>Note:</span> Claude Dev uses complex prompts, so results
+							may vary depending on the quality of the model you choose. Less capable models may not work
+							as expected.)
+						</span>
+					</p>
+				</div>
+			)}
+
 			{apiErrorMessage && (
 			{apiErrorMessage && (
 				<p
 				<p
 					style={{
 					style={{
@@ -267,7 +313,7 @@ const ApiOptions: React.FC<ApiOptionsProps> = ({ showModelOptions, apiErrorMessa
 				</p>
 				</p>
 			)}
 			)}
 
 
-			{showModelOptions && (
+			{selectedProvider !== "openai" && showModelOptions && (
 				<>
 				<>
 					<div className="dropdown-container">
 					<div className="dropdown-container">
 						<label htmlFor="model-id">
 						<label htmlFor="model-id">
@@ -365,8 +411,8 @@ export function normalizeApiConfiguration(apiConfiguration?: ApiConfiguration) {
 	const provider = apiConfiguration?.apiProvider || "anthropic"
 	const provider = apiConfiguration?.apiProvider || "anthropic"
 	const modelId = apiConfiguration?.apiModelId
 	const modelId = apiConfiguration?.apiModelId
 
 
-	const getProviderData = (models: Record<string, ModelInfo>, defaultId: ApiModelId) => {
-		let selectedModelId: ApiModelId
+	const getProviderData = (models: Record<string, ModelInfo>, defaultId: string) => {
+		let selectedModelId: string
 		let selectedModelInfo: ModelInfo
 		let selectedModelInfo: ModelInfo
 		if (modelId && modelId in models) {
 		if (modelId && modelId in models) {
 			selectedModelId = modelId
 			selectedModelId = modelId
@@ -386,6 +432,12 @@ export function normalizeApiConfiguration(apiConfiguration?: ApiConfiguration) {
 			return getProviderData(bedrockModels, bedrockDefaultModelId)
 			return getProviderData(bedrockModels, bedrockDefaultModelId)
 		case "vertex":
 		case "vertex":
 			return getProviderData(vertexModels, vertexDefaultModelId)
 			return getProviderData(vertexModels, vertexDefaultModelId)
+		case "openai":
+			return {
+				selectedProvider: provider,
+				selectedModelId: apiConfiguration?.openAiModelId ?? "",
+				selectedModelInfo: openAiModelInfoSaneDefaults,
+			}
 		default:
 		default:
 			return getProviderData(anthropicModels, anthropicDefaultModelId)
 			return getProviderData(anthropicModels, anthropicDefaultModelId)
 	}
 	}

+ 0 - 3
webview-ui/src/components/ChatView.tsx

@@ -497,9 +497,6 @@ const ChatView = ({
 					cacheReads={apiMetrics.totalCacheReads}
 					cacheReads={apiMetrics.totalCacheReads}
 					totalCost={apiMetrics.totalCost}
 					totalCost={apiMetrics.totalCost}
 					onClose={handleTaskCloseButtonClick}
 					onClose={handleTaskCloseButtonClick}
-					isHidden={isHidden}
-					vscodeUriScheme={uriScheme}
-					apiProvider={apiConfiguration?.apiProvider}
 				/>
 				/>
 			) : (
 			) : (
 				<>
 				<>

+ 6 - 2
webview-ui/src/components/HistoryPreview.tsx

@@ -108,17 +108,21 @@ const HistoryPreview = ({ showHistoryView }: HistoryPreviewProps) => {
 									<span>
 									<span>
 										Tokens: ↑{item.tokensIn?.toLocaleString()} ↓{item.tokensOut?.toLocaleString()}
 										Tokens: ↑{item.tokensIn?.toLocaleString()} ↓{item.tokensOut?.toLocaleString()}
 									</span>
 									</span>
-									{" • "}
 									{item.cacheWrites && item.cacheReads && (
 									{item.cacheWrites && item.cacheReads && (
 										<>
 										<>
+											{" • "}
 											<span>
 											<span>
 												Cache: +{item.cacheWrites?.toLocaleString()} →{" "}
 												Cache: +{item.cacheWrites?.toLocaleString()} →{" "}
 												{item.cacheReads?.toLocaleString()}
 												{item.cacheReads?.toLocaleString()}
 											</span>
 											</span>
+										</>
+									)}
+									{!!item.totalCost && (
+										<>
 											{" • "}
 											{" • "}
+											<span>API Cost: ${item.totalCost?.toFixed(4)}</span>
 										</>
 										</>
 									)}
 									)}
-									<span>API Cost: ${item.totalCost?.toFixed(4)}</span>
 								</div>
 								</div>
 							</div>
 							</div>
 						</div>
 						</div>

+ 78 - 65
webview-ui/src/components/HistoryView.tsx

@@ -63,6 +63,17 @@ const HistoryView = ({ onDone }: HistoryViewProps) => {
 		)
 		)
 	}
 	}
 
 
+	const ExportButton = ({ itemId }: { itemId: string }) => (
+		<VSCodeButton
+			appearance="icon"
+			onClick={(e) => {
+				e.stopPropagation()
+				handleExportMd(itemId)
+			}}>
+			<div style={{ fontSize: "11px", fontWeight: 500, opacity: 1 }}>EXPORT .MD</div>
+		</VSCodeButton>
+	)
+
 	return (
 	return (
 		<>
 		<>
 			<style>
 			<style>
@@ -216,52 +227,61 @@ const HistoryView = ({ onDone }: HistoryViewProps) => {
 										<div
 										<div
 											style={{
 											style={{
 												display: "flex",
 												display: "flex",
+												justifyContent: "space-between",
 												alignItems: "center",
 												alignItems: "center",
-												gap: "4px",
-												flexWrap: "wrap",
 											}}>
 											}}>
-											<span
-												style={{
-													fontWeight: 500,
-													color: "var(--vscode-descriptionForeground)",
-												}}>
-												Tokens:
-											</span>
-											<span
+											<div
 												style={{
 												style={{
 													display: "flex",
 													display: "flex",
 													alignItems: "center",
 													alignItems: "center",
-													gap: "3px",
-													color: "var(--vscode-descriptionForeground)",
+													gap: "4px",
+													flexWrap: "wrap",
 												}}>
 												}}>
-												<i
-													className="codicon codicon-arrow-up"
+												<span
 													style={{
 													style={{
-														fontSize: "12px",
-														fontWeight: "bold",
-														marginBottom: "-2px",
-													}}
-												/>
-												{item.tokensIn?.toLocaleString()}
-											</span>
-											<span
-												style={{
-													display: "flex",
-													alignItems: "center",
-													gap: "3px",
-													color: "var(--vscode-descriptionForeground)",
-												}}>
-												<i
-													className="codicon codicon-arrow-down"
+														fontWeight: 500,
+														color: "var(--vscode-descriptionForeground)",
+													}}>
+													Tokens:
+												</span>
+												<span
 													style={{
 													style={{
-														fontSize: "12px",
-														fontWeight: "bold",
-														marginBottom: "-2px",
-													}}
-												/>
-												{item.tokensOut?.toLocaleString()}
-											</span>
+														display: "flex",
+														alignItems: "center",
+														gap: "3px",
+														color: "var(--vscode-descriptionForeground)",
+													}}>
+													<i
+														className="codicon codicon-arrow-up"
+														style={{
+															fontSize: "12px",
+															fontWeight: "bold",
+															marginBottom: "-2px",
+														}}
+													/>
+													{item.tokensIn?.toLocaleString()}
+												</span>
+												<span
+													style={{
+														display: "flex",
+														alignItems: "center",
+														gap: "3px",
+														color: "var(--vscode-descriptionForeground)",
+													}}>
+													<i
+														className="codicon codicon-arrow-down"
+														style={{
+															fontSize: "12px",
+															fontWeight: "bold",
+															marginBottom: "-2px",
+														}}
+													/>
+													{item.tokensOut?.toLocaleString()}
+												</span>
+											</div>
+											{!item.totalCost && <ExportButton itemId={item.id} />}
 										</div>
 										</div>
+
 										{item.cacheWrites && item.cacheReads && (
 										{item.cacheWrites && item.cacheReads && (
 											<div
 											<div
 												style={{
 												style={{
@@ -313,36 +333,29 @@ const HistoryView = ({ onDone }: HistoryViewProps) => {
 												</span>
 												</span>
 											</div>
 											</div>
 										)}
 										)}
-										<div
-											style={{
-												display: "flex",
-												justifyContent: "space-between",
-												alignItems: "center",
-												marginTop: -2,
-											}}>
-											<div style={{ display: "flex", alignItems: "center", gap: "4px" }}>
-												<span
-													style={{
-														fontWeight: 500,
-														color: "var(--vscode-descriptionForeground)",
-													}}>
-													API Cost:
-												</span>
-												<span style={{ color: "var(--vscode-descriptionForeground)" }}>
-													${item.totalCost?.toFixed(4)}
-												</span>
-											</div>
-											<VSCodeButton
-												appearance="icon"
-												onClick={(e) => {
-													e.stopPropagation()
-													handleExportMd(item.id)
+										{!!item.totalCost && (
+											<div
+												style={{
+													display: "flex",
+													justifyContent: "space-between",
+													alignItems: "center",
+													marginTop: -2,
 												}}>
 												}}>
-												<div style={{ fontSize: "11px", fontWeight: 500, opacity: 1 }}>
-													EXPORT .MD
+												<div style={{ display: "flex", alignItems: "center", gap: "4px" }}>
+													<span
+														style={{
+															fontWeight: 500,
+															color: "var(--vscode-descriptionForeground)",
+														}}>
+														API Cost:
+													</span>
+													<span style={{ color: "var(--vscode-descriptionForeground)" }}>
+														${item.totalCost?.toFixed(4)}
+													</span>
 												</div>
 												</div>
-											</VSCodeButton>
-										</div>
+												<ExportButton itemId={item.id} />
+											</div>
+										)}
 									</div>
 									</div>
 								</div>
 								</div>
 							</div>
 							</div>

+ 1 - 6
webview-ui/src/components/SettingsView.tsx

@@ -1,9 +1,4 @@
-import {
-	VSCodeButton,
-	VSCodeCheckbox,
-	VSCodeLink,
-	VSCodeTextArea
-} from "@vscode/webview-ui-toolkit/react"
+import { VSCodeButton, VSCodeCheckbox, VSCodeLink, VSCodeTextArea } from "@vscode/webview-ui-toolkit/react"
 import { useEffect, useState } from "react"
 import { useEffect, useState } from "react"
 import { useExtensionState } from "../context/ExtensionStateContext"
 import { useExtensionState } from "../context/ExtensionStateContext"
 import { validateApiConfiguration } from "../utils/validate"
 import { validateApiConfiguration } from "../utils/validate"

+ 51 - 41
webview-ui/src/components/TaskHeader.tsx

@@ -1,8 +1,8 @@
 import { VSCodeButton } from "@vscode/webview-ui-toolkit/react"
 import { VSCodeButton } from "@vscode/webview-ui-toolkit/react"
 import React, { useEffect, useRef, useState } from "react"
 import React, { useEffect, useRef, useState } from "react"
 import { useWindowSize } from "react-use"
 import { useWindowSize } from "react-use"
-import { ApiProvider } from "../../../src/shared/api"
 import { ClaudeMessage } from "../../../src/shared/ExtensionMessage"
 import { ClaudeMessage } from "../../../src/shared/ExtensionMessage"
+import { useExtensionState } from "../context/ExtensionStateContext"
 import { vscode } from "../utils/vscode"
 import { vscode } from "../utils/vscode"
 import Thumbnails from "./Thumbnails"
 import Thumbnails from "./Thumbnails"
 
 
@@ -15,9 +15,6 @@ interface TaskHeaderProps {
 	cacheReads?: number
 	cacheReads?: number
 	totalCost: number
 	totalCost: number
 	onClose: () => void
 	onClose: () => void
-	isHidden: boolean
-	vscodeUriScheme?: string
-	apiProvider?: ApiProvider
 }
 }
 
 
 const TaskHeader: React.FC<TaskHeaderProps> = ({
 const TaskHeader: React.FC<TaskHeaderProps> = ({
@@ -29,10 +26,8 @@ const TaskHeader: React.FC<TaskHeaderProps> = ({
 	cacheReads,
 	cacheReads,
 	totalCost,
 	totalCost,
 	onClose,
 	onClose,
-	isHidden,
-	vscodeUriScheme,
-	apiProvider,
 }) => {
 }) => {
+	const { apiConfiguration } = useExtensionState()
 	const [isExpanded, setIsExpanded] = useState(false)
 	const [isExpanded, setIsExpanded] = useState(false)
 	const [showSeeMore, setShowSeeMore] = useState(false)
 	const [showSeeMore, setShowSeeMore] = useState(false)
 	const textContainerRef = useRef<HTMLDivElement>(null)
 	const textContainerRef = useRef<HTMLDivElement>(null)
@@ -100,6 +95,18 @@ const TaskHeader: React.FC<TaskHeaderProps> = ({
 		vscode.postMessage({ type: "exportCurrentTask" })
 		vscode.postMessage({ type: "exportCurrentTask" })
 	}
 	}
 
 
+	const ExportButton = () => (
+		<VSCodeButton
+			appearance="icon"
+			onClick={handleDownload}
+			style={{
+				marginBottom: "-2px",
+				marginRight: "-2.5px",
+			}}>
+			<div style={{ fontSize: "10.5px", fontWeight: "bold", opacity: 0.6 }}>EXPORT .MD</div>
+		</VSCodeButton>
+	)
+
 	return (
 	return (
 		<div style={{ padding: "10px 13px 10px 13px" }}>
 		<div style={{ padding: "10px 13px 10px 13px" }}>
 			<div
 			<div
@@ -196,23 +203,32 @@ const TaskHeader: React.FC<TaskHeaderProps> = ({
 				)}
 				)}
 				{task.images && task.images.length > 0 && <Thumbnails images={task.images} />}
 				{task.images && task.images.length > 0 && <Thumbnails images={task.images} />}
 				<div style={{ display: "flex", flexDirection: "column", gap: "4px" }}>
 				<div style={{ display: "flex", flexDirection: "column", gap: "4px" }}>
-					<div style={{ display: "flex", alignItems: "center", gap: "4px", flexWrap: "wrap" }}>
-						<span style={{ fontWeight: "bold" }}>Tokens:</span>
-						<span style={{ display: "flex", alignItems: "center", gap: "3px" }}>
-							<i
-								className="codicon codicon-arrow-up"
-								style={{ fontSize: "12px", fontWeight: "bold", marginBottom: "-2px" }}
-							/>
-							{tokensIn?.toLocaleString()}
-						</span>
-						<span style={{ display: "flex", alignItems: "center", gap: "3px" }}>
-							<i
-								className="codicon codicon-arrow-down"
-								style={{ fontSize: "12px", fontWeight: "bold", marginBottom: "-2px" }}
-							/>
-							{tokensOut?.toLocaleString()}
-						</span>
+					<div
+						style={{
+							display: "flex",
+							justifyContent: "space-between",
+							alignItems: "center",
+						}}>
+						<div style={{ display: "flex", alignItems: "center", gap: "4px", flexWrap: "wrap" }}>
+							<span style={{ fontWeight: "bold" }}>Tokens:</span>
+							<span style={{ display: "flex", alignItems: "center", gap: "3px" }}>
+								<i
+									className="codicon codicon-arrow-up"
+									style={{ fontSize: "12px", fontWeight: "bold", marginBottom: "-2px" }}
+								/>
+								{tokensIn?.toLocaleString()}
+							</span>
+							<span style={{ display: "flex", alignItems: "center", gap: "3px" }}>
+								<i
+									className="codicon codicon-arrow-down"
+									style={{ fontSize: "12px", fontWeight: "bold", marginBottom: "-2px" }}
+								/>
+								{tokensOut?.toLocaleString()}
+							</span>
+						</div>
+						{apiConfiguration?.apiProvider === "openai" && <ExportButton />}
 					</div>
 					</div>
+
 					{(doesModelSupportPromptCache || cacheReads !== undefined || cacheWrites !== undefined) && (
 					{(doesModelSupportPromptCache || cacheReads !== undefined || cacheWrites !== undefined) && (
 						<div style={{ display: "flex", alignItems: "center", gap: "4px", flexWrap: "wrap" }}>
 						<div style={{ display: "flex", alignItems: "center", gap: "4px", flexWrap: "wrap" }}>
 							<span style={{ fontWeight: "bold" }}>Cache:</span>
 							<span style={{ fontWeight: "bold" }}>Cache:</span>
@@ -232,26 +248,20 @@ const TaskHeader: React.FC<TaskHeaderProps> = ({
 							</span>
 							</span>
 						</div>
 						</div>
 					)}
 					)}
-					<div
-						style={{
-							display: "flex",
-							justifyContent: "space-between",
-							alignItems: "center",
-						}}>
-						<div style={{ display: "flex", alignItems: "center", gap: "4px" }}>
-							<span style={{ fontWeight: "bold" }}>API Cost:</span>
-							<span>${totalCost?.toFixed(4)}</span>
-						</div>
-						<VSCodeButton
-							appearance="icon"
-							onClick={handleDownload}
+					{apiConfiguration?.apiProvider !== "openai" && (
+						<div
 							style={{
 							style={{
-								marginBottom: "-2px",
-								marginRight: "-2.5px",
+								display: "flex",
+								justifyContent: "space-between",
+								alignItems: "center",
 							}}>
 							}}>
-							<div style={{ fontSize: "10.5px", fontWeight: "bold", opacity: 0.6 }}>EXPORT .MD</div>
-						</VSCodeButton>
-					</div>
+							<div style={{ display: "flex", alignItems: "center", gap: "4px" }}>
+								<span style={{ fontWeight: "bold" }}>API Cost:</span>
+								<span>${totalCost?.toFixed(4)}</span>
+							</div>
+							<ExportButton />
+						</div>
+					)}
 				</div>
 				</div>
 			</div>
 			</div>
 			{/* {apiProvider === "kodu" && (
 			{/* {apiProvider === "kodu" && (

+ 7 - 3
webview-ui/src/context/ExtensionStateContext.tsx

@@ -31,9 +31,13 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode
 			setState(message.state)
 			setState(message.state)
 			const config = message.state?.apiConfiguration
 			const config = message.state?.apiConfiguration
 			const hasKey = config
 			const hasKey = config
-				? [config.apiKey, config.openRouterApiKey, config.awsRegion, config.vertexProjectId].some(
-						(key) => key !== undefined
-				  )
+				? [
+						config.apiKey,
+						config.openRouterApiKey,
+						config.awsRegion,
+						config.vertexProjectId,
+						config.openAiApiKey,
+				  ].some((key) => key !== undefined)
 				: false
 				: false
 			setShowWelcome(!hasKey)
 			setShowWelcome(!hasKey)
 			setDidHydrateState(true)
 			setDidHydrateState(true)

+ 9 - 0
webview-ui/src/utils/validate.ts

@@ -23,6 +23,15 @@ export function validateApiConfiguration(apiConfiguration?: ApiConfiguration): s
 					return "You must provide a valid Google Cloud Project ID and Region."
 					return "You must provide a valid Google Cloud Project ID and Region."
 				}
 				}
 				break
 				break
+			case "openai":
+				if (
+					!apiConfiguration.openAiBaseUrl ||
+					!apiConfiguration.openAiApiKey ||
+					!apiConfiguration.openAiModelId
+				) {
+					return "You must provide a valid base URL, API key, and model ID."
+				}
+				break
 		}
 		}
 	}
 	}
 	return undefined
 	return undefined