Browse Source

chore: refactoring

adamdottv 9 months ago
parent
commit
1f9610e266
2 changed files with 23 additions and 14 deletions
  1. 6 7
      internal/llm/agent/agent.go
  2. 17 7
      internal/session/session.go

+ 6 - 7
internal/llm/agent/agent.go

@@ -207,9 +207,9 @@ func (a *agent) prepareMessageHistory(ctx context.Context, sessionID string) (se
 	}
 
 	var sessionMessages []message.Message
-	if currentSession.Summary != "" && currentSession.SummarizedAt > 0 {
+	if currentSession.Summary != "" && !currentSession.SummarizedAt.IsZero() {
 		// If summary exists, only fetch messages after the summarization timestamp
-		sessionMessages, err = a.messages.ListAfter(ctx, sessionID, currentSession.SummarizedAt)
+		sessionMessages, err = a.messages.ListAfter(ctx, sessionID, currentSession.SummarizedAt.UnixMilli())
 		if err != nil {
 			return currentSession, nil, fmt.Errorf("failed to list messages after summary: %w", err)
 		}
@@ -222,7 +222,7 @@ func (a *agent) prepareMessageHistory(ctx context.Context, sessionID string) (se
 	}
 
 	var messages []message.Message
-	if currentSession.Summary != "" && currentSession.SummarizedAt > 0 {
+	if currentSession.Summary != "" && !currentSession.SummarizedAt.IsZero() {
 		// If summary exists, create a temporary message for the summary
 		summaryMessage := message.Message{
 			Role: message.Assistant,
@@ -666,11 +666,11 @@ func (a *agent) CompactSession(ctx context.Context, sessionID string, force bool
 	}
 
 	var existingSummary string
-	if session.Summary != "" && session.SummarizedAt > 0 {
+	if session.Summary != "" && !session.SummarizedAt.IsZero() {
 		// Filter messages that were created after the last summarization
 		var newMessages []message.Message
 		for _, msg := range sessionMessages {
-			if msg.CreatedAt > session.SummarizedAt {
+			if msg.CreatedAt > session.SummarizedAt.UnixMilli() {
 				newMessages = append(newMessages, msg)
 			}
 		}
@@ -741,9 +741,8 @@ Your summary should be comprehensive enough to provide context but concise enoug
 	}
 
 	// Update the session with the new summary
-	currentTime := time.Now().UnixMilli()
 	session.Summary = summaryText
-	session.SummarizedAt = currentTime
+	session.SummarizedAt = time.Now()
 
 	// Save the updated session
 	_, err = a.sessions.Update(ctx, session)

+ 17 - 7
internal/session/session.go

@@ -21,9 +21,9 @@ type Session struct {
 	CompletionTokens int64
 	Cost             float64
 	Summary          string
-	SummarizedAt     int64
-	CreatedAt        int64
-	UpdatedAt        int64
+	SummarizedAt     time.Time
+	CreatedAt        time.Time
+	UpdatedAt        time.Time
 }
 
 const (
@@ -153,6 +153,11 @@ func (s *service) Update(ctx context.Context, session Session) (Session, error)
 	if session.ID == "" {
 		return Session{}, fmt.Errorf("cannot update session with empty ID")
 	}
+	var summarizedAt sql.NullInt64
+	if !session.SummarizedAt.IsZero() {
+		summarizedAt = sql.NullInt64{Int64: session.SummarizedAt.UnixMilli(), Valid: true}
+	}
+	
 	params := db.UpdateSessionParams{
 		ID:               session.ID,
 		Title:            session.Title,
@@ -160,7 +165,7 @@ func (s *service) Update(ctx context.Context, session Session) (Session, error)
 		CompletionTokens: session.CompletionTokens,
 		Cost:             session.Cost,
 		Summary:          sql.NullString{String: session.Summary, Valid: session.Summary != ""},
-		SummarizedAt:     sql.NullInt64{Int64: session.SummarizedAt, Valid: session.SummarizedAt > 0},
+		SummarizedAt:     summarizedAt,
 	}
 	dbSession, err := s.db.UpdateSession(ctx, params)
 	if err != nil {
@@ -199,6 +204,11 @@ func (s *service) Subscribe(ctx context.Context) <-chan pubsub.Event[Session] {
 }
 
 func (s *service) fromDBItem(item db.Session) Session {
+	var summarizedAt time.Time
+	if item.SummarizedAt.Valid {
+		summarizedAt = time.UnixMilli(item.SummarizedAt.Int64)
+	}
+	
 	return Session{
 		ID:               item.ID,
 		ParentSessionID:  item.ParentSessionID.String,
@@ -208,9 +218,9 @@ func (s *service) fromDBItem(item db.Session) Session {
 		CompletionTokens: item.CompletionTokens,
 		Cost:             item.Cost,
 		Summary:          item.Summary.String,
-		SummarizedAt:     item.SummarizedAt.Int64,
-		CreatedAt:        item.CreatedAt * 1000,
-		UpdatedAt:        item.UpdatedAt * 1000,
+		SummarizedAt:     summarizedAt,
+		CreatedAt:        time.UnixMilli(item.CreatedAt * 1000),
+		UpdatedAt:        time.UnixMilli(item.UpdatedAt * 1000),
 	}
 }