Quellcode durchsuchen

feat: dalle 格式支持自定义参数

Xyfacai vor 6 Monaten
Ursprung
Commit
cd7594f623
1 geänderte Dateien mit 90 neuen und 1 gelöschten Zeilen
  1. 90 1
      dto/dalle.go

+ 90 - 1
dto/dalle.go

@@ -1,6 +1,9 @@
 package dto
 
-import "encoding/json"
+import (
+	"encoding/json"
+	"reflect"
+)
 
 type ImageRequest struct {
 	Model          string          `json:"model"`
@@ -15,6 +18,58 @@ type ImageRequest struct {
 	Background     string          `json:"background,omitempty"`
 	Moderation     string          `json:"moderation,omitempty"`
 	OutputFormat   string          `json:"output_format,omitempty"`
+	// 用匿名字段接住额外的字段
+	Extra map[string]json.RawMessage `json:"-"`
+}
+
+func (r *ImageRequest) UnmarshalJSON(data []byte) error {
+	// 先解析成 map[string]interface{}
+	var rawMap map[string]json.RawMessage
+	if err := json.Unmarshal(data, &rawMap); err != nil {
+		return err
+	}
+
+	// 用 struct tag 获取所有已定义字段名
+	knownFields := GetJSONFieldNames(reflect.TypeOf(*r))
+
+	// 再正常解析已定义字段
+	type Alias ImageRequest
+	var known Alias
+	if err := json.Unmarshal(data, &known); err != nil {
+		return err
+	}
+	*r = ImageRequest(known)
+
+	// 提取多余字段
+	r.Extra = make(map[string]json.RawMessage)
+	for k, v := range rawMap {
+		if _, ok := knownFields[k]; !ok {
+			r.Extra[k] = v
+		}
+	}
+	return nil
+}
+
+func (r ImageRequest) MarshalJSON() ([]byte, error) {
+	// 将已定义字段转为 map
+	type Alias ImageRequest
+	alias := Alias(r)
+	base, err := json.Marshal(alias)
+	if err != nil {
+		return nil, err
+	}
+
+	var baseMap map[string]json.RawMessage
+	if err := json.Unmarshal(base, &baseMap); err != nil {
+		return nil, err
+	}
+
+	// 合并 ExtraFields
+	for k, v := range r.Extra {
+		baseMap[k] = v
+	}
+
+	return json.Marshal(baseMap)
 }
 
 type ImageResponse struct {
@@ -26,3 +81,37 @@ type ImageData struct {
 	B64Json       string `json:"b64_json"`
 	RevisedPrompt string `json:"revised_prompt"`
 }
+
+func GetJSONFieldNames(t reflect.Type) map[string]struct{} {
+	fields := make(map[string]struct{})
+	for i := 0; i < t.NumField(); i++ {
+		field := t.Field(i)
+
+		// 跳过匿名字段(例如 ExtraFields)
+		if field.Anonymous {
+			continue
+		}
+
+		tag := field.Tag.Get("json")
+		if tag == "-" || tag == "" {
+			continue
+		}
+
+		// 取逗号前字段名(排除 omitempty 等)
+		name := tag
+		if commaIdx := indexComma(tag); commaIdx != -1 {
+			name = tag[:commaIdx]
+		}
+		fields[name] = struct{}{}
+	}
+	return fields
+}
+
+func indexComma(s string) int {
+	for i := 0; i < len(s); i++ {
+		if s[i] == ',' {
+			return i
+		}
+	}
+	return -1
+}