Browse Source

plugin: add plugin version check.

Nick Peng 1 month ago
parent
commit
8d197f1435

+ 5 - 0
plugin/demo/demo.c

@@ -47,3 +47,8 @@ int dns_plugin_exit(struct dns_plugin *plugin)
 	smartdns_operations_unregister(&demo_ops);
 	return 0;
 }
+
+int dns_plugin_api_version(void)
+{
+	return SMARTDNS_PLUGIN_API_VERSION;
+}

+ 15 - 1
plugin/smartdns-ui/src/http_server.rs

@@ -78,7 +78,6 @@ pub struct HttpServerConfig {
 
 impl HttpServerConfig {
     pub fn new() -> Self {
-
         let host_ip = if utils::is_ipv6_supported() {
             HTTP_SERVER_DEFAULT_IPV6.to_string()
         } else {
@@ -415,6 +414,21 @@ impl HttpServer {
         Err("Plugin is not set".into())
     }
 
+    pub fn is_https_server(&self) -> bool {
+        let http_ip = self.get_conf().http_ip;
+        if http_ip.parse::<url::Url>().is_err() {
+            return false;
+        }
+
+        let binding = http_ip.parse::<url::Url>().unwrap();
+        let scheme = binding.scheme();
+        if scheme == "https" {
+            return true;
+        }
+
+        false
+    }
+
     pub fn get_data_server(&self) -> Arc<DataServer> {
         self.get_plugin().unwrap().get_data_server()
     }

+ 12 - 2
plugin/smartdns-ui/src/http_server_api.rs

@@ -234,6 +234,7 @@ impl API {
         _param: APIRouteParam,
         req: Request<body::Incoming>,
     ) -> Result<Response<Full<Bytes>>, HttpError> {
+        let is_https = this.is_https_server();
         let token = HttpServer::get_token_from_header(&req)?;
         let unauth_response =
             || API::response_error(StatusCode::UNAUTHORIZED, "Incorrect username or password.");
@@ -269,11 +270,15 @@ impl API {
 
         let cookie_token = format!("Bearer {}", token_new.token);
         let token_urlencode = urlencoding::encode(cookie_token.as_str());
-        let cookie = format!(
+        let mut cookie = format!(
             "token={}; HttpOnly; Max-Age={}; Path={}",
             token_urlencode, token_new.expire, REST_API_PATH
         );
 
+        if is_https && conf.enable_cors {
+            cookie.push_str("; SameSite=None; Secure");
+        }
+
         resp.as_mut()
             .unwrap()
             .headers_mut()
@@ -294,6 +299,7 @@ impl API {
         _param: APIRouteParam,
         req: Request<body::Incoming>,
     ) -> Result<Response<Full<Bytes>>, HttpError> {
+        let is_https = this.is_https_server();
         let whole_body = String::from_utf8(req.into_body().collect().await?.to_bytes().into())?;
         let userinfo = api_msg_parse_auth(whole_body.as_str());
         if let Err(e) = userinfo {
@@ -335,11 +341,15 @@ impl API {
 
         let cookie_token = format!("Bearer {}", token.token);
         let token_urlencode = urlencoding::encode(cookie_token.as_str());
-        let cookie = format!(
+        let mut cookie = format!(
             "token={}; HttpOnly; Max-Age={}; Path={}",
             token_urlencode, token.expire, REST_API_PATH
         );
 
+        if is_https && conf.enable_cors {
+            cookie.push_str("; SameSite=None; Secure");
+        }
+
         resp.as_mut()
             .unwrap()
             .headers_mut()

+ 5 - 0
plugin/smartdns-ui/src/smartdns.rs

@@ -389,6 +389,11 @@ extern "C" fn dns_plugin_exit(_plugin: *mut smartdns_c::dns_plugin) -> i32 {
     return 0;
 }
 
+#[no_mangle]
+extern "C" fn dns_plugin_api_version() -> u32 {
+    smartdns_c::SMARTDNS_PLUGIN_API_VERSION
+}
+
 pub trait DnsRequest: Send + Sync {
     fn get_group_name(&self) -> String;
     fn get_domain(&self) -> String;

+ 33 - 1
src/dns_plugin.c

@@ -28,10 +28,10 @@
 #include "smartdns/util.h"
 #include <dlfcn.h>
 #include <limits.h>
+#include <pthread.h>
 #include <stdio.h>
 #include <stdlib.h>
 #include <string.h>
-#include <pthread.h>
 
 struct dns_plugin_ops {
 	struct list_head list;
@@ -213,8 +213,10 @@ static struct dns_plugin *_dns_plugin_get(const char *plugin_file)
 static int _dns_plugin_load_library(struct dns_plugin *plugin)
 {
 	void *handle = NULL;
+	dns_plugin_api_version_func version_func = NULL;
 	dns_plugin_init_func init_func = NULL;
 	dns_plugin_exit_func exit_func = NULL;
+	unsigned int api_version = 0;
 
 	tlog(TLOG_DEBUG, "load plugin %s", plugin->file);
 
@@ -224,6 +226,14 @@ static int _dns_plugin_load_library(struct dns_plugin *plugin)
 		return -1;
 	}
 
+	version_func = (dns_plugin_api_version_func)dlsym(handle, DNS_PLUGIN_API_VERSION_FUNC);
+	if (!version_func) {
+		tlog(TLOG_ERROR,
+			 "plugin %s has no api version function, maybe an old version plugin, please download latest version.",
+			 plugin->file);
+		goto errout;
+	}
+
 	init_func = (dns_plugin_init_func)dlsym(handle, DNS_PLUGIN_INIT_FUNC);
 	if (!init_func) {
 		tlog(TLOG_ERROR, "load plugin failed: %s", dlerror());
@@ -238,6 +248,28 @@ static int _dns_plugin_load_library(struct dns_plugin *plugin)
 		goto errout;
 	}
 
+	api_version = version_func();
+	if (SMARTDNS_PLUGIN_API_VERSION_MAJOR(api_version) !=
+		SMARTDNS_PLUGIN_API_VERSION_MAJOR(SMARTDNS_PLUGIN_API_VERSION)) {
+		tlog(TLOG_ERROR,
+			 "plugin %s api version %u.%u not compatible with smartdns api version %u.%u, please download matching "
+			 "version.",
+			 plugin->file, SMARTDNS_PLUGIN_API_VERSION_MAJOR(api_version),
+			 SMARTDNS_PLUGIN_API_VERSION_MINOR(api_version),
+			 SMARTDNS_PLUGIN_API_VERSION_MAJOR(SMARTDNS_PLUGIN_API_VERSION),
+			 SMARTDNS_PLUGIN_API_VERSION_MINOR(SMARTDNS_PLUGIN_API_VERSION));
+		goto errout;
+	} else if (SMARTDNS_PLUGIN_API_VERSION_MINOR(api_version) >
+			   SMARTDNS_PLUGIN_API_VERSION_MINOR(SMARTDNS_PLUGIN_API_VERSION)) {
+		tlog(TLOG_ERROR,
+			 "plugin %s api version %u.%u is newer than smartdns api version %u.%u, please download matching version.",
+			 plugin->file, SMARTDNS_PLUGIN_API_VERSION_MAJOR(api_version),
+			 SMARTDNS_PLUGIN_API_VERSION_MINOR(api_version),
+			 SMARTDNS_PLUGIN_API_VERSION_MAJOR(SMARTDNS_PLUGIN_API_VERSION),
+			 SMARTDNS_PLUGIN_API_VERSION_MINOR(SMARTDNS_PLUGIN_API_VERSION));
+		goto errout;
+	}
+
 	conf_getopt_reset();
 	int ret = init_func(plugin);
 	conf_getopt_reset();

+ 5 - 0
src/include/smartdns/dns_plugin.h

@@ -28,6 +28,10 @@ extern "C" {
 
 #define DNS_PLUGIN_INIT_FUNC "dns_plugin_init"
 #define DNS_PLUGIN_EXIT_FUNC "dns_plugin_exit"
+#define DNS_PLUGIN_API_VERSION_FUNC "dns_plugin_api_version"
+#define SMARTDNS_PLUGIN_API_VERSION 0x00000101
+#define SMARTDNS_PLUGIN_API_VERSION_MAJOR(v) ((v >> 8) & 0xFFFFFF)
+#define SMARTDNS_PLUGIN_API_VERSION_MINOR(v) (v & 0xFF)
 
 struct dns_plugin;
 struct dns_plugin_ops;
@@ -35,6 +39,7 @@ struct dns_request;
 
 typedef int (*dns_plugin_init_func)(struct dns_plugin *plugin);
 typedef int (*dns_plugin_exit_func)(struct dns_plugin *plugin);
+typedef unsigned int (*dns_plugin_api_version_func)(void);
 
 struct dns_plugin;
 int dns_plugin_init(struct dns_plugin *plugin);

+ 69 - 30
src/tlog.c

@@ -518,6 +518,64 @@ static int _tlog_need_drop(struct tlog_log *log)
     return ret;
 }
 
+static int _tlog_write_screen(struct tlog_log *log, struct tlog_loginfo *info, const char *buff, int bufflen)
+{
+    if (bufflen <= 0) {
+        return 0;
+    }
+
+    if (log->logscreen == 0) {
+        return 0;
+    }
+
+    if (info == NULL) {
+        return write(STDOUT_FILENO, buff, bufflen);;
+    }
+
+    return tlog_stdout_with_color(info->level, buff, bufflen);
+}
+
+static int _tlog_write_output_func(struct tlog_log *log, char *buff, int bufflen)
+{
+    if (log->logscreen && log != tlog.root) {
+        _tlog_write_screen(log, NULL, buff, bufflen);
+    }
+
+    if (log->output_func == NULL) {
+        return -1;
+    }
+
+    return log->output_func(log, buff, bufflen);
+}
+
+static void _tlog_output_warning(void)
+{
+    static int printed = 0;
+    if (printed) {
+        return;
+    }
+    printed = 1;
+    tlog_log *root = tlog.root;
+
+    const char warning_msg[] = ""
+    "TLOG ERROR: \n"
+    "  Do not call the tlog output function from within a registered tlog log output callback function.\n"
+    "  Recursively calling the log output function will cause tlog to fail to output logs and deadlock.\n";
+
+    if (root->logcount > 0 && root->logsize > 0 && root->logfile[0] != 0) {
+        int fd = open(root->logfile, O_APPEND | O_CREAT | O_WRONLY | O_CLOEXEC, root->file_perm);
+        if (fd >= 0) {
+            write(fd, warning_msg, sizeof(warning_msg) - 1);
+            close(fd);
+        }
+    }
+
+    /* if open log file failed, print to stderr */
+    fprintf(stderr, "\033[31;1m%s\033[0m\n", warning_msg);
+
+    return;
+}
+
 static int _tlog_vprintf(struct tlog_log *log, vprint_callback print_callback, void *userptr, const char *format, va_list ap)
 {
     int len;
@@ -554,6 +612,17 @@ static int _tlog_vprintf(struct tlog_log *log, vprint_callback print_callback, v
         buff[len - 5] = '.';
     }
 
+    /* 
+     Output log from tlog_worker thread context? this may crash from upper-level function.
+     Try call printf output log. 
+     */
+    if (tlog.tid == pthread_self()) {
+        _tlog_output_warning();
+        vprintf(format, ap);
+        printf("\n");
+        return -1;
+    }
+
     pthread_mutex_lock(&tlog.lock);
     do {
         if (log->end == log->start) {
@@ -1221,23 +1290,6 @@ static void _tlog_get_log_name_dir(struct tlog_log *log)
     pthread_mutex_unlock(&tlog.lock);
 }
 
-static int _tlog_write_screen(struct tlog_log *log, struct tlog_loginfo *info, const char *buff, int bufflen)
-{
-    if (bufflen <= 0) {
-        return 0;
-    }
-
-    if (log->logscreen == 0) {
-        return 0;
-    }
-
-    if (info == NULL) {
-        return write(STDOUT_FILENO, buff, bufflen);;
-    }
-
-    return tlog_stdout_with_color(info->level, buff, bufflen);
-}
-
 static int _tlog_write(struct tlog_log *log, const char *buff, int bufflen)
 {
     int len;
@@ -1516,19 +1568,6 @@ static void _tlog_wakeup_waiters(struct tlog_log *log)
     pthread_mutex_unlock(&log->lock);
 }
 
-static int _tlog_write_output_func(struct tlog_log *log, char *buff, int bufflen)
-{
-    if (log->logscreen && log != tlog.root) {
-        _tlog_write_screen(log, NULL, buff, bufflen);
-    }
-
-    if (log->output_func == NULL) {
-        return -1;
-    }
-
-    return log->output_func(log, buff, bufflen);
-}
-
 static void _tlog_write_one_segment_log(struct tlog_log *log, char *buff, int bufflen)
 {
     struct tlog_segment_head *segment_head = NULL;