瀏覽代碼

Fix OCSP Stapling (#172)

Co-authored-by: RPRX <[email protected]>
eMeab 4 年之前
父節點
當前提交
c13b8ec9bb
共有 2 個文件被更改,包括 76 次插入14 次删除
  1. 38 7
      transport/internet/tls/config.go
  2. 38 7
      transport/internet/xtls/config.go

+ 38 - 7
transport/internet/tls/config.go

@@ -42,8 +42,8 @@ func (c *Config) loadSelfCertPool() (*x509.CertPool, error) {
 }
 
 // BuildCertificates builds a list of TLS certificates from proto definition.
-func (c *Config) BuildCertificates() []tls.Certificate {
-	certs := make([]tls.Certificate, 0, len(c.Certificate))
+func (c *Config) BuildCertificates() []*tls.Certificate {
+	certs := make([]*tls.Certificate, 0, len(c.Certificate))
 	for _, entry := range c.Certificate {
 		if entry.Usage != Certificate_ENCIPHERMENT {
 			continue
@@ -53,7 +53,12 @@ func (c *Config) BuildCertificates() []tls.Certificate {
 			newError("ignoring invalid X509 key pair").Base(err).AtWarning().WriteToLog()
 			continue
 		}
-		certs = append(certs, keyPair)
+		keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0])
+		if err != nil {
+			newError("ignoring invalid certificate").Base(err).AtWarning().WriteToLog()
+			continue
+		}
+		certs = append(certs, &keyPair)
 		if entry.OcspStapling != 0 {
 			go func(cert *tls.Certificate) {
 				t := time.NewTicker(time.Duration(entry.OcspStapling) * time.Second)
@@ -65,7 +70,7 @@ func (c *Config) BuildCertificates() []tls.Certificate {
 					}
 					<-t.C
 				}
-			}(&certs[len(certs)-1])
+			}(certs[len(certs)-1])
 		}
 	}
 	return certs
@@ -169,6 +174,33 @@ func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.Cli
 	}
 }
 
+func getNewGetCertficateFunc(certs []*tls.Certificate) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
+	return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
+		if len(certs) == 0 {
+			return nil, newError("empty certs")
+		}
+		sni := strings.ToLower(hello.ServerName)
+		if len(certs) == 1 || sni == "" {
+			return certs[0], nil
+		}
+		gsni := "*"
+		if index := strings.IndexByte(sni, '.'); index != -1 {
+			gsni += sni[index:]
+		}
+		for _, keyPair := range certs {
+			if keyPair.Leaf.Subject.CommonName == sni || keyPair.Leaf.Subject.CommonName == gsni {
+				return keyPair, nil
+			}
+			for _, name := range keyPair.Leaf.DNSNames {
+				if name == sni || name == gsni {
+					return keyPair, nil
+				}
+			}
+		}
+		return certs[0], nil
+	}
+}
+
 func (c *Config) IsExperiment8357() bool {
 	return strings.HasPrefix(c.ServerName, exp8357)
 }
@@ -210,12 +242,11 @@ func (c *Config) GetTLSConfig(opts ...Option) *tls.Config {
 		opt(config)
 	}
 
-	config.Certificates = c.BuildCertificates()
-	config.BuildNameToCertificate()
-
 	caCerts := c.getCustomCA()
 	if len(caCerts) > 0 {
 		config.GetCertificate = getGetCertificateFunc(config, caCerts)
+	} else {
+		config.GetCertificate = getNewGetCertficateFunc(c.BuildCertificates())
 	}
 
 	if sn := c.parseServerName(); len(sn) > 0 {

+ 38 - 7
transport/internet/xtls/config.go

@@ -41,8 +41,8 @@ func (c *Config) loadSelfCertPool() (*x509.CertPool, error) {
 }
 
 // BuildCertificates builds a list of TLS certificates from proto definition.
-func (c *Config) BuildCertificates() []xtls.Certificate {
-	certs := make([]xtls.Certificate, 0, len(c.Certificate))
+func (c *Config) BuildCertificates() []*xtls.Certificate {
+	certs := make([]*xtls.Certificate, 0, len(c.Certificate))
 	for _, entry := range c.Certificate {
 		if entry.Usage != Certificate_ENCIPHERMENT {
 			continue
@@ -52,7 +52,12 @@ func (c *Config) BuildCertificates() []xtls.Certificate {
 			newError("ignoring invalid X509 key pair").Base(err).AtWarning().WriteToLog()
 			continue
 		}
-		certs = append(certs, keyPair)
+		keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0])
+		if err != nil {
+			newError("ignoring invalid certificate").Base(err).AtWarning().WriteToLog()
+			continue
+		}
+		certs = append(certs, &keyPair)
 		if entry.OcspStapling != 0 {
 			go func(cert *xtls.Certificate) {
 				t := time.NewTicker(time.Duration(entry.OcspStapling) * time.Second)
@@ -64,7 +69,7 @@ func (c *Config) BuildCertificates() []xtls.Certificate {
 					}
 					<-t.C
 				}
-			}(&certs[len(certs)-1])
+			}(certs[len(certs)-1])
 		}
 	}
 	return certs
@@ -168,6 +173,33 @@ func getGetCertificateFunc(c *xtls.Config, ca []*Certificate) func(hello *xtls.C
 	}
 }
 
+func getNewGetCertficateFunc(certs []*xtls.Certificate) func(hello *xtls.ClientHelloInfo) (*xtls.Certificate, error) {
+	return func(hello *xtls.ClientHelloInfo) (*xtls.Certificate, error) {
+		if len(certs) == 0 {
+			return nil, newError("empty certs")
+		}
+		sni := strings.ToLower(hello.ServerName)
+		if len(certs) == 1 || sni == "" {
+			return certs[0], nil
+		}
+		gsni := "*"
+		if index := strings.IndexByte(sni, '.'); index != -1 {
+			gsni += sni[index:]
+		}
+		for _, keyPair := range certs {
+			if keyPair.Leaf.Subject.CommonName == sni || keyPair.Leaf.Subject.CommonName == gsni {
+				return keyPair, nil
+			}
+			for _, name := range keyPair.Leaf.DNSNames {
+				if name == sni || name == gsni {
+					return keyPair, nil
+				}
+			}
+		}
+		return certs[0], nil
+	}
+}
+
 func (c *Config) parseServerName() string {
 	return c.ServerName
 }
@@ -201,12 +233,11 @@ func (c *Config) GetXTLSConfig(opts ...Option) *xtls.Config {
 		opt(config)
 	}
 
-	config.Certificates = c.BuildCertificates()
-	config.BuildNameToCertificate()
-
 	caCerts := c.getCustomCA()
 	if len(caCerts) > 0 {
 		config.GetCertificate = getGetCertificateFunc(config, caCerts)
+	} else {
+		config.GetCertificate = getNewGetCertficateFunc(c.BuildCertificates())
 	}
 
 	if sn := c.parseServerName(); len(sn) > 0 {