Skip to content

Commit

Permalink
feat: refactoring of lightgbm code in preparation for single dataset …
Browse files Browse the repository at this point in the history
…mode (#1088)
  • Loading branch information
imatiach-msft authored Jun 16, 2021
1 parent e7d4eca commit e8a97ed
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ class SharedSingleton[T: ClassTag](constructor: => T) extends AnyRef with Serial
}

def get: T = instance

}

object SharedSingleton {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,19 @@ object LightGBMConstants {
/** The default num iterations for prediction
*/
val DefaultNumIterations: Int = -1
/** The number of retries for network initialization of native lightgbm
*/
val NetworkRetries: Int = 3
/**
* Delay prior to exponential backoff for network initialization
*/
val InitialDelay: Long = 1000L
}

/**
* Connection state of a worker
*/
object ConnectionState extends Enumeration {
type ConnectionState = Value
val Finished, EmptyTask, Connected = Value
}
136 changes: 76 additions & 60 deletions src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.ml.PipelineModel
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.SparseVector
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.slf4j.Logger

import ConnectionState._

import scala.collection.immutable.HashSet
import scala.collection.mutable.ListBuffer
import scala.concurrent.duration.{Duration, SECONDS}
Expand Down Expand Up @@ -65,15 +67,6 @@ object LightGBMUtils {
featurizer.fit(dataset)
}

def getBoosterPtrFromModelString(lgbModelString: String): SWIGTYPE_p_void = {
val boosterOutPtr = lightgbmlib.voidpp_handle()
val numItersOut = lightgbmlib.new_intp()
LightGBMUtils.validate(
lightgbmlib.LGBM_BoosterLoadModelFromString(lgbModelString, numItersOut, boosterOutPtr),
"Booster LoadFromString")
lightgbmlib.voidpp_value(boosterOutPtr)
}

def getCategoricalIndexes(df: DataFrame,
featuresCol: String,
slotNames: Array[String],
Expand Down Expand Up @@ -108,6 +101,53 @@ object LightGBMUtils {
categoricalColumnIndexes.union(categoricalIndexes).distinct
}

def sendDataToExecutors(hostAndPorts: ListBuffer[(Socket, String)], allConnections: String): Unit = {
hostAndPorts.foreach(hostAndPort => {
val writer = new BufferedWriter(new OutputStreamWriter(hostAndPort._1.getOutputStream))
writer.write(allConnections + "\n")
writer.flush()
})
}

def closeConnections(log: Logger, hostAndPorts: ListBuffer[(Socket, String)],
driverServerSocket: ServerSocket): Unit = {
log.info("driver closing all sockets and server socket")
hostAndPorts.foreach(_._1.close())
driverServerSocket.close()
}

def addSocketAndComm(hostAndPorts: ListBuffer[(Socket, String)], log: Logger,
comm: String, driverSocket: Socket): Unit = {
log.info(s"driver received socket from task: $comm")
val socketAndComm = (driverSocket, comm)
hostAndPorts += socketAndComm
}

/** Handles the connection to a task from the driver.
*
* @param driverServerSocket The driver socket.
* @param log The log4j logger.
* @param hostAndPorts A list of host and ports of connected tasks.
* @return The connection status, can be finished for barrier mode, empty task or connected.
*/
def handleConnection(driverServerSocket: ServerSocket, log: Logger,
hostAndPorts: ListBuffer[(Socket, String)]): ConnectionState = {
log.info("driver accepting a new connection...")
val driverSocket = driverServerSocket.accept()
val reader = new BufferedReader(new InputStreamReader(driverSocket.getInputStream))
val comm = reader.readLine()
if (comm == LightGBMConstants.FinishedStatus) {
log.info("driver received all tasks from barrier stage")
Finished
} else if (comm == LightGBMConstants.IgnoreStatus) {
log.info("driver received ignore status from task")
EmptyTask
} else {
addSocketAndComm(hostAndPorts, log, comm, driverSocket)
Connected
}
}

/**
* Opens a socket communications channel on the driver, starts a thread that
* waits for the host:port from the executors, and then sends back the
Expand All @@ -134,64 +174,33 @@ object LightGBMUtils {
val hostAndPorts = ListBuffer[(Socket, String)]()
if (barrierExecutionMode) {
log.info(s"driver using barrier execution mode")
var finished = false
while (!finished) {
log.info("driver accepting a new connection...")
val driverSocket = driverServerSocket.accept()
val reader = new BufferedReader(new InputStreamReader(driverSocket.getInputStream))
val comm = reader.readLine()
if (comm == LightGBMConstants.FinishedStatus) {
log.info("driver received all tasks from barrier stage")
finished = true
} else if (comm == LightGBMConstants.IgnoreStatus) {
log.info("driver received ignore status from task")
} else {
log.info(s"driver received socket from task: $comm")
val socketAndComm = (driverSocket, comm)
hostAndPorts += socketAndComm
}
}
def connectToWorkers: Boolean = handleConnection(driverServerSocket, log,
hostAndPorts) == Finished || connectToWorkers
connectToWorkers
} else {
log.info(s"driver expecting $numTasks connections...")
while (hostAndPorts.size + emptyTaskCounter < numTasks) {
log.info("driver accepting a new connection...")
val driverSocket = driverServerSocket.accept()
val reader = new BufferedReader(new InputStreamReader(driverSocket.getInputStream))
val comm = reader.readLine()
if (comm == LightGBMConstants.IgnoreStatus) {
log.info("driver received ignore status from task")
emptyTaskCounter += 1
} else {
log.info(s"driver received socket from task: $comm")
val socketAndComm = (driverSocket, comm)
hostAndPorts += socketAndComm
}
val connectionResult = handleConnection(driverServerSocket, log, hostAndPorts)
if (connectionResult == ConnectionState.EmptyTask) emptyTaskCounter += 1
}
}
// Concatenate with commas, eg: host1:port1,host2:port2, ... etc
val allConnections = hostAndPorts.map(_._2).mkString(",")
log.info(s"driver writing back to all connections: $allConnections")
// Send data back to all threads on executors
hostAndPorts.foreach(hostAndPort => {
val writer = new BufferedWriter(new OutputStreamWriter(hostAndPort._1.getOutputStream))
writer.write(allConnections + "\n")
writer.flush()
})
log.info("driver closing all sockets and server socket")
hostAndPorts.foreach(_._1.close())
driverServerSocket.close()
// Send data back to all tasks and helper tasks on executors
sendDataToExecutors(hostAndPorts, allConnections)
closeConnections(log, hostAndPorts, driverServerSocket)
}
val host = ClusterUtil.getDriverHost(df)
val port = driverServerSocket.getLocalPort
log.info(s"driver waiting for connections on host: $host and port: $port")
(host, port, f)
}

/** Returns an integer ID for the current node.
*
* @return In cluster, returns the executor id. In local case, returns the task id.
/** Returns an integer ID for the current worker.
* @return In cluster, returns the executor id. In local case, returns the partition id.
*/
def getId(): Int = {
def getWorkerId(): Int = {
val executorId = SparkEnv.get.executorId
val ctx = TaskContext.get
val partId = ctx.partitionId
Expand All @@ -201,14 +210,21 @@ object LightGBMUtils {
idAsInt
}

def generateData(numRows: Int, rowsAsDoubleArray: Array[Array[Double]]):
(SWIGTYPE_p_void, SWIGTYPE_p_double) = {
val numCols = rowsAsDoubleArray.head.length
val data = lightgbmlib.new_doubleArray(numCols.toLong * numRows.toLong)
rowsAsDoubleArray.zipWithIndex.foreach(ri =>
ri._1.zipWithIndex.foreach(value =>
lightgbmlib.doubleArray_setitem(data, (value._2 + (ri._2 * numCols)).toLong, value._1)))
(lightgbmlib.double_to_voidp_ptr(data), data)
/** Returns true if spark is run in local mode.
* @return True if spark is run in local mode.
*/
def isLocalExecution(): Boolean = {
val executorId = SparkEnv.get.executorId
executorId == "driver"
}

/** Returns a unique task Id for the current task run on the executor.
* @return A unique task id.
*/
def getTaskId(): Long = {
val ctx = TaskContext.get
val taskId = ctx.taskAttemptId()
taskId
}

def getNumRowsForChunksArray(numRows: Int, chunkSize: Int): SWIGTYPE_p_int = {
Expand Down
59 changes: 33 additions & 26 deletions src/main/scala/com/microsoft/ml/spark/lightgbm/TrainUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import com.microsoft.ml.spark.downloader.FaultToleranceUtils
import com.microsoft.ml.spark.lightgbm.booster.LightGBMBooster
import com.microsoft.ml.spark.lightgbm.dataset.LightGBMDataset
import com.microsoft.ml.spark.lightgbm.params.{ClassifierTrainParams, TrainParams}
import com.microsoft.ml.spark.lightgbm.swig.SwigUtils
import org.apache.spark.{BarrierTaskContext, TaskContext}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.attribute._
Expand Down Expand Up @@ -328,31 +327,22 @@ private object TrainUtils extends Serializable {
log: Logger,
iters: Int): Boolean = {
var isFinished = false
val isFinishedPtr = lightgbmlib.new_intp()
try {
val result =
if (trainParams.objectiveParams.fobj.isDefined) {
val classification = trainParams.isInstanceOf[ClassifierTrainParams]
val (gradient, hessian) = trainParams.objectiveParams.fobj.get.getGradient(
booster.innerPredict(0, classification), booster.trainDataset.get)
val gradPtr = SwigUtils.floatArrayToNative(gradient)
val hessPtr = SwigUtils.floatArrayToNative(hessian)
lightgbmlib.LGBM_BoosterUpdateOneIterCustom(booster.boosterHandler.boosterPtr,
gradPtr, hessPtr, isFinishedPtr)
isFinished = booster.updateOneIterationCustom(gradient, hessian)
} else {
lightgbmlib.LGBM_BoosterUpdateOneIter(booster.boosterHandler.boosterPtr, isFinishedPtr)
isFinished = booster.updateOneIteration()
}
LightGBMUtils.validate(result, "Booster Update One Iter")
isFinished = lightgbmlib.intp_value(isFinishedPtr) == 1
log.info("LightGBM running iteration: " + iters + " with is finished: " + isFinished)
} catch {
case e: java.lang.Exception =>
log.warn("LightGBM reached early termination on one task," +
" stopping training on task. This message should rarely occur." +
" Inner exception: " + e.toString)
isFinished = true
} finally {
lightgbmlib.delete_intp(isFinishedPtr)
}
isFinished
}
Expand All @@ -374,7 +364,7 @@ private object TrainUtils extends Serializable {
val newLearningRate = getLearningRate(batchIndex, partitionId, iters, log, trainParams,
learningRate)
if (newLearningRate != learningRate) {
log.info(s"LightGBM task calling LGBM_BoosterResetParameter to reset learningRate" +
log.info(s"LightGBM task calling booster.resetParameter to reset learningRate" +
s" (newLearningRate: $newLearningRate)")
booster.resetParameter(s"learning_rate=$newLearningRate")
learningRate = newLearningRate
Expand Down Expand Up @@ -521,7 +511,7 @@ private object TrainUtils extends Serializable {
}

private def findOpenPort(defaultListenPort: Int, numTasksPerExec: Int, log: Logger): Socket = {
val basePort = defaultListenPort + (LightGBMUtils.getId() * numTasksPerExec)
val basePort = defaultListenPort + (LightGBMUtils.getWorkerId() * numTasksPerExec)
if (basePort > LightGBMConstants.MaxPort) {
throw new Exception(s"Error: port $basePort out of range, possibly due to too many executors or unknown error")
}
Expand Down Expand Up @@ -645,24 +635,41 @@ private object TrainUtils extends Serializable {
mainPort.toInt
}

def trainLightGBM(batchIndex: Int, networkParams: NetworkParams, columnParams: ColumnParams,
validationData: Option[Broadcast[Array[Row]]], log: Logger,
trainParams: TrainParams, numTasksPerExec: Int, schema: StructType)
(inputRows: Iterator[Row]): Iterator[LightGBMBooster] = {
val emptyPartition = !inputRows.hasNext
// Ideally we would start the socket connections in the C layer, this opens us up for
// race conditions in case other applications open sockets on cluster, but usually this
// should not be a problem
val (nodes, localListenPort) = using(findOpenPort(networkParams.defaultListenPort, numTasksPerExec, log)) {
/** Retrieve the network nodes and current port information.
*
* Establish local socket connection.
*
* Note: Ideally we would start the socket connections in the C layer, this opens us up for
* race conditions in case other applications open sockets on cluster, but usually this
* should not be a problem
*
* @param networkParams The network parameters.
* @param numTasksPerExec The number of tasks per executor.
* @param log The logger.
* @param isEnabledWorker True if the current worker is enabled, including whether the partition
* was enabled and this is the chosen worker to initialize the network connection.
* @return A tuple containing the string with all nodes and the current worker's open socket connection.
*/
def getNetworkInfo(networkParams: NetworkParams, numTasksPerExec: Int,
log: Logger, isEnabledWorker: Boolean): (String, Int) = {
using(findOpenPort(networkParams.defaultListenPort, numTasksPerExec, log)) {
openPort =>
val localListenPort = openPort.getLocalPort
// Initialize the native library
LightGBMUtils.initializeNativeLibrary()
log.info(s"LightGBM task connecting to host: ${networkParams.addr} and port: ${networkParams.port}")
FaultToleranceUtils.retryWithTimeout() {
(getNetworkInitNodes(networkParams, localListenPort, log, emptyPartition), localListenPort)
(getNetworkInitNodes(networkParams, localListenPort, log, !isEnabledWorker), localListenPort)
}
}.get
}

def trainLightGBM(batchIndex: Int, networkParams: NetworkParams, columnParams: ColumnParams,
validationData: Option[Broadcast[Array[Row]]], log: Logger,
trainParams: TrainParams, numTasksPerExec: Int, schema: StructType)
(inputRows: Iterator[Row]): Iterator[LightGBMBooster] = {
val emptyPartition = !inputRows.hasNext
// Initialize the native library
LightGBMUtils.initializeNativeLibrary()
val (nodes, localListenPort) = getNetworkInfo(networkParams, numTasksPerExec, log, !emptyPartition)

if (emptyPartition) {
log.warn("LightGBM task encountered empty partition, for best performance ensure no partitions empty")
Expand Down
Loading

0 comments on commit e8a97ed

Please sign in to comment.