dialect_mysql.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. package gorm
  2. import (
  3. "crypto/sha1"
  4. "fmt"
  5. "reflect"
  6. "regexp"
  7. "strings"
  8. "time"
  9. "unicode/utf8"
  10. )
  11. type mysql struct {
  12. commonDialect
  13. }
  14. func init() {
  15. RegisterDialect("mysql", &mysql{})
  16. }
  17. func (mysql) GetName() string {
  18. return "mysql"
  19. }
  20. func (mysql) Quote(key string) string {
  21. return fmt.Sprintf("`%s`", key)
  22. }
  23. // Get Data Type for MySQL Dialect
  24. func (s *mysql) DataTypeOf(field *StructField) string {
  25. var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
  26. // MySQL allows only one auto increment column per table, and it must
  27. // be a KEY column.
  28. if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
  29. if _, ok = field.TagSettings["INDEX"]; !ok && !field.IsPrimaryKey {
  30. delete(field.TagSettings, "AUTO_INCREMENT")
  31. }
  32. }
  33. if sqlType == "" {
  34. switch dataValue.Kind() {
  35. case reflect.Bool:
  36. sqlType = "boolean"
  37. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32:
  38. if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
  39. field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
  40. sqlType = "int AUTO_INCREMENT"
  41. } else {
  42. sqlType = "int"
  43. }
  44. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
  45. if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
  46. field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
  47. sqlType = "int unsigned AUTO_INCREMENT"
  48. } else {
  49. sqlType = "int unsigned"
  50. }
  51. case reflect.Int64:
  52. if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
  53. field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
  54. sqlType = "bigint AUTO_INCREMENT"
  55. } else {
  56. sqlType = "bigint"
  57. }
  58. case reflect.Uint64:
  59. if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
  60. field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
  61. sqlType = "bigint unsigned AUTO_INCREMENT"
  62. } else {
  63. sqlType = "bigint unsigned"
  64. }
  65. case reflect.Float32, reflect.Float64:
  66. sqlType = "double"
  67. case reflect.String:
  68. if size > 0 && size < 65532 {
  69. sqlType = fmt.Sprintf("varchar(%d)", size)
  70. } else {
  71. sqlType = "longtext"
  72. }
  73. case reflect.Struct:
  74. if _, ok := dataValue.Interface().(time.Time); ok {
  75. if _, ok := field.TagSettings["NOT NULL"]; ok {
  76. sqlType = "timestamp"
  77. } else {
  78. sqlType = "timestamp NULL"
  79. }
  80. }
  81. default:
  82. if _, ok := dataValue.Interface().([]byte); ok {
  83. if size > 0 && size < 65532 {
  84. sqlType = fmt.Sprintf("varbinary(%d)", size)
  85. } else {
  86. sqlType = "longblob"
  87. }
  88. }
  89. }
  90. }
  91. if sqlType == "" {
  92. panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String()))
  93. }
  94. if strings.TrimSpace(additionalType) == "" {
  95. return sqlType
  96. }
  97. return fmt.Sprintf("%v %v", sqlType, additionalType)
  98. }
  99. func (s mysql) RemoveIndex(tableName string, indexName string) error {
  100. _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
  101. return err
  102. }
  103. func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
  104. var count int
  105. s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count)
  106. return count > 0
  107. }
  108. func (s mysql) CurrentDatabase() (name string) {
  109. s.db.QueryRow("SELECT DATABASE()").Scan(&name)
  110. return
  111. }
  112. func (mysql) SelectFromDummyTable() string {
  113. return "FROM DUAL"
  114. }
  115. func (s mysql) BuildForeignKeyName(tableName, field, dest string) string {
  116. keyName := s.commonDialect.BuildForeignKeyName(tableName, field, dest)
  117. if utf8.RuneCountInString(keyName) <= 64 {
  118. return keyName
  119. }
  120. h := sha1.New()
  121. h.Write([]byte(keyName))
  122. bs := h.Sum(nil)
  123. // sha1 is 40 digits, keep first 24 characters of destination
  124. destRunes := []rune(regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(dest, "_"))
  125. if len(destRunes) > 24 {
  126. destRunes = destRunes[:24]
  127. }
  128. return fmt.Sprintf("%s%x", string(destRunes), bs)
  129. }