Skip to content

Commit

Permalink
Update store and load
Browse files Browse the repository at this point in the history
  • Loading branch information
harupy committed Feb 23, 2023
1 parent a1b12b6 commit 8c69ae0
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class XGBoostClassificationModelOp extends SimpleSparkOp[XGBoostClassificationMo
withValue("infer_batch_size", Value.int(obj.getOrDefault(obj.inferBatchSize))).
withValue("use_external_memory", Value.boolean(obj.getOrDefault(obj.useExternalMemory))).
withValue("allow_non_zero_for_missing", Value.boolean(obj.getOrDefault(obj.allowNonZeroForMissing)))
withValue("objective", Value.stirng(obj.getOrDefault(obj.objective)))
}

override def load(model: Model)
Expand All @@ -57,6 +58,7 @@ class XGBoostClassificationModelOp extends SimpleSparkOp[XGBoostClassificationMo
model.getValue("allow_non_zero_for_missing").map(o => xgb.setAllowNonZeroForMissing(o.getBoolean))
model.getValue("infer_batch_size").map(o => xgb.setInferBatchSize(o.getInt))
model.getValue("use_external_memory").map(o => xgb.set(xgb.useExternalMemory, o.getBoolean))
model.getValue("objective").map(o => xgb.setObjective(o.getString))
xgb
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class XGBoostClassificationModelParitySpec extends SparkParityBase {
)

// These params are not needed for making predictions, so we don't serialize them
override val unserializedParams = Set("labelCol", "evalMetric", "objective")
override val unserializedParams = Set("labelCol", "evalMetric")

override val excludedColsForComparison = Array[String]("prediction")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class XGBoostRegressionModelParitySpec extends SparkParityBase {
)

// These params are not needed for making predictions, so we don't serialize them
override val unserializedParams = Set("labelCol", "evalMetric", "objective")
override val unserializedParams = Set("labelCol", "evalMetric")

val dataset: DataFrame = {
import spark.sqlContext.implicits._
Expand Down

0 comments on commit 8c69ae0

Please sign in to comment.