resolve_session_test.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. package app
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "strings"
  7. "testing"
  8. "github.com/charmbracelet/crush/internal/config"
  9. "github.com/charmbracelet/crush/internal/pubsub"
  10. "github.com/charmbracelet/crush/internal/session"
  11. "github.com/stretchr/testify/require"
  12. )
  13. // mockSessionService is a minimal mock of session.Service for testing resolveSession.
  14. type mockSessionService struct {
  15. sessions []session.Session
  16. created []session.Session
  17. }
  18. func (m *mockSessionService) Subscribe(context.Context) <-chan pubsub.Event[session.Session] {
  19. return make(chan pubsub.Event[session.Session])
  20. }
  21. func (m *mockSessionService) Create(_ context.Context, title string) (session.Session, error) {
  22. s := session.Session{ID: "new-session-id", Title: title}
  23. m.created = append(m.created, s)
  24. return s, nil
  25. }
  26. func (m *mockSessionService) CreateTitleSession(context.Context, string) (session.Session, error) {
  27. return session.Session{}, nil
  28. }
  29. func (m *mockSessionService) CreateTaskSession(context.Context, string, string, string) (session.Session, error) {
  30. return session.Session{}, nil
  31. }
  32. func (m *mockSessionService) Get(_ context.Context, id string) (session.Session, error) {
  33. for _, s := range m.sessions {
  34. if s.ID == id {
  35. return s, nil
  36. }
  37. }
  38. return session.Session{}, sql.ErrNoRows
  39. }
  40. func (m *mockSessionService) GetLast(_ context.Context) (session.Session, error) {
  41. if len(m.sessions) > 0 {
  42. return m.sessions[0], nil
  43. }
  44. return session.Session{}, sql.ErrNoRows
  45. }
  46. func (m *mockSessionService) List(context.Context) ([]session.Session, error) {
  47. return m.sessions, nil
  48. }
  49. func (m *mockSessionService) Save(_ context.Context, s session.Session) (session.Session, error) {
  50. return s, nil
  51. }
  52. func (m *mockSessionService) SaveWithModels(_ context.Context, s session.Session, _ map[config.SelectedModelType]config.SelectedModel) (session.Session, error) {
  53. return s, nil
  54. }
  55. func (m *mockSessionService) UpdateSessionModels(context.Context, string, map[config.SelectedModelType]config.SelectedModel) error {
  56. return nil
  57. }
  58. func (m *mockSessionService) UpdateTitleAndUsage(context.Context, string, string, int64, int64, float64) error {
  59. return nil
  60. }
  61. func (m *mockSessionService) Rename(context.Context, string, string) error {
  62. return nil
  63. }
  64. func (m *mockSessionService) Delete(context.Context, string) error {
  65. return nil
  66. }
  67. func (m *mockSessionService) CreateAgentToolSessionID(messageID, toolCallID string) string {
  68. return fmt.Sprintf("%s$$%s", messageID, toolCallID)
  69. }
  70. func (m *mockSessionService) ParseAgentToolSessionID(sessionID string) (string, string, bool) {
  71. parts := strings.Split(sessionID, "$$")
  72. if len(parts) != 2 {
  73. return "", "", false
  74. }
  75. return parts[0], parts[1], true
  76. }
  77. func (m *mockSessionService) IsAgentToolSession(sessionID string) bool {
  78. _, _, ok := m.ParseAgentToolSessionID(sessionID)
  79. return ok
  80. }
  81. func newTestApp(sessions session.Service) *App {
  82. return &App{Sessions: sessions}
  83. }
  84. func TestResolveSession_NewSession(t *testing.T) {
  85. mock := &mockSessionService{}
  86. app := newTestApp(mock)
  87. sess, err := app.resolveSession(t.Context(), "", false)
  88. require.NoError(t, err)
  89. require.Equal(t, "new-session-id", sess.ID)
  90. require.Len(t, mock.created, 1)
  91. }
  92. func TestResolveSession_ContinueByID(t *testing.T) {
  93. mock := &mockSessionService{
  94. sessions: []session.Session{
  95. {ID: "existing-id", Title: "Old session"},
  96. },
  97. }
  98. app := newTestApp(mock)
  99. sess, err := app.resolveSession(t.Context(), "existing-id", false)
  100. require.NoError(t, err)
  101. require.Equal(t, "existing-id", sess.ID)
  102. require.Equal(t, "Old session", sess.Title)
  103. require.Empty(t, mock.created)
  104. }
  105. func TestResolveSession_ContinueByID_NotFound(t *testing.T) {
  106. mock := &mockSessionService{}
  107. app := newTestApp(mock)
  108. _, err := app.resolveSession(t.Context(), "nonexistent", false)
  109. require.Error(t, err)
  110. require.Contains(t, err.Error(), "session not found")
  111. }
  112. func TestResolveSession_ContinueByID_ChildSession(t *testing.T) {
  113. mock := &mockSessionService{
  114. sessions: []session.Session{
  115. {ID: "child-id", ParentSessionID: "parent-id", Title: "Child session"},
  116. },
  117. }
  118. app := newTestApp(mock)
  119. _, err := app.resolveSession(t.Context(), "child-id", false)
  120. require.Error(t, err)
  121. require.Contains(t, err.Error(), "cannot continue a child session")
  122. }
  123. func TestResolveSession_ContinueByID_AgentToolSession(t *testing.T) {
  124. mock := &mockSessionService{}
  125. app := newTestApp(mock)
  126. _, err := app.resolveSession(t.Context(), "msg123$$tool456", false)
  127. require.Error(t, err)
  128. require.Contains(t, err.Error(), "cannot continue an agent tool session")
  129. }
  130. func TestResolveSession_Last(t *testing.T) {
  131. mock := &mockSessionService{
  132. sessions: []session.Session{
  133. {ID: "most-recent", Title: "Latest session"},
  134. {ID: "older", Title: "Older session"},
  135. },
  136. }
  137. app := newTestApp(mock)
  138. sess, err := app.resolveSession(t.Context(), "", true)
  139. require.NoError(t, err)
  140. require.Equal(t, "most-recent", sess.ID)
  141. require.Empty(t, mock.created)
  142. }
  143. func TestResolveSession_Last_NoSessions(t *testing.T) {
  144. mock := &mockSessionService{}
  145. app := newTestApp(mock)
  146. _, err := app.resolveSession(t.Context(), "", true)
  147. require.Error(t, err)
  148. require.Contains(t, err.Error(), "no sessions found")
  149. }