pki.go 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. package nebula
  2. import (
  3. "errors"
  4. "fmt"
  5. "os"
  6. "strings"
  7. "sync/atomic"
  8. "time"
  9. "github.com/sirupsen/logrus"
  10. "github.com/slackhq/nebula/cert"
  11. "github.com/slackhq/nebula/config"
  12. "github.com/slackhq/nebula/util"
  13. )
  14. type PKI struct {
  15. cs atomic.Pointer[CertState]
  16. caPool atomic.Pointer[cert.CAPool]
  17. l *logrus.Logger
  18. }
  19. type CertState struct {
  20. Certificate cert.Certificate
  21. RawCertificate []byte
  22. RawCertificateNoKey []byte
  23. PublicKey []byte
  24. PrivateKey []byte
  25. pkcs11Backed bool
  26. }
  27. func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
  28. pki := &PKI{l: l}
  29. err := pki.reload(c, true)
  30. if err != nil {
  31. return nil, err
  32. }
  33. c.RegisterReloadCallback(func(c *config.C) {
  34. rErr := pki.reload(c, false)
  35. if rErr != nil {
  36. util.LogWithContextIfNeeded("Failed to reload PKI from config", rErr, l)
  37. }
  38. })
  39. return pki, nil
  40. }
  41. func (p *PKI) GetCertState() *CertState {
  42. return p.cs.Load()
  43. }
  44. func (p *PKI) GetCAPool() *cert.CAPool {
  45. return p.caPool.Load()
  46. }
  47. func (p *PKI) reload(c *config.C, initial bool) error {
  48. err := p.reloadCert(c, initial)
  49. if err != nil {
  50. if initial {
  51. return err
  52. }
  53. err.Log(p.l)
  54. }
  55. err = p.reloadCAPool(c)
  56. if err != nil {
  57. if initial {
  58. return err
  59. }
  60. err.Log(p.l)
  61. }
  62. return nil
  63. }
  64. func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError {
  65. cs, err := newCertStateFromConfig(c)
  66. if err != nil {
  67. return util.NewContextualError("Could not load client cert", nil, err)
  68. }
  69. if !initial {
  70. //TODO: include check for mask equality as well
  71. // did IP in cert change? if so, don't set
  72. currentCert := p.cs.Load().Certificate
  73. oldIPs := currentCert.Networks()
  74. newIPs := cs.Certificate.Networks()
  75. if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() {
  76. return util.NewContextualError(
  77. "Networks in new cert was different from old",
  78. m{"new_network": newIPs[0], "old_network": oldIPs[0]},
  79. nil,
  80. )
  81. }
  82. }
  83. p.cs.Store(cs)
  84. if initial {
  85. p.l.WithField("cert", cs.Certificate).Debug("Client nebula certificate")
  86. } else {
  87. p.l.WithField("cert", cs.Certificate).Info("Client cert refreshed from disk")
  88. }
  89. return nil
  90. }
  91. func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError {
  92. caPool, err := loadCAPoolFromConfig(p.l, c)
  93. if err != nil {
  94. return util.NewContextualError("Failed to load ca from config", nil, err)
  95. }
  96. p.caPool.Store(caPool)
  97. p.l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
  98. return nil
  99. }
  100. func newCertState(certificate cert.Certificate, pkcs11backed bool, privateKey []byte) (*CertState, error) {
  101. // Marshal the certificate to ensure it is valid
  102. rawCertificate, err := certificate.Marshal()
  103. if err != nil {
  104. return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err)
  105. }
  106. publicKey := certificate.PublicKey()
  107. cs := &CertState{
  108. RawCertificate: rawCertificate,
  109. Certificate: certificate,
  110. PrivateKey: privateKey,
  111. PublicKey: publicKey,
  112. pkcs11Backed: pkcs11backed,
  113. }
  114. rawCertNoKey, err := cs.Certificate.MarshalForHandshakes()
  115. if err != nil {
  116. return nil, fmt.Errorf("error marshalling certificate no key: %s", err)
  117. }
  118. cs.RawCertificateNoKey = rawCertNoKey
  119. return cs, nil
  120. }
  121. func loadPrivateKey(privPathOrPEM string) (rawKey []byte, curve cert.Curve, isPkcs11 bool, err error) {
  122. var pemPrivateKey []byte
  123. if strings.Contains(privPathOrPEM, "-----BEGIN") {
  124. pemPrivateKey = []byte(privPathOrPEM)
  125. privPathOrPEM = "<inline>"
  126. rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
  127. if err != nil {
  128. return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
  129. }
  130. } else if strings.HasPrefix(privPathOrPEM, "pkcs11:") {
  131. rawKey = []byte(privPathOrPEM)
  132. return rawKey, cert.Curve_P256, true, nil
  133. } else {
  134. pemPrivateKey, err = os.ReadFile(privPathOrPEM)
  135. if err != nil {
  136. return nil, curve, false, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err)
  137. }
  138. rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
  139. if err != nil {
  140. return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
  141. }
  142. }
  143. return
  144. }
  145. func newCertStateFromConfig(c *config.C) (*CertState, error) {
  146. var err error
  147. privPathOrPEM := c.GetString("pki.key", "")
  148. if privPathOrPEM == "" {
  149. return nil, errors.New("no pki.key path or PEM data provided")
  150. }
  151. rawKey, curve, isPkcs11, err := loadPrivateKey(privPathOrPEM)
  152. if err != nil {
  153. return nil, err
  154. }
  155. var rawCert []byte
  156. pubPathOrPEM := c.GetString("pki.cert", "")
  157. if pubPathOrPEM == "" {
  158. return nil, errors.New("no pki.cert path or PEM data provided")
  159. }
  160. if strings.Contains(pubPathOrPEM, "-----BEGIN") {
  161. rawCert = []byte(pubPathOrPEM)
  162. pubPathOrPEM = "<inline>"
  163. } else {
  164. rawCert, err = os.ReadFile(pubPathOrPEM)
  165. if err != nil {
  166. return nil, fmt.Errorf("unable to read pki.cert file %s: %s", pubPathOrPEM, err)
  167. }
  168. }
  169. nebulaCert, _, err := cert.UnmarshalCertificateFromPEM(rawCert)
  170. if err != nil {
  171. return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err)
  172. }
  173. if nebulaCert.Expired(time.Now()) {
  174. return nil, fmt.Errorf("nebula certificate for this host is expired")
  175. }
  176. if len(nebulaCert.Networks()) == 0 {
  177. return nil, fmt.Errorf("no networks encoded in certificate")
  178. }
  179. if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil {
  180. return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
  181. }
  182. return newCertState(nebulaCert, isPkcs11, rawKey)
  183. }
  184. func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
  185. var rawCA []byte
  186. var err error
  187. caPathOrPEM := c.GetString("pki.ca", "")
  188. if caPathOrPEM == "" {
  189. return nil, errors.New("no pki.ca path or PEM data provided")
  190. }
  191. if strings.Contains(caPathOrPEM, "-----BEGIN") {
  192. rawCA = []byte(caPathOrPEM)
  193. } else {
  194. rawCA, err = os.ReadFile(caPathOrPEM)
  195. if err != nil {
  196. return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err)
  197. }
  198. }
  199. caPool, err := cert.NewCAPoolFromPEM(rawCA)
  200. if errors.Is(err, cert.ErrExpired) {
  201. var expired int
  202. for _, crt := range caPool.CAs {
  203. if crt.Certificate.Expired(time.Now()) {
  204. expired++
  205. l.WithField("cert", crt).Warn("expired certificate present in CA pool")
  206. }
  207. }
  208. if expired >= len(caPool.CAs) {
  209. return nil, errors.New("no valid CA certificates present")
  210. }
  211. } else if err != nil {
  212. return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
  213. }
  214. for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) {
  215. l.WithField("fingerprint", fp).Info("Blocklisting cert")
  216. caPool.BlocklistFingerprint(fp)
  217. }
  218. return caPool, nil
  219. }