body.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. package common
  2. import (
  3. "bytes"
  4. "context"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "strings"
  10. "github.com/bytedance/sonic"
  11. "github.com/bytedance/sonic/ast"
  12. )
  13. type requestBodyKey struct{}
  14. const (
  15. MaxRequestBodySize = 1024 * 1024 * 50 // 50MB
  16. MaxResponseBodySize = 1024 * 1024 * 50 // 50MB
  17. )
  18. func LimitReader(r io.Reader, n int64) io.Reader { return &LimitedReader{r, n} }
  19. type LimitedReader struct {
  20. R io.Reader
  21. N int64
  22. }
  23. var ErrLimitedReaderExceeded = errors.New("limited reader exceeded")
  24. func (l *LimitedReader) Read(p []byte) (n int, err error) {
  25. if l.N <= 0 {
  26. return 0, ErrLimitedReaderExceeded
  27. }
  28. if int64(len(p)) > l.N {
  29. p = p[0:l.N]
  30. }
  31. n, err = l.R.Read(p)
  32. l.N -= int64(n)
  33. return n, err
  34. }
  35. func GetBodyLimit(body io.Reader, contentLength, n int64) ([]byte, error) {
  36. var (
  37. buf []byte
  38. err error
  39. )
  40. if contentLength <= 0 {
  41. buf, err = io.ReadAll(LimitReader(body, n))
  42. if err != nil {
  43. if errors.Is(err, ErrLimitedReaderExceeded) {
  44. return nil, fmt.Errorf("body too large, max: %d", n)
  45. }
  46. return nil, fmt.Errorf("body read failed: %w", err)
  47. }
  48. } else {
  49. if contentLength > n {
  50. return nil, fmt.Errorf("body too large: %d, max: %d", contentLength, n)
  51. }
  52. buf = make([]byte, contentLength)
  53. _, err = io.ReadFull(body, buf)
  54. }
  55. if err != nil {
  56. return nil, fmt.Errorf("body read failed: %w", err)
  57. }
  58. return buf, nil
  59. }
  60. func GetRequestBodyLimit(req *http.Request, n int64) ([]byte, error) {
  61. return GetBodyLimit(req.Body, req.ContentLength, n)
  62. }
  63. func GetRequestBody(req *http.Request) ([]byte, error) {
  64. return GetRequestBodyLimit(req, MaxRequestBodySize)
  65. }
  66. func SetRequestBody(req *http.Request, body []byte) {
  67. ctx := req.Context()
  68. bufCtx := context.WithValue(ctx, requestBodyKey{}, body)
  69. *req = *req.WithContext(bufCtx)
  70. }
  71. func IsJSONContentType(ct string) bool {
  72. return strings.HasSuffix(ct, "/json") ||
  73. strings.Contains(ct, "/json;")
  74. }
  75. func GetRequestBodyReusable(req *http.Request) ([]byte, error) {
  76. contentType := req.Header.Get("Content-Type")
  77. if strings.HasPrefix(contentType, "application/x-www-form-urlencoded") ||
  78. strings.HasPrefix(contentType, "multipart/form-data") {
  79. return nil, nil
  80. }
  81. requestBody := req.Context().Value(requestBodyKey{})
  82. if requestBody != nil {
  83. b, _ := requestBody.([]byte)
  84. return b, nil
  85. }
  86. var (
  87. buf []byte
  88. err error
  89. )
  90. defer func() {
  91. req.Body.Close()
  92. if err == nil {
  93. req.Body = io.NopCloser(bytes.NewBuffer(buf))
  94. }
  95. }()
  96. if req.ContentLength <= 0 ||
  97. IsJSONContentType(contentType) {
  98. buf, err = io.ReadAll(LimitReader(req.Body, MaxRequestBodySize))
  99. if err != nil {
  100. if errors.Is(err, ErrLimitedReaderExceeded) {
  101. return nil, fmt.Errorf("request body too large, max: %d", MaxRequestBodySize)
  102. }
  103. return nil, fmt.Errorf("request body read failed: %w", err)
  104. }
  105. } else {
  106. if req.ContentLength > MaxRequestBodySize {
  107. return nil, fmt.Errorf("request body too large: %d, max: %d", req.ContentLength, MaxRequestBodySize)
  108. }
  109. buf = make([]byte, req.ContentLength)
  110. _, err = io.ReadFull(req.Body, buf)
  111. }
  112. if err != nil {
  113. return nil, fmt.Errorf("request body read failed: %w", err)
  114. }
  115. SetRequestBody(req, buf)
  116. return buf, nil
  117. }
  118. func UnmarshalRequestReusable(req *http.Request, v any) error {
  119. requestBody, err := GetRequestBodyReusable(req)
  120. if err != nil {
  121. return err
  122. }
  123. return sonic.Unmarshal(requestBody, &v)
  124. }
  125. func UnmarshalRequest2NodeReusable(req *http.Request, path ...any) (ast.Node, error) {
  126. requestBody, err := GetRequestBodyReusable(req)
  127. if err != nil {
  128. return ast.Node{}, err
  129. }
  130. return sonic.Get(requestBody, path...)
  131. }
  132. func GetResponseBodyLimit(resp *http.Response, n int64) ([]byte, error) {
  133. return GetBodyLimit(resp.Body, resp.ContentLength, n)
  134. }
  135. func GetResponseBody(resp *http.Response) ([]byte, error) {
  136. return GetResponseBodyLimit(resp, MaxResponseBodySize)
  137. }
  138. func UnmarshalResponse(resp *http.Response, v any) error {
  139. responseBody, err := GetResponseBody(resp)
  140. if err != nil {
  141. return err
  142. }
  143. return sonic.Unmarshal(responseBody, &v)
  144. }
  145. func UnmarshalResponse2Node(resp *http.Response, path ...any) (ast.Node, error) {
  146. responseBody, err := GetResponseBody(resp)
  147. if err != nil {
  148. return ast.Node{}, err
  149. }
  150. return sonic.Get(responseBody, path...)
  151. }