copilot.ts 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. import type { Hooks, PluginInput } from "@opencode-ai/plugin"
  2. import { Installation } from "@/installation"
  3. import { iife } from "@/util/iife"
  4. const CLIENT_ID = "Ov23li8tweQw6odWQebz"
  5. // Add a small safety buffer when polling to avoid hitting the server
  6. // slightly too early due to clock skew / timer drift.
  7. const OAUTH_POLLING_SAFETY_MARGIN_MS = 3000 // 3 seconds
  8. function normalizeDomain(url: string) {
  9. return url.replace(/^https?:\/\//, "").replace(/\/$/, "")
  10. }
  11. function getUrls(domain: string) {
  12. return {
  13. DEVICE_CODE_URL: `https://${domain}/login/device/code`,
  14. ACCESS_TOKEN_URL: `https://${domain}/login/oauth/access_token`,
  15. }
  16. }
  17. export async function CopilotAuthPlugin(input: PluginInput): Promise<Hooks> {
  18. const sdk = input.client
  19. return {
  20. auth: {
  21. provider: "github-copilot",
  22. async loader(getAuth, provider) {
  23. const info = await getAuth()
  24. if (!info || info.type !== "oauth") return {}
  25. const enterpriseUrl = info.enterpriseUrl
  26. const baseURL = enterpriseUrl ? `https://copilot-api.${normalizeDomain(enterpriseUrl)}` : undefined
  27. if (provider && provider.models) {
  28. for (const model of Object.values(provider.models)) {
  29. model.cost = {
  30. input: 0,
  31. output: 0,
  32. cache: {
  33. read: 0,
  34. write: 0,
  35. },
  36. }
  37. // TODO: re-enable once messages api has higher rate limits
  38. // TODO: move some of this hacky-ness to models.dev presets once we have better grasp of things here...
  39. // const base = baseURL ?? model.api.url
  40. // const claude = model.id.includes("claude")
  41. // const url = iife(() => {
  42. // if (!claude) return base
  43. // if (base.endsWith("/v1")) return base
  44. // if (base.endsWith("/")) return `${base}v1`
  45. // return `${base}/v1`
  46. // })
  47. // model.api.url = url
  48. // model.api.npm = claude ? "@ai-sdk/anthropic" : "@ai-sdk/github-copilot"
  49. model.api.npm = "@ai-sdk/github-copilot"
  50. }
  51. }
  52. return {
  53. baseURL,
  54. apiKey: "",
  55. async fetch(request: RequestInfo | URL, init?: RequestInit) {
  56. const info = await getAuth()
  57. if (info.type !== "oauth") return fetch(request, init)
  58. const url = request instanceof URL ? request.href : request.toString()
  59. const { isVision, isAgent } = iife(() => {
  60. try {
  61. const body = typeof init?.body === "string" ? JSON.parse(init.body) : init?.body
  62. // Completions API
  63. if (body?.messages && url.includes("completions")) {
  64. const last = body.messages[body.messages.length - 1]
  65. return {
  66. isVision: body.messages.some(
  67. (msg: any) =>
  68. Array.isArray(msg.content) && msg.content.some((part: any) => part.type === "image_url"),
  69. ),
  70. isAgent: last?.role !== "user",
  71. }
  72. }
  73. // Responses API
  74. if (body?.input) {
  75. const last = body.input[body.input.length - 1]
  76. return {
  77. isVision: body.input.some(
  78. (item: any) =>
  79. Array.isArray(item?.content) && item.content.some((part: any) => part.type === "input_image"),
  80. ),
  81. isAgent: last?.role !== "user",
  82. }
  83. }
  84. // Messages API
  85. if (body?.messages) {
  86. const last = body.messages[body.messages.length - 1]
  87. const hasNonToolCalls =
  88. Array.isArray(last?.content) && last.content.some((part: any) => part?.type !== "tool_result")
  89. return {
  90. isVision: body.messages.some(
  91. (item: any) =>
  92. Array.isArray(item?.content) &&
  93. item.content.some(
  94. (part: any) =>
  95. part?.type === "image" ||
  96. // images can be nested inside tool_result content
  97. (part?.type === "tool_result" &&
  98. Array.isArray(part?.content) &&
  99. part.content.some((nested: any) => nested?.type === "image")),
  100. ),
  101. ),
  102. isAgent: !(last?.role === "user" && hasNonToolCalls),
  103. }
  104. }
  105. } catch {}
  106. return { isVision: false, isAgent: false }
  107. })
  108. const headers: Record<string, string> = {
  109. "x-initiator": isAgent ? "agent" : "user",
  110. ...(init?.headers as Record<string, string>),
  111. "User-Agent": `opencode/${Installation.VERSION}`,
  112. Authorization: `Bearer ${info.refresh}`,
  113. "Openai-Intent": "conversation-edits",
  114. }
  115. if (isVision) {
  116. headers["Copilot-Vision-Request"] = "true"
  117. }
  118. delete headers["x-api-key"]
  119. delete headers["authorization"]
  120. return fetch(request, {
  121. ...init,
  122. headers,
  123. })
  124. },
  125. }
  126. },
  127. methods: [
  128. {
  129. type: "oauth",
  130. label: "Login with GitHub Copilot",
  131. prompts: [
  132. {
  133. type: "select",
  134. key: "deploymentType",
  135. message: "Select GitHub deployment type",
  136. options: [
  137. {
  138. label: "GitHub.com",
  139. value: "github.com",
  140. hint: "Public",
  141. },
  142. {
  143. label: "GitHub Enterprise",
  144. value: "enterprise",
  145. hint: "Data residency or self-hosted",
  146. },
  147. ],
  148. },
  149. {
  150. type: "text",
  151. key: "enterpriseUrl",
  152. message: "Enter your GitHub Enterprise URL or domain",
  153. placeholder: "company.ghe.com or https://company.ghe.com",
  154. condition: (inputs) => inputs.deploymentType === "enterprise",
  155. validate: (value) => {
  156. if (!value) return "URL or domain is required"
  157. try {
  158. const url = value.includes("://") ? new URL(value) : new URL(`https://${value}`)
  159. if (!url.hostname) return "Please enter a valid URL or domain"
  160. return undefined
  161. } catch {
  162. return "Please enter a valid URL (e.g., company.ghe.com or https://company.ghe.com)"
  163. }
  164. },
  165. },
  166. ],
  167. async authorize(inputs = {}) {
  168. const deploymentType = inputs.deploymentType || "github.com"
  169. let domain = "github.com"
  170. let actualProvider = "github-copilot"
  171. if (deploymentType === "enterprise") {
  172. const enterpriseUrl = inputs.enterpriseUrl
  173. domain = normalizeDomain(enterpriseUrl!)
  174. actualProvider = "github-copilot-enterprise"
  175. }
  176. const urls = getUrls(domain)
  177. const deviceResponse = await fetch(urls.DEVICE_CODE_URL, {
  178. method: "POST",
  179. headers: {
  180. Accept: "application/json",
  181. "Content-Type": "application/json",
  182. "User-Agent": `opencode/${Installation.VERSION}`,
  183. },
  184. body: JSON.stringify({
  185. client_id: CLIENT_ID,
  186. scope: "read:user",
  187. }),
  188. })
  189. if (!deviceResponse.ok) {
  190. throw new Error("Failed to initiate device authorization")
  191. }
  192. const deviceData = (await deviceResponse.json()) as {
  193. verification_uri: string
  194. user_code: string
  195. device_code: string
  196. interval: number
  197. }
  198. return {
  199. url: deviceData.verification_uri,
  200. instructions: `Enter code: ${deviceData.user_code}`,
  201. method: "auto" as const,
  202. async callback() {
  203. while (true) {
  204. const response = await fetch(urls.ACCESS_TOKEN_URL, {
  205. method: "POST",
  206. headers: {
  207. Accept: "application/json",
  208. "Content-Type": "application/json",
  209. "User-Agent": `opencode/${Installation.VERSION}`,
  210. },
  211. body: JSON.stringify({
  212. client_id: CLIENT_ID,
  213. device_code: deviceData.device_code,
  214. grant_type: "urn:ietf:params:oauth:grant-type:device_code",
  215. }),
  216. })
  217. if (!response.ok) return { type: "failed" as const }
  218. const data = (await response.json()) as {
  219. access_token?: string
  220. error?: string
  221. interval?: number
  222. }
  223. if (data.access_token) {
  224. const result: {
  225. type: "success"
  226. refresh: string
  227. access: string
  228. expires: number
  229. provider?: string
  230. enterpriseUrl?: string
  231. } = {
  232. type: "success",
  233. refresh: data.access_token,
  234. access: data.access_token,
  235. expires: 0,
  236. }
  237. if (actualProvider === "github-copilot-enterprise") {
  238. result.provider = "github-copilot-enterprise"
  239. result.enterpriseUrl = domain
  240. }
  241. return result
  242. }
  243. if (data.error === "authorization_pending") {
  244. await Bun.sleep(deviceData.interval * 1000 + OAUTH_POLLING_SAFETY_MARGIN_MS)
  245. continue
  246. }
  247. if (data.error === "slow_down") {
  248. // Based on the RFC spec, we must add 5 seconds to our current polling interval.
  249. // (See https://www.rfc-editor.org/rfc/rfc8628#section-3.5)
  250. let newInterval = (deviceData.interval + 5) * 1000
  251. // GitHub OAuth API may return the new interval in seconds in the response.
  252. // We should try to use that if provided with safety margin.
  253. const serverInterval = data.interval
  254. if (serverInterval && typeof serverInterval === "number" && serverInterval > 0) {
  255. newInterval = serverInterval * 1000
  256. }
  257. await Bun.sleep(newInterval + OAUTH_POLLING_SAFETY_MARGIN_MS)
  258. continue
  259. }
  260. if (data.error) return { type: "failed" as const }
  261. await Bun.sleep(deviceData.interval * 1000 + OAUTH_POLLING_SAFETY_MARGIN_MS)
  262. continue
  263. }
  264. },
  265. }
  266. },
  267. },
  268. ],
  269. },
  270. "chat.headers": async (incoming, output) => {
  271. if (!incoming.model.providerID.includes("github-copilot")) return
  272. if (incoming.model.api.npm === "@ai-sdk/anthropic") {
  273. output.headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
  274. }
  275. const session = await sdk.session
  276. .get({
  277. path: {
  278. id: incoming.sessionID,
  279. },
  280. query: {
  281. directory: input.directory,
  282. },
  283. throwOnError: true,
  284. })
  285. .catch(() => undefined)
  286. if (!session || !session.data.parentID) return
  287. // mark subagent sessions as agent initiated matching standard that other copilot tools have
  288. output.headers["x-initiator"] = "agent"
  289. },
  290. }
  291. }