broker_test.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. package pubsub
  2. import (
  3. "context"
  4. "sync"
  5. "testing"
  6. "time"
  7. "github.com/stretchr/testify/assert"
  8. )
  9. func TestBrokerSubscribe(t *testing.T) {
  10. t.Parallel()
  11. t.Run("with cancellable context", func(t *testing.T) {
  12. t.Parallel()
  13. broker := NewBroker[string]()
  14. ctx, cancel := context.WithCancel(context.Background())
  15. defer cancel()
  16. ch := broker.Subscribe(ctx)
  17. assert.NotNil(t, ch)
  18. assert.Equal(t, 1, broker.GetSubscriberCount())
  19. // Cancel the context should remove the subscription
  20. cancel()
  21. time.Sleep(10 * time.Millisecond) // Give time for goroutine to process
  22. assert.Equal(t, 0, broker.GetSubscriberCount())
  23. })
  24. t.Run("with background context", func(t *testing.T) {
  25. t.Parallel()
  26. broker := NewBroker[string]()
  27. // Using context.Background() should not leak goroutines
  28. ch := broker.Subscribe(context.Background())
  29. assert.NotNil(t, ch)
  30. assert.Equal(t, 1, broker.GetSubscriberCount())
  31. // Shutdown should clean up all subscriptions
  32. broker.Shutdown()
  33. assert.Equal(t, 0, broker.GetSubscriberCount())
  34. })
  35. }
  36. func TestBrokerPublish(t *testing.T) {
  37. t.Parallel()
  38. broker := NewBroker[string]()
  39. ctx := t.Context()
  40. ch := broker.Subscribe(ctx)
  41. // Publish a message
  42. broker.Publish(EventTypeCreated, "test message")
  43. // Verify message is received
  44. select {
  45. case event := <-ch:
  46. assert.Equal(t, EventTypeCreated, event.Type)
  47. assert.Equal(t, "test message", event.Payload)
  48. case <-time.After(100 * time.Millisecond):
  49. t.Fatal("timeout waiting for message")
  50. }
  51. }
  52. func TestBrokerShutdown(t *testing.T) {
  53. t.Parallel()
  54. broker := NewBroker[string]()
  55. // Create multiple subscribers
  56. ch1 := broker.Subscribe(context.Background())
  57. ch2 := broker.Subscribe(context.Background())
  58. assert.Equal(t, 2, broker.GetSubscriberCount())
  59. // Shutdown should close all channels and clean up
  60. broker.Shutdown()
  61. // Verify channels are closed
  62. _, ok1 := <-ch1
  63. _, ok2 := <-ch2
  64. assert.False(t, ok1, "channel 1 should be closed")
  65. assert.False(t, ok2, "channel 2 should be closed")
  66. // Verify subscriber count is reset
  67. assert.Equal(t, 0, broker.GetSubscriberCount())
  68. }
  69. func TestBrokerConcurrency(t *testing.T) {
  70. t.Parallel()
  71. broker := NewBroker[int]()
  72. // Create a large number of subscribers
  73. const numSubscribers = 100
  74. var wg sync.WaitGroup
  75. wg.Add(numSubscribers)
  76. // Create a channel to collect received events
  77. receivedEvents := make(chan int, numSubscribers)
  78. for i := range numSubscribers {
  79. go func(id int) {
  80. defer wg.Done()
  81. ctx, cancel := context.WithCancel(context.Background())
  82. defer cancel()
  83. ch := broker.Subscribe(ctx)
  84. // Receive one message then cancel
  85. select {
  86. case event := <-ch:
  87. receivedEvents <- event.Payload
  88. case <-time.After(1 * time.Second):
  89. t.Errorf("timeout waiting for message %d", id)
  90. }
  91. cancel()
  92. }(i)
  93. }
  94. // Give subscribers time to set up
  95. time.Sleep(10 * time.Millisecond)
  96. // Publish messages to all subscribers
  97. for i := range numSubscribers {
  98. broker.Publish(EventTypeCreated, i)
  99. }
  100. // Wait for all subscribers to finish
  101. wg.Wait()
  102. close(receivedEvents)
  103. // Give time for cleanup goroutines to run
  104. time.Sleep(10 * time.Millisecond)
  105. // Verify all subscribers are cleaned up
  106. assert.Equal(t, 0, broker.GetSubscriberCount())
  107. // Verify we received the expected number of events
  108. count := 0
  109. for range receivedEvents {
  110. count++
  111. }
  112. assert.Equal(t, numSubscribers, count)
  113. }