dialect_common.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. package gorm
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "reflect"
  6. "regexp"
  7. "strconv"
  8. "strings"
  9. "time"
  10. )
  11. // DefaultForeignKeyNamer contains the default foreign key name generator method
  12. type DefaultForeignKeyNamer struct {
  13. }
  14. type commonDialect struct {
  15. db *sql.DB
  16. DefaultForeignKeyNamer
  17. }
  18. func init() {
  19. RegisterDialect("common", &commonDialect{})
  20. }
  21. func (commonDialect) GetName() string {
  22. return "common"
  23. }
  24. func (s *commonDialect) SetDB(db *sql.DB) {
  25. s.db = db
  26. }
  27. func (commonDialect) BindVar(i int) string {
  28. return "$$" // ?
  29. }
  30. func (commonDialect) Quote(key string) string {
  31. return fmt.Sprintf(`"%s"`, key)
  32. }
  33. func (s *commonDialect) DataTypeOf(field *StructField) string {
  34. var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
  35. if sqlType == "" {
  36. switch dataValue.Kind() {
  37. case reflect.Bool:
  38. sqlType = "BOOLEAN"
  39. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
  40. if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
  41. sqlType = "INTEGER AUTO_INCREMENT"
  42. } else {
  43. sqlType = "INTEGER"
  44. }
  45. case reflect.Int64, reflect.Uint64:
  46. if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
  47. sqlType = "BIGINT AUTO_INCREMENT"
  48. } else {
  49. sqlType = "BIGINT"
  50. }
  51. case reflect.Float32, reflect.Float64:
  52. sqlType = "FLOAT"
  53. case reflect.String:
  54. if size > 0 && size < 65532 {
  55. sqlType = fmt.Sprintf("VARCHAR(%d)", size)
  56. } else {
  57. sqlType = "VARCHAR(65532)"
  58. }
  59. case reflect.Struct:
  60. if _, ok := dataValue.Interface().(time.Time); ok {
  61. sqlType = "TIMESTAMP"
  62. }
  63. default:
  64. if _, ok := dataValue.Interface().([]byte); ok {
  65. if size > 0 && size < 65532 {
  66. sqlType = fmt.Sprintf("BINARY(%d)", size)
  67. } else {
  68. sqlType = "BINARY(65532)"
  69. }
  70. }
  71. }
  72. }
  73. if sqlType == "" {
  74. panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", dataValue.Type().Name(), dataValue.Kind().String()))
  75. }
  76. if strings.TrimSpace(additionalType) == "" {
  77. return sqlType
  78. }
  79. return fmt.Sprintf("%v %v", sqlType, additionalType)
  80. }
  81. func (s commonDialect) HasIndex(tableName string, indexName string) bool {
  82. var count int
  83. s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", s.CurrentDatabase(), tableName, indexName).Scan(&count)
  84. return count > 0
  85. }
  86. func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
  87. _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName))
  88. return err
  89. }
  90. func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool {
  91. return false
  92. }
  93. func (s commonDialect) HasTable(tableName string) bool {
  94. var count int
  95. s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", s.CurrentDatabase(), tableName).Scan(&count)
  96. return count > 0
  97. }
  98. func (s commonDialect) HasColumn(tableName string, columnName string) bool {
  99. var count int
  100. s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count)
  101. return count > 0
  102. }
  103. func (s commonDialect) CurrentDatabase() (name string) {
  104. s.db.QueryRow("SELECT DATABASE()").Scan(&name)
  105. return
  106. }
  107. func (commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
  108. if limit != nil {
  109. if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit > 0 {
  110. sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
  111. }
  112. }
  113. if offset != nil {
  114. if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset > 0 {
  115. sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
  116. }
  117. }
  118. return
  119. }
  120. func (commonDialect) SelectFromDummyTable() string {
  121. return ""
  122. }
  123. func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
  124. return ""
  125. }
  126. func (DefaultForeignKeyNamer) BuildForeignKeyName(tableName, field, dest string) string {
  127. keyName := fmt.Sprintf("%s_%s_%s_foreign", tableName, field, dest)
  128. keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_")
  129. return keyName
  130. }