encoder.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. package apiform
  2. import (
  3. "fmt"
  4. "io"
  5. "mime/multipart"
  6. "net/textproto"
  7. "path"
  8. "reflect"
  9. "sort"
  10. "strconv"
  11. "strings"
  12. "sync"
  13. "time"
  14. "github.com/sst/opencode-sdk-go/internal/param"
  15. )
  16. var encoders sync.Map // map[encoderEntry]encoderFunc
  17. func Marshal(value interface{}, writer *multipart.Writer) error {
  18. e := &encoder{dateFormat: time.RFC3339}
  19. return e.marshal(value, writer)
  20. }
  21. func MarshalRoot(value interface{}, writer *multipart.Writer) error {
  22. e := &encoder{root: true, dateFormat: time.RFC3339}
  23. return e.marshal(value, writer)
  24. }
  25. type encoder struct {
  26. dateFormat string
  27. root bool
  28. }
  29. type encoderFunc func(key string, value reflect.Value, writer *multipart.Writer) error
  30. type encoderField struct {
  31. tag parsedStructTag
  32. fn encoderFunc
  33. idx []int
  34. }
  35. type encoderEntry struct {
  36. reflect.Type
  37. dateFormat string
  38. root bool
  39. }
  40. func (e *encoder) marshal(value interface{}, writer *multipart.Writer) error {
  41. val := reflect.ValueOf(value)
  42. if !val.IsValid() {
  43. return nil
  44. }
  45. typ := val.Type()
  46. enc := e.typeEncoder(typ)
  47. return enc("", val, writer)
  48. }
  49. func (e *encoder) typeEncoder(t reflect.Type) encoderFunc {
  50. entry := encoderEntry{
  51. Type: t,
  52. dateFormat: e.dateFormat,
  53. root: e.root,
  54. }
  55. if fi, ok := encoders.Load(entry); ok {
  56. return fi.(encoderFunc)
  57. }
  58. // To deal with recursive types, populate the map with an
  59. // indirect func before we build it. This type waits on the
  60. // real func (f) to be ready and then calls it. This indirect
  61. // func is only used for recursive types.
  62. var (
  63. wg sync.WaitGroup
  64. f encoderFunc
  65. )
  66. wg.Add(1)
  67. fi, loaded := encoders.LoadOrStore(entry, encoderFunc(func(key string, v reflect.Value, writer *multipart.Writer) error {
  68. wg.Wait()
  69. return f(key, v, writer)
  70. }))
  71. if loaded {
  72. return fi.(encoderFunc)
  73. }
  74. // Compute the real encoder and replace the indirect func with it.
  75. f = e.newTypeEncoder(t)
  76. wg.Done()
  77. encoders.Store(entry, f)
  78. return f
  79. }
  80. func (e *encoder) newTypeEncoder(t reflect.Type) encoderFunc {
  81. if t.ConvertibleTo(reflect.TypeOf(time.Time{})) {
  82. return e.newTimeTypeEncoder()
  83. }
  84. if t.ConvertibleTo(reflect.TypeOf((*io.Reader)(nil)).Elem()) {
  85. return e.newReaderTypeEncoder()
  86. }
  87. e.root = false
  88. switch t.Kind() {
  89. case reflect.Pointer:
  90. inner := t.Elem()
  91. innerEncoder := e.typeEncoder(inner)
  92. return func(key string, v reflect.Value, writer *multipart.Writer) error {
  93. if !v.IsValid() || v.IsNil() {
  94. return nil
  95. }
  96. return innerEncoder(key, v.Elem(), writer)
  97. }
  98. case reflect.Struct:
  99. return e.newStructTypeEncoder(t)
  100. case reflect.Slice, reflect.Array:
  101. return e.newArrayTypeEncoder(t)
  102. case reflect.Map:
  103. return e.newMapEncoder(t)
  104. case reflect.Interface:
  105. return e.newInterfaceEncoder()
  106. default:
  107. return e.newPrimitiveTypeEncoder(t)
  108. }
  109. }
  110. func (e *encoder) newPrimitiveTypeEncoder(t reflect.Type) encoderFunc {
  111. switch t.Kind() {
  112. // Note that we could use `gjson` to encode these types but it would complicate our
  113. // code more and this current code shouldn't cause any issues
  114. case reflect.String:
  115. return func(key string, v reflect.Value, writer *multipart.Writer) error {
  116. return writer.WriteField(key, v.String())
  117. }
  118. case reflect.Bool:
  119. return func(key string, v reflect.Value, writer *multipart.Writer) error {
  120. if v.Bool() {
  121. return writer.WriteField(key, "true")
  122. }
  123. return writer.WriteField(key, "false")
  124. }
  125. case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64:
  126. return func(key string, v reflect.Value, writer *multipart.Writer) error {
  127. return writer.WriteField(key, strconv.FormatInt(v.Int(), 10))
  128. }
  129. case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64:
  130. return func(key string, v reflect.Value, writer *multipart.Writer) error {
  131. return writer.WriteField(key, strconv.FormatUint(v.Uint(), 10))
  132. }
  133. case reflect.Float32:
  134. return func(key string, v reflect.Value, writer *multipart.Writer) error {
  135. return writer.WriteField(key, strconv.FormatFloat(v.Float(), 'f', -1, 32))
  136. }
  137. case reflect.Float64:
  138. return func(key string, v reflect.Value, writer *multipart.Writer) error {
  139. return writer.WriteField(key, strconv.FormatFloat(v.Float(), 'f', -1, 64))
  140. }
  141. default:
  142. return func(key string, v reflect.Value, writer *multipart.Writer) error {
  143. return fmt.Errorf("unknown type received at primitive encoder: %s", t.String())
  144. }
  145. }
  146. }
  147. func (e *encoder) newArrayTypeEncoder(t reflect.Type) encoderFunc {
  148. itemEncoder := e.typeEncoder(t.Elem())
  149. return func(key string, v reflect.Value, writer *multipart.Writer) error {
  150. if key != "" {
  151. key = key + "."
  152. }
  153. for i := 0; i < v.Len(); i++ {
  154. err := itemEncoder(key+strconv.Itoa(i), v.Index(i), writer)
  155. if err != nil {
  156. return err
  157. }
  158. }
  159. return nil
  160. }
  161. }
  162. func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc {
  163. if t.Implements(reflect.TypeOf((*param.FieldLike)(nil)).Elem()) {
  164. return e.newFieldTypeEncoder(t)
  165. }
  166. encoderFields := []encoderField{}
  167. extraEncoder := (*encoderField)(nil)
  168. // This helper allows us to recursively collect field encoders into a flat
  169. // array. The parameter `index` keeps track of the access patterns necessary
  170. // to get to some field.
  171. var collectEncoderFields func(r reflect.Type, index []int)
  172. collectEncoderFields = func(r reflect.Type, index []int) {
  173. for i := 0; i < r.NumField(); i++ {
  174. idx := append(index, i)
  175. field := t.FieldByIndex(idx)
  176. if !field.IsExported() {
  177. continue
  178. }
  179. // If this is an embedded struct, traverse one level deeper to extract
  180. // the field and get their encoders as well.
  181. if field.Anonymous {
  182. collectEncoderFields(field.Type, idx)
  183. continue
  184. }
  185. // If json tag is not present, then we skip, which is intentionally
  186. // different behavior from the stdlib.
  187. ptag, ok := parseFormStructTag(field)
  188. if !ok {
  189. continue
  190. }
  191. // We only want to support unexported field if they're tagged with
  192. // `extras` because that field shouldn't be part of the public API. We
  193. // also want to only keep the top level extras
  194. if ptag.extras && len(index) == 0 {
  195. extraEncoder = &encoderField{ptag, e.typeEncoder(field.Type.Elem()), idx}
  196. continue
  197. }
  198. if ptag.name == "-" {
  199. continue
  200. }
  201. dateFormat, ok := parseFormatStructTag(field)
  202. oldFormat := e.dateFormat
  203. if ok {
  204. switch dateFormat {
  205. case "date-time":
  206. e.dateFormat = time.RFC3339
  207. case "date":
  208. e.dateFormat = "2006-01-02"
  209. }
  210. }
  211. encoderFields = append(encoderFields, encoderField{ptag, e.typeEncoder(field.Type), idx})
  212. e.dateFormat = oldFormat
  213. }
  214. }
  215. collectEncoderFields(t, []int{})
  216. // Ensure deterministic output by sorting by lexicographic order
  217. sort.Slice(encoderFields, func(i, j int) bool {
  218. return encoderFields[i].tag.name < encoderFields[j].tag.name
  219. })
  220. return func(key string, value reflect.Value, writer *multipart.Writer) error {
  221. if key != "" {
  222. key = key + "."
  223. }
  224. for _, ef := range encoderFields {
  225. field := value.FieldByIndex(ef.idx)
  226. err := ef.fn(key+ef.tag.name, field, writer)
  227. if err != nil {
  228. return err
  229. }
  230. }
  231. if extraEncoder != nil {
  232. err := e.encodeMapEntries(key, value.FieldByIndex(extraEncoder.idx), writer)
  233. if err != nil {
  234. return err
  235. }
  236. }
  237. return nil
  238. }
  239. }
  240. func (e *encoder) newFieldTypeEncoder(t reflect.Type) encoderFunc {
  241. f, _ := t.FieldByName("Value")
  242. enc := e.typeEncoder(f.Type)
  243. return func(key string, value reflect.Value, writer *multipart.Writer) error {
  244. present := value.FieldByName("Present")
  245. if !present.Bool() {
  246. return nil
  247. }
  248. null := value.FieldByName("Null")
  249. if null.Bool() {
  250. return nil
  251. }
  252. raw := value.FieldByName("Raw")
  253. if !raw.IsNil() {
  254. return e.typeEncoder(raw.Type())(key, raw, writer)
  255. }
  256. return enc(key, value.FieldByName("Value"), writer)
  257. }
  258. }
  259. func (e *encoder) newTimeTypeEncoder() encoderFunc {
  260. format := e.dateFormat
  261. return func(key string, value reflect.Value, writer *multipart.Writer) error {
  262. return writer.WriteField(key, value.Convert(reflect.TypeOf(time.Time{})).Interface().(time.Time).Format(format))
  263. }
  264. }
  265. func (e encoder) newInterfaceEncoder() encoderFunc {
  266. return func(key string, value reflect.Value, writer *multipart.Writer) error {
  267. value = value.Elem()
  268. if !value.IsValid() {
  269. return nil
  270. }
  271. return e.typeEncoder(value.Type())(key, value, writer)
  272. }
  273. }
  274. var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"")
  275. func escapeQuotes(s string) string {
  276. return quoteEscaper.Replace(s)
  277. }
  278. func (e *encoder) newReaderTypeEncoder() encoderFunc {
  279. return func(key string, value reflect.Value, writer *multipart.Writer) error {
  280. reader := value.Convert(reflect.TypeOf((*io.Reader)(nil)).Elem()).Interface().(io.Reader)
  281. filename := "anonymous_file"
  282. contentType := "application/octet-stream"
  283. if named, ok := reader.(interface{ Filename() string }); ok {
  284. filename = named.Filename()
  285. } else if named, ok := reader.(interface{ Name() string }); ok {
  286. filename = path.Base(named.Name())
  287. }
  288. if typed, ok := reader.(interface{ ContentType() string }); ok {
  289. contentType = typed.ContentType()
  290. }
  291. // Below is taken almost 1-for-1 from [multipart.CreateFormFile]
  292. h := make(textproto.MIMEHeader)
  293. h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, escapeQuotes(key), escapeQuotes(filename)))
  294. h.Set("Content-Type", contentType)
  295. filewriter, err := writer.CreatePart(h)
  296. if err != nil {
  297. return err
  298. }
  299. _, err = io.Copy(filewriter, reader)
  300. return err
  301. }
  302. }
  303. // Given a []byte of json (may either be an empty object or an object that already contains entries)
  304. // encode all of the entries in the map to the json byte array.
  305. func (e *encoder) encodeMapEntries(key string, v reflect.Value, writer *multipart.Writer) error {
  306. type mapPair struct {
  307. key string
  308. value reflect.Value
  309. }
  310. if key != "" {
  311. key = key + "."
  312. }
  313. pairs := []mapPair{}
  314. iter := v.MapRange()
  315. for iter.Next() {
  316. if iter.Key().Type().Kind() == reflect.String {
  317. pairs = append(pairs, mapPair{key: iter.Key().String(), value: iter.Value()})
  318. } else {
  319. return fmt.Errorf("cannot encode a map with a non string key")
  320. }
  321. }
  322. // Ensure deterministic output
  323. sort.Slice(pairs, func(i, j int) bool {
  324. return pairs[i].key < pairs[j].key
  325. })
  326. elementEncoder := e.typeEncoder(v.Type().Elem())
  327. for _, p := range pairs {
  328. err := elementEncoder(key+string(p.key), p.value, writer)
  329. if err != nil {
  330. return err
  331. }
  332. }
  333. return nil
  334. }
  335. func (e *encoder) newMapEncoder(t reflect.Type) encoderFunc {
  336. return func(key string, value reflect.Value, writer *multipart.Writer) error {
  337. return e.encodeMapEntries(key, value, writer)
  338. }
  339. }