diff --git a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala index f2a9a8b..8b01c04 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala @@ -39,27 +39,27 @@ object StructTypeHelpers { }) } - private def schemaToSortedSelectExpr[A](schema: StructType, f: StructField => A, baseField: String = "")(implicit ord: Ordering[A]): Seq[Column] = { - def handleNestedType(t: DataType, name: String, outerCol: Column): Column = + private def schemaToSortedSelectExpr[A](schema: StructType, f: StructField => A)(implicit ord: Ordering[A]): Seq[Column] = { + def handleNestedType(t: DataType, name: String, outerCol: Column, firstLevel: Boolean = false): Column = t match { case st: StructType => - val sortedFields = st.fields.sortBy(f) struct( - sortedFields.map(f => - f.dataType match { - case st: StructType => - handleNestedType(st, f.name, outerCol(f.name)).as(f.name) - case ArrayType(innerType: StructType, _) => - handleArrayType(f.dataType, name, outerCol(f.name)).as(f.name) - case _ => - handleNestedType(f.dataType, f.name, outerCol).as(f.name) - } - ): _* + st.fields + .sortBy(f) + .map(field => + handleNestedType( + field.dataType, + field.name, + field.dataType match { + case StructType(_) | ArrayType(_: StructType, _) => outerCol(field.name) + case _ => outerCol + } + ).as(field.name) + ): _* ).as(name) - case ArrayType(_, _) => - handleArrayType(t, name, outerCol).as(name) - case _ => - outerCol(name) + case ArrayType(_, _) => handleArrayType(t, name, outerCol).as(name) + case _ if firstLevel => outerCol + case _ if !firstLevel => outerCol(name) } // For handling reordering of nested arrays @@ -74,16 +74,7 @@ object StructTypeHelpers { val result = schema.fields.sortBy(f).foldLeft(Seq.empty[Column]) { case (acc, field) => - val name = field.name - val sortedCol = field.dataType match { - case st: StructType => - handleNestedType(st, name, col(name)) - case arr: ArrayType => - handleArrayType(arr, name, col(name)) - case _ => col(name) - } - - acc :+ sortedCol + acc :+ handleNestedType(field.dataType, field.name, col(field.name), firstLevel = true) } result }