verify.go 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. package main
  2. import (
  3. "errors"
  4. "flag"
  5. "fmt"
  6. "io"
  7. "os"
  8. "time"
  9. "github.com/slackhq/nebula/cert"
  10. )
  11. type verifyFlags struct {
  12. set *flag.FlagSet
  13. caPath *string
  14. certPath *string
  15. }
  16. func newVerifyFlags() *verifyFlags {
  17. vf := verifyFlags{set: flag.NewFlagSet("verify", flag.ContinueOnError)}
  18. vf.set.Usage = func() {}
  19. vf.caPath = vf.set.String("ca", "", "Required: path to a file containing one or more ca certificates")
  20. vf.certPath = vf.set.String("crt", "", "Required: path to a file containing a single certificate")
  21. return &vf
  22. }
  23. func verify(args []string, out io.Writer, errOut io.Writer) error {
  24. vf := newVerifyFlags()
  25. err := vf.set.Parse(args)
  26. if err != nil {
  27. return err
  28. }
  29. if err := mustFlagString("ca", vf.caPath); err != nil {
  30. return err
  31. }
  32. if err := mustFlagString("crt", vf.certPath); err != nil {
  33. return err
  34. }
  35. caFile, err := os.Open(*vf.caPath)
  36. if err != nil {
  37. return fmt.Errorf("error while reading ca: %w", err)
  38. }
  39. defer caFile.Close()
  40. caPool, err := cert.NewCAPoolFromPEMReader(caFile)
  41. if err != nil && !errors.Is(err, cert.ErrExpired) {
  42. return fmt.Errorf("error while adding ca cert to pool: %w", err)
  43. }
  44. rawCert, err := os.ReadFile(*vf.certPath)
  45. if err != nil {
  46. return fmt.Errorf("unable to read crt: %w", err)
  47. }
  48. var errs []error
  49. for {
  50. if len(rawCert) == 0 {
  51. break
  52. }
  53. c, extra, err := cert.UnmarshalCertificateFromPEM(rawCert)
  54. if err != nil {
  55. return fmt.Errorf("error while parsing crt: %w", err)
  56. }
  57. rawCert = extra
  58. _, err = caPool.VerifyCertificate(time.Now(), c)
  59. if err != nil {
  60. switch {
  61. case errors.Is(err, cert.ErrCaNotFound):
  62. errs = append(errs, fmt.Errorf("error while verifying certificate v%d %s with issuer %s: %w", c.Version(), c.Name(), c.Issuer(), err))
  63. default:
  64. errs = append(errs, fmt.Errorf("error while verifying certificate %+v: %w", c, err))
  65. }
  66. }
  67. }
  68. return errors.Join(errs...)
  69. }
  70. func verifySummary() string {
  71. return "verify <flags>: verifies a certificate isn't expired and was signed by a trusted authority."
  72. }
  73. func verifyHelp(out io.Writer) {
  74. vf := newVerifyFlags()
  75. _, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n"))
  76. vf.set.SetOutput(out)
  77. vf.set.PrintDefaults()
  78. }