parser_test.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. package hooks
  2. import (
  3. "encoding/base64"
  4. "testing"
  5. "github.com/stretchr/testify/assert"
  6. "github.com/stretchr/testify/require"
  7. )
  8. func TestParseShellEnv(t *testing.T) {
  9. t.Run("parses basic fields", func(t *testing.T) {
  10. env := []string{
  11. "PATH=/usr/bin",
  12. "CRUSH_CONTINUE=false",
  13. "CRUSH_PERMISSION=approve",
  14. "CRUSH_MESSAGE=test message",
  15. "HOME=/home/user",
  16. }
  17. result := parseShellEnv(env)
  18. assert.False(t, result.Continue)
  19. assert.Equal(t, "approve", result.Permission)
  20. assert.Equal(t, "test message", result.Message)
  21. })
  22. t.Run("parses modified prompt", func(t *testing.T) {
  23. env := []string{
  24. "CRUSH_MODIFIED_PROMPT=new prompt text",
  25. }
  26. result := parseShellEnv(env)
  27. require.NotNil(t, result.ModifiedPrompt)
  28. assert.Equal(t, "new prompt text", *result.ModifiedPrompt)
  29. })
  30. t.Run("parses context content", func(t *testing.T) {
  31. env := []string{
  32. "CRUSH_CONTEXT_CONTENT=some context",
  33. }
  34. result := parseShellEnv(env)
  35. assert.Equal(t, "some context", result.ContextContent)
  36. })
  37. t.Run("parses base64 context content", func(t *testing.T) {
  38. text := "multiline\ncontext\nhere"
  39. encoded := base64.StdEncoding.EncodeToString([]byte(text))
  40. env := []string{
  41. "CRUSH_CONTEXT_CONTENT=" + encoded,
  42. }
  43. result := parseShellEnv(env)
  44. assert.Equal(t, text, result.ContextContent)
  45. })
  46. t.Run("parses context files", func(t *testing.T) {
  47. env := []string{
  48. "CRUSH_CONTEXT_FILES=file1.md:file2.txt:file3.go",
  49. }
  50. result := parseShellEnv(env)
  51. assert.Equal(t, []string{"file1.md", "file2.txt", "file3.go"}, result.ContextFiles)
  52. })
  53. t.Run("defaults to continue=true", func(t *testing.T) {
  54. env := []string{}
  55. result := parseShellEnv(env)
  56. assert.True(t, result.Continue)
  57. })
  58. t.Run("ignores non-CRUSH env vars", func(t *testing.T) {
  59. env := []string{
  60. "PATH=/usr/bin",
  61. "HOME=/home/user",
  62. "CRUSH_MESSAGE=test",
  63. }
  64. result := parseShellEnv(env)
  65. assert.Equal(t, "test", result.Message)
  66. })
  67. t.Run("falls back to raw value for invalid base64", func(t *testing.T) {
  68. // Invalid base64 string should be used as-is.
  69. env := []string{
  70. "CRUSH_CONTEXT_CONTENT=this is not base64!@#$",
  71. }
  72. result := parseShellEnv(env)
  73. assert.Equal(t, "this is not base64!@#$", result.ContextContent)
  74. })
  75. t.Run("parses modified input", func(t *testing.T) {
  76. env := []string{
  77. "CRUSH_MODIFIED_INPUT=command=ls -la:working_dir=/tmp",
  78. }
  79. result := parseShellEnv(env)
  80. require.NotNil(t, result.ModifiedInput)
  81. assert.Equal(t, "ls -la", result.ModifiedInput["command"])
  82. assert.Equal(t, "/tmp", result.ModifiedInput["working_dir"])
  83. })
  84. t.Run("parses modified output", func(t *testing.T) {
  85. env := []string{
  86. "CRUSH_MODIFIED_OUTPUT=status=redacted:data=[REDACTED]",
  87. }
  88. result := parseShellEnv(env)
  89. require.NotNil(t, result.ModifiedOutput)
  90. assert.Equal(t, "redacted", result.ModifiedOutput["status"])
  91. assert.Equal(t, "[REDACTED]", result.ModifiedOutput["data"])
  92. })
  93. t.Run("parses modified input with JSON types", func(t *testing.T) {
  94. env := []string{
  95. `CRUSH_MODIFIED_INPUT=offset=100:limit=50:run_in_background=true:ignore=["*.log","*.tmp"]`,
  96. }
  97. result := parseShellEnv(env)
  98. require.NotNil(t, result.ModifiedInput)
  99. assert.Equal(t, float64(100), result.ModifiedInput["offset"]) // JSON numbers are float64
  100. assert.Equal(t, float64(50), result.ModifiedInput["limit"])
  101. assert.Equal(t, true, result.ModifiedInput["run_in_background"])
  102. assert.Equal(t, []any{"*.log", "*.tmp"}, result.ModifiedInput["ignore"])
  103. })
  104. t.Run("parses modified input with strings containing colons", func(t *testing.T) {
  105. // Colons in file paths should work if the value doesn't contain '='
  106. env := []string{
  107. `CRUSH_MODIFIED_INPUT=path=/usr/local/bin:name=test`,
  108. }
  109. result := parseShellEnv(env)
  110. require.NotNil(t, result.ModifiedInput)
  111. // First pair: path=/usr/local/bin
  112. // Second pair: name=test
  113. // Note: This splits on first '=' in each pair
  114. assert.Equal(t, "/usr/local/bin", result.ModifiedInput["path"])
  115. assert.Equal(t, "test", result.ModifiedInput["name"])
  116. })
  117. }
  118. func TestParseJSONResult(t *testing.T) {
  119. t.Run("parses basic fields", func(t *testing.T) {
  120. json := []byte(`{
  121. "continue": false,
  122. "permission": "deny",
  123. "message": "blocked"
  124. }`)
  125. result, err := parseJSONResult(json)
  126. require.NoError(t, err)
  127. assert.False(t, result.Continue)
  128. assert.Equal(t, "deny", result.Permission)
  129. assert.Equal(t, "blocked", result.Message)
  130. })
  131. t.Run("parses modified_input", func(t *testing.T) {
  132. json := []byte(`{
  133. "modified_input": {
  134. "command": "ls -la",
  135. "working_dir": "/tmp"
  136. }
  137. }`)
  138. result, err := parseJSONResult(json)
  139. require.NoError(t, err)
  140. assert.Equal(t, map[string]any{
  141. "command": "ls -la",
  142. "working_dir": "/tmp",
  143. }, result.ModifiedInput)
  144. })
  145. t.Run("parses modified_output", func(t *testing.T) {
  146. json := []byte(`{
  147. "modified_output": {
  148. "content": "filtered output"
  149. }
  150. }`)
  151. result, err := parseJSONResult(json)
  152. require.NoError(t, err)
  153. assert.Equal(t, map[string]any{
  154. "content": "filtered output",
  155. }, result.ModifiedOutput)
  156. })
  157. t.Run("parses context_files array", func(t *testing.T) {
  158. json := []byte(`{
  159. "context_files": ["file1.md", "file2.txt"]
  160. }`)
  161. result, err := parseJSONResult(json)
  162. require.NoError(t, err)
  163. assert.Equal(t, []string{"file1.md", "file2.txt"}, result.ContextFiles)
  164. })
  165. t.Run("returns error on invalid JSON", func(t *testing.T) {
  166. json := []byte(`{invalid}`)
  167. _, err := parseJSONResult(json)
  168. assert.Error(t, err)
  169. })
  170. t.Run("defaults to continue=true", func(t *testing.T) {
  171. json := []byte(`{"message": "test"}`)
  172. result, err := parseJSONResult(json)
  173. require.NoError(t, err)
  174. assert.True(t, result.Continue)
  175. })
  176. t.Run("handles wrong type for modified_input", func(t *testing.T) {
  177. // modified_input should be a map, but here it's a string.
  178. json := []byte(`{
  179. "modified_input": "not a map"
  180. }`)
  181. result, err := parseJSONResult(json)
  182. require.NoError(t, err)
  183. // Should be nil/empty since type assertion failed.
  184. assert.Nil(t, result.ModifiedInput)
  185. })
  186. t.Run("handles wrong type for modified_output", func(t *testing.T) {
  187. // modified_output should be a map, but here it's an array.
  188. json := []byte(`{
  189. "modified_output": ["not", "a", "map"]
  190. }`)
  191. result, err := parseJSONResult(json)
  192. require.NoError(t, err)
  193. assert.Nil(t, result.ModifiedOutput)
  194. })
  195. t.Run("handles non-string elements in context_files", func(t *testing.T) {
  196. // context_files should be array of strings, but has numbers.
  197. json := []byte(`{
  198. "context_files": ["file1.md", 123, "file2.md", null]
  199. }`)
  200. result, err := parseJSONResult(json)
  201. require.NoError(t, err)
  202. // Should only include valid strings.
  203. assert.Equal(t, []string{"file1.md", "file2.md"}, result.ContextFiles)
  204. })
  205. t.Run("handles wrong type for context_files", func(t *testing.T) {
  206. // context_files should be an array, but here it's a string.
  207. json := []byte(`{
  208. "context_files": "not an array"
  209. }`)
  210. result, err := parseJSONResult(json)
  211. require.NoError(t, err)
  212. // Should be empty since type assertion failed.
  213. assert.Empty(t, result.ContextFiles)
  214. })
  215. }
  216. func TestMergeJSONResult(t *testing.T) {
  217. t.Run("merges continue flag", func(t *testing.T) {
  218. base := &HookResult{Continue: true}
  219. json := &HookResult{Continue: false}
  220. mergeJSONResult(base, json)
  221. assert.False(t, base.Continue)
  222. })
  223. t.Run("merges permission", func(t *testing.T) {
  224. base := &HookResult{}
  225. json := &HookResult{Permission: "approve"}
  226. mergeJSONResult(base, json)
  227. assert.Equal(t, "approve", base.Permission)
  228. })
  229. t.Run("appends messages", func(t *testing.T) {
  230. base := &HookResult{Message: "first"}
  231. json := &HookResult{Message: "second"}
  232. mergeJSONResult(base, json)
  233. assert.Equal(t, "first; second", base.Message)
  234. })
  235. t.Run("merges modified_input maps", func(t *testing.T) {
  236. base := &HookResult{
  237. ModifiedInput: map[string]any{
  238. "field1": "value1",
  239. },
  240. }
  241. json := &HookResult{
  242. ModifiedInput: map[string]any{
  243. "field2": "value2",
  244. },
  245. }
  246. mergeJSONResult(base, json)
  247. assert.Equal(t, map[string]any{
  248. "field1": "value1",
  249. "field2": "value2",
  250. }, base.ModifiedInput)
  251. })
  252. t.Run("overwrites conflicting modified_input fields", func(t *testing.T) {
  253. base := &HookResult{
  254. ModifiedInput: map[string]any{
  255. "field": "old",
  256. },
  257. }
  258. json := &HookResult{
  259. ModifiedInput: map[string]any{
  260. "field": "new",
  261. },
  262. }
  263. mergeJSONResult(base, json)
  264. assert.Equal(t, "new", base.ModifiedInput["field"])
  265. })
  266. t.Run("appends context content", func(t *testing.T) {
  267. base := &HookResult{ContextContent: "first"}
  268. json := &HookResult{ContextContent: "second"}
  269. mergeJSONResult(base, json)
  270. assert.Equal(t, "first\n\nsecond", base.ContextContent)
  271. })
  272. t.Run("appends context files", func(t *testing.T) {
  273. base := &HookResult{ContextFiles: []string{"file1.md"}}
  274. json := &HookResult{ContextFiles: []string{"file2.md", "file3.md"}}
  275. mergeJSONResult(base, json)
  276. assert.Equal(t, []string{"file1.md", "file2.md", "file3.md"}, base.ContextFiles)
  277. })
  278. t.Run("initializes ModifiedInput when nil", func(t *testing.T) {
  279. // Base has nil ModifiedInput.
  280. base := &HookResult{}
  281. json := &HookResult{
  282. ModifiedInput: map[string]any{
  283. "field": "value",
  284. },
  285. }
  286. mergeJSONResult(base, json)
  287. require.NotNil(t, base.ModifiedInput)
  288. assert.Equal(t, "value", base.ModifiedInput["field"])
  289. })
  290. t.Run("initializes ModifiedOutput when nil", func(t *testing.T) {
  291. // Base has nil ModifiedOutput.
  292. base := &HookResult{}
  293. json := &HookResult{
  294. ModifiedOutput: map[string]any{
  295. "filtered": true,
  296. },
  297. }
  298. mergeJSONResult(base, json)
  299. require.NotNil(t, base.ModifiedOutput)
  300. assert.Equal(t, true, base.ModifiedOutput["filtered"])
  301. })
  302. t.Run("sets context content when base is empty", func(t *testing.T) {
  303. base := &HookResult{}
  304. json := &HookResult{ContextContent: "new content"}
  305. mergeJSONResult(base, json)
  306. assert.Equal(t, "new content", base.ContextContent)
  307. })
  308. t.Run("sets message when base is empty", func(t *testing.T) {
  309. base := &HookResult{}
  310. json := &HookResult{Message: "new message"}
  311. mergeJSONResult(base, json)
  312. assert.Equal(t, "new message", base.Message)
  313. })
  314. }