瀏覽代碼

Make the admin URL redirect try to use the referrer first, and use /admin when DOMAIN is not configured and the referrer check doesn't work, to allow users without DOMAIN configured to use the admin page correctly

Daniel García 5 年之前
父節點
當前提交
6a972e4b19
共有 1 個文件被更改,包括 52 次插入24 次删除
  1. 52 24
      src/api/admin.rs

+ 52 - 24
src/api/admin.rs

@@ -5,7 +5,7 @@ use std::process::Command;
 
 use rocket::{
     http::{Cookie, Cookies, SameSite},
-    request::{self, FlashMessage, Form, FromRequest, Request, Outcome},
+    request::{self, FlashMessage, Form, FromRequest, Outcome, Request},
     response::{content::Html, Flash, Redirect},
     Route,
 };
@@ -66,12 +66,35 @@ fn admin_path() -> String {
     format!("{}{}", CONFIG.domain_path(), ADMIN_PATH)
 }
 
+struct Referer(Option<String>);
+
+impl<'a, 'r> FromRequest<'a, 'r> for Referer {
+    type Error = ();
+
+    fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Self::Error> {
+        Outcome::Success(Referer(request.headers().get_one("Referer").map(str::to_string)))
+    }
+}
+
 /// Used for `Location` response headers, which must specify an absolute URI
 /// (see https://tools.ietf.org/html/rfc2616#section-14.30).
-fn admin_url() -> String {
-    // Don't use CONFIG.domain() directly, since the user may want to keep a
-    // trailing slash there, particularly when running under a subpath.
-    format!("{}{}{}", CONFIG.domain_origin(), CONFIG.domain_path(), ADMIN_PATH)
+fn admin_url(referer: Referer) -> String {
+    // If we get a referer use that to make it work when, DOMAIN is not set
+    if let Some(mut referer) = referer.0 {
+        if let Some(start_index) = referer.find(ADMIN_PATH) {
+            referer.truncate(start_index + ADMIN_PATH.len());
+            return referer;
+        }
+    }
+
+    if CONFIG.domain_set() {
+        // Don't use CONFIG.domain() directly, since the user may want to keep a
+        // trailing slash there, particularly when running under a subpath.
+        format!("{}{}{}", CONFIG.domain_origin(), CONFIG.domain_path(), ADMIN_PATH)
+    } else {
+        // Last case, when no referer or domain set, technically invalid but better than nothing
+        ADMIN_PATH.to_string()
+    }
 }
 
 #[get("/", rank = 2)]
@@ -91,14 +114,19 @@ struct LoginForm {
 }
 
 #[post("/", data = "<data>")]
-fn post_admin_login(data: Form<LoginForm>, mut cookies: Cookies, ip: ClientIp) -> Result<Redirect, Flash<Redirect>> {
+fn post_admin_login(
+    data: Form<LoginForm>,
+    mut cookies: Cookies,
+    ip: ClientIp,
+    referer: Referer,
+) -> Result<Redirect, Flash<Redirect>> {
     let data = data.into_inner();
 
     // If the token is invalid, redirect to login page
     if !_validate_token(&data.token) {
         error!("Invalid admin token. IP: {}", ip.ip);
         Err(Flash::error(
-            Redirect::to(admin_url()),
+            Redirect::to(admin_url(referer)),
             "Invalid admin token, please try again.",
         ))
     } else {
@@ -114,7 +142,7 @@ fn post_admin_login(data: Form<LoginForm>, mut cookies: Cookies, ip: ClientIp) -
             .finish();
 
         cookies.add(cookie);
-        Ok(Redirect::to(admin_url()))
+        Ok(Redirect::to(admin_url(referer)))
     }
 }
 
@@ -243,9 +271,9 @@ fn test_smtp(data: Json<InviteData>, _token: AdminToken) -> EmptyResult {
 }
 
 #[get("/logout")]
-fn logout(mut cookies: Cookies) -> Result<Redirect, ()> {
+fn logout(mut cookies: Cookies, referer: Referer) -> Result<Redirect, ()> {
     cookies.remove(Cookie::named(COOKIE_NAME));
-    Ok(Redirect::to(admin_url()))
+    Ok(Redirect::to(admin_url(referer)))
 }
 
 #[get("/users")]
@@ -260,12 +288,12 @@ fn get_users_json(_token: AdminToken, conn: DbConn) -> JsonResult {
 fn users_overview(_token: AdminToken, conn: DbConn) -> ApiResult<Html<String>> {
     let users = User::get_all(&conn);
     let users_json: Vec<Value> = users.iter()
-    .map(|u| {
-        let mut usr = u.to_json(&conn);
-        usr["cipher_count"] = json!(Cipher::count_owned_by_user(&u.uuid, &conn));
-        usr["attachment_count"] = json!(Attachment::count_by_user(&u.uuid, &conn));
-        usr["attachment_size"] = json!(get_display_size(Attachment::size_by_user(&u.uuid, &conn) as i32));
-        usr
+        .map(|u| {
+            let mut usr = u.to_json(&conn);
+            usr["cipher_count"] = json!(Cipher::count_owned_by_user(&u.uuid, &conn));
+            usr["attachment_count"] = json!(Attachment::count_by_user(&u.uuid, &conn));
+            usr["attachment_size"] = json!(get_display_size(Attachment::size_by_user(&u.uuid, &conn) as i32));
+            usr
     }).collect();
 
     let text = AdminTemplateData::users(users_json).render()?;
@@ -304,12 +332,12 @@ fn update_revision_users(_token: AdminToken, conn: DbConn) -> EmptyResult {
 fn organizations_overview(_token: AdminToken, conn: DbConn) -> ApiResult<Html<String>> {
     let organizations = Organization::get_all(&conn);
     let organizations_json: Vec<Value> = organizations.iter().map(|o| {
-        let mut org = o.to_json();
-        org["user_count"] = json!(UserOrganization::count_by_org(&o.uuid, &conn));
-        org["cipher_count"] = json!(Cipher::count_by_org(&o.uuid, &conn));
-        org["attachment_count"] = json!(Attachment::count_by_org(&o.uuid, &conn));
-        org["attachment_size"] = json!(get_display_size(Attachment::size_by_org(&o.uuid, &conn) as i32));
-        org
+            let mut org = o.to_json();
+            org["user_count"] = json!(UserOrganization::count_by_org(&o.uuid, &conn));
+            org["cipher_count"] = json!(Cipher::count_by_org(&o.uuid, &conn));
+            org["attachment_count"] = json!(Attachment::count_by_org(&o.uuid, &conn));
+            org["attachment_size"] = json!(get_display_size(Attachment::size_by_org(&o.uuid, &conn) as i32));
+            org
     }).collect();
 
     let text = AdminTemplateData::organizations(organizations_json).render()?;
@@ -373,7 +401,7 @@ fn diagnostics(_token: AdminToken, _conn: DbConn) -> ApiResult<Html<String>> {
                 Ok(mut c) => {
                     c.sha.truncate(8);
                     c.sha
-                },
+            },
                 _ => "-".to_string()
             },
             match get_github_api::<GitRelease>("https://api.github.com/repos/dani-garcia/bw_web_builds/releases/latest") {
@@ -384,7 +412,7 @@ fn diagnostics(_token: AdminToken, _conn: DbConn) -> ApiResult<Html<String>> {
     } else {
         ("-".to_string(), "-".to_string(), "-".to_string())
     };
-    
+
     // Run the date check as the last item right before filling the json.
     // This should ensure that the time difference between the browser and the server is as minimal as possible.
     let dt = Utc::now();