1
0

ca_pool.go 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. package cert
  2. import (
  3. "errors"
  4. "fmt"
  5. "net/netip"
  6. "slices"
  7. "strings"
  8. "time"
  9. )
  10. type CAPool struct {
  11. CAs map[string]*CachedCertificate
  12. certBlocklist map[string]struct{}
  13. }
  14. // NewCAPool creates an empty CAPool
  15. func NewCAPool() *CAPool {
  16. ca := CAPool{
  17. CAs: make(map[string]*CachedCertificate),
  18. certBlocklist: make(map[string]struct{}),
  19. }
  20. return &ca
  21. }
  22. // NewCAPoolFromPEM will create a new CA pool from the provided
  23. // input bytes, which must be a PEM-encoded set of nebula certificates.
  24. // If the pool contains any expired certificates, an ErrExpired will be
  25. // returned along with the pool. The caller must handle any such errors.
  26. func NewCAPoolFromPEM(caPEMs []byte) (*CAPool, error) {
  27. pool := NewCAPool()
  28. var err error
  29. var expired bool
  30. for {
  31. caPEMs, err = pool.AddCAFromPEM(caPEMs)
  32. if errors.Is(err, ErrExpired) {
  33. expired = true
  34. err = nil
  35. }
  36. if err != nil {
  37. return nil, err
  38. }
  39. if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" {
  40. break
  41. }
  42. }
  43. if expired {
  44. return pool, ErrExpired
  45. }
  46. return pool, nil
  47. }
  48. // AddCAFromPEM verifies a Nebula CA certificate and adds it to the pool.
  49. // Only the first pem encoded object will be consumed, any remaining bytes are returned.
  50. // Parsed certificates will be verified and must be a CA
  51. func (ncp *CAPool) AddCAFromPEM(pemBytes []byte) ([]byte, error) {
  52. c, pemBytes, err := UnmarshalCertificateFromPEM(pemBytes)
  53. if err != nil {
  54. return pemBytes, err
  55. }
  56. err = ncp.AddCA(c)
  57. if err != nil {
  58. return pemBytes, err
  59. }
  60. return pemBytes, nil
  61. }
  62. // AddCA verifies a Nebula CA certificate and adds it to the pool.
  63. func (ncp *CAPool) AddCA(c Certificate) error {
  64. if !c.IsCA() {
  65. return fmt.Errorf("%s: %w", c.Name(), ErrNotCA)
  66. }
  67. if !c.CheckSignature(c.PublicKey()) {
  68. return fmt.Errorf("%s: %w", c.Name(), ErrNotSelfSigned)
  69. }
  70. sum, err := c.Fingerprint()
  71. if err != nil {
  72. return fmt.Errorf("could not calculate fingerprint for provided CA; error: %w; %s", err, c.Name())
  73. }
  74. cc := &CachedCertificate{
  75. Certificate: c,
  76. Fingerprint: sum,
  77. InvertedGroups: make(map[string]struct{}),
  78. }
  79. for _, g := range c.Groups() {
  80. cc.InvertedGroups[g] = struct{}{}
  81. }
  82. ncp.CAs[sum] = cc
  83. if c.Expired(time.Now()) {
  84. return fmt.Errorf("%s: %w", c.Name(), ErrExpired)
  85. }
  86. return nil
  87. }
  88. // BlocklistFingerprint adds a cert fingerprint to the blocklist
  89. func (ncp *CAPool) BlocklistFingerprint(f string) {
  90. ncp.certBlocklist[f] = struct{}{}
  91. }
  92. // ResetCertBlocklist removes all previously blocklisted cert fingerprints
  93. func (ncp *CAPool) ResetCertBlocklist() {
  94. ncp.certBlocklist = make(map[string]struct{})
  95. }
  96. // IsBlocklisted tests the provided fingerprint against the pools blocklist.
  97. // Returns true if the fingerprint is blocked.
  98. func (ncp *CAPool) IsBlocklisted(fingerprint string) bool {
  99. if _, ok := ncp.certBlocklist[fingerprint]; ok {
  100. return true
  101. }
  102. return false
  103. }
  104. // VerifyCertificate verifies the certificate is valid and is signed by a trusted CA in the pool.
  105. // If the certificate is valid then the returned CachedCertificate can be used in subsequent verification attempts
  106. // to increase performance.
  107. func (ncp *CAPool) VerifyCertificate(now time.Time, c Certificate) (*CachedCertificate, error) {
  108. if c == nil {
  109. return nil, fmt.Errorf("no certificate")
  110. }
  111. fp, err := c.Fingerprint()
  112. if err != nil {
  113. return nil, fmt.Errorf("could not calculate fingerprint to verify: %w", err)
  114. }
  115. signer, err := ncp.verify(c, now, fp, "")
  116. if err != nil {
  117. return nil, err
  118. }
  119. // Pre nebula v1.10.3 could generate signatures in either high or low s form and validation
  120. // of signatures allowed for either. Nebula v1.10.3 and beyond clamps signature generation to low-s form
  121. // but validation still allows for either. Since a change in the signature bytes affects the fingerprint, we
  122. // need to test both forms until such a time comes that we enforce low-s form on signature validation.
  123. fp2, err := CalculateAlternateFingerprint(c)
  124. if err != nil {
  125. return nil, fmt.Errorf("could not calculate alternate fingerprint to verify: %w", err)
  126. }
  127. if fp2 != "" && ncp.IsBlocklisted(fp2) {
  128. return nil, ErrBlockListed
  129. }
  130. cc := CachedCertificate{
  131. Certificate: c,
  132. InvertedGroups: make(map[string]struct{}),
  133. Fingerprint: fp,
  134. fingerprint2: fp2,
  135. signerFingerprint: signer.Fingerprint,
  136. }
  137. for _, g := range c.Groups() {
  138. cc.InvertedGroups[g] = struct{}{}
  139. }
  140. return &cc, nil
  141. }
  142. // VerifyCachedCertificate is the same as VerifyCertificate other than it operates on a pre-verified structure and
  143. // is a cheaper operation to perform as a result.
  144. func (ncp *CAPool) VerifyCachedCertificate(now time.Time, c *CachedCertificate) error {
  145. // Check any available alternate fingerprint forms for this certificate, re P256 high-s/low-s
  146. if c.fingerprint2 != "" && ncp.IsBlocklisted(c.fingerprint2) {
  147. return ErrBlockListed
  148. }
  149. _, err := ncp.verify(c.Certificate, now, c.Fingerprint, c.signerFingerprint)
  150. return err
  151. }
  152. func (ncp *CAPool) verify(c Certificate, now time.Time, certFp string, signerFp string) (*CachedCertificate, error) {
  153. if ncp.IsBlocklisted(certFp) {
  154. return nil, ErrBlockListed
  155. }
  156. signer, err := ncp.GetCAForCert(c)
  157. if err != nil {
  158. return nil, err
  159. }
  160. if signer.Certificate.Expired(now) {
  161. return nil, ErrRootExpired
  162. }
  163. if c.Expired(now) {
  164. return nil, ErrExpired
  165. }
  166. // If we are checking a cached certificate then we can bail early here
  167. // Either the root is no longer trusted or everything is fine
  168. if len(signerFp) > 0 {
  169. if signerFp != signer.Fingerprint {
  170. return nil, ErrFingerprintMismatch
  171. }
  172. return signer, nil
  173. }
  174. if !c.CheckSignature(signer.Certificate.PublicKey()) {
  175. return nil, ErrSignatureMismatch
  176. }
  177. err = CheckCAConstraints(signer.Certificate, c)
  178. if err != nil {
  179. return nil, err
  180. }
  181. return signer, nil
  182. }
  183. // GetCAForCert attempts to return the signing certificate for the provided certificate.
  184. // No signature validation is performed
  185. func (ncp *CAPool) GetCAForCert(c Certificate) (*CachedCertificate, error) {
  186. issuer := c.Issuer()
  187. if issuer == "" {
  188. return nil, fmt.Errorf("no issuer in certificate")
  189. }
  190. signer, ok := ncp.CAs[issuer]
  191. if ok {
  192. return signer, nil
  193. }
  194. return nil, ErrCaNotFound
  195. }
  196. // GetFingerprints returns an array of trusted CA fingerprints
  197. func (ncp *CAPool) GetFingerprints() []string {
  198. fp := make([]string, len(ncp.CAs))
  199. i := 0
  200. for k := range ncp.CAs {
  201. fp[i] = k
  202. i++
  203. }
  204. return fp
  205. }
  206. // CheckCAConstraints returns an error if the sub certificate violates constraints present in the signer certificate.
  207. func CheckCAConstraints(signer Certificate, sub Certificate) error {
  208. return checkCAConstraints(signer, sub.NotBefore(), sub.NotAfter(), sub.Groups(), sub.Networks(), sub.UnsafeNetworks())
  209. }
  210. // checkCAConstraints is a very generic function allowing both Certificates and TBSCertificates to be tested.
  211. func checkCAConstraints(signer Certificate, notBefore, notAfter time.Time, groups []string, networks, unsafeNetworks []netip.Prefix) error {
  212. // Make sure this cert isn't valid after the root
  213. if notAfter.After(signer.NotAfter()) {
  214. return fmt.Errorf("certificate expires after signing certificate")
  215. }
  216. // Make sure this cert wasn't valid before the root
  217. if notBefore.Before(signer.NotBefore()) {
  218. return fmt.Errorf("certificate is valid before the signing certificate")
  219. }
  220. // If the signer has a limited set of groups make sure the cert only contains a subset
  221. signerGroups := signer.Groups()
  222. if len(signerGroups) > 0 {
  223. for _, g := range groups {
  224. if !slices.Contains(signerGroups, g) {
  225. return fmt.Errorf("certificate contained a group not present on the signing ca: %s", g)
  226. }
  227. }
  228. }
  229. // If the signer has a limited set of ip ranges to issue from make sure the cert only contains a subset
  230. signingNetworks := signer.Networks()
  231. if len(signingNetworks) > 0 {
  232. for _, certNetwork := range networks {
  233. found := false
  234. for _, signingNetwork := range signingNetworks {
  235. if signingNetwork.Contains(certNetwork.Addr()) && signingNetwork.Bits() <= certNetwork.Bits() {
  236. found = true
  237. break
  238. }
  239. }
  240. if !found {
  241. return fmt.Errorf("certificate contained a network assignment outside the limitations of the signing ca: %s", certNetwork.String())
  242. }
  243. }
  244. }
  245. // If the signer has a limited set of subnet ranges to issue from make sure the cert only contains a subset
  246. signingUnsafeNetworks := signer.UnsafeNetworks()
  247. if len(signingUnsafeNetworks) > 0 {
  248. for _, certUnsafeNetwork := range unsafeNetworks {
  249. found := false
  250. for _, caNetwork := range signingUnsafeNetworks {
  251. if caNetwork.Contains(certUnsafeNetwork.Addr()) && caNetwork.Bits() <= certUnsafeNetwork.Bits() {
  252. found = true
  253. break
  254. }
  255. }
  256. if !found {
  257. return fmt.Errorf("certificate contained an unsafe network assignment outside the limitations of the signing ca: %s", certUnsafeNetwork.String())
  258. }
  259. }
  260. }
  261. return nil
  262. }