generic.go 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. package chat
  2. import (
  3. "encoding/json"
  4. "strings"
  5. "github.com/charmbracelet/crush/internal/message"
  6. "github.com/charmbracelet/crush/internal/stringext"
  7. "github.com/charmbracelet/crush/internal/ui/styles"
  8. )
  9. // GenericToolMessageItem is a message item that represents an unknown tool call.
  10. type GenericToolMessageItem struct {
  11. *baseToolMessageItem
  12. }
  13. var _ ToolMessageItem = (*GenericToolMessageItem)(nil)
  14. // NewGenericToolMessageItem creates a new [GenericToolMessageItem].
  15. func NewGenericToolMessageItem(
  16. sty *styles.Styles,
  17. toolCall message.ToolCall,
  18. result *message.ToolResult,
  19. canceled bool,
  20. ) ToolMessageItem {
  21. return newBaseToolMessageItem(sty, toolCall, result, &GenericToolRenderContext{}, canceled)
  22. }
  23. // GenericToolRenderContext renders unknown/generic tool messages.
  24. type GenericToolRenderContext struct{}
  25. // RenderTool implements the [ToolRenderer] interface.
  26. func (g *GenericToolRenderContext) RenderTool(sty *styles.Styles, width int, opts *ToolRenderOpts) string {
  27. cappedWidth := cappedMessageWidth(width)
  28. name := genericPrettyName(opts.ToolCall.Name)
  29. if opts.IsPending() {
  30. return pendingTool(sty, name, opts.Anim)
  31. }
  32. var params map[string]any
  33. if err := json.Unmarshal([]byte(opts.ToolCall.Input), &params); err != nil {
  34. return toolErrorContent(sty, &message.ToolResult{Content: "Invalid parameters"}, cappedWidth)
  35. }
  36. var toolParams []string
  37. if len(params) > 0 {
  38. parsed, _ := json.Marshal(params)
  39. toolParams = append(toolParams, string(parsed))
  40. }
  41. header := toolHeader(sty, opts.Status, name, cappedWidth, opts.Compact, toolParams...)
  42. if opts.Compact {
  43. return header
  44. }
  45. if earlyState, ok := toolEarlyStateContent(sty, opts, cappedWidth); ok {
  46. return joinToolParts(header, earlyState)
  47. }
  48. if !opts.HasResult() || opts.Result.Content == "" {
  49. return header
  50. }
  51. bodyWidth := cappedWidth - toolBodyLeftPaddingTotal
  52. // Handle image data.
  53. if opts.Result.Data != "" && strings.HasPrefix(opts.Result.MIMEType, "image/") {
  54. body := sty.Tool.Body.Render(toolOutputImageContent(sty, opts.Result.Data, opts.Result.MIMEType))
  55. return joinToolParts(header, body)
  56. }
  57. // Try to parse result as JSON for pretty display.
  58. var result json.RawMessage
  59. var body string
  60. if err := json.Unmarshal([]byte(opts.Result.Content), &result); err == nil {
  61. prettyResult, err := json.MarshalIndent(result, "", " ")
  62. if err == nil {
  63. body = sty.Tool.Body.Render(toolOutputCodeContent(sty, "result.json", string(prettyResult), 0, bodyWidth, opts.ExpandedContent))
  64. } else {
  65. body = sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.ExpandedContent))
  66. }
  67. } else if looksLikeMarkdown(opts.Result.Content) {
  68. body = sty.Tool.Body.Render(toolOutputCodeContent(sty, "result.md", opts.Result.Content, 0, bodyWidth, opts.ExpandedContent))
  69. } else {
  70. body = sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.ExpandedContent))
  71. }
  72. return joinToolParts(header, body)
  73. }
  74. // genericPrettyName converts a snake_case or kebab-case tool name to a
  75. // human-readable title case name.
  76. func genericPrettyName(name string) string {
  77. name = strings.ReplaceAll(name, "_", " ")
  78. name = strings.ReplaceAll(name, "-", " ")
  79. return stringext.Capitalize(name)
  80. }