package common import ( "bytes" "fmt" "io" "mime" "mime/multipart" "net/http" "net/url" "strings" "time" "github.com/QuantumNous/new-api/constant" "github.com/pkg/errors" "github.com/gin-gonic/gin" ) const KeyRequestBody = "key_request_body" var ErrRequestBodyTooLarge = errors.New("request body too large") func IsRequestBodyTooLargeError(err error) bool { if err == nil { return false } if errors.Is(err, ErrRequestBodyTooLarge) { return true } var mbe *http.MaxBytesError return errors.As(err, &mbe) } func GetRequestBody(c *gin.Context) ([]byte, error) { cached, exists := c.Get(KeyRequestBody) if exists && cached != nil { if b, ok := cached.([]byte); ok { return b, nil } } maxMB := constant.MaxRequestBodyMB if maxMB < 0 { // no limit body, err := io.ReadAll(c.Request.Body) _ = c.Request.Body.Close() if err != nil { return nil, err } c.Set(KeyRequestBody, body) return body, nil } maxBytes := int64(maxMB) << 20 limited := io.LimitReader(c.Request.Body, maxBytes+1) body, err := io.ReadAll(limited) if err != nil { _ = c.Request.Body.Close() if IsRequestBodyTooLargeError(err) { return nil, errors.Wrap(ErrRequestBodyTooLarge, fmt.Sprintf("request body exceeds %d MB", maxMB)) } return nil, err } _ = c.Request.Body.Close() if int64(len(body)) > maxBytes { return nil, errors.Wrap(ErrRequestBodyTooLarge, fmt.Sprintf("request body exceeds %d MB", maxMB)) } c.Set(KeyRequestBody, body) return body, nil } func UnmarshalBodyReusable(c *gin.Context, v any) error { requestBody, err := GetRequestBody(c) if err != nil { return err } //if DebugEnabled { // println("UnmarshalBodyReusable request body:", string(requestBody)) //} contentType := c.Request.Header.Get("Content-Type") if strings.HasPrefix(contentType, "application/json") { err = Unmarshal(requestBody, v) } else if strings.Contains(contentType, gin.MIMEPOSTForm) { err = parseFormData(requestBody, v) } else if strings.Contains(contentType, gin.MIMEMultipartPOSTForm) { err = parseMultipartFormData(c, requestBody, v) } else { // skip for now // TODO: someday non json request have variant model, we will need to implementation this } if err != nil { return err } // Reset request body c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) return nil } func SetContextKey(c *gin.Context, key constant.ContextKey, value any) { c.Set(string(key), value) } func GetContextKey(c *gin.Context, key constant.ContextKey) (any, bool) { return c.Get(string(key)) } func GetContextKeyString(c *gin.Context, key constant.ContextKey) string { return c.GetString(string(key)) } func GetContextKeyInt(c *gin.Context, key constant.ContextKey) int { return c.GetInt(string(key)) } func GetContextKeyBool(c *gin.Context, key constant.ContextKey) bool { return c.GetBool(string(key)) } func GetContextKeyStringSlice(c *gin.Context, key constant.ContextKey) []string { return c.GetStringSlice(string(key)) } func GetContextKeyStringMap(c *gin.Context, key constant.ContextKey) map[string]any { return c.GetStringMap(string(key)) } func GetContextKeyTime(c *gin.Context, key constant.ContextKey) time.Time { return c.GetTime(string(key)) } func GetContextKeyType[T any](c *gin.Context, key constant.ContextKey) (T, bool) { if value, ok := c.Get(string(key)); ok { if v, ok := value.(T); ok { return v, true } } var t T return t, false } func ApiError(c *gin.Context, err error) { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) } func ApiErrorMsg(c *gin.Context, msg string) { c.JSON(http.StatusOK, gin.H{ "success": false, "message": msg, }) } func ApiSuccess(c *gin.Context, data any) { c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": data, }) } func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) { requestBody, err := GetRequestBody(c) if err != nil { return nil, err } contentType := c.Request.Header.Get("Content-Type") boundary, err := parseBoundary(contentType) if err != nil { return nil, err } reader := multipart.NewReader(bytes.NewReader(requestBody), boundary) form, err := reader.ReadForm(multipartMemoryLimit()) if err != nil { return nil, err } // Reset request body c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) return form, nil } func processFormMap(formMap map[string]any, v any) error { jsonData, err := Marshal(formMap) if err != nil { return err } err = Unmarshal(jsonData, v) if err != nil { return err } return nil } func parseFormData(data []byte, v any) error { values, err := url.ParseQuery(string(data)) if err != nil { return err } formMap := make(map[string]any) for key, vals := range values { if len(vals) == 1 { formMap[key] = vals[0] } else { formMap[key] = vals } } return processFormMap(formMap, v) } func parseMultipartFormData(c *gin.Context, data []byte, v any) error { contentType := c.Request.Header.Get("Content-Type") boundary, err := parseBoundary(contentType) if err != nil { if errors.Is(err, errBoundaryNotFound) { return Unmarshal(data, v) // Fallback to JSON } return err } reader := multipart.NewReader(bytes.NewReader(data), boundary) form, err := reader.ReadForm(multipartMemoryLimit()) if err != nil { return err } defer form.RemoveAll() formMap := make(map[string]any) for key, vals := range form.Value { if len(vals) == 1 { formMap[key] = vals[0] } else { formMap[key] = vals } } return processFormMap(formMap, v) } var errBoundaryNotFound = errors.New("multipart boundary not found") // parseBoundary extracts the multipart boundary from the Content-Type header using mime.ParseMediaType func parseBoundary(contentType string) (string, error) { if contentType == "" { return "", errBoundaryNotFound } // Boundary-UUID / boundary-------xxxxxx _, params, err := mime.ParseMediaType(contentType) if err != nil { return "", err } boundary, ok := params["boundary"] if !ok || boundary == "" { return "", errBoundaryNotFound } return boundary, nil } // multipartMemoryLimit returns the configured multipart memory limit in bytes func multipartMemoryLimit() int64 { limitMB := constant.MaxFileDownloadMB if limitMB <= 0 { limitMB = 32 } return int64(limitMB) << 20 }