callback_query_preload.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. package gorm
  2. import (
  3. "errors"
  4. "fmt"
  5. "reflect"
  6. "strconv"
  7. "strings"
  8. )
  9. // preloadCallback used to preload associations
  10. func preloadCallback(scope *Scope) {
  11. if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
  12. return
  13. }
  14. if _, ok := scope.Get("gorm:auto_preload"); ok {
  15. autoPreload(scope)
  16. }
  17. if scope.Search.preload == nil || scope.HasError() {
  18. return
  19. }
  20. var (
  21. preloadedMap = map[string]bool{}
  22. fields = scope.Fields()
  23. )
  24. for _, preload := range scope.Search.preload {
  25. var (
  26. preloadFields = strings.Split(preload.schema, ".")
  27. currentScope = scope
  28. currentFields = fields
  29. )
  30. for idx, preloadField := range preloadFields {
  31. var currentPreloadConditions []interface{}
  32. if currentScope == nil {
  33. continue
  34. }
  35. // if not preloaded
  36. if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] {
  37. // assign search conditions to last preload
  38. if idx == len(preloadFields)-1 {
  39. currentPreloadConditions = preload.conditions
  40. }
  41. for _, field := range currentFields {
  42. if field.Name != preloadField || field.Relationship == nil {
  43. continue
  44. }
  45. switch field.Relationship.Kind {
  46. case "has_one":
  47. currentScope.handleHasOnePreload(field, currentPreloadConditions)
  48. case "has_many":
  49. currentScope.handleHasManyPreload(field, currentPreloadConditions)
  50. case "belongs_to":
  51. currentScope.handleBelongsToPreload(field, currentPreloadConditions)
  52. case "many_to_many":
  53. currentScope.handleManyToManyPreload(field, currentPreloadConditions)
  54. default:
  55. scope.Err(errors.New("unsupported relation"))
  56. }
  57. preloadedMap[preloadKey] = true
  58. break
  59. }
  60. if !preloadedMap[preloadKey] {
  61. scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType))
  62. return
  63. }
  64. }
  65. // preload next level
  66. if idx < len(preloadFields)-1 {
  67. currentScope = currentScope.getColumnAsScope(preloadField)
  68. if currentScope != nil {
  69. currentFields = currentScope.Fields()
  70. }
  71. }
  72. }
  73. }
  74. }
  75. func autoPreload(scope *Scope) {
  76. for _, field := range scope.Fields() {
  77. if field.Relationship == nil {
  78. continue
  79. }
  80. if val, ok := field.TagSettings["PRELOAD"]; ok {
  81. if preload, err := strconv.ParseBool(val); err != nil {
  82. scope.Err(errors.New("invalid preload option"))
  83. return
  84. } else if !preload {
  85. continue
  86. }
  87. }
  88. scope.Search.Preload(field.Name)
  89. }
  90. }
  91. func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) {
  92. var (
  93. preloadDB = scope.NewDB()
  94. preloadConditions []interface{}
  95. )
  96. for _, condition := range conditions {
  97. if scopes, ok := condition.(func(*DB) *DB); ok {
  98. preloadDB = scopes(preloadDB)
  99. } else {
  100. preloadConditions = append(preloadConditions, condition)
  101. }
  102. }
  103. return preloadDB, preloadConditions
  104. }
  105. // handleHasOnePreload used to preload has one associations
  106. func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
  107. relation := field.Relationship
  108. // get relations's primary keys
  109. primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
  110. if len(primaryKeys) == 0 {
  111. return
  112. }
  113. // preload conditions
  114. preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
  115. // find relations
  116. query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
  117. values := toQueryValues(primaryKeys)
  118. if relation.PolymorphicType != "" {
  119. query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
  120. values = append(values, relation.PolymorphicValue)
  121. }
  122. results := makeSlice(field.Struct.Type)
  123. scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
  124. // assign find results
  125. var (
  126. resultsValue = indirect(reflect.ValueOf(results))
  127. indirectScopeValue = scope.IndirectValue()
  128. )
  129. if indirectScopeValue.Kind() == reflect.Slice {
  130. for j := 0; j < indirectScopeValue.Len(); j++ {
  131. for i := 0; i < resultsValue.Len(); i++ {
  132. result := resultsValue.Index(i)
  133. foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
  134. if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) {
  135. indirectValue.FieldByName(field.Name).Set(result)
  136. break
  137. }
  138. }
  139. }
  140. } else {
  141. for i := 0; i < resultsValue.Len(); i++ {
  142. result := resultsValue.Index(i)
  143. scope.Err(field.Set(result))
  144. }
  145. }
  146. }
  147. // handleHasManyPreload used to preload has many associations
  148. func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
  149. relation := field.Relationship
  150. // get relations's primary keys
  151. primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
  152. if len(primaryKeys) == 0 {
  153. return
  154. }
  155. // preload conditions
  156. preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
  157. // find relations
  158. query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
  159. values := toQueryValues(primaryKeys)
  160. if relation.PolymorphicType != "" {
  161. query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
  162. values = append(values, relation.PolymorphicValue)
  163. }
  164. results := makeSlice(field.Struct.Type)
  165. scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
  166. // assign find results
  167. var (
  168. resultsValue = indirect(reflect.ValueOf(results))
  169. indirectScopeValue = scope.IndirectValue()
  170. )
  171. if indirectScopeValue.Kind() == reflect.Slice {
  172. preloadMap := make(map[string][]reflect.Value)
  173. for i := 0; i < resultsValue.Len(); i++ {
  174. result := resultsValue.Index(i)
  175. foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
  176. preloadMap[toString(foreignValues)] = append(preloadMap[toString(foreignValues)], result)
  177. }
  178. for j := 0; j < indirectScopeValue.Len(); j++ {
  179. object := indirect(indirectScopeValue.Index(j))
  180. objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames)
  181. f := object.FieldByName(field.Name)
  182. if results, ok := preloadMap[toString(objectRealValue)]; ok {
  183. f.Set(reflect.Append(f, results...))
  184. } else {
  185. f.Set(reflect.MakeSlice(f.Type(), 0, 0))
  186. }
  187. }
  188. } else {
  189. scope.Err(field.Set(resultsValue))
  190. }
  191. }
  192. // handleBelongsToPreload used to preload belongs to associations
  193. func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
  194. relation := field.Relationship
  195. // preload conditions
  196. preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
  197. // get relations's primary keys
  198. primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value)
  199. if len(primaryKeys) == 0 {
  200. return
  201. }
  202. // find relations
  203. results := makeSlice(field.Struct.Type)
  204. scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
  205. // assign find results
  206. var (
  207. resultsValue = indirect(reflect.ValueOf(results))
  208. indirectScopeValue = scope.IndirectValue()
  209. )
  210. for i := 0; i < resultsValue.Len(); i++ {
  211. result := resultsValue.Index(i)
  212. if indirectScopeValue.Kind() == reflect.Slice {
  213. value := getValueFromFields(result, relation.AssociationForeignFieldNames)
  214. for j := 0; j < indirectScopeValue.Len(); j++ {
  215. object := indirect(indirectScopeValue.Index(j))
  216. if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) {
  217. object.FieldByName(field.Name).Set(result)
  218. }
  219. }
  220. } else {
  221. scope.Err(field.Set(result))
  222. }
  223. }
  224. }
  225. // handleManyToManyPreload used to preload many to many associations
  226. func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) {
  227. var (
  228. relation = field.Relationship
  229. joinTableHandler = relation.JoinTableHandler
  230. fieldType = field.Struct.Type.Elem()
  231. foreignKeyValue interface{}
  232. foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type()
  233. linkHash = map[string][]reflect.Value{}
  234. isPtr bool
  235. )
  236. if fieldType.Kind() == reflect.Ptr {
  237. isPtr = true
  238. fieldType = fieldType.Elem()
  239. }
  240. var sourceKeys = []string{}
  241. for _, key := range joinTableHandler.SourceForeignKeys() {
  242. sourceKeys = append(sourceKeys, key.DBName)
  243. }
  244. // preload conditions
  245. preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
  246. // generate query with join table
  247. newScope := scope.New(reflect.New(fieldType).Interface())
  248. preloadDB = preloadDB.Table(newScope.TableName()).Model(newScope.Value)
  249. if len(preloadDB.search.selects) == 0 {
  250. preloadDB = preloadDB.Select("*")
  251. }
  252. preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value)
  253. // preload inline conditions
  254. if len(preloadConditions) > 0 {
  255. preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...)
  256. }
  257. rows, err := preloadDB.Rows()
  258. if scope.Err(err) != nil {
  259. return
  260. }
  261. defer rows.Close()
  262. columns, _ := rows.Columns()
  263. for rows.Next() {
  264. var (
  265. elem = reflect.New(fieldType).Elem()
  266. fields = scope.New(elem.Addr().Interface()).Fields()
  267. )
  268. // register foreign keys in join tables
  269. var joinTableFields []*Field
  270. for _, sourceKey := range sourceKeys {
  271. joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()})
  272. }
  273. scope.scan(rows, columns, append(fields, joinTableFields...))
  274. scope.New(elem.Addr().Interface()).
  275. InstanceSet("gorm:skip_query_callback", true).
  276. callCallbacks(scope.db.parent.callbacks.queries)
  277. var foreignKeys = make([]interface{}, len(sourceKeys))
  278. // generate hashed forkey keys in join table
  279. for idx, joinTableField := range joinTableFields {
  280. if !joinTableField.Field.IsNil() {
  281. foreignKeys[idx] = joinTableField.Field.Elem().Interface()
  282. }
  283. }
  284. hashedSourceKeys := toString(foreignKeys)
  285. if isPtr {
  286. linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr())
  287. } else {
  288. linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem)
  289. }
  290. }
  291. if err := rows.Err(); err != nil {
  292. scope.Err(err)
  293. }
  294. // assign find results
  295. var (
  296. indirectScopeValue = scope.IndirectValue()
  297. fieldsSourceMap = map[string][]reflect.Value{}
  298. foreignFieldNames = []string{}
  299. )
  300. for _, dbName := range relation.ForeignFieldNames {
  301. if field, ok := scope.FieldByName(dbName); ok {
  302. foreignFieldNames = append(foreignFieldNames, field.Name)
  303. }
  304. }
  305. if indirectScopeValue.Kind() == reflect.Slice {
  306. for j := 0; j < indirectScopeValue.Len(); j++ {
  307. object := indirect(indirectScopeValue.Index(j))
  308. key := toString(getValueFromFields(object, foreignFieldNames))
  309. fieldsSourceMap[key] = append(fieldsSourceMap[key], object.FieldByName(field.Name))
  310. }
  311. } else if indirectScopeValue.IsValid() {
  312. key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames))
  313. fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name))
  314. }
  315. for source, link := range linkHash {
  316. for i, field := range fieldsSourceMap[source] {
  317. //If not 0 this means Value is a pointer and we already added preloaded models to it
  318. if fieldsSourceMap[source][i].Len() != 0 {
  319. continue
  320. }
  321. field.Set(reflect.Append(fieldsSourceMap[source][i], link...))
  322. }
  323. }
  324. }