| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115 |
- package tools
- import (
- "context"
- "encoding/json"
- "testing"
- "github.com/kujtimiihoxha/termai/internal/permission"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
- )
- func TestSourcegraphTool_Info(t *testing.T) {
- tool := NewSourcegraphTool()
- info := tool.Info()
- assert.Equal(t, SourcegraphToolName, info.Name)
- assert.NotEmpty(t, info.Description)
- assert.Contains(t, info.Parameters, "query")
- assert.Contains(t, info.Parameters, "count")
- assert.Contains(t, info.Parameters, "timeout")
- assert.Contains(t, info.Required, "query")
- }
- func TestSourcegraphTool_Run(t *testing.T) {
- // Setup a mock permission handler that always allows
- origPermission := permission.Default
- defer func() {
- permission.Default = origPermission
- }()
- permission.Default = newMockPermissionService(true)
- t.Run("handles missing query parameter", func(t *testing.T) {
- tool := NewSourcegraphTool()
- params := SourcegraphParams{
- Query: "",
- }
- paramsJSON, err := json.Marshal(params)
- require.NoError(t, err)
- call := ToolCall{
- Name: SourcegraphToolName,
- Input: string(paramsJSON),
- }
- response, err := tool.Run(context.Background(), call)
- require.NoError(t, err)
- assert.Contains(t, response.Content, "Query parameter is required")
- })
- t.Run("handles invalid parameters", func(t *testing.T) {
- tool := NewSourcegraphTool()
- call := ToolCall{
- Name: SourcegraphToolName,
- Input: "invalid json",
- }
- response, err := tool.Run(context.Background(), call)
- require.NoError(t, err)
- assert.Contains(t, response.Content, "Failed to parse sourcegraph parameters")
- })
- t.Run("handles permission denied", func(t *testing.T) {
- permission.Default = newMockPermissionService(false)
- tool := NewSourcegraphTool()
- params := SourcegraphParams{
- Query: "test query",
- }
- paramsJSON, err := json.Marshal(params)
- require.NoError(t, err)
- call := ToolCall{
- Name: SourcegraphToolName,
- Input: string(paramsJSON),
- }
- response, err := tool.Run(context.Background(), call)
- require.NoError(t, err)
- assert.Contains(t, response.Content, "Permission denied")
- })
- t.Run("normalizes count parameter", func(t *testing.T) {
- // Test cases for count normalization
- testCases := []struct {
- name string
- inputCount int
- expectedCount int
- }{
- {"negative count", -5, 10}, // Should use default (10)
- {"zero count", 0, 10}, // Should use default (10)
- {"valid count", 50, 50}, // Should keep as is
- {"excessive count", 150, 100}, // Should cap at 100
- }
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- // Verify count normalization logic directly
- assert.NotPanics(t, func() {
- // Apply the same normalization logic as in the tool
- normalizedCount := tc.inputCount
- if normalizedCount <= 0 {
- normalizedCount = 10
- } else if normalizedCount > 100 {
- normalizedCount = 100
- }
- assert.Equal(t, tc.expectedCount, normalizedCount)
- })
- })
- }
- })
- }
|