Browse Source

fix: leaf node encrypt/decrypt

Tienson Qin 1 month ago
parent
commit
2ea6a290a6

+ 1 - 2
src/main/frontend/handler/db_based/db_sync.cljs

@@ -274,7 +274,6 @@
                             (throw (ex-info "missing snapshot download url"
                                             {:graph graph-name
                                              :response download-resp})))
-                        e2ee? (fetch-graph-e2ee? base (str graph-uuid))
                         resp (js/fetch download-url (clj->js (with-auth-headers {:method "GET"})))]
                   (when-not (.-ok resp)
                     (throw (ex-info "snapshot download failed"
@@ -295,7 +294,7 @@
                             (when (seq total-rows')
                               (p/do!
                                (state/<invoke-db-worker :thread-api/db-sync-import-kvs-rows
-                                                        graph total-rows' true graph-uuid e2ee?)
+                                                        graph total-rows' true graph-uuid)
                                (state/<invoke-db-worker :thread-api/db-sync-finalize-kvs-import graph remote-tx)))
                             total')
                           (let [value (.-value chunk)

+ 74 - 101
src/main/frontend/worker/db_sync.cljs

@@ -318,14 +318,6 @@
 
 (def ^:private invalid-transit ::invalid-transit)
 
-(declare encrypt-snapshot-rows decrypt-snapshot-rows)
-
-(defn- try-read-transit [value]
-  (try
-    (ldb/read-transit-str value)
-    (catch :default _
-      invalid-transit)))
-
 (defn- graph-e2ee?
   [repo]
   (when-let [conn (worker-state/get-datascript-conn repo)]
@@ -490,16 +482,6 @@
           (swap! *repo->aes-key assoc repo aes-key)
           aes-key)))))
 
-(defn <decrypt-kvs-rows
-  [repo graph-id rows e2ee?]
-  (if-not e2ee?
-    (p/resolved rows)
-    (p/let [aes-key (<fetch-graph-aes-key-for-download repo graph-id)
-            _ (when (nil? aes-key)
-                (fail-fast :db-sync/missing-field {:repo repo :field :aes-key}))
-            rows* (decrypt-snapshot-rows aes-key rows)]
-      rows*)))
-
 (defn- <grant-graph-access!
   [repo graph-id target-email]
   (if-not (graph-e2ee? repo)
@@ -536,37 +518,19 @@
 
 (defn- <encrypt-text-value
   [aes-key value]
-  (p/let [text (ldb/write-transit-str value)
-          encrypted (crypt/<encrypt-text aes-key text)]
+  (assert (string? value) (str "encrypting value should be a string, value: " value))
+  (p/let [encrypted (crypt/<encrypt-text aes-key (ldb/write-transit-str value))]
     (ldb/write-transit-str encrypted)))
 
 (defn- <decrypt-text-value
   [aes-key value]
-  (if-not (string? value)
-    (p/resolved value)
-    (let [decoded (try-read-transit value)]
-      (if (= decoded invalid-transit)
-        (p/resolved value)
-        (p/let [decrypted (crypt/<decrypt-text-if-encrypted aes-key decoded)]
-          (if decrypted
-            (ldb/read-transit-str decrypted)
-            value))))))
-
-(defn- <encrypt-keys-attrs
-  [aes-key keys]
-  (p/all (mapv (fn [[e a v t]]
-                 (if (contains? rtc-const/encrypt-attr-set a)
-                   (p/let [v' (<encrypt-text-value aes-key v)]
-                     [e a v' t])
-                   [e a v t])) keys)))
-
-(defn- <decrypt-keys-attrs
-  [aes-key keys]
-  (p/all (mapv (fn [[e a v t]]
-                 (if (contains? rtc-const/encrypt-attr-set a)
-                   (p/let [v' (<decrypt-text-value aes-key v)]
-                     [e a v' t])
-                   [e a v t])) keys)))
+  (assert (string? value) (str "encrypted value should be a string, value: " value))
+  (let [decoded (ldb/read-transit-str value)]
+    (if (= decoded invalid-transit)
+      (p/resolved value)
+      (p/let [value (crypt/<decrypt-text-if-encrypted aes-key decoded)
+              value' (ldb/read-transit-str value)]
+        value'))))
 
 (defn- encrypt-tx-item
   [aes-key item]
@@ -574,21 +538,11 @@
     (and (vector? item) (<= 4 (count item)))
     (let [attr (nth item 2)
           v (nth item 3)]
-      (if (and (contains? rtc-const/encrypt-attr-set attr)
-               (string? v))
+      (if (contains? rtc-const/encrypt-attr-set attr)
         (p/let [v' (<encrypt-text-value aes-key v)]
           (assoc item 3 v'))
         (p/resolved item)))
 
-    (map? item)
-    (let [attr (:a item)
-          v (:v item)]
-      (if (and (contains? rtc-const/encrypt-attr-set attr)
-               (string? v))
-        (p/let [v' (<encrypt-text-value aes-key v)]
-          (assoc item :v v'))
-        (p/resolved item)))
-
     :else
     (p/resolved item)))
 
@@ -598,70 +552,85 @@
     (and (vector? item) (<= 4 (count item)))
     (let [attr (nth item 2)
           v (nth item 3)]
-      (if (and (contains? rtc-const/encrypt-attr-set attr)
-               (string? v))
+      (if (contains? rtc-const/encrypt-attr-set attr)
         (p/let [v' (<decrypt-text-value aes-key v)]
           (assoc item 3 v'))
         (p/resolved item)))
 
-    (map? item)
-    (let [attr (:a item)
-          v (:v item)]
-      (if (and (contains? rtc-const/encrypt-attr-set attr)
-               (string? v))
-        (p/let [v' (<decrypt-text-value aes-key v)]
-          (assoc item :v v'))
-        (p/resolved item)))
-
     :else
     (p/resolved item)))
 
-(defn- encrypt-tx-data
+(defn- <encrypt-tx-data
   [aes-key tx-data]
-  (if-not (seq tx-data)
-    (p/resolved [])
+  (when (seq tx-data)
     (p/let [items (p/all (mapv (fn [item] (encrypt-tx-item aes-key item)) tx-data))]
       (vec items))))
 
-(defn- decrypt-tx-data
+(defn- <decrypt-tx-data
   [aes-key tx-data]
-  (if-not (seq tx-data)
-    (p/resolved [])
+  (when (seq tx-data)
     (p/let [items (p/all (mapv (fn [item] (decrypt-tx-item aes-key item)) tx-data))]
       (vec items))))
 
-(defn- encrypt-snapshot-rows
+(defn- <encrypt-keys-attrs
+  [aes-key keys]
+  (p/all (mapv (fn [[e a v t]]
+                 (if (contains? rtc-const/encrypt-attr-set a)
+                   (p/let [v' (<encrypt-text-value aes-key v)]
+                     [e a v' t])
+                   [e a v t])) keys)))
+
+(defn- <decrypt-keys-attrs
+  [aes-key keys]
+  (p/all (mapv (fn [[e a v t]]
+                 (if (contains? rtc-const/encrypt-attr-set a)
+                   (p/let [v' (<decrypt-text-value aes-key v)]
+                     [e a v' t])
+                   (p/resolved [e a v t]))) keys)))
+
+(defn- <encrypt-snapshot-rows
   [aes-key rows]
   (if-not (seq rows)
     (p/resolved [])
     (p/let [items (p/all
                    (mapv (fn [[addr content addresses]]
-                           (let [data (try-read-transit content)]
-                             (if (and (not= data invalid-transit) (map? data))
-                               (p/let [keys' (<encrypt-keys-attrs aes-key (:keys data))
-                                       data' (assoc data :keys keys')
-                                       content' (ldb/write-transit-str data')]
-                                 [addr content' addresses])
-                               (p/resolved [addr content addresses]))))
+                           (let [data (ldb/read-transit-str content)]
+                             (p/let [keys' (if (map? data) ; node
+                                             (<encrypt-keys-attrs aes-key (:keys data))
+                                             ;; leaf
+                                             (p/let [result (p/all (map #(<encrypt-keys-attrs aes-key %) data))]
+                                               (vec result)))
+                                     data' (if (map? data) (assoc data :keys keys') keys')
+                                     content' (ldb/write-transit-str data')]
+                               [addr content' addresses])))
                          rows))]
       (vec items))))
 
-(defn- decrypt-snapshot-rows
+(defn- <decrypt-snapshot-rows
   [aes-key rows]
   (if-not (seq rows)
     (p/resolved [])
     (p/let [items (p/all
                    (mapv (fn [[addr content addresses]]
-                           (let [data (try-read-transit content)]
-                             (if (and (not= data invalid-transit) (map? data))
-                               (p/let [keys (<decrypt-keys-attrs aes-key (:keys data))
-                                       data' (assoc data :keys keys)
-                                       content' (ldb/write-transit-str data')]
-                                 [addr content' addresses])
-                               (p/resolved [addr content addresses]))))
+                           (let [data (ldb/read-transit-str content)]
+                             (p/let [keys' (if (map? data) ; node
+                                             (<decrypt-keys-attrs aes-key (:keys data))
+                                             ;; leaf
+                                             (p/let [result (p/all (map #(<decrypt-keys-attrs aes-key %) data))]
+                                               (vec result)))
+                                     data' (if (map? data) (assoc data :keys keys') keys')
+                                     content' (ldb/write-transit-str data')]
+                               [addr content' addresses])))
                          rows))]
       (vec items))))
 
+(defn <decrypt-kvs-rows
+  [repo graph-id rows]
+  (p/let [aes-key (<fetch-graph-aes-key-for-download repo graph-id)
+          _ (when (nil? aes-key)
+              (fail-fast :db-sync/missing-field {:repo repo :field :aes-key}))]
+    (<decrypt-snapshot-rows aes-key rows)))
+
 (defn- require-asset-field
   [repo field value context]
   (when (or (nil? value) (and (string? value) (string/blank? value)))
@@ -767,16 +736,20 @@
                   ;; (prn :debug :before-keep-last-update txs)
                   ;; (prn :debug :upload :tx-data tx-data)
                   (when (seq txs)
-                    (p/let [aes-key (<ensure-graph-aes-key repo (:graph-id client))
-                            _ (when (and (graph-e2ee? repo) (nil? aes-key))
-                                (fail-fast :db-sync/missing-field {:repo repo :field :aes-key}))
-                            tx-data* (if aes-key
-                                       (encrypt-tx-data aes-key tx-data)
-                                       (p/resolved tx-data))]
-                      (reset! (:inflight client) tx-ids)
-                      (send! ws {:type "tx/batch"
-                                 :t-before local-tx
-                                 :txs (sqlite-util/write-transit-str tx-data*)}))))))))))))
+                    (->
+                     (p/let [aes-key (<ensure-graph-aes-key repo (:graph-id client))
+                             _ (when (and (graph-e2ee? repo) (nil? aes-key))
+                                 (fail-fast :db-sync/missing-field {:repo repo :field :aes-key}))
+                             tx-data* (if aes-key
+                                        (<encrypt-tx-data aes-key tx-data)
+                                        tx-data)]
+
+                       (reset! (:inflight client) tx-ids)
+                       (send! ws {:type "tx/batch"
+                                  :t-before local-tx
+                                  :txs (sqlite-util/write-transit-str tx-data*)}))
+                     (p/catch (fn [error]
+                                (js/console.error error))))))))))))))
 
 (defn- ensure-client-state! [repo]
   (let [client {:repo repo
@@ -1192,7 +1165,7 @@
                                 _ (when (and (graph-e2ee? repo) (nil? aes-key))
                                     (fail-fast :db-sync/missing-field {:repo repo :field :aes-key}))
                                 tx* (if aes-key
-                                      (decrypt-tx-data aes-key tx)
+                                      (<decrypt-tx-data aes-key tx)
                                       (p/resolved tx))]
                           (apply-remote-tx! repo client tx*
                                             :local-tx local-tx
@@ -1429,7 +1402,7 @@
                       rows (normalize-snapshot-rows rows)
                       upload-url (str base "/sync/" graph-id "/snapshot/upload?reset=" (if first-batch? "true" "false"))]
                   (p/let [rows* (if aes-key
-                                  (encrypt-snapshot-rows aes-key rows)
+                                  (<encrypt-snapshot-rows aes-key rows)
                                   (p/resolved rows))
                           {:keys [body encoding]} (<snapshot-upload-body rows*)
                           headers (cond-> {"content-type" snapshot-content-type}

+ 2 - 2
src/main/frontend/worker/db_worker.cljs

@@ -615,10 +615,10 @@
   nil)
 
 (def-thread-api :thread-api/db-sync-import-kvs-rows
-  [repo rows reset? graph-id e2ee?]
+  [repo rows reset? graph-id]
   (p/let [_ (when reset?
               (close-db! repo))
-          rows* (db-sync/<decrypt-kvs-rows repo graph-id rows e2ee?)
+          rows* (db-sync/<decrypt-kvs-rows repo graph-id rows)
           db (ensure-db-sync-import-db! repo reset?)]
     (when (seq rows*)
       (upsert-addr-content! db (rows->sqlite-binds rows*)))

+ 4 - 4
src/test/frontend/worker/db_sync_test.cljs

@@ -72,12 +72,12 @@
                      tx-data [[:db/add 1 :block/title "hello"]
                               [:db/add 2 :block/name "page"]
                               [:db/add 3 :block/uuid (random-uuid)]]
-                     encrypted (#'db-sync/encrypt-tx-data aes-key tx-data)]
+                     encrypted (#'db-sync/<encrypt-tx-data aes-key tx-data)]
                (is (not= tx-data encrypted))
                (is (string? (nth (first encrypted) 3)))
                (is (= (nth (second encrypted) 3)
                       "page"))
-               (p/let [decrypted (#'db-sync/decrypt-tx-data aes-key encrypted)]
+               (p/let [decrypted (#'db-sync/<decrypt-tx-data aes-key encrypted)]
                  (is (= tx-data decrypted))
                  (done)))
              (p/catch (fn [e]
@@ -91,12 +91,12 @@
                                                              :block/title "hello"
                                                              :block/name "page"})
                      rows [[1 content nil]]
-                     encrypted (#'db-sync/encrypt-snapshot-rows aes-key rows)]
+                     encrypted (#'db-sync/<encrypt-snapshot-rows aes-key rows)]
                (is (not= rows encrypted))
                (let [[_ content* _] (first encrypted)]
                  (is (string? content*))
                  (is (not= content content*)))
-               (p/let [decrypted (#'db-sync/decrypt-snapshot-rows aes-key encrypted)]
+               (p/let [decrypted (#'db-sync/<decrypt-snapshot-rows aes-key encrypted)]
                  (is (= rows decrypted))
                  (done)))
              (p/catch (fn [e]