callback_query_preload.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. package gorm
  2. import (
  3. "errors"
  4. "fmt"
  5. "reflect"
  6. "strings"
  7. )
  8. // preloadCallback used to preload associations
  9. func preloadCallback(scope *Scope) {
  10. if scope.Search.preload == nil || scope.HasError() {
  11. return
  12. }
  13. var (
  14. preloadedMap = map[string]bool{}
  15. fields = scope.Fields()
  16. )
  17. for _, preload := range scope.Search.preload {
  18. var (
  19. preloadFields = strings.Split(preload.schema, ".")
  20. currentScope = scope
  21. currentFields = fields
  22. )
  23. for idx, preloadField := range preloadFields {
  24. var currentPreloadConditions []interface{}
  25. if currentScope == nil {
  26. continue
  27. }
  28. // if not preloaded
  29. if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] {
  30. // assign search conditions to last preload
  31. if idx == len(preloadFields)-1 {
  32. currentPreloadConditions = preload.conditions
  33. }
  34. for _, field := range currentFields {
  35. if field.Name != preloadField || field.Relationship == nil {
  36. continue
  37. }
  38. switch field.Relationship.Kind {
  39. case "has_one":
  40. currentScope.handleHasOnePreload(field, currentPreloadConditions)
  41. case "has_many":
  42. currentScope.handleHasManyPreload(field, currentPreloadConditions)
  43. case "belongs_to":
  44. currentScope.handleBelongsToPreload(field, currentPreloadConditions)
  45. case "many_to_many":
  46. currentScope.handleManyToManyPreload(field, currentPreloadConditions)
  47. default:
  48. scope.Err(errors.New("unsupported relation"))
  49. }
  50. preloadedMap[preloadKey] = true
  51. break
  52. }
  53. if !preloadedMap[preloadKey] {
  54. scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType))
  55. return
  56. }
  57. }
  58. // preload next level
  59. if idx < len(preloadFields)-1 {
  60. currentScope = currentScope.getColumnAsScope(preloadField)
  61. if currentScope != nil {
  62. currentFields = currentScope.Fields()
  63. }
  64. }
  65. }
  66. }
  67. }
  68. func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) {
  69. var (
  70. preloadDB = scope.NewDB()
  71. preloadConditions []interface{}
  72. )
  73. for _, condition := range conditions {
  74. if scopes, ok := condition.(func(*DB) *DB); ok {
  75. preloadDB = scopes(preloadDB)
  76. } else {
  77. preloadConditions = append(preloadConditions, condition)
  78. }
  79. }
  80. return preloadDB, preloadConditions
  81. }
  82. // handleHasOnePreload used to preload has one associations
  83. func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
  84. relation := field.Relationship
  85. // get relations's primary keys
  86. primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
  87. if len(primaryKeys) == 0 {
  88. return
  89. }
  90. // preload conditions
  91. preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
  92. // find relations
  93. query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
  94. values := toQueryValues(primaryKeys)
  95. if relation.PolymorphicType != "" {
  96. query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
  97. values = append(values, relation.PolymorphicValue)
  98. }
  99. results := makeSlice(field.Struct.Type)
  100. scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
  101. // assign find results
  102. var (
  103. resultsValue = indirect(reflect.ValueOf(results))
  104. indirectScopeValue = scope.IndirectValue()
  105. )
  106. if indirectScopeValue.Kind() == reflect.Slice {
  107. for j := 0; j < indirectScopeValue.Len(); j++ {
  108. for i := 0; i < resultsValue.Len(); i++ {
  109. result := resultsValue.Index(i)
  110. foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
  111. if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) {
  112. indirectValue.FieldByName(field.Name).Set(result)
  113. break
  114. }
  115. }
  116. }
  117. } else {
  118. for i := 0; i < resultsValue.Len(); i++ {
  119. result := resultsValue.Index(i)
  120. scope.Err(field.Set(result))
  121. }
  122. }
  123. }
  124. // handleHasManyPreload used to preload has many associations
  125. func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
  126. relation := field.Relationship
  127. // get relations's primary keys
  128. primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
  129. if len(primaryKeys) == 0 {
  130. return
  131. }
  132. // preload conditions
  133. preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
  134. // find relations
  135. query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
  136. values := toQueryValues(primaryKeys)
  137. if relation.PolymorphicType != "" {
  138. query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
  139. values = append(values, relation.PolymorphicValue)
  140. }
  141. results := makeSlice(field.Struct.Type)
  142. scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
  143. // assign find results
  144. var (
  145. resultsValue = indirect(reflect.ValueOf(results))
  146. indirectScopeValue = scope.IndirectValue()
  147. )
  148. if indirectScopeValue.Kind() == reflect.Slice {
  149. preloadMap := make(map[string][]reflect.Value)
  150. for i := 0; i < resultsValue.Len(); i++ {
  151. result := resultsValue.Index(i)
  152. foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
  153. preloadMap[toString(foreignValues)] = append(preloadMap[toString(foreignValues)], result)
  154. }
  155. for j := 0; j < indirectScopeValue.Len(); j++ {
  156. object := indirect(indirectScopeValue.Index(j))
  157. objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames)
  158. f := object.FieldByName(field.Name)
  159. if results, ok := preloadMap[toString(objectRealValue)]; ok {
  160. f.Set(reflect.Append(f, results...))
  161. } else {
  162. f.Set(reflect.MakeSlice(f.Type(), 0, 0))
  163. }
  164. }
  165. } else {
  166. scope.Err(field.Set(resultsValue))
  167. }
  168. }
  169. // handleBelongsToPreload used to preload belongs to associations
  170. func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
  171. relation := field.Relationship
  172. // preload conditions
  173. preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
  174. // get relations's primary keys
  175. primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value)
  176. if len(primaryKeys) == 0 {
  177. return
  178. }
  179. // find relations
  180. results := makeSlice(field.Struct.Type)
  181. scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
  182. // assign find results
  183. var (
  184. resultsValue = indirect(reflect.ValueOf(results))
  185. indirectScopeValue = scope.IndirectValue()
  186. )
  187. for i := 0; i < resultsValue.Len(); i++ {
  188. result := resultsValue.Index(i)
  189. if indirectScopeValue.Kind() == reflect.Slice {
  190. value := getValueFromFields(result, relation.AssociationForeignFieldNames)
  191. for j := 0; j < indirectScopeValue.Len(); j++ {
  192. object := indirect(indirectScopeValue.Index(j))
  193. if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) {
  194. object.FieldByName(field.Name).Set(result)
  195. }
  196. }
  197. } else {
  198. scope.Err(field.Set(result))
  199. }
  200. }
  201. }
  202. // handleManyToManyPreload used to preload many to many associations
  203. func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) {
  204. var (
  205. relation = field.Relationship
  206. joinTableHandler = relation.JoinTableHandler
  207. fieldType = field.Struct.Type.Elem()
  208. foreignKeyValue interface{}
  209. foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type()
  210. linkHash = map[string][]reflect.Value{}
  211. isPtr bool
  212. )
  213. if fieldType.Kind() == reflect.Ptr {
  214. isPtr = true
  215. fieldType = fieldType.Elem()
  216. }
  217. var sourceKeys = []string{}
  218. for _, key := range joinTableHandler.SourceForeignKeys() {
  219. sourceKeys = append(sourceKeys, key.DBName)
  220. }
  221. // preload conditions
  222. preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
  223. // generate query with join table
  224. newScope := scope.New(reflect.New(fieldType).Interface())
  225. preloadDB = preloadDB.Table(newScope.TableName()).Model(newScope.Value).Select("*")
  226. preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value)
  227. // preload inline conditions
  228. if len(preloadConditions) > 0 {
  229. preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...)
  230. }
  231. rows, err := preloadDB.Rows()
  232. if scope.Err(err) != nil {
  233. return
  234. }
  235. defer rows.Close()
  236. columns, _ := rows.Columns()
  237. for rows.Next() {
  238. var (
  239. elem = reflect.New(fieldType).Elem()
  240. fields = scope.New(elem.Addr().Interface()).Fields()
  241. )
  242. // register foreign keys in join tables
  243. var joinTableFields []*Field
  244. for _, sourceKey := range sourceKeys {
  245. joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()})
  246. }
  247. scope.scan(rows, columns, append(fields, joinTableFields...))
  248. var foreignKeys = make([]interface{}, len(sourceKeys))
  249. // generate hashed forkey keys in join table
  250. for idx, joinTableField := range joinTableFields {
  251. if !joinTableField.Field.IsNil() {
  252. foreignKeys[idx] = joinTableField.Field.Elem().Interface()
  253. }
  254. }
  255. hashedSourceKeys := toString(foreignKeys)
  256. if isPtr {
  257. linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr())
  258. } else {
  259. linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem)
  260. }
  261. }
  262. // assign find results
  263. var (
  264. indirectScopeValue = scope.IndirectValue()
  265. fieldsSourceMap = map[string][]reflect.Value{}
  266. foreignFieldNames = []string{}
  267. )
  268. for _, dbName := range relation.ForeignFieldNames {
  269. if field, ok := scope.FieldByName(dbName); ok {
  270. foreignFieldNames = append(foreignFieldNames, field.Name)
  271. }
  272. }
  273. if indirectScopeValue.Kind() == reflect.Slice {
  274. for j := 0; j < indirectScopeValue.Len(); j++ {
  275. object := indirect(indirectScopeValue.Index(j))
  276. key := toString(getValueFromFields(object, foreignFieldNames))
  277. fieldsSourceMap[key] = append(fieldsSourceMap[key], object.FieldByName(field.Name))
  278. }
  279. } else if indirectScopeValue.IsValid() {
  280. key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames))
  281. fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name))
  282. }
  283. for source, link := range linkHash {
  284. for i, field := range fieldsSourceMap[source] {
  285. //If not 0 this means Value is a pointer and we already added preloaded models to it
  286. if fieldsSourceMap[source][i].Len() != 0 {
  287. continue
  288. }
  289. field.Set(reflect.Append(fieldsSourceMap[source][i], link...))
  290. }
  291. }
  292. }