فهرست منبع

Improve tls dialer and listener

世界 3 سال پیش
والد
کامیت
1f05420745
8فایلهای تغییر یافته به همراه265 افزوده شده و 28 حذف شده
  1. 56 2
      common/dialer/tls.go
  2. 1 1
      inbound/builder.go
  3. 17 4
      inbound/http.go
  4. 80 0
      inbound/tls.go
  5. 13 2
      inbound/vmess.go
  6. 9 5
      option/inbound.go
  7. 12 14
      option/outbound.go
  8. 77 0
      option/tls.go

+ 56 - 2
common/dialer/tls.go

@@ -59,7 +59,61 @@ func NewTLS(dialer N.Dialer, serverAddress string, options option.OutboundTLSOpt
 			return err
 		}
 	}
-
+	if len(options.ALPN) > 0 {
+		tlsConfig.NextProtos = options.ALPN
+	}
+	if options.MinVersion != "" {
+		minVersion, err := option.ParseTLSVersion(options.MinVersion)
+		if err != nil {
+			return nil, E.Cause(err, "parse min_version")
+		}
+		tlsConfig.MinVersion = minVersion
+	}
+	if options.MaxVersion != "" {
+		maxVersion, err := option.ParseTLSVersion(options.MaxVersion)
+		if err != nil {
+			return nil, E.Cause(err, "parse max_version")
+		}
+		tlsConfig.MaxVersion = maxVersion
+	}
+	if options.CipherSuites != nil {
+	find:
+		for _, cipherSuite := range options.CipherSuites {
+			for _, tlsCipherSuite := range tls.CipherSuites() {
+				if cipherSuite == tlsCipherSuite.Name {
+					tlsConfig.CipherSuites = append(tlsConfig.CipherSuites, tlsCipherSuite.ID)
+					continue find
+				}
+			}
+			return nil, E.New("unknown cipher_suite: ", cipherSuite)
+		}
+	}
+	var certificate []byte
+	if options.Certificate != "" {
+		certificate = []byte(options.Certificate)
+	} else if options.CertificatePath != "" {
+		content, err := os.ReadFile(options.CertificatePath)
+		if err != nil {
+			return nil, E.Cause(err, "read certificate")
+		}
+		certificate = content
+	}
+	if len(certificate) > 0 {
+		var certPool *x509.CertPool
+		if options.DisableSystemRoot {
+			certPool = x509.NewCertPool()
+		} else {
+			var err error
+			certPool, err = x509.SystemCertPool()
+			if err != nil {
+				return nil, E.Cause(err, "load system cert pool")
+			}
+		}
+		if !certPool.AppendCertsFromPEM([]byte(options.Certificate)) {
+			return nil, E.New("failed to parse certificate:\n\n", options.Certificate)
+		}
+		tlsConfig.RootCAs = certPool
+	}
 	return &TLSDialer{
 		dialer: dialer,
 		config: &tlsConfig,
@@ -75,7 +129,7 @@ func (d *TLSDialer) DialContext(ctx context.Context, network string, destination
 		return nil, err
 	}
 	tlsConn := tls.Client(conn, d.config)
-	ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTCPTimeout)
+	ctx, cancel := context.WithTimeout(ctx, C.DefaultTCPTimeout)
 	defer cancel()
 	err = tlsConn.HandshakeContext(ctx)
 	return tlsConn, err

+ 1 - 1
inbound/builder.go

