|
@@ -11,6 +11,7 @@ import (
|
|
|
"github.com/cloudwego/eino/schema"
|
|
"github.com/cloudwego/eino/schema"
|
|
|
"github.com/google/uuid"
|
|
"github.com/google/uuid"
|
|
|
"github.com/kujtimiihoxha/termai/internal/llm/agent"
|
|
"github.com/kujtimiihoxha/termai/internal/llm/agent"
|
|
|
|
|
+ "github.com/kujtimiihoxha/termai/internal/llm/models"
|
|
|
"github.com/kujtimiihoxha/termai/internal/logging"
|
|
"github.com/kujtimiihoxha/termai/internal/logging"
|
|
|
"github.com/kujtimiihoxha/termai/internal/message"
|
|
"github.com/kujtimiihoxha/termai/internal/message"
|
|
|
"github.com/kujtimiihoxha/termai/internal/pubsub"
|
|
"github.com/kujtimiihoxha/termai/internal/pubsub"
|
|
@@ -88,7 +89,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
log.Printf("Request: %s", content)
|
|
log.Printf("Request: %s", content)
|
|
|
- agent, systemMessage, err := agent.GetAgent(s.ctx, viper.GetString("agents.default"))
|
|
|
|
|
|
|
+ currentAgent, systemMessage, err := agent.GetAgent(s.ctx, viper.GetString("agents.default"))
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
s.Publish(AgentErrorEvent, AgentEvent{
|
|
s.Publish(AgentErrorEvent, AgentEvent{
|
|
|
ID: id,
|
|
ID: id,
|
|
@@ -110,6 +111,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
|
|
|
for _, m := range history {
|
|
for _, m := range history {
|
|
|
messages = append(messages, &m.MessageData)
|
|
messages = append(messages, &m.MessageData)
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
builder := callbacks.NewHandlerBuilder()
|
|
builder := callbacks.NewHandlerBuilder()
|
|
|
builder.OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
|
|
builder.OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
|
|
|
i, ok := input.(*eModel.CallbackInput)
|
|
i, ok := input.(*eModel.CallbackInput)
|
|
@@ -140,7 +142,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
|
|
|
return ctx
|
|
return ctx
|
|
|
})
|
|
})
|
|
|
|
|
|
|
|
- out, err := agent.Generate(s.ctx, messages, enioAgent.WithComposeOptions(compose.WithCallbacks(builder.Build())))
|
|
|
|
|
|
|
+ out, err := currentAgent.Generate(s.ctx, messages, enioAgent.WithComposeOptions(compose.WithCallbacks(builder.Build())))
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
s.Publish(AgentErrorEvent, AgentEvent{
|
|
s.Publish(AgentErrorEvent, AgentEvent{
|
|
|
ID: id,
|
|
ID: id,
|
|
@@ -153,6 +155,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
|
|
|
return
|
|
return
|
|
|
}
|
|
}
|
|
|
usage := out.ResponseMeta.Usage
|
|
usage := out.ResponseMeta.Usage
|
|
|
|
|
+ s.messages.Create(sessionID, *out)
|
|
|
if usage != nil {
|
|
if usage != nil {
|
|
|
log.Printf("Prompt Tokens: %d, Completion Tokens: %d, Total Tokens: %d", usage.PromptTokens, usage.CompletionTokens, usage.TotalTokens)
|
|
log.Printf("Prompt Tokens: %d, Completion Tokens: %d, Total Tokens: %d", usage.PromptTokens, usage.CompletionTokens, usage.TotalTokens)
|
|
|
session, err := s.sessions.Get(sessionID)
|
|
session, err := s.sessions.Get(sessionID)
|
|
@@ -170,6 +173,29 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
|
|
|
session.PromptTokens += int64(usage.PromptTokens)
|
|
session.PromptTokens += int64(usage.PromptTokens)
|
|
|
session.CompletionTokens += int64(usage.CompletionTokens)
|
|
session.CompletionTokens += int64(usage.CompletionTokens)
|
|
|
// TODO: calculate cost
|
|
// TODO: calculate cost
|
|
|
|
|
+ model := models.SupportedModels[models.ModelID(viper.GetString("models.big"))]
|
|
|
|
|
+ session.Cost += float64(usage.PromptTokens)*(model.CostPer1MIn/1_000_000) +
|
|
|
|
|
+ float64(usage.CompletionTokens)*(model.CostPer1MOut/1_000_000)
|
|
|
|
|
+ var newTitle string
|
|
|
|
|
+ if len(history) == 1 {
|
|
|
|
|
+ // first message generate the title
|
|
|
|
|
+ newTitle, err = agent.GenerateTitle(s.ctx, content)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ s.Publish(AgentErrorEvent, AgentEvent{
|
|
|
|
|
+ ID: id,
|
|
|
|
|
+ Type: AgentMessageTypeError,
|
|
|
|
|
+ AgentID: RootAgent,
|
|
|
|
|
+ MessageID: "",
|
|
|
|
|
+ SessionID: sessionID,
|
|
|
|
|
+ Content: err.Error(),
|
|
|
|
|
+ })
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ if newTitle != "" {
|
|
|
|
|
+ session.Title = newTitle
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
_, err = s.sessions.Save(session)
|
|
_, err = s.sessions.Save(session)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
s.Publish(AgentErrorEvent, AgentEvent{
|
|
s.Publish(AgentErrorEvent, AgentEvent{
|
|
@@ -183,7 +209,6 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
|
|
|
return
|
|
return
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
- s.messages.Create(sessionID, *out)
|
|
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (s *service) SendRequest(sessionID string, content string) {
|
|
func (s *service) SendRequest(sessionID string, content string) {
|