llm.ts 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. import { Provider } from "@/provider/provider"
  2. import { Log } from "@/util/log"
  3. import { streamText, wrapLanguageModel, type ModelMessage, type StreamTextResult, type Tool, type ToolSet } from "ai"
  4. import { clone, mergeDeep, pipe } from "remeda"
  5. import { ProviderTransform } from "@/provider/transform"
  6. import { Config } from "@/config/config"
  7. import { Instance } from "@/project/instance"
  8. import type { Agent } from "@/agent/agent"
  9. import type { MessageV2 } from "./message-v2"
  10. import { Plugin } from "@/plugin"
  11. import { SystemPrompt } from "./system"
  12. import { ToolRegistry } from "@/tool/registry"
  13. import { Flag } from "@/flag/flag"
  14. export namespace LLM {
  15. const log = Log.create({ service: "llm" })
  16. export const OUTPUT_TOKEN_MAX = 32_000
  17. export type StreamInput = {
  18. user: MessageV2.User
  19. sessionID: string
  20. model: Provider.Model
  21. agent: Agent.Info
  22. system: string[]
  23. abort: AbortSignal
  24. messages: ModelMessage[]
  25. small?: boolean
  26. tools: Record<string, Tool>
  27. retries?: number
  28. }
  29. export type StreamOutput = StreamTextResult<ToolSet, unknown>
  30. export async function stream(input: StreamInput) {
  31. const l = log
  32. .clone()
  33. .tag("providerID", input.model.providerID)
  34. .tag("modelID", input.model.id)
  35. .tag("sessionID", input.sessionID)
  36. .tag("small", (input.small ?? false).toString())
  37. .tag("agent", input.agent.name)
  38. l.info("stream", {
  39. modelID: input.model.id,
  40. providerID: input.model.providerID,
  41. })
  42. const [language, cfg] = await Promise.all([Provider.getLanguage(input.model), Config.get()])
  43. const system = SystemPrompt.header(input.model.providerID)
  44. system.push(
  45. [
  46. // use agent prompt otherwise provider prompt
  47. ...(input.agent.prompt ? [input.agent.prompt] : SystemPrompt.provider(input.model)),
  48. // any custom prompt passed into this call
  49. ...input.system,
  50. // any custom prompt from last user message
  51. ...(input.user.system ? [input.user.system] : []),
  52. ]
  53. .filter((x) => x)
  54. .join("\n"),
  55. )
  56. const original = clone(system)
  57. await Plugin.trigger("experimental.chat.system.transform", {}, { system })
  58. if (system.length === 0) {
  59. system.push(...original)
  60. }
  61. const params = await Plugin.trigger(
  62. "chat.params",
  63. {
  64. sessionID: input.sessionID,
  65. agent: input.agent,
  66. model: input.model,
  67. provider: Provider.getProvider(input.model.providerID),
  68. message: input.user,
  69. },
  70. {
  71. temperature: input.model.capabilities.temperature
  72. ? (input.agent.temperature ?? ProviderTransform.temperature(input.model))
  73. : undefined,
  74. topP: input.agent.topP ?? ProviderTransform.topP(input.model),
  75. topK: ProviderTransform.topK(input.model),
  76. options: pipe(
  77. {},
  78. mergeDeep(ProviderTransform.options(input.model, input.sessionID)),
  79. input.small ? mergeDeep(ProviderTransform.smallOptions(input.model)) : mergeDeep({}),
  80. mergeDeep(input.model.options),
  81. mergeDeep(input.agent.options),
  82. ),
  83. },
  84. )
  85. l.info("params", {
  86. params,
  87. })
  88. const maxOutputTokens = ProviderTransform.maxOutputTokens(
  89. input.model.api.npm,
  90. params.options,
  91. input.model.limit.output,
  92. OUTPUT_TOKEN_MAX,
  93. )
  94. const tools = await resolveTools(input)
  95. return streamText({
  96. onError(error) {
  97. l.error("stream error", {
  98. error,
  99. })
  100. },
  101. async experimental_repairToolCall(failed) {
  102. const lower = failed.toolCall.toolName.toLowerCase()
  103. if (lower !== failed.toolCall.toolName && tools[lower]) {
  104. l.info("repairing tool call", {
  105. tool: failed.toolCall.toolName,
  106. repaired: lower,
  107. })
  108. return {
  109. ...failed.toolCall,
  110. toolName: lower,
  111. }
  112. }
  113. return {
  114. ...failed.toolCall,
  115. input: JSON.stringify({
  116. tool: failed.toolCall.toolName,
  117. error: failed.error.message,
  118. }),
  119. toolName: "invalid",
  120. }
  121. },
  122. temperature: params.temperature,
  123. topP: params.topP,
  124. topK: params.topK,
  125. providerOptions: ProviderTransform.providerOptions(input.model, params.options),
  126. activeTools: Object.keys(tools).filter((x) => x !== "invalid"),
  127. tools,
  128. maxOutputTokens,
  129. abortSignal: input.abort,
  130. headers: {
  131. ...(input.model.providerID.startsWith("opencode")
  132. ? {
  133. "x-opencode-project": Instance.project.id,
  134. "x-opencode-session": input.sessionID,
  135. "x-opencode-request": input.user.id,
  136. "x-opencode-client": Flag.OPENCODE_CLIENT,
  137. }
  138. : undefined),
  139. ...input.model.headers,
  140. },
  141. maxRetries: input.retries ?? 0,
  142. messages: [
  143. ...system.map(
  144. (x): ModelMessage => ({
  145. role: "system",
  146. content: x,
  147. }),
  148. ),
  149. ...input.messages,
  150. ],
  151. model: wrapLanguageModel({
  152. model: language,
  153. middleware: [
  154. {
  155. async transformParams(args) {
  156. if (args.type === "stream") {
  157. // @ts-expect-error
  158. args.params.prompt = ProviderTransform.message(args.params.prompt, input.model)
  159. }
  160. return args.params
  161. },
  162. },
  163. ],
  164. }),
  165. experimental_telemetry: { isEnabled: cfg.experimental?.openTelemetry },
  166. })
  167. }
  168. async function resolveTools(input: Pick<StreamInput, "tools" | "agent" | "user">) {
  169. const enabled = pipe(
  170. input.agent.tools,
  171. mergeDeep(await ToolRegistry.enabled(input.agent)),
  172. mergeDeep(input.user.tools ?? {}),
  173. )
  174. for (const [key, value] of Object.entries(enabled)) {
  175. if (value === false) delete input.tools[key]
  176. }
  177. return input.tools
  178. }
  179. }