config_test.go 9.6 KB


  1. package config_test
  2. import (
  3. "encoding/json"
  4. "io/ioutil"
  5. "os"
  6. "path/filepath"
  7. "strings"
  8. "testing"
  9. "github.com/stretchr/testify/assert"
  10. "github.com/drakkan/sftpgo/config"
  11. "github.com/drakkan/sftpgo/dataprovider"
  12. "github.com/drakkan/sftpgo/httpclient"
  13. "github.com/drakkan/sftpgo/httpd"
  14. "github.com/drakkan/sftpgo/sftpd"
  15. "github.com/drakkan/sftpgo/utils"
  16. )
  17. const (
  18. tempConfigName = "temp"
  19. configName = "sftpgo"
  20. )
  21. func TestLoadConfigTest(t *testing.T) {
  22. configDir := ".."
  23. err := config.LoadConfig(configDir, configName)
  24. assert.NoError(t, err)
  25. assert.NotEqual(t, httpd.Conf{}, config.GetHTTPConfig())
  26. assert.NotEqual(t, dataprovider.Config{}, config.GetProviderConf())
  27. assert.NotEqual(t, sftpd.Configuration{}, config.GetSFTPDConfig())
  28. assert.NotEqual(t, httpclient.Config{}, config.GetHTTPConfig())
  29. confName := tempConfigName + ".json"
  30. configFilePath := filepath.Join(configDir, confName)
  31. err = config.LoadConfig(configDir, tempConfigName)
  32. assert.NotNil(t, err)
  33. err = ioutil.WriteFile(configFilePath, []byte("{invalid json}"), 0666)
  34. assert.NoError(t, err)
  35. err = config.LoadConfig(configDir, tempConfigName)
  36. assert.NotNil(t, err)
  37. err = ioutil.WriteFile(configFilePath, []byte("{\"sftpd\": {\"bind_port\": \"a\"}}"), 0666)
  38. assert.NoError(t, err)
  39. err = config.LoadConfig(configDir, tempConfigName)
  40. assert.NotNil(t, err)
  41. err = os.Remove(configFilePath)
  42. assert.NoError(t, err)
  43. }
  44. func TestEmptyBanner(t *testing.T) {
  45. configDir := ".."
  46. confName := tempConfigName + ".json"
  47. configFilePath := filepath.Join(configDir, confName)
  48. err := config.LoadConfig(configDir, configName)
  49. assert.NoError(t, err)
  50. sftpdConf := config.GetSFTPDConfig()
  51. sftpdConf.Banner = " "
  52. c := make(map[string]sftpd.Configuration)
  53. c["sftpd"] = sftpdConf
  54. jsonConf, _ := json.Marshal(c)
  55. err = ioutil.WriteFile(configFilePath, jsonConf, 0666)
  56. assert.NoError(t, err)
  57. err = config.LoadConfig(configDir, tempConfigName)
  58. assert.NoError(t, err)
  59. sftpdConf = config.GetSFTPDConfig()
  60. assert.NotEmpty(t, strings.TrimSpace(sftpdConf.Banner))
  61. err = os.Remove(configFilePath)
  62. assert.NoError(t, err)
  63. }
  64. func TestInvalidUploadMode(t *testing.T) {
  65. configDir := ".."
  66. confName := tempConfigName + ".json"
  67. configFilePath := filepath.Join(configDir, confName)
  68. err := config.LoadConfig(configDir, configName)
  69. assert.NoError(t, err)
  70. sftpdConf := config.GetSFTPDConfig()
  71. sftpdConf.UploadMode = 10
  72. c := make(map[string]sftpd.Configuration)
  73. c["sftpd"] = sftpdConf
  74. jsonConf, err := json.Marshal(c)
  75. assert.NoError(t, err)
  76. err = ioutil.WriteFile(configFilePath, jsonConf, 0666)
  77. assert.NoError(t, err)
  78. err = config.LoadConfig(configDir, tempConfigName)
  79. assert.NotNil(t, err)
  80. err = os.Remove(configFilePath)
  81. assert.NoError(t, err)
  82. }
  83. func TestInvalidExternalAuthScope(t *testing.T) {
  84. configDir := ".."
  85. confName := tempConfigName + ".json"
  86. configFilePath := filepath.Join(configDir, confName)
  87. err := config.LoadConfig(configDir, configName)
  88. assert.NoError(t, err)
  89. providerConf := config.GetProviderConf()
  90. providerConf.ExternalAuthScope = 10
  91. c := make(map[string]dataprovider.Config)
  92. c["data_provider"] = providerConf
  93. jsonConf, err := json.Marshal(c)
  94. assert.NoError(t, err)
  95. err = ioutil.WriteFile(configFilePath, jsonConf, 0666)
  96. assert.NoError(t, err)
  97. err = config.LoadConfig(configDir, tempConfigName)
  98. assert.NotNil(t, err)
  99. err = os.Remove(configFilePath)
  100. assert.NoError(t, err)
  101. }
  102. func TestInvalidCredentialsPath(t *testing.T) {
  103. configDir := ".."
  104. confName := tempConfigName + ".json"
  105. configFilePath := filepath.Join(configDir, confName)
  106. err := config.LoadConfig(configDir, configName)
  107. assert.NoError(t, err)
  108. providerConf := config.GetProviderConf()
  109. providerConf.CredentialsPath = ""
  110. c := make(map[string]dataprovider.Config)
  111. c["data_provider"] = providerConf
  112. jsonConf, err := json.Marshal(c)
  113. assert.NoError(t, err)
  114. err = ioutil.WriteFile(configFilePath, jsonConf, 0666)
  115. assert.NoError(t, err)
  116. err = config.LoadConfig(configDir, tempConfigName)
  117. assert.NotNil(t, err)
  118. err = os.Remove(configFilePath)
  119. assert.NoError(t, err)
  120. }
  121. func TestInvalidProxyProtocol(t *testing.T) {
  122. configDir := ".."
  123. confName := tempConfigName + ".json"
  124. configFilePath := filepath.Join(configDir, confName)
  125. err := config.LoadConfig(configDir, configName)
  126. assert.NoError(t, err)
  127. sftpdConf := config.GetSFTPDConfig()
  128. sftpdConf.ProxyProtocol = 10
  129. c := make(map[string]sftpd.Configuration)
  130. c["sftpd"] = sftpdConf
  131. jsonConf, err := json.Marshal(c)
  132. assert.NoError(t, err)
  133. err = ioutil.WriteFile(configFilePath, jsonConf, 0666)
  134. assert.NoError(t, err)
  135. err = config.LoadConfig(configDir, tempConfigName)
  136. assert.NotNil(t, err)
  137. err = os.Remove(configFilePath)
  138. assert.NoError(t, err)
  139. }
  140. func TestInvalidUsersBaseDir(t *testing.T) {
  141. configDir := ".."
  142. confName := tempConfigName + ".json"
  143. configFilePath := filepath.Join(configDir, confName)
  144. err := config.LoadConfig(configDir, configName)
  145. assert.NoError(t, err)
  146. providerConf := config.GetProviderConf()
  147. providerConf.UsersBaseDir = "."
  148. c := make(map[string]dataprovider.Config)
  149. c["data_provider"] = providerConf
  150. jsonConf, err := json.Marshal(c)
  151. assert.NoError(t, err)
  152. err = ioutil.WriteFile(configFilePath, jsonConf, 0666)
  153. assert.NoError(t, err)
  154. err = config.LoadConfig(configDir, tempConfigName)
  155. assert.NotNil(t, err)
  156. err = os.Remove(configFilePath)
  157. assert.NoError(t, err)
  158. }
  159. func TestHookCompatibity(t *testing.T) {
  160. configDir := ".."
  161. confName := tempConfigName + ".json"
  162. configFilePath := filepath.Join(configDir, confName)
  163. err := config.LoadConfig(configDir, configName)
  164. assert.NoError(t, err)
  165. providerConf := config.GetProviderConf()
  166. providerConf.ExternalAuthProgram = "ext_auth_program" //nolint:staticcheck
  167. providerConf.PreLoginProgram = "pre_login_program" //nolint:staticcheck
  168. providerConf.Actions.Command = "/tmp/test_cmd" //nolint:staticcheck
  169. c := make(map[string]dataprovider.Config)
  170. c["data_provider"] = providerConf
  171. jsonConf, err := json.Marshal(c)
  172. assert.NoError(t, err)
  173. err = ioutil.WriteFile(configFilePath, jsonConf, 0666)
  174. assert.NoError(t, err)
  175. err = config.LoadConfig(configDir, tempConfigName)
  176. assert.NoError(t, err)
  177. providerConf = config.GetProviderConf()
  178. assert.Equal(t, "ext_auth_program", providerConf.ExternalAuthHook)
  179. assert.Equal(t, "pre_login_program", providerConf.PreLoginHook)
  180. assert.Equal(t, "/tmp/test_cmd", providerConf.Actions.Hook)
  181. err = os.Remove(configFilePath)
  182. assert.NoError(t, err)
  183. providerConf.Actions.Hook = ""
  184. providerConf.Actions.HTTPNotificationURL = "http://example.com/notify" //nolint:staticcheck
  185. c = make(map[string]dataprovider.Config)
  186. c["data_provider"] = providerConf
  187. jsonConf, err = json.Marshal(c)
  188. assert.NoError(t, err)
  189. err = ioutil.WriteFile(configFilePath, jsonConf, 0666)
  190. assert.NoError(t, err)
  191. err = config.LoadConfig(configDir, tempConfigName)
  192. assert.NoError(t, err)
  193. providerConf = config.GetProviderConf()
  194. assert.Equal(t, "http://example.com/notify", providerConf.Actions.Hook)
  195. err = os.Remove(configFilePath)
  196. assert.NoError(t, err)
  197. sftpdConf := config.GetSFTPDConfig()
  198. sftpdConf.KeyboardInteractiveProgram = "key_int_program" //nolint:staticcheck
  199. sftpdConf.Actions.Command = "/tmp/sftp_cmd" //nolint:staticcheck
  200. cnf := make(map[string]sftpd.Configuration)
  201. cnf["sftpd"] = sftpdConf
  202. jsonConf, err = json.Marshal(cnf)
  203. assert.NoError(t, err)
  204. err = ioutil.WriteFile(configFilePath, jsonConf, 0666)
  205. assert.NoError(t, err)
  206. err = config.LoadConfig(configDir, tempConfigName)
  207. assert.NoError(t, err)
  208. sftpdConf = config.GetSFTPDConfig()
  209. assert.Equal(t, "key_int_program", sftpdConf.KeyboardInteractiveHook)
  210. assert.Equal(t, "/tmp/sftp_cmd", sftpdConf.Actions.Hook)
  211. err = os.Remove(configFilePath)
  212. assert.NoError(t, err)
  213. sftpdConf.Actions.Hook = ""
  214. sftpdConf.Actions.HTTPNotificationURL = "http://example.com/sftp" //nolint:staticcheck
  215. cnf = make(map[string]sftpd.Configuration)
  216. cnf["sftpd"] = sftpdConf
  217. jsonConf, err = json.Marshal(cnf)
  218. assert.NoError(t, err)
  219. err = ioutil.WriteFile(configFilePath, jsonConf, 0666)
  220. assert.NoError(t, err)
  221. err = config.LoadConfig(configDir, tempConfigName)
  222. assert.NoError(t, err)
  223. sftpdConf = config.GetSFTPDConfig()
  224. assert.Equal(t, "http://example.com/sftp", sftpdConf.Actions.Hook)
  225. err = os.Remove(configFilePath)
  226. assert.NoError(t, err)
  227. }
  228. func TestHostKeyCompatibility(t *testing.T) {
  229. configDir := ".."
  230. confName := tempConfigName + ".json"
  231. configFilePath := filepath.Join(configDir, confName)
  232. err := config.LoadConfig(configDir, configName)
  233. assert.NoError(t, err)
  234. sftpdConf := config.GetSFTPDConfig()
  235. sftpdConf.Keys = []sftpd.Key{ //nolint:staticcheck
  236. {
  237. PrivateKey: "rsa",
  238. },
  239. {
  240. PrivateKey: "ecdsa",
  241. },
  242. }
  243. c := make(map[string]sftpd.Configuration)
  244. c["sftpd"] = sftpdConf
  245. jsonConf, err := json.Marshal(c)
  246. assert.NoError(t, err)
  247. err = ioutil.WriteFile(configFilePath, jsonConf, 0666)
  248. assert.NoError(t, err)
  249. err = config.LoadConfig(configDir, tempConfigName)
  250. assert.NoError(t, err)
  251. sftpdConf = config.GetSFTPDConfig()
  252. assert.Equal(t, 2, len(sftpdConf.HostKeys))
  253. assert.True(t, utils.IsStringInSlice("rsa", sftpdConf.HostKeys))
  254. assert.True(t, utils.IsStringInSlice("ecdsa", sftpdConf.HostKeys))
  255. err = os.Remove(configFilePath)
  256. assert.NoError(t, err)
  257. }
  258. func TestSetGetConfig(t *testing.T) {
  259. sftpdConf := config.GetSFTPDConfig()
  260. sftpdConf.IdleTimeout = 3
  261. config.SetSFTPDConfig(sftpdConf)
  262. assert.Equal(t, sftpdConf.IdleTimeout, config.GetSFTPDConfig().IdleTimeout)
  263. dataProviderConf := config.GetProviderConf()
  264. dataProviderConf.Host = "test host"
  265. config.SetProviderConf(dataProviderConf)
  266. assert.Equal(t, dataProviderConf.Host, config.GetProviderConf().Host)
  267. httpdConf := config.GetHTTPDConfig()
  268. httpdConf.BindAddress = "0.0.0.0"
  269. config.SetHTTPDConfig(httpdConf)
  270. assert.Equal(t, httpdConf.BindAddress, config.GetHTTPDConfig().BindAddress)
  271. }