Przeglądaj źródła

refactor(provider): effectify ProviderAuthService (#17227)

Kit Langton 1 miesiąc temu
rodzic
commit
dd68b85f58

+ 4 - 2
packages/opencode/src/project/instance.ts

@@ -1,3 +1,4 @@
+import { Effect } from "effect"
 import { Log } from "@/util/log"
 import { Context } from "../util/context"
 import { Project } from "./project"
@@ -5,6 +6,7 @@ import { State } from "./state"
 import { iife } from "@/util/iife"
 import { GlobalBus } from "@/bus/global"
 import { Filesystem } from "@/util/filesystem"
+import { InstanceState } from "@/util/instance-state"
 
 interface Context {
   directory: string
@@ -106,7 +108,7 @@ export const Instance = {
   async reload(input: { directory: string; init?: () => Promise<any>; project?: Project.Info; worktree?: string }) {
     const directory = Filesystem.resolve(input.directory)
     Log.Default.info("reloading instance", { directory })
-    await State.dispose(directory)
+    await Promise.all([State.dispose(directory), Effect.runPromise(InstanceState.dispose(directory))])
     cache.delete(directory)
     const next = track(directory, boot({ ...input, directory }))
     emit(directory)
@@ -114,7 +116,7 @@ export const Instance = {
   },
   async dispose() {
     Log.Default.info("disposing instance", { directory: Instance.directory })
-    await State.dispose(Instance.directory)
+    await Promise.all([State.dispose(Instance.directory), Effect.runPromise(InstanceState.dispose(Instance.directory))])
     cache.delete(Instance.directory)
     emit(Instance.directory)
   },

+ 169 - 0
packages/opencode/src/provider/auth-service.ts

@@ -0,0 +1,169 @@
+import { Effect, Layer, Record, ServiceMap, Struct } from "effect"
+import { Instance } from "@/project/instance"
+import { Plugin } from "../plugin"
+import { filter, fromEntries, map, pipe } from "remeda"
+import type { AuthOuathResult } from "@opencode-ai/plugin"
+import { NamedError } from "@opencode-ai/util/error"
+import * as Auth from "@/auth/service"
+import { InstanceState } from "@/util/instance-state"
+import { ProviderID } from "./schema"
+import z from "zod"
+
+export const Method = z
+  .object({
+    type: z.union([z.literal("oauth"), z.literal("api")]),
+    label: z.string(),
+  })
+  .meta({
+    ref: "ProviderAuthMethod",
+  })
+export type Method = z.infer<typeof Method>
+
+export const Authorization = z
+  .object({
+    url: z.string(),
+    method: z.union([z.literal("auto"), z.literal("code")]),
+    instructions: z.string(),
+  })
+  .meta({
+    ref: "ProviderAuthAuthorization",
+  })
+export type Authorization = z.infer<typeof Authorization>
+
+export const OauthMissing = NamedError.create(
+  "ProviderAuthOauthMissing",
+  z.object({
+    providerID: ProviderID.zod,
+  }),
+)
+
+export const OauthCodeMissing = NamedError.create(
+  "ProviderAuthOauthCodeMissing",
+  z.object({
+    providerID: ProviderID.zod,
+  }),
+)
+
+export const OauthCallbackFailed = NamedError.create("ProviderAuthOauthCallbackFailed", z.object({}))
+
+export type ProviderAuthError =
+  | Auth.AuthServiceError
+  | InstanceType<typeof OauthMissing>
+  | InstanceType<typeof OauthCodeMissing>
+  | InstanceType<typeof OauthCallbackFailed>
+
+export namespace ProviderAuthService {
+  export interface Service {
+    /** Get available auth methods for each provider (e.g. OAuth, API key). */
+    readonly methods: () => Effect.Effect<Record<string, Method[]>>
+
+    /** Start an OAuth authorization flow for a provider. Returns the URL to redirect to. */
+    readonly authorize: (input: { providerID: ProviderID; method: number }) => Effect.Effect<Authorization | undefined>
+
+    /** Complete an OAuth flow after the user has authorized. Exchanges the code/callback for credentials. */
+    readonly callback: (input: {
+      providerID: ProviderID
+      method: number
+      code?: string
+    }) => Effect.Effect<void, ProviderAuthError>
+
+    /** Set an API key directly for a provider (no OAuth flow). */
+    readonly api: (input: { providerID: ProviderID; key: string }) => Effect.Effect<void, Auth.AuthServiceError>
+  }
+}
+
+export class ProviderAuthService extends ServiceMap.Service<ProviderAuthService, ProviderAuthService.Service>()(
+  "@opencode/ProviderAuth",
+) {
+  static readonly layer = Layer.effect(
+    ProviderAuthService,
+    Effect.gen(function* () {
+      const auth = yield* Auth.AuthService
+      const state = yield* InstanceState.make({
+        lookup: () =>
+          Effect.promise(async () => {
+            const methods = pipe(
+              await Plugin.list(),
+              filter((x) => x.auth?.provider !== undefined),
+              map((x) => [x.auth!.provider, x.auth!] as const),
+              fromEntries(),
+            )
+            return { methods, pending: new Map<ProviderID, AuthOuathResult>() }
+          }),
+      })
+
+      const methods = Effect.fn("ProviderAuthService.methods")(function* () {
+        const x = yield* InstanceState.get(state)
+        return Record.map(x.methods, (y) => y.methods.map((z): Method => Struct.pick(z, ["type", "label"])))
+      })
+
+      const authorize = Effect.fn("ProviderAuthService.authorize")(function* (input: {
+        providerID: ProviderID
+        method: number
+      }) {
+        const s = yield* InstanceState.get(state)
+        const method = s.methods[input.providerID].methods[input.method]
+        if (method.type !== "oauth") return
+        const result = yield* Effect.promise(() => method.authorize())
+        s.pending.set(input.providerID, result)
+        return {
+          url: result.url,
+          method: result.method,
+          instructions: result.instructions,
+        }
+      })
+
+      const callback = Effect.fn("ProviderAuthService.callback")(function* (input: {
+        providerID: ProviderID
+        method: number
+        code?: string
+      }) {
+        const s = yield* InstanceState.get(state)
+        const match = s.pending.get(input.providerID)
+        if (!match) return yield* Effect.fail(new OauthMissing({ providerID: input.providerID }))
+
+        if (match.method === "code" && !input.code)
+          return yield* Effect.fail(new OauthCodeMissing({ providerID: input.providerID }))
+
+        const result = yield* Effect.promise(() =>
+          match.method === "code" ? match.callback(input.code!) : match.callback(),
+        )
+
+        if (!result || result.type !== "success") return yield* Effect.fail(new OauthCallbackFailed({}))
+
+        if ("key" in result) {
+          yield* auth.set(input.providerID, {
+            type: "api",
+            key: result.key,
+          })
+        }
+
+        if ("refresh" in result) {
+          yield* auth.set(input.providerID, {
+            type: "oauth",
+            access: result.access,
+            refresh: result.refresh,
+            expires: result.expires,
+            ...(result.accountId ? { accountId: result.accountId } : {}),
+          })
+        }
+      })
+
+      const api = Effect.fn("ProviderAuthService.api")(function* (input: { providerID: ProviderID; key: string }) {
+        yield* auth.set(input.providerID, {
+          type: "api",
+          key: input.key,
+        })
+      })
+
+      return ProviderAuthService.of({
+        methods,
+        authorize,
+        callback,
+        api,
+      })
+    }),
+  )
+
+  static readonly defaultLayer = ProviderAuthService.layer.pipe(Layer.provide(Auth.AuthService.defaultLayer))
+}

+ 23 - 115
packages/opencode/src/provider/auth.ts

@@ -1,75 +1,36 @@
-import { Instance } from "@/project/instance"
-import { Plugin } from "../plugin"
-import { map, filter, pipe, fromEntries, mapValues } from "remeda"
+import { Effect, ManagedRuntime } from "effect"
 import z from "zod"
+
 import { fn } from "@/util/fn"
-import type { AuthOuathResult, Hooks } from "@opencode-ai/plugin"
-import { NamedError } from "@opencode-ai/util/error"
-import { Auth } from "@/auth"
+import * as S from "./auth-service"
 import { ProviderID } from "./schema"
 
-export namespace ProviderAuth {
-  const state = Instance.state(async () => {
-    const methods = pipe(
-      await Plugin.list(),
-      filter((x) => x.auth?.provider !== undefined),
-      map((x) => [x.auth!.provider, x.auth!] as const),
-      fromEntries(),
-    )
-    return { methods, pending: {} as Record<string, AuthOuathResult> }
-  })
+// Separate runtime: ProviderAuthService can't join the shared runtime because
+// runtime.ts → auth-service.ts → provider/auth.ts creates a circular import.
+// AuthService is stateless file I/O so the duplicate instance is harmless.
+const rt = ManagedRuntime.make(S.ProviderAuthService.defaultLayer)
 
-  export const Method = z
-    .object({
-      type: z.union([z.literal("oauth"), z.literal("api")]),
-      label: z.string(),
-    })
-    .meta({
-      ref: "ProviderAuthMethod",
-    })
-  export type Method = z.infer<typeof Method>
+function runPromise<A>(f: (service: S.ProviderAuthService.Service) => Effect.Effect<A, S.ProviderAuthError>) {
+  return rt.runPromise(S.ProviderAuthService.use(f))
+}
+
+export namespace ProviderAuth {
+  export const Method = S.Method
+  export type Method = S.Method
 
   export async function methods() {
-    const s = await state().then((x) => x.methods)
-    return mapValues(s, (x) =>
-      x.methods.map(
-        (y): Method => ({
-          type: y.type,
-          label: y.label,
-        }),
-      ),
-    )
+    return runPromise((service) => service.methods())
   }
 
-  export const Authorization = z
-    .object({
-      url: z.string(),
-      method: z.union([z.literal("auto"), z.literal("code")]),
-      instructions: z.string(),
-    })
-    .meta({
-      ref: "ProviderAuthAuthorization",
-    })
-  export type Authorization = z.infer<typeof Authorization>
+  export const Authorization = S.Authorization
+  export type Authorization = S.Authorization
 
   export const authorize = fn(
     z.object({
       providerID: ProviderID.zod,
       method: z.number(),
     }),
-    async (input): Promise<Authorization | undefined> => {
-      const auth = await state().then((s) => s.methods[input.providerID])
-      const method = auth.methods[input.method]
-      if (method.type === "oauth") {
-        const result = await method.authorize()
-        await state().then((s) => (s.pending[input.providerID] = result))
-        return {
-          url: result.url,
-          method: result.method,
-          instructions: result.instructions,
-        }
-      }
-    },
+    async (input): Promise<Authorization | undefined> => runPromise((service) => service.authorize(input)),
   )
 
   export const callback = fn(
@@ -78,44 +39,7 @@ export namespace ProviderAuth {
       method: z.number(),
       code: z.string().optional(),
     }),
-    async (input) => {
-      const match = await state().then((s) => s.pending[input.providerID])
-      if (!match) throw new OauthMissing({ providerID: input.providerID })
-      let result
-
-      if (match.method === "code") {
-        if (!input.code) throw new OauthCodeMissing({ providerID: input.providerID })
-        result = await match.callback(input.code)
-      }
-
-      if (match.method === "auto") {
-        result = await match.callback()
-      }
-
-      if (result?.type === "success") {
-        if ("key" in result) {
-          await Auth.set(input.providerID, {
-            type: "api",
-            key: result.key,
-          })
-        }
-        if ("refresh" in result) {
-          const info: Auth.Info = {
-            type: "oauth",
-            access: result.access,
-            refresh: result.refresh,
-            expires: result.expires,
-          }
-          if (result.accountId) {
-            info.accountId = result.accountId
-          }
-          await Auth.set(input.providerID, info)
-        }
-        return
-      }
-
-      throw new OauthCallbackFailed({})
-    },
+    async (input) => runPromise((service) => service.callback(input)),
   )
 
   export const api = fn(
@@ -123,26 +47,10 @@ export namespace ProviderAuth {
       providerID: ProviderID.zod,
       key: z.string(),
     }),
-    async (input) => {
-      await Auth.set(input.providerID, {
-        type: "api",
-        key: input.key,
-      })
-    },
-  )
-
-  export const OauthMissing = NamedError.create(
-    "ProviderAuthOauthMissing",
-    z.object({
-      providerID: ProviderID.zod,
-    }),
-  )
-  export const OauthCodeMissing = NamedError.create(
-    "ProviderAuthOauthCodeMissing",
-    z.object({
-      providerID: ProviderID.zod,
-    }),
+    async (input) => runPromise((service) => service.api(input)),
   )
 
-  export const OauthCallbackFailed = NamedError.create("ProviderAuthOauthCallbackFailed", z.object({}))
+  export import OauthMissing = S.OauthMissing
+  export import OauthCodeMissing = S.OauthCodeMissing
+  export import OauthCallbackFailed = S.OauthCallbackFailed
 }

+ 50 - 0
packages/opencode/src/util/instance-state.ts

@@ -0,0 +1,50 @@
+import { Effect, ScopedCache, Scope } from "effect"
+
+import { Instance } from "@/project/instance"
+
+const TypeId = Symbol.for("@opencode/InstanceState")
+
+type Task = (key: string) => Effect.Effect<void>
+
+const tasks = new Set<Task>()
+
+export namespace InstanceState {
+  export interface State<A, E = never, R = never> {
+    readonly [TypeId]: typeof TypeId
+    readonly cache: ScopedCache.ScopedCache<string, A, E, R>
+  }
+
+  export const make = <A, E = never, R = never>(input: {
+    lookup: (key: string) => Effect.Effect<A, E, R>
+    release?: (value: A, key: string) => Effect.Effect<void>
+  }): Effect.Effect<State<A, E, R>, never, R | Scope.Scope> =>
+    Effect.gen(function* () {
+      const cache = yield* ScopedCache.make<string, A, E, R>({
+        capacity: Number.POSITIVE_INFINITY,
+        lookup: (key) =>
+          Effect.acquireRelease(input.lookup(key), (value) => (input.release ? input.release(value, key) : Effect.void)),
+      })
+
+      const task: Task = (key) => ScopedCache.invalidate(cache, key)
+      tasks.add(task)
+      yield* Effect.addFinalizer(() => Effect.sync(() => void tasks.delete(task)))
+
+      return {
+        [TypeId]: TypeId,
+        cache,
+      }
+    })
+
+  export const get = <A, E, R>(self: State<A, E, R>) => ScopedCache.get(self.cache, Instance.directory)
+
+  export const has = <A, E, R>(self: State<A, E, R>) => ScopedCache.has(self.cache, Instance.directory)
+
+  export const invalidate = <A, E, R>(self: State<A, E, R>) =>
+    ScopedCache.invalidate(self.cache, Instance.directory)
+
+  export const dispose = (key: string) =>
+    Effect.all(
+      [...tasks].map((task) => task(key)),
+      { concurrency: "unbounded" },
+    )
+}

+ 115 - 0
packages/opencode/test/project/state.test.ts

@@ -0,0 +1,115 @@
+import { afterEach, expect, test } from "bun:test"
+
+import { Instance } from "../../src/project/instance"
+import { tmpdir } from "../fixture/fixture"
+
+afterEach(async () => {
+  await Instance.disposeAll()
+})
+
+test("Instance.state caches values for the same instance", async () => {
+  await using tmp = await tmpdir()
+  let n = 0
+  const state = Instance.state(() => ({ n: ++n }))
+
+  await Instance.provide({
+    directory: tmp.path,
+    fn: async () => {
+      const a = state()
+      const b = state()
+      expect(a).toBe(b)
+      expect(n).toBe(1)
+    },
+  })
+})
+
+test("Instance.state isolates values by directory", async () => {
+  await using a = await tmpdir()
+  await using b = await tmpdir()
+  let n = 0
+  const state = Instance.state(() => ({ n: ++n }))
+
+  const x = await Instance.provide({
+    directory: a.path,
+    fn: async () => state(),
+  })
+  const y = await Instance.provide({
+    directory: b.path,
+    fn: async () => state(),
+  })
+  const z = await Instance.provide({
+    directory: a.path,
+    fn: async () => state(),
+  })
+
+  expect(x).toBe(z)
+  expect(x).not.toBe(y)
+  expect(n).toBe(2)
+})
+
+test("Instance.state is disposed on instance reload", async () => {
+  await using tmp = await tmpdir()
+  const seen: string[] = []
+  let n = 0
+  const state = Instance.state(
+    () => ({ n: ++n }),
+    async (value) => {
+      seen.push(String(value.n))
+    },
+  )
+
+  const a = await Instance.provide({
+    directory: tmp.path,
+    fn: async () => state(),
+  })
+  await Instance.reload({ directory: tmp.path })
+  const b = await Instance.provide({
+    directory: tmp.path,
+    fn: async () => state(),
+  })
+
+  expect(a).not.toBe(b)
+  expect(seen).toEqual(["1"])
+})
+
+test("Instance.state is disposed on disposeAll", async () => {
+  await using a = await tmpdir()
+  await using b = await tmpdir()
+  const seen: string[] = []
+  const state = Instance.state(
+    () => ({ dir: Instance.directory }),
+    async (value) => {
+      seen.push(value.dir)
+    },
+  )
+
+  await Instance.provide({
+    directory: a.path,
+    fn: async () => state(),
+  })
+  await Instance.provide({
+    directory: b.path,
+    fn: async () => state(),
+  })
+  await Instance.disposeAll()
+
+  expect(seen.sort()).toEqual([a.path, b.path].sort())
+})
+
+test("Instance.state dedupes concurrent promise initialization", async () => {
+  await using tmp = await tmpdir()
+  let n = 0
+  const state = Instance.state(async () => {
+    n += 1
+    await Bun.sleep(10)
+    return { n }
+  })
+
+  const [a, b] = await Instance.provide({
+    directory: tmp.path,
+    fn: async () => Promise.all([state(), state()]),
+  })
+
+  expect(a).toBe(b)
+  expect(n).toBe(1)
+})

+ 20 - 0
packages/opencode/test/provider/auth.test.ts

@@ -0,0 +1,20 @@
+import { afterEach, expect, test } from "bun:test"
+import { Auth } from "../../src/auth"
+import { ProviderAuth } from "../../src/provider/auth"
+import { ProviderID } from "../../src/provider/schema"
+
+afterEach(async () => {
+  await Auth.remove("test-provider-auth")
+})
+
+test("ProviderAuth.api persists auth via AuthService", async () => {
+  await ProviderAuth.api({
+    providerID: ProviderID.make("test-provider-auth"),
+    key: "sk-test",
+  })
+
+  expect(await Auth.get("test-provider-auth")).toEqual({
+    type: "api",
+    key: "sk-test",
+  })
+})

+ 139 - 0
packages/opencode/test/util/instance-state.test.ts

@@ -0,0 +1,139 @@
+import { afterEach, expect, test } from "bun:test"
+import { Effect } from "effect"
+
+import { Instance } from "../../src/project/instance"
+import { InstanceState } from "../../src/util/instance-state"
+import { tmpdir } from "../fixture/fixture"
+
+async function access<A, E>(state: InstanceState.State<A, E>, dir: string) {
+  return Instance.provide({
+    directory: dir,
+    fn: () => Effect.runPromise(InstanceState.get(state)),
+  })
+}
+
+afterEach(async () => {
+  await Instance.disposeAll()
+})
+
+test("InstanceState caches values for the same instance", async () => {
+  await using tmp = await tmpdir()
+  let n = 0
+
+  await Effect.runPromise(
+    Effect.scoped(
+      Effect.gen(function* () {
+        const state = yield* InstanceState.make({
+          lookup: () => Effect.sync(() => ({ n: ++n })),
+        })
+
+        const a = yield* Effect.promise(() => access(state, tmp.path))
+        const b = yield* Effect.promise(() => access(state, tmp.path))
+
+        expect(a).toBe(b)
+        expect(n).toBe(1)
+      }),
+    ),
+  )
+})
+
+test("InstanceState isolates values by directory", async () => {
+  await using a = await tmpdir()
+  await using b = await tmpdir()
+  let n = 0
+
+  await Effect.runPromise(
+    Effect.scoped(
+      Effect.gen(function* () {
+        const state = yield* InstanceState.make({
+          lookup: (dir) => Effect.sync(() => ({ dir, n: ++n })),
+        })
+
+        const x = yield* Effect.promise(() => access(state, a.path))
+        const y = yield* Effect.promise(() => access(state, b.path))
+        const z = yield* Effect.promise(() => access(state, a.path))
+
+        expect(x).toBe(z)
+        expect(x).not.toBe(y)
+        expect(n).toBe(2)
+      }),
+    ),
+  )
+})
+
+test("InstanceState is disposed on instance reload", async () => {
+  await using tmp = await tmpdir()
+  const seen: string[] = []
+  let n = 0
+
+  await Effect.runPromise(
+    Effect.scoped(
+      Effect.gen(function* () {
+        const state = yield* InstanceState.make({
+          lookup: () => Effect.sync(() => ({ n: ++n })),
+          release: (value) =>
+            Effect.sync(() => {
+              seen.push(String(value.n))
+            }),
+        })
+
+        const a = yield* Effect.promise(() => access(state, tmp.path))
+        yield* Effect.promise(() => Instance.reload({ directory: tmp.path }))
+        const b = yield* Effect.promise(() => access(state, tmp.path))
+
+        expect(a).not.toBe(b)
+        expect(seen).toEqual(["1"])
+      }),
+    ),
+  )
+})
+
+test("InstanceState is disposed on disposeAll", async () => {
+  await using a = await tmpdir()
+  await using b = await tmpdir()
+  const seen: string[] = []
+
+  await Effect.runPromise(
+    Effect.scoped(
+      Effect.gen(function* () {
+        const state = yield* InstanceState.make({
+          lookup: (dir) => Effect.sync(() => ({ dir })),
+          release: (value) =>
+            Effect.sync(() => {
+              seen.push(value.dir)
+            }),
+        })
+
+        yield* Effect.promise(() => access(state, a.path))
+        yield* Effect.promise(() => access(state, b.path))
+        yield* Effect.promise(() => Instance.disposeAll())
+
+        expect(seen.sort()).toEqual([a.path, b.path].sort())
+      }),
+    ),
+  )
+})
+
+test("InstanceState dedupes concurrent lookups for the same directory", async () => {
+  await using tmp = await tmpdir()
+  let n = 0
+
+  await Effect.runPromise(
+    Effect.scoped(
+      Effect.gen(function* () {
+        const state = yield* InstanceState.make({
+          lookup: () =>
+            Effect.promise(async () => {
+              n += 1
+              await Bun.sleep(10)
+              return { n }
+            }),
+        })
+
+        const [a, b] = yield* Effect.promise(() => Promise.all([access(state, tmp.path), access(state, tmp.path)]))
+        expect(a).toBe(b)
+        expect(n).toBe(1)
+      }),
+    ),
+  )
+})