|
|
@@ -16,68 +16,6 @@ import { t } from "../../i18n"
|
|
|
const CEREBRAS_BASE_URL = "https://api.cerebras.ai/v1"
|
|
|
const CEREBRAS_DEFAULT_TEMPERATURE = 0
|
|
|
|
|
|
-/**
|
|
|
- * Removes thinking tokens from text to prevent model confusion when processing conversation history.
|
|
|
- * This is crucial because models can get confused by their own thinking tokens in input.
|
|
|
- */
|
|
|
-function stripThinkingTokens(text: string): string {
|
|
|
- // Remove <think>...</think> blocks entirely, including nested ones
|
|
|
- return text.replace(/<think>[\s\S]*?<\/think>/g, "").trim()
|
|
|
-}
|
|
|
-
|
|
|
-/**
|
|
|
- * Flattens OpenAI message content to simple strings that Cerebras can handle.
|
|
|
- * Cerebras doesn't support complex content arrays like OpenAI does.
|
|
|
- */
|
|
|
-function flattenMessageContent(content: any): string {
|
|
|
- if (typeof content === "string") {
|
|
|
- return content
|
|
|
- }
|
|
|
-
|
|
|
- if (Array.isArray(content)) {
|
|
|
- return content
|
|
|
- .map((part) => {
|
|
|
- if (typeof part === "string") {
|
|
|
- return part
|
|
|
- }
|
|
|
- if (part.type === "text") {
|
|
|
- return part.text || ""
|
|
|
- }
|
|
|
- if (part.type === "image_url") {
|
|
|
- return "[Image]" // Placeholder for images since Cerebras doesn't support images
|
|
|
- }
|
|
|
- return ""
|
|
|
- })
|
|
|
- .filter(Boolean)
|
|
|
- .join("\n")
|
|
|
- }
|
|
|
-
|
|
|
- // Fallback for any other content types
|
|
|
- return String(content || "")
|
|
|
-}
|
|
|
-
|
|
|
-/**
|
|
|
- * Converts OpenAI messages to Cerebras-compatible format with simple string content.
|
|
|
- * Also strips thinking tokens from assistant messages to prevent model confusion.
|
|
|
- */
|
|
|
-function convertToCerebrasMessages(openaiMessages: any[]): Array<{ role: string; content: string }> {
|
|
|
- return openaiMessages
|
|
|
- .map((msg) => {
|
|
|
- let content = flattenMessageContent(msg.content)
|
|
|
-
|
|
|
- // Strip thinking tokens from assistant messages to prevent confusion
|
|
|
- if (msg.role === "assistant") {
|
|
|
- content = stripThinkingTokens(content)
|
|
|
- }
|
|
|
-
|
|
|
- return {
|
|
|
- role: msg.role,
|
|
|
- content,
|
|
|
- }
|
|
|
- })
|
|
|
- .filter((msg) => msg.content.trim() !== "") // Remove empty messages
|
|
|
-}
|
|
|
-
|
|
|
export class CerebrasHandler extends BaseProvider implements SingleCompletionHandler {
|
|
|
private apiKey: string
|
|
|
private providerModels: typeof cerebrasModels
|
|
|
@@ -106,26 +44,70 @@ export class CerebrasHandler extends BaseProvider implements SingleCompletionHan
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ /**
|
|
|
+ * Override convertToolSchemaForOpenAI to remove unsupported schema fields for Cerebras.
|
|
|
+ * Cerebras doesn't support minItems/maxItems in array schemas with strict mode.
|
|
|
+ */
|
|
|
+ protected override convertToolSchemaForOpenAI(schema: any): any {
|
|
|
+ const converted = super.convertToolSchemaForOpenAI(schema)
|
|
|
+ return this.stripUnsupportedSchemaFields(converted)
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Recursively strips unsupported schema fields for Cerebras.
|
|
|
+ * Cerebras strict mode doesn't support minItems, maxItems on arrays.
|
|
|
+ */
|
|
|
+ private stripUnsupportedSchemaFields(schema: any): any {
|
|
|
+ if (!schema || typeof schema !== "object") {
|
|
|
+ return schema
|
|
|
+ }
|
|
|
+
|
|
|
+ const result = { ...schema }
|
|
|
+
|
|
|
+ // Remove unsupported array constraints
|
|
|
+ if (result.type === "array" || (Array.isArray(result.type) && result.type.includes("array"))) {
|
|
|
+ delete result.minItems
|
|
|
+ delete result.maxItems
|
|
|
+ }
|
|
|
+
|
|
|
+ // Recursively process properties
|
|
|
+ if (result.properties) {
|
|
|
+ const newProps = { ...result.properties }
|
|
|
+ for (const key of Object.keys(newProps)) {
|
|
|
+ newProps[key] = this.stripUnsupportedSchemaFields(newProps[key])
|
|
|
+ }
|
|
|
+ result.properties = newProps
|
|
|
+ }
|
|
|
+
|
|
|
+ // Recursively process array items
|
|
|
+ if (result.items) {
|
|
|
+ result.items = this.stripUnsupportedSchemaFields(result.items)
|
|
|
+ }
|
|
|
+
|
|
|
+ return result
|
|
|
+ }
|
|
|
+
|
|
|
async *createMessage(
|
|
|
systemPrompt: string,
|
|
|
messages: Anthropic.Messages.MessageParam[],
|
|
|
metadata?: ApiHandlerCreateMessageMetadata,
|
|
|
): ApiStream {
|
|
|
- const {
|
|
|
- id: model,
|
|
|
- info: { maxTokens: max_tokens },
|
|
|
- } = this.getModel()
|
|
|
+ const { id: model, info: modelInfo } = this.getModel()
|
|
|
+ const max_tokens = modelInfo.maxTokens
|
|
|
+ const supportsNativeTools = modelInfo.supportsNativeTools ?? false
|
|
|
const temperature = this.options.modelTemperature ?? CEREBRAS_DEFAULT_TEMPERATURE
|
|
|
|
|
|
- // Convert Anthropic messages to OpenAI format, then flatten for Cerebras
|
|
|
- // This will automatically strip thinking tokens from assistant messages
|
|
|
+ // Check if we should use native tool calling
|
|
|
+ const useNativeTools =
|
|
|
+ supportsNativeTools && metadata?.tools && metadata.tools.length > 0 && metadata?.toolProtocol !== "xml"
|
|
|
+
|
|
|
+ // Convert Anthropic messages to OpenAI format (Cerebras is OpenAI-compatible)
|
|
|
const openaiMessages = convertToOpenAiMessages(messages)
|
|
|
- const cerebrasMessages = convertToCerebrasMessages(openaiMessages)
|
|
|
|
|
|
// Prepare request body following Cerebras API specification exactly
|
|
|
- const requestBody = {
|
|
|
+ const requestBody: Record<string, any> = {
|
|
|
model,
|
|
|
- messages: [{ role: "system", content: systemPrompt }, ...cerebrasMessages],
|
|
|
+ messages: [{ role: "system", content: systemPrompt }, ...openaiMessages],
|
|
|
stream: true,
|
|
|
// Use max_completion_tokens (Cerebras-specific parameter)
|
|
|
...(max_tokens && max_tokens > 0 && max_tokens <= 32768 ? { max_completion_tokens: max_tokens } : {}),
|
|
|
@@ -135,6 +117,10 @@ export class CerebrasHandler extends BaseProvider implements SingleCompletionHan
|
|
|
temperature: Math.max(0, Math.min(1.5, temperature)),
|
|
|
}
|
|
|
: {}),
|
|
|
+ // Native tool calling support
|
|
|
+ ...(useNativeTools && { tools: this.convertToolsForOpenAI(metadata.tools) }),
|
|
|
+ ...(useNativeTools && metadata.tool_choice && { tool_choice: metadata.tool_choice }),
|
|
|
+ ...(useNativeTools && { parallel_tool_calls: metadata?.parallelToolCalls ?? false }),
|
|
|
}
|
|
|
|
|
|
try {
|
|
|
@@ -216,9 +202,11 @@ export class CerebrasHandler extends BaseProvider implements SingleCompletionHan
|
|
|
|
|
|
const parsed = JSON.parse(jsonStr)
|
|
|
|
|
|
+ const delta = parsed.choices?.[0]?.delta
|
|
|
+
|
|
|
// Handle text content - parse for thinking tokens
|
|
|
- if (parsed.choices?.[0]?.delta?.content) {
|
|
|
- const content = parsed.choices[0].delta.content
|
|
|
+ if (delta?.content) {
|
|
|
+ const content = delta.content
|
|
|
|
|
|
// Use XmlMatcher to parse <think>...</think> tags
|
|
|
for (const chunk of matcher.update(content)) {
|
|
|
@@ -226,6 +214,19 @@ export class CerebrasHandler extends BaseProvider implements SingleCompletionHan
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ // Handle tool calls in stream - emit partial chunks for NativeToolCallParser
|
|
|
+ if (delta?.tool_calls) {
|
|
|
+ for (const toolCall of delta.tool_calls) {
|
|
|
+ yield {
|
|
|
+ type: "tool_call_partial",
|
|
|
+ index: toolCall.index,
|
|
|
+ id: toolCall.id,
|
|
|
+ name: toolCall.function?.name,
|
|
|
+ arguments: toolCall.function?.arguments,
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
// Handle usage information if available
|
|
|
if (parsed.usage) {
|
|
|
inputTokens = parsed.usage.prompt_tokens || 0
|
|
|
@@ -248,7 +249,11 @@ export class CerebrasHandler extends BaseProvider implements SingleCompletionHan
|
|
|
|
|
|
// Provide token usage estimate if not available from API
|
|
|
if (inputTokens === 0 || outputTokens === 0) {
|
|
|
- const inputText = systemPrompt + cerebrasMessages.map((m) => m.content).join("")
|
|
|
+ const inputText =
|
|
|
+ systemPrompt +
|
|
|
+ openaiMessages
|
|
|
+ .map((m: any) => (typeof m.content === "string" ? m.content : JSON.stringify(m.content)))
|
|
|
+ .join("")
|
|
|
inputTokens = inputTokens || Math.ceil(inputText.length / 4) // Rough estimate: 4 chars per token
|
|
|
outputTokens = outputTokens || Math.ceil((max_tokens || 1000) / 10) // Rough estimate
|
|
|
}
|