|
|
@@ -949,4 +949,299 @@ describe("VertexHandler", () => {
|
|
|
)
|
|
|
})
|
|
|
})
|
|
|
+
|
|
|
+ describe("native tool calling", () => {
|
|
|
+ const systemPrompt = "You are a helpful assistant"
|
|
|
+ const messages: Anthropic.Messages.MessageParam[] = [
|
|
|
+ {
|
|
|
+ role: "user",
|
|
|
+ content: [{ type: "text" as const, text: "What's the weather in London?" }],
|
|
|
+ },
|
|
|
+ ]
|
|
|
+
|
|
|
+ const mockTools = [
|
|
|
+ {
|
|
|
+ type: "function" as const,
|
|
|
+ function: {
|
|
|
+ name: "get_weather",
|
|
|
+ description: "Get the current weather",
|
|
|
+ parameters: {
|
|
|
+ type: "object",
|
|
|
+ properties: {
|
|
|
+ location: { type: "string" },
|
|
|
+ },
|
|
|
+ required: ["location"],
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ ]
|
|
|
+
|
|
|
+ it("should include tools in request when native protocol is used", async () => {
|
|
|
+ handler = new AnthropicVertexHandler({
|
|
|
+ apiModelId: "claude-3-5-sonnet-v2@20241022",
|
|
|
+ vertexProjectId: "test-project",
|
|
|
+ vertexRegion: "us-central1",
|
|
|
+ })
|
|
|
+
|
|
|
+ const mockStream = [
|
|
|
+ {
|
|
|
+ type: "message_start",
|
|
|
+ message: {
|
|
|
+ usage: {
|
|
|
+ input_tokens: 10,
|
|
|
+ output_tokens: 0,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ ]
|
|
|
+
|
|
|
+ const asyncIterator = {
|
|
|
+ async *[Symbol.asyncIterator]() {
|
|
|
+ for (const chunk of mockStream) {
|
|
|
+ yield chunk
|
|
|
+ }
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ const mockCreate = vitest.fn().mockResolvedValue(asyncIterator)
|
|
|
+ ;(handler["client"].messages as any).create = mockCreate
|
|
|
+
|
|
|
+ const stream = handler.createMessage(systemPrompt, messages, {
|
|
|
+ taskId: "test-task",
|
|
|
+ tools: mockTools,
|
|
|
+ })
|
|
|
+
|
|
|
+ // Consume the stream to trigger the API call
|
|
|
+ for await (const _chunk of stream) {
|
|
|
+ // Just consume
|
|
|
+ }
|
|
|
+
|
|
|
+ expect(mockCreate).toHaveBeenCalledWith(
|
|
|
+ expect.objectContaining({
|
|
|
+ tools: expect.arrayContaining([
|
|
|
+ expect.objectContaining({
|
|
|
+ name: "get_weather",
|
|
|
+ description: "Get the current weather",
|
|
|
+ input_schema: expect.objectContaining({
|
|
|
+ type: "object",
|
|
|
+ properties: expect.objectContaining({
|
|
|
+ location: { type: "string" },
|
|
|
+ }),
|
|
|
+ }),
|
|
|
+ }),
|
|
|
+ ]),
|
|
|
+ tool_choice: { type: "auto", disable_parallel_tool_use: true },
|
|
|
+ }),
|
|
|
+ )
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should not include tools when toolProtocol is xml", async () => {
|
|
|
+ handler = new AnthropicVertexHandler({
|
|
|
+ apiModelId: "claude-3-5-sonnet-v2@20241022",
|
|
|
+ vertexProjectId: "test-project",
|
|
|
+ vertexRegion: "us-central1",
|
|
|
+ toolProtocol: "xml",
|
|
|
+ })
|
|
|
+
|
|
|
+ const mockStream = [
|
|
|
+ {
|
|
|
+ type: "message_start",
|
|
|
+ message: {
|
|
|
+ usage: {
|
|
|
+ input_tokens: 10,
|
|
|
+ output_tokens: 0,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ ]
|
|
|
+
|
|
|
+ const asyncIterator = {
|
|
|
+ async *[Symbol.asyncIterator]() {
|
|
|
+ for (const chunk of mockStream) {
|
|
|
+ yield chunk
|
|
|
+ }
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ const mockCreate = vitest.fn().mockResolvedValue(asyncIterator)
|
|
|
+ ;(handler["client"].messages as any).create = mockCreate
|
|
|
+
|
|
|
+ const stream = handler.createMessage(systemPrompt, messages, {
|
|
|
+ taskId: "test-task",
|
|
|
+ tools: mockTools,
|
|
|
+ })
|
|
|
+
|
|
|
+ // Consume the stream to trigger the API call
|
|
|
+ for await (const _chunk of stream) {
|
|
|
+ // Just consume
|
|
|
+ }
|
|
|
+
|
|
|
+ expect(mockCreate).toHaveBeenCalledWith(
|
|
|
+ expect.not.objectContaining({
|
|
|
+ tools: expect.anything(),
|
|
|
+ }),
|
|
|
+ )
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should handle tool_use blocks in stream and emit tool_call_partial", async () => {
|
|
|
+ handler = new AnthropicVertexHandler({
|
|
|
+ apiModelId: "claude-3-5-sonnet-v2@20241022",
|
|
|
+ vertexProjectId: "test-project",
|
|
|
+ vertexRegion: "us-central1",
|
|
|
+ })
|
|
|
+
|
|
|
+ const mockStream = [
|
|
|
+ {
|
|
|
+ type: "message_start",
|
|
|
+ message: {
|
|
|
+ usage: {
|
|
|
+ input_tokens: 100,
|
|
|
+ output_tokens: 50,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ type: "content_block_start",
|
|
|
+ index: 0,
|
|
|
+ content_block: {
|
|
|
+ type: "tool_use",
|
|
|
+ id: "toolu_123",
|
|
|
+ name: "get_weather",
|
|
|
+ },
|
|
|
+ },
|
|
|
+ ]
|
|
|
+
|
|
|
+ const asyncIterator = {
|
|
|
+ async *[Symbol.asyncIterator]() {
|
|
|
+ for (const chunk of mockStream) {
|
|
|
+ yield chunk
|
|
|
+ }
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ const mockCreate = vitest.fn().mockResolvedValue(asyncIterator)
|
|
|
+ ;(handler["client"].messages as any).create = mockCreate
|
|
|
+
|
|
|
+ const stream = handler.createMessage(systemPrompt, messages, {
|
|
|
+ taskId: "test-task",
|
|
|
+ tools: mockTools,
|
|
|
+ })
|
|
|
+
|
|
|
+ const chunks: ApiStreamChunk[] = []
|
|
|
+ for await (const chunk of stream) {
|
|
|
+ chunks.push(chunk)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Find the tool_call_partial chunk
|
|
|
+ const toolCallChunk = chunks.find((chunk) => chunk.type === "tool_call_partial")
|
|
|
+ expect(toolCallChunk).toBeDefined()
|
|
|
+ expect(toolCallChunk).toEqual({
|
|
|
+ type: "tool_call_partial",
|
|
|
+ index: 0,
|
|
|
+ id: "toolu_123",
|
|
|
+ name: "get_weather",
|
|
|
+ arguments: undefined,
|
|
|
+ })
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should handle input_json_delta in stream and emit tool_call_partial arguments", async () => {
|
|
|
+ handler = new AnthropicVertexHandler({
|
|
|
+ apiModelId: "claude-3-5-sonnet-v2@20241022",
|
|
|
+ vertexProjectId: "test-project",
|
|
|
+ vertexRegion: "us-central1",
|
|
|
+ })
|
|
|
+
|
|
|
+ const mockStream = [
|
|
|
+ {
|
|
|
+ type: "message_start",
|
|
|
+ message: {
|
|
|
+ usage: {
|
|
|
+ input_tokens: 100,
|
|
|
+ output_tokens: 50,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ type: "content_block_start",
|
|
|
+ index: 0,
|
|
|
+ content_block: {
|
|
|
+ type: "tool_use",
|
|
|
+ id: "toolu_123",
|
|
|
+ name: "get_weather",
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ type: "content_block_delta",
|
|
|
+ index: 0,
|
|
|
+ delta: {
|
|
|
+ type: "input_json_delta",
|
|
|
+ partial_json: '{"location":',
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ type: "content_block_delta",
|
|
|
+ index: 0,
|
|
|
+ delta: {
|
|
|
+ type: "input_json_delta",
|
|
|
+ partial_json: '"London"}',
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ type: "content_block_stop",
|
|
|
+ index: 0,
|
|
|
+ },
|
|
|
+ ]
|
|
|
+
|
|
|
+ const asyncIterator = {
|
|
|
+ async *[Symbol.asyncIterator]() {
|
|
|
+ for (const chunk of mockStream) {
|
|
|
+ yield chunk
|
|
|
+ }
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ const mockCreate = vitest.fn().mockResolvedValue(asyncIterator)
|
|
|
+ ;(handler["client"].messages as any).create = mockCreate
|
|
|
+
|
|
|
+ const stream = handler.createMessage(systemPrompt, messages, {
|
|
|
+ taskId: "test-task",
|
|
|
+ tools: mockTools,
|
|
|
+ })
|
|
|
+
|
|
|
+ const chunks: ApiStreamChunk[] = []
|
|
|
+ for await (const chunk of stream) {
|
|
|
+ chunks.push(chunk)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Find the tool_call_partial chunks
|
|
|
+ const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial")
|
|
|
+ expect(toolCallChunks).toHaveLength(3)
|
|
|
+
|
|
|
+ // First chunk has id and name
|
|
|
+ expect(toolCallChunks[0]).toEqual({
|
|
|
+ type: "tool_call_partial",
|
|
|
+ index: 0,
|
|
|
+ id: "toolu_123",
|
|
|
+ name: "get_weather",
|
|
|
+ arguments: undefined,
|
|
|
+ })
|
|
|
+
|
|
|
+ // Subsequent chunks have arguments
|
|
|
+ expect(toolCallChunks[1]).toEqual({
|
|
|
+ type: "tool_call_partial",
|
|
|
+ index: 0,
|
|
|
+ id: undefined,
|
|
|
+ name: undefined,
|
|
|
+ arguments: '{"location":',
|
|
|
+ })
|
|
|
+
|
|
|
+ expect(toolCallChunks[2]).toEqual({
|
|
|
+ type: "tool_call_partial",
|
|
|
+ index: 0,
|
|
|
+ id: undefined,
|
|
|
+ name: undefined,
|
|
|
+ arguments: '"London"}',
|
|
|
+ })
|
|
|
+ })
|
|
|
+ })
|
|
|
})
|