| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586 | 
							- // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
 
- //
 
- // Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
 
- //
 
- // This Source Code Form is subject to the terms of the Mozilla Public
 
- // License, v. 2.0. If a copy of the MPL was not distributed with this file,
 
- // You can obtain one at http://mozilla.org/MPL/2.0/.
 
- package mysql
 
- import (
 
- 	"bytes"
 
- 	"crypto/tls"
 
- 	"errors"
 
- 	"fmt"
 
- 	"net"
 
- 	"net/url"
 
- 	"sort"
 
- 	"strconv"
 
- 	"strings"
 
- 	"time"
 
- )
 
- var (
 
- 	errInvalidDSNUnescaped       = errors.New("invalid DSN: did you forget to escape a param value?")
 
- 	errInvalidDSNAddr            = errors.New("invalid DSN: network address not terminated (missing closing brace)")
 
- 	errInvalidDSNNoSlash         = errors.New("invalid DSN: missing the slash separating the database name")
 
- 	errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations")
 
- )
 
- // Config is a configuration parsed from a DSN string.
 
- // If a new Config is created instead of being parsed from a DSN string,
 
- // the NewConfig function should be used, which sets default values.
 
- type Config struct {
 
- 	User             string            // Username
 
- 	Passwd           string            // Password (requires User)
 
- 	Net              string            // Network type
 
- 	Addr             string            // Network address (requires Net)
 
- 	DBName           string            // Database name
 
- 	Params           map[string]string // Connection parameters
 
- 	Collation        string            // Connection collation
 
- 	Loc              *time.Location    // Location for time.Time values
 
- 	MaxAllowedPacket int               // Max packet size allowed
 
- 	TLSConfig        string            // TLS configuration name
 
- 	tls              *tls.Config       // TLS configuration
 
- 	Timeout          time.Duration     // Dial timeout
 
- 	ReadTimeout      time.Duration     // I/O read timeout
 
- 	WriteTimeout     time.Duration     // I/O write timeout
 
- 	AllowAllFiles           bool // Allow all files to be used with LOAD DATA LOCAL INFILE
 
- 	AllowCleartextPasswords bool // Allows the cleartext client side plugin
 
- 	AllowNativePasswords    bool // Allows the native password authentication method
 
- 	AllowOldPasswords       bool // Allows the old insecure password method
 
- 	ClientFoundRows         bool // Return number of matching rows instead of rows changed
 
- 	ColumnsWithAlias        bool // Prepend table alias to column names
 
- 	InterpolateParams       bool // Interpolate placeholders into query string
 
- 	MultiStatements         bool // Allow multiple statements in one query
 
- 	ParseTime               bool // Parse time values to time.Time
 
- 	RejectReadOnly          bool // Reject read-only connections
 
- }
 
- // NewConfig creates a new Config and sets default values.
 
- func NewConfig() *Config {
 
- 	return &Config{
 
- 		Collation:            defaultCollation,
 
- 		Loc:                  time.UTC,
 
- 		MaxAllowedPacket:     defaultMaxAllowedPacket,
 
- 		AllowNativePasswords: true,
 
- 	}
 
- }
 
- func (cfg *Config) normalize() error {
 
- 	if cfg.InterpolateParams && unsafeCollations[cfg.Collation] {
 
- 		return errInvalidDSNUnsafeCollation
 
- 	}
 
- 	// Set default network if empty
 
- 	if cfg.Net == "" {
 
- 		cfg.Net = "tcp"
 
- 	}
 
- 	// Set default address if empty
 
- 	if cfg.Addr == "" {
 
- 		switch cfg.Net {
 
- 		case "tcp":
 
- 			cfg.Addr = "127.0.0.1:3306"
 
- 		case "unix":
 
- 			cfg.Addr = "/tmp/mysql.sock"
 
- 		default:
 
- 			return errors.New("default addr for network '" + cfg.Net + "' unknown")
 
- 		}
 
- 	} else if cfg.Net == "tcp" {
 
- 		cfg.Addr = ensureHavePort(cfg.Addr)
 
- 	}
 
- 	return nil
 
- }
 
- // FormatDSN formats the given Config into a DSN string which can be passed to
 
- // the driver.
 
