|
|
@@ -280,4 +280,220 @@ describe("XAIHandler", () => {
|
|
|
}),
|
|
|
)
|
|
|
})
|
|
|
+
|
|
|
+ describe("Native Tool Calling", () => {
|
|
|
+ const testTools = [
|
|
|
+ {
|
|
|
+ type: "function" as const,
|
|
|
+ function: {
|
|
|
+ name: "test_tool",
|
|
|
+ description: "A test tool",
|
|
|
+ parameters: {
|
|
|
+ type: "object",
|
|
|
+ properties: {
|
|
|
+ arg1: { type: "string", description: "First argument" },
|
|
|
+ },
|
|
|
+ required: ["arg1"],
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ ]
|
|
|
+
|
|
|
+ it("should include tools in request when model supports native tools and tools are provided", async () => {
|
|
|
+ const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" })
|
|
|
+
|
|
|
+ mockCreate.mockImplementationOnce(() => {
|
|
|
+ return {
|
|
|
+ [Symbol.asyncIterator]: () => ({
|
|
|
+ async next() {
|
|
|
+ return { done: true }
|
|
|
+ },
|
|
|
+ }),
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ const messageGenerator = handlerWithTools.createMessage("test prompt", [], {
|
|
|
+ taskId: "test-task-id",
|
|
|
+ tools: testTools,
|
|
|
+ toolProtocol: "native",
|
|
|
+ })
|
|
|
+ await messageGenerator.next()
|
|
|
+
|
|
|
+ expect(mockCreate).toHaveBeenCalledWith(
|
|
|
+ expect.objectContaining({
|
|
|
+ tools: expect.arrayContaining([
|
|
|
+ expect.objectContaining({
|
|
|
+ type: "function",
|
|
|
+ function: expect.objectContaining({
|
|
|
+ name: "test_tool",
|
|
|
+ }),
|
|
|
+ }),
|
|
|
+ ]),
|
|
|
+ parallel_tool_calls: false,
|
|
|
+ }),
|
|
|
+ )
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should include tool_choice when provided", async () => {
|
|
|
+ const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" })
|
|
|
+
|
|
|
+ mockCreate.mockImplementationOnce(() => {
|
|
|
+ return {
|
|
|
+ [Symbol.asyncIterator]: () => ({
|
|
|
+ async next() {
|
|
|
+ return { done: true }
|
|
|
+ },
|
|
|
+ }),
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ const messageGenerator = handlerWithTools.createMessage("test prompt", [], {
|
|
|
+ taskId: "test-task-id",
|
|
|
+ tools: testTools,
|
|
|
+ toolProtocol: "native",
|
|
|
+ tool_choice: "auto",
|
|
|
+ })
|
|
|
+ await messageGenerator.next()
|
|
|
+
|
|
|
+ expect(mockCreate).toHaveBeenCalledWith(
|
|
|
+ expect.objectContaining({
|
|
|
+ tool_choice: "auto",
|
|
|
+ }),
|
|
|
+ )
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should not include tools when toolProtocol is xml", async () => {
|
|
|
+ const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" })
|
|
|
+
|
|
|
+ mockCreate.mockImplementationOnce(() => {
|
|
|
+ return {
|
|
|
+ [Symbol.asyncIterator]: () => ({
|
|
|
+ async next() {
|
|
|
+ return { done: true }
|
|
|
+ },
|
|
|
+ }),
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ const messageGenerator = handlerWithTools.createMessage("test prompt", [], {
|
|
|
+ taskId: "test-task-id",
|
|
|
+ tools: testTools,
|
|
|
+ toolProtocol: "xml",
|
|
|
+ })
|
|
|
+ await messageGenerator.next()
|
|
|
+
|
|
|
+ const callArgs = mockCreate.mock.calls[mockCreate.mock.calls.length - 1][0]
|
|
|
+ expect(callArgs).not.toHaveProperty("tools")
|
|
|
+ expect(callArgs).not.toHaveProperty("tool_choice")
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should yield tool_call_partial chunks during streaming", async () => {
|
|
|
+ const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" })
|
|
|
+
|
|
|
+ mockCreate.mockImplementationOnce(() => {
|
|
|
+ return {
|
|
|
+ [Symbol.asyncIterator]: () => ({
|
|
|
+ next: vi
|
|
|
+ .fn()
|
|
|
+ .mockResolvedValueOnce({
|
|
|
+ done: false,
|
|
|
+ value: {
|
|
|
+ choices: [
|
|
|
+ {
|
|
|
+ delta: {
|
|
|
+ tool_calls: [
|
|
|
+ {
|
|
|
+ index: 0,
|
|
|
+ id: "call_123",
|
|
|
+ function: {
|
|
|
+ name: "test_tool",
|
|
|
+ arguments: '{"arg1":',
|
|
|
+ },
|
|
|
+ },
|
|
|
+ ],
|
|
|
+ },
|
|
|
+ },
|
|
|
+ ],
|
|
|
+ },
|
|
|
+ })
|
|
|
+ .mockResolvedValueOnce({
|
|
|
+ done: false,
|
|
|
+ value: {
|
|
|
+ choices: [
|
|
|
+ {
|
|
|
+ delta: {
|
|
|
+ tool_calls: [
|
|
|
+ {
|
|
|
+ index: 0,
|
|
|
+ function: {
|
|
|
+ arguments: '"value"}',
|
|
|
+ },
|
|
|
+ },
|
|
|
+ ],
|
|
|
+ },
|
|
|
+ },
|
|
|
+ ],
|
|
|
+ },
|
|
|
+ })
|
|
|
+ .mockResolvedValueOnce({ done: true }),
|
|
|
+ }),
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ const stream = handlerWithTools.createMessage("test prompt", [], {
|
|
|
+ taskId: "test-task-id",
|
|
|
+ tools: testTools,
|
|
|
+ toolProtocol: "native",
|
|
|
+ })
|
|
|
+
|
|
|
+ const chunks = []
|
|
|
+ for await (const chunk of stream) {
|
|
|
+ chunks.push(chunk)
|
|
|
+ }
|
|
|
+
|
|
|
+ expect(chunks).toContainEqual({
|
|
|
+ type: "tool_call_partial",
|
|
|
+ index: 0,
|
|
|
+ id: "call_123",
|
|
|
+ name: "test_tool",
|
|
|
+ arguments: '{"arg1":',
|
|
|
+ })
|
|
|
+
|
|
|
+ expect(chunks).toContainEqual({
|
|
|
+ type: "tool_call_partial",
|
|
|
+ index: 0,
|
|
|
+ id: undefined,
|
|
|
+ name: undefined,
|
|
|
+ arguments: '"value"}',
|
|
|
+ })
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should set parallel_tool_calls based on metadata", async () => {
|
|
|
+ const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" })
|
|
|
+
|
|
|
+ mockCreate.mockImplementationOnce(() => {
|
|
|
+ return {
|
|
|
+ [Symbol.asyncIterator]: () => ({
|
|
|
+ async next() {
|
|
|
+ return { done: true }
|
|
|
+ },
|
|
|
+ }),
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ const messageGenerator = handlerWithTools.createMessage("test prompt", [], {
|
|
|
+ taskId: "test-task-id",
|
|
|
+ tools: testTools,
|
|
|
+ toolProtocol: "native",
|
|
|
+ parallelToolCalls: true,
|
|
|
+ })
|
|
|
+ await messageGenerator.next()
|
|
|
+
|
|
|
+ expect(mockCreate).toHaveBeenCalledWith(
|
|
|
+ expect.objectContaining({
|
|
|
+ parallel_tool_calls: true,
|
|
|
+ }),
|
|
|
+ )
|
|
|
+ })
|
|
|
+ })
|
|
|
})
|