engine.go 34 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099
  1. package mitm
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "crypto/tls"
  7. "io"
  8. "math"
  9. "mime"
  10. "net"
  11. "net/http"
  12. "net/url"
  13. "os"
  14. "path/filepath"
  15. "strings"
  16. "time"
  17. "unicode"
  18. "github.com/sagernet/sing-box/adapter"
  19. "github.com/sagernet/sing-box/common/dialer"
  20. sTLS "github.com/sagernet/sing-box/common/tls"
  21. "github.com/sagernet/sing-box/option"
  22. "github.com/sagernet/sing/common"
  23. "github.com/sagernet/sing/common/atomic"
  24. E "github.com/sagernet/sing/common/exceptions"
  25. F "github.com/sagernet/sing/common/format"
  26. "github.com/sagernet/sing/common/logger"
  27. M "github.com/sagernet/sing/common/metadata"
  28. N "github.com/sagernet/sing/common/network"
  29. "github.com/sagernet/sing/common/ntp"
  30. sHTTP "github.com/sagernet/sing/protocol/http"
  31. "github.com/sagernet/sing/service"
  32. "golang.org/x/net/http2"
  33. )
  34. var _ adapter.MITMEngine = (*Engine)(nil)
  35. type Engine struct {
  36. ctx context.Context
  37. logger logger.ContextLogger
  38. connection adapter.ConnectionManager
  39. certificate adapter.CertificateStore
  40. script adapter.ScriptManager
  41. timeFunc func() time.Time
  42. http2Enabled bool
  43. }
  44. func NewEngine(ctx context.Context, logger logger.ContextLogger, options option.MITMOptions) (*Engine, error) {
  45. engine := &Engine{
  46. ctx: ctx,
  47. logger: logger,
  48. http2Enabled: options.HTTP2Enabled,
  49. }
  50. return engine, nil
  51. }
  52. func (e *Engine) Start(stage adapter.StartStage) error {
  53. switch stage {
  54. case adapter.StartStateInitialize:
  55. e.connection = service.FromContext[adapter.ConnectionManager](e.ctx)
  56. e.certificate = service.FromContext[adapter.CertificateStore](e.ctx)
  57. e.script = service.FromContext[adapter.ScriptManager](e.ctx)
  58. e.timeFunc = ntp.TimeFuncFromContext(e.ctx)
  59. if e.timeFunc == nil {
  60. e.timeFunc = time.Now
  61. }
  62. }
  63. return nil
  64. }
  65. func (e *Engine) Close() error {
  66. return nil
  67. }
  68. func (e *Engine) NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
  69. if e.certificate.TLSDecryptionEnabled() && metadata.ClientHello != nil {
  70. err := e.newTLS(ctx, this, conn, metadata, onClose)
  71. if err != nil {
  72. e.logger.ErrorContext(ctx, err)
  73. } else {
  74. e.logger.DebugContext(ctx, "connection closed")
  75. }
  76. if onClose != nil {
  77. onClose(err)
  78. }
  79. return
  80. } else if metadata.HTTPRequest != nil {
  81. err := e.newHTTP1(ctx, this, conn, nil, metadata)
  82. if err != nil {
  83. e.logger.ErrorContext(ctx, err)
  84. } else {
  85. e.logger.DebugContext(ctx, "connection closed")
  86. }
  87. if onClose != nil {
  88. onClose(err)
  89. }
  90. return
  91. } else {
  92. e.logger.DebugContext(ctx, "HTTP and TLS not detected, skipped")
  93. }
  94. metadata.MITM = nil
  95. e.connection.NewConnection(ctx, this, conn, metadata, onClose)
  96. }
  97. func (e *Engine) newTLS(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error {
  98. acceptHTTP := len(metadata.ClientHello.SupportedProtos) == 0 || common.Contains(metadata.ClientHello.SupportedProtos, "http/1.1")
  99. acceptH2 := e.http2Enabled && common.Contains(metadata.ClientHello.SupportedProtos, "h2")
  100. if !acceptHTTP && !acceptH2 {
  101. metadata.MITM = nil
  102. e.logger.DebugContext(ctx, "unsupported application protocol: ", strings.Join(metadata.ClientHello.SupportedProtos, ","))
  103. e.connection.NewConnection(ctx, this, conn, metadata, onClose)
  104. return nil
  105. }
  106. var nextProtos []string
  107. if acceptH2 {
  108. nextProtos = append(nextProtos, "h2")
  109. } else if acceptHTTP {
  110. nextProtos = append(nextProtos, "http/1.1")
  111. }
  112. var (
  113. maxVersion uint16
  114. minVersion uint16
  115. )
  116. for _, version := range metadata.ClientHello.SupportedVersions {
  117. maxVersion = common.Max(maxVersion, version)
  118. minVersion = common.Min(minVersion, version)
  119. }
  120. serverName := metadata.ClientHello.ServerName
  121. if serverName == "" && metadata.Destination.IsIP() {
  122. serverName = metadata.Destination.Addr.String()
  123. }
  124. tlsConfig := &tls.Config{
  125. Time: e.timeFunc,
  126. ServerName: serverName,
  127. NextProtos: nextProtos,
  128. MinVersion: minVersion,
  129. MaxVersion: maxVersion,
  130. GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
  131. return sTLS.GenerateKeyPair(e.certificate.TLSDecryptionCertificate(), e.certificate.TLSDecryptionPrivateKey(), e.timeFunc, serverName)
  132. },
  133. }
  134. tlsConn := tls.Server(conn, tlsConfig)
  135. err := tlsConn.HandshakeContext(ctx)
  136. if err != nil {
  137. return E.Cause(err, "TLS handshake failed for ", metadata.ClientHello.ServerName, ", ", strings.Join(metadata.ClientHello.SupportedProtos, ", "))
  138. }
  139. if tlsConn.ConnectionState().NegotiatedProtocol == "h2" {
  140. return e.newHTTP2(ctx, this, tlsConn, tlsConfig, metadata, onClose)
  141. } else {
  142. return e.newHTTP1(ctx, this, tlsConn, tlsConfig, metadata)
  143. }
  144. }
  145. func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tlsConfig *tls.Config, metadata adapter.InboundContext) error {
  146. options := metadata.MITM
  147. defer conn.Close()
  148. reader := bufio.NewReader(conn)
  149. request, err := sHTTP.ReadRequest(reader)
  150. if err != nil {
  151. return E.Cause(err, "read HTTP request")
  152. }
  153. rawRequestURL := request.URL
  154. if tlsConfig != nil {
  155. rawRequestURL.Scheme = "https"
  156. } else {
  157. rawRequestURL.Scheme = "http"
  158. }
  159. if rawRequestURL.Host == "" {
  160. rawRequestURL.Host = request.Host
  161. }
  162. requestURL := rawRequestURL.String()
  163. request.RequestURI = ""
  164. var (
  165. requestMatch bool
  166. requestScript adapter.SurgeScript
  167. requestScriptOptions option.MITMRouteSurgeScriptOptions
  168. )
  169. match:
  170. for _, scriptOptions := range options.Script {
  171. script, loaded := e.script.Script(scriptOptions.Tag)
  172. if !loaded {
  173. e.logger.WarnContext(ctx, "script not found: ", scriptOptions.Tag)
  174. continue
  175. }
  176. surgeScript, isSurge := script.(adapter.SurgeScript)
  177. if !isSurge {
  178. e.logger.WarnContext(ctx, "specified script/", script.Type(), "[", script.Tag(), "] is not a Surge script")
  179. continue
  180. }
  181. for _, pattern := range scriptOptions.Pattern {
  182. if pattern.Build().MatchString(requestURL) {
  183. e.logger.DebugContext(ctx, "match script/", surgeScript.Type(), "[", surgeScript.Tag(), "]")
  184. requestScript = surgeScript
  185. requestScriptOptions = scriptOptions
  186. requestMatch = true
  187. break match
  188. }
  189. }
  190. }
  191. var body []byte
  192. if options.Print && request.ContentLength > 0 && request.ContentLength <= 131072 {
  193. body, err = io.ReadAll(request.Body)
  194. if err != nil {
  195. return E.Cause(err, "read HTTP request body")
  196. }
  197. request.Body = io.NopCloser(bytes.NewReader(body))
  198. }
  199. if options.Print {
  200. e.printRequest(ctx, request, body)
  201. }
  202. if requestScript != nil {
  203. if body == nil && requestScriptOptions.RequiresBody && request.ContentLength > 0 && (requestScriptOptions.MaxSize == 0 && request.ContentLength <= 131072 || request.ContentLength <= requestScriptOptions.MaxSize) {
  204. body, err = io.ReadAll(request.Body)
  205. if err != nil {
  206. return E.Cause(err, "read HTTP request body")
  207. }
  208. request.Body = io.NopCloser(bytes.NewReader(body))
  209. }
  210. var result *adapter.HTTPRequestScriptResult
  211. result, err = requestScript.ExecuteHTTPRequest(ctx, time.Duration(requestScriptOptions.Timeout), request, body, requestScriptOptions.BinaryBodyMode, requestScriptOptions.Arguments)
  212. if err != nil {
  213. return E.Cause(err, "execute script/", requestScript.Type(), "[", requestScript.Tag(), "]")
  214. }
  215. if result.Response != nil {
  216. if result.Response.Status == 0 {
  217. result.Response.Status = http.StatusOK
  218. }
  219. response := &http.Response{
  220. StatusCode: result.Response.Status,
  221. Status: http.StatusText(result.Response.Status),
  222. Proto: request.Proto,
  223. ProtoMajor: request.ProtoMajor,
  224. ProtoMinor: request.ProtoMinor,
  225. Header: result.Response.Headers,
  226. Body: io.NopCloser(bytes.NewReader(result.Response.Body)),
  227. }
  228. err = response.Write(conn)
  229. if err != nil {
  230. return E.Cause(err, "write fake response body")
  231. }
  232. return nil
  233. } else {
  234. if result.URL != "" {
  235. var newURL *url.URL
  236. newURL, err = url.Parse(result.URL)
  237. if err != nil {
  238. return E.Cause(err, "parse updated request URL")
  239. }
  240. request.URL = newURL
  241. newDestination := M.ParseSocksaddrHostPortStr(newURL.Hostname(), newURL.Port())
  242. if newDestination.Port == 0 {
  243. newDestination.Port = metadata.Destination.Port
  244. }
  245. metadata.Destination = newDestination
  246. if tlsConfig != nil {
  247. tlsConfig.ServerName = newURL.Hostname()
  248. }
  249. }
  250. for key, values := range result.Headers {
  251. request.Header[key] = values
  252. }
  253. if newHost := result.Headers.Get("Host"); newHost != "" {
  254. request.Host = newHost
  255. request.Header.Del("Host")
  256. }
  257. if result.Body != nil {
  258. body = result.Body
  259. request.Body = io.NopCloser(bytes.NewReader(body))
  260. request.ContentLength = int64(len(body))
  261. }
  262. }
  263. }
  264. if !requestMatch {
  265. for i, rule := range options.SurgeURLRewrite {
  266. if !rule.Pattern.MatchString(requestURL) {
  267. continue
  268. }
  269. e.logger.DebugContext(ctx, "match url_rewrite[", i, "] => ", rule.String())
  270. if rule.Reject {
  271. return E.New("request rejected by url_rewrite")
  272. } else if rule.Redirect {
  273. w := new(simpleResponseWriter)
  274. http.Redirect(w, request, rule.Destination.String(), http.StatusFound)
  275. err = w.Build(request).Write(conn)
  276. if err != nil {
  277. return E.Cause(err, "write url_rewrite 302 response")
  278. }
  279. return nil
  280. }
  281. requestMatch = true
  282. request.URL = rule.Destination
  283. newDestination := M.ParseSocksaddrHostPortStr(rule.Destination.Hostname(), rule.Destination.Port())
  284. if newDestination.Port == 0 {
  285. newDestination.Port = metadata.Destination.Port
  286. }
  287. metadata.Destination = newDestination
  288. if tlsConfig != nil {
  289. tlsConfig.ServerName = rule.Destination.Hostname()
  290. }
  291. break
  292. }
  293. for i, rule := range options.SurgeHeaderRewrite {
  294. if rule.Response {
  295. continue
  296. }
  297. if !rule.Pattern.MatchString(requestURL) {
  298. continue
  299. }
  300. requestMatch = true
  301. e.logger.DebugContext(ctx, "match header_rewrite[", i, "] => ", rule.String())
  302. switch {
  303. case rule.Add:
  304. if strings.ToLower(rule.Key) == "host" {
  305. request.Host = rule.Value
  306. continue
  307. }
  308. request.Header.Add(rule.Key, rule.Value)
  309. case rule.Delete:
  310. request.Header.Del(rule.Key)
  311. case rule.Replace:
  312. if request.Header.Get(rule.Key) != "" {
  313. request.Header.Set(rule.Key, rule.Value)
  314. }
  315. case rule.ReplaceRegex:
  316. if value := request.Header.Get(rule.Key); value != "" {
  317. request.Header.Set(rule.Key, rule.Match.ReplaceAllString(value, rule.Value))
  318. }
  319. }
  320. }
  321. for i, rule := range options.SurgeBodyRewrite {
  322. if rule.Response {
  323. continue
  324. }
  325. if !rule.Pattern.MatchString(requestURL) {
  326. continue
  327. }
  328. requestMatch = true
  329. e.logger.DebugContext(ctx, "match body_rewrite[", i, "] => ", rule.String())
  330. if body == nil {
  331. if request.ContentLength <= 0 {
  332. e.logger.WarnContext(ctx, "body replace skipped due to non-fixed content length")
  333. break
  334. } else if request.ContentLength > 131072 {
  335. e.logger.WarnContext(ctx, "body replace skipped due to large content length: ", request.ContentLength)
  336. break
  337. }
  338. body, err = io.ReadAll(request.Body)
  339. if err != nil {
  340. return E.Cause(err, "read HTTP request body")
  341. }
  342. }
  343. for mi := 0; i < len(rule.Match); i++ {
  344. body = rule.Match[mi].ReplaceAll(body, []byte(rule.Replace[i]))
  345. }
  346. request.Body = io.NopCloser(bytes.NewReader(body))
  347. request.ContentLength = int64(len(body))
  348. }
  349. }
  350. if !requestMatch {
  351. for i, rule := range options.SurgeMapLocal {
  352. if !rule.Pattern.MatchString(requestURL) {
  353. continue
  354. }
  355. requestMatch = true
  356. e.logger.DebugContext(ctx, "match map_local[", i, "] => ", rule.String())
  357. var (
  358. statusCode = http.StatusOK
  359. headers = make(http.Header)
  360. )
  361. if rule.StatusCode > 0 {
  362. statusCode = rule.StatusCode
  363. }
  364. switch {
  365. case rule.File:
  366. resource, err := os.ReadFile(rule.Data)
  367. if err != nil {
  368. return E.Cause(err, "open map local source")
  369. }
  370. mimeType := mime.TypeByExtension(filepath.Ext(rule.Data))
  371. if mimeType == "" {
  372. mimeType = "application/octet-stream"
  373. }
  374. headers.Set("Content-Type", mimeType)
  375. body = resource
  376. case rule.Text:
  377. headers.Set("Content-Type", "text/plain")
  378. body = []byte(rule.Data)
  379. case rule.TinyGif:
  380. headers.Set("Content-Type", "image/gif")
  381. body = surgeTinyGif()
  382. case rule.Base64:
  383. headers.Set("Content-Type", "application/octet-stream")
  384. body = rule.Base64Data
  385. }
  386. response := &http.Response{
  387. StatusCode: statusCode,
  388. Status: http.StatusText(statusCode),
  389. Proto: request.Proto,
  390. ProtoMajor: request.ProtoMajor,
  391. ProtoMinor: request.ProtoMinor,
  392. Header: headers,
  393. Body: io.NopCloser(bytes.NewReader(body)),
  394. }
  395. err = response.Write(conn)
  396. if err != nil {
  397. return E.Cause(err, "write map local response")
  398. }
  399. return nil
  400. }
  401. }
  402. ctx = adapter.WithContext(ctx, &metadata)
  403. var innerErr atomic.TypedValue[error]
  404. httpClient := &http.Client{
  405. Transport: &http.Transport{
  406. DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
  407. if len(metadata.DestinationAddresses) > 0 || metadata.Destination.IsIP() {
  408. return dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
  409. } else {
  410. return this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
  411. }
  412. },
  413. TLSClientConfig: tlsConfig,
  414. },
  415. CheckRedirect: func(req *http.Request, via []*http.Request) error {
  416. return http.ErrUseLastResponse
  417. },
  418. }
  419. defer httpClient.CloseIdleConnections()
  420. requestCtx, cancel := context.WithCancel(ctx)
  421. defer cancel()
  422. response, err := httpClient.Do(request.WithContext(requestCtx))
  423. if err != nil {
  424. cancel()
  425. return E.Errors(innerErr.Load(), err)
  426. }
  427. var (
  428. responseScript adapter.SurgeScript
  429. responseMatch bool
  430. responseScriptOptions option.MITMRouteSurgeScriptOptions
  431. )
  432. matchResponse:
  433. for _, scriptOptions := range options.Script {
  434. script, loaded := e.script.Script(scriptOptions.Tag)
  435. if !loaded {
  436. e.logger.WarnContext(ctx, "script not found: ", scriptOptions.Tag)
  437. continue
  438. }
  439. surgeScript, isSurge := script.(adapter.SurgeScript)
  440. if !isSurge {
  441. e.logger.WarnContext(ctx, "specified script/", script.Type(), "[", script.Tag(), "] is not a Surge script")
  442. continue
  443. }
  444. for _, pattern := range scriptOptions.Pattern {
  445. if pattern.Build().MatchString(requestURL) {
  446. e.logger.DebugContext(ctx, "match script/", surgeScript.Type(), "[", surgeScript.Tag(), "]")
  447. responseScript = surgeScript
  448. responseScriptOptions = scriptOptions
  449. responseMatch = true
  450. break matchResponse
  451. }
  452. }
  453. }
  454. var responseBody []byte
  455. if options.Print && response.ContentLength > 0 && response.ContentLength <= 131072 {
  456. responseBody, err = io.ReadAll(response.Body)
  457. if err != nil {
  458. return E.Cause(err, "read HTTP response body")
  459. }
  460. response.Body = io.NopCloser(bytes.NewReader(responseBody))
  461. }
  462. if options.Print {
  463. e.printResponse(ctx, request, response, responseBody)
  464. }
  465. if responseScript != nil {
  466. if responseBody == nil && responseScriptOptions.RequiresBody && response.ContentLength > 0 && (responseScriptOptions.MaxSize == 0 && response.ContentLength <= 131072 || response.ContentLength <= responseScriptOptions.MaxSize) {
  467. responseBody, err = io.ReadAll(response.Body)
  468. if err != nil {
  469. return E.Cause(err, "read HTTP response body")
  470. }
  471. response.Body = io.NopCloser(bytes.NewReader(responseBody))
  472. }
  473. var result *adapter.HTTPResponseScriptResult
  474. result, err = responseScript.ExecuteHTTPResponse(ctx, time.Duration(responseScriptOptions.Timeout), request, response, responseBody, responseScriptOptions.BinaryBodyMode, responseScriptOptions.Arguments)
  475. if err != nil {
  476. return E.Cause(err, "execute script/", responseScript.Type(), "[", responseScript.Tag(), "]")
  477. }
  478. if result.Status > 0 {
  479. response.Status = http.StatusText(result.Status)
  480. response.StatusCode = result.Status
  481. }
  482. for key, values := range result.Headers {
  483. response.Header[key] = values
  484. }
  485. if result.Body != nil {
  486. response.Body.Close()
  487. responseBody = result.Body
  488. response.Body = io.NopCloser(bytes.NewReader(responseBody))
  489. response.ContentLength = int64(len(responseBody))
  490. }
  491. }
  492. if !responseMatch {
  493. for i, rule := range options.SurgeHeaderRewrite {
  494. if !rule.Response {
  495. continue
  496. }
  497. if !rule.Pattern.MatchString(requestURL) {
  498. continue
  499. }
  500. responseMatch = true
  501. e.logger.DebugContext(ctx, "match header_rewrite[", i, "] => ", rule.String())
  502. switch {
  503. case rule.Add:
  504. response.Header.Add(rule.Key, rule.Value)
  505. case rule.Delete:
  506. response.Header.Del(rule.Key)
  507. case rule.Replace:
  508. if response.Header.Get(rule.Key) != "" {
  509. response.Header.Set(rule.Key, rule.Value)
  510. }
  511. case rule.ReplaceRegex:
  512. if value := response.Header.Get(rule.Key); value != "" {
  513. response.Header.Set(rule.Key, rule.Match.ReplaceAllString(value, rule.Value))
  514. }
  515. }
  516. }
  517. for i, rule := range options.SurgeBodyRewrite {
  518. if !rule.Response {
  519. continue
  520. }
  521. if !rule.Pattern.MatchString(requestURL) {
  522. continue
  523. }
  524. responseMatch = true
  525. e.logger.DebugContext(ctx, "match body_rewrite[", i, "] => ", rule.String())
  526. if responseBody == nil {
  527. if response.ContentLength <= 0 {
  528. e.logger.WarnContext(ctx, "body replace skipped due to non-fixed content length")
  529. break
  530. } else if response.ContentLength > 131072 {
  531. e.logger.WarnContext(ctx, "body replace skipped due to large content length: ", request.ContentLength)
  532. break
  533. }
  534. responseBody, err = io.ReadAll(response.Body)
  535. if err != nil {
  536. return E.Cause(err, "read HTTP request body")
  537. }
  538. }
  539. for mi := 0; i < len(rule.Match); i++ {
  540. responseBody = rule.Match[mi].ReplaceAll(responseBody, []byte(rule.Replace[i]))
  541. }
  542. response.Body = io.NopCloser(bytes.NewReader(responseBody))
  543. response.ContentLength = int64(len(responseBody))
  544. }
  545. }
  546. if !options.Print && !requestMatch && !responseMatch {
  547. e.logger.WarnContext(ctx, "request not modified")
  548. }
  549. err = response.Write(conn)
  550. if err != nil {
  551. return E.Errors(E.Cause(err, "write HTTP response"), innerErr.Load())
  552. } else if innerErr.Load() != nil {
  553. return E.Cause(innerErr.Load(), "write HTTP response")
  554. }
  555. return nil
  556. }
  557. func (e *Engine) newHTTP2(ctx context.Context, this N.Dialer, conn net.Conn, tlsConfig *tls.Config, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error {
  558. httpTransport := &http.Transport{
  559. ForceAttemptHTTP2: true,
  560. DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
  561. ctx = adapter.WithContext(ctx, &metadata)
  562. if len(metadata.DestinationAddresses) > 0 || metadata.Destination.IsIP() {
  563. return dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
  564. } else {
  565. return this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
  566. }
  567. },
  568. TLSClientConfig: tlsConfig,
  569. }
  570. err := http2.ConfigureTransport(httpTransport)
  571. if err != nil {
  572. return E.Cause(err, "configure HTTP/2 transport")
  573. }
  574. handler := &engineHandler{
  575. Engine: e,
  576. conn: conn,
  577. tlsConfig: tlsConfig,
  578. dialer: this,
  579. metadata: metadata,
  580. httpClient: &http.Client{
  581. Transport: httpTransport,
  582. CheckRedirect: func(req *http.Request, via []*http.Request) error {
  583. return http.ErrUseLastResponse
  584. },
  585. },
  586. onClose: onClose,
  587. }
  588. http2Server := &http2.Server{
  589. MaxReadFrameSize: math.MaxUint32,
  590. }
  591. http2Server.ServeConn(conn, &http2.ServeConnOpts{
  592. Context: ctx,
  593. Handler: handler,
  594. })
  595. return nil
  596. }
  597. type engineHandler struct {
  598. *Engine
  599. conn net.Conn
  600. tlsConfig *tls.Config
  601. dialer N.Dialer
  602. metadata adapter.InboundContext
  603. onClose N.CloseHandlerFunc
  604. httpClient *http.Client
  605. }
  606. func (e *engineHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
  607. err := e.serveHTTP(request.Context(), writer, request)
  608. if err != nil {
  609. if E.IsClosedOrCanceled(err) {
  610. e.logger.DebugContext(request.Context(), E.Cause(err, "connection closed"))
  611. } else {
  612. e.logger.ErrorContext(request.Context(), err)
  613. }
  614. }
  615. }
  616. func (e *engineHandler) serveHTTP(ctx context.Context, writer http.ResponseWriter, request *http.Request) error {
  617. options := e.metadata.MITM
  618. rawRequestURL := request.URL
  619. rawRequestURL.Scheme = "https"
  620. if rawRequestURL.Host == "" {
  621. rawRequestURL.Host = request.Host
  622. }
  623. requestURL := rawRequestURL.String()
  624. request.RequestURI = ""
  625. var (
  626. requestMatch bool
  627. requestScript adapter.SurgeScript
  628. requestScriptOptions option.MITMRouteSurgeScriptOptions
  629. )
  630. match:
  631. for _, scriptOptions := range options.Script {
  632. script, loaded := e.script.Script(scriptOptions.Tag)
  633. if !loaded {
  634. e.logger.WarnContext(ctx, "script not found: ", scriptOptions.Tag)
  635. continue
  636. }
  637. surgeScript, isSurge := script.(adapter.SurgeScript)
  638. if !isSurge {
  639. e.logger.WarnContext(ctx, "specified script/", script.Type(), "[", script.Tag(), "] is not a Surge script")
  640. continue
  641. }
  642. for _, pattern := range scriptOptions.Pattern {
  643. if pattern.Build().MatchString(requestURL) {
  644. e.logger.DebugContext(ctx, "match script/", surgeScript.Type(), "[", surgeScript.Tag(), "]")
  645. requestScript = surgeScript
  646. requestScriptOptions = scriptOptions
  647. requestMatch = true
  648. break match
  649. }
  650. }
  651. }
  652. var (
  653. body []byte
  654. err error
  655. )
  656. if options.Print && request.ContentLength > 0 && request.ContentLength <= 131072 {
  657. body, err = io.ReadAll(request.Body)
  658. if err != nil {
  659. return E.Cause(err, "read HTTP request body")
  660. }
  661. request.Body.Close()
  662. request.Body = io.NopCloser(bytes.NewReader(body))
  663. }
  664. if options.Print {
  665. e.printRequest(ctx, request, body)
  666. }
  667. if requestScript != nil {
  668. if body == nil && requestScriptOptions.RequiresBody && request.ContentLength > 0 && (requestScriptOptions.MaxSize == 0 && request.ContentLength <= 131072 || request.ContentLength <= requestScriptOptions.MaxSize) {
  669. body, err = io.ReadAll(request.Body)
  670. if err != nil {
  671. return E.Cause(err, "read HTTP request body")
  672. }
  673. request.Body.Close()
  674. request.Body = io.NopCloser(bytes.NewReader(body))
  675. }
  676. result, err := requestScript.ExecuteHTTPRequest(ctx, time.Duration(requestScriptOptions.Timeout), request, body, requestScriptOptions.BinaryBodyMode, requestScriptOptions.Arguments)
  677. if err != nil {
  678. return E.Cause(err, "execute script/", requestScript.Type(), "[", requestScript.Tag(), "]")
  679. }
  680. if result.Response != nil {
  681. if result.Response.Status == 0 {
  682. result.Response.Status = http.StatusOK
  683. }
  684. for key, values := range result.Response.Headers {
  685. writer.Header()[key] = values
  686. }
  687. writer.WriteHeader(result.Response.Status)
  688. if result.Response.Body != nil {
  689. _, err = writer.Write(result.Response.Body)
  690. if err != nil {
  691. return E.Cause(err, "write fake response body")
  692. }
  693. }
  694. return nil
  695. } else {
  696. if result.URL != "" {
  697. var newURL *url.URL
  698. newURL, err = url.Parse(result.URL)
  699. if err != nil {
  700. return E.Cause(err, "parse updated request URL")
  701. }
  702. request.URL = newURL
  703. newDestination := M.ParseSocksaddrHostPortStr(newURL.Hostname(), newURL.Port())
  704. if newDestination.Port == 0 {
  705. newDestination.Port = e.metadata.Destination.Port
  706. }
  707. e.metadata.Destination = newDestination
  708. e.tlsConfig.ServerName = newURL.Hostname()
  709. }
  710. for key, values := range result.Headers {
  711. request.Header[key] = values
  712. }
  713. if newHost := result.Headers.Get("Host"); newHost != "" {
  714. request.Host = newHost
  715. request.Header.Del("Host")
  716. }
  717. if result.Body != nil {
  718. io.Copy(io.Discard, request.Body)
  719. request.Body = io.NopCloser(bytes.NewReader(result.Body))
  720. request.ContentLength = int64(len(result.Body))
  721. }
  722. }
  723. }
  724. if !requestMatch {
  725. for i, rule := range options.SurgeURLRewrite {
  726. if !rule.Pattern.MatchString(requestURL) {
  727. continue
  728. }
  729. e.logger.DebugContext(ctx, "match url_rewrite[", i, "] => ", rule.String())
  730. if rule.Reject {
  731. return E.New("request rejected by url_rewrite")
  732. } else if rule.Redirect {
  733. http.Redirect(writer, request, rule.Destination.String(), http.StatusFound)
  734. return nil
  735. }
  736. requestMatch = true
  737. request.URL = rule.Destination
  738. newDestination := M.ParseSocksaddrHostPortStr(rule.Destination.Hostname(), rule.Destination.Port())
  739. if newDestination.Port == 0 {
  740. newDestination.Port = e.metadata.Destination.Port
  741. }
  742. e.metadata.Destination = newDestination
  743. e.tlsConfig.ServerName = rule.Destination.Hostname()
  744. break
  745. }
  746. for i, rule := range options.SurgeHeaderRewrite {
  747. if rule.Response {
  748. continue
  749. }
  750. if !rule.Pattern.MatchString(requestURL) {
  751. continue
  752. }
  753. requestMatch = true
  754. e.logger.DebugContext(ctx, "match header_rewrite[", i, "] => ", rule.String())
  755. switch {
  756. case rule.Add:
  757. if strings.ToLower(rule.Key) == "host" {
  758. request.Host = rule.Value
  759. continue
  760. }
  761. request.Header.Add(rule.Key, rule.Value)
  762. case rule.Delete:
  763. request.Header.Del(rule.Key)
  764. case rule.Replace:
  765. if request.Header.Get(rule.Key) != "" {
  766. request.Header.Set(rule.Key, rule.Value)
  767. }
  768. case rule.ReplaceRegex:
  769. if value := request.Header.Get(rule.Key); value != "" {
  770. request.Header.Set(rule.Key, rule.Match.ReplaceAllString(value, rule.Value))
  771. }
  772. }
  773. }
  774. for i, rule := range options.SurgeBodyRewrite {
  775. if rule.Response {
  776. continue
  777. }
  778. if !rule.Pattern.MatchString(requestURL) {
  779. continue
  780. }
  781. requestMatch = true
  782. e.logger.DebugContext(ctx, "match body_rewrite[", i, "] => ", rule.String())
  783. var body []byte
  784. if request.ContentLength <= 0 {
  785. e.logger.WarnContext(ctx, "body replace skipped due to non-fixed content length")
  786. break
  787. } else if request.ContentLength > 131072 {
  788. e.logger.WarnContext(ctx, "body replace skipped due to large content length: ", request.ContentLength)
  789. break
  790. }
  791. body, err := io.ReadAll(request.Body)
  792. if err != nil {
  793. return E.Cause(err, "read HTTP request body")
  794. }
  795. request.Body.Close()
  796. for mi := 0; i < len(rule.Match); i++ {
  797. body = rule.Match[mi].ReplaceAll(body, []byte(rule.Replace[i]))
  798. }
  799. request.Body = io.NopCloser(bytes.NewReader(body))
  800. request.ContentLength = int64(len(body))
  801. }
  802. }
  803. if !requestMatch {
  804. for i, rule := range options.SurgeMapLocal {
  805. if !rule.Pattern.MatchString(requestURL) {
  806. continue
  807. }
  808. requestMatch = true
  809. e.logger.DebugContext(ctx, "match map_local[", i, "] => ", rule.String())
  810. go func() {
  811. io.Copy(io.Discard, request.Body)
  812. request.Body.Close()
  813. }()
  814. var (
  815. statusCode = http.StatusOK
  816. headers = make(http.Header)
  817. body []byte
  818. )
  819. if rule.StatusCode > 0 {
  820. statusCode = rule.StatusCode
  821. }
  822. switch {
  823. case rule.File:
  824. resource, err := os.ReadFile(rule.Data)
  825. if err != nil {
  826. return E.Cause(err, "open map local source")
  827. }
  828. mimeType := mime.TypeByExtension(filepath.Ext(rule.Data))
  829. if mimeType == "" {
  830. mimeType = "application/octet-stream"
  831. }
  832. headers.Set("Content-Type", mimeType)
  833. body = resource
  834. case rule.Text:
  835. headers.Set("Content-Type", "text/plain")
  836. body = []byte(rule.Data)
  837. case rule.TinyGif:
  838. headers.Set("Content-Type", "image/gif")
  839. body = surgeTinyGif()
  840. case rule.Base64:
  841. headers.Set("Content-Type", "application/octet-stream")
  842. body = rule.Base64Data
  843. }
  844. for key, values := range headers {
  845. writer.Header()[key] = values
  846. }
  847. writer.WriteHeader(statusCode)
  848. _, err = writer.Write(body)
  849. if err != nil {
  850. return E.Cause(err, "write map local response")
  851. }
  852. return nil
  853. }
  854. }
  855. requestCtx, cancel := context.WithCancel(ctx)
  856. defer cancel()
  857. response, err := e.httpClient.Do(request.WithContext(requestCtx))
  858. if err != nil {
  859. cancel()
  860. return E.Cause(err, "exchange request")
  861. }
  862. var (
  863. responseScript adapter.SurgeScript
  864. responseMatch bool
  865. responseScriptOptions option.MITMRouteSurgeScriptOptions
  866. )
  867. matchResponse:
  868. for _, scriptOptions := range options.Script {
  869. script, loaded := e.script.Script(scriptOptions.Tag)
  870. if !loaded {
  871. e.logger.WarnContext(ctx, "script not found: ", scriptOptions.Tag)
  872. continue
  873. }
  874. surgeScript, isSurge := script.(adapter.SurgeScript)
  875. if !isSurge {
  876. e.logger.WarnContext(ctx, "specified script/", script.Type(), "[", script.Tag(), "] is not a Surge script")
  877. continue
  878. }
  879. for _, pattern := range scriptOptions.Pattern {
  880. if pattern.Build().MatchString(requestURL) {
  881. e.logger.DebugContext(ctx, "match script/", surgeScript.Type(), "[", surgeScript.Tag(), "]")
  882. responseScript = surgeScript
  883. responseScriptOptions = scriptOptions
  884. responseMatch = true
  885. break matchResponse
  886. }
  887. }
  888. }
  889. var responseBody []byte
  890. if options.Print && response.ContentLength > 0 && response.ContentLength <= 131072 {
  891. responseBody, err = io.ReadAll(response.Body)
  892. if err != nil {
  893. return E.Cause(err, "read HTTP response body")
  894. }
  895. response.Body.Close()
  896. response.Body = io.NopCloser(bytes.NewReader(responseBody))
  897. }
  898. if options.Print {
  899. e.printResponse(ctx, request, response, responseBody)
  900. }
  901. if responseScript != nil {
  902. if responseBody == nil && responseScriptOptions.RequiresBody && response.ContentLength > 0 && (responseScriptOptions.MaxSize == 0 && response.ContentLength <= 131072 || response.ContentLength <= responseScriptOptions.MaxSize) {
  903. responseBody, err = io.ReadAll(response.Body)
  904. if err != nil {
  905. return E.Cause(err, "read HTTP response body")
  906. }
  907. response.Body.Close()
  908. response.Body = io.NopCloser(bytes.NewReader(responseBody))
  909. }
  910. var result *adapter.HTTPResponseScriptResult
  911. result, err = responseScript.ExecuteHTTPResponse(ctx, time.Duration(responseScriptOptions.Timeout), request, response, responseBody, responseScriptOptions.BinaryBodyMode, responseScriptOptions.Arguments)
  912. if err != nil {
  913. return E.Cause(err, "execute script/", responseScript.Type(), "[", responseScript.Tag(), "]")
  914. }
  915. if result.Status > 0 {
  916. response.Status = http.StatusText(result.Status)
  917. response.StatusCode = result.Status
  918. }
  919. for key, values := range result.Headers {
  920. response.Header[key] = values
  921. }
  922. if result.Body != nil {
  923. response.Body.Close()
  924. response.Body = io.NopCloser(bytes.NewReader(result.Body))
  925. response.ContentLength = int64(len(result.Body))
  926. }
  927. }
  928. if !responseMatch {
  929. for i, rule := range options.SurgeHeaderRewrite {
  930. if !rule.Response {
  931. continue
  932. }
  933. if !rule.Pattern.MatchString(requestURL) {
  934. continue
  935. }
  936. responseMatch = true
  937. e.logger.DebugContext(ctx, "match header_rewrite[", i, "] => ", rule.String())
  938. switch {
  939. case rule.Add:
  940. response.Header.Add(rule.Key, rule.Value)
  941. case rule.Delete:
  942. response.Header.Del(rule.Key)
  943. case rule.Replace:
  944. if response.Header.Get(rule.Key) != "" {
  945. response.Header.Set(rule.Key, rule.Value)
  946. }
  947. case rule.ReplaceRegex:
  948. if value := response.Header.Get(rule.Key); value != "" {
  949. response.Header.Set(rule.Key, rule.Match.ReplaceAllString(value, rule.Value))
  950. }
  951. }
  952. }
  953. for i, rule := range options.SurgeBodyRewrite {
  954. if !rule.Response {
  955. continue
  956. }
  957. if !rule.Pattern.MatchString(requestURL) {
  958. continue
  959. }
  960. responseMatch = true
  961. e.logger.DebugContext(ctx, "match body_rewrite[", i, "] => ", rule.String())
  962. if responseBody == nil {
  963. if response.ContentLength <= 0 {
  964. e.logger.WarnContext(ctx, "body replace skipped due to non-fixed content length")
  965. break
  966. } else if response.ContentLength > 131072 {
  967. e.logger.WarnContext(ctx, "body replace skipped due to large content length: ", request.ContentLength)
  968. break
  969. }
  970. responseBody, err = io.ReadAll(response.Body)
  971. if err != nil {
  972. return E.Cause(err, "read HTTP request body")
  973. }
  974. response.Body.Close()
  975. }
  976. for mi := 0; i < len(rule.Match); i++ {
  977. responseBody = rule.Match[mi].ReplaceAll(responseBody, []byte(rule.Replace[i]))
  978. }
  979. response.Body = io.NopCloser(bytes.NewReader(responseBody))
  980. response.ContentLength = int64(len(responseBody))
  981. }
  982. }
  983. if !options.Print && !requestMatch && !responseMatch {
  984. e.logger.WarnContext(ctx, "request not modified")
  985. }
  986. for key, values := range response.Header {
  987. writer.Header()[key] = values
  988. }
  989. writer.WriteHeader(response.StatusCode)
  990. _, err = io.Copy(writer, response.Body)
  991. response.Body.Close()
  992. if err != nil {
  993. return E.Cause(err, "write HTTP response")
  994. }
  995. return nil
  996. }
  997. func (e *Engine) printRequest(ctx context.Context, request *http.Request, body []byte) {
  998. var builder strings.Builder
  999. builder.WriteString(F.ToString(request.Proto, " ", request.Method, " ", request.URL))
  1000. builder.WriteString("\n")
  1001. if request.URL.Hostname() != "" && request.URL.Hostname() != request.Host {
  1002. builder.WriteString("Host: ")
  1003. builder.WriteString(request.Host)
  1004. builder.WriteString("\n")
  1005. }
  1006. for key, values := range request.Header {
  1007. for _, value := range values {
  1008. builder.WriteString(key)
  1009. builder.WriteString(": ")
  1010. builder.WriteString(value)
  1011. builder.WriteString("\n")
  1012. }
  1013. }
  1014. if len(body) > 0 {
  1015. builder.WriteString("\n")
  1016. if !bytes.ContainsFunc(body, func(r rune) bool {
  1017. return !unicode.IsPrint(r) && !unicode.IsSpace(r)
  1018. }) {
  1019. builder.Write(body)
  1020. } else {
  1021. builder.WriteString("(body not printable)")
  1022. }
  1023. }
  1024. e.logger.InfoContext(ctx, "request: ", builder.String())
  1025. }
  1026. func (e *Engine) printResponse(ctx context.Context, request *http.Request, response *http.Response, body []byte) {
  1027. var builder strings.Builder
  1028. builder.WriteString(F.ToString(response.Proto, " ", response.Status, " ", request.URL))
  1029. builder.WriteString("\n")
  1030. for key, values := range response.Header {
  1031. for _, value := range values {
  1032. builder.WriteString(key)
  1033. builder.WriteString(": ")
  1034. builder.WriteString(value)
  1035. builder.WriteString("\n")
  1036. }
  1037. }
  1038. if len(body) > 0 {
  1039. builder.WriteString("\n")
  1040. if !bytes.ContainsFunc(body, func(r rune) bool {
  1041. return !unicode.IsPrint(r) && !unicode.IsSpace(r)
  1042. }) {
  1043. builder.Write(body)
  1044. } else {
  1045. builder.WriteString("(body not printable)")
  1046. }
  1047. }
  1048. e.logger.InfoContext(ctx, "response: ", builder.String())
  1049. }
  1050. type simpleResponseWriter struct {
  1051. statusCode int
  1052. header http.Header
  1053. body bytes.Buffer
  1054. }
  1055. func (w *simpleResponseWriter) Build(request *http.Request) *http.Response {
  1056. return &http.Response{
  1057. StatusCode: w.statusCode,
  1058. Status: http.StatusText(w.statusCode),
  1059. Proto: request.Proto,
  1060. ProtoMajor: request.ProtoMajor,
  1061. ProtoMinor: request.ProtoMinor,
  1062. Header: w.header,
  1063. Body: io.NopCloser(&w.body),
  1064. }
  1065. }
  1066. func (w *simpleResponseWriter) Header() http.Header {
  1067. if w.header == nil {
  1068. w.header = make(http.Header)
  1069. }
  1070. return w.header
  1071. }
  1072. func (w *simpleResponseWriter) Write(b []byte) (int, error) {
  1073. return w.body.Write(b)
  1074. }
  1075. func (w *simpleResponseWriter) WriteHeader(statusCode int) {
  1076. w.statusCode = statusCode
  1077. }