transform.ts 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. import type { APICallError, ModelMessage } from "ai"
  2. import { unique } from "remeda"
  3. import type { JSONSchema } from "zod/v4/core"
  4. import type { Provider } from "./provider"
  5. import type { ModelsDev } from "./models"
  6. type Modality = NonNullable<ModelsDev.Model["modalities"]>["input"][number]
  7. function mimeToModality(mime: string): Modality | undefined {
  8. if (mime.startsWith("image/")) return "image"
  9. if (mime.startsWith("audio/")) return "audio"
  10. if (mime.startsWith("video/")) return "video"
  11. if (mime === "application/pdf") return "pdf"
  12. return undefined
  13. }
  14. export namespace ProviderTransform {
  15. function normalizeMessages(msgs: ModelMessage[], model: Provider.Model): ModelMessage[] {
  16. if (model.api.id.includes("claude")) {
  17. return msgs.map((msg) => {
  18. if ((msg.role === "assistant" || msg.role === "tool") && Array.isArray(msg.content)) {
  19. msg.content = msg.content.map((part) => {
  20. if ((part.type === "tool-call" || part.type === "tool-result") && "toolCallId" in part) {
  21. return {
  22. ...part,
  23. toolCallId: part.toolCallId.replace(/[^a-zA-Z0-9_-]/g, "_"),
  24. }
  25. }
  26. return part
  27. })
  28. }
  29. return msg
  30. })
  31. }
  32. if (model.providerID === "mistral" || model.api.id.toLowerCase().includes("mistral")) {
  33. const result: ModelMessage[] = []
  34. for (let i = 0; i < msgs.length; i++) {
  35. const msg = msgs[i]
  36. const nextMsg = msgs[i + 1]
  37. if ((msg.role === "assistant" || msg.role === "tool") && Array.isArray(msg.content)) {
  38. msg.content = msg.content.map((part) => {
  39. if ((part.type === "tool-call" || part.type === "tool-result") && "toolCallId" in part) {
  40. // Mistral requires alphanumeric tool call IDs with exactly 9 characters
  41. const normalizedId = part.toolCallId
  42. .replace(/[^a-zA-Z0-9]/g, "") // Remove non-alphanumeric characters
  43. .substring(0, 9) // Take first 9 characters
  44. .padEnd(9, "0") // Pad with zeros if less than 9 characters
  45. return {
  46. ...part,
  47. toolCallId: normalizedId,
  48. }
  49. }
  50. return part
  51. })
  52. }
  53. result.push(msg)
  54. // Fix message sequence: tool messages cannot be followed by user messages
  55. if (msg.role === "tool" && nextMsg?.role === "user") {
  56. result.push({
  57. role: "assistant",
  58. content: [
  59. {
  60. type: "text",
  61. text: "Done.",
  62. },
  63. ],
  64. })
  65. }
  66. }
  67. return result
  68. }
  69. // TODO: rm later
  70. const bugged =
  71. (model.id === "kimi-k2-thinking" && model.providerID === "opencode") ||
  72. (model.id === "moonshotai/Kimi-K2-Thinking" && model.providerID === "baseten")
  73. if (
  74. model.providerID === "deepseek" ||
  75. model.api.id.toLowerCase().includes("deepseek") ||
  76. (model.capabilities.interleaved &&
  77. typeof model.capabilities.interleaved === "object" &&
  78. model.capabilities.interleaved.field === "reasoning_content" &&
  79. !bugged)
  80. ) {
  81. return msgs.map((msg) => {
  82. if (msg.role === "assistant" && Array.isArray(msg.content)) {
  83. const reasoningParts = msg.content.filter((part: any) => part.type === "reasoning")
  84. const reasoningText = reasoningParts.map((part: any) => part.text).join("")
  85. // Filter out reasoning parts from content
  86. const filteredContent = msg.content.filter((part: any) => part.type !== "reasoning")
  87. // Include reasoning_content directly on the message for all assistant messages
  88. if (reasoningText) {
  89. return {
  90. ...msg,
  91. content: filteredContent,
  92. providerOptions: {
  93. ...msg.providerOptions,
  94. openaiCompatible: {
  95. ...(msg.providerOptions as any)?.openaiCompatible,
  96. reasoning_content: reasoningText,
  97. },
  98. },
  99. }
  100. }
  101. return {
  102. ...msg,
  103. content: filteredContent,
  104. }
  105. }
  106. return msg
  107. })
  108. }
  109. return msgs
  110. }
  111. function applyCaching(msgs: ModelMessage[], providerID: string): ModelMessage[] {
  112. const system = msgs.filter((msg) => msg.role === "system").slice(0, 2)
  113. const final = msgs.filter((msg) => msg.role !== "system").slice(-2)
  114. const providerOptions = {
  115. anthropic: {
  116. cacheControl: { type: "ephemeral" },
  117. },
  118. openrouter: {
  119. cache_control: { type: "ephemeral" },
  120. },
  121. bedrock: {
  122. cachePoint: { type: "ephemeral" },
  123. },
  124. openaiCompatible: {
  125. cache_control: { type: "ephemeral" },
  126. },
  127. }
  128. for (const msg of unique([...system, ...final])) {
  129. const shouldUseContentOptions = providerID !== "anthropic" && Array.isArray(msg.content) && msg.content.length > 0
  130. if (shouldUseContentOptions) {
  131. const lastContent = msg.content[msg.content.length - 1]
  132. if (lastContent && typeof lastContent === "object") {
  133. lastContent.providerOptions = {
  134. ...lastContent.providerOptions,
  135. ...providerOptions,
  136. }
  137. continue
  138. }
  139. }
  140. msg.providerOptions = {
  141. ...msg.providerOptions,
  142. ...providerOptions,
  143. }
  144. }
  145. return msgs
  146. }
  147. function unsupportedParts(msgs: ModelMessage[], model: Provider.Model): ModelMessage[] {
  148. return msgs.map((msg) => {
  149. if (msg.role !== "user" || !Array.isArray(msg.content)) return msg
  150. const filtered = msg.content.map((part) => {
  151. if (part.type !== "file" && part.type !== "image") return part
  152. const mime = part.type === "image" ? part.image.toString().split(";")[0].replace("data:", "") : part.mediaType
  153. const filename = part.type === "file" ? part.filename : undefined
  154. const modality = mimeToModality(mime)
  155. if (!modality) return part
  156. if (model.capabilities.input[modality]) return part
  157. const name = filename ? `"${filename}"` : modality
  158. return {
  159. type: "text" as const,
  160. text: `ERROR: Cannot read ${name} (this model does not support ${modality} input). Inform the user.`,
  161. }
  162. })
  163. return { ...msg, content: filtered }
  164. })
  165. }
  166. export function message(msgs: ModelMessage[], model: Provider.Model) {
  167. msgs = unsupportedParts(msgs, model)
  168. msgs = normalizeMessages(msgs, model)
  169. if (model.providerID === "anthropic" || model.api.id.includes("anthropic") || model.api.id.includes("claude")) {
  170. msgs = applyCaching(msgs, model.providerID)
  171. }
  172. return msgs
  173. }
  174. export function temperature(model: Provider.Model) {
  175. if (model.api.id.toLowerCase().includes("qwen")) return 0.55
  176. if (model.api.id.toLowerCase().includes("claude")) return undefined
  177. if (model.api.id.toLowerCase().includes("gemini-3-pro")) return 1.0
  178. return 0
  179. }
  180. export function topP(model: Provider.Model) {
  181. if (model.api.id.toLowerCase().includes("qwen")) return 1
  182. return undefined
  183. }
  184. export function options(
  185. model: Provider.Model,
  186. sessionID: string,
  187. providerOptions?: Record<string, any>,
  188. ): Record<string, any> {
  189. const result: Record<string, any> = {}
  190. if (model.api.npm === "@openrouter/ai-sdk-provider") {
  191. result["usage"] = {
  192. include: true,
  193. }
  194. if (model.api.id.includes("gemini-3")) {
  195. result["reasoning"] = { effort: "high" }
  196. }
  197. }
  198. if (model.providerID === "baseten") {
  199. result["chat_template_args"] = { enable_thinking: true }
  200. }
  201. if (model.providerID === "openai" || providerOptions?.setCacheKey) {
  202. result["promptCacheKey"] = sessionID
  203. }
  204. if (model.api.npm === "@ai-sdk/google" || model.api.npm === "@ai-sdk/google-vertex") {
  205. result["thinkingConfig"] = {
  206. includeThoughts: true,
  207. }
  208. if (model.api.id.includes("gemini-3")) {
  209. result["thinkingConfig"]["thinkingLevel"] = "high"
  210. }
  211. }
  212. if (model.api.id.includes("gpt-5") && !model.api.id.includes("gpt-5-chat")) {
  213. if (model.providerID.includes("codex")) {
  214. result["store"] = false
  215. }
  216. if (!model.api.id.includes("codex") && !model.api.id.includes("gpt-5-pro")) {
  217. result["reasoningEffort"] = "medium"
  218. }
  219. if (model.api.id.endsWith("gpt-5.1") && model.providerID !== "azure") {
  220. result["textVerbosity"] = "low"
  221. }
  222. if (model.providerID.startsWith("opencode")) {
  223. result["promptCacheKey"] = sessionID
  224. result["include"] = ["reasoning.encrypted_content"]
  225. result["reasoningSummary"] = "auto"
  226. }
  227. }
  228. return result
  229. }
  230. export function smallOptions(model: Provider.Model) {
  231. const options: Record<string, any> = {}
  232. if (model.providerID === "openai" || model.api.id.includes("gpt-5")) {
  233. if (model.api.id.includes("5.1")) {
  234. options["reasoningEffort"] = "low"
  235. } else {
  236. options["reasoningEffort"] = "minimal"
  237. }
  238. }
  239. if (model.providerID === "google") {
  240. options["thinkingConfig"] = {
  241. thinkingBudget: 0,
  242. }
  243. }
  244. return options
  245. }
  246. export function providerOptions(model: Provider.Model, options: { [x: string]: any }) {
  247. switch (model.api.npm) {
  248. case "@ai-sdk/openai":
  249. case "@ai-sdk/azure":
  250. return {
  251. ["openai" as string]: options,
  252. }
  253. case "@ai-sdk/amazon-bedrock":
  254. return {
  255. ["bedrock" as string]: options,
  256. }
  257. case "@ai-sdk/anthropic":
  258. return {
  259. ["anthropic" as string]: options,
  260. }
  261. case "@ai-sdk/google":
  262. return {
  263. ["google" as string]: options,
  264. }
  265. case "@ai-sdk/gateway":
  266. return {
  267. ["gateway" as string]: options,
  268. }
  269. case "@openrouter/ai-sdk-provider":
  270. return {
  271. ["openrouter" as string]: options,
  272. }
  273. default:
  274. return {
  275. [model.providerID]: options,
  276. }
  277. }
  278. }
  279. export function maxOutputTokens(
  280. npm: string,
  281. options: Record<string, any>,
  282. modelLimit: number,
  283. globalLimit: number,
  284. ): number {
  285. const modelCap = modelLimit || globalLimit
  286. const standardLimit = Math.min(modelCap, globalLimit)
  287. if (npm === "@ai-sdk/anthropic") {
  288. const thinking = options?.["thinking"]
  289. const budgetTokens = typeof thinking?.["budgetTokens"] === "number" ? thinking["budgetTokens"] : 0
  290. const enabled = thinking?.["type"] === "enabled"
  291. if (enabled && budgetTokens > 0) {
  292. // Return text tokens so that text + thinking <= model cap, preferring 32k text when possible.
  293. if (budgetTokens + standardLimit <= modelCap) {
  294. return standardLimit
  295. }
  296. return modelCap - budgetTokens
  297. }
  298. }
  299. return standardLimit
  300. }
  301. export function schema(model: Provider.Model, schema: JSONSchema.BaseSchema) {
  302. /*
  303. if (["openai", "azure"].includes(providerID)) {
  304. if (schema.type === "object" && schema.properties) {
  305. for (const [key, value] of Object.entries(schema.properties)) {
  306. if (schema.required?.includes(key)) continue
  307. schema.properties[key] = {
  308. anyOf: [
  309. value as JSONSchema.JSONSchema,
  310. {
  311. type: "null",
  312. },
  313. ],
  314. }
  315. }
  316. }
  317. }
  318. */
  319. // Convert integer enums to string enums for Google/Gemini
  320. if (model.providerID === "google" || model.api.id.includes("gemini")) {
  321. const sanitizeGemini = (obj: any): any => {
  322. if (obj === null || typeof obj !== "object") {
  323. return obj
  324. }
  325. if (Array.isArray(obj)) {
  326. return obj.map(sanitizeGemini)
  327. }
  328. const result: any = {}
  329. for (const [key, value] of Object.entries(obj)) {
  330. if (key === "enum" && Array.isArray(value)) {
  331. // Convert all enum values to strings
  332. result[key] = value.map((v) => String(v))
  333. // If we have integer type with enum, change type to string
  334. if (result.type === "integer" || result.type === "number") {
  335. result.type = "string"
  336. }
  337. } else if (typeof value === "object" && value !== null) {
  338. result[key] = sanitizeGemini(value)
  339. } else {
  340. result[key] = value
  341. }
  342. }
  343. // Filter required array to only include fields that exist in properties
  344. if (result.type === "object" && result.properties && Array.isArray(result.required)) {
  345. result.required = result.required.filter((field: any) => field in result.properties)
  346. }
  347. return result
  348. }
  349. schema = sanitizeGemini(schema)
  350. }
  351. return schema
  352. }
  353. export function error(providerID: string, error: APICallError) {
  354. let message = error.message
  355. if (providerID === "github-copilot" && message.includes("The requested model is not supported")) {
  356. return (
  357. message +
  358. "\n\nMake sure the model is enabled in your copilot settings: https://github.com/settings/copilot/features"
  359. )
  360. }
  361. return message
  362. }
  363. }