address.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. package protocol
  2. import (
  3. "io"
  4. "github.com/xtls/xray-core/common"
  5. "github.com/xtls/xray-core/common/buf"
  6. "github.com/xtls/xray-core/common/net"
  7. "github.com/xtls/xray-core/common/serial"
  8. )
  9. type AddressOption func(*option)
  10. func PortThenAddress() AddressOption {
  11. return func(p *option) {
  12. p.portFirst = true
  13. }
  14. }
  15. func AddressFamilyByte(b byte, f net.AddressFamily) AddressOption {
  16. if b >= 16 {
  17. panic("address family byte too big")
  18. }
  19. return func(p *option) {
  20. p.addrTypeMap[b] = f
  21. p.addrByteMap[f] = b
  22. }
  23. }
  24. type AddressTypeParser func(byte) byte
  25. func WithAddressTypeParser(atp AddressTypeParser) AddressOption {
  26. return func(p *option) {
  27. p.typeParser = atp
  28. }
  29. }
  30. type AddressSerializer interface {
  31. ReadAddressPort(buffer *buf.Buffer, input io.Reader) (net.Address, net.Port, error)
  32. WriteAddressPort(writer io.Writer, addr net.Address, port net.Port) error
  33. }
  34. const afInvalid = 255
  35. type option struct {
  36. addrTypeMap [16]net.AddressFamily
  37. addrByteMap [16]byte
  38. portFirst bool
  39. typeParser AddressTypeParser
  40. }
  41. // NewAddressParser creates a new AddressParser
  42. func NewAddressParser(options ...AddressOption) AddressSerializer {
  43. var o option
  44. for i := range o.addrByteMap {
  45. o.addrByteMap[i] = afInvalid
  46. }
  47. for i := range o.addrTypeMap {
  48. o.addrTypeMap[i] = net.AddressFamily(afInvalid)
  49. }
  50. for _, opt := range options {
  51. opt(&o)
  52. }
  53. ap := &addressParser{
  54. addrByteMap: o.addrByteMap,
  55. addrTypeMap: o.addrTypeMap,
  56. }
  57. if o.typeParser != nil {
  58. ap.typeParser = o.typeParser
  59. }
  60. if o.portFirst {
  61. return portFirstAddressParser{ap: ap}
  62. }
  63. return portLastAddressParser{ap: ap}
  64. }
  65. type portFirstAddressParser struct {
  66. ap *addressParser
  67. }
  68. func (p portFirstAddressParser) ReadAddressPort(buffer *buf.Buffer, input io.Reader) (net.Address, net.Port, error) {
  69. if buffer == nil {
  70. buffer = buf.New()
  71. defer buffer.Release()
  72. }
  73. port, err := readPort(buffer, input)
  74. if err != nil {
  75. return nil, 0, err
  76. }
  77. addr, err := p.ap.readAddress(buffer, input)
  78. if err != nil {
  79. return nil, 0, err
  80. }
  81. return addr, port, nil
  82. }
  83. func (p portFirstAddressParser) WriteAddressPort(writer io.Writer, addr net.Address, port net.Port) error {
  84. if err := writePort(writer, port); err != nil {
  85. return err
  86. }
  87. return p.ap.writeAddress(writer, addr)
  88. }
  89. type portLastAddressParser struct {
  90. ap *addressParser
  91. }
  92. func (p portLastAddressParser) ReadAddressPort(buffer *buf.Buffer, input io.Reader) (net.Address, net.Port, error) {
  93. if buffer == nil {
  94. buffer = buf.New()
  95. defer buffer.Release()
  96. }
  97. addr, err := p.ap.readAddress(buffer, input)
  98. if err != nil {
  99. return nil, 0, err
  100. }
  101. port, err := readPort(buffer, input)
  102. if err != nil {
  103. return nil, 0, err
  104. }
  105. return addr, port, nil
  106. }
  107. func (p portLastAddressParser) WriteAddressPort(writer io.Writer, addr net.Address, port net.Port) error {
  108. if err := p.ap.writeAddress(writer, addr); err != nil {
  109. return err
  110. }
  111. return writePort(writer, port)
  112. }
  113. func readPort(b *buf.Buffer, reader io.Reader) (net.Port, error) {
  114. if _, err := b.ReadFullFrom(reader, 2); err != nil {
  115. return 0, err
  116. }
  117. return net.PortFromBytes(b.BytesFrom(-2)), nil
  118. }
  119. func writePort(writer io.Writer, port net.Port) error {
  120. return common.Error2(serial.WriteUint16(writer, port.Value()))
  121. }
  122. func maybeIPPrefix(b byte) bool {
  123. return b == '[' || (b >= '0' && b <= '9')
  124. }
  125. func isValidDomain(d string) bool {
  126. for _, c := range d {
  127. if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '-' || c == '.' || c == '_') {
  128. return false
  129. }
  130. }
  131. return true
  132. }
  133. type addressParser struct {
  134. addrTypeMap [16]net.AddressFamily
  135. addrByteMap [16]byte
  136. typeParser AddressTypeParser
  137. }
  138. func (p *addressParser) readAddress(b *buf.Buffer, reader io.Reader) (net.Address, error) {
  139. if _, err := b.ReadFullFrom(reader, 1); err != nil {
  140. return nil, err
  141. }
  142. addrType := b.Byte(b.Len() - 1)
  143. if p.typeParser != nil {
  144. addrType = p.typeParser(addrType)
  145. }
  146. if addrType >= 16 {
  147. return nil, newError("unknown address type: ", addrType)
  148. }
  149. addrFamily := p.addrTypeMap[addrType]
  150. if addrFamily == net.AddressFamily(afInvalid) {
  151. return nil, newError("unknown address type: ", addrType)
  152. }
  153. switch addrFamily {
  154. case net.AddressFamilyIPv4:
  155. if _, err := b.ReadFullFrom(reader, 4); err != nil {
  156. return nil, err
  157. }
  158. return net.IPAddress(b.BytesFrom(-4)), nil
  159. case net.AddressFamilyIPv6:
  160. if _, err := b.ReadFullFrom(reader, 16); err != nil {
  161. return nil, err
  162. }
  163. return net.IPAddress(b.BytesFrom(-16)), nil
  164. case net.AddressFamilyDomain:
  165. if _, err := b.ReadFullFrom(reader, 1); err != nil {
  166. return nil, err
  167. }
  168. domainLength := int32(b.Byte(b.Len() - 1))
  169. if _, err := b.ReadFullFrom(reader, domainLength); err != nil {
  170. return nil, err
  171. }
  172. domain := string(b.BytesFrom(-domainLength))
  173. if maybeIPPrefix(domain[0]) {
  174. addr := net.ParseAddress(domain)
  175. if addr.Family().IsIP() {
  176. return addr, nil
  177. }
  178. }
  179. if !isValidDomain(domain) {
  180. return nil, newError("invalid domain name: ", domain)
  181. }
  182. return net.DomainAddress(domain), nil
  183. default:
  184. panic("impossible case")
  185. }
  186. }
  187. func (p *addressParser) writeAddress(writer io.Writer, address net.Address) error {
  188. tb := p.addrByteMap[address.Family()]
  189. if tb == afInvalid {
  190. return newError("unknown address family", address.Family())
  191. }
  192. switch address.Family() {
  193. case net.AddressFamilyIPv4, net.AddressFamilyIPv6:
  194. if _, err := writer.Write([]byte{tb}); err != nil {
  195. return err
  196. }
  197. if _, err := writer.Write(address.IP()); err != nil {
  198. return err
  199. }
  200. case net.AddressFamilyDomain:
  201. domain := address.Domain()
  202. if isDomainTooLong(domain) {
  203. return newError("Super long domain is not supported: ", domain)
  204. }
  205. if _, err := writer.Write([]byte{tb, byte(len(domain))}); err != nil {
  206. return err
  207. }
  208. if _, err := writer.Write([]byte(domain)); err != nil {
  209. return err
  210. }
  211. default:
  212. panic("Unknown family type.")
  213. }
  214. return nil
  215. }