compare.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. // Copyright (c) 2013, Vastech SA (PTY) LTD. All rights reserved.
  2. // http://github.com/gogo/protobuf
  3. //
  4. // Redistribution and use in source and binary forms, with or without
  5. // modification, are permitted provided that the following conditions are
  6. // met:
  7. //
  8. // * Redistributions of source code must retain the above copyright
  9. // notice, this list of conditions and the following disclaimer.
  10. // * Redistributions in binary form must reproduce the above
  11. // copyright notice, this list of conditions and the following disclaimer
  12. // in the documentation and/or other materials provided with the
  13. // distribution.
  14. //
  15. // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
  16. // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
  17. // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
  18. // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
  19. // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  20. // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
  21. // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  22. // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  23. // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  24. // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  25. // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  26. package compare
  27. import (
  28. "github.com/gogo/protobuf/gogoproto"
  29. "github.com/gogo/protobuf/proto"
  30. descriptor "github.com/gogo/protobuf/protoc-gen-gogo/descriptor"
  31. "github.com/gogo/protobuf/protoc-gen-gogo/generator"
  32. "github.com/gogo/protobuf/vanity"
  33. )
  34. type plugin struct {
  35. *generator.Generator
  36. generator.PluginImports
  37. fmtPkg generator.Single
  38. bytesPkg generator.Single
  39. sortkeysPkg generator.Single
  40. }
  41. func NewPlugin() *plugin {
  42. return &plugin{}
  43. }
  44. func (p *plugin) Name() string {
  45. return "compare"
  46. }
  47. func (p *plugin) Init(g *generator.Generator) {
  48. p.Generator = g
  49. }
  50. func (p *plugin) Generate(file *generator.FileDescriptor) {
  51. p.PluginImports = generator.NewPluginImports(p.Generator)
  52. p.fmtPkg = p.NewImport("fmt")
  53. p.bytesPkg = p.NewImport("bytes")
  54. p.sortkeysPkg = p.NewImport("github.com/gogo/protobuf/sortkeys")
  55. for _, msg := range file.Messages() {
  56. if msg.DescriptorProto.GetOptions().GetMapEntry() {
  57. continue
  58. }
  59. if gogoproto.HasCompare(file.FileDescriptorProto, msg.DescriptorProto) {
  60. p.generateMessage(file, msg)
  61. }
  62. }
  63. }
  64. func (p *plugin) generateNullableField(fieldname string) {
  65. p.P(`if this.`, fieldname, ` != nil && that1.`, fieldname, ` != nil {`)
  66. p.In()
  67. p.P(`if *this.`, fieldname, ` != *that1.`, fieldname, `{`)
  68. p.In()
  69. p.P(`if *this.`, fieldname, ` < *that1.`, fieldname, `{`)
  70. p.In()
  71. p.P(`return -1`)
  72. p.Out()
  73. p.P(`}`)
  74. p.P(`return 1`)
  75. p.Out()
  76. p.P(`}`)
  77. p.Out()
  78. p.P(`} else if this.`, fieldname, ` != nil {`)
  79. p.In()
  80. p.P(`return 1`)
  81. p.Out()
  82. p.P(`} else if that1.`, fieldname, ` != nil {`)
  83. p.In()
  84. p.P(`return -1`)
  85. p.Out()
  86. p.P(`}`)
  87. }
  88. func (p *plugin) generateMsgNullAndTypeCheck(ccTypeName string) {
  89. p.P(`if that == nil {`)
  90. p.In()
  91. p.P(`if this == nil {`)
  92. p.In()
  93. p.P(`return 0`)
  94. p.Out()
  95. p.P(`}`)
  96. p.P(`return 1`)
  97. p.Out()
  98. p.P(`}`)
  99. p.P(``)
  100. p.P(`that1, ok := that.(*`, ccTypeName, `)`)
  101. p.P(`if !ok {`)
  102. p.In()
  103. p.P(`that2, ok := that.(`, ccTypeName, `)`)
  104. p.P(`if ok {`)
  105. p.In()
  106. p.P(`that1 = &that2`)
  107. p.Out()
  108. p.P(`} else {`)
  109. p.In()
  110. p.P(`return 1`)
  111. p.Out()
  112. p.P(`}`)
  113. p.Out()
  114. p.P(`}`)
  115. p.P(`if that1 == nil {`)
  116. p.In()
  117. p.P(`if this == nil {`)
  118. p.In()
  119. p.P(`return 0`)
  120. p.Out()
  121. p.P(`}`)
  122. p.P(`return 1`)
  123. p.Out()
  124. p.P(`} else if this == nil {`)
  125. p.In()
  126. p.P(`return -1`)
  127. p.Out()
  128. p.P(`}`)
  129. }
  130. func (p *plugin) generateField(file *generator.FileDescriptor, message *generator.Descriptor, field *descriptor.FieldDescriptorProto) {
  131. proto3 := gogoproto.IsProto3(file.FileDescriptorProto)
  132. fieldname := p.GetOneOfFieldName(message, field)
  133. repeated := field.IsRepeated()
  134. ctype := gogoproto.IsCustomType(field)
  135. nullable := gogoproto.IsNullable(field)
  136. // oneof := field.OneofIndex != nil
  137. if !repeated {
  138. if ctype {
  139. if nullable {
  140. p.P(`if that1.`, fieldname, ` == nil {`)
  141. p.In()
  142. p.P(`if this.`, fieldname, ` != nil {`)
  143. p.In()
  144. p.P(`return 1`)
  145. p.Out()
  146. p.P(`}`)
  147. p.Out()
  148. p.P(`} else if this.`, fieldname, ` == nil {`)
  149. p.In()
  150. p.P(`return -1`)
  151. p.Out()
  152. p.P(`} else if c := this.`, fieldname, `.Compare(*that1.`, fieldname, `); c != 0 {`)
  153. } else {
  154. p.P(`if c := this.`, fieldname, `.Compare(that1.`, fieldname, `); c != 0 {`)
  155. }
  156. p.In()
  157. p.P(`return c`)
  158. p.Out()
  159. p.P(`}`)
  160. } else {
  161. if field.IsMessage() || p.IsGroup(field) {
  162. if nullable {
  163. p.P(`if c := this.`, fieldname, `.Compare(that1.`, fieldname, `); c != 0 {`)
  164. } else {
  165. p.P(`if c := this.`, fieldname, `.Compare(&that1.`, fieldname, `); c != 0 {`)
  166. }
  167. p.In()
  168. p.P(`return c`)
  169. p.Out()
  170. p.P(`}`)
  171. } else if field.IsBytes() {
  172. p.P(`if c := `, p.bytesPkg.Use(), `.Compare(this.`, fieldname, `, that1.`, fieldname, `); c != 0 {`)
  173. p.In()
  174. p.P(`return c`)
  175. p.Out()
  176. p.P(`}`)
  177. } else if field.IsString() {
  178. if nullable && !proto3 {
  179. p.generateNullableField(fieldname)
  180. } else {
  181. p.P(`if this.`, fieldname, ` != that1.`, fieldname, `{`)
  182. p.In()
  183. p.P(`if this.`, fieldname, ` < that1.`, fieldname, `{`)
  184. p.In()
  185. p.P(`return -1`)
  186. p.Out()
  187. p.P(`}`)
  188. p.P(`return 1`)
  189. p.Out()
  190. p.P(`}`)
  191. }
  192. } else if field.IsBool() {
  193. if nullable && !proto3 {
  194. p.P(`if this.`, fieldname, ` != nil && that1.`, fieldname, ` != nil {`)
  195. p.In()
  196. p.P(`if *this.`, fieldname, ` != *that1.`, fieldname, `{`)
  197. p.In()
  198. p.P(`if !*this.`, fieldname, ` {`)
  199. p.In()
  200. p.P(`return -1`)
  201. p.Out()
  202. p.P(`}`)
  203. p.P(`return 1`)
  204. p.Out()
  205. p.P(`}`)
  206. p.Out()
  207. p.P(`} else if this.`, fieldname, ` != nil {`)
  208. p.In()
  209. p.P(`return 1`)
  210. p.Out()
  211. p.P(`} else if that1.`, fieldname, ` != nil {`)
  212. p.In()
  213. p.P(`return -1`)
  214. p.Out()
  215. p.P(`}`)
  216. } else {
  217. p.P(`if this.`, fieldname, ` != that1.`, fieldname, `{`)
  218. p.In()
  219. p.P(`if !this.`, fieldname, ` {`)
  220. p.In()
  221. p.P(`return -1`)
  222. p.Out()
  223. p.P(`}`)
  224. p.P(`return 1`)
  225. p.Out()
  226. p.P(`}`)
  227. }
  228. } else {
  229. if nullable && !proto3 {
  230. p.generateNullableField(fieldname)
  231. } else {
  232. p.P(`if this.`, fieldname, ` != that1.`, fieldname, `{`)
  233. p.In()
  234. p.P(`if this.`, fieldname, ` < that1.`, fieldname, `{`)
  235. p.In()
  236. p.P(`return -1`)
  237. p.Out()
  238. p.P(`}`)
  239. p.P(`return 1`)
  240. p.Out()
  241. p.P(`}`)
  242. }
  243. }
  244. }
  245. } else {
  246. p.P(`if len(this.`, fieldname, `) != len(that1.`, fieldname, `) {`)
  247. p.In()
  248. p.P(`if len(this.`, fieldname, `) < len(that1.`, fieldname, `) {`)
  249. p.In()
  250. p.P(`return -1`)
  251. p.Out()
  252. p.P(`}`)
  253. p.P(`return 1`)
  254. p.Out()
  255. p.P(`}`)
  256. p.P(`for i := range this.`, fieldname, ` {`)
  257. p.In()
  258. if ctype {
  259. p.P(`if c := this.`, fieldname, `[i].Compare(that1.`, fieldname, `[i]); c != 0 {`)
  260. p.In()
  261. p.P(`return c`)
  262. p.Out()
  263. p.P(`}`)
  264. } else {
  265. if p.IsMap(field) {
  266. m := p.GoMapType(nil, field)
  267. valuegoTyp, _ := p.GoType(nil, m.ValueField)
  268. valuegoAliasTyp, _ := p.GoType(nil, m.ValueAliasField)
  269. nullable, valuegoTyp, valuegoAliasTyp = generator.GoMapValueTypes(field, m.ValueField, valuegoTyp, valuegoAliasTyp)
  270. mapValue := m.ValueAliasField
  271. if mapValue.IsMessage() || p.IsGroup(mapValue) {
  272. if nullable && valuegoTyp == valuegoAliasTyp {
  273. p.P(`if c := this.`, fieldname, `[i].Compare(that1.`, fieldname, `[i]); c != 0 {`)
  274. } else {
  275. // Compare() has a pointer receiver, but map value is a value type
  276. a := `this.` + fieldname + `[i]`
  277. b := `that1.` + fieldname + `[i]`
  278. if valuegoTyp != valuegoAliasTyp {
  279. // cast back to the type that has the generated methods on it
  280. a = `(` + valuegoTyp + `)(` + a + `)`
  281. b = `(` + valuegoTyp + `)(` + b + `)`
  282. }
  283. p.P(`a := `, a)
  284. p.P(`b := `, b)
  285. if nullable {
  286. p.P(`if c := a.Compare(b); c != 0 {`)
  287. } else {
  288. p.P(`if c := (&a).Compare(&b); c != 0 {`)
  289. }
  290. }
  291. p.In()
  292. p.P(`return c`)
  293. p.Out()
  294. p.P(`}`)
  295. } else if mapValue.IsBytes() {
  296. p.P(`if c := `, p.bytesPkg.Use(), `.Compare(this.`, fieldname, `[i], that1.`, fieldname, `[i]); c != 0 {`)
  297. p.In()
  298. p.P(`return c`)
  299. p.Out()
  300. p.P(`}`)
  301. } else if mapValue.IsString() {
  302. p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`)
  303. p.In()
  304. p.P(`if this.`, fieldname, `[i] < that1.`, fieldname, `[i] {`)
  305. p.In()
  306. p.P(`return -1`)
  307. p.Out()
  308. p.P(`}`)
  309. p.P(`return 1`)
  310. p.Out()
  311. p.P(`}`)
  312. } else {
  313. p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`)
  314. p.In()
  315. p.P(`if this.`, fieldname, `[i] < that1.`, fieldname, `[i] {`)
  316. p.In()
  317. p.P(`return -1`)
  318. p.Out()
  319. p.P(`}`)
  320. p.P(`return 1`)
  321. p.Out()
  322. p.P(`}`)
  323. }
  324. } else if field.IsMessage() || p.IsGroup(field) {
  325. if nullable {
  326. p.P(`if c := this.`, fieldname, `[i].Compare(that1.`, fieldname, `[i]); c != 0 {`)
  327. p.In()
  328. p.P(`return c`)
  329. p.Out()
  330. p.P(`}`)
  331. } else {
  332. p.P(`if c := this.`, fieldname, `[i].Compare(&that1.`, fieldname, `[i]); c != 0 {`)
  333. p.In()
  334. p.P(`return c`)
  335. p.Out()
  336. p.P(`}`)
  337. }
  338. } else if field.IsBytes() {
  339. p.P(`if c := `, p.bytesPkg.Use(), `.Compare(this.`, fieldname, `[i], that1.`, fieldname, `[i]); c != 0 {`)
  340. p.In()
  341. p.P(`return c`)
  342. p.Out()
  343. p.P(`}`)
  344. } else if field.IsString() {
  345. p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`)
  346. p.In()
  347. p.P(`if this.`, fieldname, `[i] < that1.`, fieldname, `[i] {`)
  348. p.In()
  349. p.P(`return -1`)
  350. p.Out()
  351. p.P(`}`)
  352. p.P(`return 1`)
  353. p.Out()
  354. p.P(`}`)
  355. } else if field.IsBool() {
  356. p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`)
  357. p.In()
  358. p.P(`if !this.`, fieldname, `[i] {`)
  359. p.In()
  360. p.P(`return -1`)
  361. p.Out()
  362. p.P(`}`)
  363. p.P(`return 1`)
  364. p.Out()
  365. p.P(`}`)
  366. } else {
  367. p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`)
  368. p.In()
  369. p.P(`if this.`, fieldname, `[i] < that1.`, fieldname, `[i] {`)
  370. p.In()
  371. p.P(`return -1`)
  372. p.Out()
  373. p.P(`}`)
  374. p.P(`return 1`)
  375. p.Out()
  376. p.P(`}`)
  377. }
  378. }
  379. p.Out()
  380. p.P(`}`)
  381. }
  382. }
  383. func (p *plugin) generateMessage(file *generator.FileDescriptor, message *generator.Descriptor) {
  384. ccTypeName := generator.CamelCaseSlice(message.TypeName())
  385. p.P(`func (this *`, ccTypeName, `) Compare(that interface{}) int {`)
  386. p.In()
  387. p.generateMsgNullAndTypeCheck(ccTypeName)
  388. oneofs := make(map[string]struct{})
  389. for _, field := range message.Field {
  390. oneof := field.OneofIndex != nil
  391. if oneof {
  392. fieldname := p.GetFieldName(message, field)
  393. if _, ok := oneofs[fieldname]; ok {
  394. continue
  395. } else {
  396. oneofs[fieldname] = struct{}{}
  397. }
  398. p.P(`if that1.`, fieldname, ` == nil {`)
  399. p.In()
  400. p.P(`if this.`, fieldname, ` != nil {`)
  401. p.In()
  402. p.P(`return 1`)
  403. p.Out()
  404. p.P(`}`)
  405. p.Out()
  406. p.P(`} else if this.`, fieldname, ` == nil {`)
  407. p.In()
  408. p.P(`return -1`)
  409. p.Out()
  410. p.P(`} else if c := this.`, fieldname, `.Compare(that1.`, fieldname, `); c != 0 {`)
  411. p.In()
  412. p.P(`return c`)
  413. p.Out()
  414. p.P(`}`)
  415. } else {
  416. p.generateField(file, message, field)
  417. }
  418. }
  419. if message.DescriptorProto.HasExtension() {
  420. fieldname := "XXX_extensions"
  421. if gogoproto.HasExtensionsMap(file.FileDescriptorProto, message.DescriptorProto) {
  422. p.P(`extkeys := make([]int32, 0, len(this.`, fieldname, `)+len(that1.`, fieldname, `))`)
  423. p.P(`for k, _ := range this.`, fieldname, ` {`)
  424. p.In()
  425. p.P(`extkeys = append(extkeys, k)`)
  426. p.Out()
  427. p.P(`}`)
  428. p.P(`for k, _ := range that1.`, fieldname, ` {`)
  429. p.In()
  430. p.P(`if _, ok := this.`, fieldname, `[k]; !ok {`)
  431. p.In()
  432. p.P(`extkeys = append(extkeys, k)`)
  433. p.Out()
  434. p.P(`}`)
  435. p.Out()
  436. p.P(`}`)
  437. p.P(p.sortkeysPkg.Use(), `.Int32s(extkeys)`)
  438. p.P(`for _, k := range extkeys {`)
  439. p.In()
  440. p.P(`if v, ok := this.`, fieldname, `[k]; ok {`)
  441. p.In()
  442. p.P(`if v2, ok := that1.`, fieldname, `[k]; ok {`)
  443. p.In()
  444. p.P(`if c := v.Compare(&v2); c != 0 {`)
  445. p.In()
  446. p.P(`return c`)
  447. p.Out()
  448. p.P(`}`)
  449. p.Out()
  450. p.P(`} else {`)
  451. p.In()
  452. p.P(`return 1`)
  453. p.Out()
  454. p.P(`}`)
  455. p.Out()
  456. p.P(`} else {`)
  457. p.In()
  458. p.P(`return -1`)
  459. p.Out()
  460. p.P(`}`)
  461. p.Out()
  462. p.P(`}`)
  463. } else {
  464. p.P(`if c := `, p.bytesPkg.Use(), `.Compare(this.`, fieldname, `, that1.`, fieldname, `); c != 0 {`)
  465. p.In()
  466. p.P(`return c`)
  467. p.Out()
  468. p.P(`}`)
  469. }
  470. }
  471. if gogoproto.HasUnrecognized(file.FileDescriptorProto, message.DescriptorProto) {
  472. fieldname := "XXX_unrecognized"
  473. p.P(`if c := `, p.bytesPkg.Use(), `.Compare(this.`, fieldname, `, that1.`, fieldname, `); c != 0 {`)
  474. p.In()
  475. p.P(`return c`)
  476. p.Out()
  477. p.P(`}`)
  478. }
  479. p.P(`return 0`)
  480. p.Out()
  481. p.P(`}`)
  482. //Generate Compare methods for oneof fields
  483. m := proto.Clone(message.DescriptorProto).(*descriptor.DescriptorProto)
  484. for _, field := range m.Field {
  485. oneof := field.OneofIndex != nil
  486. if !oneof {
  487. continue
  488. }
  489. ccTypeName := p.OneOfTypeName(message, field)
  490. p.P(`func (this *`, ccTypeName, `) Compare(that interface{}) int {`)
  491. p.In()
  492. p.generateMsgNullAndTypeCheck(ccTypeName)
  493. vanity.TurnOffNullableForNativeTypesWithoutDefaultsOnly(field)
  494. p.generateField(file, message, field)
  495. p.P(`return 0`)
  496. p.Out()
  497. p.P(`}`)
  498. }
  499. }
  500. func init() {
  501. generator.RegisterPlugin(NewPlugin())
  502. }