Просмотр исходного кода

feat: better collapsed tool call visuals

adamdottv 8 месяцев назад
Родитель
Сommit
2d68814abc

+ 40 - 21
packages/tui/internal/components/chat/message.go

@@ -201,17 +201,20 @@ func renderContentBlock(content string, options ...renderingOption) string {
 	return content
 }
 
-func renderText(message client.MessageInfo, text string, author string) string {
-	t := theme.CurrentTheme()
-	width := layout.Current.Container.Width
-	padding := 0
+func calculatePadding() int {
 	if layout.Current.Viewport.Width < 80 {
-		padding = 5
+		return 5
 	} else if layout.Current.Viewport.Width < 120 {
-		padding = 15
+		return 15
 	} else {
-		padding = 20
+		return 20
 	}
+}
+
+func renderText(message client.MessageInfo, text string, author string) string {
+	t := theme.CurrentTheme()
+	width := layout.Current.Container.Width
+	padding := calculatePadding()
 
 	timestamp := time.UnixMilli(int64(message.Metadata.Time.Created)).Local().Format("02 Jan 2006 03:04 PM")
 	if time.Now().Format("02 Jan 2006") == timestamp[:11] {
@@ -222,9 +225,11 @@ func renderText(message client.MessageInfo, text string, author string) string {
 
 	textWidth := max(lipgloss.Width(text), lipgloss.Width(info))
 	markdownWidth := min(textWidth, width-padding-4) // -4 for the border and padding
+	if message.Role == client.Assistant {
+		markdownWidth = width - padding - 4
+	}
 	content := toMarkdown(text, markdownWidth, t.BackgroundSubtle())
 	content = strings.Join([]string{content, info}, "\n")
-	// content = lipgloss.JoinVertical(align, content, info)
 
 	switch message.Role {
 	case client.User:
@@ -246,6 +251,7 @@ func renderToolInvocation(
 	result *string,
 	metadata client.MessageInfo_Metadata_Tool_AdditionalProperties,
 	showResult bool,
+	isLast bool,
 ) string {
 	ignoredTools := []string{"opencode_todoread"}
 	if slices.Contains(ignoredTools, toolCall.ToolName) {
@@ -333,7 +339,7 @@ func renderToolInvocation(
 	switch toolCall.ToolName {
 	case "opencode_read":
 		toolArgs = renderArgs(&toolArgsMap, "filePath")
-		title = fmt.Sprintf("Read: %s   %s", toolArgs, elapsed)
+		title = fmt.Sprintf("READ %s   %s", toolArgs, elapsed)
 		if preview, ok := metadata.Get("preview"); ok && toolArgsMap["filePath"] != nil {
 			filename := toolArgsMap["filePath"].(string)
 			body = preview.(string)
@@ -341,7 +347,7 @@ func renderToolInvocation(
 		}
 	case "opencode_edit":
 		if filename, ok := toolArgsMap["filePath"].(string); ok {
-			title = fmt.Sprintf("Edit: %s   %s", relative(filename), elapsed)
+			title = fmt.Sprintf("EDIT %s   %s", relative(filename), elapsed)
 			if d, ok := metadata.Get("diff"); ok {
 				patch := d.(string)
 				var formattedDiff string
@@ -382,14 +388,14 @@ func renderToolInvocation(
 		}
 	case "opencode_write":
 		if filename, ok := toolArgsMap["filePath"].(string); ok {
-			title = fmt.Sprintf("Write: %s   %s", relative(filename), elapsed)
+			title = fmt.Sprintf("WRITE %s   %s", relative(filename), elapsed)
 			if content, ok := toolArgsMap["content"].(string); ok {
 				body = renderFile(filename, content)
 			}
 		}
 	case "opencode_bash":
 		if description, ok := toolArgsMap["description"].(string); ok {
-			title = fmt.Sprintf("Shell: %s   %s", description, elapsed)
+			title = fmt.Sprintf("SHELL %s   %s", description, elapsed)
 		}
 		if stdout, ok := metadata.Get("stdout"); ok {
 			command := toolArgsMap["command"].(string)
@@ -400,7 +406,7 @@ func renderToolInvocation(
 		}
 	case "opencode_webfetch":
 		toolArgs = renderArgs(&toolArgsMap, "url")
-		title = fmt.Sprintf("Fetching: %s   %s", toolArgs, elapsed)
+		title = fmt.Sprintf("FETCH %s   %s", toolArgs, elapsed)
 		if format, ok := toolArgsMap["format"].(string); ok {
 			body = *result
 			body = truncateHeight(body, 10)
@@ -410,7 +416,7 @@ func renderToolInvocation(
 			body = renderContentBlock(body, WithFullWidth(), WithMarginBottom(1))
 		}
 	case "opencode_todowrite":
-		title = fmt.Sprintf("Planning   %s", elapsed)
+		title = fmt.Sprintf("PLAN   %s", elapsed)
 
 		if to, ok := metadata.Get("todos"); ok && finished {
 			todos := to.([]any)
@@ -431,12 +437,27 @@ func renderToolInvocation(
 		}
 	default:
 		toolName := renderToolName(toolCall.ToolName)
-		title = fmt.Sprintf("%s: %s   %s", toolName, toolArgs, elapsed)
+		title = fmt.Sprintf("%s %s   %s", toolName, toolArgs, elapsed)
 		body = *result
 		body = truncateHeight(body, 10)
 		body = renderContentBlock(body, WithFullWidth(), WithMarginBottom(1))
 	}
 
+	if !showResult {
+		padding := calculatePadding()
+		style := lipgloss.NewStyle().Width(outerWidth - padding - 4).Background(t.BackgroundSubtle())
+		paddingBottom := 0
+		if isLast {
+			paddingBottom = 1
+		}
+		return renderContentBlock(style.Render(title),
+			WithAlign(lipgloss.Left),
+			WithBorderColor(t.Accent()),
+			WithPaddingTop(0),
+			WithPaddingBottom(paddingBottom),
+		)
+	}
+
 	if body == "" && error == "" {
 		body = *result
 		body = truncateHeight(body, 10)
@@ -464,19 +485,17 @@ func renderToolName(name string) string {
 	// case agent.AgentToolName:
 	// 	return "Task"
 	case "opencode_ls":
-		return "List"
+		return "LIST"
 	case "opencode_webfetch":
-		return "Fetch"
-	case "opencode_todoread":
-		return "Planning"
+		return "FETCH"
 	case "opencode_todowrite":
-		return "Planning"
+		return "PLAN"
 	default:
 		normalizedName := name
 		if strings.HasPrefix(name, "opencode_") {
 			normalizedName = strings.TrimPrefix(name, "opencode_")
 		}
-		return cases.Title(language.Und).String(normalizedName)
+		return cases.Upper(language.Und).String(normalizedName)
 	}
 }
 

+ 30 - 5
packages/tui/internal/components/chat/messages.go

@@ -1,6 +1,7 @@
 package chat
 
 import (
+	"slices"
 	"strings"
 	"time"
 
@@ -144,6 +145,17 @@ func (m *messagesComponent) renderView() {
 	for _, message := range m.app.Messages {
 		var content string
 		var cached bool
+		lastToolIndex := 0
+		lastToolIndices := []int{}
+		for i, p := range message.Parts {
+			part, _ := p.ValueByDiscriminator()
+			switch part.(type) {
+			case client.MessagePartText:
+				lastToolIndices = append(lastToolIndices, lastToolIndex)
+			case client.MessagePartToolInvocation:
+				lastToolIndex = i
+			}
+		}
 
 		author := ""
 		switch message.Role {
@@ -153,7 +165,7 @@ func (m *messagesComponent) renderView() {
 			author = message.Metadata.Assistant.ModelID
 		}
 
-		for _, p := range message.Parts {
+		for i, p := range message.Parts {
 			part, err := p.ValueByDiscriminator()
 			if err != nil {
 				continue //TODO: handle error?
@@ -180,6 +192,7 @@ func (m *messagesComponent) renderView() {
 					previousBlockType = assistantTextBlock
 				}
 			case client.MessagePartToolInvocation:
+				isLastToolInvocation := slices.Contains(lastToolIndices, i)
 				toolInvocationPart := part.(client.MessagePartToolInvocation)
 				toolCall, _ := toolInvocationPart.ToolInvocation.AsMessageToolInvocationToolCall()
 				metadata := client.MessageInfo_Metadata_Tool_AdditionalProperties{}
@@ -200,15 +213,27 @@ func (m *messagesComponent) renderView() {
 					)
 					content, cached = m.cache.Get(key)
 					if !cached {
-						content = renderToolInvocation(toolCall, result, metadata, m.showToolResults)
+						content = renderToolInvocation(
+							toolCall,
+							result,
+							metadata,
+							m.showToolResults,
+							isLastToolInvocation,
+						)
 						m.cache.Set(key, content)
 					}
 				} else {
-					// if the tool call isn't finished, never cache
-					content = renderToolInvocation(toolCall, result, metadata, m.showToolResults)
+					// if the tool call isn't finished, don't cache
+					content = renderToolInvocation(
+						toolCall,
+						result,
+						metadata,
+						m.showToolResults,
+						isLastToolInvocation,
+					)
 				}
 
-				if previousBlockType != toolInvocationBlock {
+				if previousBlockType != toolInvocationBlock && m.showToolResults {
 					blocks = append(blocks, "")
 				}
 				blocks = append(blocks, content)