فهرست منبع

webui: add some API.

Nick Peng 1 سال پیش
والد
کامیت
63f50fd64a

+ 9 - 3
plugin/smartdns-ui/Cargo.toml

@@ -7,7 +7,6 @@ edition = "2021"
 crate-type = ["cdylib", "lib"]
 
 [dependencies]
-libc = "0.2.155"
 ctor = "0.2.8"
 bytes = "1.6.1"
 rusqlite = { version = "0.32.0", features = ["bundled"] }
@@ -16,8 +15,8 @@ hyper-util = { version = "0.1.6", features = ["full"] }
 hyper-tungstenite = "0.14.0"
 tokio = { version = "1.38.1", features = ["full"] }
 serde = { version = "1.0.204", features = ["derive"] }
-tokio-rustls = "0.26.0"
-rustls-pemfile = "2.1.2"
+tokio-rustls = { version = "0.26.0", optional = true}
+rustls-pemfile = { version = "2.1.2", optional = true}
 serde_json = "1.0.120"
 http-body-util = "0.1.2"
 getopts = "0.2.21"
@@ -26,9 +25,16 @@ jsonwebtoken = "9"
 matchit = "0.8.4"
 futures = "0.3.30"
 socket2 = "0.5.7"
+cfg-if = "1.0.0"
+urlencoding = "2.1.3"
+chrono = "0.4.38"
+nix = "0.29.0"
+tokio-fd = "0.3.0"
 
 [features]
 build-release = []
+https = ["tokio-rustls", "rustls-pemfile"]
+default = ["https"]
 
 [dev-dependencies]
 reqwest = {version = "0.12.5", features = ["blocking"]}

+ 22 - 11
plugin/smartdns-ui/build.rs

@@ -1,6 +1,6 @@
+use std::collections::HashSet;
 use std::env;
 use std::path::PathBuf;
-use std::collections::HashSet;
 
 #[derive(Debug)]
 struct IgnoreMacros(HashSet<String>);