- func (cfg *Config) FormatDSN() string {
 
- 	var buf bytes.Buffer
 
- 	// [username[:password]@]
 
- 	if len(cfg.User) > 0 {
 
- 		buf.WriteString(cfg.User)
 
- 		if len(cfg.Passwd) > 0 {
 
- 			buf.WriteByte(':')
 
- 			buf.WriteString(cfg.Passwd)
 
- 		}
 
- 		buf.WriteByte('@')
 
- 	}
 
- 	// [protocol[(address)]]
 
- 	if len(cfg.Net) > 0 {
 
- 		buf.WriteString(cfg.Net)
 
- 		if len(cfg.Addr) > 0 {
 
- 			buf.WriteByte('(')
 
- 			buf.WriteString(cfg.Addr)
 
- 			buf.WriteByte(')')
 
- 		}
 
- 	}
 
- 	// /dbname
 
- 	buf.WriteByte('/')
 
- 	buf.WriteString(cfg.DBName)
 
- 	// [?param1=value1&...¶mN=valueN]
 
- 	hasParam := false
 
- 	if cfg.AllowAllFiles {
 
- 		hasParam = true
 
- 		buf.WriteString("?allowAllFiles=true")
 
- 	}
 
- 	if cfg.AllowCleartextPasswords {
 
- 		if hasParam {
 
- 			buf.WriteString("&allowCleartextPasswords=true")
 
- 		} else {
 
- 			hasParam = true
 
- 			buf.WriteString("?allowCleartextPasswords=true")
 
- 		}
 
- 	}
 
- 	if !cfg.AllowNativePasswords {
 
- 		if hasParam {
 
- 			buf.WriteString("&allowNativePasswords=false")
 
- 		} else {
 
- 			hasParam = true
 
- 			buf.WriteString("?allowNativePasswords=false")
 
- 		}
 
- 	}
 
- 	if cfg.AllowOldPasswords {
 
- 		if hasParam {
 
- 			buf.WriteString("&allowOldPasswords=true")
 
- 		} else {
 
- 			hasParam = true
 
- 			buf.WriteString("?allowOldPasswords=true")
 
- 		}
 
- 	}
 
- 	if cfg.ClientFoundRows {
 
- 		if hasParam {
 
- 			buf.WriteString("&clientFoundRows=true")
 
- 		} else {
 
- 			hasParam = true
 
- 			buf.WriteString("?clientFoundRows=true")
 
- 		}
 
- 	}
 
- 	if col := cfg.Collation; col != defaultCollation && len(col) > 0 {
 
- 		if hasParam {
 
- 			buf.WriteString("&collation=")
 
- 		} else {
 
- 			hasParam = true
 
- 			buf.WriteString("?collation=")
 
- 		}
 
- 		buf.WriteString(col)
 
- 	}
 
- 	if cfg.ColumnsWithAlias {
 
- 		if hasParam {
 
- 			buf.WriteString("&columnsWithAlias=true")
 
- 		} else {
 
- 			hasParam = true
 
- 			buf.WriteString("?columnsWithAlias=true")
 
- 		}
 
- 	}
 
- 	if cfg.InterpolateParams {
 
- 		if hasParam {
 
- 			buf.WriteString("&interpolateParams=true")
 
- 		} else {
 
- 			hasParam = true
 
- 			buf.WriteString("?interpolateParams=true")
 
- 		}
 
- 	}
 
- 	if cfg.Loc != time.UTC && cfg.Loc != nil {
 
- 		if hasParam {
 
- 			buf.WriteString("&loc=")
 
- 		} else {
 
- 			hasParam = true
 
- 			buf.WriteString("?loc=")
 
- 		}
 
- 		buf.WriteString(url.QueryEscape(cfg.Loc.String()))
 
- 	}
 
- 	if cfg.MultiStatements {
 
- 		if hasParam {
 
- 			buf.WriteString("&multiStatements=true")
 
- 		} else {
 
- 			hasParam = true
 
- 			buf.WriteString("?multiStatements=true")
 
- 		}
 
- 	}
 
- 	if cfg.ParseTime {
 
- 		if hasParam {
 
- 			buf.WriteString("&parseTime=true")
 
- 		} else {
 
- 			hasParam = true
 
- 			buf.WriteString("?parseTime=true")
 
- 		}
 
- 	}
 
- 	if cfg.ReadTimeout > 0 {
 
- 		if hasParam {
 
- 			buf.WriteString("&readTimeout=")
 
- 		} else {
 
- 			hasParam = true
 
- 			buf.WriteString("?readTimeout=")
 
- 		}
 
- 		buf.WriteString(cfg.ReadTimeout.String())
 
- 	}
 
- 	if cfg.RejectReadOnly {
 
- 		if hasParam {
 
- 			buf.WriteString("&rejectReadOnly=true")
 
- 		} else {
 
- 			hasParam = true
 
- 			buf.WriteString("?rejectReadOnly=true")
 
- 		}
 
- 	}
 
- 	if cfg.Timeout > 0 {
 
- 		if hasParam {
 
- 			buf.WriteString("&timeout=")
 
- 		} else {
 
- 			hasParam = true
 
- 			buf.WriteString("?timeout=")
 
- 		}
 
- 		buf.WriteString(cfg.Timeout.String())
 
- 	}
 
- 	if len(cfg.TLSConfig) > 0 {
 
- 		if hasParam {
 
- 			buf.WriteString("&tls=")
 
- 		} else {
 
- 			hasParam = true
 
- 			buf.WriteString("?tls=")
 
- 		}
 
- 		buf.WriteString(url.QueryEscape(cfg.TLSConfig))
 
- 	}
 
- 	if cfg.WriteTimeout > 0 {
 
- 		if hasParam {
 
- 			buf.WriteString("&writeTimeout=")
 
- 		} else {
 
- 			hasParam = true
 
- 			buf.WriteString("?writeTimeout=")
 
- 		}
 
- 		buf.WriteString(cfg.WriteTimeout.String())
 
- 	}
 
- 	if cfg.MaxAllowedPacket != defaultMaxAllowedPacket {
 
- 		if hasParam {
 
- 			buf.WriteString("&maxAllowedPacket=")
 
- 		} else {
 
- 			hasParam = true
 
- 			buf.WriteString("?maxAllowedPacket=")
 
- 		}
 
- 		buf.WriteString(strconv.Itoa(cfg.MaxAllowedPacket))
 
- 	}
 
- 	// other params
 
- 	if cfg.Params != nil {
 
- 		var params []string
 
- 		for param := range cfg.Params {
 
- 			params = append(params, param)
 
- 		}
 
- 		sort.Strings(params)
 
- 		for _, param := range params {
 
- 			if hasParam {
 
- 				buf.WriteByte('&')
 
- 			} else {
 
- 				hasParam = true
 
- 				buf.WriteByte('?')
 
- 			}
 
- 			buf.WriteString(param)
 
- 			buf.WriteByte('=')
 
- 			buf.WriteString(url.QueryEscape(cfg.Params[param]))
 
- 		}
 
- 	}
 
- 	return buf.String()
 
- }
 
