Browse Source

Merge pull request #1534 from dqroid/support-custom-baseUrl-for-google-ai-studio-gemini

support custom base url for gemini in google AI studio
Matt Rubens 11 months ago
parent
commit
705b3ba11f

+ 17 - 7
src/api/providers/__tests__/gemini.test.ts

@@ -101,10 +101,15 @@ describe("GeminiHandler", () => {
 			})
 
 			// Verify the model configuration
-			expect(mockGetGenerativeModel).toHaveBeenCalledWith({
-				model: "gemini-2.0-flash-thinking-exp-1219",
-				systemInstruction: systemPrompt,
-			})
+			expect(mockGetGenerativeModel).toHaveBeenCalledWith(
+				{
+					model: "gemini-2.0-flash-thinking-exp-1219",
+					systemInstruction: systemPrompt,
+				},
+				{
+					baseUrl: undefined,
+				},
+			)
 
 			// Verify generation config
 			expect(mockGenerateContentStream).toHaveBeenCalledWith(
@@ -149,9 +154,14 @@ describe("GeminiHandler", () => {
 
 			const result = await handler.completePrompt("Test prompt")
 			expect(result).toBe("Test response")
-			expect(mockGetGenerativeModel).toHaveBeenCalledWith({
-				model: "gemini-2.0-flash-thinking-exp-1219",
-			})
+			expect(mockGetGenerativeModel).toHaveBeenCalledWith(
+				{
+					model: "gemini-2.0-flash-thinking-exp-1219",
+				},
+				{
+					baseUrl: undefined,
+				},
+			)
 			expect(mockGenerateContent).toHaveBeenCalledWith({
 				contents: [{ role: "user", parts: [{ text: "Test prompt" }] }],
 				generationConfig: {

+ 17 - 7
src/api/providers/gemini.ts

@@ -19,10 +19,15 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
 	}
 
 	override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
-		const model = this.client.getGenerativeModel({
-			model: this.getModel().id,
-			systemInstruction: systemPrompt,
-		})
+		const model = this.client.getGenerativeModel(
+			{
+				model: this.getModel().id,
+				systemInstruction: systemPrompt,
+			},
+			{
+				baseUrl: this.options.googleGeminiBaseUrl || undefined,
+			},
+		)
 		const result = await model.generateContentStream({
 			contents: messages.map(convertAnthropicMessageToGemini),
 			generationConfig: {
@@ -57,9 +62,14 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
 
 	async completePrompt(prompt: string): Promise<string> {
 		try {
-			const model = this.client.getGenerativeModel({
-				model: this.getModel().id,
-			})
+			const model = this.client.getGenerativeModel(
+				{
+					model: this.getModel().id,
+				},
+				{
+					baseUrl: this.options.googleGeminiBaseUrl || undefined,
+				},
+			)
 
 			const result = await model.generateContent({
 				contents: [{ role: "user", parts: [{ text: prompt }] }],

+ 1 - 0
src/exports/roo-code.d.ts

@@ -154,6 +154,7 @@ export type GlobalStateKey =
 	| "openRouterModelInfo"
 	| "openRouterBaseUrl"
 	| "openRouterUseMiddleOutTransform"
+	| "googleGeminiBaseUrl"
 	| "allowedCommands"
 	| "soundEnabled"
 	| "soundVolume"

+ 2 - 0
src/shared/api.ts

@@ -56,6 +56,7 @@ export interface ApiHandlerOptions {
 	lmStudioDraftModelId?: string
 	lmStudioSpeculativeDecodingEnabled?: boolean
 	geminiApiKey?: string
+	googleGeminiBaseUrl?: string
 	openAiNativeApiKey?: string
 	mistralApiKey?: string
 	mistralCodestralUrl?: string // New option for Codestral URL
@@ -115,6 +116,7 @@ export const API_CONFIG_KEYS: GlobalStateKey[] = [
 	"lmStudioBaseUrl",
 	"lmStudioDraftModelId",
 	"lmStudioSpeculativeDecodingEnabled",
+	"googleGeminiBaseUrl",
 	"mistralCodestralUrl",
 	"azureApiVersion",
 	"openRouterUseMiddleOutTransform",

+ 1 - 0
src/shared/globalState.ts

@@ -72,6 +72,7 @@ export const GLOBAL_STATE_KEYS = [
 	"openRouterModelInfo",
 	"openRouterBaseUrl",
 	"openRouterUseMiddleOutTransform",
+	"googleGeminiBaseUrl",
 	"allowedCommands",
 	"soundEnabled",
 	"soundVolume",

+ 25 - 0
webview-ui/src/components/settings/ApiOptions.tsx

@@ -116,6 +116,9 @@ const ApiOptions = ({
 	const [anthropicBaseUrlSelected, setAnthropicBaseUrlSelected] = useState(!!apiConfiguration?.anthropicBaseUrl)
 	const [azureApiVersionSelected, setAzureApiVersionSelected] = useState(!!apiConfiguration?.azureApiVersion)
 	const [openRouterBaseUrlSelected, setOpenRouterBaseUrlSelected] = useState(!!apiConfiguration?.openRouterBaseUrl)
+	const [googleGeminiBaseUrlSelected, setGoogleGeminiBaseUrlSelected] = useState(
+		!!apiConfiguration?.googleGeminiBaseUrl,
+	)
 	const [isDescriptionExpanded, setIsDescriptionExpanded] = useState(false)
 
 	const noTransform = <T,>(value: T) => value
@@ -646,6 +649,28 @@ const ApiOptions = ({
 							Get Gemini API Key
 						</VSCodeButtonLink>
 					)}
+					<div>
+						<Checkbox
+							checked={googleGeminiBaseUrlSelected}
+							onChange={(checked: boolean) => {
+								setGoogleGeminiBaseUrlSelected(checked)
+
+								if (!checked) {
+									setApiConfigurationField("googleGeminiBaseUrl", "")
+								}
+							}}>
+							Use custom base URL
+						</Checkbox>
+						{googleGeminiBaseUrlSelected && (
+							<VSCodeTextField
+								value={apiConfiguration?.googleGeminiBaseUrl || ""}
+								type="url"
+								onInput={handleInputChange("googleGeminiBaseUrl")}
+								placeholder="https://generativelanguage.googleapis.com"
+								className="w-full mt-1"
+							/>
+						)}
+					</div>
 				</>
 			)}