tls_credentials.go 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. package v2raygrpc
  2. import (
  3. "context"
  4. "net"
  5. "os"
  6. "github.com/sagernet/sing-box/common/tls"
  7. internal_credentials "github.com/sagernet/sing-box/transport/v2raygrpc/credentials"
  8. "google.golang.org/grpc/credentials"
  9. )
  10. type TLSTransportCredentials struct {
  11. config tls.Config
  12. }
  13. func NewTLSTransportCredentials(config tls.Config) credentials.TransportCredentials {
  14. return &TLSTransportCredentials{config}
  15. }
  16. func (c *TLSTransportCredentials) Info() credentials.ProtocolInfo {
  17. return credentials.ProtocolInfo{
  18. SecurityProtocol: "tls",
  19. SecurityVersion: "1.2",
  20. ServerName: c.config.ServerName(),
  21. }
  22. }
  23. func (c *TLSTransportCredentials) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
  24. cfg := c.config.Clone()
  25. if cfg.ServerName() == "" {
  26. serverName, _, err := net.SplitHostPort(authority)
  27. if err != nil {
  28. serverName = authority
  29. }
  30. cfg.SetServerName(serverName)
  31. }
  32. conn, err := tls.ClientHandshake(ctx, rawConn, cfg)
  33. if err != nil {
  34. return nil, nil, err
  35. }
  36. tlsInfo := credentials.TLSInfo{
  37. State: conn.ConnectionState(),
  38. CommonAuthInfo: credentials.CommonAuthInfo{
  39. SecurityLevel: credentials.PrivacyAndIntegrity,
  40. },
  41. }
  42. id := internal_credentials.SPIFFEIDFromState(conn.ConnectionState())
  43. if id != nil {
  44. tlsInfo.SPIFFEID = id
  45. }
  46. return internal_credentials.WrapSyscallConn(rawConn, conn), tlsInfo, nil
  47. }
  48. func (c *TLSTransportCredentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
  49. serverConfig, isServer := c.config.(tls.ServerConfig)
  50. if !isServer {
  51. return nil, nil, os.ErrInvalid
  52. }
  53. conn, err := tls.ServerHandshake(context.Background(), rawConn, serverConfig)
  54. if err != nil {
  55. rawConn.Close()
  56. return nil, nil, err
  57. }
  58. tlsInfo := credentials.TLSInfo{
  59. State: conn.ConnectionState(),
  60. CommonAuthInfo: credentials.CommonAuthInfo{
  61. SecurityLevel: credentials.PrivacyAndIntegrity,
  62. },
  63. }
  64. id := internal_credentials.SPIFFEIDFromState(conn.ConnectionState())
  65. if id != nil {
  66. tlsInfo.SPIFFEID = id
  67. }
  68. return internal_credentials.WrapSyscallConn(rawConn, conn), tlsInfo, nil
  69. }
  70. func (c *TLSTransportCredentials) Clone() credentials.TransportCredentials {
  71. return NewTLSTransportCredentials(c.config)
  72. }
  73. func (c *TLSTransportCredentials) OverrideServerName(serverNameOverride string) error {
  74. c.config.SetServerName(serverNameOverride)
  75. return nil
  76. }