| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186 | package mainimport (	"fmt"	"log"	"os"	"sort"	"strings"	"github.com/olekukonko/tablewriter"	"github.com/pkg/errors"	"gopkg.in/DATA-DOG/go-sqlmock.v2"	"gorm.io/driver/mysql"	"gorm.io/driver/postgres"	"gorm.io/driver/sqlite"	"gorm.io/gorm"	"gorm.io/gorm/clause"	"gorm.io/gorm/schema"	"gogs.io/gogs/internal/db")//go:generate go run main.go ../../../docs/dev/database_schema.mdfunc main() {	w, err := os.Create(os.Args[1])	if err != nil {		log.Fatalf("Failed to create file: %v", err)	}	defer func() { _ = w.Close() }()	conn, _, err := sqlmock.New()	if err != nil {		log.Fatalf("Failed to get mock connection: %v", err)	}	defer func() { _ = conn.Close() }()	dialectors := []gorm.Dialector{		postgres.New(postgres.Config{			Conn: conn,		}),		mysql.New(mysql.Config{			Conn:                      conn,			SkipInitializeWithVersion: true,		}),		sqlite.Open(""),	}	collected := make([][]*tableInfo, 0, len(dialectors))	for i, dialector := range dialectors {		tableInfos, err := generate(dialector)		if err != nil {			log.Fatalf("Failed to get table info of %d: %v", i, err)		}		collected = append(collected, tableInfos)	}	for i, ti := range collected[0] {		_, _ = w.WriteString(`# Table "` + ti.Name + `"`)		_, _ = w.WriteString("\n\n")		_, _ = w.WriteString("```\n")		table := tablewriter.NewWriter(w)		table.SetHeader([]string{"Field", "Column", "PostgreSQL", "MySQL", "SQLite3"})		table.SetBorder(false)		for j, f := range ti.Fields {			table.Append([]string{				f.Name, f.Column,				strings.ToUpper(f.Type),                         // PostgreSQL				strings.ToUpper(collected[1][i].Fields[j].Type), // MySQL				strings.ToUpper(collected[2][i].Fields[j].Type), // SQLite3			})		}		table.Render()		_, _ = w.WriteString("\n")		_, _ = w.WriteString("Primary keys: ")		_, _ = w.WriteString(strings.Join(ti.PrimaryKeys, ", "))		_, _ = w.WriteString("\n")		if len(ti.Indexes) > 0 {			_, _ = w.WriteString("Indexes: \n")			for _, index := range ti.Indexes {				_, _ = w.WriteString(fmt.Sprintf("\t%q", index.Name))				if index.Class != "" {					_, _ = w.WriteString(fmt.Sprintf(" %s", index.Class))				}				if index.Type != "" {					_, _ = w.WriteString(fmt.Sprintf(", %s", index.Type))				}				if len(index.Fields) > 0 {					fields := make([]string, len(index.Fields))					for i := range index.Fields {						fields[i] = index.Fields[i].DBName					}					_, _ = w.WriteString(fmt.Sprintf(" (%s)", strings.Join(fields, ", ")))				}				_, _ = w.WriteString("\n")			}		}		_, _ = w.WriteString("```\n\n")	}}type tableField struct {	Name   string	Column string	Type   string}type tableInfo struct {	Name        string	Fields      []*tableField	PrimaryKeys []string	Indexes     []schema.Index}// This function is derived from gorm.io/gorm/migrator/migrator.go:Migrator.CreateTable.func generate(dialector gorm.Dialector) ([]*tableInfo, error) {	conn, err := gorm.Open(dialector,		&gorm.Config{			SkipDefaultTransaction: true,			NamingStrategy: schema.NamingStrategy{				SingularTable: true,			},			DryRun:               true,			DisableAutomaticPing: true,		},	)	if err != nil {		return nil, errors.Wrap(err, "open database")	}	m := conn.Migrator().(interface {		RunWithValue(value interface{}, fc func(*gorm.Statement) error) error		FullDataTypeOf(*schema.Field) clause.Expr	})	tableInfos := make([]*tableInfo, 0, len(db.Tables))	for _, table := range db.Tables {		err = m.RunWithValue(table, func(stmt *gorm.Statement) error {			fields := make([]*tableField, 0, len(stmt.Schema.DBNames))			for _, field := range stmt.Schema.Fields {				if field.DBName == "" {					continue				}				fields = append(fields, &tableField{					Name:   field.Name,					Column: field.DBName,					Type:   m.FullDataTypeOf(field).SQL,				})			}			primaryKeys := make([]string, 0, len(stmt.Schema.PrimaryFields))			if len(stmt.Schema.PrimaryFields) > 0 {				for _, field := range stmt.Schema.PrimaryFields {					primaryKeys = append(primaryKeys, field.DBName)				}			}			var indexes []schema.Index			for _, index := range stmt.Schema.ParseIndexes() {				indexes = append(indexes, index)			}			sort.Slice(indexes, func(i, j int) bool {				return indexes[i].Name < indexes[j].Name			})			tableInfos = append(tableInfos, &tableInfo{				Name:        stmt.Table,				Fields:      fields,				PrimaryKeys: primaryKeys,				Indexes:     indexes,			})			return nil		})		if err != nil {			return nil, errors.Wrap(err, "gather table information")		}	}	return tableInfos, nil}
 |