Browse Source

Add ollama provider option

Saoud Rizwan 1 year ago
parent
commit
286e569e09

+ 1 - 1
package.json

@@ -2,7 +2,7 @@
   "name": "claude-dev",
   "displayName": "Claude Dev",
   "description": "Autonomous coding agent right in your IDE, capable of creating/editing files, executing commands, and more with your permission every step of the way.",
-  "version": "1.5.20",
+  "version": "1.5.21",
   "icon": "icon.png",
   "engines": {
     "vscode": "^1.84.0"

+ 3 - 0
src/api/index.ts

@@ -5,6 +5,7 @@ import { AwsBedrockHandler } from "./bedrock"
 import { OpenRouterHandler } from "./openrouter"
 import { VertexHandler } from "./vertex"
 import { OpenAiHandler } from "./openai"
+import { OllamaHandler } from "./ollama"
 
 export interface ApiHandlerMessageResponse {
 	message: Anthropic.Messages.Message
@@ -43,6 +44,8 @@ export function buildApiHandler(configuration: ApiConfiguration): ApiHandler {
 			return new VertexHandler(options)
 		case "openai":
 			return new OpenAiHandler(options)
+		case "ollama":
+			return new OllamaHandler(options)
 		default:
 			return new AnthropicHandler(options)
 	}

+ 74 - 0
src/api/ollama.ts

@@ -0,0 +1,74 @@
+import { Anthropic } from "@anthropic-ai/sdk"
+import OpenAI from "openai"
+import { ApiHandler, ApiHandlerMessageResponse, withoutImageData } from "."
+import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../shared/api"
+import { convertToAnthropicMessage, convertToOpenAiMessages } from "../utils/openai-format"
+
+export class OllamaHandler implements ApiHandler {
+	private options: ApiHandlerOptions
+	private client: OpenAI
+
+	constructor(options: ApiHandlerOptions) {
+		this.options = options
+		this.client = new OpenAI({
+			baseURL: "http://localhost:11434/v1",
+			apiKey: "ollama",
+		})
+	}
+
+	async createMessage(
+		systemPrompt: string,
+		messages: Anthropic.Messages.MessageParam[],
+		tools: Anthropic.Messages.Tool[]
+	): Promise<ApiHandlerMessageResponse> {
+		const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
+			{ role: "system", content: systemPrompt },
+			...convertToOpenAiMessages(messages),
+		]
+		const openAiTools: OpenAI.Chat.ChatCompletionTool[] = tools.map((tool) => ({
+			type: "function",
+			function: {
+				name: tool.name,
+				description: tool.description,
+				parameters: tool.input_schema,
+			},
+		}))
+		const createParams: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
+			model: this.options.ollamaModelId ?? "",
+			messages: openAiMessages,
+			tools: openAiTools,
+			tool_choice: "auto",
+		}
+		const completion = await this.client.chat.completions.create(createParams)
+		const errorMessage = (completion as any).error?.message
+		if (errorMessage) {
+			throw new Error(errorMessage)
+		}
+		const anthropicMessage = convertToAnthropicMessage(completion)
+		return { message: anthropicMessage }
+	}
+
+	createUserReadableRequest(
+		userContent: Array<
+			| Anthropic.TextBlockParam
+			| Anthropic.ImageBlockParam
+			| Anthropic.ToolUseBlockParam
+			| Anthropic.ToolResultBlockParam
+		>
+	): any {
+		return {
+			model: this.options.ollamaModelId ?? "",
+			system: "(see SYSTEM_PROMPT in src/ClaudeDev.ts)",
+			messages: [{ conversation_history: "..." }, { role: "user", content: withoutImageData(userContent) }],
+			tools: "(see tools in src/ClaudeDev.ts)",
+			tool_choice: "auto",
+		}
+	}
+
+	getModel(): { id: string; info: ModelInfo } {
+		return {
+			id: this.options.ollamaModelId ?? "",
+			info: openAiModelInfoSaneDefaults,
+		}
+	}
+}

+ 6 - 0
src/providers/ClaudeDevProvider.ts

@@ -29,6 +29,7 @@ type GlobalStateKey =
 	| "taskHistory"
 	| "openAiBaseUrl"
 	| "openAiModelId"
