provider.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. package provider
  2. import (
  3. "context"
  4. "github.com/kujtimiihoxha/termai/internal/llm/tools"
  5. "github.com/kujtimiihoxha/termai/internal/message"
  6. )
  7. // EventType represents the type of streaming event
  8. type EventType string
  9. const (
  10. EventContentStart EventType = "content_start"
  11. EventContentDelta EventType = "content_delta"
  12. EventThinkingDelta EventType = "thinking_delta"
  13. EventContentStop EventType = "content_stop"
  14. EventComplete EventType = "complete"
  15. EventError EventType = "error"
  16. EventWarning EventType = "warning"
  17. EventInfo EventType = "info"
  18. )
  19. type TokenUsage struct {
  20. InputTokens int64
  21. OutputTokens int64
  22. CacheCreationTokens int64
  23. CacheReadTokens int64
  24. }
  25. type ProviderResponse struct {
  26. Content string
  27. ToolCalls []message.ToolCall
  28. Usage TokenUsage
  29. FinishReason string
  30. }
  31. type ProviderEvent struct {
  32. Type EventType
  33. Content string
  34. Thinking string
  35. ToolCall *message.ToolCall
  36. Error error
  37. Response *ProviderResponse
  38. // Used for giving users info on e.x retry
  39. Info string
  40. }
  41. type Provider interface {
  42. SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
  43. StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error)
  44. }
  45. func cleanupMessages(messages []message.Message) []message.Message {
  46. // First pass: filter out canceled messages
  47. var cleanedMessages []message.Message
  48. for _, msg := range messages {
  49. if msg.FinishReason() != "canceled" {
  50. cleanedMessages = append(cleanedMessages, msg)
  51. }
  52. }
  53. // Second pass: filter out tool messages without a corresponding tool call
  54. var result []message.Message
  55. toolMessageIDs := make(map[string]bool)
  56. for _, msg := range cleanedMessages {
  57. if msg.Role == message.Assistant {
  58. for _, toolCall := range msg.ToolCalls() {
  59. toolMessageIDs[toolCall.ID] = true // Mark as referenced
  60. }
  61. }
  62. }
  63. // Keep only messages that aren't unreferenced tool messages
  64. for _, msg := range cleanedMessages {
  65. if msg.Role == message.Tool {
  66. for _, toolCall := range msg.ToolResults() {
  67. if referenced, exists := toolMessageIDs[toolCall.ToolCallID]; exists && referenced {
  68. result = append(result, msg)
  69. }
  70. }
  71. } else {
  72. result = append(result, msg)
  73. }
  74. }
  75. return result
  76. }