浏览代码

mitm: Minor fixes

世界 8 月之前
父节点
当前提交
fb3007fa80
共有 2 个文件被更改,包括 54 次插入38 次删除
  1. 1 1
      experimental/clashapi/mitm.go
  2. 53 37
      mitm/engine.go

+ 1 - 1
experimental/clashapi/mitm.go

@@ -40,7 +40,7 @@ func getMobileConfig(ctx context.Context) http.HandlerFunc {
 		mobileConfig := map[string]interface{}{
 			"PayloadContent": []interface{}{
 				map[string]interface{}{
-					"PayloadCertificateFileName": "Certificate.cer",
+					"PayloadCertificateFileName": "Certificates.cer",
 					"PayloadContent":             certificate.Raw,
 					"PayloadDescription":         "Adds a root certificate",
 					"PayloadDisplayName":         certificate.Subject.CommonName,

+ 53 - 37
mitm/engine.go

@@ -26,6 +26,7 @@ import (
 	"github.com/sagernet/sing/common"
 	"github.com/sagernet/sing/common/atomic"
 	E "github.com/sagernet/sing/common/exceptions"
+	F "github.com/sagernet/sing/common/format"
 	"github.com/sagernet/sing/common/logger"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
@@ -165,7 +166,7 @@ func (e *Engine) newTLS(ctx context.Context, this N.Dialer, conn net.Conn, metad
 	tlsConn := tls.Server(conn, tlsConfig)
 	err := tlsConn.HandshakeContext(ctx)
 	if err != nil {
-		return E.Cause(err, "TLS handshake")
+		return E.Cause(err, "TLS handshake failed for ", metadata.ClientHello.ServerName, ", ", strings.Join(metadata.ClientHello.SupportedProtos, ", "))
 	}
 	if tlsConn.ConnectionState().NegotiatedProtocol == "h2" {
 		return e.newHTTP2(ctx, this, tlsConn, tlsConfig, metadata, onClose)
@@ -183,7 +184,11 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls
 		return E.Cause(err, "read HTTP request")
 	}
 	rawRequestURL := request.URL
-	rawRequestURL.Scheme = "https"
+	if tlsConfig != nil {
+		rawRequestURL.Scheme = "https"
+	} else {
+		rawRequestURL.Scheme = "http"
+	}
 	if rawRequestURL.Host == "" {
 		rawRequestURL.Host = request.Host
 	}
@@ -482,7 +487,7 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls
 		response.Body = io.NopCloser(bytes.NewReader(responseBody))
 	}
 	if options.Print {
-		e.printResponse(ctx, response, responseBody)
+		e.printResponse(ctx, request, response, responseBody)
 	}
 	if responseScript != nil {
 		if responseBody == nil && responseScript.RequiresBody() && response.ContentLength > 0 && (responseScript.MaxSize() == 0 && response.ContentLength <= 131072 || response.ContentLength <= responseScript.MaxSize()) {
@@ -578,6 +583,22 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls
 }
 
 func (e *Engine) newHTTP2(ctx context.Context, this N.Dialer, conn net.Conn, tlsConfig *tls.Config, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error {
+	httpTransport := &http.Transport{
+		ForceAttemptHTTP2: true,
+		DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
+			ctx = adapter.WithContext(ctx, &metadata)
+			if len(metadata.DestinationAddresses) > 0 || metadata.Destination.IsIP() {
+				return dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
+			} else {
+				return this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
+			}
+		},
+		TLSClientConfig: tlsConfig,
+	}
+	err := http2.ConfigureTransport(httpTransport)
+	if err != nil {
+		return E.Cause(err, "configure HTTP/2 transport")
+	}
 	handler := &engineHandler{
 		Engine:    e,
 		conn:      conn,
@@ -585,27 +606,7 @@ func (e *Engine) newHTTP2(ctx context.Context, this N.Dialer, conn net.Conn, tls
 		dialer:    this,
 		metadata:  metadata,
 		httpClient: &http.Client{
-			Transport: &http2.Transport{
-				AllowHTTP:        true,
-				MaxReadFrameSize: math.MaxUint32,
-				DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
-					ctx = adapter.WithContext(ctx, &metadata)
-					var (
-						remoteConn net.Conn
-						err        error
-					)
-					if len(metadata.DestinationAddresses) > 0 || metadata.Destination.IsIP() {
-						remoteConn, err = dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
-					} else {
-						remoteConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
-					}
-					if err != nil {
-						return nil, err
-					}
-					return tls.Client(remoteConn, cfg), nil
-				},
-				TLSClientConfig: tlsConfig,
-			},
+			Transport: httpTransport,
 			CheckRedirect: func(req *http.Request, via []*http.Request) error {
 				return http.ErrUseLastResponse
 			},
@@ -635,7 +636,6 @@ type engineHandler struct {
 func (e *engineHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
 	err := e.serveHTTP(request.Context(), writer, request)
 	if err != nil {
-		e.conn.Close()
 		if E.IsClosedOrCanceled(err) {
 			e.logger.DebugContext(request.Context(), E.Cause(err, "connection closed"))
 		} else {
@@ -921,7 +921,7 @@ func (e *engineHandler) serveHTTP(ctx context.Context, writer http.ResponseWrite
 		response.Body = io.NopCloser(bytes.NewReader(responseBody))
 	}
 	if options.Print {
-		e.printResponse(ctx, response, responseBody)
+		e.printResponse(ctx, request, response, responseBody)
 	}
 	if responseScript != nil {
 		if responseBody == nil && responseScript.RequiresBody() && response.ContentLength > 0 && (responseScript.MaxSize() == 0 && response.ContentLength <= 131072 || response.ContentLength <= responseScript.MaxSize()) {
@@ -1021,42 +1021,58 @@ func (e *engineHandler) serveHTTP(ctx context.Context, writer http.ResponseWrite
 }
 
 func (e *Engine) printRequest(ctx context.Context, request *http.Request, body []byte) {
-	e.logger.TraceContext(ctx, "request: ", request.Proto, " ", request.Method, " ", request.URL.String())
+	var builder strings.Builder
+	builder.WriteString(F.ToString(request.Proto, " ", request.Method, " ", request.URL))
+	builder.WriteString("\n")
 	if request.URL.Hostname() != "" && request.URL.Hostname() != request.Host {
-		e.logger.TraceContext(ctx, "request: ", "Host: ", request.Host)
+		builder.WriteString("Host: ")
+		builder.WriteString(request.Host)
+		builder.WriteString("\n")
 	}
 	for key, values := range request.Header {
 		for _, value := range values {
-			e.logger.TraceContext(ctx, "request: ", key, ": ", value)
+			builder.WriteString(key)
+			builder.WriteString(": ")
+			builder.WriteString(value)
+			builder.WriteString("\n")
 		}
 	}
 	if len(body) > 0 {
+		builder.WriteString("\n")
 		if !bytes.ContainsFunc(body, func(r rune) bool {
 			return !unicode.IsPrint(r) && !unicode.IsSpace(r)
 		}) {
-			e.logger.TraceContext(ctx, "request: body: ", string(body))
+			builder.Write(body)
 		} else {
-			e.logger.TraceContext(ctx, "request: body unprintable")
+			builder.WriteString("(body not printable)")
 		}
 	}
+	e.logger.InfoContext(ctx, "request: ", builder.String())
 }
 
-func (e *Engine) printResponse(ctx context.Context, response *http.Response, body []byte) {
-	e.logger.TraceContext(ctx, "response: ", response.Proto, " ", response.Status)
+func (e *Engine) printResponse(ctx context.Context, request *http.Request, response *http.Response, body []byte) {
+	var builder strings.Builder
+	builder.WriteString(F.ToString(response.Proto, " ", response.Status, " ", request.URL))
+	builder.WriteString("\n")
 	for key, values := range response.Header {
 		for _, value := range values {
-			e.logger.TraceContext(ctx, "response: ", key, ": ", value)
+			builder.WriteString(key)
+			builder.WriteString(": ")
+			builder.WriteString(value)
+			builder.WriteString("\n")
 		}
 	}
 	if len(body) > 0 {
+		builder.WriteString("\n")
 		if !bytes.ContainsFunc(body, func(r rune) bool {
 			return !unicode.IsPrint(r) && !unicode.IsSpace(r)
 		}) {
-			e.logger.TraceContext(ctx, "response: ", string(body))
+			builder.Write(body)
+		} else {
+			builder.WriteString("(body not printable)")
 		}
-	} else {
-		e.logger.TraceContext(ctx, "response: body unprintable")
 	}
+	e.logger.InfoContext(ctx, "response: ", builder.String())
 }
 
 type simpleResponseWriter struct {