Browse Source

mitm: Add HTTP2 support

世界 8 months ago
parent
commit
74920b44ac
3 changed files with 420 additions and 10 deletions
  1. 1 1
      box.go
  2. 417 7
      mitm/engine.go
  3. 2 2
      option/mitm.go

+ 1 - 1
box.go

@@ -348,7 +348,7 @@ func New(options Options) (*Box, error) {
 		services = append(services, adapter.NewLifecycleService(ntpService, "ntp service"))
 	}
 	mitmOptions := common.PtrValueOrDefault(options.MITM)
-	var mitmEngine *mitm.Engine
+	var mitmEngine adapter.MITMEngine
 	if mitmOptions.Enabled {
 		engine, err := mitm.NewEngine(ctx, logFactory.NewLogger("mitm"), mitmOptions)
 		if err != nil {

+ 417 - 7
mitm/engine.go

@@ -8,6 +8,7 @@ import (
 	"crypto/x509"
 	"encoding/base64"
 	"io"
+	"math"
 	"mime"
 	"net"
 	"net/http"
@@ -32,6 +33,7 @@ import (
 	"github.com/sagernet/sing/service"
 
 	"golang.org/x/crypto/pkcs12"
+	"golang.org/x/net/http2"
 )
 
 var _ adapter.MITMEngine = (*Engine)(nil)
@@ -51,9 +53,9 @@ type Engine struct {
 
 func NewEngine(ctx context.Context, logger logger.ContextLogger, options option.MITMOptions) (*Engine, error) {
 	engine := &Engine{
-		ctx:    ctx,
-		logger: logger,
-		// http2Enabled: options.HTTP2Enabled,
+		ctx:          ctx,
+		logger:       logger,
+		http2Enabled: options.HTTP2Enabled,
 	}
 	if options.TLSDecryptionOptions != nil && options.TLSDecryptionOptions.Enabled {
 		pfxBytes, err := base64.StdEncoding.DecodeString(options.TLSDecryptionOptions.KeyPair)
@@ -265,7 +267,7 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls
 			}
 			if result.Body != nil {
 				request.Body = io.NopCloser(bytes.NewReader(result.Body))
-				request.ContentLength = int64(len(body))
+				request.ContentLength = int64(len(result.Body))
 			}
 		}
 	}
@@ -421,7 +423,6 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls
 	var innerErr atomic.TypedValue[error]
 	httpClient := &http.Client{
 		Transport: &http.Transport{
-			DisableCompression: true,
 			DialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
 				if tlsConfig != nil {
 					return tls.Client(remoteConn, tlsConfig), nil
@@ -558,8 +559,417 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls
 	return nil
 }
 
-func (e *Engine) newHTTP2(ctx context.Context, this N.Dialer, conn *tls.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error {
-	// TODO: implement http2 support
+func (e *Engine) newHTTP2(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error {
+	handler := &engineHandler{
+		Engine:   e,
+		conn:     conn,
+		dialer:   this,
+		metadata: metadata,
+		httpClient: &http.Client{
+			Transport: &http2.Transport{
+				AllowHTTP:        true,
+				MaxReadFrameSize: math.MaxUint32,
+				DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
+					ctx = adapter.WithContext(ctx, &metadata)
+					var (
+						remoteConn net.Conn
+						err        error
+					)
+					if len(metadata.DestinationAddresses) > 0 || metadata.Destination.IsIP() {
+						remoteConn, err = dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
+					} else {
+						remoteConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
+					}
+					if err != nil {
+						return nil, err
+					}
+					return tls.Client(remoteConn, cfg), nil
+				},
+			},
+			CheckRedirect: func(req *http.Request, via []*http.Request) error {
+				return http.ErrUseLastResponse
+			},
+		},
+		onClose: onClose,
+	}
+	http2Server := &http2.Server{
+		MaxReadFrameSize: math.MaxUint32,
+	}
+	http2Server.ServeConn(conn, &http2.ServeConnOpts{
+		Context: ctx,
+		Handler: handler,
+	})
+	return nil
+}
+
+type engineHandler struct {
+	*Engine
+	conn     net.Conn
+	dialer   N.Dialer
+	metadata adapter.InboundContext
+	onClose  N.CloseHandlerFunc
+
+	httpClient *http.Client
+}
+
+func (e *engineHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
+	err := e.serveHTTP(request.Context(), writer, request)
+	if err != nil {
+		if E.IsClosedOrCanceled(err) {
+			e.logger.DebugContext(request.Context(), E.Cause(err, "connection closed"))
+		} else {
+			e.logger.ErrorContext(request.Context(), err)
+		}
+	}
+}
+
+func (e *engineHandler) serveHTTP(ctx context.Context, writer http.ResponseWriter, request *http.Request) error {
+	options := e.metadata.MITM
+	e.metadata.MITM = nil
+	rawRequestURL := request.URL
+	rawRequestURL.Scheme = "https"
+	if rawRequestURL.Host == "" {
+		rawRequestURL.Host = request.Host
+	}
+	requestURL := rawRequestURL.String()
+	request.RequestURI = ""
+	var (
+		requestMatch  bool
+		requestScript adapter.HTTPRequestScript
+	)
+	for _, script := range e.script.Scripts() {
+		if !common.Contains(options.Script, script.Tag()) {
+			continue
+		}
+		httpScript, isHTTP := script.(adapter.HTTPRequestScript)
+		if !isHTTP {
+			_, isHTTP = script.(adapter.HTTPScript)
+			if !isHTTP {
+				e.logger.WarnContext(ctx, "specified script/", script.Type(), "[", script.Tag(), "] is not a HTTP request/response script")
+			}
+			continue
+		}
+		if !httpScript.Match(requestURL) {
+			continue
+		}
+		e.logger.DebugContext(ctx, "match script/", httpScript.Type(), "[", httpScript.Tag(), "]")
+		requestScript = httpScript
+		requestMatch = true
+		break
+	}
+	var err error
+	if requestScript != nil {
+		var body []byte
+		if requestScript.RequiresBody() && request.ContentLength > 0 && (requestScript.MaxSize() == 0 && request.ContentLength <= 131072 || request.ContentLength <= requestScript.MaxSize()) {
+			body, err = io.ReadAll(request.Body)
+			if err != nil {
+				return E.Cause(err, "read HTTP request body")
+			}
+			request.Body.Close()
+			request.Body = io.NopCloser(bytes.NewReader(body))
+		}
+		result, err := requestScript.Run(ctx, request, body)
+		if err != nil {
+			return E.Cause(err, "execute script/", requestScript.Type(), "[", requestScript.Tag(), "]")
+		}
+		if result.Response != nil {
+			if result.Response.Status == 0 {
+				result.Response.Status = http.StatusOK
+			}
+			for key, values := range result.Response.Headers {
+				writer.Header()[key] = values
+			}
+			writer.WriteHeader(result.Response.Status)
+			if result.Response.Body != nil {
+				_, err = writer.Write(result.Response.Body)
+				if err != nil {
+					return E.Cause(err, "write fake response body")
+				}
+			}
+			return nil
+		} else {
+			if result.URL != "" {
+				var newURL *url.URL
+				newURL, err = url.Parse(result.URL)
+				if err != nil {
+					return E.Cause(err, "parse updated request URL")
+				}
+				request.URL = newURL
+				newDestination := M.ParseSocksaddrHostPortStr(newURL.Hostname(), newURL.Port())
+				if newDestination.Port == 0 {
+					newDestination.Port = e.metadata.Destination.Port
+				}
+				e.metadata.Destination = newDestination
+			}
+			for key, values := range result.Headers {
+				request.Header[key] = values
+			}
+			if newHost := result.Headers.Get("Host"); newHost != "" {
+				request.Host = newHost
+				request.Header.Del("Host")
+			}
+			if result.Body != nil {
+				io.Copy(io.Discard, request.Body)
+				request.Body = io.NopCloser(bytes.NewReader(result.Body))
+				request.ContentLength = int64(len(result.Body))
+			}
+		}
+	}
+	if !requestMatch {
+		for i, rule := range options.SurgeURLRewrite {
+			if !rule.Pattern.MatchString(requestURL) {
+				continue
+			}
+			e.logger.DebugContext(ctx, "match url_rewrite[", i, "] => ", rule.String())
+			if rule.Reject {
+				return E.New("request rejected by url_rewrite")
+			} else if rule.Redirect {
+				http.Redirect(writer, request, rule.Destination.String(), http.StatusFound)
+				return nil
+			}
+			requestMatch = true
+			request.URL = rule.Destination
+			newDestination := M.ParseSocksaddrHostPortStr(rule.Destination.Hostname(), rule.Destination.Port())
+			if newDestination.Port == 0 {
+				newDestination.Port = e.metadata.Destination.Port
+			}
+			e.metadata.Destination = newDestination
+			break
+		}
+		for i, rule := range options.SurgeHeaderRewrite {
+			if rule.Response {
+				continue
+			}
+			if !rule.Pattern.MatchString(requestURL) {
+				continue
+			}
+			requestMatch = true
+			e.logger.DebugContext(ctx, "match header_rewrite[", i, "] => ", rule.String())
+			switch {
+			case rule.Add:
+				if strings.ToLower(rule.Key) == "host" {
+					request.Host = rule.Value
+					continue
+				}
+				request.Header.Add(rule.Key, rule.Value)
+			case rule.Delete:
+				request.Header.Del(rule.Key)
+			case rule.Replace:
+				if request.Header.Get(rule.Key) != "" {
+					request.Header.Set(rule.Key, rule.Value)
+				}
+			case rule.ReplaceRegex:
+				if value := request.Header.Get(rule.Key); value != "" {
+					request.Header.Set(rule.Key, rule.Match.ReplaceAllString(value, rule.Value))
+				}
+			}
+		}
+		for i, rule := range options.SurgeBodyRewrite {
+			if rule.Response {
+				continue
+			}
+			if !rule.Pattern.MatchString(requestURL) {
+				continue
+			}
+			requestMatch = true
+			e.logger.DebugContext(ctx, "match body_rewrite[", i, "] => ", rule.String())
+			var body []byte
+			if request.ContentLength <= 0 {
+				e.logger.WarnContext(ctx, "body replace skipped due to non-fixed content length")
+				break
+			} else if request.ContentLength > 131072 {
+				e.logger.WarnContext(ctx, "body replace skipped due to large content length: ", request.ContentLength)
+				break
+			}
+			body, err := io.ReadAll(request.Body)
+			if err != nil {
+				return E.Cause(err, "read HTTP request body")
+			}
+			request.Body.Close()
+			for mi := 0; i < len(rule.Match); i++ {
+				body = rule.Match[mi].ReplaceAll(body, []byte(rule.Replace[i]))
+			}
+			request.Body = io.NopCloser(bytes.NewReader(body))
+			request.ContentLength = int64(len(body))
+		}
+	}
+	if !requestMatch {
+		for i, rule := range options.SurgeMapLocal {
+			if !rule.Pattern.MatchString(requestURL) {
+				continue
+			}
+			requestMatch = true
+			e.logger.DebugContext(ctx, "match map_local[", i, "] => ", rule.String())
+			go func() {
+				io.Copy(io.Discard, request.Body)
+				request.Body.Close()
+			}()
+			var (
+				statusCode = http.StatusOK
+				headers    = make(http.Header)
+				body       []byte
+			)
+			if rule.StatusCode > 0 {
+				statusCode = rule.StatusCode
+			}
+			switch {
+			case rule.File:
+				resource, err := os.ReadFile(rule.Data)
+				if err != nil {
+					return E.Cause(err, "open map local source")
+				}
+				mimeType := mime.TypeByExtension(filepath.Ext(rule.Data))
+				if mimeType == "" {
+					mimeType = "application/octet-stream"
+				}
+				headers.Set("Content-Type", mimeType)
+				body = resource
+			case rule.Text:
+				headers.Set("Content-Type", "text/plain")
+				body = []byte(rule.Data)
+			case rule.TinyGif:
+				headers.Set("Content-Type", "image/gif")
+				body = surgeTinyGif()
+			case rule.Base64:
+				headers.Set("Content-Type", "application/octet-stream")
+				body = rule.Base64Data
+			}
+			for key, values := range headers {
+				writer.Header()[key] = values
+			}
+			writer.WriteHeader(statusCode)
+			_, err = writer.Write(body)
+			if err != nil {
+				return E.Cause(err, "write map local response")
+			}
+			return nil
+		}
+	}
+	requestCtx, cancel := context.WithCancel(ctx)
+	defer cancel()
+	response, err := e.httpClient.Do(request.WithContext(requestCtx))
+	if err != nil {
+		cancel()
+		return E.Cause(err, "exchange request")
+	}
+	var (
+		responseScript adapter.HTTPResponseScript
+		responseMatch  bool
+	)
+	for _, script := range e.script.Scripts() {
+		if !common.Contains(options.Script, script.Tag()) {
+			continue
+		}
+		httpScript, isHTTP := script.(adapter.HTTPResponseScript)
+		if !isHTTP {
+			_, isHTTP = script.(adapter.HTTPScript)
+			if !isHTTP {
+				e.logger.WarnContext(ctx, "specified script/", script.Type(), "[", script.Tag(), "] is not a HTTP request/response script")
+			}
+			continue
+		}
+		if !httpScript.Match(requestURL) {
+			continue
+		}
+		e.logger.DebugContext(ctx, "match script/", httpScript.Type(), "[", httpScript.Tag(), "]")
+		responseScript = httpScript
+		responseMatch = true
+		break
+	}
+	if responseScript != nil {
+		var body []byte
+		if responseScript.RequiresBody() && response.ContentLength > 0 && (responseScript.MaxSize() == 0 && response.ContentLength <= 131072 || response.ContentLength <= responseScript.MaxSize()) {
+			body, err = io.ReadAll(response.Body)
+			if err != nil {
+				return E.Cause(err, "read HTTP response body")
+			}
+			response.Body.Close()
+			response.Body = io.NopCloser(bytes.NewReader(body))
+		}
+		var result *adapter.HTTPResponseScriptResult
+		result, err = responseScript.Run(ctx, request, response, body)
+		if err != nil {
+			return E.Cause(err, "execute script/", responseScript.Type(), "[", responseScript.Tag(), "]")
+		}
+		if result.Status > 0 {
+			response.Status = http.StatusText(result.Status)
+			response.StatusCode = result.Status
+		}
+		for key, values := range result.Headers {
+			response.Header[key] = values
+		}
+		if result.Body != nil {
+			response.Body.Close()
+			response.Body = io.NopCloser(bytes.NewReader(result.Body))
+			response.ContentLength = int64(len(result.Body))
+		}
+	}
+	if !responseMatch {
+		for i, rule := range options.SurgeHeaderRewrite {
+			if !rule.Response {
+				continue
+			}
+			if !rule.Pattern.MatchString(requestURL) {
+				continue
+			}
+			responseMatch = true
+			e.logger.DebugContext(ctx, "match header_rewrite[", i, "] => ", rule.String())
+			switch {
+			case rule.Add:
+				response.Header.Add(rule.Key, rule.Value)
+			case rule.Delete:
+				response.Header.Del(rule.Key)
+			case rule.Replace:
+				if response.Header.Get(rule.Key) != "" {
+					response.Header.Set(rule.Key, rule.Value)
+				}
+			case rule.ReplaceRegex:
+				if value := response.Header.Get(rule.Key); value != "" {
+					response.Header.Set(rule.Key, rule.Match.ReplaceAllString(value, rule.Value))
+				}
+			}
+		}
+		for i, rule := range options.SurgeBodyRewrite {
+			if !rule.Response {
+				continue
+			}
+			if !rule.Pattern.MatchString(requestURL) {
+				continue
+			}
+			responseMatch = true
+			e.logger.DebugContext(ctx, "match body_rewrite[", i, "] => ", rule.String())
+			var body []byte
+			if response.ContentLength <= 0 {
+				e.logger.WarnContext(ctx, "body replace skipped due to non-fixed content length")
+				break
+			} else if response.ContentLength > 131072 {
+				e.logger.WarnContext(ctx, "body replace skipped due to large content length: ", request.ContentLength)
+				break
+			}
+			body, err = io.ReadAll(response.Body)
+			if err != nil {
+				return E.Cause(err, "read HTTP request body")
+			}
+			response.Body.Close()
+			for mi := 0; i < len(rule.Match); i++ {
+				body = rule.Match[mi].ReplaceAll(body, []byte(rule.Replace[i]))
+			}
+			response.Body = io.NopCloser(bytes.NewReader(body))
+			response.ContentLength = int64(len(body))
+		}
+	}
+	if !requestMatch && !responseMatch {
+		e.logger.WarnContext(ctx, "request not modified")
+	}
+	for key, values := range request.Header {
+		writer.Header()[key] = values
+	}
+	writer.WriteHeader(response.StatusCode)
+	_, err = io.Copy(writer, response.Body)
+	response.Body.Close()
+	if err != nil {
+		return E.Cause(err, "write HTTP response")
+	}
 	return nil
 }
 

+ 2 - 2
option/mitm.go

@@ -5,8 +5,8 @@ import (
 )
 
 type MITMOptions struct {
-	Enabled bool `json:"enabled,omitempty"`
-	// HTTP2Enabled         bool                  `json:"http2_enabled,omitempty"`
+	Enabled              bool                  `json:"enabled,omitempty"`
+	HTTP2Enabled         bool                  `json:"http2_enabled,omitempty"`
 	TLSDecryptionOptions *TLSDecryptionOptions `json:"tls_decryption,omitempty"`
 }