Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-531] Custom Operator Example for Scala (#11401)
Browse files Browse the repository at this point in the history
Update Custom Operator Example to use new Symbol.api
  • Loading branch information
lanking520 authored and nswamy committed Jul 18, 2018
1 parent 896271b commit 072dd5a
Show file tree
Hide file tree
Showing 10 changed files with 354 additions and 243 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ package org.apache.mxnet
* Main code will be generated during compile time through Macros
*/
object NDArrayAPI extends NDArrayAPIBase {
// TODO: Implement CustomOp for NDArray
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,19 @@
*/
package org.apache.mxnet

import scala.collection.mutable


@AddSymbolAPIs(false)
/**
* typesafe Symbol API: Symbol.api._
* Main code will be generated during compile time through Macros
*/
object SymbolAPI extends SymbolAPIBase {
def Custom (op_type : String, kwargs : mutable.Map[String, Any],
name : String = null, attr : Map[String, String] = null) : Symbol = {
val map = kwargs
map.put("op_type", op_type)
Symbol.createSymbolGeneral("Custom", name, attr, Seq(), map.toMap)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,8 @@

package org.apache.mxnetexamples.customop

import org.apache.mxnet.Shape
import org.apache.mxnet.IO
import org.apache.mxnet.DataIter
import org.apache.mxnet.{DataIter, IO, Shape}

/**
* @author Depeng Liang
*/
object Data {
// return train and val iterators for mnist
def mnistIterator(dataPath: String, batchSize: Int, inputShape: Shape): (DataIter, DataIter) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,37 +17,26 @@

package org.apache.mxnetexamples.customop

import org.apache.mxnet.Callback.Speedometer
import org.apache.mxnet.DType.DType
import org.apache.mxnet.{Accuracy, Context, CustomOp, CustomOpProp, NDArray, Operator, Shape, Symbol, Xavier}
import org.apache.mxnet.optimizer.RMSProp
import org.kohsuke.args4j.{CmdLineParser, Option}
import org.slf4j.LoggerFactory

import scala.collection.JavaConverters._
import org.apache.mxnet.Symbol
import org.apache.mxnet.DType.DType
import org.apache.mxnet.DataIter
import org.apache.mxnet.DataBatch
import org.apache.mxnet.NDArray
import org.apache.mxnet.Shape
import org.apache.mxnet.EvalMetric
import org.apache.mxnet.Context
import org.apache.mxnet.Xavier
import org.apache.mxnet.optimizer.RMSProp
import org.apache.mxnet.CustomOp
import org.apache.mxnet.CustomOpProp
import org.apache.mxnet.Operator
import org.apache.mxnet.optimizer.SGD
import org.apache.mxnet.Accuracy
import org.apache.mxnet.Callback.Speedometer
import scala.collection.mutable

/**
* Example of CustomOp
* @author Depeng Liang
*/
* Example of CustomOp
*/
object ExampleCustomOp {
private val logger = LoggerFactory.getLogger(classOf[ExampleCustomOp])

class Softmax(_param: Map[String, String]) extends CustomOp {

override def forward(sTrain: Boolean, req: Array[String],
inData: Array[NDArray], outData: Array[NDArray], aux: Array[NDArray]): Unit = {
override def forward(sTrain: Boolean, req: Array[String], inData: Array[NDArray],
outData: Array[NDArray], aux: Array[NDArray]): Unit = {
val xShape = inData(0).shape
val x = inData(0).toArray.grouped(xShape(1)).toArray
val yArr = x.map { it =>
Expand All @@ -63,8 +52,8 @@ object ExampleCustomOp {
}

override def backward(req: Array[String], outGrad: Array[NDArray],
inData: Array[NDArray], outData: Array[NDArray],
inGrad: Array[NDArray], aux: Array[NDArray]): Unit = {
inData: Array[NDArray], outData: Array[NDArray],
inGrad: Array[NDArray], aux: Array[NDArray]): Unit = {
val l = inData(1).toArray.map(_.toInt)
val oShape = outData(0).shape
val yArr = outData(0).toArray.grouped(oShape(1)).toArray
Expand All @@ -86,24 +75,121 @@ object ExampleCustomOp {
override def listOutputs(): Array[String] = Array("output")

override def inferShape(inShape: Array[Shape]):
(Array[Shape], Array[Shape], Array[Shape]) = {
(Array[Shape], Array[Shape], Array[Shape]) = {
val dataShape = inShape(0)
val labelShape = Shape(dataShape(0))
val outputShape = dataShape
(Array(dataShape, labelShape), Array(outputShape), null)
}

override def inferType(inType: Array[DType]):
(Array[DType], Array[DType], Array[DType]) = {
(Array[DType], Array[DType], Array[DType]) = {
(inType, inType.take(1), null)
}

override def createOperator(ctx: String, inShapes: Array[Array[Int]],
inDtypes: Array[Int]): CustomOp = new Softmax(this.kwargs)
inDtypes: Array[Int]): CustomOp = new Softmax(this.kwargs)
}

Operator.register("softmax", new SoftmaxProp)

def test(dataPath : String, ctx : Context) : Float = {
val data = Symbol.Variable("data")
val label = Symbol.Variable("label")
val fc1 = Symbol.api.FullyConnected(data = Some(data), num_hidden = 128, name = "fc1")
val act1 = Symbol.api.Activation (data = Some(fc1), "relu", name = "relu")
val fc2 = Symbol.api.FullyConnected(Some(act1), None, None, 64, name = "fc2")
val act2 = Symbol.api.Activation(data = Some(fc2), "relu", name = "relu2")
val fc3 = Symbol.api.FullyConnected(Some(act2), None, None, 10, name = "fc3")
val kwargs = mutable.Map[String, Any]("label" -> label, "data" -> fc3)
val mlp = Symbol.api.Custom(op_type = "softmax", name = "softmax", kwargs = kwargs)

val (trainIter, testIter) =
Data.mnistIterator(dataPath, batchSize = 100, inputShape = Shape(784))

val datasAndLabels = trainIter.provideData ++ trainIter.provideLabel
val (argShapes, outputShapes, auxShapes) = mlp.inferShape(datasAndLabels)

val initializer = new Xavier(factorType = "in", magnitude = 2.34f)
val argNames = mlp.listArguments()
val argDict = argNames.zip(argShapes.map(s => NDArray.empty(s, ctx))).toMap

val gradDict = argNames.zip(argShapes).filter { case (name, shape) =>
!datasAndLabels.contains(name)
}.map(x => x._1 -> NDArray.empty(x._2, ctx) ).toMap

argDict.foreach { case (name, ndArray) =>
if (!datasAndLabels.contains(name)) {
initializer.initWeight(name, ndArray)
}
}

val executor = mlp.bind(ctx, argDict, gradDict)
val lr = 0.001f
val opt = new RMSProp(learningRate = lr, wd = 0.00001f)
val paramsGrads = gradDict.toList.zipWithIndex.map { case ((name, grad), idx) =>
(idx, name, grad, opt.createState(idx, argDict(name)))
}

val evalMetric = new Accuracy
val batchEndCallback = new Speedometer(100, 100)
val numEpoch = 10
var validationAcc = 0.0f

for (epoch <- 0 until numEpoch) {
val tic = System.currentTimeMillis
evalMetric.reset()
var nBatch = 0
var epochDone = false

trainIter.reset()
while (!epochDone) {
var doReset = true
while (doReset && trainIter.hasNext) {
val dataBatch = trainIter.next()
argDict("data").set(dataBatch.data(0))
argDict("label").set(dataBatch.label(0))
executor.forward(isTrain = true)
executor.backward()
paramsGrads.foreach { case (idx, name, grad, optimState) =>
opt.update(idx, argDict(name), grad, optimState)
}
evalMetric.update(dataBatch.label, executor.outputs)
nBatch += 1
batchEndCallback.invoke(epoch, nBatch, evalMetric)
}
if (doReset) {
trainIter.reset()
}
epochDone = true
}
val (name, value) = evalMetric.get
name.zip(value).foreach { case (n, v) =>
logger.info(s"Epoch[$epoch] Train-accuracy=$v")
}
val toc = System.currentTimeMillis
logger.info(s"Epoch[$epoch] Time cost=${toc - tic}")

evalMetric.reset()
testIter.reset()
while (testIter.hasNext) {
val evalBatch = testIter.next()
argDict("data").set(evalBatch.data(0))
argDict("label").set(evalBatch.label(0))
executor.forward(isTrain = true)
evalMetric.update(evalBatch.label, executor.outputs)
evalBatch.dispose()
}
val (names, values) = evalMetric.get
names.zip(values).foreach { case (n, v) =>
logger.info(s"Epoch[$epoch] Validation-accuracy=$v")
validationAcc = Math.max(validationAcc, v)
}
}
executor.dispose()
validationAcc
}

def main(args: Array[String]): Unit = {
val leop = new ExampleCustomOp
val parser: CmdLineParser = new CmdLineParser(leop)
Expand All @@ -115,98 +201,8 @@ object ExampleCustomOp {

val dataName = Array("data")
val labelName = Array("softmax_label")
test(leop.dataPath, ctx)

val data = Symbol.Variable("data")
val label = Symbol.Variable("label")
val fc1 = Symbol.FullyConnected("fc1")()(Map("data" -> data, "num_hidden" -> 128))
val act1 = Symbol.Activation("relu1")()(Map("data" -> fc1, "act_type" -> "relu"))
val fc2 = Symbol.FullyConnected("fc2")()(Map("data" -> act1, "num_hidden" -> 64))
val act2 = Symbol.Activation("relu2")()(Map("data" -> fc2, "act_type" -> "relu"))
val fc3 = Symbol.FullyConnected("fc3")()(Map("data" -> act2, "num_hidden" -> 10))
val mlp = Symbol.Custom("softmax")()(Map("data" -> fc3,
"label" -> label, "op_type" -> "softmax"))

val (trainIter, testIter) =
Data.mnistIterator(leop.dataPath, batchSize = 100, inputShape = Shape(784))

val datasAndLabels = trainIter.provideData ++ trainIter.provideLabel
val (argShapes, outputShapes, auxShapes) = mlp.inferShape(datasAndLabels)

val initializer = new Xavier(factorType = "in", magnitude = 2.34f)
val argNames = mlp.listArguments()
val argDict = argNames.zip(argShapes.map(s => NDArray.empty(s, ctx))).toMap

val gradDict = argNames.zip(argShapes).filter { case (name, shape) =>
!datasAndLabels.contains(name)
}.map(x => x._1 -> NDArray.empty(x._2, ctx) ).toMap

argDict.foreach { case (name, ndArray) =>
if (!datasAndLabels.contains(name)) {
initializer.initWeight(name, ndArray)
}
}

val executor = mlp.bind(ctx, argDict, gradDict)
val lr = 0.001f
val opt = new RMSProp(learningRate = lr, wd = 0.00001f)
val paramsGrads = gradDict.toList.zipWithIndex.map { case ((name, grad), idx) =>
(idx, name, grad, opt.createState(idx, argDict(name)))
}

val evalMetric = new Accuracy
val batchEndCallback = new Speedometer(100, 100)
val numEpoch = 20

for (epoch <- 0 until numEpoch) {
val tic = System.currentTimeMillis
evalMetric.reset()
var nBatch = 0
var epochDone = false

trainIter.reset()
while (!epochDone) {
var doReset = true
while (doReset && trainIter.hasNext) {
val dataBatch = trainIter.next()
argDict("data").set(dataBatch.data(0))
argDict("label").set(dataBatch.label(0))
executor.forward(isTrain = true)
executor.backward()
paramsGrads.foreach { case (idx, name, grad, optimState) =>
opt.update(idx, argDict(name), grad, optimState)
}
evalMetric.update(dataBatch.label, executor.outputs)
nBatch += 1
batchEndCallback.invoke(epoch, nBatch, evalMetric)
}
if (doReset) {
trainIter.reset()
}
epochDone = true
}
val (name, value) = evalMetric.get
name.zip(value).foreach { case (n, v) =>
logger.info(s"Epoch[$epoch] Train-accuracy=$v")
}
val toc = System.currentTimeMillis
logger.info(s"Epoch[$epoch] Time cost=${toc - tic}")

evalMetric.reset()
testIter.reset()
while (testIter.hasNext) {
val evalBatch = testIter.next()
argDict("data").set(evalBatch.data(0))
argDict("label").set(evalBatch.label(0))
executor.forward(isTrain = true)
evalMetric.update(evalBatch.label, executor.outputs)
evalBatch.dispose()
}
val (names, values) = evalMetric.get
names.zip(values).foreach { case (n, v) =>
logger.info(s"Epoch[$epoch] Validation-accuracy=$v")
}
}
executor.dispose()
} catch {
case ex: Exception => {
logger.error(ex.getMessage, ex)
Expand Down
Loading

0 comments on commit 072dd5a

Please sign in to comment.