Browse Source

Implemented better errors for JWT

Daniel García 7 years ago
parent
commit
2bb0b15e04
4 changed files with 37 additions and 41 deletions
  1. 2 4
      src/api/core/accounts.rs
  2. 1 4
      src/api/core/organizations.rs
  3. 18 25
      src/auth.rs
  4. 16 8
      src/error.rs

+ 2 - 4
src/api/core/accounts.rs

@@ -76,10 +76,8 @@ fn register(data: JsonUpcase<RegisterData>, conn: DbConn) -> EmptyResult {
                         Some(token) => token,
                         None => err!("No valid invite token")
                     };
-                    let claims: InviteJWTClaims = match decode_invite_jwt(&token) {
-                        Ok(claims) => claims,
-                        Err(msg) => err!("Invalid claim: {:#?}", msg),
-                    };
+                    
+                    let claims: InviteJWTClaims = decode_invite_jwt(&token)?;
                     if &claims.email == &data.Email {
                         user
                     } else {

+ 1 - 4
src/api/core/organizations.rs

@@ -522,10 +522,7 @@ fn accept_invite(_org_id: String, _org_user_id: String, data: JsonUpcase<AcceptD
 // The web-vault passes org_id and org_user_id in the URL, but we are just reading them from the JWT instead
     let data: AcceptData = data.into_inner().data;
     let token = &data.Token;
-    let claims: InviteJWTClaims = match decode_invite_jwt(&token) {
-            Ok(claims) => claims,
-            Err(msg) => err!("Invalid claim: {:#?}", msg),
-    };
+    let claims: InviteJWTClaims = decode_invite_jwt(&token)?;
 
     match User::find_by_mail(&claims.email, &conn) {
         Some(_) => {

+ 18 - 25
src/auth.rs

@@ -7,6 +7,7 @@ use chrono::Duration;
 use jsonwebtoken::{self, Algorithm, Header};
 use serde::ser::Serialize;
 
+use crate::error::{Error, MapResult};
 use crate::CONFIG;
 
 const JWT_ALGORITHM: Algorithm = Algorithm::RS256;
@@ -31,11 +32,11 @@ lazy_static! {
 pub fn encode_jwt<T: Serialize>(claims: &T) -> String {
     match jsonwebtoken::encode(&JWT_HEADER, claims, &PRIVATE_RSA_KEY) {
         Ok(token) => token,
-        Err(e) => panic!("Error encoding jwt {}", e)
+        Err(e) => panic!("Error encoding jwt {}", e),
     }
 }
 
-pub fn decode_jwt(token: &str) -> Result<JWTClaims, String> {
+pub fn decode_jwt(token: &str) -> Result<JWTClaims, Error> {
     let validation = jsonwebtoken::Validation {
         leeway: 30, // 30 seconds
         validate_exp: true,
@@ -47,16 +48,12 @@ pub fn decode_jwt(token: &str) -> Result<JWTClaims, String> {
         algorithms: vec![JWT_ALGORITHM],
     };
 
-    match jsonwebtoken::decode(token, &PUBLIC_RSA_KEY, &validation) {
-        Ok(decoded) => Ok(decoded.claims),
-        Err(msg) => {
-            error!("Error validating jwt - {:#?}", msg);
-            Err(msg.to_string())
-        }
-    }
+    jsonwebtoken::decode(token, &PUBLIC_RSA_KEY, &validation)
+        .map(|d| d.claims)
+        .map_res("Error decoding login JWT")
 }
 
-pub fn decode_invite_jwt(token: &str) -> Result<InviteJWTClaims, String> {
+pub fn decode_invite_jwt(token: &str) -> Result<InviteJWTClaims, Error> {
     let validation = jsonwebtoken::Validation {
         leeway: 30, // 30 seconds
         validate_exp: true,
@@ -68,13 +65,9 @@ pub fn decode_invite_jwt(token: &str) -> Result<InviteJWTClaims, String> {
         algorithms: vec![JWT_ALGORITHM],
     };
 
-    match jsonwebtoken::decode(token, &PUBLIC_RSA_KEY, &validation) {
-        Ok(decoded) => Ok(decoded.claims),
-        Err(msg) => {
-            error!("Error validating jwt - {:#?}", msg);
-            Err(msg.to_string())
-        }
-    }
+    jsonwebtoken::decode(token, &PUBLIC_RSA_KEY, &validation) 
+        .map(|d| d.claims)
+        .map_res("Error decoding invite JWT")
 }
 
 #[derive(Debug, Serialize, Deserialize)]
@@ -150,7 +143,7 @@ impl<'a, 'r> FromRequest<'a, 'r> for Headers {
             CONFIG.domain.clone()
         } else if let Some(referer) = headers.get_one("Referer") {
             referer.to_string()
-        } else {   
+        } else {
             // Try to guess from the headers
             use std::env;
 
@@ -185,7 +178,7 @@ impl<'a, 'r> FromRequest<'a, 'r> for Headers {
         // Check JWT token is valid and get device and user from it
         let claims: JWTClaims = match decode_jwt(access_token) {
             Ok(claims) => claims,
-            Err(_) => err_handler!("Invalid claim")
+            Err(_) => err_handler!("Invalid claim"),
         };
 
         let device_uuid = claims.device;
@@ -193,17 +186,17 @@ impl<'a, 'r> FromRequest<'a, 'r> for Headers {
 
         let conn = match request.guard::<DbConn>() {
             Outcome::Success(conn) => conn,
-            _ => err_handler!("Error getting DB")
+            _ => err_handler!("Error getting DB"),
         };
 
         let device = match Device::find_by_uuid(&device_uuid, &conn) {
             Some(device) => device,
-            None => err_handler!("Invalid device id")
+            None => err_handler!("Invalid device id"),
         };
 
         let user = match User::find_by_uuid(&user_uuid, &conn) {
             Some(user) => user,
-            None => err_handler!("Device has no user associated")
+            None => err_handler!("Device has no user associated"),
         };
 
         if user.security_stamp != claims.sstamp {
@@ -248,11 +241,11 @@ impl<'a, 'r> FromRequest<'a, 'r> for OrgHeaders {
                             None => err_handler!("The current user isn't member of the organization")
                         };
 
-                        Outcome::Success(Self{
+                        Outcome::Success(Self {
                             host: headers.host,
                             device: headers.device,
                             user: headers.user,
-                            org_user_type: { 
+                            org_user_type: {
                                 if let Some(org_usr_type) = UserOrgType::from_i32(org_user.type_) {
                                     org_usr_type
                                 } else { // This should only happen if the DB is corrupted
@@ -260,7 +253,7 @@ impl<'a, 'r> FromRequest<'a, 'r> for OrgHeaders {
                                 }
                             },
                         })
-                    },
+                    }
                     _ => err_handler!("Error getting the organization id"),
                 }
             }

+ 16 - 8
src/error.rs

@@ -44,14 +44,15 @@ macro_rules! make_error {
     };
 }
 
-use diesel::result::{Error as DieselError, QueryResult};
-use serde_json::{Value, Error as SerError};
+use diesel::result::Error as DieselError;
+use jsonwebtoken::errors::Error as JwtError;
+use serde_json::{Error as SerError, Value};
 use u2f::u2ferror::U2fError as U2fErr;
 
 // Error struct
 // Each variant has two elements, the first is an error of different types, used for logging purposes
 // The second is a String, and it's contents are displayed to the user when the error occurs. Inside the macro, this is represented as _
-// 
+//
 // After the variant itself, there are two expressions. The first one is a bool to indicate whether the error cause will be printed to the log.
 // The second one contains the function used to obtain the response sent to the client
 make_error! {
@@ -63,6 +64,7 @@ make_error! {
     DbError(DieselError, _): true,  _api_error,
     U2fError(U2fErr,     _): true,  _api_error,
     SerdeError(SerError, _): true,  _api_error,
+    JWTError(JwtError,   _): true,  _api_error,
     //WsError(ws::Error, _): true,  _api_error,
 }
 
@@ -73,19 +75,25 @@ impl Error {
 }
 
 pub trait MapResult<S, E> {
-    fn map_res(self, msg: &str) -> Result<(), E>;
+    fn map_res(self, msg: &str) -> Result<S, E>;
+}
+
+impl<S, E: Into<Error>> MapResult<S, Error> for Result<S, E> {
+    fn map_res(self, msg: &str) -> Result<S, Error> {
+        self.map_err(Into::into).map_err(|e| e.with_msg(msg))
+    }
 }
 
-impl MapResult<(), Error> for QueryResult<usize> {
+impl<E: Into<Error>> MapResult<(), Error> for Result<usize, E> {
     fn map_res(self, msg: &str) -> Result<(), Error> {
-        self.and(Ok(())).map_err(Error::from).map_err(|e| e.with_msg(msg))
+        self.and(Ok(())).map_res(msg)
     }
 }
 
 use serde::Serialize;
 use std::any::Any;
 
-fn _serialize(e: &impl Serialize, _: &impl Any) -> String {
+fn _serialize(e: &impl Serialize, _msg: &str) -> String {
     serde_json::to_string(e).unwrap()
 }
 
@@ -102,7 +110,7 @@ fn _api_error(_: &impl Any, msg: &str) -> String {
         "Object": "error"
     });
 
-    _serialize(&json, &false)
+    _serialize(&json, "")
 }
 
 //