+	| "ollamaModelId"
 
 export class ClaudeDevProvider implements vscode.WebviewViewProvider {
 	public static readonly sideBarId = "claude-dev.SidebarProvider" // used in package.json as the view's id. This value cannot be changed due to how vscode caches views based on their id, and updating the id would break existing instances of the extension.
@@ -319,6 +320,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
 								openAiBaseUrl,
 								openAiApiKey,
 								openAiModelId,
+								ollamaModelId,
 							} = message.apiConfiguration
 							await this.updateGlobalState("apiProvider", apiProvider)
 							await this.updateGlobalState("apiModelId", apiModelId)
@@ -333,6 +335,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
 							await this.updateGlobalState("openAiBaseUrl", openAiBaseUrl)
 							await this.storeSecret("openAiApiKey", openAiApiKey)
 							await this.updateGlobalState("openAiModelId", openAiModelId)
+							await this.updateGlobalState("ollamaModelId", ollamaModelId)
 							this.claudeDev?.updateApi(message.apiConfiguration)
 						}
 						await this.postStateToWebview()
@@ -623,6 +626,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
 			openAiBaseUrl,
 			openAiApiKey,
 			openAiModelId,
+			ollamaModelId,
 			lastShownAnnouncementId,
 			customInstructions,
 			alwaysAllowReadOnly,
@@ -641,6 +645,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
 			this.getGlobalState("openAiBaseUrl") as Promise<string | undefined>,
 			this.getSecret("openAiApiKey") as Promise<string | undefined>,
 			this.getGlobalState("openAiModelId") as Promise<string | undefined>,
+			this.getGlobalState("ollamaModelId") as Promise<string | undefined>,
 			this.getGlobalState("lastShownAnnouncementId") as Promise<string | undefined>,
 			this.getGlobalState("customInstructions") as Promise<string | undefined>,
 			this.getGlobalState("alwaysAllowReadOnly") as Promise<boolean | undefined>,
@@ -676,6 +681,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
 				openAiBaseUrl,
 				openAiApiKey,
 				openAiModelId,
+				ollamaModelId,
 			},
 			lastShownAnnouncementId,
 			customInstructions,

+ 2 - 1
src/shared/api.ts

