llm.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. package llm
  2. import (
  3. "context"
  4. "log"
  5. "sync"
  6. "time"
  7. "github.com/cloudwego/eino/callbacks"
  8. "github.com/cloudwego/eino/compose"
  9. "github.com/cloudwego/eino/schema"
  10. "github.com/google/uuid"
  11. "github.com/kujtimiihoxha/termai/internal/llm/agent"
  12. "github.com/kujtimiihoxha/termai/internal/logging"
  13. "github.com/kujtimiihoxha/termai/internal/message"
  14. "github.com/kujtimiihoxha/termai/internal/pubsub"
  15. "github.com/kujtimiihoxha/termai/internal/session"
  16. eModel "github.com/cloudwego/eino/components/model"
  17. enioAgent "github.com/cloudwego/eino/flow/agent"
  18. "github.com/spf13/viper"
  19. )
  20. const (
  21. AgentRequestoEvent pubsub.EventType = "agent_request"
  22. AgentErrorEvent pubsub.EventType = "agent_error"
  23. AgentResponseEvent pubsub.EventType = "agent_response"
  24. )
  25. type AgentMessageType int
  26. const (
  27. AgentMessageTypeNewUserMessage AgentMessageType = iota
  28. AgentMessageTypeAgentResponse
  29. AgentMessageTypeError
  30. )
  31. type agentID string
  32. const (
  33. RootAgent agentID = "root"
  34. TaskAgent agentID = "task"
  35. )
  36. type AgentEvent struct {
  37. ID string `json:"id"`
  38. Type AgentMessageType `json:"type"`
  39. AgentID agentID `json:"agent_id"`
  40. MessageID string `json:"message_id"`
  41. SessionID string `json:"session_id"`
  42. Content string `json:"content"`
  43. }
  44. type Service interface {
  45. pubsub.Suscriber[AgentEvent]
  46. SendRequest(sessionID string, content string)
  47. }
  48. type service struct {
  49. *pubsub.Broker[AgentEvent]
  50. Requests sync.Map
  51. ctx context.Context
  52. activeRequests sync.Map
  53. messages message.Service
  54. sessions session.Service
  55. logger logging.Interface
  56. }
  57. func (s *service) handleRequest(id string, sessionID string, content string) {
  58. cancel, ok := s.activeRequests.Load(id)
  59. if !ok {
  60. return
  61. }
  62. defer cancel.(context.CancelFunc)()
  63. defer s.activeRequests.Delete(id)
  64. history, err := s.messages.List(sessionID)
  65. if err != nil {
  66. s.Publish(AgentErrorEvent, AgentEvent{
  67. ID: id,
  68. Type: AgentMessageTypeError,
  69. AgentID: RootAgent,
  70. MessageID: "",
  71. SessionID: sessionID,
  72. Content: err.Error(),
  73. })
  74. return
  75. }
  76. log.Printf("Request: %s", content)
  77. agent, systemMessage, err := agent.GetAgent(s.ctx, viper.GetString("agents.default"))
  78. if err != nil {
  79. s.Publish(AgentErrorEvent, AgentEvent{
  80. ID: id,
  81. Type: AgentMessageTypeError,
  82. AgentID: RootAgent,
  83. MessageID: "",
  84. SessionID: sessionID,
  85. Content: err.Error(),
  86. })
  87. return
  88. }
  89. messages := []*schema.Message{
  90. {
  91. Role: schema.System,
  92. Content: systemMessage,
  93. },
  94. }
  95. for _, m := range history {
  96. messages = append(messages, &m.MessageData)
  97. }
  98. builder := callbacks.NewHandlerBuilder()
  99. builder.OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
  100. i, ok := input.(*eModel.CallbackInput)
  101. if info.Component == "ChatModel" && ok {
  102. if len(messages) < len(i.Messages) {
  103. // find new messages
  104. newMessages := i.Messages[len(messages):]
  105. for _, m := range newMessages {
  106. _, err = s.messages.Create(sessionID, *m)
  107. if err != nil {
  108. s.Publish(AgentErrorEvent, AgentEvent{
  109. ID: id,
  110. Type: AgentMessageTypeError,
  111. AgentID: RootAgent,
  112. MessageID: "",
  113. SessionID: sessionID,
  114. Content: err.Error(),
  115. })
  116. }
  117. messages = append(messages, m)
  118. }
  119. }
  120. }
  121. return ctx
  122. })
  123. builder.OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context {
  124. return ctx
  125. })
  126. out, err := agent.Generate(s.ctx, messages, enioAgent.WithComposeOptions(compose.WithCallbacks(builder.Build())))
  127. if err != nil {
  128. s.Publish(AgentErrorEvent, AgentEvent{
  129. ID: id,
  130. Type: AgentMessageTypeError,
  131. AgentID: RootAgent,
  132. MessageID: "",
  133. SessionID: sessionID,
  134. Content: err.Error(),
  135. })
  136. return
  137. }
  138. usage := out.ResponseMeta.Usage
  139. if usage != nil {
  140. log.Printf("Prompt Tokens: %d, Completion Tokens: %d, Total Tokens: %d", usage.PromptTokens, usage.CompletionTokens, usage.TotalTokens)
  141. session, err := s.sessions.Get(sessionID)
  142. if err != nil {
  143. s.Publish(AgentErrorEvent, AgentEvent{
  144. ID: id,
  145. Type: AgentMessageTypeError,
  146. AgentID: RootAgent,
  147. MessageID: "",
  148. SessionID: sessionID,
  149. Content: err.Error(),
  150. })
  151. return
  152. }
  153. session.PromptTokens += int64(usage.PromptTokens)
  154. session.CompletionTokens += int64(usage.CompletionTokens)
  155. // TODO: calculate cost
  156. _, err = s.sessions.Save(session)
  157. if err != nil {
  158. s.Publish(AgentErrorEvent, AgentEvent{
  159. ID: id,
  160. Type: AgentMessageTypeError,
  161. AgentID: RootAgent,
  162. MessageID: "",
  163. SessionID: sessionID,
  164. Content: err.Error(),
  165. })
  166. return
  167. }
  168. }
  169. s.messages.Create(sessionID, *out)
  170. }
  171. func (s *service) SendRequest(sessionID string, content string) {
  172. id := uuid.New().String()
  173. _, cancel := context.WithTimeout(s.ctx, 5*time.Minute)
  174. s.activeRequests.Store(id, cancel)
  175. log.Printf("Request: %s", content)
  176. go s.handleRequest(id, sessionID, content)
  177. }
  178. func NewService(ctx context.Context, logger logging.Interface, sessions session.Service, messages message.Service) Service {
  179. return &service{
  180. Broker: pubsub.NewBroker[AgentEvent](),
  181. ctx: ctx,
  182. sessions: sessions,
  183. messages: messages,
  184. logger: logger,
  185. }
  186. }