|
@@ -296,36 +296,19 @@ func (a *agent) createUserMessage(ctx context.Context, sessionID, content string
|
|
|
})
|
|
})
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-// estimateTokens provides a rough estimate of token count based on character count
|
|
|
|
|
-// using a simple heuristic of ~4 characters per token
|
|
|
|
|
-func estimateTokens(messages []message.Message) int64 {
|
|
|
|
|
- totalChars := 0
|
|
|
|
|
- for _, msg := range messages {
|
|
|
|
|
- // Get text content from all parts
|
|
|
|
|
- for _, part := range msg.Parts {
|
|
|
|
|
- if textContent, ok := part.(message.TextContent); ok {
|
|
|
|
|
- totalChars += len(textContent.Text)
|
|
|
|
|
- } else {
|
|
|
|
|
- // For non-text parts, add a conservative estimate
|
|
|
|
|
- totalChars += 100
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- // Add chars for role (conservative estimate)
|
|
|
|
|
- totalChars += 10
|
|
|
|
|
- }
|
|
|
|
|
- // Heuristic: ~4 chars per token
|
|
|
|
|
- return int64(totalChars / 4)
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
|
|
func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
|
|
|
// Check if we need to auto-compact based on token count
|
|
// Check if we need to auto-compact based on token count
|
|
|
contextWindow := a.provider.Model().ContextWindow
|
|
contextWindow := a.provider.Model().ContextWindow
|
|
|
|
|
+ maxTokens := a.provider.MaxTokens()
|
|
|
threshold := int64(float64(contextWindow) * 0.80)
|
|
threshold := int64(float64(contextWindow) * 0.80)
|
|
|
- estimatedTokens := estimateTokens(msgHistory)
|
|
|
|
|
|
|
+ usage, err := a.GetUsage(ctx, sessionID)
|
|
|
|
|
+ if err != nil || usage == nil {
|
|
|
|
|
+ return message.Message{}, nil, fmt.Errorf("failed to get usage: %w", err)
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
// If we're approaching the context window limit, trigger auto-compaction
|
|
// If we're approaching the context window limit, trigger auto-compaction
|
|
|
- if estimatedTokens >= threshold {
|
|
|
|
|
- logging.InfoPersist(fmt.Sprintf("Auto-compaction triggered for session %s. Estimated tokens: %d, Threshold: %d", sessionID, estimatedTokens, threshold))
|
|
|
|
|
|
|
+ if (*usage + maxTokens) >= threshold {
|
|
|
|
|
+ logging.InfoPersist(fmt.Sprintf("Auto-compaction triggered for session %s. Estimated tokens: %d, Threshold: %d", sessionID, usage, threshold))
|
|
|
|
|
|
|
|
// Perform compaction with pause/resume to ensure safety
|
|
// Perform compaction with pause/resume to ensure safety
|
|
|
if err := a.CompactSession(ctx, sessionID); err != nil {
|
|
if err := a.CompactSession(ctx, sessionID); err != nil {
|
|
@@ -357,11 +340,6 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
|
|
|
|
|
|
|
|
// Replace msgHistory with the new compacted version
|
|
// Replace msgHistory with the new compacted version
|
|
|
msgHistory = append([]message.Message{summaryMessage}, sessionMessages...)
|
|
msgHistory = append([]message.Message{summaryMessage}, sessionMessages...)
|
|
|
-
|
|
|
|
|
- // Log the new token estimate after compaction
|
|
|
|
|
- newEstimate := estimateTokens(msgHistory)
|
|
|
|
|
- logging.InfoPersist(fmt.Sprintf("After compaction: Estimated tokens: %d (reduced by %d)",
|
|
|
|
|
- newEstimate, estimatedTokens-newEstimate))
|
|
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
@@ -534,6 +512,16 @@ func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg
|
|
|
return nil
|
|
return nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+func (a *agent) GetUsage(ctx context.Context, sessionID string) (*int64, error) {
|
|
|
|
|
+ session, err := a.sessions.Get(ctx, sessionID)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, fmt.Errorf("failed to get session: %w", err)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ usage := session.PromptTokens + session.CompletionTokens
|
|
|
|
|
+ return &usage, nil
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
|
|
func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
|
|
|
sess, err := a.sessions.Get(ctx, sessionID)
|
|
sess, err := a.sessions.Get(ctx, sessionID)
|
|
|
if err != nil {
|
|
if err != nil {
|