encoder.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. package apijson
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "fmt"
  6. "reflect"
  7. "sort"
  8. "strconv"
  9. "strings"
  10. "sync"
  11. "time"
  12. "github.com/tidwall/sjson"
  13. "github.com/sst/opencode-sdk-go/internal/param"
  14. )
  15. var encoders sync.Map // map[encoderEntry]encoderFunc
  16. func Marshal(value interface{}) ([]byte, error) {
  17. e := &encoder{dateFormat: time.RFC3339}
  18. return e.marshal(value)
  19. }
  20. func MarshalRoot(value interface{}) ([]byte, error) {
  21. e := &encoder{root: true, dateFormat: time.RFC3339}
  22. return e.marshal(value)
  23. }
  24. type encoder struct {
  25. dateFormat string
  26. root bool
  27. }
  28. type encoderFunc func(value reflect.Value) ([]byte, error)
  29. type encoderField struct {
  30. tag parsedStructTag
  31. fn encoderFunc
  32. idx []int
  33. }
  34. type encoderEntry struct {
  35. reflect.Type
  36. dateFormat string
  37. root bool
  38. }
  39. func (e *encoder) marshal(value interface{}) ([]byte, error) {
  40. val := reflect.ValueOf(value)
  41. if !val.IsValid() {
  42. return nil, nil
  43. }
  44. typ := val.Type()
  45. enc := e.typeEncoder(typ)
  46. return enc(val)
  47. }
  48. func (e *encoder) typeEncoder(t reflect.Type) encoderFunc {
  49. entry := encoderEntry{
  50. Type: t,
  51. dateFormat: e.dateFormat,
  52. root: e.root,
  53. }
  54. if fi, ok := encoders.Load(entry); ok {
  55. return fi.(encoderFunc)
  56. }
  57. // To deal with recursive types, populate the map with an
  58. // indirect func before we build it. This type waits on the
  59. // real func (f) to be ready and then calls it. This indirect
  60. // func is only used for recursive types.
  61. var (
  62. wg sync.WaitGroup
  63. f encoderFunc
  64. )
  65. wg.Add(1)
  66. fi, loaded := encoders.LoadOrStore(entry, encoderFunc(func(v reflect.Value) ([]byte, error) {
  67. wg.Wait()
  68. return f(v)
  69. }))
  70. if loaded {
  71. return fi.(encoderFunc)
  72. }
  73. // Compute the real encoder and replace the indirect func with it.
  74. f = e.newTypeEncoder(t)
  75. wg.Done()
  76. encoders.Store(entry, f)
  77. return f
  78. }
  79. func marshalerEncoder(v reflect.Value) ([]byte, error) {
  80. return v.Interface().(json.Marshaler).MarshalJSON()
  81. }
  82. func indirectMarshalerEncoder(v reflect.Value) ([]byte, error) {
  83. return v.Addr().Interface().(json.Marshaler).MarshalJSON()
  84. }
  85. func (e *encoder) newTypeEncoder(t reflect.Type) encoderFunc {
  86. if t.ConvertibleTo(reflect.TypeOf(time.Time{})) {
  87. return e.newTimeTypeEncoder()
  88. }
  89. if !e.root && t.Implements(reflect.TypeOf((*json.Marshaler)(nil)).Elem()) {
  90. return marshalerEncoder
  91. }
  92. if !e.root && reflect.PointerTo(t).Implements(reflect.TypeOf((*json.Marshaler)(nil)).Elem()) {
  93. return indirectMarshalerEncoder
  94. }
  95. e.root = false
  96. switch t.Kind() {
  97. case reflect.Pointer:
  98. inner := t.Elem()
  99. innerEncoder := e.typeEncoder(inner)
  100. return func(v reflect.Value) ([]byte, error) {
  101. if !v.IsValid() || v.IsNil() {
  102. return nil, nil
  103. }
  104. return innerEncoder(v.Elem())
  105. }
  106. case reflect.Struct:
  107. return e.newStructTypeEncoder(t)
  108. case reflect.Array:
  109. fallthrough
  110. case reflect.Slice:
  111. return e.newArrayTypeEncoder(t)
  112. case reflect.Map:
  113. return e.newMapEncoder(t)
  114. case reflect.Interface:
  115. return e.newInterfaceEncoder()
  116. default:
  117. return e.newPrimitiveTypeEncoder(t)
  118. }
  119. }
  120. func (e *encoder) newPrimitiveTypeEncoder(t reflect.Type) encoderFunc {
  121. switch t.Kind() {
  122. // Note that we could use `gjson` to encode these types but it would complicate our
  123. // code more and this current code shouldn't cause any issues
  124. case reflect.String:
  125. return func(v reflect.Value) ([]byte, error) {
  126. return json.Marshal(v.Interface())
  127. }
  128. case reflect.Bool:
  129. return func(v reflect.Value) ([]byte, error) {
  130. if v.Bool() {
  131. return []byte("true"), nil
  132. }
  133. return []byte("false"), nil
  134. }
  135. case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64:
  136. return func(v reflect.Value) ([]byte, error) {
  137. return []byte(strconv.FormatInt(v.Int(), 10)), nil
  138. }
  139. case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64:
  140. return func(v reflect.Value) ([]byte, error) {
  141. return []byte(strconv.FormatUint(v.Uint(), 10)), nil
  142. }
  143. case reflect.Float32:
  144. return func(v reflect.Value) ([]byte, error) {
  145. return []byte(strconv.FormatFloat(v.Float(), 'f', -1, 32)), nil
  146. }
  147. case reflect.Float64:
  148. return func(v reflect.Value) ([]byte, error) {
  149. return []byte(strconv.FormatFloat(v.Float(), 'f', -1, 64)), nil
  150. }
  151. default:
  152. return func(v reflect.Value) ([]byte, error) {
  153. return nil, fmt.Errorf("unknown type received at primitive encoder: %s", t.String())
  154. }
  155. }
  156. }
  157. func (e *encoder) newArrayTypeEncoder(t reflect.Type) encoderFunc {
  158. itemEncoder := e.typeEncoder(t.Elem())
  159. return func(value reflect.Value) ([]byte, error) {
  160. json := []byte("[]")
  161. for i := 0; i < value.Len(); i++ {
  162. var value, err = itemEncoder(value.Index(i))
  163. if err != nil {
  164. return nil, err
  165. }
  166. if value == nil {
  167. // Assume that empty items should be inserted as `null` so that the output array
  168. // will be the same length as the input array
  169. value = []byte("null")
  170. }
  171. json, err = sjson.SetRawBytes(json, "-1", value)
  172. if err != nil {
  173. return nil, err
  174. }
  175. }
  176. return json, nil
  177. }
  178. }
  179. func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc {
  180. if t.Implements(reflect.TypeOf((*param.FieldLike)(nil)).Elem()) {
  181. return e.newFieldTypeEncoder(t)
  182. }
  183. encoderFields := []encoderField{}
  184. extraEncoder := (*encoderField)(nil)
  185. // This helper allows us to recursively collect field encoders into a flat
  186. // array. The parameter `index` keeps track of the access patterns necessary
  187. // to get to some field.
  188. var collectEncoderFields func(r reflect.Type, index []int)
  189. collectEncoderFields = func(r reflect.Type, index []int) {
  190. for i := 0; i < r.NumField(); i++ {
  191. idx := append(index, i)
  192. field := t.FieldByIndex(idx)
  193. if !field.IsExported() {
  194. continue
  195. }
  196. // If this is an embedded struct, traverse one level deeper to extract
  197. // the field and get their encoders as well.
  198. if field.Anonymous {
  199. collectEncoderFields(field.Type, idx)
  200. continue
  201. }
  202. // If json tag is not present, then we skip, which is intentionally
  203. // different behavior from the stdlib.
  204. ptag, ok := parseJSONStructTag(field)
  205. if !ok {
  206. continue
  207. }
  208. // We only want to support unexported field if they're tagged with
  209. // `extras` because that field shouldn't be part of the public API. We
  210. // also want to only keep the top level extras
  211. if ptag.extras && len(index) == 0 {
  212. extraEncoder = &encoderField{ptag, e.typeEncoder(field.Type.Elem()), idx}
  213. continue
  214. }
  215. if ptag.name == "-" {
  216. continue
  217. }
  218. dateFormat, ok := parseFormatStructTag(field)
  219. oldFormat := e.dateFormat
  220. if ok {
  221. switch dateFormat {
  222. case "date-time":
  223. e.dateFormat = time.RFC3339
  224. case "date":
  225. e.dateFormat = "2006-01-02"
  226. }
  227. }
  228. encoderFields = append(encoderFields, encoderField{ptag, e.typeEncoder(field.Type), idx})
  229. e.dateFormat = oldFormat
  230. }
  231. }
  232. collectEncoderFields(t, []int{})
  233. // Ensure deterministic output by sorting by lexicographic order
  234. sort.Slice(encoderFields, func(i, j int) bool {
  235. return encoderFields[i].tag.name < encoderFields[j].tag.name
  236. })
  237. return func(value reflect.Value) (json []byte, err error) {
  238. json = []byte("{}")
  239. for _, ef := range encoderFields {
  240. field := value.FieldByIndex(ef.idx)
  241. encoded, err := ef.fn(field)
  242. if err != nil {
  243. return nil, err
  244. }
  245. if encoded == nil {
  246. continue
  247. }
  248. json, err = sjson.SetRawBytes(json, ef.tag.name, encoded)
  249. if err != nil {
  250. return nil, err
  251. }
  252. }
  253. if extraEncoder != nil {
  254. json, err = e.encodeMapEntries(json, value.FieldByIndex(extraEncoder.idx))
  255. if err != nil {
  256. return nil, err
  257. }
  258. }
  259. return
  260. }
  261. }
  262. func (e *encoder) newFieldTypeEncoder(t reflect.Type) encoderFunc {
  263. f, _ := t.FieldByName("Value")
  264. enc := e.typeEncoder(f.Type)
  265. return func(value reflect.Value) (json []byte, err error) {
  266. present := value.FieldByName("Present")
  267. if !present.Bool() {
  268. return nil, nil
  269. }
  270. null := value.FieldByName("Null")
  271. if null.Bool() {
  272. return []byte("null"), nil
  273. }
  274. raw := value.FieldByName("Raw")
  275. if !raw.IsNil() {
  276. return e.typeEncoder(raw.Type())(raw)
  277. }
  278. return enc(value.FieldByName("Value"))
  279. }
  280. }
  281. func (e *encoder) newTimeTypeEncoder() encoderFunc {
  282. format := e.dateFormat
  283. return func(value reflect.Value) (json []byte, err error) {
  284. return []byte(`"` + value.Convert(reflect.TypeOf(time.Time{})).Interface().(time.Time).Format(format) + `"`), nil
  285. }
  286. }
  287. func (e encoder) newInterfaceEncoder() encoderFunc {
  288. return func(value reflect.Value) ([]byte, error) {
  289. value = value.Elem()
  290. if !value.IsValid() {
  291. return nil, nil
  292. }
  293. return e.typeEncoder(value.Type())(value)
  294. }
  295. }
  296. // Given a []byte of json (may either be an empty object or an object that already contains entries)
  297. // encode all of the entries in the map to the json byte array.
  298. func (e *encoder) encodeMapEntries(json []byte, v reflect.Value) ([]byte, error) {
  299. type mapPair struct {
  300. key []byte
  301. value reflect.Value
  302. }
  303. pairs := []mapPair{}
  304. keyEncoder := e.typeEncoder(v.Type().Key())
  305. iter := v.MapRange()
  306. for iter.Next() {
  307. var encodedKeyString string
  308. if iter.Key().Type().Kind() == reflect.String {
  309. encodedKeyString = iter.Key().String()
  310. } else {
  311. var err error
  312. encodedKeyBytes, err := keyEncoder(iter.Key())
  313. if err != nil {
  314. return nil, err
  315. }
  316. encodedKeyString = string(encodedKeyBytes)
  317. }
  318. encodedKey := []byte(sjsonReplacer.Replace(encodedKeyString))
  319. pairs = append(pairs, mapPair{key: encodedKey, value: iter.Value()})
  320. }
  321. // Ensure deterministic output
  322. sort.Slice(pairs, func(i, j int) bool {
  323. return bytes.Compare(pairs[i].key, pairs[j].key) < 0
  324. })
  325. elementEncoder := e.typeEncoder(v.Type().Elem())
  326. for _, p := range pairs {
  327. encodedValue, err := elementEncoder(p.value)
  328. if err != nil {
  329. return nil, err
  330. }
  331. if len(encodedValue) == 0 {
  332. continue
  333. }
  334. json, err = sjson.SetRawBytes(json, string(p.key), encodedValue)
  335. if err != nil {
  336. return nil, err
  337. }
  338. }
  339. return json, nil
  340. }
  341. func (e *encoder) newMapEncoder(t reflect.Type) encoderFunc {
  342. return func(value reflect.Value) ([]byte, error) {
  343. json := []byte("{}")
  344. var err error
  345. json, err = e.encodeMapEntries(json, value)
  346. if err != nil {
  347. return nil, err
  348. }
  349. return json, nil
  350. }
  351. }
  352. // If we want to set a literal key value into JSON using sjson, we need to make sure it doesn't have
  353. // special characters that sjson interprets as a path.
  354. var sjsonReplacer *strings.Replacer = strings.NewReplacer(".", "\\.", ":", "\\:", "*", "\\*")