Browse Source

Modifying the usage of unbound.ts in compliance with all providers

Vignesh Subbiah 11 months ago
parent
commit
4008a1a53e
2 changed files with 299 additions and 73 deletions
  1. 186 40
      src/api/providers/__tests__/unbound.test.ts
  2. 113 33
      src/api/providers/unbound.ts

+ 186 - 40
src/api/providers/__tests__/unbound.test.ts

@@ -1,64 +1,210 @@
 import { UnboundHandler } from "../unbound"
 import { ApiHandlerOptions } from "../../../shared/api"
-import fetchMock from "jest-fetch-mock"
+import OpenAI from "openai"
+import { Anthropic } from "@anthropic-ai/sdk"
 
-fetchMock.enableMocks()
+// Mock OpenAI client
+const mockCreate = jest.fn()
+const mockWithResponse = jest.fn()
 
-describe("UnboundHandler", () => {
-	const mockOptions: ApiHandlerOptions = {
-		unboundApiKey: "test-api-key",
-		apiModelId: "test-model-id",
+jest.mock("openai", () => {
+	return {
+		__esModule: true,
+		default: jest.fn().mockImplementation(() => ({
+			chat: {
+				completions: {
+					create: (...args: any[]) => {
+						const stream = {
+							[Symbol.asyncIterator]: async function* () {
+								yield {
+									choices: [
+										{
+											delta: { content: "Test response" },
+											index: 0,
+										},
+									],
+								}
+								yield {
+									choices: [
+										{
+											delta: {},
+											index: 0,
+										},
+									],
+								}
+							},
+						}
+
+						const result = mockCreate(...args)
+						if (args[0].stream) {
+							mockWithResponse.mockReturnValue(
+								Promise.resolve({
+									data: stream,
+									response: { headers: new Map() },
+								}),
+							)
+							result.withResponse = mockWithResponse
+						}
+						return result
+					},
+				},
+			},
+		})),
 	}
+})
+
+describe("UnboundHandler", () => {
+	let handler: UnboundHandler
+	let mockOptions: ApiHandlerOptions
 
 	beforeEach(() => {
-		fetchMock.resetMocks()
+		mockOptions = {
+			apiModelId: "anthropic/claude-3-5-sonnet-20241022",
+			unboundApiKey: "test-api-key",
+		}
+		handler = new UnboundHandler(mockOptions)
+		mockCreate.mockClear()
+		mockWithResponse.mockClear()
+
+		// Default mock implementation for non-streaming responses
+		mockCreate.mockResolvedValue({
+			id: "test-completion",
+			choices: [
+				{
+					message: { role: "assistant", content: "Test response" },
+					finish_reason: "stop",
+					index: 0,
+				},
+			],
+		})
 	})
 
-	it("should initialize with options", () => {
-		const handler = new UnboundHandler(mockOptions)
-		expect(handler).toBeDefined()
+	describe("constructor", () => {
+		it("should initialize with provided options", () => {
+			expect(handler).toBeInstanceOf(UnboundHandler)
+			expect(handler.getModel().id).toBe(mockOptions.apiModelId)
+		})
 	})
 
-	it("should create a message successfully", async () => {
-		const handler = new UnboundHandler(mockOptions)
-		const mockResponse = {
-			choices: [{ message: { content: "Hello, world!" } }],
-			usage: { prompt_tokens: 5, completion_tokens: 7 },
-		}
+	describe("createMessage", () => {
+		const systemPrompt = "You are a helpful assistant."
+		const messages: Anthropic.Messages.MessageParam[] = [
+			{
+				role: "user",
+				content: "Hello!",
+			},
+		]
 
-		fetchMock.mockResponseOnce(JSON.stringify(mockResponse))
+		it("should handle streaming responses", async () => {
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
 
-		const generator = handler.createMessage("system prompt", [])
-		const textResult = await generator.next()
-		const usageResult = await generator.next()
+			expect(chunks.length).toBe(1)
+			expect(chunks[0]).toEqual({
+				type: "text",
+				text: "Test response",
+			})
 
-		expect(textResult.value).toEqual({ type: "text", text: "Hello, world!" })
-		expect(usageResult.value).toEqual({
-			type: "usage",
-			inputTokens: 5,
-			outputTokens: 7,
+			expect(mockCreate).toHaveBeenCalledWith(
+				expect.objectContaining({
+					model: "claude-3-5-sonnet-20241022",
+					messages: expect.any(Array),
+					stream: true,
+				}),
+				expect.objectContaining({
+					headers: {
+						"X-Unbound-Metadata": expect.stringContaining("roo-code"),
+					},
+				}),
+			)
 		})
-	})
 
-	it("should handle API errors", async () => {
-		const handler = new UnboundHandler(mockOptions)
-		fetchMock.mockResponseOnce(JSON.stringify({ error: "API error" }), { status: 400 })
+		it("should handle API errors", async () => {
+			mockCreate.mockImplementationOnce(() => {
+				throw new Error("API Error")
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks = []
 
-		const generator = handler.createMessage("system prompt", [])
-		await expect(generator.next()).rejects.toThrow("Unbound Gateway completion error: API error")
+			try {
+				for await (const chunk of stream) {
+					chunks.push(chunk)
+				}
+				fail("Expected error to be thrown")
+			} catch (error) {
+				expect(error).toBeInstanceOf(Error)
+				expect(error.message).toBe("API Error")
+			}
+		})
 	})
 
-	it("should handle network errors", async () => {
-		const handler = new UnboundHandler(mockOptions)
-		fetchMock.mockRejectOnce(new Error("Network error"))
+	describe("completePrompt", () => {
+		it("should complete prompt successfully", async () => {
+			const result = await handler.completePrompt("Test prompt")
+			expect(result).toBe("Test response")
+			expect(mockCreate).toHaveBeenCalledWith(
+				expect.objectContaining({
+					model: "claude-3-5-sonnet-20241022",
+					messages: [{ role: "user", content: "Test prompt" }],
+					temperature: 0,
+					max_tokens: 8192,
+				}),
+			)
+		})
+
+		it("should handle API errors", async () => {
+			mockCreate.mockRejectedValueOnce(new Error("API Error"))
+			await expect(handler.completePrompt("Test prompt")).rejects.toThrow("Unbound completion error: API Error")
+		})
+
+		it("should handle empty response", async () => {
+			mockCreate.mockResolvedValueOnce({
+				choices: [{ message: { content: "" } }],
+			})
+			const result = await handler.completePrompt("Test prompt")
+			expect(result).toBe("")
+		})
+
+		it("should not set max_tokens for non-Anthropic models", async () => {
+			mockCreate.mockClear()
+
+			const nonAnthropicOptions = {
+				apiModelId: "openai/gpt-4o",
+				unboundApiKey: "test-key",
+			}
+			const nonAnthropicHandler = new UnboundHandler(nonAnthropicOptions)
 
-		const generator = handler.createMessage("system prompt", [])
-		await expect(generator.next()).rejects.toThrow("Unbound Gateway completion error: Network error")
+			await nonAnthropicHandler.completePrompt("Test prompt")
+			expect(mockCreate).toHaveBeenCalledWith(
+				expect.objectContaining({
+					model: "gpt-4o",
+					messages: [{ role: "user", content: "Test prompt" }],
+					temperature: 0,
+				}),
+			)
+			expect(mockCreate.mock.calls[0][0]).not.toHaveProperty("max_tokens")
+		})
 	})
 
-	it("should return the correct model", () => {
-		const handler = new UnboundHandler(mockOptions)
-		const model = handler.getModel()
-		expect(model.id).toBe("gpt-4o")
+	describe("getModel", () => {
+		it("should return model info", () => {
+			const modelInfo = handler.getModel()
+			expect(modelInfo.id).toBe(mockOptions.apiModelId)
+			expect(modelInfo.info).toBeDefined()
+		})
+
+		it("should return default model when invalid model provided", () => {
+			const handlerWithInvalidModel = new UnboundHandler({
+				...mockOptions,
+				apiModelId: "invalid/model",
+			})
+			const modelInfo = handlerWithInvalidModel.getModel()
+			expect(modelInfo.id).toBe("openai/gpt-4o") // Default model
+			expect(modelInfo.info).toBeDefined()
+		})
 	})
 })

+ 113 - 33
src/api/providers/unbound.ts

@@ -1,50 +1,108 @@
-import { ApiHandlerOptions, unboundModels, UnboundModelId, unboundDefaultModelId, ModelInfo } from "../../shared/api"
-import { ApiStream } from "../transform/stream"
 import { Anthropic } from "@anthropic-ai/sdk"
-import { ApiHandler } from "../index"
+import OpenAI from "openai"
+import { ApiHandler, SingleCompletionHandler } from "../"
+import { ApiHandlerOptions, ModelInfo, UnboundModelId, unboundDefaultModelId, unboundModels } from "../../shared/api"
+import { convertToOpenAiMessages } from "../transform/openai-format"
+import { ApiStream } from "../transform/stream"
 
-export class UnboundHandler implements ApiHandler {
-	private unboundBaseUrl: string = "https://api.getunbound.ai/v1"
+export class UnboundHandler implements ApiHandler, SingleCompletionHandler {
 	private options: ApiHandlerOptions
+	private client: OpenAI
 
 	constructor(options: ApiHandlerOptions) {
 		this.options = options
+		this.client = new OpenAI({
+			baseURL: "https://api.getunbound.ai/v1",
+			apiKey: this.options.unboundApiKey,
+		})
 	}
 
 	async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
-		try {
-			const response = await fetch(`${this.unboundBaseUrl}/chat/completions`, {
-				method: "POST",
-				headers: {
-					Authorization: `Bearer ${this.options.unboundApiKey}`,
-					"Content-Type": "application/json",
-				},
-				body: JSON.stringify({
-					model: this.getModel().id.split("/")[1],
-					messages: [{ role: "system", content: systemPrompt }, ...messages],
-				}),
+		// Convert Anthropic messages to OpenAI format
+		const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
+			{ role: "system", content: systemPrompt },
+			...convertToOpenAiMessages(messages),
+		]
+
+		// this is specifically for claude models (some models may 'support prompt caching' automatically without this)
+		if (this.getModel().id.startsWith("anthropic/claude-3")) {
+			openAiMessages[0] = {
+				role: "system",
+				content: [
+					{
+						type: "text",
+						text: systemPrompt,
+						// @ts-ignore-next-line
+						cache_control: { type: "ephemeral" },
+					},
+				],
+			}
+
+			// Add cache_control to the last two user messages
+			// (note: this works because we only ever add one user message at a time,
+			// but if we added multiple we'd need to mark the user message before the last assistant message)
+			const lastTwoUserMessages = openAiMessages.filter((msg) => msg.role === "user").slice(-2)
+			lastTwoUserMessages.forEach((msg) => {
+				if (typeof msg.content === "string") {
+					msg.content = [{ type: "text", text: msg.content }]
+				}
+				if (Array.isArray(msg.content)) {
+					// NOTE: this is fine since env details will always be added at the end.
+					// but if it weren't there, and the user added a image_url type message,
+					// it would pop a text part before it and then move it after to the end.
+					let lastTextPart = msg.content.filter((part) => part.type === "text").pop()
+
+					if (!lastTextPart) {
+						lastTextPart = { type: "text", text: "..." }
+						msg.content.push(lastTextPart)
+					}
+					// @ts-ignore-next-line
+					lastTextPart["cache_control"] = { type: "ephemeral" }
+				}
 			})
+		}
 
-			const data = await response.json()
+		// Required by Anthropic
+		// Other providers default to max tokens allowed.
+		let maxTokens: number | undefined
 
-			if (!response.ok) {
-				throw new Error(data.error.message)
-			}
+		if (this.getModel().id.startsWith("anthropic/")) {
+			maxTokens = 8_192
+		}
 
-			yield {
-				type: "text",
-				text: data.choices[0]?.message?.content || "",
-			}
-			yield {
-				type: "usage",
-				inputTokens: data.usage?.prompt_tokens || 0,
-				outputTokens: data.usage?.completion_tokens || 0,
-			}
-		} catch (error) {
-			if (error instanceof Error) {
-				throw new Error(`Unbound Gateway completion error:\n ${error.message}`)
+		const { data: completion, response } = await this.client.chat.completions
+			.create(
+				{
+					model: this.getModel().id.split("/")[1],
+					max_tokens: maxTokens,
+					temperature: 0,
+					messages: openAiMessages,
+					stream: true,
+				},
+				{
+					headers: {
+						"X-Unbound-Metadata": JSON.stringify({
+							labels: [
+								{
+									key: "app",
+									value: "roo-code",
+								},
+							],
+						}),
+					},
+				},
+			)
+			.withResponse()
+
+		for await (const chunk of completion) {
+			const delta = chunk.choices[0]?.delta
+
+			if (delta?.content) {
+				yield {
+					type: "text",
+					text: delta.content,
+				}
 			}
-			throw error
 		}
 	}
 
@@ -59,4 +117,26 @@ export class UnboundHandler implements ApiHandler {
 			info: unboundModels[unboundDefaultModelId],
 		}
 	}
+
+	async completePrompt(prompt: string): Promise<string> {
+		try {
+			const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
+				model: this.getModel().id.split("/")[1],
+				messages: [{ role: "user", content: prompt }],
+				temperature: 0,
+			}
+
+			if (this.getModel().id.startsWith("anthropic/")) {
+				requestOptions.max_tokens = 8192
+			}
+
+			const response = await this.client.chat.completions.create(requestOptions)
+			return response.choices[0]?.message.content || ""
+		} catch (error) {
+			if (error instanceof Error) {
+				throw new Error(`Unbound completion error: ${error.message}`)
+			}
+			throw error
+		}
+	}
 }