Przeglądaj źródła

feat: selective MITM based on SNI

zu1k 2 lat temu
rodzic
commit
e0ad482d12

+ 28 - 0
Cargo.lock

@@ -247,6 +247,12 @@ version = "0.6.3"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "2c676a478f63e9fa2dd5368a42f28bba0d6c560b775f38583c8bbaa7fcd67c9c"
 
+[[package]]
+name = "byteorder"
+version = "1.4.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610"
+
 [[package]]
 name = "bytes"
 version = "1.4.0"
@@ -812,6 +818,7 @@ name = "good-mitm-core"
 version = "0.4.1"
 dependencies = [
  "async-trait",
+ "byteorder",
  "bytes",
  "cfg-if",
  "http",
@@ -822,6 +829,7 @@ dependencies = [
  "log",
  "moka",
  "openssl",
+ "pin-project",
  "rand",
  "rcgen",
  "rustls",
@@ -1445,6 +1453,26 @@ dependencies = [
  "base64 0.13.1",
 ]
 
+[[package]]
+name = "pin-project"
+version = "1.1.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "030ad2bc4db10a8944cb0d837f158bdfec4d4a4873ab701a95046770d11f8842"
+dependencies = [
+ "pin-project-internal",
+]
+
+[[package]]
+name = "pin-project-internal"
+version = "1.1.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ec2e072ecce94ec471b13398d5402c188e76ac03cf74dd1a975161b23a3f6d9c"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.22",
+]
+
 [[package]]
 name = "pin-project-lite"
 version = "0.2.9"

+ 2 - 0
crates/core/Cargo.toml

@@ -10,6 +10,7 @@ license = "MIT"
 [dependencies]
 async-trait = "0.1"
 bytes = { version = "1", features = ["serde"] }
+byteorder = "1.4"
 cfg-if = "1"
 http = "0.2"
 hyper = { version = "0.14", features = ["http1", "http2", "server", "stream", "tcp", "runtime"]  }
@@ -19,6 +20,7 @@ hyper-tls = { version = "0.5", optional = true }
 log = "0.4"
 moka = { version = "0.11", features = ["future"] }
 openssl = { version = "0.10", features = ["vendored"], optional = true }
+pin-project = "1"
 rcgen = { version = "0.10", features = ["x509-parser"] }
 serde = { version = "1.0", features = ["derive"] }
 thiserror = "1"

+ 11 - 1
crates/core/src/handler.rs

@@ -45,7 +45,7 @@ impl<D: CustomContextData> MitmFilter<D> {
         }
     }
 
-    pub async fn filter(&self, _ctx: &HttpContext<D>, req: &Request<Body>) -> bool {
+    pub async fn filter_req(&self, _ctx: &HttpContext<D>, req: &Request<Body>) -> bool {
         let host = req.uri().host().unwrap_or_default();
         let list = self.filters.read().unwrap();
         for m in list.iter() {
@@ -55,4 +55,14 @@ impl<D: CustomContextData> MitmFilter<D> {
         }
         false
     }
+
+    pub async fn filter(&self, host: &str) -> bool {
+        let list = self.filters.read().unwrap();
+        for m in list.iter() {
+            if m.matches(host) {
+                return true;
+            }
+        }
+        false
+    }
 }

+ 1 - 0
crates/core/src/lib.rs

@@ -17,6 +17,7 @@ mod error;
 pub mod handler;
 mod http_client;
 pub mod mitm;
+mod sni_reader;
 
 #[derive(TypedBuilder)]
 pub struct Proxy<F, H, D>

+ 39 - 8
crates/core/src/mitm.rs

@@ -2,17 +2,21 @@ use crate::{
     ca::CertificateAuthority,
     handler::{CustomContextData, HttpHandler, MitmFilter},
     http_client::HttpClient,
+    sni_reader::{
+        read_sni_host_name_from_client_hello, HandshakeRecordReader, PrefixedReaderWriter,
+        RecordingBufReader,
+    },
 };
 use http::{header, uri::Scheme, HeaderValue, Uri};
 use hyper::{
-    body::HttpBody, server::conn::Http, service::service_fn, upgrade::Upgraded, Body, Method,
-    Request, Response,
+    body::HttpBody, server::conn::Http, service::service_fn, Body, Method, Request, Response,
 };
 use log::*;
-use std::{marker::PhantomData, sync::Arc};
+use std::{marker::PhantomData, sync::Arc, time::Duration};
 use tokio::{
     io::{AsyncRead, AsyncWrite},
     net::TcpStream,
+    pin,
 };
 use tokio_rustls::TlsAcceptor;
 
@@ -150,7 +154,7 @@ where
             ..Default::default()
         };
 
