Przeglądaj źródła

ignore: diff stuff

Dax Raad 4 miesięcy temu
rodzic
commit
f4dfae0bb0

+ 15 - 38
packages/opencode/src/server/server.ts

@@ -1,12 +1,6 @@
 import { Log } from "../util/log"
 import { Log } from "../util/log"
 import { Bus } from "../bus"
 import { Bus } from "../bus"
-import {
-  describeRoute,
-  generateSpecs,
-  validator,
-  resolver,
-  openAPIRouteHandler,
-} from "hono-openapi"
+import { describeRoute, generateSpecs, validator, resolver, openAPIRouteHandler } from "hono-openapi"
 import { Hono } from "hono"
 import { Hono } from "hono"
 import { cors } from "hono/cors"
 import { cors } from "hono/cors"
 import { streamSSE } from "hono/streaming"
 import { streamSSE } from "hono/streaming"
@@ -42,6 +36,7 @@ import { MCP } from "../mcp"
 import { Storage } from "../storage/storage"
 import { Storage } from "../storage/storage"
 import type { ContentfulStatusCode } from "hono/utils/http-status"
 import type { ContentfulStatusCode } from "hono/utils/http-status"
 import { Snapshot } from "@/snapshot"
 import { Snapshot } from "@/snapshot"
+import { MessageSummary } from "@/session/summary"
 
 
 const ERRORS = {
 const ERRORS = {
   400: {
   400: {
@@ -73,9 +68,7 @@ const ERRORS = {
 } as const
 } as const
 
 
 function errors(...codes: number[]) {
 function errors(...codes: number[]) {
-  return Object.fromEntries(
-    codes.map((code) => [code, ERRORS[code as keyof typeof ERRORS]]),
-  )
+  return Object.fromEntries(codes.map((code) => [code, ERRORS[code as keyof typeof ERRORS]]))
 }
 }
 
 
 export namespace Server {
 export namespace Server {
@@ -99,8 +92,7 @@ export namespace Server {
           else status = 500
           else status = 500
           return c.json(err.toObject(), { status })
           return c.json(err.toObject(), { status })
         }
         }
-        const message =
-          err instanceof Error && err.stack ? err.stack : err.toString()
+        const message = err instanceof Error && err.stack ? err.stack : err.toString()
         return c.json(new NamedError.Unknown({ message }).toObject(), {
         return c.json(new NamedError.Unknown({ message }).toObject(), {
           status: 500,
           status: 500,
         })
         })
@@ -194,17 +186,14 @@ export namespace Server {
       .get(
       .get(
         "/experimental/tool/ids",
         "/experimental/tool/ids",
         describeRoute({
         describeRoute({
-          description:
-            "List all tool IDs (including built-in and dynamically registered)",
+          description: "List all tool IDs (including built-in and dynamically registered)",
           operationId: "tool.ids",
           operationId: "tool.ids",
           responses: {
           responses: {
             200: {
             200: {
               description: "Tool IDs",
               description: "Tool IDs",
               content: {
               content: {
                 "application/json": {
                 "application/json": {
-                  schema: resolver(
-                    z.array(z.string()).meta({ ref: "ToolIDs" }),
-                  ),
+                  schema: resolver(z.array(z.string()).meta({ ref: "ToolIDs" })),
                 },
                 },
               },
               },
             },
             },
@@ -218,8 +207,7 @@ export namespace Server {
       .get(
       .get(
         "/experimental/tool",
         "/experimental/tool",
         describeRoute({
         describeRoute({
-          description:
-            "List tools with JSON schema parameters for a provider/model",
+          description: "List tools with JSON schema parameters for a provider/model",
           operationId: "tool.list",
           operationId: "tool.list",
           responses: {
           responses: {
             200: {
             200: {
@@ -260,9 +248,7 @@ export namespace Server {
               id: t.id,
               id: t.id,
               description: t.description,
               description: t.description,
               // Handle both Zod schemas and plain JSON schemas
               // Handle both Zod schemas and plain JSON schemas
-              parameters: (t.parameters as any)?._def
-                ? zodToJsonSchema(t.parameters as any)
-                : t.parameters,
+              parameters: (t.parameters as any)?._def ? zodToJsonSchema(t.parameters as any) : t.parameters,
             })),
             })),
           )
           )
         },
         },
@@ -643,19 +629,19 @@ export namespace Server {
         validator(
         validator(
           "param",
           "param",
           z.object({
           z.object({
-            id: Session.diff.schema.shape.sessionID,
+            id: MessageSummary.diff.schema.shape.sessionID,
           }),
           }),
         ),
         ),
         validator(
         validator(
           "query",
           "query",
           z.object({
           z.object({
-            messageID: Session.diff.schema.shape.messageID,
+            messageID: MessageSummary.diff.schema.shape.messageID,
           }),
           }),
         ),
         ),
         async (c) => {
         async (c) => {
           const query = c.req.valid("query")
           const query = c.req.valid("query")
           const params = c.req.valid("param")
           const params = c.req.valid("param")
-          const result = await Session.diff({
+          const result = await MessageSummary.diff({
             sessionID: params.id,
             sessionID: params.id,
             messageID: query.messageID,
             messageID: query.messageID,
           })
           })
@@ -1040,15 +1026,10 @@ export namespace Server {
           },
           },
         }),
         }),
         async (c) => {
         async (c) => {
-          const providers = await Provider.list().then((x) =>
-            mapValues(x, (item) => item.info),
-          )
+          const providers = await Provider.list().then((x) => mapValues(x, (item) => item.info))
           return c.json({
           return c.json({
             providers: Object.values(providers),
             providers: Object.values(providers),
-            default: mapValues(
-              providers,
-              (item) => Provider.sort(Object.values(item.models))[0].id,
-            ),
+            default: mapValues(providers, (item) => Provider.sort(Object.values(item.models))[0].id),
           })
           })
         },
         },
       )
       )
@@ -1243,12 +1224,8 @@ export namespace Server {
         validator(
         validator(
           "json",
           "json",
           z.object({
           z.object({
-            service: z
-              .string()
-              .meta({ description: "Service name for the log entry" }),
-            level: z
-              .enum(["debug", "info", "error", "warn"])
-              .meta({ description: "Log level" }),
+            service: z.string().meta({ description: "Service name for the log entry" }),
+            level: z.enum(["debug", "info", "error", "warn"]).meta({ description: "Log level" }),
             message: z.string().meta({ description: "Log message" }),
             message: z.string().meta({ description: "Log message" }),
             extra: z
             extra: z
               .record(z.string(), z.any())
               .record(z.string(), z.any())

+ 0 - 43
packages/opencode/src/session/index.ts

@@ -406,47 +406,4 @@ export namespace Session {
       await Project.setInitialized(Instance.project.id)
       await Project.setInitialized(Instance.project.id)
     },
     },
   )
   )
-
-  export const diff = fn(
-    z.object({
-      sessionID: Identifier.schema("session"),
-      messageID: Identifier.schema("message").optional(),
-    }),
-    async (input) => {
-      const all = await messages(input.sessionID)
-      const index = !input.messageID ? 0 : all.findIndex((x) => x.info.id === input.messageID)
-      if (index === -1) return []
-
-      let from: string | undefined
-      let to: string | undefined
-
-      // scan assistant messages to find earliest from and latest to
-      // snapshot
-      for (let i = index + 1; i < all.length; i++) {
-        const item = all[i]
-
-        // if messageID is provided, stop at the next user message
-        if (input.messageID && item.info.role === "user") break
-
-        if (!from) {
-          for (const part of item.parts) {
-            if (part.type === "step-start" && part.snapshot) {
-              from = part.snapshot
-              break
-            }
-          }
-        }
-
-        for (const part of item.parts) {
-          if (part.type === "step-finish" && part.snapshot) {
-            to = part.snapshot
-            break
-          }
-        }
-      }
-
-      if (from && to) return Snapshot.diffFull(from, to)
-      return []
-    },
-  )
 }
 }

+ 5 - 5
packages/opencode/src/session/prompt.ts

@@ -398,11 +398,6 @@ export namespace SessionPrompt {
       }
       }
       state().queued.delete(input.sessionID)
       state().queued.delete(input.sessionID)
       SessionCompaction.prune(input)
       SessionCompaction.prune(input)
-      MessageSummary.summarize({
-        sessionID: input.sessionID,
-        messageID: result.info.parentID,
-        providerID: model.providerID,
-      })
       return result
       return result
     }
     }
   }
   }
@@ -1297,6 +1292,11 @@ export namespace SessionPrompt {
                   }
                   }
                   snapshot = undefined
                   snapshot = undefined
                 }
                 }
+                MessageSummary.summarize({
+                  sessionID: input.sessionID,
+                  messageID: assistantMsg.parentID,
+                  providerID: assistantMsg.modelID,
+                })
                 break
                 break
 
 
               case "text-start":
               case "text-start":

+ 74 - 19
packages/opencode/src/session/summary.ts

@@ -5,6 +5,8 @@ import { Session } from "."
 import { generateText } from "ai"
 import { generateText } from "ai"
 import { MessageV2 } from "./message-v2"
 import { MessageV2 } from "./message-v2"
 import { Flag } from "@/flag/flag"
 import { Flag } from "@/flag/flag"
+import { Identifier } from "@/id/id"
+import { Snapshot } from "@/snapshot"
 
 
 export namespace MessageSummary {
 export namespace MessageSummary {
   export const summarize = fn(
   export const summarize = fn(
@@ -14,37 +16,90 @@ export namespace MessageSummary {
       providerID: z.string(),
       providerID: z.string(),
     }),
     }),
     async (input) => {
     async (input) => {
-      if (!Flag.OPENCODE_EXPERIMENTAL_TURN_SUMMARY) return
       const messages = await Session.messages(input.sessionID).then((msgs) =>
       const messages = await Session.messages(input.sessionID).then((msgs) =>
         msgs.filter(
         msgs.filter(
           (m) => m.info.id === input.messageID || (m.info.role === "assistant" && m.info.parentID === input.messageID),
           (m) => m.info.id === input.messageID || (m.info.role === "assistant" && m.info.parentID === input.messageID),
         ),
         ),
       )
       )
-      const small = await Provider.getSmallModel(input.providerID)
-      if (!small) return
-
-      const result = await generateText({
-        model: small.language,
-        maxOutputTokens: 100,
-        messages: [
-          {
-            role: "user",
-            content: `
+      const userMsg = messages.find((m) => m.info.id === input.messageID)!
+      const diffs = await computeDiff({ messages })
+      userMsg.info.summary = {
+        diffs,
+        text: "",
+      }
+      if (
+        Flag.OPENCODE_EXPERIMENTAL_TURN_SUMMARY &&
+        messages.every((m) => m.info.role !== "assistant" || m.info.time.completed)
+      ) {
+        const small = await Provider.getSmallModel(input.providerID)
+        if (!small) return
+        const result = await generateText({
+          model: small.language,
+          maxOutputTokens: 100,
+          messages: [
+            {
+              role: "user",
+              content: `
             Summarize the following conversation into 2 sentences MAX explaining what the assistant did and why. Do not explain the user's input.
             Summarize the following conversation into 2 sentences MAX explaining what the assistant did and why. Do not explain the user's input.
             <conversation>
             <conversation>
             ${JSON.stringify(MessageV2.toModelMessage(messages))}
             ${JSON.stringify(MessageV2.toModelMessage(messages))}
             </conversation>
             </conversation>
             `,
             `,
-          },
-        ],
-      })
-
-      const userMsg = messages.find((m) => m.info.id === input.messageID)!
-      userMsg.info.summary = {
-        text: result.text,
-        diffs: [],
+            },
+          ],
+        })
+        userMsg.info.summary = {
+          text: result.text,
+          diffs: [],
+        }
       }
       }
       await Session.updateMessage(userMsg.info)
       await Session.updateMessage(userMsg.info)
     },
     },
   )
   )
+
+  export const diff = fn(
+    z.object({
+      sessionID: Identifier.schema("session"),
+      messageID: Identifier.schema("message").optional(),
+    }),
+    async (input) => {
+      let all = await Session.messages(input.sessionID)
+      if (input.messageID)
+        all = all.filter(
+          (x) => x.info.id === input.messageID || (x.info.role === "assistant" && x.info.parentID === input.messageID),
+        )
+
+      return computeDiff({
+        messages: all,
+      })
+    },
+  )
+
+  async function computeDiff(input: { messages: MessageV2.WithParts[] }) {
+    let from: string | undefined
+    let to: string | undefined
+
+    // scan assistant messages to find earliest from and latest to
+    // snapshot
+    for (const item of input.messages) {
+      if (!from) {
+        for (const part of item.parts) {
+          if (part.type === "step-start" && part.snapshot) {
+            from = part.snapshot
+            break
+          }
+        }
+      }
+
+      for (const part of item.parts) {
+        if (part.type === "step-finish" && part.snapshot) {
+          to = part.snapshot
+          break
+        }
+      }
+    }
+
+    if (from && to) return Snapshot.diffFull(from, to)
+    return []
+  }
 }
 }