context_test.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. package tools
  2. import (
  3. "context"
  4. "testing"
  5. )
  6. // Test-specific context key types to avoid collisions
  7. type (
  8. testStringKey string
  9. testBoolKey string
  10. testIntKey string
  11. )
  12. const (
  13. testKey testStringKey = "testKey"
  14. missingKey testStringKey = "missingKey"
  15. boolTestKey testBoolKey = "boolKey"
  16. intTestKey testIntKey = "intKey"
  17. )
  18. func TestGetContextValue(t *testing.T) {
  19. tests := []struct {
  20. name string
  21. setup func(ctx context.Context) context.Context
  22. key any
  23. defaultValue any
  24. want any
  25. }{
  26. {
  27. name: "returns string value",
  28. setup: func(ctx context.Context) context.Context {
  29. return context.WithValue(ctx, testKey, "testValue")
  30. },
  31. key: testKey,
  32. defaultValue: "",
  33. want: "testValue",
  34. },
  35. {
  36. name: "returns default when key not found",
  37. setup: func(ctx context.Context) context.Context {
  38. return ctx
  39. },
  40. key: missingKey,
  41. defaultValue: "default",
  42. want: "default",
  43. },
  44. {
  45. name: "returns default when type mismatch",
  46. setup: func(ctx context.Context) context.Context {
  47. return context.WithValue(ctx, testKey, 123) // int, not string
  48. },
  49. key: testKey,
  50. defaultValue: "default",
  51. want: "default",
  52. },
  53. {
  54. name: "returns bool value",
  55. setup: func(ctx context.Context) context.Context {
  56. return context.WithValue(ctx, boolTestKey, true)
  57. },
  58. key: boolTestKey,
  59. defaultValue: false,
  60. want: true,
  61. },
  62. {
  63. name: "returns int value",
  64. setup: func(ctx context.Context) context.Context {
  65. return context.WithValue(ctx, intTestKey, 42)
  66. },
  67. key: intTestKey,
  68. defaultValue: 0,
  69. want: 42,
  70. },
  71. }
  72. for _, tt := range tests {
  73. t.Run(tt.name, func(t *testing.T) {
  74. ctx := tt.setup(context.Background())
  75. var got any
  76. switch tt.defaultValue.(type) {
  77. case string:
  78. got = getContextValue(ctx, tt.key, tt.defaultValue.(string))
  79. case bool:
  80. got = getContextValue(ctx, tt.key, tt.defaultValue.(bool))
  81. case int:
  82. got = getContextValue(ctx, tt.key, tt.defaultValue.(int))
  83. }
  84. if got != tt.want {
  85. t.Errorf("getContextValue() = %v, want %v", got, tt.want)
  86. }
  87. })
  88. }
  89. }
  90. func TestGetSessionFromContext(t *testing.T) {
  91. tests := []struct {
  92. name string
  93. ctx context.Context
  94. want string
  95. }{
  96. {
  97. name: "returns session ID when present",
  98. ctx: context.WithValue(context.Background(), SessionIDContextKey, "session-123"),
  99. want: "session-123",
  100. },
  101. {
  102. name: "returns empty string when not present",
  103. ctx: context.Background(),
  104. want: "",
  105. },
  106. {
  107. name: "returns empty string when wrong type",
  108. ctx: context.WithValue(context.Background(), SessionIDContextKey, 123),
  109. want: "",
  110. },
  111. }
  112. for _, tt := range tests {
  113. t.Run(tt.name, func(t *testing.T) {
  114. got := GetSessionFromContext(tt.ctx)
  115. if got != tt.want {
  116. t.Errorf("GetSessionFromContext() = %v, want %v", got, tt.want)
  117. }
  118. })
  119. }
  120. }
  121. func TestGetMessageFromContext(t *testing.T) {
  122. tests := []struct {
  123. name string
  124. ctx context.Context
  125. want string
  126. }{
  127. {
  128. name: "returns message ID when present",
  129. ctx: context.WithValue(context.Background(), MessageIDContextKey, "msg-456"),
  130. want: "msg-456",
  131. },
  132. {
  133. name: "returns empty string when not present",
  134. ctx: context.Background(),
  135. want: "",
  136. },
  137. {
  138. name: "returns empty string when wrong type",
  139. ctx: context.WithValue(context.Background(), MessageIDContextKey, 456),
  140. want: "",
  141. },
  142. }
  143. for _, tt := range tests {
  144. t.Run(tt.name, func(t *testing.T) {
  145. got := GetMessageFromContext(tt.ctx)
  146. if got != tt.want {
  147. t.Errorf("GetMessageFromContext() = %v, want %v", got, tt.want)
  148. }
  149. })
  150. }
  151. }
  152. func TestGetSupportsImagesFromContext(t *testing.T) {
  153. tests := []struct {
  154. name string
  155. ctx context.Context
  156. want bool
  157. }{
  158. {
  159. name: "returns true when present and true",
  160. ctx: context.WithValue(context.Background(), SupportsImagesContextKey, true),
  161. want: true,
  162. },
  163. {
  164. name: "returns false when present and false",
  165. ctx: context.WithValue(context.Background(), SupportsImagesContextKey, false),
  166. want: false,
  167. },
  168. {
  169. name: "returns false when not present",
  170. ctx: context.Background(),
  171. want: false,
  172. },
  173. {
  174. name: "returns false when wrong type",
  175. ctx: context.WithValue(context.Background(), SupportsImagesContextKey, "true"),
  176. want: false,
  177. },
  178. }
  179. for _, tt := range tests {
  180. t.Run(tt.name, func(t *testing.T) {
  181. got := GetSupportsImagesFromContext(tt.ctx)
  182. if got != tt.want {
  183. t.Errorf("GetSupportsImagesFromContext() = %v, want %v", got, tt.want)
  184. }
  185. })
  186. }
  187. }
  188. func TestGetModelNameFromContext(t *testing.T) {
  189. tests := []struct {
  190. name string
  191. ctx context.Context
  192. want string
  193. }{
  194. {
  195. name: "returns model name when present",
  196. ctx: context.WithValue(context.Background(), ModelNameContextKey, "claude-opus-4"),
  197. want: "claude-opus-4",
  198. },
  199. {
  200. name: "returns empty string when not present",
  201. ctx: context.Background(),
  202. want: "",
  203. },
  204. {
  205. name: "returns empty string when wrong type",
  206. ctx: context.WithValue(context.Background(), ModelNameContextKey, 789),
  207. want: "",
  208. },
  209. }
  210. for _, tt := range tests {
  211. t.Run(tt.name, func(t *testing.T) {
  212. got := GetModelNameFromContext(tt.ctx)
  213. if got != tt.want {
  214. t.Errorf("GetModelNameFromContext() = %v, want %v", got, tt.want)
  215. }
  216. })
  217. }
  218. }