session.go 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. package mux
  2. import (
  3. "io"
  4. "net"
  5. "github.com/sagernet/sing/common"
  6. "github.com/sagernet/sing/common/buf"
  7. "github.com/sagernet/sing/common/bufio"
  8. N "github.com/sagernet/sing/common/network"
  9. "github.com/sagernet/smux"
  10. )
  11. type abstractSession interface {
  12. Open() (net.Conn, error)
  13. Accept() (net.Conn, error)
  14. NumStreams() int
  15. Close() error
  16. IsClosed() bool
  17. }
  18. var _ abstractSession = (*smuxSession)(nil)
  19. type smuxSession struct {
  20. *smux.Session
  21. }
  22. func (s *smuxSession) Open() (net.Conn, error) {
  23. return s.OpenStream()
  24. }
  25. func (s *smuxSession) Accept() (net.Conn, error) {
  26. return s.AcceptStream()
  27. }
  28. type protocolConn struct {
  29. net.Conn
  30. protocol Protocol
  31. protocolWritten bool
  32. }
  33. func (c *protocolConn) Write(p []byte) (n int, err error) {
  34. if c.protocolWritten {
  35. return c.Conn.Write(p)
  36. }
  37. _buffer := buf.StackNewSize(2 + len(p))
  38. defer common.KeepAlive(_buffer)
  39. buffer := common.Dup(_buffer)
  40. defer buffer.Release()
  41. EncodeRequest(buffer, Request{
  42. Protocol: c.protocol,
  43. })
  44. common.Must(common.Error(buffer.Write(p)))
  45. n, err = c.Conn.Write(buffer.Bytes())
  46. if err == nil {
  47. n--
  48. }
  49. c.protocolWritten = true
  50. return n, err
  51. }
  52. func (c *protocolConn) ReadFrom(r io.Reader) (n int64, err error) {
  53. if !c.protocolWritten {
  54. return bufio.ReadFrom0(c, r)
  55. }
  56. return bufio.Copy(c.Conn, r)
  57. }
  58. func (c *protocolConn) Upstream() any {
  59. return c.Conn
  60. }
  61. type vectorisedProtocolConn struct {
  62. protocolConn
  63. N.VectorisedWriter
  64. }
  65. func (c *vectorisedProtocolConn) WriteVectorised(buffers []*buf.Buffer) error {
  66. if c.protocolWritten {
  67. return c.VectorisedWriter.WriteVectorised(buffers)
  68. }
  69. c.protocolWritten = true
  70. _buffer := buf.StackNewSize(2)
  71. defer common.KeepAlive(_buffer)
  72. buffer := common.Dup(_buffer)
  73. defer buffer.Release()
  74. EncodeRequest(buffer, Request{
  75. Protocol: c.protocol,
  76. })
  77. return c.VectorisedWriter.WriteVectorised(append([]*buf.Buffer{buffer}, buffers...))
  78. }