|
|
@@ -1,67 +1,28 @@
|
|
|
-// Mocks must come first, before imports
|
|
|
-const mockCreate = vi.fn()
|
|
|
-vi.mock("openai", () => {
|
|
|
+// Use vi.hoisted to define mock functions that can be referenced in hoisted vi.mock() calls
|
|
|
+const { mockStreamText, mockGenerateText } = vi.hoisted(() => ({
|
|
|
+ mockStreamText: vi.fn(),
|
|
|
+ mockGenerateText: vi.fn(),
|
|
|
+}))
|
|
|
+
|
|
|
+vi.mock("ai", async (importOriginal) => {
|
|
|
+ const actual = await importOriginal<typeof import("ai")>()
|
|
|
return {
|
|
|
- __esModule: true,
|
|
|
- default: vi.fn().mockImplementation(() => ({
|
|
|
- chat: {
|
|
|
- completions: {
|
|
|
- create: mockCreate.mockImplementation(async (options) => {
|
|
|
- if (!options.stream) {
|
|
|
- return {
|
|
|
- id: "test-completion",
|
|
|
- choices: [
|
|
|
- {
|
|
|
- message: { role: "assistant", content: "Test response", refusal: null },
|
|
|
- finish_reason: "stop",
|
|
|
- index: 0,
|
|
|
- },
|
|
|
- ],
|
|
|
- usage: {
|
|
|
- prompt_tokens: 10,
|
|
|
- completion_tokens: 5,
|
|
|
- total_tokens: 15,
|
|
|
- cached_tokens: 2,
|
|
|
- },
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- // Return async iterator for streaming
|
|
|
- return {
|
|
|
- [Symbol.asyncIterator]: async function* () {
|
|
|
- yield {
|
|
|
- choices: [
|
|
|
- {
|
|
|
- delta: { content: "Test response" },
|
|
|
- index: 0,
|
|
|
- },
|
|
|
- ],
|
|
|
- usage: null,
|
|
|
- }
|
|
|
- yield {
|
|
|
- choices: [
|
|
|
- {
|
|
|
- delta: {},
|
|
|
- index: 0,
|
|
|
- },
|
|
|
- ],
|
|
|
- usage: {
|
|
|
- prompt_tokens: 10,
|
|
|
- completion_tokens: 5,
|
|
|
- total_tokens: 15,
|
|
|
- cached_tokens: 2,
|
|
|
- },
|
|
|
- }
|
|
|
- },
|
|
|
- }
|
|
|
- }),
|
|
|
- },
|
|
|
- },
|
|
|
- })),
|
|
|
+ ...actual,
|
|
|
+ streamText: mockStreamText,
|
|
|
+ generateText: mockGenerateText,
|
|
|
}
|
|
|
})
|
|
|
|
|
|
-import OpenAI from "openai"
|
|
|
+vi.mock("@ai-sdk/openai-compatible", () => ({
|
|
|
+ createOpenAICompatible: vi.fn(() => {
|
|
|
+ // Return a function that returns a mock language model
|
|
|
+ return vi.fn(() => ({
|
|
|
+ modelId: "moonshot-chat",
|
|
|
+ provider: "moonshot",
|
|
|
+ }))
|
|
|
+ }),
|
|
|
+}))
|
|
|
+
|
|
|
import type { Anthropic } from "@anthropic-ai/sdk"
|
|
|
|
|
|
import { moonshotDefaultModelId } from "@roo-code/types"
|
|
|
@@ -90,15 +51,6 @@ describe("MoonshotHandler", () => {
|
|
|
expect(handler.getModel().id).toBe(mockOptions.apiModelId)
|
|
|
})
|
|
|
|
|
|
- it.skip("should throw error if API key is missing", () => {
|
|
|
- expect(() => {
|
|
|
- new MoonshotHandler({
|
|
|
- ...mockOptions,
|
|
|
- moonshotApiKey: undefined,
|
|
|
- })
|
|
|
- }).toThrow("Moonshot API key is required")
|
|
|
- })
|
|
|
-
|
|
|
it("should use default model ID if not provided", () => {
|
|
|
const handlerWithoutModel = new MoonshotHandler({
|
|
|
...mockOptions,
|
|
|
@@ -113,12 +65,6 @@ describe("MoonshotHandler", () => {
|
|
|
moonshotBaseUrl: undefined,
|
|
|
})
|
|
|
expect(handlerWithoutBaseUrl).toBeInstanceOf(MoonshotHandler)
|
|
|
- // The base URL is passed to OpenAI client internally
|
|
|
- expect(OpenAI).toHaveBeenCalledWith(
|
|
|
- expect.objectContaining({
|
|
|
- baseURL: "https://api.moonshot.ai/v1",
|
|
|
- }),
|
|
|
- )
|
|
|
})
|
|
|
|
|
|
it("should use chinese base URL if provided", () => {
|
|
|
@@ -128,18 +74,6 @@ describe("MoonshotHandler", () => {
|
|
|
moonshotBaseUrl: customBaseUrl,
|
|
|
})
|
|
|
expect(handlerWithCustomUrl).toBeInstanceOf(MoonshotHandler)
|
|
|
- // The custom base URL is passed to OpenAI client
|
|
|
- expect(OpenAI).toHaveBeenCalledWith(
|
|
|
- expect.objectContaining({
|
|
|
- baseURL: customBaseUrl,
|
|
|
- }),
|
|
|
- )
|
|
|
- })
|
|
|
-
|
|
|
- it("should set includeMaxTokens to true", () => {
|
|
|
- // Create a new handler and verify OpenAI client was called with includeMaxTokens
|
|
|
- const _handler = new MoonshotHandler(mockOptions)
|
|
|
- expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: mockOptions.moonshotApiKey }))
|
|
|
})
|
|
|
})
|
|
|
|
|
|
@@ -151,7 +85,7 @@ describe("MoonshotHandler", () => {
|
|
|
expect(model.info.maxTokens).toBe(16384)
|
|
|
expect(model.info.contextWindow).toBe(262144)
|
|
|
expect(model.info.supportsImages).toBe(false)
|
|
|
- expect(model.info.supportsPromptCache).toBe(true) // Should be true now
|
|
|
+ expect(model.info.supportsPromptCache).toBe(true)
|
|
|
})
|
|
|
|
|
|
it("should return provided model ID with default model info if model does not exist", () => {
|
|
|
@@ -162,11 +96,8 @@ describe("MoonshotHandler", () => {
|
|
|
const model = handlerWithInvalidModel.getModel()
|
|
|
expect(model.id).toBe("invalid-model") // Returns provided ID
|
|
|
expect(model.info).toBeDefined()
|
|
|
- // With the current implementation, it's the same object reference when using default model info
|
|
|
- expect(model.info).toBe(handler.getModel().info)
|
|
|
- // Should have the same base properties
|
|
|
+ // Should have the same base properties as default model
|
|
|
expect(model.info.contextWindow).toBe(handler.getModel().info.contextWindow)
|
|
|
- // And should have supportsPromptCache set to true
|
|
|
expect(model.info.supportsPromptCache).toBe(true)
|
|
|
})
|
|
|
|
|
|
@@ -203,6 +134,24 @@ describe("MoonshotHandler", () => {
|
|
|
]
|
|
|
|
|
|
it("should handle streaming responses", async () => {
|
|
|
+ // Mock the fullStream async generator
|
|
|
+ async function* mockFullStream() {
|
|
|
+ yield { type: "text-delta", text: "Test response" }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Mock usage promise
|
|
|
+ const mockUsage = Promise.resolve({
|
|
|
+ inputTokens: 10,
|
|
|
+ outputTokens: 5,
|
|
|
+ details: { cachedInputTokens: undefined },
|
|
|
+ raw: { cached_tokens: 2 },
|
|
|
+ })
|
|
|
+
|
|
|
+ mockStreamText.mockReturnValue({
|
|
|
+ fullStream: mockFullStream(),
|
|
|
+ usage: mockUsage,
|
|
|
+ })
|
|
|
+
|
|
|
const stream = handler.createMessage(systemPrompt, messages)
|
|
|
const chunks: any[] = []
|
|
|
for await (const chunk of stream) {
|
|
|
@@ -216,6 +165,22 @@ describe("MoonshotHandler", () => {
|
|
|
})
|
|
|
|
|
|
it("should include usage information", async () => {
|
|
|
+ async function* mockFullStream() {
|
|
|
+ yield { type: "text-delta", text: "Test response" }
|
|
|
+ }
|
|
|
+
|
|
|
+ const mockUsage = Promise.resolve({
|
|
|
+ inputTokens: 10,
|
|
|
+ outputTokens: 5,
|
|
|
+ details: {},
|
|
|
+ raw: { cached_tokens: 2 },
|
|
|
+ })
|
|
|
+
|
|
|
+ mockStreamText.mockReturnValue({
|
|
|
+ fullStream: mockFullStream(),
|
|
|
+ usage: mockUsage,
|
|
|
+ })
|
|
|
+
|
|
|
const stream = handler.createMessage(systemPrompt, messages)
|
|
|
const chunks: any[] = []
|
|
|
for await (const chunk of stream) {
|
|
|
@@ -229,6 +194,22 @@ describe("MoonshotHandler", () => {
|
|
|
})
|
|
|
|
|
|
it("should include cache metrics in usage information", async () => {
|
|
|
+ async function* mockFullStream() {
|
|
|
+ yield { type: "text-delta", text: "Test response" }
|
|
|
+ }
|
|
|
+
|
|
|
+ const mockUsage = Promise.resolve({
|
|
|
+ inputTokens: 10,
|
|
|
+ outputTokens: 5,
|
|
|
+ details: {},
|
|
|
+ raw: { cached_tokens: 2 },
|
|
|
+ })
|
|
|
+
|
|
|
+ mockStreamText.mockReturnValue({
|
|
|
+ fullStream: mockFullStream(),
|
|
|
+ usage: mockUsage,
|
|
|
+ })
|
|
|
+
|
|
|
const stream = handler.createMessage(systemPrompt, messages)
|
|
|
const chunks: any[] = []
|
|
|
for await (const chunk of stream) {
|
|
|
@@ -242,6 +223,23 @@ describe("MoonshotHandler", () => {
|
|
|
})
|
|
|
})
|
|
|
|
|
|
+ describe("completePrompt", () => {
|
|
|
+ it("should complete a prompt using generateText", async () => {
|
|
|
+ mockGenerateText.mockResolvedValue({
|
|
|
+ text: "Test completion",
|
|
|
+ })
|
|
|
+
|
|
|
+ const result = await handler.completePrompt("Test prompt")
|
|
|
+
|
|
|
+ expect(result).toBe("Test completion")
|
|
|
+ expect(mockGenerateText).toHaveBeenCalledWith(
|
|
|
+ expect.objectContaining({
|
|
|
+ prompt: "Test prompt",
|
|
|
+ }),
|
|
|
+ )
|
|
|
+ })
|
|
|
+ })
|
|
|
+
|
|
|
describe("processUsageMetrics", () => {
|
|
|
it("should correctly process usage metrics including cache information", () => {
|
|
|
// We need to access the protected method, so we'll create a test subclass
|
|
|
@@ -254,10 +252,12 @@ describe("MoonshotHandler", () => {
|
|
|
const testHandler = new TestMoonshotHandler(mockOptions)
|
|
|
|
|
|
const usage = {
|
|
|
- prompt_tokens: 100,
|
|
|
- completion_tokens: 50,
|
|
|
- total_tokens: 150,
|
|
|
- cached_tokens: 20,
|
|
|
+ inputTokens: 100,
|
|
|
+ outputTokens: 50,
|
|
|
+ details: {},
|
|
|
+ raw: {
|
|
|
+ cached_tokens: 20,
|
|
|
+ },
|
|
|
}
|
|
|
|
|
|
const result = testHandler.testProcessUsageMetrics(usage)
|
|
|
@@ -279,10 +279,10 @@ describe("MoonshotHandler", () => {
|
|
|
const testHandler = new TestMoonshotHandler(mockOptions)
|
|
|
|
|
|
const usage = {
|
|
|
- prompt_tokens: 100,
|
|
|
- completion_tokens: 50,
|
|
|
- total_tokens: 150,
|
|
|
- // No cached_tokens
|
|
|
+ inputTokens: 100,
|
|
|
+ outputTokens: 50,
|
|
|
+ details: {},
|
|
|
+ raw: {},
|
|
|
}
|
|
|
|
|
|
const result = testHandler.testProcessUsageMetrics(usage)
|
|
|
@@ -295,31 +295,25 @@ describe("MoonshotHandler", () => {
|
|
|
})
|
|
|
})
|
|
|
|
|
|
- describe("addMaxTokensIfNeeded", () => {
|
|
|
- it("should always add max_tokens regardless of includeMaxTokens option", () => {
|
|
|
- // Create a test subclass to access the protected method
|
|
|
+ describe("getMaxOutputTokens", () => {
|
|
|
+ it("should return maxTokens from model info", () => {
|
|
|
class TestMoonshotHandler extends MoonshotHandler {
|
|
|
- public testAddMaxTokensIfNeeded(requestOptions: any, modelInfo: any) {
|
|
|
- this.addMaxTokensIfNeeded(requestOptions, modelInfo)
|
|
|
+ public testGetMaxOutputTokens() {
|
|
|
+ return this.getMaxOutputTokens()
|
|
|
}
|
|
|
}
|
|
|
|
|
|
const testHandler = new TestMoonshotHandler(mockOptions)
|
|
|
- const requestOptions: any = {}
|
|
|
- const modelInfo = {
|
|
|
- maxTokens: 32_000,
|
|
|
- }
|
|
|
-
|
|
|
- // Test with includeMaxTokens set to false - should still add max tokens
|
|
|
- testHandler.testAddMaxTokensIfNeeded(requestOptions, modelInfo)
|
|
|
+ const result = testHandler.testGetMaxOutputTokens()
|
|
|
|
|
|
- expect(requestOptions.max_tokens).toBe(32_000)
|
|
|
+ // Default model maxTokens is 16384
|
|
|
+ expect(result).toBe(16384)
|
|
|
})
|
|
|
|
|
|
it("should use modelMaxTokens when provided", () => {
|
|
|
class TestMoonshotHandler extends MoonshotHandler {
|
|
|
- public testAddMaxTokensIfNeeded(requestOptions: any, modelInfo: any) {
|
|
|
- this.addMaxTokensIfNeeded(requestOptions, modelInfo)
|
|
|
+ public testGetMaxOutputTokens() {
|
|
|
+ return this.getMaxOutputTokens()
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -328,32 +322,153 @@ describe("MoonshotHandler", () => {
|
|
|
...mockOptions,
|
|
|
modelMaxTokens: customMaxTokens,
|
|
|
})
|
|
|
- const requestOptions: any = {}
|
|
|
- const modelInfo = {
|
|
|
- maxTokens: 32_000,
|
|
|
- }
|
|
|
|
|
|
- testHandler.testAddMaxTokensIfNeeded(requestOptions, modelInfo)
|
|
|
-
|
|
|
- expect(requestOptions.max_tokens).toBe(customMaxTokens)
|
|
|
+ const result = testHandler.testGetMaxOutputTokens()
|
|
|
+ expect(result).toBe(customMaxTokens)
|
|
|
})
|
|
|
|
|
|
it("should fall back to modelInfo.maxTokens when modelMaxTokens is not provided", () => {
|
|
|
class TestMoonshotHandler extends MoonshotHandler {
|
|
|
- public testAddMaxTokensIfNeeded(requestOptions: any, modelInfo: any) {
|
|
|
- this.addMaxTokensIfNeeded(requestOptions, modelInfo)
|
|
|
+ public testGetMaxOutputTokens() {
|
|
|
+ return this.getMaxOutputTokens()
|
|
|
}
|
|
|
}
|
|
|
|
|
|
const testHandler = new TestMoonshotHandler(mockOptions)
|
|
|
- const requestOptions: any = {}
|
|
|
- const modelInfo = {
|
|
|
- maxTokens: 16_000,
|
|
|
+ const result = testHandler.testGetMaxOutputTokens()
|
|
|
+
|
|
|
+ // moonshot-chat has maxTokens of 16384
|
|
|
+ expect(result).toBe(16384)
|
|
|
+ })
|
|
|
+ })
|
|
|
+
|
|
|
+ describe("tool handling", () => {
|
|
|
+ const systemPrompt = "You are a helpful assistant."
|
|
|
+ const messages: Anthropic.Messages.MessageParam[] = [
|
|
|
+ {
|
|
|
+ role: "user",
|
|
|
+ content: [{ type: "text" as const, text: "Hello!" }],
|
|
|
+ },
|
|
|
+ ]
|
|
|
+
|
|
|
+ it("should handle tool calls in streaming", async () => {
|
|
|
+ async function* mockFullStream() {
|
|
|
+ yield {
|
|
|
+ type: "tool-input-start",
|
|
|
+ id: "tool-call-1",
|
|
|
+ toolName: "read_file",
|
|
|
+ }
|
|
|
+ yield {
|
|
|
+ type: "tool-input-delta",
|
|
|
+ id: "tool-call-1",
|
|
|
+ delta: '{"path":"test.ts"}',
|
|
|
+ }
|
|
|
+ yield {
|
|
|
+ type: "tool-input-end",
|
|
|
+ id: "tool-call-1",
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
- testHandler.testAddMaxTokensIfNeeded(requestOptions, modelInfo)
|
|
|
+ const mockUsage = Promise.resolve({
|
|
|
+ inputTokens: 10,
|
|
|
+ outputTokens: 5,
|
|
|
+ details: {},
|
|
|
+ raw: {},
|
|
|
+ })
|
|
|
+
|
|
|
+ mockStreamText.mockReturnValue({
|
|
|
+ fullStream: mockFullStream(),
|
|
|
+ usage: mockUsage,
|
|
|
+ })
|
|
|
+
|
|
|
+ const stream = handler.createMessage(systemPrompt, messages, {
|
|
|
+ taskId: "test-task",
|
|
|
+ tools: [
|
|
|
+ {
|
|
|
+ type: "function",
|
|
|
+ function: {
|
|
|
+ name: "read_file",
|
|
|
+ description: "Read a file",
|
|
|
+ parameters: {
|
|
|
+ type: "object",
|
|
|
+ properties: { path: { type: "string" } },
|
|
|
+ required: ["path"],
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ ],
|
|
|
+ })
|
|
|
+
|
|
|
+ const chunks: any[] = []
|
|
|
+ for await (const chunk of stream) {
|
|
|
+ chunks.push(chunk)
|
|
|
+ }
|
|
|
+
|
|
|
+ const toolCallStartChunks = chunks.filter((c) => c.type === "tool_call_start")
|
|
|
+ const toolCallDeltaChunks = chunks.filter((c) => c.type === "tool_call_delta")
|
|
|
+ const toolCallEndChunks = chunks.filter((c) => c.type === "tool_call_end")
|
|
|
+
|
|
|
+ expect(toolCallStartChunks.length).toBe(1)
|
|
|
+ expect(toolCallStartChunks[0].id).toBe("tool-call-1")
|
|
|
+ expect(toolCallStartChunks[0].name).toBe("read_file")
|
|
|
+
|
|
|
+ expect(toolCallDeltaChunks.length).toBe(1)
|
|
|
+ expect(toolCallDeltaChunks[0].delta).toBe('{"path":"test.ts"}')
|
|
|
+
|
|
|
+ expect(toolCallEndChunks.length).toBe(1)
|
|
|
+ expect(toolCallEndChunks[0].id).toBe("tool-call-1")
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should handle complete tool calls", async () => {
|
|
|
+ async function* mockFullStream() {
|
|
|
+ yield {
|
|
|
+ type: "tool-call",
|
|
|
+ toolCallId: "tool-call-1",
|
|
|
+ toolName: "read_file",
|
|
|
+ input: { path: "test.ts" },
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ const mockUsage = Promise.resolve({
|
|
|
+ inputTokens: 10,
|
|
|
+ outputTokens: 5,
|
|
|
+ details: {},
|
|
|
+ raw: {},
|
|
|
+ })
|
|
|
+
|
|
|
+ mockStreamText.mockReturnValue({
|
|
|
+ fullStream: mockFullStream(),
|
|
|
+ usage: mockUsage,
|
|
|
+ })
|
|
|
+
|
|
|
+ const stream = handler.createMessage(systemPrompt, messages, {
|
|
|
+ taskId: "test-task",
|
|
|
+ tools: [
|
|
|
+ {
|
|
|
+ type: "function",
|
|
|
+ function: {
|
|
|
+ name: "read_file",
|
|
|
+ description: "Read a file",
|
|
|
+ parameters: {
|
|
|
+ type: "object",
|
|
|
+ properties: { path: { type: "string" } },
|
|
|
+ required: ["path"],
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ ],
|
|
|
+ })
|
|
|
+
|
|
|
+ const chunks: any[] = []
|
|
|
+ for await (const chunk of stream) {
|
|
|
+ chunks.push(chunk)
|
|
|
+ }
|
|
|
|
|
|
- expect(requestOptions.max_tokens).toBe(16_000)
|
|
|
+ const toolCallChunks = chunks.filter((c) => c.type === "tool_call")
|
|
|
+ expect(toolCallChunks.length).toBe(1)
|
|
|
+ expect(toolCallChunks[0].id).toBe("tool-call-1")
|
|
|
+ expect(toolCallChunks[0].name).toBe("read_file")
|
|
|
+ expect(toolCallChunks[0].arguments).toBe('{"path":"test.ts"}')
|
|
|
})
|
|
|
})
|
|
|
})
|