Browse Source

Add NTC support for Cerebras (#9692)

Matt Rubens 1 month ago
parent
commit
0a2d1a41e4

+ 5 - 0
packages/types/src/providers/cerebras.ts

@@ -11,6 +11,7 @@ export const cerebrasModels = {
 		contextWindow: 131072,
 		supportsImages: false,
 		supportsPromptCache: false,
+		supportsNativeTools: true,
 		inputPrice: 0,
 		outputPrice: 0,
 		description: "Highly intelligent general purpose model with up to 1,000 tokens/s",
@@ -20,6 +21,7 @@ export const cerebrasModels = {
 		contextWindow: 64000,
 		supportsImages: false,
 		supportsPromptCache: false,
+		supportsNativeTools: true,
 		inputPrice: 0,
 		outputPrice: 0,
 		description: "Intelligent model with ~1400 tokens/s",
@@ -29,6 +31,7 @@ export const cerebrasModels = {
 		contextWindow: 64000,
 		supportsImages: false,
 		supportsPromptCache: false,
+		supportsNativeTools: true,
 		inputPrice: 0,
 		outputPrice: 0,
 		description: "Powerful model with ~2600 tokens/s",
@@ -38,6 +41,7 @@ export const cerebrasModels = {
 		contextWindow: 64000,
 		supportsImages: false,
 		supportsPromptCache: false,
+		supportsNativeTools: true,
 		inputPrice: 0,
 		outputPrice: 0,
 		description: "SOTA coding performance with ~2500 tokens/s",
@@ -47,6 +51,7 @@ export const cerebrasModels = {
 		contextWindow: 64000,
 		supportsImages: false,
 		supportsPromptCache: false,
+		supportsNativeTools: true,
 		inputPrice: 0,
 		outputPrice: 0,
 		description:

+ 3 - 1
src/api/providers/base-provider.ts

@@ -20,7 +20,8 @@ export abstract class BaseProvider implements ApiHandler {
 
 	/**
 	 * Converts an array of tools to be compatible with OpenAI's strict mode.
-	 * Filters for function tools and applies schema conversion to their parameters.
+	 * Filters for function tools, applies schema conversion to their parameters,
+	 * and ensures all tools have consistent strict: true values.
 	 */
 	protected convertToolsForOpenAI(tools: any[] | undefined): any[] | undefined {
 		if (!tools) {
@@ -33,6 +34,7 @@ export abstract class BaseProvider implements ApiHandler {
 						...tool,
 						function: {
 							...tool.function,
+							strict: true,
 							parameters: this.convertToolSchemaForOpenAI(tool.function.parameters),
 						},
 					}

+ 79 - 74
src/api/providers/cerebras.ts

@@ -16,68 +16,6 @@ import { t } from "../../i18n"
 const CEREBRAS_BASE_URL = "https://api.cerebras.ai/v1"
 const CEREBRAS_DEFAULT_TEMPERATURE = 0
 
-/**
- * Removes thinking tokens from text to prevent model confusion when processing conversation history.
- * This is crucial because models can get confused by their own thinking tokens in input.
- */
-function stripThinkingTokens(text: string): string {
-	// Remove <think>...</think> blocks entirely, including nested ones
-	return text.replace(/<think>[\s\S]*?<\/think>/g, "").trim()
-}
-
-/**
- * Flattens OpenAI message content to simple strings that Cerebras can handle.
- * Cerebras doesn't support complex content arrays like OpenAI does.
- */
-function flattenMessageContent(content: any): string {
-	if (typeof content === "string") {
-		return content
-	}
-
-	if (Array.isArray(content)) {
-		return content
-			.map((part) => {
-				if (typeof part === "string") {
-					return part
-				}
-				if (part.type === "text") {
-					return part.text || ""
-				}
-				if (part.type === "image_url") {
-					return "[Image]" // Placeholder for images since Cerebras doesn't support images
-				}
-				return ""
-			})
-			.filter(Boolean)
-			.join("\n")
-	}
-
-	// Fallback for any other content types
-	return String(content || "")
-}
-
-/**
- * Converts OpenAI messages to Cerebras-compatible format with simple string content.
- * Also strips thinking tokens from assistant messages to prevent model confusion.
- */
-function convertToCerebrasMessages(openaiMessages: any[]): Array<{ role: string; content: string }> {
-	return openaiMessages
-		.map((msg) => {
-			let content = flattenMessageContent(msg.content)
-
-			// Strip thinking tokens from assistant messages to prevent confusion
-			if (msg.role === "assistant") {
-				content = stripThinkingTokens(content)
-			}
-
-			return {
-				role: msg.role,
-				content,
-			}
-		})
-		.filter((msg) => msg.content.trim() !== "") // Remove empty messages
-}
-
 export class CerebrasHandler extends BaseProvider implements SingleCompletionHandler {
 	private apiKey: string
 	private providerModels: typeof cerebrasModels
@@ -106,26 +44,70 @@ export class CerebrasHandler extends BaseProvider implements SingleCompletionHan
 		}
 	}
 
+	/**
+	 * Override convertToolSchemaForOpenAI to remove unsupported schema fields for Cerebras.
+	 * Cerebras doesn't support minItems/maxItems in array schemas with strict mode.
+	 */
+	protected override convertToolSchemaForOpenAI(schema: any): any {
+		const converted = super.convertToolSchemaForOpenAI(schema)
+		return this.stripUnsupportedSchemaFields(converted)
+	}
+
+	/**
+	 * Recursively strips unsupported schema fields for Cerebras.
+	 * Cerebras strict mode doesn't support minItems, maxItems on arrays.
+	 */
+	private stripUnsupportedSchemaFields(schema: any): any {
+		if (!schema || typeof schema !== "object") {
+			return schema
+		}
+
+		const result = { ...schema }
+
+		// Remove unsupported array constraints
+		if (result.type === "array" || (Array.isArray(result.type) && result.type.includes("array"))) {
+			delete result.minItems
+			delete result.maxItems
+		}
+
+		// Recursively process properties
+		if (result.properties) {
+			const newProps = { ...result.properties }
+			for (const key of Object.keys(newProps)) {
+				newProps[key] = this.stripUnsupportedSchemaFields(newProps[key])
+			}
+			result.properties = newProps
+		}
+
+		// Recursively process array items
+		if (result.items) {
+			result.items = this.stripUnsupportedSchemaFields(result.items)
+		}
+
+		return result
+	}
+
 	async *createMessage(
 		systemPrompt: string,
 		messages: Anthropic.Messages.MessageParam[],
 		metadata?: ApiHandlerCreateMessageMetadata,
 	): ApiStream {
-		const {
-			id: model,
-			info: { maxTokens: max_tokens },
-		} = this.getModel()
+		const { id: model, info: modelInfo } = this.getModel()
+		const max_tokens = modelInfo.maxTokens
+		const supportsNativeTools = modelInfo.supportsNativeTools ?? false
 		const temperature = this.options.modelTemperature ?? CEREBRAS_DEFAULT_TEMPERATURE
 
-		// Convert Anthropic messages to OpenAI format, then flatten for Cerebras
-		// This will automatically strip thinking tokens from assistant messages
+		// Check if we should use native tool calling
+		const useNativeTools =
+			supportsNativeTools && metadata?.tools && metadata.tools.length > 0 && metadata?.toolProtocol !== "xml"
+
+		// Convert Anthropic messages to OpenAI format (Cerebras is OpenAI-compatible)
 		const openaiMessages = convertToOpenAiMessages(messages)
-		const cerebrasMessages = convertToCerebrasMessages(openaiMessages)
 
 		// Prepare request body following Cerebras API specification exactly
-		const requestBody = {
+		const requestBody: Record<string, any> = {
 			model,
-			messages: [{ role: "system", content: systemPrompt }, ...cerebrasMessages],
+			messages: [{ role: "system", content: systemPrompt }, ...openaiMessages],
 			stream: true,
 			// Use max_completion_tokens (Cerebras-specific parameter)
 			...(max_tokens && max_tokens > 0 && max_tokens <= 32768 ? { max_completion_tokens: max_tokens } : {}),
@@ -135,6 +117,10 @@ export class CerebrasHandler extends BaseProvider implements SingleCompletionHan
 						temperature: Math.max(0, Math.min(1.5, temperature)),
 					}
 				: {}),
+			// Native tool calling support
+			...(useNativeTools && { tools: this.convertToolsForOpenAI(metadata.tools) }),
+			...(useNativeTools && metadata.tool_choice && { tool_choice: metadata.tool_choice }),
+			...(useNativeTools && { parallel_tool_calls: metadata?.parallelToolCalls ?? false }),
 		}
 
 		try {
@@ -216,9 +202,11 @@ export class CerebrasHandler extends BaseProvider implements SingleCompletionHan
 
 								const parsed = JSON.parse(jsonStr)
 
+								const delta = parsed.choices?.[0]?.delta
+
 								// Handle text content - parse for thinking tokens
-								if (parsed.choices?.[0]?.delta?.content) {
-									const content = parsed.choices[0].delta.content
+								if (delta?.content) {
+									const content = delta.content
 
 									// Use XmlMatcher to parse <think>...</think> tags
 									for (const chunk of matcher.update(content)) {
@@ -226,6 +214,19 @@ export class CerebrasHandler extends BaseProvider implements SingleCompletionHan
 									}
 								}
 
+								// Handle tool calls in stream - emit partial chunks for NativeToolCallParser
+								if (delta?.tool_calls) {
+									for (const toolCall of delta.tool_calls) {
+										yield {
+											type: "tool_call_partial",
+											index: toolCall.index,
+											id: toolCall.id,
+											name: toolCall.function?.name,
+											arguments: toolCall.function?.arguments,
+										}
+									}
+								}
+
 								// Handle usage information if available
 								if (parsed.usage) {
 									inputTokens = parsed.usage.prompt_tokens || 0
@@ -248,7 +249,11 @@ export class CerebrasHandler extends BaseProvider implements SingleCompletionHan
 
 			// Provide token usage estimate if not available from API
 			if (inputTokens === 0 || outputTokens === 0) {
-				const inputText = systemPrompt + cerebrasMessages.map((m) => m.content).join("")
+				const inputText =
+					systemPrompt +
+					openaiMessages
+						.map((m: any) => (typeof m.content === "string" ? m.content : JSON.stringify(m.content)))
+						.join("")
 				inputTokens = inputTokens || Math.ceil(inputText.length / 4) // Rough estimate: 4 chars per token
 				outputTokens = outputTokens || Math.ceil((max_tokens || 1000) / 10) // Rough estimate
 			}