messagediff.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. package messagediff
  2. import (
  3. "fmt"
  4. "reflect"
  5. "sort"
  6. "strings"
  7. "unsafe"
  8. )
  9. // PrettyDiff does a deep comparison and returns the nicely formated results.
  10. func PrettyDiff(a, b interface{}) (string, bool) {
  11. d, equal := DeepDiff(a, b)
  12. var dstr []string
  13. for path, added := range d.Added {
  14. dstr = append(dstr, fmt.Sprintf("added: %s = %#v\n", path.String(), added))
  15. }
  16. for path, removed := range d.Removed {
  17. dstr = append(dstr, fmt.Sprintf("removed: %s = %#v\n", path.String(), removed))
  18. }
  19. for path, modified := range d.Modified {
  20. dstr = append(dstr, fmt.Sprintf("modified: %s = %#v\n", path.String(), modified))
  21. }
  22. sort.Strings(dstr)
  23. return strings.Join(dstr, ""), equal
  24. }
  25. // DeepDiff does a deep comparison and returns the results.
  26. func DeepDiff(a, b interface{}) (*Diff, bool) {
  27. d := newDiff()
  28. return d, d.diff(reflect.ValueOf(a), reflect.ValueOf(b), nil)
  29. }
  30. func newDiff() *Diff {
  31. return &Diff{
  32. Added: make(map[*Path]interface{}),
  33. Removed: make(map[*Path]interface{}),
  34. Modified: make(map[*Path]interface{}),
  35. visited: make(map[visit]bool),
  36. }
  37. }
  38. func (d *Diff) diff(aVal, bVal reflect.Value, path Path) bool {
  39. // The array underlying `path` could be modified in subsequent
  40. // calls. Make sure we have a local copy.
  41. localPath := make(Path, len(path))
  42. copy(localPath, path)
  43. // Validity checks. Should only trigger if nil is one of the original arguments.
  44. if !aVal.IsValid() && !bVal.IsValid() {
  45. return true
  46. }
  47. if !bVal.IsValid() {
  48. d.Modified[&localPath] = nil
  49. return false
  50. } else if !aVal.IsValid() {
  51. d.Modified[&localPath] = bVal.Interface()
  52. return false
  53. }
  54. if aVal.Type() != bVal.Type() {
  55. d.Modified[&localPath] = bVal.Interface()
  56. return false
  57. }
  58. kind := aVal.Kind()
  59. // Borrowed from the reflect package to handle recursive data structures.
  60. hard := func(k reflect.Kind) bool {
  61. switch k {
  62. case reflect.Array, reflect.Map, reflect.Slice, reflect.Struct:
  63. return true
  64. }
  65. return false
  66. }
  67. if aVal.CanAddr() && bVal.CanAddr() && hard(kind) {
  68. addr1 := unsafe.Pointer(aVal.UnsafeAddr())
  69. addr2 := unsafe.Pointer(bVal.UnsafeAddr())
  70. if uintptr(addr1) > uintptr(addr2) {
  71. // Canonicalize order to reduce number of entries in visited.
  72. // Assumes non-moving garbage collector.
  73. addr1, addr2 = addr2, addr1
  74. }
  75. // Short circuit if references are already seen.
  76. typ := aVal.Type()
  77. v := visit{addr1, addr2, typ}
  78. if d.visited[v] {
  79. return true
  80. }
  81. // Remember for later.
  82. d.visited[v] = true
  83. }
  84. // End of borrowed code.
  85. equal := true
  86. switch kind {
  87. case reflect.Map, reflect.Ptr, reflect.Func, reflect.Chan, reflect.Slice:
  88. if aVal.IsNil() && bVal.IsNil() {
  89. return true
  90. }
  91. if aVal.IsNil() || bVal.IsNil() {
  92. d.Modified[&localPath] = bVal.Interface()
  93. return false
  94. }
  95. }
  96. switch kind {
  97. case reflect.Array, reflect.Slice:
  98. aLen := aVal.Len()
  99. bLen := bVal.Len()
  100. for i := 0; i < min(aLen, bLen); i++ {
  101. localPath := append(localPath, SliceIndex(i))
  102. if eq := d.diff(aVal.Index(i), bVal.Index(i), localPath); !eq {
  103. equal = false
  104. }
  105. }
  106. if aLen > bLen {
  107. for i := bLen; i < aLen; i++ {
  108. localPath := append(localPath, SliceIndex(i))
  109. d.Removed[&localPath] = aVal.Index(i).Interface()
  110. equal = false
  111. }
  112. } else if aLen < bLen {
  113. for i := aLen; i < bLen; i++ {
  114. localPath := append(localPath, SliceIndex(i))
  115. d.Added[&localPath] = bVal.Index(i).Interface()
  116. equal = false
  117. }
  118. }
  119. case reflect.Map:
  120. for _, key := range aVal.MapKeys() {
  121. aI := aVal.MapIndex(key)
  122. bI := bVal.MapIndex(key)
  123. localPath := append(localPath, MapKey{key.Interface()})
  124. if !bI.IsValid() {
  125. d.Removed[&localPath] = aI.Interface()
  126. equal = false
  127. } else if eq := d.diff(aI, bI, localPath); !eq {
  128. equal = false
  129. }
  130. }
  131. for _, key := range bVal.MapKeys() {
  132. aI := aVal.MapIndex(key)
  133. if !aI.IsValid() {
  134. bI := bVal.MapIndex(key)
  135. localPath := append(localPath, MapKey{key.Interface()})
  136. d.Added[&localPath] = bI.Interface()
  137. equal = false
  138. }
  139. }
  140. case reflect.Struct:
  141. typ := aVal.Type()
  142. for i := 0; i < typ.NumField(); i++ {
  143. index := []int{i}
  144. field := typ.FieldByIndex(index)
  145. if field.Tag.Get("testdiff") == "ignore" { // skip fields marked to be ignored
  146. continue
  147. }
  148. localPath := append(localPath, StructField(field.Name))
  149. aI := unsafeReflectValue(aVal.FieldByIndex(index))
  150. bI := unsafeReflectValue(bVal.FieldByIndex(index))
  151. if eq := d.diff(aI, bI, localPath); !eq {
  152. equal = false
  153. }
  154. }
  155. case reflect.Ptr:
  156. equal = d.diff(aVal.Elem(), bVal.Elem(), localPath)
  157. default:
  158. if reflect.DeepEqual(aVal.Interface(), bVal.Interface()) {
  159. equal = true
  160. } else {
  161. d.Modified[&localPath] = bVal.Interface()
  162. equal = false
  163. }
  164. }
  165. return equal
  166. }
  167. func min(a, b int) int {
  168. if a < b {
  169. return a
  170. }
  171. return b
  172. }
  173. // During deepValueEqual, must keep track of checks that are
  174. // in progress. The comparison algorithm assumes that all
  175. // checks in progress are true when it reencounters them.
  176. // Visited comparisons are stored in a map indexed by visit.
  177. // This is borrowed from the reflect package.
  178. type visit struct {
  179. a1 unsafe.Pointer
  180. a2 unsafe.Pointer
  181. typ reflect.Type
  182. }
  183. // Diff represents a change in a struct.
  184. type Diff struct {
  185. Added, Removed, Modified map[*Path]interface{}
  186. visited map[visit]bool
  187. }
  188. // Path represents a path to a changed datum.
  189. type Path []PathNode
  190. func (p Path) String() string {
  191. var out string
  192. for _, n := range p {
  193. out += n.String()
  194. }
  195. return out
  196. }
  197. // PathNode represents one step in the path.
  198. type PathNode interface {
  199. String() string
  200. }
  201. // StructField is a path element representing a field of a struct.
  202. type StructField string
  203. func (n StructField) String() string {
  204. return fmt.Sprintf(".%s", string(n))
  205. }
  206. // MapKey is a path element representing a key of a map.
  207. type MapKey struct {
  208. Key interface{}
  209. }
  210. func (n MapKey) String() string {
  211. return fmt.Sprintf("[%#v]", n.Key)
  212. }
  213. // SliceIndex is a path element representing a index of a slice.
  214. type SliceIndex int
  215. func (n SliceIndex) String() string {
  216. return fmt.Sprintf("[%d]", n)
  217. }