gin.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. package common
  2. import (
  3. "bytes"
  4. "context"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "strings"
  10. "sync"
  11. "github.com/bytedance/sonic"
  12. "github.com/bytedance/sonic/ast"
  13. "github.com/gin-gonic/gin"
  14. "github.com/sirupsen/logrus"
  15. )
  16. type RequestBodyKey struct{}
  17. const (
  18. MaxRequestBodySize = 1024 * 1024 * 50 // 50MB
  19. )
  20. func LimitReader(r io.Reader, n int64) io.Reader { return &LimitedReader{r, n} }
  21. type LimitedReader struct {
  22. R io.Reader
  23. N int64
  24. }
  25. var ErrLimitedReaderExceeded = errors.New("limited reader exceeded")
  26. func (l *LimitedReader) Read(p []byte) (n int, err error) {
  27. if l.N <= 0 {
  28. return 0, ErrLimitedReaderExceeded
  29. }
  30. if int64(len(p)) > l.N {
  31. p = p[0:l.N]
  32. }
  33. n, err = l.R.Read(p)
  34. l.N -= int64(n)
  35. return
  36. }
  37. func SetRequestBody(req *http.Request, body []byte) {
  38. ctx := req.Context()
  39. bufCtx := context.WithValue(ctx, RequestBodyKey{}, body)
  40. *req = *req.WithContext(bufCtx)
  41. }
  42. func GetRequestBody(req *http.Request) ([]byte, error) {
  43. contentType := req.Header.Get("Content-Type")
  44. if contentType == "application/x-www-form-urlencoded" ||
  45. strings.HasPrefix(contentType, "multipart/form-data") {
  46. return nil, nil
  47. }
  48. requestBody := req.Context().Value(RequestBodyKey{})
  49. if requestBody != nil {
  50. b, _ := requestBody.([]byte)
  51. return b, nil
  52. }
  53. var buf []byte
  54. var err error
  55. defer func() {
  56. req.Body.Close()
  57. if err == nil {
  58. req.Body = io.NopCloser(bytes.NewBuffer(buf))
  59. }
  60. }()
  61. if req.ContentLength <= 0 ||
  62. strings.HasPrefix(contentType, "application/json") {
  63. buf, err = io.ReadAll(LimitReader(req.Body, MaxRequestBodySize))
  64. if err != nil {
  65. if errors.Is(err, ErrLimitedReaderExceeded) {
  66. return nil, fmt.Errorf("request body too large, max: %d", MaxRequestBodySize)
  67. }
  68. return nil, fmt.Errorf("request body read failed: %w", err)
  69. }
  70. } else {
  71. if req.ContentLength > MaxRequestBodySize {
  72. return nil, fmt.Errorf("request body too large: %d, max: %d", req.ContentLength, MaxRequestBodySize)
  73. }
  74. buf = make([]byte, req.ContentLength)
  75. _, err = io.ReadFull(req.Body, buf)
  76. }
  77. if err != nil {
  78. return nil, fmt.Errorf("request body read failed: %w", err)
  79. }
  80. SetRequestBody(req, buf)
  81. return buf, nil
  82. }
  83. func UnmarshalBodyReusable(req *http.Request, v any) error {
  84. requestBody, err := GetRequestBody(req)
  85. if err != nil {
  86. return err
  87. }
  88. return sonic.Unmarshal(requestBody, &v)
  89. }
  90. func UnmarshalBody2Node(req *http.Request) (ast.Node, error) {
  91. requestBody, err := GetRequestBody(req)
  92. if err != nil {
  93. return ast.Node{}, err
  94. }
  95. return sonic.Get(requestBody)
  96. }
  97. var fieldsPool = sync.Pool{
  98. New: func() any {
  99. return make(logrus.Fields, 6)
  100. },
  101. }
  102. func GetLogFields() logrus.Fields {
  103. fields, ok := fieldsPool.Get().(logrus.Fields)
  104. if !ok {
  105. panic(fmt.Sprintf("fields pool type error: %T, %v", fields, fields))
  106. }
  107. return fields
  108. }
  109. func PutLogFields(fields logrus.Fields) {
  110. clear(fields)
  111. fieldsPool.Put(fields)
  112. }
  113. func GetLogger(c *gin.Context) *logrus.Entry {
  114. if log, ok := c.Get("log"); ok {
  115. v, ok := log.(*logrus.Entry)
  116. if !ok {
  117. panic(fmt.Sprintf("log type error: %T, %v", v, v))
  118. }
  119. return v
  120. }
  121. entry := NewLogger()
  122. c.Set("log", entry)
  123. return entry
  124. }
  125. func NewLogger() *logrus.Entry {
  126. return &logrus.Entry{
  127. Logger: logrus.StandardLogger(),
  128. Data: GetLogFields(),
  129. }
  130. }