Просмотр исходного кода

fix: use native Ollama API instead of OpenAI compatibility layer (#7137)

Daniel 4 месяцев назад
Родитель
Сommit
f3864ffebb

+ 16 - 1
pnpm-lock.yaml

@@ -676,6 +676,9 @@ importers:
       node-ipc:
         specifier: ^12.0.0
         version: 12.0.0
+      ollama:
+        specifier: ^0.5.17
+        version: 0.5.17
       openai:
         specifier: ^5.0.0
         version: 5.5.1([email protected])([email protected])
@@ -7645,6 +7648,9 @@ packages:
     resolution: {integrity: sha512-gXah6aZrcUxjWg2zR2MwouP2eHlCBzdV4pygudehaKXSGW4v2AsRQUK+lwwXhii6KFZcunEnmSUoYp5CXibxtA==}
     engines: {node: '>= 0.4'}
 
+  [email protected]:
+    resolution: {integrity: sha512-q5LmPtk6GLFouS+3aURIVl+qcAOPC4+Msmx7uBb3pd+fxI55WnGjmLZ0yijI/CYy79x0QPGx3BwC3u5zv9fBvQ==}
+
   [email protected]:
     resolution: {integrity: sha512-oVlzkg3ENAhCk2zdv7IJwd/QUD4z2RxRwpkcGY8psCVcCYZNq4wYnVWALHM+brtuJjePWiYF/ClmuDr8Ch5+kg==}
     engines: {node: '>= 0.8'}
@@ -9655,6 +9661,9 @@ packages:
     resolution: {integrity: sha512-6qN4hJdMwfYBtE3YBTTHhoeuUrDBPZmbQaxWAqSALV/MeEnR5z1xd8UKud2RAkFoPkmB+hli1TZSnyi84xz1vQ==}
     engines: {node: '>=18'}
 
+  [email protected]:
+    resolution: {integrity: sha512-EqhiFU6daOA8kpjOWTL0olhVOF3i7OrFzSYiGsEMB8GcXS+RrzauAERX65xMeNWVqxA6HXH2m69Z9LaKKdisfg==}
+
   [email protected]:
     resolution: {integrity: sha512-QaKxh0eNIi2mE9p2vEdzfagOKHCcj1pJ56EEHGQOVxp8r9/iszLUUV7v89x9O1p/T+NlTM5W7jW6+cz4Fq1YVg==}
     engines: {node: '>=18'}
@@ -13546,7 +13555,7 @@ snapshots:
       sirv: 3.0.1
       tinyglobby: 0.2.14
       tinyrainbow: 2.0.0
-      vitest: 3.2.4(@types/[email protected])(@types/node@24.2.1)(@vitest/[email protected])([email protected])([email protected])([email protected])([email protected])([email protected])
+      vitest: 3.2.4(@types/[email protected])(@types/node@20.17.50)(@vitest/[email protected])([email protected])([email protected])([email protected])([email protected])([email protected])
 
   '@vitest/[email protected]':
     dependencies:
@@ -17683,6 +17692,10 @@ snapshots:
       define-properties: 1.2.1
       es-object-atoms: 1.1.1
 
+  [email protected]:
+    dependencies:
+      whatwg-fetch: 3.6.20
+
   [email protected]:
     dependencies:
       ee-first: 1.1.1
@@ -20155,6 +20168,8 @@ snapshots:
     dependencies:
       iconv-lite: 0.6.3
 
+  [email protected]: {}
+
   [email protected]: {}
 
   [email protected]:

+ 2 - 2
src/api/index.ts

@@ -13,7 +13,6 @@ import {
 	VertexHandler,
 	AnthropicVertexHandler,
 	OpenAiHandler,
-	OllamaHandler,
 	LmStudioHandler,
 	GeminiHandler,
 	OpenAiNativeHandler,
@@ -37,6 +36,7 @@ import {
 	ZAiHandler,
 	FireworksHandler,
 } from "./providers"
+import { NativeOllamaHandler } from "./providers/native-ollama"
 
 export interface SingleCompletionHandler {
 	completePrompt(prompt: string): Promise<string>
@@ -95,7 +95,7 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler {
 		case "openai":
 			return new OpenAiHandler(options)
 		case "ollama":
-			return new OllamaHandler(options)
+			return new NativeOllamaHandler(options)
 		case "lmstudio":
 			return new LmStudioHandler(options)
 		case "gemini":

+ 162 - 0
src/api/providers/__tests__/native-ollama.spec.ts

@@ -0,0 +1,162 @@
+// npx vitest run api/providers/__tests__/native-ollama.spec.ts
+
+import { NativeOllamaHandler } from "../native-ollama"
+import { ApiHandlerOptions } from "../../../shared/api"
+
+// Mock the ollama package
+const mockChat = vitest.fn()
+vitest.mock("ollama", () => {
+	return {
+		Ollama: vitest.fn().mockImplementation(() => ({
+			chat: mockChat,
+		})),
+		Message: vitest.fn(),
+	}
+})
+
+// Mock the getOllamaModels function
+vitest.mock("../fetchers/ollama", () => ({
+	getOllamaModels: vitest.fn().mockResolvedValue({
+		llama2: {
+			contextWindow: 4096,
+			maxTokens: 4096,
+			supportsImages: false,
+			supportsPromptCache: false,
+		},
+	}),
+}))
+
+describe("NativeOllamaHandler", () => {
+	let handler: NativeOllamaHandler
+
+	beforeEach(() => {
+		vitest.clearAllMocks()
+
+		const options: ApiHandlerOptions = {
+			apiModelId: "llama2",
+			ollamaModelId: "llama2",
+			ollamaBaseUrl: "http://localhost:11434",
+		}
+
+		handler = new NativeOllamaHandler(options)
+	})
+
+	describe("createMessage", () => {
+		it("should stream messages from Ollama", async () => {
+			// Mock the chat response as an async generator
+			mockChat.mockImplementation(async function* () {
+				yield {
+					message: { content: "Hello" },
+					eval_count: undefined,
+					prompt_eval_count: undefined,
+				}
+				yield {
+					message: { content: " world" },
+					eval_count: 2,
+					prompt_eval_count: 10,
+				}
+			})
+
+			const systemPrompt = "You are a helpful assistant"
+			const messages = [{ role: "user" as const, content: "Hi there" }]
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const results = []
+
+			for await (const chunk of stream) {
+				results.push(chunk)
+			}
+
+			expect(results).toHaveLength(3)
+			expect(results[0]).toEqual({ type: "text", text: "Hello" })
+			expect(results[1]).toEqual({ type: "text", text: " world" })
+			expect(results[2]).toEqual({ type: "usage", inputTokens: 10, outputTokens: 2 })
+		})
+
+		it("should handle DeepSeek R1 models with reasoning detection", async () => {
+			const options: ApiHandlerOptions = {
+				apiModelId: "deepseek-r1",
+				ollamaModelId: "deepseek-r1",
+				ollamaBaseUrl: "http://localhost:11434",
+			}
+
+			handler = new NativeOllamaHandler(options)
+
+			// Mock response with thinking tags
+			mockChat.mockImplementation(async function* () {
+				yield { message: { content: "<think>Let me think" } }
+				yield { message: { content: " about this</think>" } }
+				yield { message: { content: "The answer is 42" } }
+			})
+
+			const stream = handler.createMessage("System", [{ role: "user" as const, content: "Question?" }])
+			const results = []
+
+			for await (const chunk of stream) {
+				results.push(chunk)
+			}
+
+			// Should detect reasoning vs regular text
+			expect(results.some((r) => r.type === "reasoning")).toBe(true)
+			expect(results.some((r) => r.type === "text")).toBe(true)
+		})
+	})
+
+	describe("completePrompt", () => {
+		it("should complete a prompt without streaming", async () => {
+			mockChat.mockResolvedValue({
+				message: { content: "This is the response" },
+			})
+
+			const result = await handler.completePrompt("Tell me a joke")
+
+			expect(mockChat).toHaveBeenCalledWith({
+				model: "llama2",
+				messages: [{ role: "user", content: "Tell me a joke" }],
+				stream: false,
+				options: {
+					temperature: 0,
+				},
+			})
+			expect(result).toBe("This is the response")
+		})
+	})
+
+	describe("error handling", () => {
+		it("should handle connection refused errors", async () => {
+			const error = new Error("ECONNREFUSED") as any
+			error.code = "ECONNREFUSED"
+			mockChat.mockRejectedValue(error)
+
+			const stream = handler.createMessage("System", [{ role: "user" as const, content: "Test" }])
+
+			await expect(async () => {
+				for await (const _ of stream) {
+					// consume stream
+				}
+			}).rejects.toThrow("Ollama service is not running")
+		})
+
+		it("should handle model not found errors", async () => {
+			const error = new Error("Not found") as any
+			error.status = 404
+			mockChat.mockRejectedValue(error)
+
+			const stream = handler.createMessage("System", [{ role: "user" as const, content: "Test" }])
+
+			await expect(async () => {
+				for await (const _ of stream) {
+					// consume stream
+				}
+			}).rejects.toThrow("Model llama2 not found in Ollama")
+		})
+	})
+
+	describe("getModel", () => {
+		it("should return the configured model", () => {
+			const model = handler.getModel()
+			expect(model.id).toBe("llama2")
+			expect(model.info).toBeDefined()
+		})
+	})
+})

+ 285 - 0
src/api/providers/native-ollama.ts

@@ -0,0 +1,285 @@
+import { Anthropic } from "@anthropic-ai/sdk"
+import { Message, Ollama } from "ollama"
+import { ModelInfo, openAiModelInfoSaneDefaults, DEEP_SEEK_DEFAULT_TEMPERATURE } from "@roo-code/types"
+import { ApiStream } from "../transform/stream"
+import { BaseProvider } from "./base-provider"
+import type { ApiHandlerOptions } from "../../shared/api"
+import { getOllamaModels } from "./fetchers/ollama"
+import { XmlMatcher } from "../../utils/xml-matcher"
+import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
+
+function convertToOllamaMessages(anthropicMessages: Anthropic.Messages.MessageParam[]): Message[] {
+	const ollamaMessages: Message[] = []
+
+	for (const anthropicMessage of anthropicMessages) {
+		if (typeof anthropicMessage.content === "string") {
+			ollamaMessages.push({
+				role: anthropicMessage.role,
+				content: anthropicMessage.content,
+			})
+		} else {
+			if (anthropicMessage.role === "user") {
+				const { nonToolMessages, toolMessages } = anthropicMessage.content.reduce<{
+					nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[]
+					toolMessages: Anthropic.ToolResultBlockParam[]
+				}>(
+					(acc, part) => {
+						if (part.type === "tool_result") {
+							acc.toolMessages.push(part)
+						} else if (part.type === "text" || part.type === "image") {
+							acc.nonToolMessages.push(part)
+						}
+						return acc
+					},
+					{ nonToolMessages: [], toolMessages: [] },
+				)
+
+				// Process tool result messages FIRST since they must follow the tool use messages
+				const toolResultImages: string[] = []
+				toolMessages.forEach((toolMessage) => {
+					// The Anthropic SDK allows tool results to be a string or an array of text and image blocks, enabling rich and structured content. In contrast, the Ollama SDK only supports tool results as a single string, so we map the Anthropic tool result parts into one concatenated string to maintain compatibility.
+					let content: string
+
+					if (typeof toolMessage.content === "string") {
+						content = toolMessage.content
+					} else {
+						content =
+							toolMessage.content
+								?.map((part) => {
+									if (part.type === "image") {
+										// Handle base64 images only (Anthropic SDK uses base64)
+										// Ollama expects raw base64 strings, not data URLs
+										if ("source" in part && part.source.type === "base64") {
+											toolResultImages.push(part.source.data)
+										}
+										return "(see following user message for image)"
+									}
+									return part.text
+								})
+								.join("\n") ?? ""
+					}
+					ollamaMessages.push({
+						role: "user",
+						images: toolResultImages.length > 0 ? toolResultImages : undefined,
+						content: content,
+					})
+				})
+
+				// Process non-tool messages
+				if (nonToolMessages.length > 0) {
+					// Separate text and images for Ollama
+					const textContent = nonToolMessages
+						.filter((part) => part.type === "text")
+						.map((part) => part.text)
+						.join("\n")
+
+					const imageData: string[] = []
+					nonToolMessages.forEach((part) => {
+						if (part.type === "image" && "source" in part && part.source.type === "base64") {
+							// Ollama expects raw base64 strings, not data URLs
+							imageData.push(part.source.data)
+						}
+					})
+
+					ollamaMessages.push({
+						role: "user",
+						content: textContent,
+						images: imageData.length > 0 ? imageData : undefined,
+					})
+				}
+			} else if (anthropicMessage.role === "assistant") {
+				const { nonToolMessages } = anthropicMessage.content.reduce<{
+					nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[]
+					toolMessages: Anthropic.ToolUseBlockParam[]
+				}>(
+					(acc, part) => {
+						if (part.type === "tool_use") {
+							acc.toolMessages.push(part)
+						} else if (part.type === "text" || part.type === "image") {
+							acc.nonToolMessages.push(part)
+						} // assistant cannot send tool_result messages
+						return acc
+					},
+					{ nonToolMessages: [], toolMessages: [] },
+				)
+
+				// Process non-tool messages
+				let content: string = ""
+				if (nonToolMessages.length > 0) {
+					content = nonToolMessages
+						.map((part) => {
+							if (part.type === "image") {
+								return "" // impossible as the assistant cannot send images
+							}
+							return part.text
+						})
+						.join("\n")
+				}
+
+				ollamaMessages.push({
+					role: "assistant",
+					content,
+				})
+			}
+		}
+	}
+
+	return ollamaMessages
+}
+
+export class NativeOllamaHandler extends BaseProvider implements SingleCompletionHandler {
+	protected options: ApiHandlerOptions
+	private client: Ollama | undefined
+	protected models: Record<string, ModelInfo> = {}
+
+	constructor(options: ApiHandlerOptions) {
+		super()
+		this.options = options
+	}
+
+	private ensureClient(): Ollama {
+		if (!this.client) {
+			try {
+				this.client = new Ollama({
+					host: this.options.ollamaBaseUrl || "http://localhost:11434",
+					// Note: The ollama npm package handles timeouts internally
+				})
+			} catch (error: any) {
+				throw new Error(`Error creating Ollama client: ${error.message}`)
+			}
+		}
+		return this.client
+	}
+
+	override async *createMessage(
+		systemPrompt: string,
+		messages: Anthropic.Messages.MessageParam[],
+		metadata?: ApiHandlerCreateMessageMetadata,
+	): ApiStream {
+		const client = this.ensureClient()
+		const { id: modelId, info: modelInfo } = await this.fetchModel()
+		const useR1Format = modelId.toLowerCase().includes("deepseek-r1")
+
+		const ollamaMessages: Message[] = [
+			{ role: "system", content: systemPrompt },
+			...convertToOllamaMessages(messages),
+		]
+
+		const matcher = new XmlMatcher(
+			"think",
+			(chunk) =>
+				({
+					type: chunk.matched ? "reasoning" : "text",
+					text: chunk.data,
+				}) as const,
+		)
+
+		try {
+			// Create the actual API request promise
+			const stream = await client.chat({
+				model: modelId,
+				messages: ollamaMessages,
+				stream: true,
+				options: {
+					num_ctx: modelInfo.contextWindow,
+					temperature: this.options.modelTemperature ?? (useR1Format ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0),
+				},
+			})
+
+			let totalInputTokens = 0
+			let totalOutputTokens = 0
+
+			try {
+				for await (const chunk of stream) {
+					if (typeof chunk.message.content === "string") {
+						// Process content through matcher for reasoning detection
+						for (const matcherChunk of matcher.update(chunk.message.content)) {
+							yield matcherChunk
+						}
+					}
+
+					// Handle token usage if available
+					if (chunk.eval_count !== undefined || chunk.prompt_eval_count !== undefined) {
+						if (chunk.prompt_eval_count) {
+							totalInputTokens = chunk.prompt_eval_count
+						}
+						if (chunk.eval_count) {
+							totalOutputTokens = chunk.eval_count
+						}
+					}
+				}
+
+				// Yield any remaining content from the matcher
+				for (const chunk of matcher.final()) {
+					yield chunk
+				}
+
+				// Yield usage information if available
+				if (totalInputTokens > 0 || totalOutputTokens > 0) {
+					yield {
+						type: "usage",
+						inputTokens: totalInputTokens,
+						outputTokens: totalOutputTokens,
+					}
+				}
+			} catch (streamError: any) {
+				console.error("Error processing Ollama stream:", streamError)
+				throw new Error(`Ollama stream processing error: ${streamError.message || "Unknown error"}`)
+			}
+		} catch (error: any) {
+			// Enhance error reporting
+			const statusCode = error.status || error.statusCode
+			const errorMessage = error.message || "Unknown error"
+
+			if (error.code === "ECONNREFUSED") {
+				throw new Error(
+					`Ollama service is not running at ${this.options.ollamaBaseUrl || "http://localhost:11434"}. Please start Ollama first.`,
+				)
+			} else if (statusCode === 404) {
+				throw new Error(
+					`Model ${this.getModel().id} not found in Ollama. Please pull the model first with: ollama pull ${this.getModel().id}`,
+				)
+			}
+
+			console.error(`Ollama API error (${statusCode || "unknown"}): ${errorMessage}`)
+			throw error
+		}
+	}
+
+	async fetchModel() {
+		this.models = await getOllamaModels(this.options.ollamaBaseUrl)
+		return this.getModel()
+	}
+
+	override getModel(): { id: string; info: ModelInfo } {
+		const modelId = this.options.ollamaModelId || ""
+		return {
+			id: modelId,
+			info: this.models[modelId] || openAiModelInfoSaneDefaults,
+		}
+	}
+
+	async completePrompt(prompt: string): Promise<string> {
+		try {
+			const client = this.ensureClient()
+			const { id: modelId } = await this.fetchModel()
+			const useR1Format = modelId.toLowerCase().includes("deepseek-r1")
+
+			const response = await client.chat({
+				model: modelId,
+				messages: [{ role: "user", content: prompt }],
+				stream: false,
+				options: {
+					temperature: this.options.modelTemperature ?? (useR1Format ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0),
+				},
+			})
+
+			return response.message?.content || ""
+		} catch (error) {
+			if (error instanceof Error) {
+				throw new Error(`Ollama completion error: ${error.message}`)
+			}
+			throw error
+		}
+	}
+}

+ 1 - 0
src/package.json

@@ -458,6 +458,7 @@
 		"monaco-vscode-textmate-theme-converter": "^0.1.7",
 		"node-cache": "^5.1.2",
 		"node-ipc": "^12.0.0",
+		"ollama": "^0.5.17",
 		"openai": "^5.0.0",
 		"os-name": "^6.0.0",
 		"p-limit": "^6.2.0",