1
0

tcp.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. package tcp
  2. import (
  3. "context"
  4. "fmt"
  5. "io"
  6. "github.com/xtls/xray-core/common/buf"
  7. "github.com/xtls/xray-core/common/net"
  8. "github.com/xtls/xray-core/common/task"
  9. "github.com/xtls/xray-core/transport/internet"
  10. "github.com/xtls/xray-core/transport/pipe"
  11. )
  12. type Server struct {
  13. Port net.Port
  14. MsgProcessor func(msg []byte) []byte
  15. ShouldClose bool
  16. SendFirst []byte
  17. Listen net.Address
  18. listener net.Listener
  19. }
  20. func (server *Server) Start() (net.Destination, error) {
  21. return server.StartContext(context.Background(), nil)
  22. }
  23. func (server *Server) StartContext(ctx context.Context, sockopt *internet.SocketConfig) (net.Destination, error) {
  24. listenerAddr := server.Listen
  25. if listenerAddr == nil {
  26. listenerAddr = net.LocalHostIP
  27. }
  28. listener, err := internet.ListenSystem(ctx, &net.TCPAddr{
  29. IP: listenerAddr.IP(),
  30. Port: int(server.Port),
  31. }, sockopt)
  32. if err != nil {
  33. return net.Destination{}, err
  34. }
  35. localAddr := listener.Addr().(*net.TCPAddr)
  36. server.Port = net.Port(localAddr.Port)
  37. server.listener = listener
  38. go server.acceptConnections(listener.(*net.TCPListener))
  39. return net.TCPDestination(net.IPAddress(localAddr.IP), net.Port(localAddr.Port)), nil
  40. }
  41. func (server *Server) acceptConnections(listener *net.TCPListener) {
  42. for {
  43. conn, err := listener.Accept()
  44. if err != nil {
  45. fmt.Printf("Failed accept TCP connection: %v\n", err)
  46. return
  47. }
  48. go server.handleConnection(conn)
  49. }
  50. }
  51. func (server *Server) handleConnection(conn net.Conn) {
  52. if len(server.SendFirst) > 0 {
  53. conn.Write(server.SendFirst)
  54. }
  55. pReader, pWriter := pipe.New(pipe.WithoutSizeLimit())
  56. err := task.Run(context.Background(), func() error {
  57. defer pWriter.Close()
  58. for {
  59. b := buf.New()
  60. if _, err := b.ReadFrom(conn); err != nil {
  61. if err == io.EOF {
  62. return nil
  63. }
  64. return err
  65. }
  66. copy(b.Bytes(), server.MsgProcessor(b.Bytes()))
  67. if err := pWriter.WriteMultiBuffer(buf.MultiBuffer{b}); err != nil {
  68. return err
  69. }
  70. }
  71. }, func() error {
  72. defer pReader.Interrupt()
  73. w := buf.NewWriter(conn)
  74. for {
  75. mb, err := pReader.ReadMultiBuffer()
  76. if err != nil {
  77. if err == io.EOF {
  78. return nil
  79. }
  80. return err
  81. }
  82. if err := w.WriteMultiBuffer(mb); err != nil {
  83. return err
  84. }
  85. }
  86. })
  87. if err != nil {
  88. fmt.Println("failed to transfer data: ", err.Error())
  89. }
  90. conn.Close()
  91. }
  92. func (server *Server) Close() error {
  93. return server.listener.Close()
  94. }