|
|
@@ -1,37 +1,23 @@
|
|
|
-package middleware
|
|
|
+package jsrt
|
|
|
|
|
|
import (
|
|
|
"bytes"
|
|
|
- "context"
|
|
|
"crypto/tls"
|
|
|
"encoding/json"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
- "maps"
|
|
|
"net/http"
|
|
|
- "net/url"
|
|
|
"one-api/common"
|
|
|
"one-api/model"
|
|
|
"os"
|
|
|
- "strconv"
|
|
|
"strings"
|
|
|
"sync"
|
|
|
"time"
|
|
|
|
|
|
"github.com/dop251/goja"
|
|
|
"github.com/gin-gonic/gin"
|
|
|
- "gorm.io/gorm"
|
|
|
)
|
|
|
|
|
|
-/// Runtime 配置
|
|
|
-type JSRuntimeConfig struct {
|
|
|
- Enabled bool `json:"enabled"`
|
|
|
- MaxVMCount int `json:"max_vm_count"`
|
|
|
- ScriptTimeout time.Duration `json:"script_timeout"`
|
|
|
- PreScriptPath string `json:"pre_script_path"`
|
|
|
- PostScriptPath string `json:"post_script_path"`
|
|
|
- FetchTimeout time.Duration `json:"fetch_timeout"`
|
|
|
-}
|
|
|
|
|
|
/// 池化
|
|
|
type JSRuntimePool struct {
|
|
|
@@ -43,211 +29,11 @@ type JSRuntimePool struct {
|
|
|
httpClient *http.Client
|
|
|
}
|
|
|
|
|
|
-/// 上下文
|
|
|
-type JSContext struct {
|
|
|
- Method string `json:"method"`
|
|
|
- URL string `json:"url"`
|
|
|
- Headers map[string]string `json:"headers"`
|
|
|
- Body any `json:"body"`
|
|
|
- UserAgent string `json:"userAgent"`
|
|
|
- RemoteIP string `json:"remoteIP"`
|
|
|
- Extra map[string]any `json:"extra"`
|
|
|
-}
|
|
|
-
|
|
|
-type JSResponse struct {
|
|
|
- StatusCode int `json:"statusCode"`
|
|
|
- Headers map[string]string `json:"headers"`
|
|
|
- Body string `json:"body"`
|
|
|
-}
|
|
|
-
|
|
|
-type JSDatabase struct {
|
|
|
- db *gorm.DB
|
|
|
-}
|
|
|
-
|
|
|
-type JSFetchRequest struct {
|
|
|
- Method string `json:"method"`
|
|
|
- URL string `json:"url"`
|
|
|
- Headers map[string]string `json:"headers"`
|
|
|
- Body string `json:"body"`
|
|
|
- Timeout int `json:"timeout"`
|
|
|
-}
|
|
|
-
|
|
|
-type JSFetchResponse struct {
|
|
|
- Status int `json:"status"`
|
|
|
- StatusText string `json:"statusText"`
|
|
|
- Headers map[string]string `json:"headers"`
|
|
|
- Body string `json:"body"`
|
|
|
- OK bool `json:"ok"`
|
|
|
-}
|
|
|
-
|
|
|
-type responseWriter struct {
|
|
|
- gin.ResponseWriter
|
|
|
- body *bytes.Buffer
|
|
|
- statusCode int
|
|
|
- headerMap http.Header
|
|
|
- written bool
|
|
|
- mu sync.RWMutex
|
|
|
-}
|
|
|
-
|
|
|
var (
|
|
|
jsRuntimePool *JSRuntimePool
|
|
|
jsPoolOnce sync.Once
|
|
|
- jsConfig = JSRuntimeConfig{}
|
|
|
-)
|
|
|
-
|
|
|
-const (
|
|
|
- defaultPreScriptPath = "scripts/pre_process.js"
|
|
|
- defaultPostScriptPath = "scripts/post_process.js"
|
|
|
- defaultScriptTimeout = 5 * time.Second
|
|
|
- defaultFetchTimeout = 10 * time.Second
|
|
|
- defaultMaxVMCount = 8
|
|
|
)
|
|
|
|
|
|
-func init() {
|
|
|
- if enabled := os.Getenv("JS_RUNTIME_ENABLED"); enabled != "" {
|
|
|
- jsConfig.Enabled = enabled == "true"
|
|
|
- }
|
|
|
-
|
|
|
- if maxCount := os.Getenv("JS_MAX_VM_COUNT"); maxCount != "" {
|
|
|
- if count, err := strconv.Atoi(maxCount); err == nil && count > 0 {
|
|
|
- jsConfig.MaxVMCount = count
|
|
|
- }
|
|
|
- } else {
|
|
|
- jsConfig.MaxVMCount = defaultMaxVMCount
|
|
|
- }
|
|
|
-
|
|
|
- if timeout := os.Getenv("JS_SCRIPT_TIMEOUT"); timeout != "" {
|
|
|
- if t, err := time.ParseDuration(timeout + "s"); err == nil && t > 0 {
|
|
|
- jsConfig.ScriptTimeout = t
|
|
|
- }
|
|
|
- } else {
|
|
|
- jsConfig.ScriptTimeout = defaultScriptTimeout
|
|
|
- }
|
|
|
-
|
|
|
- if fetchTimeout := os.Getenv("JS_FETCH_TIMEOUT"); fetchTimeout != "" {
|
|
|
- if t, err := time.ParseDuration(fetchTimeout + "s"); err == nil && t > 0 {
|
|
|
- jsConfig.FetchTimeout = t
|
|
|
- }
|
|
|
- } else {
|
|
|
- jsConfig.FetchTimeout = defaultFetchTimeout
|
|
|
- }
|
|
|
-
|
|
|
- jsConfig.PreScriptPath = os.Getenv("JS_PREPROCESS_SCRIPT_PATH")
|
|
|
- if jsConfig.PreScriptPath == "" {
|
|
|
- jsConfig.PreScriptPath = defaultPreScriptPath
|
|
|
- }
|
|
|
-
|
|
|
- jsConfig.PostScriptPath = os.Getenv("JS_POSTPROCESS_SCRIPT_PATH")
|
|
|
- if jsConfig.PostScriptPath == "" {
|
|
|
- jsConfig.PostScriptPath = defaultPostScriptPath
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-func parseBodyByType(bodyBytes []byte, contentType string) any {
|
|
|
- if len(bodyBytes) == 0 {
|
|
|
- return ""
|
|
|
- }
|
|
|
-
|
|
|
- bodyStr := string(bodyBytes)
|
|
|
- contentLower := strings.ToLower(contentType)
|
|
|
-
|
|
|
- switch {
|
|
|
- case strings.Contains(contentLower, "application/json"):
|
|
|
- var jsonObj any
|
|
|
- if err := json.Unmarshal(bodyBytes, &jsonObj); err == nil {
|
|
|
- return jsonObj
|
|
|
- }
|
|
|
- return bodyStr
|
|
|
-
|
|
|
- case strings.Contains(contentLower, "application/x-www-form-urlencoded"):
|
|
|
- if values, err := url.ParseQuery(bodyStr); err == nil {
|
|
|
- result := make(map[string]string, len(values))
|
|
|
- for k, v := range values {
|
|
|
- if len(v) > 0 {
|
|
|
- result[k] = v[0]
|
|
|
- }
|
|
|
- }
|
|
|
- return result
|
|
|
- }
|
|
|
- return bodyStr
|
|
|
-
|
|
|
- case strings.Contains(contentLower, "multipart/form-data"):
|
|
|
- return bodyBytes
|
|
|
-
|
|
|
- case strings.Contains(contentLower, "text/"):
|
|
|
- return bodyStr
|
|
|
-
|
|
|
- default:
|
|
|
- // 尝试JSON解析
|
|
|
- var jsonObj any
|
|
|
- if json.Unmarshal(bodyBytes, &jsonObj) == nil {
|
|
|
- return jsonObj
|
|
|
- }
|
|
|
-
|
|
|
- // 尝试form解析
|
|
|
- if values, err := url.ParseQuery(bodyStr); err == nil && len(values) > 0 {
|
|
|
- result := make(map[string]string, len(values))
|
|
|
- for k, v := range values {
|
|
|
- if len(v) > 0 {
|
|
|
- result[k] = v[0]
|
|
|
- }
|
|
|
- }
|
|
|
- return result
|
|
|
- }
|
|
|
-
|
|
|
- return bodyStr
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-func createJSContext(c *gin.Context) *JSContext {
|
|
|
- var bodyBytes []byte
|
|
|
- if c.Request != nil && c.Request.Body != nil {
|
|
|
- bodyBytes, _ = io.ReadAll(c.Request.Body)
|
|
|
- c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
|
|
- }
|
|
|
-
|
|
|
- // headers map
|
|
|
- headers := make(map[string]string)
|
|
|
- if c.Request != nil && c.Request.Header != nil {
|
|
|
- for key, values := range c.Request.Header {
|
|
|
- if len(values) > 0 {
|
|
|
- headers[key] = values[0]
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- method := ""
|
|
|
- url := ""
|
|
|
- userAgent := ""
|
|
|
- remoteIP := ""
|
|
|
- contentType := ""
|
|
|
-
|
|
|
- if c.Request != nil {
|
|
|
- method = c.Request.Method
|
|
|
- if c.Request.URL != nil {
|
|
|
- url = c.Request.URL.String()
|
|
|
- }
|
|
|
- userAgent = c.Request.UserAgent()
|
|
|
- contentType = c.ContentType()
|
|
|
- }
|
|
|
-
|
|
|
- if c != nil {
|
|
|
- remoteIP = c.ClientIP()
|
|
|
- }
|
|
|
-
|
|
|
- parsedBody := parseBodyByType(bodyBytes, contentType)
|
|
|
-
|
|
|
- return &JSContext{
|
|
|
- Method: method,
|
|
|
- URL: url,
|
|
|
- Headers: headers,
|
|
|
- Body: parsedBody,
|
|
|
- UserAgent: userAgent,
|
|
|
- RemoteIP: remoteIP,
|
|
|
- Extra: make(map[string]any),
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
func NewJSRuntimePool(maxSize int) *JSRuntimePool {
|
|
|
// 创建HTTP客户端
|
|
|
httpClient := &http.Client{
|
|
|
@@ -362,7 +148,7 @@ func (p *JSRuntimePool) setupGlobals(vm *goja.Runtime) {
|
|
|
// 数据库
|
|
|
vm.Set("db", &JSDatabase{db: model.DB})
|
|
|
|
|
|
- // 定时器 (简化版)
|
|
|
+ // 定时器
|
|
|
vm.Set("setTimeout", func(fn func(), delay int) {
|
|
|
go func() {
|
|
|
time.Sleep(time.Duration(delay) * time.Millisecond)
|
|
|
@@ -371,129 +157,6 @@ func (p *JSRuntimePool) setupGlobals(vm *goja.Runtime) {
|
|
|
})
|
|
|
}
|
|
|
|
|
|
-func (p *JSRuntimePool) fetch(url string, options ...any) *JSFetchResponse {
|
|
|
- req := &JSFetchRequest{
|
|
|
- Method: "GET",
|
|
|
- URL: url,
|
|
|
- Headers: make(map[string]string),
|
|
|
- Timeout: int(jsConfig.FetchTimeout.Seconds()),
|
|
|
- }
|
|
|
-
|
|
|
- // 解析选项
|
|
|
- if len(options) > 0 && options[0] != nil {
|
|
|
- if optMap, ok := options[0].(map[string]any); ok {
|
|
|
- if method, exists := optMap["method"]; exists {
|
|
|
- if methodStr, ok := method.(string); ok {
|
|
|
- req.Method = strings.ToUpper(methodStr)
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- if headers, exists := optMap["headers"]; exists {
|
|
|
- if headersMap, ok := headers.(map[string]any); ok {
|
|
|
- for k, v := range headersMap {
|
|
|
- if vStr, ok := v.(string); ok {
|
|
|
- req.Headers[k] = vStr
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- if body, exists := optMap["body"]; exists {
|
|
|
- switch v := body.(type) {
|
|
|
- case string:
|
|
|
- req.Body = v
|
|
|
- case map[string]any:
|
|
|
- if bodyBytes, err := json.Marshal(v); err == nil {
|
|
|
- req.Body = string(bodyBytes)
|
|
|
- req.Headers["Content-Type"] = "application/json"
|
|
|
- }
|
|
|
- default:
|
|
|
- req.Body = fmt.Sprintf("%v", body)
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- if timeout, exists := optMap["timeout"]; exists {
|
|
|
- if timeoutNum, ok := timeout.(float64); ok {
|
|
|
- req.Timeout = int(timeoutNum)
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- // 创建HTTP请求
|
|
|
- var bodyReader io.Reader
|
|
|
- if req.Body != "" {
|
|
|
- bodyReader = strings.NewReader(req.Body)
|
|
|
- }
|
|
|
-
|
|
|
- httpReq, err := http.NewRequest(req.Method, req.URL, bodyReader)
|
|
|
- if err != nil {
|
|
|
- return &JSFetchResponse{
|
|
|
- Status: 0,
|
|
|
- StatusText: err.Error(),
|
|
|
- Headers: make(map[string]string),
|
|
|
- Body: "",
|
|
|
- OK: false,
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- // 设置请求头
|
|
|
- for k, v := range req.Headers {
|
|
|
- httpReq.Header.Set(k, v)
|
|
|
- }
|
|
|
-
|
|
|
- // 设置默认User-Agent
|
|
|
- if httpReq.Header.Get("User-Agent") == "" {
|
|
|
- httpReq.Header.Set("User-Agent", "JS-Runtime-Fetch/1.0")
|
|
|
- }
|
|
|
-
|
|
|
- // 创建带超时的上下文
|
|
|
- ctx, cancel := context.WithTimeout(context.Background(), time.Duration(req.Timeout)*time.Second)
|
|
|
- defer cancel()
|
|
|
- httpReq = httpReq.WithContext(ctx)
|
|
|
-
|
|
|
- // 执行请求
|
|
|
- resp, err := p.httpClient.Do(httpReq)
|
|
|
- if err != nil {
|
|
|
- return &JSFetchResponse{
|
|
|
- Status: 0,
|
|
|
- StatusText: err.Error(),
|
|
|
- Headers: make(map[string]string),
|
|
|
- Body: "",
|
|
|
- OK: false,
|
|
|
- }
|
|
|
- }
|
|
|
- defer resp.Body.Close()
|
|
|
-
|
|
|
- // 读取响应体
|
|
|
- bodyBytes, err := io.ReadAll(resp.Body)
|
|
|
- if err != nil {
|
|
|
- return &JSFetchResponse{
|
|
|
- Status: resp.StatusCode,
|
|
|
- StatusText: resp.Status,
|
|
|
- Headers: make(map[string]string),
|
|
|
- Body: "",
|
|
|
- OK: resp.StatusCode >= 200 && resp.StatusCode < 300,
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- // 构建响应头
|
|
|
- headers := make(map[string]string)
|
|
|
- for k, v := range resp.Header {
|
|
|
- if len(v) > 0 {
|
|
|
- headers[k] = v[0]
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- return &JSFetchResponse{
|
|
|
- Status: resp.StatusCode,
|
|
|
- StatusText: resp.Status,
|
|
|
- Headers: headers,
|
|
|
- Body: string(bodyBytes),
|
|
|
- OK: resp.StatusCode >= 200 && resp.StatusCode < 300,
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
func (p *JSRuntimePool) loadScripts(vm *goja.Runtime) {
|
|
|
p.mu.RLock()
|
|
|
defer p.mu.RUnlock()
|
|
|
@@ -546,68 +209,6 @@ done:
|
|
|
common.SysLog("JavaScript scripts reloaded")
|
|
|
}
|
|
|
|
|
|
-func (jsdb *JSDatabase) Query(sql string, args ...any) []map[string]any {
|
|
|
- if jsdb.db == nil {
|
|
|
- common.SysError("JS DB is nil")
|
|
|
- return nil
|
|
|
- }
|
|
|
-
|
|
|
- rows, err := jsdb.db.Raw(sql, args...).Rows()
|
|
|
- if err != nil {
|
|
|
- common.SysError("JS DB Query Error: " + err.Error())
|
|
|
- return nil
|
|
|
- }
|
|
|
- defer rows.Close()
|
|
|
-
|
|
|
- columns, err := rows.Columns()
|
|
|
- if err != nil {
|
|
|
- common.SysError("JS DB Columns Error: " + err.Error())
|
|
|
- return nil
|
|
|
- }
|
|
|
-
|
|
|
- results := make([]map[string]any, 0, 100)
|
|
|
- for rows.Next() {
|
|
|
- values := make([]any, len(columns))
|
|
|
- valuePtrs := make([]any, len(columns))
|
|
|
- for i := range values {
|
|
|
- valuePtrs[i] = &values[i]
|
|
|
- }
|
|
|
-
|
|
|
- if err := rows.Scan(valuePtrs...); err != nil {
|
|
|
- common.SysError("JS DB Scan Error: " + err.Error())
|
|
|
- continue
|
|
|
- }
|
|
|
-
|
|
|
- row := make(map[string]any, len(columns))
|
|
|
- for i, col := range columns {
|
|
|
- val := values[i]
|
|
|
- if b, ok := val.([]byte); ok {
|
|
|
- row[col] = string(b)
|
|
|
- } else {
|
|
|
- row[col] = val
|
|
|
- }
|
|
|
- }
|
|
|
- results = append(results, row)
|
|
|
- }
|
|
|
-
|
|
|
- return results
|
|
|
-}
|
|
|
-
|
|
|
-func (jsdb *JSDatabase) Exec(sql string, args ...any) map[string]any {
|
|
|
- if jsdb.db == nil {
|
|
|
- return map[string]any{
|
|
|
- "rowsAffected": int64(0),
|
|
|
- "error": "database is nil",
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- result := jsdb.db.Exec(sql, args...)
|
|
|
- return map[string]any{
|
|
|
- "rowsAffected": result.RowsAffected,
|
|
|
- "error": result.Error,
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
func initJSRuntimePool() *JSRuntimePool {
|
|
|
jsPoolOnce.Do(func() {
|
|
|
jsRuntimePool = NewJSRuntimePool(jsConfig.MaxVMCount)
|
|
|
@@ -837,59 +438,6 @@ func (p *JSRuntimePool) hasPostProcessFunction() bool {
|
|
|
return postProcessFunc != nil && !goja.IsUndefined(postProcessFunc)
|
|
|
}
|
|
|
|
|
|
-func newResponseWriter(w gin.ResponseWriter) *responseWriter {
|
|
|
- return &responseWriter{
|
|
|
- ResponseWriter: w,
|
|
|
- body: &bytes.Buffer{},
|
|
|
- statusCode: 200,
|
|
|
- headerMap: make(http.Header),
|
|
|
- written: false,
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-func (w *responseWriter) Write(data []byte) (int, error) {
|
|
|
- w.mu.Lock()
|
|
|
- defer w.mu.Unlock()
|
|
|
-
|
|
|
- if !w.written {
|
|
|
- w.WriteHeader(200)
|
|
|
- }
|
|
|
- return w.body.Write(data)
|
|
|
-}
|
|
|
-
|
|
|
-func (w *responseWriter) WriteString(s string) (int, error) {
|
|
|
- w.mu.Lock()
|
|
|
- defer w.mu.Unlock()
|
|
|
-
|
|
|
- if !w.written {
|
|
|
- w.WriteHeader(200)
|
|
|
- }
|
|
|
- return w.body.WriteString(s)
|
|
|
-}
|
|
|
-
|
|
|
-func (w *responseWriter) WriteHeader(statusCode int) {
|
|
|
- w.mu.Lock()
|
|
|
- defer w.mu.Unlock()
|
|
|
-
|
|
|
- if w.written {
|
|
|
- return
|
|
|
- }
|
|
|
- w.statusCode = statusCode
|
|
|
- w.written = true
|
|
|
-
|
|
|
- maps.Copy(w.headerMap, w.ResponseWriter.Header())
|
|
|
-}
|
|
|
-
|
|
|
-func (w *responseWriter) Header() http.Header {
|
|
|
- w.mu.RLock()
|
|
|
- defer w.mu.RUnlock()
|
|
|
-
|
|
|
- if w.headerMap == nil {
|
|
|
- w.headerMap = make(http.Header)
|
|
|
- }
|
|
|
- return w.headerMap
|
|
|
-}
|
|
|
-
|
|
|
func JSRuntimeMiddleware() gin.HandlerFunc {
|
|
|
if !jsConfig.Enabled {
|
|
|
return func(c *gin.Context) {
|