script_surge.go 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. package script
  2. import (
  3. "context"
  4. "net/http"
  5. "sync"
  6. "time"
  7. "unsafe"
  8. "github.com/sagernet/sing-box/adapter"
  9. C "github.com/sagernet/sing-box/constant"
  10. "github.com/sagernet/sing-box/log"
  11. "github.com/sagernet/sing-box/option"
  12. "github.com/sagernet/sing-box/script/jsc"
  13. "github.com/sagernet/sing-box/script/modules/surge"
  14. "github.com/sagernet/sing/common"
  15. E "github.com/sagernet/sing/common/exceptions"
  16. F "github.com/sagernet/sing/common/format"
  17. "github.com/sagernet/sing/common/logger"
  18. "github.com/adhocore/gronx"
  19. "github.com/dop251/goja"
  20. )
  21. const defaultSurgeScriptTimeout = 10 * time.Second
  22. var _ adapter.SurgeScript = (*SurgeScript)(nil)
  23. type SurgeScript struct {
  24. ctx context.Context
  25. logger logger.ContextLogger
  26. tag string
  27. source Source
  28. cronExpression string
  29. cronTimeout time.Duration
  30. cronArguments []string
  31. cronTimer *time.Timer
  32. cronDone chan struct{}
  33. }
  34. func NewSurgeScript(ctx context.Context, logger logger.ContextLogger, options option.Script) (adapter.Script, error) {
  35. source, err := NewSource(ctx, logger, options)
  36. if err != nil {
  37. return nil, err
  38. }
  39. cronOptions := common.PtrValueOrDefault(options.SurgeOptions.CronOptions)
  40. if cronOptions.Expression != "" {
  41. if !gronx.IsValid(cronOptions.Expression) {
  42. return nil, E.New("invalid cron expression: ", cronOptions.Expression)
  43. }
  44. }
  45. return &SurgeScript{
  46. ctx: ctx,
  47. logger: logger,
  48. tag: options.Tag,
  49. source: source,
  50. cronExpression: cronOptions.Expression,
  51. cronTimeout: time.Duration(cronOptions.Timeout),
  52. cronArguments: cronOptions.Arguments,
  53. cronDone: make(chan struct{}),
  54. }, nil
  55. }
  56. func (s *SurgeScript) Type() string {
  57. return C.ScriptTypeSurge
  58. }
  59. func (s *SurgeScript) Tag() string {
  60. return s.tag
  61. }
  62. func (s *SurgeScript) StartContext(ctx context.Context, startContext *adapter.HTTPStartContext) error {
  63. return s.source.StartContext(ctx, startContext)
  64. }
  65. func (s *SurgeScript) PostStart() error {
  66. err := s.source.PostStart()
  67. if err != nil {
  68. return err
  69. }
  70. if s.cronExpression != "" {
  71. go s.loopCronEvents()
  72. }
  73. return nil
  74. }
  75. func (s *SurgeScript) loopCronEvents() {
  76. s.logger.Debug("starting event")
  77. err := s.ExecuteGeneric(s.ctx, "cron", s.cronTimeout, s.cronArguments)
  78. if err != nil {
  79. s.logger.Error(E.Cause(err, "running event"))
  80. }
  81. nextTick, err := gronx.NextTick(s.cronExpression, false)
  82. if err != nil {
  83. s.logger.Error(E.Cause(err, "determine next tick"))
  84. return
  85. }
  86. s.cronTimer = time.NewTimer(nextTick.Sub(time.Now()))
  87. s.logger.Debug("next event at: ", nextTick.Format(log.DefaultTimeFormat))
  88. for {
  89. select {
  90. case <-s.ctx.Done():
  91. return
  92. case <-s.cronDone:
  93. return
  94. case <-s.cronTimer.C:
  95. s.logger.Debug("starting event")
  96. err = s.ExecuteGeneric(s.ctx, "cron", s.cronTimeout, s.cronArguments)
  97. if err != nil {
  98. s.logger.Error(E.Cause(err, "running event"))
  99. }
  100. nextTick, err = gronx.NextTick(s.cronExpression, false)
  101. if err != nil {
  102. s.logger.Error(E.Cause(err, "determine next tick"))
  103. return
  104. }
  105. s.cronTimer.Reset(nextTick.Sub(time.Now()))
  106. s.logger.Debug("configured next event at: ", nextTick)
  107. }
  108. }
  109. }
  110. func (s *SurgeScript) Close() error {
  111. err := s.source.Close()
  112. if s.cronTimer != nil {
  113. s.cronTimer.Stop()
  114. close(s.cronDone)
  115. }
  116. return err
  117. }
  118. func (s *SurgeScript) ExecuteGeneric(ctx context.Context, scriptType string, timeout time.Duration, arguments []string) error {
  119. program := s.source.Program()
  120. if program == nil {
  121. return E.New("invalid script")
  122. }
  123. ctx, cancel := context.WithCancelCause(ctx)
  124. defer cancel(nil)
  125. runtime := NewRuntime(ctx, cancel)
  126. SetModules(runtime, ctx, s.logger, cancel, s.tag)
  127. surge.Enable(runtime, scriptType, arguments)
  128. if timeout == 0 {
  129. timeout = defaultSurgeScriptTimeout
  130. }
  131. ctx, timeoutCancel := context.WithTimeout(ctx, timeout)
  132. defer timeoutCancel()
  133. done := make(chan struct{})
  134. doneFunc := common.OnceFunc(func() {
  135. close(done)
  136. })
  137. runtime.Set("done", func(call goja.FunctionCall) goja.Value {
  138. doneFunc()
  139. return goja.Undefined()
  140. })
  141. var (
  142. access sync.Mutex
  143. scriptErr error
  144. )
  145. go func() {
  146. _, err := runtime.RunProgram(program)
  147. if err != nil {
  148. access.Lock()
  149. scriptErr = err
  150. access.Unlock()
  151. doneFunc()
  152. }
  153. }()
  154. select {
  155. case <-ctx.Done():
  156. runtime.Interrupt(ctx.Err())
  157. return ctx.Err()
  158. case <-done:
  159. access.Lock()
  160. defer access.Unlock()
  161. if scriptErr != nil {
  162. runtime.Interrupt(scriptErr)
  163. } else {
  164. runtime.Interrupt("script done")
  165. }
  166. }
  167. return scriptErr
  168. }
  169. func (s *SurgeScript) ExecuteHTTPRequest(ctx context.Context, timeout time.Duration, request *http.Request, body []byte, binaryBody bool, arguments []string) (*adapter.HTTPRequestScriptResult, error) {
  170. program := s.source.Program()
  171. if program == nil {
  172. return nil, E.New("invalid script")
  173. }
  174. ctx, cancel := context.WithCancelCause(ctx)
  175. defer cancel(nil)
  176. runtime := NewRuntime(ctx, cancel)
  177. SetModules(runtime, ctx, s.logger, cancel, s.tag)
  178. surge.Enable(runtime, "http-request", arguments)
  179. if timeout == 0 {
  180. timeout = defaultSurgeScriptTimeout
  181. }
  182. ctx, timeoutCancel := context.WithTimeout(ctx, timeout)
  183. defer timeoutCancel()
  184. runtime.ClearInterrupt()
  185. requestObject := runtime.NewObject()
  186. requestObject.Set("url", request.URL.String())
  187. requestObject.Set("method", request.Method)
  188. requestObject.Set("headers", jsc.HeadersToValue(runtime, request.Header))
  189. if !binaryBody {
  190. requestObject.Set("body", string(body))
  191. } else {
  192. requestObject.Set("body", jsc.NewUint8Array(runtime, body))
  193. }
  194. requestObject.Set("id", F.ToString(uintptr(unsafe.Pointer(request))))
  195. runtime.Set("request", requestObject)
  196. done := make(chan struct{})
  197. doneFunc := common.OnceFunc(func() {
  198. close(done)
  199. })
  200. var (
  201. access sync.Mutex
  202. result adapter.HTTPRequestScriptResult
  203. scriptErr error
  204. )
  205. runtime.Set("done", func(call goja.FunctionCall) goja.Value {
  206. defer doneFunc()
  207. resultObject := jsc.AssertObject(runtime, call.Argument(0), "done() argument", true)
  208. if resultObject == nil {
  209. panic(runtime.NewGoError(E.New("request rejected by script")))
  210. }
  211. access.Lock()
  212. defer access.Unlock()
  213. result.URL = jsc.AssertString(runtime, resultObject.Get("url"), "url", true)
  214. result.Headers = jsc.AssertHTTPHeader(runtime, resultObject.Get("headers"), "headers")
  215. result.Body = jsc.AssertStringBinary(runtime, resultObject.Get("body"), "body", true)
  216. responseObject := jsc.AssertObject(runtime, resultObject.Get("response"), "response", true)
  217. if responseObject != nil {
  218. result.Response = &adapter.HTTPRequestScriptResponse{
  219. Status: int(jsc.AssertInt(runtime, responseObject.Get("status"), "status", true)),
  220. Headers: jsc.AssertHTTPHeader(runtime, responseObject.Get("headers"), "headers"),
  221. Body: jsc.AssertStringBinary(runtime, responseObject.Get("body"), "body", true),
  222. }
  223. }
  224. return goja.Undefined()
  225. })
  226. go func() {
  227. _, err := runtime.RunProgram(program)
  228. if err != nil {
  229. access.Lock()
  230. scriptErr = err
  231. access.Unlock()
  232. doneFunc()
  233. }
  234. }()
  235. select {
  236. case <-ctx.Done():
  237. runtime.Interrupt(ctx.Err())
  238. return nil, ctx.Err()
  239. case <-done:
  240. access.Lock()
  241. defer access.Unlock()
  242. if scriptErr != nil {
  243. runtime.Interrupt(scriptErr)
  244. } else {
  245. runtime.Interrupt("script done")
  246. }
  247. }
  248. return &result, scriptErr
  249. }
  250. func (s *SurgeScript) ExecuteHTTPResponse(ctx context.Context, timeout time.Duration, request *http.Request, response *http.Response, body []byte, binaryBody bool, arguments []string) (*adapter.HTTPResponseScriptResult, error) {
  251. program := s.source.Program()
  252. if program == nil {
  253. return nil, E.New("invalid script")
  254. }
  255. ctx, cancel := context.WithCancelCause(ctx)
  256. defer cancel(nil)
  257. runtime := NewRuntime(ctx, cancel)
  258. SetModules(runtime, ctx, s.logger, cancel, s.tag)
  259. surge.Enable(runtime, "http-response", arguments)
  260. if timeout == 0 {
  261. timeout = defaultSurgeScriptTimeout
  262. }
  263. ctx, timeoutCancel := context.WithTimeout(ctx, timeout)
  264. defer timeoutCancel()
  265. runtime.ClearInterrupt()
  266. requestObject := runtime.NewObject()
  267. requestObject.Set("url", request.URL.String())
  268. requestObject.Set("method", request.Method)
  269. requestObject.Set("headers", jsc.HeadersToValue(runtime, request.Header))
  270. requestObject.Set("id", F.ToString(uintptr(unsafe.Pointer(request))))
  271. runtime.Set("request", requestObject)
  272. responseObject := runtime.NewObject()
  273. responseObject.Set("status", response.StatusCode)
  274. responseObject.Set("headers", jsc.HeadersToValue(runtime, response.Header))
  275. if !binaryBody {
  276. responseObject.Set("body", string(body))
  277. } else {
  278. responseObject.Set("body", jsc.NewUint8Array(runtime, body))
  279. }
  280. runtime.Set("response", responseObject)
  281. done := make(chan struct{})
  282. doneFunc := common.OnceFunc(func() {
  283. close(done)
  284. })
  285. var (
  286. access sync.Mutex
  287. result adapter.HTTPResponseScriptResult
  288. scriptErr error
  289. )
  290. runtime.Set("done", func(call goja.FunctionCall) goja.Value {
  291. resultObject := jsc.AssertObject(runtime, call.Argument(0), "done() argument", true)
  292. if resultObject == nil {
  293. panic(runtime.NewGoError(E.New("response rejected by script")))
  294. }
  295. access.Lock()
  296. defer access.Unlock()
  297. result.Status = int(jsc.AssertInt(runtime, resultObject.Get("status"), "status", true))
  298. result.Headers = jsc.AssertHTTPHeader(runtime, resultObject.Get("headers"), "headers")
  299. result.Body = jsc.AssertStringBinary(runtime, resultObject.Get("body"), "body", true)
  300. doneFunc()
  301. return goja.Undefined()
  302. })
  303. go func() {
  304. _, err := runtime.RunProgram(program)
  305. if err != nil {
  306. access.Lock()
  307. scriptErr = err
  308. access.Unlock()
  309. doneFunc()
  310. }
  311. }()
  312. select {
  313. case <-ctx.Done():
  314. runtime.Interrupt(ctx.Err())
  315. return nil, ctx.Err()
  316. case <-done:
  317. access.Lock()
  318. defer access.Unlock()
  319. if scriptErr != nil {
  320. runtime.Interrupt(scriptErr)
  321. } else {
  322. runtime.Interrupt("script done")
  323. }
  324. return &result, scriptErr
  325. }
  326. }