Преглед изворни кода

core: improve session API reliability with proper input validation

Dax Raad пре 4 месеци
родитељ
комит
10998d62b9

+ 1 - 1
packages/opencode/src/cli/cmd/run.ts

@@ -106,7 +106,7 @@ export const RunCommand = cmd({
 
         if (args.session) return Session.get(args.session)
 
-        return Session.create()
+        return Session.create({})
       })()
 
       if (!session) {

+ 9 - 25
packages/opencode/src/server/server.ts

@@ -31,7 +31,6 @@ import { SessionRevert } from "../session/revert"
 import { lazy } from "../util/lazy"
 import { Todo } from "../session/todo"
 import { InstanceBootstrap } from "../project/bootstrap"
-import { Identifier } from "@/id/id"
 
 const ERRORS = {
   400: {
@@ -308,7 +307,7 @@ export namespace Server {
         validator(
           "param",
           z.object({
-            id: z.string(),
+            id: Session.get.schema,
           }),
         ),
         async (c) => {
@@ -336,7 +335,7 @@ export namespace Server {
         validator(
           "param",
           z.object({
-            id: z.string(),
+            id: Session.children.schema,
           }),
         ),
         async (c) => {
@@ -390,18 +389,10 @@ export namespace Server {
             },
           },
         }),
-        validator(
-          "json",
-          z
-            .object({
-              parentID: z.string().optional(),
-              title: z.string().optional(),
-            })
-            .optional(),
-        ),
+        validator("json", Session.create.schema.optional()),
         async (c) => {
           const body = c.req.valid("json") ?? {}
-          const session = await Session.create(body.parentID, body.title)
+          const session = await Session.create(body)
           return c.json(session)
         },
       )
@@ -424,7 +415,7 @@ export namespace Server {
         validator(
           "param",
           z.object({
-            id: z.string(),
+            id: Session.remove.schema,
           }),
         ),
         async (c) => {
@@ -495,14 +486,7 @@ export namespace Server {
             id: z.string().meta({ description: "Session ID" }),
           }),
         ),
-        validator(
-          "json",
-          z.object({
-            messageID: z.string(),
-            providerID: z.string(),
-            modelID: z.string(),
-          }),
-        ),
+        validator("json", Session.initialize.schema.omit({ sessionID: true })),
         async (c) => {
           const sessionID = c.req.valid("param").id
           const body = c.req.valid("json")
@@ -529,7 +513,7 @@ export namespace Server {
         validator(
           "param",
           z.object({
-            id: Identifier.schema("session").meta({ description: "Session ID" }),
+            id: Session.fork.schema.shape.sessionID,
           }),
         ),
         validator("json", Session.fork.schema.omit({ sessionID: true })),
@@ -614,7 +598,7 @@ export namespace Server {
         validator(
           "param",
           z.object({
-            id: z.string(),
+            id: Session.unshare.schema,
           }),
         ),
         async (c) => {
@@ -717,7 +701,7 @@ export namespace Server {
         ),
         async (c) => {
           const params = c.req.valid("param")
-          const message = await Session.getMessage(params.id, params.messageID)
+          const message = await Session.getMessage({ sessionID: params.id, messageID: params.messageID })
           return c.json(message)
         },
       )

+ 1 - 1
packages/opencode/src/session/compaction.ts

@@ -144,7 +144,7 @@ export namespace SessionCompaction {
         },
       ],
     })
-    const usage = Session.getUsage(model.info, generated.usage, generated.providerMetadata)
+    const usage = Session.getUsage({ model: model.info, usage: generated.usage, metadata: generated.providerMetadata })
     msg.cost += usage.cost
     msg.tokens = usage.tokens
     msg.summary = true

+ 118 - 88
packages/opencode/src/session/index.ts

@@ -93,13 +93,21 @@ export namespace Session {
     ),
   }
 
-  export async function create(parentID?: string, title?: string) {
-    return createNext({
-      parentID,
-      directory: Instance.directory,
-      title,
-    })
-  }
+  export const create = fn(
+    z
+      .object({
+        parentID: Identifier.schema("session").optional(),
+        title: z.string().optional(),
+      })
+      .optional(),
+    async (input) => {
+      return createNext({
+        parentID: input?.parentID,
+        directory: Instance.directory,
+        title: input?.title,
+      })
+    },
+  )
 
   export const fork = fn(
     z.object({
@@ -132,11 +140,11 @@ export namespace Session {
     },
   )
 
-  export async function touch(sessionID: string) {
+  export const touch = fn(Identifier.schema("session"), async (sessionID) => {
     await update(sessionID, (draft) => {
       draft.time.updated = Date.now()
     })
-  }
+  })
 
   export async function createNext(input: { id?: string; title?: string; parentID?: string; directory: string }) {
     const result: Info = {
@@ -170,16 +178,16 @@ export namespace Session {
     return result
   }
 
-  export async function get(id: string) {
+  export const get = fn(Identifier.schema("session"), async (id) => {
     const read = await Storage.read<Info>(["session", Instance.project.id, id])
     return read as Info
-  }
+  })
 
-  export async function getShare(id: string) {
+  export const getShare = fn(Identifier.schema("session"), async (id) => {
     return Storage.read<ShareInfo>(["share", id])
-  }
+  })
 
-  export async function share(id: string) {
+  export const share = fn(Identifier.schema("session"), async (id) => {
     const cfg = await Config.get()
     if (cfg.share === "disabled") {
       throw new Error("Sharing is disabled in configuration")
@@ -202,9 +210,9 @@ export namespace Session {
       }
     }
     return share
-  }
+  })
 
-  export async function unshare(id: string) {
+  export const unshare = fn(Identifier.schema("session"), async (id) => {
     const share = await getShare(id)
     if (!share) return
     await Storage.remove(["share", id])
@@ -212,7 +220,7 @@ export namespace Session {
       draft.share = undefined
     })
     await Share.remove(id, share.secret)
-  }
+  })
 
   export async function update(id: string, editor: (session: Info) => void) {
     const project = Instance.project
@@ -226,7 +234,7 @@ export namespace Session {
     return result
   }
 
-  export async function messages(sessionID: string) {
+  export const messages = fn(Identifier.schema("session"), async (sessionID) => {
     const result = [] as MessageV2.WithParts[]
     for (const p of await Storage.list(["message", sessionID])) {
       const read = await Storage.read<MessageV2.Info>(p)
@@ -237,16 +245,22 @@ export namespace Session {
     }
     result.sort((a, b) => (a.info.id > b.info.id ? 1 : -1))
     return result
-  }
+  })
 
-  export async function getMessage(sessionID: string, messageID: string) {
-    return {
-      info: await Storage.read<MessageV2.Info>(["message", sessionID, messageID]),
-      parts: await getParts(messageID),
-    }
-  }
+  export const getMessage = fn(
+    z.object({
+      sessionID: Identifier.schema("session"),
+      messageID: Identifier.schema("message"),
+    }),
+    async (input) => {
+      return {
+        info: await Storage.read<MessageV2.Info>(["message", input.sessionID, input.messageID]),
+        parts: await getParts(input.messageID),
+      }
+    },
+  )
 
-  export async function getParts(messageID: string) {
+  export const getParts = fn(Identifier.schema("message"), async (messageID) => {
     const result = [] as MessageV2.Part[]
     for (const item of await Storage.list(["part", messageID])) {
       const read = await Storage.read<MessageV2.Part>(item)
@@ -254,7 +268,7 @@ export namespace Session {
     }
     result.sort((a, b) => (a.id > b.id ? 1 : -1))
     return result
-  }
+  })
 
   export async function* list() {
     const project = Instance.project
@@ -263,7 +277,7 @@ export namespace Session {
     }
   }
 
-  export async function children(parentID: string) {
+  export const children = fn(Identifier.schema("session"), async (parentID) => {
     const project = Instance.project
     const result = [] as Session.Info[]
     for (const item of await Storage.list(["session", project.id])) {
@@ -272,9 +286,9 @@ export namespace Session {
       result.push(session)
     }
     return result
-  }
+  })
 
-  export async function remove(sessionID: string) {
+  export const remove = fn(Identifier.schema("session"), async (sessionID) => {
     const project = Instance.project
     try {
       const session = await get(sessionID)
@@ -295,56 +309,69 @@ export namespace Session {
     } catch (e) {
       log.error(e)
     }
-  }
+  })
 
-  export async function updateMessage(msg: MessageV2.Info) {
+  export const updateMessage = fn(MessageV2.Info, async (msg) => {
     await Storage.write(["message", msg.sessionID, msg.id], msg)
     Bus.publish(MessageV2.Event.Updated, {
       info: msg,
     })
     return msg
-  }
+  })
 
-  export async function removeMessage(sessionID: string, messageID: string) {
-    await Storage.remove(["message", sessionID, messageID])
-    Bus.publish(MessageV2.Event.Removed, {
-      sessionID,
-      messageID,
-    })
-    return messageID
-  }
+  export const removeMessage = fn(
+    z.object({
+      sessionID: Identifier.schema("session"),
+      messageID: Identifier.schema("message"),
+    }),
+    async (input) => {
+      await Storage.remove(["message", input.sessionID, input.messageID])
+      Bus.publish(MessageV2.Event.Removed, {
+        sessionID: input.sessionID,
+        messageID: input.messageID,
+      })
+      return input.messageID
+    },
+  )
 
-  export async function updatePart(part: MessageV2.Part) {
+  export const updatePart = fn(MessageV2.Part, async (part) => {
     await Storage.write(["part", part.messageID, part.id], part)
     Bus.publish(MessageV2.Event.PartUpdated, {
       part,
     })
     return part
-  }
+  })
 
-  export function getUsage(model: ModelsDev.Model, usage: LanguageModelUsage, metadata?: ProviderMetadata) {
-    const tokens = {
-      input: usage.inputTokens ?? 0,
-      output: usage.outputTokens ?? 0,
-      reasoning: usage?.reasoningTokens ?? 0,
-      cache: {
-        write: (metadata?.["anthropic"]?.["cacheCreationInputTokens"] ??
-          // @ts-expect-error
-          metadata?.["bedrock"]?.["usage"]?.["cacheWriteInputTokens"] ??
-          0) as number,
-        read: usage.cachedInputTokens ?? 0,
-      },
-    }
-    return {
-      cost: new Decimal(0)
-        .add(new Decimal(tokens.input).mul(model.cost?.input ?? 0).div(1_000_000))
-        .add(new Decimal(tokens.output).mul(model.cost?.output ?? 0).div(1_000_000))
-        .add(new Decimal(tokens.cache.read).mul(model.cost?.cache_read ?? 0).div(1_000_000))
-        .add(new Decimal(tokens.cache.write).mul(model.cost?.cache_write ?? 0).div(1_000_000))
-        .toNumber(),
-      tokens,
-    }
-  }
+  export const getUsage = fn(
+    z.object({
+      model: z.custom<ModelsDev.Model>(),
+      usage: z.custom<LanguageModelUsage>(),
+      metadata: z.custom<ProviderMetadata>().optional(),
+    }),
+    (input) => {
+      const tokens = {
+        input: input.usage.inputTokens ?? 0,
+        output: input.usage.outputTokens ?? 0,
+        reasoning: input.usage?.reasoningTokens ?? 0,
+        cache: {
+          write: (input.metadata?.["anthropic"]?.["cacheCreationInputTokens"] ??
+            // @ts-expect-error
+            input.metadata?.["bedrock"]?.["usage"]?.["cacheWriteInputTokens"] ??
+            0) as number,
+          read: input.usage.cachedInputTokens ?? 0,
+        },
+      }
+      return {
+        cost: new Decimal(0)
+          .add(new Decimal(tokens.input).mul(input.model.cost?.input ?? 0).div(1_000_000))
+          .add(new Decimal(tokens.output).mul(input.model.cost?.output ?? 0).div(1_000_000))
+          .add(new Decimal(tokens.cache.read).mul(input.model.cost?.cache_read ?? 0).div(1_000_000))
+          .add(new Decimal(tokens.cache.write).mul(input.model.cost?.cache_write ?? 0).div(1_000_000))
+          .toNumber(),
+        tokens,
+      }
+    },
+  )
 
   export class BusyError extends Error {
     constructor(public readonly sessionID: string) {
@@ -352,27 +379,30 @@ export namespace Session {
     }
   }
 
-  export async function initialize(input: {
-    sessionID: string
-    modelID: string
-    providerID: string
-    messageID: string
-  }) {
-    await SessionPrompt.prompt({
-      sessionID: input.sessionID,
-      messageID: input.messageID,
-      model: {
-        providerID: input.providerID,
-        modelID: input.modelID,
-      },
-      parts: [
-        {
-          id: Identifier.ascending("part"),
-          type: "text",
-          text: PROMPT_INITIALIZE.replace("${path}", Instance.worktree),
+  export const initialize = fn(
+    z.object({
+      sessionID: Identifier.schema("session"),
+      modelID: z.string(),
+      providerID: z.string(),
+      messageID: Identifier.schema("message"),
+    }),
+    async (input) => {
+      await SessionPrompt.prompt({
+        sessionID: input.sessionID,
+        messageID: input.messageID,
+        model: {
+          providerID: input.providerID,
+          modelID: input.modelID,
         },
-      ],
-    })
-    await Project.setInitialized(Instance.project.id)
-  }
+        parts: [
+          {
+            id: Identifier.ascending("part"),
+            type: "text",
+            text: PROMPT_INITIALIZE.replace("${path}", Instance.worktree),
+          },
+        ],
+      })
+      await Project.setInitialized(Instance.project.id)
+    },
+  )
 }

+ 5 - 1
packages/opencode/src/session/prompt.ts

@@ -1031,7 +1031,11 @@ export namespace SessionPrompt {
                 break
 
               case "finish-step":
-                const usage = Session.getUsage(input.model, value.usage, value.providerMetadata)
+                const usage = Session.getUsage({
+                  model: input.model,
+                  usage: value.usage,
+                  metadata: value.providerMetadata,
+                })
                 assistantMsg.cost += usage.cost
                 assistantMsg.tokens = usage.tokens
                 await Session.updatePart({

+ 5 - 2
packages/opencode/src/tool/task.ts

@@ -26,8 +26,11 @@ export const TaskTool = Tool.define("task", async () => {
     async execute(params, ctx) {
       const agent = await Agent.get(params.subagent_type)
       if (!agent) throw new Error(`Unknown agent type: ${params.subagent_type} is not a valid agent type`)
-      const session = await Session.create(ctx.sessionID, params.description + ` (@${agent.name} subagent)`)
-      const msg = await Session.getMessage(ctx.sessionID, ctx.messageID)
+      const session = await Session.create({
+        parentID: ctx.sessionID,
+        title: params.description + ` (@${agent.name} subagent)`,
+      })
+      const msg = await Session.getMessage({ sessionID: ctx.sessionID, messageID: ctx.messageID })
       if (msg.info.role !== "assistant") throw new Error("Not an assistant message")
       const messageID = Identifier.ascending("message")
       const parts: Record<string, MessageV2.ToolPart> = {}