|
|
@@ -1,34 +1,46 @@
|
|
|
// npx vitest run api/providers/__tests__/native-ollama.spec.ts
|
|
|
|
|
|
-import { NativeOllamaHandler } from "../native-ollama"
|
|
|
-import { ApiHandlerOptions } from "../../../shared/api"
|
|
|
-import { getOllamaModels } from "../fetchers/ollama"
|
|
|
+// Use vi.hoisted to define mock functions that can be referenced in hoisted vi.mock() calls
|
|
|
+const { mockStreamText, mockGenerateText } = vi.hoisted(() => ({
|
|
|
+ mockStreamText: vi.fn(),
|
|
|
+ mockGenerateText: vi.fn(),
|
|
|
+}))
|
|
|
|
|
|
-// Mock the ollama package
|
|
|
-const mockChat = vitest.fn()
|
|
|
-vitest.mock("ollama", () => {
|
|
|
+vi.mock("ai", async (importOriginal) => {
|
|
|
+ const actual = await importOriginal<typeof import("ai")>()
|
|
|
return {
|
|
|
- Ollama: vitest.fn().mockImplementation(() => ({
|
|
|
- chat: mockChat,
|
|
|
- })),
|
|
|
- Message: vitest.fn(),
|
|
|
+ ...actual,
|
|
|
+ streamText: mockStreamText,
|
|
|
+ generateText: mockGenerateText,
|
|
|
}
|
|
|
})
|
|
|
|
|
|
+vi.mock("ollama-ai-provider-v2", () => ({
|
|
|
+ createOllama: vi.fn(() => {
|
|
|
+ return vi.fn(() => ({
|
|
|
+ modelId: "llama2",
|
|
|
+ provider: "ollama",
|
|
|
+ }))
|
|
|
+ }),
|
|
|
+}))
|
|
|
+
|
|
|
// Mock the getOllamaModels function
|
|
|
-vitest.mock("../fetchers/ollama", () => ({
|
|
|
- getOllamaModels: vitest.fn(),
|
|
|
+vi.mock("../fetchers/ollama", () => ({
|
|
|
+ getOllamaModels: vi.fn(),
|
|
|
}))
|
|
|
|
|
|
-const mockGetOllamaModels = vitest.mocked(getOllamaModels)
|
|
|
+import { NativeOllamaHandler } from "../native-ollama"
|
|
|
+import { ApiHandlerOptions } from "../../../shared/api"
|
|
|
+import { getOllamaModels } from "../fetchers/ollama"
|
|
|
+
|
|
|
+const mockGetOllamaModels = vi.mocked(getOllamaModels)
|
|
|
|
|
|
describe("NativeOllamaHandler", () => {
|
|
|
let handler: NativeOllamaHandler
|
|
|
|
|
|
beforeEach(() => {
|
|
|
- vitest.clearAllMocks()
|
|
|
+ vi.clearAllMocks()
|
|
|
|
|
|
- // Default mock for getOllamaModels
|
|
|
mockGetOllamaModels.mockResolvedValue({
|
|
|
llama2: {
|
|
|
contextWindow: 4096,
|
|
|
@@ -49,18 +61,14 @@ describe("NativeOllamaHandler", () => {
|
|
|
|
|
|
describe("createMessage", () => {
|
|
|
it("should stream messages from Ollama", async () => {
|
|
|
- // Mock the chat response as an async generator
|
|
|
- mockChat.mockImplementation(async function* () {
|
|
|
- yield {
|
|
|
- message: { content: "Hello" },
|
|
|
- eval_count: undefined,
|
|
|
- prompt_eval_count: undefined,
|
|
|
- }
|
|
|
- yield {
|
|
|
- message: { content: " world" },
|
|
|
- eval_count: 2,
|
|
|
- prompt_eval_count: 10,
|
|
|
- }
|
|
|
+ async function* mockFullStream() {
|
|
|
+ yield { type: "text-delta", text: "Hello" }
|
|
|
+ yield { type: "text-delta", text: " world" }
|
|
|
+ }
|
|
|
+
|
|
|
+ mockStreamText.mockReturnValue({
|
|
|
+ fullStream: mockFullStream(),
|
|
|
+ usage: Promise.resolve({ inputTokens: 10, outputTokens: 2 }),
|
|
|
})
|
|
|
|
|
|
const systemPrompt = "You are a helpful assistant"
|
|
|
@@ -79,57 +87,57 @@ describe("NativeOllamaHandler", () => {
|
|
|
expect(results[2]).toEqual({ type: "usage", inputTokens: 10, outputTokens: 2 })
|
|
|
})
|
|
|
|
|
|
- it("should not include num_ctx by default", async () => {
|
|
|
- // Mock the chat response
|
|
|
- mockChat.mockImplementation(async function* () {
|
|
|
- yield { message: { content: "Response" } }
|
|
|
+ it("should not include providerOptions by default (no num_ctx)", async () => {
|
|
|
+ async function* mockFullStream() {
|
|
|
+ yield { type: "text-delta", text: "Response" }
|
|
|
+ }
|
|
|
+
|
|
|
+ mockStreamText.mockReturnValue({
|
|
|
+ fullStream: mockFullStream(),
|
|
|
+ usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }),
|
|
|
})
|
|
|
|
|
|
const stream = handler.createMessage("System", [{ role: "user" as const, content: "Test" }])
|
|
|
|
|
|
- // Consume the stream
|
|
|
for await (const _ of stream) {
|
|
|
// consume stream
|
|
|
}
|
|
|
|
|
|
- // Verify that num_ctx was NOT included in the options
|
|
|
- expect(mockChat).toHaveBeenCalledWith(
|
|
|
- expect.objectContaining({
|
|
|
- options: expect.not.objectContaining({
|
|
|
- num_ctx: expect.anything(),
|
|
|
- }),
|
|
|
+ expect(mockStreamText).toHaveBeenCalledWith(
|
|
|
+ expect.not.objectContaining({
|
|
|
+ providerOptions: expect.anything(),
|
|
|
}),
|
|
|
)
|
|
|
})
|
|
|
|
|
|
- it("should include num_ctx when explicitly set via ollamaNumCtx", async () => {
|
|
|
+ it("should include num_ctx via providerOptions when explicitly set via ollamaNumCtx", async () => {
|
|
|
const options: ApiHandlerOptions = {
|
|
|
apiModelId: "llama2",
|
|
|
ollamaModelId: "llama2",
|
|
|
ollamaBaseUrl: "http://localhost:11434",
|
|
|
- ollamaNumCtx: 8192, // Explicitly set num_ctx
|
|
|
+ ollamaNumCtx: 8192,
|
|
|
}
|
|
|
|
|
|
handler = new NativeOllamaHandler(options)
|
|
|
|
|
|
- // Mock the chat response
|
|
|
- mockChat.mockImplementation(async function* () {
|
|
|
- yield { message: { content: "Response" } }
|
|
|
+ async function* mockFullStream() {
|
|
|
+ yield { type: "text-delta", text: "Response" }
|
|
|
+ }
|
|
|
+
|
|
|
+ mockStreamText.mockReturnValue({
|
|
|
+ fullStream: mockFullStream(),
|
|
|
+ usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }),
|
|
|
})
|
|
|
|
|
|
const stream = handler.createMessage("System", [{ role: "user" as const, content: "Test" }])
|
|
|
|
|
|
- // Consume the stream
|
|
|
for await (const _ of stream) {
|
|
|
// consume stream
|
|
|
}
|
|
|
|
|
|
- // Verify that num_ctx was included with the specified value
|
|
|
- expect(mockChat).toHaveBeenCalledWith(
|
|
|
+ expect(mockStreamText).toHaveBeenCalledWith(
|
|
|
expect.objectContaining({
|
|
|
- options: expect.objectContaining({
|
|
|
- num_ctx: 8192,
|
|
|
- }),
|
|
|
+ providerOptions: { ollama: { options: { num_ctx: 8192 } } },
|
|
|
}),
|
|
|
)
|
|
|
})
|
|
|
@@ -143,11 +151,14 @@ describe("NativeOllamaHandler", () => {
|
|
|
|
|
|
handler = new NativeOllamaHandler(options)
|
|
|
|
|
|
- // Mock response with thinking tags
|
|
|
- mockChat.mockImplementation(async function* () {
|
|
|
- yield { message: { content: "<think>Let me think" } }
|
|
|
- yield { message: { content: " about this</think>" } }
|
|
|
- yield { message: { content: "The answer is 42" } }
|
|
|
+ async function* mockFullStream() {
|
|
|
+ yield { type: "reasoning-delta", text: "Let me think about this" }
|
|
|
+ yield { type: "text-delta", text: "The answer is 42" }
|
|
|
+ }
|
|
|
+
|
|
|
+ mockStreamText.mockReturnValue({
|
|
|
+ fullStream: mockFullStream(),
|
|
|
+ usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }),
|
|
|
})
|
|
|
|
|
|
const stream = handler.createMessage("System", [{ role: "user" as const, content: "Question?" }])
|
|
|
@@ -157,70 +168,67 @@ describe("NativeOllamaHandler", () => {
|
|
|
results.push(chunk)
|
|
|
}
|
|
|
|
|
|
- // Should detect reasoning vs regular text
|
|
|
expect(results.some((r) => r.type === "reasoning")).toBe(true)
|
|
|
expect(results.some((r) => r.type === "text")).toBe(true)
|
|
|
+
|
|
|
+ expect(mockStreamText).toHaveBeenCalledWith(
|
|
|
+ expect.objectContaining({
|
|
|
+ providerOptions: { ollama: { think: true } },
|
|
|
+ }),
|
|
|
+ )
|
|
|
})
|
|
|
})
|
|
|
|
|
|
describe("completePrompt", () => {
|
|
|
it("should complete a prompt without streaming", async () => {
|
|
|
- mockChat.mockResolvedValue({
|
|
|
- message: { content: "This is the response" },
|
|
|
+ mockGenerateText.mockResolvedValue({
|
|
|
+ text: "This is the response",
|
|
|
})
|
|
|
|
|
|
const result = await handler.completePrompt("Tell me a joke")
|
|
|
|
|
|
- expect(mockChat).toHaveBeenCalledWith({
|
|
|
- model: "llama2",
|
|
|
- messages: [{ role: "user", content: "Tell me a joke" }],
|
|
|
- stream: false,
|
|
|
- options: {
|
|
|
+ expect(mockGenerateText).toHaveBeenCalledWith(
|
|
|
+ expect.objectContaining({
|
|
|
+ prompt: "Tell me a joke",
|
|
|
temperature: 0,
|
|
|
- },
|
|
|
- })
|
|
|
+ }),
|
|
|
+ )
|
|
|
expect(result).toBe("This is the response")
|
|
|
})
|
|
|
|
|
|
- it("should not include num_ctx in completePrompt by default", async () => {
|
|
|
- mockChat.mockResolvedValue({
|
|
|
- message: { content: "Response" },
|
|
|
+ it("should not include providerOptions in completePrompt by default", async () => {
|
|
|
+ mockGenerateText.mockResolvedValue({
|
|
|
+ text: "Response",
|
|
|
})
|
|
|
|
|
|
await handler.completePrompt("Test prompt")
|
|
|
|
|
|
- // Verify that num_ctx was NOT included in the options
|
|
|
- expect(mockChat).toHaveBeenCalledWith(
|
|
|
- expect.objectContaining({
|
|
|
- options: expect.not.objectContaining({
|
|
|
- num_ctx: expect.anything(),
|
|
|
- }),
|
|
|
+ expect(mockGenerateText).toHaveBeenCalledWith(
|
|
|
+ expect.not.objectContaining({
|
|
|
+ providerOptions: expect.anything(),
|
|
|
}),
|
|
|
)
|
|
|
})
|
|
|
|
|
|
- it("should include num_ctx in completePrompt when explicitly set", async () => {
|
|
|
+ it("should include num_ctx via providerOptions in completePrompt when explicitly set", async () => {
|
|
|
const options: ApiHandlerOptions = {
|
|
|
apiModelId: "llama2",
|
|
|
ollamaModelId: "llama2",
|
|
|
ollamaBaseUrl: "http://localhost:11434",
|
|
|
- ollamaNumCtx: 4096, // Explicitly set num_ctx
|
|
|
+ ollamaNumCtx: 4096,
|
|
|
}
|
|
|
|
|
|
handler = new NativeOllamaHandler(options)
|
|
|
|
|
|
- mockChat.mockResolvedValue({
|
|
|
- message: { content: "Response" },
|
|
|
+ mockGenerateText.mockResolvedValue({
|
|
|
+ text: "Response",
|
|
|
})
|
|
|
|
|
|
await handler.completePrompt("Test prompt")
|
|
|
|
|
|
- // Verify that num_ctx was included with the specified value
|
|
|
- expect(mockChat).toHaveBeenCalledWith(
|
|
|
+ expect(mockGenerateText).toHaveBeenCalledWith(
|
|
|
expect.objectContaining({
|
|
|
- options: expect.objectContaining({
|
|
|
- num_ctx: 4096,
|
|
|
- }),
|
|
|
+ providerOptions: { ollama: { options: { num_ctx: 4096 } } },
|
|
|
}),
|
|
|
)
|
|
|
})
|
|
|
@@ -230,7 +238,17 @@ describe("NativeOllamaHandler", () => {
|
|
|
it("should handle connection refused errors", async () => {
|
|
|
const error = new Error("ECONNREFUSED") as any
|
|
|
error.code = "ECONNREFUSED"
|
|
|
- mockChat.mockRejectedValue(error)
|
|
|
+
|
|
|
+ const mockFullStream = {
|
|
|
+ [Symbol.asyncIterator]: () => ({
|
|
|
+ next: () => Promise.reject(error),
|
|
|
+ }),
|
|
|
+ }
|
|
|
+
|
|
|
+ mockStreamText.mockReturnValue({
|
|
|
+ fullStream: mockFullStream,
|
|
|
+ usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }),
|
|
|
+ })
|
|
|
|
|
|
const stream = handler.createMessage("System", [{ role: "user" as const, content: "Test" }])
|
|
|
|
|
|
@@ -244,7 +262,17 @@ describe("NativeOllamaHandler", () => {
|
|
|
it("should handle model not found errors", async () => {
|
|
|
const error = new Error("Not found") as any
|
|
|
error.status = 404
|
|
|
- mockChat.mockRejectedValue(error)
|
|
|
+
|
|
|
+ const mockFullStream = {
|
|
|
+ [Symbol.asyncIterator]: () => ({
|
|
|
+ next: () => Promise.reject(error),
|
|
|
+ }),
|
|
|
+ }
|
|
|
+
|
|
|
+ mockStreamText.mockReturnValue({
|
|
|
+ fullStream: mockFullStream,
|
|
|
+ usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }),
|
|
|
+ })
|
|
|
|
|
|
const stream = handler.createMessage("System", [{ role: "user" as const, content: "Test" }])
|
|
|
|
|
|
@@ -264,9 +292,14 @@ describe("NativeOllamaHandler", () => {
|
|
|
})
|
|
|
})
|
|
|
|
|
|
+ describe("isAiSdkProvider", () => {
|
|
|
+ it("should return true", () => {
|
|
|
+ expect(handler.isAiSdkProvider()).toBe(true)
|
|
|
+ })
|
|
|
+ })
|
|
|
+
|
|
|
describe("tool calling", () => {
|
|
|
- it("should include tools when tools are provided", async () => {
|
|
|
- // Model metadata should not gate tool inclusion; metadata.tools controls it.
|
|
|
+ it("should pass tools via AI SDK when tools are provided", async () => {
|
|
|
mockGetOllamaModels.mockResolvedValue({
|
|
|
"llama3.2": {
|
|
|
contextWindow: 128000,
|
|
|
@@ -284,9 +317,13 @@ describe("NativeOllamaHandler", () => {
|
|
|
|
|
|
handler = new NativeOllamaHandler(options)
|
|
|
|
|
|
- // Mock the chat response
|
|
|
- mockChat.mockImplementation(async function* () {
|
|
|
- yield { message: { content: "I will use the tool" } }
|
|
|
+ async function* mockFullStream() {
|
|
|
+ yield { type: "text-delta", text: "I will use the tool" }
|
|
|
+ }
|
|
|
+
|
|
|
+ mockStreamText.mockReturnValue({
|
|
|
+ fullStream: mockFullStream(),
|
|
|
+ usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }),
|
|
|
})
|
|
|
|
|
|
const tools = [
|
|
|
@@ -312,36 +349,18 @@ describe("NativeOllamaHandler", () => {
|
|
|
{ taskId: "test", tools },
|
|
|
)
|
|
|
|
|
|
- // Consume the stream
|
|
|
for await (const _ of stream) {
|
|
|
// consume stream
|
|
|
}
|
|
|
|
|
|
- // Verify tools were passed to the API
|
|
|
- expect(mockChat).toHaveBeenCalledWith(
|
|
|
+ expect(mockStreamText).toHaveBeenCalledWith(
|
|
|
expect.objectContaining({
|
|
|
- tools: [
|
|
|
- {
|
|
|
- type: "function",
|
|
|
- function: {
|
|
|
- name: "get_weather",
|
|
|
- description: "Get the weather for a location",
|
|
|
- parameters: {
|
|
|
- type: "object",
|
|
|
- properties: {
|
|
|
- location: { type: "string", description: "The city name" },
|
|
|
- },
|
|
|
- required: ["location"],
|
|
|
- },
|
|
|
- },
|
|
|
- },
|
|
|
- ],
|
|
|
+ tools: expect.any(Object),
|
|
|
}),
|
|
|
)
|
|
|
})
|
|
|
|
|
|
- it("should include tools even when model metadata doesn't advertise tool support", async () => {
|
|
|
- // Model metadata should not gate tool inclusion; metadata.tools controls it.
|
|
|
+ it("should pass tools even when model metadata doesn't advertise tool support", async () => {
|
|
|
mockGetOllamaModels.mockResolvedValue({
|
|
|
llama2: {
|
|
|
contextWindow: 4096,
|
|
|
@@ -351,9 +370,13 @@ describe("NativeOllamaHandler", () => {
|
|
|
},
|
|
|
})
|
|
|
|
|
|
- // Mock the chat response
|
|
|
- mockChat.mockImplementation(async function* () {
|
|
|
- yield { message: { content: "Response without tools" } }
|
|
|
+ async function* mockFullStream() {
|
|
|
+ yield { type: "text-delta", text: "Response without tools" }
|
|
|
+ }
|
|
|
+
|
|
|
+ mockStreamText.mockReturnValue({
|
|
|
+ fullStream: mockFullStream(),
|
|
|
+ usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }),
|
|
|
})
|
|
|
|
|
|
const tools = [
|
|
|
@@ -372,21 +395,18 @@ describe("NativeOllamaHandler", () => {
|
|
|
tools,
|
|
|
})
|
|
|
|
|
|
- // Consume the stream
|
|
|
for await (const _ of stream) {
|
|
|
// consume stream
|
|
|
}
|
|
|
|
|
|
- // Verify tools were passed
|
|
|
- expect(mockChat).toHaveBeenCalledWith(
|
|
|
+ expect(mockStreamText).toHaveBeenCalledWith(
|
|
|
expect.objectContaining({
|
|
|
- tools: expect.any(Array),
|
|
|
+ tools: expect.any(Object),
|
|
|
}),
|
|
|
)
|
|
|
})
|
|
|
|
|
|
it("should not include tools when no tools are provided", async () => {
|
|
|
- // Model metadata should not gate tool inclusion; metadata.tools controls it.
|
|
|
mockGetOllamaModels.mockResolvedValue({
|
|
|
"llama3.2": {
|
|
|
contextWindow: 128000,
|
|
|
@@ -404,30 +424,31 @@ describe("NativeOllamaHandler", () => {
|
|
|
|
|
|
handler = new NativeOllamaHandler(options)
|
|
|
|
|
|
- // Mock the chat response
|
|
|
- mockChat.mockImplementation(async function* () {
|
|
|
- yield { message: { content: "Response" } }
|
|
|
+ async function* mockFullStream() {
|
|
|
+ yield { type: "text-delta", text: "Response" }
|
|
|
+ }
|
|
|
+
|
|
|
+ mockStreamText.mockReturnValue({
|
|
|
+ fullStream: mockFullStream(),
|
|
|
+ usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }),
|
|
|
})
|
|
|
|
|
|
const stream = handler.createMessage("System", [{ role: "user" as const, content: "Test" }], {
|
|
|
taskId: "test",
|
|
|
})
|
|
|
|
|
|
- // Consume the stream
|
|
|
for await (const _ of stream) {
|
|
|
// consume stream
|
|
|
}
|
|
|
|
|
|
- // Verify tools were NOT passed
|
|
|
- expect(mockChat).toHaveBeenCalledWith(
|
|
|
- expect.not.objectContaining({
|
|
|
- tools: expect.anything(),
|
|
|
+ expect(mockStreamText).toHaveBeenCalledWith(
|
|
|
+ expect.objectContaining({
|
|
|
+ tools: undefined,
|
|
|
}),
|
|
|
)
|
|
|
})
|
|
|
|
|
|
- it("should yield tool_call_partial when model returns tool calls", async () => {
|
|
|
- // Model metadata should not gate tool inclusion; metadata.tools controls it.
|
|
|
+ it("should yield tool call events when model returns tool calls", async () => {
|
|
|
mockGetOllamaModels.mockResolvedValue({
|
|
|
"llama3.2": {
|
|
|
contextWindow: 128000,
|
|
|
@@ -445,21 +466,26 @@ describe("NativeOllamaHandler", () => {
|
|
|
|
|
|
handler = new NativeOllamaHandler(options)
|
|
|
|
|
|
- // Mock the chat response with tool calls
|
|
|
- mockChat.mockImplementation(async function* () {
|
|
|
+ async function* mockFullStream() {
|
|
|
yield {
|
|
|
- message: {
|
|
|
- content: "",
|
|
|
- tool_calls: [
|
|
|
- {
|
|
|
- function: {
|
|
|
- name: "get_weather",
|
|
|
- arguments: { location: "San Francisco" },
|
|
|
- },
|
|
|
- },
|
|
|
- ],
|
|
|
- },
|
|
|
+ type: "tool-input-start",
|
|
|
+ id: "tool-call-1",
|
|
|
+ toolName: "get_weather",
|
|
|
}
|
|
|
+ yield {
|
|
|
+ type: "tool-input-delta",
|
|
|
+ id: "tool-call-1",
|
|
|
+ delta: '{"location":"San Francisco"}',
|
|
|
+ }
|
|
|
+ yield {
|
|
|
+ type: "tool-input-end",
|
|
|
+ id: "tool-call-1",
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ mockStreamText.mockReturnValue({
|
|
|
+ fullStream: mockFullStream(),
|
|
|
+ usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }),
|
|
|
})
|
|
|
|
|
|
const tools = [
|
|
|
@@ -490,20 +516,26 @@ describe("NativeOllamaHandler", () => {
|
|
|
results.push(chunk)
|
|
|
}
|
|
|
|
|
|
- // Should yield a tool_call_partial chunk
|
|
|
- const toolCallChunk = results.find((r) => r.type === "tool_call_partial")
|
|
|
- expect(toolCallChunk).toBeDefined()
|
|
|
- expect(toolCallChunk).toEqual({
|
|
|
- type: "tool_call_partial",
|
|
|
- index: 0,
|
|
|
- id: "ollama-tool-0",
|
|
|
+ const toolCallStart = results.find((r) => r.type === "tool_call_start")
|
|
|
+ expect(toolCallStart).toBeDefined()
|
|
|
+ expect(toolCallStart).toEqual({
|
|
|
+ type: "tool_call_start",
|
|
|
+ id: "tool-call-1",
|
|
|
name: "get_weather",
|
|
|
- arguments: JSON.stringify({ location: "San Francisco" }),
|
|
|
+ })
|
|
|
+
|
|
|
+ const toolCallDelta = results.find((r) => r.type === "tool_call_delta")
|
|
|
+ expect(toolCallDelta).toBeDefined()
|
|
|
+
|
|
|
+ const toolCallEnd = results.find((r) => r.type === "tool_call_end")
|
|
|
+ expect(toolCallEnd).toBeDefined()
|
|
|
+ expect(toolCallEnd).toEqual({
|
|
|
+ type: "tool_call_end",
|
|
|
+ id: "tool-call-1",
|
|
|
})
|
|
|
})
|
|
|
|
|
|
- it("should yield tool_call_end events after tool_call_partial chunks", async () => {
|
|
|
- // Model metadata should not gate tool inclusion; metadata.tools controls it.
|
|
|
+ it("should yield tool_call_end events after tool_call_start for multiple tools", async () => {
|
|
|
mockGetOllamaModels.mockResolvedValue({
|
|
|
"llama3.2": {
|
|
|
contextWindow: 128000,
|
|
|
@@ -521,27 +553,18 @@ describe("NativeOllamaHandler", () => {
|
|
|
|
|
|
handler = new NativeOllamaHandler(options)
|
|
|
|
|
|
- // Mock the chat response with multiple tool calls
|
|
|
- mockChat.mockImplementation(async function* () {
|
|
|
- yield {
|
|
|
- message: {
|
|
|
- content: "",
|
|
|
- tool_calls: [
|
|
|
- {
|
|
|
- function: {
|
|
|
- name: "get_weather",
|
|
|
- arguments: { location: "San Francisco" },
|
|
|
- },
|
|
|
- },
|
|
|
- {
|
|
|
- function: {
|
|
|
- name: "get_time",
|
|
|
- arguments: { timezone: "PST" },
|
|
|
- },
|
|
|
- },
|
|
|
- ],
|
|
|
- },
|
|
|
- }
|
|
|
+ async function* mockFullStream() {
|
|
|
+ yield { type: "tool-input-start", id: "tool-0", toolName: "get_weather" }
|
|
|
+ yield { type: "tool-input-delta", id: "tool-0", delta: '{"location":"SF"}' }
|
|
|
+ yield { type: "tool-input-end", id: "tool-0" }
|
|
|
+ yield { type: "tool-input-start", id: "tool-1", toolName: "get_time" }
|
|
|
+ yield { type: "tool-input-delta", id: "tool-1", delta: '{"timezone":"PST"}' }
|
|
|
+ yield { type: "tool-input-end", id: "tool-1" }
|
|
|
+ }
|
|
|
+
|
|
|
+ mockStreamText.mockReturnValue({
|
|
|
+ fullStream: mockFullStream(),
|
|
|
+ usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }),
|
|
|
})
|
|
|
|
|
|
const tools = [
|
|
|
@@ -582,27 +605,18 @@ describe("NativeOllamaHandler", () => {
|
|
|
results.push(chunk)
|
|
|
}
|
|
|
|
|
|
- // Should yield tool_call_partial chunks
|
|
|
- const toolCallPartials = results.filter((r) => r.type === "tool_call_partial")
|
|
|
- expect(toolCallPartials).toHaveLength(2)
|
|
|
+ const toolCallStarts = results.filter((r) => r.type === "tool_call_start")
|
|
|
+ expect(toolCallStarts).toHaveLength(2)
|
|
|
|
|
|
- // Should yield tool_call_end events for each tool call
|
|
|
const toolCallEnds = results.filter((r) => r.type === "tool_call_end")
|
|
|
expect(toolCallEnds).toHaveLength(2)
|
|
|
- expect(toolCallEnds[0]).toEqual({ type: "tool_call_end", id: "ollama-tool-0" })
|
|
|
- expect(toolCallEnds[1]).toEqual({ type: "tool_call_end", id: "ollama-tool-1" })
|
|
|
-
|
|
|
- // tool_call_end should come after tool_call_partial
|
|
|
- // Find the last tool_call_partial index
|
|
|
- let lastPartialIndex = -1
|
|
|
- for (let i = results.length - 1; i >= 0; i--) {
|
|
|
- if (results[i].type === "tool_call_partial") {
|
|
|
- lastPartialIndex = i
|
|
|
- break
|
|
|
- }
|
|
|
- }
|
|
|
+ expect(toolCallEnds[0]).toEqual({ type: "tool_call_end", id: "tool-0" })
|
|
|
+ expect(toolCallEnds[1]).toEqual({ type: "tool_call_end", id: "tool-1" })
|
|
|
+
|
|
|
+ // tool_call_end should come after corresponding tool_call_start
|
|
|
+ const firstStartIndex = results.findIndex((r) => r.type === "tool_call_start")
|
|
|
const firstEndIndex = results.findIndex((r) => r.type === "tool_call_end")
|
|
|
- expect(firstEndIndex).toBeGreaterThan(lastPartialIndex)
|
|
|
+ expect(firstEndIndex).toBeGreaterThan(firstStartIndex)
|
|
|
})
|
|
|
})
|
|
|
})
|