2
0

resolve_session_test.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. package app
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "strings"
  7. "testing"
  8. "github.com/charmbracelet/crush/internal/pubsub"
  9. "github.com/charmbracelet/crush/internal/session"
  10. "github.com/stretchr/testify/require"
  11. )
  12. // mockSessionService is a minimal mock of session.Service for testing resolveSession.
  13. type mockSessionService struct {
  14. sessions []session.Session
  15. created []session.Session
  16. }
  17. func (m *mockSessionService) Subscribe(context.Context) <-chan pubsub.Event[session.Session] {
  18. return make(chan pubsub.Event[session.Session])
  19. }
  20. func (m *mockSessionService) Create(_ context.Context, title string) (session.Session, error) {
  21. s := session.Session{ID: "new-session-id", Title: title}
  22. m.created = append(m.created, s)
  23. return s, nil
  24. }
  25. func (m *mockSessionService) CreateTitleSession(context.Context, string) (session.Session, error) {
  26. return session.Session{}, nil
  27. }
  28. func (m *mockSessionService) CreateTaskSession(context.Context, string, string, string) (session.Session, error) {
  29. return session.Session{}, nil
  30. }
  31. func (m *mockSessionService) Get(_ context.Context, id string) (session.Session, error) {
  32. for _, s := range m.sessions {
  33. if s.ID == id {
  34. return s, nil
  35. }
  36. }
  37. return session.Session{}, sql.ErrNoRows
  38. }
  39. func (m *mockSessionService) GetLast(_ context.Context) (session.Session, error) {
  40. if len(m.sessions) > 0 {
  41. return m.sessions[0], nil
  42. }
  43. return session.Session{}, sql.ErrNoRows
  44. }
  45. func (m *mockSessionService) List(context.Context) ([]session.Session, error) {
  46. return m.sessions, nil
  47. }
  48. func (m *mockSessionService) Save(_ context.Context, s session.Session) (session.Session, error) {
  49. return s, nil
  50. }
  51. func (m *mockSessionService) UpdateTitleAndUsage(context.Context, string, string, int64, int64, float64) error {
  52. return nil
  53. }
  54. func (m *mockSessionService) Rename(context.Context, string, string) error {
  55. return nil
  56. }
  57. func (m *mockSessionService) Delete(context.Context, string) error {
  58. return nil
  59. }
  60. func (m *mockSessionService) CreateAgentToolSessionID(messageID, toolCallID string) string {
  61. return fmt.Sprintf("%s$$%s", messageID, toolCallID)
  62. }
  63. func (m *mockSessionService) ParseAgentToolSessionID(sessionID string) (string, string, bool) {
  64. parts := strings.Split(sessionID, "$$")
  65. if len(parts) != 2 {
  66. return "", "", false
  67. }
  68. return parts[0], parts[1], true
  69. }
  70. func (m *mockSessionService) IsAgentToolSession(sessionID string) bool {
  71. _, _, ok := m.ParseAgentToolSessionID(sessionID)
  72. return ok
  73. }
  74. func newTestApp(sessions session.Service) *App {
  75. return &App{Sessions: sessions}
  76. }
  77. func TestResolveSession_NewSession(t *testing.T) {
  78. mock := &mockSessionService{}
  79. app := newTestApp(mock)
  80. sess, err := app.resolveSession(t.Context(), "", false)
  81. require.NoError(t, err)
  82. require.Equal(t, "new-session-id", sess.ID)
  83. require.Len(t, mock.created, 1)
  84. }
  85. func TestResolveSession_ContinueByID(t *testing.T) {
  86. mock := &mockSessionService{
  87. sessions: []session.Session{
  88. {ID: "existing-id", Title: "Old session"},
  89. },
  90. }
  91. app := newTestApp(mock)
  92. sess, err := app.resolveSession(t.Context(), "existing-id", false)
  93. require.NoError(t, err)
  94. require.Equal(t, "existing-id", sess.ID)
  95. require.Equal(t, "Old session", sess.Title)
  96. require.Empty(t, mock.created)
  97. }
  98. func TestResolveSession_ContinueByID_NotFound(t *testing.T) {
  99. mock := &mockSessionService{}
  100. app := newTestApp(mock)
  101. _, err := app.resolveSession(t.Context(), "nonexistent", false)
  102. require.Error(t, err)
  103. require.Contains(t, err.Error(), "session not found")
  104. }
  105. func TestResolveSession_ContinueByID_ChildSession(t *testing.T) {
  106. mock := &mockSessionService{
  107. sessions: []session.Session{
  108. {ID: "child-id", ParentSessionID: "parent-id", Title: "Child session"},
  109. },
  110. }
  111. app := newTestApp(mock)
  112. _, err := app.resolveSession(t.Context(), "child-id", false)
  113. require.Error(t, err)
  114. require.Contains(t, err.Error(), "cannot continue a child session")
  115. }
  116. func TestResolveSession_ContinueByID_AgentToolSession(t *testing.T) {
  117. mock := &mockSessionService{}
  118. app := newTestApp(mock)
  119. _, err := app.resolveSession(t.Context(), "msg123$$tool456", false)
  120. require.Error(t, err)
  121. require.Contains(t, err.Error(), "cannot continue an agent tool session")
  122. }
  123. func TestResolveSession_Last(t *testing.T) {
  124. mock := &mockSessionService{
  125. sessions: []session.Session{
  126. {ID: "most-recent", Title: "Latest session"},
  127. {ID: "older", Title: "Older session"},
  128. },
  129. }
  130. app := newTestApp(mock)
  131. sess, err := app.resolveSession(t.Context(), "", true)
  132. require.NoError(t, err)
  133. require.Equal(t, "most-recent", sess.ID)
  134. require.Empty(t, mock.created)
  135. }
  136. func TestResolveSession_Last_NoSessions(t *testing.T) {
  137. mock := &mockSessionService{}
  138. app := newTestApp(mock)
  139. _, err := app.resolveSession(t.Context(), "", true)
  140. require.Error(t, err)
  141. require.Contains(t, err.Error(), "no sessions found")
  142. }