|
@@ -1,15 +1,16 @@
|
|
package gorm
|
|
package gorm
|
|
|
|
|
|
import (
|
|
import (
|
|
- "bytes"
|
|
|
|
"database/sql"
|
|
"database/sql"
|
|
"database/sql/driver"
|
|
"database/sql/driver"
|
|
"errors"
|
|
"errors"
|
|
"fmt"
|
|
"fmt"
|
|
- "reflect"
|
|
|
|
"regexp"
|
|
"regexp"
|
|
|
|
+ "strconv"
|
|
"strings"
|
|
"strings"
|
|
"time"
|
|
"time"
|
|
|
|
+
|
|
|
|
+ "reflect"
|
|
)
|
|
)
|
|
|
|
|
|
// Scope contain current operation's information when you perform any operation on the database
|
|
// Scope contain current operation's information when you perform any operation on the database
|
|
@@ -115,9 +116,6 @@ func (scope *Scope) Fields() []*Field {
|
|
if isStruct {
|
|
if isStruct {
|
|
fieldValue := indirectScopeValue
|
|
fieldValue := indirectScopeValue
|
|
for _, name := range structField.Names {
|
|
for _, name := range structField.Names {
|
|
- if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() {
|
|
|
|
- fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
|
|
|
|
- }
|
|
|
|
fieldValue = reflect.Indirect(fieldValue).FieldByName(name)
|
|
fieldValue = reflect.Indirect(fieldValue).FieldByName(name)
|
|
}
|
|
}
|
|
fields = append(fields, &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)})
|
|
fields = append(fields, &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)})
|
|
@@ -462,7 +460,7 @@ func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) {
|
|
var (
|
|
var (
|
|
columnRegexp = regexp.MustCompile("^[a-zA-Z\\d]+(\\.[a-zA-Z\\d]+)*$") // only match string like `name`, `users.name`
|
|
columnRegexp = regexp.MustCompile("^[a-zA-Z\\d]+(\\.[a-zA-Z\\d]+)*$") // only match string like `name`, `users.name`
|
|
isNumberRegexp = regexp.MustCompile("^\\s*\\d+\\s*$") // match if string is number
|
|
isNumberRegexp = regexp.MustCompile("^\\s*\\d+\\s*$") // match if string is number
|
|
- comparisonRegexp = regexp.MustCompile("(?i) (=|<>|(>|<)(=?)|LIKE|IS|IN) ")
|
|
|
|
|
|
+ comparisonRegexp = regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ")
|
|
countingQueryRegexp = regexp.MustCompile("(?i)^count(.+)$")
|
|
countingQueryRegexp = regexp.MustCompile("(?i)^count(.+)$")
|
|
)
|
|
)
|
|
|
|
|
|
@@ -523,143 +521,134 @@ func (scope *Scope) primaryCondition(value interface{}) string {
|
|
return fmt.Sprintf("(%v.%v = %v)", scope.QuotedTableName(), scope.Quote(scope.PrimaryKey()), value)
|
|
return fmt.Sprintf("(%v.%v = %v)", scope.QuotedTableName(), scope.Quote(scope.PrimaryKey()), value)
|
|
}
|
|
}
|
|
|
|
|
|
-func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) (str string) {
|
|
|
|
- var (
|
|
|
|
- quotedTableName = scope.QuotedTableName()
|
|
|
|
- quotedPrimaryKey = scope.Quote(scope.PrimaryKey())
|
|
|
|
- equalSQL = "="
|
|
|
|
- inSQL = "IN"
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- // If building not conditions
|
|
|
|
- if !include {
|
|
|
|
- equalSQL = "<>"
|
|
|
|
- inSQL = "NOT IN"
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
|
|
+func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str string) {
|
|
switch value := clause["query"].(type) {
|
|
switch value := clause["query"].(type) {
|
|
- case sql.NullInt64:
|
|
|
|
- return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value.Int64)
|
|
|
|
- case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
|
|
|
- return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value)
|
|
|
|
- case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}:
|
|
|
|
- if !include && reflect.ValueOf(value).Len() == 0 {
|
|
|
|
- return
|
|
|
|
- }
|
|
|
|
- str = fmt.Sprintf("(%v.%v %s (?))", quotedTableName, quotedPrimaryKey, inSQL)
|
|
|
|
- clause["args"] = []interface{}{value}
|
|
|
|
case string:
|
|
case string:
|
|
if isNumberRegexp.MatchString(value) {
|
|
if isNumberRegexp.MatchString(value) {
|
|
- return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, scope.AddToVars(value))
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- if value != "" {
|
|
|
|
- if !include {
|
|
|
|
- if comparisonRegexp.MatchString(value) {
|
|
|
|
- str = fmt.Sprintf("NOT (%v)", value)
|
|
|
|
- } else {
|
|
|
|
- str = fmt.Sprintf("(%v.%v NOT IN (?))", quotedTableName, scope.Quote(value))
|
|
|
|
- }
|
|
|
|
- } else {
|
|
|
|
- str = fmt.Sprintf("(%v)", value)
|
|
|
|
- }
|
|
|
|
|
|
+ return scope.primaryCondition(scope.AddToVars(value))
|
|
|
|
+ } else if value != "" {
|
|
|
|
+ str = fmt.Sprintf("(%v)", value)
|
|
}
|
|
}
|
|
|
|
+ case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
|
|
|
|
+ return scope.primaryCondition(scope.AddToVars(value))
|
|
|
|
+ case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}:
|
|
|
|
+ str = fmt.Sprintf("(%v.%v IN (?))", scope.QuotedTableName(), scope.Quote(scope.PrimaryKey()))
|
|
|
|
+ clause["args"] = []interface{}{value}
|
|
case map[string]interface{}:
|
|
case map[string]interface{}:
|
|
var sqls []string
|
|
var sqls []string
|
|
for key, value := range value {
|
|
for key, value := range value {
|
|
if value != nil {
|
|
if value != nil {
|
|
- sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(key), equalSQL, scope.AddToVars(value)))
|
|
|
|
|
|
+ sqls = append(sqls, fmt.Sprintf("(%v.%v = %v)", scope.QuotedTableName(), scope.Quote(key), scope.AddToVars(value)))
|
|
} else {
|
|
} else {
|
|
- if !include {
|
|
|
|
- sqls = append(sqls, fmt.Sprintf("(%v.%v IS NOT NULL)", quotedTableName, scope.Quote(key)))
|
|
|
|
- } else {
|
|
|
|
- sqls = append(sqls, fmt.Sprintf("(%v.%v IS NULL)", quotedTableName, scope.Quote(key)))
|
|
|
|
- }
|
|
|
|
|
|
+ sqls = append(sqls, fmt.Sprintf("(%v.%v IS NULL)", scope.QuotedTableName(), scope.Quote(key)))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return strings.Join(sqls, " AND ")
|
|
return strings.Join(sqls, " AND ")
|
|
case interface{}:
|
|
case interface{}:
|
|
var sqls []string
|
|
var sqls []string
|
|
newScope := scope.New(value)
|
|
newScope := scope.New(value)
|
|
-
|
|
|
|
- if len(newScope.Fields()) == 0 {
|
|
|
|
- scope.Err(fmt.Errorf("invalid query condition: %v", value))
|
|
|
|
- return
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
for _, field := range newScope.Fields() {
|
|
for _, field := range newScope.Fields() {
|
|
if !field.IsIgnored && !field.IsBlank {
|
|
if !field.IsIgnored && !field.IsBlank {
|
|
- sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface())))
|
|
|
|
|
|
+ sqls = append(sqls, fmt.Sprintf("(%v.%v = %v)", scope.QuotedTableName(), scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return strings.Join(sqls, " AND ")
|
|
return strings.Join(sqls, " AND ")
|
|
- default:
|
|
|
|
- scope.Err(fmt.Errorf("invalid query condition: %v", value))
|
|
|
|
- return
|
|
|
|
}
|
|
}
|
|
|
|
|
|
- replacements := []string{}
|
|
|
|
args := clause["args"].([]interface{})
|
|
args := clause["args"].([]interface{})
|
|
for _, arg := range args {
|
|
for _, arg := range args {
|
|
- var err error
|
|
|
|
switch reflect.ValueOf(arg).Kind() {
|
|
switch reflect.ValueOf(arg).Kind() {
|
|
case reflect.Slice: // For where("id in (?)", []int64{1,2})
|
|
case reflect.Slice: // For where("id in (?)", []int64{1,2})
|
|
- if scanner, ok := interface{}(arg).(driver.Valuer); ok {
|
|
|
|
- arg, err = scanner.Value()
|
|
|
|
- replacements = append(replacements, scope.AddToVars(arg))
|
|
|
|
- } else if b, ok := arg.([]byte); ok {
|
|
|
|
- replacements = append(replacements, scope.AddToVars(b))
|
|
|
|
- } else if as, ok := arg.([][]interface{}); ok {
|
|
|
|
- var tempMarks []string
|
|
|
|
- for _, a := range as {
|
|
|
|
- var arrayMarks []string
|
|
|
|
- for _, v := range a {
|
|
|
|
- arrayMarks = append(arrayMarks, scope.AddToVars(v))
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- if len(arrayMarks) > 0 {
|
|
|
|
- tempMarks = append(tempMarks, fmt.Sprintf("(%v)", strings.Join(arrayMarks, ",")))
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- if len(tempMarks) > 0 {
|
|
|
|
- replacements = append(replacements, strings.Join(tempMarks, ","))
|
|
|
|
- }
|
|
|
|
|
|
+ if bytes, ok := arg.([]byte); ok {
|
|
|
|
+ str = strings.Replace(str, "?", scope.AddToVars(bytes), 1)
|
|
} else if values := reflect.ValueOf(arg); values.Len() > 0 {
|
|
} else if values := reflect.ValueOf(arg); values.Len() > 0 {
|
|
var tempMarks []string
|
|
var tempMarks []string
|
|
for i := 0; i < values.Len(); i++ {
|
|
for i := 0; i < values.Len(); i++ {
|
|
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
|
|
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
|
|
}
|
|
}
|
|
- replacements = append(replacements, strings.Join(tempMarks, ","))
|
|
|
|
|
|
+ str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
|
|
} else {
|
|
} else {
|
|
- replacements = append(replacements, scope.AddToVars(Expr("NULL")))
|
|
|
|
|
|
+ str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1)
|
|
}
|
|
}
|
|
default:
|
|
default:
|
|
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
|
|
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
|
|
- arg, err = valuer.Value()
|
|
|
|
|
|
+ arg, _ = valuer.Value()
|
|
}
|
|
}
|
|
|
|
|
|
- replacements = append(replacements, scope.AddToVars(arg))
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- if err != nil {
|
|
|
|
- scope.Err(err)
|
|
|
|
|
|
+ str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
+ return
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) {
|
|
|
|
+ var notEqualSQL string
|
|
|
|
+ var primaryKey = scope.PrimaryKey()
|
|
|
|
|
|
- buff := bytes.NewBuffer([]byte{})
|
|
|
|
- i := 0
|
|
|
|
- for _, s := range str {
|
|
|
|
- if s == '?' {
|
|
|
|
- buff.WriteString(replacements[i])
|
|
|
|
- i++
|
|
|
|
|
|
+ switch value := clause["query"].(type) {
|
|
|
|
+ case string:
|
|
|
|
+ if isNumberRegexp.MatchString(value) {
|
|
|
|
+ id, _ := strconv.Atoi(value)
|
|
|
|
+ return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id)
|
|
|
|
+ } else if comparisonRegexp.MatchString(value) {
|
|
|
|
+ str = fmt.Sprintf(" NOT (%v) ", value)
|
|
|
|
+ notEqualSQL = fmt.Sprintf("NOT (%v)", value)
|
|
|
|
+ } else {
|
|
|
|
+ str = fmt.Sprintf("(%v.%v NOT IN (?))", scope.QuotedTableName(), scope.Quote(value))
|
|
|
|
+ notEqualSQL = fmt.Sprintf("(%v.%v <> ?)", scope.QuotedTableName(), scope.Quote(value))
|
|
|
|
+ }
|
|
|
|
+ case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
|
|
|
|
+ return fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(primaryKey), value)
|
|
|
|
+ case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string:
|
|
|
|
+ if reflect.ValueOf(value).Len() > 0 {
|
|
|
|
+ str = fmt.Sprintf("(%v.%v NOT IN (?))", scope.QuotedTableName(), scope.Quote(primaryKey))
|
|
|
|
+ clause["args"] = []interface{}{value}
|
|
} else {
|
|
} else {
|
|
- buff.WriteRune(s)
|
|
|
|
|
|
+ return ""
|
|
|
|
+ }
|
|
|
|
+ case map[string]interface{}:
|
|
|
|
+ var sqls []string
|
|
|
|
+ for key, value := range value {
|
|
|
|
+ if value != nil {
|
|
|
|
+ sqls = append(sqls, fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(key), scope.AddToVars(value)))
|
|
|
|
+ } else {
|
|
|
|
+ sqls = append(sqls, fmt.Sprintf("(%v.%v IS NOT NULL)", scope.QuotedTableName(), scope.Quote(key)))
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ return strings.Join(sqls, " AND ")
|
|
|
|
+ case interface{}:
|
|
|
|
+ var sqls []string
|
|
|
|
+ var newScope = scope.New(value)
|
|
|
|
+ for _, field := range newScope.Fields() {
|
|
|
|
+ if !field.IsBlank {
|
|
|
|
+ sqls = append(sqls, fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
|
|
|
+ }
|
|
}
|
|
}
|
|
|
|
+ return strings.Join(sqls, " AND ")
|
|
}
|
|
}
|
|
|
|
|
|
- str = buff.String()
|
|
|
|
-
|
|
|
|
|
|
+ args := clause["args"].([]interface{})
|
|
|
|
+ for _, arg := range args {
|
|
|
|
+ switch reflect.ValueOf(arg).Kind() {
|
|
|
|
+ case reflect.Slice: // For where("id in (?)", []int64{1,2})
|
|
|
|
+ if bytes, ok := arg.([]byte); ok {
|
|
|
|
+ str = strings.Replace(str, "?", scope.AddToVars(bytes), 1)
|
|
|
|
+ } else if values := reflect.ValueOf(arg); values.Len() > 0 {
|
|
|
|
+ var tempMarks []string
|
|
|
|
+ for i := 0; i < values.Len(); i++ {
|
|
|
|
+ tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
|
|
|
|
+ }
|
|
|
|
+ str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
|
|
|
|
+ } else {
|
|
|
|
+ str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1)
|
|
|
|
+ }
|
|
|
|
+ default:
|
|
|
|
+ if scanner, ok := interface{}(arg).(driver.Valuer); ok {
|
|
|
|
+ arg, _ = scanner.Value()
|
|
|
|
+ }
|
|
|
|
+ str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1)
|
|
|
|
+ }
|
|
|
|
+ }
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
@@ -672,7 +661,6 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string)
|
|
}
|
|
}
|
|
|
|
|
|
args := clause["args"].([]interface{})
|
|
args := clause["args"].([]interface{})
|
|
- replacements := []string{}
|
|
|
|
for _, arg := range args {
|
|
for _, arg := range args {
|
|
switch reflect.ValueOf(arg).Kind() {
|
|
switch reflect.ValueOf(arg).Kind() {
|
|
case reflect.Slice:
|
|
case reflect.Slice:
|
|
@@ -681,28 +669,14 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string)
|
|
for i := 0; i < values.Len(); i++ {
|
|
for i := 0; i < values.Len(); i++ {
|
|
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
|
|
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
|
|
}
|
|
}
|
|
- replacements = append(replacements, strings.Join(tempMarks, ","))
|
|
|
|
|
|
+ str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
|
|
default:
|
|
default:
|
|
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
|
|
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
|
|
arg, _ = valuer.Value()
|
|
arg, _ = valuer.Value()
|
|
}
|
|
}
|
|
- replacements = append(replacements, scope.AddToVars(arg))
|
|
|
|
|
|
+ str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
-
|
|
|
|
- buff := bytes.NewBuffer([]byte{})
|
|
|
|
- i := 0
|
|
|
|
- for pos := range str {
|
|
|
|
- if str[pos] == '?' {
|
|
|
|
- buff.WriteString(replacements[i])
|
|
|
|
- i++
|
|
|
|
- } else {
|
|
|
|
- buff.WriteByte(str[pos])
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- str = buff.String()
|
|
|
|
-
|
|
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
@@ -726,19 +700,19 @@ func (scope *Scope) whereSQL() (sql string) {
|
|
}
|
|
}
|
|
|
|
|
|
for _, clause := range scope.Search.whereConditions {
|
|
for _, clause := range scope.Search.whereConditions {
|
|
- if sql := scope.buildCondition(clause, true); sql != "" {
|
|
|
|
|
|
+ if sql := scope.buildWhereCondition(clause); sql != "" {
|
|
andConditions = append(andConditions, sql)
|
|
andConditions = append(andConditions, sql)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
for _, clause := range scope.Search.orConditions {
|
|
for _, clause := range scope.Search.orConditions {
|
|
- if sql := scope.buildCondition(clause, true); sql != "" {
|
|
|
|
|
|
+ if sql := scope.buildWhereCondition(clause); sql != "" {
|
|
orConditions = append(orConditions, sql)
|
|
orConditions = append(orConditions, sql)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
for _, clause := range scope.Search.notConditions {
|
|
for _, clause := range scope.Search.notConditions {
|
|
- if sql := scope.buildCondition(clause, false); sql != "" {
|
|
|
|
|
|
+ if sql := scope.buildNotCondition(clause); sql != "" {
|
|
andConditions = append(andConditions, sql)
|
|
andConditions = append(andConditions, sql)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -812,7 +786,7 @@ func (scope *Scope) havingSQL() string {
|
|
|
|
|
|
var andConditions []string
|
|
var andConditions []string
|
|
for _, clause := range scope.Search.havingConditions {
|
|
for _, clause := range scope.Search.havingConditions {
|
|
- if sql := scope.buildCondition(clause, true); sql != "" {
|
|
|
|
|
|
+ if sql := scope.buildWhereCondition(clause); sql != "" {
|
|
andConditions = append(andConditions, sql)
|
|
andConditions = append(andConditions, sql)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -828,7 +802,7 @@ func (scope *Scope) havingSQL() string {
|
|
func (scope *Scope) joinsSQL() string {
|
|
func (scope *Scope) joinsSQL() string {
|
|
var joinConditions []string
|
|
var joinConditions []string
|
|
for _, clause := range scope.Search.joinConditions {
|
|
for _, clause := range scope.Search.joinConditions {
|
|
- if sql := scope.buildCondition(clause, true); sql != "" {
|
|
|
|
|
|
+ if sql := scope.buildWhereCondition(clause); sql != "" {
|
|
joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")"))
|
|
joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")"))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -950,34 +924,14 @@ func (scope *Scope) initialize() *Scope {
|
|
return scope
|
|
return scope
|
|
}
|
|
}
|
|
|
|
|
|
-func (scope *Scope) isQueryForColumn(query interface{}, column string) bool {
|
|
|
|
- queryStr := strings.ToLower(fmt.Sprint(query))
|
|
|
|
- if queryStr == column {
|
|
|
|
- return true
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- if strings.HasSuffix(queryStr, "as "+column) {
|
|
|
|
- return true
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- if strings.HasSuffix(queryStr, "as "+scope.Quote(column)) {
|
|
|
|
- return true
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- return false
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
func (scope *Scope) pluck(column string, value interface{}) *Scope {
|
|
func (scope *Scope) pluck(column string, value interface{}) *Scope {
|
|
dest := reflect.Indirect(reflect.ValueOf(value))
|
|
dest := reflect.Indirect(reflect.ValueOf(value))
|
|
|
|
+ scope.Search.Select(column)
|
|
if dest.Kind() != reflect.Slice {
|
|
if dest.Kind() != reflect.Slice {
|
|
scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind()))
|
|
scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind()))
|
|
return scope
|
|
return scope
|
|
}
|
|
}
|
|
|
|
|
|
- if query, ok := scope.Search.selects["query"]; !ok || !scope.isQueryForColumn(query, column) {
|
|
|
|
- scope.Search.Select(column)
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
rows, err := scope.rows()
|
|
rows, err := scope.rows()
|
|
if scope.Err(err) == nil {
|
|
if scope.Err(err) == nil {
|
|
defer rows.Close()
|
|
defer rows.Close()
|
|
@@ -996,12 +950,7 @@ func (scope *Scope) pluck(column string, value interface{}) *Scope {
|
|
|
|
|
|
func (scope *Scope) count(value interface{}) *Scope {
|
|
func (scope *Scope) count(value interface{}) *Scope {
|
|
if query, ok := scope.Search.selects["query"]; !ok || !countingQueryRegexp.MatchString(fmt.Sprint(query)) {
|
|
if query, ok := scope.Search.selects["query"]; !ok || !countingQueryRegexp.MatchString(fmt.Sprint(query)) {
|
|
- if len(scope.Search.group) != 0 {
|
|
|
|
- scope.Search.Select("count(*) FROM ( SELECT count(*) as name ")
|
|
|
|
- scope.Search.group += " ) AS count_table"
|
|
|
|
- } else {
|
|
|
|
- scope.Search.Select("count(*)")
|
|
|
|
- }
|
|
|
|
|
|
+ scope.Search.Select("count(*)")
|
|
}
|
|
}
|
|
scope.Search.ignoreOrderQuery = true
|
|
scope.Search.ignoreOrderQuery = true
|
|
scope.Err(scope.row().Scan(value))
|
|
scope.Err(scope.row().Scan(value))
|
|
@@ -1044,6 +993,18 @@ func (scope *Scope) changeableField(field *Field) bool {
|
|
return true
|
|
return true
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+func (scope *Scope) shouldSaveAssociations() bool {
|
|
|
|
+ if saveAssociations, ok := scope.Get("gorm:save_associations"); ok {
|
|
|
|
+ if v, ok := saveAssociations.(bool); ok && !v {
|
|
|
|
+ return false
|
|
|
|
+ }
|
|
|
|
+ if v, ok := saveAssociations.(string); ok && (v != "skip") {
|
|
|
|
+ return false
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ return true && !scope.HasError()
|
|
|
|
+}
|
|
|
|
+
|
|
func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
|
|
func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
|
|
toScope := scope.db.NewScope(value)
|
|
toScope := scope.db.NewScope(value)
|
|
tx := scope.db.Set("gorm:association:source", scope.Value)
|
|
tx := scope.db.Set("gorm:association:source", scope.Value)
|
|
@@ -1098,7 +1059,7 @@ func (scope *Scope) getTableOptions() string {
|
|
if !ok {
|
|
if !ok {
|
|
return ""
|
|
return ""
|
|
}
|
|
}
|
|
- return " " + tableOptions.(string)
|
|
|
|
|
|
+ return tableOptions.(string)
|
|
}
|
|
}
|
|
|
|
|
|
func (scope *Scope) createJoinTable(field *StructField) {
|
|
func (scope *Scope) createJoinTable(field *StructField) {
|
|
@@ -1131,7 +1092,7 @@ func (scope *Scope) createJoinTable(field *StructField) {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
- scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v))%s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error)
|
|
|
|
|
|
+ scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v)) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error)
|
|
}
|
|
}
|
|
scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler)
|
|
scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler)
|
|
}
|
|
}
|
|
@@ -1166,19 +1127,19 @@ func (scope *Scope) createTable() *Scope {
|
|
primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ","))
|
|
primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ","))
|
|
}
|
|
}
|
|
|
|
|
|
- scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)%s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec()
|
|
|
|
|
|
+ scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v) %s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec()
|
|
|
|
|
|
scope.autoIndex()
|
|
scope.autoIndex()
|
|
return scope
|
|
return scope
|
|
}
|
|
}
|
|
|
|
|
|
func (scope *Scope) dropTable() *Scope {
|
|
func (scope *Scope) dropTable() *Scope {
|
|
- scope.Raw(fmt.Sprintf("DROP TABLE %v%s", scope.QuotedTableName(), scope.getTableOptions())).Exec()
|
|
|
|
|
|
+ scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec()
|
|
return scope
|
|
return scope
|
|
}
|
|
}
|
|
|
|
|
|
func (scope *Scope) modifyColumn(column string, typ string) {
|
|
func (scope *Scope) modifyColumn(column string, typ string) {
|
|
- scope.db.AddError(scope.Dialect().ModifyColumn(scope.QuotedTableName(), scope.Quote(column), typ))
|
|
|
|
|
|
+ scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec()
|
|
}
|
|
}
|
|
|
|
|
|
func (scope *Scope) dropColumn(column string) {
|
|
func (scope *Scope) dropColumn(column string) {
|
|
@@ -1204,8 +1165,7 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
|
|
}
|
|
}
|
|
|
|
|
|
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
|
|
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
|
|
- // Compatible with old generated key
|
|
|
|
- keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign")
|
|
|
|
|
|
+ keyName := scope.Dialect().BuildForeignKeyName(scope.TableName(), field, dest)
|
|
|
|
|
|
if scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
|
|
if scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
|
|
return
|
|
return
|
|
@@ -1214,16 +1174,6 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on
|
|
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec()
|
|
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec()
|
|
}
|
|
}
|
|
|
|
|
|
-func (scope *Scope) removeForeignKey(field string, dest string) {
|
|
|
|
- keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest)
|
|
|
|
-
|
|
|
|
- if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
|
|
|
|
- return
|
|
|
|
- }
|
|
|
|
- var query = `ALTER TABLE %s DROP CONSTRAINT %s;`
|
|
|
|
- scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec()
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
func (scope *Scope) removeIndex(indexName string) {
|
|
func (scope *Scope) removeIndex(indexName string) {
|
|
scope.Dialect().RemoveIndex(scope.TableName(), indexName)
|
|
scope.Dialect().RemoveIndex(scope.TableName(), indexName)
|
|
}
|
|
}
|
|
@@ -1259,7 +1209,7 @@ func (scope *Scope) autoIndex() *Scope {
|
|
|
|
|
|
for _, name := range names {
|
|
for _, name := range names {
|
|
if name == "INDEX" || name == "" {
|
|
if name == "INDEX" || name == "" {
|
|
- name = scope.Dialect().BuildKeyName("idx", scope.TableName(), field.DBName)
|
|
|
|
|
|
+ name = fmt.Sprintf("idx_%v_%v", scope.TableName(), field.DBName)
|
|
}
|
|
}
|
|
indexes[name] = append(indexes[name], field.DBName)
|
|
indexes[name] = append(indexes[name], field.DBName)
|
|
}
|
|
}
|
|
@@ -1270,7 +1220,7 @@ func (scope *Scope) autoIndex() *Scope {
|
|
|
|
|
|
for _, name := range names {
|
|
for _, name := range names {
|
|
if name == "UNIQUE_INDEX" || name == "" {
|
|
if name == "UNIQUE_INDEX" || name == "" {
|
|
- name = scope.Dialect().BuildKeyName("uix", scope.TableName(), field.DBName)
|
|
|
|
|
|
+ name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName)
|
|
}
|
|
}
|
|
uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName)
|
|
uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName)
|
|
}
|
|
}
|
|
@@ -1278,15 +1228,11 @@ func (scope *Scope) autoIndex() *Scope {
|
|
}
|
|
}
|
|
|
|
|
|
for name, columns := range indexes {
|
|
for name, columns := range indexes {
|
|
- if db := scope.NewDB().Table(scope.TableName()).Model(scope.Value).AddIndex(name, columns...); db.Error != nil {
|
|
|
|
- scope.db.AddError(db.Error)
|
|
|
|
- }
|
|
|
|
|
|
+ scope.NewDB().Model(scope.Value).AddIndex(name, columns...)
|
|
}
|
|
}
|
|
|
|
|
|
for name, columns := range uniqueIndexes {
|
|
for name, columns := range uniqueIndexes {
|
|
- if db := scope.NewDB().Table(scope.TableName()).Model(scope.Value).AddUniqueIndex(name, columns...); db.Error != nil {
|
|
|
|
- scope.db.AddError(db.Error)
|
|
|
|
- }
|
|
|
|
|
|
+ scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...)
|
|
}
|
|
}
|
|
|
|
|
|
return scope
|
|
return scope
|