|
|
@@ -1,33 +1,64 @@
|
|
|
// npx vitest run api/providers/__tests__/chutes.spec.ts
|
|
|
|
|
|
-import { vitest, describe, it, expect, beforeEach } from "vitest"
|
|
|
-import OpenAI from "openai"
|
|
|
import { Anthropic } from "@anthropic-ai/sdk"
|
|
|
+import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"
|
|
|
+import OpenAI from "openai"
|
|
|
|
|
|
-import { type ChutesModelId, chutesDefaultModelId, chutesModels } from "@roo-code/types"
|
|
|
+import { type ChutesModelId, chutesDefaultModelId, chutesModels, DEEP_SEEK_DEFAULT_TEMPERATURE } from "@roo-code/types"
|
|
|
|
|
|
import { ChutesHandler } from "../chutes"
|
|
|
|
|
|
-const mockCreate = vitest.fn()
|
|
|
+// Create mock functions
|
|
|
+const mockCreate = vi.fn()
|
|
|
|
|
|
-vitest.mock("openai", () => {
|
|
|
- return {
|
|
|
- default: vitest.fn().mockImplementation(() => ({
|
|
|
- chat: {
|
|
|
- completions: {
|
|
|
- create: mockCreate,
|
|
|
- },
|
|
|
+// Mock OpenAI module
|
|
|
+vi.mock("openai", () => ({
|
|
|
+ default: vi.fn(() => ({
|
|
|
+ chat: {
|
|
|
+ completions: {
|
|
|
+ create: mockCreate,
|
|
|
},
|
|
|
- })),
|
|
|
- }
|
|
|
-})
|
|
|
+ },
|
|
|
+ })),
|
|
|
+}))
|
|
|
|
|
|
describe("ChutesHandler", () => {
|
|
|
let handler: ChutesHandler
|
|
|
|
|
|
beforeEach(() => {
|
|
|
- vitest.clearAllMocks()
|
|
|
- handler = new ChutesHandler({ chutesApiKey: "test-chutes-api-key" })
|
|
|
+ vi.clearAllMocks()
|
|
|
+ // Set up default mock implementation
|
|
|
+ mockCreate.mockImplementation(async () => ({
|
|
|
+ [Symbol.asyncIterator]: async function* () {
|
|
|
+ yield {
|
|
|
+ choices: [
|
|
|
+ {
|
|
|
+ delta: { content: "Test response" },
|
|
|
+ index: 0,
|
|
|
+ },
|
|
|
+ ],
|
|
|
+ usage: null,
|
|
|
+ }
|
|
|
+ yield {
|
|
|
+ choices: [
|
|
|
+ {
|
|
|
+ delta: {},
|
|
|
+ index: 0,
|
|
|
+ },
|
|
|
+ ],
|
|
|
+ usage: {
|
|
|
+ prompt_tokens: 10,
|
|
|
+ completion_tokens: 5,
|
|
|
+ total_tokens: 15,
|
|
|
+ },
|
|
|
+ }
|
|
|
+ },
|
|
|
+ }))
|
|
|
+ handler = new ChutesHandler({ chutesApiKey: "test-key" })
|
|
|
+ })
|
|
|
+
|
|
|
+ afterEach(() => {
|
|
|
+ vi.restoreAllMocks()
|
|
|
})
|
|
|
|
|
|
it("should use the correct Chutes base URL", () => {
|
|
|
@@ -41,18 +72,96 @@ describe("ChutesHandler", () => {
|
|
|
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: chutesApiKey }))
|
|
|
})
|
|
|
|
|
|
+ it("should handle DeepSeek R1 reasoning format", async () => {
|
|
|
+ // Override the mock for this specific test
|
|
|
+ mockCreate.mockImplementationOnce(async () => ({
|
|
|
+ [Symbol.asyncIterator]: async function* () {
|
|
|
+ yield {
|
|
|
+ choices: [
|
|
|
+ {
|
|
|
+ delta: { content: "<think>Thinking..." },
|
|
|
+ index: 0,
|
|
|
+ },
|
|
|
+ ],
|
|
|
+ usage: null,
|
|
|
+ }
|
|
|
+ yield {
|
|
|
+ choices: [
|
|
|
+ {
|
|
|
+ delta: { content: "</think>Hello" },
|
|
|
+ index: 0,
|
|
|
+ },
|
|
|
+ ],
|
|
|
+ usage: null,
|
|
|
+ }
|
|
|
+ yield {
|
|
|
+ choices: [
|
|
|
+ {
|
|
|
+ delta: {},
|
|
|
+ index: 0,
|
|
|
+ },
|
|
|
+ ],
|
|
|
+ usage: { prompt_tokens: 10, completion_tokens: 5 },
|
|
|
+ }
|
|
|
+ },
|
|
|
+ }))
|
|
|
+
|
|
|
+ const systemPrompt = "You are a helpful assistant."
|
|
|
+ const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hi" }]
|
|
|
+ vi.spyOn(handler, "getModel").mockReturnValue({
|
|
|
+ id: "deepseek-ai/DeepSeek-R1-0528",
|
|
|
+ info: { maxTokens: 1024, temperature: 0.7 },
|
|
|
+ } as any)
|
|
|
+
|
|
|
+ const stream = handler.createMessage(systemPrompt, messages)
|
|
|
+ const chunks = []
|
|
|
+ for await (const chunk of stream) {
|
|
|
+ chunks.push(chunk)
|
|
|
+ }
|
|
|
+
|
|
|
+ expect(chunks).toEqual([
|
|
|
+ { type: "reasoning", text: "Thinking..." },
|
|
|
+ { type: "text", text: "Hello" },
|
|
|
+ { type: "usage", inputTokens: 10, outputTokens: 5 },
|
|
|
+ ])
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should fall back to base provider for non-DeepSeek models", async () => {
|
|
|
+ // Use default mock implementation which returns text content
|
|
|
+ const systemPrompt = "You are a helpful assistant."
|
|
|
+ const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hi" }]
|
|
|
+ vi.spyOn(handler, "getModel").mockReturnValue({
|
|
|
+ id: "some-other-model",
|
|
|
+ info: { maxTokens: 1024, temperature: 0.7 },
|
|
|
+ } as any)
|
|
|
+
|
|
|
+ const stream = handler.createMessage(systemPrompt, messages)
|
|
|
+ const chunks = []
|
|
|
+ for await (const chunk of stream) {
|
|
|
+ chunks.push(chunk)
|
|
|
+ }
|
|
|
+
|
|
|
+ expect(chunks).toEqual([
|
|
|
+ { type: "text", text: "Test response" },
|
|
|
+ { type: "usage", inputTokens: 10, outputTokens: 5 },
|
|
|
+ ])
|
|
|
+ })
|
|
|
+
|
|
|
it("should return default model when no model is specified", () => {
|
|
|
const model = handler.getModel()
|
|
|
expect(model.id).toBe(chutesDefaultModelId)
|
|
|
- expect(model.info).toEqual(chutesModels[chutesDefaultModelId])
|
|
|
+ expect(model.info).toEqual(expect.objectContaining(chutesModels[chutesDefaultModelId]))
|
|
|
})
|
|
|
|
|
|
it("should return specified model when valid model is provided", () => {
|
|
|
const testModelId: ChutesModelId = "deepseek-ai/DeepSeek-R1"
|
|
|
- const handlerWithModel = new ChutesHandler({ apiModelId: testModelId, chutesApiKey: "test-chutes-api-key" })
|
|
|
+ const handlerWithModel = new ChutesHandler({
|
|
|
+ apiModelId: testModelId,
|
|
|
+ chutesApiKey: "test-chutes-api-key",
|
|
|
+ })
|
|
|
const model = handlerWithModel.getModel()
|
|
|
expect(model.id).toBe(testModelId)
|
|
|
- expect(model.info).toEqual(chutesModels[testModelId])
|
|
|
+ expect(model.info).toEqual(expect.objectContaining(chutesModels[testModelId]))
|
|
|
})
|
|
|
|
|
|
it("completePrompt method should return text from Chutes API", async () => {
|
|
|
@@ -74,7 +183,7 @@ describe("ChutesHandler", () => {
|
|
|
mockCreate.mockImplementationOnce(() => {
|
|
|
return {
|
|
|
[Symbol.asyncIterator]: () => ({
|
|
|
- next: vitest
|
|
|
+ next: vi
|
|
|
.fn()
|
|
|
.mockResolvedValueOnce({
|
|
|
done: false,
|
|
|
@@ -96,7 +205,7 @@ describe("ChutesHandler", () => {
|
|
|
mockCreate.mockImplementationOnce(() => {
|
|
|
return {
|
|
|
[Symbol.asyncIterator]: () => ({
|
|
|
- next: vitest
|
|
|
+ next: vi
|
|
|
.fn()
|
|
|
.mockResolvedValueOnce({
|
|
|
done: false,
|
|
|
@@ -114,8 +223,43 @@ describe("ChutesHandler", () => {
|
|
|
expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 })
|
|
|
})
|
|
|
|
|
|
- it("createMessage should pass correct parameters to Chutes client", async () => {
|
|
|
+ it("createMessage should pass correct parameters to Chutes client for DeepSeek R1", async () => {
|
|
|
const modelId: ChutesModelId = "deepseek-ai/DeepSeek-R1"
|
|
|
+
|
|
|
+ // Clear previous mocks and set up new implementation
|
|
|
+ mockCreate.mockClear()
|
|
|
+ mockCreate.mockImplementationOnce(async () => ({
|
|
|
+ [Symbol.asyncIterator]: async function* () {
|
|
|
+ // Empty stream for this test
|
|
|
+ },
|
|
|
+ }))
|
|
|
+
|
|
|
+ const handlerWithModel = new ChutesHandler({
|
|
|
+ apiModelId: modelId,
|
|
|
+ chutesApiKey: "test-chutes-api-key",
|
|
|
+ })
|
|
|
+
|
|
|
+ const systemPrompt = "Test system prompt for Chutes"
|
|
|
+ const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message for Chutes" }]
|
|
|
+
|
|
|
+ const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages)
|
|
|
+ await messageGenerator.next()
|
|
|
+
|
|
|
+ expect(mockCreate).toHaveBeenCalledWith(
|
|
|
+ expect.objectContaining({
|
|
|
+ model: modelId,
|
|
|
+ messages: [
|
|
|
+ {
|
|
|
+ role: "user",
|
|
|
+ content: `${systemPrompt}\n${messages[0].content}`,
|
|
|
+ },
|
|
|
+ ],
|
|
|
+ }),
|
|
|
+ )
|
|
|
+ })
|
|
|
+
|
|
|
+ it("createMessage should pass correct parameters to Chutes client for non-DeepSeek models", async () => {
|
|
|
+ const modelId: ChutesModelId = "unsloth/Llama-3.3-70B-Instruct"
|
|
|
const modelInfo = chutesModels[modelId]
|
|
|
const handlerWithModel = new ChutesHandler({ apiModelId: modelId, chutesApiKey: "test-chutes-api-key" })
|
|
|
|
|
|
@@ -146,4 +290,24 @@ describe("ChutesHandler", () => {
|
|
|
}),
|
|
|
)
|
|
|
})
|
|
|
+
|
|
|
+ it("should apply DeepSeek default temperature for R1 models", () => {
|
|
|
+ const testModelId: ChutesModelId = "deepseek-ai/DeepSeek-R1"
|
|
|
+ const handlerWithModel = new ChutesHandler({
|
|
|
+ apiModelId: testModelId,
|
|
|
+ chutesApiKey: "test-chutes-api-key",
|
|
|
+ })
|
|
|
+ const model = handlerWithModel.getModel()
|
|
|
+ expect(model.info.temperature).toBe(DEEP_SEEK_DEFAULT_TEMPERATURE)
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should use default temperature for non-DeepSeek models", () => {
|
|
|
+ const testModelId: ChutesModelId = "unsloth/Llama-3.3-70B-Instruct"
|
|
|
+ const handlerWithModel = new ChutesHandler({
|
|
|
+ apiModelId: testModelId,
|
|
|
+ chutesApiKey: "test-chutes-api-key",
|
|
|
+ })
|
|
|
+ const model = handlerWithModel.getModel()
|
|
|
+ expect(model.info.temperature).toBe(0.5)
|
|
|
+ })
|
|
|
})
|