@@ -20,19 +20,30 @@ fn link_smartdns_lib() {
     let smartdns_src_dir = format!("{}/../../src", curr_source_dir);
     let smartdns_lib_file = format!("{}/libsmartdns-test.a", smartdns_src_dir);
 
-    let ignored_macros = IgnoreMacros(
-        vec![
-            "IPPORT_RESERVED".into(),
-        ]
-        .into_iter()
-        .collect(),
-    );
+    let cc = env::var("RUSTC_LINKER")
+        .unwrap_or_else(|_| env::var("CC").unwrap_or_else(|_| "cc".to_string()));
 
-    let bindings = bindgen::Builder::default()
-        .header(format!("{}/smartdns.h", smartdns_src_dir))
+    let sysroot_output = std::process::Command::new(&cc)
+        .arg("--print-sysroot")
+        .output();
+    let mut sysroot = None;
+    if let Ok(output) = sysroot_output {
+        if output.status.success() {
+            let path = String::from_utf8(output.stdout).unwrap();
+            sysroot = Some(path.trim().to_string());
+        }
+    }
+
+    let ignored_macros = IgnoreMacros(vec!["IPPORT_RESERVED".into()].into_iter().collect());
+
+    let mut bindings_builder =
+        bindgen::Builder::default().header(format!("{}/smartdns.h", smartdns_src_dir));
+    if let Some(sysroot) = sysroot {
+        bindings_builder = bindings_builder.clang_arg(format!("--sysroot={}", sysroot));
+    }
+    let bindings = bindings_builder
         .clang_arg(format!("-I{}/include", smartdns_src_dir))
         .parse_callbacks(Box::new(ignored_macros))
-        .blocklist_file("/usr/include/.*")
         .generate()
         .expect("Unable to generate bindings");
 

+ 158 - 8
plugin/smartdns-ui/src/data_server.rs

@@ -16,34 +16,76 @@
  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
  */
 
+use crate::data_stats::*;
 use crate::db::*;
 use crate::dns_log;
+use crate::smartdns;
 use crate::smartdns::*;
+use crate::utils;
 
+use std::collections::HashMap;
 use std::error::Error;
 use std::sync::{Arc, Mutex, RwLock};
 use std::thread;
 use tokio::sync::mpsc;
 use tokio::time::{interval_at, Duration, Instant};
 
+pub const DEFAULT_MAX_LOG_AGE: u64 = 30 * 24 * 60 * 60;
+pub const DEFAULT_MAX_LOG_AGE_MS: u64 = DEFAULT_MAX_LOG_AGE * 1000;
+pub const MAX_LOG_AGE_VALUE_MIN: u64 = 3600;
+pub const MAX_LOG_AGE_VALUE_MAX: u64 = 365 * 24 * 60 * 60 * 10;
+
+pub struct OverviewData {
+    pub total_query_count: u64,
+    pub block_query_count: u64,
+    pub avg_query_time: f64,
+    pub cache_hit_rate: f64,
+}
+
 #[derive(Clone)]
 pub struct DataServerConfig {
     pub data_root: String,
-    pub max_log_age_ms: u32,
+    pub max_log_age_ms: u64,
 }
 
 impl DataServerConfig {
     pub fn new() -> Self {
         DataServerConfig {
             data_root: Plugin::dns_conf_data_dir() + "/ui.db",
-            max_log_age_ms: 7 * 24 * 60 * 60 * 1000,
+            max_log_age_ms: DEFAULT_MAX_LOG_AGE_MS,
         }
     }
+
+    pub fn load_config(&mut self, data_server: Arc<DataServer>) -> Result<(), Box<dyn Error>> {
+        self.max_log_age_ms = utils::parse_value(
+            data_server.get_server_config("smartdns-ui.max-query-log-age"),
+            MAX_LOG_AGE_VALUE_MIN,
+            MAX_LOG_AGE_VALUE_MAX,
+            DEFAULT_MAX_LOG_AGE,
+        ) * 1000;
+
+        let log_level = data_server.get_server_config("log-level");
+        if let Some(log_level) = log_level {
+            let log_level = log_level.try_into();
+            match log_level {
+                Ok(log_level) => {
+                    dns_log_set_level(log_level);
+                }
+                Err(_) => {
+                    dns_log!(LogLevel::WARN, "log level is invalid");
+                }
+            }
+        }
+
+        Ok(())
+    }
 }
 
 pub struct DataServerControl {
     data_server: Arc<DataServer>,
     server_thread: Mutex<Option<thread::JoinHandle<()>>>,
+    is_init: Mutex<bool>,
+    is_run: Mutex<bool>,
 }
 
 impl DataServerControl {
@@ -51,6 +93,8 @@ impl DataServerControl {
         DataServerControl {
             data_server: Arc::new(DataServer::new()),
             server_thread: Mutex::new(None),
+            is_init: Mutex::new(false),
+            is_run: Mutex::new(false),
         }
     }
 
@@ -58,32 +102,50 @@ impl DataServerControl {
         Arc::clone(&self.data_server)
     }
 
-    pub fn start_data_server(&self, conf: &DataServerConfig) -> Result<(), Box<dyn Error>> {
+    pub fn init_db(&self, conf: &DataServerConfig) -> Result<(), Box<dyn Error>> {
         let inner_clone = Arc::clone(&self.data_server);
-
         let ret = inner_clone.init_server(conf);
         if let Err(e) = ret {
             return Err(e);
         }
 
+        *self.is_init.lock().unwrap() = true;
+        Ok(())
+    }
+
+    pub fn start_data_server(&self) -> Result<(), Box<dyn Error>> {
+        let inner_clone = Arc::clone(&self.data_server);
+
+        if *self.is_init.lock().unwrap() == false {
+            return Err("data server not init".into());
+        }
+
         let server_thread = thread::spawn(move || {
             let ret = DataServer::data_server_loop(inner_clone);
             if let Err(e) = ret {
                 dns_log!(LogLevel::ERROR, "data server error: {}", e);
                 Plugin::smartdns_exit(1);
             }
+
+            dns_log!(LogLevel::INFO, "data server exit.");
         });
 
+        *self.is_run.lock().unwrap() = true;
         *self.server_thread.lock().unwrap() = Some(server_thread);
         Ok(())
     }
 
     pub fn stop_data_server(&self) {
+        if *self.is_run.lock().unwrap() == false {
+            return;
+        }
+    
         self.data_server.stop_data_server();
         let _server_thread = self.server_thread.lock().unwrap().take();
         if let Some(server_thread) = _server_thread {
             server_thread.join().unwrap();
         }
+        *self.is_run.lock().unwrap() = false;
     }
 
     pub fn send_request(&self, request: &mut DnsRequest) -> Result<(), Box<dyn Error>> {
@@ -107,6 +169,7 @@ pub struct DataServer {
     data_tx: Option<mpsc::Sender<DnsRequest>>,
     data_rx: Mutex<Option<mpsc::Receiver<DnsRequest>>>,
     db: DB,
+    stat: Arc<DataStats>,
 }
 
 impl DataServer {
@@ -118,6 +181,7 @@ impl DataServer {
             data_tx: None,
             data_rx: Mutex::new(None),
             db: DB::new(),
+            stat: DataStats::new(),
         };
 
         let (tx, rx) = mpsc::channel(100);
@@ -137,12 +201,48 @@ impl DataServer {
         dns_log!(LogLevel::INFO, "open db: {}", conf_clone.data_root);
         let ret = self.db.open(&conf_clone.data_root);
         if let Err(e) = ret {
-            dns_log!(LogLevel::ERROR, "open db error: {}", e);
             return Err(e);
         }
+
+        let ret = self.stat.clone().init();
+        if let Err(e) = ret {
+            return Err(e);
+        }
+
         Ok(())
     }
 
+    pub fn get_config(&self, key: &str) -> Option<String> {
+        let ret = self.db.get_config(key);
+        if let Ok(value) = ret {
+            return value;
+        }
+
+        None
+    }
+
+    pub fn get_server_config(&self, key: &str) -> Option<String> {
+        let ret = self.get_config(key);
+        if let Some(value) = ret {
+            return Some(value);
+        }
+
+        let ret = Plugin::dns_conf_plugin_config(key);
+        if let Some(value) = ret {
+            return Some(value);
+        }
+
+        None
+    }
+
+    pub fn get_config_list(&self) -> Result<HashMap<String, String>, Box<dyn Error>> {
+        self.db.get_config_list()
+    }
+
+    pub fn set_config(&self, key: &str, value: &str) -> Result<(), Box<dyn Error>> {
+        self.db.set_config(key, value)
+    }
+
     pub fn get_domain_list(
         &self,
         param: &DomainListGetParam,
@@ -166,11 +266,50 @@ impl DataServer {
         self.db.get_client_list()
     }
 
+    pub fn get_top_client_top_list(
+        &self,
+        count: u32,
+    ) -> Result<Vec<ClientQueryCount>, Box<dyn Error>> {
+        self.db.get_client_top_list(count)
+    }
+
+    pub fn get_top_domain_top_list(
+        &self,
+        count: u32,
+    ) -> Result<Vec<DomainQueryCount>, Box<dyn Error>> {
+        self.db.get_domain_top_list(count)
+    }
+
+    pub fn get_hourly_query_count(&self, pastt_hours: u32) -> Result<Vec<HourlyQueryCount>, Box<dyn Error>> {
+        self.db.get_hourly_query_count(pastt_hours)
+    }
+
+    pub fn get_overview(&self) -> Result<OverviewData, Box<dyn Error>> {
+        let overview = OverviewData {
+            total_query_count: smartdns::Stats::get_request_total(),
+            block_query_count: smartdns::Stats::get_request_blocked(),
+            avg_query_time: smartdns::Stats::get_avg_process_time(),
+            cache_hit_rate: smartdns::Stats::get_cache_hit_rate(),
+        };
+
+        Ok(overview)
+    }
+
     pub fn insert_domain(&self, data: &DomainData) -> Result<(), Box<dyn Error>> {
+        let client_ip = &data.client;
+        self.db.insert_client(client_ip.as_str())?;
         self.db.insert_domain(data)
     }
 
     async fn data_server_handle(this: Arc<DataServer>, req: DnsRequest) {
+        if req.is_prefetch_request() {
+            return;
+        }
+
+        if req.is_dualstack_request() {
+            return;
+        }
+
         let domain_data = DomainData {
             id: 0,
             domain: req.get_domain(),
@@ -178,7 +317,11 @@ impl DataServer {
             client: req.get_remote_addr(),
             domain_group: req.get_group_name(),
             reply_code: req.get_rcode(),
-            timestamp: req.get_query_time(),
+            timestamp: req.get_query_timestamp(),
+            query_time: req.get_query_time(),
+            ping_time: req.get_ping_time(),
+            is_blocked: req.get_is_blocked(),
+            is_cached: req.get_is_cached(),
         };
         dns_log!(
             LogLevel::DEBUG,
@@ -194,7 +337,7 @@ impl DataServer {
     }
 
     async fn hourly_work(this: Arc<DataServer>) {
-        dns_log!(LogLevel::INFO, "start hourly work");
+        dns_log!(LogLevel::ERROR, "start hourly work");
         let now = get_utc_time_ms();
 
         let ret = this
@@ -220,9 +363,14 @@ impl DataServer {
             data_rx = _rx.take().unwrap();
         }
 
-        let start = Instant::now() + Duration::from_secs(60);
+        this.stat.clone().start_worker()?;
+
+        let start: Instant =
+            Instant::now() + Duration::from_secs(utils::seconds_until_next_hour());
         let mut hour_timer = interval_at(start, Duration::from_secs(60 * 60));
 
+        dns_log!(LogLevel::INFO, "data server start.");
+
         loop {
             tokio::select! {
                 _ = rx.recv() => {
@@ -244,6 +392,8 @@ impl DataServer {
             }
         }
 
+        this.stat.clone().stop_worker();
+
         Ok(())
     }
 

+ 143 - 0
plugin/smartdns-ui/src/data_stats.rs

@@ -0,0 +1,143 @@
+/*************************************************************************
+ *
+ * Copyright (C) 2018-2024 Ruilin Peng (Nick) <[email protected]>.
+ *
+ * smartdns is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * smartdns is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ */
+
+use std::error::Error;
+use std::thread;
+
+use crate::smartdns::*;
+
+use std::sync::{
+    atomic::{AtomicBool, Ordering},
+    Arc, Mutex,
+};
+use std::time::Duration;
+use tokio::sync::mpsc;
+use tokio::time::{interval_at, Instant};
+
+use crate::utils;
+
+struct DataStatsItem {
+
+}
+
+impl DataStatsItem {
+    pub fn new() -> Self {
+        DataStatsItem { }
+    }
+
+    #[allow(dead_code)]
+    pub fn get_current_hour_total(&self) -> u64 {
+        return Stats::get_request_total();
+    }
+
+    #[allow(dead_code)]
+    pub fn update_total(&mut self, _total: u64) {
+
+    }
+}
+
+pub struct DataStats {
+    task: Mutex<Option<tokio::task::JoinHandle<()>>>,
+    notify_tx: Option<mpsc::Sender<()>>,
+    notify_rx: Mutex<Option<mpsc::Receiver<()>>>,
+    is_run: AtomicBool,
+    data: Mutex<DataStatsItem>,
+}
+
+impl DataStats {
+    pub fn new() -> Arc<Self> {
+        let (tx, rx) = mpsc::channel(100);
+
+        Arc::new(DataStats {
+            task: Mutex::new(None),
+            notify_rx: Mutex::new(Some(rx)),
+            notify_tx: Some(tx),
+            is_run: AtomicBool::new(false),
+            data: Mutex::new(DataStatsItem::new()),
+        })
+    }
+
+    pub fn init(self: Arc<Self>) -> Result<(), Box<dyn Error>> {
+        Ok(())
+    }
+
+    pub fn start_worker(self: Arc<Self>) -> Result<(), Box<dyn Error>> {
+        let this = self.clone();
+        let task = tokio::spawn(async move {
+            DataStats::worker_loop(this).await;
+        });
+
+        *(self.task.lock().unwrap()) = Some(task);
+        self.is_run.store(true, Ordering::Relaxed);
+        Ok(())
+    }
+
+    async fn update_stats(&self) {
+        let mut data = self.data.lock().unwrap();
+        let total = Stats::get_request_total();
+        data.update_total(total);
+    }
+
+    async fn worker_loop(this: Arc<Self>) {
+        let mut rx: mpsc::Receiver<()>;
+        {
+            let mut _rx = this.notify_rx.lock().unwrap();
+            rx = _rx.take().unwrap();
+        }
+
+        let start: Instant = Instant::now() + Duration::from_secs(utils::seconds_until_next_hour());
+        let mut hour_timer = interval_at(start, Duration::from_secs(60 * 60));
+
+        loop {
+            tokio::select! {
+                _ = rx.recv() => {
+                    break;
+                }
+
+                _ = hour_timer.tick() => {
+                    this.update_stats().await;
+                }
+            }
+        }
+    }
+
+    pub fn stop_worker(&self) {
+        if self.is_run.load(Ordering::Relaxed) == false {
+            return;
+        }
+
+        if let Some(tx) = self.notify_tx.as_ref().cloned() {
+            let t = thread::spawn(move || {
+                let rt = tokio::runtime::Runtime::new().unwrap();
+                rt.block_on(async move {
+                    _ = tx.send(()).await;
+                });
+            });
+
+            let _ = t.join();
+        }
+
+        self.is_run.store(false, Ordering::Relaxed);
+    }
+}
+
+impl Drop for DataStats {
+    fn drop(&mut self) {
+        self.stop_worker();
+    }
+}

+ 275 - 17
plugin/smartdns-ui/src/db.rs

@@ -18,6 +18,8 @@
 
 use crate::dns_log;
 use crate::smartdns::*;
+use crate::utils;
+use std::collections::HashMap;
 use std::error::Error;
 use std::fs;
 use std::sync::Mutex;
@@ -26,16 +28,31 @@ use rusqlite::{Connection, OpenFlags, Result};
 
 pub struct DB {
     conn: Mutex<Option<Connection>>,
+    version: i32,
 }
 
+#[derive(Debug, Clone)]
 pub struct ClientData {
     pub id: u32,
     pub client_ip: String,
 }
 
-pub struct ConfigData {
-    pub key: String,
-    pub value: String,
+#[derive(Debug, Clone)]
+pub struct ClientQueryCount {
+    pub client_ip: String,
+    pub count: u32,
+}
+
+#[derive(Debug, Clone)]
+pub struct DomainQueryCount {
+    pub domain: String,
+    pub count: u32,
+}
+
+#[derive(Debug, Clone)]
+pub struct HourlyQueryCount {
+    pub hour: String,
+    pub query_count: u32,
 }
 
 #[derive(Debug, Clone)]
@@ -47,8 +64,13 @@ pub struct DomainData {
     pub client: String,
     pub domain_group: String,
     pub reply_code: u16,
+    pub query_time: i32,
+    pub ping_time: f64,
+    pub is_blocked: bool,
+    pub is_cached: bool,
 }
 
+#[derive(Debug, Clone)]
 pub struct DomainListGetParam {
     pub id: Option<u64>,
     pub order: Option<String>,
@@ -61,6 +83,8 @@ pub struct DomainListGetParam {
     pub reply_code: Option<u16>,
     pub timestamp_before: Option<u64>,
     pub timestamp_after: Option<u64>,
+    pub is_blocked: Option<bool>,
+    pub is_cached: Option<bool>,
 }
 
 impl DomainListGetParam {
@@ -77,6 +101,8 @@ impl DomainListGetParam {
             reply_code: None,
             timestamp_before: None,
             timestamp_after: None,
+            is_blocked: None,
+            is_cached: None,
         }
     }
 }
@@ -85,10 +111,11 @@ impl DB {
     pub fn new() -> Self {
         DB {
             conn: Mutex::new(None),
+            version: 10000, /* x: major version, xx: minor version, xx: patch version */
         }
     }
 
-    fn init_db(&self, conn: &Connection) -> Result<()> {
+    fn create_table(&self, conn: &Connection) -> Result<()> {
         conn.execute(
             "CREATE TABLE IF NOT EXISTS domain (
                 id INTEGER PRIMARY KEY AUTOINCREMENT,
@@ -97,7 +124,11 @@ impl DB {
                 domain_type INTEGER NOT NULL,
                 client TEXT NOT NULL,
                 domain_group TEXT NOT NULL,
-                reply_code INTEGER NOT NULL
+                reply_code INTEGER NOT NULL,
+                query_time INTEGER NOT NULL,
+                ping_time REAL NOT NULL,
+                is_blocked INTEGER DEFAULT 0,
+                is_cached INTEGER DEFAULT 0
             )",
             [],
         )?;
@@ -106,7 +137,7 @@ impl DB {
             "
         CREATE TABLE IF NOT EXISTS client (
             id INTEGER PRIMARY KEY AUTOINCREMENT,
-            client_ip TEXT NOT NULL
+            client_ip TEXT NOT NULL UNIQUE
         )",
             [],
         )?;
@@ -119,6 +150,43 @@ impl DB {
             [],
         )?;
 
+        conn.execute(
+            "INSERT INTO schema_version (version) VALUES (?)",
+            [self.version],
+        )?;
+
+        Ok(())
+    }
+
+    fn migrate_db(&self, _conn: &Connection) -> Result<(), Box<dyn Error>> {
+        return Err(
+            "Currently Not Support Migrate Database, Please Backup DB File, And Restart Server."
+                .into(),
+        );
+    }
+
+    fn init_db(&self, conn: &Connection) -> Result<(), Box<dyn Error>> {
+        conn.execute(
+            "CREATE TABLE IF NOT EXISTS schema_version (
+                version INTEGER PRIMARY KEY
+            )",
+            [],
+        )?;
+
+        let current_version: i32 = conn
+            .query_row(
+                "SELECT version FROM schema_version ORDER BY version DESC LIMIT 1",
+                [],
+                |row| row.get(0),
+            )
+            .unwrap_or(self.version);
+
+        if current_version >= self.version {
+            self.create_table(conn)?;
+        } else {
+            self.migrate_db(conn)?;
+        }
+
         Ok(())
     }
 
@@ -136,7 +204,7 @@ impl DB {
             if let Err(e) = ret {
                 _ = ruconn.close();
                 fs::remove_file(path)?;
-                return Err(Box::new(e));
+                return Err(e);
             }
 
             *conn = Some(ruconn);
@@ -153,17 +221,16 @@ impl DB {
         Ok(())
     }
 
-    pub fn insert_config(&self, conf: &ConfigData) -> Result<(), Box<dyn Error>> {
+    pub fn set_config(&self, key: &str, value: &str) -> Result<(), Box<dyn Error>> {
         let conn = self.conn.lock().unwrap();
         if conn.as_ref().is_none() {
             return Ok(());
         }
 
         let conn = conn.as_ref().unwrap();
-        let mut stmt = conn
-            .prepare("INSERT OR REPLACE INTO config (key, value) VALUES (?1, ?2)")
-            .unwrap();
-        let ret = stmt.execute(&[&conf.key, &conf.value]);
+        let mut stmt =
+            conn.prepare("INSERT OR REPLACE INTO config (key, value) VALUES (?1, ?2)")?;
+        let ret = stmt.execute(&[&key, &value]);
 
         if let Err(e) = ret {
             return Err(Box::new(e));
@@ -172,6 +239,33 @@ impl DB {
         Ok(())
     }
 
+    pub fn get_config_list(&self) -> Result<HashMap<String, String>, Box<dyn Error>> {
+        let mut ret = HashMap::new();
+        let conn = self.conn.lock().unwrap();
+        if conn.as_ref().is_none() {
+            return Ok(ret);
+        }
+
+        let conn = conn.as_ref().unwrap();
+        let mut stmt = conn.prepare("SELECT key, value FROM config").unwrap();
+
+        let rows = stmt.query_map([], |row| {
+            let key: String = row.get(0)?;
+            let value: String = row.get(1)?;
+            Ok((key, value))
+        });
+
+        if let Ok(rows) = rows {
+            for row in rows {
+                if let Ok(row) = row {
+                    ret.insert(row.0, row.1);
+                }
+            }
+        }
+
+        Ok(ret)
+    }
+
     pub fn get_config(&self, key: &str) -> Result<Option<String>, Box<dyn Error>> {
         let conn = self.conn.lock().unwrap();
         if conn.as_ref().is_none() {
@@ -202,14 +296,21 @@ impl DB {
         }
 
         let conn = conn.as_ref().unwrap();
-        let mut stmt = conn.prepare("INSERT INTO domain (timestamp, domain, domain_type, client, domain_group, reply_code) VALUES (?1, ?2, ?3, ?4, ?5, ?6)").unwrap();
-        let ret = stmt.execute(&[
+        let mut stmt = conn.prepare(
+            "INSERT INTO domain \
+            (timestamp, domain, domain_type, client, domain_group, reply_code, query_time, ping_time, is_blocked, is_cached) \
+            VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)")?;
+        let ret = stmt.execute(rusqlite::params![
             &data.timestamp.to_string(),
             &data.domain,
             &data.domain_type.to_string(),
             &data.client,
             &data.domain_group,
-            &data.reply_code.to_string(),
+            &data.reply_code,
+            &data.query_time,
+            &data.ping_time,
+            &(data.is_blocked as i32),
+            &(data.is_cached as i32)
         ]);
 
         if let Err(e) = ret {
@@ -275,6 +376,114 @@ impl DB {
         Ok(ret.unwrap() as u64)
     }
 
+    pub fn get_client_top_list(&self, count: u32) -> Result<Vec<ClientQueryCount>, Box<dyn Error>> {
+        let mut ret = Vec::new();
+        let conn = self.conn.lock().unwrap();
+        if conn.as_ref().is_none() {
+            return Ok(ret);
+        }
+
+        let conn = conn.as_ref().unwrap();
+        let mut stmt = conn.prepare(
+            "SELECT client, COUNT(*) FROM domain GROUP BY client ORDER BY COUNT(*) DESC LIMIT ?",
+        )?;
+        let rows = stmt.query_map([count.to_string()], |row| {
+            Ok(ClientQueryCount {
+                client_ip: row.get(0)?,
+                count: row.get(1)?,
+            })
+        });
+
+        if let Ok(rows) = rows {
+            for row in rows {
+                if let Ok(row) = row {
+                    ret.push(row);
+                }
+            }
+        }
+
+        Ok(ret)
+    }
+
+    pub fn get_hourly_query_count(
+        &self,
+        past_hours: u32,
+    ) -> Result<Vec<HourlyQueryCount>, Box<dyn Error>> {
+        let mut ret = Vec::new();
+        let conn = self.conn.lock().unwrap();
+        if conn.as_ref().is_none() {
+            return Ok(ret);
+        }
+
+        let seconds = 3600 * past_hours - utils::seconds_until_next_hour() as u32;
+
+        let conn = conn.as_ref().unwrap();
+        let mut stmt = conn.prepare(
+            "SELECT \
+                    strftime('%Y-%m-%d %H:00:00', datetime(timestamp / 1000, 'unixepoch', 'localtime')) AS hour, \
+                    COUNT(*) AS query_count \
+                 FROM \
+                    domain \
+                 WHERE \
+                    timestamp >= strftime('%s', 'now', 'utc') * 1000 - ? * 1000 \
+                 GROUP BY \
+                    hour \
+                 ORDER BY \
+                    hour DESC;\
+                 ",
+        )?;
+
+        let rows = stmt.query_map([seconds.to_string()], |row| {
+            Ok(HourlyQueryCount {
+                hour: row.get(0)?,
+                query_count: row.get(1)?,
+            })
+        });
+
+        if let Ok(rows) = rows {
+            for row in rows {
+                if let Ok(row) = row {
+                    ret.push(row);
+                }
+            }
+        }
+
+        Ok(ret)
+    }
+
+    pub fn get_domain_top_list(&self, count: u32) -> Result<Vec<DomainQueryCount>, Box<dyn Error>> {
+        let mut ret = Vec::new();
+        let conn = self.conn.lock().unwrap();
+        if conn.as_ref().is_none() {
+            return Ok(ret);
+        }
+
+        let conn = conn.as_ref().unwrap();
+        let mut stmt = conn.prepare(
+            "SELECT domain, COUNT(*) FROM domain GROUP BY domain ORDER BY COUNT(*) DESC LIMIT ?",
+        )?;
+        let rows = stmt.query_map([count.to_string()], |row| {
+            Ok(DomainQueryCount {
+                domain: row.get(0)?,
+                count: row.get(1)?,
+            })
+        });
+
+        if let Err(e) = rows {
+            return Err(Box::new(e));
+        }
+
+        if let Ok(rows) = rows {
+            for row in rows {
+                if let Ok(row) = row {
+                    ret.push(row);
+                }
+            }
+        }
+
+        Ok(ret)
+    }
+
     pub fn get_domain_list(
         &self,
         param: &DomainListGetParam,
@@ -353,6 +562,30 @@ impl DB {
             sql_param.push(v.to_string());
         }
 
+        if let Some(v) = &param.is_blocked {
+            if !sql_where.is_empty() {
+                sql_where.push_str(" AND ");
+            }
+
+            if *v {
+                sql_where.push_str("is_blocked = 1");
+            } else {
+                sql_where.push_str("is_blocked = 0");
+            }
+        }
+
+        if let Some(v) = &param.is_cached {
+            if !sql_where.is_empty() {
+                sql_where.push_str(" AND ");
+            }
+
+            if *v {
+                sql_where.push_str("is_cached = 1");
+            } else {
+                sql_where.push_str("is_cached = 0");
+            }
+        }
+
         if let Some(v) = &param.order {
             if v.eq_ignore_ascii_case("asc") {
                 sql_order.push_str(" ORDER BY id ASC");
@@ -366,7 +599,7 @@ impl DB {
         }
 
         let mut sql = String::new();
-        sql.push_str("SELECT id, timestamp, domain, domain_type, client, domain_group, reply_code FROM domain");
+        sql.push_str("SELECT id, timestamp, domain, domain_type, client, domain_group, reply_code, query_time, ping_time, is_blocked, is_cached FROM domain");
 
         if !sql_where.is_empty() {
             sql.push_str(" WHERE ");
@@ -387,7 +620,7 @@ impl DB {
             return Err("get_domain_list error".into());
         }
 
-        let mut stmt = stmt.unwrap();
+        let mut stmt = stmt?;
 
         let rows = stmt.query_map(rusqlite::params_from_iter(sql_param), |row| {
             Ok(DomainData {
@@ -398,9 +631,17 @@ impl DB {
                 client: row.get(4)?,
                 domain_group: row.get(5)?,
                 reply_code: row.get(6)?,
+                query_time: row.get(7)?,
+                ping_time: row.get(8)?,
+                is_blocked: row.get(9)?,
+                is_cached: row.get(10)?,
             })
         });
 
+        if let Err(e) = rows {
+            return Err(Box::new(e));
+        }
+
         if let Ok(rows) = rows {
             for row in rows {
                 if let Ok(row) = row {
@@ -412,6 +653,23 @@ impl DB {
         Ok(ret)
     }
 
+    pub fn insert_client(&self, client_ip: &str) -> Result<(), Box<dyn Error>> {
+        let conn = self.conn.lock().unwrap();
+        if conn.as_ref().is_none() {
+            return Ok(());
+        }
+
+        let conn = conn.as_ref().unwrap();
+        let mut stmt = conn.prepare("INSERT OR IGNORE INTO client (client_ip) VALUES (?1)")?;
+        let ret = stmt.execute(rusqlite::params![client_ip]);
+
+        if let Err(e) = ret {
+            return Err(Box::new(e));
+        }
+
+        Ok(())
+    }
+
     pub fn get_client_list(&self) -> Result<Vec<ClientData>, Box<dyn Error>> {
         let conn = self.conn.lock().unwrap();
         if conn.as_ref().is_none() {

+ 362 - 109
plugin/smartdns-ui/src/http_api_msg.rs

@@ -16,14 +16,16 @@
  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
  */
 
+use crate::data_server::*;
 use crate::db::*;
 use crate::smartdns::LogLevel;
 use serde_json::json;
+use std::collections::HashMap;
 use std::error::Error;
 
 #[derive(Debug)]
 pub struct AuthUser {
-    pub user: String,
+    pub username: String,
     pub password: String,
 }
 
@@ -35,9 +37,9 @@ pub struct TokenResponse {
 
 pub fn api_msg_parse_auth(data: &str) -> Result<AuthUser, Box<dyn Error>> {
     let v: serde_json::Value = serde_json::from_str(data)?;
-    let user = v["user"].as_str();
-    if user.is_none() {
-        return Err("user not found".into());
+    let username = v["username"].as_str();
+    if username.is_none() {
+        return Err("username not found".into());
     }
     let password = v["password"].as_str();
     if password.is_none() {
@@ -45,14 +47,40 @@ pub fn api_msg_parse_auth(data: &str) -> Result<AuthUser, Box<dyn Error>> {
     }
 
     Ok(AuthUser {
-        user: user.unwrap().to_string(),
+        username: username.unwrap().to_string(),
         password: password.unwrap().to_string(),
     })
 }
 
+pub fn api_msg_parse_auth_password_change(data: &str) -> Result<(String, String), Box<dyn Error>> {
+    let v: serde_json::Value = serde_json::from_str(data)?;
+    let old_password = v["old_password"].as_str();
+    if old_password.is_none() {
+        return Err("old_password not found".into());
+    }
+    let password = v["password"].as_str();
+    if password.is_none() {
+        return Err("password not found".into());
+    }
+
+    Ok((
+        old_password.unwrap().to_string(),
+        password.unwrap().to_string(),
+    ))
+}
+
+pub fn api_msg_gen_auth_password_change(old_password: &str, password: &str) -> String {
+    let json_str = json!({
+        "old_password": old_password,
+        "password": password,
+    });
+
+    json_str.to_string()
+}
+
 pub fn api_msg_gen_auth_login(auth: &AuthUser) -> String {
     let json_str = json!({
-        "user": auth.user,
+        "username": auth.username,
         "password": auth.password,
     });
 
@@ -77,35 +105,62 @@ pub fn api_msg_gen_count(count: i64) -> String {
     json_str.to_string()
 }
 
-pub fn api_msg_parse_domain(data: &str) -> Result<DomainData, Box<dyn Error>> {
-    let v: serde_json::Value = serde_json::from_str(data)?;
-    let id = v["id"].as_u64();
+pub fn api_msg_parse_json_object_domain_value(
+    data: &serde_json::Value,
+) -> Result<DomainData, Box<dyn Error>> {
+    let id = data["id"].as_u64();
     if id.is_none() {
         return Err("id not found".into());
     }
-    let timestamp = v["timestamp"].as_u64();
+
+    let timestamp = data["timestamp"].as_u64();
     if timestamp.is_none() {
         return Err("timestamp not found".into());
     }
-    let domain = v["domain"].as_str();
+
+    let domain = data["domain"].as_str();
     if domain.is_none() {
         return Err("domain not found".into());
     }
-    let domain_type = v["domain-type"].as_u64();
+
+    let domain_type = data["domain_type"].as_u64();
     if domain_type.is_none() {
-        return Err("domain-type not found".into());
+        return Err("domain_type not found".into());
     }
-    let client = v["client"].as_str();
+
+    let client = data["client"].as_str();
     if client.is_none() {
         return Err("client not found".into());
     }
-    let domain_group = v["domain-group"].as_str();
+
+    let domain_group = data["domain_group"].as_str();
     if domain_group.is_none() {
-        return Err("domain-group not found".into());
+        return Err("domain_group not found".into());
     }
-    let reply_code = v["reply-code"].as_u64();
+
+    let reply_code = data["reply_code"].as_u64();
     if reply_code.is_none() {
-        return Err("reply-code not found".into());
+        return Err("reply_code not found".into());
+    }
+
+    let query_time = data["query_time"].as_i64();
+    if query_time.is_none() {
+        return Err("query_time not found".into());
+    }
+
+    let ping_time = data["ping_time"].as_f64();
+    if ping_time.is_none() {
+        return Err("ping_time not found".into());
+    }
+
+    let is_blocked = data["is_blocked"].as_bool();
+    if is_blocked.is_none() {
+        return Err("is_blocked not found".into());
+    }
+
+    let is_cached = data["is_cached"].as_bool();
+    if is_cached.is_none() {
+        return Err("is_cached not found".into());
     }
 
     Ok(DomainData {
@@ -116,94 +171,70 @@ pub fn api_msg_parse_domain(data: &str) -> Result<DomainData, Box<dyn Error>> {
         client: client.unwrap().to_string(),
         domain_group: domain_group.unwrap().to_string(),
         reply_code: reply_code.unwrap() as u16,
+        query_time: query_time.unwrap() as i32,
+        ping_time: ping_time.unwrap(),
+        is_blocked: is_blocked.unwrap(),
+        is_cached: is_cached.unwrap(),
     })
 }
 
-pub fn api_msg_gen_domain(domain: &DomainData) -> String {
-    let json_str = json!({
+pub fn api_msg_parse_domain(data: &str) -> Result<DomainData, Box<dyn Error>> {
+    let v: serde_json::Value = serde_json::from_str(data)?;
+    api_msg_parse_json_object_domain_value(&v)
+}
+
+pub fn api_msg_gen_json_object_domain(domain: &DomainData) -> serde_json::Value {
+    json!({
         "id": domain.id,
         "timestamp": domain.timestamp,
         "domain": domain.domain,
-        "domain-type": domain.domain_type,
+        "domain_type": domain.domain_type,
         "client": domain.client,
-        "domain-group": domain.domain_group,
-        "reply-code": domain.reply_code,
-    });
+        "domain_group": domain.domain_group,
+        "reply_code": domain.reply_code,
+        "query_time": domain.query_time,
+        "ping_time": domain.ping_time,
+        "is_blocked": domain.is_blocked,
+        "is_cached": domain.is_cached,
+    })
+}
 
+pub fn api_msg_gen_domain(domain: &DomainData) -> String {
+    let json_str = api_msg_gen_json_object_domain(domain);
     json_str.to_string()
 }
 
 pub fn api_msg_parse_domain_list(data: &str) -> Result<Vec<DomainData>, Box<dyn Error>> {
     let v: serde_json::Value = serde_json::from_str(data)?;
-    let list_count = v["list-count"].as_u64();
+    let list_count = v["list_count"].as_u64();
     if list_count.is_none() {
-        return Err("list-count not found".into());
+        return Err("list_count not found".into());
     }
     let list_count = list_count.unwrap();
     let mut domain_list = Vec::new();
     for i in 0..list_count {
-        let domain = &v["domian-list"][i as usize];
-        let id = domain["id"].as_u64();
-        if id.is_none() {
-            return Err("id not found".into());
-        }
-        let timestamp = domain["timestamp"].as_u64();
-        if timestamp.is_none() {
-            return Err("timestamp not found".into());
-        }
-        let domain_str = domain["domain"].as_str();
-        if domain_str.is_none() {
-            return Err("domain not found".into());
-        }
-        let domain_type = domain["domain-type"].as_u64();
-        if domain_type.is_none() {
-            return Err("domain-type not found".into());
-        }
-        let client = domain["client"].as_str();
-        if client.is_none() {
-            return Err("client not found".into());
-        }
-        let domain_group = domain["domain-group"].as_str();
-        if domain_group.is_none() {
-            return Err("domain-group not found".into());
-        }
-        let reply_code = domain["reply-code"].as_u64();
-        if reply_code.is_none() {
-            return Err("reply-code not found".into());
-        }
-
-        domain_list.push(DomainData {
-            id: id.unwrap(),
-            timestamp: timestamp.unwrap(),
-            domain: domain_str.unwrap().to_string(),
-            domain_type: domain_type.unwrap() as u32,
-            client: client.unwrap().to_string(),
-            domain_group: domain_group.unwrap().to_string(),
-            reply_code: reply_code.unwrap() as u16,
-        });
+        let domain_object = &v["domain_list"][i as usize];
+        let domain_data = api_msg_parse_json_object_domain_value(domain_object)?;
+        domain_list.push(domain_data);
     }
 
     Ok(domain_list)
 }
 
-pub fn api_msg_gen_domain_list(domain_list: Vec<DomainData>, total_page: u32) -> String {
+pub fn api_msg_gen_domain_list(
+    domain_list: &Vec<DomainData>,
+    total_page: u32,
+    total_count: u32,
+) -> String {
     let json_str = json!({
-        "list-count": domain_list.len(),
-        "total-page": total_page,
-        "domian-list":
+        "list_count": domain_list.len(),
+        "total_page": total_page,
+        "total_count": total_count,
+        "domain_list":
             domain_list
                 .iter()
                 .map(|x| {
-                    let s = json!({
-                        "id": x.id,
-                        "timestamp": x.timestamp,
-                        "domain": x.domain,
-                        "domain-type": x.domain_type,
-                        "client": x.client,
-                        "domain-group": x.domain_group,
-                        "reply-code": x.reply_code,
-                    });
-                    s
+                    api_msg_gen_json_object_domain(x)
                 })
                 .collect::<Vec<serde_json::Value>>()
 
@@ -212,16 +243,45 @@ pub fn api_msg_gen_domain_list(domain_list: Vec<DomainData>, total_page: u32) ->
     json_str.to_string()
 }
 
-pub fn api_msg_gen_client_list(client_list: Vec<ClientData>) -> String {
+pub fn api_msg_parse_client_list(data: &str) -> Result<Vec<ClientData>, Box<dyn Error>> {
+    let v: serde_json::Value = serde_json::from_str(data)?;
+    let list_count = v["list_count"].as_u64();
+    if list_count.is_none() {
+        return Err("list_count not found".into());
+    }
+    let list_count = list_count.unwrap();
+    let mut client_list = Vec::new();
+    for i in 0..list_count {
+        let client_object = &v["client-list"][i as usize];
+        let id = client_object["id"].as_u64();
+        if id.is_none() {
+            return Err("id not found".into());
+        }
+
+        let client_ip = client_object["client_ip"].as_str();
+        if client_ip.is_none() {
+            return Err("client_ip not found".into());
+        }
+
+        client_list.push(ClientData {
+            id: id.unwrap() as u32,
+            client_ip: client_ip.unwrap().to_string(),
+        });
+    }
+
+    Ok(client_list)
+}
+
+pub fn api_msg_gen_client_list(client_list: &Vec<ClientData>) -> String {
     let json_str = json!({
-        "list-count": client_list.len(),
+        "list_count": client_list.len(),
         "client-list":
             client_list
                 .iter()
                 .map(|x| {
                     let s = json!({
                         "id": x.id,
-                        "client-ip": x.client_ip,
+                        "client_ip": x.client_ip,
                     });
                     s
                 })
@@ -235,7 +295,8 @@ pub fn api_msg_gen_client_list(client_list: Vec<ClientData>) -> String {
 pub fn api_msg_auth_token(token: &str, expired: &str) -> String {
     let json_str = json!({
         "token": token,
-        "expires-in": expired,
+        "token_type": "Bearer",
+        "expires_in": expired,
     });
 
     json_str.to_string()
@@ -247,9 +308,9 @@ pub fn api_msg_parse_auth_token(data: &str) -> Result<TokenResponse, Box<dyn Err
     if token.is_none() {
         return Err("token not found".into());
     }
-    let expired = v["expires-in"].as_str();
+    let expired = v["expires_in"].as_str();
     if expired.is_none() {
-        return Err("expires-in not found".into());
+        return Err("expires_in not found".into());
     }
 
     Ok(TokenResponse {
@@ -260,7 +321,7 @@ pub fn api_msg_parse_auth_token(data: &str) -> Result<TokenResponse, Box<dyn Err
 
 pub fn api_msg_gen_cache_number(cache_number: i32) -> String {
     let json_str = json!({
-        "cache-number": cache_number,
+        "cache_number": cache_number,
     });
 
     json_str.to_string()
@@ -268,9 +329,9 @@ pub fn api_msg_gen_cache_number(cache_number: i32) -> String {
 
 pub fn api_msg_parse_cache_number(data: &str) -> Result<i32, Box<dyn Error>> {
     let v: serde_json::Value = serde_json::from_str(data)?;
-    let cache_number = v["cache-number"].as_i64();
+    let cache_number = v["cache_number"].as_i64();
     if cache_number.is_none() {
-        return Err("cache-number not found".into());
+        return Err("cache_number not found".into());
     }
 
     Ok(cache_number.unwrap() as i32)
@@ -296,34 +357,23 @@ pub fn api_msg_parse_error(data: &str) -> Result<String, Box<dyn Error>> {
 
 pub fn api_msg_parse_loglevel(data: &str) -> Result<LogLevel, Box<dyn Error>> {
     let v: serde_json::Value = serde_json::from_str(data)?;
-    let loglevel = v["log-level"].as_str();
+    let loglevel = v["log_level"].as_str();
     if loglevel.is_none() {
         return Err("loglevel not found".into());
     }
 
-    let loglevel = loglevel.unwrap();
-    match loglevel {
-        "debug" => Ok(LogLevel::DEBUG),
-        "info" => Ok(LogLevel::INFO),
-        "notice" => Ok(LogLevel::NOTICE),
-        "warn" => Ok(LogLevel::WARN),
-        "error" => Ok(LogLevel::ERROR),
-        "fatal" => Ok(LogLevel::FATAL),
-        _ => Err("loglevel not found".into()),
+    let ret = loglevel.unwrap().try_into();
+    if ret.is_err() {
+        return Err("log level is invalid".into());
     }
+
+    Ok(ret.unwrap())
 }
 
 pub fn api_msg_gen_loglevel(loglevel: LogLevel) -> String {
-    let loglevel = match loglevel {
-        LogLevel::DEBUG => "debug",
-        LogLevel::INFO => "info",
-        LogLevel::NOTICE => "notice",
-        LogLevel::WARN => "warn",
-        LogLevel::ERROR => "error",
-        LogLevel::FATAL => "fatal",
-    };
+    let loglevel_str = loglevel.to_string();
     let json_str = json!({
-        "log-level": loglevel,
+        "log_level": loglevel_str,
     });
 
     json_str.to_string()
@@ -332,7 +382,7 @@ pub fn api_msg_gen_loglevel(loglevel: LogLevel) -> String {
 pub fn api_msg_gen_version(smartdns_version: &str, ui_version: &str) -> String {
     let json_str = json!({
         "smartdns": smartdns_version,
-        "smartdns-ui": ui_version,
+        "smartdns_ui": ui_version,
     });
 
     json_str.to_string()
@@ -344,10 +394,213 @@ pub fn api_msg_parse_version(data: &str) -> Result<(String, String), Box<dyn Err
     if smartdns.is_none() {
         return Err("smartdns not found".into());
     }
-    let ui = v["smartdns-ui"].as_str();
+    let ui = v["smartdns_ui"].as_str();
     if ui.is_none() {
         return Err("ui not found".into());
     }
 
     Ok((smartdns.unwrap().to_string(), ui.unwrap().to_string()))
+}
+
+pub fn api_msg_gen_key_value(data: &HashMap<String, String>) -> String {
+    let mut json_map = serde_json::Map::new();
+
+    for (key, value) in data {
+        json_map.insert(key.clone(), serde_json::Value::String(value.clone()));
+    }
+
+    serde_json::Value::Object(json_map).to_string()
+}
+
+pub fn api_msg_parse_key_value(data: &str) -> Result<HashMap<String, String>, Box<dyn Error>> {
+    let v: serde_json::Value = serde_json::from_str(data)?;
+    let mut conf_map = HashMap::new();
+
+    if let serde_json::Value::Object(map) = v {
+        for (key, value) in map {
+            if let serde_json::Value::String(value_str) = value {
+                conf_map.insert(key, value_str);
+            }
+        }
+    }
+
+    Ok(conf_map)
+}
+
+pub fn api_msg_gen_top_client_list(client_list: &Vec<ClientQueryCount>) -> String {
+    let json_str = json!({
+        "client_top_list":
+            client_list
+                .iter()
+                .map(|x| {
+                    let s = json!({
+                        "client_ip": x.client_ip,
+                        "query_count": x.count,
+                    });
+                    s
+                })
+                .collect::<Vec<serde_json::Value>>()
+    });
+
+    json_str.to_string()
+}
+
+pub fn api_msg_parse_top_client_list(data: &str) -> Result<Vec<ClientQueryCount>, Box<dyn Error>> {
+    let v: serde_json::Value = serde_json::from_str(data)?;
+    let mut client_list = Vec::new();
+    let top_list = v["client_top_list"].as_array();
+    if top_list.is_none() {
+        return Err("list_count not found".into());
+    }
+
+    for item in top_list.unwrap() {
+        let client_ip = item["client_ip"].as_str();
+        if client_ip.is_none() {
+            return Err("client_ip not found".into());
+        }
+
+        let query_count = item["query_count"].as_u64();
+        if query_count.is_none() {
+            return Err("query_count not found".into());
+        }
+
+        client_list.push(ClientQueryCount {
+            client_ip: client_ip.unwrap().to_string(),
+            count: query_count.unwrap() as u32,
+        });
+    }
+
+    Ok(client_list)
+}
+
+pub fn api_msg_gen_top_domain_list(domain_list: &Vec<DomainQueryCount>) -> String {
+    let json_str = json!({
+        "domain_top_list":
+            domain_list
+                .iter()
+                .map(|x| {
+                    let s = json!({
+                        "domain": x.domain,
+                        "query_count": x.count,
+                    });
+                    s
+                })
+                .collect::<Vec<serde_json::Value>>()
+    });
+
+    json_str.to_string()
+}
+
+pub fn api_msg_parse_top_domain_list(data: &str) -> Result<Vec<DomainQueryCount>, Box<dyn Error>> {
+    let v: serde_json::Value = serde_json::from_str(data)?;
+    let mut domain_list = Vec::new();
+    let top_list = v["domain_top_list"].as_array();
+    if top_list.is_none() {
+        return Err("list_count not found".into());
+    }
+
+    for item in top_list.unwrap() {
+        let domain = item["domain"].as_str();
+        if domain.is_none() {
+            return Err("domain not found".into());
+        }
+
+        let query_count = item["query_count"].as_u64();
+        if query_count.is_none() {
+            return Err("query_count not found".into());
+        }
+
+        domain_list.push(DomainQueryCount {
+            domain: domain.unwrap().to_string(),
+            count: query_count.unwrap() as u32,
+        });
+    }
+
+    Ok(domain_list)
+}
+
+pub fn api_msg_gen_stats_overview(data: &OverviewData) -> String {
+    let json_str = json!({
+        "total_query_count": data.total_query_count,
+        "block_query_count": data.block_query_count,
+        "avg_query_time": data.avg_query_time,
+        "cache_hit_rate": data.cache_hit_rate,
+    });
+
+    json_str.to_string()
+}
+
+pub fn api_msg_parse_stats_overview(data: &str) -> Result<OverviewData, Box<dyn Error>> {
+    let v: serde_json::Value = serde_json::from_str(data)?;
+    let total_query_count = v["total_query_count"].as_u64();
+    if total_query_count.is_none() {
+        return Err("total_query_count not found".into());
+    }
+
+    let block_query_count = v["block_query_count"].as_u64();
+    if block_query_count.is_none() {
+        return Err("block_query_count not found".into());
+    }
+
+    let avg_query_time = v["avg_query_time"].as_f64();
+    if avg_query_time.is_none() {
+        return Err("avg_query_time not found".into());
+    }
+
+    let cache_hit_rate = v["cache_hit_rate"].as_f64();
+    if cache_hit_rate.is_none() {
+        return Err("cache_hit_rate not found".into());
+    }
+
+    Ok(OverviewData {
+        total_query_count: total_query_count.unwrap() as u64,
+        block_query_count: block_query_count.unwrap() as u64,
+        avg_query_time: avg_query_time.unwrap(),
+        cache_hit_rate: cache_hit_rate.unwrap(),
+    })
+}
+
+pub fn api_msg_gen_hourly_query_count(data: &Vec<HourlyQueryCount>) -> String {
+    let json_str = json!({
+        "hourly_query_count":
+            data
+                .iter()
+                .map(|x| {
+                    let s = json!({
+                        "hour": x.hour,
+                        "query_count": x.query_count,
+                    });
+                    s
+                })
+                .collect::<Vec<serde_json::Value>>()
+    });
+    json_str.to_string()
+}
+
+pub fn api_msg_parse_hourly_query_count(data: &str) -> Result<Vec<HourlyQueryCount>, Box<dyn Error>> {
+    let v: serde_json::Value = serde_json::from_str(data)?;
+    let mut hourly_query_count = Vec::new();
+    let hourly_list = v["hourly_query_count"].as_array();
+    if hourly_list.is_none() {
+        return Err("hourly_query_count not found".into());
+    }
+
+    for item in hourly_list.unwrap() {
+        let hour = item["hour"].as_str();
+        if hour.is_none() {
+            return Err("hour not found".into());
+        }
+
+        let query_count = item["query_count"].as_u64();
+        if query_count.is_none() {
+            return Err("query_count not found".into());
+        }
+
+        hourly_query_count.push(HourlyQueryCount {
+            hour: hour.unwrap().to_string(),
+            query_count: query_count.unwrap() as u32,
+        });
+    }
+
+    Ok(hourly_query_count)
 }

+ 305 - 60
plugin/smartdns-ui/src/http_server.rs

@@ -16,57 +16,92 @@
  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
  */
 
+extern crate cfg_if;
+
 use crate::data_server::*;
 use crate::dns_log;
 use crate::http_api_msg::*;
 use crate::http_jwt::*;
 use crate::http_server_api::*;
-use crate::smartdns;
 use crate::smartdns::*;
 
 use bytes::Bytes;
 use http_body_util::Full;
 use hyper::body;
+use hyper::header::HeaderValue;
 use hyper::server::conn::http1;
 use hyper::StatusCode;
 use hyper::{service::service_fn, Request, Response};
 use hyper_util::rt::TokioIo;
-use rustls_pemfile;
 use std::convert::Infallible;
 use std::error::Error;
 use std::fs::Metadata;
-use std::io::BufReader;
 use std::net::SocketAddr;
 use std::path::PathBuf;
 use std::path::{Component, Path};
+use std::sync::MutexGuard;
 use std::sync::{Arc, Mutex};
 use std::thread;
 use std::time::Duration;
+use std::time::Instant;
 use tokio::fs::read;
 use tokio::net::TcpListener;
 use tokio::net::TcpStream;
 use tokio::sync::mpsc;
-use tokio_rustls::{rustls, TlsAcceptor};
+cfg_if::cfg_if! {
+    if #[cfg(feature = "https")] {
+        use rustls_pemfile;
+        use std::io::BufReader;
+        use tokio_rustls::{rustls, TlsAcceptor};
+    }
+}
+
+const HTTP_SERVER_DEFAULT_PASSWORD: &str = "password";
+const HTTP_SERVER_DEFAULT_USERNAME: &str = "admin";
+const HTTP_SERVER_DEFAULT_WWW_ROOT: &str = "/usr/local/shared/smartdns/www";
+const HTTP_SERVER_DEFAULT_IP: &str = "http://0.0.0.0:6080";
 
 #[derive(Clone)]
 pub struct HttpServerConfig {
     pub http_ip: String,
     pub http_root: String,
-    pub user: String,
+    pub username: String,
     pub password: String,
     pub token_expired_time: u32,
+    pub enable_cors: bool,
 }
 
 impl HttpServerConfig {
     pub fn new() -> Self {
         HttpServerConfig {
-            http_ip: "http://0.0.0.0:8080".to_string(),
-            http_root: "/usr/local/shared/smartdns/wwww".to_string(),
-            user: "admin".to_string(),
-            password: "password".to_string(),
+            http_ip: HTTP_SERVER_DEFAULT_IP.to_string(),
+            http_root: HTTP_SERVER_DEFAULT_WWW_ROOT.to_string(),
+            username: HTTP_SERVER_DEFAULT_USERNAME.to_string(),
+            password: HTTP_SERVER_DEFAULT_PASSWORD.to_string(),
             token_expired_time: 600,
+            enable_cors: false,
         }
     }
+
+    pub fn load_config(&mut self, data_server: Arc<DataServer>) -> Result<(), Box<dyn Error>> {
+        if let Some(password) = data_server.get_server_config("smartdns-ui.password") {
+            self.password = password;
+        }
+
+        if let Some(username) = data_server.get_server_config("smartdns-ui.username") {
+            self.username = username;
+        }
+
+        if let Some(enable_cors) = data_server.get_server_config("smartdns-ui.cors-enable") {
+            if enable_cors.eq_ignore_ascii_case("yes") || enable_cors.eq_ignore_ascii_case("true") {
+                self.enable_cors = true;
+            } else {
+                self.enable_cors = false;
+            }
+        }
+
+        Ok(())
+    }
 }
 
 pub struct HttpServerControl {
@@ -150,10 +185,11 @@ pub struct HttpServer {
     conf: Mutex<HttpServerConfig>,
     notify_tx: Option<mpsc::Sender<()>>,
     notify_rx: Mutex<Option<mpsc::Receiver<()>>>,
-    data_server: Mutex<Arc<DataServer>>,
+    data_server: Mutex<Option<Arc<DataServer>>>,
     api: API,
     local_addr: Mutex<Option<SocketAddr>>,
     mime_map: std::collections::HashMap<&'static str, &'static str>,
+    login_attempts: Mutex<(i32, Instant)>,
 }
 
 #[allow(dead_code)]
@@ -163,9 +199,10 @@ impl HttpServer {
             conf: Mutex::new(HttpServerConfig::new()),
             notify_tx: None,
             notify_rx: Mutex::new(None),
-            data_server: Mutex::new(Arc::new(DataServer::new())),
+            data_server: Mutex::new(None),
             api: API::new(),
             local_addr: Mutex::new(None),
+            login_attempts: Mutex::new((0, Instant::now())),
             mime_map: std::collections::HashMap::from([
                 ("htm", "text/html"),
                 ("html", "text/html"),
@@ -199,6 +236,42 @@ impl HttpServer {
         conf.clone()
     }
 
+    pub fn get_conf_mut(&self) -> MutexGuard<HttpServerConfig> {
+        self.conf.lock().unwrap()
+    }
+
+    pub fn login_attempts_reset(&self) {
+        let mut attempts = self.login_attempts.lock().unwrap();
+        attempts.0 = 0;
+        attempts.1 = Instant::now();
+    }
+
+    pub fn login_attempts_check(&self) -> bool {
+        let mut attempts = self.login_attempts.lock().unwrap();
+
+        if attempts.0 == 0 {
+            attempts.1 = Instant::now();
+        }
+
+        attempts.0 += 1;
+
+        if attempts.0 > 5 {
+            let now = Instant::now();
+            let duration = now.duration_since(attempts.1);
+            if duration.as_secs() < 60 {
+                if duration.as_secs() < 30 {
+                    attempts.1 = Instant::now();
+                }
+                return false;
+            }
+
+            attempts.0 = 0;
+            attempts.1 = now;
+        }
+
+        true
+    }
+
     pub fn get_local_addr(&self) -> Option<SocketAddr> {
         let local_addr = self.local_addr.lock().unwrap();
         local_addr.clone()
@@ -218,29 +291,122 @@ impl HttpServer {
 
     fn set_data_server(&self, data_server: Arc<DataServer>) -> Result<(), Box<dyn Error>> {
         let mut _data_server = self.data_server.lock().unwrap();
-        *_data_server = data_server.clone();
+        *_data_server = Some(data_server);
         Ok(())
     }
 
     pub fn get_data_server(&self) -> Arc<DataServer> {
         let data_server = self.data_server.lock().unwrap();
+        let data_server = data_server.as_ref().unwrap();
         Arc::clone(&*data_server)
     }
 
-    pub fn auth_token_is_valid(&self, req: &Request<body::Incoming>) -> bool {
-        let token = req.headers().get("Authorization");
+    pub fn get_token_from_header(
+        req: &Request<body::Incoming>,
+    ) -> Result<Option<String>, Box<dyn Error>> {
+        let token: String;
+        let header_auth = req.headers().get("Authorization");
+        if header_auth.is_none() {
+            let cookie = req.headers().get("Cookie");
+            if cookie.is_none() {
+                return Ok(None);
+            }
+
+            let cookie = cookie.unwrap().to_str();
+            if let Err(_) = cookie {
+                return Ok(None);
+            }
+
+            let cookies = cookie.unwrap().split(';').collect::<Vec<&str>>();
+            let token_cookie = cookies.iter().find(|c| c.trim().starts_with("token="));
+            if token_cookie.is_none() {
+                return Ok(None);
+            }
+
+            let token_cookie = token_cookie.unwrap().trim().strip_prefix("token=");
+            if token_cookie.is_none() {
+                return Ok(None);
+            }
+
+            let data = urlencoding::decode(token_cookie.unwrap());
+            if let Err(_) = data {
+                return Ok(None);
+            }
+
+            let data = data.unwrap();
+            token = data.to_string();
+        } else {
+            let auth = header_auth.unwrap().to_str();
+            if let Err(_) = auth {
+                return Ok(None);
+            }
+
+            token = auth.unwrap().to_string();
+        }
+
+        let token_type = "Bearer";
+        if !token.starts_with(token_type) {
+            return Err("Invalid authorization type".into());
+        }
+
+        let token = token.strip_prefix(token_type).unwrap().trim();
+
+        Ok(Some(token.to_string()))
+    }
+
+    pub fn auth_token_is_valid(
+        &self,
+        req: &Request<body::Incoming>,
+    ) -> Result<bool, Box<dyn Error>> {
+        let token = HttpServer::get_token_from_header(req)?;
+
         if token.is_none() {
-            return false;
+            return Ok(false);
         }
 
-        let token = token.unwrap().to_str().unwrap();
+        let token = token.unwrap();
         let conf = self.conf.lock().unwrap();
-        let jwt = Jwt::new(&conf.user, &conf.password, "", conf.token_expired_time);
+        let jwt = Jwt::new(&conf.username, &conf.password, "", conf.token_expired_time);
+        if !jwt.is_token_valid(token.as_str()) {
+            return Ok(false);
+        }
+        Ok(true)
+    }
+
+    fn server_add_cors_header(
+        &self,
+        origin: &Option<hyper::header::HeaderValue>,
+        response: &mut Response<Full<Bytes>>,
+    ) {
+        if self.get_conf().enable_cors {
+            if let Some(origin) = origin {
+                response
+                    .headers_mut()
+                    .insert("Access-Control-Allow-Origin", origin.clone());
+            } else {
+                response
+                    .headers_mut()
+                    .insert("Access-Control-Allow-Origin", "*".parse().unwrap());
+            }
+
+            response.headers_mut().insert(
+                "Access-Control-Allow-Methods",
+                "GET, POST, PUT, DELETE, OPTIONS, PATCH".parse().unwrap(),
+            );
+
+            response.headers_mut().insert(
+                "Access-Control-Allow-Headers",
+                "Content-Type, Authorization, Set-Cookie".parse().unwrap(),
+            );
+
+            response
+                .headers_mut()
+                .insert("Access-Control-Allow-Credentials", "true".parse().unwrap());
 
-        if !jwt.is_token_valid(token) {
-            return false;
+            response
+                .headers_mut()
+                .insert("Access-Control-Max-Age", "600".parse().unwrap());
         }
-        true
     }
 
     async fn server_handle_http_api_request(
@@ -248,6 +414,12 @@ impl HttpServer {
         req: Request<body::Incoming>,
         _path: PathBuf,
     ) -> Result<Response<Full<Bytes>>, Infallible> {
+        let mut origin: Option<HeaderValue> = None;
+
+        if let Some(o) = req.headers().get("Origin") {
+            origin = Some(o.clone());
+        }
+
         let error_response = |code: StatusCode, msg: &str| {
             let bytes = Bytes::from(api_msg_error(msg));
             let mut response = Response::new(Full::new(bytes));
@@ -258,14 +430,36 @@ impl HttpServer {
                 .headers_mut()
                 .insert("Cache-Control", "no-cache".parse().unwrap());
             *response.status_mut() = code;
+
+            this.server_add_cors_header(&origin, &mut response);
             Ok(response)
         };
 
         dns_log!(LogLevel::DEBUG, "api request: {:?}", req.uri());
+
+        if req.method() == hyper::Method::OPTIONS {
+            let mut response = Response::new(Full::new(Bytes::from("")));
+            response
+                .headers_mut()
+                .insert("Content-Type", "application/json".parse().unwrap());
+            response
+                .headers_mut()
+                .insert("Cache-Control", "no-cache".parse().unwrap());
+            this.server_add_cors_header(&origin, &mut response);
+            return Ok(response);
+        }
+
         match this.api.get_router(req.method(), req.uri().path()) {
             Some((router, param)) => {
-                if router.auth && !this.auth_token_is_valid(&req) {
-                    return error_response(StatusCode::UNAUTHORIZED, "Please login.");
+                if router.auth {
+                    let is_token_valid = this.auth_token_is_valid(&req);
+                    if let Err(e) = is_token_valid {
+                        return error_response(StatusCode::BAD_REQUEST, e.to_string().as_str());
+                    }
+
+                    if !is_token_valid.unwrap() {
+                        return error_response(StatusCode::UNAUTHORIZED, "Please login.");
+                    }
                 }
 
                 if router.method != req.method() {
@@ -285,6 +479,9 @@ impl HttpServer {
                             resp.headers_mut()
                                 .insert("Cache-Control", "no-cache".parse().unwrap());
                         }
+
+                        this.server_add_cors_header(&origin, &mut resp);
+
                         Ok(resp)
                     }
                     Err(e) => Ok(e.to_response()),
@@ -308,6 +505,7 @@ impl HttpServer {
         req: Request<body::Incoming>,
     ) -> Result<Response<Full<Bytes>>, Infallible> {
         let path = PathBuf::from(req.uri().path());
+        let mut is_404 = false;
         let www_root = {
             let conf = this.conf.lock().unwrap();
             PathBuf::from(conf.http_root.clone())
@@ -332,16 +530,34 @@ impl HttpServer {
 
         dns_log!(LogLevel::DEBUG, "page request: {:?}", req.uri());
         let mut filepath = www_root.join(path);
-        let mut path = req.uri().path().to_string();
+        let uri_path = req.uri().path().to_string();
+        let mut path = uri_path.clone();
+
+        if !filepath.exists() || filepath.is_dir() {
+            let suffix = filepath.extension();
+            if suffix.is_none() && !uri_path.ends_with("/") {
+                let check_filepath = filepath.with_extension("html");
+                if check_filepath.exists() {
+                    filepath = check_filepath;
+                    path = format!("{}.html", uri_path);
+                }
+            }
 
-        if !filepath.exists() {
-            filepath = www_root.join("index.html");
-            path = format!("{}/index.html", path);
-        }
+            if filepath.is_dir() {
+                filepath = filepath.join("index.html");
+                path = format!("{}/index.html", uri_path);
+            }
 
-        if filepath.is_dir() {
-            filepath = filepath.join("index.html");
-            path = format!("{}/index.html", path);
+            if !filepath.exists() {
+                filepath = www_root.join("404.html");
+                path = "/404.html".to_string();
+                if !filepath.exists() {
+                    filepath = www_root.join("index.html");
+                    path = format!("/index.html");
+                } else {
+                    is_404 = true;
+                }
+            }
         }
 
         let mut file_meta: Option<Metadata> = None;
@@ -391,7 +607,13 @@ impl HttpServer {
                     let etag = fn_get_etag(&file_meta.as_ref().unwrap());
                     header.insert("ETag", etag.parse().unwrap());
                 }
-                *response.status_mut() = StatusCode::OK;
+
+                if is_404 {
+                    *response.status_mut() = StatusCode::NOT_FOUND;
+                } else {
+                    *response.status_mut() = StatusCode::OK;
+                }
+
                 Ok(response)
             }
             Err(_) => {
@@ -420,6 +642,7 @@ impl HttpServer {
         });
     }
 
+    #[cfg(feature = "https")]
     async fn https_server_handle_conn(
         this: Arc<HttpServer>,
         stream: tokio_rustls::server::TlsStream<TcpStream>,
@@ -440,6 +663,7 @@ impl HttpServer {
         });
     }
 
+    #[cfg(feature = "https")]
     async fn handle_tls_accept(this: Arc<HttpServer>, acceptor: TlsAcceptor, stream: TcpStream) {
         tokio::task::spawn(async move {
             let acceptor_future = acceptor.accept(stream);
@@ -477,30 +701,44 @@ impl HttpServer {
         }
 
         let url = addr.parse::<url::Url>()?;
-        let mut acceptor = None;
-        if url.scheme() == "https" {
-            let cert_info = smartdns::Plugin::smartdns_get_cert()?;
-
-            dns_log!(
-                LogLevel::DEBUG,
-                "cert: {}, key: {}",
-                cert_info.cert,
-                cert_info.key
-            );
-            let cert_chain: Result<Vec<rustls::pki_types::CertificateDer<'_>>, _> =
-                rustls_pemfile::certs(&mut BufReader::new(std::fs::File::open(cert_info.cert)?))
-                    .collect();
-            let cert_chain = cert_chain.unwrap_or_else(|_| Vec::new());
-            let key_der = rustls_pemfile::private_key(&mut BufReader::new(std::fs::File::open(
-                cert_info.key,
-            )?))?
-            .unwrap();
-
-            let config = rustls::ServerConfig::builder()
-                .with_no_client_auth()
-                .with_single_cert(cert_chain, key_der)?;
-            acceptor = Some(TlsAcceptor::from(Arc::new(config)));
+
+        cfg_if::cfg_if! {
+            if #[cfg(feature = "https")]
+            {
+                let mut acceptor = None;
+                if url.scheme() == "https" {
+                    #[cfg(feature = "https")]
+                    let cert_info = Plugin::smartdns_get_cert()?;
+
+                    dns_log!(
+                        LogLevel::DEBUG,
+                        "cert: {}, key: {}",
+                        cert_info.cert,
+                        cert_info.key
+                    );
+                    let cert_chain: Result<Vec<rustls::pki_types::CertificateDer<'_>>, _> =
+                        rustls_pemfile::certs(&mut BufReader::new(std::fs::File::open(
+                            cert_info.cert,
+                        )?))
+                        .collect();
+                    let cert_chain = cert_chain.unwrap_or_else(|_| Vec::new());
+                    let key_der = rustls_pemfile::private_key(&mut BufReader::new(
+                        std::fs::File::open(cert_info.key)?,
+                    ))?
+                    .unwrap();
+
+                    let config = rustls::ServerConfig::builder()
+                        .with_no_client_auth()
+                        .with_single_cert(cert_chain, key_der)?;
+                    acceptor = Some(TlsAcceptor::from(Arc::new(config)));
+                }
+            } else {
+                if url.scheme() == "https" {
+                    return Err("https is not supported.".into());
+                }
+            }
         }
+
         let host = url.host_str().unwrap_or("0.0.0.0");
         let port = url.port().unwrap_or(80);
         let sock_addr = format!("{}:{}", host, port).parse::<SocketAddr>()?;
@@ -527,12 +765,19 @@ impl HttpServer {
                             ka = ka.with_interval(Duration::from_secs(10));
                             sock_ref.set_tcp_keepalive(&ka)?;
                             sock_ref.set_nonblocking(true)?;
-                            if acceptor.is_some() {
-                                let acceptor = acceptor.clone().unwrap().clone();
-                                let this_clone = this.clone();
-                                HttpServer::handle_tls_accept(this_clone, acceptor, stream).await;
-                            } else {
-                                HttpServer::http_server_handle_conn(this.clone(), stream).await;
+                            cfg_if::cfg_if! {
+                                if #[cfg(feature = "https")]
+                                {
+                                    if acceptor.is_some() {
+                                        let acceptor = acceptor.clone().unwrap().clone();
+                                        let this_clone = this.clone();
+                                        HttpServer::handle_tls_accept(this_clone, acceptor, stream).await;
+                                    } else {
+                                        HttpServer::http_server_handle_conn(this.clone(), stream).await;
+                                    }
+                                } else  {
+                                    HttpServer::http_server_handle_conn(this.clone(), stream).await;
+                                }
                             }
                         }
                         Err(e) => {

+ 279 - 26
plugin/smartdns-ui/src/http_server_api.rs

@@ -21,7 +21,7 @@ use crate::http_api_msg::*;
 use crate::http_error::*;
 use crate::http_jwt::*;
 use crate::http_server::*;
-use crate::http_server_log_stream;
+use crate::http_server_stream;
 use crate::smartdns;
 use crate::smartdns::*;
 use crate::Plugin;
@@ -37,6 +37,9 @@ use std::pin::Pin;
 use std::sync::Arc;
 use url::form_urlencoded;
 
+const PASSWORD_CONFIG_KEY: &str = "smartdns-ui.password";
+const REST_API_PATH: &str = "/api";
+
 type APIRouteFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
 type APIRouterFun = fn(
     this: Arc<HttpServer>,
@@ -73,6 +76,8 @@ impl API {
         api.register(Method::PUT, "/api/cache/flush",  true, APIRoute!(API::api_cache_flush));
         api.register(Method::GET, "/api/cache/count",  true, APIRoute!(API::api_cache_count));
         api.register(Method::POST, "/api/auth/login",  false, APIRoute!(API::api_auth_login));
+        api.register(Method::POST, "/api/auth/logout",  false, APIRoute!(API::api_auth_logout));
+        api.register(Method::PUT, "/api/auth/password",  false, APIRoute!(API::api_auth_change_password));
         api.register(Method::POST, "/api/auth/refresh",  true, APIRoute!(API::api_auth_refresh));
         api.register(Method::GET, "/api/domain",  true, APIRoute!(API::api_domain_get_list));
         api.register(Method::DELETE, "/api/domain",  true, APIRoute!(API::api_domain_delete_list));
@@ -84,6 +89,13 @@ impl API {
         api.register(Method::PUT, "/api/log/level", true, APIRoute!(API::api_log_set_level));
         api.register(Method::GET, "/api/log/level", true, APIRoute!(API::api_log_get_level));
         api.register(Method::GET, "/api/server/version", false, APIRoute!(API::api_server_version));
+        api.register(Method::GET, "/api/config/settings", true, APIRoute!(API::api_config_get_settings));
+        api.register(Method::PUT, "/api/config/settings", true, APIRoute!(API::api_config_set_settings));
+        api.register(Method::GET, "/api/stats/top/client", true, APIRoute!(API::api_stats_get_top_client));
+        api.register(Method::GET, "/api/stats/top/domain", true, APIRoute!(API::api_stats_get_top_domain));
+        api.register(Method::GET, "/api/stats/overview", true, APIRoute!(API::api_stats_get_overview));
+        api.register(Method::GET, "/api/stats/hourly-query-count", true, APIRoute!(API::api_stats_get_hourly_query_count));
+        api.register(Method::GET, "/api/tool/term", true, APIRoute!(API::api_tool_term));
         api
     }
 
@@ -212,43 +224,59 @@ impl API {
         _param: APIRouteParam,
         req: Request<body::Incoming>,
     ) -> Result<Response<Full<Bytes>>, HttpError> {
-        let token = req.headers().get("Authorization");
+        let token = HttpServer::get_token_from_header(&req)?;
+        let unauth_response =
+            || API::response_error(StatusCode::UNAUTHORIZED, "Incorrect username or password.");
+
         if token.is_none() {
-            return API::response_error(
-                StatusCode::UNAUTHORIZED,
-                "Incorrect username or password.",
-            );
+            return unauth_response();
         }
 
+        let token = token.unwrap();
         let conf = this.get_conf();
-
         let jtw = Jwt::new(
-            &conf.user.as_str(),
+            &conf.username.as_str(),
             conf.password.as_str(),
             "",
             conf.token_expired_time,
         );
 
-        let token = token.unwrap().to_str().unwrap();
-        let token_new = jtw.refresh_token(token);
+        let calim = jtw.decode_token(token.as_str());
+        if calim.is_err() {
+            return unauth_response();
+        }
+
+        let token_new = jtw.refresh_token(token.as_str());
         if token_new.is_err() {
-            return API::response_error(
-                StatusCode::UNAUTHORIZED,
-                "Incorrect username or password.",
-            );
+            return unauth_response();
         }
+
         let token_new = token_new.unwrap();
-        API::response_build(
+        let mut resp = API::response_build(
             StatusCode::OK,
             api_msg_auth_token(&token_new.token, &token_new.expire),
-        )
+        );
+
+        let cookie_token = format!("Bearer {}", token_new.token);
+        let token_urlencode = urlencoding::encode(cookie_token.as_str());
+        let cookie = format!(
+            "token={}; HttpOnly; Max-Age={}; Path={}",
+            token_urlencode, token_new.expire, REST_API_PATH
+        );
+
+        resp.as_mut()
+            .unwrap()
+            .headers_mut()
+            .insert(hyper::header::SET_COOKIE, cookie.parse().unwrap());
+
+        resp
     }
 
     /// Login
     /// API: POST /api/auth/login
     ///     body:
     /// {
-    ///   "user": "admin"
+    ///   "username": "admin"
     ///   "password": "password"
     /// }
     async fn api_auth_login(
@@ -265,24 +293,121 @@ impl API {
         let conf = this.get_conf();
         let userinfo = userinfo.unwrap();
 
-        if userinfo.user != conf.user || userinfo.password != conf.password {
+        if !this.login_attempts_check() {
+            return API::response_error(
+                StatusCode::FORBIDDEN,
+                "Too many login attempts, please try again later.",
+            );
+        }
+
+        if userinfo.username != conf.username || userinfo.password != conf.password {
             return API::response_error(
                 StatusCode::UNAUTHORIZED,
                 "Incorrect username or password.",
             );
         }
 
+        this.login_attempts_reset();
+
         let jtw = Jwt::new(
-            userinfo.user.as_str(),
+            userinfo.username.as_str(),
             conf.password.as_str(),
             "",
             conf.token_expired_time,
         );
         let token = jtw.encode_token();
-        API::response_build(
+        let mut resp = API::response_build(
             StatusCode::OK,
             api_msg_auth_token(&token.token, &token.expire),
-        )
+        );
+
+        let cookie_token = format!("Bearer {}", token.token);
+        let token_urlencode = urlencoding::encode(cookie_token.as_str());
+        let cookie = format!(
+            "token={}; HttpOnly; Max-Age={}; Path={}",
+            token_urlencode, token.expire, REST_API_PATH
+        );
+
+        resp.as_mut()
+            .unwrap()
+            .headers_mut()
+            .insert(hyper::header::SET_COOKIE, cookie.parse().unwrap());
+
+        resp
+    }
+
+    async fn api_auth_logout(
+        _this: Arc<HttpServer>,
+        _param: APIRouteParam,
+        _req: Request<body::Incoming>,
+    ) -> Result<Response<Full<Bytes>>, HttpError> {
+        let mut response = Response::new(Full::new(Bytes::from("")));
+
+        let cookie = format!("token=none; HttpOnly; Max-Age=1; Path={}", REST_API_PATH);
+
+        response
+            .headers_mut()
+            .insert(hyper::header::SET_COOKIE, cookie.parse().unwrap());
+        *response.status_mut() = StatusCode::NO_CONTENT;
+        Ok(response)
+    }
+
+    async fn api_auth_change_password(
+        this: Arc<HttpServer>,
+        _param: APIRouteParam,
+        req: Request<body::Incoming>,
+    ) -> Result<Response<Full<Bytes>>, HttpError> {
+        let unauth_response =
+            || API::response_error(StatusCode::UNAUTHORIZED, "Incorrect username or password.");
+        let token = HttpServer::get_token_from_header(&req)?;
+        let whole_body = String::from_utf8(req.into_body().collect().await?.to_bytes().into())?;
+        if token.is_none() {
+            return unauth_response();
+        }
+
+        let password_info = api_msg_parse_auth_password_change(whole_body.as_str());
+        if let Err(e) = password_info {
+            return API::response_error(StatusCode::BAD_REQUEST, e.to_string().as_str());
+        }
+
+        let password_info = password_info.unwrap();
+        if password_info.0 == password_info.1 {
+            return API::response_error(
+                StatusCode::BAD_REQUEST,
+                "The new password is the same as the old password.",
+            );
+        }
+
+        let token = token.unwrap();
+        let mut conf = this.get_conf_mut();
+        let jtw = Jwt::new(
+            &conf.username.as_str(),
+            password_info.0.as_str(),
+            "",
+            conf.token_expired_time,
+        );
+
+        if !this.login_attempts_check() {
+            return API::response_error(
+                StatusCode::FORBIDDEN,
+                "Too many login attempts, please try again later.",
+            );
+        }
+
+        let calim = jtw.decode_token(token.as_str());
+        if calim.is_err() {
+            return API::response_error(StatusCode::FORBIDDEN, "Incorrect password.");
+        }
+
+        let data_server = this.get_data_server();
+        conf.password = password_info.1.clone();
+        let ret = data_server.set_config(PASSWORD_CONFIG_KEY, password_info.1.as_str());
+        if let Err(e) = ret {
+            return API::response_error(StatusCode::INTERNAL_SERVER_ERROR, e.to_string().as_str());
+        }
+
+        this.login_attempts_reset();
+        API::response_build(StatusCode::NO_CONTENT, "".to_string())
     }
 
     /// Restart the service <br>
@@ -455,7 +580,9 @@ impl API {
         if list_count % page_size != 0 {
             total_page += 1;
         }
-        let body = api_msg_gen_domain_list(domain_list, total_page);
+
+        let total_count = data_server.get_domain_list_count();
+        let body = api_msg_gen_domain_list(&domain_list, total_page, total_count);
 
         API::response_build(StatusCode::OK, body)
     }
@@ -498,7 +625,7 @@ impl API {
     ) -> Result<Response<Full<Bytes>>, HttpError> {
         let data_server = this.get_data_server();
         let client_list: Vec<ClientData> = data_server.get_client_list()?;
-        let body = api_msg_gen_client_list(client_list);
+        let body = api_msg_gen_client_list(&client_list);
 
         API::response_build(StatusCode::OK, body)
     }
@@ -513,7 +640,7 @@ impl API {
                 .map_err(|e| HttpError::new(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
 
             tokio::spawn(async move {
-                if let Err(e) = http_server_log_stream::serve_log_stream(websocket).await {
+                if let Err(e) = http_server_stream::serve_log_stream(websocket).await {
                     eprintln!("Error in websocket connection: {e}");
                 }
             });
@@ -525,7 +652,7 @@ impl API {
     }
 
     async fn api_log_set_level(
-        _this: Arc<HttpServer>,
+        this: Arc<HttpServer>,
         _param: APIRouteParam,
         _req: Request<body::Incoming>,
     ) -> Result<Response<Full<Bytes>>, HttpError> {
@@ -537,6 +664,8 @@ impl API {
 
         let level = level.unwrap();
         dns_log_set_level(level);
+        let data_server = this.get_data_server();
+        _ = data_server.set_config("log-level", level.to_string().as_str());
         API::response_build(StatusCode::NO_CONTENT, "".to_string())
     }
 
@@ -550,7 +679,7 @@ impl API {
         API::response_build(StatusCode::OK, msg)
     }
 
-    async  fn api_server_version(
+    async fn api_server_version(
         _this: Arc<HttpServer>,
         _param: APIRouteParam,
         _req: Request<body::Incoming>,
@@ -560,4 +689,128 @@ impl API {
         let msg = api_msg_gen_version(server_version, ui_version);
         API::response_build(StatusCode::OK, msg)
     }
+
+    async fn api_config_get_settings(
+        this: Arc<HttpServer>,
+        _param: APIRouteParam,
+        _req: Request<body::Incoming>,
+    ) -> Result<Response<Full<Bytes>>, HttpError> {
+        let data_server = this.get_data_server();
+        let settings = data_server.get_config_list();
+        if settings.is_err() {
+            return API::response_error(StatusCode::NOT_FOUND, "Not found");
+        }
+
+        let mut settings = settings.unwrap();
+        let pass = settings.get(PASSWORD_CONFIG_KEY);
+        if pass.is_some() {
+            let pass = "********".to_string();
+            settings.insert(PASSWORD_CONFIG_KEY.to_string(), pass);
+        }
+        let msg = api_msg_gen_key_value(&settings);
+        API::response_build(StatusCode::OK, msg)
+    }
+
+    async fn api_config_set_settings(
+        this: Arc<HttpServer>,
+        _param: APIRouteParam,
+        req: Request<body::Incoming>,
+    ) -> Result<Response<Full<Bytes>>, HttpError> {
+        let data_server = this.get_data_server();
+        let whole_body = String::from_utf8(req.into_body().collect().await?.to_bytes().into())?;
+        let settings = api_msg_parse_key_value(whole_body.as_str());
+        if let Err(e) = settings {
+            return API::response_error(StatusCode::BAD_REQUEST, e.to_string().as_str());
+        }
+
+        let settings = settings.unwrap();
+        for (key, value) in settings {
+            if key == PASSWORD_CONFIG_KEY {
+                continue;
+            }
+            let ret = data_server.set_config(key.as_str(), value.as_str());
+            if let Err(e) = ret {
+                return API::response_error(
+                    StatusCode::INTERNAL_SERVER_ERROR,
+                    e.to_string().as_str(),
+                );
+            }
+        }
+
+        API::response_build(StatusCode::NO_CONTENT, "".to_string())
+    }
+
+    async fn api_stats_get_top_client(
+        this: Arc<HttpServer>,
+        _param: APIRouteParam,
+        _req: Request<body::Incoming>,
+    ) -> Result<Response<Full<Bytes>>, HttpError> {
+        let data_server = this.get_data_server();
+        let params = API::get_params(&_req);
+        let count = API::params_get_value_default(&params, "count", 10 as u32)?;
+        let client_list = data_server.get_top_client_top_list(count)?;
+        let body = api_msg_gen_top_client_list(&client_list);
+
+        API::response_build(StatusCode::OK, body)
+    }
+
+    async fn api_stats_get_top_domain(
+        this: Arc<HttpServer>,
+        _param: APIRouteParam,
+        _req: Request<body::Incoming>,
+    ) -> Result<Response<Full<Bytes>>, HttpError> {
+        let data_server = this.get_data_server();
+        let params = API::get_params(&_req);
+        let count = API::params_get_value_default(&params, "count", 10 as u32)?;
+        let domain_list = data_server.get_top_domain_top_list(count)?;
+        let body = api_msg_gen_top_domain_list(&domain_list);
+
+        API::response_build(StatusCode::OK, body)
+    }
+
+    async fn api_stats_get_overview(
+        this: Arc<HttpServer>,
+        _param: APIRouteParam,
+        _req: Request<body::Incoming>,
+    ) -> Result<Response<Full<Bytes>>, HttpError> {
+        let data_server = this.get_data_server();
+        let overview = data_server.get_overview()?;
+        let body = api_msg_gen_stats_overview(&overview);
+        API::response_build(StatusCode::OK, body)
+    }
+
+    async fn api_stats_get_hourly_query_count(
+        this: Arc<HttpServer>,
+        _param: APIRouteParam,
+        _req: Request<body::Incoming>,
+    ) -> Result<Response<Full<Bytes>>, HttpError> {
+        let params = API::get_params(&_req);
+        let past_hours = API::params_get_value_default(&params, "past_hours", 24 as u32)?;
+        let data_server = this.get_data_server();
+        let hourly_query_count = data_server.get_hourly_query_count(past_hours)?;
+        let body = api_msg_gen_hourly_query_count(&hourly_query_count);
+        API::response_build(StatusCode::OK, body)
+    }
+
+
+    async fn api_tool_term(
+        _this: Arc<HttpServer>,
+        _param: APIRouteParam,
+        mut req: Request<body::Incoming>,
+    ) -> Result<Response<Full<Bytes>>, HttpError> {
+        if hyper_tungstenite::is_upgrade_request(&req) {
+            let (response, websocket) = hyper_tungstenite::upgrade(&mut req, None)
+                .map_err(|e| HttpError::new(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
+
+            tokio::spawn(async move {
+                if let Err(e) = http_server_stream::serve_term(websocket).await {
+                    eprintln!("Error in websocket connection: {e}");
+                }
+            });
+
+            Ok(response)
+        } else {
+            return API::response_error(StatusCode::BAD_REQUEST, "Need websocket upgrade.");
+        }
+    }
 }

+ 0 - 48
plugin/smartdns-ui/src/http_server_log_stream.rs

@@ -1,48 +0,0 @@
-/*************************************************************************
- *
- * Copyright (C) 2018-2024 Ruilin Peng (Nick) <[email protected]>.
- *
- * smartdns is free software: you can redistribute it and/or modify
- * it under the terms of the GNU General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * smartdns is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program.  If not, see <http://www.gnu.org/licenses/>.
- */
-
-use futures::stream::StreamExt;
-use hyper_tungstenite::{tungstenite, HyperWebsocket};
-use tungstenite::Message;
-
-type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
-
-pub async fn serve_log_stream(websocket: HyperWebsocket) -> Result<(), Error> {
-    let mut websocket = websocket.await?;
-    loop {
-        tokio::select! {
-            msg = websocket.next() => {
-                let message = msg.ok_or("websocket closed")??;
-                match message {
-                    Message::Text(_msg) => {}
-                    Message::Binary(_msg) => {}
-                    Message::Ping(_msg) => {}
-                    Message::Pong(_msg) => {}
-                    Message::Close(_msg) => {
-                        break;
-                    }
-                    Message::Frame(_msg) => {
-                        unreachable!();
-                    }
-                }
-            }
-        }
-    }
-
-    Ok(())
-}

+ 261 - 0
plugin/smartdns-ui/src/http_server_stream.rs

@@ -0,0 +1,261 @@
+/*************************************************************************
+ *
+ * Copyright (C) 2018-2024 Ruilin Peng (Nick) <[email protected]>.
+ *
+ * smartdns is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * smartdns is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ */
+
+use futures::sink::SinkExt;
+use futures::stream::StreamExt;
+use std::os::fd::AsRawFd;
+use tokio_fd::AsyncFd;
+
+use hyper_tungstenite::{tungstenite, HyperWebsocket};
+use nix::libc::*;
+use tokio::io::{AsyncReadExt, AsyncWriteExt};
+use tungstenite::Message;
+
+use crate::dns_log;
+use crate::smartdns::LogLevel;
+
+type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
+
+pub async fn serve_log_stream(websocket: HyperWebsocket) -> Result<(), Error> {
+    let mut websocket = websocket.await?;
+
+    loop {
+        tokio::select! {
+            msg = websocket.next() => {
+                let message = msg.ok_or("websocket closed")??;
+                match message {
+                    Message::Text(_msg) => {}
+                    Message::Binary(_msg) => {}
+                    Message::Ping(_msg) => {}
+                    Message::Pong(_msg) => {}
+                    Message::Close(_msg) => {
+                        break;
+                    }
+                    Message::Frame(_msg) => {
+                        unreachable!();
+                    }
+                }
+            }
+        }
+    }
+
+    Ok(())
+}
+
+enum TermMessageType {
+    Data,
+    Err,
+    Resize,
+}
+
+impl TryFrom<u8> for TermMessageType {
+    type Error = ();
+
+    fn try_from(value: u8) -> Result<Self, Self::Error> {
+        match value {
+            0 => Ok(TermMessageType::Data),
+            1 => Ok(TermMessageType::Err),
+            2 => Ok(TermMessageType::Resize),
+            _ => Err(()),
+        }
+    }
+}
+
+#[cfg(target_os = "linux")]
+pub async fn serve_term(websocket: HyperWebsocket) -> Result<(), Error> {
+    type WsType =
+        hyper_tungstenite::WebSocketStream<hyper_util::rt::TokioIo<hyper::upgrade::Upgraded>>;
+    let mut websocket = websocket.await?;
+
+    let (pid, asyncfd) = unsafe {
+        let mut fd_master: std::os::fd::RawFd = 0;
+        let mut ws = winsize {
+            ws_row: 24,
+            ws_col: 80,
+            ws_xpixel: 0,
+            ws_ypixel: 0,
+        };
+        let pid = forkpty(
+            &mut fd_master,
+            std::ptr::null_mut(),
+            std::ptr::null_mut(),
+            &mut ws,
+        );
+        if pid < 0 {
+            return Err("forkpty failed".into());
+        }
+
+        if pid == 0 {
+            let _ = ioctl(0, TIOCSCTTY, 1);
+            for i in 3..1024 {
+                close(i);
+            }
+            use std::ffi::CString;
+
+            let find_cmd = |cmd: &str| -> Result<String, Box<dyn std::error::Error>> {
+                let env_path = std::env::var("PATH")?;
+                let paths = env_path.split(':');
+
+                for path in paths {
+                    let cmd_path = format!("{}/{}", path, cmd);
+                    if std::fs::metadata(&cmd_path).is_ok() {
+                        return Ok(cmd_path);
+                    }
+                }
+
+                Err("command not found".into())
+            };
+
+            let su_path = find_cmd("su");
+            let login_path = find_cmd("login");
+
+            if su_path.is_ok() {
+                let uid = getuid();
+                let pw = getpwuid(uid);
+
+                if pw.is_null() {
+                    return Err("getpwuid failed".into());
+                }
+
+                let arg0 = CString::new("su").unwrap();
+                let arg1 = CString::new("-").unwrap();
+                let arg2 = (*pw).pw_name;
+                let login_message =
+                    format!("Login as {}", std::ffi::CStr::from_ptr(arg2).to_str()?);
+                println!("{}", login_message);
+
+                let cmd_path = CString::new(su_path.unwrap()).unwrap();
+                let args = [arg0.as_ptr(), arg1.as_ptr(), arg2, std::ptr::null()];
+                let _ = execv(cmd_path.as_ptr(), args.as_ptr());
+            } else if login_path.is_ok() {
+                let arg0 = CString::new("login").unwrap();
+                let cmd_path = CString::new(login_path.unwrap()).unwrap();
+                let args = [arg0.as_ptr()];
+                let _ = execv(cmd_path.as_ptr(), args.as_ptr());
+            }
+
+            println!("Failed to execute `su` or `login`");
+            exit(1);
+        }
+
+        (pid, AsyncFd::try_from(fd_master))
+    };
+
+    if let Err(e) = asyncfd {
+        return Err(e.into());
+    }
+
+    let send_error_msg = |ws: &mut WsType, msg: &str| {
+        let mut buf = [0u8; 4096];
+        buf[0] = TermMessageType::Err as u8;
+        buf[1..msg.len() + 1].copy_from_slice(msg.as_bytes());
+        let msg = Message::Binary(buf[..msg.len() + 1].to_vec());
+        let _ = ws.send(msg);
+
+        let msg = Message::Close(None);
+        let _ = ws.send(msg);
+    };
+
+    let mut asyncfd = asyncfd.unwrap();
+    loop {
+        let mut buf = [0u8; 4096];
+        let (data_type, data_buf) = buf.split_at_mut(1);
+        let data_len;
+        tokio::select! {
+            n = asyncfd.read(data_buf) => {
+                match n {
+                    Ok(n) => {
+                        if n == 0 {
+                            websocket.send(Message::Close(None)).await?;
+                            dns_log!(LogLevel::ERROR, "EOF");
+                            break;
+                        }
+                        data_len = n + 1;
+                        data_type[0] = TermMessageType::Data as u8;
+                        let msg = Message::Binary(buf[..data_len].to_vec());
+                        websocket.send(msg).await?;
+                    }
+                    Err(e) => {
+                        send_error_msg(&mut websocket, e.to_string().as_str());
+                        dns_log!(LogLevel::ERROR, "Error: {}", e.to_string().as_str());
+                        break;
+                    }
+                }
+            }
+            msg = websocket.next() => {
+                let message = msg.ok_or("websocket closed")??;
+                match message {
+                    Message::Text(msg) => {
+                        asyncfd.write(msg.as_bytes()).await?;
+                    }
+                    Message::Binary(msg) => {
+                        if msg.len() == 0 {
+                            continue;
+                        }
+
+                        let msg_type = TermMessageType::try_from(msg[0]);
+                        if msg_type.is_err() {
+                            send_error_msg(&mut websocket, "invalid message type");
+                            break;
+                        }
+
+                        let msg_type = msg_type.unwrap();
+                        let msg = &msg[1..];
+
+                        match msg_type {
+                            TermMessageType::Resize => {
+                                let ws = winsize {
+                                    ws_col: u16::from_be_bytes(msg[0..2].try_into().unwrap()),
+                                    ws_row: u16::from_be_bytes(msg[2..4].try_into().unwrap()),
+                                    ws_xpixel: 0,
+                                    ws_ypixel: 0,
+                                };
+                                unsafe {
+                                    let _ = ioctl(asyncfd.as_raw_fd(), TIOCSWINSZ, &ws);
+                                }
+                            }
+                            TermMessageType::Data => {
+                                asyncfd.write(msg).await?;
+                            }
+                            _ => {
+                                continue;
+                            }
+                        }
+                    }
+                    Message::Ping(_msg) => {}
+                    Message::Pong(_msg) => {}
+                    Message::Close(_msg) => {
+                        dns_log!(LogLevel::DEBUG, "Peer term closed");
+                        break;
+                    }
+                    Message::Frame(_msg) => {
+                        unreachable!();
+                    }
+                }
+            }
+
+        }
+    }
+
+    unsafe {
+        let _ = kill(pid, SIGKILL);
+        let _ = waitpid(pid, std::ptr::null_mut(), 0);
+    }
+
+    Ok(())
+}

+ 17 - 1
plugin/smartdns-ui/src/lib.rs

@@ -23,11 +23,14 @@ pub mod http_error;
 pub mod http_jwt;
 pub mod http_server;
 pub mod http_server_api;
-pub mod http_server_log_stream;
+pub mod http_server_stream;
 pub mod plugin;
+pub mod data_stats;
+pub mod utils;
 pub mod smartdns;
 
 use ctor::ctor;
+use ctor::dtor;
 #[cfg(not(test))]
 use plugin::*;
 use smartdns::*;
@@ -40,6 +43,13 @@ fn lib_init_ops() {
     }
 }
 
+#[cfg(not(test))]
+fn lib_deinit_ops() {
+    unsafe {
+        PLUGIN.clear_operation();
+    }
+}
+
 #[cfg(test)]
 fn lib_init_smartdns_lib() {
     smartdns::dns_log_set_level(LogLevel::DEBUG);
@@ -53,3 +63,9 @@ fn lib_init() {
     #[cfg(test)]
     lib_init_smartdns_lib();
 }
+
+#[dtor]
+fn lib_deinit() {
+    #[cfg(not(test))]
+    lib_deinit_ops();
+}

+ 16 - 20
plugin/smartdns-ui/src/plugin.rs

@@ -70,22 +70,13 @@ impl SmartdnsPlugin {
             }
         };
 
-        let www_root =
-            Plugin::dns_conf_plugin_config("smartdns-ui.www-root", "/usr/share/smartdns/www");
-        self.http_conf.http_root = www_root;
-
-        let user = Plugin::dns_conf_plugin_config("smartdns-ui.user", "");
-        let password = Plugin::dns_conf_plugin_config("smartdns-ui.password", "");
-        let ip = Plugin::dns_conf_plugin_config("smartdns-ui.ip", "");
-        if user.len() > 0  {
-            self.http_conf.user = user;
+        let www_root = Plugin::dns_conf_plugin_config("smartdns-ui.www-root");
+        if let Some(www_root) = www_root {
+            self.http_conf.http_root = www_root;
         }
 
-        if  password.len() > 0 {
-            self.http_conf.password = password;
-        }
-
-        if ip.len() > 0 {
+        let ip = Plugin::dns_conf_plugin_config("smartdns-ui.ip");
+        if let Some(ip) = ip {
             self.http_conf.http_ip = ip;
         }
 
@@ -118,9 +109,18 @@ impl SmartdnsPlugin {
         Ok(())
     }
 
+    pub fn load_config(&mut self) -> Result<(), Box<dyn Error>> {
+        let data_server = self.get_data_server();
+        self.data_conf.load_config(data_server.clone())?;
+        self.http_conf.load_config(data_server.clone())?;
+        Ok(())
+    }
+
     pub fn start(&mut self, args: &Vec<String>) -> Result<(), Box<dyn Error>> {
         self.parser_args(args)?;
-        self.data_server_ctl.start_data_server(&self.data_conf)?;
+        self.data_server_ctl.init_db(&self.data_conf)?;
+        self.load_config()?;
+        self.data_server_ctl.start_data_server()?;
         self.http_server_ctl
             .start_http_server(&self.http_conf, self.data_server_ctl.get_data_server())?;
 
@@ -135,11 +135,7 @@ impl SmartdnsPlugin {
     pub fn query_complete(&self, request: &mut DnsRequest) {
         let ret = self.data_server_ctl.send_request(request);
         if let Err(e) = ret {
-            dns_log!(
-                LogLevel::ERROR,
-                "send data to data server error: {}",
-                e.to_string()
-            );
+            dns_log!(LogLevel::ERROR, "send data to data server error: {}", e);
             return;
         }
     }

+ 151 - 28
plugin/smartdns-ui/src/smartdns.rs

@@ -21,25 +21,15 @@
 #![allow(non_snake_case)]
 #![allow(dead_code)]
 #![allow(unused_imports)]
-mod smartdns_c {
-    use libc::gid_t;
-    use libc::in6_addr;
-    use libc::in_addr;
-    use libc::sockaddr;
-    use libc::sockaddr_storage;
-    use libc::socklen_t;
-    use libc::time_t;
-    use libc::timeval;
-    use libc::tm;
-    use libc::uid_t;
-    use u32 as u_int;
+#![allow(improper_ctypes)]
+pub mod smartdns_c {
     include!(concat!(env!("OUT_DIR"), "/smartdns_bindings.rs"));
-
 }
 
-extern crate libc;
 use std::error::Error;
 use std::ffi::CString;
+use std::fmt;
+use std::os::raw::*;
 
 #[repr(C)]
 #[derive(Copy, Clone, Debug, PartialEq)]
@@ -53,6 +43,25 @@ pub enum LogLevel {
     FATAL = 5,
 }
 
+impl From<LogLevel> for u32 {
+    fn from(level: LogLevel) -> u32 {
+        level as u32
+    }
+}
+
+impl ToString for LogLevel {
+    fn to_string(&self) -> String {
+        match self {
+            LogLevel::DEBUG => "debug".to_string(),
+            LogLevel::INFO => "info".to_string(),
+            LogLevel::NOTICE => "notice".to_string(),
+            LogLevel::WARN => "warn".to_string(),
+            LogLevel::ERROR => "error".to_string(),
+            LogLevel::FATAL => "fatal".to_string(),
+        }
+    }
+}
+
 impl TryFrom<u32> for LogLevel {
     type Error = ();
 
@@ -69,6 +78,29 @@ impl TryFrom<u32> for LogLevel {
     }
 }
 
+impl TryFrom<&str> for LogLevel {
+    type Error = ();
+
+    fn try_from(value: &str) -> Result<Self, Self::Error> {
+        match value.to_lowercase().as_str() {
+            "debug" => Ok(LogLevel::DEBUG),
+            "info" => Ok(LogLevel::INFO),
+            "notice" => Ok(LogLevel::NOTICE),
+            "warn" => Ok(LogLevel::WARN),
+            "error" => Ok(LogLevel::ERROR),
+            "fatal" => Ok(LogLevel::FATAL),
+            _ => Err(()),
+        }
+    }
+}
+
+impl TryFrom<String> for LogLevel {
+    type Error = ();
+    fn try_from(value: String) -> Result<Self, Self::Error> {
+        LogLevel::try_from(value.as_str())
+    }
+}
+
 #[macro_export]
 macro_rules! dns_log {
     ($level:expr, $($arg:tt)*) => {
@@ -248,18 +280,37 @@ impl DnsRequest {
         unsafe { smartdns_c::dns_server_request_get_rcode(self.request) as u16 }
     }
 
-    pub fn get_query_time(&self) -> u64 {
+    pub fn get_query_time(&self) -> i32 {
         unsafe { smartdns_c::dns_server_request_get_query_time(self.request) }
     }
 
+    pub fn get_query_timestamp(&self) -> u64 {
+        unsafe { smartdns_c::dns_server_request_get_query_timestamp(self.request) }
+    }
+
+    pub fn get_ping_time(&self) -> f64 {
+        let v = unsafe { smartdns_c::dns_server_request_get_ping_time(self.request) };
+        let mut ping_time = v as f64;
+        ping_time = (ping_time * 10.0).round() / 10.0;
+        ping_time
+    }
+
+    pub fn get_is_blocked(&self) -> bool {
+        unsafe { smartdns_c::dns_server_request_is_blocked(self.request) != 0 }
+    }
+
+    pub fn get_is_cached(&self) -> bool {
+        unsafe { smartdns_c::dns_server_request_is_cached(self.request) != 0 }
+    }
+
     pub fn get_remote_addr(&self) -> String {
         unsafe {
             let addr = smartdns_c::dns_server_request_get_remote_addr(self.request);
             let mut buf = [0u8; 1024];
             let retstr = smartdns_c::get_host_by_addr(
-                buf.as_mut_ptr(),
+                buf.as_mut_ptr() as *mut c_char,
                 buf.len() as i32,
-                addr as *const libc::sockaddr,
+                addr as *const smartdns_c::sockaddr,
             );
             if retstr.is_null() {
                 return String::new();
@@ -277,9 +328,9 @@ impl DnsRequest {
             let addr = smartdns_c::dns_server_request_get_local_addr(self.request);
             let mut buf = [0u8; 1024];
             let retstr = smartdns_c::get_host_by_addr(
-                buf.as_mut_ptr(),
+                buf.as_mut_ptr() as *mut c_char,
                 buf.len() as i32,
-                addr as *const libc::sockaddr,
+                addr as *const smartdns_c::sockaddr,
             );
             if retstr.is_null() {
                 return String::new();
@@ -291,6 +342,14 @@ impl DnsRequest {
             addr
         }
     }
+
+    pub fn is_prefetch_request(&self) -> bool {
+        unsafe { smartdns_c::dns_server_request_is_prefetch(self.request) != 0 }
+    }
+
+    pub fn is_dualstack_request(&self) -> bool {
+        unsafe { smartdns_c::dns_server_request_is_dualstack(self.request) != 0 }
+    }
 }
 
 impl Drop for DnsRequest {
@@ -345,6 +404,10 @@ impl Plugin {
         self.ops = Some(ops);
     }
 
+    pub fn clear_operation(&mut self) {
+        self.ops = None;
+    }
+
     pub fn smartdns_exit(status: i32) {
         unsafe {
             smartdns_c::smartdns_exit(status);
@@ -362,17 +425,17 @@ impl Plugin {
             let mut key = [0u8; 4096];
             let mut cert = [0u8; 4096];
             let ret = smartdns_c::smartdns_get_cert(
-                key.as_mut_ptr() as *mut libc::c_char,
-                cert.as_mut_ptr() as *mut libc::c_char,
+                key.as_mut_ptr() as *mut c_char,
+                cert.as_mut_ptr() as *mut c_char,
             );
             if ret != 0 {
                 return Err("get cert error".to_string());
             }
 
-            let key = std::ffi::CStr::from_ptr(key.as_ptr() as *const libc::c_char)
+            let key = std::ffi::CStr::from_ptr(key.as_ptr() as *const c_char)
                 .to_string_lossy()
                 .into_owned();
-            let cert = std::ffi::CStr::from_ptr(cert.as_ptr() as *const libc::c_char)
+            let cert = std::ffi::CStr::from_ptr(cert.as_ptr() as *const c_char)
                 .to_string_lossy()
                 .into_owned();
             Ok(SmartdnsCert {
@@ -414,18 +477,30 @@ impl Plugin {
     }
 
     #[allow(dead_code)]
-    pub fn dns_conf_plugin_config(key: &str, default: &str) -> String {
+    pub fn dns_conf_plugin_config(key: &str) -> Option<String> {
         let key = CString::new(key).expect("Failed to convert to CString");
         unsafe {
             let value = smartdns_c::smartdns_plugin_get_config(key.as_ptr());
             if value.is_null() {
-                return default.to_string();
+                return None;
             }
 
-            std::ffi::CStr::from_ptr(value)
-                .to_string_lossy()
-                .into_owned()
+            Some(
+                std::ffi::CStr::from_ptr(value)
+                    .to_string_lossy()
+                    .into_owned(),
+            )
+        }
+    }
+
+    #[allow(dead_code)]
+    pub fn dns_conf_plugin_config_default(key: &str, default_val: &String) -> String {
+        let v = Plugin::dns_conf_plugin_config(key);
+        if let Some(v) = v {
+            return v;
         }
+
+        default_val.clone()
     }
 
     #[allow(dead_code)]
@@ -454,6 +529,54 @@ impl Plugin {
     }
 }
 
+pub struct Stats {}
+
+impl Stats {
+    pub fn get_avg_process_time() -> f64 {
+        unsafe {
+            let v = smartdns_c::dns_stats_avg_time_get();
+            let mut process_time = v as f64;
+            process_time = (process_time * 10.0).round() / 10.0;
+            process_time
+        }
+    }
+
+    pub fn get_request_total() -> u64 {
+        unsafe { smartdns_c::dns_stats_request_total_get() }
+    }
+
+    pub fn get_request_success() -> u64 {
+        unsafe { smartdns_c::dns_stats_request_success_get() }
+    }
+
+    pub fn get_request_from_client() -> u64 {
+        unsafe { smartdns_c::dns_stats_request_from_client_get() }
+    }
+
+    pub fn get_request_blocked() -> u64 {
+        unsafe { smartdns_c::dns_stats_request_blocked_get() }
+    }
+
+    pub fn get_cache_hit() -> u64 {
+        unsafe { smartdns_c::dns_stats_cache_hit_get() }
+    }
+
+    pub fn get_cache_hit_rate() -> f64 {
+        unsafe {
+            let v = smartdns_c::dns_stats_cache_hit_rate_get() as f64;
+            let mut cache_hit_rate = v as f64;
+            cache_hit_rate = (cache_hit_rate * 10.0).round() / 10.0;
+            cache_hit_rate
+        }
+    }
+}
+
+impl fmt::Display for Stats {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        write!(f, "Stats")
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;

+ 52 - 0
plugin/smartdns-ui/src/utils.rs

@@ -0,0 +1,52 @@
+/*************************************************************************
+ *
+ * Copyright (C) 2018-2024 Ruilin Peng (Nick) <[email protected]>.
+ *
+ * smartdns is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * smartdns is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ */
+
+pub fn parse_value<T>(value: Option<String>, min: T, max: T, default: T) -> T
+where
+    T: PartialOrd + std::str::FromStr,
+{
+    if value.is_none() {
+        return default;
+    }
+
+    let value = value.unwrap().parse::<T>();
+    if let Err(_) = value {
+        return default;
+    }
+
+    let mut value = value.unwrap_or_else(|_| default);
+
+    if value < min {
+        value = min;
+    }
+
+    if value > max {
+        value = max;
+    }
+
+    value
+}
+
+
+pub fn seconds_until_next_hour() -> u64 {
+    let now = chrono::Local::now();
+    let minutes = chrono::Timelike::minute(&now);
+    let seconds = chrono::Timelike::second(&now);
+    let remaining_seconds = 3600 - (minutes * 60 + seconds) as u64;
+    remaining_seconds
+}

+ 26 - 11
plugin/smartdns-ui/tests/common/client.rs

@@ -25,6 +25,8 @@ use tungstenite::*;
 pub struct TestClient {
     url: String,
     token: Option<http_api_msg::TokenResponse>,
+    client: reqwest::blocking::Client,
+    no_auth_header: bool,
 }
 
 impl TestClient {
@@ -32,19 +34,27 @@ impl TestClient {
         let client = TestClient {
             url: url.clone(),
             token: None,
+            client: reqwest::blocking::ClientBuilder::new()
+                .danger_accept_invalid_certs(true)
+                .build()
+                .unwrap(),
+            no_auth_header: false,
         };
 
         client
     }
 
-    pub fn login(&mut self, user: &str, password: &str) -> Result<String, Box<dyn Error>> {
+    pub fn set_with_auth_header(&mut self, with_auth_header: bool) {
+        self.no_auth_header = with_auth_header;
+    }
+
+    pub fn login(&mut self, username: &str, password: &str) -> Result<String, Box<dyn Error>> {
         let url = self.url.clone() + "/api/auth/login";
         let body = http_api_msg::api_msg_gen_auth_login(&http_api_msg::AuthUser {
-            user: user.to_string(),
+            username: username.to_string(),
             password: password.to_string(),
         });
-        let client = reqwest::blocking::Client::new();
-        let resp = client.post(&url).body(body).send()?;
+        let resp = self.client.post(&url).body(body).send()?;
         let text = resp.text()?;
 
         let token = http_api_msg::api_msg_parse_auth_token(&text)?;
@@ -52,19 +62,24 @@ impl TestClient {
         Ok(text)
     }
 
+    pub fn logout(&mut self) -> Result<String, Box<dyn Error>> {
+        let url = self.url.clone() + "/api/auth/logout";
+        let resp = self.client.post(&url).send()?;
+        let text = resp.text()?;
+        self.token = None;
+        Ok(text)
+    }
+
     fn prep_request(
         &self,
         method: reqwest::Method,
         path: &str,
     ) -> Result<reqwest::blocking::RequestBuilder, Box<dyn Error>> {
         let url = self.url.clone() + path;
-        let client = reqwest::blocking::ClientBuilder::new()
-            .danger_accept_invalid_certs(true)
-            .build()?;
-        let mut req = client.request(method, url);
+        let mut req = self.client.request(method, url);
         if let Some(token) = &self.token {
-            if self.token.is_some() {
-                req = req.header("Authorization", format!("{}", token.token));
+            if self.token.is_some() && !self.no_auth_header {
+                req = req.header("Authorization", format!("Bearer {}", token.token));
             }
         }
         Ok(req)
@@ -116,7 +131,7 @@ impl TestClient {
         if let Some(token) = &self.token {
             if self.token.is_some() {
                 request_builder =
-                    request_builder.with_header("Authorization", format!("{}", token.token));
+                    request_builder.with_header("Authorization", format!("Bearer {}", token.token));
             }
         }
 

+ 4 - 0
plugin/smartdns-ui/tests/common/server.rs

@@ -231,6 +231,10 @@ impl TestServer {
             client: "127.0.0.1".to_string(),
             domain_group: "default".to_string(),
             reply_code: 0,
+            query_time: 0,
+            ping_time: -0.1 as f64,
+            is_blocked: false,
+            is_cached: false,
         }
     }
 

+ 232 - 3
plugin/smartdns-ui/tests/restapi_test.rs

@@ -30,7 +30,7 @@ async fn test_rest_api_login() {
 
     let c = reqwest::Client::new();
     let body = json!({
-        "user": "admin",
+        "username": "admin",
         "password": "password",
     });
 
@@ -64,6 +64,32 @@ async fn test_rest_api_login() {
     assert_eq!(calims.user, "admin");
 }
 
+
+#[test]
+fn test_rest_api_logout() {
+    let mut server = common::TestServer::new();
+    server.set_log_level(LogLevel::DEBUG);
+    assert!(server.start().is_ok());
+
+    let mut client = common::TestClient::new(&server.get_host());
+    client.set_with_auth_header(false);
+    let res = client.login("admin", "password");
+    assert!(res.is_ok());
+
+    let c = client.get("/api/cache/count");
+    assert!(c.is_ok());
+    let (code, _) = c.unwrap();
+    assert_eq!(code, 200);
+
+    let ret = client.logout();
+    assert!(ret.is_ok());
+
+    let c = client.get("/api/cache/count");
+    assert!(c.is_ok());
+    let (code, _) = c.unwrap();
+    assert_eq!(code, 401);
+}
+
 #[tokio::test]
 async fn test_rest_api_login_incorrect() {
     let mut server = common::TestServer::new();
@@ -72,7 +98,7 @@ async fn test_rest_api_login_incorrect() {
 
     let c = reqwest::Client::new();
     let body = json!({
-        "user": "admin",
+        "username": "admin",
         "password": "wrongpassword",
     });
 
@@ -92,6 +118,31 @@ async fn test_rest_api_login_incorrect() {
     assert_eq!(result.unwrap(), "Incorrect username or password.");
 }
 
+#[test]
+fn test_rest_api_change_password() {
+    let mut server = common::TestServer::new();
+    server.set_log_level(LogLevel::DEBUG);
+    assert!(server.start().is_ok());
+
+    let mut client = common::TestClient::new(&server.get_host());
+    let res = client.login("admin", "password");
+    assert!(res.is_ok());
+    let password_msg = http_api_msg::api_msg_gen_auth_password_change("wrong_oldpassword", "newpassword");
+    let c = client.put("/api/auth/password", password_msg.as_str());
+    assert!(c.is_ok());
+    let (code, _) = c.unwrap();
+    assert_eq!(code, 403);
+
+    let password_msg = http_api_msg::api_msg_gen_auth_password_change("password", "newpassword");
+    let c = client.put("/api/auth/password", password_msg.as_str());
+    assert!(c.is_ok());
+    let (code, _) = c.unwrap();
+    assert_eq!(code, 204);
+
+    let res = client.login("admin", "password");
+    assert!(!res.is_ok());    
+}
+
 #[test]
 fn test_rest_api_cache_count() {
     let mut server = common::TestServer::new();
@@ -317,7 +368,6 @@ fn test_rest_api_delete_domain_by_id() {
 fn test_rest_api_server_version() {
     let mut server = common::TestServer::new();
     server.set_log_level(LogLevel::DEBUG);
-    server.enable_mock_server();
     assert!(server.start().is_ok());
 
     let client = common::TestClient::new(&server.get_host());
@@ -354,3 +404,182 @@ fn test_rest_api_https_server() {
     assert_eq!(version.1, env!("CARGO_PKG_VERSION"));
 }
 
+#[test]
+fn test_rest_api_settings() {
+    let mut server = common::TestServer::new();
+    server.set_log_level(LogLevel::DEBUG);
+    assert!(server.start().is_ok());
+
+    let mut client = common::TestClient::new(&server.get_host());
+
+    let res = client.login("admin", "password");
+    assert!(res.is_ok());
+
+    let c = client.get("/api/config/settings");
+    assert!(c.is_ok());
+    let (code, body) = c.unwrap();
+    assert_eq!(code, 200);
+    let settings = http_api_msg::api_msg_parse_key_value(&body);
+    assert!(settings.is_ok());
+
+    let mut settings = std::collections::HashMap::new();
+    settings.insert("key1".to_string(), "value1".to_string());
+    settings.insert("key2".to_string(), "value2".to_string());
+    let body = http_api_msg::api_msg_gen_key_value(&settings);
+    let c = client.put("/api/config/settings", body.as_str());
+    assert!(c.is_ok());
+    let (code, _) = c.unwrap();
+    assert_eq!(code, 204);
+
+    let c = client.get("/api/config/settings");
+    assert!(c.is_ok());
+    let (code, body) = c.unwrap();
+    assert_eq!(code, 200);
+    let settings = http_api_msg::api_msg_parse_key_value(&body);
+    assert!(settings.is_ok());
+    let settings = settings.unwrap();
+    assert_eq!(settings.len(), 2);
+    assert_eq!(settings["key1"], "value1");
+}
+
+#[test]
+fn test_rest_api_get_client() {
+    let mut server = common::TestServer::new();
+    server.set_log_level(LogLevel::DEBUG);
+    assert!(server.start().is_ok());
+
+    let record = server.new_mock_domain_record();
+    for i in 0..1024 {
+        let mut record = record.clone();
+        record.domain = format!("{}.com", i);
+        record.client = format!("client-{}", i);
+        assert!(server.add_domain_record(&record).is_ok());
+    }
+
+    let mut client = common::TestClient::new(&server.get_host());
+    let res = client.login("admin", "password");
+    assert!(res.is_ok());
+
+    let c = client.get("/api/client");
+    assert!(c.is_ok());
+    let (code, body) = c.unwrap();
+    assert_eq!(code, 200);
+    let list = http_api_msg::api_msg_parse_client_list(&body);
+    assert!(list.is_ok());
+    let list = list.unwrap();
+    assert_eq!(list.len(), 1024);
+}
+
+#[test]
+fn test_rest_api_stats_top() {
+    let mut server = common::TestServer::new();
+    server.set_log_level(LogLevel::DEBUG);
+    assert!(server.start().is_ok());
+
+    let record = server.new_mock_domain_record();
+    for i in 0..1024 {
+        let mut record = record.clone();
+        if i < 512 {
+            record.domain = format!("a.com");
+            record.client = format!("192.168.1.1");
+        } else if i < 512 + 256 + 128 {
+            record.domain = format!("b.com");
+            record.client = format!("192.168.1.2");
+        } else {
+            record.domain = format!("c.com");
+            record.client = format!("192.168.1.3");
+        }
+        assert!(server.add_domain_record(&record).is_ok());
+    }
+
+    let mut client = common::TestClient::new(&server.get_host());
+    let res = client.login("admin", "password");
+    assert!(res.is_ok());
+
+    let c = client.get("/api/stats/top/client");
+    assert!(c.is_ok());
+    let (code, body) = c.unwrap();
+    assert_eq!(code, 200);
+    let list = http_api_msg::api_msg_parse_top_client_list(&body);
+    assert!(list.is_ok());
+    let list = list.unwrap();
+    assert_eq!(list.len(), 3);
+    assert_eq!(list[0].client_ip, "192.168.1.1");
+    assert_eq!(list[0].count, 512);
+    assert_eq!(list[2].client_ip, "192.168.1.3");
+    assert_eq!(list[2].count, 128);
+
+    let c = client.get("/api/stats/top/domain");
+    assert!(c.is_ok());
+    let (code, body) = c.unwrap();
+    assert_eq!(code, 200);
+    let list = http_api_msg::api_msg_parse_top_domain_list(&body);
+    assert!(list.is_ok());
+    let list = list.unwrap();
+    assert_eq!(list.len(), 3);
+    assert_eq!(list[0].domain, "a.com");
+    assert_eq!(list[0].count, 512);
+    assert_eq!(list[2].domain, "c.com");
+    assert_eq!(list[2].count, 128);
+}
+
+
+#[test]
+fn test_rest_api_stats_overview() {
+    let mut server = common::TestServer::new();
+    server.set_log_level(LogLevel::DEBUG);
+    server.enable_mock_server();
+    assert!(server.start().is_ok());
+
+    let mut client = common::TestClient::new(&server.get_host());
+    let res = client.login("admin", "password");
+    assert!(res.is_ok());
+
+    unsafe {
+        smartdns_ui::smartdns::smartdns_c::dns_stats.avg_time.avg_time = 22.0 as f32;
+        smartdns_ui::smartdns::smartdns_c::dns_stats.request.blocked_count = 10;
+        smartdns_ui::smartdns::smartdns_c::dns_stats.request.total = 15;
+    }
+
+    let c = client.get("/api/stats/overview");
+    assert!(c.is_ok());
+    let (code, body) = c.unwrap();
+    assert_eq!(code, 200);
+    let overview = http_api_msg::api_msg_parse_stats_overview(&body);
+    assert!(overview.is_ok());
+    let overview = overview.unwrap();
+    assert_eq!(overview.avg_query_time, 22.0 as f64);
+    assert_eq!(overview.cache_hit_rate, 0 as f64);
+    assert_eq!(overview.total_query_count, 15);
+    assert_eq!(overview.block_query_count, 10);
+}
+
+
+#[test]
+fn test_rest_api_get_hourly_query_count() {
+    let mut server = common::TestServer::new();
+    server.set_log_level(LogLevel::DEBUG);
+    assert!(server.start().is_ok());
+
+    let record = server.new_mock_domain_record();
+    for i in 0..1024 {
+        let mut record = record.clone();
+        record.domain = format!("{}.com", i);
+        record.client = format!("client-{}", i);
+        assert!(server.add_domain_record(&record).is_ok());
+    }
+
+    let mut client = common::TestClient::new(&server.get_host());
+    let res = client.login("admin", "password");
+    assert!(res.is_ok());
+
+    let c = client.get("/api/stats/hourly-query-count");
+    assert!(c.is_ok());
+    let (code, body) = c.unwrap();
+    assert_eq!(code, 200);
+    let list = http_api_msg::api_msg_parse_hourly_query_count(&body);
+    assert!(list.is_ok());
+    let list = list.unwrap();
+    assert_eq!(list.len(), 1);
+    assert_eq!(list[0].query_count, 1024);
+}

+ 138 - 8
src/dns_server.c

@@ -25,6 +25,7 @@
 #include "dns_cache.h"
 #include "dns_client.h"
 #include "dns_conf.h"
+#include "dns_stats.h"
 #include "dns_plugin.h"
 #include "fast_ping.h"
 #include "hashtable.h"
@@ -192,6 +193,7 @@ struct dns_server_post_context {
 	int skip_notify_count;
 	int select_all_best_ip;
 	int no_release_parent;
+	int is_cache_reply;
 };
 
 typedef enum dns_server_client_status {
@@ -342,6 +344,8 @@ struct dns_request {
 
 	int is_mdns_lookup;
 
+	int is_cache_reply;
+
 	struct dns_srv_records *srv_records;
 
 	atomic_t notified;
@@ -392,7 +396,8 @@ struct dns_request {
 
 	void *private_data;
 
-	uint64_t query_time;
+	uint64_t query_timestamp;
+	int query_time;
 };
 
 /* dns server data */
@@ -690,6 +695,7 @@ static int _dns_server_is_return_soa_qtype(struct dns_request *request, dns_type
 	if (rule_flag) {
 		flags = rule_flag->flags;
 		if (flags & DOMAIN_FLAG_ADDR_SOA) {
+			dns_stats_request_blocked_inc();
 			return 1;
 		}
 
@@ -701,6 +707,7 @@ static int _dns_server_is_return_soa_qtype(struct dns_request *request, dns_type
 		switch (qtype) {
 		case DNS_T_A:
 			if (flags & DOMAIN_FLAG_ADDR_IPV4_SOA) {
+				dns_stats_request_blocked_inc();
 				return 1;
 			}
 
@@ -711,6 +718,7 @@ static int _dns_server_is_return_soa_qtype(struct dns_request *request, dns_type
 			break;
 		case DNS_T_AAAA:
 			if (flags & DOMAIN_FLAG_ADDR_IPV6_SOA) {
+				dns_stats_request_blocked_inc();
 				return 1;
 			}
 
@@ -721,6 +729,7 @@ static int _dns_server_is_return_soa_qtype(struct dns_request *request, dns_type
 			break;
 		case DNS_T_HTTPS:
 			if (flags & DOMAIN_FLAG_ADDR_HTTPS_SOA) {
+				dns_stats_request_blocked_inc();
 				return 1;
 			}
 
@@ -853,6 +862,10 @@ static void _dns_server_audit_log(struct dns_server_post_context *context)
 	struct dns_request *request = context->request;
 	int has_soa = request->has_soa;
 
+	if (atomic_read(&request->notified) == 1) {
+		request->query_time = get_tick_count() - request->send_tick;
+	}
+
 	if (dns_audit == NULL || !dns_conf.audit_enable || context->do_audit == 0) {
 		return;
 	}
@@ -950,9 +963,8 @@ static void _dns_server_audit_log(struct dns_server_post_context *context)
 				 tm.min, tm.sec, tm.usec / 1000);
 	}
 
-	tlog_printf(dns_audit, "%s%s query %s, type %d, time %lums, speed: %.1fms, group %s, result %s\n", req_time,
-				req_host, request->domain, request->qtype, get_tick_count() - request->send_tick,
-				((float)request->ping_time) / 10,
+	tlog_printf(dns_audit, "%s%s query %s, type %d, time %dms, speed: %.1fms, group %s, result %s\n", req_time,
+				req_host, request->domain, request->qtype, request->query_time, ((float)request->ping_time) / 10,
 				request->dns_group_name[0] != '\0' ? request->dns_group_name : DNS_SERVER_GROUP_DEFAULT, req_result);
 }
 
@@ -2456,8 +2468,10 @@ static int _dns_server_reply_all_pending_list(struct dns_request *request, struc
 		req->dualstack_selection_has_ip = request->dualstack_selection_has_ip;
 		req->dualstack_selection_ping_time = request->dualstack_selection_ping_time;
 		req->ping_time = request->ping_time;
+		req->is_cache_reply = request->is_cache_reply;
 		_dns_server_get_answer(&context_pending);
 
+		context_pending.is_cache_reply = context->is_cache_reply;
 		context_pending.do_cache = 0;
 		context_pending.do_audit = context->do_audit;
 		context_pending.do_reply = context->do_reply;
@@ -2465,6 +2479,11 @@ static int _dns_server_reply_all_pending_list(struct dns_request *request, struc
 		context_pending.do_ipset = 0;
 		context_pending.reply_ttl = request->ip_ttl;
 		context_pending.no_release_parent = 0;
+
+		if (context_pending.is_cache_reply) {
+			dns_stats_cache_hit_inc();
+		}
+
 		_dns_server_reply_passthrough(&context_pending);
 
 		req->request_pending_list = NULL;
@@ -2644,7 +2663,6 @@ out:
 	context.skip_notify_count = 1;
 	context.select_all_best_ip = with_all_ips;
 	context.no_release_parent = 1;
-
 	_dns_request_post(&context);
 	return _dns_server_reply_all_pending_list(request, &context);
 }
@@ -2896,7 +2914,7 @@ static void _dns_server_request_release_complete(struct dns_request *request, in
 	list_del_init(&request->pending_list);
 	pthread_mutex_unlock(&server.request_pending_lock);
 
-	if (do_complete) {
+	if (do_complete && atomic_read(&request->plugin_complete_called) == 0) {
 		/* Select max hit ip address, and return to client */
 		_dns_server_select_possible_ipaddress(request);
 		_dns_server_complete_with_multi_ipaddress(request);
@@ -2925,6 +2943,13 @@ static void _dns_server_request_release_complete(struct dns_request *request, in
 	}
 	pthread_mutex_unlock(&request->ip_map_lock);
 
+	if (request->rcode == DNS_RC_NOERROR) {
+		dns_stats_request_success_inc();
+	}
+
+	if (request->conn) {
+		dns_stats_avg_time_add(request->query_time);
+	}
 	_dns_server_delete_request(request);
 }
 
@@ -2942,46 +2967,136 @@ static void _dns_server_request_get(struct dns_request *request)
 
 const struct sockaddr *dns_server_request_get_remote_addr(struct dns_request *request)
 {
+	if (request->conn == NULL) {
+		return NULL;
+	}
+
 	return &request->addr;
 }
 
 const struct sockaddr *dns_server_request_get_local_addr(struct dns_request *request)
 {
+	if (request == NULL) {
+		return NULL;
+	}
+
 	return (struct sockaddr *)&request->localaddr;
 }
 
 const char *dns_server_request_get_group_name(struct dns_request *request)
 {
+	if (request == NULL) {
+		return NULL;
+	}
+
 	return request->dns_group_name;
 }
 
 const char *dns_server_request_get_domain(struct dns_request *request)
 {
+	if (request == NULL) {
+		return NULL;
+	}
+
 	return request->domain;
 }
 
 int dns_server_request_get_qtype(struct dns_request *request)
 {
+	if (request == NULL) {
+		return 0;
+	}
+
 	return request->qtype;
 }
 
 int dns_server_request_get_qclass(struct dns_request *request)
 {
+	if (request == NULL) {
+		return 0;
+	}
+
 	return request->qclass;
 }
 
-uint64_t dns_server_request_get_query_time(struct dns_request *request)
+int dns_server_request_get_query_time(struct dns_request *request)
 {
+	if (request == NULL) {
+		return -1;
+	}
+
 	return request->query_time;
 }
 
+uint64_t dns_server_request_get_query_timestamp(struct dns_request *request)
+{
+	if (request == NULL) {
+		return 0;
+	}
+
+	return request->query_timestamp;
+}
+
+float dns_server_request_get_ping_time(struct dns_request *request)
+{
+	if (request == NULL) {
+		return 0;
+	}
+
+	return (float)request->ping_time / 10;
+}
+
+int dns_server_request_is_prefetch(struct dns_request *request)
+{
+	if (request == NULL) {
+		return 0;
+	}
+
+	return request->prefetch;
+}
+
+int dns_server_request_is_dualstack(struct dns_request *request)
+{
+	if (request == NULL) {
+		return 0;
+	}
+
+	return request->dualstack_selection_query;
+}
+
+int dns_server_request_is_blocked(struct dns_request *request)
+{
+	if (request == NULL) {
+		return 0;
+	}
+
+	return _dns_server_is_return_soa(request);
+}
+
+int dns_server_request_is_cached(struct dns_request *request)
+{
+	if (request == NULL) {
+		return 0;
+	}
+
+	return request->is_cache_reply;
+}
+
 int dns_server_request_get_id(struct dns_request *request)
 {
+	if (request == NULL) {
+		return 0;
+	}
+
 	return request->id;
 }
 
 int dns_server_request_get_rcode(struct dns_request *request)
 {
+	if (request == NULL) {
+		return DNS_RC_SERVFAIL;
+	}
+
 	return request->rcode;
 }
 
@@ -2997,11 +3112,19 @@ void dns_server_request_put(struct dns_request *request)
 
 void dns_server_request_set_private(struct dns_request *request, void *private_data)
 {
+	if (request == NULL) {
+		return;
+	}
+
 	request->private_data = private_data;
 }
 
 void *dns_server_request_get_private(struct dns_request *request)
 {
+	if (request == NULL) {
+		return NULL;
+	}
+
 	return request->private_data;
 }
 
@@ -3102,13 +3225,14 @@ static struct dns_request *_dns_server_new_request(void)
 	request->conf = dns_server_get_default_rule_group();
 	request->check_order_list = &dns_conf.default_check_orders;
 	request->response_mode = dns_conf.default_response_mode;
-	request->query_time = get_utc_time_ms();
+	request->query_timestamp = get_utc_time_ms();
 	INIT_LIST_HEAD(&request->list);
 	INIT_LIST_HEAD(&request->pending_list);
 	INIT_LIST_HEAD(&request->check_list);
 	hash_init(request->ip_map);
 	_dns_server_request_get(request);
 	atomic_add(1, &server.request_num);
+	dns_stats_request_inc();
 
 	return request;
 errout:
@@ -6282,11 +6406,14 @@ static int _dns_server_process_cache_packet(struct dns_request *request, struct
 		}
 	}
 
+	request->is_cache_reply = 1;
+	dns_stats_cache_hit_inc();
 	request->rcode = context.packet->head.rcode;
 	context.do_cache = 0;
 	context.do_ipset = do_ipset;
 	context.do_audit = 1;
 	context.do_reply = 1;
+	context.is_cache_reply = 1;
 	context.reply_ttl = _dns_server_get_expired_ttl_reply(request, dns_cache);
 	ret = _dns_server_reply_passthrough(&context);
 out:
@@ -7231,6 +7358,7 @@ static int _dns_server_recv(struct dns_server_conn_head *conn, unsigned char *in
 	_dns_server_request_set_client(request, conn);
 	_dns_server_request_set_client_addr(request, from, from_len);
 	_dns_server_request_set_id(request, packet->head.id);
+	dns_stats_request_from_client_inc();
 
 	if (_dns_server_parser_request(request, packet) != 0) {
 		tlog(TLOG_DEBUG, "parser request failed.");
@@ -8532,6 +8660,8 @@ static void _dns_server_period_run_second(void)
 	}
 
 	_dns_server_save_cache_to_file();
+
+	dns_stats_period_run_second();
 }
 
 static void _dns_server_period_run(unsigned int msec)

+ 13 - 1
src/dns_server.h

@@ -87,7 +87,19 @@ int dns_server_request_get_id(struct dns_request *request);
 
 int dns_server_request_get_rcode(struct dns_request *request);
 
-uint64_t dns_server_request_get_query_time(struct dns_request *request);
+uint64_t dns_server_request_get_query_timestamp(struct dns_request *request);
+
+int dns_server_request_get_query_time(struct dns_request *request);
+
+float dns_server_request_get_ping_time(struct dns_request *request);
+
+int dns_server_request_is_prefetch(struct dns_request *request);
+
+int dns_server_request_is_dualstack(struct dns_request *request);
+
+int dns_server_request_is_blocked(struct dns_request *request);
+
+int dns_server_request_is_cached(struct dns_request *request);
 
 void dns_server_request_get(struct dns_request *request);
 

+ 116 - 0
src/dns_stats.c

@@ -0,0 +1,116 @@
+/*************************************************************************
+ *
+ * Copyright (C) 2018-2024 Ruilin Peng (Nick) <[email protected]>.
+ *
+ * smartdns is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * smartdns is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#include "dns_stats.h"
+#include "stddef.h"
+#include "string.h"
+
+struct dns_stats dns_stats;
+
+#define SAMPLE_PERIOD 5
+
+void dns_stats_avg_time_update(void)
+{
+	uint64_t total = stats_read_and_set(&dns_stats.avg_time.total, 0);
+	uint64_t count = total >> 32;
+	uint64_t time = total & 0xFFFFFFFF;
+
+	if (count == 0) {
+		return;
+	}
+
+	float sample_avg = (float)time / count;
+
+	if (dns_stats.avg_time.avg_time == 0) {
+		dns_stats.avg_time.avg_time = sample_avg;
+	} else {
+		int base = 1000;
+		if (count > 100) {
+			count = 100;
+		}
+
+		float weight_new = (float)count / base;
+		float weight_prev = 1.0 - weight_new;
+
+		dns_stats.avg_time.avg_time = (dns_stats.avg_time.avg_time * weight_prev) + (sample_avg * weight_new);
+	}
+}
+
+void dns_stats_period_run_second(void)
+{
+	static int last_total = 0;
+	last_total++;
+	if (last_total % SAMPLE_PERIOD == 0) {
+		dns_stats_avg_time_update();
+		dns_stats_avg_time_get();
+	}
+}
+
+float dns_stats_avg_time_get(void)
+{
+	return dns_stats.avg_time.avg_time;
+}
+
+uint64_t dns_stats_request_total_get(void)
+{
+	return stats_read(&dns_stats.request.total);
+}
+
+uint64_t dns_stats_request_success_get(void)
+{
+	return stats_read(&dns_stats.request.success_count);
+}
+
+uint64_t dns_stats_request_from_client_get(void)
+{
+	return stats_read(&dns_stats.request.from_client_count);
+}
+
+uint64_t dns_stats_request_blocked_get(void)
+{
+	return stats_read(&dns_stats.request.blocked_count);
+}
+
+uint64_t dns_stats_cache_hit_get(void)
+{
+	return stats_read(&dns_stats.cache.hit_count);
+}
+
+float dns_stats_cache_hit_rate_get(void)
+{
+	uint64_t total = stats_read(&dns_stats.request.from_client_count);
+	uint64_t hit = stats_read(&dns_stats.cache.hit_count);
+
+	if (total == 0) {
+		return 0;
+	}
+
+	return (float)(hit * 100) / total;
+}
+
+int dns_stats_init(void)
+{
+	memset(&dns_stats, 0, sizeof(dns_stats));
+	return 0;
+}
+
+void dns_stats_exit(void)
+{	
+	memset(&dns_stats, 0, sizeof(dns_stats));
+	return;
+}

+ 144 - 0
src/dns_stats.h

@@ -0,0 +1,144 @@
+/*************************************************************************
+ *
+ * Copyright (C) 2018-2024 Ruilin Peng (Nick) <[email protected]>.
+ *
+ * smartdns is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * smartdns is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef SMART_DNS_STATS_H
+#define SMART_DNS_STATS_H
+
+#include "atomic.h"
+#include <stdint.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif /*__cplusplus */
+
+struct dns_request_avg_time {
+	uint64_t total; /* Hight 4 bytes, count, Low 4 bytes time*/
+	float avg_time;
+};
+
+struct dns_request_stats {
+	uint64_t total;
+	uint64_t success_count;
+	uint64_t from_client_count;
+	uint64_t blocked_count;
+};
+
+struct dns_cache_stats {
+	uint64_t hit_count;
+};
+
+struct dns_stats {
+	struct dns_request_stats request;
+	struct dns_cache_stats cache;
+	struct dns_request_avg_time avg_time;
+};
+
+extern struct dns_stats dns_stats;
+
+static inline uint64_t stats_read(const uint64_t *s)
+{
+	return READ_ONCE((*s));
+}
+
+static inline uint64_t stats_read_and_set(uint64_t *s, uint64_t v)
+{
+	return __sync_lock_test_and_set(s, v);
+}
+
+static inline void stats_set(uint64_t *s, uint64_t v)
+{
+	*s = v;
+}
+
+static inline void stats_add(uint64_t *s, uint64_t v)
+{
+	(void)__sync_add_and_fetch(s, v);
+}
+
+static inline void stats_inc(uint64_t *s)
+{
+	(void)__sync_add_and_fetch(s, 1);
+}
+
+static inline void stats_sub(uint64_t *s, uint64_t v)
+{
+	(void)__sync_sub_and_fetch(s, v);
+}
+
+static inline void stats_dec(uint64_t *s)
+{
+	(void)__sync_sub_and_fetch(s, 1);
+}
+
+static inline void dns_stats_request_inc(void)
+{
+	stats_inc(&dns_stats.request.total);
+}
+
+static inline void dns_stats_request_success_inc(void)
+{
+	stats_inc(&dns_stats.request.success_count);
+}
+
+static inline void dns_stats_request_from_client_inc(void)
+{
+	stats_inc(&dns_stats.request.from_client_count);
+}
+
+static inline void dns_stats_request_blocked_inc(void)
+{
+	stats_inc(&dns_stats.request.blocked_count);
+}
+
+static inline void dns_stats_cache_hit_inc(void)
+{
+	stats_inc(&dns_stats.cache.hit_count);
+}
+
+static inline void dns_stats_avg_time_add(uint64_t time)
+{
+	uint64_t total = (uint64_t)1 << 32 | time; 
+	stats_add(&dns_stats.avg_time.total, total);
+}
+
+float dns_stats_avg_time_get(void);
+
+uint64_t dns_stats_request_total_get(void);
+
+uint64_t dns_stats_request_success_get(void);
+
+uint64_t dns_stats_request_from_client_get(void);
+
+uint64_t dns_stats_request_blocked_get(void);
+
+uint64_t dns_stats_cache_hit_get(void);
+
+float dns_stats_cache_hit_rate_get(void);
+
+void dns_stats_avg_time_update(void);
+
+void dns_stats_period_run_second(void);
+
+int dns_stats_init(void);
+
+void dns_stats_exit(void);
+
+#ifdef __cplusplus
+}
+#endif /*__cplusplus */
+#endif

+ 12 - 1
src/main.c

@@ -17,8 +17,19 @@
  */
 
 #include "smartdns.h"
+#include <errno.h>
+#include <stdio.h>
+#include <string.h>
 
 int main(int argc, char *argv[])
 {
-    return smartdns_main(argc, argv);
+	const char *smartdns_workdir = getenv("SMARTDNS_WORKDIR");
+	if (smartdns_workdir != NULL) {
+		if (chdir(smartdns_workdir) != 0) {
+			fprintf(stderr, "chdir to %s failed: %s\n", smartdns_workdir, strerror(errno));
+			return 1;
+		}
+	}
+
+	return smartdns_main(argc, argv);
 }

+ 7 - 0
src/smartdns.c

@@ -619,6 +619,12 @@ static int _smartdns_init(void)
 		tlog(TLOG_ERROR, "add proxy servers failed.");
 	}
 
+	ret = dns_stats_init();
+	if (ret != 0) {
+		tlog(TLOG_ERROR, "start dns stats failed.\n");
+		goto errout;
+	}
+
 	ret = dns_server_init();
 	if (ret != 0) {
 		tlog(TLOG_ERROR, "start dns server failed.\n");
@@ -661,6 +667,7 @@ static void _smartdns_exit(void)
 	proxy_exit();
 	fast_ping_exit();
 	dns_server_exit();
+	dns_stats_exit();
 	_smartdns_destroy_ssl();
 	dns_timer_destroy();
 	tlog_exit();

+ 1 - 0
src/smartdns.h

@@ -26,6 +26,7 @@
 #include "dns_plugin.h"
 #include "dns_server.h"
 #include "fast_ping.h"
+#include "dns_stats.h"
 #include "util.h"
 
 #ifdef __cplusplus