main.go 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. package main
  2. import (
  3. "bytes"
  4. "flag"
  5. "fmt"
  6. "go/ast"
  7. "go/format"
  8. "go/parser"
  9. "go/token"
  10. "os"
  11. "regexp"
  12. "strconv"
  13. "strings"
  14. "text/template"
  15. )
  16. var output string
  17. type field struct {
  18. Name string
  19. IsBasic bool
  20. IsSlice bool
  21. IsMap bool
  22. FieldType string
  23. KeyType string
  24. Encoder string
  25. Convert string
  26. Max int
  27. }
  28. var headerTpl = template.Must(template.New("header").Parse(`package {{.Package}}
  29. import (
  30. "bytes"
  31. "io"
  32. "github.com/calmh/syncthing/xdr"
  33. )
  34. `))
  35. var encodeTpl = template.Must(template.New("encoder").Parse(`
  36. func (o {{.TypeName}}) EncodeXDR(w io.Writer) (int, error) {
  37. var xw = xdr.NewWriter(w)
  38. return o.encodeXDR(xw)
  39. }//+n
  40. func (o {{.TypeName}}) MarshalXDR() []byte {
  41. var buf bytes.Buffer
  42. var xw = xdr.NewWriter(&buf)
  43. o.encodeXDR(xw)
  44. return buf.Bytes()
  45. }//+n
  46. func (o {{.TypeName}}) encodeXDR(xw *xdr.Writer) (int, error) {
  47. {{range $field := .Fields}}
  48. {{if not $field.IsSlice}}
  49. {{if ne $field.Convert ""}}
  50. xw.Write{{$field.Encoder}}({{$field.Convert}}(o.{{$field.Name}}))
  51. {{else if $field.IsBasic}}
  52. {{if ge $field.Max 1}}
  53. if len(o.{{$field.Name}}) > {{$field.Max}} {
  54. return xw.Tot(), xdr.ErrElementSizeExceeded
  55. }
  56. {{end}}
  57. xw.Write{{$field.Encoder}}(o.{{$field.Name}})
  58. {{else}}
  59. o.{{$field.Name}}.encodeXDR(xw)
  60. {{end}}
  61. {{else}}
  62. {{if ge $field.Max 1}}
  63. if len(o.{{$field.Name}}) > {{$field.Max}} {
  64. return xw.Tot(), xdr.ErrElementSizeExceeded
  65. }
  66. {{end}}
  67. xw.WriteUint32(uint32(len(o.{{$field.Name}})))
  68. for i := range o.{{$field.Name}} {
  69. {{if ne $field.Convert ""}}
  70. xw.Write{{$field.Encoder}}({{$field.Convert}}(o.{{$field.Name}}[i]))
  71. {{else if $field.IsBasic}}
  72. xw.Write{{$field.Encoder}}(o.{{$field.Name}}[i])
  73. {{else}}
  74. o.{{$field.Name}}[i].encodeXDR(xw)
  75. {{end}}
  76. }
  77. {{end}}
  78. {{end}}
  79. return xw.Tot(), xw.Error()
  80. }//+n
  81. func (o *{{.TypeName}}) DecodeXDR(r io.Reader) error {
  82. xr := xdr.NewReader(r)
  83. return o.decodeXDR(xr)
  84. }//+n
  85. func (o *{{.TypeName}}) UnmarshalXDR(bs []byte) error {
  86. var buf = bytes.NewBuffer(bs)
  87. var xr = xdr.NewReader(buf)
  88. return o.decodeXDR(xr)
  89. }//+n
  90. func (o *{{.TypeName}}) decodeXDR(xr *xdr.Reader) error {
  91. {{range $field := .Fields}}
  92. {{if not $field.IsSlice}}
  93. {{if ne $field.Convert ""}}
  94. o.{{$field.Name}} = {{$field.FieldType}}(xr.Read{{$field.Encoder}}())
  95. {{else if $field.IsBasic}}
  96. {{if ge $field.Max 1}}
  97. o.{{$field.Name}} = xr.Read{{$field.Encoder}}Max({{$field.Max}})
  98. {{else}}
  99. o.{{$field.Name}} = xr.Read{{$field.Encoder}}()
  100. {{end}}
  101. {{else}}
  102. (&o.{{$field.Name}}).decodeXDR(xr)
  103. {{end}}
  104. {{else}}
  105. _{{$field.Name}}Size := int(xr.ReadUint32())
  106. {{if ge $field.Max 1}}
  107. if _{{$field.Name}}Size > {{$field.Max}} {
  108. return xdr.ErrElementSizeExceeded
  109. }
  110. {{end}}
  111. o.{{$field.Name}} = make([]{{$field.FieldType}}, _{{$field.Name}}Size)
  112. for i := range o.{{$field.Name}} {
  113. {{if ne $field.Convert ""}}
  114. o.{{$field.Name}}[i] = {{$field.FieldType}}(xr.Read{{$field.Encoder}}())
  115. {{else if $field.IsBasic}}
  116. o.{{$field.Name}}[i] = xr.Read{{$field.Encoder}}()
  117. {{else}}
  118. (&o.{{$field.Name}}[i]).decodeXDR(xr)
  119. {{end}}
  120. }
  121. {{end}}
  122. {{end}}
  123. return xr.Error()
  124. }`))
  125. var maxRe = regexp.MustCompile(`\Wmax:(\d+)`)
  126. type typeSet struct {
  127. Type string
  128. Encoder string
  129. }
  130. var xdrEncoders = map[string]typeSet{
  131. "int16": typeSet{"uint16", "Uint16"},
  132. "uint16": typeSet{"", "Uint16"},
  133. "int32": typeSet{"uint32", "Uint32"},
  134. "uint32": typeSet{"", "Uint32"},
  135. "int64": typeSet{"uint64", "Uint64"},
  136. "uint64": typeSet{"", "Uint64"},
  137. "int": typeSet{"uint64", "Uint64"},
  138. "string": typeSet{"", "String"},
  139. "[]byte": typeSet{"", "Bytes"},
  140. "bool": typeSet{"", "Bool"},
  141. }
  142. func handleStruct(name string, t *ast.StructType) {
  143. var fs []field
  144. for _, sf := range t.Fields.List {
  145. if len(sf.Names) == 0 {
  146. // We don't handle anonymous fields
  147. continue
  148. }
  149. fn := sf.Names[0].Name
  150. var max = 0
  151. if sf.Comment != nil {
  152. c := sf.Comment.List[0].Text
  153. if m := maxRe.FindStringSubmatch(c); m != nil {
  154. max, _ = strconv.Atoi(m[1])
  155. }
  156. }
  157. var f field
  158. switch ft := sf.Type.(type) {
  159. case *ast.Ident:
  160. tn := ft.Name
  161. if enc, ok := xdrEncoders[tn]; ok {
  162. f = field{
  163. Name: fn,
  164. IsBasic: true,
  165. FieldType: tn,
  166. Encoder: enc.Encoder,
  167. Convert: enc.Type,
  168. Max: max,
  169. }
  170. } else {
  171. f = field{
  172. Name: fn,
  173. IsBasic: false,
  174. FieldType: tn,
  175. Max: max,
  176. }
  177. }
  178. case *ast.ArrayType:
  179. if ft.Len != nil {
  180. // We don't handle arrays
  181. continue
  182. }
  183. tn := ft.Elt.(*ast.Ident).Name
  184. if enc, ok := xdrEncoders["[]"+tn]; ok {
  185. f = field{
  186. Name: fn,
  187. IsBasic: true,
  188. FieldType: tn,
  189. Encoder: enc.Encoder,
  190. Convert: enc.Type,
  191. Max: max,
  192. }
  193. } else if enc, ok := xdrEncoders[tn]; ok {
  194. f = field{
  195. Name: fn,
  196. IsBasic: true,
  197. IsSlice: true,
  198. FieldType: tn,
  199. Encoder: enc.Encoder,
  200. Convert: enc.Type,
  201. Max: max,
  202. }
  203. } else {
  204. f = field{
  205. Name: fn,
  206. IsBasic: false,
  207. IsSlice: true,
  208. FieldType: tn,
  209. Max: max,
  210. }
  211. }
  212. }
  213. fs = append(fs, f)
  214. }
  215. switch output {
  216. case "code":
  217. generateCode(name, fs)
  218. case "diagram":
  219. generateDiagram(name, fs)
  220. case "xdr":
  221. generateXdr(name, fs)
  222. }
  223. }
  224. func generateCode(name string, fs []field) {
  225. var buf bytes.Buffer
  226. err := encodeTpl.Execute(&buf, map[string]interface{}{"TypeName": name, "Fields": fs})
  227. if err != nil {
  228. panic(err)
  229. }
  230. bs := regexp.MustCompile(`(\s*\n)+`).ReplaceAll(buf.Bytes(), []byte("\n"))
  231. bs = bytes.Replace(bs, []byte("//+n"), []byte("\n"), -1)
  232. bs, err = format.Source(bs)
  233. if err != nil {
  234. panic(err)
  235. }
  236. fmt.Println(string(bs))
  237. }
  238. func generateDiagram(sn string, fs []field) {
  239. fmt.Println(sn + " Structure:")
  240. fmt.Println()
  241. fmt.Println(" 0 1 2 3")
  242. fmt.Println(" 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1")
  243. line := "+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+"
  244. fmt.Println(line)
  245. for _, f := range fs {
  246. tn := f.FieldType
  247. sl := f.IsSlice
  248. if sl {
  249. fmt.Printf("| %s |\n", center("Number of "+f.Name, 61))
  250. fmt.Println(line)
  251. }
  252. switch tn {
  253. case "uint16":
  254. fmt.Printf("| %s | %s |\n", center(f.Name, 29), center("0x0000", 29))
  255. fmt.Println(line)
  256. case "uint32":
  257. fmt.Printf("| %s |\n", center(f.Name, 61))
  258. fmt.Println(line)
  259. case "int64", "uint64":
  260. fmt.Printf("| %-61s |\n", "")
  261. fmt.Printf("+ %s +\n", center(f.Name+" (64 bits)", 61))
  262. fmt.Printf("| %-61s |\n", "")
  263. fmt.Println(line)
  264. case "string", "byte": // XXX We assume slice of byte!
  265. fmt.Printf("| %s |\n", center("Length of "+f.Name, 61))
  266. fmt.Println(line)
  267. fmt.Printf("/ %61s /\n", "")
  268. fmt.Printf("\\ %s \\\n", center(f.Name+" (variable length)", 61))
  269. fmt.Printf("/ %61s /\n", "")
  270. fmt.Println(line)
  271. default:
  272. if sl {
  273. tn = "Zero or more " + tn + " Structures"
  274. fmt.Printf("/ %s /\n", center("", 61))
  275. fmt.Printf("\\ %s \\\n", center(tn, 61))
  276. fmt.Printf("/ %s /\n", center("", 61))
  277. } else {
  278. fmt.Printf("| %s |\n", center(tn, 61))
  279. }
  280. fmt.Println(line)
  281. }
  282. }
  283. fmt.Println()
  284. fmt.Println()
  285. }
  286. func generateXdr(sn string, fs []field) {
  287. fmt.Printf("struct %s {\n", sn)
  288. for _, f := range fs {
  289. tn := f.FieldType
  290. fn := f.Name
  291. suf := ""
  292. if f.IsSlice {
  293. suf = "<>"
  294. }
  295. switch tn {
  296. case "uint16":
  297. fmt.Printf("\tunsigned short %s%s;\n", fn, suf)
  298. case "uint32":
  299. fmt.Printf("\tunsigned int %s%s;\n", fn, suf)
  300. case "int64":
  301. fmt.Printf("\thyper %s%s;\n", fn, suf)
  302. case "uint64":
  303. fmt.Printf("\tunsigned hyper %s%s;\n", fn, suf)
  304. case "string":
  305. fmt.Printf("\tstring %s<>;\n", fn)
  306. case "byte":
  307. fmt.Printf("\topaque %s<>;\n", fn)
  308. default:
  309. fmt.Printf("\t%s %s%s;\n", tn, fn, suf)
  310. }
  311. }
  312. fmt.Println("}")
  313. fmt.Println()
  314. }
  315. func center(s string, w int) string {
  316. w -= len(s)
  317. l := w / 2
  318. r := l
  319. if l+r < w {
  320. r++
  321. }
  322. return strings.Repeat(" ", l) + s + strings.Repeat(" ", r)
  323. }
  324. func inspector(fset *token.FileSet) func(ast.Node) bool {
  325. return func(n ast.Node) bool {
  326. switch n := n.(type) {
  327. case *ast.TypeSpec:
  328. switch t := n.Type.(type) {
  329. case *ast.StructType:
  330. name := n.Name.Name
  331. handleStruct(name, t)
  332. }
  333. return false
  334. default:
  335. return true
  336. }
  337. }
  338. }
  339. func main() {
  340. flag.StringVar(&output, "output", "code", "code,xdr,diagram")
  341. flag.Parse()
  342. fname := flag.Arg(0)
  343. // Create the AST by parsing src.
  344. fset := token.NewFileSet() // positions are relative to fset
  345. f, err := parser.ParseFile(fset, fname, nil, parser.ParseComments)
  346. if err != nil {
  347. panic(err)
  348. }
  349. //ast.Print(fset, f)
  350. if output == "code" {
  351. headerTpl.Execute(os.Stdout, map[string]string{"Package": f.Name.Name})
  352. }
  353. i := inspector(fset)
  354. ast.Inspect(f, i)
  355. }