ソースを参照

add output length errors

Dax Raad 8 ヶ月 前
コミット
77ae3fb9b9

+ 18 - 0
packages/opencode/src/session/index.ts

@@ -657,6 +657,21 @@ export namespace Session {
             }
             }
             break
             break
 
 
+          case "finish":
+            log.info("message finish", {
+              reason: value.finishReason,
+            })
+            const assistant = next.metadata!.assistant!
+            const usage = getUsage(
+              model.info,
+              value.usage,
+              value.providerMetadata,
+            )
+            assistant.cost = usage.cost
+            await updateMessage(next)
+            if (value.finishReason === "length")
+              throw new Message.OutputLengthError({})
+            break
           default:
           default:
             l.info("unhandled", {
             l.info("unhandled", {
               type: value.type,
               type: value.type,
@@ -670,6 +685,9 @@ export namespace Session {
         error: e,
         error: e,
       })
       })
       switch (true) {
       switch (true) {
+        case Message.OutputLengthError.isInstance(e):
+          next.metadata.error = e
+          break
         case LoadAPIKeyError.isInstance(e):
         case LoadAPIKeyError.isInstance(e):
           next.metadata.error = new Provider.AuthError(
           next.metadata.error = new Provider.AuthError(
             {
             {

+ 6 - 0
packages/opencode/src/session/message.ts

@@ -4,6 +4,11 @@ import { Provider } from "../provider/provider"
 import { NamedError } from "../util/error"
 import { NamedError } from "../util/error"
 
 
 export namespace Message {
 export namespace Message {
+  export const OutputLengthError = NamedError.create(
+    "SessionOutputLengthError",
+    z.object({}),
+  )
+
   export const ToolCall = z
   export const ToolCall = z
     .object({
     .object({
       state: z.literal("call"),
       state: z.literal("call"),
@@ -145,6 +150,7 @@ export namespace Message {
             .discriminatedUnion("name", [
             .discriminatedUnion("name", [
               Provider.AuthError.Schema,
               Provider.AuthError.Schema,
               NamedError.Unknown.Schema,
               NamedError.Unknown.Schema,
+              OutputLengthError.Schema,
             ])
             ])
             .optional(),
             .optional(),
           sessionID: z.string(),
           sessionID: z.string(),