Browse Source

mitm: Fix HTTP2 support & Add print

世界 9 months ago
parent
commit
1fe983a81b
2 changed files with 179 additions and 96 deletions
  1. 178 96
      mitm/engine.go
  2. 1 0
      option/mitm.go

+ 178 - 96
mitm/engine.go

@@ -17,6 +17,7 @@ import (
 	"path/filepath"
 	"strings"
 	"time"
+	"unicode"
 
 	"github.com/sagernet/sing-box/adapter"
 	"github.com/sagernet/sing-box/common/dialer"
@@ -124,6 +125,7 @@ func (e *Engine) newTLS(ctx context.Context, this N.Dialer, conn net.Conn, metad
 	acceptHTTP := len(metadata.ClientHello.SupportedProtos) == 0 || common.Contains(metadata.ClientHello.SupportedProtos, "http/1.1")
 	acceptH2 := e.http2Enabled && common.Contains(metadata.ClientHello.SupportedProtos, "h2")
 	if !acceptHTTP && !acceptH2 {
+		metadata.MITM = nil
 		e.logger.DebugContext(ctx, "unsupported application protocol: ", strings.Join(metadata.ClientHello.SupportedProtos, ","))
 		e.connection.NewConnection(ctx, this, conn, metadata, onClose)
 		return nil
@@ -147,12 +149,11 @@ func (e *Engine) newTLS(ctx context.Context, this N.Dialer, conn net.Conn, metad
 		serverName = metadata.Destination.Addr.String()
 	}
 	tlsConfig := &tls.Config{
-		Time:             e.timeFunc,
-		CipherSuites:     metadata.ClientHello.CipherSuites,
-		ServerName:       serverName,
-		CurvePreferences: metadata.ClientHello.SupportedCurves,
-		NextProtos:       nextProtos,
-		MinVersion:       minVersion,
+		Time:       e.timeFunc,
+		ServerName: serverName,
+		NextProtos: nextProtos,
+		MinVersion: minVersion,
+		MaxVersion: maxVersion,
 		GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
 			return sTLS.GenerateKeyPair(e.tlsCertificate, e.tlsPrivateKey, e.timeFunc, serverName)
 		},
@@ -163,7 +164,7 @@ func (e *Engine) newTLS(ctx context.Context, this N.Dialer, conn net.Conn, metad
 		return E.Cause(err, "TLS handshake")
 	}
 	if tlsConn.ConnectionState().NegotiatedProtocol == "h2" {
-		return e.newHTTP2(ctx, this, tlsConn, metadata, onClose)
+		return e.newHTTP2(ctx, this, tlsConn, tlsConfig, metadata, onClose)
 	} else {
 		return e.newHTTP1(ctx, this, tlsConn, tlsConfig, metadata)
 	}
@@ -171,7 +172,6 @@ func (e *Engine) newTLS(ctx context.Context, this N.Dialer, conn net.Conn, metad
 
 func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tlsConfig *tls.Config, metadata adapter.InboundContext) error {
 	options := metadata.MITM
-	metadata.MITM = nil
 	defer conn.Close()
 	reader := bufio.NewReader(conn)
 	request, err := sHTTP.ReadRequest(reader)
@@ -209,9 +209,19 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls
 		requestMatch = true
 		break
 	}
+	var body []byte
+	if options.Print && request.ContentLength > 0 && request.ContentLength <= 131072 {
+		body, err = io.ReadAll(request.Body)
+		if err != nil {
+			return E.Cause(err, "read HTTP request body")
+		}
+		request.Body = io.NopCloser(bytes.NewReader(body))
+	}
+	if options.Print {
+		e.printRequest(ctx, request, body)
+	}
 	if requestScript != nil {
-		var body []byte
-		if requestScript.RequiresBody() && request.ContentLength > 0 && (requestScript.MaxSize() == 0 && request.ContentLength <= 131072 || request.ContentLength <= requestScript.MaxSize()) {
+		if body == nil && requestScript.RequiresBody() && request.ContentLength > 0 && (requestScript.MaxSize() == 0 && request.ContentLength <= 131072 || request.ContentLength <= requestScript.MaxSize()) {
 			body, err = io.ReadAll(request.Body)
 			if err != nil {
 				return E.Cause(err, "read HTTP request body")
@@ -266,8 +276,9 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls
 				request.Header.Del("Host")
 			}
 			if result.Body != nil {
-				request.Body = io.NopCloser(bytes.NewReader(result.Body))
-				request.ContentLength = int64(len(result.Body))
+				body = result.Body
+				request.Body = io.NopCloser(bytes.NewReader(body))
+				request.ContentLength = int64(len(body))
 			}
 		}
 	}
@@ -337,17 +348,18 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls
 			}
 			requestMatch = true
 			e.logger.DebugContext(ctx, "match body_rewrite[", i, "] => ", rule.String())
-			var body []byte
-			if request.ContentLength <= 0 {
-				e.logger.WarnContext(ctx, "body replace skipped due to non-fixed content length")
-				break
-			} else if request.ContentLength > 131072 {
-				e.logger.WarnContext(ctx, "body replace skipped due to large content length: ", request.ContentLength)
-				break
-			}
-			body, err = io.ReadAll(request.Body)
-			if err != nil {
-				return E.Cause(err, "read HTTP request body")
+			if body == nil {
+				if request.ContentLength <= 0 {
+					e.logger.WarnContext(ctx, "body replace skipped due to non-fixed content length")
+					break
+				} else if request.ContentLength > 131072 {
+					e.logger.WarnContext(ctx, "body replace skipped due to large content length: ", request.ContentLength)
+					break
+				}
+				body, err = io.ReadAll(request.Body)
+				if err != nil {
+					return E.Cause(err, "read HTTP request body")
+				}
 			}
 			for mi := 0; i < len(rule.Match); i++ {
 				body = rule.Match[mi].ReplaceAll(body, []byte(rule.Replace[i]))
@@ -366,7 +378,6 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls
 			var (
 				statusCode = http.StatusOK
 				headers    = make(http.Header)
-				body       []byte
 			)
 			if rule.StatusCode > 0 {
 				statusCode = rule.StatusCode
@@ -410,26 +421,17 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls
 		}
 	}
 	ctx = adapter.WithContext(ctx, &metadata)
-	var remoteConn net.Conn
-	if len(metadata.DestinationAddresses) > 0 || metadata.Destination.IsIP() {
-		remoteConn, err = dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
-	} else {
-		remoteConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
-	}
-	if err != nil {
-		return E.Cause(err, "open outbound connection")
-	}
-	defer remoteConn.Close()
 	var innerErr atomic.TypedValue[error]
 	httpClient := &http.Client{
 		Transport: &http.Transport{
-			DialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
-				if tlsConfig != nil {
-					return tls.Client(remoteConn, tlsConfig), nil
+			DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
+				if len(metadata.DestinationAddresses) > 0 || metadata.Destination.IsIP() {
+					return dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
 				} else {
-					return remoteConn, nil
+					return this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
 				}
 			},
+			TLSClientConfig: tlsConfig,
 		},
 		CheckRedirect: func(req *http.Request, via []*http.Request) error {
 			return http.ErrUseLastResponse
@@ -467,17 +469,27 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls
 		responseMatch = true
 		break
 	}
+	var responseBody []byte
+	if options.Print && response.ContentLength > 0 && response.ContentLength <= 131072 {
+		responseBody, err = io.ReadAll(response.Body)
+		if err != nil {
+			return E.Cause(err, "read HTTP response body")
+		}
+		response.Body = io.NopCloser(bytes.NewReader(responseBody))
+	}
+	if options.Print {
+		e.printResponse(ctx, response, responseBody)
+	}
 	if responseScript != nil {
-		var body []byte
-		if responseScript.RequiresBody() && response.ContentLength > 0 && (responseScript.MaxSize() == 0 && response.ContentLength <= 131072 || response.ContentLength <= responseScript.MaxSize()) {
-			body, err = io.ReadAll(response.Body)
+		if responseBody == nil && responseScript.RequiresBody() && response.ContentLength > 0 && (responseScript.MaxSize() == 0 && response.ContentLength <= 131072 || response.ContentLength <= responseScript.MaxSize()) {
+			responseBody, err = io.ReadAll(response.Body)
 			if err != nil {
 				return E.Cause(err, "read HTTP response body")
 			}
-			response.Body = io.NopCloser(bytes.NewReader(body))
+			response.Body = io.NopCloser(bytes.NewReader(responseBody))
 		}
 		var result *adapter.HTTPResponseScriptResult
-		result, err = responseScript.Run(ctx, request, response, body)
+		result, err = responseScript.Run(ctx, request, response, responseBody)
 		if err != nil {
 			return E.Cause(err, "execute script/", responseScript.Type(), "[", responseScript.Tag(), "]")
 		}
@@ -490,8 +502,9 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls
 		}
 		if result.Body != nil {
 			response.Body.Close()
-			response.Body = io.NopCloser(bytes.NewReader(result.Body))
-			response.ContentLength = int64(len(result.Body))
+			responseBody = result.Body
+			response.Body = io.NopCloser(bytes.NewReader(responseBody))
+			response.ContentLength = int64(len(responseBody))
 		}
 	}
 	if !responseMatch {
@@ -528,26 +541,27 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls
 			}
 			responseMatch = true
 			e.logger.DebugContext(ctx, "match body_rewrite[", i, "] => ", rule.String())
-			var body []byte
-			if response.ContentLength <= 0 {
-				e.logger.WarnContext(ctx, "body replace skipped due to non-fixed content length")
-				break
-			} else if response.ContentLength > 131072 {
-				e.logger.WarnContext(ctx, "body replace skipped due to large content length: ", request.ContentLength)
-				break
-			}
-			body, err = io.ReadAll(response.Body)
-			if err != nil {
-				return E.Cause(err, "read HTTP request body")
+			if responseBody == nil {
+				if response.ContentLength <= 0 {
+					e.logger.WarnContext(ctx, "body replace skipped due to non-fixed content length")
+					break
+				} else if response.ContentLength > 131072 {
+					e.logger.WarnContext(ctx, "body replace skipped due to large content length: ", request.ContentLength)
+					break
+				}
+				responseBody, err = io.ReadAll(response.Body)
+				if err != nil {
+					return E.Cause(err, "read HTTP request body")
+				}
 			}
 			for mi := 0; i < len(rule.Match); i++ {
-				body = rule.Match[mi].ReplaceAll(body, []byte(rule.Replace[i]))
+				responseBody = rule.Match[mi].ReplaceAll(responseBody, []byte(rule.Replace[i]))
 			}
-			response.Body = io.NopCloser(bytes.NewReader(body))
-			response.ContentLength = int64(len(body))
+			response.Body = io.NopCloser(bytes.NewReader(responseBody))
+			response.ContentLength = int64(len(responseBody))
 		}
 	}
-	if !requestMatch && !responseMatch {
+	if !options.Print && !requestMatch && !responseMatch {
 		e.logger.WarnContext(ctx, "request not modified")
 	}
 	err = response.Write(conn)
@@ -559,12 +573,13 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls
 	return nil
 }
 
-func (e *Engine) newHTTP2(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error {
+func (e *Engine) newHTTP2(ctx context.Context, this N.Dialer, conn net.Conn, tlsConfig *tls.Config, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error {
 	handler := &engineHandler{
-		Engine:   e,
-		conn:     conn,
-		dialer:   this,
-		metadata: metadata,
+		Engine:    e,
+		conn:      conn,
+		tlsConfig: tlsConfig,
+		dialer:    this,
+		metadata:  metadata,
 		httpClient: &http.Client{
 			Transport: &http2.Transport{
 				AllowHTTP:        true,
@@ -585,6 +600,7 @@ func (e *Engine) newHTTP2(ctx context.Context, this N.Dialer, conn net.Conn, met
 					}
 					return tls.Client(remoteConn, cfg), nil
 				},
+				TLSClientConfig: tlsConfig,
 			},
 			CheckRedirect: func(req *http.Request, via []*http.Request) error {
 				return http.ErrUseLastResponse
@@ -604,17 +620,18 @@ func (e *Engine) newHTTP2(ctx context.Context, this N.Dialer, conn net.Conn, met
 
 type engineHandler struct {
 	*Engine
-	conn     net.Conn
-	dialer   N.Dialer
-	metadata adapter.InboundContext
-	onClose  N.CloseHandlerFunc
-
+	conn       net.Conn
+	tlsConfig  *tls.Config
+	dialer     N.Dialer
+	metadata   adapter.InboundContext
+	onClose    N.CloseHandlerFunc
 	httpClient *http.Client
 }
 
 func (e *engineHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
 	err := e.serveHTTP(request.Context(), writer, request)
 	if err != nil {
+		e.conn.Close()
 		if E.IsClosedOrCanceled(err) {
 			e.logger.DebugContext(request.Context(), E.Cause(err, "connection closed"))
 		} else {
@@ -625,7 +642,6 @@ func (e *engineHandler) ServeHTTP(writer http.ResponseWriter, request *http.Requ
 
 func (e *engineHandler) serveHTTP(ctx context.Context, writer http.ResponseWriter, request *http.Request) error {
 	options := e.metadata.MITM
-	e.metadata.MITM = nil
 	rawRequestURL := request.URL
 	rawRequestURL.Scheme = "https"
 	if rawRequestURL.Host == "" {
@@ -657,10 +673,23 @@ func (e *engineHandler) serveHTTP(ctx context.Context, writer http.ResponseWrite
 		requestMatch = true
 		break
 	}
-	var err error
+	var (
+		body []byte
+		err  error
+	)
+	if options.Print && request.ContentLength > 0 && request.ContentLength <= 131072 {
+		body, err = io.ReadAll(request.Body)
+		if err != nil {
+			return E.Cause(err, "read HTTP request body")
+		}
+		request.Body.Close()
+		request.Body = io.NopCloser(bytes.NewReader(body))
+	}
+	if options.Print {
+		e.printRequest(ctx, request, body)
+	}
 	if requestScript != nil {
-		var body []byte
-		if requestScript.RequiresBody() && request.ContentLength > 0 && (requestScript.MaxSize() == 0 && request.ContentLength <= 131072 || request.ContentLength <= requestScript.MaxSize()) {
+		if body == nil && requestScript.RequiresBody() && request.ContentLength > 0 && (requestScript.MaxSize() == 0 && request.ContentLength <= 131072 || request.ContentLength <= requestScript.MaxSize()) {
 			body, err = io.ReadAll(request.Body)
 			if err != nil {
 				return E.Cause(err, "read HTTP request body")
@@ -700,6 +729,7 @@ func (e *engineHandler) serveHTTP(ctx context.Context, writer http.ResponseWrite
 					newDestination.Port = e.metadata.Destination.Port
 				}
 				e.metadata.Destination = newDestination
+				e.tlsConfig.ServerName = newURL.Hostname()
 			}
 			for key, values := range result.Headers {
 				request.Header[key] = values
@@ -734,6 +764,7 @@ func (e *engineHandler) serveHTTP(ctx context.Context, writer http.ResponseWrite
 				newDestination.Port = e.metadata.Destination.Port
 			}
 			e.metadata.Destination = newDestination
+			e.tlsConfig.ServerName = rule.Destination.Hostname()
 			break
 		}
 		for i, rule := range options.SurgeHeaderRewrite {
@@ -876,18 +907,29 @@ func (e *engineHandler) serveHTTP(ctx context.Context, writer http.ResponseWrite
 		responseMatch = true
 		break
 	}
+	var responseBody []byte
+	if options.Print && response.ContentLength > 0 && response.ContentLength <= 131072 {
+		responseBody, err = io.ReadAll(response.Body)
+		if err != nil {
+			return E.Cause(err, "read HTTP response body")
+		}
+		response.Body.Close()
+		response.Body = io.NopCloser(bytes.NewReader(responseBody))
+	}
+	if options.Print {
+		e.printResponse(ctx, response, responseBody)
+	}
 	if responseScript != nil {
-		var body []byte
-		if responseScript.RequiresBody() && response.ContentLength > 0 && (responseScript.MaxSize() == 0 && response.ContentLength <= 131072 || response.ContentLength <= responseScript.MaxSize()) {
-			body, err = io.ReadAll(response.Body)
+		if responseBody == nil && responseScript.RequiresBody() && response.ContentLength > 0 && (responseScript.MaxSize() == 0 && response.ContentLength <= 131072 || response.ContentLength <= responseScript.MaxSize()) {
+			responseBody, err = io.ReadAll(response.Body)
 			if err != nil {
 				return E.Cause(err, "read HTTP response body")
 			}
 			response.Body.Close()
-			response.Body = io.NopCloser(bytes.NewReader(body))
+			response.Body = io.NopCloser(bytes.NewReader(responseBody))
 		}
 		var result *adapter.HTTPResponseScriptResult
-		result, err = responseScript.Run(ctx, request, response, body)
+		result, err = responseScript.Run(ctx, request, response, responseBody)
 		if err != nil {
 			return E.Cause(err, "execute script/", responseScript.Type(), "[", responseScript.Tag(), "]")
 		}
@@ -938,30 +980,31 @@ func (e *engineHandler) serveHTTP(ctx context.Context, writer http.ResponseWrite
 			}
 			responseMatch = true
 			e.logger.DebugContext(ctx, "match body_rewrite[", i, "] => ", rule.String())
-			var body []byte
-			if response.ContentLength <= 0 {
-				e.logger.WarnContext(ctx, "body replace skipped due to non-fixed content length")
-				break
-			} else if response.ContentLength > 131072 {
-				e.logger.WarnContext(ctx, "body replace skipped due to large content length: ", request.ContentLength)
-				break
-			}
-			body, err = io.ReadAll(response.Body)
-			if err != nil {
-				return E.Cause(err, "read HTTP request body")
+			if responseBody == nil {
+				if response.ContentLength <= 0 {
+					e.logger.WarnContext(ctx, "body replace skipped due to non-fixed content length")
+					break
+				} else if response.ContentLength > 131072 {
+					e.logger.WarnContext(ctx, "body replace skipped due to large content length: ", request.ContentLength)
+					break
+				}
+				responseBody, err = io.ReadAll(response.Body)
+				if err != nil {
+					return E.Cause(err, "read HTTP request body")
+				}
+				response.Body.Close()
 			}
-			response.Body.Close()
 			for mi := 0; i < len(rule.Match); i++ {
-				body = rule.Match[mi].ReplaceAll(body, []byte(rule.Replace[i]))
+				responseBody = rule.Match[mi].ReplaceAll(responseBody, []byte(rule.Replace[i]))
 			}
-			response.Body = io.NopCloser(bytes.NewReader(body))
-			response.ContentLength = int64(len(body))
+			response.Body = io.NopCloser(bytes.NewReader(responseBody))
+			response.ContentLength = int64(len(responseBody))
 		}
 	}
-	if !requestMatch && !responseMatch {
+	if !options.Print && !requestMatch && !responseMatch {
 		e.logger.WarnContext(ctx, "request not modified")
 	}
-	for key, values := range request.Header {
+	for key, values := range response.Header {
 		writer.Header()[key] = values
 	}
 	writer.WriteHeader(response.StatusCode)
@@ -973,6 +1016,45 @@ func (e *engineHandler) serveHTTP(ctx context.Context, writer http.ResponseWrite
 	return nil
 }
 
+func (e *Engine) printRequest(ctx context.Context, request *http.Request, body []byte) {
+	e.logger.TraceContext(ctx, "request: ", request.Proto, " ", request.Method, " ", request.URL.String())
+	if request.URL.Hostname() != "" && request.URL.Hostname() != request.Host {
+		e.logger.TraceContext(ctx, "request: ", "Host: ", request.Host)
+	}
+	for key, values := range request.Header {
+		for _, value := range values {
+			e.logger.TraceContext(ctx, "request: ", key, ": ", value)
+		}
+	}
+	if len(body) > 0 {
+		if !bytes.ContainsFunc(body, func(r rune) bool {
+			return !unicode.IsPrint(r) && !unicode.IsSpace(r)
+		}) {
+			e.logger.TraceContext(ctx, "request: body: ", string(body))
+		} else {
+			e.logger.TraceContext(ctx, "request: body unprintable")
+		}
+	}
+}
+
+func (e *Engine) printResponse(ctx context.Context, response *http.Response, body []byte) {
+	e.logger.TraceContext(ctx, "response: ", response.Proto, " ", response.Status)
+	for key, values := range response.Header {
+		for _, value := range values {
+			e.logger.TraceContext(ctx, "response: ", key, ": ", value)
+		}
+	}
+	if len(body) > 0 {
+		if !bytes.ContainsFunc(body, func(r rune) bool {
+			return !unicode.IsPrint(r) && !unicode.IsSpace(r)
+		}) {
+			e.logger.TraceContext(ctx, "response: ", string(body))
+		}
+	} else {
+		e.logger.TraceContext(ctx, "response: body unprintable")
+	}
+}
+
 type simpleResponseWriter struct {
 	statusCode int
 	header     http.Header

+ 1 - 0
option/mitm.go

@@ -18,6 +18,7 @@ type TLSDecryptionOptions struct {
 
 type MITMRouteOptions struct {
 	Enabled            bool                                       `json:"enabled,omitempty"`
+	Print              bool                                       `json:"print,omitempty"`
 	Script             badoption.Listable[string]                 `json:"script,omitempty"`
 	SurgeURLRewrite    badoption.Listable[SurgeURLRewriteLine]    `json:"sg_url_rewrite,omitempty"`
 	SurgeHeaderRewrite badoption.Listable[SurgeHeaderRewriteLine] `json:"sg_header_rewrite,omitempty"`