瀏覽代碼

fix: block on tcp loop

zu1k 3 年之前
父節點
當前提交
18a6cd4531
共有 5 個文件被更改,包括 41 次插入78 次删除
  1. 18 49
      crates/core/src/lib.rs
  2. 18 7
      crates/core/src/mitm.rs
  3. 1 0
      docs/_sidebar.md
  4. 3 3
      docs/guide/transparent_proxy.md
  5. 1 19
      src/main.rs

+ 18 - 49
crates/core/src/lib.rs

@@ -1,11 +1,6 @@
 use error::Error;
 use handler::{CustomContextData, HttpHandler, MitmFilter};
 use http_client::gen_client;
-use hyper::{
-    server::conn::AddrStream,
-    service::{make_service_fn, service_fn},
-    Server,
-};
 use hyper_proxy::Proxy as UpstreamProxy;
 use mitm::MitmProxy;
 use std::{future::Future, marker::PhantomData, net::SocketAddr, sync::Arc};
@@ -54,64 +49,38 @@ where
     pub async fn start_proxy(self) -> Result<(), Error> {
         let client = gen_client(self.upstream_proxy)?;
         let ca = Arc::new(self.ca);
-
         let http_handler = Arc::new(self.handler);
         let mitm_filter = Arc::new(MitmFilter::new(self.mitm_filters));
 
-        let make_service = make_service_fn(move |_conn: &AddrStream| {
+        let tcp_listener = TcpListener::bind(self.listen_addr).await?;
+        loop {
             let client = client.clone();
             let ca = Arc::clone(&ca);
             let http_handler = Arc::clone(&http_handler);
             let mitm_filter = Arc::clone(&mitm_filter);
 
-            // TODO: conn tls or http?
-
-            async move {
-                Ok::<_, Error>(service_fn(move |req| {
-                    MitmProxy {
-                        ca: Arc::clone(&ca),
+            if let Ok((tcp_stream, _)) = tcp_listener.accept().await {
+                tokio::spawn(async move {
+                    let mitm_proxy = MitmProxy {
+                        ca: ca.clone(),
                         client: client.clone(),
-
                         http_handler: Arc::clone(&http_handler),
                         mitm_filter: Arc::clone(&mitm_filter),
-
                         custom_contex_data: Default::default(),
+                    };
+
+                    let mut tls_content_type = [0; 1];
+                    if tcp_stream.peek(&mut tls_content_type).await.is_ok() {
+                        if tls_content_type[0] <= 0x40 {
+                            // ASCII < 'A', assuming tls
+                            mitm_proxy.serve_tls(tcp_stream).await;
+                        } else {
+                            // assuming http
+                            _ = mitm_proxy.serve_stream(tcp_stream).await;
+                        }
                     }
-                    .proxy(req)
-                }))
-            }
-        });
-
-        Server::bind(&self.listen_addr)
-            .http1_preserve_header_case(true)
-            .http1_title_case_headers(true)
-            .serve(make_service)
-            .with_graceful_shutdown(self.shutdown_signal)
-            .await
-            .map_err(Error::from)
-    }
-
-    pub async fn start_https_transparent_proxy(self) -> Result<(), Error> {
-        let client = gen_client(self.upstream_proxy)?;
-        let ca = Arc::new(self.ca);
-        let http_handler = Arc::new(self.handler);
-        let mitm_filter = Arc::new(MitmFilter::new(self.mitm_filters));
-
-        let tcp_listener = TcpListener::bind(self.listen_addr).await?;
-
-        loop {
-            let (tcp_stream, _) = tcp_listener.accept().await?;
-            MitmProxy {
-                ca: Arc::clone(&ca),
-                client: client.clone(),
-
-                http_handler: Arc::clone(&http_handler),
-                mitm_filter: Arc::clone(&mitm_filter),
-
-                custom_contex_data: Default::default(),
+                });
             }
-            .serve_tls(tcp_stream)
-            .await;
         }
     }
 }

+ 18 - 7
crates/core/src/mitm.rs

@@ -52,7 +52,10 @@ where
     H: HttpHandler<D>,
     D: CustomContextData,
 {
-    pub(crate) async fn proxy(self, req: Request<Body>) -> Result<Response<Body>, hyper::Error> {
+    pub(crate) async fn proxy_req(
+        self,
+        req: Request<Body>,
+    ) -> Result<Response<Body>, hyper::Error> {
         let res = if req.method() == Method::CONNECT {
             self.process_connect(req).await
         } else {
@@ -177,7 +180,16 @@ where
 
         match TlsAcceptor::from(server_config).accept(stream).await {
             Ok(stream) => {
-                if let Err(e) = self.serve_stream(stream, Scheme::HTTPS).await {
+                if let Err(e) = Http::new()
+                    .http1_preserve_header_case(true)
+                    .http1_title_case_headers(true)
+                    .serve_connection(
+                        stream,
+                        service_fn(|req| self.clone().process_request(req, Scheme::HTTPS)),
+                    )
+                    .with_upgrades()
+                    .await
+                {
                     let e_string = e.to_string();
                     if !e_string.starts_with("error shutting down connection") {
                         debug!("res:: {}", e);
@@ -190,15 +202,14 @@ where
         }
     }
 
-    async fn serve_stream<S>(self, stream: S, scheme: Scheme) -> Result<(), hyper::Error>
+    pub async fn serve_stream<S>(self, stream: S) -> Result<(), hyper::Error>
     where
         S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
     {
         Http::new()
-            .serve_connection(
-                stream,
-                service_fn(|req| self.clone().process_request(req, scheme.clone())),
-            )
+            .http1_preserve_header_case(true)
+            .http1_title_case_headers(true)
+            .serve_connection(stream, service_fn(|req| self.clone().proxy_req(req)))
             .with_upgrades()
             .await
     }

+ 1 - 0
docs/_sidebar.md

@@ -1,6 +1,7 @@
 * [介绍](/)
 * [指南](guide/README.md)
   * [证书准备](guide/0_cert.md)
+  * [透明代理](guide/transparent_proxy.md)
 * [Rule 规则](rule/README.md)
   * [Filter 筛选器](rule/filter.md)
   * [Action 动作](rule/action.md)

+ 3 - 3
docs/guide/transparent_proxy.md

@@ -8,8 +8,8 @@ sudo sysctl -w net.ipv6.conf.all.forwarding=1
 sudo sysctl -w net.ipv4.conf.all.send_redirects=0
 
 sudo useradd --create-home mitm
-sudo -u mitm -H bash -c 'good-mitm run -r rules/log.yaml -b 0.0.0.0:8080'
+sudo -u mitm -H bash -c 'good-mitm run -r rules/log.yaml -b 0.0.0.0:34567'
 
-iptables -t nat -A OUTPUT -p tcp -m owner ! --uid-owner mitm --dport 80 -j REDIRECT --to-port 8080
-iptables -t nat -A OUTPUT -p tcp -m owner ! --uid-owner mitm --dport 443 -j REDIRECT --to-port 8081
+iptables -t nat -A OUTPUT -p tcp -m owner ! --uid-owner mitm --dport 80 -j REDIRECT --to-port 34567
+iptables -t nat -A OUTPUT -p tcp -m owner ! --uid-owner mitm --dport 443 -j REDIRECT --to-port 34567
 ```

+ 1 - 19
src/main.rs

@@ -7,7 +7,7 @@ use log::*;
 use mitm_core::{CertificateAuthority, Proxy};
 use rule::RuleHttpHandler;
 use rustls_pemfile as pemfile;
-use std::{fs, net::SocketAddr, sync::Arc};
+use std::{fs, sync::Arc};
 
 use good_mitm::*;
 
@@ -111,24 +111,6 @@ async fn run(opts: &Run) -> Result<()> {
 
     tokio::spawn(proxy.start_proxy());
 
-    let mut bind: SocketAddr = opts.bind.parse().expect("bind address not valid!");
-    bind.set_port(bind.port() + 1);
-    info!("Https Transparent Proxy listen on: {}", bind);
-    let proxy2 = Proxy::builder()
-        .ca(ca)
-        .listen_addr(bind)
-        .upstream_proxy(
-            opts.proxy
-                .clone()
-                .map(|proxy| hyper_proxy::Proxy::new(Intercept::All, proxy.parse().unwrap())),
-        )
-        .shutdown_signal(shutdown_signal())
-        .mitm_filters(mitm_filters)
-        .handler(http_handler)
-        .build();
-
-    tokio::spawn(proxy2.start_https_transparent_proxy());
-
     tokio::signal::ctrl_c()
         .await
         .expect("failed to listen for event");