|
|
@@ -2,6 +2,7 @@
|
|
|
|
|
|
import { NativeOllamaHandler } from "../native-ollama"
|
|
|
import { ApiHandlerOptions } from "../../../shared/api"
|
|
|
+import { getOllamaModels } from "../fetchers/ollama"
|
|
|
|
|
|
// Mock the ollama package
|
|
|
const mockChat = vitest.fn()
|
|
|
@@ -16,22 +17,27 @@ vitest.mock("ollama", () => {
|
|
|
|
|
|
// Mock the getOllamaModels function
|
|
|
vitest.mock("../fetchers/ollama", () => ({
|
|
|
- getOllamaModels: vitest.fn().mockResolvedValue({
|
|
|
- llama2: {
|
|
|
- contextWindow: 4096,
|
|
|
- maxTokens: 4096,
|
|
|
- supportsImages: false,
|
|
|
- supportsPromptCache: false,
|
|
|
- },
|
|
|
- }),
|
|
|
+ getOllamaModels: vitest.fn(),
|
|
|
}))
|
|
|
|
|
|
+const mockGetOllamaModels = vitest.mocked(getOllamaModels)
|
|
|
+
|
|
|
describe("NativeOllamaHandler", () => {
|
|
|
let handler: NativeOllamaHandler
|
|
|
|
|
|
beforeEach(() => {
|
|
|
vitest.clearAllMocks()
|
|
|
|
|
|
+ // Default mock for getOllamaModels
|
|
|
+ mockGetOllamaModels.mockResolvedValue({
|
|
|
+ llama2: {
|
|
|
+ contextWindow: 4096,
|
|
|
+ maxTokens: 4096,
|
|
|
+ supportsImages: false,
|
|
|
+ supportsPromptCache: false,
|
|
|
+ },
|
|
|
+ })
|
|
|
+
|
|
|
const options: ApiHandlerOptions = {
|
|
|
apiModelId: "llama2",
|
|
|
ollamaModelId: "llama2",
|
|
|
@@ -257,4 +263,260 @@ describe("NativeOllamaHandler", () => {
|
|
|
expect(model.info).toBeDefined()
|
|
|
})
|
|
|
})
|
|
|
+
|
|
|
+ describe("tool calling", () => {
|
|
|
+ it("should include tools when model supports native tools", async () => {
|
|
|
+ // Mock model with native tool support
|
|
|
+ mockGetOllamaModels.mockResolvedValue({
|
|
|
+ "llama3.2": {
|
|
|
+ contextWindow: 128000,
|
|
|
+ maxTokens: 4096,
|
|
|
+ supportsImages: true,
|
|
|
+ supportsPromptCache: false,
|
|
|
+ supportsNativeTools: true,
|
|
|
+ },
|
|
|
+ })
|
|
|
+
|
|
|
+ const options: ApiHandlerOptions = {
|
|
|
+ apiModelId: "llama3.2",
|
|
|
+ ollamaModelId: "llama3.2",
|
|
|
+ ollamaBaseUrl: "http://localhost:11434",
|
|
|
+ }
|
|
|
+
|
|
|
+ handler = new NativeOllamaHandler(options)
|
|
|
+
|
|
|
+ // Mock the chat response
|
|
|
+ mockChat.mockImplementation(async function* () {
|
|
|
+ yield { message: { content: "I will use the tool" } }
|
|
|
+ })
|
|
|
+
|
|
|
+ const tools = [
|
|
|
+ {
|
|
|
+ type: "function" as const,
|
|
|
+ function: {
|
|
|
+ name: "get_weather",
|
|
|
+ description: "Get the weather for a location",
|
|
|
+ parameters: {
|
|
|
+ type: "object",
|
|
|
+ properties: {
|
|
|
+ location: { type: "string", description: "The city name" },
|
|
|
+ },
|
|
|
+ required: ["location"],
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ ]
|
|
|
+
|
|
|
+ const stream = handler.createMessage(
|
|
|
+ "System",
|
|
|
+ [{ role: "user" as const, content: "What's the weather?" }],
|
|
|
+ { taskId: "test", tools },
|
|
|
+ )
|
|
|
+
|
|
|
+ // Consume the stream
|
|
|
+ for await (const _ of stream) {
|
|
|
+ // consume stream
|
|
|
+ }
|
|
|
+
|
|
|
+ // Verify tools were passed to the API
|
|
|
+ expect(mockChat).toHaveBeenCalledWith(
|
|
|
+ expect.objectContaining({
|
|
|
+ tools: [
|
|
|
+ {
|
|
|
+ type: "function",
|
|
|
+ function: {
|
|
|
+ name: "get_weather",
|
|
|
+ description: "Get the weather for a location",
|
|
|
+ parameters: {
|
|
|
+ type: "object",
|
|
|
+ properties: {
|
|
|
+ location: { type: "string", description: "The city name" },
|
|
|
+ },
|
|
|
+ required: ["location"],
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ ],
|
|
|
+ }),
|
|
|
+ )
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should not include tools when model does not support native tools", async () => {
|
|
|
+ // Mock model without native tool support
|
|
|
+ mockGetOllamaModels.mockResolvedValue({
|
|
|
+ llama2: {
|
|
|
+ contextWindow: 4096,
|
|
|
+ maxTokens: 4096,
|
|
|
+ supportsImages: false,
|
|
|
+ supportsPromptCache: false,
|
|
|
+ supportsNativeTools: false,
|
|
|
+ },
|
|
|
+ })
|
|
|
+
|
|
|
+ // Mock the chat response
|
|
|
+ mockChat.mockImplementation(async function* () {
|
|
|
+ yield { message: { content: "Response without tools" } }
|
|
|
+ })
|
|
|
+
|
|
|
+ const tools = [
|
|
|
+ {
|
|
|
+ type: "function" as const,
|
|
|
+ function: {
|
|
|
+ name: "get_weather",
|
|
|
+ description: "Get the weather",
|
|
|
+ parameters: { type: "object", properties: {} },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ ]
|
|
|
+
|
|
|
+ const stream = handler.createMessage("System", [{ role: "user" as const, content: "Test" }], {
|
|
|
+ taskId: "test",
|
|
|
+ tools,
|
|
|
+ })
|
|
|
+
|
|
|
+ // Consume the stream
|
|
|
+ for await (const _ of stream) {
|
|
|
+ // consume stream
|
|
|
+ }
|
|
|
+
|
|
|
+ // Verify tools were NOT passed
|
|
|
+ expect(mockChat).toHaveBeenCalledWith(
|
|
|
+ expect.not.objectContaining({
|
|
|
+ tools: expect.anything(),
|
|
|
+ }),
|
|
|
+ )
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should not include tools when toolProtocol is xml", async () => {
|
|
|
+ // Mock model with native tool support
|
|
|
+ mockGetOllamaModels.mockResolvedValue({
|
|
|
+ "llama3.2": {
|
|
|
+ contextWindow: 128000,
|
|
|
+ maxTokens: 4096,
|
|
|
+ supportsImages: true,
|
|
|
+ supportsPromptCache: false,
|
|
|
+ supportsNativeTools: true,
|
|
|
+ },
|
|
|
+ })
|
|
|
+
|
|
|
+ const options: ApiHandlerOptions = {
|
|
|
+ apiModelId: "llama3.2",
|
|
|
+ ollamaModelId: "llama3.2",
|
|
|
+ ollamaBaseUrl: "http://localhost:11434",
|
|
|
+ }
|
|
|
+
|
|
|
+ handler = new NativeOllamaHandler(options)
|
|
|
+
|
|
|
+ // Mock the chat response
|
|
|
+ mockChat.mockImplementation(async function* () {
|
|
|
+ yield { message: { content: "Response" } }
|
|
|
+ })
|
|
|
+
|
|
|
+ const tools = [
|
|
|
+ {
|
|
|
+ type: "function" as const,
|
|
|
+ function: {
|
|
|
+ name: "get_weather",
|
|
|
+ description: "Get the weather",
|
|
|
+ parameters: { type: "object", properties: {} },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ ]
|
|
|
+
|
|
|
+ const stream = handler.createMessage("System", [{ role: "user" as const, content: "Test" }], {
|
|
|
+ taskId: "test",
|
|
|
+ tools,
|
|
|
+ toolProtocol: "xml",
|
|
|
+ })
|
|
|
+
|
|
|
+ // Consume the stream
|
|
|
+ for await (const _ of stream) {
|
|
|
+ // consume stream
|
|
|
+ }
|
|
|
+
|
|
|
+ // Verify tools were NOT passed (XML protocol forces XML format)
|
|
|
+ expect(mockChat).toHaveBeenCalledWith(
|
|
|
+ expect.not.objectContaining({
|
|
|
+ tools: expect.anything(),
|
|
|
+ }),
|
|
|
+ )
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should yield tool_call_partial when model returns tool calls", async () => {
|
|
|
+ // Mock model with native tool support
|
|
|
+ mockGetOllamaModels.mockResolvedValue({
|
|
|
+ "llama3.2": {
|
|
|
+ contextWindow: 128000,
|
|
|
+ maxTokens: 4096,
|
|
|
+ supportsImages: true,
|
|
|
+ supportsPromptCache: false,
|
|
|
+ supportsNativeTools: true,
|
|
|
+ },
|
|
|
+ })
|
|
|
+
|
|
|
+ const options: ApiHandlerOptions = {
|
|
|
+ apiModelId: "llama3.2",
|
|
|
+ ollamaModelId: "llama3.2",
|
|
|
+ ollamaBaseUrl: "http://localhost:11434",
|
|
|
+ }
|
|
|
+
|
|
|
+ handler = new NativeOllamaHandler(options)
|
|
|
+
|
|
|
+ // Mock the chat response with tool calls
|
|
|
+ mockChat.mockImplementation(async function* () {
|
|
|
+ yield {
|
|
|
+ message: {
|
|
|
+ content: "",
|
|
|
+ tool_calls: [
|
|
|
+ {
|
|
|
+ function: {
|
|
|
+ name: "get_weather",
|
|
|
+ arguments: { location: "San Francisco" },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ ],
|
|
|
+ },
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ const tools = [
|
|
|
+ {
|
|
|
+ type: "function" as const,
|
|
|
+ function: {
|
|
|
+ name: "get_weather",
|
|
|
+ description: "Get the weather for a location",
|
|
|
+ parameters: {
|
|
|
+ type: "object",
|
|
|
+ properties: {
|
|
|
+ location: { type: "string" },
|
|
|
+ },
|
|
|
+ required: ["location"],
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ ]
|
|
|
+
|
|
|
+ const stream = handler.createMessage(
|
|
|
+ "System",
|
|
|
+ [{ role: "user" as const, content: "What's the weather in SF?" }],
|
|
|
+ { taskId: "test", tools },
|
|
|
+ )
|
|
|
+
|
|
|
+ const results = []
|
|
|
+ for await (const chunk of stream) {
|
|
|
+ results.push(chunk)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Should yield a tool_call_partial chunk
|
|
|
+ const toolCallChunk = results.find((r) => r.type === "tool_call_partial")
|
|
|
+ expect(toolCallChunk).toBeDefined()
|
|
|
+ expect(toolCallChunk).toEqual({
|
|
|
+ type: "tool_call_partial",
|
|
|
+ index: 0,
|
|
|
+ id: "ollama-tool-0",
|
|
|
+ name: "get_weather",
|
|
|
+ arguments: JSON.stringify({ location: "San Francisco" }),
|
|
|
+ })
|
|
|
+ })
|
|
|
+ })
|
|
|
})
|