auth.rs 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645
  1. //
  2. // JWT Handling
  3. //
  4. use chrono::{Duration, Utc};
  5. use num_traits::FromPrimitive;
  6. use once_cell::sync::Lazy;
  7. use jsonwebtoken::{self, Algorithm, DecodingKey, EncodingKey, Header};
  8. use serde::de::DeserializeOwned;
  9. use serde::ser::Serialize;
  10. use crate::{
  11. error::{Error, MapResult},
  12. util::read_file,
  13. CONFIG,
  14. };
  15. const JWT_ALGORITHM: Algorithm = Algorithm::RS256;
  16. pub static DEFAULT_VALIDITY: Lazy<Duration> = Lazy::new(|| Duration::hours(2));
  17. static JWT_HEADER: Lazy<Header> = Lazy::new(|| Header::new(JWT_ALGORITHM));
  18. pub static JWT_LOGIN_ISSUER: Lazy<String> = Lazy::new(|| format!("{}|login", CONFIG.domain_origin()));
  19. static JWT_INVITE_ISSUER: Lazy<String> = Lazy::new(|| format!("{}|invite", CONFIG.domain_origin()));
  20. static JWT_DELETE_ISSUER: Lazy<String> = Lazy::new(|| format!("{}|delete", CONFIG.domain_origin()));
  21. static JWT_VERIFYEMAIL_ISSUER: Lazy<String> = Lazy::new(|| format!("{}|verifyemail", CONFIG.domain_origin()));
  22. static JWT_ADMIN_ISSUER: Lazy<String> = Lazy::new(|| format!("{}|admin", CONFIG.domain_origin()));
  23. static PRIVATE_RSA_KEY: Lazy<Vec<u8>> = Lazy::new(|| match read_file(&CONFIG.private_rsa_key()) {
  24. Ok(key) => key,
  25. Err(e) => panic!("Error loading private RSA Key.\n Error: {}", e),
  26. });
  27. static PUBLIC_RSA_KEY: Lazy<Vec<u8>> = Lazy::new(|| match read_file(&CONFIG.public_rsa_key()) {
  28. Ok(key) => key,
  29. Err(e) => panic!("Error loading public RSA Key.\n Error: {}", e),
  30. });
  31. pub fn encode_jwt<T: Serialize>(claims: &T) -> String {
  32. match jsonwebtoken::encode(&JWT_HEADER, claims, &EncodingKey::from_rsa_der(&PRIVATE_RSA_KEY)) {
  33. Ok(token) => token,
  34. Err(e) => panic!("Error encoding jwt {}", e),
  35. }
  36. }
  37. fn decode_jwt<T: DeserializeOwned>(token: &str, issuer: String) -> Result<T, Error> {
  38. let validation = jsonwebtoken::Validation {
  39. leeway: 30, // 30 seconds
  40. validate_exp: true,
  41. validate_nbf: true,
  42. aud: None,
  43. iss: Some(issuer),
  44. sub: None,
  45. algorithms: vec![JWT_ALGORITHM],
  46. };
  47. let token = token.replace(char::is_whitespace, "");
  48. jsonwebtoken::decode(&token, &DecodingKey::from_rsa_der(&PUBLIC_RSA_KEY), &validation)
  49. .map(|d| d.claims)
  50. .map_res("Error decoding JWT")
  51. }
  52. pub fn decode_login(token: &str) -> Result<LoginJWTClaims, Error> {
  53. decode_jwt(token, JWT_LOGIN_ISSUER.to_string())
  54. }
  55. pub fn decode_invite(token: &str) -> Result<InviteJWTClaims, Error> {
  56. decode_jwt(token, JWT_INVITE_ISSUER.to_string())
  57. }
  58. pub fn decode_delete(token: &str) -> Result<DeleteJWTClaims, Error> {
  59. decode_jwt(token, JWT_DELETE_ISSUER.to_string())
  60. }
  61. pub fn decode_verify_email(token: &str) -> Result<VerifyEmailJWTClaims, Error> {
  62. decode_jwt(token, JWT_VERIFYEMAIL_ISSUER.to_string())
  63. }
  64. pub fn decode_admin(token: &str) -> Result<AdminJWTClaims, Error> {
  65. decode_jwt(token, JWT_ADMIN_ISSUER.to_string())
  66. }
  67. #[derive(Debug, Serialize, Deserialize)]
  68. pub struct LoginJWTClaims {
  69. // Not before
  70. pub nbf: i64,
  71. // Expiration time
  72. pub exp: i64,
  73. // Issuer
  74. pub iss: String,
  75. // Subject
  76. pub sub: String,
  77. pub premium: bool,
  78. pub name: String,
  79. pub email: String,
  80. pub email_verified: bool,
  81. pub orgowner: Vec<String>,
  82. pub orgadmin: Vec<String>,
  83. pub orguser: Vec<String>,
  84. pub orgmanager: Vec<String>,
  85. // user security_stamp
  86. pub sstamp: String,
  87. // device uuid
  88. pub device: String,
  89. // [ "api", "offline_access" ]
  90. pub scope: Vec<String>,
  91. // [ "Application" ]
  92. pub amr: Vec<String>,
  93. }
  94. #[derive(Debug, Serialize, Deserialize)]
  95. pub struct InviteJWTClaims {
  96. // Not before
  97. pub nbf: i64,
  98. // Expiration time
  99. pub exp: i64,
  100. // Issuer
  101. pub iss: String,
  102. // Subject
  103. pub sub: String,
  104. pub email: String,
  105. pub org_id: Option<String>,
  106. pub user_org_id: Option<String>,
  107. pub invited_by_email: Option<String>,
  108. }
  109. pub fn generate_invite_claims(
  110. uuid: String,
  111. email: String,
  112. org_id: Option<String>,
  113. user_org_id: Option<String>,
  114. invited_by_email: Option<String>,
  115. ) -> InviteJWTClaims {
  116. let time_now = Utc::now().naive_utc();
  117. InviteJWTClaims {
  118. nbf: time_now.timestamp(),
  119. exp: (time_now + Duration::days(5)).timestamp(),
  120. iss: JWT_INVITE_ISSUER.to_string(),
  121. sub: uuid,
  122. email,
  123. org_id,
  124. user_org_id,
  125. invited_by_email,
  126. }
  127. }
  128. #[derive(Debug, Serialize, Deserialize)]
  129. pub struct DeleteJWTClaims {
  130. // Not before
  131. pub nbf: i64,
  132. // Expiration time
  133. pub exp: i64,
  134. // Issuer
  135. pub iss: String,
  136. // Subject
  137. pub sub: String,
  138. }
  139. pub fn generate_delete_claims(uuid: String) -> DeleteJWTClaims {
  140. let time_now = Utc::now().naive_utc();
  141. DeleteJWTClaims {
  142. nbf: time_now.timestamp(),
  143. exp: (time_now + Duration::days(5)).timestamp(),
  144. iss: JWT_DELETE_ISSUER.to_string(),
  145. sub: uuid,
  146. }
  147. }
  148. #[derive(Debug, Serialize, Deserialize)]
  149. pub struct VerifyEmailJWTClaims {
  150. // Not before
  151. pub nbf: i64,
  152. // Expiration time
  153. pub exp: i64,
  154. // Issuer
  155. pub iss: String,
  156. // Subject
  157. pub sub: String,
  158. }
  159. pub fn generate_verify_email_claims(uuid: String) -> DeleteJWTClaims {
  160. let time_now = Utc::now().naive_utc();
  161. DeleteJWTClaims {
  162. nbf: time_now.timestamp(),
  163. exp: (time_now + Duration::days(5)).timestamp(),
  164. iss: JWT_VERIFYEMAIL_ISSUER.to_string(),
  165. sub: uuid,
  166. }
  167. }
  168. #[derive(Debug, Serialize, Deserialize)]
  169. pub struct AdminJWTClaims {
  170. // Not before
  171. pub nbf: i64,
  172. // Expiration time
  173. pub exp: i64,
  174. // Issuer
  175. pub iss: String,
  176. // Subject
  177. pub sub: String,
  178. }
  179. pub fn generate_admin_claims() -> AdminJWTClaims {
  180. let time_now = Utc::now().naive_utc();
  181. AdminJWTClaims {
  182. nbf: time_now.timestamp(),
  183. exp: (time_now + Duration::minutes(20)).timestamp(),
  184. iss: JWT_ADMIN_ISSUER.to_string(),
  185. sub: "admin_panel".to_string(),
  186. }
  187. }
  188. //
  189. // Bearer token authentication
  190. //
  191. use rocket::request::{FromRequest, Outcome, Request};
  192. use crate::db::{
  193. models::{CollectionUser, Device, User, UserOrgStatus, UserOrgType, UserOrganization, UserStampException},
  194. DbConn,
  195. };
  196. pub struct Host {
  197. pub host: String
  198. }
  199. impl<'a, 'r> FromRequest<'a, 'r> for Host {
  200. type Error = &'static str;
  201. fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
  202. let headers = request.headers();
  203. // Get host
  204. let host = if CONFIG.domain_set() {
  205. CONFIG.domain()
  206. } else if let Some(referer) = headers.get_one("Referer") {
  207. referer.to_string()
  208. } else {
  209. // Try to guess from the headers
  210. use std::env;
  211. let protocol = if let Some(proto) = headers.get_one("X-Forwarded-Proto") {
  212. proto
  213. } else if env::var("ROCKET_TLS").is_ok() {
  214. "https"
  215. } else {
  216. "http"
  217. };
  218. let host = if let Some(host) = headers.get_one("X-Forwarded-Host") {
  219. host
  220. } else if let Some(host) = headers.get_one("Host") {
  221. host
  222. } else {
  223. ""
  224. };
  225. format!("{}://{}", protocol, host)
  226. };
  227. Outcome::Success(Host { host })
  228. }
  229. }
  230. pub struct Headers {
  231. pub host: String,
  232. pub device: Device,
  233. pub user: User,
  234. }
  235. impl<'a, 'r> FromRequest<'a, 'r> for Headers {
  236. type Error = &'static str;
  237. fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
  238. let headers = request.headers();
  239. let host = match Host::from_request(request) {
  240. Outcome::Forward(_) => return Outcome::Forward(()),
  241. Outcome::Failure(f) => return Outcome::Failure(f),
  242. Outcome::Success(host) => host.host,
  243. };
  244. // Get access_token
  245. let access_token: &str = match headers.get_one("Authorization") {
  246. Some(a) => match a.rsplit("Bearer ").next() {
  247. Some(split) => split,
  248. None => err_handler!("No access token provided"),
  249. },
  250. None => err_handler!("No access token provided"),
  251. };
  252. // Check JWT token is valid and get device and user from it
  253. let claims = match decode_login(access_token) {
  254. Ok(claims) => claims,
  255. Err(_) => err_handler!("Invalid claim"),
  256. };
  257. let device_uuid = claims.device;
  258. let user_uuid = claims.sub;
  259. let conn = match request.guard::<DbConn>() {
  260. Outcome::Success(conn) => conn,
  261. _ => err_handler!("Error getting DB"),
  262. };
  263. let device = match Device::find_by_uuid(&device_uuid, &conn) {
  264. Some(device) => device,
  265. None => err_handler!("Invalid device id"),
  266. };
  267. let user = match User::find_by_uuid(&user_uuid, &conn) {
  268. Some(user) => user,
  269. None => err_handler!("Device has no user associated"),
  270. };
  271. if user.security_stamp != claims.sstamp {
  272. if let Some(stamp_exception) = user
  273. .stamp_exception
  274. .as_deref()
  275. .and_then(|s| serde_json::from_str::<UserStampException>(s).ok())
  276. {
  277. let current_route = match request.route().and_then(|r| r.name) {
  278. Some(name) => name,
  279. _ => err_handler!("Error getting current route for stamp exception"),
  280. };
  281. // Check if both match, if not this route is not allowed with the current security stamp.
  282. if stamp_exception.route != current_route {
  283. err_handler!("Invalid security stamp: Current route and exception route do not match")
  284. } else if stamp_exception.security_stamp != claims.sstamp {
  285. err_handler!("Invalid security stamp for matched stamp exception")
  286. }
  287. } else {
  288. err_handler!("Invalid security stamp")
  289. }
  290. }
  291. Outcome::Success(Headers { host, device, user })
  292. }
  293. }
  294. pub struct OrgHeaders {
  295. pub host: String,
  296. pub device: Device,
  297. pub user: User,
  298. pub org_user_type: UserOrgType,
  299. pub org_user: UserOrganization,
  300. pub org_id: String,
  301. }
  302. // org_id is usually the second path param ("/organizations/<org_id>"),
  303. // but there are cases where it is a query value.
  304. // First check the path, if this is not a valid uuid, try the query values.
  305. fn get_org_id(request: &Request) -> Option<String> {
  306. if let Some(Ok(org_id)) = request.get_param::<String>(1) {
  307. if uuid::Uuid::parse_str(&org_id).is_ok() {
  308. return Some(org_id);
  309. }
  310. }
  311. if let Some(Ok(org_id)) = request.get_query_value::<String>("organizationId") {
  312. if uuid::Uuid::parse_str(&org_id).is_ok() {
  313. return Some(org_id);
  314. }
  315. }
  316. None
  317. }
  318. impl<'a, 'r> FromRequest<'a, 'r> for OrgHeaders {
  319. type Error = &'static str;
  320. fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
  321. match request.guard::<Headers>() {
  322. Outcome::Forward(_) => Outcome::Forward(()),
  323. Outcome::Failure(f) => Outcome::Failure(f),
  324. Outcome::Success(headers) => {
  325. match get_org_id(request) {
  326. Some(org_id) => {
  327. let conn = match request.guard::<DbConn>() {
  328. Outcome::Success(conn) => conn,
  329. _ => err_handler!("Error getting DB"),
  330. };
  331. let user = headers.user;
  332. let org_user = match UserOrganization::find_by_user_and_org(&user.uuid, &org_id, &conn) {
  333. Some(user) => {
  334. if user.status == UserOrgStatus::Confirmed as i32 {
  335. user
  336. } else {
  337. err_handler!("The current user isn't confirmed member of the organization")
  338. }
  339. }
  340. None => err_handler!("The current user isn't member of the organization"),
  341. };
  342. Outcome::Success(Self {
  343. host: headers.host,
  344. device: headers.device,
  345. user,
  346. org_user_type: {
  347. if let Some(org_usr_type) = UserOrgType::from_i32(org_user.atype) {
  348. org_usr_type
  349. } else {
  350. // This should only happen if the DB is corrupted
  351. err_handler!("Unknown user type in the database")
  352. }
  353. },
  354. org_user,
  355. org_id,
  356. })
  357. }
  358. _ => err_handler!("Error getting the organization id"),
  359. }
  360. }
  361. }
  362. }
  363. }
  364. pub struct AdminHeaders {
  365. pub host: String,
  366. pub device: Device,
  367. pub user: User,
  368. pub org_user_type: UserOrgType,
  369. }
  370. impl<'a, 'r> FromRequest<'a, 'r> for AdminHeaders {
  371. type Error = &'static str;
  372. fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
  373. match request.guard::<OrgHeaders>() {
  374. Outcome::Forward(_) => Outcome::Forward(()),
  375. Outcome::Failure(f) => Outcome::Failure(f),
  376. Outcome::Success(headers) => {
  377. if headers.org_user_type >= UserOrgType::Admin {
  378. Outcome::Success(Self {
  379. host: headers.host,
  380. device: headers.device,
  381. user: headers.user,
  382. org_user_type: headers.org_user_type,
  383. })
  384. } else {
  385. err_handler!("You need to be Admin or Owner to call this endpoint")
  386. }
  387. }
  388. }
  389. }
  390. }
  391. impl From<AdminHeaders> for Headers {
  392. fn from(h: AdminHeaders) -> Headers {
  393. Headers {
  394. host: h.host,
  395. device: h.device,
  396. user: h.user,
  397. }
  398. }
  399. }
  400. // col_id is usually the fourth path param ("/organizations/<org_id>/collections/<col_id>"),
  401. // but there could be cases where it is a query value.
  402. // First check the path, if this is not a valid uuid, try the query values.
  403. fn get_col_id(request: &Request) -> Option<String> {
  404. if let Some(Ok(col_id)) = request.get_param::<String>(3) {
  405. if uuid::Uuid::parse_str(&col_id).is_ok() {
  406. return Some(col_id);
  407. }
  408. }
  409. if let Some(Ok(col_id)) = request.get_query_value::<String>("collectionId") {
  410. if uuid::Uuid::parse_str(&col_id).is_ok() {
  411. return Some(col_id);
  412. }
  413. }
  414. None
  415. }
  416. /// The ManagerHeaders are used to check if you are at least a Manager
  417. /// and have access to the specific collection provided via the <col_id>/collections/collectionId.
  418. /// This does strict checking on the collection_id, ManagerHeadersLoose does not.
  419. pub struct ManagerHeaders {
  420. pub host: String,
  421. pub device: Device,
  422. pub user: User,
  423. pub org_user_type: UserOrgType,
  424. }
  425. impl<'a, 'r> FromRequest<'a, 'r> for ManagerHeaders {
  426. type Error = &'static str;
  427. fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
  428. match request.guard::<OrgHeaders>() {
  429. Outcome::Forward(_) => Outcome::Forward(()),
  430. Outcome::Failure(f) => Outcome::Failure(f),
  431. Outcome::Success(headers) => {
  432. if headers.org_user_type >= UserOrgType::Manager {
  433. match get_col_id(request) {
  434. Some(col_id) => {
  435. let conn = match request.guard::<DbConn>() {
  436. Outcome::Success(conn) => conn,
  437. _ => err_handler!("Error getting DB"),
  438. };
  439. if !headers.org_user.has_full_access() {
  440. match CollectionUser::find_by_collection_and_user(&col_id, &headers.org_user.user_uuid, &conn) {
  441. Some(_) => (),
  442. None => err_handler!("The current user isn't a manager for this collection"),
  443. }
  444. }
  445. }
  446. _ => err_handler!("Error getting the collection id"),
  447. }
  448. Outcome::Success(Self {
  449. host: headers.host,
  450. device: headers.device,
  451. user: headers.user,
  452. org_user_type: headers.org_user_type,
  453. })
  454. } else {
  455. err_handler!("You need to be a Manager, Admin or Owner to call this endpoint")
  456. }
  457. }
  458. }
  459. }
  460. }
  461. impl From<ManagerHeaders> for Headers {
  462. fn from(h: ManagerHeaders) -> Headers {
  463. Headers {
  464. host: h.host,
  465. device: h.device,
  466. user: h.user,
  467. }
  468. }
  469. }
  470. /// The ManagerHeadersLoose is used when you at least need to be a Manager,
  471. /// but there is no collection_id sent with the request (either in the path or as form data).
  472. pub struct ManagerHeadersLoose {
  473. pub host: String,
  474. pub device: Device,
  475. pub user: User,
  476. pub org_user_type: UserOrgType,
  477. }
  478. impl<'a, 'r> FromRequest<'a, 'r> for ManagerHeadersLoose {
  479. type Error = &'static str;
  480. fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
  481. match request.guard::<OrgHeaders>() {
  482. Outcome::Forward(_) => Outcome::Forward(()),
  483. Outcome::Failure(f) => Outcome::Failure(f),
  484. Outcome::Success(headers) => {
  485. if headers.org_user_type >= UserOrgType::Manager {
  486. Outcome::Success(Self {
  487. host: headers.host,
  488. device: headers.device,
  489. user: headers.user,
  490. org_user_type: headers.org_user_type,
  491. })
  492. } else {
  493. err_handler!("You need to be a Manager, Admin or Owner to call this endpoint")
  494. }
  495. }
  496. }
  497. }
  498. }
  499. impl From<ManagerHeadersLoose> for Headers {
  500. fn from(h: ManagerHeadersLoose) -> Headers {
  501. Headers {
  502. host: h.host,
  503. device: h.device,
  504. user: h.user,
  505. }
  506. }
  507. }
  508. pub struct OwnerHeaders {
  509. pub host: String,
  510. pub device: Device,
  511. pub user: User,
  512. }
  513. impl<'a, 'r> FromRequest<'a, 'r> for OwnerHeaders {
  514. type Error = &'static str;
  515. fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
  516. match request.guard::<OrgHeaders>() {
  517. Outcome::Forward(_) => Outcome::Forward(()),
  518. Outcome::Failure(f) => Outcome::Failure(f),
  519. Outcome::Success(headers) => {
  520. if headers.org_user_type == UserOrgType::Owner {
  521. Outcome::Success(Self {
  522. host: headers.host,
  523. device: headers.device,
  524. user: headers.user,
  525. })
  526. } else {
  527. err_handler!("You need to be Owner to call this endpoint")
  528. }
  529. }
  530. }
  531. }
  532. }
  533. //
  534. // Client IP address detection
  535. //
  536. use std::net::IpAddr;
  537. pub struct ClientIp {
  538. pub ip: IpAddr,
  539. }
  540. impl<'a, 'r> FromRequest<'a, 'r> for ClientIp {
  541. type Error = ();
  542. fn from_request(req: &'a Request<'r>) -> Outcome<Self, Self::Error> {
  543. let ip = if CONFIG._ip_header_enabled() {
  544. req.headers().get_one(&CONFIG.ip_header()).and_then(|ip| {
  545. match ip.find(',') {
  546. Some(idx) => &ip[..idx],
  547. None => ip,
  548. }
  549. .parse()
  550. .map_err(|_| warn!("'{}' header is malformed: {}", CONFIG.ip_header(), ip))
  551. .ok()
  552. })
  553. } else {
  554. None
  555. };
  556. let ip = ip
  557. .or_else(|| req.remote().map(|r| r.ip()))
  558. .unwrap_or_else(|| "0.0.0.0".parse().unwrap());
  559. Outcome::Success(ClientIp { ip })
  560. }
  561. }