node.go 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. // Copyright (C) 2019 Nicola Murino
  2. //
  3. // This program is free software: you can redistribute it and/or modify
  4. // it under the terms of the GNU Affero General Public License as published
  5. // by the Free Software Foundation, version 3.
  6. //
  7. // This program is distributed in the hope that it will be useful,
  8. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. // GNU Affero General Public License for more details.
  11. //
  12. // You should have received a copy of the GNU Affero General Public License
  13. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  14. package dataprovider
  15. import (
  16. "bytes"
  17. "context"
  18. "crypto/sha256"
  19. "encoding/hex"
  20. "encoding/json"
  21. "errors"
  22. "fmt"
  23. "io"
  24. "net/http"
  25. "strconv"
  26. "strings"
  27. "time"
  28. "github.com/lestrrat-go/jwx/v2/jwa"
  29. "github.com/lestrrat-go/jwx/v2/jwt"
  30. "github.com/rs/xid"
  31. "github.com/drakkan/sftpgo/v2/internal/httpclient"
  32. "github.com/drakkan/sftpgo/v2/internal/kms"
  33. "github.com/drakkan/sftpgo/v2/internal/logger"
  34. "github.com/drakkan/sftpgo/v2/internal/util"
  35. )
  36. // Supported protocols for connecting to other nodes
  37. const (
  38. NodeProtoHTTP = "http"
  39. NodeProtoHTTPS = "https"
  40. )
  41. const (
  42. // NodeTokenHeader defines the header to use for the node auth token
  43. NodeTokenHeader = "X-SFTPGO-Node"
  44. )
  45. var (
  46. // current node
  47. currentNode *Node
  48. errNoClusterNodes = errors.New("no cluster node defined")
  49. activeNodeTimeDiff = -2 * time.Minute
  50. nodeReqTimeout = 8 * time.Second
  51. )
  52. // NodeConfig defines the node configuration
  53. type NodeConfig struct {
  54. Host string `json:"host" mapstructure:"host"`
  55. Port int `json:"port" mapstructure:"port"`
  56. Proto string `json:"proto" mapstructure:"proto"`
  57. }
  58. func (n *NodeConfig) validate() error {
  59. currentNode = nil
  60. if config.IsShared != 1 {
  61. return nil
  62. }
  63. if n.Host == "" {
  64. return nil
  65. }
  66. currentNode = &Node{
  67. Data: NodeData{
  68. Host: n.Host,
  69. Port: n.Port,
  70. Proto: n.Proto,
  71. },
  72. }
  73. return provider.addNode()
  74. }
  75. // NodeData defines the details to connect to a cluster node
  76. type NodeData struct {
  77. Host string `json:"host"`
  78. Port int `json:"port"`
  79. Proto string `json:"proto"`
  80. Key *kms.Secret `json:"api_key"`
  81. }
  82. func (n *NodeData) validate() error {
  83. if n.Host == "" {
  84. return util.NewValidationError("node host is mandatory")
  85. }
  86. if n.Port < 0 || n.Port > 65535 {
  87. return util.NewValidationError(fmt.Sprintf("invalid node port: %d", n.Port))
  88. }
  89. if n.Proto != NodeProtoHTTP && n.Proto != NodeProtoHTTPS {
  90. return util.NewValidationError(fmt.Sprintf("invalid node proto: %s", n.Proto))
  91. }
  92. n.Key = kms.NewPlainSecret(util.GenerateOpaqueString())
  93. n.Key.SetAdditionalData(n.Host)
  94. if err := n.Key.Encrypt(); err != nil {
  95. return fmt.Errorf("unable to encrypt node key: %w", err)
  96. }
  97. return nil
  98. }
  99. func (n *NodeData) getNodeName() string {
  100. h := sha256.New()
  101. var b bytes.Buffer
  102. b.WriteString(fmt.Sprintf("%s:%d", n.Host, n.Port))
  103. h.Write(b.Bytes())
  104. return hex.EncodeToString(h.Sum(nil))
  105. }
  106. // Node defines a cluster node
  107. type Node struct {
  108. Name string `json:"name"`
  109. Data NodeData `json:"data"`
  110. CreatedAt int64 `json:"created_at"`
  111. UpdatedAt int64 `json:"updated_at"`
  112. }
  113. func (n *Node) validate() error {
  114. if n.Name == "" {
  115. n.Name = n.Data.getNodeName()
  116. }
  117. return n.Data.validate()
  118. }
  119. func (n *Node) authenticate(token string) (string, string, error) {
  120. if err := n.Data.Key.TryDecrypt(); err != nil {
  121. providerLog(logger.LevelError, "unable to decrypt node key: %v", err)
  122. return "", "", err
  123. }
  124. if token == "" {
  125. return "", "", ErrInvalidCredentials
  126. }
  127. t, err := jwt.Parse([]byte(token), jwt.WithKey(jwa.HS256, []byte(n.Data.Key.GetPayload())), jwt.WithValidate(true))
  128. if err != nil {
  129. return "", "", fmt.Errorf("unable to parse and validate token: %v", err)
  130. }
  131. var adminUsername, role string
  132. if admin, ok := t.Get("admin"); ok {
  133. if val, ok := admin.(string); ok && val != "" {
  134. adminUsername = val
  135. }
  136. }
  137. if adminUsername == "" {
  138. return "", "", errors.New("no admin username associated with node token")
  139. }
  140. if r, ok := t.Get("role"); ok {
  141. if val, ok := r.(string); ok && val != "" {
  142. role = val
  143. }
  144. }
  145. return adminUsername, role, nil
  146. }
  147. // getBaseURL returns the base URL for this node
  148. func (n *Node) getBaseURL() string {
  149. var sb strings.Builder
  150. sb.WriteString(n.Data.Proto)
  151. sb.WriteString("://")
  152. sb.WriteString(n.Data.Host)
  153. if n.Data.Port > 0 {
  154. sb.WriteString(":")
  155. sb.WriteString(strconv.Itoa(n.Data.Port))
  156. }
  157. return sb.String()
  158. }
  159. // generateAuthToken generates a new auth token
  160. func (n *Node) generateAuthToken(username, role string) (string, error) {
  161. if err := n.Data.Key.TryDecrypt(); err != nil {
  162. return "", fmt.Errorf("unable to decrypt node key: %w", err)
  163. }
  164. now := time.Now().UTC()
  165. t := jwt.New()
  166. t.Set("admin", username) //nolint:errcheck
  167. t.Set("role", role) //nolint:errcheck
  168. t.Set(jwt.JwtIDKey, xid.New().String()) //nolint:errcheck
  169. t.Set(jwt.NotBeforeKey, now.Add(-30*time.Second)) //nolint:errcheck
  170. t.Set(jwt.ExpirationKey, now.Add(1*time.Minute)) //nolint:errcheck
  171. payload, err := jwt.Sign(t, jwt.WithKey(jwa.HS256, []byte(n.Data.Key.GetPayload())))
  172. if err != nil {
  173. return "", fmt.Errorf("unable to sign authentication token: %w", err)
  174. }
  175. return util.BytesToString(payload), nil
  176. }
  177. func (n *Node) prepareRequest(ctx context.Context, username, role, relativeURL, method string,
  178. body io.Reader,
  179. ) (*http.Request, error) {
  180. url := fmt.Sprintf("%s%s", n.getBaseURL(), relativeURL)
  181. req, err := http.NewRequestWithContext(ctx, method, url, body)
  182. if err != nil {
  183. return nil, err
  184. }
  185. token, err := n.generateAuthToken(username, role)
  186. if err != nil {
  187. return nil, err
  188. }
  189. req.Header.Set(NodeTokenHeader, fmt.Sprintf("Bearer %s", token))
  190. return req, nil
  191. }
  192. // SendGetRequest sends an HTTP GET request to this node.
  193. // The responseHolder must be a pointer
  194. func (n *Node) SendGetRequest(username, role, relativeURL string, responseHolder any) error {
  195. ctx, cancel := context.WithTimeout(context.Background(), nodeReqTimeout)
  196. defer cancel()
  197. req, err := n.prepareRequest(ctx, username, role, relativeURL, http.MethodGet, nil)
  198. if err != nil {
  199. return err
  200. }
  201. client := httpclient.GetHTTPClient()
  202. defer client.CloseIdleConnections()
  203. resp, err := client.Do(req)
  204. if err != nil {
  205. return fmt.Errorf("unable to send HTTP GET to node %s: %w", n.Name, err)
  206. }
  207. defer resp.Body.Close()
  208. if resp.StatusCode < http.StatusOK || resp.StatusCode > http.StatusNoContent {
  209. return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
  210. }
  211. respBody, err := io.ReadAll(io.LimitReader(resp.Body, 10485760))
  212. if err != nil {
  213. return fmt.Errorf("unable to read response body: %w", err)
  214. }
  215. err = json.Unmarshal(respBody, responseHolder)
  216. if err != nil {
  217. return errors.New("unable to decode response as json")
  218. }
  219. return nil
  220. }
  221. // SendDeleteRequest sends an HTTP DELETE request to this node
  222. func (n *Node) SendDeleteRequest(username, role, relativeURL string) error {
  223. ctx, cancel := context.WithTimeout(context.Background(), nodeReqTimeout)
  224. defer cancel()
  225. req, err := n.prepareRequest(ctx, username, role, relativeURL, http.MethodDelete, nil)
  226. if err != nil {
  227. return err
  228. }
  229. client := httpclient.GetHTTPClient()
  230. defer client.CloseIdleConnections()
  231. resp, err := client.Do(req)
  232. if err != nil {
  233. return fmt.Errorf("unable to send HTTP DELETE to node %s: %w", n.Name, err)
  234. }
  235. defer resp.Body.Close()
  236. if resp.StatusCode < http.StatusOK || resp.StatusCode > http.StatusNoContent {
  237. return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
  238. }
  239. return nil
  240. }
  241. // AuthenticateNodeToken check the validity of the provided token
  242. func AuthenticateNodeToken(token string) (string, string, error) {
  243. if currentNode == nil {
  244. return "", "", errNoClusterNodes
  245. }
  246. return currentNode.authenticate(token)
  247. }
  248. // GetNodeName returns the node name or an empty string
  249. func GetNodeName() string {
  250. if currentNode == nil {
  251. return ""
  252. }
  253. return currentNode.Name
  254. }