sourcegraph_test.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. package tools
  2. import (
  3. "context"
  4. "encoding/json"
  5. "testing"
  6. "github.com/kujtimiihoxha/termai/internal/permission"
  7. "github.com/stretchr/testify/assert"
  8. "github.com/stretchr/testify/require"
  9. )
  10. func TestSourcegraphTool_Info(t *testing.T) {
  11. tool := NewSourcegraphTool()
  12. info := tool.Info()
  13. assert.Equal(t, SourcegraphToolName, info.Name)
  14. assert.NotEmpty(t, info.Description)
  15. assert.Contains(t, info.Parameters, "query")
  16. assert.Contains(t, info.Parameters, "count")
  17. assert.Contains(t, info.Parameters, "timeout")
  18. assert.Contains(t, info.Required, "query")
  19. }
  20. func TestSourcegraphTool_Run(t *testing.T) {
  21. // Setup a mock permission handler that always allows
  22. origPermission := permission.Default
  23. defer func() {
  24. permission.Default = origPermission
  25. }()
  26. permission.Default = newMockPermissionService(true)
  27. t.Run("handles missing query parameter", func(t *testing.T) {
  28. tool := NewSourcegraphTool()
  29. params := SourcegraphParams{
  30. Query: "",
  31. }
  32. paramsJSON, err := json.Marshal(params)
  33. require.NoError(t, err)
  34. call := ToolCall{
  35. Name: SourcegraphToolName,
  36. Input: string(paramsJSON),
  37. }
  38. response, err := tool.Run(context.Background(), call)
  39. require.NoError(t, err)
  40. assert.Contains(t, response.Content, "Query parameter is required")
  41. })
  42. t.Run("handles invalid parameters", func(t *testing.T) {
  43. tool := NewSourcegraphTool()
  44. call := ToolCall{
  45. Name: SourcegraphToolName,
  46. Input: "invalid json",
  47. }
  48. response, err := tool.Run(context.Background(), call)
  49. require.NoError(t, err)
  50. assert.Contains(t, response.Content, "Failed to parse sourcegraph parameters")
  51. })
  52. t.Run("handles permission denied", func(t *testing.T) {
  53. permission.Default = newMockPermissionService(false)
  54. tool := NewSourcegraphTool()
  55. params := SourcegraphParams{
  56. Query: "test query",
  57. }
  58. paramsJSON, err := json.Marshal(params)
  59. require.NoError(t, err)
  60. call := ToolCall{
  61. Name: SourcegraphToolName,
  62. Input: string(paramsJSON),
  63. }
  64. response, err := tool.Run(context.Background(), call)
  65. require.NoError(t, err)
  66. assert.Contains(t, response.Content, "Permission denied")
  67. })
  68. t.Run("normalizes count parameter", func(t *testing.T) {
  69. // Test cases for count normalization
  70. testCases := []struct {
  71. name string
  72. inputCount int
  73. expectedCount int
  74. }{
  75. {"negative count", -5, 10}, // Should use default (10)
  76. {"zero count", 0, 10}, // Should use default (10)
  77. {"valid count", 50, 50}, // Should keep as is
  78. {"excessive count", 150, 100}, // Should cap at 100
  79. }
  80. for _, tc := range testCases {
  81. t.Run(tc.name, func(t *testing.T) {
  82. // Verify count normalization logic directly
  83. assert.NotPanics(t, func() {
  84. // Apply the same normalization logic as in the tool
  85. normalizedCount := tc.inputCount
  86. if normalizedCount <= 0 {
  87. normalizedCount = 10
  88. } else if normalizedCount > 100 {
  89. normalizedCount = 100
  90. }
  91. assert.Equal(t, tc.expectedCount, normalizedCount)
  92. })
  93. })
  94. }
  95. })
  96. }