- // ParseDSN parses the DSN string to a Config
 
- func ParseDSN(dsn string) (cfg *Config, err error) {
 
- 	// New config with some default values
 
- 	cfg = NewConfig()
 
- 	// [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN]
 
- 	// Find the last '/' (since the password or the net addr might contain a '/')
 
- 	foundSlash := false
 
- 	for i := len(dsn) - 1; i >= 0; i-- {
 
- 		if dsn[i] == '/' {
 
- 			foundSlash = true
 
- 			var j, k int
 
- 			// left part is empty if i <= 0
 
- 			if i > 0 {
 
- 				// [username[:password]@][protocol[(address)]]
 
- 				// Find the last '@' in dsn[:i]
 
- 				for j = i; j >= 0; j-- {
 
- 					if dsn[j] == '@' {
 
- 						// username[:password]
 
- 						// Find the first ':' in dsn[:j]
 
- 						for k = 0; k < j; k++ {
 
- 							if dsn[k] == ':' {
 
- 								cfg.Passwd = dsn[k+1 : j]
 
- 								break
 
- 							}
 
- 						}
 
- 						cfg.User = dsn[:k]
 
- 						break
 
- 					}
 
- 				}
 
- 				// [protocol[(address)]]
 
- 				// Find the first '(' in dsn[j+1:i]
 
- 				for k = j + 1; k < i; k++ {
 
- 					if dsn[k] == '(' {
 
- 						// dsn[i-1] must be == ')' if an address is specified
 
- 						if dsn[i-1] != ')' {
 
- 							if strings.ContainsRune(dsn[k+1:i], ')') {
 
- 								return nil, errInvalidDSNUnescaped
 
- 							}
 
- 							return nil, errInvalidDSNAddr
 
- 						}
 
- 						cfg.Addr = dsn[k+1 : i-1]
 
- 						break
 
- 					}
 
- 				}
 
- 				cfg.Net = dsn[j+1 : k]
 
- 			}
 
- 			// dbname[?param1=value1&...¶mN=valueN]
 
- 			// Find the first '?' in dsn[i+1:]
 
- 			for j = i + 1; j < len(dsn); j++ {
 
- 				if dsn[j] == '?' {
 
- 					if err = parseDSNParams(cfg, dsn[j+1:]); err != nil {
 
- 						return
 
- 					}
 
- 					break
 
- 				}
 
- 			}
 
- 			cfg.DBName = dsn[i+1 : j]
 
- 			break
 
- 		}
 
- 	}
 
- 	if !foundSlash && len(dsn) > 0 {
 
- 		return nil, errInvalidDSNNoSlash
 
- 	}
 
- 	if err = cfg.normalize(); err != nil {
 
- 		return nil, err
 
- 	}
 
- 	return
 
- }
 
