Explorar el Código

(jsrt) opt.: script load strategy

lollipopkit🏳️‍⚧️ hace 5 meses
padre
commit
155f67e960

+ 2 - 5
.env.example

@@ -82,8 +82,5 @@
 # JS_MAX_VM_COUNT=
 # 运行超时时间(单位:秒,默认:5)
 # JS_SCRIPT_TIMEOUT=
-# 预处理脚本路径(默认:scripts/pre_process.js)
-# JS_PREPROCESS_SCRIPT_PATH=
-# 后处理脚本路径(默认:scripts/post_process.js)
-# JS_POSTPROCESS_SCRIPT_PATH=
-
+# 脚本文件夹(默认:scripts/)
+# JS_SCRIPT_PATH=

+ 1 - 1
docker-compose.yml

@@ -11,6 +11,7 @@ services:
     volumes:
       - ./data:/data
       - ./logs:/app/logs
+      - ${JS_SCRIPT_DIR:-./scripts}:/app/scripts
     environment:
       - SQL_DSN=root:123456@tcp(mysql:3306)/new-api  # Point to the mysql service
       - REDIS_CONN_STRING=redis://redis
@@ -21,7 +22,6 @@ services:
     #      - NODE_TYPE=slave  # Uncomment for slave node in multi-node deployment
     #      - SYNC_FREQUENCY=60  # Uncomment if regular database syncing is needed
     #      - FRONTEND_BASE_URL=https://openai.justsong.cn  # Uncomment for multi-node deployment with front-end URL
-
     depends_on:
       - redis
       - mysql

+ 12 - 19
middleware/jsrt/cfg.go

