瀏覽代碼

feat(effect-zod): transform support + walk memoization + flattened checks (#23203)

Kit Langton 1 天之前
父節點
當前提交
f3d1fd9ce8
共有 2 個文件被更改,包括 213 次插入15 次删除
  1. 66 14
      packages/opencode/src/util/effect-zod.ts
  2. 147 1
      packages/opencode/test/util/effect-zod.test.ts

+ 66 - 14
packages/opencode/src/util/effect-zod.ts

@@ -1,4 +1,4 @@
-import { Schema, SchemaAST } from "effect"
+import { Effect, Option, Schema, SchemaAST } from "effect"
 import z from "zod"
 
 /**
@@ -8,33 +8,85 @@ import z from "zod"
  */
 export const ZodOverride: unique symbol = Symbol.for("effect-zod/override")
 
+// AST nodes are immutable and frequently shared across schemas (e.g. a single
+// Schema.Class embedded in multiple parents). Memoizing by node identity
+// avoids rebuilding equivalent Zod subtrees and keeps derived children stable
+// by reference across callers.
+const walkCache = new WeakMap<SchemaAST.AST, z.ZodTypeAny>()
+
+// Shared empty ParseOptions for the rare callers that need one — avoids
+// allocating a fresh object per parse inside refinements and transforms.
+const EMPTY_PARSE_OPTIONS = {} as SchemaAST.ParseOptions
+
 export function zod<S extends Schema.Top>(schema: S): z.ZodType<Schema.Schema.Type<S>> {
   return walk(schema.ast) as z.ZodType<Schema.Schema.Type<S>>
 }
 
 function walk(ast: SchemaAST.AST): z.ZodTypeAny {
+  const cached = walkCache.get(ast)
+  if (cached) return cached
+  const result = walkUncached(ast)
+  walkCache.set(ast, result)
+  return result
+}
+
+function walkUncached(ast: SchemaAST.AST): z.ZodTypeAny {
   const override = (ast.annotations as any)?.[ZodOverride] as z.ZodTypeAny | undefined
   if (override) return override
 
-  let out = body(ast)
-  for (const check of ast.checks ?? []) {
-    out = applyCheck(out, check, ast)
-  }
+  // Schema.Class wraps its fields in a Declaration AST plus an encoding that
+  // constructs the class instance. For the Zod derivation we want the plain
+  // field shape (the decoded/consumer view), not the class instance — so
+  // Declarations fall through to body(), not encoded(). User-level
+  // Schema.decodeTo / Schema.transform attach encoding to non-Declaration
+  // nodes, where we do apply the transform.
+  const hasTransform = ast.encoding?.length && ast._tag !== "Declaration"
+  const base = hasTransform ? encoded(ast) : body(ast)
+  const out = ast.checks?.length ? applyChecks(base, ast.checks, ast) : base
   const desc = SchemaAST.resolveDescription(ast)
   const ref = SchemaAST.resolveIdentifier(ast)
-  const next = desc ? out.describe(desc) : out
-  return ref ? next.meta({ ref }) : next
+  const described = desc ? out.describe(desc) : out
+  return ref ? described.meta({ ref }) : described
+}
+
+// Walk the encoded side and apply each link's decode to produce the decoded
+// shape. A node `Target` produced by `from.decodeTo(Target)` carries
+// `Target.encoding = [Link(from, transformation)]`. Chained decodeTo calls
+// nest the encoding via `Link.to` so walking it recursively threads all
+// prior transforms — typical encoding.length is 1.
+function encoded(ast: SchemaAST.AST): z.ZodTypeAny {
+  const encoding = ast.encoding!
+  return encoding.reduce<z.ZodTypeAny>((acc, link) => acc.transform((v) => decode(link.transformation, v)), walk(encoding[0].to))
+}
+
+// Transformations built via pure `SchemaGetter.transform(fn)` (the common
+// decodeTo case) resolve synchronously, so running with no services is safe.
+// Effectful / middleware-based transforms will surface as Effect defects.
+function decode(transformation: SchemaAST.Link["transformation"], value: unknown): unknown {
+  const exit = Effect.runSyncExit(
+    (transformation.decode as any).run(Option.some(value), EMPTY_PARSE_OPTIONS) as Effect.Effect<Option.Option<unknown>>,
+  )
+  if (exit._tag === "Failure") throw new Error(`effect-zod: transform failed: ${String(exit.cause)}`)
+  return Option.getOrElse(exit.value, () => value)
 }
 
-function applyCheck(out: z.ZodTypeAny, check: SchemaAST.Check<any>, ast: SchemaAST.AST): z.ZodTypeAny {
-  if (check._tag === "FilterGroup") {
-    return check.checks.reduce((acc, sub) => applyCheck(acc, sub, ast), out)
+// Flatten FilterGroups and any nested variants into a linear list of Filters
+// so we can run all of them inside a single Zod .superRefine wrapper instead
+// of stacking N wrapper layers (one per check).
+function applyChecks(out: z.ZodTypeAny, checks: SchemaAST.Checks, ast: SchemaAST.AST): z.ZodTypeAny {
+  const filters: SchemaAST.Filter<unknown>[] = []
+  const collect = (c: SchemaAST.Check<unknown>) => {
+    if (c._tag === "FilterGroup") c.checks.forEach(collect)
+    else filters.push(c)
   }
+  checks.forEach(collect)
   return out.superRefine((value, ctx) => {
-    const issue = check.run(value, ast, {} as any)
-    if (!issue) return
-    const message = issueMessage(issue) ?? (check.annotations as any)?.message ?? "Validation failed"
-    ctx.addIssue({ code: "custom", message })
+    for (const filter of filters) {
+      const issue = filter.run(value, ast, EMPTY_PARSE_OPTIONS)
+      if (!issue) continue
+      const message = issueMessage(issue) ?? (filter.annotations as any)?.message ?? "Validation failed"
+      ctx.addIssue({ code: "custom", message })
+    }
   })
 }
 

+ 147 - 1
packages/opencode/test/util/effect-zod.test.ts

@@ -1,5 +1,5 @@
 import { describe, expect, test } from "bun:test"
-import { Schema } from "effect"
+import { Schema, SchemaGetter } from "effect"
 import z from "zod"
 
 import { zod, ZodOverride } from "../../src/util/effect-zod"
@@ -332,4 +332,150 @@ describe("util.effect-zod", () => {
       expect(schema.parse({ id: "x" })).toEqual({ id: "x" })
     })
   })
+
+  describe("transforms (Schema.decodeTo)", () => {
+    test("Number -> pseudo-Duration (seconds) applies the decode function", () => {
+      // Models the account/account.ts DurationFromSeconds pattern.
+      const SecondsToMs = Schema.Number.pipe(
+        Schema.decodeTo(Schema.Number, {
+          decode: SchemaGetter.transform((n: number) => n * 1000),
+          encode: SchemaGetter.transform((ms: number) => ms / 1000),
+        }),
+      )
+
+      const schema = zod(SecondsToMs)
+      expect(schema.parse(3)).toBe(3000)
+      expect(schema.parse(0)).toBe(0)
+    })
+
+    test("String -> Number via parseInt decode", () => {
+      const ParsedInt = Schema.String.pipe(
+        Schema.decodeTo(Schema.Number, {
+          decode: SchemaGetter.transform((s: string) => Number.parseInt(s, 10)),
+          encode: SchemaGetter.transform((n: number) => String(n)),
+        }),
+      )
+
+      const schema = zod(ParsedInt)
+      expect(schema.parse("42")).toBe(42)
+      expect(schema.parse("0")).toBe(0)
+    })
+
+    test("transform inside a struct field applies per-field", () => {
+      const Field = Schema.Number.pipe(
+        Schema.decodeTo(Schema.Number, {
+          decode: SchemaGetter.transform((n: number) => n + 1),
+          encode: SchemaGetter.transform((n: number) => n - 1),
+        }),
+      )
+
+      const schema = zod(
+        Schema.Struct({
+          plain: Schema.Number,
+          bumped: Field,
+        }),
+      )
+
+      expect(schema.parse({ plain: 5, bumped: 10 })).toEqual({ plain: 5, bumped: 11 })
+    })
+
+    test("chained decodeTo composes transforms in order", () => {
+      // String -> Number (parseInt) -> Number (doubled).
+      // Exercises the encoded() reduce, not just a single link.
+      const Chained = Schema.String.pipe(
+        Schema.decodeTo(Schema.Number, {
+          decode: SchemaGetter.transform((s: string) => Number.parseInt(s, 10)),
+          encode: SchemaGetter.transform((n: number) => String(n)),
+        }),
+        Schema.decodeTo(Schema.Number, {
+          decode: SchemaGetter.transform((n: number) => n * 2),
+          encode: SchemaGetter.transform((n: number) => n / 2),
+        }),
+      )
+
+      const schema = zod(Chained)
+      expect(schema.parse("21")).toBe(42)
+      expect(schema.parse("0")).toBe(0)
+    })
+
+    test("Schema.Class is unaffected by transform walker (returns plain object, not instance)", () => {
+      // Schema.Class uses Declaration + encoding under the hood to construct
+      // class instances. The walker must NOT apply that transform, or zod
+      // parsing would return class instances instead of plain objects.
+      class Method extends Schema.Class<Method>("TxTestMethod")({
+        type: Schema.String,
+        value: Schema.Number,
+      }) {}
+
+      const schema = zod(Method)
+      const parsed = schema.parse({ type: "oauth", value: 1 })
+      expect(parsed).toEqual({ type: "oauth", value: 1 })
+      // Guardrail: ensure we didn't get back a Method instance.
+      expect(parsed).not.toBeInstanceOf(Method)
+    })
+  })
+
+  describe("optimizations", () => {
+    test("walk() memoizes by AST identity — same AST node returns same Zod", () => {
+      const shared = Schema.Struct({ id: Schema.String, name: Schema.String })
+      const left = zod(shared)
+      const right = zod(shared)
+      expect(left).toBe(right)
+    })
+
+    test("nested reuse of the same AST reuses the cached Zod child", () => {
+      // Two different parents embed the same inner schema. The inner zod
+      // child should be identical by reference inside both parents.
+      class Inner extends Schema.Class<Inner>("MemoTestInner")({
+        value: Schema.String,
+      }) {}
+
+      class OuterA extends Schema.Class<OuterA>("MemoTestOuterA")({
+        inner: Inner,
+      }) {}
+
+      class OuterB extends Schema.Class<OuterB>("MemoTestOuterB")({
+        inner: Inner,
+      }) {}
+
+      const shapeA = (zod(OuterA) as any).shape ?? (zod(OuterA) as any)._def?.shape?.()
+      const shapeB = (zod(OuterB) as any).shape ?? (zod(OuterB) as any)._def?.shape?.()
+      expect(shapeA.inner).toBe(shapeB.inner)
+    })
+
+    test("multiple checks run in a single refinement layer (all fire on one value)", () => {
+      // Three checks attached to the same schema. All three must run and
+      // report — asserting that no check silently got dropped when we
+      // flattened into one superRefine.
+      const positive = Schema.makeFilter((n: number) => (n > 0 ? undefined : "not positive"))
+      const even = Schema.makeFilter((n: number) => (n % 2 === 0 ? undefined : "not even"))
+      const under100 = Schema.makeFilter((n: number) => (n < 100 ? undefined : "too big"))
+
+      const schema = zod(Schema.Number.check(positive).check(even).check(under100))
+
+      const neg = schema.safeParse(-3)
+      expect(neg.success).toBe(false)
+      expect(neg.error!.issues.map((i) => i.message)).toEqual(expect.arrayContaining(["not positive", "not even"]))
+
+      const big = schema.safeParse(101)
+      expect(big.success).toBe(false)
+      expect(big.error!.issues.map((i) => i.message)).toContain("too big")
+
+      // Passing value satisfies all three
+      expect(schema.parse(42)).toBe(42)
+    })
+
+    test("FilterGroup flattens into the single refinement layer alongside its siblings", () => {
+      const positive = Schema.makeFilter((n: number) => (n > 0 ? undefined : "not positive"))
+      const even = Schema.makeFilter((n: number) => (n % 2 === 0 ? undefined : "not even"))
+      const group = Schema.makeFilterGroup([positive, even])
+      const under100 = Schema.makeFilter((n: number) => (n < 100 ? undefined : "too big"))
+
+      const schema = zod(Schema.Number.check(group).check(under100))
+
+      const bad = schema.safeParse(-3)
+      expect(bad.success).toBe(false)
+      expect(bad.error!.issues.map((i) => i.message)).toEqual(expect.arrayContaining(["not positive", "not even"]))
+    })
+  })
 })