|
|
@@ -3,21 +3,77 @@ import { unique } from "remeda"
|
|
|
import type { JSONSchema } from "zod/v4/core"
|
|
|
|
|
|
export namespace ProviderTransform {
|
|
|
- function normalizeToolCallIds(msgs: ModelMessage[]): ModelMessage[] {
|
|
|
- return msgs.map((msg) => {
|
|
|
- if ((msg.role === "assistant" || msg.role === "tool") && Array.isArray(msg.content)) {
|
|
|
- msg.content = msg.content.map((part) => {
|
|
|
- if ((part.type === "tool-call" || part.type === "tool-result") && "toolCallId" in part) {
|
|
|
- return {
|
|
|
- ...part,
|
|
|
- toolCallId: part.toolCallId.replace(/[^a-zA-Z0-9_-]/g, "_"),
|
|
|
+ function normalizeMessages(
|
|
|
+ msgs: ModelMessage[],
|
|
|
+ providerID: string,
|
|
|
+ modelID: string,
|
|
|
+ ): ModelMessage[] {
|
|
|
+ if (modelID.includes("claude")) {
|
|
|
+ return msgs.map((msg) => {
|
|
|
+ if ((msg.role === "assistant" || msg.role === "tool") && Array.isArray(msg.content)) {
|
|
|
+ msg.content = msg.content.map((part) => {
|
|
|
+ if (
|
|
|
+ (part.type === "tool-call" || part.type === "tool-result") &&
|
|
|
+ "toolCallId" in part
|
|
|
+ ) {
|
|
|
+ return {
|
|
|
+ ...part,
|
|
|
+ toolCallId: part.toolCallId.replace(/[^a-zA-Z0-9_-]/g, "_"),
|
|
|
+ }
|
|
|
}
|
|
|
- }
|
|
|
- return part
|
|
|
- })
|
|
|
+ return part
|
|
|
+ })
|
|
|
+ }
|
|
|
+ return msg
|
|
|
+ })
|
|
|
+ }
|
|
|
+ if (providerID === "mistral" || modelID.toLowerCase().includes("mistral")) {
|
|
|
+ const result: ModelMessage[] = []
|
|
|
+ for (let i = 0; i < msgs.length; i++) {
|
|
|
+ const msg = msgs[i]
|
|
|
+ const prevMsg = msgs[i - 1]
|
|
|
+ const nextMsg = msgs[i + 1]
|
|
|
+
|
|
|
+ if ((msg.role === "assistant" || msg.role === "tool") && Array.isArray(msg.content)) {
|
|
|
+ msg.content = msg.content.map((part) => {
|
|
|
+ if (
|
|
|
+ (part.type === "tool-call" || part.type === "tool-result") &&
|
|
|
+ "toolCallId" in part
|
|
|
+ ) {
|
|
|
+ // Mistral requires alphanumeric tool call IDs with exactly 9 characters
|
|
|
+ const normalizedId = part.toolCallId
|
|
|
+ .replace(/[^a-zA-Z0-9]/g, "") // Remove non-alphanumeric characters
|
|
|
+ .substring(0, 9) // Take first 9 characters
|
|
|
+ .padEnd(9, "0") // Pad with zeros if less than 9 characters
|
|
|
+
|
|
|
+ return {
|
|
|
+ ...part,
|
|
|
+ toolCallId: normalizedId,
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return part
|
|
|
+ })
|
|
|
+ }
|
|
|
+
|
|
|
+ result.push(msg)
|
|
|
+
|
|
|
+ // Fix message sequence: tool messages cannot be followed by user messages
|
|
|
+ if (msg.role === "tool" && nextMsg?.role === "user") {
|
|
|
+ result.push({
|
|
|
+ role: "assistant",
|
|
|
+ content: [
|
|
|
+ {
|
|
|
+ type: "text",
|
|
|
+ text: "Done.",
|
|
|
+ },
|
|
|
+ ],
|
|
|
+ })
|
|
|
+ }
|
|
|
}
|
|
|
- return msg
|
|
|
- })
|
|
|
+ return result
|
|
|
+ }
|
|
|
+
|
|
|
+ return msgs
|
|
|
}
|
|
|
|
|
|
function applyCaching(msgs: ModelMessage[], providerID: string): ModelMessage[] {
|
|
|
@@ -64,9 +120,7 @@ export namespace ProviderTransform {
|
|
|
}
|
|
|
|
|
|
export function message(msgs: ModelMessage[], providerID: string, modelID: string) {
|
|
|
- if (modelID.includes("claude")) {
|
|
|
- msgs = normalizeToolCallIds(msgs)
|
|
|
- }
|
|
|
+ msgs = normalizeMessages(msgs, providerID, modelID)
|
|
|
if (providerID === "anthropic" || modelID.includes("anthropic") || modelID.includes("claude")) {
|
|
|
msgs = applyCaching(msgs, providerID)
|
|
|
}
|