Browse Source

webUI: add audit log API.

Nick Peng 1 month ago
parent
commit
762a7d8900

+ 17 - 0
plugin/smartdns-ui/src/data_server.rs

@@ -23,6 +23,8 @@ use crate::dns_log;
 use crate::plugin::SmartdnsPlugin;
 use crate::server_log::ServerLog;
 use crate::server_log::ServerLogMsg;
+use crate::server_log::ServerAuditLog;
+use crate::server_log::ServerAuditLogMsg;
 use crate::smartdns;
 use crate::smartdns::*;
 use crate::utils;
@@ -238,6 +240,10 @@ impl DataServerControl {
     pub fn server_log(&self, level: LogLevel, msg: &str, msg_len: i32) {
         self.data_server.server_log(level, msg, msg_len);
     }
+
+    pub fn server_audit_log(&self, msg: &str, msg_len: i32) {
+        self.data_server.server_audit_log(msg, msg_len);
+    }
 }
 
 impl Drop for DataServerControl {
@@ -256,6 +262,7 @@ pub struct DataServer {
     disable_handle_request: AtomicBool,
     stat: Arc<DataStats>,
     server_log: ServerLog,
+    server_audit_log: ServerAuditLog,
     plugin: Mutex<Weak<SmartdnsPlugin>>,
     whois: whois::WhoIs,
     startup_timestamp: u64,
@@ -275,6 +282,7 @@ impl DataServer {
             db: db.clone(),
             stat: DataStats::new(db, conf.clone()),
             server_log: ServerLog::new(),
+            server_audit_log: ServerAuditLog::new(),
             plugin: Mutex::new(Weak::new()),
             whois: whois::WhoIs::new(),
             startup_timestamp: get_utc_time_ms(),
@@ -622,6 +630,15 @@ impl DataServer {
     pub fn server_log(&self, level: LogLevel, msg: &str, msg_len: i32) {
         self.server_log.dispatch_log(level, msg, msg_len);
     }
+    
+    pub async fn get_audit_log_stream(&self) -> mpsc::Receiver<ServerAuditLogMsg> {
+        return self.server_audit_log.get_audit_log_stream().await;
+    }
+
+    pub fn server_audit_log(&self, msg: &str, msg_len: i32) {
+        self.server_audit_log
+            .dispatch_audit_log(msg, msg_len);
+    }
 
     fn server_check(&self) {
         let free_disk_space = self.get_free_disk_space();

+ 22 - 0
plugin/smartdns-ui/src/http_server_api.rs

@@ -92,6 +92,7 @@ impl API {
         api.register(Method::GET, "/api/log/stream", true, APIRoute!(API::api_log_stream));
         api.register(Method::PUT, "/api/log/level", true, APIRoute!(API::api_log_set_level));
         api.register(Method::GET, "/api/log/level", true, APIRoute!(API::api_log_get_level));
+        api.register(Method::GET, "/api/log/audit/stream", true, APIRoute!(API::api_audit_log_stream));
         api.register(Method::GET, "/api/server/version", false, APIRoute!(API::api_server_version));
         api.register(Method::GET, "/api/upstream-server", true, APIRoute!(API::api_upstream_server_get_list));
         api.register(Method::GET, "/api/config/settings", true, APIRoute!(API::api_config_get_settings));
@@ -833,6 +834,27 @@ impl API {
         }
     }
 
+    async fn api_audit_log_stream(
+        this: Arc<HttpServer>,
+        _param: APIRouteParam,
+        mut req: Request<body::Incoming>,
+    ) -> Result<Response<Full<Bytes>>, HttpError> {
+        if hyper_tungstenite::is_upgrade_request(&req) {
+            let (response, websocket) = hyper_tungstenite::upgrade(&mut req, None)
+                .map_err(|e| HttpError::new(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
+
+            tokio::spawn(async move {
+                if let Err(e) = http_server_stream::serve_audit_log_stream(this, websocket).await {
+                    dns_log!(LogLevel::DEBUG, "Error in websocket connection: {e}");
+                }
+            });
+
+            Ok(response)
+        } else {
+            return API::response_error(StatusCode::BAD_REQUEST, "Need websocket upgrade.");
+        }
+    }
+
     async fn api_log_set_level(
         this: Arc<HttpServer>,
         _param: APIRouteParam,

+ 82 - 0
plugin/smartdns-ui/src/http_server_stream.rs

@@ -172,6 +172,88 @@ pub async fn serve_log_stream(
     Ok(())
 }
 
+pub async fn serve_audit_log_stream(
+    http_server: Arc<HttpServer>,
+    websocket: HyperWebsocket,
+) -> Result<(), Error> {
+    let mut websocket = websocket.await?;
+    let mut is_pause = false;
+
+    let data_server = http_server.get_data_server();
+    let mut log_stream = data_server.get_audit_log_stream().await;
+
+    loop {
+        tokio::select! {
+            msg = log_stream.recv() => {
+                if is_pause {
+                    continue;
+                }
+
+                match msg {
+                    Some(msg) => {
+                        let mut binary_msg = Vec::with_capacity(1 + msg.msg.len());
+                        binary_msg.push(0);
+                        binary_msg.extend_from_slice(msg.msg.as_bytes());
+                        let msg = Message::Binary(binary_msg.into());
+                        websocket.send(msg).await?;
+                    }
+                    None => {
+                        websocket.send(Message::Close(None)).await?;
+                        break;
+                    }
+                }
+            }
+
+            msg = websocket.next() => {
+                let message = msg.ok_or("websocket closed")??;
+                match message {
+                    Message::Text(_msg) => {}
+                    Message::Binary(msg) => {
+                        if msg.len() == 0 {
+                            continue;
+                        }
+
+                        let msg_type = msg[0];
+                        match msg_type {
+                            LOG_CONTROL_MESSAGE_TYPE => {
+                                if msg.len() < 2 {
+                                    continue;
+                                }
+                                let control_type = msg[1];
+                                match control_type {
+                                    LOG_CONTROL_PAUSE => {
+                                        is_pause = true;
+                                        continue;
+                                    }
+                                    LOG_CONTROL_RESUME => {
+                                        is_pause = false;
+                                        continue;
+                                    }
+                                    _ => {
+                                        continue;
+                                    }
+                                }
+                            }
+                            _ => {}
+                        }
+                    }
+                    Message::Ping(_msg) => {}
+                    Message::Pong(_msg) => {}
+                    Message::Close(_msg) => {
+                        websocket.send(Message::Close(None)).await?;
+                        break;
+                    }
+                    Message::Frame(_msg) => {
+                        unreachable!();
+                    }
+                }
+            }
+        }
+    }
+
+    Ok(())
+}
+
 pub async fn serve_metrics(
     data_server: Arc<DataServer>,
     websocket: HyperWebsocket,

+ 8 - 0
plugin/smartdns-ui/src/plugin.rs

@@ -180,6 +180,10 @@ impl SmartdnsPlugin {
     pub fn server_log(&self, level: LogLevel, msg: &str, msg_len: i32) {
         self.data_server_ctl.server_log(level, msg, msg_len);
     }
+
+    pub fn server_audit_log(&self, msg: &str, msg_len: i32) {
+        self.data_server_ctl.server_audit_log(msg, msg_len);
+    }
 }
 
 impl Drop for SmartdnsPlugin {
@@ -215,6 +219,10 @@ impl SmartdnsOperations for SmartdnsPluginImpl {
         self.plugin.server_log(level, msg, msg_len);
     }
 
+    fn server_audit_log(&self, msg: &str, msg_len: i32) {
+        self.plugin.server_audit_log(msg, msg_len);
+    }
+
     fn server_init(&mut self, args: &Vec<String>) -> Result<(), Box<dyn Error>> {
         self.plugin.start(args)
     }

+ 60 - 0
plugin/smartdns-ui/src/server_log.rs

@@ -80,3 +80,63 @@ impl ServerLog {
         }
     }
 }
+
+
+#[derive(Clone)]
+pub struct ServerAuditLogMsg {
+    pub msg: String,
+    pub len: i32,
+}
+
+pub struct ServerAuditLog {
+    streams: RwLock<Vec<mpsc::Sender<ServerAuditLogMsg>>>,
+}
+
+impl ServerAuditLog {
+    pub fn new() -> Self {
+        ServerAuditLog {
+            streams: RwLock::new(Vec::new()),
+        }
+    }
+
+    pub async fn get_audit_log_stream(&self) -> mpsc::Receiver<ServerAuditLogMsg> {
+        let (tx, rx) = mpsc::channel(4096);
+        self.streams.write().await.push(tx);
+        rx
+    }
+
+    pub fn dispatch_audit_log(&self, msg: &str, len: i32) {
+        let mut remove_list = Vec::new();
+
+        {
+            let streams = self.streams.blocking_read();
+            if streams.len() == 0 {
+                return;
+            }
+
+            let msg = ServerAuditLogMsg {
+                msg: msg.to_string(),
+                len,
+            };
+
+            for (i, stream) in streams.iter().enumerate() {
+                let ret = stream.try_send(msg.clone());
+                if let Err(e) = ret {
+                    match e {
+                        mpsc::error::TrySendError::Full(_) => {}
+                        mpsc::error::TrySendError::Closed(_) => {
+                            remove_list.push(i);
+                        }
+                    }
+                }
+            }
+        }
+
+        if remove_list.len() > 0 {
+            let mut streams = self.streams.blocking_write();
+            for i in remove_list.iter().rev() {
+                streams.remove(*i);
+            }
+        }
+    }
+}

+ 21 - 0
plugin/smartdns-ui/src/smartdns.rs

@@ -297,6 +297,7 @@ static SMARTDNS_OPS: smartdns_c::smartdns_operations = smartdns_c::smartdns_oper
     server_recv: None,
     server_query_complete: Some(dns_request_complete),
     server_log: Some(dns_server_log),
+    server_audit_log: Some(dns_server_audit_log),
 };
 
 #[no_mangle]
@@ -338,6 +339,25 @@ extern "C" fn dns_server_log(
     }
 }
 
+#[no_mangle]
+extern "C" fn dns_server_audit_log(msg: *const c_char, msg_len: i32) {
+    unsafe {
+        let plugin_addr = std::ptr::addr_of_mut!(PLUGIN);
+        let ops = (*plugin_addr).ops.as_ref();
+        if let None = ops {
+            return;
+        }
+
+        let raw_msg = std::slice::from_raw_parts(msg as *const u8, msg_len as usize + 1);
+        let msg = std::ffi::CStr::from_bytes_with_nul_unchecked(raw_msg)
+            .to_string_lossy()
+            .into_owned();
+
+        let ops = ops.unwrap();
+        ops.server_audit_log(msg.as_str(), msg_len as i32);
+    }
+}
+
 #[no_mangle]
 extern "C" fn dns_plugin_init(plugin: *mut smartdns_c::dns_plugin) -> i32 {
     unsafe {
@@ -728,6 +748,7 @@ unsafe impl Send for DnsUpstreamServer {}
 pub trait SmartdnsOperations {
     fn server_query_complete(&self, request: Box<dyn DnsRequest>);
     fn server_log(&self, level: LogLevel, msg: &str, msg_len: i32);
+    fn server_audit_log(&self, msg: &str, msg_len: i32);
     fn server_init(&mut self, args: &Vec<String>) -> Result<(), Box<dyn Error>>;
     fn server_exit(&mut self);
 }

+ 18 - 0
plugin/smartdns-ui/tests/restapi_test.rs

@@ -299,6 +299,24 @@ fn test_rest_api_get_domain() {
     assert_eq!(result[0].domain, "100.com");
 }
 
+#[test]
+fn test_rest_api_audit_log_stream() {
+    let mut server = common::TestServer::new();
+    server.set_log_level(LogLevel::DEBUG);
+    assert!(server.start().is_ok());
+
+    let mut client = common::TestClient::new(&server.get_host());
+    let res = client.login("admin", "password");
+    assert!(res.is_ok());
+    let socket = client.websocket("/api/log/audit/stream");
+    assert!(socket.is_ok());
+    let mut socket = socket.unwrap();
+
+    _ = socket.send(tungstenite::Message::Text("aaaa".to_string()));
+    _ = socket.close(None);
+}
+
+
 #[test]
 fn test_rest_api_get_by_id() {
     let mut server = common::TestServer::new();

+ 22 - 0
src/dns_plugin.c

@@ -132,6 +132,28 @@ void smartdns_plugin_func_server_log_callback(smartdns_log_level level, const ch
 	return;
 }
 
+void smartdns_plugin_func_server_audit_log_callback(const char *msg, int msg_len)
+{
+	struct dns_plugin_ops *chain = NULL;
+
+	if (unlikely(is_plugin_init == 0)) {
+		return;
+	}
+
+	pthread_rwlock_rdlock(&plugins.lock);
+	list_for_each_entry(chain, &plugins.list, list)
+	{
+		if (!chain->ops.server_audit_log) {
+			continue;
+		}
+
+		chain->ops.server_audit_log(msg, msg_len);
+	}
+	pthread_rwlock_unlock(&plugins.lock);
+
+	return;
+}
+
 int smartdns_operations_register(const struct smartdns_operations *operations)
 {
 	struct dns_plugin_ops *chain = NULL;

+ 15 - 3
src/dns_server/audit.c

@@ -18,6 +18,9 @@
 
 #include "audit.h"
 #include "dns_server.h"
+
+#include "smartdns/dns_plugin.h"
+
 #include <syslog.h>
 
 static tlog_log *dns_audit;
@@ -154,6 +157,17 @@ static int _dns_server_audit_syslog(struct tlog_log *log, const char *buff, int
 	return 0;
 }
 
+static int _dns_server_audit_output_callback(struct tlog_log *log, const char *buff, int bufflen)
+{
+	smartdns_plugin_func_server_audit_log_callback(buff, bufflen);
+
+	if (dns_conf.audit_syslog) {
+		return _dns_server_audit_syslog(log, buff, bufflen);
+	}
+
+	return tlog_write(log, buff, bufflen);
+}
+
 int _dns_server_audit_init(void)
 {
 	char *audit_file = SMARTDNS_AUDIT_FILE;
@@ -176,9 +190,7 @@ int _dns_server_audit_init(void)
 		return -1;
 	}
 
-	if (dns_conf.audit_syslog) {
-		tlog_reg_output_func(dns_audit, _dns_server_audit_syslog);
-	}
+	tlog_reg_output_func(dns_audit, _dns_server_audit_output_callback);
 
 	if (dns_conf.audit_file_mode > 0) {
 		tlog_set_permission(dns_audit, dns_conf.audit_file_mode, dns_conf.audit_file_mode);

+ 4 - 0
src/include/smartdns/dns_plugin.h

@@ -83,6 +83,8 @@ void smartdns_plugin_func_server_complete_request(struct dns_request *request);
 
 void smartdns_plugin_func_server_log_callback(smartdns_log_level level, const char *msg, int msg_len);
 
+void smartdns_plugin_func_server_audit_log_callback(const char *msg, int msg_len);
+
 struct smartdns_operations {
 	int (*server_recv)(struct dns_packet *packet, unsigned char *inpacket, int inpacket_len,
 					   struct sockaddr_storage *local, socklen_t local_len, struct sockaddr_storage *from,
@@ -90,6 +92,8 @@ struct smartdns_operations {
 	void (*server_query_complete)(struct dns_request *request);
 
 	void (*server_log)(smartdns_log_level level, const char *msg, int msg_len);
+
+	void (*server_audit_log)(const char *msg, int msg_len);
 };
 
 int smartdns_operations_register(const struct smartdns_operations *operations);