sourcegraph_test.go 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. package tools
  2. import (
  3. "context"
  4. "encoding/json"
  5. "testing"
  6. "github.com/stretchr/testify/assert"
  7. "github.com/stretchr/testify/require"
  8. )
  9. func TestSourcegraphTool_Info(t *testing.T) {
  10. tool := NewSourcegraphTool()
  11. info := tool.Info()
  12. assert.Equal(t, SourcegraphToolName, info.Name)
  13. assert.NotEmpty(t, info.Description)
  14. assert.Contains(t, info.Parameters, "query")
  15. assert.Contains(t, info.Parameters, "count")
  16. assert.Contains(t, info.Parameters, "timeout")
  17. assert.Contains(t, info.Required, "query")
  18. }
  19. func TestSourcegraphTool_Run(t *testing.T) {
  20. t.Run("handles missing query parameter", func(t *testing.T) {
  21. tool := NewSourcegraphTool()
  22. params := SourcegraphParams{
  23. Query: "",
  24. }
  25. paramsJSON, err := json.Marshal(params)
  26. require.NoError(t, err)
  27. call := ToolCall{
  28. Name: SourcegraphToolName,
  29. Input: string(paramsJSON),
  30. }
  31. response, err := tool.Run(context.Background(), call)
  32. require.NoError(t, err)
  33. assert.Contains(t, response.Content, "Query parameter is required")
  34. })
  35. t.Run("handles invalid parameters", func(t *testing.T) {
  36. tool := NewSourcegraphTool()
  37. call := ToolCall{
  38. Name: SourcegraphToolName,
  39. Input: "invalid json",
  40. }
  41. response, err := tool.Run(context.Background(), call)
  42. require.NoError(t, err)
  43. assert.Contains(t, response.Content, "Failed to parse sourcegraph parameters")
  44. })
  45. t.Run("normalizes count parameter", func(t *testing.T) {
  46. // Test cases for count normalization
  47. testCases := []struct {
  48. name string
  49. inputCount int
  50. expectedCount int
  51. }{
  52. {"negative count", -5, 10}, // Should use default (10)
  53. {"zero count", 0, 10}, // Should use default (10)
  54. {"valid count", 50, 50}, // Should keep as is
  55. {"excessive count", 150, 100}, // Should cap at 100
  56. }
  57. for _, tc := range testCases {
  58. t.Run(tc.name, func(t *testing.T) {
  59. // Verify count normalization logic directly
  60. assert.NotPanics(t, func() {
  61. // Apply the same normalization logic as in the tool
  62. normalizedCount := tc.inputCount
  63. if normalizedCount <= 0 {
  64. normalizedCount = 10
  65. } else if normalizedCount > 100 {
  66. normalizedCount = 100
  67. }
  68. assert.Equal(t, tc.expectedCount, normalizedCount)
  69. })
  70. })
  71. }
  72. })
  73. }