util.rs 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728
  1. //
  2. // Web Headers and caching
  3. //
  4. use std::io::{Cursor, ErrorKind};
  5. use rocket::{
  6. fairing::{Fairing, Info, Kind},
  7. http::{ContentType, Header, HeaderMap, Method, Status},
  8. request::FromParam,
  9. response::{self, Responder},
  10. Data, Orbit, Request, Response, Rocket,
  11. };
  12. use tokio::{
  13. runtime::Handle,
  14. time::{sleep, Duration},
  15. };
  16. use crate::CONFIG;
  17. pub struct AppHeaders();
  18. #[rocket::async_trait]
  19. impl Fairing for AppHeaders {
  20. fn info(&self) -> Info {
  21. Info {
  22. name: "Application Headers",
  23. kind: Kind::Response,
  24. }
  25. }
  26. async fn on_response<'r>(&self, req: &'r Request<'_>, res: &mut Response<'r>) {
  27. res.set_raw_header("Permissions-Policy", "accelerometer=(), ambient-light-sensor=(), autoplay=(), battery=(), camera=(), display-capture=(), document-domain=(), encrypted-media=(), execution-while-not-rendered=(), execution-while-out-of-viewport=(), fullscreen=(), geolocation=(), gyroscope=(), keyboard-map=(), magnetometer=(), microphone=(), midi=(), payment=(), picture-in-picture=(), screen-wake-lock=(), sync-xhr=(), usb=(), web-share=(), xr-spatial-tracking=()");
  28. res.set_raw_header("Referrer-Policy", "same-origin");
  29. res.set_raw_header("X-Content-Type-Options", "nosniff");
  30. // Obsolete in modern browsers, unsafe (XS-Leak), and largely replaced by CSP
  31. res.set_raw_header("X-XSS-Protection", "0");
  32. let req_uri_path = req.uri().path();
  33. // Do not send the Content-Security-Policy (CSP) Header and X-Frame-Options for the *-connector.html files.
  34. // This can cause issues when some MFA requests needs to open a popup or page within the clients like WebAuthn, or Duo.
  35. // This is the same behaviour as upstream Bitwarden.
  36. if !req_uri_path.ends_with("connector.html") {
  37. // # Frame Ancestors:
  38. // Chrome Web Store: https://chrome.google.com/webstore/detail/bitwarden-free-password-m/nngceckbapebfimnlniiiahkandclblb
  39. // Edge Add-ons: https://microsoftedge.microsoft.com/addons/detail/bitwarden-free-password/jbkfoedolllekgbhcbcoahefnbanhhlh?hl=en-US
  40. // Firefox Browser Add-ons: https://addons.mozilla.org/en-US/firefox/addon/bitwarden-password-manager/
  41. // # img/child/frame src:
  42. // Have I Been Pwned and Gravator to allow those calls to work.
  43. // # Connect src:
  44. // Leaked Passwords check: api.pwnedpasswords.com
  45. // 2FA/MFA Site check: api.2fa.directory
  46. // # Mail Relay: https://bitwarden.com/blog/add-privacy-and-security-using-email-aliases-with-bitwarden/
  47. // app.simplelogin.io, app.anonaddy.com, api.fastmail.com, quack.duckduckgo.com
  48. let csp = format!(
  49. "default-src 'self'; \
  50. base-uri 'self'; \
  51. form-action 'self'; \
  52. object-src 'self' blob:; \
  53. script-src 'self' 'wasm-unsafe-eval'; \
  54. style-src 'self' 'unsafe-inline'; \
  55. child-src 'self' https://*.duosecurity.com https://*.duofederal.com; \
  56. frame-src 'self' https://*.duosecurity.com https://*.duofederal.com; \
  57. frame-ancestors 'self' \
  58. chrome-extension://nngceckbapebfimnlniiiahkandclblb \
  59. chrome-extension://jbkfoedolllekgbhcbcoahefnbanhhlh \
  60. moz-extension://* \
  61. {allowed_iframe_ancestors}; \
  62. img-src 'self' data: \
  63. https://haveibeenpwned.com \
  64. https://www.gravatar.com \
  65. {icon_service_csp}; \
  66. connect-src 'self' \
  67. https://api.pwnedpasswords.com \
  68. https://api.2fa.directory \
  69. https://app.simplelogin.io/api/ \
  70. https://app.anonaddy.com/api/ \
  71. https://api.fastmail.com/ \
  72. ;\
  73. ",
  74. icon_service_csp = CONFIG._icon_service_csp(),
  75. allowed_iframe_ancestors = CONFIG.allowed_iframe_ancestors()
  76. );
  77. res.set_raw_header("Content-Security-Policy", csp);
  78. res.set_raw_header("X-Frame-Options", "SAMEORIGIN");
  79. } else {
  80. // It looks like this header get's set somewhere else also, make sure this is not sent for these files, it will cause MFA issues.
  81. res.remove_header("X-Frame-Options");
  82. }
  83. // Disable cache unless otherwise specified
  84. if !res.headers().contains("cache-control") {
  85. res.set_raw_header("Cache-Control", "no-cache, no-store, max-age=0");
  86. }
  87. }
  88. }
  89. pub struct Cors();
  90. impl Cors {
  91. fn get_header(headers: &HeaderMap<'_>, name: &str) -> String {
  92. match headers.get_one(name) {
  93. Some(h) => h.to_string(),
  94. _ => String::new(),
  95. }
  96. }
  97. // Check a request's `Origin` header against the list of allowed origins.
  98. // If a match exists, return it. Otherwise, return None.
  99. fn get_allowed_origin(headers: &HeaderMap<'_>) -> Option<String> {
  100. let origin = Cors::get_header(headers, "Origin");
  101. let domain_origin = CONFIG.domain_origin();
  102. let safari_extension_origin = "file://";
  103. if origin == domain_origin || origin == safari_extension_origin {
  104. Some(origin)
  105. } else {
  106. None
  107. }
  108. }
  109. }
  110. #[rocket::async_trait]
  111. impl Fairing for Cors {
  112. fn info(&self) -> Info {
  113. Info {
  114. name: "Cors",
  115. kind: Kind::Response,
  116. }
  117. }
  118. async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
  119. let req_headers = request.headers();
  120. if let Some(origin) = Cors::get_allowed_origin(req_headers) {
  121. response.set_header(Header::new("Access-Control-Allow-Origin", origin));
  122. }
  123. // Preflight request
  124. if request.method() == Method::Options {
  125. let req_allow_headers = Cors::get_header(req_headers, "Access-Control-Request-Headers");
  126. let req_allow_method = Cors::get_header(req_headers, "Access-Control-Request-Method");
  127. response.set_header(Header::new("Access-Control-Allow-Methods", req_allow_method));
  128. response.set_header(Header::new("Access-Control-Allow-Headers", req_allow_headers));
  129. response.set_header(Header::new("Access-Control-Allow-Credentials", "true"));
  130. response.set_status(Status::Ok);
  131. response.set_header(ContentType::Plain);
  132. response.set_sized_body(Some(0), Cursor::new(""));
  133. }
  134. }
  135. }
  136. pub struct Cached<R> {
  137. response: R,
  138. is_immutable: bool,
  139. ttl: u64,
  140. }
  141. impl<R> Cached<R> {
  142. pub fn long(response: R, is_immutable: bool) -> Cached<R> {
  143. Self {
  144. response,
  145. is_immutable,
  146. ttl: 604800, // 7 days
  147. }
  148. }
  149. pub fn short(response: R, is_immutable: bool) -> Cached<R> {
  150. Self {
  151. response,
  152. is_immutable,
  153. ttl: 600, // 10 minutes
  154. }
  155. }
  156. pub fn ttl(response: R, ttl: u64, is_immutable: bool) -> Cached<R> {
  157. Self {
  158. response,
  159. is_immutable,
  160. ttl,
  161. }
  162. }
  163. }
  164. impl<'r, R: 'r + Responder<'r, 'static> + Send> Responder<'r, 'static> for Cached<R> {
  165. fn respond_to(self, request: &'r Request<'_>) -> response::Result<'static> {
  166. let mut res = self.response.respond_to(request)?;
  167. let cache_control_header = if self.is_immutable {
  168. format!("public, immutable, max-age={}", self.ttl)
  169. } else {
  170. format!("public, max-age={}", self.ttl)
  171. };
  172. res.set_raw_header("Cache-Control", cache_control_header);
  173. let time_now = chrono::Local::now();
  174. let expiry_time = time_now + chrono::Duration::seconds(self.ttl.try_into().unwrap());
  175. res.set_raw_header("Expires", format_datetime_http(&expiry_time));
  176. Ok(res)
  177. }
  178. }
  179. pub struct SafeString(String);
  180. impl std::fmt::Display for SafeString {
  181. fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  182. self.0.fmt(f)
  183. }
  184. }
  185. impl AsRef<Path> for SafeString {
  186. #[inline]
  187. fn as_ref(&self) -> &Path {
  188. Path::new(&self.0)
  189. }
  190. }
  191. impl<'r> FromParam<'r> for SafeString {
  192. type Error = ();
  193. #[inline(always)]
  194. fn from_param(param: &'r str) -> Result<Self, Self::Error> {
  195. if param.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) {
  196. Ok(SafeString(param.to_string()))
  197. } else {
  198. Err(())
  199. }
  200. }
  201. }
  202. // Log all the routes from the main paths list, and the attachments endpoint
  203. // Effectively ignores, any static file route, and the alive endpoint
  204. const LOGGED_ROUTES: [&str; 7] = ["/api", "/admin", "/identity", "/icons", "/attachments", "/events", "/notifications"];
  205. // Boolean is extra debug, when true, we ignore the whitelist above and also print the mounts
  206. pub struct BetterLogging(pub bool);
  207. #[rocket::async_trait]
  208. impl Fairing for BetterLogging {
  209. fn info(&self) -> Info {
  210. Info {
  211. name: "Better Logging",
  212. kind: Kind::Liftoff | Kind::Request | Kind::Response,
  213. }
  214. }
  215. async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
  216. if self.0 {
  217. info!(target: "routes", "Routes loaded:");
  218. let mut routes: Vec<_> = rocket.routes().collect();
  219. routes.sort_by_key(|r| r.uri.path().as_str());
  220. for route in routes {
  221. if route.rank < 0 {
  222. info!(target: "routes", "{:<6} {}", route.method, route.uri);
  223. } else {
  224. info!(target: "routes", "{:<6} {} [{}]", route.method, route.uri, route.rank);
  225. }
  226. }
  227. }
  228. let config = rocket.config();
  229. let scheme = if config.tls_enabled() {
  230. "https"
  231. } else {
  232. "http"
  233. };
  234. let addr = format!("{}://{}:{}", &scheme, &config.address, &config.port);
  235. info!(target: "start", "Rocket has launched from {}", addr);
  236. }
  237. async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) {
  238. let method = request.method();
  239. if !self.0 && method == Method::Options {
  240. return;
  241. }
  242. let uri = request.uri();
  243. let uri_path = uri.path();
  244. let uri_path_str = uri_path.url_decode_lossy();
  245. let uri_subpath = uri_path_str.strip_prefix(&CONFIG.domain_path()).unwrap_or(&uri_path_str);
  246. if self.0 || LOGGED_ROUTES.iter().any(|r| uri_subpath.starts_with(r)) {
  247. match uri.query() {
  248. Some(q) => info!(target: "request", "{} {}?{}", method, uri_path_str, &q[..q.len().min(30)]),
  249. None => info!(target: "request", "{} {}", method, uri_path_str),
  250. };
  251. }
  252. }
  253. async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
  254. if !self.0 && request.method() == Method::Options {
  255. return;
  256. }
  257. let uri_path = request.uri().path();
  258. let uri_path_str = uri_path.url_decode_lossy();
  259. let uri_subpath = uri_path_str.strip_prefix(&CONFIG.domain_path()).unwrap_or(&uri_path_str);
  260. if self.0 || LOGGED_ROUTES.iter().any(|r| uri_subpath.starts_with(r)) {
  261. let status = response.status();
  262. if let Some(ref route) = request.route() {
  263. info!(target: "response", "{} => {}", route, status)
  264. } else {
  265. info!(target: "response", "{}", status)
  266. }
  267. }
  268. }
  269. }
  270. //
  271. // File handling
  272. //
  273. use std::{
  274. fs::{self, File},
  275. io::Result as IOResult,
  276. path::Path,
  277. };
  278. pub fn file_exists(path: &str) -> bool {
  279. Path::new(path).exists()
  280. }
  281. pub fn write_file(path: &str, content: &[u8]) -> Result<(), crate::error::Error> {
  282. use std::io::Write;
  283. let mut f = match File::create(path) {
  284. Ok(file) => file,
  285. Err(e) => {
  286. if e.kind() == ErrorKind::PermissionDenied {
  287. error!("Can't create '{}': Permission denied", path);
  288. }
  289. return Err(From::from(e));
  290. }
  291. };
  292. f.write_all(content)?;
  293. f.flush()?;
  294. Ok(())
  295. }
  296. pub fn delete_file(path: &str) -> IOResult<()> {
  297. let res = fs::remove_file(path);
  298. if let Some(parent) = Path::new(path).parent() {
  299. // If the directory isn't empty, this returns an error, which we ignore
  300. // We only want to delete the folder if it's empty
  301. fs::remove_dir(parent).ok();
  302. }
  303. res
  304. }
  305. pub fn get_display_size(size: i32) -> String {
  306. const UNITS: [&str; 6] = ["bytes", "KB", "MB", "GB", "TB", "PB"];
  307. let mut size: f64 = size.into();
  308. let mut unit_counter = 0;
  309. loop {
  310. if size > 1024. {
  311. size /= 1024.;
  312. unit_counter += 1;
  313. } else {
  314. break;
  315. }
  316. }
  317. format!("{:.2} {}", size, UNITS[unit_counter])
  318. }
  319. pub fn get_uuid() -> String {
  320. uuid::Uuid::new_v4().to_string()
  321. }
  322. //
  323. // String util methods
  324. //
  325. use std::str::FromStr;
  326. #[inline]
  327. pub fn upcase_first(s: &str) -> String {
  328. let mut c = s.chars();
  329. match c.next() {
  330. None => String::new(),
  331. Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
  332. }
  333. }
  334. #[inline]
  335. pub fn lcase_first(s: &str) -> String {
  336. let mut c = s.chars();
  337. match c.next() {
  338. None => String::new(),
  339. Some(f) => f.to_lowercase().collect::<String>() + c.as_str(),
  340. }
  341. }
  342. pub fn try_parse_string<S, T>(string: Option<S>) -> Option<T>
  343. where
  344. S: AsRef<str>,
  345. T: FromStr,
  346. {
  347. if let Some(Ok(value)) = string.map(|s| s.as_ref().parse::<T>()) {
  348. Some(value)
  349. } else {
  350. None
  351. }
  352. }
  353. //
  354. // Env methods
  355. //
  356. use std::env;
  357. pub fn get_env_str_value(key: &str) -> Option<String> {
  358. let key_file = format!("{key}_FILE");
  359. let value_from_env = env::var(key);
  360. let value_file = env::var(&key_file);
  361. match (value_from_env, value_file) {
  362. (Ok(_), Ok(_)) => panic!("You should not define both {key} and {key_file}!"),
  363. (Ok(v_env), Err(_)) => Some(v_env),
  364. (Err(_), Ok(v_file)) => match fs::read_to_string(v_file) {
  365. Ok(content) => Some(content.trim().to_string()),
  366. Err(e) => panic!("Failed to load {key}: {e:?}"),
  367. },
  368. _ => None,
  369. }
  370. }
  371. pub fn get_env<V>(key: &str) -> Option<V>
  372. where
  373. V: FromStr,
  374. {
  375. try_parse_string(get_env_str_value(key))
  376. }
  377. pub fn get_env_bool(key: &str) -> Option<bool> {
  378. const TRUE_VALUES: &[&str] = &["true", "t", "yes", "y", "1"];
  379. const FALSE_VALUES: &[&str] = &["false", "f", "no", "n", "0"];
  380. match get_env_str_value(key) {
  381. Some(val) if TRUE_VALUES.contains(&val.to_lowercase().as_ref()) => Some(true),
  382. Some(val) if FALSE_VALUES.contains(&val.to_lowercase().as_ref()) => Some(false),
  383. _ => None,
  384. }
  385. }
  386. //
  387. // Date util methods
  388. //
  389. use chrono::{DateTime, Local, NaiveDateTime, TimeZone};
  390. // Format used by Bitwarden API
  391. const DATETIME_FORMAT: &str = "%Y-%m-%dT%H:%M:%S%.6fZ";
  392. /// Formats a UTC-offset `NaiveDateTime` in the format used by Bitwarden API
  393. /// responses with "date" fields (`CreationDate`, `RevisionDate`, etc.).
  394. pub fn format_date(dt: &NaiveDateTime) -> String {
  395. dt.format(DATETIME_FORMAT).to_string()
  396. }
  397. /// Formats a `DateTime<Local>` using the specified format string.
  398. ///
  399. /// For a `DateTime<Local>`, the `%Z` specifier normally formats as the
  400. /// time zone's UTC offset (e.g., `+00:00`). In this function, if the
  401. /// `TZ` environment variable is set, then `%Z` instead formats as the
  402. /// abbreviation for that time zone (e.g., `UTC`).
  403. pub fn format_datetime_local(dt: &DateTime<Local>, fmt: &str) -> String {
  404. // Try parsing the `TZ` environment variable to enable formatting `%Z` as
  405. // a time zone abbreviation.
  406. if let Ok(tz) = env::var("TZ") {
  407. if let Ok(tz) = tz.parse::<chrono_tz::Tz>() {
  408. return dt.with_timezone(&tz).format(fmt).to_string();
  409. }
  410. }
  411. // Otherwise, fall back to formatting `%Z` as a UTC offset.
  412. dt.format(fmt).to_string()
  413. }
  414. /// Formats a UTC-offset `NaiveDateTime` as a datetime in the local time zone.
  415. ///
  416. /// This function basically converts the `NaiveDateTime` to a `DateTime<Local>`,
  417. /// and then calls [format_datetime_local](crate::util::format_datetime_local).
  418. pub fn format_naive_datetime_local(dt: &NaiveDateTime, fmt: &str) -> String {
  419. format_datetime_local(&Local.from_utc_datetime(dt), fmt)
  420. }
  421. /// Formats a `DateTime<Local>` as required for HTTP
  422. ///
  423. /// https://httpwg.org/specs/rfc7231.html#http.date
  424. pub fn format_datetime_http(dt: &DateTime<Local>) -> String {
  425. let expiry_time: chrono::DateTime<chrono::Utc> = chrono::DateTime::from_utc(dt.naive_utc(), chrono::Utc);
  426. // HACK: HTTP expects the date to always be GMT (UTC) rather than giving an
  427. // offset (which would always be 0 in UTC anyway)
  428. expiry_time.to_rfc2822().replace("+0000", "GMT")
  429. }
  430. pub fn parse_date(date: &str) -> NaiveDateTime {
  431. NaiveDateTime::parse_from_str(date, DATETIME_FORMAT).unwrap()
  432. }
  433. //
  434. // Deployment environment methods
  435. //
  436. /// Returns true if the program is running in Docker or Podman.
  437. pub fn is_running_in_docker() -> bool {
  438. Path::new("/.dockerenv").exists() || Path::new("/run/.containerenv").exists()
  439. }
  440. /// Simple check to determine on which docker base image vaultwarden is running.
  441. /// We build images based upon Debian or Alpine, so these we check here.
  442. pub fn docker_base_image() -> &'static str {
  443. if Path::new("/etc/debian_version").exists() {
  444. "Debian"
  445. } else if Path::new("/etc/alpine-release").exists() {
  446. "Alpine"
  447. } else {
  448. "Unknown"
  449. }
  450. }
  451. //
  452. // Deserialization methods
  453. //
  454. use std::fmt;
  455. use serde::de::{self, DeserializeOwned, Deserializer, MapAccess, SeqAccess, Visitor};
  456. use serde_json::{self, Value};
  457. pub type JsonMap = serde_json::Map<String, Value>;
  458. #[derive(Serialize, Deserialize)]
  459. pub struct UpCase<T: DeserializeOwned> {
  460. #[serde(deserialize_with = "upcase_deserialize")]
  461. #[serde(flatten)]
  462. pub data: T,
  463. }
  464. // https://github.com/serde-rs/serde/issues/586
  465. pub fn upcase_deserialize<'de, T, D>(deserializer: D) -> Result<T, D::Error>
  466. where
  467. T: DeserializeOwned,
  468. D: Deserializer<'de>,
  469. {
  470. let d = deserializer.deserialize_any(UpCaseVisitor)?;
  471. T::deserialize(d).map_err(de::Error::custom)
  472. }
  473. struct UpCaseVisitor;
  474. impl<'de> Visitor<'de> for UpCaseVisitor {
  475. type Value = Value;
  476. fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
  477. formatter.write_str("an object or an array")
  478. }
  479. fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
  480. where
  481. A: MapAccess<'de>,
  482. {
  483. let mut result_map = JsonMap::new();
  484. while let Some((key, value)) = map.next_entry()? {
  485. result_map.insert(upcase_first(key), upcase_value(value));
  486. }
  487. Ok(Value::Object(result_map))
  488. }
  489. fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
  490. where
  491. A: SeqAccess<'de>,
  492. {
  493. let mut result_seq = Vec::<Value>::new();
  494. while let Some(value) = seq.next_element()? {
  495. result_seq.push(upcase_value(value));
  496. }
  497. Ok(Value::Array(result_seq))
  498. }
  499. }
  500. fn upcase_value(value: Value) -> Value {
  501. if let Value::Object(map) = value {
  502. let mut new_value = Value::Object(serde_json::Map::new());
  503. for (key, val) in map.into_iter() {
  504. let processed_key = _process_key(&key);
  505. new_value[processed_key] = upcase_value(val);
  506. }
  507. new_value
  508. } else if let Value::Array(array) = value {
  509. // Initialize array with null values
  510. let mut new_value = Value::Array(vec![Value::Null; array.len()]);
  511. for (index, val) in array.into_iter().enumerate() {
  512. new_value[index] = upcase_value(val);
  513. }
  514. new_value
  515. } else {
  516. value
  517. }
  518. }
  519. // Inner function to handle some speciale case for the 'ssn' key.
  520. // This key is part of the Identity Cipher (Social Security Number)
  521. fn _process_key(key: &str) -> String {
  522. match key.to_lowercase().as_ref() {
  523. "ssn" => "SSN".into(),
  524. _ => self::upcase_first(key),
  525. }
  526. }
  527. //
  528. // Retry methods
  529. //
  530. pub fn retry<F, T, E>(mut func: F, max_tries: u32) -> Result<T, E>
  531. where
  532. F: FnMut() -> Result<T, E>,
  533. {
  534. let mut tries = 0;
  535. loop {
  536. match func() {
  537. ok @ Ok(_) => return ok,
  538. err @ Err(_) => {
  539. tries += 1;
  540. if tries >= max_tries {
  541. return err;
  542. }
  543. Handle::current().block_on(sleep(Duration::from_millis(500)));
  544. }
  545. }
  546. }
  547. }
  548. pub async fn retry_db<F, T, E>(mut func: F, max_tries: u32) -> Result<T, E>
  549. where
  550. F: FnMut() -> Result<T, E>,
  551. E: std::error::Error,
  552. {
  553. let mut tries = 0;
  554. loop {
  555. match func() {
  556. ok @ Ok(_) => return ok,
  557. Err(e) => {
  558. tries += 1;
  559. if tries >= max_tries && max_tries > 0 {
  560. return Err(e);
  561. }
  562. warn!("Can't connect to database, retrying: {:?}", e);
  563. sleep(Duration::from_millis(1_000)).await;
  564. }
  565. }
  566. }
  567. }
  568. use reqwest::{header, Client, ClientBuilder};
  569. pub fn get_reqwest_client() -> Client {
  570. match get_reqwest_client_builder().build() {
  571. Ok(client) => client,
  572. Err(e) => {
  573. error!("Possible trust-dns error, trying with trust-dns disabled: '{e}'");
  574. get_reqwest_client_builder().trust_dns(false).build().expect("Failed to build client")
  575. }
  576. }
  577. }
  578. pub fn get_reqwest_client_builder() -> ClientBuilder {
  579. let mut headers = header::HeaderMap::new();
  580. headers.insert(header::USER_AGENT, header::HeaderValue::from_static("Vaultwarden"));
  581. Client::builder().default_headers(headers).timeout(Duration::from_secs(10))
  582. }
  583. pub fn convert_json_key_lcase_first(src_json: Value) -> Value {
  584. match src_json {
  585. Value::Array(elm) => {
  586. let mut new_array: Vec<Value> = Vec::with_capacity(elm.len());
  587. for obj in elm {
  588. new_array.push(convert_json_key_lcase_first(obj));
  589. }
  590. Value::Array(new_array)
  591. }
  592. Value::Object(obj) => {
  593. let mut json_map = JsonMap::new();
  594. for (key, value) in obj.iter() {
  595. match (key, value) {
  596. (key, Value::Object(elm)) => {
  597. let inner_value = convert_json_key_lcase_first(Value::Object(elm.clone()));
  598. json_map.insert(lcase_first(key), inner_value);
  599. }
  600. (key, Value::Array(elm)) => {
  601. let mut inner_array: Vec<Value> = Vec::with_capacity(elm.len());
  602. for inner_obj in elm {
  603. inner_array.push(convert_json_key_lcase_first(inner_obj.clone()));
  604. }
  605. json_map.insert(lcase_first(key), Value::Array(inner_array));
  606. }
  607. (key, value) => {
  608. json_map.insert(lcase_first(key), value.clone());
  609. }
  610. }
  611. }
  612. Value::Object(json_map)
  613. }
  614. value => value,
  615. }
  616. }