| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174 |
- package dto
- import (
- "encoding/json"
- "reflect"
- "strings"
- "github.com/QuantumNous/new-api/common"
- "github.com/QuantumNous/new-api/types"
- "github.com/gin-gonic/gin"
- )
- type ImageRequest struct {
- Model string `json:"model"`
- Prompt string `json:"prompt" binding:"required"`
- N uint `json:"n,omitempty"`
- Size string `json:"size,omitempty"`
- Quality string `json:"quality,omitempty"`
- ResponseFormat string `json:"response_format,omitempty"`
- Style json.RawMessage `json:"style,omitempty"`
- User json.RawMessage `json:"user,omitempty"`
- ExtraFields json.RawMessage `json:"extra_fields,omitempty"`
- Background json.RawMessage `json:"background,omitempty"`
- Moderation json.RawMessage `json:"moderation,omitempty"`
- OutputFormat json.RawMessage `json:"output_format,omitempty"`
- OutputCompression json.RawMessage `json:"output_compression,omitempty"`
- PartialImages json.RawMessage `json:"partial_images,omitempty"`
- // Stream bool `json:"stream,omitempty"`
- Watermark *bool `json:"watermark,omitempty"`
- // 用匿名参数接收额外参数
- Extra map[string]json.RawMessage `json:"-"`
- }
- func (i *ImageRequest) UnmarshalJSON(data []byte) error {
- // 先解析成 map[string]interface{}
- var rawMap map[string]json.RawMessage
- if err := common.Unmarshal(data, &rawMap); err != nil {
- return err
- }
- // 用 struct tag 获取所有已定义字段名
- knownFields := GetJSONFieldNames(reflect.TypeOf(*i))
- // 再正常解析已定义字段
- type Alias ImageRequest
- var known Alias
- if err := common.Unmarshal(data, &known); err != nil {
- return err
- }
- *i = ImageRequest(known)
- // 提取多余字段
- i.Extra = make(map[string]json.RawMessage)
- for k, v := range rawMap {
- if _, ok := knownFields[k]; !ok {
- i.Extra[k] = v
- }
- }
- return nil
- }
- // 序列化时需要重新把字段平铺
- func (r ImageRequest) MarshalJSON() ([]byte, error) {
- // 将已定义字段转为 map
- type Alias ImageRequest
- alias := Alias(r)
- base, err := common.Marshal(alias)
- if err != nil {
- return nil, err
- }
- var baseMap map[string]json.RawMessage
- if err := common.Unmarshal(base, &baseMap); err != nil {
- return nil, err
- }
- // 不能合并ExtraFields!!!!!!!!
- // 合并 ExtraFields
- //for k, v := range r.Extra {
- // if _, exists := baseMap[k]; !exists {
- // baseMap[k] = v
- // }
- //}
- return common.Marshal(baseMap)
- }
- func GetJSONFieldNames(t reflect.Type) map[string]struct{} {
- fields := make(map[string]struct{})
- for i := 0; i < t.NumField(); i++ {
- field := t.Field(i)
- // 跳过匿名字段(例如 ExtraFields)
- if field.Anonymous {
- continue
- }
- tag := field.Tag.Get("json")
- if tag == "-" || tag == "" {
- continue
- }
- // 取逗号前字段名(排除 omitempty 等)
- name := tag
- if commaIdx := indexComma(tag); commaIdx != -1 {
- name = tag[:commaIdx]
- }
- fields[name] = struct{}{}
- }
- return fields
- }
- func indexComma(s string) int {
- for i := 0; i < len(s); i++ {
- if s[i] == ',' {
- return i
- }
- }
- return -1
- }
- func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta {
- var sizeRatio = 1.0
- var qualityRatio = 1.0
- if strings.HasPrefix(i.Model, "dall-e") {
- // Size
- if i.Size == "256x256" {
- sizeRatio = 0.4
- } else if i.Size == "512x512" {
- sizeRatio = 0.45
- } else if i.Size == "1024x1024" {
- sizeRatio = 1
- } else if i.Size == "1024x1792" || i.Size == "1792x1024" {
- sizeRatio = 2
- }
- if i.Model == "dall-e-3" && i.Quality == "hd" {
- qualityRatio = 2.0
- if i.Size == "1024x1792" || i.Size == "1792x1024" {
- qualityRatio = 1.5
- }
- }
- }
- // not support token count for dalle
- return &types.TokenCountMeta{
- CombineText: i.Prompt,
- MaxTokens: 1584,
- ImagePriceRatio: sizeRatio * qualityRatio * float64(i.N),
- }
- }
- func (i *ImageRequest) IsStream(c *gin.Context) bool {
- return false
- }
- func (i *ImageRequest) SetModelName(modelName string) {
- if modelName != "" {
- i.Model = modelName
- }
- }
- type ImageResponse struct {
- Data []ImageData `json:"data"`
- Created int64 `json:"created"`
- Extra any `json:"extra,omitempty"`
- }
- type ImageData struct {
- Url string `json:"url"`
- B64Json string `json:"b64_json"`
- RevisedPrompt string `json:"revised_prompt"`
- }
|