utils.go 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. package gorm
  2. import (
  3. "bytes"
  4. "database/sql/driver"
  5. "fmt"
  6. "reflect"
  7. "regexp"
  8. "runtime"
  9. "strings"
  10. "sync"
  11. "time"
  12. )
  13. // NowFunc returns current time, this function is exported in order to be able
  14. // to give the flexibility to the developer to customize it according to their
  15. // needs, e.g:
  16. // gorm.NowFunc = func() time.Time {
  17. // return time.Now().UTC()
  18. // }
  19. var NowFunc = func() time.Time {
  20. return time.Now()
  21. }
  22. // Copied from golint
  23. var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
  24. var commonInitialismsReplacer *strings.Replacer
  25. var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`)
  26. var goTestRegexp = regexp.MustCompile(`jinzhu/gorm/.*test.go`)
  27. func init() {
  28. var commonInitialismsForReplacer []string
  29. for _, initialism := range commonInitialisms {
  30. commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism)))
  31. }
  32. commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
  33. }
  34. type safeMap struct {
  35. m map[string]string
  36. l *sync.RWMutex
  37. }
  38. func (s *safeMap) Set(key string, value string) {
  39. s.l.Lock()
  40. defer s.l.Unlock()
  41. s.m[key] = value
  42. }
  43. func (s *safeMap) Get(key string) string {
  44. s.l.RLock()
  45. defer s.l.RUnlock()
  46. return s.m[key]
  47. }
  48. func newSafeMap() *safeMap {
  49. return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)}
  50. }
  51. var smap = newSafeMap()
  52. type strCase bool
  53. const (
  54. lower strCase = false
  55. upper strCase = true
  56. )
  57. // ToDBName convert string to db name
  58. func ToDBName(name string) string {
  59. if v := smap.Get(name); v != "" {
  60. return v
  61. }
  62. if name == "" {
  63. return ""
  64. }
  65. var (
  66. value = commonInitialismsReplacer.Replace(name)
  67. buf = bytes.NewBufferString("")
  68. lastCase, currCase, nextCase strCase
  69. )
  70. for i, v := range value[:len(value)-1] {
  71. nextCase = strCase(value[i+1] >= 'A' && value[i+1] <= 'Z')
  72. if i > 0 {
  73. if currCase == upper {
  74. if lastCase == upper && nextCase == upper {
  75. buf.WriteRune(v)
  76. } else {
  77. if value[i-1] != '_' && value[i+1] != '_' {
  78. buf.WriteRune('_')
  79. }
  80. buf.WriteRune(v)
  81. }
  82. } else {
  83. buf.WriteRune(v)
  84. }
  85. } else {
  86. currCase = upper
  87. buf.WriteRune(v)
  88. }
  89. lastCase = currCase
  90. currCase = nextCase
  91. }
  92. buf.WriteByte(value[len(value)-1])
  93. s := strings.ToLower(buf.String())
  94. smap.Set(name, s)
  95. return s
  96. }
  97. // SQL expression
  98. type expr struct {
  99. expr string
  100. args []interface{}
  101. }
  102. // Expr generate raw SQL expression, for example:
  103. // DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100))
  104. func Expr(expression string, args ...interface{}) *expr {
  105. return &expr{expr: expression, args: args}
  106. }
  107. func indirect(reflectValue reflect.Value) reflect.Value {
  108. for reflectValue.Kind() == reflect.Ptr {
  109. reflectValue = reflectValue.Elem()
  110. }
  111. return reflectValue
  112. }
  113. func toQueryMarks(primaryValues [][]interface{}) string {
  114. var results []string
  115. for _, primaryValue := range primaryValues {
  116. var marks []string
  117. for _, _ = range primaryValue {
  118. marks = append(marks, "?")
  119. }
  120. if len(marks) > 1 {
  121. results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ",")))
  122. } else {
  123. results = append(results, strings.Join(marks, ""))
  124. }
  125. }
  126. return strings.Join(results, ",")
  127. }
  128. func toQueryCondition(scope *Scope, columns []string) string {
  129. var newColumns []string
  130. for _, column := range columns {
  131. newColumns = append(newColumns, scope.Quote(column))
  132. }
  133. if len(columns) > 1 {
  134. return fmt.Sprintf("(%v)", strings.Join(newColumns, ","))
  135. }
  136. return strings.Join(newColumns, ",")
  137. }
  138. func toQueryValues(values [][]interface{}) (results []interface{}) {
  139. for _, value := range values {
  140. for _, v := range value {
  141. results = append(results, v)
  142. }
  143. }
  144. return
  145. }
  146. func fileWithLineNum() string {
  147. for i := 2; i < 15; i++ {
  148. _, file, line, ok := runtime.Caller(i)
  149. if ok && (!goSrcRegexp.MatchString(file) || goTestRegexp.MatchString(file)) {
  150. return fmt.Sprintf("%v:%v", file, line)
  151. }
  152. }
  153. return ""
  154. }
  155. func isBlank(value reflect.Value) bool {
  156. return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface())
  157. }
  158. func toSearchableMap(attrs ...interface{}) (result interface{}) {
  159. if len(attrs) > 1 {
  160. if str, ok := attrs[0].(string); ok {
  161. result = map[string]interface{}{str: attrs[1]}
  162. }
  163. } else if len(attrs) == 1 {
  164. if attr, ok := attrs[0].(map[string]interface{}); ok {
  165. result = attr
  166. }
  167. if attr, ok := attrs[0].(interface{}); ok {
  168. result = attr
  169. }
  170. }
  171. return
  172. }
  173. func equalAsString(a interface{}, b interface{}) bool {
  174. return toString(a) == toString(b)
  175. }
  176. func toString(str interface{}) string {
  177. if values, ok := str.([]interface{}); ok {
  178. var results []string
  179. for _, value := range values {
  180. results = append(results, toString(value))
  181. }
  182. return strings.Join(results, "_")
  183. } else if bytes, ok := str.([]byte); ok {
  184. return string(bytes)
  185. } else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() {
  186. return fmt.Sprintf("%v", reflectValue.Interface())
  187. }
  188. return ""
  189. }
  190. func makeSlice(elemType reflect.Type) interface{} {
  191. if elemType.Kind() == reflect.Slice {
  192. elemType = elemType.Elem()
  193. }
  194. sliceType := reflect.SliceOf(elemType)
  195. slice := reflect.New(sliceType)
  196. slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0))
  197. return slice.Interface()
  198. }
  199. func strInSlice(a string, list []string) bool {
  200. for _, b := range list {
  201. if b == a {
  202. return true
  203. }
  204. }
  205. return false
  206. }
  207. // getValueFromFields return given fields's value
  208. func getValueFromFields(value reflect.Value, fieldNames []string) (results []interface{}) {
  209. // If value is a nil pointer, Indirect returns a zero Value!
  210. // Therefor we need to check for a zero value,
  211. // as FieldByName could panic
  212. if indirectValue := reflect.Indirect(value); indirectValue.IsValid() {
  213. for _, fieldName := range fieldNames {
  214. if fieldValue := indirectValue.FieldByName(fieldName); fieldValue.IsValid() {
  215. result := fieldValue.Interface()
  216. if r, ok := result.(driver.Valuer); ok {
  217. result, _ = r.Value()
  218. }
  219. results = append(results, result)
  220. }
  221. }
  222. }
  223. return
  224. }
  225. func addExtraSpaceIfExist(str string) string {
  226. if str != "" {
  227. return " " + str
  228. }
  229. return ""
  230. }