|
@@ -3,11 +3,7 @@ use crate::{
|
|
|
handler::{CustomContextData, HttpHandler, MitmFilter},
|
|
|
http_client::HttpClient,
|
|
|
};
|
|
|
-use http::{
|
|
|
- header,
|
|
|
- uri::{PathAndQuery, Scheme},
|
|
|
- HeaderValue, Uri,
|
|
|
-};
|
|
|
+use http::{header, uri::Scheme, HeaderValue, Uri};
|
|
|
use hyper::{
|
|
|
body::HttpBody, server::conn::Http, service::service_fn, upgrade::Upgraded, Body, Method,
|
|
|
Request, Response,
|
|
@@ -60,7 +56,7 @@ where
|
|
|
let res = if req.method() == Method::CONNECT {
|
|
|
self.process_connect(req).await
|
|
|
} else {
|
|
|
- self.process_request(req).await
|
|
|
+ self.process_request(req, Scheme::HTTP).await
|
|
|
};
|
|
|
|
|
|
match res {
|
|
@@ -75,34 +71,39 @@ where
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- async fn process_request(self, mut req: Request<Body>) -> Result<Response<Body>, hyper::Error> {
|
|
|
+ async fn process_request(
|
|
|
+ self,
|
|
|
+ mut req: Request<Body>,
|
|
|
+ scheme: Scheme,
|
|
|
+ ) -> Result<Response<Body>, hyper::Error> {
|
|
|
+ if req.uri().path().starts_with("/mitm/cert") {
|
|
|
+ return Ok(self.get_cert_res());
|
|
|
+ }
|
|
|
+
|
|
|
let mut ctx = HttpContext {
|
|
|
uri: None,
|
|
|
should_modify_response: false,
|
|
|
..Default::default()
|
|
|
};
|
|
|
|
|
|
- let host = req
|
|
|
- .headers()
|
|
|
- .get(http::header::HOST)
|
|
|
- .map(|h| h.to_str())
|
|
|
- .unwrap()
|
|
|
- .map(|h| h.to_owned())
|
|
|
- .unwrap_or_default();
|
|
|
-
|
|
|
- let uri = req.uri_mut();
|
|
|
- if uri.authority().is_none() {
|
|
|
- *uri = http::uri::Uri::builder()
|
|
|
- .scheme(uri.scheme().unwrap_or(&Scheme::HTTP).as_str())
|
|
|
- .authority(host.as_str())
|
|
|
- .path_and_query(uri.path_and_query().map_or("/", |p| p.as_str()))
|
|
|
- .build()
|
|
|
- .unwrap();
|
|
|
- }
|
|
|
+ // if req.uri().authority().is_none() {
|
|
|
+ if req.version() == http::Version::HTTP_10 || req.version() == http::Version::HTTP_11 {
|
|
|
+ let (mut parts, body) = req.into_parts();
|
|
|
+
|
|
|
+ if let Some(Ok(authority)) = parts
|
|
|
+ .headers
|
|
|
+ .get(http::header::HOST)
|
|
|
+ .map(|host| host.to_str())
|
|
|
+ {
|
|
|
+ let mut uri = parts.uri.into_parts();
|
|
|
+ uri.scheme = Some(scheme.clone());
|
|
|
+ uri.authority = authority.try_into().ok();
|
|
|
+ parts.uri = Uri::from_parts(uri).expect("build uri");
|
|
|
+ }
|
|
|
|
|
|
- if req.uri().path().starts_with("/mitm/cert") || host.contains("cert.mitm") {
|
|
|
- return Ok(self.get_cert_res());
|
|
|
- }
|
|
|
+ req = Request::from_parts(parts, body);
|
|
|
+ };
|
|
|
+ // }
|
|
|
|
|
|
let mut req = match self.http_handler.handle_request(&mut ctx, req).await {
|
|
|
RequestOrResponse::Request(req) => req,
|
|
@@ -111,7 +112,7 @@ where
|
|
|
|
|
|
{
|
|
|
let header_mut = req.headers_mut();
|
|
|
- // header_mut.remove(http::header::HOST);
|
|
|
+ header_mut.remove(http::header::HOST);
|
|
|
header_mut.remove(http::header::ACCEPT_ENCODING);
|
|
|
header_mut.remove(http::header::CONTENT_LENGTH);
|
|
|
}
|
|
@@ -139,50 +140,6 @@ where
|
|
|
Ok(res)
|
|
|
}
|
|
|
|
|
|
- // async fn process_connect(self, req: Request<Body>) -> Result<Response<Body>, hyper::Error> {
|
|
|
- // let ctx = HttpContext {
|
|
|
- // uri: None,
|
|
|
- // should_modify_response: false,
|
|
|
- // ..Default::default()
|
|
|
- // };
|
|
|
-
|
|
|
- // if self.mitm_filter.filter(&ctx, &req).await {
|
|
|
- // tokio::task::spawn(async move {
|
|
|
- // let authority = req
|
|
|
- // .uri()
|
|
|
- // .authority()
|
|
|
- // .expect("URI does not contain authority")
|
|
|
- // .clone();
|
|
|
-
|
|
|
- // match hyper::upgrade::on(req).await {
|
|
|
- // Ok(upgraded) => {
|
|
|
- // let server_config = self.ca.clone().gen_server_config();
|
|
|
-
|
|
|
- // let stream = TlsAcceptor::from(server_config)
|
|
|
- // .accept(upgraded)
|
|
|
- // .await
|
|
|
- // .expect("Failed to establish TLS connection with client");
|
|
|
-
|
|
|
- // if let Err(e) = self.serve_tls_stream(stream).await {
|
|
|
- // let e_string = e.to_string();
|
|
|
- // if !e_string.starts_with("error shutting down connection") {
|
|
|
- // debug!("res:: {}", e);
|
|
|
- // }
|
|
|
- // }
|
|
|
- // }
|
|
|
- // Err(e) => debug!("upgrade error for {}: {}", authority, e),
|
|
|
- // };
|
|
|
- // });
|
|
|
- // } else {
|
|
|
- // tokio::task::spawn(async move {
|
|
|
- // let remote_addr = host_addr(req.uri()).unwrap();
|
|
|
- // let upgraded = hyper::upgrade::on(req).await.unwrap();
|
|
|
- // tunnel(upgraded, remote_addr).await
|
|
|
- // });
|
|
|
- // }
|
|
|
- // Ok(Response::new(Body::empty()))
|
|
|
- // }
|
|
|
-
|
|
|
async fn process_connect(self, req: Request<Body>) -> Result<Response<Body>, hyper::Error> {
|
|
|
let ctx = HttpContext {
|
|
|
uri: None,
|
|
@@ -218,54 +175,30 @@ where
|
|
|
pub async fn serve_tls<IO: AsyncRead + AsyncWrite + Unpin + Send + 'static>(self, stream: IO) {
|
|
|
let server_config = self.ca.clone().gen_server_config();
|
|
|
|
|
|
- let stream = TlsAcceptor::from(server_config)
|
|
|
- .accept(stream)
|
|
|
- .await
|
|
|
- .expect("Failed to establish TLS connection with client");
|
|
|
-
|
|
|
- if let Err(e) = self.serve_tls_stream(stream).await {
|
|
|
- let e_string = e.to_string();
|
|
|
- if !e_string.starts_with("error shutting down connection") {
|
|
|
- debug!("res:: {}", e);
|
|
|
+ match TlsAcceptor::from(server_config).accept(stream).await {
|
|
|
+ Ok(stream) => {
|
|
|
+ if let Err(e) = self.serve_stream(stream, Scheme::HTTPS).await {
|
|
|
+ let e_string = e.to_string();
|
|
|
+ if !e_string.starts_with("error shutting down connection") {
|
|
|
+ debug!("res:: {}", e);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ Err(err) => {
|
|
|
+ error!("Tls accept failed: {err}")
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- async fn serve_tls_stream<IO: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
|
|
|
- self,
|
|
|
- stream: tokio_rustls::server::TlsStream<IO>,
|
|
|
- ) -> Result<(), hyper::Error> {
|
|
|
- let service = service_fn(|mut req| {
|
|
|
- if req.version() == http::Version::HTTP_10 || req.version() == http::Version::HTTP_11 {
|
|
|
- let authority = req
|
|
|
- .headers()
|
|
|
- .get(http::header::HOST)
|
|
|
- .expect("Host is a required header")
|
|
|
- .to_str()
|
|
|
- .expect("Failed to convert host to str");
|
|
|
-
|
|
|
- let uri = http::uri::Builder::new()
|
|
|
- .scheme(http::uri::Scheme::HTTPS)
|
|
|
- .authority(authority)
|
|
|
- .path_and_query(
|
|
|
- req.uri()
|
|
|
- .path_and_query()
|
|
|
- .unwrap_or(&PathAndQuery::from_static("/"))
|
|
|
- .to_owned(),
|
|
|
- )
|
|
|
- .build()
|
|
|
- .expect("Failed to build URI");
|
|
|
-
|
|
|
- let (mut parts, body) = req.into_parts();
|
|
|
- parts.uri = uri;
|
|
|
- req = Request::from_parts(parts, body)
|
|
|
- };
|
|
|
-
|
|
|
- self.clone().process_request(req)
|
|
|
- });
|
|
|
-
|
|
|
+ async fn serve_stream<S>(self, stream: S, scheme: Scheme) -> Result<(), hyper::Error>
|
|
|
+ where
|
|
|
+ S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
|
|
+ {
|
|
|
Http::new()
|
|
|
- .serve_connection(stream, service)
|
|
|
+ .serve_connection(
|
|
|
+ stream,
|
|
|
+ service_fn(|req| self.clone().process_request(req, scheme.clone())),
|
|
|
+ )
|
|
|
.with_upgrades()
|
|
|
.await
|
|
|
}
|