httpclient.go 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. // Copyright (C) 2019-2023 Nicola Murino
  2. //
  3. // This program is free software: you can redistribute it and/or modify
  4. // it under the terms of the GNU Affero General Public License as published
  5. // by the Free Software Foundation, version 3.
  6. //
  7. // This program is distributed in the hope that it will be useful,
  8. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. // GNU Affero General Public License for more details.
  11. //
  12. // You should have received a copy of the GNU Affero General Public License
  13. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  14. // Package httpclient provides HTTP client configuration for SFTPGo hooks
  15. package httpclient
  16. import (
  17. "crypto/tls"
  18. "crypto/x509"
  19. "fmt"
  20. "io"
  21. "net/http"
  22. "os"
  23. "path/filepath"
  24. "strings"
  25. "time"
  26. "github.com/hashicorp/go-retryablehttp"
  27. "github.com/drakkan/sftpgo/v2/internal/logger"
  28. "github.com/drakkan/sftpgo/v2/internal/util"
  29. )
  30. // TLSKeyPair defines the paths for a TLS key pair
  31. type TLSKeyPair struct {
  32. Cert string `json:"cert" mapstructure:"cert"`
  33. Key string `json:"key" mapstructure:"key"`
  34. }
  35. // Header defines an HTTP header.
  36. // If the URL is not empty, the header is added only if the
  37. // requested URL starts with the one specified
  38. type Header struct {
  39. Key string `json:"key" mapstructure:"key"`
  40. Value string `json:"value" mapstructure:"value"`
  41. URL string `json:"url" mapstructure:"url"`
  42. }
  43. // Config defines the configuration for HTTP clients.
  44. // HTTP clients are used for executing hooks such as the ones used for
  45. // custom actions, external authentication and pre-login user modifications
  46. type Config struct {
  47. // Timeout specifies a time limit, in seconds, for a request
  48. Timeout float64 `json:"timeout" mapstructure:"timeout"`
  49. // RetryWaitMin defines the minimum waiting time between attempts in seconds
  50. RetryWaitMin int `json:"retry_wait_min" mapstructure:"retry_wait_min"`
  51. // RetryWaitMax defines the minimum waiting time between attempts in seconds
  52. RetryWaitMax int `json:"retry_wait_max" mapstructure:"retry_wait_max"`
  53. // RetryMax defines the maximum number of attempts
  54. RetryMax int `json:"retry_max" mapstructure:"retry_max"`
  55. // CACertificates defines extra CA certificates to trust.
  56. // The paths can be absolute or relative to the config dir.
  57. // Adding trusted CA certificates is a convenient way to use self-signed
  58. // certificates without defeating the purpose of using TLS
  59. CACertificates []string `json:"ca_certificates" mapstructure:"ca_certificates"`
  60. // Certificates defines the certificates to use for mutual TLS
  61. Certificates []TLSKeyPair `json:"certificates" mapstructure:"certificates"`
  62. // if enabled the HTTP client accepts any TLS certificate presented by
  63. // the server and any host name in that certificate.
  64. // In this mode, TLS is susceptible to man-in-the-middle attacks.
  65. // This should be used only for testing.
  66. SkipTLSVerify bool `json:"skip_tls_verify" mapstructure:"skip_tls_verify"`
  67. // Headers defines a list of http headers to add to each request
  68. Headers []Header `json:"headers" mapstructure:"headers"`
  69. customTransport *http.Transport
  70. }
  71. const logSender = "httpclient"
  72. var httpConfig Config
  73. // Initialize configures HTTP clients
  74. func (c *Config) Initialize(configDir string) error {
  75. if c.Timeout <= 0 {
  76. return fmt.Errorf("invalid timeout: %v", c.Timeout)
  77. }
  78. rootCAs, err := c.loadCACerts(configDir)
  79. if err != nil {
  80. return err
  81. }
  82. customTransport := http.DefaultTransport.(*http.Transport).Clone()
  83. if customTransport.TLSClientConfig != nil {
  84. customTransport.TLSClientConfig.RootCAs = rootCAs
  85. } else {
  86. customTransport.TLSClientConfig = &tls.Config{
  87. RootCAs: rootCAs,
  88. NextProtos: []string{"h2", "http/1.1"},
  89. }
  90. }
  91. customTransport.TLSClientConfig.InsecureSkipVerify = c.SkipTLSVerify
  92. c.customTransport = customTransport
  93. err = c.loadCertificates(configDir)
  94. if err != nil {
  95. return err
  96. }
  97. var headers []Header
  98. for _, h := range c.Headers {
  99. if h.Key != "" && h.Value != "" {
  100. headers = append(headers, h)
  101. }
  102. }
  103. c.Headers = headers
  104. httpConfig = *c
  105. return nil
  106. }
  107. // loadCACerts returns system cert pools and try to add the configured
  108. // CA certificates to it
  109. func (c *Config) loadCACerts(configDir string) (*x509.CertPool, error) {
  110. if len(c.CACertificates) == 0 {
  111. return nil, nil
  112. }
  113. rootCAs, err := x509.SystemCertPool()
  114. if err != nil {
  115. rootCAs = x509.NewCertPool()
  116. }
  117. for _, ca := range c.CACertificates {
  118. if !util.IsFileInputValid(ca) {
  119. return nil, fmt.Errorf("unable to load invalid CA certificate: %q", ca)
  120. }
  121. if !filepath.IsAbs(ca) {
  122. ca = filepath.Join(configDir, ca)
  123. }
  124. certs, err := os.ReadFile(ca)
  125. if err != nil {
  126. return nil, fmt.Errorf("unable to load CA certificate: %v", err)
  127. }
  128. if rootCAs.AppendCertsFromPEM(certs) {
  129. logger.Debug(logSender, "", "CA certificate %q added to the trusted certificates", ca)
  130. } else {
  131. return nil, fmt.Errorf("unable to add CA certificate %q to the trusted cetificates", ca)
  132. }
  133. }
  134. return rootCAs, nil
  135. }
  136. func (c *Config) loadCertificates(configDir string) error {
  137. if len(c.Certificates) == 0 {
  138. return nil
  139. }
  140. for _, keyPair := range c.Certificates {
  141. cert := keyPair.Cert
  142. key := keyPair.Key
  143. if !util.IsFileInputValid(cert) {
  144. return fmt.Errorf("unable to load invalid certificate: %q", cert)
  145. }
  146. if !util.IsFileInputValid(key) {
  147. return fmt.Errorf("unable to load invalid key: %q", key)
  148. }
  149. if !filepath.IsAbs(cert) {
  150. cert = filepath.Join(configDir, cert)
  151. }
  152. if !filepath.IsAbs(key) {
  153. key = filepath.Join(configDir, key)
  154. }
  155. tlsCert, err := tls.LoadX509KeyPair(cert, key)
  156. if err != nil {
  157. return fmt.Errorf("unable to load key pair %q, %q: %v", cert, key, err)
  158. }
  159. x509Cert, err := x509.ParseCertificate(tlsCert.Certificate[0])
  160. if err == nil {
  161. logger.Debug(logSender, "", "adding leaf certificate for key pair %q, %q", cert, key)
  162. tlsCert.Leaf = x509Cert
  163. }
  164. logger.Debug(logSender, "", "client certificate %q and key %q successfully loaded", cert, key)
  165. c.customTransport.TLSClientConfig.Certificates = append(c.customTransport.TLSClientConfig.Certificates, tlsCert)
  166. }
  167. return nil
  168. }
  169. // GetHTTPClient returns a new HTTP client with the configured parameters
  170. func GetHTTPClient() *http.Client {
  171. return &http.Client{
  172. Timeout: time.Duration(httpConfig.Timeout * float64(time.Second)),
  173. Transport: httpConfig.customTransport,
  174. }
  175. }
  176. // GetRetraybleHTTPClient returns an HTTP client that retry a request on error.
  177. // It uses the configured retry parameters
  178. func GetRetraybleHTTPClient() *retryablehttp.Client {
  179. client := retryablehttp.NewClient()
  180. client.HTTPClient.Timeout = time.Duration(httpConfig.Timeout * float64(time.Second))
  181. client.HTTPClient.Transport.(*http.Transport).TLSClientConfig = httpConfig.customTransport.TLSClientConfig
  182. client.Logger = &logger.LeveledLogger{Sender: "RetryableHTTPClient"}
  183. client.RetryWaitMin = time.Duration(httpConfig.RetryWaitMin) * time.Second
  184. client.RetryWaitMax = time.Duration(httpConfig.RetryWaitMax) * time.Second
  185. client.RetryMax = httpConfig.RetryMax
  186. return client
  187. }
  188. // Get issues a GET to the specified URL
  189. func Get(url string) (*http.Response, error) {
  190. req, err := http.NewRequest(http.MethodGet, url, nil)
  191. if err != nil {
  192. return nil, err
  193. }
  194. addHeaders(req, url)
  195. client := GetHTTPClient()
  196. defer client.CloseIdleConnections()
  197. return client.Do(req)
  198. }
  199. // Post issues a POST to the specified URL
  200. func Post(url string, contentType string, body io.Reader) (*http.Response, error) {
  201. req, err := http.NewRequest(http.MethodPost, url, body)
  202. if err != nil {
  203. return nil, err
  204. }
  205. req.Header.Set("Content-Type", contentType)
  206. addHeaders(req, url)
  207. client := GetHTTPClient()
  208. defer client.CloseIdleConnections()
  209. return client.Do(req)
  210. }
  211. // RetryableGet issues a GET to the specified URL using the retryable client
  212. func RetryableGet(url string) (*http.Response, error) {
  213. req, err := retryablehttp.NewRequest(http.MethodGet, url, nil)
  214. if err != nil {
  215. return nil, err
  216. }
  217. addHeadersToRetryableReq(req, url)
  218. client := GetRetraybleHTTPClient()
  219. defer client.HTTPClient.CloseIdleConnections()
  220. return client.Do(req)
  221. }
  222. // RetryablePost issues a POST to the specified URL using the retryable client
  223. func RetryablePost(url string, contentType string, body io.Reader) (*http.Response, error) {
  224. req, err := retryablehttp.NewRequest(http.MethodPost, url, body)
  225. if err != nil {
  226. return nil, err
  227. }
  228. req.Header.Set("Content-Type", contentType)
  229. addHeadersToRetryableReq(req, url)
  230. client := GetRetraybleHTTPClient()
  231. defer client.HTTPClient.CloseIdleConnections()
  232. return client.Do(req)
  233. }
  234. func addHeaders(req *http.Request, url string) {
  235. for idx := range httpConfig.Headers {
  236. h := &httpConfig.Headers[idx]
  237. if h.URL == "" || strings.HasPrefix(url, h.URL) {
  238. req.Header.Set(h.Key, h.Value)
  239. }
  240. }
  241. }
  242. func addHeadersToRetryableReq(req *retryablehttp.Request, url string) {
  243. for idx := range httpConfig.Headers {
  244. h := &httpConfig.Headers[idx]
  245. if h.URL == "" || strings.HasPrefix(url, h.URL) {
  246. req.Header.Set(h.Key, h.Value)
  247. }
  248. }
  249. }