http.go 7.1 KB


  1. package http
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "io"
  7. "net"
  8. "net/http"
  9. "strings"
  10. "time"
  11. "github.com/xtls/xray-core/common"
  12. "github.com/xtls/xray-core/common/buf"
  13. "github.com/xtls/xray-core/common/errors"
  14. )
  15. const (
  16. // CRLF is the line ending in HTTP header
  17. CRLF = "\r\n"
  18. // ENDING is the double line ending between HTTP header and body.
  19. ENDING = CRLF + CRLF
  20. // max length of HTTP header. Safety precaution for DDoS attack.
  21. maxHeaderLength = 8192
  22. )
  23. var (
  24. ErrHeaderToLong = errors.New("Header too long.")
  25. ErrHeaderMisMatch = errors.New("Header Mismatch.")
  26. )
  27. type Reader interface {
  28. Read(io.Reader) (*buf.Buffer, error)
  29. }
  30. type Writer interface {
  31. Write(io.Writer) error
  32. }
  33. type NoOpReader struct{}
  34. func (NoOpReader) Read(io.Reader) (*buf.Buffer, error) {
  35. return nil, nil
  36. }
  37. type NoOpWriter struct{}
  38. func (NoOpWriter) Write(io.Writer) error {
  39. return nil
  40. }
  41. type HeaderReader struct {
  42. req *http.Request
  43. expectedHeader *RequestConfig
  44. }
  45. func (h *HeaderReader) ExpectThisRequest(expectedHeader *RequestConfig) *HeaderReader {
  46. h.expectedHeader = expectedHeader
  47. return h
  48. }
  49. func (h *HeaderReader) Read(reader io.Reader) (*buf.Buffer, error) {
  50. buffer := buf.New()
  51. totalBytes := int32(0)
  52. endingDetected := false
  53. var headerBuf bytes.Buffer
  54. for totalBytes < maxHeaderLength {
  55. _, err := buffer.ReadFrom(reader)
  56. if err != nil {
  57. buffer.Release()
  58. return nil, err
  59. }
  60. if n := bytes.Index(buffer.Bytes(), []byte(ENDING)); n != -1 {
  61. headerBuf.Write(buffer.BytesRange(0, int32(n+len(ENDING))))
  62. buffer.Advance(int32(n + len(ENDING)))
  63. endingDetected = true
  64. break
  65. }
  66. lenEnding := int32(len(ENDING))
  67. if buffer.Len() >= lenEnding {
  68. totalBytes += buffer.Len() - lenEnding
  69. headerBuf.Write(buffer.BytesRange(0, buffer.Len()-lenEnding))
  70. leftover := buffer.BytesFrom(-lenEnding)
  71. buffer.Clear()
  72. copy(buffer.Extend(lenEnding), leftover)
  73. if _, err := readRequest(bufio.NewReader(bytes.NewReader(headerBuf.Bytes()))); err != io.ErrUnexpectedEOF {
  74. return nil, err
  75. }
  76. }
  77. }
  78. if !endingDetected {
  79. buffer.Release()
  80. return nil, ErrHeaderToLong
  81. }
  82. if h.expectedHeader == nil {
  83. if buffer.IsEmpty() {
  84. buffer.Release()
  85. return nil, nil
  86. }
  87. return buffer, nil
  88. }
  89. // Parse the request
  90. if req, err := readRequest(bufio.NewReader(bytes.NewReader(headerBuf.Bytes()))); err != nil {
  91. return nil, err
  92. } else {
  93. h.req = req
  94. }
  95. // Check req
  96. path := h.req.URL.Path
  97. hasThisURI := false
  98. for _, u := range h.expectedHeader.Uri {
  99. if u == path {
  100. hasThisURI = true
  101. }
  102. }
  103. if !hasThisURI {
  104. return nil, ErrHeaderMisMatch
  105. }
  106. if buffer.IsEmpty() {
  107. buffer.Release()
  108. return nil, nil
  109. }
  110. return buffer, nil
  111. }
  112. type HeaderWriter struct {
  113. header *buf.Buffer
  114. }
  115. func NewHeaderWriter(header *buf.Buffer) *HeaderWriter {
  116. return &HeaderWriter{
  117. header: header,
  118. }
  119. }
  120. func (w *HeaderWriter) Write(writer io.Writer) error {
  121. if w.header == nil {
  122. return nil
  123. }
  124. err := buf.WriteAllBytes(writer, w.header.Bytes(), nil)
  125. w.header.Release()
  126. w.header = nil
  127. return err
  128. }
  129. type Conn struct {
  130. net.Conn
  131. readBuffer *buf.Buffer
  132. oneTimeReader Reader
  133. oneTimeWriter Writer
  134. errorWriter Writer
  135. errorMismatchWriter Writer
  136. errorTooLongWriter Writer
  137. errReason error
  138. }
  139. func NewConn(conn net.Conn, reader Reader, writer Writer, errorWriter Writer, errorMismatchWriter Writer, errorTooLongWriter Writer) *Conn {
  140. return &Conn{
  141. Conn: conn,
  142. oneTimeReader: reader,
  143. oneTimeWriter: writer,
  144. errorWriter: errorWriter,
  145. errorMismatchWriter: errorMismatchWriter,
  146. errorTooLongWriter: errorTooLongWriter,
  147. }
  148. }
  149. func (c *Conn) Read(b []byte) (int, error) {
  150. if c.oneTimeReader != nil {
  151. buffer, err := c.oneTimeReader.Read(c.Conn)
  152. if err != nil {
  153. c.errReason = err
  154. return 0, err
  155. }
  156. c.readBuffer = buffer
  157. c.oneTimeReader = nil
  158. }
  159. if !c.readBuffer.IsEmpty() {
  160. nBytes, _ := c.readBuffer.Read(b)
  161. if c.readBuffer.IsEmpty() {
  162. c.readBuffer.Release()
  163. c.readBuffer = nil
  164. }
  165. return nBytes, nil
  166. }
  167. return c.Conn.Read(b)
  168. }
  169. // Write implements io.Writer.
  170. func (c *Conn) Write(b []byte) (int, error) {
  171. if c.oneTimeWriter != nil {
  172. err := c.oneTimeWriter.Write(c.Conn)
  173. c.oneTimeWriter = nil
  174. if err != nil {
  175. return 0, err
  176. }
  177. }
  178. return c.Conn.Write(b)
  179. }
  180. // Close implements net.Conn.Close().
  181. func (c *Conn) Close() error {
  182. if c.oneTimeWriter != nil && c.errorWriter != nil {
  183. // Connection is being closed but header wasn't sent. This means the client request
  184. // is probably not valid. Sending back a server error header in this case.
  185. // Write response based on error reason
  186. switch c.errReason {
  187. case ErrHeaderMisMatch:
  188. c.errorMismatchWriter.Write(c.Conn)
  189. case ErrHeaderToLong:
  190. c.errorTooLongWriter.Write(c.Conn)
  191. default:
  192. c.errorWriter.Write(c.Conn)
  193. }
  194. }
  195. return c.Conn.Close()
  196. }
  197. func formResponseHeader(config *ResponseConfig) *HeaderWriter {
  198. header := buf.New()
  199. common.Must2(header.WriteString(strings.Join([]string{config.GetFullVersion(), config.GetStatusValue().Code, config.GetStatusValue().Reason}, " ")))
  200. common.Must2(header.WriteString(CRLF))
  201. headers := config.PickHeaders()
  202. for _, h := range headers {
  203. common.Must2(header.WriteString(h))
  204. common.Must2(header.WriteString(CRLF))
  205. }
  206. if !config.HasHeader("Date") {
  207. common.Must2(header.WriteString("Date: "))
  208. common.Must2(header.WriteString(time.Now().Format(http.TimeFormat)))
  209. common.Must2(header.WriteString(CRLF))
  210. }
  211. common.Must2(header.WriteString(CRLF))
  212. return &HeaderWriter{
  213. header: header,
  214. }
  215. }
  216. type Authenticator struct {
  217. config *Config
  218. }
  219. func (a Authenticator) GetClientWriter() *HeaderWriter {
  220. header := buf.New()
  221. config := a.config.Request
  222. common.Must2(header.WriteString(strings.Join([]string{config.GetMethodValue(), config.PickURI(), config.GetFullVersion()}, " ")))
  223. common.Must2(header.WriteString(CRLF))
  224. headers := config.PickHeaders()
  225. for _, h := range headers {
  226. common.Must2(header.WriteString(h))
  227. common.Must2(header.WriteString(CRLF))
  228. }
  229. common.Must2(header.WriteString(CRLF))
  230. return &HeaderWriter{
  231. header: header,
  232. }
  233. }
  234. func (a Authenticator) GetServerWriter() *HeaderWriter {
  235. return formResponseHeader(a.config.Response)
  236. }
  237. func (a Authenticator) Client(conn net.Conn) net.Conn {
  238. if a.config.Request == nil && a.config.Response == nil {
  239. return conn
  240. }
  241. var reader Reader = NoOpReader{}
  242. if a.config.Request != nil {
  243. reader = new(HeaderReader)
  244. }
  245. var writer Writer = NoOpWriter{}
  246. if a.config.Response != nil {
  247. writer = a.GetClientWriter()
  248. }
  249. return NewConn(conn, reader, writer, NoOpWriter{}, NoOpWriter{}, NoOpWriter{})
  250. }
  251. func (a Authenticator) Server(conn net.Conn) net.Conn {
  252. if a.config.Request == nil && a.config.Response == nil {
  253. return conn
  254. }
  255. return NewConn(conn, new(HeaderReader).ExpectThisRequest(a.config.Request), a.GetServerWriter(),
  256. formResponseHeader(resp400),
  257. formResponseHeader(resp404),
  258. formResponseHeader(resp400))
  259. }
  260. func NewAuthenticator(ctx context.Context, config *Config) (Authenticator, error) {
  261. return Authenticator{
  262. config: config,
  263. }, nil
  264. }
  265. func init() {
  266. common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
  267. return NewAuthenticator(ctx, config.(*Config))
  268. }))
  269. }