ソースを参照

webui: add pagination for /api/client/.

Nick Peng 6 ヶ月 前
コミット
2bf153edd8

+ 5 - 2
etc/smartdns/smartdns.conf

@@ -435,11 +435,14 @@ log-level info
 
 # load plugin
 # plugin [path/to/file] [args]
-# plugin /usr/lib/libsmartdns-ui.so --p 8080 -i 0.0.0.0 -r /usr/share/smartdns/wwwroot
+
+# web ui plugin
+# plugin /usr/lib/libsmartdns-ui.so
 # smartdns-ui.www-root /usr/share/smartdns/wwwroot
 # smartdns-ui.ip http://0.0.0.0:6080
+# smartdns-ui.ip https://0.0.0.0:6080
 # smartdns-ui.token-expire 600
-# smartdns-ui.token-secret 123456
+# smartdns-ui.max-query-log-age 86400
 # smartdns-ui.enable-terminal yes
 # smartdns-ui.enable-cors yes
 # smartdns-ui.user admin

+ 6 - 3
plugin/smartdns-ui/src/data_server.rs

@@ -384,8 +384,11 @@ impl DataServer {
         self.db.delete_client_by_id(id)
     }
 
-    pub fn get_client_list(&self) -> Result<Vec<ClientData>, Box<dyn Error>> {
-        self.db.get_client_list()
+    pub fn get_client_list(
+        &self,
+        param: &ClientListGetParam,
+    ) -> Result<QueryClientListResult, Box<dyn Error>> {
+        self.db.get_client_list(Some(param))
     }
 
     pub fn get_top_client_top_list(
@@ -617,7 +620,7 @@ impl DataServer {
     async fn data_server_loop(this: Arc<DataServer>) -> Result<(), Box<dyn Error>> {
         let mut rx: mpsc::Receiver<()>;
         let mut data_rx: mpsc::Receiver<Box<dyn DnsRequest>>;
-        let batch_mode  = *this.recv_in_batch.lock().unwrap();
+        let batch_mode = *this.recv_in_batch.lock().unwrap();
 
         {
             let mut _rx = this.notify_rx.lock().unwrap();

+ 304 - 7
plugin/smartdns-ui/src/db.rs

@@ -114,6 +114,51 @@ pub struct DomainListGetParamCursor {
     pub direction: String,
 }
 
+#[derive(Debug, Clone)]
+pub struct QueryClientListResult {
+    pub client_list: Vec<ClientData>,
+    pub total_count: u64,
+    pub step_by_cursor: bool,
+}
+
+#[derive(Debug, Clone)]
+pub struct ClientListGetParamCursor {
+    pub id: Option<u64>,
+    pub total_count: u64,
+    pub direction: String,
+}
+
+#[derive(Debug, Clone)]
+pub struct ClientListGetParam {
+    pub id: Option<u64>,
+    pub order: Option<String>,
+    pub page_num: u64,
+    pub page_size: u64,
+    pub client_ip: Option<String>,
+    pub mac: Option<String>,
+    pub hostname: Option<String>,
+    pub timestamp_before: Option<u64>,
+    pub timestamp_after: Option<u64>,
+    pub cursor: Option<ClientListGetParamCursor>,
+}
+
+impl ClientListGetParam {
+    pub fn new() -> Self {
+        ClientListGetParam {
+            id: None,
+            page_num: 1,
+            order: None,
+            page_size: 10,
+            client_ip: None,
+            mac: None,
+            hostname: None,
+            timestamp_before: None,
+            timestamp_after: None,
+            cursor: None,
+        }
+    }
+}
+
 #[derive(Debug, Clone)]
 pub struct DomainListGetParam {
     pub id: Option<u64>,
@@ -241,6 +286,11 @@ impl DB {
             [],
         )?;
 
+        conn.execute(
+            "CREATE INDEX IF NOT EXISTS idx_client_last_query_timestamp ON client (last_query_timestamp)",
+            [],
+        )?;
+
         conn.execute(
             "CREATE TABLE IF NOT EXISTS config (
                 key TEXT PRIMARY KEY,
@@ -1333,7 +1383,7 @@ impl DB {
         if conn.as_ref().is_none() {
             return Err("db is not open".into());
         }
-        
+
         let conn = conn.as_mut().unwrap();
         let tx = conn.transaction()?;
         let mut stmt = tx.prepare("INSERT INTO client (id, client_ip, mac, hostname, last_query_timestamp) VALUES (
@@ -1361,16 +1411,240 @@ impl DB {
         Ok(())
     }
 
-    pub fn get_client_list(&self) -> Result<Vec<ClientData>, Box<dyn Error>> {
+    pub fn get_client_list_count(&self, param: Option<&ClientListGetParam>) -> u64 {
+        let conn = self.get_readonly_conn();
+        if conn.as_ref().is_none() {
+            return 0;
+        }
+
+        let conn = conn.as_ref().unwrap();
+        let mut sql = String::new();
+        let mut sql_param = Vec::new();
+        sql.push_str("SELECT COUNT(*) FROM client");
+        if let Ok((sql_where, sql_order, mut ret_sql_param)) = Self::get_client_sql_where(param) {
+            sql.push_str(sql_where.as_str());
+            sql.push_str(sql_order.as_str());
+            sql_param.append(&mut ret_sql_param);
+        }
+
+        let mut stmt = conn.prepare(sql.as_str()).unwrap();
+        let rows = stmt.query_map(rusqlite::params_from_iter(sql_param), |row| Ok(row.get(0)?));
+
+        if let Ok(rows) = rows {
+            for row in rows {
+                if let Ok(row) = row {
+                    return row;
+                }
+            }
+        }
+
+        0
+    }
+
+    fn get_client_sql_where(
+        param: Option<&ClientListGetParam>,
+    ) -> Result<(String, String, Vec<String>), Box<dyn Error>> {
+        let mut is_desc_order = true;
+        let mut is_cursor_prev = false;
+        let param = match param {
+            Some(v) => v,
+            None => return Ok((String::new(), String::new(), Vec::new())),
+        };
+        let mut order_timestamp_first = false;
+        let mut cusor_with_timestamp = false;
+
+        let mut sql_where = Vec::new();
+        let mut sql_param: Vec<String> = Vec::new();
+        let mut sql_order = String::new();
+
+        if let Some(v) = &param.id {
+            sql_where.push("id = ?".to_string());
+            sql_param.push(v.to_string());
+            order_timestamp_first = false;
+        }
+
+        if let Some(v) = &param.order {
+            if v.eq_ignore_ascii_case("asc") {
+                is_cursor_prev = true;
+            } else if v.eq_ignore_ascii_case("desc") {
+                is_cursor_prev = false;
+            } else {
+                return Err("order param error".into());
+            }
+        }
+
+        if let Some(v) = &param.cursor {
+            if v.direction.eq_ignore_ascii_case("prev") {
+                is_desc_order = !is_desc_order;
+            } else if v.direction.eq_ignore_ascii_case("next") {
+                is_desc_order = is_desc_order;
+            } else {
+                return Err("cursor direction param error".into());
+            }
+        }
+
+        if let Some(v) = &param.client_ip {
+            sql_where.push("client_ip = ?".to_string());
+            sql_param.push(v.to_string());
+        }
+
+        if let Some(v) = &param.mac {
+            sql_where.push("mac = ?".to_string());
+            sql_param.push(v.to_string());
+        }
+
+        if let Some(v) = &param.hostname {
+            sql_where.push("hostname = ?".to_string());
+            sql_param.push(v.to_string());
+        }
+
+        if let Some(v) = &param.timestamp_before {
+            let mut use_cursor = false;
+            if param.cursor.is_some() && (is_desc_order || is_cursor_prev) {
+                let v = param.cursor.as_ref().unwrap().id;
+                if let Some(v) = v {
+                    sql_where.push("id < ?".to_string());
+                    sql_param.push(v.to_string());
+                    use_cursor = true;
+                    order_timestamp_first = false;
+                    cusor_with_timestamp = true;
+                }
+            }
+
+            if use_cursor == false {
+                sql_where.push("last_query_timestamp <= ?".to_string());
+                sql_param.push(v.to_string());
+            }
+        }
+
+        if let Some(v) = &param.timestamp_after {
+            let mut use_cursor = false;
+            if param.cursor.is_some() && (!is_desc_order || is_cursor_prev) {
+                let v = param.cursor.as_ref().unwrap().id;
+                if let Some(v) = v {
+                    sql_where.push("id > ?".to_string());
+                    sql_param.push(v.to_string());
+                    use_cursor = true;
+                    order_timestamp_first = false;
+                    cusor_with_timestamp = true;
+                }
+            }
+
+            if use_cursor == false {
+                sql_where.push("last_query_timestamp >= ?".to_string());
+                sql_param.push(v.to_string());
+            }
+        }
+
+        if !cusor_with_timestamp {
+            if let Some(v) = &param.cursor {
+                if is_cursor_prev {
+                    if let Some(id) = &v.id {
+                        if is_desc_order {
+                            sql_where.push("id > ?".to_string());
+                        } else {
+                            sql_where.push("id < ?".to_string());
+                        }
+
+                        sql_param.push(id.to_string());
+                        order_timestamp_first = false;
+                    }
+                } else {
+                    if let Some(id) = &v.id {
+                        if is_desc_order {
+                            sql_where.push("id < ?".to_string());
+                        } else {
+                            sql_where.push("id > ?".to_string());
+                        }
+
+                        sql_param.push(id.to_string());
+                        order_timestamp_first = false;
+                    }
+                }
+            }
+        }
+
+        if is_desc_order {
+            if order_timestamp_first {
+                sql_order.push_str(" ORDER BY last_query_timestamp DESC, id DESC");
+            } else {
+                sql_order.push_str(" ORDER BY id DESC, last_query_timestamp DESC");
+            }
+        } else {
+            if order_timestamp_first {
+                sql_order.push_str(" ORDER BY last_query_timestamp ASC, id ASC");
+            } else {
+                sql_order.push_str(" ORDER BY id ASC, last_query_timestamp ASC");
+            }
+        }
+
+        let sql_where = if sql_where.is_empty() {
+            String::new()
+        } else {
+            format!(" WHERE {}", sql_where.join(" AND "))
+        };
+
+        Ok((sql_where, sql_order, sql_param))
+    }
+
+    pub fn get_client_list(
+        &self,
+        param: Option<&ClientListGetParam>,
+    ) -> Result<QueryClientListResult, Box<dyn Error>> {
+        let query_start = std::time::Instant::now();
+        let mut cursor_reverse = false;
+
+        let mut ret = QueryClientListResult {
+            client_list: vec![],
+            total_count: 0,
+            step_by_cursor: false,
+        };
+
         let conn = self.get_readonly_conn();
         if conn.as_ref().is_none() {
             return Err("db is not open".into());
         }
 
         let conn = conn.as_ref().unwrap();
-        let mut ret = Vec::new();
-        let mut stmt = conn.prepare("SELECT id, client_ip, mac, hostname, last_query_timestamp FROM client").unwrap();
-        let rows = stmt.query_map([], |row| {
+
+        let (sql_where, sql_order, mut sql_param) = Self::get_client_sql_where(param)?;
+
+        let mut sql = String::new();
+        sql.push_str("SELECT id, client_ip, mac, hostname, last_query_timestamp FROM client");
+
+        sql.push_str(sql_where.as_str());
+        sql.push_str(sql_order.as_str());
+
+        if let Some(p) = param {
+            let mut with_offset = true;
+            if let Some(cursor) = &p.cursor {
+                if cursor.id.is_some() {
+                    sql.push_str(" LIMIT ?");
+                    sql_param.push(p.page_size.to_string());
+                    with_offset = false;
+                }
+
+                if cursor.direction.eq_ignore_ascii_case("prev") {
+                    cursor_reverse = true;
+                }
+            }
+
+            if with_offset {
+                sql.push_str(" LIMIT ? OFFSET ?");
+                sql_param.push(p.page_size.to_string());
+                sql_param.push(((p.page_num - 1) * p.page_size).to_string());
+            }
+        }
+
+        self.debug_query_plan(conn, sql.clone(), &sql_param);
+        let stmt = conn.prepare(&sql);
+        if let Err(e) = stmt {
+            dns_log!(LogLevel::ERROR, "get_client_list error: {}", e);
+            return Err("get_client_list error".into());
+        }
+        let mut stmt = stmt?;
+
+        let rows = stmt.query_map(rusqlite::params_from_iter(sql_param), |row| {
             Ok(ClientData {
                 id: row.get(0)?,
                 client_ip: row.get(1)?,
@@ -1380,14 +1654,37 @@ impl DB {
             })
         });
 
+        if let Err(e) = rows {
+            return Err(Box::new(e));
+        }
+
         if let Ok(rows) = rows {
             for row in rows {
                 if let Ok(row) = row {
-                    ret.push(row);
+                    ret.client_list.push(row);
                 }
             }
         }
 
+        if cursor_reverse {
+            ret.client_list.reverse();
+        }
+
+        if let Some(p) = param {
+            if let Some(v) = &p.cursor {
+                ret.total_count = v.total_count;
+                ret.step_by_cursor = true;
+            } else {
+                let total_count = self.get_client_list_count(param);
+                ret.total_count = total_count;
+            }
+        }
+
+        dns_log!(
+            LogLevel::DEBUG,
+            "domain_list time: {}ms",
+            query_start.elapsed().as_millis()
+        );
         Ok(ret)
     }
 
@@ -1398,7 +1695,7 @@ impl DB {
         }
 
         let conn = conn.as_ref().unwrap();
-        
+
         let ret = conn.execute("DELETE FROM client WHERE id = ?", &[&id]);
 
         if let Err(e) = ret {

+ 21 - 15
plugin/smartdns-ui/src/http_api_msg.rs

@@ -294,22 +294,31 @@ pub fn api_msg_parse_client_list(data: &str) -> Result<Vec<ClientData>, Box<dyn
     Ok(client_list)
 }
 
-pub fn api_msg_gen_client_list(client_list: &Vec<ClientData>, total_count: u32) -> String {
+pub fn api_msg_gen_json_object_client(client: &ClientData) -> serde_json::Value {
+    json!({
+        "id": client.id,
+        "client_ip": client.client_ip,
+        "mac": client.mac,
+        "hostname": client.hostname,
+        "last_query_timestamp": client.last_query_timestamp,
+    })
+}
+
+pub fn api_msg_gen_client_list(
+    client_list_result: &QueryClientListResult,
+    total_page: u64,
+    total_count: u64,
+) -> String {
     let json_str = json!({
-        "list_count": client_list.len(),
+        "list_count": client_list_result.client_list.len(),
+        "total_page": total_page,
         "total_count": total_count,
+        "step_by_cursor": client_list_result.step_by_cursor,
         "client_list":
-            client_list
+        client_list_result.client_list
                 .iter()
                 .map(|x| {
-                    let s = json!({
-                        "id": x.id,
-                        "client_ip": x.client_ip,
-                        "mac": x.mac,
-                        "hostname": x.hostname,
-                        "last_query_timestamp": x.last_query_timestamp,
-                    });
-                    s
+                    api_msg_gen_json_object_client(x)
                 })
                 .collect::<Vec<serde_json::Value>>()
 
@@ -796,7 +805,6 @@ pub fn api_msg_parse_stats_overview(data: &str) -> Result<OverviewData, Box<dyn
         startup_timestamp: startup_timestamp.unwrap() as u64,
         free_disk_space: free_disk_space.unwrap() as u64,
         is_process_suspended: is_process_suspended.unwrap(),
-
     })
 }
 
@@ -818,9 +826,7 @@ pub fn api_msg_gen_hourly_query_count(hourly_count: &HourlyQueryCount) -> String
     json_str.to_string()
 }
 
-pub fn api_msg_parse_hourly_query_count(
-    data: &str,
-) -> Result<HourlyQueryCount, Box<dyn Error>> {
+pub fn api_msg_parse_hourly_query_count(data: &str) -> Result<HourlyQueryCount, Box<dyn Error>> {
     let v: serde_json::Value = serde_json::from_str(data)?;
     let query_timestamp = v["query_timestamp"].as_u64();
     if query_timestamp.is_none() {

+ 77 - 13
plugin/smartdns-ui/src/http_server_api.rs

@@ -640,7 +640,11 @@ impl API {
         if cursor.is_some() || total_count.is_some() {
             let param_cursor = DomainListGetParamCursor {
                 id: if cursor.is_some() { cursor } else { None },
-                total_count: total_count.unwrap(),
+                total_count: if total_count.is_some() {
+                    total_count.unwrap()
+                } else {
+                    0
+                },
                 direction: cursor_direction,
             };
             param.cursor = Some(param_cursor);
@@ -717,11 +721,65 @@ impl API {
     async fn api_client_get_list(
         this: Arc<HttpServer>,
         _param: APIRouteParam,
-        _req: Request<body::Incoming>,
+        req: Request<body::Incoming>,
     ) -> Result<Response<Full<Bytes>>, HttpError> {
+        let params = API::get_params(&req);
+
+        let page_num = API::params_get_value_default(&params, "page_num", 1 as u64)?;
+        let page_size = API::params_get_value_default(&params, "page_size", 10 as u64)?;
+        if page_num == 0 || page_size == 0 {
+            return API::response_error(
+                StatusCode::BAD_REQUEST,
+                "Invalid parameter: page_num or page_size",
+            );
+        }
+
+        let id = API::params_get_value(&params, "id");
+        let client_ip = API::params_get_value(&params, "client_ip");
+        let hostname = API::params_get_value(&params, "hostname");
+        let mac = API::params_get_value(&params, "mac");
+        let timestamp_after = API::params_get_value(&params, "timestamp_after");
+        let timestamp_before = API::params_get_value(&params, "timestamp_before");
+        let order = API::params_get_value(&params, "order");
+        let cursor = API::params_get_value(&params, "cursor");
+        let cursor_direction =
+            match API::params_get_value_default(&params, "cursor_direction", "next".to_string()) {
+                Ok(v) => v,
+                Err(e) => {
+                    return Ok(e.to_response());
+                }
+            };
+        let total_count = API::params_get_value(&params, "total_count");
+
+        let mut param = ClientListGetParam::new();
+        param.id = id;
+        param.page_num = page_num;
+        param.page_size = page_size;
+        param.client_ip = client_ip;
+        param.hostname = hostname;
+        param.mac = mac;
+        param.order = order;
+        param.timestamp_after = timestamp_after;
+        param.timestamp_before = timestamp_before;
+
+        if cursor.is_some() || total_count.is_some() {
+            let param_cursor = ClientListGetParamCursor {
+                id: if cursor.is_some() { cursor } else { None },
+                total_count: if total_count.is_some() {
+                    total_count.unwrap()
+                } else {
+                    0
+                },
+                direction: cursor_direction,
+            };
+            param.cursor = Some(param_cursor);
+        }
+
         let data_server = this.get_data_server();
         let ret = API::call_blocking(this, move || {
-            let ret = data_server.get_client_list();
+            let ret = data_server
+                .get_client_list(&param)
+                .map_err(|e| e.to_string());
             if let Err(e) = ret {
                 return Err(e.to_string());
             }
@@ -729,21 +787,27 @@ impl API {
             let ret = ret.unwrap();
 
             return Ok(ret);
-        }).await;
+        })
+        .await;
 
-        let ret = match ret {
-            Ok(v) => v,
-            Err(e) => {
-                return API::response_error(StatusCode::INTERNAL_SERVER_ERROR, e.to_string().as_str());
-            },
-        };
-    
+        if let Err(e) = ret {
+            return API::response_error(StatusCode::INTERNAL_SERVER_ERROR, e.to_string().as_str());
+        }
+
+        let ret = ret.unwrap();
         if let Err(e) = ret {
             return API::response_error(StatusCode::INTERNAL_SERVER_ERROR, e.to_string().as_str());
         }
 
         let client_list = ret.unwrap();
-        let body = api_msg_gen_client_list(&client_list, client_list.len() as u32);
+        let list_count = client_list.total_count;
+        let mut total_page = list_count / page_size;
+        if list_count % page_size != 0 {
+            total_page += 1;
+        }
+
+        let total_count = client_list.total_count;
+        let body = api_msg_gen_client_list(&client_list, total_page, total_count);
 
         API::response_build(StatusCode::OK, body)
     }
@@ -846,7 +910,7 @@ impl API {
         }
 
         if key.is_some() {
-            let key : String = key.unwrap();
+            let key: String = key.unwrap();
             let value = settings.get(key.as_str());
             if value.is_none() {
                 return API::response_error(StatusCode::NOT_FOUND, "Not found");

+ 6 - 2
plugin/smartdns-ui/src/plugin.rs

@@ -95,7 +95,7 @@ impl SmartdnsPlugin {
 
         let www_root = Plugin::dns_conf_plugin_config("smartdns-ui.www-root");
         if let Some(www_root) = www_root {
-            http_conf.http_root = www_root;
+            http_conf.http_root = smartdns_conf_get_conf_fullpath(&www_root);
         }
 
         let ip = Plugin::dns_conf_plugin_config("smartdns-ui.ip");
@@ -112,7 +112,11 @@ impl SmartdnsPlugin {
         }
         dns_log!(LogLevel::INFO, "www root: {}", http_conf.http_root);
 
-        if let Some(token_expire) = matches.opt_str("token-expire") {
+        let mut token_expire = Plugin::dns_conf_plugin_config("smartdns-ui.token-expire");
+        if token_expire.is_none() {
+            token_expire = matches.opt_str("token-expire");
+        }
+        if let Some(token_expire) = token_expire {
             let v = token_expire.parse::<u32>();
             if let Err(e) = v {
                 dns_log!(

+ 17 - 0
plugin/smartdns-ui/src/smartdns.rs

@@ -266,6 +266,23 @@ pub fn smartdns_enable_update_neighbour(enable: bool) {
     }
 }
 
+pub fn smartdns_conf_get_conf_fullpath(path: &str) -> String {
+    let path = CString::new(path).expect("Failed to convert to CString");
+    unsafe {
+        let mut buffer = [0u8; 4096];
+        smartdns_c::conf_get_conf_fullpath(
+            path.as_ptr(),
+            buffer.as_mut_ptr() as *mut c_char,
+            buffer.len() as usize,
+        );
+        let conf_fullpath = std::ffi::CStr::from_ptr(buffer.as_ptr() as *const c_char)
+            .to_string_lossy()
+            .into_owned();
+
+        conf_fullpath
+    }
+}
+
 pub fn smartdns_server_stop() {
     unsafe {
         smartdns_c::smartdns_server_stop();

+ 1 - 1
plugin/smartdns-ui/tests/restapi_test.rs

@@ -493,7 +493,7 @@ fn test_rest_api_get_client() {
     let res = client.login("admin", "password");
     assert!(res.is_ok());
 
-    let c = client.get("/api/client");
+    let c = client.get("/api/client?page_size=4096");
     assert!(c.is_ok());
     let (code, body) = c.unwrap();
     assert_eq!(code, 200);