@@ -8,12 +8,11 @@ import (
 
 // Runtime 配置
 type JSRuntimeConfig struct {
-	Enabled        bool          `json:"enabled"`
-	MaxVMCount     int           `json:"max_vm_count"`
-	ScriptTimeout  time.Duration `json:"script_timeout"`
-	PreScriptPath  string        `json:"pre_script_path"`
-	PostScriptPath string        `json:"post_script_path"`
-	FetchTimeout   time.Duration `json:"fetch_timeout"`
+	Enabled       bool          `json:"enabled"`
+	MaxVMCount    int           `json:"max_vm_count"`
+	ScriptTimeout time.Duration `json:"script_timeout"`
+	ScriptDir     string        `json:"script_dir"`
+	FetchTimeout  time.Duration `json:"fetch_timeout"`
 }
 
 var (
@@ -21,11 +20,10 @@ var (
 )
 
 const (
-	defaultPreScriptPath  = "scripts/pre_process.js"
-	defaultPostScriptPath = "scripts/post_process.js"
-	defaultScriptTimeout  = 5 * time.Second
-	defaultFetchTimeout   = 10 * time.Second
-	defaultMaxVMCount     = 8
+	defaultScriptDir     = "scripts/"
+	defaultScriptTimeout = 5 * time.Second
+	defaultFetchTimeout  = 10 * time.Second
+	defaultMaxVMCount    = 8
 )
 
 func loadCfg() {
@@ -57,13 +55,8 @@ func loadCfg() {
 		jsConfig.FetchTimeout = defaultFetchTimeout
 	}
 
-	jsConfig.PreScriptPath = os.Getenv("JS_PREPROCESS_SCRIPT_PATH")
-	if jsConfig.PreScriptPath == "" {
-		jsConfig.PreScriptPath = defaultPreScriptPath
-	}
-
-	jsConfig.PostScriptPath = os.Getenv("JS_POSTPROCESS_SCRIPT_PATH")
-	if jsConfig.PostScriptPath == "" {
-		jsConfig.PostScriptPath = defaultPostScriptPath
+	jsConfig.ScriptDir = os.Getenv("JS_SCRIPT_DIR")
+	if jsConfig.ScriptDir == "" {
+		jsConfig.ScriptDir = defaultScriptDir
 	}
 }

+ 69 - 28
middleware/jsrt/jsrt.go

@@ -10,6 +10,7 @@ import (
 	"one-api/common"
 	"one-api/model"
 	"os"
+	"path/filepath"
 	"strings"
 	"sync"
 	"time"
@@ -23,7 +24,7 @@ type JSRuntimePool struct {
 	pool       chan *goja.Runtime
 	maxSize    int
 	createFunc func() *goja.Runtime
-	scripts    map[string]string
+	scripts    string
 	mu         sync.RWMutex
 	httpClient *http.Client
 }
@@ -50,7 +51,7 @@ func NewJSRuntimePool(maxSize int) *JSRuntimePool {
 	pool := &JSRuntimePool{
 		pool:       make(chan *goja.Runtime, maxSize),
 		maxSize:    maxSize,
-		scripts:    make(map[string]string),
+		scripts:    "",
 		httpClient: httpClient,
 	}
 
@@ -161,31 +162,76 @@ func (p *JSRuntimePool) loadScripts(vm *goja.Runtime) {
 	p.mu.RLock()
 	defer p.mu.RUnlock()
 
-	// 加载预处理脚本
-	if script, exists := p.scripts["pre"]; exists {
-		if _, err := vm.RunString(script); err != nil {
-			common.SysError("Failed to load pre_process.js: " + err.Error())
+	// 如果已经缓存了合并的脚本,直接使用
+	if p.scripts != "" {
+		if _, err := vm.RunString(p.scripts); err != nil {
+			common.SysError("Failed to load cached scripts: " + err.Error())
 		}
-	} else if preScript, err := os.ReadFile(jsConfig.PreScriptPath); err == nil {
-		p.scripts["pre"] = string(preScript)
-		if _, err = vm.RunString(string(preScript)); err != nil {
-			common.SysError("Failed to load pre_process.js: " + err.Error())
-		} else {
-			common.SysLog("Loaded pre_process.js")
+		return
+	}
+
+	// 首次加载时,读取 scripts/ 文件夹中的所有脚本
+	p.mu.RUnlock()
+	p.mu.Lock()
+	defer func() {
+		p.mu.Unlock()
+		p.mu.RLock()
+	}()
+
+	if p.scripts != "" {
+		if _, err := vm.RunString(p.scripts); err != nil {
+			common.SysError("Failed to load cached scripts: " + err.Error())
 		}
+		return
+	}
+
+	// 读取所有脚本文件
+	var combinedScript strings.Builder
+	scriptDir := jsConfig.ScriptDir
+
+	// 检查目录是否存在
+	if _, err := os.Stat(scriptDir); os.IsNotExist(err) {
+		common.SysLog("Scripts directory does not exist: " + scriptDir)
+		return
 	}
 
-	// 加载后处理脚本
-	if script, exists := p.scripts["post"]; exists {
-		if _, err := vm.RunString(script); err != nil {
-			common.SysError("Failed to load post_process.js: " + err.Error())
+	// 读取目录中的所有 .js 文件
+	files, err := filepath.Glob(filepath.Join(scriptDir, "*.js"))
+	if err != nil {
+		common.SysError("Failed to read scripts directory: " + err.Error())
+		return
+	}
+
+	if len(files) == 0 {
+		common.SysLog("No JavaScript files found in: " + scriptDir)
+		return
+	}
+
+	// 按文件名排序以确保加载顺序一致
+	for _, file := range files {
+		content, err := os.ReadFile(file)
+		if err != nil {
+			common.SysError("Failed to read script file " + file + ": " + err.Error())
+			continue
 		}
-	} else if postScript, err := os.ReadFile(jsConfig.PostScriptPath); err == nil {
-		p.scripts["post"] = string(postScript)
-		if _, err = vm.RunString(string(postScript)); err != nil {
-			common.SysError("Failed to load post_process.js: " + err.Error())
+
+		// 添加文件注释和内容
+		combinedScript.WriteString("// File: " + filepath.Base(file) + "\n")
+		combinedScript.WriteString(string(content))
+		combinedScript.WriteString("\n\n")
+
+		common.SysLog("Loaded script: " + filepath.Base(file))
+	}
+
+	// 缓存合并后的脚本
+	p.scripts = combinedScript.String()
+
+	// 执行脚本
+	if p.scripts != "" {
+		if _, err := vm.RunString(p.scripts); err != nil {
+			common.SysError("Failed to load combined scripts: " + err.Error())
 		} else {
-			common.SysLog("Loaded post_process.js")
+			common.SysLog("Successfully loaded and combined all JavaScript files from: " + scriptDir)
 		}
 	}
 }
@@ -195,7 +241,7 @@ func (p *JSRuntimePool) ReloadScripts() {
 	defer p.mu.Unlock()
 
 	// 清空缓存的脚本
-	p.scripts = make(map[string]string)
+	p.scripts = ""
 
 	// 清空VM池,强制重新创建
 	for {
@@ -227,7 +273,7 @@ func validateGinContext(c *gin.Context) error {
 	return nil
 }
 
-func (p *JSRuntimePool) executeWithTimeout(vm *goja.Runtime, fn func() (goja.Value, error)) (goja.Value, error) {
+func (p *JSRuntimePool) executeWithTimeout(_ *goja.Runtime, fn func() (goja.Value, error)) (goja.Value, error) {
 	type result struct {
 		value goja.Value
 		err   error
@@ -451,16 +497,13 @@ func JSRuntimeMiddleware() *gin.HandlerFunc {
 		start := time.Now()
 
 		// 预处理
-		common.SysLog("JS Runtime PreProcessing Request: " + c.Request.Method + " " + c.Request.URL.String())
 		if err := pool.PreProcessRequest(c); err != nil {
 			common.SysError("JS Runtime PreProcess Error: " + err.Error())
 			return
 		}
-		common.SysLog("JS Runtime PreProcessing Completed")
 
 		// 后处理
 		if pool.hasPostProcessFunction() {
-			common.SysLog("JS Runtime PostProcessing Response")
 			writer := newResponseWriter(c.Writer)
 			c.Writer = writer
 
@@ -495,11 +538,9 @@ func JSRuntimeMiddleware() *gin.HandlerFunc {
 			} else {
 				// 没有响应体时,恢复原始writer
 				c.Writer = writer.ResponseWriter
-				common.SysLog("JS Runtime PostProcessing Completed with no body")
 			}
 		} else {
 			c.Next()
-			common.SysLog("JS Runtime PostProcessing Skipped: No postProcessResponse function defined")
 		}
 
 		// 记录处理时间

+ 15 - 0
scripts/01_utils.js

@@ -0,0 +1,15 @@
+// Utility functions for JavaScript runtime
+
+function logWithTimestamp(message) {
+    const timestamp = new Date().toISOString();
+    console.log(`[${timestamp}] ${message}`);
+}
+
+function safeJsonParse(str, defaultValue = null) {
+    try {
+        return JSON.parse(str);
+    } catch (e) {
+        console.error('JSON parse error:', e.message);
+        return defaultValue;
+    }
+}

+ 5 - 0
scripts/02_pre_process.js

@@ -0,0 +1,5 @@
+// Pre-processing function for incoming requests
+
+function preProcessRequest(req) {
+    logWithTimestamp('Pre-processing request: ' + req.method + ' ' + req.url);
+}

+ 5 - 0
scripts/03_post_process.js

@@ -0,0 +1,5 @@
+// Post-processing function for outgoing responses
+
+function postProcessResponse(req, res) {
+    logWithTimestamp('Post-processing response with status: ' + res.statusCode);
+}

+ 67 - 61
docs/jsrt/JS_RUNTIME.md → scripts/README.md

@@ -1,73 +1,79 @@
-# JavaScript Runtime 中间件
+# JavaScript Runtime Scripts
+
+本目录包含 JavaScript Runtime 中间件使用的脚本文件。
+
+## 脚本加载
+
+- 系统会自动读取 `scripts/` 目录下的所有 `.js` 文件
+- 脚本按文件名字母顺序加载
+- 建议使用数字前缀来控制加载顺序(如:`01_utils.js`, `02_pre_process.js`)
+- 所有脚本会被合并到一个 JavaScript 运行时环境中
 
 ## 配置
 
-将 JavaScript 脚本放置在项目根目录的 `scripts/` 文件夹中:
-
-- `scripts/pre_process.js` - 请求预处理脚本
-- `scripts/post_process.js` - 响应后处理脚本
-
-## API 参考
-
-### 预处理函数
-
-```javascript
-function preProcessRequest(req) {
-    // req 包含以下属性:
-    // - method: 请求方法 (GET, POST, etc.)
-    // - url: 请求URL
-    // - headers: 请求头 (object)
-    // - body: 请求体 (object/string/ArrayBuffer)
-    // - remoteIP: 客户端IP
-    // - extra: 额外数据 (object)
-    
-    // 返回值:
-    // - undefined: 继续正常处理
-    // - object: 修改请求或阻止请求
-    //   - block: true/false - 是否阻止请求
-    //   - statusCode: 状态码
-    //   - message: 错误消息
-    //   - headers: 修改的请求头 (object)
-    //   - body: 修改的请求体
-}
-```
+通过环境变量配置:
 
-### 后处理函数
-
-```javascript
-function postProcessResponse(req, response) {
-    // ctx: 请求上下文 (同预处理)
-    // response 包含以下属性:
-    // - statusCode: 响应状态码
-    // - headers: 响应头 (object)
-    // - body: 响应体
-    
-    // 返回值:
-    // - undefined: 保持原始响应
-    // - object: 修改响应
-    //   - statusCode: 新的状态码
-    //   - headers: 修改的响应头
-    //   - body: 修改的响应体
-}
-```
+- `JS_RUNTIME_ENABLED=true` - 启用 JavaScript Runtime
+- `JS_SCRIPT_DIR=scripts/` - 脚本目录路径
+- `JS_MAX_VM_COUNT=8` - 最大虚拟机数量
+- `JS_SCRIPT_TIMEOUT=5s` - 脚本执行超时时间
+- `JS_FETCH_TIMEOUT=10s` - HTTP 请求超时时间
 
-### 数据库对象
+更多的详细配置可以在 `.env.example` 文件中找到,并在实际使用时重命名为 `.env`。
 
-```javascript
-// 查询数据库
-var results = db.Query("SELECT * FROM users WHERE id = ?", 123);
+## 必需的函数
 
-// 执行 SQL
-var result = db.Exec("UPDATE users SET last_login = NOW() WHERE id = ?", 123);
-// result 包含: { rowsAffected: number, error: any }
-```
+脚本中必须定义以下两个函数:
+
+### 1. preProcessRequest(req)
+
+在请求被转发到后端 API 之前调用。
+
+**参数:**
+
+- `req`: 请求对象,包含 `method`, `url`, `headers`, `body` 等属性
+
+**返回值:**
+返回一个对象,可包含以下属性:
+
+- `block`: boolean - 是否阻止请求继续执行
+- `statusCode`: number - 阻止请求时返回的状态码
+- `message`: string - 阻止请求时返回的错误消息
+- `headers`: object - 要修改或添加的请求头
+- `body`: any - 修改后的请求体
+
+### 2. postProcessResponse(req, res)
+
+在响应返回给客户端之前调用。
+
+**参数:**
+
+- `req`: 原始请求对象
+- `res`: 响应对象,包含 `statusCode`, `headers`, `body` 等属性
+
+**返回值:**
+返回一个对象,可包含以下属性:
+
+- `statusCode`: number - 修改后的状态码
+- `headers`: object - 要修改或添加的响应头
+- `body`: string - 修改后的响应体
+
+## 可用的全局对象和函数
+
+- `console.log()`, `console.error()`, `console.warn()` - 日志输出
+- `JSON.parse()`, `JSON.stringify()` - JSON 处理
+- `fetch(url, options)` - HTTP 请求
+- `db` - 主数据库连接
+- `logdb` - 日志数据库连接
+- `setTimeout(fn, delay)` - 定时器
+
+## 示例脚本
 
-### 全局对象
+参考现有的示例脚本:
 
-- `console.log()` - 输出日志
-- `console.error()` - 输出错误日志
-- `JSON.parse()` - 解析 JSON
-- `JSON.stringify()` - 序列化为 JSON
+- `01_utils.js` - 工具函数
+- `02_pre_process.js` - 请求预处理
+- `03_post_process.js` - 响应后处理
 
 ## 使用示例
 

+ 0 - 9
scripts/post_process.js

@@ -1,9 +0,0 @@
-// 后处理
-// 在请求处理完成后执行的函数
-//
-// @param {Object} ctx - 请求上下文对象
-// @param {Object} response - 响应对象(包含状态码、头部和正文等)
-// @returns {Object|undefined} - 返回修改后的响应对象或 undefined
-function postProcessResponse(ctx, response) {
-    return undefined;
-}

+ 0 - 164
scripts/pre_process.js

@@ -1,164 +0,0 @@
-// 请求预处理
-// 在请求被处理之前执行的函数
-//
-// @param {Object} req - 请求对象
-// @returns {Object|undefined} - 返回修改后的请求对象或 undefined
-// 
-// 参考: [JS Rt](./middleware/jsrt/req.go) 里的 `JSReq`
-function preProcessRequest(req) {
-    // 例子:基于数据库的速率限制
-    // if (req.url.includes("/v1/chat/completions")) {
-    //     try {
-    //         // Check recent requests from this IP
-    //         var recentRequests = db.query(
-    //             "SELECT COUNT(*) as count FROM logs WHERE created_at > ? AND ip = ?",
-    //             Math.floor(Date.now() / 1000) - 60, // last minute
-    //             req.remoteIP
-    //         );
-
-    //         if (recentRequests && recentRequests.length > 0 && recentRequests[0].count > 10) {
-    //             console.log("速率限制 IP:", req.remoteIP);
-    //             return {
-    //                 block: true,
-    //                 statusCode: 429,
-    //                 message: "超过速率限制"
-    //             };
-    //         }
-    //     } catch (e) {
-    //         console.error("Ratelimit 数据库错误:", e);
-    //     }
-    // }
-
-    // 例子:修改请求
-    // if (req.url.includes("/chat/completions")) {
-    //     try {
-    //         var bodyObj = req.body;
-
-    //         let firstMsg = { // 需要新建一个对象,不能修改原有对象
-    //             role: "user",
-    //             content: "喵呜🐱~嘻嘻"
-    //         };
-    //         bodyObj.messages[0] = firstMsg;
-    //         console.log("Modified first message:", JSON.stringify(firstMsg));
-    //         console.log("Modified body:", JSON.stringify(bodyObj));
-
-    //         return {
-    //             body: bodyObj,
-    //             headers: {
-    //                 ...req.headers,
-    //                 "X-Modified-Body": "true"
-    //             }
-    //         };
-    //     } catch (e) {
-    //         console.error("Failed to modify request body:", {
-    //             message: e.message,
-    //             stack: e.stack,
-    //             bodyType: typeof req.body,
-    //             url: req.url
-    //         });
-    //     }
-    // }
-
-    // 例子:读取最近一条日志,新增 jsrt 日志,并输出日志总数
-    // try {
-    //     // 1. 读取最近一条日志
-    //     var recentLogs = logdb.query(
-    //         "SELECT id, user_id, username, content, created_at FROM logs ORDER BY id DESC LIMIT 1"
-    //     );
-
-    //     var recentLog = null;
-    //     if (recentLogs && recentLogs.length > 0) {
-    //         recentLog = recentLogs[0];
-    //         console.log("最近一条日志:", JSON.stringify(recentLog));
-    //     }
-
-    //     // 2. 新增一条 jsrt 日志
-    //     var currentTimestamp = Math.floor(Date.now() / 1000);
-    //     var jsrtLogContent = "JSRT 预处理中间件执行 - " + req.URL + " - " + new Date().toISOString();
-
-    //     var insertResult = logdb.exec(
-    //         "INSERT INTO logs (user_id, username, created_at, type, content) VALUES (?, ?, ?, ?, ?)",
-    //         req.UserID || 0,
-    //         req.Username || "jsrt-system",
-    //         currentTimestamp,
-    //         4, // LogTypeSystem
-    //         jsrtLogContent
-    //     );
-
-    //     if (insertResult.error) {
-    //         console.error("插入 JSRT 日志失败:", insertResult.error);
-    //     } else {
-    //         console.log("成功插入 JSRT 日志,影响行数:", insertResult.rowsAffected);
-    //     }
-
-    //     // 3. 输出日志总数
-    //     var totalLogsResult = logdb.query("SELECT COUNT(*) as total FROM logs");
-    //     var totalLogs = 0;
-    //     if (totalLogsResult && totalLogsResult.length > 0) {
-    //         totalLogs = totalLogsResult[0].total;
-    //     }
-
-    //     console.log("当前日志总数:", totalLogs);
-    //     console.log("JSRT 日志管理示例执行完成");
-
-    // } catch (e) {
-    //     console.error("JSRT 日志管理示例执行失败:", {
-    //         message: e.message,
-    //         stack: e.stack,
-    //         url: req.URL
-    //     });
-    // }
-
-    // 例子:使用 fetch 调用外部 API
-    // if (req.url.includes("/api/uptime/status")) {
-    //     try {
-    //         // 使用 httpbin.org/ip 测试 fetch 功能
-    //         var response = fetch("https://httpbin.org/ip", {
-    //             method: "GET",
-    //             timeout: 5, // 5秒超时
-    //             headers: {
-    //                 "User-Agent": "OneAPI-JSRT/1.0"
-    //             }
-    //         });
-
-    //         if (response.Error.length === 0) {
-    //             // 解析响应体
-    //             var ipData = JSON.parse(response.Body);
-
-    //             // 可以根据获取到的 IP 信息进行后续处理
-    //             if (ipData.origin) {
-    //                 console.log("外部 IP 地址:", ipData.origin);
-
-    //                 // 示例:记录 IP 信息到数据库
-    //                 var currentTimestamp = Math.floor(Date.now() / 1000);
-    //                 var logContent = "Fetch 示例 - 外部 IP: " + ipData.origin + " - " + new Date().toISOString();
-
-    //                 var insertResult = logdb.exec(
-    //                     "INSERT INTO logs (user_id, username, created_at, type, content) VALUES (?, ?, ?, ?, ?)",
-    //                     0,
-    //                     "jsrt-fetch",
-    //                     currentTimestamp,
-    //                     4, // LogTypeSystem
-    //                     logContent
-    //                 );
-
-    //                 if (insertResult.error) {
-    //                     console.error("记录 IP 信息失败:", insertResult.error);
-    //                 } else {
-    //                     console.log("成功记录 IP 信息到数据库");
-    //                 }
-    //             }
-    //         } else {
-    //             console.error("Fetch 失败 ", response.Status, " ", response.Error);
-    //         }
-    //     } catch (e) {
-    //         console.error("Fetch 失败:", {
-    //             message: e.message,
-    //             stack: e.stack,
-    //             url: req.url
-    //         });
-    //     }
-    // }
-
-    return undefined; // 跳过处理,继续执行下一个中间件或路由
-}