|
@@ -30,9 +30,12 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
|
|
// read packet header
|
|
// read packet header
|
|
data, err := mc.buf.readNext(4)
|
|
data, err := mc.buf.readNext(4)
|
|
if err != nil {
|
|
if err != nil {
|
|
|
|
+ if cerr := mc.canceled.Value(); cerr != nil {
|
|
|
|
+ return nil, cerr
|
|
|
|
+ }
|
|
errLog.Print(err)
|
|
errLog.Print(err)
|
|
mc.Close()
|
|
mc.Close()
|
|
- return nil, driver.ErrBadConn
|
|
|
|
|
|
+ return nil, ErrInvalidConn
|
|
}
|
|
}
|
|
|
|
|
|
// packet length [24 bit]
|
|
// packet length [24 bit]
|
|
@@ -54,7 +57,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
|
|
if prevData == nil {
|
|
if prevData == nil {
|
|
errLog.Print(ErrMalformPkt)
|
|
errLog.Print(ErrMalformPkt)
|
|
mc.Close()
|
|
mc.Close()
|
|
- return nil, driver.ErrBadConn
|
|
|
|
|
|
+ return nil, ErrInvalidConn
|
|
}
|
|
}
|
|
|
|
|
|
return prevData, nil
|
|
return prevData, nil
|
|
@@ -63,9 +66,12 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
|
|
// read packet body [pktLen bytes]
|
|
// read packet body [pktLen bytes]
|
|
data, err = mc.buf.readNext(pktLen)
|
|
data, err = mc.buf.readNext(pktLen)
|
|
if err != nil {
|
|
if err != nil {
|
|
|
|
+ if cerr := mc.canceled.Value(); cerr != nil {
|
|
|
|
+ return nil, cerr
|
|
|
|
+ }
|
|
errLog.Print(err)
|
|
errLog.Print(err)
|
|
mc.Close()
|
|
mc.Close()
|
|
- return nil, driver.ErrBadConn
|
|
|
|
|
|
+ return nil, ErrInvalidConn
|
|
}
|
|
}
|
|
|
|
|
|
// return data if this was the last packet
|
|
// return data if this was the last packet
|
|
@@ -125,11 +131,20 @@ func (mc *mysqlConn) writePacket(data []byte) error {
|
|
|
|
|
|
// Handle error
|
|
// Handle error
|
|
if err == nil { // n != len(data)
|
|
if err == nil { // n != len(data)
|
|
|
|
+ mc.cleanup()
|
|
errLog.Print(ErrMalformPkt)
|
|
errLog.Print(ErrMalformPkt)
|
|
} else {
|
|
} else {
|
|
|
|
+ if cerr := mc.canceled.Value(); cerr != nil {
|
|
|
|
+ return cerr
|
|
|
|
+ }
|
|
|
|
+ if n == 0 && pktLen == len(data)-4 {
|
|
|
|
+ // only for the first loop iteration when nothing was written yet
|
|
|
|
+ return errBadConnNoWrite
|
|
|
|
+ }
|
|
|
|
+ mc.cleanup()
|
|
errLog.Print(err)
|
|
errLog.Print(err)
|
|
}
|
|
}
|
|
- return driver.ErrBadConn
|
|
|
|
|
|
+ return ErrInvalidConn
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
@@ -263,7 +278,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
|
|
if data == nil {
|
|
if data == nil {
|
|
// can not take the buffer. Something must be wrong with the connection
|
|
// can not take the buffer. Something must be wrong with the connection
|
|
errLog.Print(ErrBusyBuffer)
|
|
errLog.Print(ErrBusyBuffer)
|
|
- return driver.ErrBadConn
|
|
|
|
|
|
+ return errBadConnNoWrite
|
|
}
|
|
}
|
|
|
|
|
|
// ClientFlags [32 bit]
|
|
// ClientFlags [32 bit]
|
|
@@ -341,7 +356,9 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
|
|
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
|
|
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
|
|
func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {
|
|
func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {
|
|
// User password
|
|
// User password
|
|
- scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.Passwd))
|
|
|
|
|
|
+ // https://dev.mysql.com/doc/internals/en/old-password-authentication.html
|
|
|
|
+ // Old password authentication only need and will need 8-byte challenge.
|
|
|
|
+ scrambleBuff := scrambleOldPassword(cipher[:8], []byte(mc.cfg.Passwd))
|
|
|
|
|
|
// Calculate the packet length and add a tailing 0
|
|
// Calculate the packet length and add a tailing 0
|
|
pktLen := len(scrambleBuff) + 1
|
|
pktLen := len(scrambleBuff) + 1
|
|
@@ -349,7 +366,7 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {
|
|
if data == nil {
|
|
if data == nil {
|
|
// can not take the buffer. Something must be wrong with the connection
|
|
// can not take the buffer. Something must be wrong with the connection
|
|
errLog.Print(ErrBusyBuffer)
|
|
errLog.Print(ErrBusyBuffer)
|
|
- return driver.ErrBadConn
|
|
|
|
|
|
+ return errBadConnNoWrite
|
|
}
|
|
}
|
|
|
|
|
|
// Add the scrambled password [null terminated string]
|
|
// Add the scrambled password [null terminated string]
|
|
@@ -368,7 +385,7 @@ func (mc *mysqlConn) writeClearAuthPacket() error {
|
|
if data == nil {
|
|
if data == nil {
|
|
// can not take the buffer. Something must be wrong with the connection
|
|
// can not take the buffer. Something must be wrong with the connection
|
|
errLog.Print(ErrBusyBuffer)
|
|
errLog.Print(ErrBusyBuffer)
|
|
- return driver.ErrBadConn
|
|
|
|
|
|
+ return errBadConnNoWrite
|
|
}
|
|
}
|
|
|
|
|
|
// Add the clear password [null terminated string]
|
|
// Add the clear password [null terminated string]
|
|
@@ -381,7 +398,9 @@ func (mc *mysqlConn) writeClearAuthPacket() error {
|
|
// Native password authentication method
|
|
// Native password authentication method
|
|
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
|
|
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
|
|
func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error {
|
|
func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error {
|
|
- scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd))
|
|
|
|
|
|
+ // https://dev.mysql.com/doc/internals/en/secure-password-authentication.html
|
|
|
|
+ // Native password authentication only need and will need 20-byte challenge.
|
|
|
|
+ scrambleBuff := scramblePassword(cipher[0:20], []byte(mc.cfg.Passwd))
|
|
|
|
|
|
// Calculate the packet length and add a tailing 0
|
|
// Calculate the packet length and add a tailing 0
|
|
pktLen := len(scrambleBuff)
|
|
pktLen := len(scrambleBuff)
|
|
@@ -389,7 +408,7 @@ func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error {
|
|
if data == nil {
|
|
if data == nil {
|
|
// can not take the buffer. Something must be wrong with the connection
|
|
// can not take the buffer. Something must be wrong with the connection
|
|
errLog.Print(ErrBusyBuffer)
|
|
errLog.Print(ErrBusyBuffer)
|
|
- return driver.ErrBadConn
|
|
|
|
|
|
+ return errBadConnNoWrite
|
|
}
|
|
}
|
|
|
|
|
|
// Add the scramble
|
|
// Add the scramble
|
|
@@ -410,7 +429,7 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
|
|
if data == nil {
|
|
if data == nil {
|
|
// can not take the buffer. Something must be wrong with the connection
|
|
// can not take the buffer. Something must be wrong with the connection
|
|
errLog.Print(ErrBusyBuffer)
|
|
errLog.Print(ErrBusyBuffer)
|
|
- return driver.ErrBadConn
|
|
|
|
|
|
+ return errBadConnNoWrite
|
|
}
|
|
}
|
|
|
|
|
|
// Add command byte
|
|
// Add command byte
|
|
@@ -429,7 +448,7 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
|
|
if data == nil {
|
|
if data == nil {
|
|
// can not take the buffer. Something must be wrong with the connection
|
|
// can not take the buffer. Something must be wrong with the connection
|
|
errLog.Print(ErrBusyBuffer)
|
|
errLog.Print(ErrBusyBuffer)
|
|
- return driver.ErrBadConn
|
|
|
|
|
|
+ return errBadConnNoWrite
|
|
}
|
|
}
|
|
|
|
|
|
// Add command byte
|
|
// Add command byte
|
|
@@ -450,7 +469,7 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
|
|
if data == nil {
|
|
if data == nil {
|
|
// can not take the buffer. Something must be wrong with the connection
|
|
// can not take the buffer. Something must be wrong with the connection
|
|
errLog.Print(ErrBusyBuffer)
|
|
errLog.Print(ErrBusyBuffer)
|
|
- return driver.ErrBadConn
|
|
|
|
|
|
+ return errBadConnNoWrite
|
|
}
|
|
}
|
|
|
|
|
|
// Add command byte
|
|
// Add command byte
|
|
@@ -484,25 +503,26 @@ func (mc *mysqlConn) readResultOK() ([]byte, error) {
|
|
if len(data) > 1 {
|
|
if len(data) > 1 {
|
|
pluginEndIndex := bytes.IndexByte(data, 0x00)
|
|
pluginEndIndex := bytes.IndexByte(data, 0x00)
|
|
plugin := string(data[1:pluginEndIndex])
|
|
plugin := string(data[1:pluginEndIndex])
|
|
- cipher := data[pluginEndIndex+1 : len(data)-1]
|
|
|
|
|
|
+ cipher := data[pluginEndIndex+1:]
|
|
|
|
|
|
- if plugin == "mysql_old_password" {
|
|
|
|
|
|
+ switch plugin {
|
|
|
|
+ case "mysql_old_password":
|
|
// using old_passwords
|
|
// using old_passwords
|
|
return cipher, ErrOldPassword
|
|
return cipher, ErrOldPassword
|
|
- } else if plugin == "mysql_clear_password" {
|
|
|
|
|
|
+ case "mysql_clear_password":
|
|
// using clear text password
|
|
// using clear text password
|
|
return cipher, ErrCleartextPassword
|
|
return cipher, ErrCleartextPassword
|
|
- } else if plugin == "mysql_native_password" {
|
|
|
|
|
|
+ case "mysql_native_password":
|
|
// using mysql default authentication method
|
|
// using mysql default authentication method
|
|
return cipher, ErrNativePassword
|
|
return cipher, ErrNativePassword
|
|
- } else {
|
|
|
|
|
|
+ default:
|
|
return cipher, ErrUnknownPlugin
|
|
return cipher, ErrUnknownPlugin
|
|
}
|
|
}
|
|
- } else {
|
|
|
|
- // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest
|
|
|
|
- return nil, ErrOldPassword
|
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest
|
|
|
|
+ return nil, ErrOldPassword
|
|
|
|
+
|
|
default: // Error otherwise
|
|
default: // Error otherwise
|
|
return nil, mc.handleErrorPacket(data)
|
|
return nil, mc.handleErrorPacket(data)
|
|
}
|
|
}
|
|
@@ -550,6 +570,22 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error {
|
|
// Error Number [16 bit uint]
|
|
// Error Number [16 bit uint]
|
|
errno := binary.LittleEndian.Uint16(data[1:3])
|
|
errno := binary.LittleEndian.Uint16(data[1:3])
|
|
|
|
|
|
|
|
+ // 1792: ER_CANT_EXECUTE_IN_READ_ONLY_TRANSACTION
|
|
|
|
+ // 1290: ER_OPTION_PREVENTS_STATEMENT (returned by Aurora during failover)
|
|
|
|
+ if (errno == 1792 || errno == 1290) && mc.cfg.RejectReadOnly {
|
|
|
|
+ // Oops; we are connected to a read-only connection, and won't be able
|
|
|
|
+ // to issue any write statements. Since RejectReadOnly is configured,
|
|
|
|
+ // we throw away this connection hoping this one would have write
|
|
|
|
+ // permission. This is specifically for a possible race condition
|
|
|
|
+ // during failover (e.g. on AWS Aurora). See README.md for more.
|
|
|
|
+ //
|
|
|
|
+ // We explicitly close the connection before returning
|
|
|
|
+ // driver.ErrBadConn to ensure that `database/sql` purges this
|
|
|
|
+ // connection and initiates a new one for next statement next time.
|
|
|
|
+ mc.Close()
|
|
|
|
+ return driver.ErrBadConn
|
|
|
|
+ }
|
|
|
|
+
|
|
pos := 3
|
|
pos := 3
|
|
|
|
|
|
// SQL State [optional: # + 5bytes string]
|
|
// SQL State [optional: # + 5bytes string]
|
|
@@ -584,19 +620,12 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
|
|
|
|
|
|
// server_status [2 bytes]
|
|
// server_status [2 bytes]
|
|
mc.status = readStatus(data[1+n+m : 1+n+m+2])
|
|
mc.status = readStatus(data[1+n+m : 1+n+m+2])
|
|
- if err := mc.discardResults(); err != nil {
|
|
|
|
- return err
|
|
|
|
|
|
+ if mc.status&statusMoreResultsExists != 0 {
|
|
|
|
+ return nil
|
|
}
|
|
}
|
|
|
|
|
|
// warning count [2 bytes]
|
|
// warning count [2 bytes]
|
|
- if !mc.strict {
|
|
|
|
- return nil
|
|
|
|
- }
|
|
|
|
|
|
|
|
- pos := 1 + n + m + 2
|
|
|
|
- if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 {
|
|
|
|
- return mc.getWarnings()
|
|
|
|
- }
|
|
|
|
return nil
|
|
return nil
|
|
}
|
|
}
|
|
|
|
|
|
@@ -671,11 +700,14 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
|
|
|
|
|
|
// Filler [uint8]
|
|
// Filler [uint8]
|
|
// Charset [charset, collation uint8]
|
|
// Charset [charset, collation uint8]
|
|
|
|
+ pos += n + 1 + 2
|
|
|
|
+
|
|
// Length [uint32]
|
|
// Length [uint32]
|
|
- pos += n + 1 + 2 + 4
|
|
|
|
|
|
+ columns[i].length = binary.LittleEndian.Uint32(data[pos : pos+4])
|
|
|
|
+ pos += 4
|
|
|
|
|
|
// Field type [uint8]
|
|
// Field type [uint8]
|
|
- columns[i].fieldType = data[pos]
|
|
|
|
|
|
+ columns[i].fieldType = fieldType(data[pos])
|
|
pos++
|
|
pos++
|
|
|
|
|
|
// Flags [uint16]
|
|
// Flags [uint16]
|
|
@@ -698,6 +730,10 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
|
|
func (rows *textRows) readRow(dest []driver.Value) error {
|
|
func (rows *textRows) readRow(dest []driver.Value) error {
|
|
mc := rows.mc
|
|
mc := rows.mc
|
|
|
|
|
|
|
|
+ if rows.rs.done {
|
|
|
|
+ return io.EOF
|
|
|
|
+ }
|
|
|
|
+
|
|
data, err := mc.readPacket()
|
|
data, err := mc.readPacket()
|
|
if err != nil {
|
|
if err != nil {
|
|
return err
|
|
return err
|
|
@@ -707,15 +743,11 @@ func (rows *textRows) readRow(dest []driver.Value) error {
|
|
if data[0] == iEOF && len(data) == 5 {
|
|
if data[0] == iEOF && len(data) == 5 {
|
|
// server_status [2 bytes]
|
|
// server_status [2 bytes]
|
|
rows.mc.status = readStatus(data[3:])
|
|
rows.mc.status = readStatus(data[3:])
|
|
- err = rows.mc.discardResults()
|
|
|
|
- if err == nil {
|
|
|
|
- err = io.EOF
|
|
|
|
- } else {
|
|
|
|
- // connection unusable
|
|
|
|
- rows.mc.Close()
|
|
|
|
|
|
+ rows.rs.done = true
|
|
|
|
+ if !rows.HasNextResultSet() {
|
|
|
|
+ rows.mc = nil
|
|
}
|
|
}
|
|
- rows.mc = nil
|
|
|
|
- return err
|
|
|
|
|
|
+ return io.EOF
|
|
}
|
|
}
|
|
if data[0] == iERR {
|
|
if data[0] == iERR {
|
|
rows.mc = nil
|
|
rows.mc = nil
|
|
@@ -736,7 +768,7 @@ func (rows *textRows) readRow(dest []driver.Value) error {
|
|
if !mc.parseTime {
|
|
if !mc.parseTime {
|
|
continue
|
|
continue
|
|
} else {
|
|
} else {
|
|
- switch rows.columns[i].fieldType {
|
|
|
|
|
|
+ switch rows.rs.columns[i].fieldType {
|
|
case fieldTypeTimestamp, fieldTypeDateTime,
|
|
case fieldTypeTimestamp, fieldTypeDateTime,
|
|
fieldTypeDate, fieldTypeNewDate:
|
|
fieldTypeDate, fieldTypeNewDate:
|
|
dest[i], err = parseDateTime(
|
|
dest[i], err = parseDateTime(
|
|
@@ -808,14 +840,7 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
|
|
// Reserved [8 bit]
|
|
// Reserved [8 bit]
|
|
|
|
|
|
// Warning count [16 bit uint]
|
|
// Warning count [16 bit uint]
|
|
- if !stmt.mc.strict {
|
|
|
|
- return columnCount, nil
|
|
|
|
- }
|
|
|
|
|
|
|
|
- // Check for warnings count > 0, only available in MySQL > 4.1
|
|
|
|
- if len(data) >= 12 && binary.LittleEndian.Uint16(data[10:12]) > 0 {
|
|
|
|
- return columnCount, stmt.mc.getWarnings()
|
|
|
|
- }
|
|
|
|
return columnCount, nil
|
|
return columnCount, nil
|
|
}
|
|
}
|
|
return 0, err
|
|
return 0, err
|
|
@@ -900,7 +925,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
|
if data == nil {
|
|
if data == nil {
|
|
// can not take the buffer. Something must be wrong with the connection
|
|
// can not take the buffer. Something must be wrong with the connection
|
|
errLog.Print(ErrBusyBuffer)
|
|
errLog.Print(ErrBusyBuffer)
|
|
- return driver.ErrBadConn
|
|
|
|
|
|
+ return errBadConnNoWrite
|
|
}
|
|
}
|
|
|
|
|
|
// command [1 byte]
|
|
// command [1 byte]
|
|
@@ -959,7 +984,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
|
// build NULL-bitmap
|
|
// build NULL-bitmap
|
|
if arg == nil {
|
|
if arg == nil {
|
|
nullMask[i/8] |= 1 << (uint(i) & 7)
|
|
nullMask[i/8] |= 1 << (uint(i) & 7)
|
|
- paramTypes[i+i] = fieldTypeNULL
|
|
|
|
|
|
+ paramTypes[i+i] = byte(fieldTypeNULL)
|
|
paramTypes[i+i+1] = 0x00
|
|
paramTypes[i+i+1] = 0x00
|
|
continue
|
|
continue
|
|
}
|
|
}
|
|
@@ -967,7 +992,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
|
// cache types and values
|
|
// cache types and values
|
|
switch v := arg.(type) {
|
|
switch v := arg.(type) {
|
|
case int64:
|
|
case int64:
|
|
- paramTypes[i+i] = fieldTypeLongLong
|
|
|
|
|
|
+ paramTypes[i+i] = byte(fieldTypeLongLong)
|
|
paramTypes[i+i+1] = 0x00
|
|
paramTypes[i+i+1] = 0x00
|
|
|
|
|
|
if cap(paramValues)-len(paramValues)-8 >= 0 {
|
|
if cap(paramValues)-len(paramValues)-8 >= 0 {
|
|
@@ -983,7 +1008,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
|
}
|
|
}
|
|
|
|
|
|
case float64:
|
|
case float64:
|
|
- paramTypes[i+i] = fieldTypeDouble
|
|
|
|
|
|
+ paramTypes[i+i] = byte(fieldTypeDouble)
|
|
paramTypes[i+i+1] = 0x00
|
|
paramTypes[i+i+1] = 0x00
|
|
|
|
|
|
if cap(paramValues)-len(paramValues)-8 >= 0 {
|
|
if cap(paramValues)-len(paramValues)-8 >= 0 {
|
|
@@ -999,7 +1024,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
|
}
|
|
}
|
|
|
|
|
|
case bool:
|
|
case bool:
|
|
- paramTypes[i+i] = fieldTypeTiny
|
|
|
|
|
|
+ paramTypes[i+i] = byte(fieldTypeTiny)
|
|
paramTypes[i+i+1] = 0x00
|
|
paramTypes[i+i+1] = 0x00
|
|
|
|
|
|
if v {
|
|
if v {
|
|
@@ -1011,7 +1036,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
|
case []byte:
|
|
case []byte:
|
|
// Common case (non-nil value) first
|
|
// Common case (non-nil value) first
|
|
if v != nil {
|
|
if v != nil {
|
|
- paramTypes[i+i] = fieldTypeString
|
|
|
|
|
|
+ paramTypes[i+i] = byte(fieldTypeString)
|
|
paramTypes[i+i+1] = 0x00
|
|
paramTypes[i+i+1] = 0x00
|
|
|
|
|
|
if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 {
|
|
if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 {
|
|
@@ -1029,11 +1054,11 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
|
|
|
|
|
// Handle []byte(nil) as a NULL value
|
|
// Handle []byte(nil) as a NULL value
|
|
nullMask[i/8] |= 1 << (uint(i) & 7)
|
|
nullMask[i/8] |= 1 << (uint(i) & 7)
|
|
- paramTypes[i+i] = fieldTypeNULL
|
|
|
|
|
|
+ paramTypes[i+i] = byte(fieldTypeNULL)
|
|
paramTypes[i+i+1] = 0x00
|
|
paramTypes[i+i+1] = 0x00
|
|
|
|
|
|
case string:
|
|
case string:
|
|
- paramTypes[i+i] = fieldTypeString
|
|
|
|
|
|
+ paramTypes[i+i] = byte(fieldTypeString)
|
|
paramTypes[i+i+1] = 0x00
|
|
paramTypes[i+i+1] = 0x00
|
|
|
|
|
|
if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 {
|
|
if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 {
|
|
@@ -1048,20 +1073,22 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
|
}
|
|
}
|
|
|
|
|
|
case time.Time:
|
|
case time.Time:
|
|
- paramTypes[i+i] = fieldTypeString
|
|
|
|
|
|
+ paramTypes[i+i] = byte(fieldTypeString)
|
|
paramTypes[i+i+1] = 0x00
|
|
paramTypes[i+i+1] = 0x00
|
|
|
|
|
|
- var val []byte
|
|
|
|
|
|
+ var a [64]byte
|
|
|
|
+ var b = a[:0]
|
|
|
|
+
|
|
if v.IsZero() {
|
|
if v.IsZero() {
|
|
- val = []byte("0000-00-00")
|
|
|
|
|
|
+ b = append(b, "0000-00-00"...)
|
|
} else {
|
|
} else {
|
|
- val = []byte(v.In(mc.cfg.Loc).Format(timeFormat))
|
|
|
|
|
|
+ b = v.In(mc.cfg.Loc).AppendFormat(b, timeFormat)
|
|
}
|
|
}
|
|
|
|
|
|
paramValues = appendLengthEncodedInteger(paramValues,
|
|
paramValues = appendLengthEncodedInteger(paramValues,
|
|
- uint64(len(val)),
|
|
|
|
|
|
+ uint64(len(b)),
|
|
)
|
|
)
|
|
- paramValues = append(paramValues, val...)
|
|
|
|
|
|
+ paramValues = append(paramValues, b...)
|
|
|
|
|
|
default:
|
|
default:
|
|
return fmt.Errorf("can not convert type: %T", arg)
|
|
return fmt.Errorf("can not convert type: %T", arg)
|
|
@@ -1097,8 +1124,6 @@ func (mc *mysqlConn) discardResults() error {
|
|
if err := mc.readUntilEOF(); err != nil {
|
|
if err := mc.readUntilEOF(); err != nil {
|
|
return err
|
|
return err
|
|
}
|
|
}
|
|
- } else {
|
|
|
|
- mc.status &^= statusMoreResultsExists
|
|
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
return nil
|
|
@@ -1116,20 +1141,17 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
|
|
// EOF Packet
|
|
// EOF Packet
|
|
if data[0] == iEOF && len(data) == 5 {
|
|
if data[0] == iEOF && len(data) == 5 {
|
|
rows.mc.status = readStatus(data[3:])
|
|
rows.mc.status = readStatus(data[3:])
|
|
- err = rows.mc.discardResults()
|
|
|
|
- if err == nil {
|
|
|
|
- err = io.EOF
|
|
|
|
- } else {
|
|
|
|
- // connection unusable
|
|
|
|
- rows.mc.Close()
|
|
|
|
|
|
+ rows.rs.done = true
|
|
|
|
+ if !rows.HasNextResultSet() {
|
|
|
|
+ rows.mc = nil
|
|
}
|
|
}
|
|
- rows.mc = nil
|
|
|
|
- return err
|
|
|
|
|
|
+ return io.EOF
|
|
}
|
|
}
|
|
|
|
+ mc := rows.mc
|
|
rows.mc = nil
|
|
rows.mc = nil
|
|
|
|
|
|
// Error otherwise
|
|
// Error otherwise
|
|
- return rows.mc.handleErrorPacket(data)
|
|
|
|
|
|
+ return mc.handleErrorPacket(data)
|
|
}
|
|
}
|
|
|
|
|
|
// NULL-bitmap, [(column-count + 7 + 2) / 8 bytes]
|
|
// NULL-bitmap, [(column-count + 7 + 2) / 8 bytes]
|
|
@@ -1145,14 +1167,14 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
|
|
}
|
|
}
|
|
|
|
|
|
// Convert to byte-coded string
|
|
// Convert to byte-coded string
|
|
- switch rows.columns[i].fieldType {
|
|
|
|
|
|
+ switch rows.rs.columns[i].fieldType {
|
|
case fieldTypeNULL:
|
|
case fieldTypeNULL:
|
|
dest[i] = nil
|
|
dest[i] = nil
|
|
continue
|
|
continue
|
|
|
|
|
|
// Numeric Types
|
|
// Numeric Types
|
|
case fieldTypeTiny:
|
|
case fieldTypeTiny:
|
|
- if rows.columns[i].flags&flagUnsigned != 0 {
|
|
|
|
|
|
+ if rows.rs.columns[i].flags&flagUnsigned != 0 {
|
|
dest[i] = int64(data[pos])
|
|
dest[i] = int64(data[pos])
|
|
} else {
|
|
} else {
|
|
dest[i] = int64(int8(data[pos]))
|
|
dest[i] = int64(int8(data[pos]))
|
|
@@ -1161,7 +1183,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
|
|
continue
|
|
continue
|
|
|
|
|
|
case fieldTypeShort, fieldTypeYear:
|
|
case fieldTypeShort, fieldTypeYear:
|
|
- if rows.columns[i].flags&flagUnsigned != 0 {
|
|
|
|
|
|
+ if rows.rs.columns[i].flags&flagUnsigned != 0 {
|
|
dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2]))
|
|
dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2]))
|
|
} else {
|
|
} else {
|
|
dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2])))
|
|
dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2])))
|
|
@@ -1170,7 +1192,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
|
|
continue
|
|
continue
|
|
|
|
|
|
case fieldTypeInt24, fieldTypeLong:
|
|
case fieldTypeInt24, fieldTypeLong:
|
|
- if rows.columns[i].flags&flagUnsigned != 0 {
|
|
|
|
|
|
+ if rows.rs.columns[i].flags&flagUnsigned != 0 {
|
|
dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4]))
|
|
dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4]))
|
|
} else {
|
|
} else {
|
|
dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4])))
|
|
dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4])))
|
|
@@ -1179,7 +1201,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
|
|
continue
|
|
continue
|
|
|
|
|
|
case fieldTypeLongLong:
|
|
case fieldTypeLongLong:
|
|
- if rows.columns[i].flags&flagUnsigned != 0 {
|
|
|
|
|
|
+ if rows.rs.columns[i].flags&flagUnsigned != 0 {
|
|
val := binary.LittleEndian.Uint64(data[pos : pos+8])
|
|
val := binary.LittleEndian.Uint64(data[pos : pos+8])
|
|
if val > math.MaxInt64 {
|
|
if val > math.MaxInt64 {
|
|
dest[i] = uint64ToString(val)
|
|
dest[i] = uint64ToString(val)
|
|
@@ -1193,7 +1215,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
|
|
continue
|
|
continue
|
|
|
|
|
|
case fieldTypeFloat:
|
|
case fieldTypeFloat:
|
|
- dest[i] = float32(math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4])))
|
|
|
|
|
|
+ dest[i] = math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4]))
|
|
pos += 4
|
|
pos += 4
|
|
continue
|
|
continue
|
|
|
|
|
|
@@ -1233,10 +1255,10 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
|
|
case isNull:
|
|
case isNull:
|
|
dest[i] = nil
|
|
dest[i] = nil
|
|
continue
|
|
continue
|
|
- case rows.columns[i].fieldType == fieldTypeTime:
|
|
|
|
|
|
+ case rows.rs.columns[i].fieldType == fieldTypeTime:
|
|
// database/sql does not support an equivalent to TIME, return a string
|
|
// database/sql does not support an equivalent to TIME, return a string
|
|
var dstlen uint8
|
|
var dstlen uint8
|
|
- switch decimals := rows.columns[i].decimals; decimals {
|
|
|
|
|
|
+ switch decimals := rows.rs.columns[i].decimals; decimals {
|
|
case 0x00, 0x1f:
|
|
case 0x00, 0x1f:
|
|
dstlen = 8
|
|
dstlen = 8
|
|
case 1, 2, 3, 4, 5, 6:
|
|
case 1, 2, 3, 4, 5, 6:
|
|
@@ -1244,7 +1266,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
|
|
default:
|
|
default:
|
|
return fmt.Errorf(
|
|
return fmt.Errorf(
|
|
"protocol error, illegal decimals value %d",
|
|
"protocol error, illegal decimals value %d",
|
|
- rows.columns[i].decimals,
|
|
|
|
|
|
+ rows.rs.columns[i].decimals,
|
|
)
|
|
)
|
|
}
|
|
}
|
|
dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true)
|
|
dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true)
|
|
@@ -1252,10 +1274,10 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
|
|
dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc)
|
|
dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc)
|
|
default:
|
|
default:
|
|
var dstlen uint8
|
|
var dstlen uint8
|
|
- if rows.columns[i].fieldType == fieldTypeDate {
|
|
|
|
|
|
+ if rows.rs.columns[i].fieldType == fieldTypeDate {
|
|
dstlen = 10
|
|
dstlen = 10
|
|
} else {
|
|
} else {
|
|
- switch decimals := rows.columns[i].decimals; decimals {
|
|
|
|
|
|
+ switch decimals := rows.rs.columns[i].decimals; decimals {
|
|
case 0x00, 0x1f:
|
|
case 0x00, 0x1f:
|
|
dstlen = 19
|
|
dstlen = 19
|
|
case 1, 2, 3, 4, 5, 6:
|
|
case 1, 2, 3, 4, 5, 6:
|
|
@@ -1263,7 +1285,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
|
|
default:
|
|
default:
|
|
return fmt.Errorf(
|
|
return fmt.Errorf(
|
|
"protocol error, illegal decimals value %d",
|
|
"protocol error, illegal decimals value %d",
|
|
- rows.columns[i].decimals,
|
|
|
|
|
|
+ rows.rs.columns[i].decimals,
|
|
)
|
|
)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -1279,7 +1301,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
|
|
|
|
|
|
// Please report if this happens!
|
|
// Please report if this happens!
|
|
default:
|
|
default:
|
|
- return fmt.Errorf("unknown field type %d", rows.columns[i].fieldType)
|
|
|
|
|
|
+ return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|