@@ -27,7 +27,7 @@ func New(ctx context.Context, router adapter.Router, logger log.ContextLogger, o
 	case C.TypeSocks:
 		return NewSocks(ctx, router, logger, options.Tag, options.SocksOptions), nil
 	case C.TypeHTTP:
-		return NewHTTP(ctx, router, logger, options.Tag, options.HTTPOptions), nil
+		return NewHTTP(ctx, router, logger, options.Tag, options.HTTPOptions)
 	case C.TypeMixed:
 		return NewMixed(ctx, router, logger, options.Tag, options.MixedOptions), nil
 	case C.TypeShadowsocks:

+ 17 - 4
inbound/http.go

@@ -3,12 +3,14 @@ package inbound
 import (
 	std_bufio "bufio"
 	"context"
+	"crypto/tls"
 	"net"
 
 	"github.com/sagernet/sing-box/adapter"
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/log"
 	"github.com/sagernet/sing-box/option"
+	"github.com/sagernet/sing/common"
 	"github.com/sagernet/sing/common/auth"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
@@ -20,11 +22,12 @@ var _ adapter.Inbound = (*HTTP)(nil)
 type HTTP struct {
 	myInboundAdapter
 	authenticator auth.Authenticator
+	tlsConfig     *tls.Config
 }
 
-func NewHTTP(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HTTPMixedInboundOptions) *HTTP {
+func NewHTTP(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HTTPMixedInboundOptions) (*HTTP, error) {
 	inbound := &HTTP{
-		myInboundAdapter{
+		myInboundAdapter: myInboundAdapter{
 			protocol:       C.TypeHTTP,
 			network:        []string{C.NetworkTCP},
 			ctx:            ctx,
@@ -34,13 +37,23 @@ func NewHTTP(ctx context.Context, router adapter.Router, logger log.ContextLogge
 			listenOptions:  options.ListenOptions,
 			setSystemProxy: options.SetSystemProxy,
 		},
-		auth.NewAuthenticator(options.Users),
+		authenticator: auth.NewAuthenticator(options.Users),
+	}
+	if options.TLS != nil {
+		tlsConfig, err := NewTLSConfig(common.PtrValueOrDefault(options.TLS))
+		if err != nil {
+			return nil, err
+		}
+		inbound.tlsConfig = tlsConfig
 	}
 	inbound.connHandler = inbound
-	return inbound
+	return inbound, nil
 }
 
 func (h *HTTP) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
+	if h.tlsConfig != nil {
+		conn = tls.Server(conn, h.tlsConfig)
+	}
 	return http.HandleConnection(ctx, conn, std_bufio.NewReader(conn), h.authenticator, h.upstreamUserHandler(metadata), M.Metadata{})
 }
 

+ 80 - 0
inbound/tls.go

@@ -0,0 +1,80 @@
+package inbound
+
+import (
+	"crypto/tls"
+	"os"
+
+	"github.com/sagernet/sing-box/option"
+	E "github.com/sagernet/sing/common/exceptions"
+)
+
+func NewTLSConfig(options option.InboundTLSOptions) (*tls.Config, error) {
+	if !options.Enabled {
+		return nil, nil
+	}
+	var tlsConfig tls.Config
+	if options.ServerName != "" {
+		tlsConfig.ServerName = options.ServerName
+	}
+	if len(options.ALPN) > 0 {
+		tlsConfig.NextProtos = options.ALPN
+	}
+	if options.MinVersion != "" {
+		minVersion, err := option.ParseTLSVersion(options.MinVersion)
+		if err != nil {
+			return nil, E.Cause(err, "parse min_version")
+		}
+		tlsConfig.MinVersion = minVersion
+	}
+	if options.MaxVersion != "" {
+		maxVersion, err := option.ParseTLSVersion(options.MaxVersion)
+		if err != nil {
+			return nil, E.Cause(err, "parse max_version")
+		}
+		tlsConfig.MaxVersion = maxVersion
+	}
+	if options.CipherSuites != nil {
+	find:
+		for _, cipherSuite := range options.CipherSuites {
+			for _, tlsCipherSuite := range tls.CipherSuites() {
+				if cipherSuite == tlsCipherSuite.Name {
+					tlsConfig.CipherSuites = append(tlsConfig.CipherSuites, tlsCipherSuite.ID)
+					continue find
+				}
+			}
+			return nil, E.New("unknown cipher_suite: ", cipherSuite)
+		}
+	}
+	var certificate []byte
+	if options.Certificate != "" {
+		certificate = []byte(options.Certificate)
+	} else if options.CertificatePath != "" {
+		content, err := os.ReadFile(options.CertificatePath)
+		if err != nil {
+			return nil, E.Cause(err, "read certificate")
+		}
+		certificate = content
+	}
+	var key []byte
+	if options.Key != "" {
+		key = []byte(options.Key)
+	} else if options.KeyPath != "" {
+		content, err := os.ReadFile(options.KeyPath)
+		if err != nil {
+			return nil, E.Cause(err, "read key")
+		}
+		key = content
+	}
+	if certificate == nil {
+		return nil, E.New("missing certificate")
+	}
+	if key == nil {
+		return nil, E.New("missing key")
+	}
+	keyPair, err := tls.X509KeyPair(certificate, key)
+	if err != nil {
+		return nil, E.Cause(err, "parse x509 key pair")
+	}
+	tlsConfig.Certificates = []tls.Certificate{keyPair}
+	return &tlsConfig, nil
+}

+ 13 - 2
inbound/vmess.go

@@ -2,6 +2,7 @@ package inbound
 
 import (
 	"context"
+	"crypto/tls"
 	"net"
 	"os"
 
@@ -20,8 +21,9 @@ var _ adapter.Inbound = (*VMess)(nil)
 
 type VMess struct {
 	myInboundAdapter
-	service *vmess.Service[int]
-	users   []option.VMessUser
+	service   *vmess.Service[int]
+	users     []option.VMessUser
+	tlsConfig *tls.Config
 }
 
 func NewVMess(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.VMessInboundOptions) (*VMess, error) {
@@ -46,12 +48,21 @@ func NewVMess(ctx context.Context, router adapter.Router, logger log.ContextLogg
 	if err != nil {
 		return nil, err
 	}
+	if options.TLS != nil {
+		inbound.tlsConfig, err = NewTLSConfig(common.PtrValueOrDefault(options.TLS))
+		if err != nil {
+			return nil, err
+		}
+	}
 	inbound.service = service
 	inbound.connHandler = inbound
 	return inbound, nil
 }
 
 func (h *VMess) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
+	if h.tlsConfig != nil {
+		conn = tls.Server(conn, h.tlsConfig)
+	}
 	return h.service.NewConnection(adapter.WithContext(log.ContextWithNewID(ctx), &metadata), conn, adapter.UpstreamMetadata(metadata))
 }
 

+ 9 - 5
option/inbound.go

@@ -126,14 +126,16 @@ func (o SocksInboundOptions) Equals(other SocksInboundOptions) bool {
 
 type HTTPMixedInboundOptions struct {
 	ListenOptions
-	Users          []auth.User `json:"users,omitempty"`
-	SetSystemProxy bool        `json:"set_system_proxy,omitempty"`
+	Users          []auth.User        `json:"users,omitempty"`
+	SetSystemProxy bool               `json:"set_system_proxy,omitempty"`
+	TLS            *InboundTLSOptions `json:"tls,omitempty"`
 }
 
 func (o HTTPMixedInboundOptions) Equals(other HTTPMixedInboundOptions) bool {
 	return o.ListenOptions == other.ListenOptions &&
 		common.ComparableSliceEquals(o.Users, other.Users) &&
-		o.SetSystemProxy == other.SetSystemProxy
+		o.SetSystemProxy == other.SetSystemProxy &&
+		common.PtrEquals(o.TLS, other.TLS)
 }
 
 type DirectInboundOptions struct {
@@ -174,12 +176,14 @@ type ShadowsocksDestination struct {
 
 type VMessInboundOptions struct {
 	ListenOptions
-	Users []VMessUser `json:"users,omitempty"`
+	Users []VMessUser        `json:"users,omitempty"`
+	TLS   *InboundTLSOptions `json:"tls,omitempty"`
 }
 
 func (o VMessInboundOptions) Equals(other VMessInboundOptions) bool {
 	return o.ListenOptions == other.ListenOptions &&
-		common.ComparableSliceEquals(o.Users, other.Users)
+		common.ComparableSliceEquals(o.Users, other.Users) &&
+		common.PtrEquals(o.TLS, other.TLS)
 }
 
 type VMessUser struct {

+ 12 - 14
option/outbound.go

@@ -140,13 +140,6 @@ type HTTPOutboundOptions struct {
 	TLSOptions *OutboundTLSOptions `json:"tls,omitempty"`
 }
 
-type OutboundTLSOptions struct {
-	Enabled    bool   `json:"enabled,omitempty"`
-	DisableSNI bool   `json:"disable_sni,omitempty"`
-	ServerName string `json:"server_name,omitempty"`
-	Insecure   bool   `json:"insecure,omitempty"`
-}
-
 type ShadowsocksOutboundOptions struct {
 	OutboundDialerOptions
 	ServerOptions
@@ -158,13 +151,18 @@ type ShadowsocksOutboundOptions struct {
 type VMessOutboundOptions struct {
 	OutboundDialerOptions
 	ServerOptions
-	UUID                string              `json:"uuid"`
-	Security            string              `json:"security"`
-	AlterId             int                 `json:"alter_id,omitempty"`
-	GlobalPadding       bool                `json:"global_padding,omitempty"`
-	AuthenticatedLength bool                `json:"authenticated_length,omitempty"`
-	Network             NetworkList         `json:"network,omitempty"`
-	TLSOptions          *OutboundTLSOptions `json:"tls,omitempty"`
+	UUID                string                 `json:"uuid"`
+	Security            string                 `json:"security"`
+	AlterId             int                    `json:"alter_id,omitempty"`
+	GlobalPadding       bool                   `json:"global_padding,omitempty"`
+	AuthenticatedLength bool                   `json:"authenticated_length,omitempty"`
+	Network             NetworkList            `json:"network,omitempty"`
+	TLSOptions          *OutboundTLSOptions    `json:"tls,omitempty"`
+	TransportOptions    *VMessTransportOptions `json:"transport,omitempty"`
+}
+
+type VMessTransportOptions struct {
+	Network string `json:"network,omitempty"`
 }
 
 type SelectorOutboundOptions struct {

+ 77 - 0
option/tls.go

@@ -0,0 +1,77 @@
+package option
+
+import (
+	"crypto/tls"
+
+	"github.com/sagernet/sing/common"
+	E "github.com/sagernet/sing/common/exceptions"
+)
+
+type InboundTLSOptions struct {
+	Enabled         bool     `json:"enabled,omitempty"`
+	ServerName      string   `json:"server_name,omitempty"`
+	ALPN            []string `json:"alpn,omitempty"`
+	MinVersion      string   `json:"min_version,omitempty"`
+	MaxVersion      string   `json:"max_version,omitempty"`
+	CipherSuites    []string `json:"cipher_suites,omitempty"`
+	Certificate     string   `json:"certificate,omitempty"`
+	CertificatePath string   `json:"certificate_path,omitempty"`
+	Key             string   `json:"key,omitempty"`
+	KeyPath         string   `json:"key_path,omitempty"`
+}
+
+func (o InboundTLSOptions) Equals(other InboundTLSOptions) bool {
+	return o.Enabled == other.Enabled &&
+		o.ServerName == other.ServerName &&
+		common.ComparableSliceEquals(o.ALPN, other.ALPN) &&
+		o.MinVersion == other.MinVersion &&
+		o.MaxVersion == other.MaxVersion &&
+		common.ComparableSliceEquals(o.CipherSuites, other.CipherSuites) &&
+		o.Certificate == other.Certificate &&
+		o.CertificatePath == other.CertificatePath &&
+		o.Key == other.Key &&
+		o.KeyPath == other.KeyPath
+}
+
+type OutboundTLSOptions struct {
+	Enabled           bool     `json:"enabled,omitempty"`
+	DisableSNI        bool     `json:"disable_sni,omitempty"`
+	ServerName        string   `json:"server_name,omitempty"`
+	Insecure          bool     `json:"insecure,omitempty"`
+	ALPN              []string `json:"alpn,omitempty"`
+	MinVersion        string   `json:"min_version,omitempty"`
+	MaxVersion        string   `json:"max_version,omitempty"`
+	CipherSuites      []string `json:"cipher_suites,omitempty"`
+	DisableSystemRoot bool     `json:"disable_system_root,omitempty"`
+	Certificate       string   `json:"certificate,omitempty"`
+	CertificatePath   string   `json:"certificate_path,omitempty"`
+}
+
+func (o OutboundTLSOptions) Equals(other OutboundTLSOptions) bool {
+	return o.Enabled == other.Enabled &&
+		o.DisableSNI == other.DisableSNI &&
+		o.ServerName == other.ServerName &&
+		o.Insecure == other.Insecure &&
+		common.ComparableSliceEquals(o.ALPN, other.ALPN) &&
+		o.MinVersion == other.MinVersion &&
+		o.MaxVersion == other.MaxVersion &&
+		common.ComparableSliceEquals(o.CipherSuites, other.CipherSuites) &&
+		o.DisableSystemRoot == other.DisableSystemRoot &&
+		o.Certificate == other.Certificate &&
+		o.CertificatePath == other.CertificatePath
+}
+
+func ParseTLSVersion(version string) (uint16, error) {
+	switch version {
+	case "1.0":
+		return tls.VersionTLS10, nil
+	case "1.1":
+		return tls.VersionTLS11, nil
+	case "1.2":
+		return tls.VersionTLS12, nil
+	case "1.3":
+		return tls.VersionTLS13, nil
+	default:
+		return 0, E.New("unknown tls version:", version)
+	}
+}