statement.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
  2. //
  3. // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
  4. //
  5. // This Source Code Form is subject to the terms of the Mozilla Public
  6. // License, v. 2.0. If a copy of the MPL was not distributed with this file,
  7. // You can obtain one at http://mozilla.org/MPL/2.0/.
  8. package mysql
  9. import (
  10. "database/sql/driver"
  11. "fmt"
  12. "io"
  13. "reflect"
  14. "strconv"
  15. )
  16. type mysqlStmt struct {
  17. mc *mysqlConn
  18. id uint32
  19. paramCount int
  20. }
  21. func (stmt *mysqlStmt) Close() error {
  22. if stmt.mc == nil || stmt.mc.closed.IsSet() {
  23. // driver.Stmt.Close can be called more than once, thus this function
  24. // has to be idempotent.
  25. // See also Issue #450 and golang/go#16019.
  26. //errLog.Print(ErrInvalidConn)
  27. return driver.ErrBadConn
  28. }
  29. err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id)
  30. stmt.mc = nil
  31. return err
  32. }
  33. func (stmt *mysqlStmt) NumInput() int {
  34. return stmt.paramCount
  35. }
  36. func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter {
  37. return converter{}
  38. }
  39. func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
  40. if stmt.mc.closed.IsSet() {
  41. errLog.Print(ErrInvalidConn)
  42. return nil, driver.ErrBadConn
  43. }
  44. // Send command
  45. err := stmt.writeExecutePacket(args)
  46. if err != nil {
  47. return nil, stmt.mc.markBadConn(err)
  48. }
  49. mc := stmt.mc
  50. mc.affectedRows = 0
  51. mc.insertId = 0
  52. // Read Result
  53. resLen, err := mc.readResultSetHeaderPacket()
  54. if err != nil {
  55. return nil, err
  56. }
  57. if resLen > 0 {
  58. // Columns
  59. if err = mc.readUntilEOF(); err != nil {
  60. return nil, err
  61. }
  62. // Rows
  63. if err := mc.readUntilEOF(); err != nil {
  64. return nil, err
  65. }
  66. }
  67. if err := mc.discardResults(); err != nil {
  68. return nil, err
  69. }
  70. return &mysqlResult{
  71. affectedRows: int64(mc.affectedRows),
  72. insertId: int64(mc.insertId),
  73. }, nil
  74. }
  75. func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
  76. return stmt.query(args)
  77. }
  78. func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
  79. if stmt.mc.closed.IsSet() {
  80. errLog.Print(ErrInvalidConn)
  81. return nil, driver.ErrBadConn
  82. }
  83. // Send command
  84. err := stmt.writeExecutePacket(args)
  85. if err != nil {
  86. return nil, stmt.mc.markBadConn(err)
  87. }
  88. mc := stmt.mc
  89. // Read Result
  90. resLen, err := mc.readResultSetHeaderPacket()
  91. if err != nil {
  92. return nil, err
  93. }
  94. rows := new(binaryRows)
  95. if resLen > 0 {
  96. rows.mc = mc
  97. rows.rs.columns, err = mc.readColumns(resLen)
  98. } else {
  99. rows.rs.done = true
  100. switch err := rows.NextResultSet(); err {
  101. case nil, io.EOF:
  102. return rows, nil
  103. default:
  104. return nil, err
  105. }
  106. }
  107. return rows, err
  108. }
  109. type converter struct{}
  110. func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
  111. if driver.IsValue(v) {
  112. return v, nil
  113. }
  114. if v != nil {
  115. if valuer, ok := v.(driver.Valuer); ok {
  116. return valuer.Value()
  117. }
  118. }
  119. rv := reflect.ValueOf(v)
  120. switch rv.Kind() {
  121. case reflect.Ptr:
  122. // indirect pointers
  123. if rv.IsNil() {
  124. return nil, nil
  125. }
  126. return c.ConvertValue(rv.Elem().Interface())
  127. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  128. return rv.Int(), nil
  129. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
  130. return int64(rv.Uint()), nil
  131. case reflect.Uint64:
  132. u64 := rv.Uint()
  133. if u64 >= 1<<63 {
  134. return strconv.FormatUint(u64, 10), nil
  135. }
  136. return int64(u64), nil
  137. case reflect.Float32, reflect.Float64:
  138. return rv.Float(), nil
  139. case reflect.Bool:
  140. return rv.Bool(), nil
  141. case reflect.Slice:
  142. ek := rv.Type().Elem().Kind()
  143. if ek == reflect.Uint8 {
  144. return rv.Bytes(), nil
  145. }
  146. return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek)
  147. case reflect.String:
  148. return rv.String(), nil
  149. }
  150. return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind())
  151. }