Browse Source

refactor(api): improve OpenAI handler inheritance

- Add OpenAiHandlerOptions interface for configuration
- Extract processUsageMetrics to base class for reuse
- Update RequestyHandler to extend OpenAiHandler
- Add proper type safety for metrics handling
- Clean up code duplication across handlers
sam hoang 1 year ago
parent
commit
41fcf85c48
3 changed files with 46 additions and 131 deletions
  1. 3 3
      src/api/providers/deepseek.ts
  2. 18 14
      src/api/providers/openai.ts
  3. 25 114
      src/api/providers/requesty.ts

+ 3 - 3
src/api/providers/deepseek.ts

@@ -1,9 +1,9 @@
-import { OpenAiHandler } from "./openai"
-import { ApiHandlerOptions, ModelInfo } from "../../shared/api"
+import { OpenAiHandler, OpenAiHandlerOptions } from "./openai"
+import { ModelInfo } from "../../shared/api"
 import { deepSeekModels, deepSeekDefaultModelId } from "../../shared/api"
 
 export class DeepSeekHandler extends OpenAiHandler {
-	constructor(options: ApiHandlerOptions) {
+	constructor(options: OpenAiHandlerOptions) {
 		super({
 			...options,
 			openAiApiKey: options.deepSeekApiKey ?? "not-provided",

+ 18 - 14
src/api/providers/openai.ts

@@ -11,13 +11,17 @@ import { ApiHandler, SingleCompletionHandler } from "../index"
 import { convertToOpenAiMessages } from "../transform/openai-format"
 import { convertToR1Format } from "../transform/r1-format"
 import { convertToSimpleMessages } from "../transform/simple-format"
-import { ApiStream } from "../transform/stream"
+import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
+
+export interface OpenAiHandlerOptions extends ApiHandlerOptions {
+	defaultHeaders?: Record<string, string>
+}
 
 export class OpenAiHandler implements ApiHandler, SingleCompletionHandler {
-	protected options: ApiHandlerOptions
+	protected options: OpenAiHandlerOptions
 	private client: OpenAI
 
-	constructor(options: ApiHandlerOptions) {
+	constructor(options: OpenAiHandlerOptions) {
 		this.options = options
 
 		const baseURL = this.options.openAiBaseUrl ?? "https://api.openai.com/v1"
@@ -41,7 +45,7 @@ export class OpenAiHandler implements ApiHandler, SingleCompletionHandler {
 				apiVersion: this.options.azureApiVersion || azureOpenAiDefaultApiVersion,
 			})
 		} else {
-			this.client = new OpenAI({ baseURL, apiKey })
+			this.client = new OpenAI({ baseURL, apiKey, defaultHeaders: this.options.defaultHeaders })
 		}
 	}
 
@@ -98,11 +102,7 @@ export class OpenAiHandler implements ApiHandler, SingleCompletionHandler {
 					}
 				}
 				if (chunk.usage) {
-					yield {
-						type: "usage",
-						inputTokens: chunk.usage.prompt_tokens || 0,
-						outputTokens: chunk.usage.completion_tokens || 0,
-					}
+					yield this.processUsageMetrics(chunk.usage)
 				}
 			}
 		} else {
@@ -125,11 +125,15 @@ export class OpenAiHandler implements ApiHandler, SingleCompletionHandler {
 				type: "text",
 				text: response.choices[0]?.message.content || "",
 			}
-			yield {
-				type: "usage",
-				inputTokens: response.usage?.prompt_tokens || 0,
-				outputTokens: response.usage?.completion_tokens || 0,
-			}
+			yield this.processUsageMetrics(response.usage)
+		}
+	}
+
+	protected processUsageMetrics(usage: any): ApiStreamUsageChunk {
+		return {
+			type: "usage",
+			inputTokens: usage?.prompt_tokens || 0,
+			outputTokens: usage?.completion_tokens || 0,
 		}
 	}
 

+ 25 - 114
src/api/providers/requesty.ts

@@ -1,21 +1,18 @@
-import { Anthropic } from "@anthropic-ai/sdk"
-import OpenAI from "openai"
-
-import { ApiHandlerOptions, ModelInfo, requestyModelInfoSaneDefaults } from "../../shared/api"
-import { ApiHandler, SingleCompletionHandler } from "../index"
-import { convertToOpenAiMessages } from "../transform/openai-format"
-import { convertToR1Format } from "../transform/r1-format"
-import { ApiStream } from "../transform/stream"
-
-export class RequestyHandler implements ApiHandler, SingleCompletionHandler {
-	protected options: ApiHandlerOptions
-	private client: OpenAI
-
-	constructor(options: ApiHandlerOptions) {
-		this.options = options
-		this.client = new OpenAI({
-			baseURL: "https://router.requesty.ai/v1",
-			apiKey: this.options.requestyApiKey,
+import { OpenAiHandler, OpenAiHandlerOptions } from "./openai"
+import { ModelInfo, requestyModelInfoSaneDefaults, requestyDefaultModelId } from "../../shared/api"
+import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
+
+export class RequestyHandler extends OpenAiHandler {
+	constructor(options: OpenAiHandlerOptions) {
+		if (!options.requestyApiKey) {
+			throw new Error("Requesty API key is required. Please provide it in the settings.")
+		}
+		super({
+			...options,
+			openAiApiKey: options.requestyApiKey,
+			openAiModelId: options.requestyModelId ?? requestyDefaultModelId,
+			openAiBaseUrl: "https://router.requesty.ai/v1",
+			openAiCustomModelInfo: options.requestyModelInfo ?? requestyModelInfoSaneDefaults,
 			defaultHeaders: {
 				"HTTP-Referer": "https://github.com/RooVetGit/Roo-Cline",
 				"X-Title": "Roo Code",
@@ -23,107 +20,21 @@ export class RequestyHandler implements ApiHandler, SingleCompletionHandler {
 		})
 	}
 
-	async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
-		const modelInfo = this.getModel().info
-		const modelId = this.options.requestyModelId ?? ""
-
-		const deepseekReasoner = modelId.includes("deepseek-reasoner")
-
-		if (this.options.openAiStreamingEnabled ?? true) {
-			const systemMessage: OpenAI.Chat.ChatCompletionSystemMessageParam = {
-				role: "system",
-				content: systemPrompt,
-			}
-			const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
-				model: modelId,
-				temperature: 0,
-				messages: deepseekReasoner
-					? convertToR1Format([{ role: "user", content: systemPrompt }, ...messages])
-					: [systemMessage, ...convertToOpenAiMessages(messages)],
-				stream: true as const,
-				stream_options: { include_usage: true },
-			}
-			if (this.options.includeMaxTokens) {
-				requestOptions.max_tokens = modelInfo.maxTokens
-			}
-
-			const stream = await this.client.chat.completions.create(requestOptions)
-
-			for await (const chunk of stream) {
-				const delta = chunk.choices[0]?.delta ?? {}
-
-				if (delta.content) {
-					yield {
-						type: "text",
-						text: delta.content,
-					}
-				}
-
-				if ("reasoning_content" in delta && delta.reasoning_content) {
-					yield {
-						type: "reasoning",
-						text: (delta.reasoning_content as string | undefined) || "",
-					}
-				}
-				if (chunk.usage) {
-					yield {
-						type: "usage",
-						inputTokens: chunk.usage.prompt_tokens || 0,
-						outputTokens: chunk.usage.completion_tokens || 0,
-						cacheWriteTokens: (chunk.usage as any).cache_creation_input_tokens || undefined,
-						cacheReadTokens: (chunk.usage as any).cache_read_input_tokens || undefined,
-					}
-				}
-			}
-		} else {
-			// o1 for instance doesnt support streaming, non-1 temp, or system prompt
-			const systemMessage: OpenAI.Chat.ChatCompletionUserMessageParam = {
-				role: "user",
-				content: systemPrompt,
-			}
-
-			const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
-				model: modelId,
-				messages: deepseekReasoner
-					? convertToR1Format([{ role: "user", content: systemPrompt }, ...messages])
-					: [systemMessage, ...convertToOpenAiMessages(messages)],
-			}
-
-			const response = await this.client.chat.completions.create(requestOptions)
-
-			yield {
-				type: "text",
-				text: response.choices[0]?.message.content || "",
-			}
-			yield {
-				type: "usage",
-				inputTokens: response.usage?.prompt_tokens || 0,
-				outputTokens: response.usage?.completion_tokens || 0,
-			}
-		}
-	}
-
-	getModel(): { id: string; info: ModelInfo } {
+	override getModel(): { id: string; info: ModelInfo } {
+		const modelId = this.options.requestyModelId ?? requestyDefaultModelId
 		return {
-			id: this.options.requestyModelId ?? "",
+			id: modelId,
 			info: this.options.requestyModelInfo ?? requestyModelInfoSaneDefaults,
 		}
 	}
 
-	async completePrompt(prompt: string): Promise<string> {
-		try {
-			const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
-				model: this.getModel().id,
-				messages: [{ role: "user", content: prompt }],
-			}
-
-			const response = await this.client.chat.completions.create(requestOptions)
-			return response.choices[0]?.message.content || ""
-		} catch (error) {
-			if (error instanceof Error) {
-				throw new Error(`OpenAI completion error: ${error.message}`)
-			}
-			throw error
+	protected override processUsageMetrics(usage: any): ApiStreamUsageChunk {
+		return {
+			type: "usage",
+			inputTokens: usage?.prompt_tokens || 0,
+			outputTokens: usage?.completion_tokens || 0,
+			cacheWriteTokens: usage?.cache_creation_input_tokens,
+			cacheReadTokens: usage?.cache_read_input_tokens,
 		}
 	}
 }