Browse Source

Update to rocket 0.5 and made code async, missing updating all db calls, that are currently blocking

Daniel García 4 years ago
parent
commit
0b7d6bf6df

File diff suppressed because it is too large
+ 322 - 221
Cargo.lock


+ 15 - 14
Cargo.toml

@@ -3,7 +3,7 @@ name = "vaultwarden"
 version = "1.0.0"
 authors = ["Daniel García <[email protected]>"]
 edition = "2021"
-rust-version = "1.60"
+rust-version = "1.56"
 resolver = "2"
 
 repository = "https://github.com/dani-garcia/vaultwarden"
@@ -13,6 +13,7 @@ publish = false
 build = "build.rs"
 
 [features]
+# default = ["sqlite"]
 # Empty to keep compatibility, prefer to set USE_SYSLOG=true
 enable_syslog = []
 mysql = ["diesel/mysql", "diesel_migrations/mysql"]
@@ -29,22 +30,22 @@ unstable = []
 syslog = "4.0.1"
 
 [dependencies]
-# Web framework for nightly with a focus on ease-of-use, expressibility, and speed.
-rocket = { version = "=0.5.0-dev", features = ["tls"], default-features = false }
-rocket_contrib = "=0.5.0-dev"
-
-# HTTP client
-reqwest = { version = "0.11.9", features = ["blocking", "json", "gzip", "brotli", "socks", "cookies", "trust-dns"] }
+# Web framework
+rocket = { version = "0.5.0-rc.1", features = ["tls", "json"], default-features = false }
+
+# Async futures
+futures = "0.3.19"
+tokio = { version = "1.16.1", features = ["rt-multi-thread", "fs", "io-util", "parking_lot"] }
+ 
+ # HTTP client
+reqwest = { version = "0.11.9", features = ["stream", "json", "gzip", "brotli", "socks", "cookies", "trust-dns"] }
+bytes = "1.1.0"
 
 # Used for custom short lived cookie jar
 cookie = "0.15.1"
 cookie_store = "0.15.1"
-bytes = "1.1.0"
 url = "2.2.2"
 
-# multipart/form-data support
-multipart = { version = "0.18.0", features = ["server"], default-features = false }
-
 # WebSockets library
 ws = { version = "0.11.1", package = "parity-ws" }
 
@@ -141,10 +142,10 @@ backtrace = "0.3.64"
 paste = "1.0.6"
 governor = "0.4.1"
 
+ctrlc = { version = "3.2.1", features = ["termination"] }
+
 [patch.crates-io]
-# Use newest ring
-rocket = { git = 'https://github.com/SergioBenitez/Rocket', rev = '263e39b5b429de1913ce7e3036575a7b4d88b6d7' }
-rocket_contrib = { git = 'https://github.com/SergioBenitez/Rocket', rev = '263e39b5b429de1913ce7e3036575a7b4d88b6d7' }
+rocket = { git = 'https://github.com/SergioBenitez/Rocket', rev = '8cae077ba1d54b92cdef3e171a730b819d5eeb8e' }
 
 # The maintainer of the `job_scheduler` crate doesn't seem to have responded
 # to any issues or PRs for almost a year (as of April 2021). This hopefully

+ 0 - 2
Rocket.toml

@@ -1,2 +0,0 @@
-[global.limits]
-json = 10485760 # 10 MiB

+ 1 - 1
rust-toolchain

@@ -1 +1 @@
-nightly-2022-01-23
+stable

+ 37 - 28
src/api/admin.rs

@@ -3,13 +3,14 @@ use serde::de::DeserializeOwned;
 use serde_json::Value;
 use std::env;
 
+use rocket::serde::json::Json;
 use rocket::{
-    http::{Cookie, Cookies, SameSite, Status},
-    request::{self, FlashMessage, Form, FromRequest, Outcome, Request},
-    response::{content::Html, Flash, Redirect},
+    form::Form,
+    http::{Cookie, CookieJar, SameSite, Status},
+    request::{self, FlashMessage, FromRequest, Outcome, Request},
+    response::{content::RawHtml as Html, Flash, Redirect},
     Route,
 };
