gin.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. package common
  2. import (
  3. "bytes"
  4. "errors"
  5. "io"
  6. "mime"
  7. "mime/multipart"
  8. "net/http"
  9. "net/url"
  10. "strings"
  11. "time"
  12. "github.com/QuantumNous/new-api/constant"
  13. "github.com/gin-gonic/gin"
  14. )
  15. const KeyRequestBody = "key_request_body"
  16. func GetRequestBody(c *gin.Context) ([]byte, error) {
  17. requestBody, _ := c.Get(KeyRequestBody)
  18. if requestBody != nil {
  19. return requestBody.([]byte), nil
  20. }
  21. requestBody, err := io.ReadAll(c.Request.Body)
  22. if err != nil {
  23. return nil, err
  24. }
  25. _ = c.Request.Body.Close()
  26. c.Set(KeyRequestBody, requestBody)
  27. return requestBody.([]byte), nil
  28. }
  29. func UnmarshalBodyReusable(c *gin.Context, v any) error {
  30. requestBody, err := GetRequestBody(c)
  31. if err != nil {
  32. return err
  33. }
  34. //if DebugEnabled {
  35. // println("UnmarshalBodyReusable request body:", string(requestBody))
  36. //}
  37. contentType := c.Request.Header.Get("Content-Type")
  38. if strings.HasPrefix(contentType, "application/json") {
  39. err = Unmarshal(requestBody, v)
  40. } else if strings.Contains(contentType, gin.MIMEPOSTForm) {
  41. err = parseFormData(requestBody, v)
  42. } else if strings.Contains(contentType, gin.MIMEMultipartPOSTForm) {
  43. err = parseMultipartFormData(c, requestBody, v)
  44. } else {
  45. // skip for now
  46. // TODO: someday non json request have variant model, we will need to implementation this
  47. }
  48. if err != nil {
  49. return err
  50. }
  51. // Reset request body
  52. c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
  53. return nil
  54. }
  55. func SetContextKey(c *gin.Context, key constant.ContextKey, value any) {
  56. c.Set(string(key), value)
  57. }
  58. func GetContextKey(c *gin.Context, key constant.ContextKey) (any, bool) {
  59. return c.Get(string(key))
  60. }
  61. func GetContextKeyString(c *gin.Context, key constant.ContextKey) string {
  62. return c.GetString(string(key))
  63. }
  64. func GetContextKeyInt(c *gin.Context, key constant.ContextKey) int {
  65. return c.GetInt(string(key))
  66. }
  67. func GetContextKeyBool(c *gin.Context, key constant.ContextKey) bool {
  68. return c.GetBool(string(key))
  69. }
  70. func GetContextKeyStringSlice(c *gin.Context, key constant.ContextKey) []string {
  71. return c.GetStringSlice(string(key))
  72. }
  73. func GetContextKeyStringMap(c *gin.Context, key constant.ContextKey) map[string]any {
  74. return c.GetStringMap(string(key))
  75. }
  76. func GetContextKeyTime(c *gin.Context, key constant.ContextKey) time.Time {
  77. return c.GetTime(string(key))
  78. }
  79. func GetContextKeyType[T any](c *gin.Context, key constant.ContextKey) (T, bool) {
  80. if value, ok := c.Get(string(key)); ok {
  81. if v, ok := value.(T); ok {
  82. return v, true
  83. }
  84. }
  85. var t T
  86. return t, false
  87. }
  88. func ApiError(c *gin.Context, err error) {
  89. c.JSON(http.StatusOK, gin.H{
  90. "success": false,
  91. "message": err.Error(),
  92. })
  93. }
  94. func ApiErrorMsg(c *gin.Context, msg string) {
  95. c.JSON(http.StatusOK, gin.H{
  96. "success": false,
  97. "message": msg,
  98. })
  99. }
  100. func ApiSuccess(c *gin.Context, data any) {
  101. c.JSON(http.StatusOK, gin.H{
  102. "success": true,
  103. "message": "",
  104. "data": data,
  105. })
  106. }
  107. func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) {
  108. requestBody, err := GetRequestBody(c)
  109. if err != nil {
  110. return nil, err
  111. }
  112. contentType := c.Request.Header.Get("Content-Type")
  113. boundary, err := parseBoundary(contentType)
  114. if err != nil {
  115. return nil, err
  116. }
  117. reader := multipart.NewReader(bytes.NewReader(requestBody), boundary)
  118. form, err := reader.ReadForm(multipartMemoryLimit())
  119. if err != nil {
  120. return nil, err
  121. }
  122. // Reset request body
  123. c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
  124. return form, nil
  125. }
  126. func processFormMap(formMap map[string]any, v any) error {
  127. jsonData, err := Marshal(formMap)
  128. if err != nil {
  129. return err
  130. }
  131. err = Unmarshal(jsonData, v)
  132. if err != nil {
  133. return err
  134. }
  135. return nil
  136. }
  137. func parseFormData(data []byte, v any) error {
  138. values, err := url.ParseQuery(string(data))
  139. if err != nil {
  140. return err
  141. }
  142. formMap := make(map[string]any)
  143. for key, vals := range values {
  144. if len(vals) == 1 {
  145. formMap[key] = vals[0]
  146. } else {
  147. formMap[key] = vals
  148. }
  149. }
  150. return processFormMap(formMap, v)
  151. }
  152. func parseMultipartFormData(c *gin.Context, data []byte, v any) error {
  153. contentType := c.Request.Header.Get("Content-Type")
  154. boundary, err := parseBoundary(contentType)
  155. if err != nil {
  156. if errors.Is(err, errBoundaryNotFound) {
  157. return Unmarshal(data, v) // Fallback to JSON
  158. }
  159. return err
  160. }
  161. reader := multipart.NewReader(bytes.NewReader(data), boundary)
  162. form, err := reader.ReadForm(multipartMemoryLimit())
  163. if err != nil {
  164. return err
  165. }
  166. defer form.RemoveAll()
  167. formMap := make(map[string]any)
  168. for key, vals := range form.Value {
  169. if len(vals) == 1 {
  170. formMap[key] = vals[0]
  171. } else {
  172. formMap[key] = vals
  173. }
  174. }
  175. return processFormMap(formMap, v)
  176. }
  177. var errBoundaryNotFound = errors.New("multipart boundary not found")
  178. // parseBoundary extracts the multipart boundary from the Content-Type header using mime.ParseMediaType
  179. func parseBoundary(contentType string) (string, error) {
  180. if contentType == "" {
  181. return "", errBoundaryNotFound
  182. }
  183. // Boundary-UUID / boundary-------xxxxxx
  184. _, params, err := mime.ParseMediaType(contentType)
  185. if err != nil {
  186. return "", err
  187. }
  188. boundary, ok := params["boundary"]
  189. if !ok || boundary == "" {
  190. return "", errBoundaryNotFound
  191. }
  192. return boundary, nil
  193. }
  194. // multipartMemoryLimit returns the configured multipart memory limit in bytes
  195. func multipartMemoryLimit() int64 {
  196. limitMB := constant.MaxFileDownloadMB
  197. if limitMB <= 0 {
  198. limitMB = 32
  199. }
  200. return int64(limitMB) << 20
  201. }