config_test.go 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. package config_test
  2. import (
  3. "encoding/json"
  4. "io/ioutil"
  5. "os"
  6. "path/filepath"
  7. "strings"
  8. "testing"
  9. "github.com/stretchr/testify/assert"
  10. "github.com/drakkan/sftpgo/common"
  11. "github.com/drakkan/sftpgo/config"
  12. "github.com/drakkan/sftpgo/dataprovider"
  13. "github.com/drakkan/sftpgo/ftpd"
  14. "github.com/drakkan/sftpgo/httpclient"
  15. "github.com/drakkan/sftpgo/httpd"
  16. "github.com/drakkan/sftpgo/sftpd"
  17. "github.com/drakkan/sftpgo/utils"
  18. )
  19. const (
  20. tempConfigName = "temp"
  21. configName = "sftpgo"
  22. )
  23. func TestLoadConfigTest(t *testing.T) {
  24. configDir := ".."
  25. err := config.LoadConfig(configDir, configName)
  26. assert.NoError(t, err)
  27. assert.NotEqual(t, httpd.Conf{}, config.GetHTTPConfig())
  28. assert.NotEqual(t, dataprovider.Config{}, config.GetProviderConf())
  29. assert.NotEqual(t, sftpd.Configuration{}, config.GetSFTPDConfig())
  30. assert.NotEqual(t, httpclient.Config{}, config.GetHTTPConfig())
  31. confName := tempConfigName + ".json"
  32. configFilePath := filepath.Join(configDir, confName)
  33. err = config.LoadConfig(configDir, tempConfigName)
  34. assert.NotNil(t, err)
  35. err = ioutil.WriteFile(configFilePath, []byte("{invalid json}"), os.ModePerm)
  36. assert.NoError(t, err)
  37. err = config.LoadConfig(configDir, tempConfigName)
  38. assert.NotNil(t, err)
  39. err = ioutil.WriteFile(configFilePath, []byte("{\"sftpd\": {\"bind_port\": \"a\"}}"), os.ModePerm)
  40. assert.NoError(t, err)
  41. err = config.LoadConfig(configDir, tempConfigName)
  42. assert.NotNil(t, err)
  43. err = os.Remove(configFilePath)
  44. assert.NoError(t, err)
  45. }
  46. func TestEmptyBanner(t *testing.T) {
  47. configDir := ".."
  48. confName := tempConfigName + ".json"
  49. configFilePath := filepath.Join(configDir, confName)
  50. err := config.LoadConfig(configDir, configName)
  51. assert.NoError(t, err)
  52. sftpdConf := config.GetSFTPDConfig()
  53. sftpdConf.Banner = " "
  54. c := make(map[string]sftpd.Configuration)
  55. c["sftpd"] = sftpdConf
  56. jsonConf, _ := json.Marshal(c)
  57. err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm)
  58. assert.NoError(t, err)
  59. err = config.LoadConfig(configDir, tempConfigName)
  60. assert.NoError(t, err)
  61. sftpdConf = config.GetSFTPDConfig()
  62. assert.NotEmpty(t, strings.TrimSpace(sftpdConf.Banner))
  63. err = os.Remove(configFilePath)
  64. assert.NoError(t, err)
  65. ftpdConf := config.GetFTPDConfig()
  66. ftpdConf.Banner = " "
  67. c1 := make(map[string]ftpd.Configuration)
  68. c1["ftpd"] = ftpdConf
  69. jsonConf, _ = json.Marshal(c1)
  70. err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm)
  71. assert.NoError(t, err)
  72. err = config.LoadConfig(configDir, tempConfigName)
  73. assert.NoError(t, err)
  74. ftpdConf = config.GetFTPDConfig()
  75. assert.NotEmpty(t, strings.TrimSpace(ftpdConf.Banner))
  76. err = os.Remove(configFilePath)
  77. assert.NoError(t, err)
  78. }
  79. func TestInvalidUploadMode(t *testing.T) {
  80. configDir := ".."
  81. confName := tempConfigName + ".json"
  82. configFilePath := filepath.Join(configDir, confName)
  83. err := config.LoadConfig(configDir, configName)
  84. assert.NoError(t, err)
  85. commonConf := config.GetCommonConfig()
  86. commonConf.UploadMode = 10
  87. c := make(map[string]common.Configuration)
  88. c["common"] = commonConf
  89. jsonConf, err := json.Marshal(c)
  90. assert.NoError(t, err)
  91. err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm)
  92. assert.NoError(t, err)
  93. err = config.LoadConfig(configDir, tempConfigName)
  94. assert.NotNil(t, err)
  95. err = os.Remove(configFilePath)
  96. assert.NoError(t, err)
  97. }
  98. func TestInvalidExternalAuthScope(t *testing.T) {
  99. configDir := ".."
  100. confName := tempConfigName + ".json"
  101. configFilePath := filepath.Join(configDir, confName)
  102. err := config.LoadConfig(configDir, configName)
  103. assert.NoError(t, err)
  104. providerConf := config.GetProviderConf()
  105. providerConf.ExternalAuthScope = 10
  106. c := make(map[string]dataprovider.Config)
  107. c["data_provider"] = providerConf
  108. jsonConf, err := json.Marshal(c)
  109. assert.NoError(t, err)
  110. err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm)
  111. assert.NoError(t, err)
  112. err = config.LoadConfig(configDir, tempConfigName)
  113. assert.NotNil(t, err)
  114. err = os.Remove(configFilePath)
  115. assert.NoError(t, err)
  116. }
  117. func TestInvalidCredentialsPath(t *testing.T) {
  118. configDir := ".."
  119. confName := tempConfigName + ".json"
  120. configFilePath := filepath.Join(configDir, confName)
  121. err := config.LoadConfig(configDir, configName)
  122. assert.NoError(t, err)
  123. providerConf := config.GetProviderConf()
  124. providerConf.CredentialsPath = ""
  125. c := make(map[string]dataprovider.Config)
  126. c["data_provider"] = providerConf
  127. jsonConf, err := json.Marshal(c)
  128. assert.NoError(t, err)
  129. err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm)
  130. assert.NoError(t, err)
  131. err = config.LoadConfig(configDir, tempConfigName)
  132. assert.NotNil(t, err)
  133. err = os.Remove(configFilePath)
  134. assert.NoError(t, err)
  135. }
  136. func TestInvalidProxyProtocol(t *testing.T) {
  137. configDir := ".."
  138. confName := tempConfigName + ".json"
  139. configFilePath := filepath.Join(configDir, confName)
  140. err := config.LoadConfig(configDir, configName)
  141. assert.NoError(t, err)
  142. commonConf := config.GetCommonConfig()
  143. commonConf.ProxyProtocol = 10
  144. c := make(map[string]common.Configuration)
  145. c["common"] = commonConf
  146. jsonConf, err := json.Marshal(c)
  147. assert.NoError(t, err)
  148. err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm)
  149. assert.NoError(t, err)
  150. err = config.LoadConfig(configDir, tempConfigName)
  151. assert.NotNil(t, err)
  152. err = os.Remove(configFilePath)
  153. assert.NoError(t, err)
  154. }
  155. func TestInvalidUsersBaseDir(t *testing.T) {
  156. configDir := ".."
  157. confName := tempConfigName + ".json"
  158. configFilePath := filepath.Join(configDir, confName)
  159. err := config.LoadConfig(configDir, configName)
  160. assert.NoError(t, err)
  161. providerConf := config.GetProviderConf()
  162. providerConf.UsersBaseDir = "."
  163. c := make(map[string]dataprovider.Config)
  164. c["data_provider"] = providerConf
  165. jsonConf, err := json.Marshal(c)
  166. assert.NoError(t, err)
  167. err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm)
  168. assert.NoError(t, err)
  169. err = config.LoadConfig(configDir, tempConfigName)
  170. assert.NotNil(t, err)
  171. err = os.Remove(configFilePath)
  172. assert.NoError(t, err)
  173. }
  174. func TestCommonParamsCompatibility(t *testing.T) {
  175. configDir := ".."
  176. confName := tempConfigName + ".json"
  177. configFilePath := filepath.Join(configDir, confName)
  178. err := config.LoadConfig(configDir, configName)
  179. assert.NoError(t, err)
  180. sftpdConf := config.GetSFTPDConfig()
  181. sftpdConf.IdleTimeout = 21 //nolint:staticcheck
  182. sftpdConf.Actions.Hook = "http://hook"
  183. sftpdConf.Actions.ExecuteOn = []string{"upload"}
  184. sftpdConf.SetstatMode = 1 //nolint:staticcheck
  185. sftpdConf.UploadMode = common.UploadModeAtomicWithResume //nolint:staticcheck
  186. sftpdConf.ProxyProtocol = 1 //nolint:staticcheck
  187. sftpdConf.ProxyAllowed = []string{"192.168.1.1"} //nolint:staticcheck
  188. c := make(map[string]sftpd.Configuration)
  189. c["sftpd"] = sftpdConf
  190. jsonConf, err := json.Marshal(c)
  191. assert.NoError(t, err)
  192. err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm)
  193. assert.NoError(t, err)
  194. err = config.LoadConfig(configDir, tempConfigName)
  195. assert.NoError(t, err)
  196. commonConf := config.GetCommonConfig()
  197. assert.Equal(t, 21, commonConf.IdleTimeout)
  198. assert.Equal(t, "http://hook", commonConf.Actions.Hook)
  199. assert.Len(t, commonConf.Actions.ExecuteOn, 1)
  200. assert.True(t, utils.IsStringInSlice("upload", commonConf.Actions.ExecuteOn))
  201. assert.Equal(t, 1, commonConf.SetstatMode)
  202. assert.Equal(t, 1, commonConf.ProxyProtocol)
  203. assert.Len(t, commonConf.ProxyAllowed, 1)
  204. assert.True(t, utils.IsStringInSlice("192.168.1.1", commonConf.ProxyAllowed))
  205. err = os.Remove(configFilePath)
  206. assert.NoError(t, err)
  207. }
  208. func TestHostKeyCompatibility(t *testing.T) {
  209. configDir := ".."
  210. confName := tempConfigName + ".json"
  211. configFilePath := filepath.Join(configDir, confName)
  212. err := config.LoadConfig(configDir, configName)
  213. assert.NoError(t, err)
  214. sftpdConf := config.GetSFTPDConfig()
  215. sftpdConf.Keys = []sftpd.Key{ //nolint:staticcheck
  216. {
  217. PrivateKey: "rsa",
  218. },
  219. {
  220. PrivateKey: "ecdsa",
  221. },
  222. }
  223. c := make(map[string]sftpd.Configuration)
  224. c["sftpd"] = sftpdConf
  225. jsonConf, err := json.Marshal(c)
  226. assert.NoError(t, err)
  227. err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm)
  228. assert.NoError(t, err)
  229. err = config.LoadConfig(configDir, tempConfigName)
  230. assert.NoError(t, err)
  231. sftpdConf = config.GetSFTPDConfig()
  232. assert.Equal(t, 2, len(sftpdConf.HostKeys))
  233. assert.True(t, utils.IsStringInSlice("rsa", sftpdConf.HostKeys))
  234. assert.True(t, utils.IsStringInSlice("ecdsa", sftpdConf.HostKeys))
  235. err = os.Remove(configFilePath)
  236. assert.NoError(t, err)
  237. }
  238. func TestSetGetConfig(t *testing.T) {
  239. sftpdConf := config.GetSFTPDConfig()
  240. sftpdConf.MaxAuthTries = 10
  241. config.SetSFTPDConfig(sftpdConf)
  242. assert.Equal(t, sftpdConf.MaxAuthTries, config.GetSFTPDConfig().MaxAuthTries)
  243. dataProviderConf := config.GetProviderConf()
  244. dataProviderConf.Host = "test host"
  245. config.SetProviderConf(dataProviderConf)
  246. assert.Equal(t, dataProviderConf.Host, config.GetProviderConf().Host)
  247. httpdConf := config.GetHTTPDConfig()
  248. httpdConf.BindAddress = "0.0.0.0"
  249. config.SetHTTPDConfig(httpdConf)
  250. assert.Equal(t, httpdConf.BindAddress, config.GetHTTPDConfig().BindAddress)
  251. commonConf := config.GetCommonConfig()
  252. commonConf.IdleTimeout = 10
  253. config.SetCommonConfig(commonConf)
  254. assert.Equal(t, commonConf.IdleTimeout, config.GetCommonConfig().IdleTimeout)
  255. ftpdConf := config.GetFTPDConfig()
  256. ftpdConf.CertificateFile = "cert"
  257. ftpdConf.CertificateKeyFile = "key"
  258. config.SetFTPDConfig(ftpdConf)
  259. assert.Equal(t, ftpdConf.CertificateFile, config.GetFTPDConfig().CertificateFile)
  260. assert.Equal(t, ftpdConf.CertificateKeyFile, config.GetFTPDConfig().CertificateKeyFile)
  261. }