assertion_compare.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489
  1. package assert
  2. import (
  3. "bytes"
  4. "fmt"
  5. "reflect"
  6. "time"
  7. )
  8. // Deprecated: CompareType has only ever been for internal use and has accidentally been published since v1.6.0. Do not use it.
  9. type CompareType = compareResult
  10. type compareResult int
  11. const (
  12. compareLess compareResult = iota - 1
  13. compareEqual
  14. compareGreater
  15. )
  16. var (
  17. intType = reflect.TypeOf(int(1))
  18. int8Type = reflect.TypeOf(int8(1))
  19. int16Type = reflect.TypeOf(int16(1))
  20. int32Type = reflect.TypeOf(int32(1))
  21. int64Type = reflect.TypeOf(int64(1))
  22. uintType = reflect.TypeOf(uint(1))
  23. uint8Type = reflect.TypeOf(uint8(1))
  24. uint16Type = reflect.TypeOf(uint16(1))
  25. uint32Type = reflect.TypeOf(uint32(1))
  26. uint64Type = reflect.TypeOf(uint64(1))
  27. uintptrType = reflect.TypeOf(uintptr(1))
  28. float32Type = reflect.TypeOf(float32(1))
  29. float64Type = reflect.TypeOf(float64(1))
  30. stringType = reflect.TypeOf("")
  31. timeType = reflect.TypeOf(time.Time{})
  32. bytesType = reflect.TypeOf([]byte{})
  33. )
  34. func compare(obj1, obj2 interface{}, kind reflect.Kind) (compareResult, bool) {
  35. obj1Value := reflect.ValueOf(obj1)
  36. obj2Value := reflect.ValueOf(obj2)
  37. // throughout this switch we try and avoid calling .Convert() if possible,
  38. // as this has a pretty big performance impact
  39. switch kind {
  40. case reflect.Int:
  41. {
  42. intobj1, ok := obj1.(int)
  43. if !ok {
  44. intobj1 = obj1Value.Convert(intType).Interface().(int)
  45. }
  46. intobj2, ok := obj2.(int)
  47. if !ok {
  48. intobj2 = obj2Value.Convert(intType).Interface().(int)
  49. }
  50. if intobj1 > intobj2 {
  51. return compareGreater, true
  52. }
  53. if intobj1 == intobj2 {
  54. return compareEqual, true
  55. }
  56. if intobj1 < intobj2 {
  57. return compareLess, true
  58. }
  59. }
  60. case reflect.Int8:
  61. {
  62. int8obj1, ok := obj1.(int8)
  63. if !ok {
  64. int8obj1 = obj1Value.Convert(int8Type).Interface().(int8)
  65. }
  66. int8obj2, ok := obj2.(int8)
  67. if !ok {
  68. int8obj2 = obj2Value.Convert(int8Type).Interface().(int8)
  69. }
  70. if int8obj1 > int8obj2 {
  71. return compareGreater, true
  72. }
  73. if int8obj1 == int8obj2 {
  74. return compareEqual, true
  75. }
  76. if int8obj1 < int8obj2 {
  77. return compareLess, true
  78. }
  79. }
  80. case reflect.Int16:
  81. {
  82. int16obj1, ok := obj1.(int16)
  83. if !ok {
  84. int16obj1 = obj1Value.Convert(int16Type).Interface().(int16)
  85. }
  86. int16obj2, ok := obj2.(int16)
  87. if !ok {
  88. int16obj2 = obj2Value.Convert(int16Type).Interface().(int16)
  89. }
  90. if int16obj1 > int16obj2 {
  91. return compareGreater, true
  92. }
  93. if int16obj1 == int16obj2 {
  94. return compareEqual, true
  95. }
  96. if int16obj1 < int16obj2 {
  97. return compareLess, true
  98. }
  99. }
  100. case reflect.Int32:
  101. {
  102. int32obj1, ok := obj1.(int32)
  103. if !ok {
  104. int32obj1 = obj1Value.Convert(int32Type).Interface().(int32)
  105. }
  106. int32obj2, ok := obj2.(int32)
  107. if !ok {
  108. int32obj2 = obj2Value.Convert(int32Type).Interface().(int32)
  109. }
  110. if int32obj1 > int32obj2 {
  111. return compareGreater, true
  112. }
  113. if int32obj1 == int32obj2 {
  114. return compareEqual, true
  115. }
  116. if int32obj1 < int32obj2 {
  117. return compareLess, true
  118. }
  119. }
  120. case reflect.Int64:
  121. {
  122. int64obj1, ok := obj1.(int64)
  123. if !ok {
  124. int64obj1 = obj1Value.Convert(int64Type).Interface().(int64)
  125. }
  126. int64obj2, ok := obj2.(int64)
  127. if !ok {
  128. int64obj2 = obj2Value.Convert(int64Type).Interface().(int64)
  129. }
  130. if int64obj1 > int64obj2 {
  131. return compareGreater, true
  132. }
  133. if int64obj1 == int64obj2 {
  134. return compareEqual, true
  135. }
  136. if int64obj1 < int64obj2 {
  137. return compareLess, true
  138. }
  139. }
  140. case reflect.Uint:
  141. {
  142. uintobj1, ok := obj1.(uint)
  143. if !ok {
  144. uintobj1 = obj1Value.Convert(uintType).Interface().(uint)
  145. }
  146. uintobj2, ok := obj2.(uint)
  147. if !ok {
  148. uintobj2 = obj2Value.Convert(uintType).Interface().(uint)
  149. }
  150. if uintobj1 > uintobj2 {
  151. return compareGreater, true
  152. }
  153. if uintobj1 == uintobj2 {
  154. return compareEqual, true
  155. }
  156. if uintobj1 < uintobj2 {
  157. return compareLess, true
  158. }
  159. }
  160. case reflect.Uint8:
  161. {
  162. uint8obj1, ok := obj1.(uint8)
  163. if !ok {
  164. uint8obj1 = obj1Value.Convert(uint8Type).Interface().(uint8)
  165. }
  166. uint8obj2, ok := obj2.(uint8)
  167. if !ok {
  168. uint8obj2 = obj2Value.Convert(uint8Type).Interface().(uint8)
  169. }
  170. if uint8obj1 > uint8obj2 {
  171. return compareGreater, true
  172. }
  173. if uint8obj1 == uint8obj2 {
  174. return compareEqual, true
  175. }
  176. if uint8obj1 < uint8obj2 {
  177. return compareLess, true
  178. }
  179. }
  180. case reflect.Uint16:
  181. {
  182. uint16obj1, ok := obj1.(uint16)
  183. if !ok {
  184. uint16obj1 = obj1Value.Convert(uint16Type).Interface().(uint16)
  185. }
  186. uint16obj2, ok := obj2.(uint16)
  187. if !ok {
  188. uint16obj2 = obj2Value.Convert(uint16Type).Interface().(uint16)
  189. }
  190. if uint16obj1 > uint16obj2 {
  191. return compareGreater, true
  192. }
  193. if uint16obj1 == uint16obj2 {
  194. return compareEqual, true
  195. }
  196. if uint16obj1 < uint16obj2 {
  197. return compareLess, true
  198. }
  199. }
  200. case reflect.Uint32:
  201. {
  202. uint32obj1, ok := obj1.(uint32)
  203. if !ok {
  204. uint32obj1 = obj1Value.Convert(uint32Type).Interface().(uint32)
  205. }
  206. uint32obj2, ok := obj2.(uint32)
  207. if !ok {
  208. uint32obj2 = obj2Value.Convert(uint32Type).Interface().(uint32)
  209. }
  210. if uint32obj1 > uint32obj2 {
  211. return compareGreater, true
  212. }
  213. if uint32obj1 == uint32obj2 {
  214. return compareEqual, true
  215. }
  216. if uint32obj1 < uint32obj2 {
  217. return compareLess, true
  218. }
  219. }
  220. case reflect.Uint64:
  221. {
  222. uint64obj1, ok := obj1.(uint64)
  223. if !ok {
  224. uint64obj1 = obj1Value.Convert(uint64Type).Interface().(uint64)
  225. }
  226. uint64obj2, ok := obj2.(uint64)
  227. if !ok {
  228. uint64obj2 = obj2Value.Convert(uint64Type).Interface().(uint64)
  229. }
  230. if uint64obj1 > uint64obj2 {
  231. return compareGreater, true
  232. }
  233. if uint64obj1 == uint64obj2 {
  234. return compareEqual, true
  235. }
  236. if uint64obj1 < uint64obj2 {
  237. return compareLess, true
  238. }
  239. }
  240. case reflect.Float32:
  241. {
  242. float32obj1, ok := obj1.(float32)
  243. if !ok {
  244. float32obj1 = obj1Value.Convert(float32Type).Interface().(float32)
  245. }
  246. float32obj2, ok := obj2.(float32)
  247. if !ok {
  248. float32obj2 = obj2Value.Convert(float32Type).Interface().(float32)
  249. }
  250. if float32obj1 > float32obj2 {
  251. return compareGreater, true
  252. }
  253. if float32obj1 == float32obj2 {
  254. return compareEqual, true
  255. }
  256. if float32obj1 < float32obj2 {
  257. return compareLess, true
  258. }
  259. }
  260. case reflect.Float64:
  261. {
  262. float64obj1, ok := obj1.(float64)
  263. if !ok {
  264. float64obj1 = obj1Value.Convert(float64Type).Interface().(float64)
  265. }
  266. float64obj2, ok := obj2.(float64)
  267. if !ok {
  268. float64obj2 = obj2Value.Convert(float64Type).Interface().(float64)
  269. }
  270. if float64obj1 > float64obj2 {
  271. return compareGreater, true
  272. }
  273. if float64obj1 == float64obj2 {
  274. return compareEqual, true
  275. }
  276. if float64obj1 < float64obj2 {
  277. return compareLess, true
  278. }
  279. }
  280. case reflect.String:
  281. {
  282. stringobj1, ok := obj1.(string)
  283. if !ok {
  284. stringobj1 = obj1Value.Convert(stringType).Interface().(string)
  285. }
  286. stringobj2, ok := obj2.(string)
  287. if !ok {
  288. stringobj2 = obj2Value.Convert(stringType).Interface().(string)
  289. }
  290. if stringobj1 > stringobj2 {
  291. return compareGreater, true
  292. }
  293. if stringobj1 == stringobj2 {
  294. return compareEqual, true
  295. }
  296. if stringobj1 < stringobj2 {
  297. return compareLess, true
  298. }
  299. }
  300. // Check for known struct types we can check for compare results.
  301. case reflect.Struct:
  302. {
  303. // All structs enter here. We're not interested in most types.
  304. if !obj1Value.CanConvert(timeType) {
  305. break
  306. }
  307. // time.Time can be compared!
  308. timeObj1, ok := obj1.(time.Time)
  309. if !ok {
  310. timeObj1 = obj1Value.Convert(timeType).Interface().(time.Time)
  311. }
  312. timeObj2, ok := obj2.(time.Time)
  313. if !ok {
  314. timeObj2 = obj2Value.Convert(timeType).Interface().(time.Time)
  315. }
  316. if timeObj1.Before(timeObj2) {
  317. return compareLess, true
  318. }
  319. if timeObj1.Equal(timeObj2) {
  320. return compareEqual, true
  321. }
  322. return compareGreater, true
  323. }
  324. case reflect.Slice:
  325. {
  326. // We only care about the []byte type.
  327. if !obj1Value.CanConvert(bytesType) {
  328. break
  329. }
  330. // []byte can be compared!
  331. bytesObj1, ok := obj1.([]byte)
  332. if !ok {
  333. bytesObj1 = obj1Value.Convert(bytesType).Interface().([]byte)
  334. }
  335. bytesObj2, ok := obj2.([]byte)
  336. if !ok {
  337. bytesObj2 = obj2Value.Convert(bytesType).Interface().([]byte)
  338. }
  339. return compareResult(bytes.Compare(bytesObj1, bytesObj2)), true
  340. }
  341. case reflect.Uintptr:
  342. {
  343. uintptrObj1, ok := obj1.(uintptr)
  344. if !ok {
  345. uintptrObj1 = obj1Value.Convert(uintptrType).Interface().(uintptr)
  346. }
  347. uintptrObj2, ok := obj2.(uintptr)
  348. if !ok {
  349. uintptrObj2 = obj2Value.Convert(uintptrType).Interface().(uintptr)
  350. }
  351. if uintptrObj1 > uintptrObj2 {
  352. return compareGreater, true
  353. }
  354. if uintptrObj1 == uintptrObj2 {
  355. return compareEqual, true
  356. }
  357. if uintptrObj1 < uintptrObj2 {
  358. return compareLess, true
  359. }
  360. }
  361. }
  362. return compareEqual, false
  363. }
  364. // Greater asserts that the first element is greater than the second
  365. //
  366. // assert.Greater(t, 2, 1)
  367. // assert.Greater(t, float64(2), float64(1))
  368. // assert.Greater(t, "b", "a")
  369. func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
  370. if h, ok := t.(tHelper); ok {
  371. h.Helper()
  372. }
  373. return compareTwoValues(t, e1, e2, []compareResult{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...)
  374. }
  375. // GreaterOrEqual asserts that the first element is greater than or equal to the second
  376. //
  377. // assert.GreaterOrEqual(t, 2, 1)
  378. // assert.GreaterOrEqual(t, 2, 2)
  379. // assert.GreaterOrEqual(t, "b", "a")
  380. // assert.GreaterOrEqual(t, "b", "b")
  381. func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
  382. if h, ok := t.(tHelper); ok {
  383. h.Helper()
  384. }
  385. return compareTwoValues(t, e1, e2, []compareResult{compareGreater, compareEqual}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...)
  386. }
  387. // Less asserts that the first element is less than the second
  388. //
  389. // assert.Less(t, 1, 2)
  390. // assert.Less(t, float64(1), float64(2))
  391. // assert.Less(t, "a", "b")
  392. func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
  393. if h, ok := t.(tHelper); ok {
  394. h.Helper()
  395. }
  396. return compareTwoValues(t, e1, e2, []compareResult{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...)
  397. }
  398. // LessOrEqual asserts that the first element is less than or equal to the second
  399. //
  400. // assert.LessOrEqual(t, 1, 2)
  401. // assert.LessOrEqual(t, 2, 2)
  402. // assert.LessOrEqual(t, "a", "b")
  403. // assert.LessOrEqual(t, "b", "b")
  404. func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
  405. if h, ok := t.(tHelper); ok {
  406. h.Helper()
  407. }
  408. return compareTwoValues(t, e1, e2, []compareResult{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...)
  409. }
  410. // Positive asserts that the specified element is positive
  411. //
  412. // assert.Positive(t, 1)
  413. // assert.Positive(t, 1.23)
  414. func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
  415. if h, ok := t.(tHelper); ok {
  416. h.Helper()
  417. }
  418. zero := reflect.Zero(reflect.TypeOf(e))
  419. return compareTwoValues(t, e, zero.Interface(), []compareResult{compareGreater}, "\"%v\" is not positive", msgAndArgs...)
  420. }
  421. // Negative asserts that the specified element is negative
  422. //
  423. // assert.Negative(t, -1)
  424. // assert.Negative(t, -1.23)
  425. func Negative(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
  426. if h, ok := t.(tHelper); ok {
  427. h.Helper()
  428. }
  429. zero := reflect.Zero(reflect.TypeOf(e))
  430. return compareTwoValues(t, e, zero.Interface(), []compareResult{compareLess}, "\"%v\" is not negative", msgAndArgs...)
  431. }
  432. func compareTwoValues(t TestingT, e1 interface{}, e2 interface{}, allowedComparesResults []compareResult, failMessage string, msgAndArgs ...interface{}) bool {
  433. if h, ok := t.(tHelper); ok {
  434. h.Helper()
  435. }
  436. e1Kind := reflect.ValueOf(e1).Kind()
  437. e2Kind := reflect.ValueOf(e2).Kind()
  438. if e1Kind != e2Kind {
  439. return Fail(t, "Elements should be the same type", msgAndArgs...)
  440. }
  441. compareResult, isComparable := compare(e1, e2, e1Kind)
  442. if !isComparable {
  443. return Fail(t, fmt.Sprintf("Can not compare type \"%s\"", reflect.TypeOf(e1)), msgAndArgs...)
  444. }
  445. if !containsValue(allowedComparesResults, compareResult) {
  446. return Fail(t, fmt.Sprintf(failMessage, e1, e2), msgAndArgs...)
  447. }
  448. return true
  449. }
  450. func containsValue(values []compareResult, value compareResult) bool {
  451. for _, v := range values {
  452. if v == value {
  453. return true
  454. }
  455. }
  456. return false
  457. }