@@ -1,4 +1,4 @@
-export type ApiProvider = "anthropic" | "openrouter" | "bedrock" | "vertex" | "openai"
+export type ApiProvider = "anthropic" | "openrouter" | "bedrock" | "vertex" | "openai" | "ollama"
 
 export interface ApiHandlerOptions {
 	apiModelId?: string
@@ -13,6 +13,7 @@ export interface ApiHandlerOptions {
 	openAiBaseUrl?: string
 	openAiApiKey?: string
 	openAiModelId?: string
+	ollamaModelId?: string
 }
 
 export type ApiConfiguration = ApiHandlerOptions & {

+ 45 - 4
webview-ui/src/components/ApiOptions.tsx

@@ -79,6 +79,7 @@ const ApiOptions: React.FC<ApiOptionsProps> = ({ showModelOptions, apiErrorMessa
 					<VSCodeOption value="bedrock">AWS Bedrock</VSCodeOption>
 					<VSCodeOption value="vertex">GCP Vertex AI</VSCodeOption>
 					<VSCodeOption value="openai">OpenAI Compatible</VSCodeOption>
+					<VSCodeOption value="ollama">Ollama</VSCodeOption>
 				</VSCodeDropdown>
 			</div>
 
@@ -268,7 +269,7 @@ const ApiOptions: React.FC<ApiOptionsProps> = ({ showModelOptions, apiErrorMessa
 						style={{ width: "100%" }}
 						type="url"
 						onInput={handleInputChange("openAiBaseUrl")}
-						placeholder={"e.g. http://localhost:11434/v1"}>
+						placeholder={"Enter base URL..."}>
 						<span style={{ fontWeight: 500 }}>Base URL</span>
 					</VSCodeTextField>
 					<VSCodeTextField
@@ -276,14 +277,14 @@ const ApiOptions: React.FC<ApiOptionsProps> = ({ showModelOptions, apiErrorMessa
 						style={{ width: "100%" }}
 						type="password"
 						onInput={handleInputChange("openAiApiKey")}
-						placeholder="e.g. ollama">
+						placeholder="Enter API Key...">
 						<span style={{ fontWeight: 500 }}>API Key</span>
 					</VSCodeTextField>
 					<VSCodeTextField
 						value={apiConfiguration?.openAiModelId || ""}
 						style={{ width: "100%" }}
 						onInput={handleInputChange("openAiModelId")}
-						placeholder={"e.g. llama3.1"}>
+						placeholder={"Enter Model ID..."}>
 						<span style={{ fontWeight: 500 }}>Model ID</span>
 					</VSCodeTextField>
 					<p
@@ -301,6 +302,40 @@ const ApiOptions: React.FC<ApiOptionsProps> = ({ showModelOptions, apiErrorMessa
 				</div>
 			)}
 
+			{selectedProvider === "ollama" && (
+				<div>
+					<VSCodeTextField
+						value={apiConfiguration?.ollamaModelId || ""}
+						style={{ width: "100%" }}
+						onInput={handleInputChange("ollamaModelId")}
+						placeholder={"e.g. llama3.1"}>
+						<span style={{ fontWeight: 500 }}>Model ID</span>
+					</VSCodeTextField>
+					<p
+						style={{
+							fontSize: "12px",
+							marginTop: "5px",
+							color: "var(--vscode-descriptionForeground)",
+						}}>
+						Ollama allows you to run models locally on your computer. For instructions on how to get started
+						with Ollama, see their
+						<VSCodeLink
+							href="https://github.com/ollama/ollama/blob/main/README.md"
+							style={{ display: "inline" }}>
+							quickstart guide.
+						</VSCodeLink>{" "}
+						You can use any models that support{" "}
+						<VSCodeLink href="https://ollama.com/search?c=tools" style={{ display: "inline" }}>
+							tool use.
+						</VSCodeLink>
+						<span style={{ color: "var(--vscode-errorForeground)" }}>
+							(<span style={{ fontWeight: 500 }}>Note:</span> Claude Dev uses complex prompts, so less
+							capable models may not work as expected.)
+						</span>
+					</p>
+				</div>
+			)}
+
 			{apiErrorMessage && (
 				<p
 					style={{
@@ -312,7 +347,7 @@ const ApiOptions: React.FC<ApiOptionsProps> = ({ showModelOptions, apiErrorMessa
 				</p>
 			)}
 
-			{selectedProvider !== "openai" && showModelOptions && (
+			{selectedProvider !== "openai" && selectedProvider !== "ollama" && showModelOptions && (
 				<>
 					<div className="dropdown-container">
 						<label htmlFor="model-id">
@@ -437,6 +472,12 @@ export function normalizeApiConfiguration(apiConfiguration?: ApiConfiguration) {
 				selectedModelId: apiConfiguration?.openAiModelId ?? "",
 				selectedModelInfo: openAiModelInfoSaneDefaults,
 			}
+		case "ollama":
+			return {
+				selectedProvider: provider,
+				selectedModelId: apiConfiguration?.ollamaModelId ?? "",
+				selectedModelInfo: openAiModelInfoSaneDefaults,
+			}
 		default:
 			return getProviderData(anthropicModels, anthropicDefaultModelId)
 	}

+ 4 - 2
webview-ui/src/components/TaskHeader.tsx

@@ -226,7 +226,9 @@ const TaskHeader: React.FC<TaskHeaderProps> = ({
 								{tokensOut?.toLocaleString()}
 							</span>
 						</div>
-						{apiConfiguration?.apiProvider === "openai" && <ExportButton />}
+						{(apiConfiguration?.apiProvider === "openai" || apiConfiguration?.apiProvider === "ollama") && (
+							<ExportButton />
+						)}
 					</div>
 
 					{(doesModelSupportPromptCache || cacheReads !== undefined || cacheWrites !== undefined) && (
@@ -248,7 +250,7 @@ const TaskHeader: React.FC<TaskHeaderProps> = ({
 							</span>
 						</div>
 					)}
-					{apiConfiguration?.apiProvider !== "openai" && (
+					{apiConfiguration?.apiProvider !== "openai" && apiConfiguration?.apiProvider !== "ollama" && (
 						<div
 							style={{
 								display: "flex",

+ 5 - 0
webview-ui/src/utils/validate.ts

@@ -32,6 +32,11 @@ export function validateApiConfiguration(apiConfiguration?: ApiConfiguration): s
 					return "You must provide a valid base URL, API key, and model ID."
 				}
 				break
+			case "ollama":
+				if (!apiConfiguration.ollamaModelId) {
+					return "You must provide a valid model ID."
+				}
+				break
 		}
 	}
 	return undefined