Browse Source

chore: cleanup

adamdottv 9 months ago
parent
commit
0c21ca5318

+ 74 - 8
internal/llm/agent/agent.go

@@ -47,7 +47,9 @@ type Service interface {
 	IsSessionBusy(sessionID string) bool
 	IsBusy() bool
 	Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error)
-	CompactSession(ctx context.Context, sessionID string) error
+	CompactSession(ctx context.Context, sessionID string, force bool) error
+	GetUsage(ctx context.Context, sessionID string) (*int64, error)
+	EstimateContextWindowUsage(ctx context.Context, sessionID string) (float64, bool, error)
 }
 
 type agent struct {
@@ -194,17 +196,16 @@ func (a *agent) Run(ctx context.Context, sessionID string, content string, attac
 		events <- result
 		close(events)
 	}()
+
 	return events, nil
 }
 
 func (a *agent) prepareMessageHistory(ctx context.Context, sessionID string) (session.Session, []message.Message, error) {
-	// Get the current session to check for summary
 	currentSession, err := a.sessions.Get(ctx, sessionID)
 	if err != nil {
 		return currentSession, nil, fmt.Errorf("failed to get session: %w", err)
 	}
 
-	// Fetch messages based on whether a summary exists
 	var sessionMessages []message.Message
 	if currentSession.Summary != "" && currentSession.SummarizedAt > 0 {
 		// If summary exists, only fetch messages after the summarization timestamp
@@ -220,7 +221,6 @@ func (a *agent) prepareMessageHistory(ctx context.Context, sessionID string) (se
 		}
 	}
 
-	// Prepare the message history for the LLM
 	var messages []message.Message
 	if currentSession.Summary != "" && currentSession.SummarizedAt > 0 {
 		// If summary exists, create a temporary message for the summary
@@ -253,7 +253,6 @@ func (a *agent) triggerTitleGeneration(sessionID string, content string) {
 }
 
 func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
-	// Get message history and session info
 	currentSession, sessionMessages, err := a.prepareMessageHistory(ctx, sessionID)
 	if err != nil {
 		return a.err(err)
@@ -269,7 +268,6 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string
 		return a.err(fmt.Errorf("failed to create user message: %w", err))
 	}
 
-	// Append the new user message to the conversation history
 	messages := append(sessionMessages, userMsg)
 
 	for {
@@ -280,6 +278,41 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string
 		default:
 			// Continue processing
 		}
+
+		// Check if auto-compaction is needed before calling the provider
+		usagePercentage, needsCompaction, errEstimate := a.EstimateContextWindowUsage(ctx, sessionID)
+		if errEstimate != nil {
+			slog.Warn("Failed to estimate context window usage for auto-compaction", "error", errEstimate, "sessionID", sessionID)
+		} else if needsCompaction {
+			status.Info(fmt.Sprintf("Context window usage is at %.2f%%. Auto-compacting conversation...", usagePercentage))
+
+			// Run compaction synchronously
+			compactCtx, cancelCompact := context.WithTimeout(ctx, 30*time.Second) // Use appropriate context
+			errCompact := a.CompactSession(compactCtx, sessionID, true)
+			cancelCompact()
+
+			if errCompact != nil {
+				status.Warn(fmt.Sprintf("Auto-compaction failed: %v. Context window usage may continue to grow.", errCompact))
+			} else {
+				status.Info("Auto-compaction completed successfully.")
+				// After compaction, message history needs to be re-prepared.
+				// The 'messages' slice needs to be updated with the new summary and subsequent messages,
+				// ensuring the latest user message is correctly appended.
+				_, sessionMessagesFromCompact, errPrepare := a.prepareMessageHistory(ctx, sessionID)
+				if errPrepare != nil {
+					return a.err(fmt.Errorf("failed to re-prepare message history after compaction: %w", errPrepare))
+				}
+				messages = sessionMessagesFromCompact
+
+				// Ensure the user message that triggered this cycle is the last one.
+				// 'userMsg' was created before this loop using a.createUserMessage.
+				// It should be appended to the 'messages' slice if it's not already the last element.
+				if len(messages) == 0 || (len(messages) > 0 && messages[len(messages)-1].ID != userMsg.ID) {
+					messages = append(messages, userMsg)
+				}
+			}
+		}
+
 		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, messages)
 		if err != nil {
 			if errors.Is(err, context.Canceled) {
@@ -531,6 +564,39 @@ func (a *agent) GetUsage(ctx context.Context, sessionID string) (*int64, error)
 	return &usage, nil
 }
 
+func (a *agent) EstimateContextWindowUsage(ctx context.Context, sessionID string) (float64, bool, error) {
+	session, err := a.sessions.Get(ctx, sessionID)
+	if err != nil {
+		return 0, false, fmt.Errorf("failed to get session: %w", err)
+	}
+
+	// Get the model's context window size
+	model := a.provider.Model()
+	contextWindow := model.ContextWindow
+	if contextWindow <= 0 {
+		// Default to a reasonable size if not specified
+		contextWindow = 100000
+	}
+
+	// Calculate current token usage
+	currentTokens := session.PromptTokens + session.CompletionTokens
+
+	// Get the max tokens setting for the agent
+	maxTokens := a.provider.MaxTokens()
+
+	// Calculate percentage of context window used
+	usagePercentage := float64(currentTokens) / float64(contextWindow)
+
+	// Check if we need to auto-compact
+	// Auto-compact when:
+	// 1. Usage exceeds 90% of context window, OR
+	// 2. Current usage + maxTokens would exceed 100% of context window
+	needsCompaction := usagePercentage >= 0.9 ||
+		float64(currentTokens+maxTokens) > float64(contextWindow)
+
+	return usagePercentage * 100, needsCompaction, nil
+}
+
 func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
 	sess, err := a.sessions.Get(ctx, sessionID)
 	if err != nil {
@@ -572,9 +638,9 @@ func (a *agent) Update(agentName config.AgentName, modelID models.ModelID) (mode
 	return a.provider.Model(), nil
 }
 
-func (a *agent) CompactSession(ctx context.Context, sessionID string) error {
+func (a *agent) CompactSession(ctx context.Context, sessionID string, force bool) error {
 	// Check if the session is busy
-	if a.IsSessionBusy(sessionID) {
+	if a.IsSessionBusy(sessionID) && !force {
 		return ErrSessionBusy
 	}
 

+ 2 - 2
internal/logging/logging.go

@@ -95,11 +95,11 @@ func (s *service) Create(ctx context.Context, log Log) error {
 	err := s.db.CreateLog(ctx, db.CreateLogParams{
 		ID:         log.ID,
 		SessionID:  sql.NullString{String: log.SessionID, Valid: log.SessionID != ""},
-		Timestamp:  log.Timestamp / 1000,
+		Timestamp:  log.Timestamp,
 		Level:      log.Level,
 		Message:    log.Message,
 		Attributes: attributesJSON,
-		CreatedAt:  log.CreatedAt / 1000,
+		CreatedAt:  log.CreatedAt,
 	})
 	if err != nil {
 		return fmt.Errorf("db.CreateLog: %w", err)

+ 2 - 4
internal/message/message.go

@@ -128,7 +128,7 @@ func (s *service) Update(ctx context.Context, message Message) (Message, error)
 	finishPart := message.FinishPart()
 	if finishPart != nil && finishPart.Time > 0 {
 		dbFinishedAt = sql.NullInt64{
-			Int64: finishPart.Time / 1000, // Convert Milliseconds from Go struct to Seconds for DB
+			Int64: finishPart.Time,
 			Valid: true,
 		}
 	}
@@ -193,11 +193,9 @@ func (s *service) ListAfter(ctx context.Context, sessionID string, timestampMill
 	s.mu.RLock()
 	defer s.mu.RUnlock()
 
-	timestampSeconds := timestampMillis / 1000 // Convert to seconds for DB query
-
 	dbMessages, err := s.db.ListMessagesBySessionAfter(ctx, db.ListMessagesBySessionAfterParams{
 		SessionID: sessionID,
-		CreatedAt: timestampSeconds,
+		CreatedAt: timestampMillis,
 	})
 	if err != nil {
 		return nil, fmt.Errorf("db.ListMessagesBySessionAfter: %w", err)

+ 0 - 12
internal/tui/components/chat/message.go

@@ -672,15 +672,3 @@ func renderToolMessage(
 	}
 	return toolMsg
 }
-
-// Helper function to format the time difference between two Unix timestamps
-func formatTimestampDiff(start, end int64) string {
-	diffSeconds := float64(end-start) / 1000.0 // Convert to seconds
-	if diffSeconds < 1 {
-		return fmt.Sprintf("%dms", int(diffSeconds*1000))
-	}
-	if diffSeconds < 60 {
-		return fmt.Sprintf("%.1fs", diffSeconds)
-	}
-	return fmt.Sprintf("%.1fm", diffSeconds/60)
-}

+ 1 - 1
internal/tui/page/chat.go

@@ -92,7 +92,7 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 
 		// Run compaction in background
 		go func(sessionID string) {
-			err := p.app.PrimaryAgent.CompactSession(context.Background(), sessionID)
+			err := p.app.PrimaryAgent.CompactSession(context.Background(), sessionID, false)
 			if err != nil {
 				status.Error(fmt.Sprintf("Compaction failed: %v", err))
 			} else {