decode.go 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. package gocsv
  2. import (
  3. "encoding/csv"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "reflect"
  8. )
  9. // Decoder .
  10. type Decoder interface {
  11. getCSVRows() ([][]string, error)
  12. }
  13. // SimpleDecoder .
  14. type SimpleDecoder interface {
  15. getCSVRow() ([]string, error)
  16. }
  17. type decoder struct {
  18. in io.Reader
  19. csvDecoder *csvDecoder
  20. }
  21. func newDecoder(in io.Reader) *decoder {
  22. return &decoder{in: in}
  23. }
  24. func (decode *decoder) getCSVRows() ([][]string, error) {
  25. return getCSVReader(decode.in).ReadAll()
  26. }
  27. func (decode *decoder) getCSVRow() ([]string, error) {
  28. if decode.csvDecoder == nil {
  29. decode.csvDecoder = &csvDecoder{getCSVReader(decode.in)}
  30. }
  31. return decode.csvDecoder.Read()
  32. }
  33. type CSVReader interface {
  34. Read() ([]string, error)
  35. ReadAll() ([][]string, error)
  36. }
  37. type csvDecoder struct {
  38. CSVReader
  39. }
  40. func (c csvDecoder) getCSVRows() ([][]string, error) {
  41. return c.ReadAll()
  42. }
  43. func (c csvDecoder) getCSVRow() ([]string, error) {
  44. return c.Read()
  45. }
  46. func maybeMissingStructFields(structInfo []fieldInfo, headers []string) error {
  47. if len(structInfo) == 0 {
  48. return nil
  49. }
  50. headerMap := make(map[string]struct{}, len(headers))
  51. for idx := range headers {
  52. headerMap[headers[idx]] = struct{}{}
  53. }
  54. for _, info := range structInfo {
  55. found := false
  56. for _, key := range info.keys {
  57. if _, ok := headerMap[key]; ok {
  58. found = true
  59. break
  60. }
  61. }
  62. if !found {
  63. return fmt.Errorf("found unmatched struct field with tags %v", info.keys)
  64. }
  65. }
  66. return nil
  67. }
  68. // Check that no header name is repeated twice
  69. func maybeDoubleHeaderNames(headers []string) error {
  70. headerMap := make(map[string]bool, len(headers))
  71. for _, v := range headers {
  72. if _, ok := headerMap[v]; ok {
  73. return fmt.Errorf("Repeated header name: %v", v)
  74. }
  75. headerMap[v] = true
  76. }
  77. return nil
  78. }
  79. func readTo(decoder Decoder, out interface{}) error {
  80. outValue, outType := getConcreteReflectValueAndType(out) // Get the concrete type (not pointer) (Slice<?> or Array<?>)
  81. if err := ensureOutType(outType); err != nil {
  82. return err
  83. }
  84. outInnerWasPointer, outInnerType := getConcreteContainerInnerType(outType) // Get the concrete inner type (not pointer) (Container<"?">)
  85. if err := ensureOutInnerType(outInnerType); err != nil {
  86. return err
  87. }
  88. csvRows, err := decoder.getCSVRows() // Get the CSV csvRows
  89. if err != nil {
  90. return err
  91. }
  92. if len(csvRows) == 0 {
  93. return errors.New("empty csv file given")
  94. }
  95. if err := ensureOutCapacity(&outValue, len(csvRows)); err != nil { // Ensure the container is big enough to hold the CSV content
  96. return err
  97. }
  98. outInnerStructInfo := getStructInfo(outInnerType) // Get the inner struct info to get CSV annotations
  99. if len(outInnerStructInfo.Fields) == 0 {
  100. return errors.New("no csv struct tags found")
  101. }
  102. headers := csvRows[0]
  103. body := csvRows[1:]
  104. csvHeadersLabels := make(map[int]*fieldInfo, len(outInnerStructInfo.Fields)) // Used to store the correspondance header <-> position in CSV
  105. headerCount := map[string]int{}
  106. for i, csvColumnHeader := range headers {
  107. curHeaderCount := headerCount[csvColumnHeader]
  108. if fieldInfo := getCSVFieldPosition(csvColumnHeader, outInnerStructInfo, curHeaderCount); fieldInfo != nil {
  109. csvHeadersLabels[i] = fieldInfo
  110. if ShouldAlignDuplicateHeadersWithStructFieldOrder {
  111. curHeaderCount++
  112. headerCount[csvColumnHeader] = curHeaderCount
  113. }
  114. }
  115. }
  116. if FailIfUnmatchedStructTags {
  117. if err := maybeMissingStructFields(outInnerStructInfo.Fields, headers); err != nil {
  118. return err
  119. }
  120. }
  121. if FailIfDoubleHeaderNames {
  122. if err := maybeDoubleHeaderNames(headers); err != nil {
  123. return err
  124. }
  125. }
  126. for i, csvRow := range body {
  127. outInner := createNewOutInner(outInnerWasPointer, outInnerType)
  128. for j, csvColumnContent := range csvRow {
  129. if fieldInfo, ok := csvHeadersLabels[j]; ok { // Position found accordingly to header name
  130. if err := setInnerField(&outInner, outInnerWasPointer, fieldInfo.IndexChain, csvColumnContent, fieldInfo.omitEmpty); err != nil { // Set field of struct
  131. return &csv.ParseError{
  132. Line: i + 2, //add 2 to account for the header & 0-indexing of arrays
  133. Column: j + 1,
  134. Err: err,
  135. }
  136. }
  137. }
  138. }
  139. outValue.Index(i).Set(outInner)
  140. }
  141. return nil
  142. }
  143. func readEach(decoder SimpleDecoder, c interface{}) error {
  144. headers, err := decoder.getCSVRow()
  145. if err != nil {
  146. return err
  147. }
  148. outValue, outType := getConcreteReflectValueAndType(c) // Get the concrete type (not pointer) (Slice<?> or Array<?>)
  149. if err := ensureOutType(outType); err != nil {
  150. return err
  151. }
  152. defer outValue.Close()
  153. outInnerWasPointer, outInnerType := getConcreteContainerInnerType(outType) // Get the concrete inner type (not pointer) (Container<"?">)
  154. if err := ensureOutInnerType(outInnerType); err != nil {
  155. return err
  156. }
  157. outInnerStructInfo := getStructInfo(outInnerType) // Get the inner struct info to get CSV annotations
  158. if len(outInnerStructInfo.Fields) == 0 {
  159. return errors.New("no csv struct tags found")
  160. }
  161. csvHeadersLabels := make(map[int]*fieldInfo, len(outInnerStructInfo.Fields)) // Used to store the correspondance header <-> position in CSV
  162. headerCount := map[string]int{}
  163. for i, csvColumnHeader := range headers {
  164. curHeaderCount := headerCount[csvColumnHeader]
  165. if fieldInfo := getCSVFieldPosition(csvColumnHeader, outInnerStructInfo, curHeaderCount); fieldInfo != nil {
  166. csvHeadersLabels[i] = fieldInfo
  167. if ShouldAlignDuplicateHeadersWithStructFieldOrder {
  168. curHeaderCount++
  169. headerCount[csvColumnHeader] = curHeaderCount
  170. }
  171. }
  172. }
  173. if err := maybeMissingStructFields(outInnerStructInfo.Fields, headers); err != nil {
  174. if FailIfUnmatchedStructTags {
  175. return err
  176. }
  177. }
  178. if FailIfDoubleHeaderNames {
  179. if err := maybeDoubleHeaderNames(headers); err != nil {
  180. return err
  181. }
  182. }
  183. i := 0
  184. for {
  185. line, err := decoder.getCSVRow()
  186. if err == io.EOF {
  187. break
  188. } else if err != nil {
  189. return err
  190. }
  191. outInner := createNewOutInner(outInnerWasPointer, outInnerType)
  192. for j, csvColumnContent := range line {
  193. if fieldInfo, ok := csvHeadersLabels[j]; ok { // Position found accordingly to header name
  194. if err := setInnerField(&outInner, outInnerWasPointer, fieldInfo.IndexChain, csvColumnContent, fieldInfo.omitEmpty); err != nil { // Set field of struct
  195. return &csv.ParseError{
  196. Line: i + 2, //add 2 to account for the header & 0-indexing of arrays
  197. Column: j + 1,
  198. Err: err,
  199. }
  200. }
  201. }
  202. }
  203. outValue.Send(outInner)
  204. i++
  205. }
  206. return nil
  207. }
  208. // Check if the outType is an array or a slice
  209. func ensureOutType(outType reflect.Type) error {
  210. switch outType.Kind() {
  211. case reflect.Slice:
  212. fallthrough
  213. case reflect.Chan:
  214. fallthrough
  215. case reflect.Array:
  216. return nil
  217. }
  218. return fmt.Errorf("cannot use " + outType.String() + ", only slice or array supported")
  219. }
  220. // Check if the outInnerType is of type struct
  221. func ensureOutInnerType(outInnerType reflect.Type) error {
  222. switch outInnerType.Kind() {
  223. case reflect.Struct:
  224. return nil
  225. }
  226. return fmt.Errorf("cannot use " + outInnerType.String() + ", only struct supported")
  227. }
  228. func ensureOutCapacity(out *reflect.Value, csvLen int) error {
  229. switch out.Kind() {
  230. case reflect.Array:
  231. if out.Len() < csvLen-1 { // Array is not big enough to hold the CSV content (arrays are not addressable)
  232. return fmt.Errorf("array capacity problem: cannot store %d %s in %s", csvLen-1, out.Type().Elem().String(), out.Type().String())
  233. }
  234. case reflect.Slice:
  235. if !out.CanAddr() && out.Len() < csvLen-1 { // Slice is not big enough tho hold the CSV content and is not addressable
  236. return fmt.Errorf("slice capacity problem and is not addressable (did you forget &?)")
  237. } else if out.CanAddr() && out.Len() < csvLen-1 {
  238. out.Set(reflect.MakeSlice(out.Type(), csvLen-1, csvLen-1)) // Slice is not big enough, so grows it
  239. }
  240. }
  241. return nil
  242. }
  243. func getCSVFieldPosition(key string, structInfo *structInfo, curHeaderCount int) *fieldInfo {
  244. matchedFieldCount := 0
  245. for _, field := range structInfo.Fields {
  246. if field.matchesKey(key) {
  247. if matchedFieldCount >= curHeaderCount {
  248. return &field
  249. } else {
  250. matchedFieldCount++
  251. }
  252. }
  253. }
  254. return nil
  255. }
  256. func createNewOutInner(outInnerWasPointer bool, outInnerType reflect.Type) reflect.Value {
  257. if outInnerWasPointer {
  258. return reflect.New(outInnerType)
  259. }
  260. return reflect.New(outInnerType).Elem()
  261. }
  262. func setInnerField(outInner *reflect.Value, outInnerWasPointer bool, index []int, value string, omitEmpty bool) error {
  263. oi := *outInner
  264. if outInnerWasPointer {
  265. oi = outInner.Elem()
  266. }
  267. return setField(oi.FieldByIndex(index), value, omitEmpty)
  268. }