| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611 | // Go MySQL Driver - A MySQL-Driver for Go's database/sql package//// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.//// This Source Code Form is subject to the terms of the Mozilla Public// License, v. 2.0. If a copy of the MPL was not distributed with this file,// You can obtain one at http://mozilla.org/MPL/2.0/.package mysqlimport (	"bytes"	"crypto/rsa"	"crypto/tls"	"errors"	"fmt"	"net"	"net/url"	"sort"	"strconv"	"strings"	"time")var (	errInvalidDSNUnescaped       = errors.New("invalid DSN: did you forget to escape a param value?")	errInvalidDSNAddr            = errors.New("invalid DSN: network address not terminated (missing closing brace)")	errInvalidDSNNoSlash         = errors.New("invalid DSN: missing the slash separating the database name")	errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations"))// Config is a configuration parsed from a DSN string.// If a new Config is created instead of being parsed from a DSN string,// the NewConfig function should be used, which sets default values.type Config struct {	User             string            // Username	Passwd           string            // Password (requires User)	Net              string            // Network type	Addr             string            // Network address (requires Net)	DBName           string            // Database name	Params           map[string]string // Connection parameters	Collation        string            // Connection collation	Loc              *time.Location    // Location for time.Time values	MaxAllowedPacket int               // Max packet size allowed	ServerPubKey     string            // Server public key name	pubKey           *rsa.PublicKey    // Server public key	TLSConfig        string            // TLS configuration name	tls              *tls.Config       // TLS configuration	Timeout          time.Duration     // Dial timeout	ReadTimeout      time.Duration     // I/O read timeout	WriteTimeout     time.Duration     // I/O write timeout	AllowAllFiles           bool // Allow all files to be used with LOAD DATA LOCAL INFILE	AllowCleartextPasswords bool // Allows the cleartext client side plugin	AllowNativePasswords    bool // Allows the native password authentication method	AllowOldPasswords       bool // Allows the old insecure password method	ClientFoundRows         bool // Return number of matching rows instead of rows changed	ColumnsWithAlias        bool // Prepend table alias to column names	InterpolateParams       bool // Interpolate placeholders into query string	MultiStatements         bool // Allow multiple statements in one query	ParseTime               bool // Parse time values to time.Time	RejectReadOnly          bool // Reject read-only connections}// NewConfig creates a new Config and sets default values.func NewConfig() *Config {	return &Config{		Collation:            defaultCollation,		Loc:                  time.UTC,		MaxAllowedPacket:     defaultMaxAllowedPacket,		AllowNativePasswords: true,	}}func (cfg *Config) normalize() error {	if cfg.InterpolateParams && unsafeCollations[cfg.Collation] {		return errInvalidDSNUnsafeCollation	}	// Set default network if empty	if cfg.Net == "" {		cfg.Net = "tcp"	}	// Set default address if empty	if cfg.Addr == "" {		switch cfg.Net {		case "tcp":			cfg.Addr = "127.0.0.1:3306"		case "unix":			cfg.Addr = "/tmp/mysql.sock"		default:			return errors.New("default addr for network '" + cfg.Net + "' unknown")		}	} else if cfg.Net == "tcp" {		cfg.Addr = ensureHavePort(cfg.Addr)	}	if cfg.tls != nil {		if cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify {			host, _, err := net.SplitHostPort(cfg.Addr)			if err == nil {				cfg.tls.ServerName = host			}		}	}	return nil}// FormatDSN formats the given Config into a DSN string which can be passed to// the driver.func (cfg *Config) FormatDSN() string {	var buf bytes.Buffer	// [username[:password]@]	if len(cfg.User) > 0 {		buf.WriteString(cfg.User)		if len(cfg.Passwd) > 0 {			buf.WriteByte(':')			buf.WriteString(cfg.Passwd)		}		buf.WriteByte('@')	}	// [protocol[(address)]]	if len(cfg.Net) > 0 {		buf.WriteString(cfg.Net)		if len(cfg.Addr) > 0 {			buf.WriteByte('(')			buf.WriteString(cfg.Addr)			buf.WriteByte(')')		}	}	// /dbname	buf.WriteByte('/')	buf.WriteString(cfg.DBName)	// [?param1=value1&...¶mN=valueN]	hasParam := false	if cfg.AllowAllFiles {		hasParam = true		buf.WriteString("?allowAllFiles=true")	}	if cfg.AllowCleartextPasswords {		if hasParam {			buf.WriteString("&allowCleartextPasswords=true")		} else {			hasParam = true			buf.WriteString("?allowCleartextPasswords=true")		}	}	if !cfg.AllowNativePasswords {		if hasParam {			buf.WriteString("&allowNativePasswords=false")		} else {			hasParam = true			buf.WriteString("?allowNativePasswords=false")		}	}	if cfg.AllowOldPasswords {		if hasParam {			buf.WriteString("&allowOldPasswords=true")		} else {			hasParam = true			buf.WriteString("?allowOldPasswords=true")		}	}	if cfg.ClientFoundRows {		if hasParam {			buf.WriteString("&clientFoundRows=true")		} else {			hasParam = true			buf.WriteString("?clientFoundRows=true")		}	}	if col := cfg.Collation; col != defaultCollation && len(col) > 0 {		if hasParam {			buf.WriteString("&collation=")		} else {			hasParam = true			buf.WriteString("?collation=")		}		buf.WriteString(col)	}	if cfg.ColumnsWithAlias {		if hasParam {			buf.WriteString("&columnsWithAlias=true")		} else {			hasParam = true			buf.WriteString("?columnsWithAlias=true")		}	}	if cfg.InterpolateParams {		if hasParam {			buf.WriteString("&interpolateParams=true")		} else {			hasParam = true			buf.WriteString("?interpolateParams=true")		}	}	if cfg.Loc != time.UTC && cfg.Loc != nil {		if hasParam {			buf.WriteString("&loc=")		} else {			hasParam = true			buf.WriteString("?loc=")		}		buf.WriteString(url.QueryEscape(cfg.Loc.String()))	}	if cfg.MultiStatements {		if hasParam {			buf.WriteString("&multiStatements=true")		} else {			hasParam = true			buf.WriteString("?multiStatements=true")		}	}	if cfg.ParseTime {		if hasParam {			buf.WriteString("&parseTime=true")		} else {			hasParam = true			buf.WriteString("?parseTime=true")		}	}	if cfg.ReadTimeout > 0 {		if hasParam {			buf.WriteString("&readTimeout=")		} else {			hasParam = true			buf.WriteString("?readTimeout=")		}		buf.WriteString(cfg.ReadTimeout.String())	}	if cfg.RejectReadOnly {		if hasParam {			buf.WriteString("&rejectReadOnly=true")		} else {			hasParam = true			buf.WriteString("?rejectReadOnly=true")		}	}	if len(cfg.ServerPubKey) > 0 {		if hasParam {			buf.WriteString("&serverPubKey=")		} else {			hasParam = true			buf.WriteString("?serverPubKey=")		}		buf.WriteString(url.QueryEscape(cfg.ServerPubKey))	}	if cfg.Timeout > 0 {		if hasParam {			buf.WriteString("&timeout=")		} else {			hasParam = true			buf.WriteString("?timeout=")		}		buf.WriteString(cfg.Timeout.String())	}	if len(cfg.TLSConfig) > 0 {		if hasParam {			buf.WriteString("&tls=")		} else {			hasParam = true			buf.WriteString("?tls=")		}		buf.WriteString(url.QueryEscape(cfg.TLSConfig))	}	if cfg.WriteTimeout > 0 {		if hasParam {			buf.WriteString("&writeTimeout=")		} else {			hasParam = true			buf.WriteString("?writeTimeout=")		}		buf.WriteString(cfg.WriteTimeout.String())	}	if cfg.MaxAllowedPacket != defaultMaxAllowedPacket {		if hasParam {			buf.WriteString("&maxAllowedPacket=")		} else {			hasParam = true			buf.WriteString("?maxAllowedPacket=")		}		buf.WriteString(strconv.Itoa(cfg.MaxAllowedPacket))	}	// other params	if cfg.Params != nil {		var params []string		for param := range cfg.Params {			params = append(params, param)		}		sort.Strings(params)		for _, param := range params {			if hasParam {				buf.WriteByte('&')			} else {				hasParam = true				buf.WriteByte('?')			}			buf.WriteString(param)			buf.WriteByte('=')			buf.WriteString(url.QueryEscape(cfg.Params[param]))		}	}	return buf.String()}// ParseDSN parses the DSN string to a Configfunc ParseDSN(dsn string) (cfg *Config, err error) {	// New config with some default values	cfg = NewConfig()	// [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN]	// Find the last '/' (since the password or the net addr might contain a '/')	foundSlash := false	for i := len(dsn) - 1; i >= 0; i-- {		if dsn[i] == '/' {			foundSlash = true			var j, k int			// left part is empty if i <= 0			if i > 0 {				// [username[:password]@][protocol[(address)]]				// Find the last '@' in dsn[:i]				for j = i; j >= 0; j-- {					if dsn[j] == '@' {						// username[:password]						// Find the first ':' in dsn[:j]						for k = 0; k < j; k++ {							if dsn[k] == ':' {								cfg.Passwd = dsn[k+1 : j]								break							}						}						cfg.User = dsn[:k]						break					}				}				// [protocol[(address)]]				// Find the first '(' in dsn[j+1:i]				for k = j + 1; k < i; k++ {					if dsn[k] == '(' {						// dsn[i-1] must be == ')' if an address is specified						if dsn[i-1] != ')' {							if strings.ContainsRune(dsn[k+1:i], ')') {								return nil, errInvalidDSNUnescaped							}							return nil, errInvalidDSNAddr						}						cfg.Addr = dsn[k+1 : i-1]						break					}				}				cfg.Net = dsn[j+1 : k]			}			// dbname[?param1=value1&...¶mN=valueN]			// Find the first '?' in dsn[i+1:]			for j = i + 1; j < len(dsn); j++ {				if dsn[j] == '?' {					if err = parseDSNParams(cfg, dsn[j+1:]); err != nil {						return					}					break				}			}			cfg.DBName = dsn[i+1 : j]			break		}	}	if !foundSlash && len(dsn) > 0 {		return nil, errInvalidDSNNoSlash	}	if err = cfg.normalize(); err != nil {		return nil, err	}	return}// parseDSNParams parses the DSN "query string"// Values must be url.QueryEscape'edfunc parseDSNParams(cfg *Config, params string) (err error) {	for _, v := range strings.Split(params, "&") {		param := strings.SplitN(v, "=", 2)		if len(param) != 2 {			continue		}		// cfg params		switch value := param[1]; param[0] {		// Disable INFILE whitelist / enable all files		case "allowAllFiles":			var isBool bool			cfg.AllowAllFiles, isBool = readBool(value)			if !isBool {				return errors.New("invalid bool value: " + value)			}		// Use cleartext authentication mode (MySQL 5.5.10+)		case "allowCleartextPasswords":			var isBool bool			cfg.AllowCleartextPasswords, isBool = readBool(value)			if !isBool {				return errors.New("invalid bool value: " + value)			}		// Use native password authentication		case "allowNativePasswords":			var isBool bool			cfg.AllowNativePasswords, isBool = readBool(value)			if !isBool {				return errors.New("invalid bool value: " + value)			}		// Use old authentication mode (pre MySQL 4.1)		case "allowOldPasswords":			var isBool bool			cfg.AllowOldPasswords, isBool = readBool(value)			if !isBool {				return errors.New("invalid bool value: " + value)			}		// Switch "rowsAffected" mode		case "clientFoundRows":			var isBool bool			cfg.ClientFoundRows, isBool = readBool(value)			if !isBool {				return errors.New("invalid bool value: " + value)			}		// Collation		case "collation":			cfg.Collation = value			break		case "columnsWithAlias":			var isBool bool			cfg.ColumnsWithAlias, isBool = readBool(value)			if !isBool {				return errors.New("invalid bool value: " + value)			}		// Compression		case "compress":			return errors.New("compression not implemented yet")		// Enable client side placeholder substitution		case "interpolateParams":			var isBool bool			cfg.InterpolateParams, isBool = readBool(value)			if !isBool {				return errors.New("invalid bool value: " + value)			}		// Time Location		case "loc":			if value, err = url.QueryUnescape(value); err != nil {				return			}			cfg.Loc, err = time.LoadLocation(value)			if err != nil {				return			}		// multiple statements in one query		case "multiStatements":			var isBool bool			cfg.MultiStatements, isBool = readBool(value)			if !isBool {				return errors.New("invalid bool value: " + value)			}		// time.Time parsing		case "parseTime":			var isBool bool			cfg.ParseTime, isBool = readBool(value)			if !isBool {				return errors.New("invalid bool value: " + value)			}		// I/O read Timeout		case "readTimeout":			cfg.ReadTimeout, err = time.ParseDuration(value)			if err != nil {				return			}		// Reject read-only connections		case "rejectReadOnly":			var isBool bool			cfg.RejectReadOnly, isBool = readBool(value)			if !isBool {				return errors.New("invalid bool value: " + value)			}		// Server public key		case "serverPubKey":			name, err := url.QueryUnescape(value)			if err != nil {				return fmt.Errorf("invalid value for server pub key name: %v", err)			}			if pubKey := getServerPubKey(name); pubKey != nil {				cfg.ServerPubKey = name				cfg.pubKey = pubKey			} else {				return errors.New("invalid value / unknown server pub key name: " + name)			}		// Strict mode		case "strict":			panic("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode")		// Dial Timeout		case "timeout":			cfg.Timeout, err = time.ParseDuration(value)			if err != nil {				return			}		// TLS-Encryption		case "tls":			boolValue, isBool := readBool(value)			if isBool {				if boolValue {					cfg.TLSConfig = "true"					cfg.tls = &tls.Config{}				} else {					cfg.TLSConfig = "false"				}			} else if vl := strings.ToLower(value); vl == "skip-verify" || vl == "preferred" {				cfg.TLSConfig = vl				cfg.tls = &tls.Config{InsecureSkipVerify: true}			} else {				name, err := url.QueryUnescape(value)				if err != nil {					return fmt.Errorf("invalid value for TLS config name: %v", err)				}				if tlsConfig := getTLSConfigClone(name); tlsConfig != nil {					cfg.TLSConfig = name					cfg.tls = tlsConfig				} else {					return errors.New("invalid value / unknown config name: " + name)				}			}		// I/O write Timeout		case "writeTimeout":			cfg.WriteTimeout, err = time.ParseDuration(value)			if err != nil {				return			}		case "maxAllowedPacket":			cfg.MaxAllowedPacket, err = strconv.Atoi(value)			if err != nil {				return			}		default:			// lazy init			if cfg.Params == nil {				cfg.Params = make(map[string]string)			}			if cfg.Params[param[0]], err = url.QueryUnescape(value); err != nil {				return			}		}	}	return}func ensureHavePort(addr string) string {	if _, _, err := net.SplitHostPort(addr); err != nil {		return net.JoinHostPort(addr, "3306")	}	return addr}
 |