فهرست منبع

feat: add new provider plugin hook for resolving models and sync models from github models endpoint (falls back to models.dev) (#20533)

Aiden Cline 2 هفته پیش
والد
کامیت
1fcfb69bf7

+ 41 - 31
packages/opencode/src/plugin/copilot.ts → packages/opencode/src/plugin/github-copilot/copilot.ts

@@ -1,7 +1,12 @@
 import type { Hooks, PluginInput } from "@opencode-ai/plugin"
+import type { Model } from "@opencode-ai/sdk/v2"
 import { Installation } from "@/installation"
 import { iife } from "@/util/iife"
+import { Log } from "../../util/log"
 import { setTimeout as sleep } from "node:timers/promises"
+import { CopilotModels } from "./models"
+
+const log = Log.create({ service: "plugin.copilot" })
 
 const CLIENT_ID = "Ov23li8tweQw6odWQebz"
 // Add a small safety buffer when polling to avoid hitting the server
@@ -18,45 +23,50 @@ function getUrls(domain: string) {
   }
 }
 
+function base(enterpriseUrl?: string) {
+  return enterpriseUrl ? `https://copilot-api.${normalizeDomain(enterpriseUrl)}` : "https://api.githubcopilot.com"
+}
+
+function fix(model: Model): Model {
+  return {
+    ...model,
+    api: {
+      ...model.api,
+      npm: "@ai-sdk/github-copilot",
+    },
+  }
+}
+
 export async function CopilotAuthPlugin(input: PluginInput): Promise<Hooks> {
   const sdk = input.client
   return {
+    provider: {
+      id: "github-copilot",
+      async models(provider, ctx) {
+        if (ctx.auth?.type !== "oauth") {
+          return Object.fromEntries(Object.entries(provider.models).map(([id, model]) => [id, fix(model)]))
+        }
+
+        return CopilotModels.get(
+          base(ctx.auth.enterpriseUrl),
+          {
+            Authorization: `Bearer ${ctx.auth.refresh}`,
+            "User-Agent": `opencode/${Installation.VERSION}`,
+          },
+          provider.models,
+        ).catch((error) => {
+          log.error("failed to fetch copilot models", { error })
+          return Object.fromEntries(Object.entries(provider.models).map(([id, model]) => [id, fix(model)]))
+        })
+      },
+    },
     auth: {
       provider: "github-copilot",
-      async loader(getAuth, provider) {
+      async loader(getAuth) {
         const info = await getAuth()
         if (!info || info.type !== "oauth") return {}
 
-        const enterpriseUrl = info.enterpriseUrl
-        const baseURL = enterpriseUrl ? `https://copilot-api.${normalizeDomain(enterpriseUrl)}` : undefined
-
-        if (provider && provider.models) {
-          for (const model of Object.values(provider.models)) {
-            model.cost = {
-              input: 0,
-              output: 0,
-              cache: {
-                read: 0,
-                write: 0,
-              },
-            }
-
-            // TODO: re-enable once messages api has higher rate limits
-            // TODO: move some of this hacky-ness to models.dev presets once we have better grasp of things here...
-            // const base = baseURL ?? model.api.url
-            // const claude = model.id.includes("claude")
-            // const url = iife(() => {
-            //   if (!claude) return base
-            //   if (base.endsWith("/v1")) return base
-            //   if (base.endsWith("/")) return `${base}v1`
-            //   return `${base}/v1`
-            // })
-
-            // model.api.url = url
-            // model.api.npm = claude ? "@ai-sdk/anthropic" : "@ai-sdk/github-copilot"
-            model.api.npm = "@ai-sdk/github-copilot"
-          }
-        }
+        const baseURL = base(info.enterpriseUrl)
 
         return {
           baseURL,

+ 143 - 0
packages/opencode/src/plugin/github-copilot/models.ts

@@ -0,0 +1,143 @@
+import { z } from "zod"
+import type { Model } from "@opencode-ai/sdk/v2"
+
+export namespace CopilotModels {
+  export const schema = z.object({
+    data: z.array(
+      z.object({
+        model_picker_enabled: z.boolean(),
+        id: z.string(),
+        name: z.string(),
+        // every version looks like: `{model.id}-YYYY-MM-DD`
+        version: z.string(),
+        supported_endpoints: z.array(z.string()).optional(),
+        capabilities: z.object({
+          family: z.string(),
+          limits: z.object({
+            max_context_window_tokens: z.number(),
+            max_output_tokens: z.number(),
+            max_prompt_tokens: z.number(),
+            vision: z
+              .object({
+                max_prompt_image_size: z.number(),
+                max_prompt_images: z.number(),
+                supported_media_types: z.array(z.string()),
+              })
+              .optional(),
+          }),
+          supports: z.object({
+            adaptive_thinking: z.boolean().optional(),
+            max_thinking_budget: z.number().optional(),
+            min_thinking_budget: z.number().optional(),
+            reasoning_effort: z.array(z.string()).optional(),
+            streaming: z.boolean(),
+            structured_outputs: z.boolean().optional(),
+            tool_calls: z.boolean(),
+            vision: z.boolean().optional(),
+          }),
+        }),
+      }),
+    ),
+  })
+
+  type Item = z.infer<typeof schema>["data"][number]
+
+  function build(key: string, remote: Item, url: string, prev?: Model): Model {
+    const reasoning =
+      !!remote.capabilities.supports.adaptive_thinking ||
+      !!remote.capabilities.supports.reasoning_effort?.length ||
+      remote.capabilities.supports.max_thinking_budget !== undefined ||
+      remote.capabilities.supports.min_thinking_budget !== undefined
+    const image =
+      (remote.capabilities.supports.vision ?? false) ||
+      (remote.capabilities.limits.vision?.supported_media_types ?? []).some((item) => item.startsWith("image/"))
+
+    return {
+      id: key,
+      providerID: "github-copilot",
+      api: {
+        id: remote.id,
+        url,
+        npm: "@ai-sdk/github-copilot",
+      },
+      // API response wins
+      status: "active",
+      limit: {
+        context: remote.capabilities.limits.max_context_window_tokens,
+        input: remote.capabilities.limits.max_prompt_tokens,
+        output: remote.capabilities.limits.max_output_tokens,
+      },
+      capabilities: {
+        temperature: prev?.capabilities.temperature ?? true,
+        reasoning: prev?.capabilities.reasoning ?? reasoning,
+        attachment: prev?.capabilities.attachment ?? true,
+        toolcall: remote.capabilities.supports.tool_calls,
+        input: {
+          text: true,
+          audio: false,
+          image,
+          video: false,
+          pdf: false,
+        },
+        output: {
+          text: true,
+          audio: false,
+          image: false,
+          video: false,
+          pdf: false,
+        },
+        interleaved: false,
+      },
+      // existing wins
+      family: prev?.family ?? remote.capabilities.family,
+      name: prev?.name ?? remote.name,
+      cost: {
+        input: 0,
+        output: 0,
+        cache: { read: 0, write: 0 },
+      },
+      options: prev?.options ?? {},
+      headers: prev?.headers ?? {},
+      release_date:
+        prev?.release_date ??
+        (remote.version.startsWith(`${remote.id}-`) ? remote.version.slice(remote.id.length + 1) : remote.version),
+      variants: prev?.variants ?? {},
+    }
+  }
+
+  export async function get(
+    baseURL: string,
+    headers: HeadersInit = {},
+    existing: Record<string, Model> = {},
+  ): Promise<Record<string, Model>> {
+    const data = await fetch(`${baseURL}/models`, {
+      headers,
+    }).then(async (res) => {
+      if (!res.ok) {
+        throw new Error(`Failed to fetch models: ${res.status}`)
+      }
+      return schema.parse(await res.json())
+    })
+
+    const result = { ...existing }
+    const remote = new Map(data.data.filter((m) => m.model_picker_enabled).map((m) => [m.id, m] as const))
+
+    // prune existing models whose api.id isn't in the endpoint response
+    for (const [key, model] of Object.entries(result)) {
+      const m = remote.get(model.api.id)
+      if (!m) {
+        delete result[key]
+        continue
+      }
+      result[key] = build(key, m, baseURL, model)
+    }
+
+    // add new endpoint models not already keyed in result
+    for (const [id, m] of remote) {
+      if (id in result) continue
+      result[id] = build(id, m, baseURL)
+    }
+
+    return result
+  }
+}

+ 1 - 1
packages/opencode/src/plugin/index.ts

@@ -7,7 +7,7 @@ import { Flag } from "../flag/flag"
 import { CodexAuthPlugin } from "./codex"
 import { Session } from "../session"
 import { NamedError } from "@opencode-ai/util/error"
-import { CopilotAuthPlugin } from "./copilot"
+import { CopilotAuthPlugin } from "./github-copilot/copilot"
 import { gitlabAuthPlugin as GitlabAuthPlugin } from "opencode-gitlab-auth"
 import { PoeAuthPlugin } from "opencode-poe-auth"
 import { Effect, Layer, ServiceMap, Stream } from "effect"

+ 43 - 16
packages/opencode/src/provider/provider.ts

@@ -1178,6 +1178,49 @@ export namespace Provider {
             mergeProvider(providerID, partial)
           }
 
+          const gitlab = ProviderID.make("gitlab")
+          if (discoveryLoaders[gitlab] && providers[gitlab] && isProviderAllowed(gitlab)) {
+            yield* Effect.promise(async () => {
+              try {
+                const discovered = await discoveryLoaders[gitlab]()
+                for (const [modelID, model] of Object.entries(discovered)) {
+                  if (!providers[gitlab].models[modelID]) {
+                    providers[gitlab].models[modelID] = model
+                  }
+                }
+              } catch (e) {
+                log.warn("state discovery error", { id: "gitlab", error: e })
+              }
+            })
+          }
+
+          for (const hook of plugins) {
+            const p = hook.provider
+            const models = p?.models
+            if (!p || !models) continue
+
+            const providerID = ProviderID.make(p.id)
+            if (disabled.has(providerID)) continue
+
+            const provider = providers[providerID]
+            if (!provider) continue
+            const pluginAuth = yield* auth.get(providerID).pipe(Effect.orDie)
+
+            provider.models = yield* Effect.promise(async () => {
+              const next = await models(provider, { auth: pluginAuth })
+              return Object.fromEntries(
+                Object.entries(next).map(([id, model]) => [
+                  id,
+                  {
+                    ...model,
+                    id: ModelID.make(id),
+                    providerID,
+                  },
+                ]),
+              )
+            })
+          }
+
           for (const [id, provider] of Object.entries(providers)) {
             const providerID = ProviderID.make(id)
             if (!isProviderAllowed(providerID)) {
@@ -1222,22 +1265,6 @@ export namespace Provider {
             log.info("found", { providerID })
           }
 
-          const gitlab = ProviderID.make("gitlab")
-          if (discoveryLoaders[gitlab] && providers[gitlab]) {
-            yield* Effect.promise(async () => {
-              try {
-                const discovered = await discoveryLoaders[gitlab]()
-                for (const [modelID, model] of Object.entries(discovered)) {
-                  if (!providers[gitlab].models[modelID]) {
-                    providers[gitlab].models[modelID] = model
-                  }
-                }
-              } catch (e) {
-                log.warn("state discovery error", { id: "gitlab", error: e })
-              }
-            })
-          }
-
           return {
             models: languages,
             providers,

+ 117 - 0
packages/opencode/test/plugin/github-copilot-models.test.ts

@@ -0,0 +1,117 @@
+import { afterEach, expect, mock, test } from "bun:test"
+import { CopilotModels } from "@/plugin/github-copilot/models"
+
+const originalFetch = globalThis.fetch
+
+afterEach(() => {
+  globalThis.fetch = originalFetch
+})
+
+test("preserves temperature support from existing provider models", async () => {
+  globalThis.fetch = mock(() =>
+    Promise.resolve(
+      new Response(
+        JSON.stringify({
+          data: [
+            {
+              model_picker_enabled: true,
+              id: "gpt-4o",
+              name: "GPT-4o",
+              version: "gpt-4o-2024-05-13",
+              capabilities: {
+                family: "gpt",
+                limits: {
+                  max_context_window_tokens: 64000,
+                  max_output_tokens: 16384,
+                  max_prompt_tokens: 64000,
+                },
+                supports: {
+                  streaming: true,
+                  tool_calls: true,
+                },
+              },
+            },
+            {
+              model_picker_enabled: true,
+              id: "brand-new",
+              name: "Brand New",
+              version: "brand-new-2026-04-01",
+              capabilities: {
+                family: "test",
+                limits: {
+                  max_context_window_tokens: 32000,
+                  max_output_tokens: 8192,
+                  max_prompt_tokens: 32000,
+                },
+                supports: {
+                  streaming: true,
+                  tool_calls: false,
+                },
+              },
+            },
+          ],
+        }),
+        { status: 200 },
+      ),
+    ),
+  ) as unknown as typeof fetch
+
+  const models = await CopilotModels.get(
+    "https://api.githubcopilot.com",
+    {},
+    {
+      "gpt-4o": {
+        id: "gpt-4o",
+        providerID: "github-copilot",
+        api: {
+          id: "gpt-4o",
+          url: "https://api.githubcopilot.com",
+          npm: "@ai-sdk/openai-compatible",
+        },
+        name: "GPT-4o",
+        family: "gpt",
+        capabilities: {
+          temperature: true,
+          reasoning: false,
+          attachment: true,
+          toolcall: true,
+          input: {
+            text: true,
+            audio: false,
+            image: true,
+            video: false,
+            pdf: false,
+          },
+          output: {
+            text: true,
+            audio: false,
+            image: false,
+            video: false,
+            pdf: false,
+          },
+          interleaved: false,
+        },
+        cost: {
+          input: 0,
+          output: 0,
+          cache: {
+            read: 0,
+            write: 0,
+          },
+        },
+        limit: {
+          context: 64000,
+          output: 16384,
+        },
+        options: {},
+        headers: {},
+        release_date: "2024-05-13",
+        variants: {},
+        status: "active",
+      },
+    },
+  )
+
+  expect(models["gpt-4o"].capabilities.temperature).toBe(true)
+  expect(models["brand-new"].capabilities.temperature).toBe(true)
+})

+ 11 - 0
packages/plugin/src/index.ts

@@ -11,6 +11,7 @@ import type {
   Auth,
   Config as SDKConfig,
 } from "@opencode-ai/sdk"
+import type { Provider as ProviderV2, Model as ModelV2 } from "@opencode-ai/sdk/v2"
 
 import type { BunShell } from "./shell.js"
 import { type ToolDefinition } from "./tool.js"
@@ -173,6 +174,15 @@ export type AuthOAuthResult = { url: string; instructions: string } & (
     }
 )
 
+export type ProviderHookContext = {
+  auth?: Auth
+}
+
+export type ProviderHook = {
+  id: string
+  models?: (provider: ProviderV2, ctx: ProviderHookContext) => Promise<Record<string, ModelV2>>
+}
+
 /** @deprecated Use AuthOAuthResult instead. */
 export type AuthOuathResult = AuthOAuthResult
 
@@ -183,6 +193,7 @@ export interface Hooks {
     [key: string]: ToolDefinition
   }
   auth?: AuthHook
+  provider?: ProviderHook
   /**
    * Called when a new message is received
    */