From 034aaaa32cb59bbde8652800810d8b2a7b3b70ef Mon Sep 17 00:00:00 2001 From: TensorFlow Lattice Authors Date: Thu, 30 Jul 2020 13:46:26 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 324074661 Change-Id: I98daeac3af4e94760aa940baf3df0dd221c4c591 --- build_docs.py | 1 + .../tutorials/aggregate_function_models.ipynb | 2 +- docs/tutorials/canned_estimators.ipynb | 87 +- docs/tutorials/custom_estimators.ipynb | 2 +- docs/tutorials/keras_layers.ipynb | 8 +- docs/tutorials/premade_models.ipynb | 132 ++- docs/tutorials/shape_constraints.ipynb | 4 +- .../shape_constraints_for_ethics.ipynb | 4 +- examples/canned_estimators_uci_heart.py | 4 +- examples/keras_sequential_uci_heart.py | 6 +- setup.py | 2 +- tensorflow_lattice/BUILD | 3 + tensorflow_lattice/__init__.py | 8 +- tensorflow_lattice/layers/__init__.py | 1 + tensorflow_lattice/python/BUILD | 393 +++++---- .../python/categorical_calibration_lib.py | 4 +- tensorflow_lattice/python/configs.py | 21 + tensorflow_lattice/python/estimators.py | 803 ++++++++++-------- tensorflow_lattice/python/estimators_test.py | 89 +- tensorflow_lattice/python/internal_utils.py | 15 +- .../python/internal_utils_test.py | 20 +- .../kronecker_factored_lattice_layer.py | 428 ++++++++++ .../python/kronecker_factored_lattice_lib.py | 409 +++++++++ .../python/kronecker_factored_lattice_test.py | 760 +++++++++++++++++ tensorflow_lattice/python/lattice_layer.py | 135 ++- tensorflow_lattice/python/lattice_lib.py | 162 +--- tensorflow_lattice/python/lattice_test.py | 81 +- tensorflow_lattice/python/linear_layer.py | 19 +- tensorflow_lattice/python/linear_lib.py | 195 ++--- tensorflow_lattice/python/linear_test.py | 3 +- tensorflow_lattice/python/premade.py | 97 +-- tensorflow_lattice/python/premade_lib.py | 313 ++++++- tensorflow_lattice/python/premade_test.py | 81 +- .../python/pwl_calibration_layer.py | 19 +- .../python/pwl_calibration_lib.py | 65 +- .../python/pwl_calibration_sonnet_module.py | 23 +- .../python/pwl_calibration_test.py | 49 +- tensorflow_lattice/python/rtl_layer.py | 91 +- tensorflow_lattice/python/rtl_lib.py | 90 ++ tensorflow_lattice/python/utils.py | 242 ++++++ tensorflow_lattice/python/utils_test.py | 177 ++++ tensorflow_lattice/python/visualization.py | 1 - 42 files changed, 3848 insertions(+), 1201 deletions(-) create mode 100644 tensorflow_lattice/python/kronecker_factored_lattice_layer.py create mode 100644 tensorflow_lattice/python/kronecker_factored_lattice_lib.py create mode 100644 tensorflow_lattice/python/kronecker_factored_lattice_test.py create mode 100644 tensorflow_lattice/python/rtl_lib.py create mode 100644 tensorflow_lattice/python/utils.py create mode 100644 tensorflow_lattice/python/utils_test.py diff --git a/build_docs.py b/build_docs.py index 2b544b6..4d18e66 100644 --- a/build_docs.py +++ b/build_docs.py @@ -64,6 +64,7 @@ def main(_): 'tfl': ['python'], 'tfl.aggregation_layer': ['Aggregation'], 'tfl.categorical_calibration_layer': ['CategoricalCalibration'], + 'tfl.kronecker_factored_lattice_layer': ['KroneckerFactoredLattice'], 'tfl.lattice_layer': ['Lattice'], 'tfl.linear_layer': ['Linear'], 'tfl.pwl_calibration_layer': ['PWLCalibration'], diff --git a/docs/tutorials/aggregate_function_models.ipynb b/docs/tutorials/aggregate_function_models.ipynb index dc5b2e9..3ece46b 100644 --- a/docs/tutorials/aggregate_function_models.ipynb +++ b/docs/tutorials/aggregate_function_models.ipynb @@ -476,7 +476,7 @@ "source": [ "## Aggregate Function Model\n", "\n", - "To construct a TFL premade model, first construct a model configuration from [tfl.configs](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs). An aggregate function model is constructed using the [tfl.configs.AggregateFunctionConfig](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs/AggregateFunctionConfig). It applies piecewise-linear and categorical calibration, followed by a lattice model on each dimension of the ragged input. It then applies an aggregation layer over the output for each dimension. This is then followed by an optional output piecewise-lienar calibration." + "To construct a TFL premade model, first construct a model configuration from [tfl.configs](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs). An aggregate function model is constructed using the [tfl.configs.AggregateFunctionConfig](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs/AggregateFunctionConfig). It applies piecewise-linear and categorical calibration, followed by a lattice model on each dimension of the ragged input. It then applies an aggregation layer over the output for each dimension. This is then followed by an optional output piecewise-linear calibration." ] }, { diff --git a/docs/tutorials/canned_estimators.ipynb b/docs/tutorials/canned_estimators.ipynb index 4d32b49..413fd03 100644 --- a/docs/tutorials/canned_estimators.ipynb +++ b/docs/tutorials/canned_estimators.ipynb @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "cellView": "form", "colab": {}, @@ -101,7 +101,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -125,7 +125,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "cellView": "both", "colab": {}, @@ -136,6 +136,7 @@ "source": [ "import tensorflow as tf\n", "\n", + "import copy\n", "import logging\n", "import numpy as np\n", "import pandas as pd\n", @@ -157,7 +158,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "cellView": "both", "colab": {}, @@ -190,7 +191,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "cellView": "both", "colab": {}, @@ -219,7 +220,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -285,7 +286,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -341,7 +342,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -450,7 +451,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -499,7 +500,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -559,7 +560,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -568,7 +569,7 @@ "outputs": [], "source": [ "# This is random lattice ensemble model with separate calibration:\n", - "# model output is the average output of separatly calibrated lattices.\n", + "# model output is the average output of separately calibrated lattices.\n", "model_config = tfl.configs.CalibratedLatticeEnsembleConfig(\n", " feature_configs=feature_configs,\n", " num_lattices=5,\n", @@ -589,6 +590,60 @@ "tfl.visualization.draw_model_graph(model_graph, calibrator_dpi=15)" ] }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "7uyO8s97FGJM" + }, + "source": [ + "### RTL Layer Random Lattice Ensemble\n", + "\n", + "The following model config uses a `tfl.layers.RTL` layer that uses a random subset of features for each lattice. We note that `tfl.layers.RTL` only supports monotonicity constraints and must have the same lattice size for all features and no per-feature regularization. Note that using a `tfl.layers.RTL` layer lets you scale to much larger ensembles than using separate `tfl.layers.Lattice` instances." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "8v7dKg-FF7iz" + }, + "outputs": [], + "source": [ + "# Make sure our feature configs have the same lattice size, no per-feature\n", + "# regularization, and only monotonicity constraints.\n", + "rtl_layer_feature_configs = copy.deepcopy(feature_configs)\n", + "for feature_config in rtl_layer_feature_configs:\n", + " feature_config.lattice_size = 2\n", + " feature_config.unimodality = 'none'\n", + " feature_config.reflects_trust_in = None\n", + " feature_config.dominates = None\n", + " feature_config.regularizer_configs = None\n", + "# This is RTL layer ensemble model with separate calibration:\n", + "# model output is the average output of separately calibrated lattices.\n", + "model_config = tfl.configs.CalibratedLatticeEnsembleConfig(\n", + " lattices='rtl_layer',\n", + " feature_configs=rtl_layer_feature_configs,\n", + " num_lattices=5,\n", + " lattice_rank=3)\n", + "# A CannedClassifier is constructed from the given model config.\n", + "estimator = tfl.estimators.CannedClassifier(\n", + " feature_columns=feature_columns,\n", + " model_config=model_config,\n", + " feature_analysis_input_fn=feature_analysis_input_fn,\n", + " optimizer=tf.keras.optimizers.Adam(LEARNING_RATE),\n", + " config=tf.estimator.RunConfig(tf_random_seed=42))\n", + "estimator.train(input_fn=train_input_fn)\n", + "results = estimator.evaluate(input_fn=test_input_fn)\n", + "print('Random ensemble test AUC: {}'.format(results['auc']))\n", + "saved_model_path = estimator.export_saved_model(estimator.model_dir,\n", + " serving_input_fn)\n", + "model_graph = tfl.estimators.get_model_graph(saved_model_path)\n", + "tfl.visualization.draw_model_graph(model_graph, calibrator_dpi=15)" + ] + }, { "cell_type": "markdown", "metadata": { @@ -605,7 +660,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -634,7 +689,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -643,7 +698,7 @@ "outputs": [], "source": [ "# This is Crystals ensemble model with separate calibration: model output is\n", - "# the average output of separatly calibrated lattices.\n", + "# the average output of separately calibrated lattices.\n", "model_config = tfl.configs.CalibratedLatticeEnsembleConfig(\n", " feature_configs=feature_configs,\n", " lattices='crystals',\n", @@ -680,7 +735,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", diff --git a/docs/tutorials/custom_estimators.ipynb b/docs/tutorials/custom_estimators.ipynb index 2c71272..4adb4e2 100644 --- a/docs/tutorials/custom_estimators.ipynb +++ b/docs/tutorials/custom_estimators.ipynb @@ -303,7 +303,7 @@ "\n", "There are several ways to create a custom estimator. Here we will construct a `model_fn` that calls a Keras model on the parsed input tensors. To parse the input features, you can use `tf.feature_column.input_layer`, `tf.keras.layers.DenseFeatures`, or `tfl.estimators.transform_features`. If you use the latter, you will not need to wrap categorical features with dense feature columns, and the resulting tensors will not be concatenated, which makes it easier to use the features in the calibration layers.\n", "\n", - "To construct a model, you can mix and match TFL layers or any other Keras layers. Here we create a calibrated lattice Keras model out of TFL layers and impose several monotonicity constraints. When then use the Keras model to create the custom estimator.\n" + "To construct a model, you can mix and match TFL layers or any other Keras layers. Here we create a calibrated lattice Keras model out of TFL layers and impose several monotonicity constraints. We then use the Keras model to create the custom estimator.\n" ] }, { diff --git a/docs/tutorials/keras_layers.ipynb b/docs/tutorials/keras_layers.ipynb index 81bfbad..f98fc00 100644 --- a/docs/tutorials/keras_layers.ipynb +++ b/docs/tutorials/keras_layers.ipynb @@ -237,7 +237,7 @@ "id": "W3DnEKWvQYXm" }, "source": [ - "We use a `tfl.layers.ParallelCombination` layer to group together calibration layers which have to be executed in paralel in order to be able to create a Sequential model.\n" + "We use a `tfl.layers.ParallelCombination` layer to group together calibration layers which have to be executed in parallel in order to be able to create a Sequential model.\n" ] }, { @@ -260,7 +260,7 @@ "id": "BPZsSUZiQiwc" }, "source": [ - "We create a calibration layer for each feature and add it to the parallel combination layer. For numeric features we use `tfl.layers.PWLCalibration` and for categorical features we use `tfl.layers.CategoricalCalibration`." + "We create a calibration layer for each feature and add it to the parallel combination layer. For numeric features we use `tfl.layers.PWLCalibration`, and for categorical features we use `tfl.layers.CategoricalCalibration`." ] }, { @@ -282,7 +282,7 @@ " training_data_df['age'].min(), training_data_df['age'].max(), num=5),\n", " # You need to ensure that input keypoints have same dtype as layer input.\n", " # You can do it by setting dtype here or by providing keypoints in such\n", - " # format which will be converted to deisred tf.dtype by default.\n", + " # format which will be converted to desired tf.dtype by default.\n", " dtype=tf.float32,\n", " # Output range must correspond to expected lattice input range.\n", " output_min=0.0,\n", @@ -542,7 +542,7 @@ " training_data_df['age'].min(), training_data_df['age'].max(), num=5),\n", " # You need to ensure that input keypoints have same dtype as layer input.\n", " # You can do it by setting dtype here or by providing keypoints in such\n", - " # format which will be converted to deisred tf.dtype by default.\n", + " # format which will be converted to desired tf.dtype by default.\n", " dtype=tf.float32,\n", " # Output range must correspond to expected lattice input range.\n", " output_min=0.0,\n", diff --git a/docs/tutorials/premade_models.ipynb b/docs/tutorials/premade_models.ipynb index db1e9ad..9ae3157 100644 --- a/docs/tutorials/premade_models.ipynb +++ b/docs/tutorials/premade_models.ipynb @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "cellView": "form", "colab": {}, @@ -93,7 +93,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -117,7 +117,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -127,6 +127,7 @@ "source": [ "import tensorflow as tf\n", "\n", + "import copy\n", "import logging\n", "import numpy as np\n", "import pandas as pd\n", @@ -147,7 +148,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -176,7 +177,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -214,7 +215,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -248,7 +249,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -262,7 +263,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -290,7 +291,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -332,7 +333,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -379,7 +380,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -509,7 +510,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -536,7 +537,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -576,7 +577,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -604,7 +605,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -632,7 +633,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -671,7 +672,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -715,7 +716,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -756,7 +757,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -788,7 +789,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -830,7 +831,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -848,6 +849,91 @@ "print(random_ensemble_model.evaluate(test_xs, test_ys))" ] }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ZhJWe7fZIs4-" + }, + "source": [ + "### RTL Layer Random Lattice Ensemble\n", + "\n", + "When using a random lattice ensemble, you can specify that the model use a single `tfl.layers.RTL` layer. We note that `tfl.layers.RTL` only supports monotonicity constraints and must have the same lattice size for all features and no per-feature regularization. Note that using a `tfl.layers.RTL` layer lets you scale to much larger ensembles than using separate `tfl.layers.Lattice` instances.\n", + "\n", + "This example creates a calibrated lattice ensemble model with 5 lattices and 3 features per lattice." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "0PC9oRFYJMF_" + }, + "outputs": [], + "source": [ + "# Make sure our feature configs have the same lattice size, no per-feature\n", + "# regularization, and only monotonicity constraints.\n", + "rtl_layer_feature_configs = copy.deepcopy(feature_configs)\n", + "for feature_config in rtl_layer_feature_configs:\n", + " feature_config.lattice_size = 2\n", + " feature_config.unimodality = 'none'\n", + " feature_config.reflects_trust_in = None\n", + " feature_config.dominates = None\n", + " feature_config.regularizer_configs = None\n", + "# This is a calibrated lattice ensemble model: inputs are calibrated, then\n", + "# combined non-linearly and averaged using multiple lattice layers.\n", + "rtl_layer_ensemble_model_config = tfl.configs.CalibratedLatticeEnsembleConfig(\n", + " feature_configs=rtl_layer_feature_configs,\n", + " lattices='rtl_layer',\n", + " num_lattices=5,\n", + " lattice_rank=3,\n", + " output_min=min_label,\n", + " output_max=max_label - numerical_error_epsilon,\n", + " output_initialization=[min_label, max_label],\n", + " random_seed=42)\n", + "# A CalibratedLatticeEnsemble premade model constructed from the given\n", + "# model config. Note that we do not have to specify the lattices by calling\n", + "# a helper function (like before with random) because the RTL Layer will take\n", + "# care of that for us.\n", + "rtl_layer_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(\n", + " rtl_layer_ensemble_model_config)\n", + "# Let's plot our model.\n", + "tf.keras.utils.plot_model(\n", + " rtl_layer_ensemble_model, show_layer_names=False, rankdir='LR')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "yWdxZpS0JWag" + }, + "source": [ + "As before, we compile, fit, and evaluate our model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "HQdkkWwqJW8p" + }, + "outputs": [], + "source": [ + "rtl_layer_ensemble_model.compile(\n", + " loss=tf.keras.losses.BinaryCrossentropy(),\n", + " metrics=[tf.keras.metrics.AUC()],\n", + " optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))\n", + "rtl_layer_ensemble_model.fit(\n", + " train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)\n", + "print('Test Set Evaluation...')\n", + "print(rtl_layer_ensemble_model.evaluate(test_xs, test_ys))" + ] + }, { "cell_type": "markdown", "metadata": { @@ -866,7 +952,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -927,7 +1013,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", diff --git a/docs/tutorials/shape_constraints.ipynb b/docs/tutorials/shape_constraints.ipynb index f39bf90..6004f03 100644 --- a/docs/tutorials/shape_constraints.ipynb +++ b/docs/tutorials/shape_constraints.ipynb @@ -372,7 +372,7 @@ "\n", "In practice, users often do not go through all search results. This means that users will likely only see restaurants already considered \"good\" by the current ranking model in use. As a result, \"good\" restaurants are more frequently impressed and over-represented in the training datasets. When using more features, the training dataset can have large gaps in \"bad\" parts of the feature space.\n", "\n", - "When the model is used for ranking, it is often evaluated on all relevant results with a more uniform distribution that is not well-represented by the training dataset. A flexible and complicated model might fail in this case due to overfitting the over-represented data points and thus lacking generalizability. We handle this issue by applying domain knowledge to add *shape constraints* that guide the model to make reasonable predictions when it cannot pick them up from the training dataset.\n", + "When the model is used for ranking, it is often evaluated on all relevant results with a more uniform distribution that is not well-represented by the training dataset. A flexible and complicated model might fail in this case due to overfitting the over-represented data points and thus lack generalizability. We handle this issue by applying domain knowledge to add *shape constraints* that guide the model to make reasonable predictions when it cannot pick them up from the training dataset.\n", "\n", "In this example, the training dataset mostly consists of user interactions with good and popular restaurants. The testing dataset has a uniform distribution to simulate the evaluation setting discussed above. Note that such testing dataset will not be available in a real problem setting." ] @@ -1008,7 +1008,7 @@ "id": "HHpp4goLvuPi" }, "source": [ - "The calibrators are now smooth and the overall estimated CTR better matches the ground truth. This is reflected both in the testing metric and in the contour plots." + "The calibrators are now smooth, and the overall estimated CTR better matches the ground truth. This is reflected both in the testing metric and in the contour plots." ] }, { diff --git a/docs/tutorials/shape_constraints_for_ethics.ipynb b/docs/tutorials/shape_constraints_for_ethics.ipynb index d96f315..6ac3121 100644 --- a/docs/tutorials/shape_constraints_for_ethics.ipynb +++ b/docs/tutorials/shape_constraints_for_ethics.ipynb @@ -406,7 +406,7 @@ "bar.\n", "\n", "We will first train a calibrated linear model without any constraints. Then, we\n", - "will train a calibrated linear model with monotonicity constraints, and observe\n", + "will train a calibrated linear model with monotonicity constraints and observe\n", "the difference in the model output and accuracy." ] }, @@ -1064,7 +1064,7 @@ "person defaulted on a loan.\n", "\n", "We will first train a calibrated linear model without any constraints. Then, we\n", - "will train a calibrated linear model with monotonicity constraints, and observe\n", + "will train a calibrated linear model with monotonicity constraints and observe\n", "the difference in the model output and accuracy." ] }, diff --git a/examples/canned_estimators_uci_heart.py b/examples/canned_estimators_uci_heart.py index 51597c3..881b3e2 100644 --- a/examples/canned_estimators_uci_heart.py +++ b/examples/canned_estimators_uci_heart.py @@ -265,7 +265,7 @@ def main(_): estimator.export_saved_model(estimator.model_dir, serving_input_fn))) # This is random lattice ensemble model with separate calibration: - # model output is the average output of separatly calibrated lattices. + # model output is the average output of separately calibrated lattices. model_config = configs.CalibratedLatticeEnsembleConfig( feature_configs=feature_configs, num_lattices=6, @@ -290,7 +290,7 @@ def main(_): estimator.export_saved_model(estimator.model_dir, serving_input_fn))) # This is Crystals ensemble model with separate calibration: model output is - # the average output of separatly calibrated lattices. + # the average output of separately calibrated lattices. # Crystals algorithm first trains a prefitting model and uses the interactions # between features to form the final lattice ensemble. model_config = configs.CalibratedLatticeEnsembleConfig( diff --git a/examples/keras_sequential_uci_heart.py b/examples/keras_sequential_uci_heart.py index 3c721ec..80ce0c7 100644 --- a/examples/keras_sequential_uci_heart.py +++ b/examples/keras_sequential_uci_heart.py @@ -114,7 +114,7 @@ def main(_): lattice_sizes = [3, 2, 2, 2, 2, 2, 2] # Use ParallelCombination helper layer to group togehter calibration layers - # which have to be executed in paralel in order to be able to use Sequential + # which have to be executed in parallel in order to be able to use Sequential # model. Alternatively you can use functional API. combined_calibrators = tfl.layers.ParallelCombination() @@ -131,7 +131,7 @@ def main(_): num=5), # You need to ensure that input keypoints have same dtype as layer input. # You can do it by setting dtype here or by providing keypoints in such - # format which will be converted to deisred tf.dtype by default. + # format which will be converted to desired tf.dtype by default. dtype=tf.float32, # Output range must correspond to expected lattice input range. output_min=0.0, @@ -196,7 +196,7 @@ def main(_): # Monotonicity of calibrator can be 'decreasing'. Note that corresponding # lattice dimension must have 'increasing' monotonicity regardless of # monotonicity direction of calibrator. - # Its not some weird configuration hack. Its just how math works :) + # It's not some weird configuration hack. It's just how math works :) monotonicity='decreasing', # Convexity together with decreasing monotonicity result in diminishing # return constraint. diff --git a/setup.py b/setup.py index 597705e..96b9807 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ # This version number should always be that of the *next* (unreleased) version. # Immediately after uploading a package to PyPI, you should increment the # version number and push to gitHub. -__version__ = "2.0.5" +__version__ = "2.0.6" if "--release" in sys.argv: sys.argv.remove("--release") diff --git a/tensorflow_lattice/BUILD b/tensorflow_lattice/BUILD index 90b09f0..78d86f7 100644 --- a/tensorflow_lattice/BUILD +++ b/tensorflow_lattice/BUILD @@ -37,6 +37,8 @@ py_library( "//tensorflow_lattice/python:categorical_calibration_lib", "//tensorflow_lattice/python:configs", "//tensorflow_lattice/python:estimators", + "//tensorflow_lattice/python:kronecker_factored_lattice_layer", + "//tensorflow_lattice/python:kronecker_factored_lattice_lib", "//tensorflow_lattice/python:lattice_layer", "//tensorflow_lattice/python:lattice_lib", "//tensorflow_lattice/python:linear_layer", @@ -50,6 +52,7 @@ py_library( "//tensorflow_lattice/python:pwl_calibration_sonnet_module", "//tensorflow_lattice/python:rtl_layer", "//tensorflow_lattice/python:test_utils", + "//tensorflow_lattice/python:utils", "//tensorflow_lattice/python:visualization", ], ) diff --git a/tensorflow_lattice/__init__.py b/tensorflow_lattice/__init__.py index aef8487..a0470dd 100644 --- a/tensorflow_lattice/__init__.py +++ b/tensorflow_lattice/__init__.py @@ -20,13 +20,14 @@ from __future__ import absolute_import import tensorflow_lattice.layers -import tensorflow_lattice.sonnet_modules from tensorflow_lattice.python import aggregation_layer from tensorflow_lattice.python import categorical_calibration_layer from tensorflow_lattice.python import categorical_calibration_lib from tensorflow_lattice.python import configs from tensorflow_lattice.python import estimators +from tensorflow_lattice.python import kronecker_factored_lattice_layer +from tensorflow_lattice.python import kronecker_factored_lattice_lib from tensorflow_lattice.python import lattice_layer from tensorflow_lattice.python import lattice_lib from tensorflow_lattice.python import linear_layer @@ -36,7 +37,10 @@ from tensorflow_lattice.python import premade from tensorflow_lattice.python import premade_lib from tensorflow_lattice.python import pwl_calibration_layer -from tensorflow_lattice.python import pwl_calibration_sonnet_module from tensorflow_lattice.python import pwl_calibration_lib +from tensorflow_lattice.python import pwl_calibration_sonnet_module from tensorflow_lattice.python import test_utils +from tensorflow_lattice.python import utils from tensorflow_lattice.python import visualization + +import tensorflow_lattice.sonnet_modules diff --git a/tensorflow_lattice/layers/__init__.py b/tensorflow_lattice/layers/__init__.py index 8a20a70..8ae90fc 100644 --- a/tensorflow_lattice/layers/__init__.py +++ b/tensorflow_lattice/layers/__init__.py @@ -16,6 +16,7 @@ from tensorflow_lattice.python.aggregation_layer import Aggregation from tensorflow_lattice.python.categorical_calibration_layer import CategoricalCalibration +from tensorflow_lattice.python.kronecker_factored_lattice_layer import KroneckerFactoredLattice from tensorflow_lattice.python.lattice_layer import Lattice from tensorflow_lattice.python.linear_layer import Linear from tensorflow_lattice.python.parallel_combination_layer import ParallelCombination diff --git a/tensorflow_lattice/python/BUILD b/tensorflow_lattice/python/BUILD index a572b17..aefa309 100644 --- a/tensorflow_lattice/python/BUILD +++ b/tensorflow_lattice/python/BUILD @@ -21,126 +21,182 @@ package( licenses(["notice"]) +# Build rules are alphabetized. Please add new rules alphabetically +# to maintain the ordering. py_library( - name = "pwl_calibration_layer", - srcs = ["pwl_calibration_layer.py"], + name = "aggregation_layer", + srcs = ["aggregation_layer.py"], srcs_version = "PY2AND3", deps = [ - ":pwl_calibration_lib", - # absl/logging dep, - # tensorflow:tensorflow_no_contrib dep, + # tensorflow dep, + ], +) + +py_test( + name = "aggregation_test", + srcs = ["aggregation_test.py"], + python_version = "PY3", + srcs_version = "PY2AND3", + deps = [ + ":aggregation_layer", + # tensorflow dep, ], ) py_library( - name = "pwl_calibration_sonnet_module", - srcs = ["pwl_calibration_sonnet_module.py"], + name = "categorical_calibration_layer", + srcs = ["categorical_calibration_layer.py"], srcs_version = "PY2AND3", deps = [ - ":pwl_calibration_lib", - # absl/logging dep, - # sonnet dep, + ":categorical_calibration_lib", # tensorflow:tensorflow_no_contrib dep, ], ) py_library( - name = "pwl_calibration_lib", - srcs = ["pwl_calibration_lib.py"], + name = "categorical_calibration_lib", + srcs = ["categorical_calibration_lib.py"], srcs_version = "PY2AND3", deps = [ + ":internal_utils", # enum dep, # tensorflow:tensorflow_no_contrib dep, ], ) py_test( - name = "pwl_calibration_test", + name = "categorical_calibration_test", size = "large", - srcs = ["pwl_calibration_test.py"], + srcs = ["categorical_calibration_test.py"], python_version = "PY3", - # shard_count = 12, + # shard_count = 4, srcs_version = "PY2AND3", deps = [ + ":categorical_calibration_layer", ":parallel_combination_layer", - ":pwl_calibration_layer", - ":pwl_calibration_sonnet_module", ":test_utils", # absl/logging dep, # absl/testing:parameterized dep, # numpy dep, - # sonnet dep, # tensorflow dep, ], ) py_library( - name = "linear_layer", - srcs = ["linear_layer.py"], + name = "configs", + srcs = ["configs.py"], srcs_version = "PY2AND3", deps = [ - ":linear_lib", - # tensorflow:tensorflow_no_contrib dep, + # absl/logging dep, + # tensorflow dep, ], ) -py_library( - name = "linear_lib", - srcs = ["linear_lib.py"], +py_test( + name = "configs_test", + size = "small", + srcs = ["configs_test.py"], + python_version = "PY3", srcs_version = "PY2AND3", deps = [ - ":internal_utils", - # tensorflow:tensorflow_no_contrib dep, + ":categorical_calibration_layer", + ":configs", + ":lattice_layer", + ":linear_layer", + ":premade", + ":pwl_calibration_layer", + # absl/logging dep, + # tensorflow dep, ], ) py_library( - name = "categorical_calibration_layer", - srcs = ["categorical_calibration_layer.py"], + name = "estimators", + srcs = ["estimators.py"], srcs_version = "PY2AND3", deps = [ - ":categorical_calibration_lib", - # tensorflow:tensorflow_no_contrib dep, + ":categorical_calibration_layer", + ":configs", + ":lattice_layer", + ":linear_layer", + ":model_info", + ":premade", + ":premade_lib", + ":pwl_calibration_layer", + ":rtl_layer", + # absl/logging dep, + # tensorflow dep, + ], +) + +py_test( + name = "estimators_test", + size = "enormous", + srcs = ["estimators_test.py"], + python_version = "PY3", + # shard_count = 10, + srcs_version = "PY2AND3", + deps = [ + ":configs", + ":estimators", + ":model_info", + # absl/logging dep, + # sklearn dep, + # tensorflow dep, ], ) py_library( - name = "categorical_calibration_lib", - srcs = ["categorical_calibration_lib.py"], + name = "internal_utils", + srcs = ["internal_utils.py"], srcs_version = "PY2AND3", deps = [ - ":internal_utils", - # enum dep, - # tensorflow:tensorflow_no_contrib dep, + # tensorflow dep, ], ) py_test( - name = "categorical_calibration_test", - size = "large", - srcs = ["categorical_calibration_test.py"], + name = "internal_utils_test", + srcs = ["internal_utils_test.py"], python_version = "PY3", - # shard_count = 4, srcs_version = "PY2AND3", deps = [ - ":categorical_calibration_layer", - ":parallel_combination_layer", - ":test_utils", - # absl/logging dep, - # absl/testing:parameterized dep, + ":internal_utils", + # tensorflow dep, + ], +) + +py_library( + name = "kronecker_factored_lattice_layer", + srcs = ["kronecker_factored_lattice_layer.py"], + srcs_version = "PY2AND3", + deps = [ + ":kronecker_factored_lattice_lib", + ":utils", + # tensorflow dep, + ], +) + +py_library( + name = "kronecker_factored_lattice_lib", + srcs = ["kronecker_factored_lattice_lib.py"], + srcs_version = "PY2AND3", + deps = [ + ":utils", # numpy dep, # tensorflow dep, ], ) py_test( - name = "linear_test", + name = "kronecker_factored_lattice_test", size = "large", - srcs = ["linear_test.py"], + srcs = ["kronecker_factored_lattice_test.py"], python_version = "PY3", + # shard_count = 12, srcs_version = "PY2AND3", deps = [ - ":linear_layer", + ":kronecker_factored_lattice_layer", ":test_utils", # absl/logging dep, # absl/testing:parameterized dep, @@ -154,9 +210,9 @@ py_library( srcs = ["lattice_layer.py"], srcs_version = "PY2AND3", deps = [ - ":categorical_calibration_layer", ":lattice_lib", ":pwl_calibration_layer", + ":utils", # tensorflow:tensorflow_no_contrib dep, ], ) @@ -166,6 +222,7 @@ py_library( srcs = ["lattice_lib.py"], srcs_version = "PY2AND3", deps = [ + ":utils", # absl/logging dep, # numpy dep, # tensorflow:tensorflow_no_contrib dep, @@ -190,27 +247,37 @@ py_test( ) py_library( - name = "parallel_combination_layer", - srcs = ["parallel_combination_layer.py"], + name = "linear_layer", + srcs = ["linear_layer.py"], srcs_version = "PY2AND3", deps = [ - ":categorical_calibration_layer", - ":lattice_layer", - ":linear_layer", - ":pwl_calibration_layer", + ":linear_lib", + ":utils", + # tensorflow:tensorflow_no_contrib dep, + ], +) + +py_library( + name = "linear_lib", + srcs = ["linear_lib.py"], + srcs_version = "PY2AND3", + deps = [ + ":internal_utils", + ":utils", # tensorflow:tensorflow_no_contrib dep, ], ) py_test( - name = "parallel_combination_test", + name = "linear_test", size = "large", - srcs = ["parallel_combination_test.py"], + srcs = ["linear_test.py"], python_version = "PY3", srcs_version = "PY2AND3", deps = [ - ":lattice_layer", - ":parallel_combination_layer", + ":linear_layer", + ":test_utils", + ":utils", # absl/logging dep, # absl/testing:parameterized dep, # numpy dep, @@ -219,25 +286,34 @@ py_test( ) py_library( - name = "rtl_layer", - srcs = ["rtl_layer.py"], + name = "model_info", + srcs = ["model_info.py"], + srcs_version = "PY2AND3", + deps = [], +) + +py_library( + name = "parallel_combination_layer", + srcs = ["parallel_combination_layer.py"], srcs_version = "PY2AND3", deps = [ + ":categorical_calibration_layer", ":lattice_layer", + ":linear_layer", + ":pwl_calibration_layer", # tensorflow:tensorflow_no_contrib dep, ], ) py_test( - name = "rtl_test", + name = "parallel_combination_test", size = "large", - srcs = ["rtl_test.py"], + srcs = ["parallel_combination_test.py"], python_version = "PY3", srcs_version = "PY2AND3", deps = [ - ":linear_layer", - ":pwl_calibration_layer", - ":rtl_layer", + ":lattice_layer", + ":parallel_combination_layer", # absl/logging dep, # absl/testing:parameterized dep, # numpy dep, @@ -246,191 +322,198 @@ py_test( ) py_library( - name = "configs", - srcs = ["configs.py"], + name = "premade", + srcs = ["premade.py"], srcs_version = "PY2AND3", deps = [ + ":aggregation_layer", + ":categorical_calibration_layer", + ":configs", + ":lattice_layer", + ":parallel_combination_layer", + ":premade_lib", + ":pwl_calibration_layer", # absl/logging dep, # tensorflow dep, ], ) -py_test( - name = "configs_test", - size = "small", - srcs = ["configs_test.py"], - python_version = "PY3", +py_library( + name = "premade_lib", + srcs = ["premade_lib.py"], srcs_version = "PY2AND3", deps = [ + ":aggregation_layer", ":categorical_calibration_layer", ":configs", ":lattice_layer", + ":lattice_lib", ":linear_layer", - ":premade", ":pwl_calibration_layer", + ":rtl_layer", + ":utils", # absl/logging dep, - # tensorflow dep, - ], -) - -py_library( - name = "internal_utils", - srcs = ["internal_utils.py"], - srcs_version = "PY2AND3", - deps = [ + # enum dep, + # numpy dep, + # six dep, # tensorflow dep, ], ) py_test( - name = "internal_utils_test", - srcs = ["internal_utils_test.py"], + name = "premade_test", + size = "large", + srcs = ["premade_test.py"], python_version = "PY3", + # shard_count = 10, srcs_version = "PY2AND3", deps = [ - ":internal_utils", + ":configs", + ":premade", + ":premade_lib", + # absl/logging dep, + # numpy dep, # tensorflow dep, ], ) py_library( - name = "test_utils", - srcs = ["test_utils.py"], + name = "pwl_calibration_layer", + srcs = ["pwl_calibration_layer.py"], srcs_version = "PY2AND3", deps = [ - ":visualization", + ":pwl_calibration_lib", + ":utils", # absl/logging dep, - # numpy dep, - # sonnet dep, - # tensorflow dep, + # tensorflow:tensorflow_no_contrib dep, ], ) py_library( - name = "visualization", - srcs = ["visualization.py"], + name = "pwl_calibration_lib", + srcs = ["pwl_calibration_lib.py"], srcs_version = "PY2AND3", deps = [ - ":model_info", - # graphviz dep, - # matplotlib dep, - # mpl_toolkits/mplot3d dep, - # numpy dep, + ":utils", + # enum dep, + # tensorflow:tensorflow_no_contrib dep, ], ) py_library( - name = "estimators", - srcs = ["estimators.py"], + name = "pwl_calibration_sonnet_module", + srcs = ["pwl_calibration_sonnet_module.py"], srcs_version = "PY2AND3", deps = [ - ":categorical_calibration_layer", - ":configs", - ":lattice_layer", - ":linear_layer", - ":model_info", - ":premade", - ":premade_lib", - ":pwl_calibration_layer", + ":pwl_calibration_lib", + ":utils", # absl/logging dep, - # tensorflow dep, + # sonnet dep, + # tensorflow:tensorflow_no_contrib dep, ], ) py_test( - name = "estimators_test", - size = "enormous", - srcs = ["estimators_test.py"], + name = "pwl_calibration_test", + size = "large", + srcs = ["pwl_calibration_test.py"], python_version = "PY3", - # shard_count = 10, + # shard_count = 12, srcs_version = "PY2AND3", deps = [ - ":configs", - ":estimators", - ":model_info", + ":parallel_combination_layer", + ":pwl_calibration_layer", + ":pwl_calibration_sonnet_module", + ":test_utils", + ":utils", # absl/logging dep, - # sklearn dep, + # absl/testing:parameterized dep, + # numpy dep, + # sonnet dep, # tensorflow dep, ], ) py_library( - name = "premade", - srcs = ["premade.py"], + name = "rtl_layer", + srcs = ["rtl_layer.py"], srcs_version = "PY2AND3", deps = [ - ":aggregation_layer", - ":categorical_calibration_layer", - ":configs", ":lattice_layer", - ":parallel_combination_layer", - ":premade_lib", - ":pwl_calibration_layer", - # absl/logging dep, - # tensorflow dep, + ":rtl_lib", + # tensorflow:tensorflow_no_contrib dep, ], ) py_library( - name = "premade_lib", - srcs = ["premade_lib.py"], + name = "rtl_lib", + srcs = ["rtl_lib.py"], srcs_version = "PY2AND3", deps = [ - ":aggregation_layer", - ":categorical_calibration_layer", - ":configs", - ":lattice_layer", - ":lattice_lib", - ":linear_layer", - ":pwl_calibration_layer", - # absl/logging dep, - # enum dep, - # numpy dep, # six dep, - # tensorflow dep, ], ) py_test( - name = "premade_test", + name = "rtl_test", size = "large", - srcs = ["premade_test.py"], + srcs = ["rtl_test.py"], python_version = "PY3", - # shard_count = 10, srcs_version = "PY2AND3", deps = [ - ":configs", - ":premade", - ":premade_lib", + ":linear_layer", + ":pwl_calibration_layer", + ":rtl_layer", # absl/logging dep, + # absl/testing:parameterized dep, # numpy dep, # tensorflow dep, ], ) py_library( - name = "aggregation_layer", - srcs = ["aggregation_layer.py"], + name = "test_utils", + srcs = ["test_utils.py"], srcs_version = "PY2AND3", deps = [ + ":visualization", + # absl/logging dep, + # numpy dep, + # sonnet dep, # tensorflow dep, ], ) +py_library( + name = "utils", + srcs = ["utils.py"], + srcs_version = "PY2AND3", + deps = [ + # six dep, + ], +) + py_test( - name = "aggregation_test", - srcs = ["aggregation_test.py"], + name = "utils_test", + srcs = ["utils_test.py"], python_version = "PY3", srcs_version = "PY2AND3", deps = [ - ":aggregation_layer", + ":utils", + # absl/testing:parameterized dep, # tensorflow dep, ], ) py_library( - name = "model_info", - srcs = ["model_info.py"], + name = "visualization", + srcs = ["visualization.py"], srcs_version = "PY2AND3", - deps = [], + deps = [ + ":model_info", + # graphviz dep, + # matplotlib dep, + # mpl_toolkits/mplot3d dep, + # numpy dep, + ], ) diff --git a/tensorflow_lattice/python/categorical_calibration_lib.py b/tensorflow_lattice/python/categorical_calibration_lib.py index 613a0cc..2fcc40b 100644 --- a/tensorflow_lattice/python/categorical_calibration_lib.py +++ b/tensorflow_lattice/python/categorical_calibration_lib.py @@ -17,7 +17,7 @@ from __future__ import division from __future__ import print_function -from . import internal_utils as iu +from . import internal_utils import tensorflow as tf @@ -51,7 +51,7 @@ def project(weights, output_min, output_max, monotonicities): if monotonicities: projected_weights = ( - iu.approximately_project_categorical_partial_monotonicities( + internal_utils.approximately_project_categorical_partial_monotonicities( projected_weights, monotonicities)) if output_min is not None: diff --git a/tensorflow_lattice/python/configs.py b/tensorflow_lattice/python/configs.py index 0b1d6b4..a561079 100644 --- a/tensorflow_lattice/python/configs.py +++ b/tensorflow_lattice/python/configs.py @@ -219,6 +219,23 @@ class CalibratedLatticeEnsembleConfig(_Config, _HasFeatureConfigs, feature_analysis_input_fn=feature_analysis_input_fn) estimator.train(input_fn=train_input_fn) ``` + You can also construct a random ensemble (RTL) using a `tfl.layers.RTL` + layer so long as all features have the same lattice size: + ```python + model_config = tfl.configs.CalibratedLatticeEnsembleConfig( + lattices='rtl_layer', + num_lattices=6, # number of lattices + lattice_rank=5, # number of features in each lattice + feature_configs=[...], + ) + feature_analysis_input_fn = create_input_fn(num_epochs=1, ...) + train_input_fn = create_input_fn(num_epochs=100, ...) + estimator = tfl.estimators.CannedClassifier( + feature_columns=feature_columns, + model_config=model_config, + feature_analysis_input_fn=feature_analysis_input_fn) + estimator.train(input_fn=train_input_fn) + ``` To create a Crystals model, you will need to provide a *prefitting_input_fn* to the estimator constructor. This input_fn is used to train the prefitting model, as described above. The prefitting model does not need to be fully @@ -271,6 +288,10 @@ def __init__(self, lattices: Should be one of the following: - String `'random'` indicating that the features in each lattice should be selected randomly + - String `'rtl_layer'` indicating that the features in each lattice + should be selected randomly using a `tfl.layers.RTL` layer. Note that + using a `tfl.layers.RTL` layer scales better than using separate + `tfl.layers.Lattice` instances for the ensemble. - String `'crystals'` to use a heuristic to construct the lattice ensemble based on pairwise feature interactions - An explicit list of list of feature names to be used in each lattice diff --git a/tensorflow_lattice/python/estimators.py b/tensorflow_lattice/python/estimators.py index 646fe49..6fdcd2a 100644 --- a/tensorflow_lattice/python/estimators.py +++ b/tensorflow_lattice/python/estimators.py @@ -79,6 +79,7 @@ from . import premade from . import premade_lib from . import pwl_calibration_layer +from . import rtl_layer from absl import logging import numpy as np @@ -129,7 +130,7 @@ def _poll_for_file(filename): while not tf.io.gfile.exists(filename): time.sleep(_POLL_INTERVAL_SECS) if time.time() - start > _MAX_WAIT_TIME: - raise WaitTimeOutError('Waiting for file {} timed-out'.filename) + raise WaitTimeOutError('Waiting for file {} timed-out'.format(filename)) def transform_features(features, feature_columns=None): @@ -159,9 +160,17 @@ def transform_features(features, feature_columns=None): feature_column.name] = feature_column._transform_feature(features) elif (isinstance(feature_column, fc._CategoricalColumn) or isinstance(feature_column, fc2.CategoricalColumn)): - parsed_features[ - feature_column.name] = feature_column._transform_feature( - features).values + if feature_column.num_oov_buckets: + # If oov buckets are used, missing values are assigned to the last + # oov bucket. + default_value = feature_column.num_buckets - 1 + else: + default_value = feature_column.default_value + parsed_features[feature_column.name] = tf.reshape( + tf.sparse.to_dense( + sp_input=feature_column._transform_feature(features), + default_value=default_value), + shape=[-1, 1]) else: raise ValueError( 'Unsupported feature_column: {}'.format(feature_column)) @@ -476,6 +485,9 @@ def _finalize_model_structure(model_config, label_dimension, feature_columns, _ENSEMBLE_STRUCTURE_FILE) if ((config is None or config.is_chief) and not tf.io.gfile.exists(ensemble_structure_filename)): + if model_config.lattices not in ['random', 'crystals', 'rtl_layer']: + raise ValueError('Unsupported ensemble structure: {}'.format( + model_config.lattices)) if model_config.lattices == 'random': premade_lib.set_random_lattice_ensemble(model_config, feature_names) elif model_config.lattices == 'crystals': @@ -490,10 +502,10 @@ def _finalize_model_structure(model_config, label_dimension, feature_columns, prefitting_steps=prefitting_steps, config=config, dtype=dtype) - else: - raise ValueError('Unsupported ensemble structure: {}'.format( - model_config.lattices)) - if model_config.fix_ensemble_for_2d_constraints: + if (model_config.fix_ensemble_for_2d_constraints and + model_config.lattices != 'rtl_layer'): + # Note that we currently only support monotonicity and bound constraints + # for RTL. _fix_ensemble_for_2d_constraints(model_config, feature_names) # Save lattices to file as the chief worker. @@ -552,6 +564,8 @@ def _update_by_feature_columns(model_config, feature_columns): feature_config.vocabulary_list = feature_column.vocabulary_list feature_config.num_buckets = feature_column.num_buckets if feature_column.num_oov_buckets: + # A positive num_oov_buckets can not be specified with default_value. + # See tf.feature_column.categorical_column_with_vocabulary_list. feature_config.default_value = None else: # We add a bucket at the end for the default_value, since num_buckets @@ -1157,382 +1171,501 @@ def __init__(self, def _match_op(ops, regex): """Returns ops that match given regex along with the matched sections.""" matches = [] + prog = re.compile(regex) for op in ops: - op_matches = re.findall(regex, op) + op_matches = prog.findall(op) if op_matches: matches.append((op, op_matches[0])) return matches -def get_model_graph(saved_model_path, tag='serve'): - """Returns all layers and parameters used in a saved model as a graph. - - The returned graph is not a TF graph, rather a graph of python object that - encodes the model structure and includes trained model parameters. The graph - can be used by the `tfl.visualization` module for plotting and other - visualization and analysis. +def _create_feature_nodes(sess, ops, graph): + """Returns a map from feature name to InputFeatureNode.""" + # Extract list of features from the graph. + # {FEATURES_SCOPE}/{feature_name} + feature_nodes = {} + feature_op_re = '{}/(.*)'.format(FEATURES_SCOPE) + for (_, feature_name) in _match_op(ops, feature_op_re): + category_table_op = 'transform/{}_lookup/Const'.format(feature_name) + if category_table_op in ops: + is_categorical = True + vocabulary_list = sess.run( + graph.get_operation_by_name(category_table_op).outputs[0]) + # Replace byte types with their string values. + vocabulary_list = [ + str(x.decode()) if isinstance(x, bytes) else str(x) + for x in vocabulary_list + ] + else: + is_categorical = False + vocabulary_list = None + + feature_node = model_info.InputFeatureNode( + name=feature_name, + is_categorical=is_categorical, + vocabulary_list=vocabulary_list) + feature_nodes[feature_name] = feature_node + return feature_nodes + + +def _create_categorical_calibration_nodes(sess, ops, graph, feature_nodes): + """Returns a map from feature_name to list of CategoricalCalibrationNode.""" + categorical_calibration_nodes = collections.defaultdict(list) + # Get calibrator output values. We need to call the read variable op. + # {CALIB_LAYER_NAME}_{feature_name}/ + # {CATEGORICAL_CALIBRATION_KERNEL_NAME}/Read/ReadVariableOp + kernel_op_re = '^{}_(.*)/{}/Read/ReadVariableOp$'.format( + premade_lib.CALIB_LAYER_NAME, + categorical_calibration_layer.CATEGORICAL_CALIBRATION_KERNEL_NAME, + ) + for kernel_op, feature_name in _match_op(ops, kernel_op_re): + output_values = sess.run(graph.get_operation_by_name(kernel_op).outputs[0]) + + # Get default input value if defined. + # {CALIB_LAYER_NAME}_{feature_name}/ + # {DEFAULT_INPUT_VALUE_NAME} + default_input_value_op = '^{}_{}/{}$'.format( + premade_lib.CALIB_LAYER_NAME, + feature_name, + categorical_calibration_layer.DEFAULT_INPUT_VALUE_NAME, + ) + if default_input_value_op in ops: + default_input = sess.run( + graph.get_operation_by_name(default_input_value_op).outputs[0]) + else: + default_input = None + + # Create one calibration node per output dimension of the calibrator. + for calibration_output_idx in range(output_values.shape[1]): + categorical_calibration_node = model_info.CategoricalCalibrationNode( + input_node=feature_nodes[feature_name], + output_values=output_values[:, calibration_output_idx], + default_input=default_input) + categorical_calibration_nodes[feature_name].append( + categorical_calibration_node) + return categorical_calibration_nodes + + +def _create_pwl_calibration_nodes(sess, ops, graph, feature_nodes): + """Returns a map from feature_name to list of PWLCalibrationNode.""" + pwl_calibration_nodes = collections.defaultdict(list) + # Calculate input keypoints. + # We extract lengh (deltas between keypoints) and kernel interpolation + # keypoints (which does not include the last keypoint), and then + # construct the full keypoints list using both. + + # Lengths (deltas between keypoints). + # {CALIB_LAYER_NAME}_{feature_name}/{LENGTHS_NAME} + lengths_op_re = '^{}_(.*)/{}$'.format( + premade_lib.CALIB_LAYER_NAME, + pwl_calibration_layer.LENGTHS_NAME, + ) + for lengths_op, feature_name in _match_op(ops, lengths_op_re): + # Interpolation keypoints does not inlcude the last input keypoint. + # {CALIB_LAYER_NAME}_{feature_name}/{INTERPOLATION_KEYPOINTS_NAME} + keypoints_op = '{}_{}/{}'.format( + premade_lib.CALIB_LAYER_NAME, + feature_name, + pwl_calibration_layer.INTERPOLATION_KEYPOINTS_NAME, + ) - Example: + # Output keypoints. We need to call the varible read op. + # {CALIB_LAYER_NAME}_{feature_name}/{PWL_CALIBRATION_KERNEL_NAME} + kernel_op = '{}_{}/{}/Read/ReadVariableOp'.format( + premade_lib.CALIB_LAYER_NAME, + feature_name, + pwl_calibration_layer.PWL_CALIBRATION_KERNEL_NAME, + ) - ```python - model_graph = estimators.get_model_graph(saved_model_path) - visualization.plot_feature_calibrator(model_graph, "feature_name") - visualization.plot_all_calibrators(model_graph) - visualization.draw_model_graph(model_graph) - ``` + (lengths, keypoints, kernel) = sess.run( + (graph.get_operation_by_name(lengths_op).outputs[0], + graph.get_operation_by_name(keypoints_op).outputs[0], + graph.get_operation_by_name(kernel_op).outputs[0])) + output_keypoints = np.cumsum(kernel, axis=0) - Args: - saved_model_path: Path to the saved model. - tag: Saved model tag for loading. + # Add the last keypoint to the keypoint list. + # TODO: handle cyclic PWL layers. + input_keypoints = np.append(keypoints, keypoints[-1] + lengths[-1]) - Returns: - A `model_info.ModelGraph` object that includes the model graph. - """ - # List of all the nodes in the model. - nodes = [] + # Get missing/default input value if present: + # {CALIB_LAYER_NAME}_{feature_name}/{MISSING_INPUT_VALUE_NAME} + default_input_value_op = '{}_{}/{}'.format( + premade_lib.CALIB_LAYER_NAME, + feature_name, + pwl_calibration_layer.MISSING_INPUT_VALUE_NAME, + ) + if default_input_value_op in ops: + default_input = sess.run( + graph.get_operation_by_name(default_input_value_op).outputs[0])[0] + else: + default_input = None - # Dict from feature name to corresponding InputFeatureNode object. - feature_nodes = {} + # Find corresponding default/missing output if present. + # {CALIB_LAYER_NAME}_{feature_name}/{PWL_CALIBRATION_MISSING_OUTPUT_NAME} + default_output_op = '{}_{}/{}/Read/ReadVariableOp'.format( + premade_lib.CALIB_LAYER_NAME, + feature_name, + pwl_calibration_layer.PWL_CALIBRATION_MISSING_OUTPUT_NAME, + ) + if default_output_op in ops: + default_output = sess.run( + graph.get_operation_by_name(default_output_op).outputs[0]) + else: + default_output = None - # Dict from submodel index to a list of calibrated inputs for the submodel. - submodel_input_nodes = collections.defaultdict(list) + # Create one calibration node per output dimension of the calibrator. + for calibration_output_idx in range(output_keypoints.shape[1]): + pwl_calibration_node = model_info.PWLCalibrationNode( + input_node=feature_nodes[feature_name], + input_keypoints=input_keypoints, + output_keypoints=output_keypoints[:, calibration_output_idx], + default_input=default_input, + default_output=(None if default_output is None else + default_output[:, calibration_output_idx])) + pwl_calibration_nodes[feature_name].append(pwl_calibration_node) + return pwl_calibration_nodes - # Dict from submodel index to the output node of the submodel. - submodel_output_nodes = {} - tf.compat.v1.reset_default_graph() - with tf.compat.v1.Session() as sess: - tf.compat.v1.saved_model.loader.load(sess, [tag], saved_model_path) - g = tf.compat.v1.get_default_graph() - ops = [op.name for op in g.get_operations()] - - ############################# - # Create input feature nodes. - ############################# - - # Extract list of features from the graph. - # {FEATURES_SCOPE}/{feature_name} - feature_op_re = '{}/(.*)'.format(FEATURES_SCOPE) - for (_, feature_name) in _match_op(ops, feature_op_re): - category_table_op = 'transform/{}_lookup/Const'.format(feature_name) - if category_table_op in ops: - is_categorical = True - vocabulary_list = sess.run( - g.get_operation_by_name(category_table_op).outputs[0]) - # Replace byte types with their string values. - vocabulary_list = [ - str(x.decode()) if isinstance(x, bytes) else str(x) - for x in vocabulary_list - ] - else: - is_categorical = False - vocabulary_list = None +def _create_submodel_input_map(ops, calibration_nodes_map): + """Returns a map from submodel_idx to a list of calibration nodes.""" + submodel_input_nodes = collections.defaultdict(list) + for feature_name, calibration_nodes in calibration_nodes_map.items(): + # Identity passthrough ops that pass this calibration to each submodel. + # {CALIB_PASSTHROUGH_NAME}_{feature_name}_ + # {calibration_output_idx}_{submodel_idx}_{submodel_input_idx} + shared_calib_passthrough_op_re = r'^{}_{}_(\d*)_(\d*)_(\d*)$'.format( + premade_lib.CALIB_PASSTHROUGH_NAME, feature_name) + for _, (calibration_output_idx, submodel_idx, + submodel_input_idx) in _match_op(ops, + shared_calib_passthrough_op_re): + submodel_input_nodes[submodel_idx].append( + (submodel_input_idx, calibration_nodes[int(calibration_output_idx)])) + return submodel_input_nodes + + +def _create_linear_nodes(sess, ops, graph, submodel_input_nodes): + """Returns a map from submodel_idx to LinearNode.""" + linear_nodes = {} + # Linear coefficients. + # {LINEAR_LAYER_NAME}_{submodel_idx}/{LINEAR_LAYER_KERNEL_NAME} + linear_kernel_op_re = '^{}_(.*)/{}/Read/ReadVariableOp$'.format( + premade_lib.LINEAR_LAYER_NAME, + linear_layer.LINEAR_LAYER_KERNEL_NAME, + ) + for linear_kernel_op, submodel_idx in _match_op(ops, linear_kernel_op_re): + coefficients = sess.run( + graph.get_operation_by_name(linear_kernel_op).outputs[0]).flatten() + + # Bias term. + # {LINEAR_LAYER_NAME}_{submodel_idx}/{LINEAR_LAYER_BIAS_NAME} + bias_op = '{}_{}/{}/Read/ReadVariableOp'.format( + premade_lib.LINEAR_LAYER_NAME, + submodel_idx, + linear_layer.LINEAR_LAYER_BIAS_NAME, + ) + if bias_op in ops: + bias = sess.run(graph.get_operation_by_name(bias_op).outputs[0]) + else: + bias = 0.0 - feature_node = model_info.InputFeatureNode( - name=feature_name, - is_categorical=is_categorical, - vocabulary_list=vocabulary_list) - feature_nodes[feature_name] = feature_node - nodes.append(feature_node) + # Sort input nodes by input index. + input_nodes = [ + node for _, node in sorted(submodel_input_nodes[submodel_idx]) + ] - ####################################### - # Create categorical calibration nodes. - ####################################### + linear_node = model_info.LinearNode( + input_nodes=input_nodes, coefficients=coefficients, bias=bias) + linear_nodes[submodel_idx] = linear_node + return linear_nodes + + +def _create_lattice_nodes(sess, ops, graph, submodel_input_nodes): + """Returns a map from submodel_idx to LatticeNode.""" + lattice_nodes = {} + # Lattice weights. + # {LATTICE_LAYER_NAME}_{submodel_idx}/{LATTICE_KERNEL_NAME} + lattice_kernel_op_re = '^{}_(.*)/{}/Read/ReadVariableOp$'.format( + premade_lib.LATTICE_LAYER_NAME, + lattice_layer.LATTICE_KERNEL_NAME, + ) + for lattice_kernel_op, submodel_idx in _match_op(ops, lattice_kernel_op_re): + lattice_kernel = sess.run( + graph.get_operation_by_name(lattice_kernel_op).outputs[0]).flatten() + + # Lattice sizes. + # {Lattice_LAYER_NAME}_{submodel_idx}/{LATTICE_SIZES_NAME} + lattice_sizes_op_name = '{}_{}/{}'.format(premade_lib.LATTICE_LAYER_NAME, + submodel_idx, + lattice_layer.LATTICE_SIZES_NAME) + lattice_sizes = sess.run( + graph.get_operation_by_name( + lattice_sizes_op_name).outputs[0]).flatten() + + # Shape the flat lattice parameters based on the calculated lattice sizes. + weights = np.reshape(lattice_kernel, lattice_sizes) + + # Sort input nodes by input index. + input_nodes = [ + node for _, node in sorted(submodel_input_nodes[submodel_idx]) + ] - # Get calibrator output values. We need to call the read variable op. - # {CALIB_LAYER_NAME}_{feature_name}/ - # {CATEGORICAL_CALIBRATION_KERNEL_NAME}/Read/ReadVariableOp - kernel_op_re = '^{}_(.*)/{}/Read/ReadVariableOp$'.format( - premade_lib.CALIB_LAYER_NAME, - categorical_calibration_layer.CATEGORICAL_CALIBRATION_KERNEL_NAME, - ) - for kernel_op, feature_name in _match_op(ops, kernel_op_re): - output_values = sess.run(g.get_operation_by_name(kernel_op).outputs[0]) - - # Get default input value if defined. - # {CALIB_LAYER_NAME}_{feature_name}/ - # {DEFAULT_INPUT_VALUE_NAME} - default_input_value_op = '^{}_{}/{}$'.format( - premade_lib.CALIB_LAYER_NAME, - feature_name, - categorical_calibration_layer.DEFAULT_INPUT_VALUE_NAME, - ) - if default_input_value_op in ops: - default_input = sess.run( - g.get_operation_by_name(default_input_value_op).outputs[0]) - else: - default_input = None - - # Create one calibration node per output dimension of the calibrator. - categorical_calibration_nodes = [] - for calibration_output_idx in range(output_values.shape[1]): - categorical_calibration_node = model_info.CategoricalCalibrationNode( - input_node=feature_nodes[feature_name], - output_values=output_values[:, calibration_output_idx], - default_input=default_input) - categorical_calibration_nodes.append(categorical_calibration_node) - nodes.append(categorical_calibration_node) - - # Identity passthrough ops that pass this calibration to each submodel. - # {CALIB_PASSTHROUGH_NAME}_{feature_name}_ - # {calibration_output_idx}_{submodel_idx}_{submodel_input_idx} - shared_calib_passthrough_op_re = r'^{}_{}_(\d*)_(\d*)_(\d*)$'.format( - premade_lib.CALIB_PASSTHROUGH_NAME, feature_name) - for op, (calibration_output_idx, submodel_idx, - submodel_input_idx) in _match_op(ops, - shared_calib_passthrough_op_re): - submodel_input_nodes[submodel_idx].append( - (submodel_input_idx, - categorical_calibration_nodes[int(calibration_output_idx)])) - - ############################### - # Create PWL calibration nodes. - ############################### - - # Calculate input keypoints. - # We extract lengh (deltas between keypoints) and kernel interpolation - # keypoints (which does not include the last keypoint), and then - # construct the full keypoints list using both. - - # Lengths (deltas between keypoints). - # {CALIB_LAYER_NAME}_{feature_name}/{LENGTHS_NAME} - lengths_op_re = '^{}_(.*)/{}$'.format( - premade_lib.CALIB_LAYER_NAME, - pwl_calibration_layer.LENGTHS_NAME, + lattice_node = model_info.LatticeNode( + input_nodes=input_nodes, weights=weights) + lattice_nodes[submodel_idx] = lattice_node + return lattice_nodes + + +def _create_rtl_lattice_nodes(sess, ops, graph, calibration_nodes_map): + """Returns a map from lattice_submodel_index to LatticeNode.""" + lattice_nodes = {} + lattice_submodel_index = 0 + # Feature name in concat op. + # {RTL_INPUT_NAME}_{feature_name}:0 + feature_name_prog = re.compile('^{}_(.*):0$'.format( + premade_lib.RTL_INPUT_NAME)) + # RTL Layer identified by single concat op per submodel. + # {RTL_LAYER_NAME}_{submodel_idx}/RTL_CONCAT_NAME + rtl_layer_concat_op_re = '^{}_(.*)/{}$'.format(premade_lib.RTL_LAYER_NAME, + rtl_layer.RTL_CONCAT_NAME) + for concat_op_name, submodel_idx in _match_op(ops, rtl_layer_concat_op_re): + # First we reconstruct the flattened calibration outputs for this submodel. + concat_op = graph.get_operation_by_name(concat_op_name) + input_names = [input_tensor.name for input_tensor in concat_op.inputs] + names_in_flattened_order = [] + for input_name in input_names: + match = feature_name_prog.match(input_name) + if match: + names_in_flattened_order.append(match.group(1)) + flattened_calibration_nodes = [] + for feature_name in names_in_flattened_order: + flattened_calibration_nodes.extend(calibration_nodes_map[feature_name]) + + # Lattice kernel weights. + # {RTL_LAYER_NAME}_{submodel_idx}/ + # {RTL_LATTICE_NAME}_{monotonicities}/{LATTICE_KERNEL_NAME} + lattice_kernel_op_re = '^{}_{}/{}_(.*)/{}/Read/ReadVariableOp$'.format( + premade_lib.RTL_LAYER_NAME, + submodel_idx, + rtl_layer.RTL_LATTICE_NAME, + lattice_layer.LATTICE_KERNEL_NAME, ) - for lengths_op, feature_name in _match_op(ops, lengths_op_re): - # Interpolation keypoints does not inlcude the last input keypoint. - # {CALIB_LAYER_NAME}_{feature_name}/{INTERPOLATION_KEYPOINTS_NAME} - keypoints_op = '{}_{}/{}'.format( - premade_lib.CALIB_LAYER_NAME, - feature_name, - pwl_calibration_layer.INTERPOLATION_KEYPOINTS_NAME, - ) + for lattice_kernel_op, monotonicities in _match_op(ops, + lattice_kernel_op_re): + # Lattice kernel weights. + lattice_kernel = sess.run( + graph.get_operation_by_name(lattice_kernel_op).outputs[0]) - # Output keypoints. We need to call the varible read op. - # {CALIB_LAYER_NAME}_{feature_name}/{PWL_CALIBRATION_KERNEL_NAME} - kernel_op = '{}_{}/{}/Read/ReadVariableOp'.format( - premade_lib.CALIB_LAYER_NAME, - feature_name, - pwl_calibration_layer.PWL_CALIBRATION_KERNEL_NAME, - ) + # Lattice sizes. + # {RTL_LAYER_NAME}_{submodel_idx}/ + # {RTL_LATTICE_NAME}_{monotonicities}/{LATTICE_SIZES_NAME} + lattice_sizes_op_name = '{}_{}/{}_{}/{}'.format( + premade_lib.RTL_LAYER_NAME, submodel_idx, rtl_layer.RTL_LATTICE_NAME, + monotonicities, lattice_layer.LATTICE_SIZES_NAME) - (lengths, keypoints, kernel) = sess.run( - (g.get_operation_by_name(lengths_op).outputs[0], - g.get_operation_by_name(keypoints_op).outputs[0], - g.get_operation_by_name(kernel_op).outputs[0])) - output_keypoints = np.cumsum(kernel, axis=0) - - # Add the last keypoint to the keypoint list. - # TODO: handle cyclic PWL layers. - input_keypoints = np.append(keypoints, keypoints[-1] + lengths[-1]) - - # Get missing/default input value if present: - # {CALIB_LAYER_NAME}_{feature_name}/{MISSING_INPUT_VALUE_NAME} - default_input_value_op = '{}_{}/{}'.format( - premade_lib.CALIB_LAYER_NAME, - feature_name, - pwl_calibration_layer.MISSING_INPUT_VALUE_NAME, - ) - if default_input_value_op in ops: - default_input = sess.run( - g.get_operation_by_name(default_input_value_op).outputs[0])[0] - else: - default_input = None - - # Find corresponding default/missing output if present. - # {CALIB_LAYER_NAME}_{feature_name}/{PWL_CALIBRATION_MISSING_OUTPUT_NAME} - default_output_op = '{}_{}/{}/Read/ReadVariableOp'.format( - premade_lib.CALIB_LAYER_NAME, - feature_name, - pwl_calibration_layer.PWL_CALIBRATION_MISSING_OUTPUT_NAME, - ) - if default_output_op in ops: - default_output = sess.run( - g.get_operation_by_name(default_output_op).outputs[0]) - else: - default_output = None - - # Create one calibration node per output dimension of the calibrator. - pwl_calibration_nodes = [] - for calibration_output_idx in range(output_keypoints.shape[1]): - pwl_calibration_node = model_info.PWLCalibrationNode( - input_node=feature_nodes[feature_name], - input_keypoints=input_keypoints, - output_keypoints=output_keypoints[:, calibration_output_idx], - default_input=default_input, - default_output=(None if default_output is None else - default_output[:, calibration_output_idx])) - pwl_calibration_nodes.append(pwl_calibration_node) - nodes.append(pwl_calibration_node) - - # Identity passthrough ops that pass this calibration to each submodel. - # {CALIB_PASSTHROUGH_NAME}_{feature_name}_ - # {calibration_output_idx}_{submodel_idx}_{submodel_input_idx} - shared_calib_passthrough_op_re = r'^{}_{}_(\d*)_(\d*)_(\d*)$'.format( - premade_lib.CALIB_PASSTHROUGH_NAME, feature_name) - for op, (calibration_output_idx, submodel_idx, - submodel_input_idx) in _match_op(ops, - shared_calib_passthrough_op_re): - submodel_input_nodes[submodel_idx].append( - (submodel_input_idx, - pwl_calibration_nodes[int(calibration_output_idx)])) - - ###################### - # Create linear nodes. - ###################### + lattice_sizes = sess.run( + graph.get_operation_by_name( + lattice_sizes_op_name).outputs[0]).flatten() + + # inputs_for_units + # {RTL_LAYER_NAME}_{submodel_index}/ + # {INPUTS_FOR_UNITS_PREFIX}_{monotonicities} + inputs_for_units_op_name = '{}_{}/{}_{}'.format( + premade_lib.RTL_LAYER_NAME, submodel_idx, + rtl_layer.INPUTS_FOR_UNITS_PREFIX, monotonicities) + + inputs_for_units = sess.run( + graph.get_operation_by_name(inputs_for_units_op_name).outputs[0]) + + # Make a unique lattice for each unit. + units = inputs_for_units.shape[0] + for i in range(units): + # Shape the flat lattice parameters based on the calculated lattice + # sizes. + weights = np.reshape(lattice_kernel[:, i], lattice_sizes) + + # Gather input nodes for lattice node. + indices = inputs_for_units[i] + input_nodes = [flattened_calibration_nodes[index] for index in indices] + + lattice_node = model_info.LatticeNode( + input_nodes=input_nodes, weights=weights) + lattice_nodes[lattice_submodel_index] = lattice_node + lattice_submodel_index += 1 + return lattice_nodes + + +def _create_output_combination_node(sess, ops, graph, submodel_output_nodes): + """Returns None, a LinearNode, or a MeanNode.""" + output_combination_node = None + # Mean node is only added for ensemble models. + if len(submodel_output_nodes) > 1: + input_nodes = [ + submodel_output_nodes[idx] + for idx in sorted(submodel_output_nodes.keys(), key=int) + ] # Linear coefficients. - # {LINEAR_LAYER_NAME}_{submodel_idx}/{LINEAR_LAYER_KERNEL_NAME} - linear_kernel_op_re = '^{}_(.*)/{}/Read/ReadVariableOp$'.format( - premade_lib.LINEAR_LAYER_NAME, + # {LINEAR_LAYER_COMBINATION_NAME}/{LINEAR_LAYER_KERNEL_NAME} + linear_combination_kernel_op = '{}/{}/Read/ReadVariableOp'.format( + premade_lib.OUTPUT_LINEAR_COMBINATION_LAYER_NAME, linear_layer.LINEAR_LAYER_KERNEL_NAME, ) - for linear_kernel_op, submodel_idx in _match_op(ops, linear_kernel_op_re): + if linear_combination_kernel_op in ops: coefficients = sess.run( - g.get_operation_by_name(linear_kernel_op).outputs[0]).flatten() + graph.get_operation_by_name( + linear_combination_kernel_op).outputs[0]).flatten() # Bias term. - # {LINEAR_LAYER_NAME}_{submodel_idx}/{LINEAR_LAYER_BIAS_NAME} - bias_op = '{}_{}/{}/Read/ReadVariableOp'.format( - premade_lib.LINEAR_LAYER_NAME, - submodel_idx, + # {OUTPUT_LINEAR_COMBINATION_LAYER_NAME}/{LINEAR_LAYER_BIAS_NAME} + bias_op = '{}/{}/Read/ReadVariableOp'.format( + premade_lib.OUTPUT_LINEAR_COMBINATION_LAYER_NAME, linear_layer.LINEAR_LAYER_BIAS_NAME, ) if bias_op in ops: - bias = sess.run(g.get_operation_by_name(bias_op).outputs[0]) + bias = sess.run(graph.get_operation_by_name(bias_op).outputs[0]) else: bias = 0.0 - # Sort input nodes by input index. - input_nodes = [ - node for _, node in sorted(submodel_input_nodes[submodel_idx]) - ] - - linear_node = model_info.LinearNode( + linear_combination_node = model_info.LinearNode( input_nodes=input_nodes, coefficients=coefficients, bias=bias) - submodel_output_nodes[submodel_idx] = linear_node - nodes.append(linear_node) - - ####################### - # Create lattice nodes. - ####################### - - # Lattice weights. - # {Lattice_LAYER_NAME}_{submodel_idx}/{LATTICE_KERNEL_NAME} - lattice_kernel_op_re = '^{}_(.*)/{}/Read/ReadVariableOp$'.format( - premade_lib.LATTICE_LAYER_NAME, - lattice_layer.LATTICE_KERNEL_NAME, + output_combination_node = linear_combination_node + else: + average_node = model_info.MeanNode(input_nodes=input_nodes) + output_combination_node = average_node + return output_combination_node + + +def _create_output_calibration_node(sess, ops, graph, input_node): + """Returns a PWLCalibrationNode.""" + output_calibration_node = None + # Lengths (deltas between keypoints). + # {OUTPUT_CALIB_LAYER_NAME}/{LENGTHS_NAME} + lengths_op = '{}/{}'.format( + premade_lib.OUTPUT_CALIB_LAYER_NAME, + pwl_calibration_layer.LENGTHS_NAME, + ) + if lengths_op in ops: + # Interpolation keypoints does not inlcude the last input keypoint. + # {OUTPUT_CALIB_LAYER_NAME}/{INTERPOLATION_KEYPOINTS_NAME} + keypoints_op = '{}/{}'.format( + premade_lib.OUTPUT_CALIB_LAYER_NAME, + pwl_calibration_layer.INTERPOLATION_KEYPOINTS_NAME, ) - for lattice_kernel_op, submodel_idx in _match_op(ops, lattice_kernel_op_re): - lattice_kernel = sess.run( - g.get_operation_by_name(lattice_kernel_op).outputs[0]).flatten() - # Lattice sizes. - # {Lattice_LAYER_NAME}_{submodel_idx}/{LATTICE_SIZES_NAME} - lattice_sizes_op_name = '{}_{}/{}'.format( - premade_lib.LATTICE_LAYER_NAME, submodel_idx, - lattice_layer.LATTICE_SIZES_NAME) - lattice_sizes = sess.run( - g.get_operation_by_name(lattice_sizes_op_name).outputs[0]).flatten() + # Output keypoints. We need to call the varible read op. + # {OUTPUT_CALIB_LAYER_NAME}/{PWL_CALIBRATION_KERNEL_NAME} + kernel_op = '{}/{}/Read/ReadVariableOp'.format( + premade_lib.OUTPUT_CALIB_LAYER_NAME, + pwl_calibration_layer.PWL_CALIBRATION_KERNEL_NAME, + ) - # Shape the flat lattice parameters based on the calculated lattice sizes. - weights = np.reshape(lattice_kernel, lattice_sizes) + (lengths, keypoints, kernel) = sess.run( + (graph.get_operation_by_name(lengths_op).outputs[0], + graph.get_operation_by_name(keypoints_op).outputs[0], + graph.get_operation_by_name(kernel_op).outputs[0])) + output_keypoints = np.cumsum(kernel.flatten()) - # Sort input nodes by input index. - input_nodes = [ - node for _, node in sorted(submodel_input_nodes[submodel_idx]) - ] + # Add the last keypoint to the keypoint list. + input_keypoints = np.append(keypoints, keypoints[-1] + lengths[-1]) - lattice_node = model_info.LatticeNode( - input_nodes=input_nodes, weights=weights) - submodel_output_nodes[submodel_idx] = lattice_node - nodes.append(lattice_node) + output_calibration_node = model_info.PWLCalibrationNode( + input_node=input_node, + input_keypoints=input_keypoints, + output_keypoints=output_keypoints, + default_input=None, + default_output=None) + return output_calibration_node - ##################################################### - # Create output linear combination or averaging node. - ##################################################### - # Mean node is only added for ensemble models. - if len(submodel_output_nodes) > 1: - input_nodes = [ - submodel_output_nodes[idx] - for idx in sorted(submodel_output_nodes.keys(), key=int) - ] +def get_model_graph(saved_model_path, tag='serve'): + """Returns all layers and parameters used in a saved model as a graph. - # Linear coefficients. - # {LINEAR_LAYER_COMBINATION_NAME}/{LINEAR_LAYER_KERNEL_NAME} - linear_combination_kernel_op = '{}/{}/Read/ReadVariableOp'.format( - premade_lib.OUTPUT_LINEAR_COMBINATION_LAYER_NAME, - linear_layer.LINEAR_LAYER_KERNEL_NAME, - ) - if linear_combination_kernel_op in ops: - coefficients = sess.run( - g.get_operation_by_name( - linear_combination_kernel_op).outputs[0]).flatten() - - # Bias term. - # {OUTPUT_LINEAR_COMBINATION_LAYER_NAME}/{LINEAR_LAYER_BIAS_NAME} - bias_op = '{}/{}/Read/ReadVariableOp'.format( - premade_lib.OUTPUT_LINEAR_COMBINATION_LAYER_NAME, - linear_layer.LINEAR_LAYER_BIAS_NAME, - ) - if bias_op in ops: - bias = sess.run(g.get_operation_by_name(bias_op).outputs[0]) - else: - bias = 0.0 + The returned graph is not a TF graph, rather a graph of python object that + encodes the model structure and includes trained model parameters. The graph + can be used by the `tfl.visualization` module for plotting and other + visualization and analysis. - linear_combination_node = model_info.LinearNode( - input_nodes=input_nodes, coefficients=coefficients, bias=bias) - nodes.append(linear_combination_node) - model_output_node = linear_combination_node - else: - average_node = model_info.MeanNode(input_nodes=input_nodes) - nodes.append(average_node) - model_output_node = average_node - else: - model_output_node = list(submodel_output_nodes.values())[0] + Example: - ##################################### - # Create output PWL calibration node. - ##################################### + ```python + model_graph = estimators.get_model_graph(saved_model_path) + visualization.plot_feature_calibrator(model_graph, "feature_name") + visualization.plot_all_calibrators(model_graph) + visualization.draw_model_graph(model_graph) + ``` - # Lengths (deltas between keypoints). - # {OUTPUT_CALIB_LAYER_NAME}/{LENGTHS_NAME} - lengths_op = '{}/{}'.format( - premade_lib.OUTPUT_CALIB_LAYER_NAME, - pwl_calibration_layer.LENGTHS_NAME, - ) - if lengths_op in ops: - # Interpolation keypoints does not inlcude the last input keypoint. - # {OUTPUT_CALIB_LAYER_NAME}/{INTERPOLATION_KEYPOINTS_NAME} - keypoints_op = '{}/{}'.format( - premade_lib.OUTPUT_CALIB_LAYER_NAME, - pwl_calibration_layer.INTERPOLATION_KEYPOINTS_NAME, - ) + Args: + saved_model_path: Path to the saved model. + tag: Saved model tag for loading. - # Output keypoints. We need to call the varible read op. - # {OUTPUT_CALIB_LAYER_NAME}/{PWL_CALIBRATION_KERNEL_NAME} - kernel_op = '{}/{}/Read/ReadVariableOp'.format( - premade_lib.OUTPUT_CALIB_LAYER_NAME, - pwl_calibration_layer.PWL_CALIBRATION_KERNEL_NAME, - ) + Returns: + A `model_info.ModelGraph` object that includes the model graph. + """ + # List of all the nodes in the model. + nodes = [] - (lengths, keypoints, kernel) = sess.run( - (g.get_operation_by_name(lengths_op).outputs[0], - g.get_operation_by_name(keypoints_op).outputs[0], - g.get_operation_by_name(kernel_op).outputs[0])) - output_keypoints = np.cumsum(kernel.flatten()) + # Dict from submodel index to the output node of the submodel. + submodel_output_nodes = {} - # Add the last keypoint to the keypoint list. - input_keypoints = np.append(keypoints, keypoints[-1] + lengths[-1]) + tf.compat.v1.reset_default_graph() + with tf.compat.v1.Session() as sess: + tf.compat.v1.saved_model.loader.load(sess, [tag], saved_model_path) + graph = tf.compat.v1.get_default_graph() + ops = [op.name for op in graph.get_operations()] + + # Dict from feature name to corresponding InputFeatureNode object. + feature_nodes = _create_feature_nodes(sess, ops, graph) + nodes.extend(feature_nodes.values()) + + # Categorical Calibration Nodes. + categorical_calibration_nodes = _create_categorical_calibration_nodes( + sess, ops, graph, feature_nodes) + for calibration_nodes in categorical_calibration_nodes.values(): + nodes.extend(calibration_nodes) + + # PWL Calibration Nodes. + pwl_calibration_nodes = _create_pwl_calibration_nodes( + sess, ops, graph, feature_nodes) + for calibration_nodes in pwl_calibration_nodes.values(): + nodes.extend(calibration_nodes) + + # Dict from feature name to list of calibration nodes (Categorical and PWL). + calibration_nodes_map = {} + calibration_nodes_map.update(categorical_calibration_nodes) + calibration_nodes_map.update(pwl_calibration_nodes) + # Dict from submodel index to a list of calibrated inputs for the submodel. + submodel_input_nodes = _create_submodel_input_map(ops, + calibration_nodes_map) + + # Linear nodes + linear_nodes = _create_linear_nodes(sess, ops, graph, submodel_input_nodes) + submodel_output_nodes.update(linear_nodes) + nodes.extend(linear_nodes.values()) + + # Ensemble Lattice nodes. + lattice_nodes = _create_lattice_nodes(sess, ops, graph, + submodel_input_nodes) + submodel_output_nodes.update(lattice_nodes) + nodes.extend(lattice_nodes.values()) + + # RTL Lattice nodes. + rtl_lattice_nodes = _create_rtl_lattice_nodes(sess, ops, graph, + calibration_nodes_map) + submodel_output_nodes.update(rtl_lattice_nodes) + nodes.extend(rtl_lattice_nodes.values()) + + # Output combination node. + model_output_node = _create_output_combination_node(sess, ops, graph, + submodel_output_nodes) + if model_output_node: + nodes.append(model_output_node) + else: + model_output_node = list(submodel_output_nodes.values())[0] - output_calibration_node = model_info.PWLCalibrationNode( - input_node=model_output_node, - input_keypoints=input_keypoints, - output_keypoints=output_keypoints, - default_input=None, - default_output=None) + # Output calibration node. + output_calibration_node = _create_output_calibration_node( + sess, ops, graph, model_output_node) + if output_calibration_node: nodes.append(output_calibration_node) model_output_node = output_calibration_node diff --git a/tensorflow_lattice/python/estimators_test.py b/tensorflow_lattice/python/estimators_test.py index 13b871b..7064f58 100644 --- a/tensorflow_lattice/python/estimators_test.py +++ b/tensorflow_lattice/python/estimators_test.py @@ -17,6 +17,8 @@ from __future__ import division from __future__ import print_function +import copy + from absl import logging from absl.testing import parameterized import numpy as np @@ -210,9 +212,10 @@ def setUp(self): pwl_calibration_input_keypoints='uniform', pwl_calibration_always_monotonic=False, reflects_trust_in=[ - configs.TrustConfig(feature_name='RM', - trust_type='edgeworth', - direction='negative'), + configs.TrustConfig( + feature_name='RM', + trust_type='edgeworth', + direction='negative'), ], regularizer_configs=[ configs.RegularizerConfig(name='calib_wrinkle', l2=1e-4), @@ -245,8 +248,8 @@ def setUp(self): name='LSTAT', monotonicity=-1, dominates=[ - configs.DominanceConfig(feature_name='AGE', - dominance_type='monotonic'), + configs.DominanceConfig( + feature_name='AGE', dominance_type='monotonic'), ], ), ] @@ -293,6 +296,10 @@ def _GetBostonTestInputFn(self, **kwargs): 'age', 'sex', 'cp', 'trestbps', 'chol', 'fbs', 'restecg', 'thalach', 'exang', 'oldpeak', 'slope', 'ca', 'thal' ], 'crystals', 6, 5, True, False, 0.85), + ([ + 'age', 'sex', 'cp', 'trestbps', 'chol', 'fbs', 'restecg', 'thalach', + 'exang', 'oldpeak', 'slope', 'ca', 'thal' + ], 'rtl_layer', 6, 5, True, False, 0.85), ) def testCalibratedLatticeEnsembleClassifier(self, feature_names, lattices, num_lattices, lattice_rank, @@ -307,6 +314,15 @@ def testCalibratedLatticeEnsembleClassifier(self, feature_names, lattices, feature_config for feature_config in self.heart_feature_configs if feature_config.name in feature_names ] + if lattices == 'rtl_layer': + # RTL Layer only supports monotonicity and bound constraints. + feature_configs = copy.deepcopy(feature_configs) + for feature_config in feature_configs: + feature_config.lattice_size = 2 + feature_config.unimodality = 'none' + feature_config.reflects_trust_in = None + feature_config.dominates = None + feature_config.regularizer_configs = None model_config = configs.CalibratedLatticeEnsembleConfig( regularizer_configs=[ configs.RegularizerConfig(name='torsion', l2=1e-4), @@ -418,6 +434,10 @@ def testCalibratedLinearClassifier(self, feature_names, output_calibration, 'CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD', 'TAX', 'PTRATIO', 'B', 'LSTAT' ], 'crystals', 6, 5, True, False, 50.0), + ([ + 'CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD', + 'TAX', 'PTRATIO', 'B', 'LSTAT' + ], 'rtl_layer', 6, 5, True, False, 50.0), ) def testCalibratedLatticeEnsembleRegressor(self, feature_names, lattices, num_lattices, lattice_rank, @@ -432,6 +452,15 @@ def testCalibratedLatticeEnsembleRegressor(self, feature_names, lattices, feature_config for feature_config in self.boston_feature_configs if feature_config.name in feature_names ] + if lattices == 'rtl_layer': + # RTL Layer only supports monotonicity and bound constraints. + feature_configs = copy.deepcopy(feature_configs) + for feature_config in feature_configs: + feature_config.lattice_size = 2 + feature_config.unimodality = 'none' + feature_config.reflects_trust_in = None + feature_config.dominates = None + feature_config.regularizer_configs = None model_config = configs.CalibratedLatticeEnsembleConfig( regularizer_configs=[ configs.RegularizerConfig(name='torsion', l2=1e-5), @@ -565,15 +594,27 @@ def testCalibratedLinearEstimator(self, feature_names, output_calibration, self.assertLess(results['average_loss'], average_loss) @parameterized.parameters( - (5, 6, False, True), - (4, 5, True, False), + ('random', 5, 6, False, True), + ('random', 4, 5, True, False), + ('rtl_layer', 5, 6, False, True), + ('rtl_layer', 4, 5, True, False), ) - def testCalibratedLatticeEnsembleModelInfo(self, num_lattices, lattice_rank, - separate_calibrators, + def testCalibratedLatticeEnsembleModelInfo(self, lattices, num_lattices, + lattice_rank, separate_calibrators, output_calibration): self._ResetAllBackends() + feature_configs = copy.deepcopy(self.heart_feature_configs) + if lattices == 'rtl_layer': + # RTL Layer only supports monotonicity and bound constraints. + for feature_config in feature_configs: + feature_config.lattice_size = 2 + feature_config.unimodality = 'none' + feature_config.reflects_trust_in = None + feature_config.dominates = None + feature_config.regularizer_configs = None model_config = configs.CalibratedLatticeEnsembleConfig( - feature_configs=self.heart_feature_configs, + feature_configs=feature_configs, + lattices=lattices, num_lattices=num_lattices, lattice_rank=lattice_rank, separate_calibrators=separate_calibrators, @@ -611,19 +652,16 @@ def testCalibratedLatticeEnsembleModelInfo(self, num_lattices, lattice_rank, self.assertLen(model.nodes, expected_num_nodes) @parameterized.parameters( - (['ZN', 'INDUS', 'RM'], 'random', 3, 1, - [['ZN', 'RM'], ['RM'], ['INDUS']]), - (['ZN', 'INDUS', 'RM'], 'crystals', 3, 1, - [['RM'], ['INDUS'], ['ZN', 'RM']]), - (['RM', 'LSTAT', 'AGE'], 'crystals', 3, 1, - [['LSTAT'], ['LSTAT', 'AGE'], ['RM']]), + (['ZN', 'INDUS', 'RM'], 'random', 3, 1, [['ZN', 'RM'], ['RM'], ['INDUS'] + ]), + (['ZN', 'INDUS', 'RM'], 'crystals', 3, 1, [['RM'], ['INDUS'], + ['ZN', 'RM']]), + (['RM', 'LSTAT', 'AGE'], 'crystals', 3, 1, [['LSTAT'], ['LSTAT', 'AGE'], + ['RM']]), ) - def testCalibratedLatticeEnsembleFix2dConstraintViolations(self, - feature_names, - lattices, - num_lattices, - lattice_rank, - expected_lattices): + def testCalibratedLatticeEnsembleFix2dConstraintViolations( + self, feature_names, lattices, num_lattices, lattice_rank, + expected_lattices): self._ResetAllBackends() feature_columns = [ feature_column for feature_column in self.boston_feature_columns @@ -652,8 +690,7 @@ def testCalibratedLatticeEnsembleFix2dConstraintViolations(self, # Serving input fn is used to create saved models. serving_input_fn = ( tf.estimator.export.build_parsing_serving_input_receiver_fn( - feature_spec=fc.make_parse_example_spec(feature_columns)) - ) + feature_spec=fc.make_parse_example_spec(feature_columns))) saved_model_path = estimator.export_saved_model(estimator.model_dir, serving_input_fn) logging.info('Model exported to %s', saved_model_path) @@ -661,8 +698,8 @@ def testCalibratedLatticeEnsembleFix2dConstraintViolations(self, lattices = [] for node in model.nodes: if isinstance(node, model_info.LatticeNode): - lattices.append([input_node.input_node.name - for input_node in node.input_nodes]) + lattices.append( + [input_node.input_node.name for input_node in node.input_nodes]) self.assertLen(lattices, len(expected_lattices)) for lattice, expected_lattice in zip(lattices, expected_lattices): diff --git a/tensorflow_lattice/python/internal_utils.py b/tensorflow_lattice/python/internal_utils.py index 24cec35..706e7bf 100644 --- a/tensorflow_lattice/python/internal_utils.py +++ b/tensorflow_lattice/python/internal_utils.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Internal helpers shared by multiple modules in TFL. Note that this module is not expected to be used by TFL users, and that it is @@ -72,8 +71,8 @@ def _min_projection(weights, sorted_indices, key_less_than_values, step): sorted_indices: Topologically sorted list of indices based on the monotonicity constraints. key_less_than_values: A defaultdict from index to a list of indices, such - that for `j` in `key_less_than_values[i]` we must have - `weight[i] <= weight[j]`. + that for `j` in `key_less_than_values[i]` we must have `weight[i] <= + weight[j]`. step: A value defining if we should apply a full projection (`step == 1`) or a partial projection (`step < 1`). @@ -125,8 +124,8 @@ def _max_projection(weights, sorted_indices, key_greater_than_values, step): return projected_weights -def approximately_project_categorical_partial_monotonicities(weights, - monotonicities): +def approximately_project_categorical_partial_monotonicities( + weights, monotonicities): """Returns an approximation L2 projection for categorical monotonicities. Categorical monotonocities are monotonicity constraints applied to the real @@ -151,8 +150,7 @@ def approximately_project_categorical_partial_monotonicities(weights, projected_weights = tf.unstack(weights) # A 0.5 min projection followed by a full max projection. - projected_weights_min_max = _min_projection(projected_weights, - sorted_indices, + projected_weights_min_max = _min_projection(projected_weights, sorted_indices, key_less_than_values, 0.5) projected_weights_min_max = _max_projection(projected_weights_min_max, sorted_indices, @@ -160,8 +158,7 @@ def approximately_project_categorical_partial_monotonicities(weights, projected_weights_min_max = tf.stack(projected_weights_min_max) # A 0.5 max projection followed by a full min projection. - projected_weights_max_min = _max_projection(projected_weights, - sorted_indices, + projected_weights_max_min = _max_projection(projected_weights, sorted_indices, key_greater_than_values, 0.5) projected_weights_max_min = _min_projection(projected_weights_max_min, sorted_indices, diff --git a/tensorflow_lattice/python/internal_utils_test.py b/tensorflow_lattice/python/internal_utils_test.py index 1ada0f9..7e9bb17 100644 --- a/tensorflow_lattice/python/internal_utils_test.py +++ b/tensorflow_lattice/python/internal_utils_test.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for Tensorflow Lattice utility functions.""" from __future__ import absolute_import @@ -21,7 +20,7 @@ from absl.testing import parameterized import numpy as np import tensorflow as tf -from tensorflow_lattice.python import internal_utils as iu +from tensorflow_lattice.python import internal_utils class InternalUtilsTest(parameterized.TestCase, tf.test.TestCase): @@ -30,22 +29,21 @@ def _ResetAllBackends(self): tf.compat.v1.reset_default_graph() @parameterized.parameters( - ([3., 4.], [(0, 1)], [3., 4.]), - ([4., 3.], [(0, 1)], [3.5, 3.5]), - ([1., 0.], [(0, 1)], [0.5, 0.5]), - ([-1., 0.], [(1, 0)], [-0.5, -0.5]), - ([4., 3., 2., 1., 0.], [(0, 1), (1, 2), (2, 3), (3, 4)], - [2., 2., 2., 2., 2.])) + ([3., 4.], [(0, 1)], [3., 4.]), ([4., 3.], [(0, 1)], [3.5, 3.5]), + ([1., 0.], [(0, 1)], [0.5, 0.5]), ([-1., 0.], [(1, 0)], [-0.5, -0.5]), + ([4., 3., 2., 1., 0.], [(0, 1), (1, 2), (2, 3), + (3, 4)], [2., 2., 2., 2., 2.])) def testApproximatelyProjectCategoricalPartialMonotonicities( self, weights, monotonicities, expected_projected_weights): self._ResetAllBackends() weights = tf.Variable(weights) projected_weights = ( - iu.approximately_project_categorical_partial_monotonicities( + internal_utils.approximately_project_categorical_partial_monotonicities( weights, monotonicities)) self.evaluate(tf.compat.v1.global_variables_initializer()) - self.assertAllClose(self.evaluate(projected_weights), - np.array(expected_projected_weights)) + self.assertAllClose( + self.evaluate(projected_weights), np.array(expected_projected_weights)) + if __name__ == '__main__': tf.test.main() diff --git a/tensorflow_lattice/python/kronecker_factored_lattice_layer.py b/tensorflow_lattice/python/kronecker_factored_lattice_layer.py new file mode 100644 index 0000000..66a117b --- /dev/null +++ b/tensorflow_lattice/python/kronecker_factored_lattice_layer.py @@ -0,0 +1,428 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Kronecker-Factored Lattice layer with monotonicity constraints. + +Keras implementation of tensorflow Kronecker-Factored Lattice layer. This layer +takes one or more d-dimensional input(s) and combines them using a +Kronecker-Factored Lattice function, satisfying monotonicity constraints if +specified. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from . import kronecker_factored_lattice_lib as kfl_lib +from . import utils +from tensorflow import keras + +KFL_SCALE_NAME = "kronecker_factored_lattice_scale" +KFL_BIAS_NAME = "kronecker_factored_lattice_bias" +KFL_KERNEL_NAME = "kronecker_factored_lattice_kernel" + + +# TODO: add support for different lattice_sizes for each input +# dimension. +class KroneckerFactoredLattice(keras.layers.Layer): + # pyformat: disable + """Kronecker-Factored Lattice layer. + + A Kronecker-Factored Lattice is a reparameterization of a Lattice using + kronecker-facotrization, which gives us linear time and space complexity. + While the underlying representation is different, the input-output behavior + remains the same. + + A Kronecker-Factored Lattice consists of 'units' lattices. Each unit computes + the function described below on a distinct 'dims'-dimensional vector x taken + from the input tensor. Each unit has its own set of parameters. The function + each unit computes is given by: + + f(x) = b + (1/num_terms) * sum_{t=1}^{num_terms} scale_t * prod_{d=1}^{dims} PLF(x[d];w[d]) + + where bias and each scale_t are scalar parameters, w[d] is a + 'lattice_size'-dimensional vector of parameters, and PLF(;w) denotes the + one-dimensional piecewise-linear function with domain [0, lattice_sizes-1] + whose graph consists of lattice_sizes-1 linear segments interpolating the + points (i, w[i]), for i=0,1,...,lattice_size-1. + + There is currently one type of constraint on the shape of the learned + function. + + * **Monotonicity:** constrains the function to be increasing in the + corresponding dimension. + + Input shape: + - if `units == 1`: tensor of shape: `(batch_size, ..., dims)` + or list of `dims` tensors of same shape: `(batch_size, ..., 1)` + - if `units > 1`: tensor of shape: `(batch_size, ..., units, dims)` or list + of `dims` tensors of same shape: `(batch_size, ..., units, 1)` + + A typical shape is: `(batch_size, len(monotonicities))` + + Output shape: + Tensor of shape: `(batch_size, ..., units)` + + Attributes: + - All `__init__` arguments. + scale: A tensor of shape `(units, num_terms)`. Contains the `scale_t` + parameter for each unit for each term. + bias: A tensor of shape `(units)`. Contains the `b` parameter for each unit. + kernel: The `w` weights parameter of the Kronecker-Factored Lattice of + shape: `(1, lattice_sizes, units * dims, num_terms)`. Note that the kernel + is unit-major in its second to last dimension. + + Example: + + ```python + kfl = tfl.layers.KroneckerFactoredLattice( + # Number of vertices along each dimension. + lattice_sizes=2, + # Number of output units. + units=2, + # Number of independently trained submodels per unit, the outputs + # of which are averaged to get the final output. + num_terms=4, + # You can specify monotonicity constraints. + monotonicities=['increasing', 'none', 'increasing', 'increasing', + 'increasing', 'increasing', 'increasing']) + ``` + """ + # pyformat: enable + + def __init__(self, + lattice_sizes, + units=1, + num_terms=2, + monotonicities=None, + clip_inputs=True, + satisfy_constraints_at_every_step=True, + kernel_initializer="random_monotonic_initializer", + **kwargs): + # pyformat: disable + """Initializes an instance of `KroneckerFactoredLattice`. + + Args: + lattice_sizes: Number of vertices per dimension (minimum is 2). + units: Output dimension of the layer. See class comments for details. + num_terms: Number of independently trained submodels per unit, the outputs + of which are averaged to get the final output. + monotonicities: None or list or tuple of same length as input dimension of + {'none', 'increasing', 0, 1} which specifies if the model output should + be monotonic in the corresponding feature, using 'increasing' or 1 to + indicate increasing monotonicity and 'none' or 0 to indicate no + monotonicity constraints. + clip_inputs: If inputs should be clipped to the input range of the + Kronecker-Factored Lattice. + satisfy_constraints_at_every_step: Whether to strictly enforce constraints + after every gradient update by applying an imprecise projection. + kernel_initializer: None or one of: + - `'random_monotonic_initializer'`: initializes parameters as uniform + random functions that are monotonic in monotonic dimensions. + - Any Keras initializer object. + **kwargs: Other args passed to `tf.keras.layers.Layer` initializer. + + Raises: + ValueError: If layer hyperparameters are invalid. + """ + # pyformat: enable + kfl_lib.verify_hyperparameters( + lattice_sizes=lattice_sizes, units=units, num_terms=num_terms) + super(KroneckerFactoredLattice, self).__init__(**kwargs) + + self.lattice_sizes = lattice_sizes + self.units = units + self.num_terms = num_terms + self.monotonicities = monotonicities + self.clip_inputs = clip_inputs + self.satisfy_constraints_at_every_step = satisfy_constraints_at_every_step + + self.kernel_initializer = create_kernel_initializer( + kernel_initializer_id=kernel_initializer, + monotonicities=self.monotonicities) + + def build(self, input_shape): + """Standard Keras build() method.""" + kfl_lib.verify_hyperparameters( + units=self.units, + input_shape=input_shape, + monotonicities=self.monotonicities) + # input_shape: (batch, ..., units, dims) + if isinstance(input_shape, list): + dims = len(input_shape) + else: + dims = input_shape.as_list()[-1] + + self.scale = self.add_weight( + KFL_SCALE_NAME, + shape=[self.units, self.num_terms], + initializer=ScaleInitializer(), + dtype=self.dtype) + self.bias = self.add_weight( + KFL_BIAS_NAME, + shape=[self.units], + initializer="zeros", + dtype=self.dtype) + + if self.monotonicities: + constraints = KroneckerFactoredLatticeConstraints( + units=self.units, + scale=self.scale, + monotonicities=self.monotonicities, + satisfy_constraints_at_every_step=self + .satisfy_constraints_at_every_step) + else: + constraints = None + + # Note that the first dimension of shape is 1 to work with + # tf.nn.depthwise_conv2d. + self.kernel = self.add_weight( + KFL_KERNEL_NAME, + shape=[1, self.lattice_sizes, self.units * dims, self.num_terms], + initializer=self.kernel_initializer, + constraint=constraints, + dtype=self.dtype) + + self._final_constraints = KroneckerFactoredLatticeConstraints( + units=self.units, + scale=self.scale, + monotonicities=self.monotonicities, + satisfy_constraints_at_every_step=True) + + super(KroneckerFactoredLattice, self).build(input_shape) + + def call(self, inputs): + """Standard Keras call() method.""" + return kfl_lib.evaluate_with_hypercube_interpolation( + inputs=inputs, + scale=self.scale, + bias=self.bias, + kernel=self.kernel, + units=self.units, + num_terms=self.num_terms, + lattice_sizes=self.lattice_sizes, + clip_inputs=self.clip_inputs) + + def compute_output_shape(self, input_shape): + """Standard Keras compute_output_shape() method.""" + if isinstance(input_shape, list): + input_shape = input_shape[0] + if self.units == 1: + return tuple(input_shape[:-1]) + (1,) + else: + # Second to last dimension must be equal to 'units'. Nothing to append. + return input_shape[:-1] + + def get_config(self): + """Standard Keras config for serialization.""" + config = { + "lattice_sizes": self.lattice_sizes, + "units": self.units, + "num_terms": self.num_terms, + "monotonicities": self.monotonicities, + "clip_inputs": self.clip_inputs, + "satisfy_constraints_at_every_step": + self.satisfy_constraints_at_every_step, + "kernel_initializer": + keras.initializers.serialize(self.kernel_initializer), + } # pyformat: disable + config.update(super(KroneckerFactoredLattice, self).get_config()) + return config + + def finalize_constraints(self): + """Ensures layers weights strictly satisfy constraints. + + Applies approximate projection to strictly satisfy specified constraints. + If `monotonic_at_every_step == True` there is no need to call this function. + + Returns: + In eager mode directly updates weights and returns variable which stores + them. In graph mode returns `assign_add` op which has to be executed to + updates weights. + """ + return self.kernel.assign_add( + self._final_constraints(self.kernel) - self.kernel) + + def assert_constraints(self, eps=1e-6): + """Asserts that weights satisfy all constraints. + + In graph mode builds and returns list of assertion ops. + In eager mode directly executes assertions. + + Args: + eps: allowed constraints violation. + + Returns: + List of assertion ops in graph mode or immediately asserts in eager mode. + """ + return kfl_lib.assert_constraints( + weights=self.kernel, + units=self.units, + scale=self.scale, + monotonicities=utils.canonicalize_monotonicities( + self.monotonicities, allow_decreasing=False), + eps=eps) + + +def create_kernel_initializer(kernel_initializer_id, monotonicities): + """Returns a kernel Keras initializer object from its id. + + This function is used to convert the 'kernel_initializer' parameter in the + constructor of tfl.layers.KroneckerFactoredLattice into the corresponding + initializer object. + + Args: + kernel_initializer_id: See the documentation of the 'kernel_initializer' + parameter in the constructor of `tfl.layers.KroneckerFactoredLattice`. + monotonicities: See the documentation of the same parameter in the + constructor of `tfl.layers.KroneckerFactoredLattice`. + + Returns: + The Keras initializer object for the `tfl.layers.KroneckerFactoredLattice` + kernel variable. + """ + # Construct initializer. + if kernel_initializer_id in [ + "random_monotonic_initializer", "RandomMonotonicInitializer" + ]: + return RandomMonotonicInitializer(monotonicities) + else: + # This is needed for Keras deserialization logic to be aware of our custom + # objects. + with keras.utils.custom_object_scope({ + "RandomMonotonicInitializer": RandomMonotonicInitializer, + }): + return keras.initializers.get(kernel_initializer_id) + + +class RandomMonotonicInitializer(keras.initializers.Initializer): + # pyformat: disable + """Initializes a `tfl.layers.KroneckerFactoredLattice` as random monotonic.""" + # pyformat: enable + + def __init__(self, monotonicities, seed=None): + """Initializes an instance of `RandomMonotonicInitializer`. + + Args: + monotonicities: Monotonic dimensions for initialization. Does not need to + match `monotonicities` of `tfl.layers.KroneckerFactoredLattice`. + seed: A Python integer. Used to create a random seed for the distribution. + """ + self.monotonicities = monotonicities + self.seed = seed + + def __call__(self, shape, dtype=None, partition_info=None): + """Returns weights of `tfl.layers.KroneckerFactoredLattice` layer. + + Args: + shape: Must be: `(1, lattice_sizes, units * dims, num_terms)`. + dtype: Standard Keras initializer param. + partition_info: Standard Keras initializer param. Not used. + """ + del partition_info + return kfl_lib.random_monotonic_initializer( + shape=shape, + monotonicities=utils.canonicalize_monotonicities( + self.monotonicities, allow_decreasing=False), + dtype=dtype, + seed=self.seed) + + def get_config(self): + """Standard Keras config for serializaion.""" + config = { + "monotonicities": self.monotonicities, + "seed": self.seed, + } # pyformat: disable + return config + + +class ScaleInitializer(keras.initializers.Initializer): + # pyformat: disable + """Initializes scale to alternate between 1 and -1 for each term.""" + # pyformat: enable + + def __call__(self, shape, dtype=None, partition_info=None): + """Returns weights of `tfl.layers.KroneckerFactoredLattice` scale. + + Args: + shape: Must be: `(units, num_terms)`. + dtype: Standard Keras initializer param. + partition_info: Standard Keras initializer param. Not used. + """ + del partition_info + units, num_terms = shape + return kfl_lib.scale_initializer(units=units, num_terms=num_terms) + + +class KroneckerFactoredLatticeConstraints(keras.constraints.Constraint): + # pyformat: disable + """Constraints for `tfl.layers.KroneckerFactoredLattice` layer. + + Applies all constraints to the Kronecker-Factored Lattice weights. See + `tfl.layers.KroneckerFactoredLattice` for more details. + + Attributes: + - All `__init__` arguments. + """ + # pyformat: enable + + def __init__(self, + units, + scale, + monotonicities=None, + satisfy_constraints_at_every_step=True): + """Initializes an instance of `KroneckerFactoredLatticeConstraints`. + + Args: + units: Same meaning as corresponding parameter of + `KroneckerFactoredLattice`. + scale: Scale variable of shape: `(units, num_terms)`. + monotonicities: Same meaning as corresponding parameter of + `KroneckerFactoredLattice`. + satisfy_constraints_at_every_step: Whether to use approximate projection + to ensure that constratins are strictly satisfied. + """ + self.units = units + self.scale = scale + self.monotonicities = utils.canonicalize_monotonicities( + monotonicities, allow_decreasing=False) + self.num_constraint_dims = utils.count_non_zeros(self.monotonicities) + self.satisfy_constraints_at_every_step = satisfy_constraints_at_every_step + + def __call__(self, w): + """Applies constraints to `w`. + + Args: + w: Kronecker-Factored Lattice weights tensor of shape: `(1, lattice_sizes, + units * dims, num_terms)`. + + Returns: + Constrained and projected w. + """ + if self.num_constraint_dims and self.satisfy_constraints_at_every_step: + w = kfl_lib.finalize_constraints( + w, + units=self.units, + scale=self.scale, + monotonicities=self.monotonicities) + return w + + def get_config(self): + """Standard Keras config for serialization.""" + return { + "units": self.units, + "scale": self.scale, + "monotonicities": self.monotonicities, + "satisfy_constraints_at_every_step": + self.satisfy_constraints_at_every_step, + } # pyformat: disable diff --git a/tensorflow_lattice/python/kronecker_factored_lattice_lib.py b/tensorflow_lattice/python/kronecker_factored_lattice_lib.py new file mode 100644 index 0000000..979cbb4 --- /dev/null +++ b/tensorflow_lattice/python/kronecker_factored_lattice_lib.py @@ -0,0 +1,409 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Algorithm implementations required for Kronecker-Factored Lattice layer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from . import utils +import numpy as np +import tensorflow as tf + + +def evaluate_with_hypercube_interpolation(inputs, scale, bias, kernel, units, + num_terms, lattice_sizes, + clip_inputs): + """Evaluates a Kronecker-Factored Lattice using hypercube interpolation. + + Kronecker-Factored Lattice function is the product of the piece-wise linear + interpolation weights for each dimension of the input. + + Args: + inputs: Tensor representing points to apply lattice interpolation to. If + units = 1, tensor should be of shape: `(batch_size, ..., dims)` or list of + `dims` tensors of same shape `(batch_size, ..., 1)`. If units > 1, + tensor + should be of shape: `(batch_size, ..., units, dims)` or list of `dims` + tensors of same shape `(batch_size, ..., units, 1)`. A typical shape is + `(batch_size, dims)`. + scale: Kronecker-Factored Lattice scale of shape `(units, num_terms)`. + bias: Kronecker-Factored Lattice bias of shape `(units)`. + kernel: Kronecker-Factored Lattice kernel of shape `(1, lattice_sizes, units + * dims, num_terms)`. + units: Output dimension of the Kronecker-Factored Lattice. + num_terms: Number of independently trained submodels per unit, the outputs + of which are averaged to get the final output. + lattice_sizes: Number of vertices per dimension. + clip_inputs: If inputs should be clipped to the input range of the + Kronecker-Factored Lattice. + + Returns: + Tensor of shape: `(batch_size, ..., units)`. + """ + # Convert list of tensors to single tensor object. + if isinstance(inputs, list): + inputs = tf.stack(inputs, axis=-1) + if clip_inputs: + inputs = tf.clip_by_value(inputs, 0.0, lattice_sizes - 1.0) + + inputs_shape = inputs.get_shape().as_list() + dims = inputs_shape[-1] + # Compute total dimension size before units excluding batch to squeeze into + # one axis. + idx = -1 if units == 1 else -2 + rows = int(np.prod(inputs_shape[1:idx])) + inputs = tf.reshape(inputs, [-1, rows, units * dims]) + + # interpolation_weights.shape: (batch, rows, lattice_sizes, units * dims). + # interpolation_weights[m,n,i,j] should be the interpolation weight of the + # (m,n,j) input in the i'th vertex, i.e. 0 if dist(input[m,n,j], i) >= 1, + # otherwise 1 - dist(input[m,n,j], i), where `dist(...)` denotes the Euclidean + # distance between scalars. + if lattice_sizes == 2: + interpolation_weights = tf.stack([1 - inputs, inputs], axis=-2) + else: + vertices = tf.constant( + list(range(lattice_sizes)), + shape=(lattice_sizes, 1), + dtype=inputs.dtype) + interpolation_weights = vertices - tf.expand_dims(inputs, axis=-2) + interpolation_weights = 1 - tf.minimum(tf.abs(interpolation_weights), 1) + + # dotprod.shape: (batch, rows, 1, units * dims * num_terms) + dotprod = tf.nn.depthwise_conv2d( + interpolation_weights, kernel, [1, 1, 1, 1], padding="VALID") + dotprod = tf.reshape(dotprod, [-1, rows, units, dims, num_terms]) + + prod = tf.reduce_prod(dotprod, axis=-2) + + results = scale * prod + # Average across terms for each unit. + results = tf.reduce_mean(results, axis=-1) + results = results + bias + + # results.shape: (batch, rows, units) + results_shape = [-1] + inputs_shape[1:-1] + if units == 1: + results_shape.append(1) + results = tf.reshape(results, results_shape) + return results + + +def random_monotonic_initializer(shape, + monotonicities, + dtype=tf.float32, + seed=None): + """Returns a uniformly random sampled monotonic weight tensor. + + - The uniform random monotonic function will initilaize the lattice parameters + uniformly at random and make it such that the parameters are monotonically + increasing for each input. + - The random parameters will be sampled from `[0, 1]` + + Args: + shape: Shape of weights to initialize. Must be: `(1, lattice_sizes, units * + dims, num_terms)`. + monotonicities: None or list or tuple of length dims of elements of {0,1} + which represents monotonicity constraints per dimension. 1 stands for + increasing (non-decreasing in fact), 0 for no monotonicity constraints. + dtype: dtype + seed: A Python integer. Used to create a random seed for the distribution. + + Returns: + Kronecker-Factored Lattice weights tensor of shape: + `(1, lattice_sizes, units * dims, num_terms)`. + """ + # Sample from the uniform distribution. + weights = tf.random.uniform(shape, dtype=dtype, seed=seed) + if utils.count_non_zeros(monotonicities) > 0: + # To sort, we must first reshape and unstack our weights. + dims = len(monotonicities) + _, lattice_sizes, units_times_dims, num_terms = shape + if units_times_dims % dims != 0: + raise ValueError( + "len(monotonicities) is {}, which does not evenly divide shape[2]" + "len(monotonicities) should be equal to `dims`, and shape[2] " + "should be equal to units * dims.".format(dims)) + units = units_times_dims // dims + weights = tf.reshape(weights, [-1, lattice_sizes, units, dims, num_terms]) + # Now we can unstack each dimension. + weights = tf.unstack(weights, axis=3) + monotonic_weights = [ + tf.sort(weight) if monotonicity else weight + for monotonicity, weight in zip(monotonicities, weights) + ] + # Restack, reshape, and return weights + weights = tf.stack(monotonic_weights, axis=3) + weights = tf.reshape(weights, shape) + return weights + + +def scale_initializer(units, num_terms): + """Initializes scale to alternate between 1 and -1 for each term. + + Args: + units: Output dimension of the layer. Each of units scale will be + initialized identically. + num_terms: Number of independently trained submodels per unit, the outputs + of which are averaged to get the final output. + + Returns: + Kronecker-Factored Lattice scale of shape: `(units, num_terms)`. + """ + signs = (np.arange(num_terms) % -2) * 2 + 1 + return np.tile(signs, [units, 1]) + + +def _approximately_project_monotonicity(weights, units, scale, monotonicities): + """Approximately projects to strictly meet monotonicity constraints. + + For more details, see _approximately_project_monotonicity in lattice_lib.py. + + Args: + weights: Tensor with weights of shape `(1, lattice_sizes, units * dims, + num_terms)`. + units: Number of units per input dimension. + scale: Scale variable of shape: `(units, num_terms)`. + monotonicities: List or tuple of length dims of elements of {0,1} which + represents monotonicity constraints per dimension. 1 stands for increasing + (non-decreasing in fact), 0 for no monotonicity constraints. + + Returns: + Tensor with projected weights matching shape of input weights. + """ + # Recall that w.shape is (1, lattice_sizes, units * dims, num_terms). + weights_shape = weights.get_shape().as_list() + _, lattice_sizes, units_times_dims, num_terms = weights_shape + assert units_times_dims % units == 0 + dims = units_times_dims // units + weights = tf.reshape(weights, [-1, lattice_sizes, units, dims, num_terms]) + + # Extract the sign of scale to determine the projection direction. + direction = tf.expand_dims(tf.sign(scale), axis=1) + + # TODO: optimize for case where all dims are monotonic and we won't + # need to unstack. + # Unstack our weights such that we have the weight for each dimension. We + # multiply by direction such that we always project the weights to be + # increasing. + weights = tf.unstack(direction * weights, axis=3) + projected = [] + for weight, monotonicity in zip(weights, monotonicities): + if monotonicity: + # First we go forward to find the maximum projection. + max_projection = tf.unstack(weight, axis=1) + for i in range(1, len(max_projection)): + max_projection[i] = tf.maximum(max_projection[i], max_projection[i - 1]) + # Find the halfway projection to find the minimum projection. + half_projection = (weight + tf.stack(max_projection, axis=1)) / 2.0 + # Now we go backwards to find the minimum projection. + min_projection = tf.unstack(half_projection, axis=1) + for i in range(len(min_projection) - 2, -1, -1): + min_projection[i] = tf.minimum(min_projection[i], min_projection[i + 1]) + # Restack our weight from the minimum projection. + weight = tf.stack(min_projection, axis=1) + # Add our projected weight to our running list. + projected.append(weight) + # Restack our final projected weights. We multiply by direction such that if + # direction is negative we end up with decreasing weights. + weights = direction * tf.stack(projected, axis=3) + + # Reshape projected weights into original shape and return them. + weights = tf.reshape(weights, weights_shape) + return weights + + +def finalize_constraints(weights, units, scale, monotonicities): + """Approximately projects weights to strictly satisfy all constraints. + + This projeciton guarantees that constraints are strictly met, but it is not + an exact projection w.r.t. the L2 norm. The computational cost is + `O(num_monotonic_dims * num_lattice_weights)`. + + See helper functions `_approximately_project_*` for details of the individual + projection algorithms for each set of constraints. + + Args: + weights: Kronecker-Factored Lattice weights tensor of shape: `(1, + lattice_sizes, units * dims, num_terms)`. + units: Number of units per input dimension. + scale: Scale variable of shape: `(units, num_terms)`. + monotonicities: List or tuple of length dims of elements of {0,1} which + represents monotonicity constraints per dimension. 1 stands for increasing + (non-decreasing in fact), 0 for no monotonicity constraints. + + Returns: + Projected weights tensor of same shape as `weights`. + """ + if utils.count_non_zeros(monotonicities) == 0: + return weights + + # TODO: in the case of only one monotonic dimension, we only have to + # constrain the non-monotonic dimensions to be positive. + # There must be monotonicity constraints, so we need all positive weights. + weights = tf.maximum(weights, 0) + + # Project monotonicity constraints. + weights = _approximately_project_monotonicity(weights, units, scale, + monotonicities) + + return weights + + +def verify_hyperparameters(lattice_sizes=None, + units=None, + num_terms=None, + input_shape=None, + monotonicities=None): + """Verifies that all given hyperparameters are consistent. + + This function does not inspect weights themselves. Only their shape. Use + `assert_constraints()` to assert actual weights against constraints. + + See `tfl.layers.KroneckerFactoredLattice` class level comment for detailed + description of arguments. + + Args: + lattice_sizes: Lattice size to check against. + units: Units hyperparameter of `KroneckerFactoredLattice` layer. + num_terms: Number of independently trained submodels hyperparameter of + `KroneckerFactoredLattice` layer. + input_shape: Shape of layer input. Useful only if `units` and/or + `monotonicities` is set. + monotonicities: Monotonicities hyperparameter of `KroneckerFactoredLattice` + layer. Useful only if `input_shape` is set. + + Raises: + ValueError: If lattice_sizes < 2. + ValueError: If units < 1. + ValueError: If num_terms < 1. + ValueError: If len(monotonicities) does not match number of inputs. + """ + if lattice_sizes and lattice_sizes < 2: + raise ValueError("Lattice size must be at least 2. Given: %s" % + lattice_sizes) + + if units and units < 1: + raise ValueError("Units must be at least 1. Given: %s" % units) + + if num_terms and num_terms < 1: + raise ValueError("Number of terms must be at least 1. Given: %s" % + num_terms) + + # input_shape: (batch, ..., units, dims) + if input_shape: + # It also raises errors if monotonicities is specified incorrectly. + monotonicities = utils.canonicalize_monotonicities( + monotonicities, allow_decreasing=False) + # Extract shape to check units and dims to check monotonicity + if isinstance(input_shape, list): + dims = len(input_shape) + # Check monotonicity. + if monotonicities and len(monotonicities) != dims: + raise ValueError("If input is provided as list of tensors, their number" + " must match monotonicities. 'input_list': %s, " + "'monotonicities': %s" % (input_shape, monotonicities)) + shape = input_shape[0] + else: + dims = input_shape.as_list()[-1] + # Check monotonicity. + if monotonicities and len(monotonicities) != dims: + raise ValueError("Last dimension of input shape must have same number " + "of elements as 'monotonicities'. 'input shape': %s, " + "'monotonicities': %s" % (input_shape, monotonicities)) + shape = input_shape + if units and units > 1 and (len(shape) < 3 or shape[-2] != units): + raise ValueError("If 'units' > 1 then input shape of " + "KroneckerFactoredLattice layer must have rank at least " + "3 where the second from the last dimension is equal to " + "'units'. 'units': %s, 'input_shape: %s" % + (units, input_shape)) + + +def _assert_monotonicity_constraints(weights, units, scale, monotonicities, + eps): + """Asserts that weights satisfy monotonicity constraints. + + Args: + weights: `KroneckerFactoredLattice` weights tensor of shape: `(1, + lattice_sizes, units * dims, num_terms)`. + units: Number of units per input dimension. + scale: Scale variable of shape: `(units, num_terms)`. + monotonicities: Monotonicity constraints. + eps: Allowed constraints violation. + + Returns: + List of monotonicity assertion ops in graph mode or directly executes + assertions in eager mode and returns a list of NoneType elements. + """ + monotonicity_asserts = [] + + # Recall that w.shape is (1, lattice_sizes, units * dims, num_terms). + weights_shape = weights.get_shape().as_list() + _, lattice_sizes, units_times_dims, num_terms = weights_shape + assert units_times_dims % units == 0 + dims = units_times_dims // units + weights = tf.reshape(weights, [-1, lattice_sizes, units, dims, num_terms]) + + # Extract the sign of scale to determine the assertion direction. + direction = tf.expand_dims(tf.sign(scale), axis=1) + + # Unstack our weights given our extracted sign. + weights = tf.unstack(direction * weights, axis=3) + for i, (weight, monotonicity) in enumerate(zip(weights, monotonicities)): + if monotonicity: + keypoints = tf.unstack(weight, axis=1) + for j in range(1, len(keypoints)): + diff = tf.reduce_min(keypoints[j] - keypoints[j - 1]) + monotonicity_asserts.append( + tf.Assert( + diff >= -eps, + data=[ + "Monotonicity violation", "Feature index:", i, + "Min monotonicity diff:", diff, "Upper layer number:", j, + "Epsilon:", eps, "Keypoints:", keypoints[j], + keypoints[j - 1] + ])) + + return monotonicity_asserts + + +def assert_constraints(weights, units, scale, monotonicities, eps=1e-6): + """Asserts that weights satisfy constraints. + + Args: + weights: `KroneckerFactoredLattice` weights tensor of shape: `(1, + lattice_sizes, units * dims, num_terms)`. + units: Number of units per input dimension. + scale: Scale variable of shape: `(units, num_terms)`. + monotonicities: Monotonicity constraints. + eps: Allowed constraints violation. + + Returns: + List of assertion ops in graph mode or directly executes assertions in eager + mode. + """ + asserts = [] + + if monotonicities: + monotonicity_asserts = _assert_monotonicity_constraints( + weights=weights, + units=units, + scale=scale, + monotonicities=monotonicities, + eps=eps) + asserts.extend(monotonicity_asserts) + + return asserts diff --git a/tensorflow_lattice/python/kronecker_factored_lattice_test.py b/tensorflow_lattice/python/kronecker_factored_lattice_test.py new file mode 100644 index 0000000..18dc0e4 --- /dev/null +++ b/tensorflow_lattice/python/kronecker_factored_lattice_test.py @@ -0,0 +1,760 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for KroneckerFactoredLattice Layer.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import tempfile +from absl import logging +from absl.testing import parameterized +import numpy as np +import tensorflow as tf +from tensorflow import keras +from tensorflow_lattice.python import kronecker_factored_lattice_layer as kfll +from tensorflow_lattice.python import test_utils + + +class KroneckerFactoredLatticeTest(parameterized.TestCase, tf.test.TestCase): + + def setUp(self): + super(KroneckerFactoredLatticeTest, self).setUp() + self.disable_all = False + self.disable_ensembles = False + self.loss_eps = 0.0001 + self.small_eps = 1e-6 + self.seed = 42 + + def _ResetAllBackends(self): + keras.backend.clear_session() + tf.compat.v1.reset_default_graph() + + def _ScatterXUniformly(self, num_points, lattice_sizes, input_dims): + """Deterministically generates num_point random points within lattice.""" + np.random.seed(41) + x = [] + for _ in range(num_points): + point = [ + np.random.random() * (lattice_sizes - 1.0) for _ in range(input_dims) + ] + x.append(np.asarray(point)) + if input_dims == 1: + x.sort() + return x + + def _ScatterXUniformlyExtendedRange(self, num_points, lattice_sizes, + input_dims): + """Extends every dimension by 1.0 on both sides and generates points.""" + np.random.seed(41) + x = [] + for _ in range(num_points): + point = [ + np.random.random() * (lattice_sizes + 1.0) - 1.0 + for _ in range(input_dims) + ] + x.append(np.asarray(point)) + if input_dims == 1: + x.sort() + return x + + def _SameValueForAllDims(self, num_points, lattice_sizes, input_dims): + """Generates random point with same value for every dimension.""" + np.random.seed(41) + x = [] + for _ in range(num_points): + rand = np.random.random() * (lattice_sizes - 1.0) + point = [rand] * input_dims + x.append(np.asarray(point)) + if input_dims == 1: + x.sort() + return x + + def _TwoDMeshGrid(self, num_points, lattice_sizes, input_dims): + """Mesh grid for visualisation of 3-d surfaces via pyplot.""" + if input_dims != 2: + raise ValueError("2-d mesh grid is possible only for 2-d lattice. Lattice" + " dimension given: %s" % input_dims) + return test_utils.two_dim_mesh_grid( + num_points=num_points, + x_min=0.0, + y_min=0.0, + x_max=lattice_sizes - 1.0, + y_max=lattice_sizes - 1.0) + + def _TwoDMeshGridExtendedRange(self, num_points, lattice_sizes, input_dims): + """Mesh grid extended by 1.0 on every side.""" + if input_dims != 2: + raise ValueError("2-d mesh grid is possible only for 2-d lattice. Lattice" + " dimension given: %s" % input_dims) + return test_utils.two_dim_mesh_grid( + num_points=num_points, + x_min=-1.0, + y_min=-1.0, + x_max=lattice_sizes, + y_max=lattice_sizes) + + def _Sin(self, x): + return math.sin(x[0]) + + def _SinPlusX(self, x): + return math.sin(x[0]) + x[0] / 3.0 + + def _SinPlusLargeX(self, x): + return math.sin(x[0]) + x[0] + + def _SinPlusXNd(self, x): + res = 0.0 + for y in x: + res = res + math.sin(y) + y / 5.0 + return res + + def _SinOfSum(self, x): + return math.sin(sum(x)) + + def _Square(self, x): + return x[0]**2 + + def _Max(self, x): + return np.amax(x) + + def _ScaledSum(self, x): + result = 0.0 + for y in x: + result += y / len(x) + return result + + def _GetNonMonotonicInitializer(self, weights): + """Tiles given weights along 'units' dimension.""" + dims = len(weights) + + def Initializer(shape, dtype): + _, lattice_sizes, num_inputs, num_terms = shape + units = num_inputs // dims + # Create expanded weights, tile, reshape, return. + return tf.reshape( + tf.tile( + tf.constant( + weights, + shape=[1, lattice_sizes, 1, dims, num_terms], + dtype=dtype), + multiples=[1, 1, units, 1, 1]), shape) + + return Initializer + + def _GetTrainingInputsAndLabels(self, config): + """Generates training inputs and labels. + + Args: + config: Dictionary with config for this unit test. + + Returns: + Tuple `(training_inputs, training_labels, raw_training_inputs)` where + `training_inputs` and `training_labels` are data for training and + `raw_training_inputs` are representation of training_inputs for + visualisation. + """ + raw_training_inputs = config["x_generator"]( + num_points=config["num_training_records"], + lattice_sizes=config["lattice_sizes"], + input_dims=config["input_dims"]) + + if isinstance(raw_training_inputs, tuple): + # This means that raw inputs are 2-d mesh grid. Convert them into list of + # 2-d points. + training_inputs = list(np.dstack(raw_training_inputs).reshape((-1, 2))) + else: + training_inputs = raw_training_inputs + + training_labels = [config["y_function"](x) for x in training_inputs] + return training_inputs, training_labels, raw_training_inputs + + def _SetDefaults(self, config): + config.setdefault("units", 1) + config.setdefault("num_terms", 2) + config.setdefault("monotonicities", None) + config.setdefault("signal_name", "TEST") + config.setdefault("satisfy_constraints_at_every_step", True) + config.setdefault("target_monotonicity_diff", 0.0) + config.setdefault("lattice_index", 0) + + return config + + def _TestEnsemble(self, config): + """Verifies that 'units > 1' lattice produces same output as 'units==1'.""" + # Note that the initialization of the lattice must be the same across the + # units dimension (otherwise the loss will be different). + if self.disable_ensembles: + return + config = dict(config) + config["num_training_epoch"] = 3 + config["kernel_initializer"] = "constant" + losses = [] + for units, lattice_index in [(1, 0), (3, 0), (3, 2)]: + config["units"] = units + config["lattice_index"] = lattice_index + losses.append(self._TrainModel(config)) + self.assertAlmostEqual(min(losses), max(losses), delta=self.loss_eps) + + def _TrainModel(self, config, plot_path=None): + logging.info("Testing config:") + logging.info(config) + config = self._SetDefaults(config) + self._ResetAllBackends() + + training_inputs, training_labels, raw_training_inputs = ( + self._GetTrainingInputsAndLabels(config)) + + units = config["units"] + input_dims = config["input_dims"] + lattice_sizes = config["lattice_sizes"] + if units > 1: + # In order to test multi 'units' lattice replecate inputs 'units' times + # and later use just one out of 'units' outputs in order to ensure that + # multi 'units' lattice trains exactly similar to single 'units' one. + training_inputs = [ + np.tile(np.expand_dims(x, axis=0), reps=[units, 1]) + for x in training_inputs + ] + input_shape = (units, input_dims) + else: + input_shape = (input_dims,) + + keras_layer = kfll.KroneckerFactoredLattice( + lattice_sizes=lattice_sizes, + units=units, + num_terms=config["num_terms"], + monotonicities=config["monotonicities"], + satisfy_constraints_at_every_step=config[ + "satisfy_constraints_at_every_step"], + kernel_initializer=config["kernel_initializer"], + input_shape=input_shape, + dtype=tf.float32) + model = keras.models.Sequential() + model.add(keras_layer) + + # When we use multi-unit lattices, we only extract a single lattice for + # testing. + if units > 1: + lattice_index = config["lattice_index"] + model.add( + keras.layers.Lambda(lambda x: x[:, lattice_index:lattice_index + 1])) + + optimizer = config["optimizer"](learning_rate=config["learning_rate"]) + model.compile(loss=keras.losses.mean_squared_error, optimizer=optimizer) + + training_data = (training_inputs, training_labels, raw_training_inputs) + loss = test_utils.run_training_loop( + config=config, + training_data=training_data, + keras_model=model, + plot_path=plot_path) + + if tf.executing_eagerly(): + tf.print("final weights: ", keras_layer.kernel) + assetion_ops = keras_layer.assert_constraints( + eps=-config["target_monotonicity_diff"]) + if not tf.executing_eagerly() and assetion_ops: + tf.compat.v1.keras.backend.get_session().run(assetion_ops) + + return loss + + def testMonotonicityOneD(self): + if self.disable_all: + return + monotonicities = [1] + kernel_initializer = kfll.RandomMonotonicInitializer( + monotonicities=monotonicities, seed=self.seed) + config = { + "lattice_sizes": 20, + "input_dims": 1, + "num_training_records": 128, + "num_training_epoch": 50, + "optimizer": tf.keras.optimizers.Adagrad, + "learning_rate": 1.0, + "x_generator": self._ScatterXUniformly, + "y_function": self._SinPlusX, + "monotonicities": monotonicities, + "kernel_initializer": kernel_initializer, + } # pyformat: disable + loss = self._TrainModel(config) + self.assertAlmostEqual(loss, 0.123856, delta=self.loss_eps) + self._TestEnsemble(config) + + monotonicities = ["increasing"] + kernel_initializer = kfll.RandomMonotonicInitializer( + monotonicities=monotonicities, seed=self.seed) + config = { + "lattice_sizes": 20, + "input_dims": 1, + "num_training_records": 100, + "num_training_epoch": 50, + "optimizer": tf.keras.optimizers.Adagrad, + "learning_rate": 1.0, + "x_generator": self._ScatterXUniformly, + "y_function": lambda x: -self._SinPlusX(x), + "monotonicities": monotonicities, + "kernel_initializer": kernel_initializer, + } # pyformat: disable + loss = self._TrainModel(config) + self.assertAlmostEqual(loss, 2.841356, delta=self.loss_eps) + self._TestEnsemble(config) + + monotonicities = [1] + kernel_initializer = kfll.RandomMonotonicInitializer( + monotonicities=monotonicities, seed=self.seed) + config = { + "lattice_sizes": 5, + "input_dims": 1, + "num_training_records": 100, + "num_training_epoch": 200, + "optimizer": tf.keras.optimizers.Adagrad, + "learning_rate": 1.0, + "x_generator": self._ScatterXUniformly, + "y_function": self._SinPlusLargeX, + "monotonicities": monotonicities, + "kernel_initializer": kernel_initializer, + # Target function is strictly increasing. + "target_monotonicity_diff": 1e-6, + } # pyformat: disable + loss = self._TrainModel(config) + self.assertAlmostEqual(loss, 0.000780, delta=self.loss_eps) + + def testMonotonicityTwoD(self): + if self.disable_all: + return + monotonicities = [1, 1] + kernel_initializer = kfll.RandomMonotonicInitializer( + monotonicities=monotonicities, seed=self.seed) + config = { + "lattice_sizes": 21, + "input_dims": 2, + "num_training_records": 900, + "num_training_epoch": 100, + "optimizer": tf.keras.optimizers.Adagrad, + "learning_rate": 1.0, + "x_generator": self._TwoDMeshGrid, + "y_function": self._SinPlusXNd, + "monotonicities": monotonicities, + "kernel_initializer": kernel_initializer, + } # pyformat: disable + loss = self._TrainModel(config) + self.assertAlmostEqual(loss, 0.562003, delta=self.loss_eps) + self._TestEnsemble(config) + + monotonicities = ["none", "increasing"] + kernel_initializer = kfll.RandomMonotonicInitializer( + monotonicities=monotonicities, seed=self.seed) + config = { + "lattice_sizes": 21, + "input_dims": 2, + "num_training_records": 900, + "num_training_epoch": 100, + "optimizer": tf.keras.optimizers.Adagrad, + "learning_rate": 1.0, + "x_generator": self._TwoDMeshGrid, + "y_function": self._SinPlusXNd, + "monotonicities": monotonicities, + "kernel_initializer": kernel_initializer, + } # pyformat: disable + loss = self._TrainModel(config) + self.assertAlmostEqual(loss, 0.222727, delta=self.loss_eps) + self._TestEnsemble(config) + + monotonicities = [1, 0] + kernel_initializer = kfll.RandomMonotonicInitializer( + monotonicities=monotonicities, seed=self.seed) + config = { + "lattice_sizes": 21, + "input_dims": 2, + "num_training_records": 900, + "num_training_epoch": 100, + "optimizer": tf.keras.optimizers.Adagrad, + "learning_rate": 1.5, + "x_generator": self._TwoDMeshGrid, + "y_function": self._SinPlusXNd, + "monotonicities": monotonicities, + "kernel_initializer": kernel_initializer, + } # pyformat: disable + loss = self._TrainModel(config) + self.assertAlmostEqual(loss, 0.498311, delta=self.loss_eps) + self._TestEnsemble(config) + + monotonicities = [1, 1] + kernel_initializer = kfll.RandomMonotonicInitializer( + monotonicities=monotonicities, seed=self.seed) + config = { + "lattice_sizes": 2, + "input_dims": 2, + "num_training_records": 100, + "num_training_epoch": 20, + "optimizer": tf.keras.optimizers.Adagrad, + "learning_rate": 1.0, + "x_generator": self._TwoDMeshGrid, + "y_function": lambda x: -self._ScaledSum(x), + "monotonicities": monotonicities, + "kernel_initializer": kernel_initializer, + } # pyformat: disable + loss = self._TrainModel(config) + self.assertAlmostEqual(loss, 0.050929, delta=self.loss_eps) + self._TestEnsemble(config) + + def testMonotonicity5d(self): + if self.disable_all: + return + config = { + "lattice_sizes": 2, + "input_dims": 5, + "num_training_records": 100, + "num_training_epoch": 200, + "optimizer": tf.keras.optimizers.Adagrad, + "learning_rate": 1.0, + "x_generator": self._ScatterXUniformly, + "y_function": self._ScaledSum, + "monotonicities": [1, 1, 1, 1, 1], + "kernel_initializer": keras.initializers.Constant(value=0.5), + # Function is strictly increasing everywhere, so request monotonicity + # diff to be strictly positive. + "target_monotonicity_diff": 0.08, + } # pyformat: disable + loss = self._TrainModel(config) + self.assertAlmostEqual(loss, 0.000524, delta=self.loss_eps) + + monotonicities = [1, 1, 1, 1, 1] + kernel_initializer = kfll.RandomMonotonicInitializer( + monotonicities=monotonicities, seed=self.seed) + config = { + "lattice_sizes": 2, + "input_dims": 5, + "num_training_records": 100, + "num_training_epoch": 40, + "optimizer": tf.keras.optimizers.Adagrad, + "learning_rate": 1.0, + "x_generator": self._ScatterXUniformly, + "y_function": lambda x: -self._ScaledSum(x), + "monotonicities": monotonicities, + "kernel_initializer": kernel_initializer, + } # pyformat: disable + loss = self._TrainModel(config) + self.assertAlmostEqual(loss, 0.015019, delta=self.loss_eps) + self._TestEnsemble(config) + + monotonicities = [1, "increasing", 1, 1] + kernel_initializer = kfll.RandomMonotonicInitializer( + monotonicities=monotonicities, seed=self.seed) + config = { + "lattice_sizes": 3, + "input_dims": 4, + "num_training_records": 100, + "num_training_epoch": 100, + "optimizer": tf.keras.optimizers.Adagrad, + "learning_rate": 1.0, + "x_generator": self._ScatterXUniformly, + "y_function": self._SinOfSum, + "monotonicities": monotonicities, + "kernel_initializer": kernel_initializer, + } # pyformat: disable + loss = self._TrainModel(config) + self.assertAlmostEqual(loss, 0.377306, delta=self.loss_eps) + self._TestEnsemble(config) + + @parameterized.parameters( + ([0, 1, 1],), + ([1, 0, 1],), + ([1, 1, 0],), + ) + def testMonotonicityEquivalence(self, monotonicities): + if self.disable_all: + return + config = { + "lattice_sizes": 3, + "input_dims": 3, + "monotonicities": monotonicities, + "num_training_records": 100, + "num_training_epoch": 50, + "optimizer": tf.keras.optimizers.Adagrad, + "learning_rate": 10.0, + "x_generator": self._SameValueForAllDims, + "y_function": self._SinOfSum, + "kernel_initializer": "zeros", + } # pyformat: disable + loss = self._TrainModel(config) + self.assertAlmostEqual(loss, 0.550760, delta=self.loss_eps) + self._TestEnsemble(config) + + def testMonotonicity10dAlmostMonotone(self): + if self.disable_all: + return + np.random.seed(4411) + num_weights = 1024 + weights = [1.0 * i / num_weights for i in range(num_weights)] + for _ in range(10): + i = int(np.random.random() * num_weights) + weights[i] = 0.0 + + config = { + "lattice_sizes": 2, + "input_dims": 10, + "num_terms": 128, + "num_training_records": 1000, + "num_training_epoch": 100, + "optimizer": tf.keras.optimizers.Adagrad, + "learning_rate": 1.0, + "x_generator": self._ScatterXUniformly, + "y_function": test_utils.get_hypercube_interpolation_fn(weights), + "monotonicities": [1] * 10, + "kernel_initializer": "zeros", + } # pyformat: disable + loss = self._TrainModel(config) + self.assertAlmostEqual(loss, 0.025735, delta=self.loss_eps) + + config["monotonicities"] = [0, 1, 0, 1, 1, 0, 1, 1, 1, 0] + loss = self._TrainModel(config) + self.assertAlmostEqual(loss, 0.025735, delta=self.loss_eps) + self._TestEnsemble(config) + + def testMonotonicity10dSinOfSum(self): + if self.disable_all: + return + monotonicities = [1] * 10 + kernel_initializer = kfll.RandomMonotonicInitializer( + monotonicities=monotonicities, seed=self.seed) + config = { + "lattice_sizes": 2, + "input_dims": 10, + "num_training_records": 1000, + "num_training_epoch": 100, + "optimizer": tf.keras.optimizers.Adagrad, + "learning_rate": 1.0, + "x_generator": self._ScatterXUniformly, + "y_function": self._SinOfSum, + "monotonicities": [1] * 10, + "kernel_initializer": kernel_initializer, + } # pyformat: disable + loss = self._TrainModel(config) + self.assertAlmostEqual(loss, 0.179642, delta=self.loss_eps) + + monotonicities = [0, 1, 0, 1, 1, 0, 1, 1, 1, 0] + kernel_initializer = kfll.RandomMonotonicInitializer( + monotonicities=monotonicities, seed=self.seed) + config["monotonicities"] = monotonicities + config["kernel_initializer"] = kernel_initializer + loss = self._TrainModel(config) + self.assertAlmostEqual(loss, 0.181125, delta=self.loss_eps) + + @parameterized.parameters( + # Custom TFL initializer: + ("random_monotonic_initializer", 2.405433), + # Standard Keras initializer: + (keras.initializers.Constant(value=1.5), 2.140740), + # Standard Keras initializer specified as string constant: + ("zeros", 2.140740), + ) + def testInitializerType(self, initializer, expected_loss): + if self.disable_all: + return + if initializer == "random_monotonic_initializer": + initializer = kfll.RandomMonotonicInitializer( + monotonicities=None, seed=self.seed) + config = { + "lattice_sizes": 3, + "input_dims": 2, + "num_training_records": 100, + "num_training_epoch": 0, + "optimizer": tf.keras.optimizers.Adagrad, + "learning_rate": 1.0, + "x_generator": self._TwoDMeshGrid, + "y_function": self._Max, + "kernel_initializer": initializer, + } # pyformat: disable + loss = self._TrainModel(config) + self.assertAlmostEqual(loss, expected_loss, delta=self.loss_eps) + self._TestEnsemble(config) + + def testAssertMonotonicity(self): + if self.disable_all: + return + # Specify non monotonic initializer and do 0 training iterations so no + # projections are being executed. + config = { + "lattice_sizes": 2, + "input_dims": 2, + "num_training_records": 100, + "num_training_epoch": 0, + "optimizer": tf.keras.optimizers.Adagrad, + "learning_rate": 0.15, + "x_generator": self._TwoDMeshGrid, + "y_function": self._ScaledSum, + "monotonicities": [0, 0], + "kernel_initializer": self._GetNonMonotonicInitializer( + weights=[ + [[4.0, 3.0], [4.0, 3.0]], + [[2.0, 1.0], [2.0, 1.0]] + ]) + } # pyformat: disable + loss = self._TrainModel(config) + self.assertAlmostEqual(loss, 4.458333, delta=self.loss_eps) + + for monotonicity in [[0, 1], [1, 0], [1, 1]]: + for units in [1, 3]: + config["monotonicities"] = monotonicity + config["units"] = units + with self.assertRaises(tf.errors.InvalidArgumentError): + self._TrainModel(config) + + def testInputOutOfBounds(self): + if self.disable_all: + return + config = { + "lattice_sizes": 6, + "input_dims": 1, + "num_training_records": 100, + "num_training_epoch": 20, + "optimizer": tf.keras.optimizers.Adagrad, + "learning_rate": 1.0, + "x_generator": self._ScatterXUniformlyExtendedRange, + "y_function": self._Sin, + "kernel_initializer": keras.initializers.Zeros(), + } # pyformat: disable + loss = self._TrainModel(config) + self.assertAlmostEqual(loss, 0.018726, delta=self.loss_eps) + self._TestEnsemble(config) + + kernel_initializer = kfll.RandomMonotonicInitializer( + monotonicities=None, seed=self.seed) + config = { + "lattice_sizes": 2, + "input_dims": 2, + "num_training_records": 100, + "num_training_epoch": 20, + "optimizer": tf.keras.optimizers.Adagrad, + "learning_rate": 1.0, + "x_generator": self._TwoDMeshGridExtendedRange, + "y_function": self._SinOfSum, + "kernel_initializer": kernel_initializer, + } # pyformat: disable + loss = self._TrainModel(config) + self.assertAlmostEqual(loss, 0.130816, delta=self.loss_eps) + self._TestEnsemble(config) + + def testHighDimensionsStressTest(self): + if self.disable_all: + return + monotonicities = [0] * 16 + monotonicities[3], monotonicities[4], monotonicities[10] = (1, 1, 1) + kernel_initializer = kfll.RandomMonotonicInitializer( + monotonicities=monotonicities, seed=self.seed) + config = { + "lattice_sizes": 2, + "input_dims": 16, + "num_terms": 128, + "units": 2, + "monotonicities": monotonicities, + "num_training_records": 100, + "num_training_epoch": 3, + "optimizer": tf.keras.optimizers.Adagrad, + "learning_rate": 1.0, + "x_generator": self._ScatterXUniformly, + "y_function": self._SinOfSum, + "kernel_initializer": kernel_initializer, + "target_monotonicity_diff": -1e-5, + } # pyformat: disable + loss = self._TrainModel(config) + self.assertAlmostEqual(loss, 0.224262, delta=self.loss_eps) + + @parameterized.parameters( + (2, 5, 2, 46), + (2, 6, 4, 46), + (2, 9, 2, 46), + (3, 5, 4, 56), + (3, 9, 2, 56), + ) + def testGraphSize(self, lattice_sizes, input_dims, num_terms, + expected_graph_size): + # If this test failed then you modified core lattice interpolation logic in + # a way which increases number of ops in the graph. Or maybe Keras team + # changed something under the hood. Please ensure that this increase is + # unavoidable and try to minimize it. + if self.disable_all: + return + tf.compat.v1.disable_eager_execution() + tf.compat.v1.reset_default_graph() + + layer = kfll.KroneckerFactoredLattice( + lattice_sizes=lattice_sizes, num_terms=num_terms) + input_tensor = tf.ones(shape=(1, input_dims)) + layer(input_tensor) + graph_size = len(tf.compat.v1.get_default_graph().as_graph_def().node) + + self.assertLessEqual(graph_size, expected_graph_size) + + @parameterized.parameters( + ("random_uniform", tf.keras.initializers.RandomUniform), + ("random_monotonic_initializer", kfll.RandomMonotonicInitializer)) + def testCreateKernelInitializer(self, kernel_initializer_id, expected_type): + self.assertEqual( + expected_type, + type( + kfll.create_kernel_initializer( + kernel_initializer_id, monotonicities=None))) + + # We test that the scale variable attribute of our KroneckerFactoredLattice + # is the same object as the scale contained in the constraint on the kernel, + # both before and after save/load. We test this because we must make sure that + # any updates to the scale variable (before/after save/load) are consistent + # across all uses of the object. + def testSavingLoadingScale(self): + # Create simple x --> x^2 dataset. + train_data = [[float(x), float(x)**2] for x in range(100)] + train_x, train_y = zip(*train_data) + # Construct simple single lattice model. Must have monotonicities specified + # or constraint will be None. + keras_layer = kfll.KroneckerFactoredLattice( + lattice_sizes=2, monotonicities=[1]) + model = keras.models.Sequential() + model.add(keras_layer) + # Compile and fit the model. + model.compile( + loss="mse", optimizer=keras.optimizers.Adam(learning_rate=0.1)) + model.fit(train_x, train_y) + # Extract scale from layer and constraint before save. + layer_scale = keras_layer.scale + constraint_scale = keras_layer.kernel.constraint.scale + self.assertIs(layer_scale, constraint_scale) + # Save and load the model. + with tempfile.NamedTemporaryFile(suffix=".h5") as f: + keras.models.save_model(model, f.name) + loaded_model = keras.models.load_model( + f.name, + custom_objects={ + "KroneckerFactoredLattice": + kfll.KroneckerFactoredLattice, + "KroneckerFactoredLatticeConstraint": + kfll.KroneckerFactoredLatticeConstraints + }) + # Extract loaded layer. + loaded_keras_layer = loaded_model.layers[0] + # Extract scale from layer and constraint after load. + loaded_layer_scale = loaded_keras_layer.scale + loaded_constraint_scale = loaded_keras_layer.kernel.constraint.scale + self.assertIs(loaded_layer_scale, loaded_constraint_scale) + # Train for another epoch and test equality of all updated elements just to + # be safe. + loaded_model.fit(train_x, train_y) + self.assertAllEqual(loaded_layer_scale, loaded_constraint_scale) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_lattice/python/lattice_layer.py b/tensorflow_lattice/python/lattice_layer.py index fa7d3a4..0bac43e 100644 --- a/tensorflow_lattice/python/lattice_layer.py +++ b/tensorflow_lattice/python/lattice_layer.py @@ -23,6 +23,7 @@ from __future__ import print_function from . import lattice_lib +from . import utils import six import tensorflow as tf from tensorflow import keras @@ -35,7 +36,7 @@ class Lattice(keras.layers.Layer): # pyformat: disable """Lattice layer. - Layer performs interpolation using one of `units` d-dimension lattices with + Layer performs interpolation using one of `units` d-dimensional lattices with arbitrary number of keypoints per dimension. There are trainable weights associated with lattice vertices. Input to this layer is considered to be a d-dimensional point within the lattice. If point coincides with one of the @@ -322,12 +323,8 @@ def __init__(self, self.interpolation = interpolation self.kernel_initializer = create_kernel_initializer( - kernel_initializer, - self.lattice_sizes, - self.monotonicities, - self.output_min, - self.output_max, - self.unimodalities, + kernel_initializer, self.lattice_sizes, self.monotonicities, + self.output_min, self.output_max, self.unimodalities, self.joint_unimodalities) self.kernel_regularizer = [] @@ -435,7 +432,7 @@ def build(self, input_shape): def call(self, inputs): """Standard Keras call() method.""" # Use control dependencies to save lattice sizes as graph constant for - # visualisation toolbox to be able to recove it from saved graph. + # visualisation toolbox to be able to recover it from saved graph. # Wrap this constant into pure op since in TF 2.0 there are issues passing # tensors into control_dependencies. with tf.control_dependencies([tf.identity(self.lattice_sizes_tensor)]): @@ -522,10 +519,10 @@ def assert_constraints(self, eps=1e-6): return lattice_lib.assert_constraints( weights=self.kernel, lattice_sizes=self.lattice_sizes, - monotonicities=lattice_lib.canonicalize_monotonicities( - self.monotonicities), - edgeworth_trusts=lattice_lib.canonicalize_trust(self.edgeworth_trusts), - trapezoid_trusts=lattice_lib.canonicalize_trust(self.trapezoid_trusts), + monotonicities=utils.canonicalize_monotonicities( + self.monotonicities, allow_decreasing=False), + edgeworth_trusts=utils.canonicalize_trust(self.edgeworth_trusts), + trapezoid_trusts=utils.canonicalize_trust(self.trapezoid_trusts), monotonic_dominances=self.monotonic_dominances, range_dominances=self.range_dominances, joint_monotonicities=self.joint_monotonicities, @@ -535,14 +532,9 @@ def assert_constraints(self, eps=1e-6): eps=eps) -def create_kernel_initializer( - kernel_initializer_id, - lattice_sizes, - monotonicities, - output_min, - output_max, - unimodalities, - joint_unimodalities): +def create_kernel_initializer(kernel_initializer_id, lattice_sizes, + monotonicities, output_min, output_max, + unimodalities, joint_unimodalities): """Returns a kernel Keras initializer object from its id. This function is used to convert the 'kernel_initializer' parameter in the @@ -550,22 +542,24 @@ def create_kernel_initializer( Args: kernel_initializer_id: See the documentation of the 'kernel_initializer' - parameter in the constructor of tfl.Lattice. + parameter in the constructor of tfl.Lattice. lattice_sizes: See the documentation of the same parameter in the - constructor of tfl.Lattice. + constructor of tfl.Lattice. monotonicities: See the documentation of the same parameter in the - constructor of tfl.Lattice. - output_min: See the documentation of the same parameter in the - constructor of tfl.Lattice. - output_max: See the documentation of the same parameter in the - constructor of tfl.Lattice. + constructor of tfl.Lattice. + output_min: See the documentation of the same parameter in the constructor + of tfl.Lattice. + output_max: See the documentation of the same parameter in the constructor + of tfl.Lattice. unimodalities: See the documentation of the same parameter in the - constructor of tfl.Lattice. + constructor of tfl.Lattice. joint_unimodalities: See the documentation of the same parameter in the - constructor of tfl.Lattice. + constructor of tfl.Lattice. + Returns: The Keras initializer object for the tfl.Lattice kernel variable. """ + def default_params(output_min, output_max): """Return reasonable default parameters if not defined explicitly.""" if output_min is not None: @@ -611,8 +605,9 @@ def do_joint_unimodalities_contain_all_features(joint_unimodalities): output_min=output_init_min, output_max=output_init_max, unimodalities=all_unimodalities) - elif kernel_initializer_id in ["random_monotonic_initializer", - "RandomMonotonicInitializer"]: + elif kernel_initializer_id in [ + "random_monotonic_initializer", "RandomMonotonicInitializer" + ]: output_init_min, output_init_max = default_params(output_min, output_max) return RandomMonotonicInitializer( @@ -620,23 +615,16 @@ def do_joint_unimodalities_contain_all_features(joint_unimodalities): output_min=output_init_min, output_max=output_init_max, unimodalities=all_unimodalities) - elif kernel_initializer_id in ["random_uniform_or_linear_initializer", - "RandomUniformOrLinearInitializer"]: + elif kernel_initializer_id in [ + "random_uniform_or_linear_initializer", "RandomUniformOrLinearInitializer" + ]: if do_joint_unimodalities_contain_all_features(joint_unimodalities): - return create_kernel_initializer("random_uniform", - lattice_sizes, - monotonicities, - output_min, - output_max, - unimodalities, - joint_unimodalities) - return create_kernel_initializer("linear_initializer", - lattice_sizes, - monotonicities, - output_min, - output_max, - unimodalities, - joint_unimodalities) + return create_kernel_initializer("random_uniform", lattice_sizes, + monotonicities, output_min, output_max, + unimodalities, joint_unimodalities) + return create_kernel_initializer("linear_initializer", lattice_sizes, + monotonicities, output_min, output_max, + unimodalities, joint_unimodalities) else: # This is needed for Keras deserialization logic to be aware of our custom # objects. @@ -714,10 +702,9 @@ def __call__(self, shape, dtype=None, partition_info=None): del partition_info return lattice_lib.linear_initializer( lattice_sizes=self.lattice_sizes, - monotonicities=lattice_lib.canonicalize_monotonicities( - self.monotonicities), - unimodalities=lattice_lib.canonicalize_unimodalities( - self.unimodalities), + monotonicities=utils.canonicalize_monotonicities( + self.monotonicities, allow_decreasing=False), + unimodalities=utils.canonicalize_unimodalities(self.unimodalities), output_min=self.output_min, output_max=self.output_max, units=shape[1], @@ -749,11 +736,7 @@ class RandomMonotonicInitializer(keras.initializers.Initializer): """ # pyformat: enable - def __init__(self, - lattice_sizes, - output_min, - output_max, - unimodalities=None): + def __init__(self, lattice_sizes, output_min, output_max, unimodalities=None): """Initializes an instance of `RandomMonotonicInitializer`. Args: @@ -843,8 +826,7 @@ def __init__(self, range_dominances: Same meaning as corresponding parameter of `Lattice`. joint_monotonicities: Same meaning as corresponding parameter of `Lattice`. - joint_unimodalities: Same meaning as corresponding parameter of - `Lattice`. + joint_unimodalities: Same meaning as corresponding parameter of `Lattice`. output_min: Minimum possible output. output_max: Maximum possible output. num_projection_iterations: Same meaning as corresponding parameter of @@ -867,10 +849,11 @@ def __init__(self, joint_unimodalities=joint_unimodalities) self.lattice_sizes = lattice_sizes - self.monotonicities = monotonicities - self.unimodalities = unimodalities - self.edgeworth_trusts = edgeworth_trusts - self.trapezoid_trusts = trapezoid_trusts + self.monotonicities = utils.canonicalize_monotonicities( + monotonicities, allow_decreasing=False) + self.unimodalities = utils.canonicalize_unimodalities(unimodalities) + self.edgeworth_trusts = utils.canonicalize_trust(edgeworth_trusts) + self.trapezoid_trusts = utils.canonicalize_trust(trapezoid_trusts) self.monotonic_dominances = monotonic_dominances self.range_dominances = range_dominances self.joint_monotonicities = joint_monotonicities @@ -879,31 +862,23 @@ def __init__(self, self.output_max = output_max self.num_projection_iterations = num_projection_iterations self.enforce_strict_monotonicity = enforce_strict_monotonicity + self.num_constraint_dims = utils.count_non_zeros(self.monotonicities, + self.unimodalities) def __call__(self, w): """Applies constraints to `w`.""" - canonical_monotonicities = lattice_lib.canonicalize_monotonicities( - self.monotonicities) - canonical_unimodalities = lattice_lib.canonicalize_unimodalities( - self.unimodalities) - canonical_edgeworth_trusts = lattice_lib.canonicalize_trust( - self.edgeworth_trusts) - canonical_trapezoid_trusts = lattice_lib.canonicalize_trust( - self.trapezoid_trusts) - num_constraint_dims = lattice_lib.count_non_zeros(canonical_monotonicities, - canonical_unimodalities) # No need to separately check for trust constraints and monotonic dominance, # since monotonicity is required to impose them. The only exception is joint # monotonicity. - if (num_constraint_dims > 0 or self.joint_monotonicities or + if (self.num_constraint_dims > 0 or self.joint_monotonicities or self.joint_unimodalities): w = lattice_lib.project_by_dykstra( w, lattice_sizes=self.lattice_sizes, - monotonicities=canonical_monotonicities, - unimodalities=canonical_unimodalities, - edgeworth_trusts=canonical_edgeworth_trusts, - trapezoid_trusts=canonical_trapezoid_trusts, + monotonicities=self.monotonicities, + unimodalities=self.unimodalities, + edgeworth_trusts=self.edgeworth_trusts, + trapezoid_trusts=self.trapezoid_trusts, monotonic_dominances=self.monotonic_dominances, range_dominances=self.range_dominances, joint_monotonicities=self.joint_monotonicities, @@ -913,9 +888,9 @@ def __call__(self, w): w = lattice_lib.finalize_constraints( w, lattice_sizes=self.lattice_sizes, - monotonicities=canonical_monotonicities, - edgeworth_trusts=canonical_edgeworth_trusts, - trapezoid_trusts=canonical_trapezoid_trusts, + monotonicities=self.monotonicities, + edgeworth_trusts=self.edgeworth_trusts, + trapezoid_trusts=self.trapezoid_trusts, output_min=self.output_min, output_max=self.output_max) # TODO: come up with a better solution than separately applying diff --git a/tensorflow_lattice/python/lattice_lib.py b/tensorflow_lattice/python/lattice_lib.py index 80d9771..82c6907 100644 --- a/tensorflow_lattice/python/lattice_lib.py +++ b/tensorflow_lattice/python/lattice_lib.py @@ -21,6 +21,8 @@ import copy import itertools import math + +from . import utils from absl import logging import numpy as np import six @@ -62,6 +64,9 @@ def evaluate_with_simplex_interpolation(inputs, kernel, units, lattice_sizes, layer for which interpolation is being computed. clip_inputs: Whether inputs should be clipped to the input range of the lattice. + + Returns: + Tensor of shape: `(batch_size, ..., units)`. """ if isinstance(inputs, list): inputs = tf.concat(inputs, axis=-1) @@ -74,11 +79,6 @@ def evaluate_with_simplex_interpolation(inputs, kernel, units, lattice_sizes, input_dim = len(inputs.shape) all_size_2 = all(size == 2 for size in lattice_sizes) - if input_dim == 2: - units = 1 - else: - units = int(inputs.shape[-2]) - # Strides are the changes in the global index (index into the flattened # parameters) when moving across each dimension. # E.g. for 2x2x2, strides are [4, 2, 1]. @@ -158,16 +158,23 @@ def evaluate_with_hypercube_interpolation(inputs, kernel, units, lattice_sizes, hypercube into simplices", D.G. Mead, Proceedings of the AMS, 76:2, Sep. 1979. Args: - inputs: Tensor of shape: `(batch_size, ..., len(lattice_sizes))` or list of - `len(lattice_sizes)` tensors of same shape `(batch_size, ..., 1)` which - represents points to apply lattice interpolation to. A typical shape is - `(batch_size, len(lattice_sizes))`. + inputs: Tensor representing points to apply lattice interpolation to. If + units = 1, tensor should be of shape: `(batch_size, ..., + len(lattice_sizes))` or list of `len(lattice_sizes)` tensors of same + shape `(batch_size, ..., 1)`. + If units > 1, tensor should be of shape: `(batch_size, ..., units, + len(lattice_sizes))` or list of `len(lattice_sizes)` tensors of same + shape `(batch_size, ..., units, 1)`. A typical shape is `(batch_size, + len(lattice_sizes))`. kernel: Lattice kernel of shape (num_params_per_lattice, units). units: Output dimension of the lattice. lattice_sizes: List or tuple of integers which represents lattice sizes of layer for which interpolation is being computed. clip_inputs: Whether inputs should be clipped to the input range of the lattice. + + Returns: + Tensor of shape: `(batch_size, ..., units)`. """ interpolation_weights = compute_interpolation_weights( inputs=inputs, lattice_sizes=lattice_sizes, clip_inputs=clip_inputs) @@ -271,10 +278,9 @@ def batch_outer_operation(list_of_tensors, operation="auto"): Args: list_of_tensors: List of tensors of same shape `(batch_size, ..., k[i])` where everything expect `k_i` matches. - operation: - - binary TF operation which supports broadcasting to be applied. + operation: - binary TF operation which supports broadcasting to be applied. - string "auto" in order to apply tf.multiply for first several tensors - and tf.matmul for remaining. + and tf.matmul for remaining. Returns: Tensor of shape: `(batch_size, ..., mul_i(k[i]))`. @@ -455,7 +461,7 @@ def linear_initializer(lattice_sizes, if unimodalities is None: unimodalities = [0] * len(lattice_sizes) - num_constraint_dims = count_non_zeros(monotonicities, unimodalities) + num_constraint_dims = utils.count_non_zeros(monotonicities, unimodalities) if num_constraint_dims == 0: monotonicities = [1] * len(lattice_sizes) num_constraint_dims = len(lattice_sizes) @@ -1026,7 +1032,7 @@ def finalize_constraints(weights, Returns: Projected weights tensor of same shape as `weights`. """ - if count_non_zeros(monotonicities) == 0: + if utils.count_non_zeros(monotonicities) == 0: return weights units = weights.shape[1] if units > 1: @@ -1709,8 +1715,8 @@ def _project_partial_joint_unimodality(weights, lattice_sizes, vertices=all_vertices) -def _project_onto_hyperplane(weights, joint_unimodalities, - hyperplane, vertices): +def _project_onto_hyperplane(weights, joint_unimodalities, hyperplane, + vertices): """Projects onto hyperplane. Args: @@ -1749,8 +1755,9 @@ def _project_onto_hyperplane(weights, joint_unimodalities, # 4) Come up with better option. dimensions, direction = joint_unimodalities layers = _unstack_nd(weights, dims=dimensions) - affected_weights = [_get_element(lists=layers, indices=position) - for position in vertices] + affected_weights = [ + _get_element(lists=layers, indices=position) for position in vertices + ] affected_weights = tf.stack(affected_weights, axis=-1) violation = tf.reduce_sum(affected_weights * hyperplane, axis=-1) @@ -1838,7 +1845,7 @@ def project_by_dykstra(weights, """ if num_iterations == 0: return weights - if (count_non_zeros(monotonicities, unimodalities) == 0 and + if (utils.count_non_zeros(monotonicities, unimodalities) == 0 and not joint_monotonicities and not joint_unimodalities and not range_dominances): return weights @@ -1955,8 +1962,7 @@ def body(iteration, weights, last_change): rolled_back_weights = weights - last_change[ ("RANGE_DOMINANCE", constraint, constraint_group)] weights = _project_partial_range_dominance(rolled_back_weights, - lattice_sizes, - constraint, + lattice_sizes, constraint, constraint_group) last_change[("RANGE_DOMINANCE", constraint, constraint_group)] = weights - rolled_back_weights @@ -2273,8 +2279,7 @@ def verify_hyperparameters(lattice_sizes, range_dominances: Range dominances hyperparameter of `Lattice` layer. joint_monotonicities: Joint monotonicities hyperparameter of `Lattice` layer. - joint_unimodalities: Joint unimodalities hyperparameter of `Lattice` - layer. + joint_unimodalities: Joint unimodalities hyperparameter of `Lattice` layer. output_min: Minimum output of `Lattice` layer. output_max: Maximum output of `Lattice` layer. regularization_amount: Regularization amount for regularizers. @@ -2290,14 +2295,15 @@ def verify_hyperparameters(lattice_sizes, lattice_sizes) # It also raises errors if monotonicities specified incorrectly. - monotonicities = canonicalize_monotonicities(monotonicities) + monotonicities = utils.canonicalize_monotonicities( + monotonicities, allow_decreasing=False) if monotonicities is not None: if len(monotonicities) != len(lattice_sizes): raise ValueError("If provided 'monotonicities' should have same number " "of elements as 'lattice_sizes'. 'monotonicities': %s," "'lattice_sizes: %s" % (monotonicities, lattice_sizes)) - unimodalities = canonicalize_unimodalities(unimodalities) + unimodalities = utils.canonicalize_unimodalities(unimodalities) if unimodalities is not None: if len(unimodalities) != len(lattice_sizes): raise ValueError("If provided 'unimodalities' should have same number " @@ -2318,8 +2324,8 @@ def verify_hyperparameters(lattice_sizes, "'monotonicities': %s, 'unimodalities': %s" % (i, monotonicities, unimodalities)) - all_trusts = canonicalize_trust((edgeworth_trusts or []) + - (trapezoid_trusts or [])) or [] + all_trusts = utils.canonicalize_trust((edgeworth_trusts or []) + + (trapezoid_trusts or [])) or [] main_dims, cond_dims, trapezoid_cond_dims = set(), set(), set() dim_pairs_direction = {} for i, constraint in enumerate(all_trusts): @@ -2669,108 +2675,6 @@ def assert_constraints(weights, return asserts -def count_non_zeros(*iterables): - """Returns total number of non 0 elements in given iterables.""" - result = 0 - for iterable in iterables: - if iterable is not None: - result += [element != 0 for element in iterable].count(True) - return result - - -def canonicalize_monotonicities(monotonicities): - """Converts string constants representing monotonicities into integers. - - Args: - monotonicities: monotonicities hyperparameter of `Lattice` layer. - - Raises: - ValueError if one of monotonicities is invalid. - - Returns: - monotonicities represented as 0 or 1. - """ - if monotonicities: - canonicalized = [] - for item in monotonicities: - if item in [0, 1]: - canonicalized.append(item) - elif isinstance(item, six.string_types) and item.lower() == "increasing": - canonicalized.append(1) - elif isinstance(item, six.string_types) and item.lower() == "none": - canonicalized.append(0) - else: - raise ValueError("'monotonicities' elements must be from: [0, 1, " - "'increasing', 'none']. Given: %s" % monotonicities) - return canonicalized - return None - - -def canonicalize_unimodalities(unimodalities): - """Converts string constants representing unimodalities into integers. - - Args: - unimodalities: unimodalities hyperparameter of `Lattice` layer. - - Raises: - ValueError if one of unimodalities is invalid. - - Returns: - unimodalities represented as -1, 0 or 1. - """ - if not unimodalities: - return None - canonicalized = [] - for item in unimodalities: - if item in [-1, 0, 1]: - canonicalized.append(item) - elif isinstance(item, six.string_types) and item.lower() == "valley": - canonicalized.append(1) - elif isinstance(item, six.string_types) and item.lower() == "peak": - canonicalized.append(-1) - elif isinstance(item, six.string_types) and item.lower() == "none": - canonicalized.append(0) - else: - raise ValueError("'unimodalities' elements must be from: [-1, 0, 1, " - "'peak', 'none', 'valley']. Given: %s" % unimodalities) - return canonicalized - - -def canonicalize_trust(trusts): - """Converts string constants representing trust direction into integers. - - Args: - trusts: edgeworth_trusts or trapezoid_trusts hyperparameter of `Lattice` - layer. - - Raises: - ValueError if one of trust constraints is invalid. - - Returns: - Trust constraints with direction represented as 0 or 1. - """ - if trusts: - canonicalized = [] - for item in trusts: - if len(item) != 3: - raise ValueError("Trust constraints must consist of 3 elements. Seeing " - "constraint tuple %s" % item) - direction = item[2] - if direction in [-1, 1]: - canonicalized.append(item) - elif (isinstance(direction, six.string_types) and - direction.lower() == "positive"): - canonicalized.append((item[0], item[1], 1)) - elif (isinstance(direction, six.string_types) and - direction.lower() == "negative"): - canonicalized.append((item[0], item[1], -1)) - else: - raise ValueError("trust constraint direction must be from: [-1, 1, " - "'negative', 'positive']. Given: %s" % direction) - return canonicalized - return None - - def _unstack_nested_lists(tensor_or_list, axis): """Unstacks tensors stored within nested list.""" if isinstance(tensor_or_list, list): diff --git a/tensorflow_lattice/python/lattice_test.py b/tensorflow_lattice/python/lattice_test.py index 0cd34ba..4d358a9 100644 --- a/tensorflow_lattice/python/lattice_test.py +++ b/tensorflow_lattice/python/lattice_test.py @@ -220,6 +220,8 @@ def _SetDefaults(self, config): def _TestEnsemble(self, config): """Verifies that 'units > 1' lattice produces same output as 'units==1'.""" + # Note that the initialization of the lattice must be the same across the + # units dimension (otherwise the loss will be different). if self.disable_ensembles: return config = dict(config) @@ -277,6 +279,8 @@ def _TrainModel(self, config, plot_path=None): model = keras.models.Sequential() model.add(keras_layer) + # When we use multi-unit lattices, we only extract a single lattice for + # testing. if units > 1: lattice_index = config["lattice_index"] model.add( @@ -904,6 +908,7 @@ def testSimpleJointMonotonicity2D(self, joint_monotonicities, expected_loss): def testJointUnimodality1D(self, joint_unimodalities, expected_loss): if self.disable_all: return + def _Sin(x): result = math.sin(x[0]) # Make test exactly symmetric for both unimodality directions. @@ -964,10 +969,11 @@ def testJointUnimodality2DWshaped(self, joint_unimodalities, expected_loss): return center = (3, 3) + def WShaped2dFunction(x): distance = lambda x1, y1, x2, y2: ((x2 - x1)**2 + (y2 - y1)**2)**0.5 d = distance(x[0], x[1], center[0], center[1]) - t = (d - 0.6*center[0])**2 + t = (d - 0.6 * center[0])**2 return min(t, 6.0 - t) config = { @@ -1011,9 +1017,9 @@ def testJointUnimodality2OutOf4D(self, joint_unimodalities): def WShaped2dFunction(x): distance = lambda x1, y1, x2, y2: ((x2 - x1)**2 + (y2 - y1)**2)**0.5 - d = distance(x[center_indices[0]], x[center_indices[1]], - center[0], center[1]) - t = (d - 0.6*center[0])**2 + d = distance(x[center_indices[0]], x[center_indices[1]], center[0], + center[1]) + t = (d - 0.6 * center[0])**2 return min(t, 4.5 - t) def _DistributeXUniformly(num_points, lattice_sizes): @@ -1024,8 +1030,10 @@ def _DistributeXUniformly(num_points, lattice_sizes): for j in range(0, lattice_sizes[1] * points_per_vertex + 1): for k in range(0, lattice_sizes[2] * points_per_vertex + 1): for l in range(0, lattice_sizes[3] * points_per_vertex + 1): - p = [i / float(points_per_vertex), j / float(points_per_vertex), - k / float(points_per_vertex), l / float(points_per_vertex)] + p = [ + i / float(points_per_vertex), j / float(points_per_vertex), + k / float(points_per_vertex), l / float(points_per_vertex) + ] result.append(p) return result @@ -1914,40 +1922,37 @@ def testGraphSize(self, lattice_sizes, expected_graph_size): self.assertLessEqual(graph_size, expected_graph_size) @parameterized.parameters( - ("random_uniform_or_linear_initializer", [3, 3, 3], - [([0, 1, 2], "peak")], - tf.keras.initializers.RandomUniform), - ("random_uniform_or_linear_initializer", [3, 3, 3], - [([0, 1, 2], "valley")], - tf.keras.initializers.RandomUniform), - ("random_uniform_or_linear_initializer", [3, 3, 3], - [([0, 1], "valley")], - ll.LinearInitializer), - ("random_uniform_or_linear_initializer", [3, 3, 3], - [([0, 1], "valley"), ([2], "peak")], - ll.LinearInitializer), - ("random_uniform_or_linear_initializer", [3, 3, 3], - None, - ll.LinearInitializer), - ("linear_initializer", [3, 3, 3], - [([0, 1], "valley")], - ll.LinearInitializer), - ("random_monotonic_initializer", [3, 3, 3], - [([0, 1], "valley")], - ll.RandomMonotonicInitializer)) - def testCreateKernelInitializer( - self, - kernel_initializer_id, lattice_sizes, joint_unimodalities, expected_type): + ("random_uniform_or_linear_initializer", [3, 3, 3], [ + ([0, 1, 2], "peak") + ], tf.keras.initializers.RandomUniform), + ("random_uniform_or_linear_initializer", [3, 3, 3], [ + ([0, 1, 2], "valley") + ], tf.keras.initializers.RandomUniform), + ("random_uniform_or_linear_initializer", [3, 3, 3], [ + ([0, 1], "valley") + ], ll.LinearInitializer), + ("random_uniform_or_linear_initializer", [3, 3, 3], [ + ([0, 1], "valley"), ([2], "peak") + ], ll.LinearInitializer), ("random_uniform_or_linear_initializer", + [3, 3, 3], None, ll.LinearInitializer), + ("linear_initializer", [3, 3, 3], [ + ([0, 1], "valley") + ], ll.LinearInitializer), ("random_monotonic_initializer", [3, 3, 3], [ + ([0, 1], "valley") + ], ll.RandomMonotonicInitializer)) + def testCreateKernelInitializer(self, kernel_initializer_id, lattice_sizes, + joint_unimodalities, expected_type): self.assertEqual( expected_type, - type(ll.create_kernel_initializer( - kernel_initializer_id, - lattice_sizes, - monotonicities=None, - output_min=0.0, - output_max=1.0, - unimodalities=None, - joint_unimodalities=joint_unimodalities))) + type( + ll.create_kernel_initializer( + kernel_initializer_id, + lattice_sizes, + monotonicities=None, + output_min=0.0, + output_max=1.0, + unimodalities=None, + joint_unimodalities=joint_unimodalities))) @parameterized.parameters( # Single Unit diff --git a/tensorflow_lattice/python/linear_layer.py b/tensorflow_lattice/python/linear_layer.py index 6fba8a2..2ea0f19 100644 --- a/tensorflow_lattice/python/linear_layer.py +++ b/tensorflow_lattice/python/linear_layer.py @@ -23,6 +23,7 @@ from __future__ import print_function from . import linear_lib +from . import utils import numpy as np import tensorflow as tf from tensorflow import keras @@ -239,8 +240,8 @@ def build(self, input_shape): constraint=None, dtype=self.dtype) - input_min = linear_lib.canonicalize_input_bounds(self.input_min) - input_max = linear_lib.canonicalize_input_bounds(self.input_max) + input_min = utils.canonicalize_input_bounds(self.input_min) + input_max = utils.canonicalize_input_bounds(self.input_max) if ((input_min and input_min.count(None) < len(input_min)) or (input_max and input_max.count(None) < len(input_max))): lower_bounds = [val if val is not None else -np.inf @@ -315,12 +316,11 @@ def assert_constraints(self, eps=1e-4): """ return linear_lib.assert_constraints( weights=self.kernel, - monotonicities=linear_lib.canonicalize_monotonicities( - self.monotonicities), + monotonicities=utils.canonicalize_monotonicities(self.monotonicities), monotonic_dominances=self.monotonic_dominances, range_dominances=self.range_dominances, - input_min=linear_lib.canonicalize_input_bounds(self.input_min), - input_max=linear_lib.canonicalize_input_bounds(self.input_max), + input_min=utils.canonicalize_input_bounds(self.input_min), + input_max=utils.canonicalize_input_bounds(self.input_max), normalization_order=self.normalization_order, eps=eps) @@ -393,12 +393,11 @@ def __call__(self, w): """ return linear_lib.project( weights=w, - monotonicities=linear_lib.canonicalize_monotonicities( - self.monotonicities), + monotonicities=utils.canonicalize_monotonicities(self.monotonicities), monotonic_dominances=self.monotonic_dominances, range_dominances=self.range_dominances, - input_min=linear_lib.canonicalize_input_bounds(self.input_min), - input_max=linear_lib.canonicalize_input_bounds(self.input_max), + input_min=utils.canonicalize_input_bounds(self.input_min), + input_max=utils.canonicalize_input_bounds(self.input_max), normalization_order=self.normalization_order) def get_config(self): diff --git a/tensorflow_lattice/python/linear_lib.py b/tensorflow_lattice/python/linear_lib.py index 986c386..8e0f0a7 100644 --- a/tensorflow_lattice/python/linear_lib.py +++ b/tensorflow_lattice/python/linear_lib.py @@ -11,22 +11,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Implementation of algorithms required for Linear layer.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from . import internal_utils as iu +from . import internal_utils +from . import utils import six import tensorflow as tf _NORMALIZATION_EPS = 1e-8 -def project(weights, monotonicities, monotonic_dominances=None, - range_dominances=None, input_min=None, input_max=None, +def project(weights, + monotonicities, + monotonic_dominances=None, + range_dominances=None, + input_min=None, + input_max=None, normalization_order=None): """Applies constraints to weights. @@ -56,12 +60,13 @@ def project(weights, monotonicities, monotonic_dominances=None, Returns: 'weights' with monotonicity constraints and normalization applied to it. """ - verify_hyperparameters(weights_shape=weights.shape, - monotonicities=monotonicities, - monotonic_dominances=monotonic_dominances, - range_dominances=range_dominances, - input_min=input_min, - input_max=input_max) + verify_hyperparameters( + weights_shape=weights.shape, + monotonicities=monotonicities, + monotonic_dominances=monotonic_dominances, + range_dominances=range_dominances, + input_min=input_min, + input_max=input_max) if any(monotonicities): if 1 in monotonicities: inverted_increasing_mask = tf.constant( @@ -82,7 +87,8 @@ def project(weights, monotonicities, monotonic_dominances=None, if monotonic_dominances: monotonic_dominances = [(j, i) for i, j in monotonic_dominances] - weights = iu.approximately_project_categorical_partial_monotonicities(weights, monotonic_dominances) + weights = internal_utils.approximately_project_categorical_partial_monotonicities( + weights, monotonic_dominances) if range_dominances: range_dominances = [(j, i) for i, j in range_dominances] @@ -92,22 +98,28 @@ def project(weights, monotonicities, monotonic_dominances=None, scalings[dim] *= upper - lower scalings = tf.constant(scalings, dtype=weights.dtype, shape=weights.shape) weights *= scalings - weights = iu.approximately_project_categorical_partial_monotonicities( + weights = internal_utils.approximately_project_categorical_partial_monotonicities( weights, range_dominances) weights /= scalings if normalization_order: norm = tf.norm(weights, ord=normalization_order) - weights = tf.cond(norm < _NORMALIZATION_EPS, - true_fn=lambda: weights, - false_fn=lambda: weights / norm) + weights = tf.cond( + norm < _NORMALIZATION_EPS, + true_fn=lambda: weights, + false_fn=lambda: weights / norm) return weights -def assert_constraints(weights, monotonicities, monotonic_dominances, - range_dominances, input_min, input_max, - normalization_order, eps=1e-4): +def assert_constraints(weights, + monotonicities, + monotonic_dominances, + range_dominances, + input_min, + input_max, + normalization_order, + eps=1e-4): """Asserts that weights satisfy constraints. Args: @@ -138,29 +150,29 @@ def assert_constraints(weights, monotonicities, monotonic_dominances, if any(monotonicities): # Create constant specifying shape explicitly because otherwise due to # weights shape ending with dimesion of size 1 broadcasting will hurt us. - monotonicities_constant = tf.constant(monotonicities, - shape=weights.shape, - dtype=weights.dtype) + monotonicities_constant = tf.constant( + monotonicities, shape=weights.shape, dtype=weights.dtype) diff = tf.reduce_min(weights * monotonicities_constant) asserts.append( - tf.Assert(diff >= -eps, - data=["Monotonicity violation", - "Monotonicities:", monotonicities, - "Min monotonicity diff:", diff, - "Epsilon:", eps, - "Weights:", weights], - summarize=weights.shape[0])) + tf.Assert( + diff >= -eps, + data=[ + "Monotonicity violation", "Monotonicities:", monotonicities, + "Min monotonicity diff:", diff, "Epsilon:", eps, "Weights:", + weights + ], + summarize=weights.shape[0])) for dominant_dim, weak_dim in monotonic_dominances or []: diff = tf.reduce_min(weights[dominant_dim] - weights[weak_dim]) asserts.append( - tf.Assert(diff >= -eps, - data=["Monotonic dominance violation", - "Dominant dim:", dominant_dim, - "Weak dim:", weak_dim, - "Epsilon:", eps, - "Weights:", weights], - summarize=weights.shape[0])) + tf.Assert( + diff >= -eps, + data=[ + "Monotonic dominance violation", "Dominant dim:", dominant_dim, + "Weak dim:", weak_dim, "Epsilon:", eps, "Weights:", weights + ], + summarize=weights.shape[0])) if range_dominances: scalings = [-1.0 if m == -1 else 1.0 for m in monotonicities] @@ -171,27 +183,29 @@ def assert_constraints(weights, monotonicities, monotonic_dominances, diff = tf.reduce_min(scalings[dominant_dim] * weights[dominant_dim] - scalings[weak_dim] * weights[weak_dim]) asserts.append( - tf.Assert(diff >= -eps, - data=["Range dominance violation", - "Dominant dim:", dominant_dim, - "Weak dim:", weak_dim, - "Epsilon:", eps, - "Weights:", weights, - "Scalings:", scalings], - summarize=weights.shape[0])) + tf.Assert( + diff >= -eps, + data=[ + "Range dominance violation", "Dominant dim:", dominant_dim, + "Weak dim:", weak_dim, "Epsilon:", eps, "Weights:", weights, + "Scalings:", scalings + ], + summarize=weights.shape[0])) if normalization_order: norm = tf.norm(weights, ord=normalization_order) asserts.append( # Norm can be either 0.0 or 1.0, because if all weights are close to 0.0 # we can't scale them to get norm 1.0. - tf.Assert(tf.logical_or(tf.abs(norm - 1.0) < eps, - tf.abs(norm) < _NORMALIZATION_EPS), - data=["Normalization order violation", - "Norm:", norm, - "Epsilon:", eps, - "Weights:", weights], - summarize=weights.shape[0])) + tf.Assert( + tf.logical_or( + tf.abs(norm - 1.0) < eps, + tf.abs(norm) < _NORMALIZATION_EPS), + data=[ + "Normalization order violation", "Norm:", norm, "Epsilon:", eps, + "Weights:", weights + ], + summarize=weights.shape[0])) return asserts @@ -237,16 +251,16 @@ def verify_hyperparameters(num_input_dims=None, ValueError: If something is inconsistent. """ # It also raises errors if monotonicities specified incorrectly. - monotonicities = canonicalize_monotonicities(monotonicities) - input_min = canonicalize_input_bounds(input_min) - input_max = canonicalize_input_bounds(input_max) + monotonicities = utils.canonicalize_monotonicities(monotonicities) + input_min = utils.canonicalize_input_bounds(input_min) + input_max = utils.canonicalize_input_bounds(input_max) if monotonicities is not None and num_input_dims is not None: if len(monotonicities) != num_input_dims: raise ValueError("Number of elements in 'monotonicities' must be equal to" " num_input_dims. monotoniticites: %s, " - "len(monotonicities): %d, num_input_dims: %d" - % (monotonicities, len(monotonicities), num_input_dims)) + "len(monotonicities): %d, num_input_dims: %d" % + (monotonicities, len(monotonicities), num_input_dims)) if weights_shape is not None: if len(weights_shape) != 2 or weights_shape[1] != 1: @@ -257,13 +271,15 @@ def verify_hyperparameters(num_input_dims=None, "correspond to number of weights. Weights shape: %s, " "monotonicities: %s" % (weights_shape, monotonicities)) if input_min is not None and weights_shape[0] != len(input_min): - raise ValueError("Number of elements in 'input_min' does not correspond " - "to number of weights. Weights shape: %s, input_min: %s" - % (weights_shape, input_min)) + raise ValueError( + "Number of elements in 'input_min' does not correspond " + "to number of weights. Weights shape: %s, input_min: %s" % + (weights_shape, input_min)) if input_max is not None and weights_shape[0] != len(input_max): - raise ValueError("Number of elements in 'input_max' does not correspond " - "to number of weights. Weights shape: %s, input_max: %s" - % (weights_shape, input_max)) + raise ValueError( + "Number of elements in 'input_max' does not correspond " + "to number of weights. Weights shape: %s, input_max: %s" % + (weights_shape, input_max)) for dim, (lower, upper) in enumerate(zip(input_min or [], input_max or [])): if lower is not None and upper is not None and lower > upper: @@ -353,60 +369,3 @@ def verify_hyperparameters(num_input_dims=None, raise ValueError("Cannot have both monotonic and range dominance " "constraints specified on the same dimension. " "Dimension %d is set by both." % (dim)) - - -def canonicalize_monotonicities(monotonicities): - """Converts string constants representing monotonicities into integers. - - Args: - monotonicities: monotonicities hyperparameter of `Lattice` layer. - - Raises: - ValueError if one of monotonicities is invalid. - - Returns: - monotonicities represented as 0 or 1. - """ - if monotonicities: - canonicalized = [] - for item in monotonicities: - if item in [-1, 0, 1]: - canonicalized.append(item) - elif isinstance(item, six.string_types) and item.lower() == "decreasing": - canonicalized.append(-1) - elif isinstance(item, six.string_types) and item.lower() == "none": - canonicalized.append(0) - elif isinstance(item, six.string_types) and item.lower() == "increasing": - canonicalized.append(1) - else: - raise ValueError("'monotonicities' elements must be from: [-1, 0, 1, " - "'decreasing', 'none', 'increasing']. " - "Given: %s" % monotonicities) - return canonicalized - return None - - -def canonicalize_input_bounds(input_bounds): - """Converts string constant 'none' representing unspecified bound into None. - - Args: - input_bounds: input_min or input_max hyperparameter of `Linear` layer. - - Raises: - ValueError if one of elements in input_bounds is invalid. - - Returns: - input_bounds represented as float or None. - """ - if input_bounds: - canonicalized = [] - for item in input_bounds: - if isinstance(item, float) or item is None: - canonicalized.append(item) - elif isinstance(item, six.string_types) and item.lower() == "none": - canonicalized.append(None) - else: - raise ValueError("Both 'input_min' and 'input_max' elements must be " - "either float or 'none'. Given: %s" % input_bounds) - return canonicalized - return None diff --git a/tensorflow_lattice/python/linear_test.py b/tensorflow_lattice/python/linear_test.py index 4da42b1..784d468 100644 --- a/tensorflow_lattice/python/linear_test.py +++ b/tensorflow_lattice/python/linear_test.py @@ -28,6 +28,7 @@ from tensorflow_lattice.python import linear_layer as linl from tensorflow_lattice.python import linear_lib from tensorflow_lattice.python import test_utils +from tensorflow_lattice.python import utils _DISABLE_ALL = False _LOSS_EPS = 0.0001 @@ -196,7 +197,7 @@ def _NegateAndTrain(self, config): if isinstance(config["monotonicities"], list): negated_config["monotonicities"] = [ -monotonicity for monotonicity in - linear_lib.canonicalize_monotonicities(config["monotonicities"]) + utils.canonicalize_monotonicities(config["monotonicities"]) ] else: negated_config["monotonicities"] = -config["monotonicities"] diff --git a/tensorflow_lattice/python/premade.py b/tensorflow_lattice/python/premade.py index 1dd28ce..0595ea6 100644 --- a/tensorflow_lattice/python/premade.py +++ b/tensorflow_lattice/python/premade.py @@ -44,6 +44,7 @@ from . import parallel_combination_layer from . import premade_lib from . import pwl_calibration_layer +from . import rtl_layer from absl import logging import tensorflow as tf @@ -102,47 +103,23 @@ def __init__(self, model_config=None, dtype=tf.float32, **kwargs): input_layer = premade_lib.build_input_layer( feature_configs=model_config.feature_configs, dtype=dtype) - submodels_inputs = premade_lib.build_calibration_layers( + lattice_outputs = premade_lib.build_calibrated_lattice_ensemble_layer( calibration_input_layer=input_layer, - feature_configs=model_config.feature_configs, model_config=model_config, - layer_output_range=premade_lib.LayerOutputRange.INPUT_TO_LATTICE, - submodels=model_config.lattices, - separate_calibrators=model_config.separate_calibrators, dtype=dtype) - lattice_outputs = [] - for submodel_index, (lattice_feature_names, lattice_input) in enumerate( - zip(model_config.lattices, submodels_inputs)): - lattice_feature_configs = [ - model_config.feature_config_by_name(feature_name) - for feature_name in lattice_feature_names - ] - - lattice_layer_output_range = ( - premade_lib.LayerOutputRange.INPUT_TO_FINAL_CALIBRATION - if model_config.output_calibration else - premade_lib.LayerOutputRange.MODEL_OUTPUT) - lattice_outputs.append( - premade_lib.build_lattice_layer( - lattice_input=lattice_input, - feature_configs=lattice_feature_configs, - model_config=model_config, - layer_output_range=lattice_layer_output_range, - submodel_index=submodel_index, - is_inside_ensemble=True, - dtype=dtype)) - - if len(lattice_outputs) > 1: - if model_config.use_linear_combination: - averaged_lattice_output = premade_lib.build_linear_combination_layer( - ensemble_outputs=lattice_outputs, - model_config=model_config, - dtype=dtype) - else: - averaged_lattice_output = tf.keras.layers.Average()(lattice_outputs) + if model_config.use_linear_combination: + averaged_lattice_output = premade_lib.build_linear_combination_layer( + ensemble_outputs=lattice_outputs, + model_config=model_config, + dtype=dtype) else: - averaged_lattice_output = lattice_outputs[0] + if isinstance(lattice_outputs, list): + averaged_lattice_output = tf.keras.layers.Average()(lattice_outputs) + else: + averaged_lattice_output = tf.reduce_mean( + lattice_outputs, axis=-1, keepdims=True) + if model_config.output_calibration: model_output = premade_lib.build_output_calibration_layer( output_calibration_input=averaged_lattice_output, @@ -235,7 +212,6 @@ def __init__(self, model_config=None, dtype=tf.float32, **kwargs): feature_configs=model_config.feature_configs, dtype=dtype) submodels_inputs = premade_lib.build_calibration_layers( calibration_input_layer=input_layer, - feature_configs=model_config.feature_configs, model_config=model_config, layer_output_range=premade_lib.LayerOutputRange.INPUT_TO_LATTICE, submodels=[[ @@ -355,7 +331,6 @@ def __init__(self, model_config=None, dtype=tf.float32, **kwargs): premade_lib.LayerOutputRange.MODEL_OUTPUT) submodels_inputs = premade_lib.build_calibration_layers( calibration_input_layer=input_layer, - feature_configs=model_config.feature_configs, model_config=model_config, layer_output_range=calibration_layer_output_range, submodels=[[ @@ -552,34 +527,32 @@ def get_custom_objects(custom_objects=None): tfl_custom_objects = { 'AggregateFunction': AggregateFunction, + 'AggregateFunctionConfig': + configs.AggregateFunctionConfig, + 'Aggregation': + aggregation_layer.Aggregation, 'CalibratedLatticeEnsemble': CalibratedLatticeEnsemble, 'CalibratedLattice': CalibratedLattice, + 'CalibratedLatticeConfig': + configs.CalibratedLatticeConfig, + 'CalibratedLatticeEnsembleConfig': + configs.CalibratedLatticeEnsembleConfig, 'CalibratedLinear': CalibratedLinear, + 'CalibratedLinearConfig': + configs.CalibratedLinearConfig, 'CategoricalCalibration': categorical_calibration_layer.CategoricalCalibration, 'CategoricalCalibrationConstraints': categorical_calibration_layer.CategoricalCalibrationConstraints, - 'FeatureConfig': - configs.FeatureConfig, - 'RegularizerConfig': - configs.RegularizerConfig, - 'TrustConfig': - configs.TrustConfig, 'DominanceConfig': configs.DominanceConfig, - 'CalibratedLatticeEnsembleConfig': - configs.CalibratedLatticeEnsembleConfig, - 'CalibratedLatticeConfig': - configs.CalibratedLatticeConfig, - 'CalibratedLinearConfig': - configs.CalibratedLinearConfig, - 'AggregateFunctionConfig': - configs.AggregateFunctionConfig, - 'Aggregation': - aggregation_layer.Aggregation, + 'FeatureConfig': + configs.FeatureConfig, + 'LaplacianRegularizer': + lattice_layer.LaplacianRegularizer, 'Lattice': lattice_layer.Lattice, 'LatticeConstraints': @@ -588,14 +561,26 @@ def get_custom_objects(custom_objects=None): linear_layer.Linear, 'LinearConstraints': linear_layer.LinearConstraints, + 'LinearInitializer': + lattice_layer.LinearInitializer, + 'NaiveBoundsConstraints': + pwl_calibration_layer.NaiveBoundsConstraints, 'ParallelCombination': parallel_combination_layer.ParallelCombination, 'PWLCalibration': pwl_calibration_layer.PWLCalibration, 'PWLCalibrationConstraints': pwl_calibration_layer.PWLCalibrationConstraints, - 'NaiveBoundsConstraints': - pwl_calibration_layer.NaiveBoundsConstraints, + 'RandomMonotonicInitializer': + lattice_layer.RandomMonotonicInitializer, + 'RegularizerConfig': + configs.RegularizerConfig, + 'RTL': + rtl_layer.RTL, + 'TorsionRegularizer': + lattice_layer.TorsionRegularizer, + 'TrustConfig': + configs.TrustConfig, } if custom_objects is not None: tfl_custom_objects.update(custom_objects) diff --git a/tensorflow_lattice/python/premade_lib.py b/tensorflow_lattice/python/premade_lib.py index 101d70d..058dfd0 100644 --- a/tensorflow_lattice/python/premade_lib.py +++ b/tensorflow_lattice/python/premade_lib.py @@ -29,7 +29,8 @@ from . import lattice_lib from . import linear_layer from . import pwl_calibration_layer -from . import pwl_calibration_lib +from . import rtl_layer +from . import utils from absl import logging import numpy as np @@ -45,6 +46,8 @@ LINEAR_LAYER_NAME = 'tfl_linear' OUTPUT_LINEAR_COMBINATION_LAYER_NAME = 'tfl_output_linear_combination' OUTPUT_CALIB_LAYER_NAME = 'tfl_output_calib' +RTL_LAYER_NAME = 'tfl_rtl' +RTL_INPUT_NAME = 'tfl_rtl_input' # Prefix for passthrough (identity) nodes for shared calibration. # These nodes pass shared calibrated values to submodels in an ensemble. @@ -174,44 +177,31 @@ def build_input_layer(feature_configs, dtype, ragged=False): return input_layer -def build_calibration_layers(calibration_input_layer, feature_configs, - model_config, layer_output_range, submodels, - separate_calibrators, dtype): - """Creates a calibration layer for `submodels` as list of list of features. +def build_multi_unit_calibration_layers(calibration_input_layer, + calibration_output_units, model_config, + layer_output_range, + output_single_tensor, dtype): + """Creates a mapping from feature names to calibration outputs. Args: calibration_input_layer: A mapping from feature name to `tf.keras.Input`. - feature_configs: A list of `tfl.configs.FeatureConfig` instances that - specify configurations for each feature. + calibration_output_units: A mapping from feature name to units. model_config: Model configuration object describing model architecture. Should be one of the model configs in `tfl.configs`. layer_output_range: A `tfl.premade_lib.LayerOutputRange` enum. - submodels: A list of list of feature names. - separate_calibrators: If features should be separately calibrated for each - lattice in an ensemble. + output_single_tensor: If output for each feature should be a single tensor. dtype: dtype Returns: - A list of list of Tensors representing a calibration layer for `submodels`. + A mapping from feature name to calibration output Tensors. """ - # Create a list of (feature_name, calibration_output_idx) pairs for each - # submodel. When using shared calibration, all submodels will have - # calibration_output_idx = 0. - submodels_input_features = [] - calibration_last_index = collections.defaultdict(int) - for submodel in submodels: - submodel_input_features = [] - submodels_input_features.append(submodel_input_features) - for feature_name in submodel: - submodel_input_features.append( - (feature_name, calibration_last_index[feature_name])) - if separate_calibrators: - calibration_last_index[feature_name] += 1 - calibration_output = {} - for feature_config in feature_configs: - feature_name = feature_config.name - units = max(calibration_last_index[feature_name], 1) + for feature_name, units in calibration_output_units.items(): + if units == 0: + raise ValueError( + 'Feature {} is not used. Calibration output units is 0.'.format( + feature_name)) + feature_config = model_config.feature_config_by_name(feature_name) calibration_input = calibration_input_layer[feature_name] layer_name = '{}_{}'.format(CALIB_LAYER_NAME, feature_name) @@ -238,13 +228,14 @@ def build_calibration_layers(calibration_input_layer, feature_configs, kernel_regularizer = _input_calibration_regularizers( model_config, feature_config) monotonicity = feature_config.monotonicity - if (pwl_calibration_lib.canonicalize_monotonicity(monotonicity) == 0 and + if (utils.canonicalize_monotonicity(monotonicity) == 0 and feature_config.pwl_calibration_always_monotonic): monotonicity = 1 kernel_initializer = pwl_calibration_layer.UniformOutputInitializer( output_min=output_init_min, output_max=output_init_max, - monotonicity=monotonicity) + monotonicity=monotonicity, + keypoints=feature_config.pwl_calibration_input_keypoints) calibrated = ( pwl_calibration_layer.PWLCalibration( units=units, @@ -261,10 +252,58 @@ def build_calibration_layers(calibration_input_layer, feature_configs, convexity=feature_config.pwl_calibration_convexity, dtype=dtype, name=layer_name)(calibration_input)) - if units == 1: + if output_single_tensor: + calibration_output[feature_name] = calibrated + elif units == 1: calibration_output[feature_name] = [calibrated] else: calibration_output[feature_name] = tf.split(calibrated, units, axis=1) + return calibration_output + + +def build_calibration_layers(calibration_input_layer, model_config, + layer_output_range, submodels, + separate_calibrators, dtype): + """Creates a calibration layer for `submodels` as list of list of features. + + Args: + calibration_input_layer: A mapping from feature name to `tf.keras.Input`. + model_config: Model configuration object describing model architecture. + Should be one of the model configs in `tfl.configs`. + layer_output_range: A `tfl.premade_lib.LayerOutputRange` enum. + submodels: A list of list of feature names. + separate_calibrators: If features should be separately calibrated for each + lattice in an ensemble. + dtype: dtype + + Returns: + A list of list of Tensors representing a calibration layer for `submodels`. + """ + # Create a list of (feature_name, calibration_output_idx) pairs for each + # submodel. When using shared calibration, all submodels will have + # calibration_output_idx = 0. + submodels_input_features = [] + calibration_last_index = collections.defaultdict(int) + for submodel in submodels: + submodel_input_features = [] + submodels_input_features.append(submodel_input_features) + for feature_name in submodel: + submodel_input_features.append( + (feature_name, calibration_last_index[feature_name])) + if separate_calibrators: + calibration_last_index[feature_name] += 1 + + # This is to account for shared calibration. + calibration_output_units = { + name: max(index, 1) for name, index in calibration_last_index.items() + } + calibration_output = build_multi_unit_calibration_layers( + calibration_input_layer=calibration_input_layer, + calibration_output_units=calibration_output_units, + model_config=model_config, + layer_output_range=layer_output_range, + output_single_tensor=False, + dtype=dtype) # Create passthrough nodes for each submodel input so that we can recover # the model structure for plotting and analysis. @@ -336,7 +375,7 @@ def build_aggregation_layer(aggregation_input_layer, model_config, dtype=np.float32), output_min=0.0, output_max=lattice_sizes[i] - 1.0, - monotonicity=pwl_calibration_lib.canonicalize_monotonicity( + monotonicity=utils.canonicalize_monotonicity( model_config.middle_monotonicity), kernel_regularizer=_middle_calibration_regularizers(model_config), dtype=dtype, @@ -485,7 +524,8 @@ def build_lattice_layer(lattice_input, feature_configs, model_config, lattice_unimodalities = [ feature_config.unimodality for feature_config in feature_configs ] - lattice_regularizers = _lattice_regularizers(model_config, feature_configs) + lattice_regularizers = _lattice_regularizers(model_config, + feature_configs) or None # Construct trust constraints within this lattice. edgeworth_trusts = [] @@ -539,6 +579,156 @@ def build_lattice_layer(lattice_input, feature_configs, model_config, lattice_input) +def build_lattice_ensemble_layer(submodels_inputs, model_config, dtype): + """Creates an ensemble of `tfl.layers.Lattice` layers. + + Args: + submodels_inputs: List of inputs to each of the lattice layers in the + ensemble. The order corresponds to the elements of model_config.lattices. + model_config: Model configuration object describing model architecture. + Should be one of the model configs in `tfl.configs`. + dtype: dtype + + Returns: + A list of `tfl.layers.Lattice` instances. + """ + lattice_outputs = [] + for submodel_index, (lattice_feature_names, lattice_input) in enumerate( + zip(model_config.lattices, submodels_inputs)): + lattice_feature_configs = [ + model_config.feature_config_by_name(feature_name) + for feature_name in lattice_feature_names + ] + lattice_layer_output_range = ( + LayerOutputRange.INPUT_TO_FINAL_CALIBRATION + if model_config.output_calibration else LayerOutputRange.MODEL_OUTPUT) + lattice_outputs.append( + build_lattice_layer( + lattice_input=lattice_input, + feature_configs=lattice_feature_configs, + model_config=model_config, + layer_output_range=lattice_layer_output_range, + submodel_index=submodel_index, + is_inside_ensemble=True, + dtype=dtype)) + return lattice_outputs + + +def build_rtl_layer(calibration_outputs, model_config, submodel_index, dtype): + """Creates a `tfl.layers.RTL` layer. + + This function expects that all features defined in + model_config.feature_configs are used and present in calibration_outputs. + + Args: + calibration_outputs: A mapping from feature name to calibration output. + model_config: Model configuration object describing model architecture. + Should be one of the model configs in `tfl.configs`. + submodel_index: Corresponding index into submodels. + dtype: dtype + + Returns: + A `tfl.layers.RTL` instance. + """ + layer_name = '{}_{}'.format(RTL_LAYER_NAME, submodel_index) + + rtl_layer_output_range = ( + LayerOutputRange.INPUT_TO_FINAL_CALIBRATION + if model_config.output_calibration else LayerOutputRange.MODEL_OUTPUT) + + (output_min, output_max, output_init_min, + output_init_max) = _output_range(rtl_layer_output_range, model_config) + + lattice_regularizers = _lattice_regularizers( + model_config, model_config.feature_configs) or None + + rtl_inputs = collections.defaultdict(list) + for feature_config in model_config.feature_configs: + passthrough_name = '{}_{}'.format(RTL_INPUT_NAME, feature_config.name) + calibration_output = tf.identity( + calibration_outputs[feature_config.name], name=passthrough_name) + if feature_config.monotonicity in [1, -1, 'increasing', 'decreasing']: + rtl_inputs['increasing'].append(calibration_output) + else: + rtl_inputs['unconstrained'].append(calibration_output) + + lattice_size = model_config.feature_configs[0].lattice_size + kernel_initializer = lattice_layer.RandomMonotonicInitializer( + lattice_sizes=[lattice_size] * model_config.lattice_rank, + output_min=output_init_min, + output_max=output_init_max) + return rtl_layer.RTL( + num_lattices=model_config.num_lattices, + lattice_rank=model_config.lattice_rank, + lattice_size=lattice_size, + output_min=output_min, + output_max=output_max, + clip_inputs=False, + interpolation=model_config.interpolation, + kernel_regularizer=lattice_regularizers, + kernel_initializer=kernel_initializer, + dtype=dtype, + name=layer_name)( + rtl_inputs) + + +def build_calibrated_lattice_ensemble_layer(calibration_input_layer, + model_config, dtype): + """Creates a calibration layer followed by a lattice ensemble layer. + + Args: + calibration_input_layer: A mapping from feature name to `tf.keras.Input`. + model_config: Model configuration object describing model architecture. + Should be one of the model configs in `tfl.configs`. + dtype: dtype + + Returns: + A `tfl.layers.RTL` instance if model_config.lattices is 'rtl_layer. + Otherwise a list of `tfl.layers.Lattice` instances. + """ + if model_config.lattices == 'rtl_layer': + num_features = len(model_config.feature_configs) + units = [1] * num_features + if model_config.separate_calibrators: + num_inputs = model_config.num_lattices * model_config.lattice_rank + # We divide the number of inputs semi-evenly by the number of features. + for i in range(num_features): + units[i] = ((i + 1) * num_inputs // num_features - + i * num_inputs // num_features) + calibration_output_units = { + feature_config.name: units[i] + for i, feature_config in enumerate(model_config.feature_configs) + } + calibration_outputs = build_multi_unit_calibration_layers( + calibration_input_layer=calibration_input_layer, + calibration_output_units=calibration_output_units, + model_config=model_config, + layer_output_range=LayerOutputRange.INPUT_TO_LATTICE, + output_single_tensor=True, + dtype=dtype) + + lattice_outputs = build_rtl_layer( + calibration_outputs=calibration_outputs, + model_config=model_config, + submodel_index=0, + dtype=dtype) + else: + submodels_inputs = build_calibration_layers( + calibration_input_layer=calibration_input_layer, + model_config=model_config, + layer_output_range=LayerOutputRange.INPUT_TO_LATTICE, + submodels=model_config.lattices, + separate_calibrators=model_config.separate_calibrators, + dtype=dtype) + + lattice_outputs = build_lattice_ensemble_layer( + submodels_inputs=submodels_inputs, + model_config=model_config, + dtype=dtype) + + return lattice_outputs + + def build_linear_combination_layer(ensemble_outputs, model_config, dtype): """Creates a `tfl.layers.Linear` layer initialized to be an average. @@ -1090,13 +1280,58 @@ def verify_config(model_config): Should be one of the model configs in `tfl.configs`. """ if isinstance(model_config, configs.CalibratedLatticeEnsembleConfig): - if not isinstance(model_config.lattices, list): - raise ValueError('Lattices are not fully specified for ensemble config.') - for lattice in model_config.lattices: - if (not np.iterable(lattice) or - any(not isinstance(x, str) for x in lattice)): + # Make sure there are more than one lattice. If not, tell user to use + # CalibratedLattice instead. + if model_config.lattices == 'rtl_layer': + # RTL must have num_lattices specified and >= 2. + if model_config.num_lattices is None: + raise ValueError('model_config.num_lattices must be specified when ' + 'model_config.lattices is set to \'rtl_layer\'.') + if model_config.num_lattices < 2: + raise ValueError( + 'CalibratedLatticeEnsemble must have >= 2 lattices. For single ' + 'lattice models, use CalibratedLattice instead.') + # Check that all lattices sizes for all features are the same. + if any(feature_config.lattice_size != + model_config.feature_configs[0].lattice_size + for feature_config in model_config.feature_configs): + raise ValueError('Lattice sizes must be the same for all features.') + # Check that there are only monotonicity and bound constraints. + if any(feature_config.unimodality != 'none' + for feature_config in model_config.feature_configs): raise ValueError( - 'Lattices are not fully specified for ensemble config.') + 'RTL Layer does not currently support unimodality constraints.') + if any(feature_config.reflects_trust_in is not None + for feature_config in model_config.feature_configs): + raise ValueError( + 'RTL Layer does not currently support trust constraints.') + if any(feature_config.dominates is not None + for feature_config in model_config.feature_configs): + raise ValueError( + 'RTL Layer does not currently support dominance constraints.') + # Check that there are no per-feature lattice regularizers. + for feature_config in model_config.feature_configs: + for regularizer_config in feature_config.regularizer_configs or []: + if not regularizer_config.name.startswith( + _INPUT_CALIB_REGULARIZER_PREFIX): + raise ValueError( + 'RTL Layer does not currently support per-feature lattice ' + 'regularizers.') + elif isinstance(model_config.lattices, list): + if len(model_config.lattices) < 2: + raise ValueError( + 'CalibratedLatticeEnsemble must have >= 2 lattices. For single ' + 'lattice models, use CalibratedLattice instead.') + for lattice in model_config.lattices: + if (not np.iterable(lattice) or + any(not isinstance(x, str) for x in lattice)): + raise ValueError( + 'Lattices are not fully specified for ensemble config.') + else: + raise ValueError( + 'Lattices are not fully specified for ensemble config. Lattices must ' + 'be set to \'rtl_layer\' or be fully specified as a list of lists of ' + 'feature names.') if isinstance(model_config, configs.AggregateFunctionConfig): if model_config.middle_dimension < 1: raise ValueError('Middle dimension must be at least 1: {}'.format( diff --git a/tensorflow_lattice/python/premade_test.py b/tensorflow_lattice/python/premade_test.py index 91a55ef..0b1e11a 100644 --- a/tensorflow_lattice/python/premade_test.py +++ b/tensorflow_lattice/python/premade_test.py @@ -545,7 +545,50 @@ def testCalibratedLatticeEnsembleCrystals(self): verbose=False) results = model.evaluate( self.heart_test_x, self.heart_test_y, verbose=False) - logging.info('Calibrated lattice ensemble classifier results:') + logging.info('Calibrated lattice ensemble crystals classifier results:') + logging.info(results) + self.assertGreater(results[1], 0.85) + + def testCalibratedLatticeEnsembleRTL(self): + # Construct model. + self._ResetAllBackends() + rtl_feature_configs = copy.deepcopy(self.heart_feature_configs) + for feature_config in rtl_feature_configs: + feature_config.lattice_size = 2 + feature_config.unimodality = 'none' + feature_config.reflects_trust_in = None + feature_config.dominates = None + feature_config.regularizer_configs = None + model_config = configs.CalibratedLatticeEnsembleConfig( + regularizer_configs=[ + configs.RegularizerConfig(name='torsion', l2=1e-4), + configs.RegularizerConfig(name='output_calib_hessian', l2=1e-4), + ], + feature_configs=rtl_feature_configs, + lattices='rtl_layer', + num_lattices=6, + lattice_rank=5, + separate_calibrators=True, + output_calibration=False, + output_min=self.heart_min_label, + output_max=self.heart_max_label - self.numerical_error_epsilon, + output_initialization=[self.heart_min_label, self.heart_max_label], + ) + # Construct and train final model + model = premade.CalibratedLatticeEnsemble(model_config) + model.compile( + loss=tf.keras.losses.BinaryCrossentropy(), + metrics=tf.keras.metrics.AUC(), + optimizer=tf.keras.optimizers.Adam(0.01)) + model.fit( + self.heart_train_x, + self.heart_train_y, + batch_size=100, + epochs=200, + verbose=False) + results = model.evaluate( + self.heart_test_x, self.heart_test_y, verbose=False) + logging.info('Calibrated lattice ensemble rtl classifier results:') logging.info(results) self.assertGreater(results[1], 0.85) @@ -579,6 +622,42 @@ def testLatticeEnsembleH5FormatSaveLoad(self): model.predict(fake_data['eval_xs']), loaded_model.predict(fake_data['eval_xs'])) + def testLatticeEnsembleRTLH5FormatSaveLoad(self): + rtl_feature_configs = copy.deepcopy(feature_configs) + for feature_config in rtl_feature_configs: + feature_config.lattice_size = 2 + feature_config.unimodality = 'none' + feature_config.reflects_trust_in = None + feature_config.dominates = None + feature_config.regularizer_configs = None + model_config = configs.CalibratedLatticeEnsembleConfig( + feature_configs=copy.deepcopy(rtl_feature_configs), + lattices='rtl_layer', + num_lattices=2, + lattice_rank=2, + separate_calibrators=True, + regularizer_configs=[ + configs.RegularizerConfig('calib_hessian', l2=1e-3), + configs.RegularizerConfig('torsion', l2=1e-4), + ], + output_min=-1.0, + output_max=1.0, + output_calibration=True, + output_calibration_num_keypoints=5, + output_initialization=[-1.0, 1.0]) + model = premade.CalibratedLatticeEnsemble(model_config) + # Compile and fit model. + model.compile(loss='mse', optimizer=tf.keras.optimizers.Adam(0.1)) + model.fit(fake_data['train_xs'], fake_data['train_ys']) + # Save model using H5 format. + with tempfile.NamedTemporaryFile(suffix='.h5') as f: + tf.keras.models.save_model(model, f.name) + loaded_model = tf.keras.models.load_model( + f.name, custom_objects=premade.get_custom_objects()) + self.assertAllClose( + model.predict(fake_data['eval_xs']), + loaded_model.predict(fake_data['eval_xs'])) + def testLatticeH5FormatSaveLoad(self): model_config = configs.CalibratedLatticeConfig( feature_configs=copy.deepcopy(feature_configs), diff --git a/tensorflow_lattice/python/pwl_calibration_layer.py b/tensorflow_lattice/python/pwl_calibration_layer.py index 626b94e..b8e35c3 100644 --- a/tensorflow_lattice/python/pwl_calibration_layer.py +++ b/tensorflow_lattice/python/pwl_calibration_layer.py @@ -24,6 +24,7 @@ from __future__ import print_function from . import pwl_calibration_lib +from . import utils from absl import logging import numpy as np @@ -493,8 +494,7 @@ def assert_constraints(self, eps=1e-6): asserts = pwl_calibration_lib.assert_constraints( outputs=outputs, - monotonicity=pwl_calibration_lib.canonicalize_monotonicity( - self.monotonicity), + monotonicity=utils.canonicalize_monotonicity(self.monotonicity), output_min=self.output_min, output_max=self.output_max, clamp_min=self.clamp_min, @@ -584,8 +584,7 @@ def __call__(self, shape, dtype=None, partition_info=None): shape=shape, output_min=self.output_min, output_max=self.output_max, - monotonicity=pwl_calibration_lib.canonicalize_monotonicity( - self.monotonicity), + monotonicity=utils.canonicalize_monotonicity(self.monotonicity), keypoints=self.keypoints, dtype=dtype) @@ -652,10 +651,8 @@ def __init__( self.output_max_constraints = output_max_constraints self.num_projection_iterations = num_projection_iterations - canonical_convexity = pwl_calibration_lib.canonicalize_convexity( - self.convexity) - canonical_monotonicity = pwl_calibration_lib.canonicalize_monotonicity( - self.monotonicity) + canonical_convexity = utils.canonicalize_convexity(self.convexity) + canonical_monotonicity = utils.canonicalize_monotonicity(self.monotonicity) if (canonical_convexity != 0 and canonical_monotonicity == 0 and (output_min_constraints != pwl_calibration_lib.BoundConstraintsType.NONE or output_max_constraints != @@ -670,14 +667,12 @@ def __call__(self, w): """Applies constraints to w.""" return pwl_calibration_lib.project_all_constraints( weights=w, - monotonicity=pwl_calibration_lib.canonicalize_monotonicity( - self.monotonicity), + monotonicity=utils.canonicalize_monotonicity(self.monotonicity), output_min=self.output_min, output_max=self.output_max, output_min_constraints=self.output_min_constraints, output_max_constraints=self.output_max_constraints, - convexity=pwl_calibration_lib.canonicalize_convexity( - self.convexity), + convexity=utils.canonicalize_convexity(self.convexity), lengths=self.lengths, num_projection_iterations=self.num_projection_iterations) diff --git a/tensorflow_lattice/python/pwl_calibration_lib.py b/tensorflow_lattice/python/pwl_calibration_lib.py index e859bbc..659bd22 100644 --- a/tensorflow_lattice/python/pwl_calibration_lib.py +++ b/tensorflow_lattice/python/pwl_calibration_lib.py @@ -19,12 +19,13 @@ import collections import copy -from enum import Enum +import enum +from . import utils import six import tensorflow as tf -class BoundConstraintsType(Enum): +class BoundConstraintsType(enum.Enum): """Type of bound constraints for PWL calibration. - NONE: no constraints. @@ -918,8 +919,8 @@ def verify_hyperparameters(input_keypoints=None, "They are: ({}, {})".format(output_min, output_max)) # It also raises errors if monotonicities specified incorrectly. - monotonicity = canonicalize_monotonicity(monotonicity) - convexity = canonicalize_convexity(convexity) + monotonicity = utils.canonicalize_monotonicity(monotonicity) + convexity = utils.canonicalize_convexity(convexity) if is_cyclic and (monotonicity or convexity): raise ValueError("'is_cyclic' can not be specified together with " @@ -940,59 +941,3 @@ def verify_hyperparameters(input_keypoints=None, raise ValueError("Number of lengths must be equal to number of weights " "minus one. Lengths: %s, weights_shape: %s" % (lengths, weights_shape)) - - -def canonicalize_monotonicity(monotonicity): - """Converts string constants representing monotonicity into integers. - - Args: - monotonicity: monotonicity hyperparameter of `PWLCalibration` layer. - - Raises: - ValueError if monotonicity is invalid. - - Returns: - monotonicity represented as -1, 0 or 1. - """ - if monotonicity is None: - return None - - if monotonicity in [-1, 0, 1]: - return monotonicity - elif isinstance(monotonicity, six.string_types): - if monotonicity.lower() == "decreasing": - return -1 - if monotonicity.lower() == "none": - return 0 - if monotonicity.lower() == "increasing": - return 1 - raise ValueError("'monotonicities' must be from: [-1, 0, 1, 'decreasing', " - "'none', 'increasing']. Given: %s" % monotonicity) - - -def canonicalize_convexity(convexity): - """Converts string constants representing convexity into integers. - - Args: - convexity: convexity hyperparameter of `PWLCalibration` layer. - - Raises: - ValueError if convexity is invalid. - - Returns: - convexity represented as -1, 0 or 1. - """ - if convexity is None: - return None - - if convexity in [-1, 0, 1]: - return convexity - elif isinstance(convexity, six.string_types): - if convexity.lower() == "concave": - return -1 - if convexity.lower() == "none": - return 0 - if convexity.lower() == "convex": - return 1 - raise ValueError("'convexity' must be from: [-1, 0, 1, 'concave', " - "'none', 'convex']. Given: %s" % convexity) diff --git a/tensorflow_lattice/python/pwl_calibration_sonnet_module.py b/tensorflow_lattice/python/pwl_calibration_sonnet_module.py index fbc044a..c3ca32f 100644 --- a/tensorflow_lattice/python/pwl_calibration_sonnet_module.py +++ b/tensorflow_lattice/python/pwl_calibration_sonnet_module.py @@ -26,6 +26,7 @@ from __future__ import print_function from . import pwl_calibration_lib +from . import utils from absl import logging import numpy as np @@ -312,8 +313,8 @@ def __call__(self, inputs): if len(inputs.shape) != 2 or (inputs.shape[1] != self.units and inputs.shape[1] != 1): raise ValueError("Shape of input tensor for PWLCalibration module must" - " be [-1, units] or [-1, 1]. It is: " - + str(inputs.shape)) + " be [-1, units] or [-1, 1]. It is: " + + str(inputs.shape)) if self._interpolation_keypoints.dtype != inputs.dtype: raise ValueError("dtype(%s) of input to PWLCalibration module does not " @@ -335,8 +336,7 @@ def __call__(self, inputs): # Need to add such last height to make all heights to sum up to 0.0 in # order to make calibrator cyclic. bias_and_heights = tf.concat( - [self.kernel, -tf.reduce_sum(self.kernel[1:], - axis=0, keepdims=True)], + [self.kernel, -tf.reduce_sum(self.kernel[1:], axis=0, keepdims=True)], axis=0) else: bias_and_heights = self.kernel @@ -424,8 +424,7 @@ def __call__(self, shape, dtype): shape=shape, output_min=self.output_min, output_max=self.output_max, - monotonicity=pwl_calibration_lib.canonicalize_monotonicity( - self.monotonicity), + monotonicity=utils.canonicalize_monotonicity(self.monotonicity), keypoints=self.keypoints, dtype=dtype) @@ -483,10 +482,8 @@ def __init__( self.output_max_constraints = output_max_constraints self.num_projection_iterations = num_projection_iterations - canonical_convexity = pwl_calibration_lib.canonicalize_convexity( - self.convexity) - canonical_monotonicity = pwl_calibration_lib.canonicalize_monotonicity( - self.monotonicity) + canonical_convexity = utils.canonicalize_convexity(self.convexity) + canonical_monotonicity = utils.canonicalize_monotonicity(self.monotonicity) if (canonical_convexity != 0 and canonical_monotonicity == 0 and (output_min_constraints != pwl_calibration_lib.BoundConstraintsType.NONE or output_max_constraints != @@ -501,14 +498,12 @@ def __call__(self, w): """Applies constraints to w.""" return pwl_calibration_lib.project_all_constraints( weights=w, - monotonicity=pwl_calibration_lib.canonicalize_monotonicity( - self.monotonicity), + monotonicity=utils.canonicalize_monotonicity(self.monotonicity), output_min=self.output_min, output_max=self.output_max, output_min_constraints=self.output_min_constraints, output_max_constraints=self.output_max_constraints, - convexity=pwl_calibration_lib.canonicalize_convexity( - self.convexity), + convexity=utils.canonicalize_convexity(self.convexity), lengths=self.lengths, num_projection_iterations=self.num_projection_iterations) diff --git a/tensorflow_lattice/python/pwl_calibration_test.py b/tensorflow_lattice/python/pwl_calibration_test.py index 8e6d70c..bd3d6ae 100644 --- a/tensorflow_lattice/python/pwl_calibration_test.py +++ b/tensorflow_lattice/python/pwl_calibration_test.py @@ -33,9 +33,10 @@ from tensorflow import keras from tensorflow_lattice.python import parallel_combination_layer as parallel_combination from tensorflow_lattice.python import pwl_calibration_layer as keras_layer -from tensorflow_lattice.python import pwl_calibration_sonnet_module as sonnet_module from tensorflow_lattice.python import pwl_calibration_lib as pwl_lib +from tensorflow_lattice.python import pwl_calibration_sonnet_module as sonnet_module from tensorflow_lattice.python import test_utils +from tensorflow_lattice.python import utils class CalibrateWithSeparateMissing(tf.keras.layers.Layer): @@ -51,8 +52,8 @@ def __init__(self, calibration_layer, missing_input_value): self.missing_input_value = missing_input_value def call(self, x): - is_missing = tf.cast(tf.equal(x, self.missing_input_value), - dtype=tf.float32) + is_missing = tf.cast( + tf.equal(x, self.missing_input_value), dtype=tf.float32) return self.calibration_layer([x, is_missing]) @@ -156,10 +157,12 @@ def _SetDefaults(self, config): # If "input_keypoints" are provided - other params referred by code below # might be not available, so we make sure it exists before executing # this code. - config.setdefault("input_keypoints", - np.linspace(start=config["input_min"], - stop=config["input_max"], - num=config["num_keypoints"])) + config.setdefault( + "input_keypoints", + np.linspace( + start=config["input_min"], + stop=config["input_max"], + num=config["num_keypoints"])) return config def _TrainModel(self, config, plot_path=None): @@ -229,17 +232,19 @@ def _TrainModel(self, config, plot_path=None): num_projection_iterations=config["num_projection_iterations"])) if len(calibration_layers) == 1: if config["use_separate_missing"]: - model.add(CalibrateWithSeparateMissing( - calibration_layer=calibration_layers[0], - missing_input_value=config["missing_input_value"])) + model.add( + CalibrateWithSeparateMissing( + calibration_layer=calibration_layers[0], + missing_input_value=config["missing_input_value"])) else: model.add(calibration_layers[0]) else: model.add(parallel_combination.ParallelCombination(calibration_layers)) if config["units"] > 1: - model.add(keras.layers.Lambda( - lambda x: tf.reduce_mean(x, axis=1, keepdims=True))) + model.add( + keras.layers.Lambda( + lambda x: tf.reduce_mean(x, axis=1, keepdims=True))) model.compile( loss=keras.losses.mean_squared_error, @@ -278,9 +283,9 @@ def _InverseAndTrain(self, config): inversed_config["clamp_min"] = config["clamp_max"] inversed_config["clamp_max"] = config["clamp_min"] - inversed_config["monotonicity"] = -pwl_lib.canonicalize_monotonicity( + inversed_config["monotonicity"] = -utils.canonicalize_monotonicity( config["monotonicity"]) - inversed_config["convexity"] = -pwl_lib.canonicalize_convexity( + inversed_config["convexity"] = -utils.canonicalize_convexity( config["convexity"]) inversed_loss = self._TrainModel(inversed_config) return inversed_loss @@ -304,7 +309,7 @@ def _CreateKerasLayer(self, config): # We use 'config["missing_input_value"]' to create the is_missing tensor, # and we want the model to use the is_missing tensor so we don't pass # a missing_input_value to the model. - missing_input_value=None + missing_input_value = None return keras_layer.PWLCalibration( input_keypoints=config["input_keypoints"], units=config["units"], @@ -329,7 +334,7 @@ def _CreateSonnetModule(self, config): # We use 'config["missing_input_value"]' to create the is_missing tensor, # and we want the model to use the is_missing tensor so we don't pass # a missing_input_value to the model. - missing_input_value=None + missing_input_value = None return sonnet_module.PWLCalibration( input_keypoints=config["input_keypoints"], units=config["units"], @@ -1111,14 +1116,14 @@ def testConvexityDifferentNumKeypoints(self, units, num_keypoints, @parameterized.parameters( (1, "increasing", None, 0.055837), (1, "decreasing", None, 0.046657), - (1, "none", 0.0, 0.027777), - (1, "increasing", 0.0, 0.065516), - (1, "decreasing", 0.0, 0.057453), + (1, "none", 0.0, 0.027777), + (1, "increasing", 0.0, 0.065516), + (1, "decreasing", 0.0, 0.057453), (3, "increasing", None, 0.022467), (3, "decreasing", None, 0.019012), - (3, "none", 0.0, 0.014693), - (3, "increasing", 0.0, 0.026284), - (3, "decreasing", 0.0, 0.025498), + (3, "none", 0.0, 0.014693), + (3, "increasing", 0.0, 0.026284), + (3, "decreasing", 0.0, 0.025498), ) def testConvexityWithMonotonicityAndBounds(self, units, monotonicity, output_max, expected_loss): diff --git a/tensorflow_lattice/python/rtl_layer.py b/tensorflow_lattice/python/rtl_layer.py index 1ffc5fe..4eed46a 100644 --- a/tensorflow_lattice/python/rtl_layer.py +++ b/tensorflow_lattice/python/rtl_layer.py @@ -29,15 +29,20 @@ import itertools from . import lattice_layer +from . import rtl_lib from absl import logging import numpy as np +import six import tensorflow as tf from tensorflow import keras _MAX_RTL_SWAPS = 10000 _RTLInput = collections.namedtuple('_RTLInput', ['monotonicity', 'group', 'input_index']) +RTL_LATTICE_NAME = 'rtl_lattice' +INPUTS_FOR_UNITS_PREFIX = 'inputs_for_lattice' +RTL_CONCAT_NAME = 'rtl_concat' class RTL(keras.layers.Layer): @@ -180,19 +185,28 @@ def __init__(self, `tfl.lattice_layer.RandomMonotonicInitializer` class docstring for more details. kernel_regularizer: None or a single element or a list of following: - - Tuple `('torsion', l1, l2)` where l1 and l2 represent corresponding - regularization amount for graph Torsion regularizer. l1 and l2 can - either be single floats or lists of floats to specify different - regularization amount for every dimension. - - Tuple `('laplacian', l1, l2)` where l1 and l2 represent corresponding - regularization amount for graph Laplacian regularizer. l1 and l2 can - either be single floats or lists of floats to specify different - regularization amount for every dimension. + - Tuple `('torsion', l1, l2)` or List `['torsion', l1, l2]` where l1 and + l2 represent corresponding regularization amount for graph Torsion + regularizer. l1 and l2 must be single floats. Lists of floats to + specify different regularization amount for every dimension is not + currently supported. + - Tuple `('laplacian', l1, l2)` or List `['laplacian', l1, l2]` where l1 + and l2 represent corresponding regularization amount for graph + Laplacian regularizer. l1 and l2 must be single floats. Lists of + floats to specify different regularization amount for every dimension + is not currently supported. **kwargs: Other args passed to `tf.keras.layers.Layer` initializer. Raises: ValueError: If layer hyperparameters are invalid. """ + # pyformat: enable + rtl_lib.verify_hyperparameters( + lattice_size=lattice_size, + output_min=output_min, + output_max=output_max, + kernel_regularizer=kernel_regularizer, + interpolation=interpolation) super(RTL, self).__init__(**kwargs) self.num_lattices = num_lattices self.lattice_rank = lattice_rank @@ -211,24 +225,48 @@ def __init__(self, def build(self, input_shape): """Standard Keras build() method.""" + rtl_lib.verify_hyperparameters( + lattice_size=self.lattice_size, input_shape=input_shape) + # Convert kernel regularizers to proper form (tuples). + kernel_regularizer = self.kernel_regularizer + if isinstance(self.kernel_regularizer, list): + if isinstance(self.kernel_regularizer[0], six.string_types): + kernel_regularizer = tuple(self.kernel_regularizer) + else: + kernel_regularizer = [tuple(r) for r in self.kernel_regularizer] self._rtl_structure = self._get_rtl_structure(input_shape) # dict from monotonicities to the lattice layers with those monotonicities. self._lattice_layers = {} for monotonicities, inputs_for_units in self._rtl_structure: - units = len(inputs_for_units) - self._lattice_layers[str(monotonicities)] = lattice_layer.Lattice( - lattice_sizes=[self.lattice_size] * self.lattice_rank, - units=units, - monotonicities=monotonicities, - output_min=self.output_min, - output_max=self.output_max, - num_projection_iterations=self.num_projection_iterations, - monotonic_at_every_step=self.monotonic_at_every_step, - clip_inputs=self.clip_inputs, - interpolation=self.interpolation, - kernel_initializer=self.kernel_initializer, - kernel_regularizer=self.kernel_regularizer, - ) + monotonicities_str = ''.join( + [str(monotonicity) for monotonicity in monotonicities]) + # Passthrough names for reconstructing model graph. + inputs_for_units_name = '{}_{}'.format(INPUTS_FOR_UNITS_PREFIX, + monotonicities_str) + # Use control dependencies to save inputs_for_units as graph constant for + # visualisation toolbox to be able to recover it from saved graph. + # Wrap this constant into pure op since in TF 2.0 there are issues passing + # tensors into control_dependencies. + with tf.control_dependencies([ + tf.constant( + inputs_for_units, dtype=tf.int32, name=inputs_for_units_name) + ]): + units = len(inputs_for_units) + layer_name = '{}_{}'.format(RTL_LATTICE_NAME, monotonicities_str) + self._lattice_layers[str(monotonicities)] = lattice_layer.Lattice( + lattice_sizes=[self.lattice_size] * self.lattice_rank, + units=units, + monotonicities=monotonicities, + output_min=self.output_min, + output_max=self.output_max, + num_projection_iterations=self.num_projection_iterations, + monotonic_at_every_step=self.monotonic_at_every_step, + clip_inputs=self.clip_inputs, + interpolation=self.interpolation, + kernel_initializer=self.kernel_initializer, + kernel_regularizer=kernel_regularizer, + name=layer_name, + ) super(RTL, self).build(input_shape) def call(self, x, **kwargs): @@ -248,7 +286,7 @@ def call(self, x, **kwargs): if len(input_tensors) == 1: flattened_input = input_tensors[0] else: - flattened_input = tf.concat(input_tensors, axis=1) + flattened_input = tf.concat(input_tensors, axis=1, name=RTL_CONCAT_NAME) # outputs_for_monotonicity[0] are non-monotonic outputs # outputs_for_monotonicity[1] are monotonic outputs @@ -393,8 +431,11 @@ def _get_rtl_structure(self, input_shape): for shape in shapes: for _ in range(shape[1]): - rtl_inputs.append(_RTLInput( - monotonicity=monotonicity, group=group, input_index=input_index)) + rtl_inputs.append( + _RTLInput( + monotonicity=monotonicity, + group=group, + input_index=input_index)) input_index += 1 group += 1 diff --git a/tensorflow_lattice/python/rtl_lib.py b/tensorflow_lattice/python/rtl_lib.py new file mode 100644 index 0000000..4aa0842 --- /dev/null +++ b/tensorflow_lattice/python/rtl_lib.py @@ -0,0 +1,90 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implementation of algorithms required for RTL layer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + + +def verify_hyperparameters(lattice_size, + input_shape=None, + output_min=None, + output_max=None, + kernel_regularizer=None, + interpolation="hypercube"): + """Verifies that all given hyperparameters are consistent. + + See `tfl.layers.RTL` class level comment for detailed description of + arguments. + + Args: + lattice_size: Lattice size to check againts. + input_shape: Shape of layer input. + output_min: Minimum output of `RTL` layer. + output_max: Maximum output of `RTL` layer. + kernel_regularizer: Regularizers to check against. + interpolation: One of 'simplex' or 'hypercube' interpolation. + + Raises: + ValueError: If lattice_size < 2. + KeyError: If input_shape is a dict with incorrect keys. + ValueError: If output_min >= output_max. + ValueError: If kernel_regularizer contains a tuple with len != 3. + ValueError: If kernel_regularizer contains a tuple with non-float l1 value. + ValueError: If kernel_regularizer contains a tuple with non-flaot l2 value. + ValueError: If interpolation is not one of 'simplex' or 'hypercube'. + + """ + if lattice_size < 2: + raise ValueError( + "Lattice size must be at least 2. Given: {}".format(lattice_size)) + + if input_shape: + if isinstance(input_shape, dict): + for key in input_shape: + if key not in ["unconstrained", "increasing"]: + raise KeyError("Input shape keys should be either 'unconstrained' " + "or 'increasing', but seeing: {}".format(key)) + + if output_min is not None and output_max is not None: + if output_min >= output_max: + raise ValueError("'output_min' must be not greater than 'output_max'. " + "'output_min': %f, 'output_max': %f" % + (output_min, output_max)) + + if kernel_regularizer: + if isinstance(kernel_regularizer, list): + regularizers = kernel_regularizer + if isinstance(kernel_regularizer[0], six.string_types): + regularizers = [kernel_regularizer] + for regularizer in regularizers: + if len(regularizer) != 3: + raise ValueError("Regularizer tuples/lists must have three elements " + "(type, l1, and l2). Given: {}".format(regularizer)) + _, l1, l2 = regularizer + if not isinstance(l1, float): + raise ValueError( + "Regularizer l1 must be a single float. Given: {}".format( + type(l1))) + if not isinstance(l2, float): + raise ValueError( + "Regularizer l2 must be a single float. Given: {}".format( + type(l2))) + + if interpolation not in ["hypercube", "simplex"]: + raise ValueError("RTL interpolation type should be either 'simplex' " + "or 'hypercube': %s" % interpolation) diff --git a/tensorflow_lattice/python/utils.py b/tensorflow_lattice/python/utils.py new file mode 100644 index 0000000..aeef9d5 --- /dev/null +++ b/tensorflow_lattice/python/utils.py @@ -0,0 +1,242 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Helpers shared by multiple modules in TFL.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + + +# TODO: update library not to explicitly check if None so we can return +# an empty list instead of None for these canonicalization methods. +def canonicalize_convexity(convexity): + """Converts string constants representing convexity into integers. + + Args: + convexity: The convexity hyperparameter of `tfl.layers.PWLCalibration` + layer. + + Returns: + convexity represented as -1, 0, 1, or None. + + Raises: + ValueError: If convexity is not in the set + {-1, 0, 1, 'concave', 'none', 'convex'}. + """ + if convexity is None: + return None + + if convexity in [-1, 0, 1]: + return convexity + elif isinstance(convexity, six.string_types): + if convexity.lower() == "concave": + return -1 + if convexity.lower() == "none": + return 0 + if convexity.lower() == "convex": + return 1 + raise ValueError("'convexity' must be from: [-1, 0, 1, 'concave', " + "'none', 'convex']. Given: {}".format(convexity)) + + +def canonicalize_input_bounds(input_bounds): + """Converts string constant 'none' representing unspecified bound into None. + + Args: + input_bounds: The input_min or input_max hyperparameter of + `tfl.layers.Linear` layer. + + Returns: + A list of [val, val, ...] where val can be a float or None, or the value + None if input_bounds is None. + + Raises: + ValueError: If one of elements in input_bounds is not a float, None or + 'none'. + """ + if input_bounds: + canonicalized = [] + for item in input_bounds: + if isinstance(item, float) or item is None: + canonicalized.append(item) + elif isinstance(item, six.string_types) and item.lower() == "none": + canonicalized.append(None) + else: + raise ValueError("Both 'input_min' and 'input_max' elements must be " + "either int, float, None, or 'none'. Given: {}".format( + input_bounds)) + return canonicalized + return None + + +def canonicalize_monotonicity(monotonicity, allow_decreasing=True): + """Converts string constants representing monotonicity into integers. + + Args: + monotonicity: The monotonicities hyperparameter of a `tfl.layers` Layer + (e.g. `tfl.layers.PWLCalibration`). + allow_decreasing: If decreasing monotonicity is considered a valid + monotonicity. + + Returns: + monotonicity represented as -1, 0, 1, or None. + + Raises: + ValueError: If monotonicity is not in the set + {-1, 0, 1, 'decreasing', 'none', 'increasing'} and allow_decreasing is + True. + ValueError: If monotonicity is not in the set {0, 1, 'none', 'increasing'} + and allow_decreasing is False. + """ + if monotonicity is None: + return None + + if monotonicity in [-1, 0, 1]: + if not allow_decreasing and monotonicity == -1: + raise ValueError( + "'monotonicities' must be from: [0, 1, 'none', 'increasing']. " + "Given: {}".format(monotonicity)) + return monotonicity + elif isinstance(monotonicity, six.string_types): + if monotonicity.lower() == "decreasing": + if not allow_decreasing: + raise ValueError( + "'monotonicities' must be from: [0, 1, 'none', 'increasing']. " + "Given: {}".format(monotonicity)) + return -1 + if monotonicity.lower() == "none": + return 0 + if monotonicity.lower() == "increasing": + return 1 + raise ValueError("'monotonicities' must be from: [-1, 0, 1, 'decreasing', " + "'none', 'increasing']. Given: {}".format(monotonicity)) + + +def canonicalize_monotonicities(monotonicities, allow_decreasing=True): + """Converts string constants representing monotonicities into integers. + + Args: + monotonicities: monotonicities hyperparameter of a `tfl.layers` Layer (e.g. + `tfl.layers.Lattice`). + allow_decreasing: If decreasing monotonicity is considered a valid + monotonicity. + + Returns: + A list of monotonicities represented as -1, 0, 1, or the value None + if monotonicities is None. + + Raises: + ValueError: If one of monotonicities is not in the set + {-1, 0, 1, 'decreasing', 'none', 'increasing'} and allow_decreasing is + True. + ValueError: If one of monotonicities is not in the set + {0, 1, 'none', 'increasing'} and allow_decreasing is False. + """ + if monotonicities: + return [ + canonicalize_monotonicity( + monotonicity, allow_decreasing=allow_decreasing) + for monotonicity in monotonicities + ] + return None + + +def canonicalize_trust(trusts): + """Converts string constants representing trust direction into integers. + + Args: + trusts: edgeworth_trusts or trapezoid_trusts hyperparameter of + `tfl.layers.Lattice` layer. + + Returns: + A list of trust constraint tuples of the form + (feature_a, feature_b, direction) where direction can be -1 or 1, or the + value None if trusts is None. + + Raises: + ValueError: If one of trust constraints does not have 3 elements. + ValueError: If one of trust constraints' direction is not in the set + {-1, 1, 'negative', 'positive'}. + """ + if trusts: + canonicalized = [] + for trust in trusts: + if len(trust) != 3: + raise ValueError("Trust constraints must consist of 3 elements. Seeing " + "constraint tuple {}".format(trust)) + feature_a, feature_b, direction = trust + if direction in [-1, 1]: + canonicalized.append(trust) + elif (isinstance(direction, six.string_types) and + direction.lower() == "negative"): + canonicalized.append((feature_a, feature_b, -1)) + elif (isinstance(direction, six.string_types) and + direction.lower() == "positive"): + canonicalized.append((feature_a, feature_b, 1)) + else: + raise ValueError("trust constraint direction must be from: [-1, 1, " + "'negative', 'positive']. Given: {}".format(direction)) + return canonicalized + return None + + +def canonicalize_unimodalities(unimodalities): + """Converts string constants representing unimodalities into integers. + + Args: + unimodalities: unimodalities hyperparameter of `tfl.layers.Lattice` layer. + + Returns: + A list of unimodalities represented as -1, 0, 1, or the value None if + unimodalities is None. + + Raises: + ValueError: If one of unimodalities is not in the set + {-1, 0, 1, 'peak', 'none', 'valley'}. + """ + if not unimodalities: + return None + canonicalized = [] + for unimodality in unimodalities: + if unimodality in [-1, 0, 1]: + canonicalized.append(unimodality) + elif isinstance(unimodality, + six.string_types) and unimodality.lower() == "peak": + canonicalized.append(-1) + elif isinstance(unimodality, + six.string_types) and unimodality.lower() == "none": + canonicalized.append(0) + elif isinstance(unimodality, + six.string_types) and unimodality.lower() == "valley": + canonicalized.append(1) + else: + raise ValueError( + "'unimodalities' elements must be from: [-1, 0, 1, 'peak', 'none', " + "'valley']. Given: {}".format(unimodalities)) + return canonicalized + + +def count_non_zeros(*iterables): + """Returns total number of non 0 elements in given iterables. + + Args: + *iterables: Any number of the value None or iterables of numeric values. + """ + result = 0 + for iterable in iterables: + if iterable is not None: + result += sum(1 for element in iterable if element != 0) + return result diff --git a/tensorflow_lattice/python/utils_test.py b/tensorflow_lattice/python/utils_test.py new file mode 100644 index 0000000..7a2d067 --- /dev/null +++ b/tensorflow_lattice/python/utils_test.py @@ -0,0 +1,177 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for Tensorflow Lattice utility functions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import tensorflow as tf +from tensorflow_lattice.python import utils + + +class UtilsTest(parameterized.TestCase, tf.test.TestCase): + + @parameterized.parameters((-1, -1), (0, 0), (1, 1), ("concave", -1), + ("none", 0), ("convex", 1)) + def testCanonicalizeConvexity(self, convexity, + expected_canonicalized_convexity): + canonicalized_convexity = utils.canonicalize_convexity(convexity) + self.assertEqual(canonicalized_convexity, expected_canonicalized_convexity) + + @parameterized.parameters((-2), (0.5), (3), ("invalid_convexity"), + ("concaves"), ("nonw"), ("conve")) + def testInvalidConvexity(self, invalid_convexity): + error_message = ( + "'convexity' must be from: [-1, 0, 1, 'concave', 'none', 'convex']. " + "Given: {}").format(invalid_convexity) + with self.assertRaisesWithLiteralMatch(ValueError, error_message): + utils.canonicalize_convexity(invalid_convexity) + + # Note: must use mapping format because otherwise input parameter list is + # considered multiple parameters (not just a single list parameter). + @parameterized.parameters( + { + "input_bounds": [0.0, -3.0], + "expected_canonicalized_input_bounds": [0.0, -3.0] + }, { + "input_bounds": [float("-inf"), 0.12345], + "expected_canonicalized_input_bounds": [float("-inf"), 0.12345] + }, { + "input_bounds": ["none", None], + "expected_canonicalized_input_bounds": [None, None] + }) + def testCanonicalizeInputBounds(self, input_bounds, + expected_canonicalized_input_bounds): + canonicalized_input_bounds = utils.canonicalize_input_bounds(input_bounds) + self.assertAllEqual(canonicalized_input_bounds, + expected_canonicalized_input_bounds) + + @parameterized.parameters({"invalid_input_bounds": [0, 1.0, 2.0]}, + {"invalid_input_bounds": [None, "nonw"]}) + def testInvalidInputBounds(self, invalid_input_bounds): + error_message = ( + "Both 'input_min' and 'input_max' elements must be either int, float, " + "None, or 'none'. Given: {}").format(invalid_input_bounds) + with self.assertRaisesWithLiteralMatch(ValueError, error_message): + utils.canonicalize_input_bounds(invalid_input_bounds) + + @parameterized.parameters((-1, -1), (0, 0), (1, 1), ("decreasing", -1), + ("none", 0), ("increasing", 1)) + def testCanonicalizeMonotonicity(self, monotonicity, + expected_canonicalized_monotonicity): + canonicalized_monotonicity = utils.canonicalize_monotonicity(monotonicity) + self.assertEqual(canonicalized_monotonicity, + expected_canonicalized_monotonicity) + + @parameterized.parameters((-2), (0.5), (3), ("invalid_monotonicity"), + ("decrease"), ("increase")) + def testInvalidMonotonicity(self, invalid_monotonicity): + error_message = ( + "'monotonicities' must be from: [-1, 0, 1, 'decreasing', 'none', " + "'increasing']. Given: {}").format(invalid_monotonicity) + with self.assertRaisesWithLiteralMatch(ValueError, error_message): + utils.canonicalize_monotonicity(invalid_monotonicity) + + @parameterized.parameters(("decreasing"), (-1)) + def testInvalidDecreasingMonotonicity(self, invalid_monotonicity): + error_message = ( + "'monotonicities' must be from: [0, 1, 'none', 'increasing']. " + "Given: {}").format(invalid_monotonicity) + with self.assertRaisesWithLiteralMatch(ValueError, error_message): + utils.canonicalize_monotonicity( + invalid_monotonicity, allow_decreasing=False) + + # Note: since canonicalize_monotonicities calls canonicalize_monotonicity, + # the above test for invalidity is sufficient. + @parameterized.parameters(([-1, 0, 1], [-1, 0, 1]), + (["decreasing", "none", "increasing"], [-1, 0, 1]), + (["decreasing", -1], [-1, -1]), + (["none", 0], [0, 0]), (["increasing", 1], [1, 1])) + def testCanonicalizeMonotonicities(self, monotonicities, + expected_canonicalized_monotonicities): + canonicalized_monotonicities = utils.canonicalize_monotonicities( + monotonicities) + self.assertAllEqual(canonicalized_monotonicities, + expected_canonicalized_monotonicities) + + @parameterized.parameters(([("a", "b", -1), ("b", "c", 1)], [("a", "b", -1), + ("b", "c", 1)]), + ([("a", "b", "negative"), + ("b", "c", "positive")], [("a", "b", -1), + ("b", "c", 1)])) + def testCanonicalizeTrust(self, trusts, expected_canonicalized_trusts): + canonicalized_trusts = utils.canonicalize_trust(trusts) + self.assertAllEqual(canonicalized_trusts, expected_canonicalized_trusts) + + # Note 1: this test assumes the first trust in the list has the incorrect + # direction. A list with a single trust tuple is sufficient. + # Note 2: must use mapping format because otherwise input parameter list is + # considered multiple parameters (not just a single list parameter). + @parameterized.parameters({"invalid_trusts": [("a", "b", 0)]}, + {"invalid_trusts": [("a", "b", "negativ")]}) + def testInvalidTrustDirection(self, invalid_trusts): + error_message = ( + "trust constraint direction must be from: [-1, 1, 'negative', " + "'positive']. Given: {}").format(invalid_trusts[0][2]) + with self.assertRaisesWithLiteralMatch(ValueError, error_message): + utils.canonicalize_trust(invalid_trusts) + + # Note 1: this test assumes the first trust in the list has the incorrect + # size. A list with a single trust tuple is sufficient. + # Note 2: must use mapping format because otherwise input parameter list is + # considered multiple parameters (not just a single list parameter). + @parameterized.parameters({"invalid_trusts": [("a", 1)]}, + {"invalid_trusts": [("a", "b", -1, 1)]}) + def testInvalidTrustLength(self, invalid_trusts): + error_message = ( + "Trust constraints must consist of 3 elements. Seeing constraint " + "tuple {}").format(invalid_trusts[0]) + with self.assertRaisesWithLiteralMatch(ValueError, error_message): + utils.canonicalize_trust(invalid_trusts) + + @parameterized.parameters(([0, 1, 1, 0], [1, 0], 3), + ([0, 0, 0], [0, 0, 0], 0), + ([-1, 0, 0, 1], [0, 0], 2), + (None, [1, 1, 1, 1, 1], 5)) + def testCountNonZeros(self, monotonicities, unimodalities, + expected_non_zeros): + non_zeros = utils.count_non_zeros(monotonicities, unimodalities) + self.assertEqual(non_zeros, expected_non_zeros) + + @parameterized.parameters( + ([-1, 0, 1], [-1, 0, 1]), (["peak", "none", "valley"], [-1, 0, 1]), + (["peak", -1], [-1, -1]), (["none", 0], [0, 0]), (["valley", 1], [1, 1])) + def testCanonicalizeUnimodalities(self, unimodalities, + expected_canonicalized_unimodalities): + canonicalized_unimodalities = utils.canonicalize_unimodalities( + unimodalities) + self.assertAllEqual(canonicalized_unimodalities, + expected_canonicalized_unimodalities) + + # Note: must use mapping format because otherwise input parameter list is + # considered multiple parameters (not just a single list parameter). + @parameterized.parameters({"invalid_unimodalities": ["vally", 0]}, + {"invalid_unimodalities": [-1, 0, 2]}) + def testInvalidUnimoadlities(self, invalid_unimodalities): + error_message = ( + "'unimodalities' elements must be from: [-1, 0, 1, 'peak', 'none', " + "'valley']. Given: {}").format(invalid_unimodalities) + with self.assertRaisesWithLiteralMatch(ValueError, error_message): + utils.canonicalize_unimodalities(invalid_unimodalities) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_lattice/python/visualization.py b/tensorflow_lattice/python/visualization.py index 0763fc6..969452e 100644 --- a/tensorflow_lattice/python/visualization.py +++ b/tensorflow_lattice/python/visualization.py @@ -503,7 +503,6 @@ def plot_outputs(inputs, outputs_map, file_path=None, figsize=(20, 20)): Pyplot object containing visualisation. """ legend = [] - plt.cla() if isinstance(inputs, tuple): figure = plt.figure(figsize=figsize) axes = figure.gca(projection='3d')