config_test.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465
  1. package config
  2. import (
  3. "fmt"
  4. "os"
  5. "path/filepath"
  6. "testing"
  7. "github.com/kujtimiihoxha/termai/internal/llm/models"
  8. "github.com/spf13/viper"
  9. "github.com/stretchr/testify/assert"
  10. "github.com/stretchr/testify/require"
  11. )
  12. func TestLoad(t *testing.T) {
  13. setupTest(t)
  14. t.Run("loads configuration successfully", func(t *testing.T) {
  15. homeDir := t.TempDir()
  16. t.Setenv("HOME", homeDir)
  17. configPath := filepath.Join(homeDir, ".termai.json")
  18. configContent := `{
  19. "data": {
  20. "directory": "custom-dir"
  21. },
  22. "log": {
  23. "level": "debug"
  24. },
  25. "mcpServers": {
  26. "test-server": {
  27. "command": "test-command",
  28. "env": ["TEST_ENV=value"],
  29. "args": ["--arg1", "--arg2"],
  30. "type": "stdio",
  31. "url": "",
  32. "headers": {}
  33. },
  34. "sse-server": {
  35. "command": "",
  36. "env": [],
  37. "args": [],
  38. "type": "sse",
  39. "url": "https://api.example.com/events",
  40. "headers": {
  41. "Authorization": "Bearer token123",
  42. "Content-Type": "application/json"
  43. }
  44. }
  45. },
  46. "providers": {
  47. "anthropic": {
  48. "apiKey": "test-api-key",
  49. "enabled": true
  50. }
  51. },
  52. "model": {
  53. "coder": "claude-3-haiku",
  54. "task": "claude-3-haiku"
  55. }
  56. }`
  57. err := os.WriteFile(configPath, []byte(configContent), 0o644)
  58. require.NoError(t, err)
  59. cfg = nil
  60. viper.Reset()
  61. err = Load(false)
  62. require.NoError(t, err)
  63. config := Get()
  64. assert.NotNil(t, config)
  65. assert.Equal(t, "custom-dir", config.Data.Directory)
  66. assert.Equal(t, "debug", config.Log.Level)
  67. assert.Contains(t, config.MCPServers, "test-server")
  68. stdioServer := config.MCPServers["test-server"]
  69. assert.Equal(t, "test-command", stdioServer.Command)
  70. assert.Equal(t, []string{"TEST_ENV=value"}, stdioServer.Env)
  71. assert.Equal(t, []string{"--arg1", "--arg2"}, stdioServer.Args)
  72. assert.Equal(t, MCPStdio, stdioServer.Type)
  73. assert.Equal(t, "", stdioServer.URL)
  74. assert.Empty(t, stdioServer.Headers)
  75. assert.Contains(t, config.MCPServers, "sse-server")
  76. sseServer := config.MCPServers["sse-server"]
  77. assert.Equal(t, "", sseServer.Command)
  78. assert.Empty(t, sseServer.Env)
  79. assert.Empty(t, sseServer.Args)
  80. assert.Equal(t, MCPSse, sseServer.Type)
  81. assert.Equal(t, "https://api.example.com/events", sseServer.URL)
  82. assert.Equal(t, map[string]string{
  83. "authorization": "Bearer token123",
  84. "content-type": "application/json",
  85. }, sseServer.Headers)
  86. assert.Contains(t, config.Providers, models.ModelProvider("anthropic"))
  87. provider := config.Providers[models.ModelProvider("anthropic")]
  88. assert.Equal(t, "test-api-key", provider.APIKey)
  89. assert.True(t, provider.Enabled)
  90. assert.NotNil(t, config.Model)
  91. assert.Equal(t, models.Claude3Haiku, config.Model.Coder)
  92. assert.Equal(t, models.Claude3Haiku, config.Model.Task)
  93. assert.Equal(t, defaultMaxTokens, config.Model.CoderMaxTokens)
  94. })
  95. t.Run("loads configuration with environment variables", func(t *testing.T) {
  96. homeDir := t.TempDir()
  97. t.Setenv("HOME", homeDir)
  98. configPath := filepath.Join(homeDir, ".termai.json")
  99. err := os.WriteFile(configPath, []byte("{}"), 0o644)
  100. require.NoError(t, err)
  101. t.Setenv("ANTHROPIC_API_KEY", "env-anthropic-key")
  102. t.Setenv("OPENAI_API_KEY", "env-openai-key")
  103. t.Setenv("GEMINI_API_KEY", "env-gemini-key")
  104. cfg = nil
  105. viper.Reset()
  106. err = Load(false)
  107. require.NoError(t, err)
  108. config := Get()
  109. assert.NotNil(t, config)
  110. assert.Equal(t, defaultDataDirectory, config.Data.Directory)
  111. assert.Equal(t, defaultLogLevel, config.Log.Level)
  112. assert.Contains(t, config.Providers, models.ModelProvider("anthropic"))
  113. assert.Equal(t, "env-anthropic-key", config.Providers[models.ModelProvider("anthropic")].APIKey)
  114. assert.True(t, config.Providers[models.ModelProvider("anthropic")].Enabled)
  115. assert.Contains(t, config.Providers, models.ModelProvider("openai"))
  116. assert.Equal(t, "env-openai-key", config.Providers[models.ModelProvider("openai")].APIKey)
  117. assert.True(t, config.Providers[models.ModelProvider("openai")].Enabled)
  118. assert.Contains(t, config.Providers, models.ModelProvider("gemini"))
  119. assert.Equal(t, "env-gemini-key", config.Providers[models.ModelProvider("gemini")].APIKey)
  120. assert.True(t, config.Providers[models.ModelProvider("gemini")].Enabled)
  121. assert.Equal(t, models.Claude37Sonnet, config.Model.Coder)
  122. })
  123. t.Run("local config overrides global config", func(t *testing.T) {
  124. homeDir := t.TempDir()
  125. t.Setenv("HOME", homeDir)
  126. globalConfigPath := filepath.Join(homeDir, ".termai.json")
  127. globalConfig := `{
  128. "data": {
  129. "directory": "global-dir"
  130. },
  131. "log": {
  132. "level": "info"
  133. }
  134. }`
  135. err := os.WriteFile(globalConfigPath, []byte(globalConfig), 0o644)
  136. require.NoError(t, err)
  137. workDir := t.TempDir()
  138. origDir, err := os.Getwd()
  139. require.NoError(t, err)
  140. defer os.Chdir(origDir)
  141. err = os.Chdir(workDir)
  142. require.NoError(t, err)
  143. localConfigPath := filepath.Join(workDir, ".termai.json")
  144. localConfig := `{
  145. "data": {
  146. "directory": "local-dir"
  147. },
  148. "log": {
  149. "level": "debug"
  150. }
  151. }`
  152. err = os.WriteFile(localConfigPath, []byte(localConfig), 0o644)
  153. require.NoError(t, err)
  154. cfg = nil
  155. viper.Reset()
  156. err = Load(false)
  157. require.NoError(t, err)
  158. config := Get()
  159. assert.NotNil(t, config)
  160. assert.Equal(t, "local-dir", config.Data.Directory)
  161. assert.Equal(t, "debug", config.Log.Level)
  162. })
  163. t.Run("missing config file should not return error", func(t *testing.T) {
  164. emptyDir := t.TempDir()
  165. t.Setenv("HOME", emptyDir)
  166. cfg = nil
  167. viper.Reset()
  168. err := Load(false)
  169. assert.NoError(t, err)
  170. })
  171. t.Run("model priority and fallbacks", func(t *testing.T) {
  172. testCases := []struct {
  173. name string
  174. anthropicKey string
  175. openaiKey string
  176. geminiKey string
  177. expectedModel models.ModelID
  178. explicitModel models.ModelID
  179. useExplicitModel bool
  180. }{
  181. {
  182. name: "anthropic has priority",
  183. anthropicKey: "test-key",
  184. openaiKey: "test-key",
  185. geminiKey: "test-key",
  186. expectedModel: models.Claude37Sonnet,
  187. },
  188. {
  189. name: "fallback to openai when no anthropic",
  190. anthropicKey: "",
  191. openaiKey: "test-key",
  192. geminiKey: "test-key",
  193. expectedModel: models.GPT41,
  194. },
  195. {
  196. name: "fallback to gemini when no others",
  197. anthropicKey: "",
  198. openaiKey: "",
  199. geminiKey: "test-key",
  200. expectedModel: models.GRMINI20Flash,
  201. },
  202. {
  203. name: "explicit model overrides defaults",
  204. anthropicKey: "test-key",
  205. openaiKey: "test-key",
  206. geminiKey: "test-key",
  207. explicitModel: models.GPT41,
  208. useExplicitModel: true,
  209. expectedModel: models.GPT41,
  210. },
  211. }
  212. for _, tc := range testCases {
  213. t.Run(tc.name, func(t *testing.T) {
  214. homeDir := t.TempDir()
  215. t.Setenv("HOME", homeDir)
  216. configPath := filepath.Join(homeDir, ".termai.json")
  217. configContent := "{}"
  218. if tc.useExplicitModel {
  219. configContent = fmt.Sprintf(`{"model":{"coder":"%s"}}`, tc.explicitModel)
  220. }
  221. err := os.WriteFile(configPath, []byte(configContent), 0o644)
  222. require.NoError(t, err)
  223. if tc.anthropicKey != "" {
  224. t.Setenv("ANTHROPIC_API_KEY", tc.anthropicKey)
  225. } else {
  226. t.Setenv("ANTHROPIC_API_KEY", "")
  227. }
  228. if tc.openaiKey != "" {
  229. t.Setenv("OPENAI_API_KEY", tc.openaiKey)
  230. } else {
  231. t.Setenv("OPENAI_API_KEY", "")
  232. }
  233. if tc.geminiKey != "" {
  234. t.Setenv("GEMINI_API_KEY", tc.geminiKey)
  235. } else {
  236. t.Setenv("GEMINI_API_KEY", "")
  237. }
  238. cfg = nil
  239. viper.Reset()
  240. err = Load(false)
  241. require.NoError(t, err)
  242. config := Get()
  243. assert.NotNil(t, config)
  244. assert.Equal(t, tc.expectedModel, config.Model.Coder)
  245. })
  246. }
  247. })
  248. }
  249. func TestGet(t *testing.T) {
  250. t.Run("get returns same config instance", func(t *testing.T) {
  251. setupTest(t)
  252. homeDir := t.TempDir()
  253. t.Setenv("HOME", homeDir)
  254. configPath := filepath.Join(homeDir, ".termai.json")
  255. err := os.WriteFile(configPath, []byte("{}"), 0o644)
  256. require.NoError(t, err)
  257. cfg = nil
  258. viper.Reset()
  259. config1 := Get()
  260. require.NotNil(t, config1)
  261. config2 := Get()
  262. require.NotNil(t, config2)
  263. assert.Same(t, config1, config2)
  264. })
  265. t.Run("get loads config if not loaded", func(t *testing.T) {
  266. setupTest(t)
  267. homeDir := t.TempDir()
  268. t.Setenv("HOME", homeDir)
  269. configPath := filepath.Join(homeDir, ".termai.json")
  270. configContent := `{"data":{"directory":"test-dir"}}`
  271. err := os.WriteFile(configPath, []byte(configContent), 0o644)
  272. require.NoError(t, err)
  273. cfg = nil
  274. viper.Reset()
  275. config := Get()
  276. require.NotNil(t, config)
  277. assert.Equal(t, "test-dir", config.Data.Directory)
  278. })
  279. }
  280. func TestWorkingDirectory(t *testing.T) {
  281. t.Run("returns current working directory", func(t *testing.T) {
  282. setupTest(t)
  283. homeDir := t.TempDir()
  284. t.Setenv("HOME", homeDir)
  285. configPath := filepath.Join(homeDir, ".termai.json")
  286. err := os.WriteFile(configPath, []byte("{}"), 0o644)
  287. require.NoError(t, err)
  288. cfg = nil
  289. viper.Reset()
  290. err = Load(false)
  291. require.NoError(t, err)
  292. wd := WorkingDirectory()
  293. expectedWd, err := os.Getwd()
  294. require.NoError(t, err)
  295. assert.Equal(t, expectedWd, wd)
  296. })
  297. }
  298. func TestWrite(t *testing.T) {
  299. t.Run("writes config to file", func(t *testing.T) {
  300. setupTest(t)
  301. homeDir := t.TempDir()
  302. t.Setenv("HOME", homeDir)
  303. configPath := filepath.Join(homeDir, ".termai.json")
  304. err := os.WriteFile(configPath, []byte("{}"), 0o644)
  305. require.NoError(t, err)
  306. cfg = nil
  307. viper.Reset()
  308. err = Load(false)
  309. require.NoError(t, err)
  310. viper.Set("data.directory", "modified-dir")
  311. err = Write()
  312. require.NoError(t, err)
  313. content, err := os.ReadFile(configPath)
  314. require.NoError(t, err)
  315. assert.Contains(t, string(content), "modified-dir")
  316. })
  317. }
  318. func TestMCPType(t *testing.T) {
  319. t.Run("MCPType constants", func(t *testing.T) {
  320. assert.Equal(t, MCPType("stdio"), MCPStdio)
  321. assert.Equal(t, MCPType("sse"), MCPSse)
  322. })
  323. t.Run("MCPType JSON unmarshaling", func(t *testing.T) {
  324. homeDir := t.TempDir()
  325. t.Setenv("HOME", homeDir)
  326. configPath := filepath.Join(homeDir, ".termai.json")
  327. configContent := `{
  328. "mcpServers": {
  329. "stdio-server": {
  330. "type": "stdio"
  331. },
  332. "sse-server": {
  333. "type": "sse"
  334. },
  335. "invalid-server": {
  336. "type": "invalid"
  337. }
  338. }
  339. }`
  340. err := os.WriteFile(configPath, []byte(configContent), 0o644)
  341. require.NoError(t, err)
  342. cfg = nil
  343. viper.Reset()
  344. err = Load(false)
  345. require.NoError(t, err)
  346. config := Get()
  347. assert.NotNil(t, config)
  348. assert.Equal(t, MCPStdio, config.MCPServers["stdio-server"].Type)
  349. assert.Equal(t, MCPSse, config.MCPServers["sse-server"].Type)
  350. assert.Equal(t, MCPType("invalid"), config.MCPServers["invalid-server"].Type)
  351. })
  352. t.Run("default MCPType", func(t *testing.T) {
  353. homeDir := t.TempDir()
  354. t.Setenv("HOME", homeDir)
  355. configPath := filepath.Join(homeDir, ".termai.json")
  356. configContent := `{
  357. "mcpServers": {
  358. "test-server": {
  359. "command": "test-command"
  360. }
  361. }
  362. }`
  363. err := os.WriteFile(configPath, []byte(configContent), 0o644)
  364. require.NoError(t, err)
  365. cfg = nil
  366. viper.Reset()
  367. err = Load(false)
  368. require.NoError(t, err)
  369. config := Get()
  370. assert.NotNil(t, config)
  371. assert.Equal(t, MCPType(""), config.MCPServers["test-server"].Type)
  372. })
  373. }
  374. func setupTest(t *testing.T) {
  375. origHome := os.Getenv("HOME")
  376. origXdgConfigHome := os.Getenv("XDG_CONFIG_HOME")
  377. origAnthropicKey := os.Getenv("ANTHROPIC_API_KEY")
  378. origOpenAIKey := os.Getenv("OPENAI_API_KEY")
  379. origGeminiKey := os.Getenv("GEMINI_API_KEY")
  380. t.Cleanup(func() {
  381. t.Setenv("HOME", origHome)
  382. t.Setenv("XDG_CONFIG_HOME", origXdgConfigHome)
  383. t.Setenv("ANTHROPIC_API_KEY", origAnthropicKey)
  384. t.Setenv("OPENAI_API_KEY", origOpenAIKey)
  385. t.Setenv("GEMINI_API_KEY", origGeminiKey)
  386. cfg = nil
  387. viper.Reset()
  388. })
  389. }