sourcegraph.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. package tools
  2. import (
  3. "bytes"
  4. "context"
  5. _ "embed"
  6. "encoding/json"
  7. "fmt"
  8. "io"
  9. "net/http"
  10. "strings"
  11. "time"
  12. "charm.land/fantasy"
  13. )
  14. type SourcegraphParams struct {
  15. Query string `json:"query" description:"The Sourcegraph search query"`
  16. Count int `json:"count,omitempty" description:"Optional number of results to return (default: 10, max: 20)"`
  17. ContextWindow int `json:"context_window,omitempty" description:"The context around the match to return (default: 10 lines)"`
  18. Timeout int `json:"timeout,omitempty" description:"Optional timeout in seconds (max 120)"`
  19. }
  20. type SourcegraphResponseMetadata struct {
  21. NumberOfMatches int `json:"number_of_matches"`
  22. Truncated bool `json:"truncated"`
  23. }
  24. const SourcegraphToolName = "sourcegraph"
  25. //go:embed sourcegraph.md
  26. var sourcegraphDescription []byte
  27. func NewSourcegraphTool(client *http.Client) fantasy.AgentTool {
  28. if client == nil {
  29. transport := http.DefaultTransport.(*http.Transport).Clone()
  30. transport.MaxIdleConns = 100
  31. transport.MaxIdleConnsPerHost = 10
  32. transport.IdleConnTimeout = 90 * time.Second
  33. client = &http.Client{
  34. Timeout: 30 * time.Second,
  35. Transport: transport,
  36. }
  37. }
  38. return fantasy.NewParallelAgentTool(
  39. SourcegraphToolName,
  40. string(sourcegraphDescription),
  41. func(ctx context.Context, params SourcegraphParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
  42. if params.Query == "" {
  43. return fantasy.NewTextErrorResponse("Query parameter is required"), nil
  44. }
  45. if params.Count <= 0 {
  46. params.Count = 10
  47. } else if params.Count > 20 {
  48. params.Count = 20 // Limit to 20 results
  49. }
  50. if params.ContextWindow <= 0 {
  51. params.ContextWindow = 10 // Default context window
  52. }
  53. // Handle timeout with context
  54. requestCtx := ctx
  55. if params.Timeout > 0 {
  56. maxTimeout := 120 // 2 minutes
  57. if params.Timeout > maxTimeout {
  58. params.Timeout = maxTimeout
  59. }
  60. var cancel context.CancelFunc
  61. requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
  62. defer cancel()
  63. }
  64. type graphqlRequest struct {
  65. Query string `json:"query"`
  66. Variables struct {
  67. Query string `json:"query"`
  68. } `json:"variables"`
  69. }
  70. request := graphqlRequest{
  71. Query: "query Search($query: String!) { search(query: $query, version: V2, patternType: keyword ) { results { matchCount, limitHit, resultCount, approximateResultCount, missing { name }, timedout { name }, indexUnavailable, results { __typename, ... on FileMatch { repository { name }, file { path, url, content }, lineMatches { preview, lineNumber, offsetAndLengths } } } } } }",
  72. }
  73. request.Variables.Query = params.Query
  74. graphqlQueryBytes, err := json.Marshal(request)
  75. if err != nil {
  76. return fantasy.ToolResponse{}, fmt.Errorf("failed to marshal GraphQL request: %w", err)
  77. }
  78. graphqlQuery := string(graphqlQueryBytes)
  79. req, err := http.NewRequestWithContext(
  80. requestCtx,
  81. "POST",
  82. "https://sourcegraph.com/.api/graphql",
  83. bytes.NewBuffer([]byte(graphqlQuery)),
  84. )
  85. if err != nil {
  86. return fantasy.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
  87. }
  88. req.Header.Set("Content-Type", "application/json")
  89. req.Header.Set("User-Agent", "crush/1.0")
  90. resp, err := client.Do(req)
  91. if err != nil {
  92. return fantasy.ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
  93. }
  94. defer resp.Body.Close()
  95. if resp.StatusCode != http.StatusOK {
  96. body, _ := io.ReadAll(resp.Body)
  97. if len(body) > 0 {
  98. return fantasy.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d, response: %s", resp.StatusCode, string(body))), nil
  99. }
  100. return fantasy.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
  101. }
  102. body, err := io.ReadAll(resp.Body)
  103. if err != nil {
  104. return fantasy.ToolResponse{}, fmt.Errorf("failed to read response body: %w", err)
  105. }
  106. var result map[string]any
  107. if err = json.Unmarshal(body, &result); err != nil {
  108. return fantasy.ToolResponse{}, fmt.Errorf("failed to unmarshal response: %w", err)
  109. }
  110. formattedResults, err := formatSourcegraphResults(result, params.ContextWindow)
  111. if err != nil {
  112. return fantasy.NewTextErrorResponse("Failed to format results: " + err.Error()), nil
  113. }
  114. return fantasy.NewTextResponse(formattedResults), nil
  115. })
  116. }
  117. func formatSourcegraphResults(result map[string]any, contextWindow int) (string, error) {
  118. var buffer strings.Builder
  119. if errors, ok := result["errors"].([]any); ok && len(errors) > 0 {
  120. buffer.WriteString("## Sourcegraph API Error\n\n")
  121. for _, err := range errors {
  122. if errMap, ok := err.(map[string]any); ok {
  123. if message, ok := errMap["message"].(string); ok {
  124. fmt.Fprintf(&buffer, "- %s\n", message)
  125. }
  126. }
  127. }
  128. return buffer.String(), nil
  129. }
  130. data, ok := result["data"].(map[string]any)
  131. if !ok {
  132. return "", fmt.Errorf("invalid response format: missing data field")
  133. }
  134. search, ok := data["search"].(map[string]any)
  135. if !ok {
  136. return "", fmt.Errorf("invalid response format: missing search field")
  137. }
  138. searchResults, ok := search["results"].(map[string]any)
  139. if !ok {
  140. return "", fmt.Errorf("invalid response format: missing results field")
  141. }
  142. matchCount, _ := searchResults["matchCount"].(float64)
  143. resultCount, _ := searchResults["resultCount"].(float64)
  144. limitHit, _ := searchResults["limitHit"].(bool)
  145. buffer.WriteString("# Sourcegraph Search Results\n\n")
  146. fmt.Fprintf(&buffer, "Found %d matches across %d results\n", int(matchCount), int(resultCount))
  147. if limitHit {
  148. buffer.WriteString("(Result limit reached, try a more specific query)\n")
  149. }
  150. buffer.WriteString("\n")
  151. results, ok := searchResults["results"].([]any)
  152. if !ok || len(results) == 0 {
  153. buffer.WriteString("No results found. Try a different query.\n")
  154. return buffer.String(), nil
  155. }
  156. maxResults := 10
  157. if len(results) > maxResults {
  158. results = results[:maxResults]
  159. }
  160. for i, res := range results {
  161. fileMatch, ok := res.(map[string]any)
  162. if !ok {
  163. continue
  164. }
  165. typeName, _ := fileMatch["__typename"].(string)
  166. if typeName != "FileMatch" {
  167. continue
  168. }
  169. repo, _ := fileMatch["repository"].(map[string]any)
  170. file, _ := fileMatch["file"].(map[string]any)
  171. lineMatches, _ := fileMatch["lineMatches"].([]any)
  172. if repo == nil || file == nil {
  173. continue
  174. }
  175. repoName, _ := repo["name"].(string)
  176. filePath, _ := file["path"].(string)
  177. fileURL, _ := file["url"].(string)
  178. fileContent, _ := file["content"].(string)
  179. fmt.Fprintf(&buffer, "## Result %d: %s/%s\n\n", i+1, repoName, filePath)
  180. if fileURL != "" {
  181. fmt.Fprintf(&buffer, "URL: %s\n\n", fileURL)
  182. }
  183. if len(lineMatches) > 0 {
  184. for _, lm := range lineMatches {
  185. lineMatch, ok := lm.(map[string]any)
  186. if !ok {
  187. continue
  188. }
  189. lineNumber, _ := lineMatch["lineNumber"].(float64)
  190. preview, _ := lineMatch["preview"].(string)
  191. if fileContent != "" {
  192. lines := strings.Split(fileContent, "\n")
  193. buffer.WriteString("```\n")
  194. startLine := max(1, int(lineNumber)-contextWindow)
  195. for j := startLine - 1; j < int(lineNumber)-1 && j < len(lines); j++ {
  196. if j >= 0 {
  197. fmt.Fprintf(&buffer, "%d| %s\n", j+1, lines[j])
  198. }
  199. }
  200. fmt.Fprintf(&buffer, "%d| %s\n", int(lineNumber), preview)
  201. endLine := int(lineNumber) + contextWindow
  202. for j := int(lineNumber); j < endLine && j < len(lines); j++ {
  203. if j < len(lines) {
  204. fmt.Fprintf(&buffer, "%d| %s\n", j+1, lines[j])
  205. }
  206. }
  207. buffer.WriteString("```\n\n")
  208. } else {
  209. buffer.WriteString("```\n")
  210. fmt.Fprintf(&buffer, "%d| %s\n", int(lineNumber), preview)
  211. buffer.WriteString("```\n\n")
  212. }
  213. }
  214. }
  215. }
  216. return buffer.String(), nil
  217. }