batch_test.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. package tools
  2. import (
  3. "context"
  4. "encoding/json"
  5. "testing"
  6. "github.com/stretchr/testify/assert"
  7. )
  8. // MockTool is a simple tool implementation for testing
  9. type MockTool struct {
  10. name string
  11. description string
  12. response ToolResponse
  13. err error
  14. }
  15. func (m *MockTool) Info() ToolInfo {
  16. return ToolInfo{
  17. Name: m.name,
  18. Description: m.description,
  19. Parameters: map[string]any{},
  20. Required: []string{},
  21. }
  22. }
  23. func (m *MockTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
  24. return m.response, m.err
  25. }
  26. func TestBatchTool(t *testing.T) {
  27. t.Parallel()
  28. t.Run("successful batch execution", func(t *testing.T) {
  29. t.Parallel()
  30. // Create mock tools
  31. mockTools := map[string]BaseTool{
  32. "tool1": &MockTool{
  33. name: "tool1",
  34. description: "Mock Tool 1",
  35. response: NewTextResponse("Tool 1 Response"),
  36. err: nil,
  37. },
  38. "tool2": &MockTool{
  39. name: "tool2",
  40. description: "Mock Tool 2",
  41. response: NewTextResponse("Tool 2 Response"),
  42. err: nil,
  43. },
  44. }
  45. // Create batch tool
  46. batchTool := NewBatchTool(mockTools)
  47. // Create batch call
  48. input := `{
  49. "calls": [
  50. {
  51. "name": "tool1",
  52. "input": {}
  53. },
  54. {
  55. "name": "tool2",
  56. "input": {}
  57. }
  58. ]
  59. }`
  60. call := ToolCall{
  61. ID: "test-batch",
  62. Name: "batch",
  63. Input: input,
  64. }
  65. // Execute batch
  66. response, err := batchTool.Run(context.Background(), call)
  67. // Verify results
  68. assert.NoError(t, err)
  69. assert.Equal(t, ToolResponseTypeText, response.Type)
  70. assert.False(t, response.IsError)
  71. // Parse the response
  72. var batchResult BatchResult
  73. err = json.Unmarshal([]byte(response.Content), &batchResult)
  74. assert.NoError(t, err)
  75. // Verify batch results
  76. assert.Len(t, batchResult.Results, 2)
  77. assert.Empty(t, batchResult.Results[0].Error)
  78. assert.Empty(t, batchResult.Results[1].Error)
  79. assert.Empty(t, batchResult.Results[0].Separator)
  80. assert.NotEmpty(t, batchResult.Results[1].Separator)
  81. // Verify individual results
  82. var result1 ToolResponse
  83. err = json.Unmarshal(batchResult.Results[0].Result, &result1)
  84. assert.NoError(t, err)
  85. assert.Equal(t, "Tool 1 Response", result1.Content)
  86. var result2 ToolResponse
  87. err = json.Unmarshal(batchResult.Results[1].Result, &result2)
  88. assert.NoError(t, err)
  89. assert.Equal(t, "Tool 2 Response", result2.Content)
  90. })
  91. t.Run("tool not found", func(t *testing.T) {
  92. t.Parallel()
  93. // Create mock tools
  94. mockTools := map[string]BaseTool{
  95. "tool1": &MockTool{
  96. name: "tool1",
  97. description: "Mock Tool 1",
  98. response: NewTextResponse("Tool 1 Response"),
  99. err: nil,
  100. },
  101. }
  102. // Create batch tool
  103. batchTool := NewBatchTool(mockTools)
  104. // Create batch call with non-existent tool
  105. input := `{
  106. "calls": [
  107. {
  108. "name": "tool1",
  109. "input": {}
  110. },
  111. {
  112. "name": "nonexistent",
  113. "input": {}
  114. }
  115. ]
  116. }`
  117. call := ToolCall{
  118. ID: "test-batch",
  119. Name: "batch",
  120. Input: input,
  121. }
  122. // Execute batch
  123. response, err := batchTool.Run(context.Background(), call)
  124. // Verify results
  125. assert.NoError(t, err)
  126. assert.Equal(t, ToolResponseTypeText, response.Type)
  127. assert.False(t, response.IsError)
  128. // Parse the response
  129. var batchResult BatchResult
  130. err = json.Unmarshal([]byte(response.Content), &batchResult)
  131. assert.NoError(t, err)
  132. // Verify batch results
  133. assert.Len(t, batchResult.Results, 2)
  134. assert.Empty(t, batchResult.Results[0].Error)
  135. assert.Contains(t, batchResult.Results[1].Error, "tool not found: nonexistent")
  136. })
  137. t.Run("empty calls", func(t *testing.T) {
  138. t.Parallel()
  139. // Create batch tool with empty tools map
  140. batchTool := NewBatchTool(map[string]BaseTool{})
  141. // Create batch call with empty calls
  142. input := `{
  143. "calls": []
  144. }`
  145. call := ToolCall{
  146. ID: "test-batch",
  147. Name: "batch",
  148. Input: input,
  149. }
  150. // Execute batch
  151. response, err := batchTool.Run(context.Background(), call)
  152. // Verify results
  153. assert.NoError(t, err)
  154. assert.Equal(t, ToolResponseTypeText, response.Type)
  155. assert.True(t, response.IsError)
  156. assert.Contains(t, response.Content, "no tool calls provided")
  157. })
  158. t.Run("invalid input", func(t *testing.T) {
  159. t.Parallel()
  160. // Create batch tool with empty tools map
  161. batchTool := NewBatchTool(map[string]BaseTool{})
  162. // Create batch call with invalid JSON
  163. input := `{
  164. "calls": [
  165. {
  166. "name": "tool1",
  167. "input": {
  168. "invalid": json
  169. }
  170. }
  171. ]
  172. }`
  173. call := ToolCall{
  174. ID: "test-batch",
  175. Name: "batch",
  176. Input: input,
  177. }
  178. // Execute batch
  179. response, err := batchTool.Run(context.Background(), call)
  180. // Verify results
  181. assert.NoError(t, err)
  182. assert.Equal(t, ToolResponseTypeText, response.Type)
  183. assert.True(t, response.IsError)
  184. assert.Contains(t, response.Content, "error parsing parameters")
  185. })
  186. }