-use rocket_contrib::json::Json;
 
 use crate::{
     api::{ApiResult, EmptyResult, JsonResult, NumberOrString},
@@ -85,10 +86,11 @@ fn admin_path() -> String {
 
 struct Referer(Option<String>);
 
-impl<'a, 'r> FromRequest<'a, 'r> for Referer {
+#[rocket::async_trait]
+impl<'r> FromRequest<'r> for Referer {
     type Error = ();
 
-    fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Self::Error> {
+    async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
         Outcome::Success(Referer(request.headers().get_one("Referer").map(str::to_string)))
     }
 }
@@ -96,10 +98,11 @@ impl<'a, 'r> FromRequest<'a, 'r> for Referer {
 #[derive(Debug)]
 struct IpHeader(Option<String>);
 
-impl<'a, 'r> FromRequest<'a, 'r> for IpHeader {
+#[rocket::async_trait]
+impl<'r> FromRequest<'r> for IpHeader {
     type Error = ();
 
-    fn from_request(req: &'a Request<'r>) -> Outcome<Self, Self::Error> {
+    async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
         if req.headers().get_one(&CONFIG.ip_header()).is_some() {
             Outcome::Success(IpHeader(Some(CONFIG.ip_header())))
         } else if req.headers().get_one("X-Client-IP").is_some() {
@@ -138,7 +141,7 @@ fn admin_url(referer: Referer) -> String {
 #[get("/", rank = 2)]
 fn admin_login(flash: Option<FlashMessage>) -> ApiResult<Html<String>> {
     // If there is an error, show it
-    let msg = flash.map(|msg| format!("{}: {}", msg.name(), msg.msg()));
+    let msg = flash.map(|msg| format!("{}: {}", msg.kind(), msg.message()));
     let json = json!({
         "page_content": "admin/login",
         "version": VERSION,
@@ -159,7 +162,7 @@ struct LoginForm {
 #[post("/", data = "<data>")]
 fn post_admin_login(
     data: Form<LoginForm>,
-    mut cookies: Cookies,
+    cookies: &CookieJar,
     ip: ClientIp,
     referer: Referer,
 ) -> Result<Redirect, Flash<Redirect>> {
@@ -180,7 +183,7 @@ fn post_admin_login(
 
         let cookie = Cookie::build(COOKIE_NAME, jwt)
             .path(admin_path())
-            .max_age(time::Duration::minutes(20))
+            .max_age(rocket::time::Duration::minutes(20))
             .same_site(SameSite::Strict)
             .http_only(true)
             .finish();
@@ -297,7 +300,7 @@ fn test_smtp(data: Json<InviteData>, _token: AdminToken) -> EmptyResult {
 }
 
 #[get("/logout")]
-fn logout(mut cookies: Cookies, referer: Referer) -> Redirect {
+fn logout(cookies: &CookieJar, referer: Referer) -> Redirect {
     cookies.remove(Cookie::named(COOKIE_NAME));
     Redirect::to(admin_url(referer))
 }
@@ -462,23 +465,23 @@ struct GitCommit {
     sha: String,
 }
 
-fn get_github_api<T: DeserializeOwned>(url: &str) -> Result<T, Error> {
+async fn get_github_api<T: DeserializeOwned>(url: &str) -> Result<T, Error> {
     let github_api = get_reqwest_client();
 
-    Ok(github_api.get(url).send()?.error_for_status()?.json::<T>()?)
+    Ok(github_api.get(url).send().await?.error_for_status()?.json::<T>().await?)
 }
 
-fn has_http_access() -> bool {
+async fn has_http_access() -> bool {
     let http_access = get_reqwest_client();
 
-    match http_access.head("https://github.com/dani-garcia/vaultwarden").send() {
+    match http_access.head("https://github.com/dani-garcia/vaultwarden").send().await {
         Ok(r) => r.status().is_success(),
         _ => false,
     }
 }
 
 #[get("/diagnostics")]
-fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResult<Html<String>> {
+async fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResult<Html<String>> {
     use crate::util::read_file_string;
     use chrono::prelude::*;
     use std::net::ToSocketAddrs;
@@ -497,7 +500,7 @@ fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResu
 
     // Execute some environment checks
     let running_within_docker = is_running_in_docker();
-    let has_http_access = has_http_access();
+    let has_http_access = has_http_access().await;
     let uses_proxy = env::var_os("HTTP_PROXY").is_some()
         || env::var_os("http_proxy").is_some()
         || env::var_os("HTTPS_PROXY").is_some()
@@ -513,11 +516,14 @@ fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResu
     // TODO: Maybe we need to cache this using a LazyStatic or something. Github only allows 60 requests per hour, and we use 3 here already.
     let (latest_release, latest_commit, latest_web_build) = if has_http_access {
         (
-            match get_github_api::<GitRelease>("https://api.github.com/repos/dani-garcia/vaultwarden/releases/latest") {
+            match get_github_api::<GitRelease>("https://api.github.com/repos/dani-garcia/vaultwarden/releases/latest")
+                .await
+            {
                 Ok(r) => r.tag_name,
                 _ => "-".to_string(),
             },
-            match get_github_api::<GitCommit>("https://api.github.com/repos/dani-garcia/vaultwarden/commits/main") {
+            match get_github_api::<GitCommit>("https://api.github.com/repos/dani-garcia/vaultwarden/commits/main").await
+            {
                 Ok(mut c) => {
                     c.sha.truncate(8);
                     c.sha
@@ -531,7 +537,9 @@ fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResu
             } else {
                 match get_github_api::<GitRelease>(
                     "https://api.github.com/repos/dani-garcia/bw_web_builds/releases/latest",
-                ) {
+                )
+                .await
+                {
                     Ok(r) => r.tag_name.trim_start_matches('v').to_string(),
                     _ => "-".to_string(),
                 }
@@ -562,7 +570,7 @@ fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResu
         "ip_header_config": &CONFIG.ip_header(),
         "uses_proxy": uses_proxy,
         "db_type": *DB_TYPE,
-        "db_version": get_sql_server_version(&conn),
+        "db_version": get_sql_server_version(&conn).await,
         "admin_url": format!("{}/diagnostics", admin_url(Referer(None))),
         "overrides": &CONFIG.get_overrides().join(", "),
         "server_time_local": Local::now().format("%Y-%m-%d %H:%M:%S %Z").to_string(),
@@ -591,9 +599,9 @@ fn delete_config(_token: AdminToken) -> EmptyResult {
 }
 
 #[post("/config/backup_db")]
-fn backup_db(_token: AdminToken, conn: DbConn) -> EmptyResult {
+async fn backup_db(_token: AdminToken, conn: DbConn) -> EmptyResult {
     if *CAN_BACKUP {
-        backup_database(&conn)
+        backup_database(&conn).await
     } else {
         err!("Can't back up current DB (Only SQLite supports this feature)");
     }
@@ -601,21 +609,22 @@ fn backup_db(_token: AdminToken, conn: DbConn) -> EmptyResult {
 
 pub struct AdminToken {}
 
-impl<'a, 'r> FromRequest<'a, 'r> for AdminToken {
+#[rocket::async_trait]
+impl<'r> FromRequest<'r> for AdminToken {
     type Error = &'static str;
 
-    fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Self::Error> {
+    async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
         if CONFIG.disable_admin_token() {
             Outcome::Success(AdminToken {})
         } else {
-            let mut cookies = request.cookies();
+            let cookies = request.cookies();
 
             let access_token = match cookies.get(COOKIE_NAME) {
                 Some(cookie) => cookie.value(),
                 None => return Outcome::Forward(()), // If there is no cookie, redirect to login
             };
 
-            let ip = match request.guard::<ClientIp>() {
+            let ip = match ClientIp::from_request(request).await {
                 Outcome::Success(ip) => ip.ip,
                 _ => err_handler!("Error getting Client IP"),
             };

+ 1 - 1
src/api/core/accounts.rs

@@ -1,5 +1,5 @@
 use chrono::Utc;
-use rocket_contrib::json::Json;
+use rocket::serde::json::Json;
 use serde_json::Value;
 
 use crate::{

+ 100 - 150
src/api/core/ciphers.rs

@@ -1,13 +1,14 @@
 use std::collections::{HashMap, HashSet};
-use std::path::{Path, PathBuf};
 
 use chrono::{NaiveDateTime, Utc};
-use rocket::{http::ContentType, request::Form, Data, Route};
-use rocket_contrib::json::Json;
+use rocket::fs::TempFile;
+use rocket::serde::json::Json;
+use rocket::{
+    form::{Form, FromForm},
+    Route,
+};
 use serde_json::Value;
 
-use multipart::server::{save::SavedData, Multipart, SaveResult};
-
 use crate::{
     api::{self, EmptyResult, JsonResult, JsonUpcase, Notify, PasswordData, UpdateType},
     auth::Headers,
@@ -79,9 +80,9 @@ pub fn routes() -> Vec<Route> {
     ]
 }
 
-pub fn purge_trashed_ciphers(pool: DbPool) {
+pub async fn purge_trashed_ciphers(pool: DbPool) {
     debug!("Purging trashed ciphers");
-    if let Ok(conn) = pool.get() {
+    if let Ok(conn) = pool.get().await {
         Cipher::purge_trash(&conn);
     } else {
         error!("Failed to get DB connection while purging trashed ciphers")
@@ -90,12 +91,12 @@ pub fn purge_trashed_ciphers(pool: DbPool) {
 
 #[derive(FromForm, Default)]
 struct SyncData {
-    #[form(field = "excludeDomains")]
+    #[field(name = "excludeDomains")]
     exclude_domains: bool, // Default: 'false'
 }
 
 #[get("/sync?<data..>")]
-fn sync(data: Form<SyncData>, headers: Headers, conn: DbConn) -> Json<Value> {
+fn sync(data: SyncData, headers: Headers, conn: DbConn) -> Json<Value> {
     let user_json = headers.user.to_json(&conn);
 
     let folders = Folder::find_by_user(&headers.user.uuid, &conn);
@@ -828,6 +829,12 @@ fn post_attachment_v2(
     })))
 }
 
+#[derive(FromForm)]
+struct UploadData<'f> {
+    key: Option<String>,
+    data: TempFile<'f>,
+}
+
 /// Saves the data content of an attachment to a file. This is common code
 /// shared between the v2 and legacy attachment APIs.
 ///
@@ -836,22 +843,21 @@ fn post_attachment_v2(
 ///
 /// When used with the v2 API, post_attachment_v2() has already created the
 /// database record, which is passed in as `attachment`.
-fn save_attachment(
+async fn save_attachment(
     mut attachment: Option<Attachment>,
     cipher_uuid: String,
-    data: Data,
-    content_type: &ContentType,
+    data: Form<UploadData<'_>>,
     headers: &Headers,
-    conn: &DbConn,
-    nt: Notify,
-) -> Result<Cipher, crate::error::Error> {
-    let cipher = match Cipher::find_by_uuid(&cipher_uuid, conn) {
+    conn: DbConn,
+    nt: Notify<'_>,
+) -> Result<(Cipher, DbConn), crate::error::Error> {
+    let cipher = match Cipher::find_by_uuid(&cipher_uuid, &conn) {
         Some(cipher) => cipher,
-        None => err_discard!("Cipher doesn't exist", data),
+        None => err!("Cipher doesn't exist"),
     };
 
-    if !cipher.is_write_accessible_to_user(&headers.user.uuid, conn) {
-        err_discard!("Cipher is not write accessible", data)
+    if !cipher.is_write_accessible_to_user(&headers.user.uuid, &conn) {
+        err!("Cipher is not write accessible")
     }
 
     // In the v2 API, the attachment record has already been created,
@@ -863,11 +869,11 @@ fn save_attachment(
 
     let size_limit = if let Some(ref user_uuid) = cipher.user_uuid {
         match CONFIG.user_attachment_limit() {
-            Some(0) => err_discard!("Attachments are disabled", data),
+            Some(0) => err!("Attachments are disabled"),
             Some(limit_kb) => {
-                let left = (limit_kb * 1024) - Attachment::size_by_user(user_uuid, conn) + size_adjust;
+                let left = (limit_kb * 1024) - Attachment::size_by_user(user_uuid, &conn) + size_adjust;
                 if left <= 0 {
-                    err_discard!("Attachment storage limit reached! Delete some attachments to free up space", data)
+                    err!("Attachment storage limit reached! Delete some attachments to free up space")
                 }
                 Some(left as u64)
             }
@@ -875,130 +881,78 @@ fn save_attachment(
         }
     } else if let Some(ref org_uuid) = cipher.organization_uuid {
         match CONFIG.org_attachment_limit() {
-            Some(0) => err_discard!("Attachments are disabled", data),
+            Some(0) => err!("Attachments are disabled"),
             Some(limit_kb) => {
-                let left = (limit_kb * 1024) - Attachment::size_by_org(org_uuid, conn) + size_adjust;
+                let left = (limit_kb * 1024) - Attachment::size_by_org(org_uuid, &conn) + size_adjust;
                 if left <= 0 {
-                    err_discard!("Attachment storage limit reached! Delete some attachments to free up space", data)
+                    err!("Attachment storage limit reached! Delete some attachments to free up space")
                 }
                 Some(left as u64)
             }
             None => None,
         }
     } else {
-        err_discard!("Cipher is neither owned by a user nor an organization", data);
+        err!("Cipher is neither owned by a user nor an organization");
     };
 
-    let mut params = content_type.params();
-    let boundary_pair = params.next().expect("No boundary provided");
-    let boundary = boundary_pair.1;
-
-    let base_path = Path::new(&CONFIG.attachments_folder()).join(&cipher_uuid);
-    let mut path = PathBuf::new();
+    let mut data = data.into_inner();
 
-    let mut attachment_key = None;
-    let mut error = None;
-
-    Multipart::with_body(data.open(), boundary)
-        .foreach_entry(|mut field| {
-            match &*field.headers.name {
-                "key" => {
-                    use std::io::Read;
-                    let mut key_buffer = String::new();
-                    if field.data.read_to_string(&mut key_buffer).is_ok() {
-                        attachment_key = Some(key_buffer);
-                    }
-                }
-                "data" => {
-                    // In the legacy API, this is the encrypted filename
-                    // provided by the client, stored to the database as-is.
-                    // In the v2 API, this value doesn't matter, as it was
-                    // already provided and stored via an earlier API call.
-                    let encrypted_filename = field.headers.filename;
-
-                    // This random ID is used as the name of the file on disk.
-                    // In the legacy API, we need to generate this value here.
-                    // In the v2 API, we use the value from post_attachment_v2().
-                    let file_id = match &attachment {
-                        Some(attachment) => attachment.id.clone(), // v2 API
-                        None => crypto::generate_attachment_id(),  // Legacy API
-                    };
-                    path = base_path.join(&file_id);
-
-                    let size =
-                        match field.data.save().memory_threshold(0).size_limit(size_limit).with_path(path.clone()) {
-                            SaveResult::Full(SavedData::File(_, size)) => size as i32,
-                            SaveResult::Full(other) => {
-                                error = Some(format!("Attachment is not a file: {:?}", other));
-                                return;
-                            }
-                            SaveResult::Partial(_, reason) => {
-                                error = Some(format!("Attachment storage limit exceeded with this file: {:?}", reason));
-                                return;
-                            }
-                            SaveResult::Error(e) => {
-                                error = Some(format!("Error: {:?}", e));
-                                return;
-                            }
-                        };
-
-                    if let Some(attachment) = &mut attachment {
-                        // v2 API
-
-                        // Check the actual size against the size initially provided by
-                        // the client. Upstream allows +/- 1 MiB deviation from this
-                        // size, but it's not clear when or why this is needed.
-                        const LEEWAY: i32 = 1024 * 1024; // 1 MiB
-                        let min_size = attachment.file_size - LEEWAY;
-                        let max_size = attachment.file_size + LEEWAY;
-
-                        if min_size <= size && size <= max_size {
-                            if size != attachment.file_size {
-                                // Update the attachment with the actual file size.
-                                attachment.file_size = size;
-                                attachment.save(conn).expect("Error updating attachment");
-                            }
-                        } else {
-                            attachment.delete(conn).ok();
+    if let Some(size_limit) = size_limit {
+        if data.data.len() > size_limit {
+            err!("Attachment storage limit exceeded with this file");
+        }
+    }
 
-                            let err_msg = "Attachment size mismatch".to_string();
-                            error!("{} (expected within [{}, {}], got {})", err_msg, min_size, max_size, size);
-                            error = Some(err_msg);
-                        }
-                    } else {
-                        // Legacy API
+    let file_id = match &attachment {
+        Some(attachment) => attachment.id.clone(), // v2 API
+        None => crypto::generate_attachment_id(),  // Legacy API
+    };
 
-                        if encrypted_filename.is_none() {
-                            error = Some("No filename provided".to_string());
-                            return;
-                        }
-                        if attachment_key.is_none() {
-                            error = Some("No attachment key provided".to_string());
-                            return;
-                        }
-                        let attachment = Attachment::new(
-                            file_id,
-                            cipher_uuid.clone(),
-                            encrypted_filename.unwrap(),
-                            size,
-                            attachment_key.clone(),
-                        );
-                        attachment.save(conn).expect("Error saving attachment");
-                    }
-                }
-                _ => error!("Invalid multipart name"),
+    let folder_path = tokio::fs::canonicalize(&CONFIG.attachments_folder()).await?.join(&cipher_uuid);
+    let file_path = folder_path.join(&file_id);
+    tokio::fs::create_dir_all(&folder_path).await?;
+
+    let size = data.data.len() as i32;
+    if let Some(attachment) = &mut attachment {
+        // v2 API
+
+        // Check the actual size against the size initially provided by
+        // the client. Upstream allows +/- 1 MiB deviation from this
+        // size, but it's not clear when or why this is needed.
+        const LEEWAY: i32 = 1024 * 1024; // 1 MiB
+        let min_size = attachment.file_size - LEEWAY;
+        let max_size = attachment.file_size + LEEWAY;
+
+        if min_size <= size && size <= max_size {
+            if size != attachment.file_size {
+                // Update the attachment with the actual file size.
+                attachment.file_size = size;
+                attachment.save(&conn).expect("Error updating attachment");
             }
-        })
-        .expect("Error processing multipart data");
+        } else {
+            attachment.delete(&conn).ok();
+
+            err!(format!("Attachment size mismatch (expected within [{}, {}], got {})", min_size, max_size, size));
+        }
+    } else {
+        // Legacy API
+        let encrypted_filename = data.data.raw_name().map(|s| s.dangerous_unsafe_unsanitized_raw().to_string());
 
-    if let Some(ref e) = error {
-        std::fs::remove_file(path).ok();
-        err!(e);
+        if encrypted_filename.is_none() {
+            err!("No filename provided")
+        }
+        if data.key.is_none() {
+            err!("No attachment key provided")
+        }
+        let attachment = Attachment::new(file_id, cipher_uuid.clone(), encrypted_filename.unwrap(), size, data.key);
+        attachment.save(&conn).expect("Error saving attachment");
     }
 
-    nt.send_cipher_update(UpdateType::CipherUpdate, &cipher, &cipher.update_users_revision(conn));
+    data.data.persist_to(file_path).await?;
+
+    nt.send_cipher_update(UpdateType::CipherUpdate, &cipher, &cipher.update_users_revision(&conn));
 
-    Ok(cipher)
+    Ok((cipher, conn))
 }
 
 /// v2 API for uploading the actual data content of an attachment.
@@ -1006,14 +960,13 @@ fn save_attachment(
 /// /ciphers/<uuid>/attachment/v2 route, which would otherwise conflict
 /// with this one.
 #[post("/ciphers/<uuid>/attachment/<attachment_id>", format = "multipart/form-data", data = "<data>", rank = 1)]
-fn post_attachment_v2_data(
+async fn post_attachment_v2_data(
     uuid: String,
     attachment_id: String,
-    data: Data,
-    content_type: &ContentType,
+    data: Form<UploadData<'_>>,
     headers: Headers,
     conn: DbConn,
-    nt: Notify,
+    nt: Notify<'_>,
 ) -> EmptyResult {
     let attachment = match Attachment::find_by_id(&attachment_id, &conn) {
         Some(attachment) if uuid == attachment.cipher_uuid => Some(attachment),
@@ -1021,54 +974,51 @@ fn post_attachment_v2_data(
         None => err!("Attachment doesn't exist"),
     };
 
-    save_attachment(attachment, uuid, data, content_type, &headers, &conn, nt)?;
+    save_attachment(attachment, uuid, data, &headers, conn, nt).await?;
 
     Ok(())
 }
 
 /// Legacy API for creating an attachment associated with a cipher.
 #[post("/ciphers/<uuid>/attachment", format = "multipart/form-data", data = "<data>")]
-fn post_attachment(
+async fn post_attachment(
     uuid: String,
-    data: Data,
-    content_type: &ContentType,
+    data: Form<UploadData<'_>>,
     headers: Headers,
     conn: DbConn,
-    nt: Notify,
+    nt: Notify<'_>,
 ) -> JsonResult {
     // Setting this as None signifies to save_attachment() that it should create
     // the attachment database record as well as saving the data to disk.
     let attachment = None;
 
-    let cipher = save_attachment(attachment, uuid, data, content_type, &headers, &conn, nt)?;
+    let (cipher, conn) = save_attachment(attachment, uuid, data, &headers, conn, nt).await?;
 
     Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, &conn)))
 }
 
 #[post("/ciphers/<uuid>/attachment-admin", format = "multipart/form-data", data = "<data>")]
-fn post_attachment_admin(
+async fn post_attachment_admin(
     uuid: String,
-    data: Data,
-    content_type: &ContentType,
+    data: Form<UploadData<'_>>,
     headers: Headers,
     conn: DbConn,
-    nt: Notify,
+    nt: Notify<'_>,
 ) -> JsonResult {
-    post_attachment(uuid, data, content_type, headers, conn, nt)
+    post_attachment(uuid, data, headers, conn, nt).await
 }
 
 #[post("/ciphers/<uuid>/attachment/<attachment_id>/share", format = "multipart/form-data", data = "<data>")]
-fn post_attachment_share(
+async fn post_attachment_share(
     uuid: String,
     attachment_id: String,
-    data: Data,
-    content_type: &ContentType,
+    data: Form<UploadData<'_>>,
     headers: Headers,
     conn: DbConn,
-    nt: Notify,
+    nt: Notify<'_>,
 ) -> JsonResult {
     _delete_cipher_attachment_by_id(&uuid, &attachment_id, &headers, &conn, &nt)?;
-    post_attachment(uuid, data, content_type, headers, conn, nt)
+    post_attachment(uuid, data, headers, conn, nt).await
 }
 
 #[post("/ciphers/<uuid>/attachment/<attachment_id>/delete-admin")]
@@ -1248,13 +1198,13 @@ fn move_cipher_selected_put(
 
 #[derive(FromForm)]
 struct OrganizationId {
-    #[form(field = "organizationId")]
+    #[field(name = "organizationId")]
     org_id: String,
 }
 
 #[post("/ciphers/purge?<organization..>", data = "<data>")]
 fn delete_all(
-    organization: Option<Form<OrganizationId>>,
+    organization: Option<OrganizationId>,
     data: JsonUpcase<PasswordData>,
     headers: Headers,
     conn: DbConn,

+ 5 - 5
src/api/core/emergency_access.rs

@@ -1,6 +1,6 @@
 use chrono::{Duration, Utc};
+use rocket::serde::json::Json;
 use rocket::Route;
-use rocket_contrib::json::Json;
 use serde_json::Value;
 use std::borrow::Borrow;
 
@@ -709,13 +709,13 @@ fn check_emergency_access_allowed() -> EmptyResult {
     Ok(())
 }
 
-pub fn emergency_request_timeout_job(pool: DbPool) {
+pub async fn emergency_request_timeout_job(pool: DbPool) {
     debug!("Start emergency_request_timeout_job");
     if !CONFIG.emergency_access_allowed() {
         return;
     }
 
-    if let Ok(conn) = pool.get() {
+    if let Ok(conn) = pool.get().await {
         let emergency_access_list = EmergencyAccess::find_all_recoveries(&conn);
 
         if emergency_access_list.is_empty() {
@@ -756,13 +756,13 @@ pub fn emergency_request_timeout_job(pool: DbPool) {
     }
 }
 
-pub fn emergency_notification_reminder_job(pool: DbPool) {
+pub async fn emergency_notification_reminder_job(pool: DbPool) {
     debug!("Start emergency_notification_reminder_job");
     if !CONFIG.emergency_access_allowed() {
         return;
     }
 
-    if let Ok(conn) = pool.get() {
+    if let Ok(conn) = pool.get().await {
         let emergency_access_list = EmergencyAccess::find_all_recoveries(&conn);
 
         if emergency_access_list.is_empty() {

+ 1 - 1
src/api/core/folders.rs

@@ -1,4 +1,4 @@
-use rocket_contrib::json::Json;
+use rocket::serde::json::Json;
 use serde_json::Value;
 
 use crate::{

+ 4 - 4
src/api/core/mod.rs

@@ -31,8 +31,8 @@ pub fn routes() -> Vec<Route> {
 //
 // Move this somewhere else
 //
+use rocket::serde::json::Json;
 use rocket::Route;
-use rocket_contrib::json::Json;
 use serde_json::Value;
 
 use crate::{
@@ -144,7 +144,7 @@ fn put_eq_domains(data: JsonUpcase<EquivDomainData>, headers: Headers, conn: DbC
 }
 
 #[get("/hibp/breach?<username>")]
-fn hibp_breach(username: String) -> JsonResult {
+async fn hibp_breach(username: String) -> JsonResult {
     let url = format!(
         "https://haveibeenpwned.com/api/v3/breachedaccount/{}?truncateResponse=false&includeUnverified=false",
         username
@@ -153,14 +153,14 @@ fn hibp_breach(username: String) -> JsonResult {
     if let Some(api_key) = crate::CONFIG.hibp_api_key() {
         let hibp_client = get_reqwest_client();
 
-        let res = hibp_client.get(&url).header("hibp-api-key", api_key).send()?;
+        let res = hibp_client.get(&url).header("hibp-api-key", api_key).send().await?;
 
         // If we get a 404, return a 404, it means no breached accounts
         if res.status() == 404 {
             return Err(Error::empty().with_code(404));
         }
 
-        let value: Value = res.error_for_status()?.json()?;
+        let value: Value = res.error_for_status()?.json().await?;
         Ok(Json(value))
     } else {
         Ok(Json(json!([{

+ 6 - 6
src/api/core/organizations.rs

@@ -1,6 +1,6 @@
 use num_traits::FromPrimitive;
-use rocket::{request::Form, Route};
-use rocket_contrib::json::Json;
+use rocket::serde::json::Json;
+use rocket::Route;
 use serde_json::Value;
 
 use crate::{
@@ -469,12 +469,12 @@ fn put_collection_users(
 
 #[derive(FromForm)]
 struct OrgIdData {
-    #[form(field = "organizationId")]
+    #[field(name = "organizationId")]
     organization_id: String,
 }
 
 #[get("/ciphers/organization-details?<data..>")]
-fn get_org_details(data: Form<OrgIdData>, headers: Headers, conn: DbConn) -> Json<Value> {
+fn get_org_details(data: OrgIdData, headers: Headers, conn: DbConn) -> Json<Value> {
     let ciphers = Cipher::find_by_org(&data.organization_id, &conn);
     let ciphers_json: Vec<Value> =
         ciphers.iter().map(|c| c.to_json(&headers.host, &headers.user.uuid, &conn)).collect();
@@ -1097,14 +1097,14 @@ struct RelationsData {
 
 #[post("/ciphers/import-organization?<query..>", data = "<data>")]
 fn post_org_import(
-    query: Form<OrgIdData>,
+    query: OrgIdData,
     data: JsonUpcase<ImportData>,
     headers: AdminHeaders,
     conn: DbConn,
     nt: Notify,
 ) -> EmptyResult {
     let data: ImportData = data.into_inner().data;
-    let org_id = query.into_inner().organization_id;
+    let org_id = query.organization_id;
 
     // Read and create the collections
     let collections: Vec<_> = data

+ 33 - 52
src/api/core/sends.rs

@@ -1,9 +1,10 @@
-use std::{io::Read, path::Path};
+use std::path::Path;
 
 use chrono::{DateTime, Duration, Utc};
-use multipart::server::{save::SavedData, Multipart, SaveResult};
-use rocket::{http::ContentType, response::NamedFile, Data};
-use rocket_contrib::json::Json;
+use rocket::form::Form;
+use rocket::fs::NamedFile;
+use rocket::fs::TempFile;
+use rocket::serde::json::Json;
 use serde_json::Value;
 
 use crate::{
@@ -31,9 +32,9 @@ pub fn routes() -> Vec<rocket::Route> {
     ]
 }
 
-pub fn purge_sends(pool: DbPool) {
+pub async fn purge_sends(pool: DbPool) {
     debug!("Purging sends");
-    if let Ok(conn) = pool.get() {
+    if let Ok(conn) = pool.get().await {
         Send::purge(&conn);
     } else {
         error!("Failed to get DB connection while purging sends")
@@ -177,25 +178,23 @@ fn post_send(data: JsonUpcase<SendData>, headers: Headers, conn: DbConn, nt: Not
     Ok(Json(send.to_json()))
 }
 
+#[derive(FromForm)]
+struct UploadData<'f> {
+    model: Json<crate::util::UpCase<SendData>>,
+    data: TempFile<'f>,
+}
+
 #[post("/sends/file", format = "multipart/form-data", data = "<data>")]
-fn post_send_file(data: Data, content_type: &ContentType, headers: Headers, conn: DbConn, nt: Notify) -> JsonResult {
+async fn post_send_file(data: Form<UploadData<'_>>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
     enforce_disable_send_policy(&headers, &conn)?;
 
-    let boundary = content_type.params().next().expect("No boundary provided").1;
-
-    let mut mpart = Multipart::with_body(data.open(), boundary);
-
-    // First entry is the SendData JSON
-    let mut model_entry = match mpart.read_entry()? {
-        Some(e) if &*e.headers.name == "model" => e,
-        Some(_) => err!("Invalid entry name"),
-        None => err!("No model entry present"),
-    };
+    let UploadData {
+        model,
+        mut data,
+    } = data.into_inner();
+    let model = model.into_inner().data;
 
-    let mut buf = String::new();
-    model_entry.data.read_to_string(&mut buf)?;
-    let data = serde_json::from_str::<crate::util::UpCase<SendData>>(&buf)?;
-    enforce_disable_hide_email_policy(&data.data, &headers, &conn)?;
+    enforce_disable_hide_email_policy(&model, &headers, &conn)?;
 
     // Get the file length and add an extra 5% to avoid issues
     const SIZE_525_MB: u64 = 550_502_400;
@@ -212,45 +211,27 @@ fn post_send_file(data: Data, content_type: &ContentType, headers: Headers, conn
         None => SIZE_525_MB,
     };
 
-    // Create the Send
-    let mut send = create_send(data.data, headers.user.uuid)?;
-    let file_id = crate::crypto::generate_send_id();
-
+    let mut send = create_send(model, headers.user.uuid)?;
     if send.atype != SendType::File as i32 {
         err!("Send content is not a file");
     }
 
-    let file_path = Path::new(&CONFIG.sends_folder()).join(&send.uuid).join(&file_id);
-
-    // Read the data entry and save the file
-    let mut data_entry = match mpart.read_entry()? {
-        Some(e) if &*e.headers.name == "data" => e,
-        Some(_) => err!("Invalid entry name"),
-        None => err!("No model entry present"),
-    };
+    let size = data.len();
+    if size > size_limit {
+        err!("Attachment storage limit exceeded with this file");
+    }
 
-    let size = match data_entry.data.save().memory_threshold(0).size_limit(size_limit).with_path(&file_path) {
-        SaveResult::Full(SavedData::File(_, size)) => size as i32,
-        SaveResult::Full(other) => {
-            std::fs::remove_file(&file_path).ok();
-            err!(format!("Attachment is not a file: {:?}", other));
-        }
-        SaveResult::Partial(_, reason) => {
-            std::fs::remove_file(&file_path).ok();
-            err!(format!("Attachment storage limit exceeded with this file: {:?}", reason));
-        }
-        SaveResult::Error(e) => {
-            std::fs::remove_file(&file_path).ok();
-            err!(format!("Error: {:?}", e));
-        }
-    };
+    let file_id = crate::crypto::generate_send_id();
+    let folder_path = tokio::fs::canonicalize(&CONFIG.sends_folder()).await?.join(&send.uuid);
+    let file_path = folder_path.join(&file_id);
+    tokio::fs::create_dir_all(&folder_path).await?;
+    data.persist_to(&file_path).await?;
 
-    // Set ID and sizes
     let mut data_value: Value = serde_json::from_str(&send.data)?;
     if let Some(o) = data_value.as_object_mut() {
         o.insert(String::from("Id"), Value::String(file_id));
         o.insert(String::from("Size"), Value::Number(size.into()));
-        o.insert(String::from("SizeName"), Value::String(crate::util::get_display_size(size)));
+        o.insert(String::from("SizeName"), Value::String(crate::util::get_display_size(size as i32)));
     }
     send.data = serde_json::to_string(&data_value)?;
 
@@ -367,10 +348,10 @@ fn post_access_file(
 }
 
 #[get("/sends/<send_id>/<file_id>?<t>")]
-fn download_send(send_id: SafeString, file_id: SafeString, t: String) -> Option<NamedFile> {
+async fn download_send(send_id: SafeString, file_id: SafeString, t: String) -> Option<NamedFile> {
     if let Ok(claims) = crate::auth::decode_send(&t) {
         if claims.sub == format!("{}/{}", send_id, file_id) {
-            return NamedFile::open(Path::new(&CONFIG.sends_folder()).join(send_id).join(file_id)).ok();
+            return NamedFile::open(Path::new(&CONFIG.sends_folder()).join(send_id).join(file_id)).await.ok();
         }
     }
     None

+ 1 - 1
src/api/core/two_factor/authenticator.rs

@@ -1,6 +1,6 @@
 use data_encoding::BASE32;
+use rocket::serde::json::Json;
 use rocket::Route;
-use rocket_contrib::json::Json;
 
 use crate::{
     api::{

+ 8 - 7
src/api/core/two_factor/duo.rs

@@ -1,7 +1,7 @@
 use chrono::Utc;
 use data_encoding::BASE64;
+use rocket::serde::json::Json;
 use rocket::Route;
-use rocket_contrib::json::Json;
 
 use crate::{
     api::{core::two_factor::_generate_recover_code, ApiResult, EmptyResult, JsonResult, JsonUpcase, PasswordData},
@@ -152,7 +152,7 @@ fn check_duo_fields_custom(data: &EnableDuoData) -> bool {
 }
 
 #[post("/two-factor/duo", data = "<data>")]
-fn activate_duo(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn) -> JsonResult {
+async fn activate_duo(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn) -> JsonResult {
     let data: EnableDuoData = data.into_inner().data;
     let mut user = headers.user;
 
@@ -163,7 +163,7 @@ fn activate_duo(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn)
     let (data, data_str) = if check_duo_fields_custom(&data) {
         let data_req: DuoData = data.into();
         let data_str = serde_json::to_string(&data_req)?;
-        duo_api_request("GET", "/auth/v2/check", "", &data_req).map_res("Failed to validate Duo credentials")?;
+        duo_api_request("GET", "/auth/v2/check", "", &data_req).await.map_res("Failed to validate Duo credentials")?;
         (data_req.obscure(), data_str)
     } else {
         (DuoData::secret(), String::new())
@@ -185,11 +185,11 @@ fn activate_duo(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn)
 }
 
 #[put("/two-factor/duo", data = "<data>")]
-fn activate_duo_put(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn) -> JsonResult {
-    activate_duo(data, headers, conn)
+async fn activate_duo_put(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn) -> JsonResult {
+    activate_duo(data, headers, conn).await
 }
 
-fn duo_api_request(method: &str, path: &str, params: &str, data: &DuoData) -> EmptyResult {
+async fn duo_api_request(method: &str, path: &str, params: &str, data: &DuoData) -> EmptyResult {
     use reqwest::{header, Method};
     use std::str::FromStr;
 
@@ -209,7 +209,8 @@ fn duo_api_request(method: &str, path: &str, params: &str, data: &DuoData) -> Em
         .basic_auth(username, Some(password))
         .header(header::USER_AGENT, "vaultwarden:Duo/1.0 (Rust)")
         .header(header::DATE, date)
-        .send()?
+        .send()
+        .await?
         .error_for_status()?;
 
     Ok(())

+ 1 - 1
src/api/core/two_factor/email.rs

@@ -1,6 +1,6 @@
 use chrono::{Duration, NaiveDateTime, Utc};
+use rocket::serde::json::Json;
 use rocket::Route;
-use rocket_contrib::json::Json;
 
 use crate::{
     api::{core::two_factor::_generate_recover_code, EmptyResult, JsonResult, JsonUpcase, PasswordData},

+ 3 - 3
src/api/core/two_factor/mod.rs

@@ -1,7 +1,7 @@
 use chrono::{Duration, Utc};
 use data_encoding::BASE32;
+use rocket::serde::json::Json;
 use rocket::Route;
-use rocket_contrib::json::Json;
 use serde_json::Value;
 
 use crate::{
@@ -158,14 +158,14 @@ fn disable_twofactor_put(data: JsonUpcase<DisableTwoFactorData>, headers: Header
     disable_twofactor(data, headers, conn)
 }
 
-pub fn send_incomplete_2fa_notifications(pool: DbPool) {
+pub async fn send_incomplete_2fa_notifications(pool: DbPool) {
     debug!("Sending notifications for incomplete 2FA logins");
 
     if CONFIG.incomplete_2fa_time_limit() <= 0 || !CONFIG.mail_enabled() {
         return;
     }
 
-    let conn = match pool.get() {
+    let conn = match pool.get().await {
         Ok(conn) => conn,
         _ => {
             error!("Failed to get DB connection in send_incomplete_2fa_notifications()");

+ 1 - 1
src/api/core/two_factor/u2f.rs

@@ -1,6 +1,6 @@
 use once_cell::sync::Lazy;
+use rocket::serde::json::Json;
 use rocket::Route;
-use rocket_contrib::json::Json;
 use serde_json::Value;
 use u2f::{
     messages::{RegisterResponse, SignResponse, U2fSignRequest},

+ 1 - 1
src/api/core/two_factor/webauthn.rs

@@ -1,5 +1,5 @@
+use rocket::serde::json::Json;
 use rocket::Route;
-use rocket_contrib::json::Json;
 use serde_json::Value;
 use url::Url;
 use webauthn_rs::{base64_data::Base64UrlSafeData, proto::*, AuthenticationState, RegistrationState, Webauthn};

+ 1 - 1
src/api/core/two_factor/yubikey.rs

@@ -1,5 +1,5 @@
+use rocket::serde::json::Json;
 use rocket::Route;
-use rocket_contrib::json::Json;
 use serde_json::Value;
 use yubico::{config::Config, verify};
 

+ 66 - 59
src/api/icons.rs

@@ -1,19 +1,19 @@
 use std::{
     collections::HashMap,
-    fs::{create_dir_all, remove_file, symlink_metadata, File},
-    io::prelude::*,
     net::{IpAddr, ToSocketAddrs},
     sync::{Arc, RwLock},
     time::{Duration, SystemTime},
 };
 
+use bytes::{Buf, Bytes, BytesMut};
+use futures::{stream::StreamExt, TryFutureExt};
 use once_cell::sync::Lazy;
 use regex::Regex;
-use reqwest::{blocking::Client, blocking::Response, header};
-use rocket::{
-    http::ContentType,
-    response::{Content, Redirect},
-    Route,
+use reqwest::{header, Client, Response};
+use rocket::{http::ContentType, response::Redirect, Route};
+use tokio::{
+    fs::{create_dir_all, remove_file, symlink_metadata, File},
+    io::{AsyncReadExt, AsyncWriteExt},
 };
 
 use crate::{
@@ -104,27 +104,23 @@ fn icon_google(domain: String) -> Option<Redirect> {
 }
 
 #[get("/<domain>/icon.png")]
-fn icon_internal(domain: String) -> Cached<Content<Vec<u8>>> {
+async fn icon_internal(domain: String) -> Cached<(ContentType, Vec<u8>)> {
     const FALLBACK_ICON: &[u8] = include_bytes!("../static/images/fallback-icon.png");
 
     if !is_valid_domain(&domain) {
         warn!("Invalid domain: {}", domain);
         return Cached::ttl(
-            Content(ContentType::new("image", "png"), FALLBACK_ICON.to_vec()),
+            (ContentType::new("image", "png"), FALLBACK_ICON.to_vec()),
             CONFIG.icon_cache_negttl(),
             true,
         );
     }
 
-    match get_icon(&domain) {
+    match get_icon(&domain).await {
         Some((icon, icon_type)) => {
-            Cached::ttl(Content(ContentType::new("image", icon_type), icon), CONFIG.icon_cache_ttl(), true)
+            Cached::ttl((ContentType::new("image", icon_type), icon), CONFIG.icon_cache_ttl(), true)
         }
-        _ => Cached::ttl(
-            Content(ContentType::new("image", "png"), FALLBACK_ICON.to_vec()),
-            CONFIG.icon_cache_negttl(),
-            true,
-        ),
+        _ => Cached::ttl((ContentType::new("image", "png"), FALLBACK_ICON.to_vec()), CONFIG.icon_cache_negttl(), true),
     }
 }
 
@@ -317,15 +313,15 @@ fn is_domain_blacklisted(domain: &str) -> bool {
     is_blacklisted
 }
 
-fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> {
+async fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> {
     let path = format!("{}/{}.png", CONFIG.icon_cache_folder(), domain);
 
     // Check for expiration of negatively cached copy
-    if icon_is_negcached(&path) {
+    if icon_is_negcached(&path).await {
         return None;
     }
 
-    if let Some(icon) = get_cached_icon(&path) {
+    if let Some(icon) = get_cached_icon(&path).await {
         let icon_type = match get_icon_type(&icon) {
             Some(x) => x,
             _ => "x-icon",
@@ -338,31 +334,31 @@ fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> {
     }
 
     // Get the icon, or None in case of error
-    match download_icon(domain) {
+    match download_icon(domain).await {
         Ok((icon, icon_type)) => {
-            save_icon(&path, &icon);
-            Some((icon, icon_type.unwrap_or("x-icon").to_string()))
+            save_icon(&path, &icon).await;
+            Some((icon.to_vec(), icon_type.unwrap_or("x-icon").to_string()))
         }
         Err(e) => {
             warn!("Unable to download icon: {:?}", e);
             let miss_indicator = path + ".miss";
-            save_icon(&miss_indicator, &[]);
+            save_icon(&miss_indicator, &[]).await;
             None
         }
     }
 }
 
-fn get_cached_icon(path: &str) -> Option<Vec<u8>> {
+async fn get_cached_icon(path: &str) -> Option<Vec<u8>> {
     // Check for expiration of successfully cached copy
-    if icon_is_expired(path) {
+    if icon_is_expired(path).await {
         return None;
     }
 
     // Try to read the cached icon, and return it if it exists
-    if let Ok(mut f) = File::open(path) {
+    if let Ok(mut f) = File::open(path).await {
         let mut buffer = Vec::new();
 
-        if f.read_to_end(&mut buffer).is_ok() {
+        if f.read_to_end(&mut buffer).await.is_ok() {
             return Some(buffer);
         }
     }
@@ -370,22 +366,22 @@ fn get_cached_icon(path: &str) -> Option<Vec<u8>> {
     None
 }
 
-fn file_is_expired(path: &str, ttl: u64) -> Result<bool, Error> {
-    let meta = symlink_metadata(path)?;
+async fn file_is_expired(path: &str, ttl: u64) -> Result<bool, Error> {
+    let meta = symlink_metadata(path).await?;
     let modified = meta.modified()?;
     let age = SystemTime::now().duration_since(modified)?;
 
     Ok(ttl > 0 && ttl <= age.as_secs())
 }
 
-fn icon_is_negcached(path: &str) -> bool {
+async fn icon_is_negcached(path: &str) -> bool {
     let miss_indicator = path.to_owned() + ".miss";
-    let expired = file_is_expired(&miss_indicator, CONFIG.icon_cache_negttl());
+    let expired = file_is_expired(&miss_indicator, CONFIG.icon_cache_negttl()).await;
 
     match expired {
         // No longer negatively cached, drop the marker
         Ok(true) => {
-            if let Err(e) = remove_file(&miss_indicator) {
+            if let Err(e) = remove_file(&miss_indicator).await {
                 error!("Could not remove negative cache indicator for icon {:?}: {:?}", path, e);
             }
             false
@@ -397,8 +393,8 @@ fn icon_is_negcached(path: &str) -> bool {
     }
 }
 
-fn icon_is_expired(path: &str) -> bool {
-    let expired = file_is_expired(path, CONFIG.icon_cache_ttl());
+async fn icon_is_expired(path: &str) -> bool {
+    let expired = file_is_expired(path, CONFIG.icon_cache_ttl()).await;
     expired.unwrap_or(true)
 }
 
@@ -521,13 +517,13 @@ struct IconUrlResult {
 /// let icon_result = get_icon_url("github.com")?;
 /// let icon_result = get_icon_url("vaultwarden.discourse.group")?;
 /// ```
-fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
+async fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
     // Default URL with secure and insecure schemes
     let ssldomain = format!("https://{}", domain);
     let httpdomain = format!("http://{}", domain);
 
     // First check the domain as given during the request for both HTTPS and HTTP.
-    let resp = match get_page(&ssldomain).or_else(|_| get_page(&httpdomain)) {
+    let resp = match get_page(&ssldomain).or_else(|_| get_page(&httpdomain)).await {
         Ok(c) => Ok(c),
         Err(e) => {
             let mut sub_resp = Err(e);
@@ -546,7 +542,7 @@ fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
                     let httpbase = format!("http://{}", base_domain);
                     debug!("[get_icon_url]: Trying without subdomains '{}'", base_domain);
 
-                    sub_resp = get_page(&sslbase).or_else(|_| get_page(&httpbase));
+                    sub_resp = get_page(&sslbase).or_else(|_| get_page(&httpbase)).await;
                 }
 
             // When the domain is not an IP, and has less then 2 dots, try to add www. infront of it.
@@ -557,7 +553,7 @@ fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
                     let httpwww = format!("http://{}", www_domain);
                     debug!("[get_icon_url]: Trying with www. prefix '{}'", www_domain);
 
-                    sub_resp = get_page(&sslwww).or_else(|_| get_page(&httpwww));
+                    sub_resp = get_page(&sslwww).or_else(|_| get_page(&httpwww)).await;
                 }
             }
 
@@ -581,7 +577,7 @@ fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
         iconlist.push(Icon::new(35, String::from(url.join("/favicon.ico").unwrap())));
 
         // 384KB should be more than enough for the HTML, though as we only really need the HTML header.
-        let mut limited_reader = content.take(384 * 1024);
+        let mut limited_reader = stream_to_bytes_limit(content, 384 * 1024).await?.reader();
 
         use html5ever::tendril::TendrilSink;
         let dom = html5ever::parse_document(markup5ever_rcdom::RcDom::default(), Default::default())
@@ -607,11 +603,11 @@ fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
     })
 }
 
-fn get_page(url: &str) -> Result<Response, Error> {
-    get_page_with_referer(url, "")
+async fn get_page(url: &str) -> Result<Response, Error> {
+    get_page_with_referer(url, "").await
 }
 
-fn get_page_with_referer(url: &str, referer: &str) -> Result<Response, Error> {
+async fn get_page_with_referer(url: &str, referer: &str) -> Result<Response, Error> {
     if is_domain_blacklisted(url::Url::parse(url).unwrap().host_str().unwrap_or_default()) {
         warn!("Favicon '{}' resolves to a blacklisted domain or IP!", url);
     }
@@ -621,7 +617,7 @@ fn get_page_with_referer(url: &str, referer: &str) -> Result<Response, Error> {
         client = client.header("Referer", referer)
     }
 
-    match client.send() {
+    match client.send().await {
         Ok(c) => c.error_for_status().map_err(Into::into),
         Err(e) => err_silent!(format!("{}", e)),
     }
@@ -706,14 +702,14 @@ fn parse_sizes(sizes: Option<&str>) -> (u16, u16) {
     (width, height)
 }
 
-fn download_icon(domain: &str) -> Result<(Vec<u8>, Option<&str>), Error> {
+async fn download_icon(domain: &str) -> Result<(Bytes, Option<&str>), Error> {
     if is_domain_blacklisted(domain) {
         err_silent!("Domain is blacklisted", domain)
     }
 
-    let icon_result = get_icon_url(domain)?;
+    let icon_result = get_icon_url(domain).await?;
 
-    let mut buffer = Vec::new();
+    let mut buffer = Bytes::new();
     let mut icon_type: Option<&str> = None;
 
     use data_url::DataUrl;
@@ -722,8 +718,12 @@ fn download_icon(domain: &str) -> Result<(Vec<u8>, Option<&str>), Error> {
         if icon.href.starts_with("data:image") {
             let datauri = DataUrl::process(&icon.href).unwrap();
             // Check if we are able to decode the data uri
-            match datauri.decode_to_vec() {
-                Ok((body, _fragment)) => {
+            let mut body = BytesMut::new();
+            match datauri.decode::<_, ()>(|bytes| {
+                body.extend_from_slice(bytes);
+                Ok(())
+            }) {
+                Ok(_) => {
                     // Also check if the size is atleast 67 bytes, which seems to be the smallest png i could create
                     if body.len() >= 67 {
                         // Check if the icon type is allowed, else try an icon from the list.
@@ -733,17 +733,17 @@ fn download_icon(domain: &str) -> Result<(Vec<u8>, Option<&str>), Error> {
                             continue;
                         }
                         info!("Extracted icon from data:image uri for {}", domain);
-                        buffer = body;
+                        buffer = body.freeze();
                         break;
                     }
                 }
                 _ => debug!("Extracted icon from data:image uri is invalid"),
             };
         } else {
-            match get_page_with_referer(&icon.href, &icon_result.referer) {
-                Ok(mut res) => {
-                    res.copy_to(&mut buffer)?;
-                    // Check if the icon type is allowed, else try an icon from the list.
+            match get_page_with_referer(&icon.href, &icon_result.referer).await {
+                Ok(res) => {
+                    buffer = stream_to_bytes_limit(res, 512 * 1024).await?; // 512 KB for each icon max
+                                                                            // Check if the icon type is allowed, else try an icon from the list.
                     icon_type = get_icon_type(&buffer);
                     if icon_type.is_none() {
                         buffer.clear();
@@ -765,13 +765,13 @@ fn download_icon(domain: &str) -> Result<(Vec<u8>, Option<&str>), Error> {
     Ok((buffer, icon_type))
 }
 
-fn save_icon(path: &str, icon: &[u8]) {
-    match File::create(path) {
+async fn save_icon(path: &str, icon: &[u8]) {
+    match File::create(path).await {
         Ok(mut f) => {
-            f.write_all(icon).expect("Error writing icon file");
+            f.write_all(icon).await.expect("Error writing icon file");
         }
         Err(ref e) if e.kind() == std::io::ErrorKind::NotFound => {
-            create_dir_all(&CONFIG.icon_cache_folder()).expect("Error creating icon cache folder");
+            create_dir_all(&CONFIG.icon_cache_folder()).await.expect("Error creating icon cache folder");
         }
         Err(e) => {
             warn!("Unable to save icon: {:?}", e);
@@ -820,8 +820,6 @@ impl reqwest::cookie::CookieStore for Jar {
     }
 
     fn cookies(&self, url: &url::Url) -> Option<header::HeaderValue> {
-        use bytes::Bytes;
-
         let cookie_store = self.0.read().unwrap();
         let s = cookie_store
             .get_request_values(url)
@@ -836,3 +834,12 @@ impl reqwest::cookie::CookieStore for Jar {
         header::HeaderValue::from_maybe_shared(Bytes::from(s)).ok()
     }
 }
+
+async fn stream_to_bytes_limit(res: Response, max_size: usize) -> Result<Bytes, reqwest::Error> {
+    let mut stream = res.bytes_stream().take(max_size);
+    let mut buf = BytesMut::new();
+    while let Some(chunk) = stream.next().await {
+        buf.extend(chunk?);
+    }
+    Ok(buf.freeze())
+}

+ 31 - 40
src/api/identity.rs

@@ -1,10 +1,10 @@
 use chrono::Utc;
 use num_traits::FromPrimitive;
+use rocket::serde::json::Json;
 use rocket::{
-    request::{Form, FormItems, FromForm},
+    form::{Form, FromForm},
     Route,
 };
-use rocket_contrib::json::Json;
 use serde_json::Value;
 
 use crate::{
@@ -455,66 +455,57 @@ fn _json_err_twofactor(providers: &[i32], user_uuid: &str, conn: &DbConn) -> Api
 
 // https://github.com/bitwarden/jslib/blob/master/common/src/models/request/tokenRequest.ts
 // https://github.com/bitwarden/mobile/blob/master/src/Core/Models/Request/TokenRequest.cs
-#[derive(Debug, Clone, Default)]
+#[derive(Debug, Clone, Default, FromForm)]
 #[allow(non_snake_case)]
 struct ConnectData {
-    // refresh_token, password, client_credentials (API key)
-    grant_type: String,
+    #[field(name = uncased("grant_type"))]
+    #[field(name = uncased("granttype"))]
+    grant_type: String, // refresh_token, password, client_credentials (API key)
 
     // Needed for grant_type="refresh_token"
+    #[field(name = uncased("refresh_token"))]
+    #[field(name = uncased("refreshtoken"))]
     refresh_token: Option<String>,
 
     // Needed for grant_type = "password" | "client_credentials"
-    client_id: Option<String>,     // web, cli, desktop, browser, mobile
-    client_secret: Option<String>, // API key login (cli only)
+    #[field(name = uncased("client_id"))]
+    #[field(name = uncased("clientid"))]
+    client_id: Option<String>, // web, cli, desktop, browser, mobile
+    #[field(name = uncased("client_secret"))]
+    #[field(name = uncased("clientsecret"))]
+    client_secret: Option<String>,
+    #[field(name = uncased("password"))]
     password: Option<String>,
+    #[field(name = uncased("scope"))]
     scope: Option<String>,
+    #[field(name = uncased("username"))]
     username: Option<String>,
 
+    #[field(name = uncased("device_identifier"))]
+    #[field(name = uncased("deviceidentifier"))]
     device_identifier: Option<String>,
+    #[field(name = uncased("device_name"))]
+    #[field(name = uncased("devicename"))]
     device_name: Option<String>,
+    #[field(name = uncased("device_type"))]
+    #[field(name = uncased("devicetype"))]
     device_type: Option<String>,
+    #[field(name = uncased("device_push_token"))]
+    #[field(name = uncased("devicepushtoken"))]
     device_push_token: Option<String>, // Unused; mobile device push not yet supported.
 
     // Needed for two-factor auth
+    #[field(name = uncased("two_factor_provider"))]
+    #[field(name = uncased("twofactorprovider"))]
     two_factor_provider: Option<i32>,
+    #[field(name = uncased("two_factor_token"))]
+    #[field(name = uncased("twofactortoken"))]
     two_factor_token: Option<String>,
+    #[field(name = uncased("two_factor_remember"))]
+    #[field(name = uncased("twofactorremember"))]
     two_factor_remember: Option<i32>,
 }
 
-impl<'f> FromForm<'f> for ConnectData {
-    type Error = String;
-
-    fn from_form(items: &mut FormItems<'f>, _strict: bool) -> Result<Self, Self::Error> {
-        let mut form = Self::default();
-        for item in items {
-            let (key, value) = item.key_value_decoded();
-            let mut normalized_key = key.to_lowercase();
-            normalized_key.retain(|c| c != '_'); // Remove '_'
-
-            match normalized_key.as_ref() {
-                "granttype" => form.grant_type = value,
-                "refreshtoken" => form.refresh_token = Some(value),
-                "clientid" => form.client_id = Some(value),
-                "clientsecret" => form.client_secret = Some(value),
-                "password" => form.password = Some(value),
-                "scope" => form.scope = Some(value),
-                "username" => form.username = Some(value),
-                "deviceidentifier" => form.device_identifier = Some(value),
-                "devicename" => form.device_name = Some(value),
-                "devicetype" => form.device_type = Some(value),
-                "devicepushtoken" => form.device_push_token = Some(value),
-                "twofactorprovider" => form.two_factor_provider = value.parse().ok(),
-                "twofactortoken" => form.two_factor_token = Some(value),
-                "twofactorremember" => form.two_factor_remember = value.parse().ok(),
-                key => warn!("Detected unexpected parameter during login: {}", key),
-            }
-        }
-
-        Ok(form)
-    }
-}
-
 fn _check_is_some<T>(value: &Option<T>, msg: &str) -> EmptyResult {
     if value.is_none() {
         err!(msg)

+ 1 - 1
src/api/mod.rs

@@ -5,7 +5,7 @@ mod identity;
 mod notifications;
 mod web;
 
-use rocket_contrib::json::Json;
+use rocket::serde::json::Json;
 use serde_json::Value;
 
 pub use crate::api::{

+ 7 - 8
src/api/notifications.rs

@@ -1,7 +1,7 @@
 use std::sync::atomic::{AtomicBool, Ordering};
 
+use rocket::serde::json::Json;
 use rocket::Route;
-use rocket_contrib::json::Json;
 use serde_json::Value as JsonValue;
 
 use crate::{api::EmptyResult, auth::Headers, Error, CONFIG};
@@ -417,7 +417,7 @@ pub enum UpdateType {
 }
 
 use rocket::State;
-pub type Notify<'a> = State<'a, WebSocketUsers>;
+pub type Notify<'a> = &'a State<WebSocketUsers>;
 
 pub fn start_notification_server() -> WebSocketUsers {
     let factory = WsFactory::init();
@@ -430,12 +430,11 @@ pub fn start_notification_server() -> WebSocketUsers {
             settings.queue_size = 2;
             settings.panic_on_internal = false;
 
-            ws::Builder::new()
-                .with_settings(settings)
-                .build(factory)
-                .unwrap()
-                .listen((CONFIG.websocket_address().as_str(), CONFIG.websocket_port()))
-                .unwrap();
+            let ws = ws::Builder::new().with_settings(settings).build(factory).unwrap();
+            CONFIG.set_ws_shutdown_handle(ws.broadcaster());
+            ws.listen((CONFIG.websocket_address().as_str(), CONFIG.websocket_port())).unwrap();
+
+            warn!("WS Server stopped!");
         });
     }
 

+ 22 - 27
src/api/web.rs

@@ -1,7 +1,7 @@
 use std::path::{Path, PathBuf};
 
-use rocket::{http::ContentType, response::content::Content, response::NamedFile, Route};
-use rocket_contrib::json::Json;
+use rocket::serde::json::Json;
+use rocket::{fs::NamedFile, http::ContentType, Route};
 use serde_json::Value;
 
 use crate::{
@@ -21,16 +21,16 @@ pub fn routes() -> Vec<Route> {
 }
 
 #[get("/")]
-fn web_index() -> Cached<Option<NamedFile>> {
-    Cached::short(NamedFile::open(Path::new(&CONFIG.web_vault_folder()).join("index.html")).ok(), false)
+async fn web_index() -> Cached<Option<NamedFile>> {
+    Cached::short(NamedFile::open(Path::new(&CONFIG.web_vault_folder()).join("index.html")).await.ok(), false)
 }
 
 #[get("/app-id.json")]
-fn app_id() -> Cached<Content<Json<Value>>> {
+fn app_id() -> Cached<(ContentType, Json<Value>)> {
     let content_type = ContentType::new("application", "fido.trusted-apps+json");
 
     Cached::long(
-        Content(
+        (
             content_type,
             Json(json!({
             "trustedFacets": [
@@ -58,13 +58,13 @@ fn app_id() -> Cached<Content<Json<Value>>> {
 }
 
 #[get("/<p..>", rank = 10)] // Only match this if the other routes don't match
-fn web_files(p: PathBuf) -> Cached<Option<NamedFile>> {
-    Cached::long(NamedFile::open(Path::new(&CONFIG.web_vault_folder()).join(p)).ok(), true)
+async fn web_files(p: PathBuf) -> Cached<Option<NamedFile>> {
+    Cached::long(NamedFile::open(Path::new(&CONFIG.web_vault_folder()).join(p)).await.ok(), true)
 }
 
 #[get("/attachments/<uuid>/<file_id>")]
-fn attachments(uuid: SafeString, file_id: SafeString) -> Option<NamedFile> {
-    NamedFile::open(Path::new(&CONFIG.attachments_folder()).join(uuid).join(file_id)).ok()
+async fn attachments(uuid: SafeString, file_id: SafeString) -> Option<NamedFile> {
+    NamedFile::open(Path::new(&CONFIG.attachments_folder()).join(uuid).join(file_id)).await.ok()
 }
 
 // We use DbConn here to let the alive healthcheck also verify the database connection.
@@ -78,25 +78,20 @@ fn alive(_conn: DbConn) -> Json<String> {
 }
 
 #[get("/vw_static/<filename>")]
-fn static_files(filename: String) -> Result<Content<&'static [u8]>, Error> {
+fn static_files(filename: String) -> Result<(ContentType, &'static [u8]), Error> {
     match filename.as_ref() {
-        "mail-github.png" => Ok(Content(ContentType::PNG, include_bytes!("../static/images/mail-github.png"))),
-        "logo-gray.png" => Ok(Content(ContentType::PNG, include_bytes!("../static/images/logo-gray.png"))),
-        "error-x.svg" => Ok(Content(ContentType::SVG, include_bytes!("../static/images/error-x.svg"))),
-        "hibp.png" => Ok(Content(ContentType::PNG, include_bytes!("../static/images/hibp.png"))),
-        "vaultwarden-icon.png" => {
-            Ok(Content(ContentType::PNG, include_bytes!("../static/images/vaultwarden-icon.png")))
-        }
-
-        "bootstrap.css" => Ok(Content(ContentType::CSS, include_bytes!("../static/scripts/bootstrap.css"))),
-        "bootstrap-native.js" => {
-            Ok(Content(ContentType::JavaScript, include_bytes!("../static/scripts/bootstrap-native.js")))
-        }
-        "identicon.js" => Ok(Content(ContentType::JavaScript, include_bytes!("../static/scripts/identicon.js"))),
-        "datatables.js" => Ok(Content(ContentType::JavaScript, include_bytes!("../static/scripts/datatables.js"))),
-        "datatables.css" => Ok(Content(ContentType::CSS, include_bytes!("../static/scripts/datatables.css"))),
+        "mail-github.png" => Ok((ContentType::PNG, include_bytes!("../static/images/mail-github.png"))),
+        "logo-gray.png" => Ok((ContentType::PNG, include_bytes!("../static/images/logo-gray.png"))),
+        "error-x.svg" => Ok((ContentType::SVG, include_bytes!("../static/images/error-x.svg"))),
+        "hibp.png" => Ok((ContentType::PNG, include_bytes!("../static/images/hibp.png"))),
+        "vaultwarden-icon.png" => Ok((ContentType::PNG, include_bytes!("../static/images/vaultwarden-icon.png"))),
+        "bootstrap.css" => Ok((ContentType::CSS, include_bytes!("../static/scripts/bootstrap.css"))),
+        "bootstrap-native.js" => Ok((ContentType::JavaScript, include_bytes!("../static/scripts/bootstrap-native.js"))),
+        "identicon.js" => Ok((ContentType::JavaScript, include_bytes!("../static/scripts/identicon.js"))),
+        "datatables.js" => Ok((ContentType::JavaScript, include_bytes!("../static/scripts/datatables.js"))),
+        "datatables.css" => Ok((ContentType::CSS, include_bytes!("../static/scripts/datatables.css"))),
         "jquery-3.6.0.slim.js" => {
-            Ok(Content(ContentType::JavaScript, include_bytes!("../static/scripts/jquery-3.6.0.slim.js")))
+            Ok((ContentType::JavaScript, include_bytes!("../static/scripts/jquery-3.6.0.slim.js")))
         }
         _ => err!(format!("Static file not found: {}", filename)),
     }

+ 122 - 144
src/auth.rs

@@ -257,7 +257,10 @@ pub fn generate_send_claims(send_id: &str, file_id: &str) -> BasicJwtClaims {
 //
 // Bearer token authentication
 //
-use rocket::request::{FromRequest, Outcome, Request};
+use rocket::{
+    outcome::try_outcome,
+    request::{FromRequest, Outcome, Request},
+};
 
 use crate::db::{
     models::{CollectionUser, Device, User, UserOrgStatus, UserOrgType, UserOrganization, UserStampException},
@@ -268,10 +271,11 @@ pub struct Host {
     pub host: String,
 }
 
-impl<'a, 'r> FromRequest<'a, 'r> for Host {
+#[rocket::async_trait]
+impl<'r> FromRequest<'r> for Host {
     type Error = &'static str;
 
-    fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
+    async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
         let headers = request.headers();
 
         // Get host
@@ -314,17 +318,14 @@ pub struct Headers {
     pub user: User,
 }
 
-impl<'a, 'r> FromRequest<'a, 'r> for Headers {
+#[rocket::async_trait]
+impl<'r> FromRequest<'r> for Headers {
     type Error = &'static str;
 
-    fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
+    async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
         let headers = request.headers();
 
-        let host = match Host::from_request(request) {
-            Outcome::Forward(_) => return Outcome::Forward(()),
-            Outcome::Failure(f) => return Outcome::Failure(f),
-            Outcome::Success(host) => host.host,
-        };
+        let host = try_outcome!(Host::from_request(request).await).host;
 
         // Get access_token
         let access_token: &str = match headers.get_one("Authorization") {
@@ -344,7 +345,7 @@ impl<'a, 'r> FromRequest<'a, 'r> for Headers {
         let device_uuid = claims.device;
         let user_uuid = claims.sub;
 
-        let conn = match request.guard::<DbConn>() {
+        let conn = match DbConn::from_request(request).await {
             Outcome::Success(conn) => conn,
             _ => err_handler!("Error getting DB"),
         };
@@ -363,7 +364,7 @@ impl<'a, 'r> FromRequest<'a, 'r> for Headers {
             if let Some(stamp_exception) =
                 user.stamp_exception.as_deref().and_then(|s| serde_json::from_str::<UserStampException>(s).ok())
             {
-                let current_route = match request.route().and_then(|r| r.name) {
+                let current_route = match request.route().and_then(|r| r.name.as_deref()) {
                     Some(name) => name,
                     _ => err_handler!("Error getting current route for stamp exception"),
                 };
@@ -411,13 +412,13 @@ pub struct OrgHeaders {
 // but there are cases where it is a query value.
 // First check the path, if this is not a valid uuid, try the query values.
 fn get_org_id(request: &Request) -> Option<String> {
-    if let Some(Ok(org_id)) = request.get_param::<String>(1) {
+    if let Some(Ok(org_id)) = request.param::<String>(1) {
         if uuid::Uuid::parse_str(&org_id).is_ok() {
             return Some(org_id);
         }
     }
 
-    if let Some(Ok(org_id)) = request.get_query_value::<String>("organizationId") {
+    if let Some(Ok(org_id)) = request.query_value::<String>("organizationId") {
         if uuid::Uuid::parse_str(&org_id).is_ok() {
             return Some(org_id);
         }
@@ -426,52 +427,48 @@ fn get_org_id(request: &Request) -> Option<String> {
     None
 }
 
-impl<'a, 'r> FromRequest<'a, 'r> for OrgHeaders {
+#[rocket::async_trait]
+impl<'r> FromRequest<'r> for OrgHeaders {
     type Error = &'static str;
 
-    fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
-        match request.guard::<Headers>() {
-            Outcome::Forward(_) => Outcome::Forward(()),
-            Outcome::Failure(f) => Outcome::Failure(f),
-            Outcome::Success(headers) => {
-                match get_org_id(request) {
-                    Some(org_id) => {
-                        let conn = match request.guard::<DbConn>() {
-                            Outcome::Success(conn) => conn,
-                            _ => err_handler!("Error getting DB"),
-                        };
-
-                        let user = headers.user;
-                        let org_user = match UserOrganization::find_by_user_and_org(&user.uuid, &org_id, &conn) {
-                            Some(user) => {
-                                if user.status == UserOrgStatus::Confirmed as i32 {
-                                    user
-                                } else {
-                                    err_handler!("The current user isn't confirmed member of the organization")
-                                }
-                            }
-                            None => err_handler!("The current user isn't member of the organization"),
-                        };
-
-                        Outcome::Success(Self {
-                            host: headers.host,
-                            device: headers.device,
-                            user,
-                            org_user_type: {
-                                if let Some(org_usr_type) = UserOrgType::from_i32(org_user.atype) {
-                                    org_usr_type
-                                } else {
-                                    // This should only happen if the DB is corrupted
-                                    err_handler!("Unknown user type in the database")
-                                }
-                            },
-                            org_user,
-                            org_id,
-                        })
+    async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
+        let headers = try_outcome!(Headers::from_request(request).await);
+        match get_org_id(request) {
+            Some(org_id) => {
+                let conn = match DbConn::from_request(request).await {
+                    Outcome::Success(conn) => conn,
+                    _ => err_handler!("Error getting DB"),
+                };
+
+                let user = headers.user;
+                let org_user = match UserOrganization::find_by_user_and_org(&user.uuid, &org_id, &conn) {
+                    Some(user) => {
+                        if user.status == UserOrgStatus::Confirmed as i32 {
+                            user
+                        } else {
+                            err_handler!("The current user isn't confirmed member of the organization")
+                        }
                     }
-                    _ => err_handler!("Error getting the organization id"),
-                }
+                    None => err_handler!("The current user isn't member of the organization"),
+                };
+
+                Outcome::Success(Self {
+                    host: headers.host,
+                    device: headers.device,
+                    user,
+                    org_user_type: {
+                        if let Some(org_usr_type) = UserOrgType::from_i32(org_user.atype) {
+                            org_usr_type
+                        } else {
+                            // This should only happen if the DB is corrupted
+                            err_handler!("Unknown user type in the database")
+                        }
+                    },
+                    org_user,
+                    org_id,
+                })
             }
+            _ => err_handler!("Error getting the organization id"),
         }
     }
 }
@@ -483,25 +480,21 @@ pub struct AdminHeaders {
     pub org_user_type: UserOrgType,
 }
 
-impl<'a, 'r> FromRequest<'a, 'r> for AdminHeaders {
+#[rocket::async_trait]
+impl<'r> FromRequest<'r> for AdminHeaders {
     type Error = &'static str;
 
-    fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
-        match request.guard::<OrgHeaders>() {
-            Outcome::Forward(_) => Outcome::Forward(()),
-            Outcome::Failure(f) => Outcome::Failure(f),
-            Outcome::Success(headers) => {
-                if headers.org_user_type >= UserOrgType::Admin {
-                    Outcome::Success(Self {
-                        host: headers.host,
-                        device: headers.device,
-                        user: headers.user,
-                        org_user_type: headers.org_user_type,
-                    })
-                } else {
-                    err_handler!("You need to be Admin or Owner to call this endpoint")
-                }
-            }
+    async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
+        let headers = try_outcome!(OrgHeaders::from_request(request).await);
+        if headers.org_user_type >= UserOrgType::Admin {
+            Outcome::Success(Self {
+                host: headers.host,
+                device: headers.device,
+                user: headers.user,
+                org_user_type: headers.org_user_type,
+            })
+        } else {
+            err_handler!("You need to be Admin or Owner to call this endpoint")
         }
     }
 }
@@ -520,13 +513,13 @@ impl From<AdminHeaders> for Headers {
 // but there could be cases where it is a query value.
 // First check the path, if this is not a valid uuid, try the query values.
 fn get_col_id(request: &Request) -> Option<String> {
-    if let Some(Ok(col_id)) = request.get_param::<String>(3) {
+    if let Some(Ok(col_id)) = request.param::<String>(3) {
         if uuid::Uuid::parse_str(&col_id).is_ok() {
             return Some(col_id);
         }
     }
 
-    if let Some(Ok(col_id)) = request.get_query_value::<String>("collectionId") {
+    if let Some(Ok(col_id)) = request.query_value::<String>("collectionId") {
         if uuid::Uuid::parse_str(&col_id).is_ok() {
             return Some(col_id);
         }
@@ -545,46 +538,38 @@ pub struct ManagerHeaders {
     pub org_user_type: UserOrgType,
 }
 
-impl<'a, 'r> FromRequest<'a, 'r> for ManagerHeaders {
+#[rocket::async_trait]
+impl<'r> FromRequest<'r> for ManagerHeaders {
     type Error = &'static str;
 
-    fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
-        match request.guard::<OrgHeaders>() {
-            Outcome::Forward(_) => Outcome::Forward(()),
-            Outcome::Failure(f) => Outcome::Failure(f),
-            Outcome::Success(headers) => {
-                if headers.org_user_type >= UserOrgType::Manager {
-                    match get_col_id(request) {
-                        Some(col_id) => {
-                            let conn = match request.guard::<DbConn>() {
-                                Outcome::Success(conn) => conn,
-                                _ => err_handler!("Error getting DB"),
-                            };
-
-                            if !headers.org_user.has_full_access() {
-                                match CollectionUser::find_by_collection_and_user(
-                                    &col_id,
-                                    &headers.org_user.user_uuid,
-                                    &conn,
-                                ) {
-                                    Some(_) => (),
-                                    None => err_handler!("The current user isn't a manager for this collection"),
-                                }
-                            }
+    async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
+        let headers = try_outcome!(OrgHeaders::from_request(request).await);
+        if headers.org_user_type >= UserOrgType::Manager {
+            match get_col_id(request) {
+                Some(col_id) => {
+                    let conn = match DbConn::from_request(request).await {
+                        Outcome::Success(conn) => conn,
+                        _ => err_handler!("Error getting DB"),
+                    };
+
+                    if !headers.org_user.has_full_access() {
+                        match CollectionUser::find_by_collection_and_user(&col_id, &headers.org_user.user_uuid, &conn) {
+                            Some(_) => (),
+                            None => err_handler!("The current user isn't a manager for this collection"),
                         }
-                        _ => err_handler!("Error getting the collection id"),
                     }
-
-                    Outcome::Success(Self {
-                        host: headers.host,
-                        device: headers.device,
-                        user: headers.user,
-                        org_user_type: headers.org_user_type,
-                    })
-                } else {
-                    err_handler!("You need to be a Manager, Admin or Owner to call this endpoint")
                 }
+                _ => err_handler!("Error getting the collection id"),
             }
+
+            Outcome::Success(Self {
+                host: headers.host,
+                device: headers.device,
+                user: headers.user,
+                org_user_type: headers.org_user_type,
+            })
+        } else {
+            err_handler!("You need to be a Manager, Admin or Owner to call this endpoint")
         }
     }
 }
@@ -608,25 +593,21 @@ pub struct ManagerHeadersLoose {
     pub org_user_type: UserOrgType,
 }
 
-impl<'a, 'r> FromRequest<'a, 'r> for ManagerHeadersLoose {
+#[rocket::async_trait]
+impl<'r> FromRequest<'r> for ManagerHeadersLoose {
     type Error = &'static str;
 
-    fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
-        match request.guard::<OrgHeaders>() {
-            Outcome::Forward(_) => Outcome::Forward(()),
-            Outcome::Failure(f) => Outcome::Failure(f),
-            Outcome::Success(headers) => {
-                if headers.org_user_type >= UserOrgType::Manager {
-                    Outcome::Success(Self {
-                        host: headers.host,
-                        device: headers.device,
-                        user: headers.user,
-                        org_user_type: headers.org_user_type,
-                    })
-                } else {
-                    err_handler!("You need to be a Manager, Admin or Owner to call this endpoint")
-                }
-            }
+    async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
+        let headers = try_outcome!(OrgHeaders::from_request(request).await);
+        if headers.org_user_type >= UserOrgType::Manager {
+            Outcome::Success(Self {
+                host: headers.host,
+                device: headers.device,
+                user: headers.user,
+                org_user_type: headers.org_user_type,
+            })
+        } else {
+            err_handler!("You need to be a Manager, Admin or Owner to call this endpoint")
         }
     }
 }
@@ -647,24 +628,20 @@ pub struct OwnerHeaders {
     pub user: User,
 }
 
-impl<'a, 'r> FromRequest<'a, 'r> for OwnerHeaders {
+#[rocket::async_trait]
+impl<'r> FromRequest<'r> for OwnerHeaders {
     type Error = &'static str;
 
-    fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
-        match request.guard::<OrgHeaders>() {
-            Outcome::Forward(_) => Outcome::Forward(()),
-            Outcome::Failure(f) => Outcome::Failure(f),
-            Outcome::Success(headers) => {
-                if headers.org_user_type == UserOrgType::Owner {
-                    Outcome::Success(Self {
-                        host: headers.host,
-                        device: headers.device,
-                        user: headers.user,
-                    })
-                } else {
-                    err_handler!("You need to be Owner to call this endpoint")
-                }
-            }
+    async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
+        let headers = try_outcome!(OrgHeaders::from_request(request).await);
+        if headers.org_user_type == UserOrgType::Owner {
+            Outcome::Success(Self {
+                host: headers.host,
+                device: headers.device,
+                user: headers.user,
+            })
+        } else {
+            err_handler!("You need to be Owner to call this endpoint")
         }
     }
 }
@@ -678,10 +655,11 @@ pub struct ClientIp {
     pub ip: IpAddr,
 }
 
-impl<'a, 'r> FromRequest<'a, 'r> for ClientIp {
+#[rocket::async_trait]
+impl<'r> FromRequest<'r> for ClientIp {
     type Error = ();
 
-    fn from_request(req: &'a Request<'r>) -> Outcome<Self, Self::Error> {
+    async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
         let ip = if CONFIG._ip_header_enabled() {
             req.headers().get_one(&CONFIG.ip_header()).and_then(|ip| {
                 match ip.find(',') {

+ 31 - 0
src/config.rs

@@ -36,6 +36,9 @@ macro_rules! make_config {
         pub struct Config { inner: RwLock<Inner> }
 
         struct Inner {
+            rocket_shutdown_handle: Option<rocket::Shutdown>,
+            ws_shutdown_handle: Option<ws::Sender>,
+
             templates: Handlebars<'static>,
             config: ConfigItems,
 
@@ -332,6 +335,8 @@ make_config! {
         attachments_folder:     String, false,  auto,   |c| format!("{}/{}", c.data_folder, "attachments");
         /// Sends folder
         sends_folder:           String, false,  auto,   |c| format!("{}/{}", c.data_folder, "sends");
+        /// Temp folder |> Used for storing temporary file uploads
+        tmp_folder:           String, false,  auto,   |c| format!("{}/{}", c.data_folder, "tmp");
         /// Templates folder
         templates_folder:       String, false,  auto,   |c| format!("{}/{}", c.data_folder, "templates");
         /// Session JWT key
@@ -509,6 +514,9 @@ make_config! {
         /// Max database connection retries |> Number of times to retry the database connection during startup, with 1 second between each retry, set to 0 to retry indefinitely
         db_connection_retries:  u32,    false,  def,    15;
 
+        /// Timeout when aquiring database connection
+        database_timeout:     u64,    false,  def,    30;
+
         /// Database connection pool size
         database_max_conns:     u32,    false,  def,    10;
 
@@ -743,6 +751,8 @@ impl Config {
 
         Ok(Config {
             inner: RwLock::new(Inner {
+                rocket_shutdown_handle: None,
+                ws_shutdown_handle: None,
                 templates: load_templates(&config.templates_folder),
                 config,
                 _env,
@@ -907,6 +917,27 @@ impl Config {
             hb.render(name, data).map_err(Into::into)
         }
     }
+
+    pub fn set_rocket_shutdown_handle(&self, handle: rocket::Shutdown) {
+        self.inner.write().unwrap().rocket_shutdown_handle = Some(handle);
+    }
+
+    pub fn set_ws_shutdown_handle(&self, handle: ws::Sender) {
+        self.inner.write().unwrap().ws_shutdown_handle = Some(handle);
+    }
+
+    pub fn shutdown(&self) {
+        if let Ok(c) = self.inner.read() {
+            if let Some(handle) = c.ws_shutdown_handle.clone() {
+                handle.shutdown().ok();
+            }
+            // Wait a bit before stopping the web server
+            std::thread::sleep(std::time::Duration::from_secs(1));
+            if let Some(handle) = c.rocket_shutdown_handle.clone() {
+                handle.notify();
+            }
+        }
+    }
 }
 
 use handlebars::{Context, Handlebars, Helper, HelperResult, Output, RenderContext, RenderError, Renderable};

+ 185 - 46
src/db/mod.rs

@@ -1,8 +1,16 @@
+use std::{sync::Arc, time::Duration};
+
 use diesel::r2d2::{ConnectionManager, Pool, PooledConnection};
 use rocket::{
     http::Status,
+    outcome::IntoOutcome,
     request::{FromRequest, Outcome},
-    Request, State,
+    Request,
+};
+
+use tokio::{
+    sync::{Mutex, OwnedSemaphorePermit, Semaphore},
+    time::timeout,
 };
 
 use crate::{
@@ -22,6 +30,23 @@ pub mod __mysql_schema;
 #[path = "schemas/postgresql/schema.rs"]
 pub mod __postgresql_schema;
 
+// There changes are based on Rocket 0.5-rc wrapper of Diesel: https://github.com/SergioBenitez/Rocket/blob/v0.5-rc/contrib/sync_db_pools
+
+// A wrapper around spawn_blocking that propagates panics to the calling code.
+pub async fn run_blocking<F, R>(job: F) -> R
+where
+    F: FnOnce() -> R + Send + 'static,
+    R: Send + 'static,
+{
+    match tokio::task::spawn_blocking(job).await {
+        Ok(ret) => ret,
+        Err(e) => match e.try_into_panic() {
+            Ok(panic) => std::panic::resume_unwind(panic),
+            Err(_) => unreachable!("spawn_blocking tasks are never cancelled"),
+        },
+    }
+}
+
 // This is used to generate the main DbConn and DbPool enums, which contain one variant for each database supported
 macro_rules! generate_connections {
     ( $( $name:ident: $ty:ty ),+ ) => {
@@ -29,12 +54,53 @@ macro_rules! generate_connections {
         #[derive(Eq, PartialEq)]
         pub enum DbConnType { $( $name, )+ }
 
+        pub struct DbConn {
+            conn: Arc<Mutex<Option<DbConnInner>>>,
+            permit: Option<OwnedSemaphorePermit>,
+        }
+
         #[allow(non_camel_case_types)]
-        pub enum DbConn { $( #[cfg($name)] $name(PooledConnection<ConnectionManager< $ty >>), )+ }
+        pub enum DbConnInner { $( #[cfg($name)] $name(PooledConnection<ConnectionManager< $ty >>), )+ }
+
+
+        #[derive(Clone)]
+        pub struct DbPool {
+            // This is an 'Option' so that we can drop the pool in a 'spawn_blocking'.
+            pool: Option<DbPoolInner>,
+            semaphore: Arc<Semaphore>
+        }
 
         #[allow(non_camel_case_types)]
         #[derive(Clone)]
-        pub enum DbPool { $( #[cfg($name)] $name(Pool<ConnectionManager< $ty >>), )+ }
+        pub enum DbPoolInner { $( #[cfg($name)] $name(Pool<ConnectionManager< $ty >>), )+ }
+
+        impl Drop for DbConn {
+            fn drop(&mut self) {
+                let conn = self.conn.clone();
+                let permit = self.permit.take();
+
+                // Since connection can't be on the stack in an async fn during an
+                // await, we have to spawn a new blocking-safe thread...
+                tokio::task::spawn_blocking(move || {
+                    // And then re-enter the runtime to wait on the async mutex, but in a blocking fashion.
+                    let mut conn = tokio::runtime::Handle::current().block_on(conn.lock_owned());
+
+                    if let Some(conn) = conn.take() {
+                        drop(conn);
+                    }
+
+                    // Drop permit after the connection is dropped
+                    drop(permit);
+                });
+            }
+        }
+
+        impl Drop for DbPool {
+            fn drop(&mut self) {
+                let pool = self.pool.take();
+                tokio::task::spawn_blocking(move || drop(pool));
+            }
+        }
 
         impl DbPool {
             // For the given database URL, guess it's type, run migrations create pool and return it
@@ -50,9 +116,13 @@ macro_rules! generate_connections {
                             let manager = ConnectionManager::new(&url);
                             let pool = Pool::builder()
                                 .max_size(CONFIG.database_max_conns())
+                                .connection_timeout(Duration::from_secs(CONFIG.database_timeout()))
                                 .build(manager)
                                 .map_res("Failed to create pool")?;
-                            return Ok(Self::$name(pool));
+                            return Ok(DbPool {
+                                pool: Some(DbPoolInner::$name(pool)),
+                                semaphore: Arc::new(Semaphore::new(CONFIG.database_max_conns() as usize)),
+                            });
                         }
                         #[cfg(not($name))]
                         #[allow(unreachable_code)]
@@ -61,10 +131,26 @@ macro_rules! generate_connections {
                 )+ }
             }
             // Get a connection from the pool
-            pub fn get(&self) -> Result<DbConn, Error> {
-                match self {  $(
+            pub async fn get(&self) -> Result<DbConn, Error> {
+                let duration = Duration::from_secs(CONFIG.database_timeout());
+                let permit = match timeout(duration, self.semaphore.clone().acquire_owned()).await {
+                    Ok(p) => p.expect("Semaphore should be open"),
+                    Err(_) => {
+                        err!("Timeout waiting for database connection");
+                    }
+                };
+
+                match self.pool.as_ref().expect("DbPool.pool should always be Some()") {  $(
                     #[cfg($name)]
-                    Self::$name(p) => Ok(DbConn::$name(p.get().map_res("Error retrieving connection from pool")?)),
+                    DbPoolInner::$name(p) => {
+                        let pool = p.clone();
+                        let c = run_blocking(move || pool.get_timeout(duration)).await.map_res("Error retrieving connection from pool")?;
+
+                        return Ok(DbConn {
+                            conn: Arc::new(Mutex::new(Some(DbConnInner::$name(c)))),
+                            permit: Some(permit)
+                        });
+                    },
                 )+ }
             }
         }
@@ -113,42 +199,95 @@ macro_rules! db_run {
         db_run! { $conn: sqlite, mysql, postgresql $body }
     };
 
-    // Different code for each db
-    ( $conn:ident: $( $($db:ident),+ $body:block )+ ) => {{
-        #[allow(unused)] use diesel::prelude::*;
-        match $conn {
-            $($(
-                #[cfg($db)]
-                crate::db::DbConn::$db(ref $conn) => {
-                    paste::paste! {
-                        #[allow(unused)] use crate::db::[<__ $db _schema>]::{self as schema, *};
-                        #[allow(unused)] use [<__ $db _model>]::*;
-                        #[allow(unused)] use crate::db::FromDb;
-                    }
-                    $body
-                },
-            )+)+
-        }}
-    };
-
-    // Same for all dbs
     ( @raw $conn:ident: $body:block ) => {
         db_run! { @raw $conn: sqlite, mysql, postgresql $body }
     };
 
     // Different code for each db
-    ( @raw $conn:ident: $( $($db:ident),+ $body:block )+ ) => {
+    ( $conn:ident: $( $($db:ident),+ $body:block )+ ) => {{
         #[allow(unused)] use diesel::prelude::*;
-        #[allow(unused_variables)]
-        match $conn {
-            $($(
-                #[cfg($db)]
-                crate::db::DbConn::$db(ref $conn) => {
-                    $body
-                },
-            )+)+
-        }
-    };
+
+        // It is important that this inner Arc<Mutex<>> (or the OwnedMutexGuard
+        // derived from it) never be a variable on the stack at an await point,
+        // where Drop might be called at any time. This causes (synchronous)
+        // Drop to be called from asynchronous code, which some database
+        // wrappers do not or can not handle.
+        let conn = $conn.conn.clone();
+
+        // Since connection can't be on the stack in an async fn during an
+        // await, we have to spawn a new blocking-safe thread...
+        /*
+        run_blocking(move || {
+            // And then re-enter the runtime to wait on the async mutex, but in
+            // a blocking fashion.
+            let mut conn = tokio::runtime::Handle::current().block_on(conn.lock_owned());
+            let conn = conn.as_mut().expect("internal invariant broken: self.connection is Some");
+            */
+            let mut __conn_mutex = conn.try_lock_owned().unwrap();
+            let conn = __conn_mutex.as_mut().unwrap();
+            match conn {
+                    $($(
+                    #[cfg($db)]
+                    crate::db::DbConnInner::$db($conn) => {
+                        paste::paste! {
+                            #[allow(unused)] use crate::db::[<__ $db _schema>]::{self as schema, *};
+                            #[allow(unused)] use [<__ $db _model>]::*;
+                            #[allow(unused)] use crate::db::FromDb;
+                        }
+
+                        /*
+                        // Since connection can't be on the stack in an async fn during an
+                        // await, we have to spawn a new blocking-safe thread...
+                        run_blocking(move || {
+                            // And then re-enter the runtime to wait on the async mutex, but in
+                            // a blocking fashion.
+                            let mut conn = tokio::runtime::Handle::current().block_on(async {
+                                conn.lock_owned().await
+                            });
+
+                            let conn = conn.as_mut().expect("internal invariant broken: self.connection is Some");
+                            f(conn)
+                        }).await;*/
+
+                        $body
+                    },
+                )+)+
+            }
+        // }).await
+    }};
+
+    ( @raw $conn:ident: $( $($db:ident),+ $body:block )+ ) => {{
+        #[allow(unused)] use diesel::prelude::*;
+
+        // It is important that this inner Arc<Mutex<>> (or the OwnedMutexGuard
+        // derived from it) never be a variable on the stack at an await point,
+        // where Drop might be called at any time. This causes (synchronous)
+        // Drop to be called from asynchronous code, which some database
+        // wrappers do not or can not handle.
+        let conn = $conn.conn.clone();
+
+        // Since connection can't be on the stack in an async fn during an
+        // await, we have to spawn a new blocking-safe thread...
+        run_blocking(move || {
+            // And then re-enter the runtime to wait on the async mutex, but in
+            // a blocking fashion.
+            let mut conn = tokio::runtime::Handle::current().block_on(conn.lock_owned());
+            match conn.as_mut().expect("internal invariant broken: self.connection is Some") {
+                    $($(
+                    #[cfg($db)]
+                    crate::db::DbConnInner::$db($conn) => {
+                        paste::paste! {
+                            #[allow(unused)] use crate::db::[<__ $db _schema>]::{self as schema, *};
+                            // @RAW: #[allow(unused)] use [<__ $db _model>]::*;
+                            #[allow(unused)] use crate::db::FromDb;
+                        }
+
+                        $body
+                    },
+                )+)+
+            }
+        }).await
+    }};
 }
 
 pub trait FromDb {
@@ -227,9 +366,10 @@ pub mod models;
 
 /// Creates a back-up of the sqlite database
 /// MySQL/MariaDB and PostgreSQL are not supported.
-pub fn backup_database(conn: &DbConn) -> Result<(), Error> {
+pub async fn backup_database(conn: &DbConn) -> Result<(), Error> {
     db_run! {@raw conn:
         postgresql, mysql {
+            let _ = conn;
             err!("PostgreSQL and MySQL/MariaDB do not support this backup feature");
         }
         sqlite {
@@ -244,7 +384,7 @@ pub fn backup_database(conn: &DbConn) -> Result<(), Error> {
 }
 
 /// Get the SQL Server version
-pub fn get_sql_server_version(conn: &DbConn) -> String {
+pub async fn get_sql_server_version(conn: &DbConn) -> String {
     db_run! {@raw conn:
         postgresql, mysql {
             no_arg_sql_function!(version, diesel::sql_types::Text);
@@ -260,15 +400,14 @@ pub fn get_sql_server_version(conn: &DbConn) -> String {
 /// Attempts to retrieve a single connection from the managed database pool. If
 /// no pool is currently managed, fails with an `InternalServerError` status. If
 /// no connections are available, fails with a `ServiceUnavailable` status.
-impl<'a, 'r> FromRequest<'a, 'r> for DbConn {
+#[rocket::async_trait]
+impl<'r> FromRequest<'r> for DbConn {
     type Error = ();
 
-    fn from_request(request: &'a Request<'r>) -> Outcome<DbConn, ()> {
-        // https://github.com/SergioBenitez/Rocket/commit/e3c1a4ad3ab9b840482ec6de4200d30df43e357c
-        let pool = try_outcome!(request.guard::<State<DbPool>>());
-        match pool.get() {
-            Ok(conn) => Outcome::Success(conn),
-            Err(_) => Outcome::Failure((Status::ServiceUnavailable, ())),
+    async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
+        match request.rocket().state::<DbPool>() {
+            Some(p) => p.get().await.map_err(|_| ()).into_outcome(Status::ServiceUnavailable),
+            None => Outcome::Failure((Status::InternalServerError, ())),
         }
     }
 }

+ 6 - 4
src/error.rs

@@ -45,6 +45,7 @@ use lettre::transport::smtp::Error as SmtpErr;
 use openssl::error::ErrorStack as SSLErr;
 use regex::Error as RegexErr;
 use reqwest::Error as ReqErr;
+use rocket::error::Error as RocketErr;
 use serde_json::{Error as SerdeErr, Value};
 use std::io::Error as IoErr;
 use std::time::SystemTimeError as TimeErr;
@@ -84,6 +85,7 @@ make_error! {
     Address(AddrErr):  _has_source, _api_error,
     Smtp(SmtpErr):     _has_source, _api_error,
     OpenSSL(SSLErr):   _has_source, _api_error,
+    Rocket(RocketErr): _has_source, _api_error,
 
     DieselCon(DieselConErr): _has_source, _api_error,
     DieselMig(DieselMigErr): _has_source, _api_error,
@@ -193,8 +195,8 @@ use rocket::http::{ContentType, Status};
 use rocket::request::Request;
 use rocket::response::{self, Responder, Response};
 
-impl<'r> Responder<'r> for Error {
-    fn respond_to(self, _: &Request) -> response::Result<'r> {
+impl<'r> Responder<'r, 'static> for Error {
+    fn respond_to(self, _: &Request) -> response::Result<'static> {
         match self.error {
             ErrorKind::Empty(_) => {}  // Don't print the error in this situation
             ErrorKind::Simple(_) => {} // Don't print the error in this situation
@@ -202,8 +204,8 @@ impl<'r> Responder<'r> for Error {
         };
 
         let code = Status::from_code(self.error_code).unwrap_or(Status::BadRequest);
-
-        Response::build().status(code).header(ContentType::JSON).sized_body(Cursor::new(format!("{}", self))).ok()
+        let body = self.to_string();
+        Response::build().status(code).header(ContentType::JSON).sized_body(Some(body.len()), Cursor::new(body)).ok()
     }
 }
 

+ 65 - 33
src/main.rs

@@ -20,8 +20,15 @@ extern crate diesel;
 #[macro_use]
 extern crate diesel_migrations;
 
-use job_scheduler::{Job, JobScheduler};
-use std::{fs::create_dir_all, panic, path::Path, process::exit, str::FromStr, thread, time::Duration};
+use std::{
+    fs::{canonicalize, create_dir_all},
+    panic,
+    path::Path,
+    process::exit,
+    str::FromStr,
+    thread,
+    time::Duration,
+};
 
 #[macro_use]
 mod error;
@@ -37,9 +44,11 @@ mod util;
 
 pub use config::CONFIG;
 pub use error::{Error, MapResult};
+use rocket::data::{Limits, ToByteUnit};
 pub use util::is_running_in_docker;
 
-fn main() {
+#[rocket::main]
+async fn main() -> Result<(), Error> {
     parse_args();
     launch_info();
 
@@ -56,13 +65,16 @@ fn main() {
     });
     check_web_vault();
 
-    create_icon_cache_folder();
+    create_dir(&CONFIG.icon_cache_folder(), "icon cache");
+    create_dir(&CONFIG.tmp_folder(), "tmp folder");
+    create_dir(&CONFIG.sends_folder(), "sends folder");
+    create_dir(&CONFIG.attachments_folder(), "attachments folder");
 
     let pool = create_db_pool();
-    schedule_jobs(pool.clone());
-    crate::db::models::TwoFactor::migrate_u2f_to_webauthn(&pool.get().unwrap()).unwrap();
+    schedule_jobs(pool.clone()).await;
+    crate::db::models::TwoFactor::migrate_u2f_to_webauthn(&pool.get().await.unwrap()).unwrap();
 
-    launch_rocket(pool, extra_debug); // Blocks until program termination.
+    launch_rocket(pool, extra_debug).await // Blocks until program termination.
 }
 
 const HELP: &str = "\
@@ -127,10 +139,12 @@ fn init_logging(level: log::LevelFilter) -> Result<(), fern::InitError> {
         .level_for("hyper::server", log::LevelFilter::Warn)
         // Silence rocket logs
         .level_for("_", log::LevelFilter::Off)
-        .level_for("launch", log::LevelFilter::Off)
-        .level_for("launch_", log::LevelFilter::Off)
-        .level_for("rocket::rocket", log::LevelFilter::Off)
-        .level_for("rocket::fairing", log::LevelFilter::Off)
+        .level_for("rocket::launch", log::LevelFilter::Error)
+        .level_for("rocket::launch_", log::LevelFilter::Error)
+        .level_for("rocket::rocket", log::LevelFilter::Warn)
+        .level_for("rocket::server", log::LevelFilter::Warn)
+        .level_for("rocket::fairing::fairings", log::LevelFilter::Warn)
+        .level_for("rocket::shield::shield", log::LevelFilter::Warn)
         // Never show html5ever and hyper::proto logs, too noisy
         .level_for("html5ever", log::LevelFilter::Off)
         .level_for("hyper::proto", log::LevelFilter::Off)
@@ -243,10 +257,6 @@ fn create_dir(path: &str, description: &str) {
     create_dir_all(path).expect(&err_msg);
 }
 
-fn create_icon_cache_folder() {
-    create_dir(&CONFIG.icon_cache_folder(), "icon cache");
-}
-
 fn check_data_folder() {
     let data_folder = &CONFIG.data_folder();
     let path = Path::new(data_folder);
@@ -314,51 +324,73 @@ fn create_db_pool() -> db::DbPool {
     }
 }
 
-fn launch_rocket(pool: db::DbPool, extra_debug: bool) {
+async fn launch_rocket(pool: db::DbPool, extra_debug: bool) -> Result<(), Error> {
     let basepath = &CONFIG.domain_path();
 
+    let mut config = rocket::Config::from(rocket::Config::figment());
+    config.address = std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED); // TODO: Allow this to be changed, keep ROCKET_ADDRESS for compat
+    config.temp_dir = canonicalize(CONFIG.tmp_folder()).unwrap().into();
+    config.limits = Limits::new() //
+        .limit("json", 10.megabytes())
+        .limit("data-form", 150.megabytes())
+        .limit("file", 150.megabytes());
+
     // If adding more paths here, consider also adding them to
     // crate::utils::LOGGED_ROUTES to make sure they appear in the log
-    let result = rocket::ignite()
-        .mount(&[basepath, "/"].concat(), api::web_routes())
-        .mount(&[basepath, "/api"].concat(), api::core_routes())
-        .mount(&[basepath, "/admin"].concat(), api::admin_routes())
-        .mount(&[basepath, "/identity"].concat(), api::identity_routes())
-        .mount(&[basepath, "/icons"].concat(), api::icons_routes())
-        .mount(&[basepath, "/notifications"].concat(), api::notifications_routes())
+    let instance = rocket::custom(config)
+        .mount([basepath, "/"].concat(), api::web_routes())
+        .mount([basepath, "/api"].concat(), api::core_routes())
+        .mount([basepath, "/admin"].concat(), api::admin_routes())
+        .mount([basepath, "/identity"].concat(), api::identity_routes())
+        .mount([basepath, "/icons"].concat(), api::icons_routes())
+        .mount([basepath, "/notifications"].concat(), api::notifications_routes())
         .manage(pool)
         .manage(api::start_notification_server())
         .attach(util::AppHeaders())
         .attach(util::Cors())
         .attach(util::BetterLogging(extra_debug))
-        .launch();
+        .ignite()
+        .await?;
+
+    CONFIG.set_rocket_shutdown_handle(instance.shutdown());
+    ctrlc::set_handler(move || {
+        info!("Exiting vaultwarden!");
+        CONFIG.shutdown();
+    })
+    .expect("Error setting Ctrl-C handler");
 
-    // Launch and print error if there is one
-    // The launch will restore the original logging level
-    error!("Launch error {:#?}", result);
+    instance.launch().await?;
+
+    info!("Vaultwarden process exited!");
+    Ok(())
 }
 
-fn schedule_jobs(pool: db::DbPool) {
+async fn schedule_jobs(pool: db::DbPool) {
     if CONFIG.job_poll_interval_ms() == 0 {
         info!("Job scheduler disabled.");
         return;
     }
+
+    let runtime = tokio::runtime::Handle::current();
+
     thread::Builder::new()
         .name("job-scheduler".to_string())
         .spawn(move || {
+            use job_scheduler::{Job, JobScheduler};
+
             let mut sched = JobScheduler::new();
 
             // Purge sends that are past their deletion date.
             if !CONFIG.send_purge_schedule().is_empty() {
                 sched.add(Job::new(CONFIG.send_purge_schedule().parse().unwrap(), || {
-                    api::purge_sends(pool.clone());
+                    runtime.spawn(api::purge_sends(pool.clone()));
                 }));
             }
 
             // Purge trashed items that are old enough to be auto-deleted.
             if !CONFIG.trash_purge_schedule().is_empty() {
                 sched.add(Job::new(CONFIG.trash_purge_schedule().parse().unwrap(), || {
-                    api::purge_trashed_ciphers(pool.clone());
+                    runtime.spawn(api::purge_trashed_ciphers(pool.clone()));
                 }));
             }
 
@@ -366,7 +398,7 @@ fn schedule_jobs(pool: db::DbPool) {
             // indicates that a user's master password has been compromised.
             if !CONFIG.incomplete_2fa_schedule().is_empty() {
                 sched.add(Job::new(CONFIG.incomplete_2fa_schedule().parse().unwrap(), || {
-                    api::send_incomplete_2fa_notifications(pool.clone());
+                    runtime.spawn(api::send_incomplete_2fa_notifications(pool.clone()));
                 }));
             }
 
@@ -375,7 +407,7 @@ fn schedule_jobs(pool: db::DbPool) {
             // sending reminders for requests that are about to be granted anyway.
             if !CONFIG.emergency_request_timeout_schedule().is_empty() {
                 sched.add(Job::new(CONFIG.emergency_request_timeout_schedule().parse().unwrap(), || {
-                    api::emergency_request_timeout_job(pool.clone());
+                    runtime.spawn(api::emergency_request_timeout_job(pool.clone()));
                 }));
             }
 
@@ -383,7 +415,7 @@ fn schedule_jobs(pool: db::DbPool) {
             // emergency access requests.
             if !CONFIG.emergency_notification_reminder_schedule().is_empty() {
                 sched.add(Job::new(CONFIG.emergency_notification_reminder_schedule().parse().unwrap(), || {
-                    api::emergency_notification_reminder_job(pool.clone());
+                    runtime.spawn(api::emergency_notification_reminder_job(pool.clone()));
                 }));
             }
 

+ 33 - 37
src/util.rs

@@ -5,10 +5,10 @@ use std::io::Cursor;
 
 use rocket::{
     fairing::{Fairing, Info, Kind},
-    http::{ContentType, Header, HeaderMap, Method, RawStr, Status},
+    http::{ContentType, Header, HeaderMap, Method, Status},
     request::FromParam,
     response::{self, Responder},
-    Data, Request, Response, Rocket,
+    Data, Orbit, Request, Response, Rocket,
 };
 
 use std::thread::sleep;
@@ -18,6 +18,7 @@ use crate::CONFIG;
 
 pub struct AppHeaders();
 
+#[rocket::async_trait]
 impl Fairing for AppHeaders {
     fn info(&self) -> Info {
         Info {
@@ -26,7 +27,7 @@ impl Fairing for AppHeaders {
         }
     }
 
-    fn on_response(&self, _req: &Request, res: &mut Response) {
+    async fn on_response<'r>(&self, _req: &'r Request<'_>, res: &mut Response<'r>) {
         res.set_raw_header("Permissions-Policy", "accelerometer=(), ambient-light-sensor=(), autoplay=(), camera=(), encrypted-media=(), fullscreen=(), geolocation=(), gyroscope=(), magnetometer=(), microphone=(), midi=(), payment=(), picture-in-picture=(), sync-xhr=(self \"https://haveibeenpwned.com\" \"https://2fa.directory\"), usb=(), vr=()");
         res.set_raw_header("Referrer-Policy", "same-origin");
         res.set_raw_header("X-Frame-Options", "SAMEORIGIN");
@@ -72,6 +73,7 @@ impl Cors {
     }
 }
 
+#[rocket::async_trait]
 impl Fairing for Cors {
     fn info(&self) -> Info {
         Info {
@@ -80,7 +82,7 @@ impl Fairing for Cors {
         }
     }
 
-    fn on_response(&self, request: &Request, response: &mut Response) {
+    async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
         let req_headers = request.headers();
 
         if let Some(origin) = Cors::get_allowed_origin(req_headers) {
@@ -97,7 +99,7 @@ impl Fairing for Cors {
             response.set_header(Header::new("Access-Control-Allow-Credentials", "true"));
             response.set_status(Status::Ok);
             response.set_header(ContentType::Plain);
-            response.set_sized_body(Cursor::new(""));
+            response.set_sized_body(Some(0), Cursor::new(""));
         }
     }
 }
@@ -134,25 +136,21 @@ impl<R> Cached<R> {
     }
 }
 
-impl<'r, R: Responder<'r>> Responder<'r> for Cached<R> {
-    fn respond_to(self, req: &Request) -> response::Result<'r> {
+impl<'r, R: 'r + Responder<'r, 'static> + Send> Responder<'r, 'static> for Cached<R> {
+    fn respond_to(self, request: &'r Request<'_>) -> response::Result<'static> {
+        let mut res = self.response.respond_to(request)?;
+
         let cache_control_header = if self.is_immutable {
             format!("public, immutable, max-age={}", self.ttl)
         } else {
             format!("public, max-age={}", self.ttl)
         };
+        res.set_raw_header("Cache-Control", cache_control_header);
 
         let time_now = chrono::Local::now();
-
-        match self.response.respond_to(req) {
-            Ok(mut res) => {
-                res.set_raw_header("Cache-Control", cache_control_header);
-                let expiry_time = time_now + chrono::Duration::seconds(self.ttl.try_into().unwrap());
-                res.set_raw_header("Expires", format_datetime_http(&expiry_time));
-                Ok(res)
-            }
-            e @ Err(_) => e,
-        }
+        let expiry_time = time_now + chrono::Duration::seconds(self.ttl.try_into().unwrap());
+        res.set_raw_header("Expires", format_datetime_http(&expiry_time));
+        Ok(res)
     }
 }
 
@@ -175,11 +173,9 @@ impl<'r> FromParam<'r> for SafeString {
     type Error = ();
 
     #[inline(always)]
-    fn from_param(param: &'r RawStr) -> Result<Self, Self::Error> {
-        let s = param.percent_decode().map(|cow| cow.into_owned()).map_err(|_| ())?;
-
-        if s.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) {
-            Ok(SafeString(s))
+    fn from_param(param: &'r str) -> Result<Self, Self::Error> {
+        if param.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) {
+            Ok(SafeString(param.to_string()))
         } else {
             Err(())
         }
@@ -193,15 +189,16 @@ const LOGGED_ROUTES: [&str; 6] =
 
 // Boolean is extra debug, when true, we ignore the whitelist above and also print the mounts
 pub struct BetterLogging(pub bool);
+#[rocket::async_trait]
 impl Fairing for BetterLogging {
     fn info(&self) -> Info {
         Info {
             name: "Better Logging",
-            kind: Kind::Launch | Kind::Request | Kind::Response,
+            kind: Kind::Liftoff | Kind::Request | Kind::Response,
         }
     }
 
-    fn on_launch(&self, rocket: &Rocket) {
+    async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
         if self.0 {
             info!(target: "routes", "Routes loaded:");
             let mut routes: Vec<_> = rocket.routes().collect();
@@ -225,34 +222,36 @@ impl Fairing for BetterLogging {
         info!(target: "start", "Rocket has launched from {}", addr);
     }
 
-    fn on_request(&self, request: &mut Request<'_>, _data: &Data) {
+    async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) {
         let method = request.method();
         if !self.0 && method == Method::Options {
             return;
         }
         let uri = request.uri();
         let uri_path = uri.path();
-        let uri_subpath = uri_path.strip_prefix(&CONFIG.domain_path()).unwrap_or(uri_path);
+        let uri_path_str = uri_path.url_decode_lossy();
+        let uri_subpath = uri_path_str.strip_prefix(&CONFIG.domain_path()).unwrap_or(&uri_path_str);
         if self.0 || LOGGED_ROUTES.iter().any(|r| uri_subpath.starts_with(r)) {
             match uri.query() {
-                Some(q) => info!(target: "request", "{} {}?{}", method, uri_path, &q[..q.len().min(30)]),
-                None => info!(target: "request", "{} {}", method, uri_path),
+                Some(q) => info!(target: "request", "{} {}?{}", method, uri_path_str, &q[..q.len().min(30)]),
+                None => info!(target: "request", "{} {}", method, uri_path_str),
             };
         }
     }
 
-    fn on_response(&self, request: &Request, response: &mut Response) {
+    async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
         if !self.0 && request.method() == Method::Options {
             return;
         }
         let uri_path = request.uri().path();
-        let uri_subpath = uri_path.strip_prefix(&CONFIG.domain_path()).unwrap_or(uri_path);
+        let uri_path_str = uri_path.url_decode_lossy();
+        let uri_subpath = uri_path_str.strip_prefix(&CONFIG.domain_path()).unwrap_or(&uri_path_str);
         if self.0 || LOGGED_ROUTES.iter().any(|r| uri_subpath.starts_with(r)) {
             let status = response.status();
-            if let Some(route) = request.route() {
-                info!(target: "response", "{} => {} {}", route, status.code, status.reason)
+            if let Some(ref route) = request.route() {
+                info!(target: "response", "{} => {}", route, status)
             } else {
-                info!(target: "response", "{} {}", status.code, status.reason)
+                info!(target: "response", "{}", status)
             }
         }
     }
@@ -614,10 +613,7 @@ where
     }
 }
 
-use reqwest::{
-    blocking::{Client, ClientBuilder},
-    header,
-};
+use reqwest::{header, Client, ClientBuilder};
 
 pub fn get_reqwest_client() -> Client {
     get_reqwest_client_builder().build().expect("Failed to build client")

Some files were not shown because too many files changed in this diff