diff --git a/3rdparty/dmlc-core b/3rdparty/dmlc-core index dadcd97fdceb..649be18a8c55 160000 --- a/3rdparty/dmlc-core +++ b/3rdparty/dmlc-core @@ -1 +1 @@ -Subproject commit dadcd97fdceb5f395e963b2a637f6ed377f59fc4 +Subproject commit 649be18a8c55c48517861d67158a45dec54992ee diff --git a/src/io/image_iter_common.h b/src/io/image_iter_common.h index 56822888a445..8580ff8f9f9c 100644 --- a/src/io/image_iter_common.h +++ b/src/io/image_iter_common.h @@ -348,6 +348,7 @@ struct PrefetcherParam : public dmlc::Parameter { .add_enum("float32", mshadow::kFloat32) .add_enum("float64", mshadow::kFloat64) .add_enum("float16", mshadow::kFloat16) + .add_enum("int64", mshadow::kInt64) .add_enum("int32", mshadow::kInt32) .add_enum("uint8", mshadow::kUint8) .set_default(dmlc::optional()) diff --git a/src/io/iter_csv.cc b/src/io/iter_csv.cc index ca3f042f45a3..5fd149535be2 100644 --- a/src/io/iter_csv.cc +++ b/src/io/iter_csv.cc @@ -174,15 +174,21 @@ class CSVIter: public IIterator { for (const auto& arg : kwargs) { if (arg.first == "dtype") { dtype_has_value = true; - if (arg.second == "int32" || arg.second == "float32") { - target_dtype = (arg.second == "int32") ? mshadow::kInt32 : mshadow::kFloat32; + if (arg.second == "int32") { + target_dtype = mshadow::kInt32; + } else if (arg.second == "int64") { + target_dtype = mshadow::kInt64; + } else if (arg.second == "float32") { + target_dtype = mshadow::kFloat32; } else { CHECK(false) << arg.second << " is not supported for CSVIter"; } } } if (dtype_has_value && target_dtype == mshadow::kInt32) { - iterator_.reset(reinterpret_cast(new CSVIterTyped())); + iterator_.reset(reinterpret_cast(new CSVIterTyped())); + } else if (dtype_has_value && target_dtype == mshadow::kInt64) { + iterator_.reset(reinterpret_cast(new CSVIterTyped())); } else if (!dtype_has_value || target_dtype == mshadow::kFloat32) { iterator_.reset(reinterpret_cast(new CSVIterTyped())); } @@ -229,8 +235,8 @@ If ``data_csv = 'data/'`` is set, then all the files in this directory will be r ``reset()`` is expected to be called only after a complete pass of data. By default, the CSVIter parses all entries in the data file as float32 data type, -if `dtype` argument is set to be 'int32' then CSVIter will parse all entries in the file -as int32 data type. +if `dtype` argument is set to be 'int32' or 'int64' then CSVIter will parse all entries in the file +as int32 or int64 data type accordingly. Examples::