-        if self.mitm_filter.filter(&ctx, &req).await {
+        if self.mitm_filter.filter_req(&ctx, &req).await {
             tokio::task::spawn(async move {
                 let authority = req
                     .uri()
@@ -175,10 +179,34 @@ where
         Ok(Response::new(Body::empty()))
     }
 
-    pub async fn serve_tls<IO: AsyncRead + AsyncWrite + Unpin + Send + 'static>(self, stream: IO) {
+    pub async fn serve_tls<IO: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
+        self,
+        mut stream: IO,
+    ) {
+        // Read SNI hostname.
+        let mut recording_reader = RecordingBufReader::new(&mut stream);
+        let reader = HandshakeRecordReader::new(&mut recording_reader);
+        pin!(reader);
+        let sni_hostname = tokio::time::timeout(
+            Duration::from_secs(5),
+            read_sni_host_name_from_client_hello(reader),
+        )
+        .await
+        .unwrap()
+        .unwrap();
+
+        let read_buf = recording_reader.buf();
+        let client_stream = PrefixedReaderWriter::new(stream, read_buf);
+
+        if !self.mitm_filter.filter(&sni_hostname).await {
+            let remote_addr = format!("{sni_hostname}:443");
+            tokio::task::spawn(async move { tunnel(client_stream, remote_addr).await });
+            return;
+        }
+
         let server_config = self.ca.clone().gen_server_config();
 
-        match TlsAcceptor::from(server_config).accept(stream).await {
+        match TlsAcceptor::from(server_config).accept(client_stream).await {
             Ok(stream) => {
                 if let Err(e) = Http::new()
                     .http1_preserve_header_case(true)
@@ -239,8 +267,11 @@ fn host_addr(uri: &http::Uri) -> Option<String> {
     uri.authority().map(|auth| auth.to_string())
 }
 
-async fn tunnel(mut upgraded: Upgraded, addr: String) -> std::io::Result<()> {
+async fn tunnel<A>(mut client_stream: A, addr: String) -> std::io::Result<()>
+where
+    A: AsyncRead + AsyncWrite + Unpin,
+{
     let mut server = TcpStream::connect(addr).await?;
-    tokio::io::copy_bidirectional(&mut upgraded, &mut server).await?;
+    tokio::io::copy_bidirectional(&mut client_stream, &mut server).await?;
     Ok(())
 }

+ 365 - 0
crates/core/src/sni_reader.rs

@@ -0,0 +1,365 @@
+/// from https://github.com/branlwyd/rspd/blob/master/src/main.rs
+use byteorder::{ByteOrder, NetworkEndian};
+use pin_project::pin_project;
+use std::{
+    cmp::min,
+    io::ErrorKind,
+    mem,
+    pin::Pin,
+    task::{Context, Poll},
+};
+use tokio::{
+    io::{self, AsyncRead, AsyncReadExt, AsyncWrite, Error, ReadBuf},
+    pin,
+};
+
+#[pin_project]
+pub struct RecordingBufReader<R: AsyncRead> {
+    #[pin]
+    reader: R,
+    buf: Vec<u8>,
+    read_offset: usize,
+}
+
+const RECORDING_READER_BUF_SIZE: usize = 1 << 10; // 1 KiB
+
+impl<R: AsyncRead> RecordingBufReader<R> {
+    pub fn new(reader: R) -> RecordingBufReader<R> {
+        RecordingBufReader {
+            reader,
+            buf: Vec::with_capacity(RECORDING_READER_BUF_SIZE),
+            read_offset: 0,
+        }
+    }
+
+    pub fn buf(self) -> Vec<u8> {
+        self.buf
+    }
+}
+
+impl<R: AsyncRead> AsyncRead for RecordingBufReader<R> {
+    fn poll_read(
+        self: Pin<&mut Self>,
+        cx: &mut Context<'_>,
+        caller_buf: &mut ReadBuf<'_>,
+    ) -> Poll<Result<(), Error>> {
+        // if we don't have any buffered bytes, read some bytes into our buffer.
+        let mut this = self.project();
+        if *this.read_offset == this.buf.len() {
+            this.buf.reserve(RECORDING_READER_BUF_SIZE);
+            let mut read_buf = ReadBuf::uninit(this.buf.spare_capacity_mut());
+            match this.reader.as_mut().poll_read(cx, &mut read_buf) {
+                Poll::Ready(Ok(())) => {
+                    let bytes_read = read_buf.filled().len();
+                    let new_len = this.buf.len() + bytes_read;
+                    unsafe {
+                        this.buf.set_len(new_len); // lol
+                    }
+                }
+                rslt => return rslt,
+            };
+        }
+
+        // read from the buffered bytes into the caller's buffer.
+        let unread_bytes = &this.buf[*this.read_offset..];
+        let n = min(caller_buf.remaining(), unread_bytes.len());
+        caller_buf.put_slice(&unread_bytes[..n]);
+        *this.read_offset += n;
+        Poll::Ready(Ok(()))
+    }
+}
+
+#[pin_project]
+pub struct PrefixedReaderWriter<T: AsyncRead + AsyncWrite> {
+    #[pin]
+    inner: T,
+    prefix: Vec<u8>,
+    prefix_read_offset: usize,
+}
+
+impl<T: AsyncRead + AsyncWrite> PrefixedReaderWriter<T> {
+    pub fn new(inner: T, prefix: Vec<u8>) -> PrefixedReaderWriter<T> {
+        PrefixedReaderWriter {
+            inner,
+            prefix,
+            prefix_read_offset: 0,
+        }
+    }
+}
+
+impl<T: AsyncRead + AsyncWrite> AsyncRead for PrefixedReaderWriter<T> {
+    fn poll_read(
+        self: Pin<&mut Self>,
+        cx: &mut Context<'_>,
+        buf: &mut ReadBuf<'_>,
+    ) -> Poll<std::io::Result<()>> {
+        let this = self.project();
+        if this.prefix.is_empty() {
+            return this.inner.poll_read(cx, buf);
+        }
+
+        let prefix = &this.prefix[*this.prefix_read_offset..];
+        let n = min(buf.remaining(), prefix.len());
+        buf.put_slice(&prefix[..n]);
+        *this.prefix_read_offset += n;
+
+        if *this.prefix_read_offset == this.prefix.len() {
+            mem::take(this.prefix);
+        }
+
+        Poll::Ready(Ok(()))
+    }
+}
+
+impl<T: AsyncRead + AsyncWrite> AsyncWrite for PrefixedReaderWriter<T> {
+    fn poll_write(
+        self: Pin<&mut Self>,
+        cx: &mut Context<'_>,
+        buf: &[u8],
+    ) -> Poll<Result<usize, std::io::Error>> {
+        let this = self.project();
+        this.inner.poll_write(cx, buf)
+    }
+
+    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
+        let this = self.project();
+        this.inner.poll_flush(cx)
+    }
+
+    fn poll_shutdown(
+        self: Pin<&mut Self>,
+        cx: &mut Context<'_>,
+    ) -> Poll<Result<(), std::io::Error>> {
+        let this = self.project();
+        this.inner.poll_shutdown(cx)
+    }
+}
+
+#[pin_project]
+pub struct HandshakeRecordReader<R: AsyncRead> {
+    #[pin]
+    reader: R,
+    currently_reading: HandshakeRecordReaderReading,
+}
+
+impl<R: AsyncRead> HandshakeRecordReader<R> {
+    pub fn new(reader: R) -> HandshakeRecordReader<R> {
+        HandshakeRecordReader {
+            reader,
+            currently_reading: HandshakeRecordReaderReading::ContentType,
+        }
+    }
+}
+
+enum HandshakeRecordReaderReading {
+    ContentType,
+    MajorMinorVersion(usize),
+    RecordSize([u8; 2], usize),
+    Record(usize),
+}
+
+impl<R: AsyncRead> AsyncRead for HandshakeRecordReader<R> {
+    fn poll_read(
+        self: Pin<&mut Self>,
+        cx: &mut Context<'_>,
+        caller_buf: &mut ReadBuf<'_>,
+    ) -> Poll<Result<(), Error>> {
+        let mut this = self.project();
+        loop {
+            match this.currently_reading {
+                HandshakeRecordReaderReading::ContentType => {
+                    const CONTENT_TYPE_HANDSHAKE: u8 = 22;
+                    let mut buf = [0];
+                    let mut buf = ReadBuf::new(&mut buf[..]);
+                    match this.reader.as_mut().poll_read(cx, &mut buf) {
+                        Poll::Ready(Ok(())) if buf.filled().len() == 1 => {
+                            if buf.filled()[0] != CONTENT_TYPE_HANDSHAKE {
+                                return Poll::Ready(Err(io::Error::new(
+                                    io::ErrorKind::InvalidData,
+                                    format!(
+                                        "got wrong content type (wanted {}, got {})",
+                                        CONTENT_TYPE_HANDSHAKE,
+                                        buf.filled()[0]
+                                    ),
+                                )));
+                            }
+                            *this.currently_reading =
+                                HandshakeRecordReaderReading::MajorMinorVersion(0);
+                        }
+                        rslt => return rslt,
+                    }
+                }
+
+                HandshakeRecordReaderReading::MajorMinorVersion(bytes_read) => {
+                    let mut buf = [0, 0];
+                    let mut buf = ReadBuf::new(&mut buf[..]);
+                    buf.advance(*bytes_read);
+                    match this.reader.as_mut().poll_read(cx, &mut buf) {
+                        Poll::Ready(Ok(())) => {
+                            *bytes_read = buf.filled().len();
+                            if *bytes_read == 2 {
+                                *this.currently_reading =
+                                    HandshakeRecordReaderReading::RecordSize([0, 0], 0);
+                            }
+                        }
+                        rslt => return rslt,
+                    }
+                }
+
+                HandshakeRecordReaderReading::RecordSize(backing_array, bytes_read) => {
+                    const MAX_RECORD_SIZE: usize = 1 << 14;
+                    let mut buf = ReadBuf::new(&mut backing_array[..]);
+                    buf.advance(*bytes_read);
+                    match this.reader.as_mut().poll_read(cx, &mut buf) {
+                        Poll::Ready(Ok(())) => {
+                            *bytes_read = buf.filled().len();
+                            if *bytes_read == 2 {
+                                let record_size = u16::from_be_bytes(*backing_array).into();
+                                if record_size > MAX_RECORD_SIZE {
+                                    return Poll::Ready(Err(io::Error::new(
+                                        io::ErrorKind::InvalidData,
+                                        format!(
+                                            "record too large ({} > {})",
+                                            record_size, MAX_RECORD_SIZE
+                                        ),
+                                    )));
+                                }
+                                *this.currently_reading =
+                                    HandshakeRecordReaderReading::Record(record_size)
+                            }
+                        }
+                        rslt => return rslt,
+                    }
+                }
+
+                HandshakeRecordReaderReading::Record(remaining_record_bytes) => {
+                    // We ultimately want to read record bytes into `caller_buf`, but we need to
+                    // ensure that we don't read more bytes than there are record bytes (and end
+                    // up handing the caller record header bytes). So we call `caller_buf.take()`.
+                    // Since `take` returns an independent `ReadBuf`, we have to update `caller_buf`
+                    // once we're done reading: first we call `assume_init` to assert that the
+                    // `bytes_read` bytes we read are initialized, then we call `advance` to assert
+                    // that the appropriate section of the buffer is filled.
+
+                    let mut buf = caller_buf.take(*remaining_record_bytes);
+                    let rslt = this.reader.as_mut().poll_read(cx, &mut buf);
+                    if let Poll::Ready(Ok(())) = rslt {
+                        let bytes_read = buf.filled().len();
+                        unsafe {
+                            caller_buf.assume_init(bytes_read);
+                        }
+                        caller_buf.advance(bytes_read);
+                        *remaining_record_bytes -= bytes_read;
+                        if *remaining_record_bytes == 0 {
+                            *this.currently_reading = HandshakeRecordReaderReading::ContentType;
+                        }
+                    }
+                    return rslt;
+                }
+            }
+        }
+    }
+}
+
+pub async fn read_sni_host_name_from_client_hello<R: AsyncRead>(
+    mut reader: Pin<&mut R>,
+) -> io::Result<String> {
+    // Handshake message type.
+    const HANDSHAKE_TYPE_CLIENT_HELLO: u8 = 1;
+    let typ = reader.read_u8().await?;
+    if typ != HANDSHAKE_TYPE_CLIENT_HELLO {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidData,
+            format!(
+                "handshake message not a ClientHello (type {}, expected {})",
+                typ, HANDSHAKE_TYPE_CLIENT_HELLO
+            ),
+        ));
+    }
+
+    // Handshake message length.
+    let len = read_u24(reader.as_mut()).await?;
+    let reader = reader.take(len.into());
+    pin!(reader);
+
+    // ProtocolVersion (2 bytes) & random (32 bytes).
+    skip(reader.as_mut(), 34).await?;
+
+    // Session ID (u8-length vec), cipher suites (u16-length vec), compression methods (u8-length vec).
+    skip_vec_u8(reader.as_mut()).await?;
+    skip_vec_u16(reader.as_mut()).await?;
+    skip_vec_u8(reader.as_mut()).await?;
+
+    // Extensions.
+    let ext_len = reader.read_u16().await?;
+    let new_limit = min(reader.limit(), ext_len.into());
+    reader.set_limit(new_limit);
+    loop {
+        // Extension type & length.
+        let ext_typ = reader.read_u16().await?;
+        let ext_len = reader.read_u16().await?;
+
+        const EXTENSION_TYPE_SNI: u16 = 0;
+        if ext_typ != EXTENSION_TYPE_SNI {
+            skip(reader.as_mut(), ext_len.into()).await?;
+            continue;
+        }
+        let new_limit = min(reader.limit(), ext_len.into());
+        reader.set_limit(new_limit);
+
+        // ServerNameList length.
+        let snl_len = reader.read_u16().await?;
+        let new_limit = min(reader.limit(), snl_len.into());
+        reader.set_limit(new_limit);
+
+        // ServerNameList.
+        loop {
+            // NameType & length.
+            let name_typ = reader.read_u8().await?;
+
+            const NAME_TYPE_HOST_NAME: u8 = 0;
+            if name_typ != NAME_TYPE_HOST_NAME {
+                skip_vec_u16(reader.as_mut()).await?;
+                continue;
+            }
+
+            let name_len = reader.read_u16().await?;
+            let new_limit = min(reader.limit(), name_len.into());
+            reader.set_limit(new_limit);
+            let mut name_buf = vec![0; name_len.into()];
+            reader.read_exact(&mut name_buf).await?;
+            return String::from_utf8(name_buf)
+                .map_err(|err| io::Error::new(ErrorKind::InvalidData, err));
+        }
+    }
+}
+
+async fn skip<R: AsyncRead>(reader: Pin<&mut R>, len: u64) -> io::Result<()> {
+    let bytes_read = io::copy(&mut reader.take(len), &mut io::sink()).await?;
+    if bytes_read < len {
+        return Err(io::Error::new(
+            ErrorKind::UnexpectedEof,
+            format!("skip read {} < {} bytes", bytes_read, len),
+        ));
+    }
+    Ok(())
+}
+
+async fn skip_vec_u8<R: AsyncRead>(mut reader: Pin<&mut R>) -> io::Result<()> {
+    let sz = reader.read_u8().await?;
+    skip(reader.as_mut(), sz.into()).await
+}
+
+async fn skip_vec_u16<R: AsyncRead>(mut reader: Pin<&mut R>) -> io::Result<()> {
+    let sz = reader.read_u16().await?;
+    skip(reader.as_mut(), sz.into()).await
+}
+
+async fn read_u24<R: AsyncRead>(mut reader: Pin<&mut R>) -> io::Result<u32> {
+    let mut buf = [0; 3];
+    reader
+        .as_mut()
+        .read_exact(&mut buf)
+        .await
+        .map(|_| NetworkEndian::read_u24(&buf))
+}