Просмотр исходного кода

logpolicy: split out DialContext into a func

Updates tailscale/corp#10030

Signed-off-by: Maisem Ali <[email protected]>
Maisem Ali 3 лет назад
Родитель
Сommit
7300b908fb
2 измененных файлов с 54 добавлено и 51 удалено
  1. 1 2
      ipn/ipnserver/proxyconnect.go
  2. 53 49
      logpolicy/logpolicy.go

+ 1 - 2
ipn/ipnserver/proxyconnect.go

@@ -37,8 +37,7 @@ func (s *Server) handleProxyConnectConn(w http.ResponseWriter, r *http.Request)
 		return
 	}
 
-	tr := logpolicy.NewLogtailTransport(logHost)
-	back, err := tr.DialContext(ctx, "tcp", hostPort)
+	back, err := logpolicy.DialContext(ctx, "tcp", hostPort)
 	if err != nil {
 		s.logf("error CONNECT dialing %v: %v", hostPort, err)
 		http.Error(w, "Connect failure", http.StatusBadGateway)

+ 53 - 49
logpolicy/logpolicy.go

@@ -667,11 +667,59 @@ func (p *Policy) Shutdown(ctx context.Context) error {
 	return nil
 }
 
-// NewLogtailTransport returns an HTTP Transport particularly suited to uploading
-// logs to the given host name. This includes:
-//   - If DNS lookup fails, consult the bootstrap DNS list of Tailscale hostnames.
+// DialContext is a net.Dialer.DialContext specialized for use by logtail.
+// It does the following:
+//   - If DNS lookup fails, consults the bootstrap DNS list of Tailscale hostnames.
 //   - If TLS connection fails, try again using LetsEncrypt's built-in root certificate,
 //     for the benefit of older OS platforms which might not include it.
+func DialContext(ctx context.Context, netw, addr string) (net.Conn, error) {
+	nd := netns.FromDialer(log.Printf, &net.Dialer{
+		Timeout:   30 * time.Second,
+		KeepAlive: netknob.PlatformTCPKeepAlive(),
+	})
+	t0 := time.Now()
+	c, err := nd.DialContext(ctx, netw, addr)
+	d := time.Since(t0).Round(time.Millisecond)
+	if err == nil {
+		dialLog.Printf("dialed %q in %v", addr, d)
+		return c, nil
+	}
+
+	if version.IsWindowsGUI() && strings.HasPrefix(netw, "tcp") {
+		if c, err := safesocket.Connect(safesocket.DefaultConnectionStrategy("")); err == nil {
+			fmt.Fprintf(c, "CONNECT %s HTTP/1.0\r\n\r\n", addr)
+			br := bufio.NewReader(c)
+			res, err := http.ReadResponse(br, nil)
+			if err == nil && res.StatusCode != 200 {
+				err = errors.New(res.Status)
+			}
+			if err != nil {
+				log.Printf("logtail: CONNECT response error from tailscaled: %v", err)
+				c.Close()
+			} else {
+				dialLog.Printf("connected via tailscaled")
+				return c, nil
+			}
+		}
+	}
+
+	// If we failed to dial, try again with bootstrap DNS.
+	log.Printf("logtail: dial %q failed: %v (in %v), trying bootstrap...", addr, err, d)
+	dnsCache := &dnscache.Resolver{
+		Forward:          dnscache.Get().Forward, // use default cache's forwarder
+		UseLastGood:      true,
+		LookupIPFallback: dnsfallback.Lookup,
+	}
+	dialer := dnscache.Dialer(nd.DialContext, dnsCache)
+	c, err = dialer(ctx, netw, addr)
+	if err == nil {
+		log.Printf("logtail: bootstrap dial succeeded")
+	}
+	return c, err
+}
+
+// NewLogtailTransport returns an HTTP Transport particularly suited to uploading
+// logs to the given host name. See DialContext for details on how it works.
 func NewLogtailTransport(host string) *http.Transport {
 	// Start with a copy of http.DefaultTransport and tweak it a bit.
 	tr := http.DefaultTransport.(*http.Transport).Clone()
@@ -685,51 +733,7 @@ func NewLogtailTransport(host string) *http.Transport {
 	tr.DisableCompression = true
 
 	// Log whenever we dial:
-	tr.DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
-		nd := netns.FromDialer(log.Printf, &net.Dialer{
-			Timeout:   30 * time.Second,
-			KeepAlive: netknob.PlatformTCPKeepAlive(),
-		})
-		t0 := time.Now()
-		c, err := nd.DialContext(ctx, netw, addr)
-		d := time.Since(t0).Round(time.Millisecond)
-		if err == nil {
-			dialLog.Printf("dialed %q in %v", addr, d)
-			return c, nil
-		}
-
-		if version.IsWindowsGUI() && strings.HasPrefix(netw, "tcp") {
-			if c, err := safesocket.Connect(safesocket.DefaultConnectionStrategy("")); err == nil {
-				fmt.Fprintf(c, "CONNECT %s HTTP/1.0\r\n\r\n", addr)
-				br := bufio.NewReader(c)
-				res, err := http.ReadResponse(br, nil)
-				if err == nil && res.StatusCode != 200 {
-					err = errors.New(res.Status)
-				}
-				if err != nil {
-					log.Printf("logtail: CONNECT response error from tailscaled: %v", err)
-					c.Close()
-				} else {
-					dialLog.Printf("connected via tailscaled")
-					return c, nil
-				}
-			}
-		}
-
-		// If we failed to dial, try again with bootstrap DNS.
-		log.Printf("logtail: dial %q failed: %v (in %v), trying bootstrap...", addr, err, d)
-		dnsCache := &dnscache.Resolver{
-			Forward:          dnscache.Get().Forward, // use default cache's forwarder
-			UseLastGood:      true,
-			LookupIPFallback: dnsfallback.Lookup,
-		}
-		dialer := dnscache.Dialer(nd.DialContext, dnsCache)
-		c, err = dialer(ctx, netw, addr)
-		if err == nil {
-			log.Printf("logtail: bootstrap dial succeeded")
-		}
-		return c, err
-	}
+	tr.DialContext = DialContext
 
 	// We're contacting exactly 1 hostname, so the default's 100
 	// max idle conns is very high for our needs. Even 2 is
@@ -762,7 +766,7 @@ func goVersion() string {
 type noopPretendSuccessTransport struct{}
 
 func (noopPretendSuccessTransport) RoundTrip(req *http.Request) (*http.Response, error) {
-	io.ReadAll(req.Body)
+	io.Copy(io.Discard, req.Body)
 	req.Body.Close()
 	return &http.Response{
 		StatusCode: 200,