transform.ts 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. import type { ModelMessage } from "ai"
  2. import { unique } from "remeda"
  3. import type { JSONSchema } from "zod/v4/core"
  4. export namespace ProviderTransform {
  5. function normalizeMessages(
  6. msgs: ModelMessage[],
  7. providerID: string,
  8. modelID: string,
  9. ): ModelMessage[] {
  10. if (modelID.includes("claude")) {
  11. return msgs.map((msg) => {
  12. if ((msg.role === "assistant" || msg.role === "tool") && Array.isArray(msg.content)) {
  13. msg.content = msg.content.map((part) => {
  14. if (
  15. (part.type === "tool-call" || part.type === "tool-result") &&
  16. "toolCallId" in part
  17. ) {
  18. return {
  19. ...part,
  20. toolCallId: part.toolCallId.replace(/[^a-zA-Z0-9_-]/g, "_"),
  21. }
  22. }
  23. return part
  24. })
  25. }
  26. return msg
  27. })
  28. }
  29. if (providerID === "mistral" || modelID.toLowerCase().includes("mistral")) {
  30. const result: ModelMessage[] = []
  31. for (let i = 0; i < msgs.length; i++) {
  32. const msg = msgs[i]
  33. const prevMsg = msgs[i - 1]
  34. const nextMsg = msgs[i + 1]
  35. if ((msg.role === "assistant" || msg.role === "tool") && Array.isArray(msg.content)) {
  36. msg.content = msg.content.map((part) => {
  37. if (
  38. (part.type === "tool-call" || part.type === "tool-result") &&
  39. "toolCallId" in part
  40. ) {
  41. // Mistral requires alphanumeric tool call IDs with exactly 9 characters
  42. const normalizedId = part.toolCallId
  43. .replace(/[^a-zA-Z0-9]/g, "") // Remove non-alphanumeric characters
  44. .substring(0, 9) // Take first 9 characters
  45. .padEnd(9, "0") // Pad with zeros if less than 9 characters
  46. return {
  47. ...part,
  48. toolCallId: normalizedId,
  49. }
  50. }
  51. return part
  52. })
  53. }
  54. result.push(msg)
  55. // Fix message sequence: tool messages cannot be followed by user messages
  56. if (msg.role === "tool" && nextMsg?.role === "user") {
  57. result.push({
  58. role: "assistant",
  59. content: [
  60. {
  61. type: "text",
  62. text: "Done.",
  63. },
  64. ],
  65. })
  66. }
  67. }
  68. return result
  69. }
  70. return msgs
  71. }
  72. function applyCaching(msgs: ModelMessage[], providerID: string): ModelMessage[] {
  73. const system = msgs.filter((msg) => msg.role === "system").slice(0, 2)
  74. const final = msgs.filter((msg) => msg.role !== "system").slice(-2)
  75. const providerOptions = {
  76. anthropic: {
  77. cacheControl: { type: "ephemeral" },
  78. },
  79. openrouter: {
  80. cache_control: { type: "ephemeral" },
  81. },
  82. bedrock: {
  83. cachePoint: { type: "ephemeral" },
  84. },
  85. openaiCompatible: {
  86. cache_control: { type: "ephemeral" },
  87. },
  88. }
  89. for (const msg of unique([...system, ...final])) {
  90. const shouldUseContentOptions =
  91. providerID !== "anthropic" && Array.isArray(msg.content) && msg.content.length > 0
  92. if (shouldUseContentOptions) {
  93. const lastContent = msg.content[msg.content.length - 1]
  94. if (lastContent && typeof lastContent === "object") {
  95. lastContent.providerOptions = {
  96. ...lastContent.providerOptions,
  97. ...providerOptions,
  98. }
  99. continue
  100. }
  101. }
  102. msg.providerOptions = {
  103. ...msg.providerOptions,
  104. ...providerOptions,
  105. }
  106. }
  107. return msgs
  108. }
  109. export function message(msgs: ModelMessage[], providerID: string, modelID: string) {
  110. msgs = normalizeMessages(msgs, providerID, modelID)
  111. if (providerID === "anthropic" || modelID.includes("anthropic") || modelID.includes("claude")) {
  112. msgs = applyCaching(msgs, providerID)
  113. }
  114. return msgs
  115. }
  116. export function temperature(_providerID: string, modelID: string) {
  117. if (modelID.toLowerCase().includes("qwen")) return 0.55
  118. if (modelID.toLowerCase().includes("claude")) return undefined
  119. return 0
  120. }
  121. export function topP(_providerID: string, modelID: string) {
  122. if (modelID.toLowerCase().includes("qwen")) return 1
  123. return undefined
  124. }
  125. export function options(
  126. providerID: string,
  127. modelID: string,
  128. sessionID: string,
  129. ): Record<string, any> | undefined {
  130. const result: Record<string, any> = {}
  131. if (providerID === "openai") {
  132. result["promptCacheKey"] = sessionID
  133. }
  134. if (modelID.includes("gpt-5") && !modelID.includes("gpt-5-chat")) {
  135. if (modelID.includes("codex")) {
  136. result["store"] = false
  137. }
  138. if (!modelID.includes("codex") && !modelID.includes("gpt-5-pro")) {
  139. result["reasoningEffort"] = "medium"
  140. }
  141. if (providerID === "opencode") {
  142. result["promptCacheKey"] = sessionID
  143. result["include"] = ["reasoning.encrypted_content"]
  144. result["reasoningSummary"] = "auto"
  145. }
  146. }
  147. return result
  148. }
  149. export function providerOptions(
  150. npm: string | undefined,
  151. providerID: string,
  152. options: { [x: string]: any },
  153. ) {
  154. switch (npm) {
  155. case "@ai-sdk/openai":
  156. case "@ai-sdk/azure":
  157. return {
  158. ["openai" as string]: options,
  159. }
  160. case "@ai-sdk/amazon-bedrock":
  161. return {
  162. ["bedrock" as string]: options,
  163. }
  164. case "@ai-sdk/anthropic":
  165. return {
  166. ["anthropic" as string]: options,
  167. }
  168. default:
  169. return {
  170. [providerID]: options,
  171. }
  172. }
  173. }
  174. export function maxOutputTokens(
  175. providerID: string,
  176. options: Record<string, any>,
  177. modelLimit: number,
  178. globalLimit: number,
  179. ): number {
  180. const modelCap = modelLimit || globalLimit
  181. const standardLimit = Math.min(modelCap, globalLimit)
  182. if (providerID === "anthropic") {
  183. const thinking = options?.["thinking"]
  184. const budgetTokens =
  185. typeof thinking?.["budgetTokens"] === "number" ? thinking["budgetTokens"] : 0
  186. const enabled = thinking?.["type"] === "enabled"
  187. if (enabled && budgetTokens > 0) {
  188. // Return text tokens so that text + thinking <= model cap, preferring 32k text when possible.
  189. if (budgetTokens + standardLimit <= modelCap) {
  190. return standardLimit
  191. }
  192. return modelCap - budgetTokens
  193. }
  194. }
  195. return standardLimit
  196. }
  197. export function schema(_providerID: string, _modelID: string, schema: JSONSchema.BaseSchema) {
  198. /*
  199. if (["openai", "azure"].includes(providerID)) {
  200. if (schema.type === "object" && schema.properties) {
  201. for (const [key, value] of Object.entries(schema.properties)) {
  202. if (schema.required?.includes(key)) continue
  203. schema.properties[key] = {
  204. anyOf: [
  205. value as JSONSchema.JSONSchema,
  206. {
  207. type: "null",
  208. },
  209. ],
  210. }
  211. }
  212. }
  213. }
  214. if (providerID === "google") {
  215. }
  216. */
  217. return schema
  218. }
  219. }