package common import ( "bytes" "errors" "io" "mime" "mime/multipart" "net/http" "net/url" "strings" "time" "github.com/QuantumNous/new-api/constant" "github.com/gin-gonic/gin" ) const KeyRequestBody = "key_request_body" func GetRequestBody(c *gin.Context) ([]byte, error) { requestBody, _ := c.Get(KeyRequestBody) if requestBody != nil { return requestBody.([]byte), nil } requestBody, err := io.ReadAll(c.Request.Body) if err != nil { return nil, err } _ = c.Request.Body.Close() c.Set(KeyRequestBody, requestBody) return requestBody.([]byte), nil } func UnmarshalBodyReusable(c *gin.Context, v any) error { requestBody, err := GetRequestBody(c) if err != nil { return err } //if DebugEnabled { // println("UnmarshalBodyReusable request body:", string(requestBody)) //} contentType := c.Request.Header.Get("Content-Type") if strings.HasPrefix(contentType, "application/json") { err = Unmarshal(requestBody, v) } else if strings.Contains(contentType, gin.MIMEPOSTForm) { err = parseFormData(requestBody, v) } else if strings.Contains(contentType, gin.MIMEMultipartPOSTForm) { err = parseMultipartFormData(c, requestBody, v) } else { // skip for now // TODO: someday non json request have variant model, we will need to implementation this } if err != nil { return err } // Reset request body c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) return nil } func SetContextKey(c *gin.Context, key constant.ContextKey, value any) { c.Set(string(key), value) } func GetContextKey(c *gin.Context, key constant.ContextKey) (any, bool) { return c.Get(string(key)) } func GetContextKeyString(c *gin.Context, key constant.ContextKey) string { return c.GetString(string(key)) } func GetContextKeyInt(c *gin.Context, key constant.ContextKey) int { return c.GetInt(string(key)) } func GetContextKeyBool(c *gin.Context, key constant.ContextKey) bool { return c.GetBool(string(key)) } func GetContextKeyStringSlice(c *gin.Context, key constant.ContextKey) []string { return c.GetStringSlice(string(key)) } func GetContextKeyStringMap(c *gin.Context, key constant.ContextKey) map[string]any { return c.GetStringMap(string(key)) } func GetContextKeyTime(c *gin.Context, key constant.ContextKey) time.Time { return c.GetTime(string(key)) } func GetContextKeyType[T any](c *gin.Context, key constant.ContextKey) (T, bool) { if value, ok := c.Get(string(key)); ok { if v, ok := value.(T); ok { return v, true } } var t T return t, false } func ApiError(c *gin.Context, err error) { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) } func ApiErrorMsg(c *gin.Context, msg string) { c.JSON(http.StatusOK, gin.H{ "success": false, "message": msg, }) } func ApiSuccess(c *gin.Context, data any) { c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": data, }) } func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) { requestBody, err := GetRequestBody(c) if err != nil { return nil, err } contentType := c.Request.Header.Get("Content-Type") boundary, err := parseBoundary(contentType) if err != nil { return nil, err } reader := multipart.NewReader(bytes.NewReader(requestBody), boundary) form, err := reader.ReadForm(multipartMemoryLimit()) if err != nil { return nil, err } // Reset request body c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) return form, nil } func processFormMap(formMap map[string]any, v any) error { jsonData, err := Marshal(formMap) if err != nil { return err } err = Unmarshal(jsonData, v) if err != nil { return err } return nil } func parseFormData(data []byte, v any) error { values, err := url.ParseQuery(string(data)) if err != nil { return err } formMap := make(map[string]any) for key, vals := range values { if len(vals) == 1 { formMap[key] = vals[0] } else { formMap[key] = vals } } return processFormMap(formMap, v) } func parseMultipartFormData(c *gin.Context, data []byte, v any) error { contentType := c.Request.Header.Get("Content-Type") boundary, err := parseBoundary(contentType) if err != nil { if errors.Is(err, errBoundaryNotFound) { return Unmarshal(data, v) // Fallback to JSON } return err } reader := multipart.NewReader(bytes.NewReader(data), boundary) form, err := reader.ReadForm(multipartMemoryLimit()) if err != nil { return err } defer form.RemoveAll() formMap := make(map[string]any) for key, vals := range form.Value { if len(vals) == 1 { formMap[key] = vals[0] } else { formMap[key] = vals } } return processFormMap(formMap, v) } var errBoundaryNotFound = errors.New("multipart boundary not found") // parseBoundary extracts the multipart boundary from the Content-Type header using mime.ParseMediaType func parseBoundary(contentType string) (string, error) { if contentType == "" { return "", errBoundaryNotFound } // Boundary-UUID / boundary-------xxxxxx _, params, err := mime.ParseMediaType(contentType) if err != nil { return "", err } boundary, ok := params["boundary"] if !ok || boundary == "" { return "", errBoundaryNotFound } return boundary, nil } // multipartMemoryLimit returns the configured multipart memory limit in bytes func multipartMemoryLimit() int64 { limitMB := constant.MaxFileDownloadMB if limitMB <= 0 { limitMB = 32 } return int64(limitMB) << 20 }