123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240 |
- package mux
- import (
- "encoding/binary"
- "io"
- "net"
- C "github.com/sagernet/sing-box/constant"
- "github.com/sagernet/sing/common"
- "github.com/sagernet/sing/common/buf"
- E "github.com/sagernet/sing/common/exceptions"
- M "github.com/sagernet/sing/common/metadata"
- N "github.com/sagernet/sing/common/network"
- "github.com/sagernet/sing/common/rw"
- "github.com/sagernet/smux"
- "github.com/hashicorp/yamux"
- )
- var Destination = M.Socksaddr{
- Fqdn: "sp.mux.sing-box.arpa",
- Port: 444,
- }
- const (
- ProtocolSMux Protocol = iota
- ProtocolYAMux
- )
- type Protocol byte
- func ParseProtocol(name string) (Protocol, error) {
- switch name {
- case "", "smux":
- return ProtocolSMux, nil
- case "yamux":
- return ProtocolYAMux, nil
- default:
- return ProtocolYAMux, E.New("unknown multiplex protocol: ", name)
- }
- }
- func (p Protocol) newServer(conn net.Conn) (abstractSession, error) {
- switch p {
- case ProtocolSMux:
- session, err := smux.Server(conn, smuxConfig())
- if err != nil {
- return nil, err
- }
- return &smuxSession{session}, nil
- case ProtocolYAMux:
- return yamux.Server(conn, yaMuxConfig())
- default:
- panic("unknown protocol")
- }
- }
- func (p Protocol) newClient(conn net.Conn) (abstractSession, error) {
- switch p {
- case ProtocolSMux:
- session, err := smux.Client(conn, smuxConfig())
- if err != nil {
- return nil, err
- }
- return &smuxSession{session}, nil
- case ProtocolYAMux:
- return yamux.Client(conn, yaMuxConfig())
- default:
- panic("unknown protocol")
- }
- }
- func smuxConfig() *smux.Config {
- config := smux.DefaultConfig()
- config.KeepAliveDisabled = true
- return config
- }
- func yaMuxConfig() *yamux.Config {
- config := yamux.DefaultConfig()
- config.LogOutput = io.Discard
- config.StreamCloseTimeout = C.TCPTimeout
- config.StreamOpenTimeout = C.TCPTimeout
- return config
- }
- func (p Protocol) String() string {
- switch p {
- case ProtocolSMux:
- return "smux"
- case ProtocolYAMux:
- return "yamux"
- default:
- return "unknown"
- }
- }
- const (
- version0 = 0
- )
- type Request struct {
- Protocol Protocol
- }
- func ReadRequest(reader io.Reader) (*Request, error) {
- version, err := rw.ReadByte(reader)
- if err != nil {
- return nil, err
- }
- if version != version0 {
- return nil, E.New("unsupported version: ", version)
- }
- protocol, err := rw.ReadByte(reader)
- if err != nil {
- return nil, err
- }
- if protocol > byte(ProtocolYAMux) {
- return nil, E.New("unsupported protocol: ", protocol)
- }
- return &Request{Protocol: Protocol(protocol)}, nil
- }
- func EncodeRequest(buffer *buf.Buffer, request Request) {
- buffer.WriteByte(version0)
- buffer.WriteByte(byte(request.Protocol))
- }
- const (
- flagUDP = 1
- flagAddr = 2
- statusSuccess = 0
- statusError = 1
- )
- type StreamRequest struct {
- Network string
- Destination M.Socksaddr
- PacketAddr bool
- }
- func ReadStreamRequest(reader io.Reader) (*StreamRequest, error) {
- var flags uint16
- err := binary.Read(reader, binary.BigEndian, &flags)
- if err != nil {
- return nil, err
- }
- destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
- if err != nil {
- return nil, err
- }
- var network string
- var udpAddr bool
- if flags&flagUDP == 0 {
- network = N.NetworkTCP
- } else {
- network = N.NetworkUDP
- udpAddr = flags&flagAddr != 0
- }
- return &StreamRequest{network, destination, udpAddr}, nil
- }
- func requestLen(request StreamRequest) int {
- var rLen int
- rLen += 1 // version
- rLen += 2 // flags
- rLen += M.SocksaddrSerializer.AddrPortLen(request.Destination)
- return rLen
- }
- func EncodeStreamRequest(request StreamRequest, buffer *buf.Buffer) {
- destination := request.Destination
- var flags uint16
- if request.Network == N.NetworkUDP {
- flags |= flagUDP
- }
- if request.PacketAddr {
- flags |= flagAddr
- if !destination.IsValid() {
- destination = Destination
- }
- }
- common.Must(
- binary.Write(buffer, binary.BigEndian, flags),
- M.SocksaddrSerializer.WriteAddrPort(buffer, destination),
- )
- }
- type StreamResponse struct {
- Status uint8
- Message string
- }
- func ReadStreamResponse(reader io.Reader) (*StreamResponse, error) {
- var response StreamResponse
- status, err := rw.ReadByte(reader)
- if err != nil {
- return nil, err
- }
- response.Status = status
- if status == statusError {
- response.Message, err = rw.ReadVString(reader)
- if err != nil {
- return nil, err
- }
- }
- return &response, nil
- }
- type wrapStream struct {
- net.Conn
- }
- func (w *wrapStream) Read(p []byte) (n int, err error) {
- n, err = w.Conn.Read(p)
- err = wrapError(err)
- return
- }
- func (w *wrapStream) Write(p []byte) (n int, err error) {
- n, err = w.Conn.Write(p)
- err = wrapError(err)
- return
- }
- func (w *wrapStream) WriteIsThreadUnsafe() {
- }
- func (w *wrapStream) Upstream() any {
- return w.Conn
- }
- func wrapError(err error) error {
- switch err {
- case yamux.ErrStreamClosed:
- return io.EOF
- default:
- return err
- }
- }
|