|
|
@@ -17,7 +17,6 @@ import {
|
|
|
tool,
|
|
|
wrapLanguageModel,
|
|
|
type StreamTextResult,
|
|
|
- LoadAPIKeyError,
|
|
|
stepCountIs,
|
|
|
jsonSchema,
|
|
|
} from "ai"
|
|
|
@@ -28,6 +27,7 @@ import { Bus } from "../bus"
|
|
|
import { ProviderTransform } from "../provider/transform"
|
|
|
import { SystemPrompt } from "./system"
|
|
|
import { Plugin } from "../plugin"
|
|
|
+import { SessionRetry } from "./retry"
|
|
|
|
|
|
import PROMPT_PLAN from "../session/prompt/plan.txt"
|
|
|
import BUILD_SWITCH from "../session/prompt/build-switch.txt"
|
|
|
@@ -44,7 +44,6 @@ import { TaskTool } from "../tool/task"
|
|
|
import { FileTime } from "../file/time"
|
|
|
import { Permission } from "../permission"
|
|
|
import { Snapshot } from "../snapshot"
|
|
|
-import { NamedError } from "../util/error"
|
|
|
import { ulid } from "ulid"
|
|
|
import { spawn } from "child_process"
|
|
|
import { Command } from "../command"
|
|
|
@@ -55,6 +54,7 @@ import { MessageSummary } from "./summary"
|
|
|
export namespace SessionPrompt {
|
|
|
const log = Log.create({ service: "session.prompt" })
|
|
|
export const OUTPUT_TOKEN_MAX = 32_000
|
|
|
+ const MAX_RETRIES = 10
|
|
|
|
|
|
export const Event = {
|
|
|
Idle: Bus.event(
|
|
|
@@ -240,93 +240,145 @@ export namespace SessionPrompt {
|
|
|
await using _ = defer(async () => {
|
|
|
await processor.end()
|
|
|
})
|
|
|
- const stream = streamText({
|
|
|
- onError(error) {
|
|
|
- log.error("stream error", {
|
|
|
- error,
|
|
|
- })
|
|
|
- },
|
|
|
- async experimental_repairToolCall(input) {
|
|
|
- const lower = input.toolCall.toolName.toLowerCase()
|
|
|
- if (lower !== input.toolCall.toolName && tools[lower]) {
|
|
|
- log.info("repairing tool call", {
|
|
|
- tool: input.toolCall.toolName,
|
|
|
- repaired: lower,
|
|
|
+ const doStream = () =>
|
|
|
+ streamText({
|
|
|
+ onError(error) {
|
|
|
+ log.error("stream error", {
|
|
|
+ error,
|
|
|
})
|
|
|
+ },
|
|
|
+ async experimental_repairToolCall(input) {
|
|
|
+ const lower = input.toolCall.toolName.toLowerCase()
|
|
|
+ if (lower !== input.toolCall.toolName && tools[lower]) {
|
|
|
+ log.info("repairing tool call", {
|
|
|
+ tool: input.toolCall.toolName,
|
|
|
+ repaired: lower,
|
|
|
+ })
|
|
|
+ return {
|
|
|
+ ...input.toolCall,
|
|
|
+ toolName: lower,
|
|
|
+ }
|
|
|
+ }
|
|
|
return {
|
|
|
...input.toolCall,
|
|
|
- toolName: lower,
|
|
|
+ input: JSON.stringify({
|
|
|
+ tool: input.toolCall.toolName,
|
|
|
+ error: input.error.message,
|
|
|
+ }),
|
|
|
+ toolName: "invalid",
|
|
|
}
|
|
|
- }
|
|
|
- return {
|
|
|
- ...input.toolCall,
|
|
|
- input: JSON.stringify({
|
|
|
- tool: input.toolCall.toolName,
|
|
|
- error: input.error.message,
|
|
|
- }),
|
|
|
- toolName: "invalid",
|
|
|
- }
|
|
|
- },
|
|
|
- headers:
|
|
|
- model.providerID === "opencode"
|
|
|
- ? {
|
|
|
- "x-opencode-session": input.sessionID,
|
|
|
- "x-opencode-request": userMsg.info.id,
|
|
|
- }
|
|
|
- : undefined,
|
|
|
- maxRetries: 10,
|
|
|
- activeTools: Object.keys(tools).filter((x) => x !== "invalid"),
|
|
|
- maxOutputTokens: ProviderTransform.maxOutputTokens(
|
|
|
- model.providerID,
|
|
|
- params.options,
|
|
|
- model.info.limit.output,
|
|
|
- OUTPUT_TOKEN_MAX,
|
|
|
- ),
|
|
|
- abortSignal: abort.signal,
|
|
|
- providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, params.options),
|
|
|
- stopWhen: stepCountIs(1),
|
|
|
- temperature: params.temperature,
|
|
|
- topP: params.topP,
|
|
|
- messages: [
|
|
|
- ...system.map(
|
|
|
- (x): ModelMessage => ({
|
|
|
- role: "system",
|
|
|
- content: x,
|
|
|
- }),
|
|
|
- ),
|
|
|
- ...MessageV2.toModelMessage(
|
|
|
- msgs.filter((m) => {
|
|
|
- if (m.info.role !== "assistant" || m.info.error === undefined) {
|
|
|
- return true
|
|
|
- }
|
|
|
- if (
|
|
|
- MessageV2.AbortedError.isInstance(m.info.error) &&
|
|
|
- m.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning")
|
|
|
- ) {
|
|
|
- return true
|
|
|
- }
|
|
|
-
|
|
|
- return false
|
|
|
- }),
|
|
|
+ },
|
|
|
+ headers:
|
|
|
+ model.providerID === "opencode"
|
|
|
+ ? {
|
|
|
+ "x-opencode-session": input.sessionID,
|
|
|
+ "x-opencode-request": userMsg.info.id,
|
|
|
+ }
|
|
|
+ : undefined,
|
|
|
+ // set to 0, we handle loop
|
|
|
+ maxRetries: 0,
|
|
|
+ activeTools: Object.keys(tools).filter((x) => x !== "invalid"),
|
|
|
+ maxOutputTokens: ProviderTransform.maxOutputTokens(
|
|
|
+ model.providerID,
|
|
|
+ params.options,
|
|
|
+ model.info.limit.output,
|
|
|
+ OUTPUT_TOKEN_MAX,
|
|
|
),
|
|
|
- ],
|
|
|
- tools: model.info.tool_call === false ? undefined : tools,
|
|
|
- model: wrapLanguageModel({
|
|
|
- model: model.language,
|
|
|
- middleware: [
|
|
|
- {
|
|
|
- async transformParams(args) {
|
|
|
- if (args.type === "stream") {
|
|
|
- // @ts-expect-error
|
|
|
- args.params.prompt = ProviderTransform.message(args.params.prompt, model.providerID, model.modelID)
|
|
|
+ abortSignal: abort.signal,
|
|
|
+ providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, params.options),
|
|
|
+ stopWhen: stepCountIs(1),
|
|
|
+ temperature: params.temperature,
|
|
|
+ topP: params.topP,
|
|
|
+ messages: [
|
|
|
+ ...system.map(
|
|
|
+ (x): ModelMessage => ({
|
|
|
+ role: "system",
|
|
|
+ content: x,
|
|
|
+ }),
|
|
|
+ ),
|
|
|
+ ...MessageV2.toModelMessage(
|
|
|
+ msgs.filter((m) => {
|
|
|
+ if (m.info.role !== "assistant" || m.info.error === undefined) {
|
|
|
+ return true
|
|
|
}
|
|
|
- return args.params
|
|
|
- },
|
|
|
- },
|
|
|
+ if (
|
|
|
+ MessageV2.AbortedError.isInstance(m.info.error) &&
|
|
|
+ m.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning")
|
|
|
+ ) {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+
|
|
|
+ return false
|
|
|
+ }),
|
|
|
+ ),
|
|
|
],
|
|
|
- }),
|
|
|
+ tools: model.info.tool_call === false ? undefined : tools,
|
|
|
+ model: wrapLanguageModel({
|
|
|
+ model: model.language,
|
|
|
+ middleware: [
|
|
|
+ {
|
|
|
+ async transformParams(args) {
|
|
|
+ if (args.type === "stream") {
|
|
|
+ // @ts-expect-error
|
|
|
+ args.params.prompt = ProviderTransform.message(args.params.prompt, model.providerID, model.modelID)
|
|
|
+ }
|
|
|
+ return args.params
|
|
|
+ },
|
|
|
+ },
|
|
|
+ ],
|
|
|
+ }),
|
|
|
+ })
|
|
|
+
|
|
|
+ let stream = doStream()
|
|
|
+ let result = await processor.process(stream, {
|
|
|
+ count: 0,
|
|
|
+ max: MAX_RETRIES,
|
|
|
})
|
|
|
- const result = await processor.process(stream)
|
|
|
+ if (result.shouldRetry) {
|
|
|
+ for (let retry = 1; retry < MAX_RETRIES; retry++) {
|
|
|
+ const lastRetryPart = result.parts.findLast((p) => p.type === "retry")
|
|
|
+
|
|
|
+ if (lastRetryPart) {
|
|
|
+ const delayMs = SessionRetry.getRetryDelayInMs(lastRetryPart.error, retry)
|
|
|
+
|
|
|
+ log.info("retrying with backoff", {
|
|
|
+ attempt: retry,
|
|
|
+ delayMs,
|
|
|
+ })
|
|
|
+
|
|
|
+ const stop = await SessionRetry.sleep(delayMs, abort.signal)
|
|
|
+ .then(() => false)
|
|
|
+ .catch((error) => {
|
|
|
+ if (error instanceof DOMException && error.name === "AbortError") {
|
|
|
+ const err = new MessageV2.AbortedError(
|
|
|
+ { message: error.message },
|
|
|
+ {
|
|
|
+ cause: error,
|
|
|
+ },
|
|
|
+ ).toObject()
|
|
|
+ result.info.error = err
|
|
|
+ Bus.publish(Session.Event.Error, {
|
|
|
+ sessionID: result.info.sessionID,
|
|
|
+ error: result.info.error,
|
|
|
+ })
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ throw error
|
|
|
+ })
|
|
|
+
|
|
|
+ if (stop) break
|
|
|
+ }
|
|
|
+
|
|
|
+ stream = doStream()
|
|
|
+ result = await processor.process(stream, {
|
|
|
+ count: retry,
|
|
|
+ max: MAX_RETRIES,
|
|
|
+ })
|
|
|
+ if (!result.shouldRetry) {
|
|
|
+ break
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
await processor.end()
|
|
|
|
|
|
const queued = state().queued.get(input.sessionID) ?? []
|
|
|
@@ -959,9 +1011,10 @@ export namespace SessionPrompt {
|
|
|
partFromToolCall(toolCallID: string) {
|
|
|
return toolcalls[toolCallID]
|
|
|
},
|
|
|
- async process(stream: StreamTextResult<Record<string, AITool>, never>) {
|
|
|
+ async process(stream: StreamTextResult<Record<string, AITool>, never>, retries: { count: number; max: number }) {
|
|
|
log.info("process")
|
|
|
if (!assistantMsg) throw new Error("call next() first before processing")
|
|
|
+ let shouldRetry = false
|
|
|
try {
|
|
|
let currentText: MessageV2.TextPart | undefined
|
|
|
let reasoningMap: Record<string, MessageV2.ReasoningPart> = {}
|
|
|
@@ -1314,37 +1367,27 @@ export namespace SessionPrompt {
|
|
|
log.error("process", {
|
|
|
error: e,
|
|
|
})
|
|
|
- switch (true) {
|
|
|
- case e instanceof DOMException && e.name === "AbortError":
|
|
|
- assistantMsg.error = new MessageV2.AbortedError(
|
|
|
- { message: e.message },
|
|
|
- {
|
|
|
- cause: e,
|
|
|
- },
|
|
|
- ).toObject()
|
|
|
- break
|
|
|
- case MessageV2.OutputLengthError.isInstance(e):
|
|
|
- assistantMsg.error = e
|
|
|
- break
|
|
|
- case LoadAPIKeyError.isInstance(e):
|
|
|
- assistantMsg.error = new MessageV2.AuthError(
|
|
|
- {
|
|
|
- providerID: input.providerID,
|
|
|
- message: e.message,
|
|
|
- },
|
|
|
- { cause: e },
|
|
|
- ).toObject()
|
|
|
- break
|
|
|
- case e instanceof Error:
|
|
|
- assistantMsg.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
|
|
|
- break
|
|
|
- default:
|
|
|
- assistantMsg.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
|
|
|
+ const error = MessageV2.fromError(e, { providerID: input.providerID })
|
|
|
+ if (retries.count < retries.max && MessageV2.APIError.isInstance(error) && error.data.isRetryable) {
|
|
|
+ shouldRetry = true
|
|
|
+ await Session.updatePart({
|
|
|
+ id: Identifier.ascending("part"),
|
|
|
+ messageID: assistantMsg.id,
|
|
|
+ sessionID: assistantMsg.sessionID,
|
|
|
+ type: "retry",
|
|
|
+ attempt: retries.count + 1,
|
|
|
+ time: {
|
|
|
+ created: Date.now(),
|
|
|
+ },
|
|
|
+ error,
|
|
|
+ })
|
|
|
+ } else {
|
|
|
+ assistantMsg.error = error
|
|
|
+ Bus.publish(Session.Event.Error, {
|
|
|
+ sessionID: assistantMsg.sessionID,
|
|
|
+ error: assistantMsg.error,
|
|
|
+ })
|
|
|
}
|
|
|
- Bus.publish(Session.Event.Error, {
|
|
|
- sessionID: assistantMsg.sessionID,
|
|
|
- error: assistantMsg.error,
|
|
|
- })
|
|
|
}
|
|
|
const p = await Session.getParts(assistantMsg.id)
|
|
|
for (const part of p) {
|
|
|
@@ -1363,9 +1406,11 @@ export namespace SessionPrompt {
|
|
|
})
|
|
|
}
|
|
|
}
|
|
|
- assistantMsg.time.completed = Date.now()
|
|
|
+ if (!shouldRetry) {
|
|
|
+ assistantMsg.time.completed = Date.now()
|
|
|
+ }
|
|
|
await Session.updateMessage(assistantMsg)
|
|
|
- return { info: assistantMsg, parts: p, blocked }
|
|
|
+ return { info: assistantMsg, parts: p, blocked, shouldRetry }
|
|
|
},
|
|
|
}
|
|
|
return result
|