- // parseDSNParams parses the DSN "query string"
 
- // Values must be url.QueryEscape'ed
 
- func parseDSNParams(cfg *Config, params string) (err error) {
 
- 	for _, v := range strings.Split(params, "&") {
 
- 		param := strings.SplitN(v, "=", 2)
 
- 		if len(param) != 2 {
 
- 			continue
 
- 		}
 
- 		// cfg params
 
- 		switch value := param[1]; param[0] {
 
- 		// Disable INFILE whitelist / enable all files
 
- 		case "allowAllFiles":
 
- 			var isBool bool
 
- 			cfg.AllowAllFiles, isBool = readBool(value)
 
- 			if !isBool {
 
- 				return errors.New("invalid bool value: " + value)
 
- 			}
 
- 		// Use cleartext authentication mode (MySQL 5.5.10+)
 
- 		case "allowCleartextPasswords":
 
- 			var isBool bool
 
- 			cfg.AllowCleartextPasswords, isBool = readBool(value)
 
- 			if !isBool {
 
- 				return errors.New("invalid bool value: " + value)
 
- 			}
 
- 		// Use native password authentication
 
- 		case "allowNativePasswords":
 
- 			var isBool bool
 
- 			cfg.AllowNativePasswords, isBool = readBool(value)
 
- 			if !isBool {
 
- 				return errors.New("invalid bool value: " + value)
 
- 			}
 
- 		// Use old authentication mode (pre MySQL 4.1)
 
- 		case "allowOldPasswords":
 
- 			var isBool bool
 
- 			cfg.AllowOldPasswords, isBool = readBool(value)
 
- 			if !isBool {
 
- 				return errors.New("invalid bool value: " + value)
 
- 			}
 
- 		// Switch "rowsAffected" mode
 
- 		case "clientFoundRows":
 
- 			var isBool bool
 
- 			cfg.ClientFoundRows, isBool = readBool(value)
 
- 			if !isBool {
 
- 				return errors.New("invalid bool value: " + value)
 
- 			}
 
- 		// Collation
 
- 		case "collation":
 
- 			cfg.Collation = value
 
- 			break
 
- 		case "columnsWithAlias":
 
- 			var isBool bool
 
- 			cfg.ColumnsWithAlias, isBool = readBool(value)
 
- 			if !isBool {
 
- 				return errors.New("invalid bool value: " + value)
 
- 			}
 
- 		// Compression
 
- 		case "compress":
 
- 			return errors.New("compression not implemented yet")
 
- 		// Enable client side placeholder substitution
 
- 		case "interpolateParams":
 
- 			var isBool bool
 
- 			cfg.InterpolateParams, isBool = readBool(value)
 
- 			if !isBool {
 
- 				return errors.New("invalid bool value: " + value)
 
- 			}
 
- 		// Time Location
 
- 		case "loc":
 
- 			if value, err = url.QueryUnescape(value); err != nil {
 
- 				return
 
- 			}
 
- 			cfg.Loc, err = time.LoadLocation(value)
 
- 			if err != nil {
 
- 				return
 
- 			}
 
- 		// multiple statements in one query
 
- 		case "multiStatements":
 
- 			var isBool bool
 
- 			cfg.MultiStatements, isBool = readBool(value)
 
- 			if !isBool {
 
- 				return errors.New("invalid bool value: " + value)
 
- 			}
 
- 		// time.Time parsing
 
- 		case "parseTime":
 
- 			var isBool bool
 
- 			cfg.ParseTime, isBool = readBool(value)
 
- 			if !isBool {
 
- 				return errors.New("invalid bool value: " + value)
 
- 			}
 
- 		// I/O read Timeout
 
- 		case "readTimeout":
 
- 			cfg.ReadTimeout, err = time.ParseDuration(value)
 
- 			if err != nil {
 
- 				return
 
- 			}
 
- 		// Reject read-only connections
 
- 		case "rejectReadOnly":
 
- 			var isBool bool
 
- 			cfg.RejectReadOnly, isBool = readBool(value)
 
- 			if !isBool {
 
- 				return errors.New("invalid bool value: " + value)
 
- 			}
 
- 		// Strict mode
 
- 		case "strict":
 
- 			panic("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode")
 
- 		// Dial Timeout
 
- 		case "timeout":
 
- 			cfg.Timeout, err = time.ParseDuration(value)
 
- 			if err != nil {
 
- 				return
 
- 			}
 
- 		// TLS-Encryption
 
- 		case "tls":
 
- 			boolValue, isBool := readBool(value)
 
- 			if isBool {
 
- 				if boolValue {
 
- 					cfg.TLSConfig = "true"
 
- 					cfg.tls = &tls.Config{}
 
- 					host, _, err := net.SplitHostPort(cfg.Addr)
 
- 					if err == nil {
 
- 						cfg.tls.ServerName = host
 
- 					}
 
- 				} else {
 
- 					cfg.TLSConfig = "false"
 
- 				}
 
- 			} else if vl := strings.ToLower(value); vl == "skip-verify" {
 
- 				cfg.TLSConfig = vl
 
- 				cfg.tls = &tls.Config{InsecureSkipVerify: true}
 
- 			} else {
 
- 				name, err := url.QueryUnescape(value)
 
- 				if err != nil {
 
- 					return fmt.Errorf("invalid value for TLS config name: %v", err)
 
- 				}
 
- 				if tlsConfig := getTLSConfigClone(name); tlsConfig != nil {
 
- 					if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify {
 
- 						host, _, err := net.SplitHostPort(cfg.Addr)
 
- 						if err == nil {
 
- 							tlsConfig.ServerName = host
 
- 						}
 
- 					}
 
- 					cfg.TLSConfig = name
 
- 					cfg.tls = tlsConfig
 
- 				} else {
 
- 					return errors.New("invalid value / unknown config name: " + name)
 
- 				}
 
- 			}
 
- 		// I/O write Timeout
 
- 		case "writeTimeout":
 
- 			cfg.WriteTimeout, err = time.ParseDuration(value)
 
- 			if err != nil {
 
- 				return
 
- 			}
 
- 		case "maxAllowedPacket":
 
- 			cfg.MaxAllowedPacket, err = strconv.Atoi(value)
 
- 			if err != nil {
 
- 				return
 
- 			}
 
- 		default:
 
- 			// lazy init
 
- 			if cfg.Params == nil {
 
- 				cfg.Params = make(map[string]string)
 
- 			}
 
- 			if cfg.Params[param[0]], err = url.QueryUnescape(value); err != nil {
 
- 				return
 
- 			}
 
- 		}
 
- 	}
 
- 	return
 
- }
 
- func ensureHavePort(addr string) string {
 
- 	if _, _, err := net.SplitHostPort(addr); err != nil {
 
- 		return net.JoinHostPort(addr, "3306")
 
- 	}
 
- 	return addr
 
- }
 
 
  |