llm.ts 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. import { Provider } from "@/provider/provider"
  2. import { Log } from "@/util/log"
  3. import {
  4. streamText,
  5. wrapLanguageModel,
  6. type ModelMessage,
  7. type StreamTextResult,
  8. type Tool,
  9. type ToolSet,
  10. extractReasoningMiddleware,
  11. } from "ai"
  12. import { clone, mergeDeep, pipe } from "remeda"
  13. import { ProviderTransform } from "@/provider/transform"
  14. import { Config } from "@/config/config"
  15. import { Instance } from "@/project/instance"
  16. import type { Agent } from "@/agent/agent"
  17. import type { MessageV2 } from "./message-v2"
  18. import { Plugin } from "@/plugin"
  19. import { SystemPrompt } from "./system"
  20. import { ToolRegistry } from "@/tool/registry"
  21. import { Flag } from "@/flag/flag"
  22. export namespace LLM {
  23. const log = Log.create({ service: "llm" })
  24. export const OUTPUT_TOKEN_MAX = Flag.OPENCODE_EXPERIMENTAL_OUTPUT_TOKEN_MAX || 32_000
  25. export type StreamInput = {
  26. user: MessageV2.User
  27. sessionID: string
  28. model: Provider.Model
  29. agent: Agent.Info
  30. system: string[]
  31. abort: AbortSignal
  32. messages: ModelMessage[]
  33. small?: boolean
  34. tools: Record<string, Tool>
  35. retries?: number
  36. }
  37. export type StreamOutput = StreamTextResult<ToolSet, unknown>
  38. export async function stream(input: StreamInput) {
  39. const l = log
  40. .clone()
  41. .tag("providerID", input.model.providerID)
  42. .tag("modelID", input.model.id)
  43. .tag("sessionID", input.sessionID)
  44. .tag("small", (input.small ?? false).toString())
  45. .tag("agent", input.agent.name)
  46. l.info("stream", {
  47. modelID: input.model.id,
  48. providerID: input.model.providerID,
  49. })
  50. const [language, cfg] = await Promise.all([Provider.getLanguage(input.model), Config.get()])
  51. const system = SystemPrompt.header(input.model.providerID)
  52. system.push(
  53. [
  54. // use agent prompt otherwise provider prompt
  55. ...(input.agent.prompt ? [input.agent.prompt] : SystemPrompt.provider(input.model)),
  56. // any custom prompt passed into this call
  57. ...input.system,
  58. // any custom prompt from last user message
  59. ...(input.user.system ? [input.user.system] : []),
  60. ]
  61. .filter((x) => x)
  62. .join("\n"),
  63. )
  64. const header = system[0]
  65. const original = clone(system)
  66. await Plugin.trigger("experimental.chat.system.transform", {}, { system })
  67. if (system.length === 0) {
  68. system.push(...original)
  69. }
  70. // rejoin to maintain 2-part structure for caching if header unchanged
  71. if (system.length > 2 && system[0] === header) {
  72. const rest = system.slice(1)
  73. system.length = 0
  74. system.push(header, rest.join("\n"))
  75. }
  76. const provider = await Provider.getProvider(input.model.providerID)
  77. const variant = input.model.variants && input.user.variant ? input.model.variants[input.user.variant] : undefined
  78. const options = pipe(
  79. ProviderTransform.options(input.model, input.sessionID, provider.options),
  80. mergeDeep(input.small ? ProviderTransform.smallOptions(input.model) : {}),
  81. mergeDeep(input.model.options),
  82. mergeDeep(input.agent.options),
  83. mergeDeep(variant && !variant.disabled ? variant : {}),
  84. )
  85. const params = await Plugin.trigger(
  86. "chat.params",
  87. {
  88. sessionID: input.sessionID,
  89. agent: input.agent,
  90. model: input.model,
  91. provider: Provider.getProvider(input.model.providerID),
  92. message: input.user,
  93. },
  94. {
  95. temperature: input.model.capabilities.temperature
  96. ? (input.agent.temperature ?? ProviderTransform.temperature(input.model))
  97. : undefined,
  98. topP: input.agent.topP ?? ProviderTransform.topP(input.model),
  99. topK: ProviderTransform.topK(input.model),
  100. options,
  101. },
  102. )
  103. l.info("params", {
  104. params,
  105. })
  106. const maxOutputTokens = ProviderTransform.maxOutputTokens(
  107. input.model.api.npm,
  108. params.options,
  109. input.model.limit.output,
  110. OUTPUT_TOKEN_MAX,
  111. )
  112. const tools = await resolveTools(input)
  113. return streamText({
  114. onError(error) {
  115. l.error("stream error", {
  116. error,
  117. })
  118. },
  119. async experimental_repairToolCall(failed) {
  120. const lower = failed.toolCall.toolName.toLowerCase()
  121. if (lower !== failed.toolCall.toolName && tools[lower]) {
  122. l.info("repairing tool call", {
  123. tool: failed.toolCall.toolName,
  124. repaired: lower,
  125. })
  126. return {
  127. ...failed.toolCall,
  128. toolName: lower,
  129. }
  130. }
  131. return {
  132. ...failed.toolCall,
  133. input: JSON.stringify({
  134. tool: failed.toolCall.toolName,
  135. error: failed.error.message,
  136. }),
  137. toolName: "invalid",
  138. }
  139. },
  140. temperature: params.temperature,
  141. topP: params.topP,
  142. topK: params.topK,
  143. providerOptions: ProviderTransform.providerOptions(input.model, params.options),
  144. activeTools: Object.keys(tools).filter((x) => x !== "invalid"),
  145. tools,
  146. maxOutputTokens,
  147. abortSignal: input.abort,
  148. headers: {
  149. ...(input.model.providerID.startsWith("opencode")
  150. ? {
  151. "x-opencode-project": Instance.project.id,
  152. "x-opencode-session": input.sessionID,
  153. "x-opencode-request": input.user.id,
  154. "x-opencode-client": Flag.OPENCODE_CLIENT,
  155. }
  156. : undefined),
  157. ...input.model.headers,
  158. },
  159. maxRetries: input.retries ?? 0,
  160. messages: [
  161. ...system.map(
  162. (x): ModelMessage => ({
  163. role: "system",
  164. content: x,
  165. }),
  166. ),
  167. ...input.messages,
  168. ],
  169. model: wrapLanguageModel({
  170. model: language,
  171. middleware: [
  172. {
  173. async transformParams(args) {
  174. if (args.type === "stream") {
  175. // @ts-expect-error
  176. args.params.prompt = ProviderTransform.message(args.params.prompt, input.model)
  177. }
  178. return args.params
  179. },
  180. },
  181. extractReasoningMiddleware({ tagName: "think", startWithReasoning: false }),
  182. ],
  183. }),
  184. experimental_telemetry: { isEnabled: cfg.experimental?.openTelemetry },
  185. })
  186. }
  187. async function resolveTools(input: Pick<StreamInput, "tools" | "agent" | "user">) {
  188. const enabled = pipe(
  189. input.agent.tools,
  190. mergeDeep(await ToolRegistry.enabled(input.agent)),
  191. mergeDeep(input.user.tools ?? {}),
  192. )
  193. for (const [key, value] of Object.entries(enabled)) {
  194. if (value === false) delete input.tools[key]
  195. }
  196. return input.tools
  197. }
  198. }