Extra specs is an opaque valid JSON that can be set on a pool and which will be passed along to the provider as part of instance bootstrap params. This field is meant to allow operators to send extra configuration values to external or built-in providers. The extra specs is not interpreted or useful in any way to garm itself, but it may be useful to the provider which interacts with the IaaS. The extra specs are not meant to be used for secrets. Adding sensitive information to this field is highly discouraged. This field is meant as a means to add fine tuning knobs to the providers, on a per pool basis. Signed-off-by: Gabriel Adrian Samfira <gsamfira@cloudbasesolutions.com>
1349 lines
32 KiB
Go
1349 lines
32 KiB
Go
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
|
//
|
|
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
|
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
|
|
package mysql
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/tls"
|
|
"database/sql/driver"
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"math"
|
|
"time"
|
|
)
|
|
|
|
// Packets documentation:
|
|
// http://dev.mysql.com/doc/internals/en/client-server-protocol.html
|
|
|
|
// Read packet to buffer 'data'
|
|
func (mc *mysqlConn) readPacket() ([]byte, error) {
|
|
var prevData []byte
|
|
for {
|
|
// read packet header
|
|
data, err := mc.buf.readNext(4)
|
|
if err != nil {
|
|
if cerr := mc.canceled.Value(); cerr != nil {
|
|
return nil, cerr
|
|
}
|
|
errLog.Print(err)
|
|
mc.Close()
|
|
return nil, ErrInvalidConn
|
|
}
|
|
|
|
// packet length [24 bit]
|
|
pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16)
|
|
|
|
// check packet sync [8 bit]
|
|
if data[3] != mc.sequence {
|
|
if data[3] > mc.sequence {
|
|
return nil, ErrPktSyncMul
|
|
}
|
|
return nil, ErrPktSync
|
|
}
|
|
mc.sequence++
|
|
|
|
// packets with length 0 terminate a previous packet which is a
|
|
// multiple of (2^24)-1 bytes long
|
|
if pktLen == 0 {
|
|
// there was no previous packet
|
|
if prevData == nil {
|
|
errLog.Print(ErrMalformPkt)
|
|
mc.Close()
|
|
return nil, ErrInvalidConn
|
|
}
|
|
|
|
return prevData, nil
|
|
}
|
|
|
|
// read packet body [pktLen bytes]
|
|
data, err = mc.buf.readNext(pktLen)
|
|
if err != nil {
|
|
if cerr := mc.canceled.Value(); cerr != nil {
|
|
return nil, cerr
|
|
}
|
|
errLog.Print(err)
|
|
mc.Close()
|
|
return nil, ErrInvalidConn
|
|
}
|
|
|
|
// return data if this was the last packet
|
|
if pktLen < maxPacketSize {
|
|
// zero allocations for non-split packets
|
|
if prevData == nil {
|
|
return data, nil
|
|
}
|
|
|
|
return append(prevData, data...), nil
|
|
}
|
|
|
|
prevData = append(prevData, data...)
|
|
}
|
|
}
|
|
|
|
// Write packet buffer 'data'
|
|
func (mc *mysqlConn) writePacket(data []byte) error {
|
|
pktLen := len(data) - 4
|
|
|
|
if pktLen > mc.maxAllowedPacket {
|
|
return ErrPktTooLarge
|
|
}
|
|
|
|
// Perform a stale connection check. We only perform this check for
|
|
// the first query on a connection that has been checked out of the
|
|
// connection pool: a fresh connection from the pool is more likely
|
|
// to be stale, and it has not performed any previous writes that
|
|
// could cause data corruption, so it's safe to return ErrBadConn
|
|
// if the check fails.
|
|
if mc.reset {
|
|
mc.reset = false
|
|
conn := mc.netConn
|
|
if mc.rawConn != nil {
|
|
conn = mc.rawConn
|
|
}
|
|
var err error
|
|
if mc.cfg.CheckConnLiveness {
|
|
if mc.cfg.ReadTimeout != 0 {
|
|
err = conn.SetReadDeadline(time.Now().Add(mc.cfg.ReadTimeout))
|
|
}
|
|
if err == nil {
|
|
err = connCheck(conn)
|
|
}
|
|
}
|
|
if err != nil {
|
|
errLog.Print("closing bad idle connection: ", err)
|
|
mc.Close()
|
|
return driver.ErrBadConn
|
|
}
|
|
}
|
|
|
|
for {
|
|
var size int
|
|
if pktLen >= maxPacketSize {
|
|
data[0] = 0xff
|
|
data[1] = 0xff
|
|
data[2] = 0xff
|
|
size = maxPacketSize
|
|
} else {
|
|
data[0] = byte(pktLen)
|
|
data[1] = byte(pktLen >> 8)
|
|
data[2] = byte(pktLen >> 16)
|
|
size = pktLen
|
|
}
|
|
data[3] = mc.sequence
|
|
|
|
// Write packet
|
|
if mc.writeTimeout > 0 {
|
|
if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
n, err := mc.netConn.Write(data[:4+size])
|
|
if err == nil && n == 4+size {
|
|
mc.sequence++
|
|
if size != maxPacketSize {
|
|
return nil
|
|
}
|
|
pktLen -= size
|
|
data = data[size:]
|
|
continue
|
|
}
|
|
|
|
// Handle error
|
|
if err == nil { // n != len(data)
|
|
mc.cleanup()
|
|
errLog.Print(ErrMalformPkt)
|
|
} else {
|
|
if cerr := mc.canceled.Value(); cerr != nil {
|
|
return cerr
|
|
}
|
|
if n == 0 && pktLen == len(data)-4 {
|
|
// only for the first loop iteration when nothing was written yet
|
|
return errBadConnNoWrite
|
|
}
|
|
mc.cleanup()
|
|
errLog.Print(err)
|
|
}
|
|
return ErrInvalidConn
|
|
}
|
|
}
|
|
|
|
/******************************************************************************
|
|
* Initialization Process *
|
|
******************************************************************************/
|
|
|
|
// Handshake Initialization Packet
|
|
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
|
|
func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) {
|
|
data, err = mc.readPacket()
|
|
if err != nil {
|
|
// for init we can rewrite this to ErrBadConn for sql.Driver to retry, since
|
|
// in connection initialization we don't risk retrying non-idempotent actions.
|
|
if err == ErrInvalidConn {
|
|
return nil, "", driver.ErrBadConn
|
|
}
|
|
return
|
|
}
|
|
|
|
if data[0] == iERR {
|
|
return nil, "", mc.handleErrorPacket(data)
|
|
}
|
|
|
|
// protocol version [1 byte]
|
|
if data[0] < minProtocolVersion {
|
|
return nil, "", fmt.Errorf(
|
|
"unsupported protocol version %d. Version %d or higher is required",
|
|
data[0],
|
|
minProtocolVersion,
|
|
)
|
|
}
|
|
|
|
// server version [null terminated string]
|
|
// connection id [4 bytes]
|
|
pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4
|
|
|
|
// first part of the password cipher [8 bytes]
|
|
authData := data[pos : pos+8]
|
|
|
|
// (filler) always 0x00 [1 byte]
|
|
pos += 8 + 1
|
|
|
|
// capability flags (lower 2 bytes) [2 bytes]
|
|
mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
|
|
if mc.flags&clientProtocol41 == 0 {
|
|
return nil, "", ErrOldProtocol
|
|
}
|
|
if mc.flags&clientSSL == 0 && mc.cfg.TLS != nil {
|
|
if mc.cfg.AllowFallbackToPlaintext {
|
|
mc.cfg.TLS = nil
|
|
} else {
|
|
return nil, "", ErrNoTLS
|
|
}
|
|
}
|
|
pos += 2
|
|
|
|
if len(data) > pos {
|
|
// character set [1 byte]
|
|
// status flags [2 bytes]
|
|
// capability flags (upper 2 bytes) [2 bytes]
|
|
// length of auth-plugin-data [1 byte]
|
|
// reserved (all [00]) [10 bytes]
|
|
pos += 1 + 2 + 2 + 1 + 10
|
|
|
|
// second part of the password cipher [mininum 13 bytes],
|
|
// where len=MAX(13, length of auth-plugin-data - 8)
|
|
//
|
|
// The web documentation is ambiguous about the length. However,
|
|
// according to mysql-5.7/sql/auth/sql_authentication.cc line 538,
|
|
// the 13th byte is "\0 byte, terminating the second part of
|
|
// a scramble". So the second part of the password cipher is
|
|
// a NULL terminated string that's at least 13 bytes with the
|
|
// last byte being NULL.
|
|
//
|
|
// The official Python library uses the fixed length 12
|
|
// which seems to work but technically could have a hidden bug.
|
|
authData = append(authData, data[pos:pos+12]...)
|
|
pos += 13
|
|
|
|
// EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
|
|
// \NUL otherwise
|
|
if end := bytes.IndexByte(data[pos:], 0x00); end != -1 {
|
|
plugin = string(data[pos : pos+end])
|
|
} else {
|
|
plugin = string(data[pos:])
|
|
}
|
|
|
|
// make a memory safe copy of the cipher slice
|
|
var b [20]byte
|
|
copy(b[:], authData)
|
|
return b[:], plugin, nil
|
|
}
|
|
|
|
// make a memory safe copy of the cipher slice
|
|
var b [8]byte
|
|
copy(b[:], authData)
|
|
return b[:], plugin, nil
|
|
}
|
|
|
|
// Client Authentication Packet
|
|
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
|
|
func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error {
|
|
// Adjust client flags based on server support
|
|
clientFlags := clientProtocol41 |
|
|
clientSecureConn |
|
|
clientLongPassword |
|
|
clientTransactions |
|
|
clientLocalFiles |
|
|
clientPluginAuth |
|
|
clientMultiResults |
|
|
mc.flags&clientLongFlag
|
|
|
|
if mc.cfg.ClientFoundRows {
|
|
clientFlags |= clientFoundRows
|
|
}
|
|
|
|
// To enable TLS / SSL
|
|
if mc.cfg.TLS != nil {
|
|
clientFlags |= clientSSL
|
|
}
|
|
|
|
if mc.cfg.MultiStatements {
|
|
clientFlags |= clientMultiStatements
|
|
}
|
|
|
|
// encode length of the auth plugin data
|
|
var authRespLEIBuf [9]byte
|
|
authRespLen := len(authResp)
|
|
authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(authRespLen))
|
|
if len(authRespLEI) > 1 {
|
|
// if the length can not be written in 1 byte, it must be written as a
|
|
// length encoded integer
|
|
clientFlags |= clientPluginAuthLenEncClientData
|
|
}
|
|
|
|
pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1
|
|
|
|
// To specify a db name
|
|
if n := len(mc.cfg.DBName); n > 0 {
|
|
clientFlags |= clientConnectWithDB
|
|
pktLen += n + 1
|
|
}
|
|
|
|
// Calculate packet length and get buffer with that size
|
|
data, err := mc.buf.takeSmallBuffer(pktLen + 4)
|
|
if err != nil {
|
|
// cannot take the buffer. Something must be wrong with the connection
|
|
errLog.Print(err)
|
|
return errBadConnNoWrite
|
|
}
|
|
|
|
// ClientFlags [32 bit]
|
|
data[4] = byte(clientFlags)
|
|
data[5] = byte(clientFlags >> 8)
|
|
data[6] = byte(clientFlags >> 16)
|
|
data[7] = byte(clientFlags >> 24)
|
|
|
|
// MaxPacketSize [32 bit] (none)
|
|
data[8] = 0x00
|
|
data[9] = 0x00
|
|
data[10] = 0x00
|
|
data[11] = 0x00
|
|
|
|
// Charset [1 byte]
|
|
var found bool
|
|
data[12], found = collations[mc.cfg.Collation]
|
|
if !found {
|
|
// Note possibility for false negatives:
|
|
// could be triggered although the collation is valid if the
|
|
// collations map does not contain entries the server supports.
|
|
return errors.New("unknown collation")
|
|
}
|
|
|
|
// Filler [23 bytes] (all 0x00)
|
|
pos := 13
|
|
for ; pos < 13+23; pos++ {
|
|
data[pos] = 0
|
|
}
|
|
|
|
// SSL Connection Request Packet
|
|
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
|
|
if mc.cfg.TLS != nil {
|
|
// Send TLS / SSL request packet
|
|
if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Switch to TLS
|
|
tlsConn := tls.Client(mc.netConn, mc.cfg.TLS)
|
|
if err := tlsConn.Handshake(); err != nil {
|
|
return err
|
|
}
|
|
mc.rawConn = mc.netConn
|
|
mc.netConn = tlsConn
|
|
mc.buf.nc = tlsConn
|
|
}
|
|
|
|
// User [null terminated string]
|
|
if len(mc.cfg.User) > 0 {
|
|
pos += copy(data[pos:], mc.cfg.User)
|
|
}
|
|
data[pos] = 0x00
|
|
pos++
|
|
|
|
// Auth Data [length encoded integer]
|
|
pos += copy(data[pos:], authRespLEI)
|
|
pos += copy(data[pos:], authResp)
|
|
|
|
// Databasename [null terminated string]
|
|
if len(mc.cfg.DBName) > 0 {
|
|
pos += copy(data[pos:], mc.cfg.DBName)
|
|
data[pos] = 0x00
|
|
pos++
|
|
}
|
|
|
|
pos += copy(data[pos:], plugin)
|
|
data[pos] = 0x00
|
|
pos++
|
|
|
|
// Send Auth packet
|
|
return mc.writePacket(data[:pos])
|
|
}
|
|
|
|
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
|
|
func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
|
|
pktLen := 4 + len(authData)
|
|
data, err := mc.buf.takeSmallBuffer(pktLen)
|
|
if err != nil {
|
|
// cannot take the buffer. Something must be wrong with the connection
|
|
errLog.Print(err)
|
|
return errBadConnNoWrite
|
|
}
|
|
|
|
// Add the auth data [EOF]
|
|
copy(data[4:], authData)
|
|
return mc.writePacket(data)
|
|
}
|
|
|
|
/******************************************************************************
|
|
* Command Packets *
|
|
******************************************************************************/
|
|
|
|
func (mc *mysqlConn) writeCommandPacket(command byte) error {
|
|
// Reset Packet Sequence
|
|
mc.sequence = 0
|
|
|
|
data, err := mc.buf.takeSmallBuffer(4 + 1)
|
|
if err != nil {
|
|
// cannot take the buffer. Something must be wrong with the connection
|
|
errLog.Print(err)
|
|
return errBadConnNoWrite
|
|
}
|
|
|
|
// Add command byte
|
|
data[4] = command
|
|
|
|
// Send CMD packet
|
|
return mc.writePacket(data)
|
|
}
|
|
|
|
func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
|
|
// Reset Packet Sequence
|
|
mc.sequence = 0
|
|
|
|
pktLen := 1 + len(arg)
|
|
data, err := mc.buf.takeBuffer(pktLen + 4)
|
|
if err != nil {
|
|
// cannot take the buffer. Something must be wrong with the connection
|
|
errLog.Print(err)
|
|
return errBadConnNoWrite
|
|
}
|
|
|
|
// Add command byte
|
|
data[4] = command
|
|
|
|
// Add arg
|
|
copy(data[5:], arg)
|
|
|
|
// Send CMD packet
|
|
return mc.writePacket(data)
|
|
}
|
|
|
|
func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
|
|
// Reset Packet Sequence
|
|
mc.sequence = 0
|
|
|
|
data, err := mc.buf.takeSmallBuffer(4 + 1 + 4)
|
|
if err != nil {
|
|
// cannot take the buffer. Something must be wrong with the connection
|
|
errLog.Print(err)
|
|
return errBadConnNoWrite
|
|
}
|
|
|
|
// Add command byte
|
|
data[4] = command
|
|
|
|
// Add arg [32 bit]
|
|
data[5] = byte(arg)
|
|
data[6] = byte(arg >> 8)
|
|
data[7] = byte(arg >> 16)
|
|
data[8] = byte(arg >> 24)
|
|
|
|
// Send CMD packet
|
|
return mc.writePacket(data)
|
|
}
|
|
|
|
/******************************************************************************
|
|
* Result Packets *
|
|
******************************************************************************/
|
|
|
|
func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {
|
|
data, err := mc.readPacket()
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
// packet indicator
|
|
switch data[0] {
|
|
|
|
case iOK:
|
|
return nil, "", mc.handleOkPacket(data)
|
|
|
|
case iAuthMoreData:
|
|
return data[1:], "", err
|
|
|
|
case iEOF:
|
|
if len(data) == 1 {
|
|
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest
|
|
return nil, "mysql_old_password", nil
|
|
}
|
|
pluginEndIndex := bytes.IndexByte(data, 0x00)
|
|
if pluginEndIndex < 0 {
|
|
return nil, "", ErrMalformPkt
|
|
}
|
|
plugin := string(data[1:pluginEndIndex])
|
|
authData := data[pluginEndIndex+1:]
|
|
return authData, plugin, nil
|
|
|
|
default: // Error otherwise
|
|
return nil, "", mc.handleErrorPacket(data)
|
|
}
|
|
}
|
|
|
|
// Returns error if Packet is not an 'Result OK'-Packet
|
|
func (mc *mysqlConn) readResultOK() error {
|
|
data, err := mc.readPacket()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if data[0] == iOK {
|
|
return mc.handleOkPacket(data)
|
|
}
|
|
return mc.handleErrorPacket(data)
|
|
}
|
|
|
|
// Result Set Header Packet
|
|
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset
|
|
func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
|
|
data, err := mc.readPacket()
|
|
if err == nil {
|
|
switch data[0] {
|
|
|
|
case iOK:
|
|
return 0, mc.handleOkPacket(data)
|
|
|
|
case iERR:
|
|
return 0, mc.handleErrorPacket(data)
|
|
|
|
case iLocalInFile:
|
|
return 0, mc.handleInFileRequest(string(data[1:]))
|
|
}
|
|
|
|
// column count
|
|
num, _, n := readLengthEncodedInteger(data)
|
|
if n-len(data) == 0 {
|
|
return int(num), nil
|
|
}
|
|
|
|
return 0, ErrMalformPkt
|
|
}
|
|
return 0, err
|
|
}
|
|
|
|
// Error Packet
|
|
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-ERR_Packet
|
|
func (mc *mysqlConn) handleErrorPacket(data []byte) error {
|
|
if data[0] != iERR {
|
|
return ErrMalformPkt
|
|
}
|
|
|
|
// 0xff [1 byte]
|
|
|
|
// Error Number [16 bit uint]
|
|
errno := binary.LittleEndian.Uint16(data[1:3])
|
|
|
|
// 1792: ER_CANT_EXECUTE_IN_READ_ONLY_TRANSACTION
|
|
// 1290: ER_OPTION_PREVENTS_STATEMENT (returned by Aurora during failover)
|
|
if (errno == 1792 || errno == 1290) && mc.cfg.RejectReadOnly {
|
|
// Oops; we are connected to a read-only connection, and won't be able
|
|
// to issue any write statements. Since RejectReadOnly is configured,
|
|
// we throw away this connection hoping this one would have write
|
|
// permission. This is specifically for a possible race condition
|
|
// during failover (e.g. on AWS Aurora). See README.md for more.
|
|
//
|
|
// We explicitly close the connection before returning
|
|
// driver.ErrBadConn to ensure that `database/sql` purges this
|
|
// connection and initiates a new one for next statement next time.
|
|
mc.Close()
|
|
return driver.ErrBadConn
|
|
}
|
|
|
|
me := &MySQLError{Number: errno}
|
|
|
|
pos := 3
|
|
|
|
// SQL State [optional: # + 5bytes string]
|
|
if data[3] == 0x23 {
|
|
copy(me.SQLState[:], data[4:4+5])
|
|
pos = 9
|
|
}
|
|
|
|
// Error Message [string]
|
|
me.Message = string(data[pos:])
|
|
|
|
return me
|
|
}
|
|
|
|
func readStatus(b []byte) statusFlag {
|
|
return statusFlag(b[0]) | statusFlag(b[1])<<8
|
|
}
|
|
|
|
// Ok Packet
|
|
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
|
|
func (mc *mysqlConn) handleOkPacket(data []byte) error {
|
|
var n, m int
|
|
|
|
// 0x00 [1 byte]
|
|
|
|
// Affected rows [Length Coded Binary]
|
|
mc.affectedRows, _, n = readLengthEncodedInteger(data[1:])
|
|
|
|
// Insert id [Length Coded Binary]
|
|
mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
|
|
|
|
// server_status [2 bytes]
|
|
mc.status = readStatus(data[1+n+m : 1+n+m+2])
|
|
if mc.status&statusMoreResultsExists != 0 {
|
|
return nil
|
|
}
|
|
|
|
// warning count [2 bytes]
|
|
|
|
return nil
|
|
}
|
|
|
|
// Read Packets as Field Packets until EOF-Packet or an Error appears
|
|
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41
|
|
func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
|
|
columns := make([]mysqlField, count)
|
|
|
|
for i := 0; ; i++ {
|
|
data, err := mc.readPacket()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// EOF Packet
|
|
if data[0] == iEOF && (len(data) == 5 || len(data) == 1) {
|
|
if i == count {
|
|
return columns, nil
|
|
}
|
|
return nil, fmt.Errorf("column count mismatch n:%d len:%d", count, len(columns))
|
|
}
|
|
|
|
// Catalog
|
|
pos, err := skipLengthEncodedString(data)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Database [len coded string]
|
|
n, err := skipLengthEncodedString(data[pos:])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
pos += n
|
|
|
|
// Table [len coded string]
|
|
if mc.cfg.ColumnsWithAlias {
|
|
tableName, _, n, err := readLengthEncodedString(data[pos:])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
pos += n
|
|
columns[i].tableName = string(tableName)
|
|
} else {
|
|
n, err = skipLengthEncodedString(data[pos:])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
pos += n
|
|
}
|
|
|
|
// Original table [len coded string]
|
|
n, err = skipLengthEncodedString(data[pos:])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
pos += n
|
|
|
|
// Name [len coded string]
|
|
name, _, n, err := readLengthEncodedString(data[pos:])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
columns[i].name = string(name)
|
|
pos += n
|
|
|
|
// Original name [len coded string]
|
|
n, err = skipLengthEncodedString(data[pos:])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
pos += n
|
|
|
|
// Filler [uint8]
|
|
pos++
|
|
|
|
// Charset [charset, collation uint8]
|
|
columns[i].charSet = data[pos]
|
|
pos += 2
|
|
|
|
// Length [uint32]
|
|
columns[i].length = binary.LittleEndian.Uint32(data[pos : pos+4])
|
|
pos += 4
|
|
|
|
// Field type [uint8]
|
|
columns[i].fieldType = fieldType(data[pos])
|
|
pos++
|
|
|
|
// Flags [uint16]
|
|
columns[i].flags = fieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
|
|
pos += 2
|
|
|
|
// Decimals [uint8]
|
|
columns[i].decimals = data[pos]
|
|
//pos++
|
|
|
|
// Default value [len coded binary]
|
|
//if pos < len(data) {
|
|
// defaultVal, _, err = bytesToLengthCodedBinary(data[pos:])
|
|
//}
|
|
}
|
|
}
|
|
|
|
// Read Packets as Field Packets until EOF-Packet or an Error appears
|
|
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow
|
|
func (rows *textRows) readRow(dest []driver.Value) error {
|
|
mc := rows.mc
|
|
|
|
if rows.rs.done {
|
|
return io.EOF
|
|
}
|
|
|
|
data, err := mc.readPacket()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// EOF Packet
|
|
if data[0] == iEOF && len(data) == 5 {
|
|
// server_status [2 bytes]
|
|
rows.mc.status = readStatus(data[3:])
|
|
rows.rs.done = true
|
|
if !rows.HasNextResultSet() {
|
|
rows.mc = nil
|
|
}
|
|
return io.EOF
|
|
}
|
|
if data[0] == iERR {
|
|
rows.mc = nil
|
|
return mc.handleErrorPacket(data)
|
|
}
|
|
|
|
// RowSet Packet
|
|
var (
|
|
n int
|
|
isNull bool
|
|
pos int = 0
|
|
)
|
|
|
|
for i := range dest {
|
|
// Read bytes and convert to string
|
|
dest[i], isNull, n, err = readLengthEncodedString(data[pos:])
|
|
pos += n
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if isNull {
|
|
dest[i] = nil
|
|
continue
|
|
}
|
|
|
|
if !mc.parseTime {
|
|
continue
|
|
}
|
|
|
|
// Parse time field
|
|
switch rows.rs.columns[i].fieldType {
|
|
case fieldTypeTimestamp,
|
|
fieldTypeDateTime,
|
|
fieldTypeDate,
|
|
fieldTypeNewDate:
|
|
if dest[i], err = parseDateTime(dest[i].([]byte), mc.cfg.Loc); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read
|
|
func (mc *mysqlConn) readUntilEOF() error {
|
|
for {
|
|
data, err := mc.readPacket()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
switch data[0] {
|
|
case iERR:
|
|
return mc.handleErrorPacket(data)
|
|
case iEOF:
|
|
if len(data) == 5 {
|
|
mc.status = readStatus(data[3:])
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
|
|
/******************************************************************************
|
|
* Prepared Statements *
|
|
******************************************************************************/
|
|
|
|
// Prepare Result Packets
|
|
// http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html
|
|
func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
|
|
data, err := stmt.mc.readPacket()
|
|
if err == nil {
|
|
// packet indicator [1 byte]
|
|
if data[0] != iOK {
|
|
return 0, stmt.mc.handleErrorPacket(data)
|
|
}
|
|
|
|
// statement id [4 bytes]
|
|
stmt.id = binary.LittleEndian.Uint32(data[1:5])
|
|
|
|
// Column count [16 bit uint]
|
|
columnCount := binary.LittleEndian.Uint16(data[5:7])
|
|
|
|
// Param count [16 bit uint]
|
|
stmt.paramCount = int(binary.LittleEndian.Uint16(data[7:9]))
|
|
|
|
// Reserved [8 bit]
|
|
|
|
// Warning count [16 bit uint]
|
|
|
|
return columnCount, nil
|
|
}
|
|
return 0, err
|
|
}
|
|
|
|
// http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html
|
|
func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
|
|
maxLen := stmt.mc.maxAllowedPacket - 1
|
|
pktLen := maxLen
|
|
|
|
// After the header (bytes 0-3) follows before the data:
|
|
// 1 byte command
|
|
// 4 bytes stmtID
|
|
// 2 bytes paramID
|
|
const dataOffset = 1 + 4 + 2
|
|
|
|
// Cannot use the write buffer since
|
|
// a) the buffer is too small
|
|
// b) it is in use
|
|
data := make([]byte, 4+1+4+2+len(arg))
|
|
|
|
copy(data[4+dataOffset:], arg)
|
|
|
|
for argLen := len(arg); argLen > 0; argLen -= pktLen - dataOffset {
|
|
if dataOffset+argLen < maxLen {
|
|
pktLen = dataOffset + argLen
|
|
}
|
|
|
|
stmt.mc.sequence = 0
|
|
// Add command byte [1 byte]
|
|
data[4] = comStmtSendLongData
|
|
|
|
// Add stmtID [32 bit]
|
|
data[5] = byte(stmt.id)
|
|
data[6] = byte(stmt.id >> 8)
|
|
data[7] = byte(stmt.id >> 16)
|
|
data[8] = byte(stmt.id >> 24)
|
|
|
|
// Add paramID [16 bit]
|
|
data[9] = byte(paramID)
|
|
data[10] = byte(paramID >> 8)
|
|
|
|
// Send CMD packet
|
|
err := stmt.mc.writePacket(data[:4+pktLen])
|
|
if err == nil {
|
|
data = data[pktLen-dataOffset:]
|
|
continue
|
|
}
|
|
return err
|
|
|
|
}
|
|
|
|
// Reset Packet Sequence
|
|
stmt.mc.sequence = 0
|
|
return nil
|
|
}
|
|
|
|
// Execute Prepared Statement
|
|
// http://dev.mysql.com/doc/internals/en/com-stmt-execute.html
|
|
func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
|
if len(args) != stmt.paramCount {
|
|
return fmt.Errorf(
|
|
"argument count mismatch (got: %d; has: %d)",
|
|
len(args),
|
|
stmt.paramCount,
|
|
)
|
|
}
|
|
|
|
const minPktLen = 4 + 1 + 4 + 1 + 4
|
|
mc := stmt.mc
|
|
|
|
// Determine threshold dynamically to avoid packet size shortage.
|
|
longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1)
|
|
if longDataSize < 64 {
|
|
longDataSize = 64
|
|
}
|
|
|
|
// Reset packet-sequence
|
|
mc.sequence = 0
|
|
|
|
var data []byte
|
|
var err error
|
|
|
|
if len(args) == 0 {
|
|
data, err = mc.buf.takeBuffer(minPktLen)
|
|
} else {
|
|
data, err = mc.buf.takeCompleteBuffer()
|
|
// In this case the len(data) == cap(data) which is used to optimise the flow below.
|
|
}
|
|
if err != nil {
|
|
// cannot take the buffer. Something must be wrong with the connection
|
|
errLog.Print(err)
|
|
return errBadConnNoWrite
|
|
}
|
|
|
|
// command [1 byte]
|
|
data[4] = comStmtExecute
|
|
|
|
// statement_id [4 bytes]
|
|
data[5] = byte(stmt.id)
|
|
data[6] = byte(stmt.id >> 8)
|
|
data[7] = byte(stmt.id >> 16)
|
|
data[8] = byte(stmt.id >> 24)
|
|
|
|
// flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte]
|
|
data[9] = 0x00
|
|
|
|
// iteration_count (uint32(1)) [4 bytes]
|
|
data[10] = 0x01
|
|
data[11] = 0x00
|
|
data[12] = 0x00
|
|
data[13] = 0x00
|
|
|
|
if len(args) > 0 {
|
|
pos := minPktLen
|
|
|
|
var nullMask []byte
|
|
if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= cap(data) {
|
|
// buffer has to be extended but we don't know by how much so
|
|
// we depend on append after all data with known sizes fit.
|
|
// We stop at that because we deal with a lot of columns here
|
|
// which makes the required allocation size hard to guess.
|
|
tmp := make([]byte, pos+maskLen+typesLen)
|
|
copy(tmp[:pos], data[:pos])
|
|
data = tmp
|
|
nullMask = data[pos : pos+maskLen]
|
|
// No need to clean nullMask as make ensures that.
|
|
pos += maskLen
|
|
} else {
|
|
nullMask = data[pos : pos+maskLen]
|
|
for i := range nullMask {
|
|
nullMask[i] = 0
|
|
}
|
|
pos += maskLen
|
|
}
|
|
|
|
// newParameterBoundFlag 1 [1 byte]
|
|
data[pos] = 0x01
|
|
pos++
|
|
|
|
// type of each parameter [len(args)*2 bytes]
|
|
paramTypes := data[pos:]
|
|
pos += len(args) * 2
|
|
|
|
// value of each parameter [n bytes]
|
|
paramValues := data[pos:pos]
|
|
valuesCap := cap(paramValues)
|
|
|
|
for i, arg := range args {
|
|
// build NULL-bitmap
|
|
if arg == nil {
|
|
nullMask[i/8] |= 1 << (uint(i) & 7)
|
|
paramTypes[i+i] = byte(fieldTypeNULL)
|
|
paramTypes[i+i+1] = 0x00
|
|
continue
|
|
}
|
|
|
|
if v, ok := arg.(json.RawMessage); ok {
|
|
arg = []byte(v)
|
|
}
|
|
// cache types and values
|
|
switch v := arg.(type) {
|
|
case int64:
|
|
paramTypes[i+i] = byte(fieldTypeLongLong)
|
|
paramTypes[i+i+1] = 0x00
|
|
|
|
if cap(paramValues)-len(paramValues)-8 >= 0 {
|
|
paramValues = paramValues[:len(paramValues)+8]
|
|
binary.LittleEndian.PutUint64(
|
|
paramValues[len(paramValues)-8:],
|
|
uint64(v),
|
|
)
|
|
} else {
|
|
paramValues = append(paramValues,
|
|
uint64ToBytes(uint64(v))...,
|
|
)
|
|
}
|
|
|
|
case uint64:
|
|
paramTypes[i+i] = byte(fieldTypeLongLong)
|
|
paramTypes[i+i+1] = 0x80 // type is unsigned
|
|
|
|
if cap(paramValues)-len(paramValues)-8 >= 0 {
|
|
paramValues = paramValues[:len(paramValues)+8]
|
|
binary.LittleEndian.PutUint64(
|
|
paramValues[len(paramValues)-8:],
|
|
uint64(v),
|
|
)
|
|
} else {
|
|
paramValues = append(paramValues,
|
|
uint64ToBytes(uint64(v))...,
|
|
)
|
|
}
|
|
|
|
case float64:
|
|
paramTypes[i+i] = byte(fieldTypeDouble)
|
|
paramTypes[i+i+1] = 0x00
|
|
|
|
if cap(paramValues)-len(paramValues)-8 >= 0 {
|
|
paramValues = paramValues[:len(paramValues)+8]
|
|
binary.LittleEndian.PutUint64(
|
|
paramValues[len(paramValues)-8:],
|
|
math.Float64bits(v),
|
|
)
|
|
} else {
|
|
paramValues = append(paramValues,
|
|
uint64ToBytes(math.Float64bits(v))...,
|
|
)
|
|
}
|
|
|
|
case bool:
|
|
paramTypes[i+i] = byte(fieldTypeTiny)
|
|
paramTypes[i+i+1] = 0x00
|
|
|
|
if v {
|
|
paramValues = append(paramValues, 0x01)
|
|
} else {
|
|
paramValues = append(paramValues, 0x00)
|
|
}
|
|
|
|
case []byte:
|
|
// Common case (non-nil value) first
|
|
if v != nil {
|
|
paramTypes[i+i] = byte(fieldTypeString)
|
|
paramTypes[i+i+1] = 0x00
|
|
|
|
if len(v) < longDataSize {
|
|
paramValues = appendLengthEncodedInteger(paramValues,
|
|
uint64(len(v)),
|
|
)
|
|
paramValues = append(paramValues, v...)
|
|
} else {
|
|
if err := stmt.writeCommandLongData(i, v); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
continue
|
|
}
|
|
|
|
// Handle []byte(nil) as a NULL value
|
|
nullMask[i/8] |= 1 << (uint(i) & 7)
|
|
paramTypes[i+i] = byte(fieldTypeNULL)
|
|
paramTypes[i+i+1] = 0x00
|
|
|
|
case string:
|
|
paramTypes[i+i] = byte(fieldTypeString)
|
|
paramTypes[i+i+1] = 0x00
|
|
|
|
if len(v) < longDataSize {
|
|
paramValues = appendLengthEncodedInteger(paramValues,
|
|
uint64(len(v)),
|
|
)
|
|
paramValues = append(paramValues, v...)
|
|
} else {
|
|
if err := stmt.writeCommandLongData(i, []byte(v)); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
case time.Time:
|
|
paramTypes[i+i] = byte(fieldTypeString)
|
|
paramTypes[i+i+1] = 0x00
|
|
|
|
var a [64]byte
|
|
var b = a[:0]
|
|
|
|
if v.IsZero() {
|
|
b = append(b, "0000-00-00"...)
|
|
} else {
|
|
b, err = appendDateTime(b, v.In(mc.cfg.Loc))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
paramValues = appendLengthEncodedInteger(paramValues,
|
|
uint64(len(b)),
|
|
)
|
|
paramValues = append(paramValues, b...)
|
|
|
|
default:
|
|
return fmt.Errorf("cannot convert type: %T", arg)
|
|
}
|
|
}
|
|
|
|
// Check if param values exceeded the available buffer
|
|
// In that case we must build the data packet with the new values buffer
|
|
if valuesCap != cap(paramValues) {
|
|
data = append(data[:pos], paramValues...)
|
|
if err = mc.buf.store(data); err != nil {
|
|
errLog.Print(err)
|
|
return errBadConnNoWrite
|
|
}
|
|
}
|
|
|
|
pos += len(paramValues)
|
|
data = data[:pos]
|
|
}
|
|
|
|
return mc.writePacket(data)
|
|
}
|
|
|
|
func (mc *mysqlConn) discardResults() error {
|
|
for mc.status&statusMoreResultsExists != 0 {
|
|
resLen, err := mc.readResultSetHeaderPacket()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if resLen > 0 {
|
|
// columns
|
|
if err := mc.readUntilEOF(); err != nil {
|
|
return err
|
|
}
|
|
// rows
|
|
if err := mc.readUntilEOF(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html
|
|
func (rows *binaryRows) readRow(dest []driver.Value) error {
|
|
data, err := rows.mc.readPacket()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// packet indicator [1 byte]
|
|
if data[0] != iOK {
|
|
// EOF Packet
|
|
if data[0] == iEOF && len(data) == 5 {
|
|
rows.mc.status = readStatus(data[3:])
|
|
rows.rs.done = true
|
|
if !rows.HasNextResultSet() {
|
|
rows.mc = nil
|
|
}
|
|
return io.EOF
|
|
}
|
|
mc := rows.mc
|
|
rows.mc = nil
|
|
|
|
// Error otherwise
|
|
return mc.handleErrorPacket(data)
|
|
}
|
|
|
|
// NULL-bitmap, [(column-count + 7 + 2) / 8 bytes]
|
|
pos := 1 + (len(dest)+7+2)>>3
|
|
nullMask := data[1:pos]
|
|
|
|
for i := range dest {
|
|
// Field is NULL
|
|
// (byte >> bit-pos) % 2 == 1
|
|
if ((nullMask[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 {
|
|
dest[i] = nil
|
|
continue
|
|
}
|
|
|
|
// Convert to byte-coded string
|
|
switch rows.rs.columns[i].fieldType {
|
|
case fieldTypeNULL:
|
|
dest[i] = nil
|
|
continue
|
|
|
|
// Numeric Types
|
|
case fieldTypeTiny:
|
|
if rows.rs.columns[i].flags&flagUnsigned != 0 {
|
|
dest[i] = int64(data[pos])
|
|
} else {
|
|
dest[i] = int64(int8(data[pos]))
|
|
}
|
|
pos++
|
|
continue
|
|
|
|
case fieldTypeShort, fieldTypeYear:
|
|
if rows.rs.columns[i].flags&flagUnsigned != 0 {
|
|
dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2]))
|
|
} else {
|
|
dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2])))
|
|
}
|
|
pos += 2
|
|
continue
|
|
|
|
case fieldTypeInt24, fieldTypeLong:
|
|
if rows.rs.columns[i].flags&flagUnsigned != 0 {
|
|
dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4]))
|
|
} else {
|
|
dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4])))
|
|
}
|
|
pos += 4
|
|
continue
|
|
|
|
case fieldTypeLongLong:
|
|
if rows.rs.columns[i].flags&flagUnsigned != 0 {
|
|
val := binary.LittleEndian.Uint64(data[pos : pos+8])
|
|
if val > math.MaxInt64 {
|
|
dest[i] = uint64ToString(val)
|
|
} else {
|
|
dest[i] = int64(val)
|
|
}
|
|
} else {
|
|
dest[i] = int64(binary.LittleEndian.Uint64(data[pos : pos+8]))
|
|
}
|
|
pos += 8
|
|
continue
|
|
|
|
case fieldTypeFloat:
|
|
dest[i] = math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4]))
|
|
pos += 4
|
|
continue
|
|
|
|
case fieldTypeDouble:
|
|
dest[i] = math.Float64frombits(binary.LittleEndian.Uint64(data[pos : pos+8]))
|
|
pos += 8
|
|
continue
|
|
|
|
// Length coded Binary Strings
|
|
case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar,
|
|
fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB,
|
|
fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB,
|
|
fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON:
|
|
var isNull bool
|
|
var n int
|
|
dest[i], isNull, n, err = readLengthEncodedString(data[pos:])
|
|
pos += n
|
|
if err == nil {
|
|
if !isNull {
|
|
continue
|
|
} else {
|
|
dest[i] = nil
|
|
continue
|
|
}
|
|
}
|
|
return err
|
|
|
|
case
|
|
fieldTypeDate, fieldTypeNewDate, // Date YYYY-MM-DD
|
|
fieldTypeTime, // Time [-][H]HH:MM:SS[.fractal]
|
|
fieldTypeTimestamp, fieldTypeDateTime: // Timestamp YYYY-MM-DD HH:MM:SS[.fractal]
|
|
|
|
num, isNull, n := readLengthEncodedInteger(data[pos:])
|
|
pos += n
|
|
|
|
switch {
|
|
case isNull:
|
|
dest[i] = nil
|
|
continue
|
|
case rows.rs.columns[i].fieldType == fieldTypeTime:
|
|
// database/sql does not support an equivalent to TIME, return a string
|
|
var dstlen uint8
|
|
switch decimals := rows.rs.columns[i].decimals; decimals {
|
|
case 0x00, 0x1f:
|
|
dstlen = 8
|
|
case 1, 2, 3, 4, 5, 6:
|
|
dstlen = 8 + 1 + decimals
|
|
default:
|
|
return fmt.Errorf(
|
|
"protocol error, illegal decimals value %d",
|
|
rows.rs.columns[i].decimals,
|
|
)
|
|
}
|
|
dest[i], err = formatBinaryTime(data[pos:pos+int(num)], dstlen)
|
|
case rows.mc.parseTime:
|
|
dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc)
|
|
default:
|
|
var dstlen uint8
|
|
if rows.rs.columns[i].fieldType == fieldTypeDate {
|
|
dstlen = 10
|
|
} else {
|
|
switch decimals := rows.rs.columns[i].decimals; decimals {
|
|
case 0x00, 0x1f:
|
|
dstlen = 19
|
|
case 1, 2, 3, 4, 5, 6:
|
|
dstlen = 19 + 1 + decimals
|
|
default:
|
|
return fmt.Errorf(
|
|
"protocol error, illegal decimals value %d",
|
|
rows.rs.columns[i].decimals,
|
|
)
|
|
}
|
|
}
|
|
dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen)
|
|
}
|
|
|
|
if err == nil {
|
|
pos += int(num)
|
|
continue
|
|
} else {
|
|
return err
|
|
}
|
|
|
|
// Please report if this happens!
|
|
default:
|
|
return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|