diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala index f83bcc4207b2..4c6a64d99034 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala @@ -125,7 +125,8 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, checkCall(_LIB.mxDataIterNext(handle, next)) if (next.value > 0) { currentBatch = new DataBatch(data = getData(), label = getLabel(), - index = getIndex(), pad = getPad()) + index = getIndex(), pad = getPad(), + dtype = currentBatch.dtype, layout = currentBatch.layout) } else { currentBatch = null }