Browse Source

Add AcceptNoWrap to DowngradingListener

AudriusButkevicius 10 years ago
parent
commit
61130ea191
1 changed files with 28 additions and 11 deletions
  1. 28 11
      lib/tlsutil/tlsutil.go

+ 28 - 11
lib/tlsutil/tlsutil.go

@@ -23,6 +23,10 @@ import (
 	"time"
 )
 
+var (
+	ErrIdentificationFailed = fmt.Errorf("failed to identify socket type")
+)
+
 func NewCertificate(certFile, keyFile, tlsDefaultCommonName string, tlsRSABits int) (tls.Certificate, error) {
 	priv, err := rsa.GenerateKey(rand.Reader, tlsRSABits)
 	if err != nil {
@@ -85,9 +89,28 @@ type DowngradingListener struct {
 }
 
 func (l *DowngradingListener) Accept() (net.Conn, error) {
+	conn, isTLS, err := l.AcceptNoWrap()
+
+	// We failed to identify the socket type, pretend that everything is fine,
+	// and pass it to the underlying handler, and let them deal with it.
+	if err == ErrIdentificationFailed {
+		return conn, nil
+	}
+
+	if err != nil {
+		return conn, err
+	}
+
+	if isTLS {
+		return tls.Server(conn, l.TLSConfig), nil
+	}
+	return conn, nil
+}
+
+func (l *DowngradingListener) AcceptNoWrap() (net.Conn, bool, error) {
 	conn, err := l.Listener.Accept()
 	if err != nil {
-		return nil, err
+		return nil, false, err
 	}
 
 	br := bufio.NewReader(conn)
@@ -96,18 +119,12 @@ func (l *DowngradingListener) Accept() (net.Conn, error) {
 	conn.SetReadDeadline(time.Time{})
 	if err != nil {
 		// We hit a read error here, but the Accept() call succeeded so we must not return an error.
-		// We return the connection as is and let whoever tries to use it deal with the error.
-		return conn, nil
-	}
-
-	wrapper := &WrappedConnection{br, conn}
-
-	// 0x16 is the first byte of a TLS handshake
-	if bs[0] == 0x16 {
-		return tls.Server(wrapper, l.TLSConfig), nil
+		// We return the connection as is with a special error which handles this
+		// special case in Accept().
+		return conn, false, ErrIdentificationFailed
 	}
 
-	return wrapper, nil
+	return &WrappedConnection{br, conn}, bs[0] == 0x16, nil
 }
 
 type WrappedConnection struct {