Преглед изворни кода

Slightly clean up XDR generator

Jakob Borg пре 11 година
родитељ
комит
2d4b89a8e9
1 измењених фајлова са 139 додато и 112 уклоњено
  1. 139 112
      xdr/cmd/genxdr/main.go

+ 139 - 112
xdr/cmd/genxdr/main.go

@@ -19,21 +19,30 @@ import (
 	"text/template"
 )
 
-var output string
-
-type field struct {
+type fieldInfo struct {
 	Name      string
-	IsBasic   bool
-	IsSlice   bool
-	IsMap     bool
-	FieldType string
-	KeyType   string
-	Encoder   string
-	Convert   string
-	Max       int
+	IsBasic   bool   // handled by one the native Read/WriteUint64 etc functions
+	IsSlice   bool   // field is a slice of FieldType
+	FieldType string // original type of field, i.e. "int"
+	Encoder   string // the encoder name, i.e. "Uint64" for Read/WriteUint64
+	Convert   string // what to convert to when encoding, i.e. "uint64"
+	Max       int    // max size for slices and strings
+}
+
+type structInfo struct {
+	Name   string
+	Fields []fieldInfo
 }
 
-var headerTpl = template.Must(template.New("header").Parse(`package {{.Package}}
+var headerTpl = template.Must(template.New("header").Parse(`// Copyright (C) 2014 Jakob Borg and Contributors (see the CONTRIBUTORS file).
+// All rights reserved. Use of this source code is governed by an MIT-style
+// license that can be found in the LICENSE file.
+
+// ************************************************************
+// This file is automatically generated by genxdr. Do not edit.
+// ************************************************************
+
+package {{.Package}}
 
 import (
 	"bytes"
@@ -61,37 +70,37 @@ func (o {{.TypeName}}) AppendXDR(bs []byte) []byte {
 }//+n
 
 func (o {{.TypeName}}) encodeXDR(xw *xdr.Writer) (int, error) {
-	{{range $field := .Fields}}
-	{{if not $field.IsSlice}}
-		{{if ne $field.Convert ""}}
-		xw.Write{{$field.Encoder}}({{$field.Convert}}(o.{{$field.Name}}))
-		{{else if $field.IsBasic}}
-		{{if ge $field.Max 1}}
-		if len(o.{{$field.Name}}) > {{$field.Max}} {
-			return xw.Tot(), xdr.ErrElementSizeExceeded
-		}
-		{{end}}
-		xw.Write{{$field.Encoder}}(o.{{$field.Name}})
+	{{range $fieldInfo := .Fields}}
+		{{if not $fieldInfo.IsSlice}}
+			{{if ne $fieldInfo.Convert ""}}
+				xw.Write{{$fieldInfo.Encoder}}({{$fieldInfo.Convert}}(o.{{$fieldInfo.Name}}))
+			{{else if $fieldInfo.IsBasic}}
+				{{if ge $fieldInfo.Max 1}}
+					if len(o.{{$fieldInfo.Name}}) > {{$fieldInfo.Max}} {
+						return xw.Tot(), xdr.ErrElementSizeExceeded
+					}
+				{{end}}
+				xw.Write{{$fieldInfo.Encoder}}(o.{{$fieldInfo.Name}})
+			{{else}}
+				o.{{$fieldInfo.Name}}.encodeXDR(xw)
+			{{end}}
 		{{else}}
-		o.{{$field.Name}}.encodeXDR(xw)
-		{{end}}
-	{{else}}
-	{{if ge $field.Max 1}}
-	if len(o.{{$field.Name}}) > {{$field.Max}} {
-		return xw.Tot(), xdr.ErrElementSizeExceeded
-	}
-	{{end}}
-	xw.WriteUint32(uint32(len(o.{{$field.Name}})))
-	for i := range o.{{$field.Name}} {
-		{{if ne $field.Convert ""}}
-		xw.Write{{$field.Encoder}}({{$field.Convert}}(o.{{$field.Name}}[i]))
-		{{else if $field.IsBasic}}
-		xw.Write{{$field.Encoder}}(o.{{$field.Name}}[i])
-		{{else}}
-		o.{{$field.Name}}[i].encodeXDR(xw)
+			{{if ge $fieldInfo.Max 1}}
+				if len(o.{{$fieldInfo.Name}}) > {{$fieldInfo.Max}} {
+					return xw.Tot(), xdr.ErrElementSizeExceeded
+				}
+			{{end}}
+			xw.WriteUint32(uint32(len(o.{{$fieldInfo.Name}})))
+			for i := range o.{{$fieldInfo.Name}} {
+			{{if ne $fieldInfo.Convert ""}}
+				xw.Write{{$fieldInfo.Encoder}}({{$fieldInfo.Convert}}(o.{{$fieldInfo.Name}}[i]))
+			{{else if $fieldInfo.IsBasic}}
+				xw.Write{{$fieldInfo.Encoder}}(o.{{$fieldInfo.Name}}[i])
+			{{else}}
+				o.{{$fieldInfo.Name}}[i].encodeXDR(xw)
+			{{end}}
+			}
 		{{end}}
-	}
-	{{end}}
 	{{end}}
 	return xw.Tot(), xw.Error()
 }//+n
@@ -108,37 +117,37 @@ func (o *{{.TypeName}}) UnmarshalXDR(bs []byte) error {
 }//+n
 
 func (o *{{.TypeName}}) decodeXDR(xr *xdr.Reader) error {
-	{{range $field := .Fields}}
-	{{if not $field.IsSlice}}
-		{{if ne $field.Convert ""}}
-		o.{{$field.Name}} = {{$field.FieldType}}(xr.Read{{$field.Encoder}}())
-		{{else if $field.IsBasic}}
-		{{if ge $field.Max 1}}
-		o.{{$field.Name}} = xr.Read{{$field.Encoder}}Max({{$field.Max}})
-		{{else}}
-		o.{{$field.Name}} = xr.Read{{$field.Encoder}}()
-		{{end}}
+	{{range $fieldInfo := .Fields}}
+		{{if not $fieldInfo.IsSlice}}
+			{{if ne $fieldInfo.Convert ""}}
+				o.{{$fieldInfo.Name}} = {{$fieldInfo.FieldType}}(xr.Read{{$fieldInfo.Encoder}}())
+			{{else if $fieldInfo.IsBasic}}
+				{{if ge $fieldInfo.Max 1}}
+					o.{{$fieldInfo.Name}} = xr.Read{{$fieldInfo.Encoder}}Max({{$fieldInfo.Max}})
+				{{else}}
+					o.{{$fieldInfo.Name}} = xr.Read{{$fieldInfo.Encoder}}()
+				{{end}}
+			{{else}}
+				(&o.{{$fieldInfo.Name}}).decodeXDR(xr)
+			{{end}}
 		{{else}}
-		(&o.{{$field.Name}}).decodeXDR(xr)
-		{{end}}
-	{{else}}
-	_{{$field.Name}}Size := int(xr.ReadUint32())
-	{{if ge $field.Max 1}}
-	if _{{$field.Name}}Size > {{$field.Max}} {
-		return xdr.ErrElementSizeExceeded
-	}
-	{{end}}
-	o.{{$field.Name}} = make([]{{$field.FieldType}}, _{{$field.Name}}Size)
-	for i := range o.{{$field.Name}} {
-		{{if ne $field.Convert ""}}
-		o.{{$field.Name}}[i] = {{$field.FieldType}}(xr.Read{{$field.Encoder}}())
-		{{else if $field.IsBasic}}
-		o.{{$field.Name}}[i] = xr.Read{{$field.Encoder}}()
-		{{else}}
-		(&o.{{$field.Name}}[i]).decodeXDR(xr)
+			_{{$fieldInfo.Name}}Size := int(xr.ReadUint32())
+			{{if ge $fieldInfo.Max 1}}
+				if _{{$fieldInfo.Name}}Size > {{$fieldInfo.Max}} {
+					return xdr.ErrElementSizeExceeded
+				}
+			{{end}}
+			o.{{$fieldInfo.Name}} = make([]{{$fieldInfo.FieldType}}, _{{$fieldInfo.Name}}Size)
+			for i := range o.{{$fieldInfo.Name}} {
+				{{if ne $fieldInfo.Convert ""}}
+					o.{{$fieldInfo.Name}}[i] = {{$fieldInfo.FieldType}}(xr.Read{{$fieldInfo.Encoder}}())
+				{{else if $fieldInfo.IsBasic}}
+					o.{{$fieldInfo.Name}}[i] = xr.Read{{$fieldInfo.Encoder}}()
+				{{else}}
+					(&o.{{$fieldInfo.Name}}[i]).decodeXDR(xr)
+				{{end}}
+			}
 		{{end}}
-	}
-	{{end}}
 	{{end}}
 	return xr.Error()
 }`))
@@ -163,8 +172,9 @@ var xdrEncoders = map[string]typeSet{
 	"bool":   typeSet{"", "Bool"},
 }
 
-func handleStruct(name string, t *ast.StructType) {
-	var fs []field
+func handleStruct(t *ast.StructType) []fieldInfo {
+	var fs []fieldInfo
+
 	for _, sf := range t.Fields.List {
 		if len(sf.Names) == 0 {
 			// We don't handle anonymous fields
@@ -183,12 +193,12 @@ func handleStruct(name string, t *ast.StructType) {
 			}
 		}
 
-		var f field
+		var f fieldInfo
 		switch ft := sf.Type.(type) {
 		case *ast.Ident:
 			tn := ft.Name
 			if enc, ok := xdrEncoders[tn]; ok {
-				f = field{
+				f = fieldInfo{
 					Name:      fn,
 					IsBasic:   true,
 					FieldType: tn,
@@ -197,7 +207,7 @@ func handleStruct(name string, t *ast.StructType) {
 					Max:       max,
 				}
 			} else {
-				f = field{
+				f = fieldInfo{
 					Name:      fn,
 					IsBasic:   false,
 					FieldType: tn,
@@ -213,7 +223,7 @@ func handleStruct(name string, t *ast.StructType) {
 
 			tn := ft.Elt.(*ast.Ident).Name
 			if enc, ok := xdrEncoders["[]"+tn]; ok {
-				f = field{
+				f = fieldInfo{
 					Name:      fn,
 					IsBasic:   true,
 					FieldType: tn,
@@ -222,7 +232,7 @@ func handleStruct(name string, t *ast.StructType) {
 					Max:       max,
 				}
 			} else if enc, ok := xdrEncoders[tn]; ok {
-				f = field{
+				f = fieldInfo{
 					Name:      fn,
 					IsBasic:   true,
 					IsSlice:   true,
@@ -232,7 +242,7 @@ func handleStruct(name string, t *ast.StructType) {
 					Max:       max,
 				}
 			} else {
-				f = field{
+				f = fieldInfo{
 					Name:      fn,
 					IsBasic:   false,
 					IsSlice:   true,
@@ -245,17 +255,13 @@ func handleStruct(name string, t *ast.StructType) {
 		fs = append(fs, f)
 	}
 
-	switch output {
-	case "code":
-		generateCode(name, fs)
-	case "diagram":
-		generateDiagram(name, fs)
-	case "xdr":
-		generateXdr(name, fs)
-	}
+	return fs
 }
 
-func generateCode(name string, fs []field) {
+func generateCode(s structInfo) {
+	name := s.Name
+	fs := s.Fields
+
 	var buf bytes.Buffer
 	err := encodeTpl.Execute(&buf, map[string]interface{}{"TypeName": name, "Fields": fs})
 	if err != nil {
@@ -272,7 +278,16 @@ func generateCode(name string, fs []field) {
 	fmt.Println(string(bs))
 }
 
-func generateDiagram(sn string, fs []field) {
+func uncamelize(s string) string {
+	return regexp.MustCompile("[a-z][A-Z]").ReplaceAllStringFunc(s, func(camel string) string {
+		return camel[:1] + " " + camel[1:]
+	})
+}
+
+func generateDiagram(s structInfo) {
+	sn := s.Name
+	fs := s.Fields
+
 	fmt.Println(sn + " Structure:")
 	fmt.Println()
 	fmt.Println(" 0                   1                   2                   3")
@@ -283,28 +298,32 @@ func generateDiagram(sn string, fs []field) {
 	for _, f := range fs {
 		tn := f.FieldType
 		sl := f.IsSlice
+		name := uncamelize(f.Name)
 
 		if sl {
-			fmt.Printf("| %s |\n", center("Number of "+f.Name, 61))
+			fmt.Printf("| %s |\n", center("Number of "+name, 61))
 			fmt.Println(line)
 		}
 		switch tn {
+		case "bool":
+			fmt.Printf("| %s |V|\n", center(name+" (V=0 or 1)", 59))
+			fmt.Println(line)
 		case "uint16":
-			fmt.Printf("| %s | %s |\n", center(f.Name, 29), center("0x0000", 29))
+			fmt.Printf("| %s | %s |\n", center("0x0000", 29), center(name, 29))
 			fmt.Println(line)
 		case "uint32":
-			fmt.Printf("| %s |\n", center(f.Name, 61))
+			fmt.Printf("| %s |\n", center(name, 61))
 			fmt.Println(line)
 		case "int64", "uint64":
 			fmt.Printf("| %-61s |\n", "")
-			fmt.Printf("+ %s +\n", center(f.Name+" (64 bits)", 61))
+			fmt.Printf("+ %s +\n", center(name+" (64 bits)", 61))
 			fmt.Printf("| %-61s |\n", "")
 			fmt.Println(line)
 		case "string", "byte": // XXX We assume slice of byte!
-			fmt.Printf("| %s |\n", center("Length of "+f.Name, 61))
+			fmt.Printf("| %s |\n", center("Length of "+name, 61))
 			fmt.Println(line)
 			fmt.Printf("/ %61s /\n", "")
-			fmt.Printf("\\ %s \\\n", center(f.Name+" (variable length)", 61))
+			fmt.Printf("\\ %s \\\n", center(name+" (variable length)", 61))
 			fmt.Printf("/ %61s /\n", "")
 			fmt.Println(line)
 		default:
@@ -323,30 +342,35 @@ func generateDiagram(sn string, fs []field) {
 	fmt.Println()
 }
 
-func generateXdr(sn string, fs []field) {
+func generateXdr(s structInfo) {
+	sn := s.Name
+	fs := s.Fields
+
 	fmt.Printf("struct %s {\n", sn)
 
 	for _, f := range fs {
 		tn := f.FieldType
 		fn := f.Name
 		suf := ""
+		l := ""
+		if f.Max > 0 {
+			l = strconv.Itoa(f.Max)
+		}
 		if f.IsSlice {
-			suf = "<>"
+			suf = "<" + l + ">"
 		}
 
 		switch tn {
-		case "uint16":
-			fmt.Printf("\tunsigned short %s%s;\n", fn, suf)
-		case "uint32":
+		case "uint16", "uint32":
 			fmt.Printf("\tunsigned int %s%s;\n", fn, suf)
 		case "int64":
 			fmt.Printf("\thyper %s%s;\n", fn, suf)
 		case "uint64":
 			fmt.Printf("\tunsigned hyper %s%s;\n", fn, suf)
 		case "string":
-			fmt.Printf("\tstring %s<>;\n", fn)
+			fmt.Printf("\tstring %s<%s>;\n", fn, l)
 		case "byte":
-			fmt.Printf("\topaque %s<>;\n", fn)
+			fmt.Printf("\topaque %s<%s>;\n", fn, l)
 		default:
 			fmt.Printf("\t%s %s%s;\n", tn, fn, suf)
 		}
@@ -365,14 +389,15 @@ func center(s string, w int) string {
 	return strings.Repeat(" ", l) + s + strings.Repeat(" ", r)
 }
 
-func inspector(fset *token.FileSet) func(ast.Node) bool {
+func inspector(structs *[]structInfo) func(ast.Node) bool {
 	return func(n ast.Node) bool {
 		switch n := n.(type) {
 		case *ast.TypeSpec:
 			switch t := n.Type.(type) {
 			case *ast.StructType:
 				name := n.Name.Name
-				handleStruct(name, t)
+				fs := handleStruct(t)
+				*structs = append(*structs, structInfo{name, fs})
 			}
 			return false
 		default:
@@ -382,23 +407,25 @@ func inspector(fset *token.FileSet) func(ast.Node) bool {
 }
 
 func main() {
-	flag.StringVar(&output, "output", "code", "code,xdr,diagram")
 	flag.Parse()
 	fname := flag.Arg(0)
 
-	// Create the AST by parsing src.
-	fset := token.NewFileSet() // positions are relative to fset
+	fset := token.NewFileSet()
 	f, err := parser.ParseFile(fset, fname, nil, parser.ParseComments)
 	if err != nil {
 		panic(err)
 	}
 
-	//ast.Print(fset, f)
+	var structs []structInfo
+	i := inspector(&structs)
+	ast.Inspect(f, i)
 
-	if output == "code" {
-		headerTpl.Execute(os.Stdout, map[string]string{"Package": f.Name.Name})
+	headerTpl.Execute(os.Stdout, map[string]string{"Package": f.Name.Name})
+	for _, s := range structs {
+		fmt.Printf("\n/*\n\n")
+		generateDiagram(s)
+		generateXdr(s)
+		fmt.Printf("*/\n")
+		generateCode(s)
 	}
-
-	i := inspector(fset)
-	ast.Inspect(f, i)
 }