config_test.go 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. package config_test
  2. import (
  3. "encoding/json"
  4. "io/ioutil"
  5. "os"
  6. "path/filepath"
  7. "strings"
  8. "testing"
  9. "github.com/stretchr/testify/assert"
  10. "github.com/drakkan/sftpgo/config"
  11. "github.com/drakkan/sftpgo/dataprovider"
  12. "github.com/drakkan/sftpgo/httpclient"
  13. "github.com/drakkan/sftpgo/httpd"
  14. "github.com/drakkan/sftpgo/sftpd"
  15. "github.com/drakkan/sftpgo/utils"
  16. )
  17. const (
  18. tempConfigName = "temp"
  19. configName = "sftpgo"
  20. )
  21. func TestLoadConfigTest(t *testing.T) {
  22. configDir := ".."
  23. err := config.LoadConfig(configDir, configName)
  24. assert.NoError(t, err)
  25. assert.NotEqual(t, httpd.Conf{}, config.GetHTTPConfig())
  26. assert.NotEqual(t, dataprovider.Config{}, config.GetProviderConf())
  27. assert.NotEqual(t, sftpd.Configuration{}, config.GetSFTPDConfig())
  28. assert.NotEqual(t, httpclient.Config{}, config.GetHTTPConfig())
  29. confName := tempConfigName + ".json"
  30. configFilePath := filepath.Join(configDir, confName)
  31. err = config.LoadConfig(configDir, tempConfigName)
  32. assert.NotNil(t, err)
  33. err = ioutil.WriteFile(configFilePath, []byte("{invalid json}"), 0666)
  34. assert.NoError(t, err)
  35. err = config.LoadConfig(configDir, tempConfigName)
  36. assert.NotNil(t, err)
  37. err = ioutil.WriteFile(configFilePath, []byte("{\"sftpd\": {\"bind_port\": \"a\"}}"), 0666)
  38. assert.NoError(t, err)
  39. err = config.LoadConfig(configDir, tempConfigName)
  40. assert.NotNil(t, err)
  41. err = os.Remove(configFilePath)
  42. assert.NoError(t, err)
  43. }
  44. func TestEmptyBanner(t *testing.T) {
  45. configDir := ".."
  46. confName := tempConfigName + ".json"
  47. configFilePath := filepath.Join(configDir, confName)
  48. err := config.LoadConfig(configDir, configName)
  49. assert.NoError(t, err)
  50. sftpdConf := config.GetSFTPDConfig()
  51. sftpdConf.Banner = " "
  52. c := make(map[string]sftpd.Configuration)
  53. c["sftpd"] = sftpdConf
  54. jsonConf, _ := json.Marshal(c)
  55. err = ioutil.WriteFile(configFilePath, jsonConf, 0666)
  56. assert.NoError(t, err)
  57. err = config.LoadConfig(configDir, tempConfigName)
  58. assert.NoError(t, err)
  59. sftpdConf = config.GetSFTPDConfig()
  60. assert.NotEmpty(t, strings.TrimSpace(sftpdConf.Banner))
  61. err = os.Remove(configFilePath)
  62. assert.NoError(t, err)
  63. }
  64. func TestInvalidUploadMode(t *testing.T) {
  65. configDir := ".."
  66. confName := tempConfigName + ".json"
  67. configFilePath := filepath.Join(configDir, confName)
  68. err := config.LoadConfig(configDir, configName)
  69. assert.NoError(t, err)
  70. sftpdConf := config.GetSFTPDConfig()
  71. sftpdConf.UploadMode = 10
  72. c := make(map[string]sftpd.Configuration)
  73. c["sftpd"] = sftpdConf
  74. jsonConf, err := json.Marshal(c)
  75. assert.NoError(t, err)
  76. err = ioutil.WriteFile(configFilePath, jsonConf, 0666)
  77. assert.NoError(t, err)
  78. err = config.LoadConfig(configDir, tempConfigName)
  79. assert.NotNil(t, err)
  80. err = os.Remove(configFilePath)
  81. assert.NoError(t, err)
  82. }
  83. func TestInvalidExternalAuthScope(t *testing.T) {
  84. configDir := ".."
  85. confName := tempConfigName + ".json"
  86. configFilePath := filepath.Join(configDir, confName)
  87. err := config.LoadConfig(configDir, configName)
  88. assert.NoError(t, err)
  89. providerConf := config.GetProviderConf()
  90. providerConf.ExternalAuthScope = 10
  91. c := make(map[string]dataprovider.Config)
  92. c["data_provider"] = providerConf
  93. jsonConf, err := json.Marshal(c)
  94. assert.NoError(t, err)
  95. err = ioutil.WriteFile(configFilePath, jsonConf, 0666)
  96. assert.NoError(t, err)
  97. err = config.LoadConfig(configDir, tempConfigName)
  98. assert.NotNil(t, err)
  99. err = os.Remove(configFilePath)
  100. assert.NoError(t, err)
  101. }
  102. func TestInvalidCredentialsPath(t *testing.T) {
  103. configDir := ".."
  104. confName := tempConfigName + ".json"
  105. configFilePath := filepath.Join(configDir, confName)
  106. err := config.LoadConfig(configDir, configName)
  107. assert.NoError(t, err)
  108. providerConf := config.GetProviderConf()
  109. providerConf.CredentialsPath = ""
  110. c := make(map[string]dataprovider.Config)
  111. c["data_provider"] = providerConf
  112. jsonConf, err := json.Marshal(c)
  113. assert.NoError(t, err)
  114. err = ioutil.WriteFile(configFilePath, jsonConf, 0666)
  115. assert.NoError(t, err)
  116. err = config.LoadConfig(configDir, tempConfigName)
  117. assert.NotNil(t, err)
  118. err = os.Remove(configFilePath)
  119. assert.NoError(t, err)
  120. }
  121. func TestInvalidProxyProtocol(t *testing.T) {
  122. configDir := ".."
  123. confName := tempConfigName + ".json"
  124. configFilePath := filepath.Join(configDir, confName)
  125. err := config.LoadConfig(configDir, configName)
  126. assert.NoError(t, err)
  127. sftpdConf := config.GetSFTPDConfig()
  128. sftpdConf.ProxyProtocol = 10
  129. c := make(map[string]sftpd.Configuration)
  130. c["sftpd"] = sftpdConf
  131. jsonConf, err := json.Marshal(c)
  132. assert.NoError(t, err)
  133. err = ioutil.WriteFile(configFilePath, jsonConf, 0666)
  134. assert.NoError(t, err)
  135. err = config.LoadConfig(configDir, tempConfigName)
  136. assert.NotNil(t, err)
  137. err = os.Remove(configFilePath)
  138. assert.NoError(t, err)
  139. }
  140. func TestInvalidUsersBaseDir(t *testing.T) {
  141. configDir := ".."
  142. confName := tempConfigName + ".json"
  143. configFilePath := filepath.Join(configDir, confName)
  144. err := config.LoadConfig(configDir, configName)
  145. assert.NoError(t, err)
  146. providerConf := config.GetProviderConf()
  147. providerConf.UsersBaseDir = "."
  148. c := make(map[string]dataprovider.Config)
  149. c["data_provider"] = providerConf
  150. jsonConf, err := json.Marshal(c)
  151. assert.NoError(t, err)
  152. err = ioutil.WriteFile(configFilePath, jsonConf, 0666)
  153. assert.NoError(t, err)
  154. err = config.LoadConfig(configDir, tempConfigName)
  155. assert.NotNil(t, err)
  156. err = os.Remove(configFilePath)
  157. assert.NoError(t, err)
  158. }
  159. func TestHookCompatibity(t *testing.T) {
  160. configDir := ".."
  161. confName := tempConfigName + ".json"
  162. configFilePath := filepath.Join(configDir, confName)
  163. err := config.LoadConfig(configDir, configName)
  164. assert.NoError(t, err)
  165. providerConf := config.GetProviderConf()
  166. providerConf.ExternalAuthProgram = "ext_auth_program" //nolint:staticcheck
  167. providerConf.PreLoginProgram = "pre_login_program" //nolint:staticcheck
  168. c := make(map[string]dataprovider.Config)
  169. c["data_provider"] = providerConf
  170. jsonConf, err := json.Marshal(c)
  171. assert.NoError(t, err)
  172. err = ioutil.WriteFile(configFilePath, jsonConf, 0666)
  173. assert.NoError(t, err)
  174. err = config.LoadConfig(configDir, tempConfigName)
  175. assert.NoError(t, err)
  176. providerConf = config.GetProviderConf()
  177. assert.Equal(t, "ext_auth_program", providerConf.ExternalAuthHook)
  178. assert.Equal(t, "pre_login_program", providerConf.PreLoginHook)
  179. err = os.Remove(configFilePath)
  180. assert.NoError(t, err)
  181. sftpdConf := config.GetSFTPDConfig()
  182. sftpdConf.KeyboardInteractiveProgram = "key_int_program" //nolint:staticcheck
  183. cnf := make(map[string]sftpd.Configuration)
  184. cnf["sftpd"] = sftpdConf
  185. jsonConf, err = json.Marshal(cnf)
  186. assert.NoError(t, err)
  187. err = ioutil.WriteFile(configFilePath, jsonConf, 0666)
  188. assert.NoError(t, err)
  189. err = config.LoadConfig(configDir, tempConfigName)
  190. assert.NoError(t, err)
  191. sftpdConf = config.GetSFTPDConfig()
  192. assert.Equal(t, "key_int_program", sftpdConf.KeyboardInteractiveHook)
  193. err = os.Remove(configFilePath)
  194. assert.NoError(t, err)
  195. }
  196. func TestHostKeyCompatibility(t *testing.T) {
  197. configDir := ".."
  198. confName := tempConfigName + ".json"
  199. configFilePath := filepath.Join(configDir, confName)
  200. err := config.LoadConfig(configDir, configName)
  201. assert.NoError(t, err)
  202. sftpdConf := config.GetSFTPDConfig()
  203. sftpdConf.Keys = []sftpd.Key{ //nolint:staticcheck
  204. {
  205. PrivateKey: "rsa",
  206. },
  207. {
  208. PrivateKey: "ecdsa",
  209. },
  210. }
  211. c := make(map[string]sftpd.Configuration)
  212. c["sftpd"] = sftpdConf
  213. jsonConf, err := json.Marshal(c)
  214. assert.NoError(t, err)
  215. err = ioutil.WriteFile(configFilePath, jsonConf, 0666)
  216. assert.NoError(t, err)
  217. err = config.LoadConfig(configDir, tempConfigName)
  218. assert.NoError(t, err)
  219. sftpdConf = config.GetSFTPDConfig()
  220. assert.Equal(t, 2, len(sftpdConf.HostKeys))
  221. assert.True(t, utils.IsStringInSlice("rsa", sftpdConf.HostKeys))
  222. assert.True(t, utils.IsStringInSlice("ecdsa", sftpdConf.HostKeys))
  223. err = os.Remove(configFilePath)
  224. assert.NoError(t, err)
  225. }
  226. func TestSetGetConfig(t *testing.T) {
  227. sftpdConf := config.GetSFTPDConfig()
  228. sftpdConf.IdleTimeout = 3
  229. config.SetSFTPDConfig(sftpdConf)
  230. assert.Equal(t, sftpdConf.IdleTimeout, config.GetSFTPDConfig().IdleTimeout)
  231. dataProviderConf := config.GetProviderConf()
  232. dataProviderConf.Host = "test host"
  233. config.SetProviderConf(dataProviderConf)
  234. assert.Equal(t, dataProviderConf.Host, config.GetProviderConf().Host)
  235. httpdConf := config.GetHTTPDConfig()
  236. httpdConf.BindAddress = "0.0.0.0"
  237. config.SetHTTPDConfig(httpdConf)
  238. assert.Equal(t, httpdConf.BindAddress, config.GetHTTPDConfig().BindAddress)
  239. }