Browse Source

Add WaitReadCloser to make H2 real 0-RTT

RPRX 2 years ago
parent
commit
6526e74d49
1 changed files with 57 additions and 9 deletions
  1. 57 9
      transport/internet/http/dialer.go

+ 57 - 9
transport/internet/http/dialer.go

@@ -3,6 +3,7 @@ package http
 import (
 	"context"
 	gotls "crypto/tls"
+	"io"
 	"net/http"
 	"net/url"
 	"sync"
@@ -166,23 +167,70 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
 	// Disable any compression method from server.
 	request.Header.Set("Accept-Encoding", "identity")
 
-	response, err := client.Do(request)
-	if err != nil {
-		return nil, newError("failed to dial to ", dest).Base(err).AtWarning()
-	}
-	if response.StatusCode != 200 {
-		return nil, newError("unexpected status", response.StatusCode).AtWarning()
-	}
+	wrc := &WaitReadCloser{Wait: make(chan struct{})}
+	go func() {
+		response, err := client.Do(request)
+		if err != nil {
+			newError("failed to dial to ", dest).Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx))
+			wrc.Close()
+			return
+		}
+		if response.StatusCode != 200 {
+			newError("unexpected status", response.StatusCode).AtWarning().WriteToLog(session.ExportIDToError(ctx))
+			wrc.Close()
+			return
+		}
+		wrc.Set(response.Body)
+	}()
 
 	bwriter := buf.NewBufferedWriter(pwriter)
 	common.Must(bwriter.SetBuffered(false))
 	return cnc.NewConnection(
-		cnc.ConnectionOutput(response.Body),
+		cnc.ConnectionOutput(wrc),
 		cnc.ConnectionInput(bwriter),
-		cnc.ConnectionOnClose(common.ChainedClosable{breader, bwriter, response.Body}),
+		cnc.ConnectionOnClose(common.ChainedClosable{breader, bwriter, wrc}),
 	), nil
 }
 
 func init() {
 	common.Must(internet.RegisterTransportDialer(protocolName, Dial))
 }
+
+type WaitReadCloser struct {
+	Wait chan struct{}
+	io.ReadCloser
+}
+
+func (w *WaitReadCloser) Set(rc io.ReadCloser) {
+	w.ReadCloser = rc
+	defer func() {
+		if err := recover(); err != nil {
+			rc.Close()
+		}
+	}()
+	close(w.Wait)
+}
+
+func (w *WaitReadCloser) Read(b []byte) (int, error) {
+	if w.ReadCloser == nil {
+		if <-w.Wait; w.ReadCloser == nil {
+			return 0, io.ErrClosedPipe
+		}
+	}
+	return w.ReadCloser.Read(b)
+}
+
+func (w *WaitReadCloser) Close() error {
+	if w.ReadCloser != nil {
+		return w.ReadCloser.Close()
+	}
+	defer func() {
+		if err := recover(); err != nil {
+			if w.ReadCloser != nil {
+				w.ReadCloser.Close()
+			}
+		}
+	}()
+	close(w.Wait)
+	return nil
+}