From 1c75176947730de8322acf6ad996096625e92e3a Mon Sep 17 00:00:00 2001 From: TensorFlow Lattice Authors Date: Mon, 15 Jun 2020 15:41:53 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 316560241 Change-Id: I04a3766631c4193d850c7ad829efb241bf0c526a --- docs/_book.yaml | 10 +- .../tutorials/aggregate_function_models.ipynb | 614 ++++++++++++++++++ docs/tutorials/premade_models.ipynb | 9 +- setup.py | 3 +- tensorflow_lattice/BUILD | 2 + tensorflow_lattice/__init__.py | 2 + tensorflow_lattice/python/BUILD | 18 +- .../python/categorical_calibration_layer.py | 4 +- tensorflow_lattice/python/configs.py | 31 + tensorflow_lattice/python/estimators.py | 45 +- tensorflow_lattice/python/lattice_layer.py | 223 +++++-- tensorflow_lattice/python/lattice_lib.py | 188 +++++- tensorflow_lattice/python/lattice_test.py | 252 ++++++- tensorflow_lattice/python/linear_layer.py | 4 +- tensorflow_lattice/python/premade.py | 50 +- tensorflow_lattice/python/premade_lib.py | 53 +- tensorflow_lattice/python/premade_test.py | 151 ++++- .../python/pwl_calibration_layer.py | 8 +- .../python/pwl_calibration_lib.py | 4 +- .../python/pwl_calibration_sonnet_module.py | 543 ++++++++++++++++ .../python/pwl_calibration_test.py | 234 ++++++- tensorflow_lattice/python/rtl_layer.py | 78 ++- tensorflow_lattice/python/test_utils.py | 94 +++ tensorflow_lattice/sonnet_modules/__init__.py | 17 + 24 files changed, 2470 insertions(+), 167 deletions(-) create mode 100644 docs/tutorials/aggregate_function_models.ipynb create mode 100644 tensorflow_lattice/python/pwl_calibration_sonnet_module.py create mode 100644 tensorflow_lattice/sonnet_modules/__init__.py diff --git a/docs/_book.yaml b/docs/_book.yaml index a029800..686b055 100644 --- a/docs/_book.yaml +++ b/docs/_book.yaml @@ -18,18 +18,20 @@ upper_tabs: - title: Install path: /lattice/install - heading: Tutorials - - title: Shape constraints + - title: Shape Constraints path: /lattice/tutorials/shape_constraints - - title: Ethical constraints for ML fairness + - title: Ethical Constraints for ML Fairness path: /lattice/tutorials/shape_constraints_for_ethics - - title: Keras layers + - title: Keras Layers and Custom Models path: /lattice/tutorials/keras_layers - - title: Keras premade models + - title: Keras Premade Models path: /lattice/tutorials/premade_models - title: Canned Estimators path: /lattice/tutorials/canned_estimators - title: Custom Estimators path: /lattice/tutorials/custom_estimators + - title: Aggregate Function Models + path: /lattice/tutorials/aggregate_function_models - name: API skip_translation: true diff --git a/docs/tutorials/aggregate_function_models.ipynb b/docs/tutorials/aggregate_function_models.ipynb new file mode 100644 index 0000000..dc5b2e9 --- /dev/null +++ b/docs/tutorials/aggregate_function_models.ipynb @@ -0,0 +1,614 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "RYmPh1qB_KO2" + }, + "source": [ + "##### Copyright 2020 The TensorFlow Authors." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "oMRm3czy9tLh" + }, + "outputs": [], + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ooXoR4kx_YL9" + }, + "source": [ + "# TF Lattice Aggregate Function Models" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "BR6XNYEXEgSU" + }, + "source": [ + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/lattice/tutorials/aggregate_function_learning_models\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/lattice/blob/master/docs/tutorials/aggregate_function_learning_models.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/lattice/blob/master/docs/tutorials/aggregate_function_learning_models.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/lattice/docs/tutorials/aggregate_function_learning_models.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n", + " \u003c/td\u003e\n", + "\u003c/table\u003e" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "-ZfQWUmfEsyZ" + }, + "source": [ + "## Overview\n", + "\n", + "TFL Premade Aggregate Function Models are quick and easy ways to build TFL `tf.keras.model` instances for learning complex aggregation functions. This guide outlines the steps needed to construct a TFL Premade Aggregate Function Model and train/test it. " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "L0lgWoB6Gmk1" + }, + "source": [ + "## Setup\n", + "\n", + "Installing TF Lattice package:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "ivwKrEdLGphZ" + }, + "outputs": [], + "source": [ + "#@test {\"skip\": true}\n", + "!pip install tensorflow-lattice pydot" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "VQsRKS4wGrMu" + }, + "source": [ + "Importing required packages:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "j41-kd4MGtDS" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "\n", + "import collections\n", + "import logging\n", + "import numpy as np\n", + "import pandas as pd\n", + "import sys\n", + "import tensorflow_lattice as tfl\n", + "logging.disable(sys.maxsize)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ZHPohKjBIFG5" + }, + "source": [ + "Downloading the Puzzles dataset:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "VjYHpw2dSfHH" + }, + "outputs": [], + "source": [ + "train_dataframe = pd.read_csv(\n", + " 'https://github.com/raw/wbakst/puzzles_data/master/train.csv')\n", + "train_dataframe.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "UOsgu3eIEur6" + }, + "outputs": [], + "source": [ + "test_dataframe = pd.read_csv(\n", + " 'https://github.com/raw/wbakst/puzzles_data/master/test.csv')\n", + "test_dataframe.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "XG7MPCyzVr22" + }, + "source": [ + "Extract and convert features and labels" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "bYdJicq5bBuz" + }, + "outputs": [], + "source": [ + "# Features:\n", + "# - star_rating rating out of 5 stars (1-5)\n", + "# - word_count number of words in the review\n", + "# - is_amazon 1 = reviewed on amazon; 0 = reviewed on artifact website\n", + "# - includes_photo if the review includes a photo of the puzzle\n", + "# - num_helpful number of people that found this review helpful\n", + "# - num_reviews total number of reviews for this puzzle (we construct)\n", + "#\n", + "# This ordering of feature names will be the exact same order that we construct\n", + "# our model to expect.\n", + "feature_names = [\n", + " 'star_rating', 'word_count', 'is_amazon', 'includes_photo', 'num_helpful',\n", + " 'num_reviews'\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "kx0ZX2HR-4qb" + }, + "outputs": [], + "source": [ + "def extract_features(dataframe, label_name):\n", + " # First we extract flattened features.\n", + " flattened_features = {\n", + " feature_name: dataframe[feature_name].values.astype(float)\n", + " for feature_name in feature_names[:-1]\n", + " }\n", + "\n", + " # Construct mapping from puzzle name to feature.\n", + " star_rating = collections.defaultdict(list)\n", + " word_count = collections.defaultdict(list)\n", + " is_amazon = collections.defaultdict(list)\n", + " includes_photo = collections.defaultdict(list)\n", + " num_helpful = collections.defaultdict(list)\n", + " labels = {}\n", + "\n", + " # Extract each review.\n", + " for i in range(len(dataframe)):\n", + " row = dataframe.iloc[i]\n", + " puzzle_name = row['puzzle_name']\n", + " star_rating[puzzle_name].append(float(row['star_rating']))\n", + " word_count[puzzle_name].append(float(row['word_count']))\n", + " is_amazon[puzzle_name].append(float(row['is_amazon']))\n", + " includes_photo[puzzle_name].append(float(row['includes_photo']))\n", + " num_helpful[puzzle_name].append(float(row['num_helpful']))\n", + " labels[puzzle_name] = float(row[label_name])\n", + "\n", + " # Organize data into list of list of features.\n", + " names = list(star_rating.keys())\n", + " star_rating = [star_rating[name] for name in names]\n", + " word_count = [word_count[name] for name in names]\n", + " is_amazon = [is_amazon[name] for name in names]\n", + " includes_photo = [includes_photo[name] for name in names]\n", + " num_helpful = [num_helpful[name] for name in names]\n", + " num_reviews = [[len(ratings)] * len(ratings) for ratings in star_rating]\n", + " labels = [labels[name] for name in names]\n", + "\n", + " # Flatten num_reviews\n", + " flattened_features['num_reviews'] = [len(reviews) for reviews in num_reviews]\n", + "\n", + " # Convert data into ragged tensors.\n", + " star_rating = tf.ragged.constant(star_rating)\n", + " word_count = tf.ragged.constant(word_count)\n", + " is_amazon = tf.ragged.constant(is_amazon)\n", + " includes_photo = tf.ragged.constant(includes_photo)\n", + " num_helpful = tf.ragged.constant(num_helpful)\n", + " num_reviews = tf.ragged.constant(num_reviews)\n", + " labels = tf.constant(labels)\n", + "\n", + " # Now we can return our extracted data.\n", + " return (star_rating, word_count, is_amazon, includes_photo, num_helpful,\n", + " num_reviews), labels, flattened_features" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Nd6j_J5CbNiz" + }, + "outputs": [], + "source": [ + "train_xs, train_ys, flattened_features = extract_features(train_dataframe, 'Sales12-18MonthsAgo')\n", + "test_xs, test_ys, _ = extract_features(test_dataframe, 'SalesLastSixMonths')" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "KfHHhCRsHejl" + }, + "outputs": [], + "source": [ + "# Let's define our label minimum and maximum.\n", + "min_label, max_label = float(np.min(train_ys)), float(np.max(train_ys))\n", + "min_label, max_label = float(np.min(train_ys)), float(np.max(train_ys))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "9TwqlRirIhAq" + }, + "source": [ + "Setting the default values used for training in this guide:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "GckmXFzRIhdD" + }, + "outputs": [], + "source": [ + "LEARNING_RATE = 0.1\n", + "BATCH_SIZE = 128\n", + "NUM_EPOCHS = 500\n", + "MIDDLE_DIM = 3\n", + "MIDDLE_LATTICE_SIZE = 2\n", + "MIDDLE_KEYPOINTS = 16\n", + "OUTPUT_KEYPOINTS = 8" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "TpDKon4oIh2W" + }, + "source": [ + "## Feature Configs\n", + "\n", + "Feature calibration and per-feature configurations are set using [tfl.configs.FeatureConfig](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs/FeatureConfig). Feature configurations include monotonicity constraints, per-feature regularization (see [tfl.configs.RegularizerConfig](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs/RegularizerConfig)), and lattice sizes for lattice models.\n", + "\n", + "Note that we must fully specify the feature config for any feature that we want our model to recognize. Otherwise the model will have no way of knowing that such a feature exists. For aggregation models, these features will automaticaly be considered and properly handled as ragged." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "_IMwcDh7Xs5n" + }, + "source": [ + "### Compute Quantiles\n", + "\n", + "Although the default setting for `pwl_calibration_input_keypoints` in `tfl.configs.FeatureConfig` is 'quantiles', for premade models we have to manually define the input keypoints. To do so, we first define our own helper function for computing quantiles." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "l0uYl9ZpXtW1" + }, + "outputs": [], + "source": [ + "def compute_quantiles(features,\n", + " num_keypoints=10,\n", + " clip_min=None,\n", + " clip_max=None,\n", + " missing_value=None):\n", + " # Clip min and max if desired.\n", + " if clip_min is not None:\n", + " features = np.maximum(features, clip_min)\n", + " features = np.append(features, clip_min)\n", + " if clip_max is not None:\n", + " features = np.minimum(features, clip_max)\n", + " features = np.append(features, clip_max)\n", + " # Make features unique.\n", + " unique_features = np.unique(features)\n", + " # Remove missing values if specified.\n", + " if missing_value is not None:\n", + " unique_features = np.delete(unique_features,\n", + " np.where(unique_features == missing_value))\n", + " # Compute and return quantiles over unique non-missing feature values.\n", + " return np.quantile(\n", + " unique_features,\n", + " np.linspace(0., 1., num=num_keypoints),\n", + " interpolation='nearest').astype(float)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "9oYZdVeWEhf2" + }, + "source": [ + "### Defining Our Feature Configs\n", + "\n", + "Now that we can compute our quantiles, we define a feature config for each feature that we want our model to take as input." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "rEYlSXhTEmoh" + }, + "outputs": [], + "source": [ + "# Feature configs are used to specify how each feature is calibrated and used.\n", + "feature_configs = [\n", + " tfl.configs.FeatureConfig(\n", + " name='star_rating',\n", + " lattice_size=2,\n", + " monotonicity='increasing',\n", + " pwl_calibration_num_keypoints=5,\n", + " pwl_calibration_input_keypoints=compute_quantiles(\n", + " flattened_features['star_rating'], num_keypoints=5),\n", + " ),\n", + " tfl.configs.FeatureConfig(\n", + " name='word_count',\n", + " lattice_size=2,\n", + " monotonicity='increasing',\n", + " pwl_calibration_num_keypoints=5,\n", + " pwl_calibration_input_keypoints=compute_quantiles(\n", + " flattened_features['word_count'], num_keypoints=5),\n", + " ),\n", + " tfl.configs.FeatureConfig(\n", + " name='is_amazon',\n", + " lattice_size=2,\n", + " num_buckets=2,\n", + " ),\n", + " tfl.configs.FeatureConfig(\n", + " name='includes_photo',\n", + " lattice_size=2,\n", + " num_buckets=2,\n", + " ),\n", + " tfl.configs.FeatureConfig(\n", + " name='num_helpful',\n", + " lattice_size=2,\n", + " monotonicity='increasing',\n", + " pwl_calibration_num_keypoints=5,\n", + " pwl_calibration_input_keypoints=compute_quantiles(\n", + " flattened_features['num_helpful'], num_keypoints=5),\n", + " # Larger num_helpful indicating more trust in star_rating.\n", + " reflects_trust_in=[\n", + " tfl.configs.TrustConfig(\n", + " feature_name=\"star_rating\", trust_type=\"trapezoid\"),\n", + " ],\n", + " ),\n", + " tfl.configs.FeatureConfig(\n", + " name='num_reviews',\n", + " lattice_size=2,\n", + " monotonicity='increasing',\n", + " pwl_calibration_num_keypoints=5,\n", + " pwl_calibration_input_keypoints=compute_quantiles(\n", + " flattened_features['num_reviews'], num_keypoints=5),\n", + " )\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "9zoPJRBvPdcH" + }, + "source": [ + "## Aggregate Function Model\n", + "\n", + "To construct a TFL premade model, first construct a model configuration from [tfl.configs](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs). An aggregate function model is constructed using the [tfl.configs.AggregateFunctionConfig](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs/AggregateFunctionConfig). It applies piecewise-linear and categorical calibration, followed by a lattice model on each dimension of the ragged input. It then applies an aggregation layer over the output for each dimension. This is then followed by an optional output piecewise-lienar calibration." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "l_4J7EjSPiP3" + }, + "outputs": [], + "source": [ + "# Model config defines the model structure for the aggregate function model.\n", + "aggregate_function_model_config = tfl.configs.AggregateFunctionConfig(\n", + " feature_configs=feature_configs,\n", + " middle_dimension=MIDDLE_DIM,\n", + " middle_lattice_size=MIDDLE_LATTICE_SIZE,\n", + " middle_calibration=True,\n", + " middle_calibration_num_keypoints=MIDDLE_KEYPOINTS,\n", + " middle_monotonicity='increasing',\n", + " output_min=min_label,\n", + " output_max=max_label,\n", + " output_calibration=True,\n", + " output_calibration_num_keypoints=OUTPUT_KEYPOINTS,\n", + " output_initialization=np.linspace(\n", + " min_label, max_label, num=OUTPUT_KEYPOINTS))\n", + "# An AggregateFunction premade model constructed from the given model config.\n", + "aggregate_function_model = tfl.premade.AggregateFunction(\n", + " aggregate_function_model_config)\n", + "# Let's plot our model.\n", + "tf.keras.utils.plot_model(\n", + " aggregate_function_model, show_layer_names=False, rankdir='LR')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "4F7AwiXgWhe2" + }, + "source": [ + "The output of each Aggregation layer is the averaged output of a calibrated lattice over the ragged inputs. Here is the model used inside the first Aggregation layer:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "UM7XF6UIWo4T" + }, + "outputs": [], + "source": [ + "aggregation_layers = [\n", + " layer for layer in aggregate_function_model.layers\n", + " if isinstance(layer, tfl.layers.Aggregation)\n", + "]\n", + "tf.keras.utils.plot_model(\n", + " aggregation_layers[0].model, show_layer_names=False, rankdir='LR')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "0ohYOftgTZhq" + }, + "source": [ + "Now, as with any other [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model), we compile and fit the model to our data." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "uB9di3-lTfMy" + }, + "outputs": [], + "source": [ + "aggregate_function_model.compile(\n", + " loss='mae',\n", + " optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))\n", + "aggregate_function_model.fit(\n", + " train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "pwZtGDR-Tzur" + }, + "source": [ + "After training our model, we can evaluate it on our test set." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "RWj1YfubT0NE" + }, + "outputs": [], + "source": [ + "print('Test Set Evaluation...')\n", + "print(aggregate_function_model.evaluate(test_xs, test_ys))" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "name": "aggregate_function_models.ipynb", + "private_outputs": true, + "provenance": [ + { + "file_id": "1ohMV9lhzSWZq3aH27fBAZ1Oj3wy19PI0", + "timestamp": 1588637142053 + } + ], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs/tutorials/premade_models.ipynb b/docs/tutorials/premade_models.ipynb index 00a6cf0..db1e9ad 100644 --- a/docs/tutorials/premade_models.ipynb +++ b/docs/tutorials/premade_models.ipynb @@ -132,7 +132,6 @@ "import pandas as pd\n", "import sys\n", "import tensorflow_lattice as tfl\n", - "from tensorflow import feature_column as fc\n", "logging.disable(sys.maxsize)" ] }, @@ -326,9 +325,9 @@ "id": "WLSfZ5G7-YT_" }, "source": [ - "### Compute Quatiles\n", + "### Compute Quantiles\n", "\n", - "Although the default setting for `pwl_calibration_input_keypoints` in `tfl.configs.FeatureConfig` is 'quatiles', for premade models we have to manually define the input keypoints. To do so, we first define our own helper function for computing quantiles. " + "Although the default setting for `pwl_calibration_input_keypoints` in `tfl.configs.FeatureConfig` is 'quantiles', for premade models we have to manually define the input keypoints. To do so, we first define our own helper function for computing quantiles. " ] }, { @@ -545,7 +544,7 @@ }, "outputs": [], "source": [ - "# Model config defined the model structure for the estimator.\n", + "# Model config defines the model structure for the premade model.\n", "linear_model_config = tfl.configs.CalibratedLinearConfig(\n", " feature_configs=feature_configs[:5],\n", " use_bias=True,\n", @@ -943,7 +942,7 @@ "crystals_ensemble_model.fit(\n", " train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)\n", "print('Test Set Evaluation...')\n", - "print(random_ensemble_model.evaluate(test_xs, test_ys))" + "print(crystals_ensemble_model.evaluate(test_xs, test_ys))" ] } ], diff --git a/setup.py b/setup.py index a29e6ba..597705e 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ # This version number should always be that of the *next* (unreleased) version. # Immediately after uploading a package to PyPI, you should increment the # version number and push to gitHub. -__version__ = "2.0.4" +__version__ = "2.0.5" if "--release" in sys.argv: sys.argv.remove("--release") @@ -45,6 +45,7 @@ "sklearn", "matplotlib", "graphviz", + "dm-sonnet", ] # Part of the visualization code uses colabtools and IPython libraries. These diff --git a/tensorflow_lattice/BUILD b/tensorflow_lattice/BUILD index ea3955c..90b09f0 100644 --- a/tensorflow_lattice/BUILD +++ b/tensorflow_lattice/BUILD @@ -28,6 +28,7 @@ py_library( srcs = [ "__init__.py", "layers/__init__.py", + "sonnet_modules/__init__.py", ], srcs_version = "PY2AND3", deps = [ @@ -46,6 +47,7 @@ py_library( "//tensorflow_lattice/python:premade_lib", "//tensorflow_lattice/python:pwl_calibration_layer", "//tensorflow_lattice/python:pwl_calibration_lib", + "//tensorflow_lattice/python:pwl_calibration_sonnet_module", "//tensorflow_lattice/python:rtl_layer", "//tensorflow_lattice/python:test_utils", "//tensorflow_lattice/python:visualization", diff --git a/tensorflow_lattice/__init__.py b/tensorflow_lattice/__init__.py index 0c58c39..aef8487 100644 --- a/tensorflow_lattice/__init__.py +++ b/tensorflow_lattice/__init__.py @@ -20,6 +20,7 @@ from __future__ import absolute_import import tensorflow_lattice.layers +import tensorflow_lattice.sonnet_modules from tensorflow_lattice.python import aggregation_layer from tensorflow_lattice.python import categorical_calibration_layer @@ -35,6 +36,7 @@ from tensorflow_lattice.python import premade from tensorflow_lattice.python import premade_lib from tensorflow_lattice.python import pwl_calibration_layer +from tensorflow_lattice.python import pwl_calibration_sonnet_module from tensorflow_lattice.python import pwl_calibration_lib from tensorflow_lattice.python import test_utils from tensorflow_lattice.python import visualization diff --git a/tensorflow_lattice/python/BUILD b/tensorflow_lattice/python/BUILD index c72240e..a572b17 100644 --- a/tensorflow_lattice/python/BUILD +++ b/tensorflow_lattice/python/BUILD @@ -32,6 +32,18 @@ py_library( ], ) +py_library( + name = "pwl_calibration_sonnet_module", + srcs = ["pwl_calibration_sonnet_module.py"], + srcs_version = "PY2AND3", + deps = [ + ":pwl_calibration_lib", + # absl/logging dep, + # sonnet dep, + # tensorflow:tensorflow_no_contrib dep, + ], +) + py_library( name = "pwl_calibration_lib", srcs = ["pwl_calibration_lib.py"], @@ -52,10 +64,12 @@ py_test( deps = [ ":parallel_combination_layer", ":pwl_calibration_layer", + ":pwl_calibration_sonnet_module", ":test_utils", # absl/logging dep, # absl/testing:parameterized dep, # numpy dep, + # sonnet dep, # tensorflow dep, ], ) @@ -287,6 +301,8 @@ py_library( ":visualization", # absl/logging dep, # numpy dep, + # sonnet dep, + # tensorflow dep, ], ) @@ -377,7 +393,7 @@ py_library( py_test( name = "premade_test", - size = "enormous", + size = "large", srcs = ["premade_test.py"], python_version = "PY3", # shard_count = 10, diff --git a/tensorflow_lattice/python/categorical_calibration_layer.py b/tensorflow_lattice/python/categorical_calibration_layer.py index 574ea94..f8b516a 100644 --- a/tensorflow_lattice/python/categorical_calibration_layer.py +++ b/tensorflow_lattice/python/categorical_calibration_layer.py @@ -248,13 +248,13 @@ def assert_constraints(self, eps=1e-6): In graph mode builds and returns list of assertion ops. Note that ops will be created at the moment when this function is being called. - In eager mode directly executes assetions. + In eager mode directly executes assertions. Args: eps: Allowed constraints violation. Returns: - List of assertion ops in graph mode or immideately asserts in eager mode. + List of assertion ops in graph mode or immediately asserts in eager mode. """ return categorical_calibration_lib.assert_constraints( weights=self.kernel, diff --git a/tensorflow_lattice/python/configs.py b/tensorflow_lattice/python/configs.py index 502b26f..0b1d6b4 100644 --- a/tensorflow_lattice/python/configs.py +++ b/tensorflow_lattice/python/configs.py @@ -249,7 +249,10 @@ def __init__(self, lattices='random', num_lattices=None, lattice_rank=None, + interpolation='hypercube', separate_calibrators=True, + use_linear_combination=False, + use_bias=False, regularizer_configs=None, output_min=None, output_max=None, @@ -276,8 +279,18 @@ def __init__(self, lattices are not explicitly provided. lattice_rank: Number of features in each lattice. Must be provided if lattices are not explicitly provided. + interpolation: One of 'hypercube' or 'simplex' interpolation. For a + d-dimensional lattice, 'hypercube' interpolates 2^d parameters, whereas + 'simplex' uses d+1 parameters and thus scales better. For details see + `tfl.lattice_lib.evaluate_with_simplex_interpolation` and + `tfl.lattice_lib.evaluate_with_hypercube_interpolation`. separate_calibrators: If features should be separately calibrated for each lattice in the ensemble. + use_linear_combination: If set to true, a linear combination layer will be + used to combine ensemble outputs. Otherwise an averaging layer will be + used. If output is bounded or output calibration is used, then this + layer will be a weighted average. + use_bias: If a bias term should be used for the linear combination. regularizer_configs: A list of `tfl.configs.RegularizerConfig` instances that apply global regularization. output_min: Lower bound constraint on the output of the model. @@ -340,6 +353,7 @@ class CalibratedLatticeConfig(_Config, _HasFeatureConfigs, def __init__(self, feature_configs=None, + interpolation='hypercube', regularizer_configs=None, output_min=None, output_max=None, @@ -352,6 +366,11 @@ def __init__(self, feature_configs: A list of `tfl.configs.FeatureConfig` instances that specify configurations for each feature. If a configuration is not provided for a feature, a default configuration will be used. + interpolation: One of 'hypercube' or 'simplex' interpolation. For a + d-dimensional lattice, 'hypercube' interpolates 2^d parameters, whereas + 'simplex' uses d+1 parameters and thus scales better. For details see + `tfl.lattice_lib.evaluate_with_simplex_interpolation` and + `tfl.lattice_lib.evaluate_with_hypercube_interpolation`. regularizer_configs: A list of `tfl.configs.RegularizerConfig` instances that apply global regularization. output_min: Lower bound constraint on the output of the model. @@ -481,6 +500,8 @@ def __init__(self, middle_calibration=False, middle_calibration_num_keypoints=10, middle_monotonicity=None, + middle_lattice_interpolation='hypercube', + aggregation_lattice_interpolation='hypercube', output_min=None, output_max=None, output_calibration=False, @@ -507,6 +528,16 @@ def __init__(self, monotonic, using 'increasing' or 1 to indicate increasing monotonicity, 'decreasing' or -1 to indicate decreasing monotonicity, and 'none' or 0 to indicate no monotonicity constraints. + middle_lattice_interpolation: One of 'hypercube' or 'simplex'. For a + d-dimensional lattice, 'hypercube' interpolates 2^d parameters, whereas + 'simplex' uses d+1 parameters and thus scales better. For details see + `tfl.lattice_lib.evaluate_with_simplex_interpolation` and + `tfl.lattice_lib.evaluate_with_hypercube_interpolation`. + aggregation_lattice_interpolation: One of 'hypercube' or 'simplex'. For a + d-dimensional lattice, 'hypercube' interpolates 2^d parameters, whereas + 'simplex' uses d+1 parameters and thus scales better. For details see + `tfl.lattice_lib.evaluate_with_simplex_interpolation` and + `tfl.lattice_lib.evaluate_with_hypercube_interpolation`. output_min: Lower bound constraint on the output of the model. output_max: Upper bound constraint on the output of the model. output_calibration: If a piecewise-linear calibration should be used on diff --git a/tensorflow_lattice/python/estimators.py b/tensorflow_lattice/python/estimators.py index 66b5239..646fe49 100644 --- a/tensorflow_lattice/python/estimators.py +++ b/tensorflow_lattice/python/estimators.py @@ -1393,9 +1393,10 @@ def get_model_graph(saved_model_path, tag='serve'): g.get_operation_by_name(linear_kernel_op).outputs[0]).flatten() # Bias term. - # {LINEAR_LAYER_NAME}/{LINEAR_LAYER_BIAS_NAME} - bias_op = '{}/{}/Read/ReadVariableOp'.format( + # {LINEAR_LAYER_NAME}_{submodel_idx}/{LINEAR_LAYER_BIAS_NAME} + bias_op = '{}_{}/{}/Read/ReadVariableOp'.format( premade_lib.LINEAR_LAYER_NAME, + submodel_idx, linear_layer.LINEAR_LAYER_BIAS_NAME, ) if bias_op in ops: @@ -1448,9 +1449,9 @@ def get_model_graph(saved_model_path, tag='serve'): submodel_output_nodes[submodel_idx] = lattice_node nodes.append(lattice_node) - ################### - # Create mean node. - ################### + ##################################################### + # Create output linear combination or averaging node. + ##################################################### # Mean node is only added for ensemble models. if len(submodel_output_nodes) > 1: @@ -1458,9 +1459,37 @@ def get_model_graph(saved_model_path, tag='serve'): submodel_output_nodes[idx] for idx in sorted(submodel_output_nodes.keys(), key=int) ] - average_node = model_info.MeanNode(input_nodes=input_nodes) - nodes.append(average_node) - model_output_node = average_node + + # Linear coefficients. + # {LINEAR_LAYER_COMBINATION_NAME}/{LINEAR_LAYER_KERNEL_NAME} + linear_combination_kernel_op = '{}/{}/Read/ReadVariableOp'.format( + premade_lib.OUTPUT_LINEAR_COMBINATION_LAYER_NAME, + linear_layer.LINEAR_LAYER_KERNEL_NAME, + ) + if linear_combination_kernel_op in ops: + coefficients = sess.run( + g.get_operation_by_name( + linear_combination_kernel_op).outputs[0]).flatten() + + # Bias term. + # {OUTPUT_LINEAR_COMBINATION_LAYER_NAME}/{LINEAR_LAYER_BIAS_NAME} + bias_op = '{}/{}/Read/ReadVariableOp'.format( + premade_lib.OUTPUT_LINEAR_COMBINATION_LAYER_NAME, + linear_layer.LINEAR_LAYER_BIAS_NAME, + ) + if bias_op in ops: + bias = sess.run(g.get_operation_by_name(bias_op).outputs[0]) + else: + bias = 0.0 + + linear_combination_node = model_info.LinearNode( + input_nodes=input_nodes, coefficients=coefficients, bias=bias) + nodes.append(linear_combination_node) + model_output_node = linear_combination_node + else: + average_node = model_info.MeanNode(input_nodes=input_nodes) + nodes.append(average_node) + model_output_node = average_node else: model_output_node = list(submodel_output_nodes.values())[0] diff --git a/tensorflow_lattice/python/lattice_layer.py b/tensorflow_lattice/python/lattice_layer.py index 6d7f52d..fa7d3a4 100644 --- a/tensorflow_lattice/python/lattice_layer.py +++ b/tensorflow_lattice/python/lattice_layer.py @@ -164,7 +164,8 @@ def __init__(self, num_projection_iterations=10, monotonic_at_every_step=True, clip_inputs=True, - kernel_initializer="linear_initializer", + interpolation="hypercube", + kernel_initializer="random_uniform_or_linear_initializer", kernel_regularizer=None, **kwargs): # pyformat: disable @@ -230,6 +231,11 @@ def __init__(self, num_projection_iterations parameter is likely to hurt convergence. clip_inputs: If inputs should be clipped to the input range of the lattice. + interpolation: One of 'hypercube' or 'simplex' interpolation. For a + d-dimensional lattice, 'hypercube' interpolates 2^d parameters, whereas + 'simplex' uses d+1 parameters and thus scales better. For details see + `tfl.lattice_lib.evaluate_with_simplex_interpolation` and + `tfl.lattice_lib.evaluate_with_hypercube_interpolation`. kernel_initializer: None or one of: - `'linear_initializer'`: initialize parameters to form a linear function with positive and equal coefficients for monotonic dimensions @@ -244,6 +250,10 @@ def __init__(self, `[output_min, output_max]`. See `tfl.lattice_layer.RandomMonotonicInitializer` class docstring for more details. + - `random_uniform_or_linear_initializer`: if the lattice has a single + joint unimodality constraint group encompassing all features then use + the Keras 'random_uniform' initializer; otherwise, use TFL's + 'linear_initializer'. - Any Keras initializer object. kernel_regularizer: None or a single element or a list of following: - Tuple `('torsion', l1, l2)` where l1 and l2 represent corresponding @@ -264,7 +274,8 @@ def __init__(self, lattice_lib.verify_hyperparameters( lattice_sizes=lattice_sizes, monotonicities=monotonicities, - unimodalities=unimodalities) + unimodalities=unimodalities, + interpolation=interpolation) super(Lattice, self).__init__(**kwargs) self.lattice_sizes = lattice_sizes @@ -308,61 +319,16 @@ def __init__(self, self.num_projection_iterations = num_projection_iterations self.monotonic_at_every_step = monotonic_at_every_step self.clip_inputs = clip_inputs + self.interpolation = interpolation - def default_params(output_min, output_max): - """Return reasonable default parameters if not defined explicitly.""" - if output_min is not None: - output_init_min = output_min - elif output_max is not None: - output_init_min = min(0.0, output_max) - else: - output_init_min = 0.0 - if output_max is not None: - output_init_max = output_max - elif output_min is not None: - output_init_max = max(1.0, output_min) - else: - output_init_max = 1.0 - # Return our min and max. - return output_init_min, output_init_max - - # Initialize joint unimodalities identical to regular ones. - all_unimodalities = [0] * len(lattice_sizes) - if self.unimodalities: - for i, value in enumerate(self.unimodalities): - if value: - all_unimodalities[i] = value - if self.joint_unimodalities: - for dimensions, direction in self.joint_unimodalities: - for dim in dimensions: - all_unimodalities[dim] = direction - - if kernel_initializer in ["linear_initializer", "LinearInitializer"]: - output_init_min, output_init_max = default_params(output_min, output_max) - - self.kernel_initializer = LinearInitializer( - lattice_sizes=lattice_sizes, - monotonicities=monotonicities, - output_min=output_init_min, - output_max=output_init_max, - unimodalities=all_unimodalities) - elif kernel_initializer in ["random_monotonic_initializer", - "RandomMonotonicInitializer"]: - output_init_min, output_init_max = default_params(output_min, output_max) - - self.kernel_initializer = RandomMonotonicInitializer( - lattice_sizes=lattice_sizes, - output_min=output_init_min, - output_max=output_init_max, - unimodalities=all_unimodalities) - else: - # This is needed for Keras deserialization logic to be aware of our custom - # objects. - with keras.utils.custom_object_scope({ - "LinearInitializer": LinearInitializer, - "RandomMonotonicInitializer": RandomMonotonicInitializer, - }): - self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.kernel_initializer = create_kernel_initializer( + kernel_initializer, + self.lattice_sizes, + self.monotonicities, + self.output_min, + self.output_max, + self.unimodalities, + self.joint_unimodalities) self.kernel_regularizer = [] if kernel_regularizer: @@ -468,25 +434,27 @@ def build(self, input_shape): def call(self, inputs): """Standard Keras call() method.""" - interpolation_weights = lattice_lib.compute_interpolation_weights( - inputs=inputs, - lattice_sizes=self.lattice_sizes, - clip_inputs=self.clip_inputs) - # Use control dependencies to save lattice sizes as graph constant for # visualisation toolbox to be able to recove it from saved graph. # Wrap this constant into pure op since in TF 2.0 there are issues passing # tensors into control_dependencies. with tf.control_dependencies([tf.identity(self.lattice_sizes_tensor)]): - if self.units == 1: - # Weights shape: (batch-size, ..., prod(lattice_sizes)) - # Kernel shape: (prod(lattice_sizes), 1) - return tf.matmul(interpolation_weights, self.kernel) + if self.interpolation == "simplex": + return lattice_lib.evaluate_with_simplex_interpolation( + inputs=inputs, + kernel=self.kernel, + units=self.units, + lattice_sizes=self.lattice_sizes, + clip_inputs=self.clip_inputs) + elif self.interpolation == "hypercube": + return lattice_lib.evaluate_with_hypercube_interpolation( + inputs=inputs, + kernel=self.kernel, + units=self.units, + lattice_sizes=self.lattice_sizes, + clip_inputs=self.clip_inputs) else: - # Weights shape: (batch-size, ..., units, prod(lattice_sizes)) - # Kernel shape: (prod(lattice_sizes), units) - return tf.reduce_sum( - interpolation_weights * tf.transpose(self.kernel), axis=-1) + raise ValueError("Unknown interpolation type: %s" % self.interpolation) def compute_output_shape(self, input_shape): """Standard Keras compute_output_shape() method.""" @@ -516,6 +484,7 @@ def get_config(self): "num_projection_iterations": self.num_projection_iterations, "monotonic_at_every_step": self.monotonic_at_every_step, "clip_inputs": self.clip_inputs, + "interpolation": self.interpolation, "kernel_initializer": keras.initializers.serialize(self.kernel_initializer), "kernel_regularizer": @@ -542,13 +511,13 @@ def assert_constraints(self, eps=1e-6): """Asserts that weights satisfy all constraints. In graph mode builds and returns list of assertion ops. - In eager mode directly executes assetions. + In eager mode directly executes assertions. Args: eps: allowed constraints violation. Returns: - List of assertion ops in graph mode or immideately asserts in eager mode. + List of assertion ops in graph mode or immediately asserts in eager mode. """ return lattice_lib.assert_constraints( weights=self.kernel, @@ -566,6 +535,118 @@ def assert_constraints(self, eps=1e-6): eps=eps) +def create_kernel_initializer( + kernel_initializer_id, + lattice_sizes, + monotonicities, + output_min, + output_max, + unimodalities, + joint_unimodalities): + """Returns a kernel Keras initializer object from its id. + + This function is used to convert the 'kernel_initializer' parameter in the + constructor of tfl.Lattice into the corresponding initializer object. + + Args: + kernel_initializer_id: See the documentation of the 'kernel_initializer' + parameter in the constructor of tfl.Lattice. + lattice_sizes: See the documentation of the same parameter in the + constructor of tfl.Lattice. + monotonicities: See the documentation of the same parameter in the + constructor of tfl.Lattice. + output_min: See the documentation of the same parameter in the + constructor of tfl.Lattice. + output_max: See the documentation of the same parameter in the + constructor of tfl.Lattice. + unimodalities: See the documentation of the same parameter in the + constructor of tfl.Lattice. + joint_unimodalities: See the documentation of the same parameter in the + constructor of tfl.Lattice. + Returns: + The Keras initializer object for the tfl.Lattice kernel variable. + """ + def default_params(output_min, output_max): + """Return reasonable default parameters if not defined explicitly.""" + if output_min is not None: + output_init_min = output_min + elif output_max is not None: + output_init_min = min(0.0, output_max) + else: + output_init_min = 0.0 + + if output_max is not None: + output_init_max = output_max + elif output_min is not None: + output_init_max = max(1.0, output_min) + else: + output_init_max = 1.0 + + # Return our min and max. + return output_init_min, output_init_max + + def do_joint_unimodalities_contain_all_features(joint_unimodalities): + if (joint_unimodalities is None) or (len(joint_unimodalities) != 1): + return False + [joint_unimodalities] = joint_unimodalities + return set(joint_unimodalities[0]) == set(range(len(lattice_sizes))) + + # Initialize joint unimodalities identical to regular ones. + all_unimodalities = [0] * len(lattice_sizes) + if unimodalities: + for i, value in enumerate(unimodalities): + if value: + all_unimodalities[i] = value + if joint_unimodalities: + for dimensions, direction in joint_unimodalities: + for dim in dimensions: + all_unimodalities[dim] = direction + + if kernel_initializer_id in ["linear_initializer", "LinearInitializer"]: + output_init_min, output_init_max = default_params(output_min, output_max) + + return LinearInitializer( + lattice_sizes=lattice_sizes, + monotonicities=monotonicities, + output_min=output_init_min, + output_max=output_init_max, + unimodalities=all_unimodalities) + elif kernel_initializer_id in ["random_monotonic_initializer", + "RandomMonotonicInitializer"]: + output_init_min, output_init_max = default_params(output_min, output_max) + + return RandomMonotonicInitializer( + lattice_sizes=lattice_sizes, + output_min=output_init_min, + output_max=output_init_max, + unimodalities=all_unimodalities) + elif kernel_initializer_id in ["random_uniform_or_linear_initializer", + "RandomUniformOrLinearInitializer"]: + if do_joint_unimodalities_contain_all_features(joint_unimodalities): + return create_kernel_initializer("random_uniform", + lattice_sizes, + monotonicities, + output_min, + output_max, + unimodalities, + joint_unimodalities) + return create_kernel_initializer("linear_initializer", + lattice_sizes, + monotonicities, + output_min, + output_max, + unimodalities, + joint_unimodalities) + else: + # This is needed for Keras deserialization logic to be aware of our custom + # objects. + with keras.utils.custom_object_scope({ + "LinearInitializer": LinearInitializer, + "RandomMonotonicInitializer": RandomMonotonicInitializer, + }): + return keras.initializers.get(kernel_initializer_id) + + class LinearInitializer(keras.initializers.Initializer): # pyformat: disable """Initializes a `tfl.layers.Lattice` as linear function. diff --git a/tensorflow_lattice/python/lattice_lib.py b/tensorflow_lattice/python/lattice_lib.py index 4bbbc3f..80d9771 100644 --- a/tensorflow_lattice/python/lattice_lib.py +++ b/tensorflow_lattice/python/lattice_lib.py @@ -24,12 +24,167 @@ from absl import logging import numpy as np import six - import tensorflow as tf +def evaluate_with_simplex_interpolation(inputs, kernel, units, lattice_sizes, + clip_inputs): + """Evaluates a lattice using simplex interpolation. + + Within each cell of the lattice, we partition the hypercube into d! simplices, + where each simplex has d+1 vertices. Each simplex (relative to the lower + corner of the hypercube) includes the all-zeros vertex, a vertex with a + single one, a vertex with two ones, ... and the all-ones vertex. + For example, for a three-dimensional unit hypercube the 3! = 6 simplices are: + + [0,0,0], [0,0,1], [0,1,1], [1,1,1] + [0,0,0], [0,0,1], [1,0,1], [1,1,1] + [0,0,0], [0,1,0], [0,1,1], [1,1,1] + [0,0,0], [0,1,0], [1,1,0], [1,1,1] + [0,0,0], [1,0,0], [1,1,0], [1,1,1] + [0,0,0], [1,0,0], [1,0,1], [1,1,1] + + A point x in the hypercube is contained in the simplex corresponding to the + order of x's components. For example, x = [0.4,0.2,0.8] is contained in the + simplex specified by [2,0,1] (second in the above list). The weight associated + with each vertex in the simplex is the difference between the decreasingly + sorted cooredinates of the input. For details, see e.g. "Dissection of the + hypercube into simplices", D.G. Mead, Proceedings of the AMS, 76:2, Sep. 1979. + + Args: + inputs: Tensor of shape: `(batch_size, ..., len(lattice_sizes))` or list of + `len(lattice_sizes)` tensors of same shape `(batch_size, ..., 1)` which + represents points to apply lattice interpolation to. A typical shape is + `(batch_size, len(lattice_sizes))`. + kernel: Lattice kernel of shape (num_params_per_lattice, units). + units: Output dimension of the lattice. + lattice_sizes: List or tuple of integers which represents lattice sizes of + layer for which interpolation is being computed. + clip_inputs: Whether inputs should be clipped to the input range of the + lattice. + """ + if isinstance(inputs, list): + inputs = tf.concat(inputs, axis=-1) + + if clip_inputs: + inputs = _clip_onto_lattice_range( + inputs=inputs, lattice_sizes=lattice_sizes) + + lattice_rank = len(lattice_sizes) + input_dim = len(inputs.shape) + all_size_2 = all(size == 2 for size in lattice_sizes) + + if input_dim == 2: + units = 1 + else: + units = int(inputs.shape[-2]) + + # Strides are the changes in the global index (index into the flattened + # parameters) when moving across each dimension. + # E.g. for 2x2x2, strides are [4, 2, 1]. + strides = tf.constant( + np.cumprod([1] + lattice_sizes[::-1][:-1])[::-1], tf.int32) + + if not all_size_2: + # Find offset (into flattened parameters) for the lower corner of the + # hypercube where input lands in. + lower_corner_coordinates = tf.cast(inputs, tf.int32) + # Avoid the corner case of landing on the outermost edge. + lower_corner_coordinates = tf.minimum(lower_corner_coordinates, + np.array(lattice_sizes) - 2) + + # Multiplying coordinates by strides and summing up gives out the index into + # the flattened parameter tensor. + # Note: Alternative method using tf.tensordot + tf.expand_dims is slower. + lower_corner_offset = tf.reduce_sum( + lower_corner_coordinates * strides, axis=-1, keepdims=True) + + # Continue simplex interpolation with the residuals + inputs = inputs - tf.cast(lower_corner_coordinates, inputs.dtype) + + # Get sorted values and indicies. + # TODO: investigate if there is a way to avoid sorting twice. + sorted_indices = tf.argsort(inputs, direction="DESCENDING") + sorted_inputs = tf.sort(inputs, direction="DESCENDING") + + # Simplex interpolation weights are the deltas between residuals. + no_padding_dims = [[0, 0]] * (input_dim - 1) + sorted_inputs_padded_left = tf.pad( + sorted_inputs, no_padding_dims + [[1, 0]], constant_values=1.) + sorted_inputs_padded_right = tf.pad( + sorted_inputs, no_padding_dims + [[0, 1]], constant_values=0.) + weights = sorted_inputs_padded_left - sorted_inputs_padded_right + + # Calculate cumsum over the strides of sorted dimensions to get index of + # simplex vertices into the flattened lattice parameters. + sorted_strides = tf.gather(strides, sorted_indices) + if all_size_2: + # Lower corner offset is 0 for 2^d lattices. + corner_offset_and_sorted_strides = tf.pad(sorted_strides, + no_padding_dims + [[1, 0]]) + else: + corner_offset_and_sorted_strides = tf.concat( + [lower_corner_offset, sorted_strides], axis=-1) + indices = tf.cumsum(corner_offset_and_sorted_strides, axis=-1) + + # Get parameters values of simplex indicies. + if units == 1: + gathered_params = tf.gather(tf.reshape(kernel, [-1]), indices) + else: + # We now have two tensors 'indices' and 'weights' of shape (batch, units). + # The kernel is of shape (num_params_per_lattice, units). + # In order to use tf.gather, we need to convert 'indices' so that they are + # indices into the flattened parameter tensor. + # Note: Alternative method that uses a transpose on the parameters instead + # of a multiply on the indices is slower with typical batch sizes. + unit_offset = tf.constant([[i] * (lattice_rank + 1) for i in range(units)]) + flat_indices = indices * units + unit_offset + gathered_params = tf.gather(tf.reshape(kernel, [-1]), flat_indices) + + # Dot product with interpolation weights. + # Note: Alternative method using tf.einsum is slightly slower on CPU. + return tf.reduce_sum( + tf.multiply(gathered_params, weights), axis=-1, keepdims=(units == 1)) + + +def evaluate_with_hypercube_interpolation(inputs, kernel, units, lattice_sizes, + clip_inputs): + """Evaluates a lattice using hypercube interpolation. + + Lattice function is multi-linearly interpolated between the 2^d vertices of a + hypercube. This interpolation method is typically slower than simplex + interpolation, since each value is interpolated from 2^d hypercube corners, + rather than d+1 simplex corners. For details, see e.g. "Dissection of the + hypercube into simplices", D.G. Mead, Proceedings of the AMS, 76:2, Sep. 1979. + + Args: + inputs: Tensor of shape: `(batch_size, ..., len(lattice_sizes))` or list of + `len(lattice_sizes)` tensors of same shape `(batch_size, ..., 1)` which + represents points to apply lattice interpolation to. A typical shape is + `(batch_size, len(lattice_sizes))`. + kernel: Lattice kernel of shape (num_params_per_lattice, units). + units: Output dimension of the lattice. + lattice_sizes: List or tuple of integers which represents lattice sizes of + layer for which interpolation is being computed. + clip_inputs: Whether inputs should be clipped to the input range of the + lattice. + """ + interpolation_weights = compute_interpolation_weights( + inputs=inputs, lattice_sizes=lattice_sizes, clip_inputs=clip_inputs) + + if units == 1: + # Weights shape: (batch-size, ..., prod(lattice_sizes)) + # Kernel shape: (prod(lattice_sizes), 1) + return tf.matmul(interpolation_weights, kernel) + else: + # Weights shape: (batch-size, ..., units, prod(lattice_sizes)) + # Kernel shape: (prod(lattice_sizes), units) + return tf.reduce_sum(interpolation_weights * tf.transpose(kernel), axis=-1) + + +# TODO: Rename and update usage. def compute_interpolation_weights(inputs, lattice_sizes, clip_inputs=True): - """Computes weights for lattice interpolation. + """Computes weights for hypercube lattice interpolation. Running time: `O(batch_size * prod(lattice_sizes))` @@ -63,6 +218,14 @@ def compute_interpolation_weights(inputs, lattice_sizes, clip_inputs=True): input_dtype = inputs.dtype verify_hyperparameters(lattice_sizes=lattice_sizes, input_shape=input_shape) + # Special case: 2^d lattice with input passed in as a single tensor + if all(size == 2 for size in lattice_sizes) and not isinstance(inputs, list): + w = tf.stack([(1.0 - inputs), inputs], axis=-1) + if clip_inputs: + w = tf.clip_by_value(w, clip_value_min=0, clip_value_max=1) + one_d_interpolation_weights = tf.unstack(w, axis=-2) + return batch_outer_operation(one_d_interpolation_weights, operation="auto") + if clip_inputs: inputs = _clip_onto_lattice_range( inputs=inputs, lattice_sizes=lattice_sizes) @@ -116,6 +279,19 @@ def batch_outer_operation(list_of_tensors, operation="auto"): Returns: Tensor of shape: `(batch_size, ..., mul_i(k[i]))`. """ + # Alternative implementation using tf.einsum creates fewer graph nodes. + # This is slightly slower on CPU as of 2020/5, but the timing results might + # change with different setup/platform/hardware. + # Create a formula for outer product. e.g. '...a,...b,...c->...abc' + # if operation == "auto": + # n = len(list_of_tensors) + # chars = string.ascii_lowercase[:n] + # eqn = ",".join(["..." + c for c in chars]) + "->..." + "".join(chars) + # result = tf.einsum(eqn, *list_of_tensors) + # result_shape = [-1] + [int(size) for size in result.shape[1:]] + # output_shape = result_shape[:-n] + [np.prod(result_shape[-n:])] + # return tf.reshape(result, shape=output_shape) + if len(list_of_tensors) == 1: return list_of_tensors[0] @@ -2073,7 +2249,8 @@ def verify_hyperparameters(lattice_sizes, output_min=None, output_max=None, regularization_amount=None, - regularization_info=""): + regularization_info="", + interpolation="hypercube"): """Verifies that all given hyperparameters are consistent. This function does not inspect weights themselves. Only their shape. Use @@ -2102,6 +2279,7 @@ def verify_hyperparameters(lattice_sizes, output_max: Maximum output of `Lattice` layer. regularization_amount: Regularization amount for regularizers. regularization_info: String which describes `regularization_amount`. + interpolation: One of 'simplex' or 'hypercube' interpolation. Raises: ValueError: If something is inconsistent. @@ -2283,6 +2461,10 @@ def verify_hyperparameters(lattice_sizes, "l1: %s, lattice sizes: %s" % (regularization_info, regularization_amount, lattice_sizes)) + if interpolation not in ["hypercube", "simplex"]: + raise ValueError("Lattice interpolation type should be either 'simplex' " + "or 'hypercube': %s" % interpolation) + # TODO: investigate whether eps should be bigger. def assert_constraints(weights, diff --git a/tensorflow_lattice/python/lattice_test.py b/tensorflow_lattice/python/lattice_test.py index 66914c4..0cd34ba 100644 --- a/tensorflow_lattice/python/lattice_test.py +++ b/tensorflow_lattice/python/lattice_test.py @@ -214,6 +214,7 @@ def _SetDefaults(self, config): config.setdefault("kernel_regularizer", None) config.setdefault("units", 1) config.setdefault("lattice_index", 0) + config.setdefault("interpolation", "hypercube") return config @@ -268,6 +269,7 @@ def _TrainModel(self, config, plot_path=None): output_max=config["output_max"], num_projection_iterations=config["num_projection_iterations"], monotonic_at_every_step=config["monotonic_at_every_step"], + interpolation=config["interpolation"], kernel_initializer=config["kernel_initializer"], kernel_regularizer=config["kernel_regularizer"], input_shape=input_shape, @@ -1749,6 +1751,119 @@ def testUnconstrained(self): self.assertAlmostEqual(loss, 0.004216, delta=self.loss_eps) self._TestEnsemble(config) + config = { + "lattice_sizes": [20], + "interpolation": "simplex", + "num_training_records": 100, + "num_training_epoch": 200, + "optimizer": tf.keras.optimizers.Adagrad, + "learning_rate": 0.15, + "x_generator": self._ScatterXUniformly, + "y_function": self._Sin, + "kernel_initializer": keras.initializers.Zeros, + } # pyformat: disable + loss = self._TrainModel(config) + self.assertAlmostEqual(loss, 0.000917, delta=self.loss_eps) + self._TestEnsemble(config) + + config = { + "lattice_sizes": [2], + "interpolation": "simplex", + "num_training_records": 100, + "num_training_epoch": 50, + "optimizer": tf.keras.optimizers.Adagrad, + "learning_rate": 0.15, + "x_generator": self._ScatterXUniformly, + "y_function": self._Square, + } # pyformat: disable + loss = self._TrainModel(config) + self.assertAlmostEqual(loss, 0.004277, delta=self.loss_eps) + self._TestEnsemble(config) + + config = { + "lattice_sizes": [2, 2], + "interpolation": "simplex", + "num_training_records": 100, + "num_training_epoch": 100, + "optimizer": tf.keras.optimizers.Adagrad, + "learning_rate": 0.15, + "x_generator": self._ScatterXUniformly, + "y_function": self._Max, + } # pyformat: disable + loss = self._TrainModel(config) + self.assertAlmostEqual(loss, 5e-06, delta=self.loss_eps) + self._TestEnsemble(config) + + config = { + "lattice_sizes": [2] * 6, + "interpolation": "simplex", + "num_training_records": 100, + "num_training_epoch": 300, + "optimizer": tf.keras.optimizers.Adagrad, + "learning_rate": 30.0, + "x_generator": self._ScatterXUniformly, + "y_function": self._PseudoLinear, + } # pyformat: disable + loss = self._TrainModel(config) + self.assertAlmostEqual(loss, 0.08056, delta=self.loss_eps) + self._TestEnsemble(config) + + config = { + "lattice_sizes": [2, 3, 4], + "interpolation": "simplex", + "num_training_records": 100, + "num_training_epoch": 200, + "optimizer": tf.keras.optimizers.Adagrad, + "learning_rate": 10.0, + "x_generator": self._ScatterXUniformly, + "y_function": self._PseudoLinear, + } # pyformat: disable + loss = self._TrainModel(config) + self.assertAlmostEqual(loss, 0.04316, delta=self.loss_eps) + self._TestEnsemble(config) + + config = { + "lattice_sizes": [4, 5], + "interpolation": "simplex", + "num_training_records": 100, + "num_training_epoch": 100, + "optimizer": tf.keras.optimizers.Adagrad, + "learning_rate": 10.0, + "x_generator": self._ScatterXUniformly, + "y_function": self._WeightedSum, + } # pyformat: disable + loss = self._TrainModel(config) + self.assertAlmostEqual(loss, 0.0, delta=self.loss_eps) + self._TestEnsemble(config) + + config = { + "lattice_sizes": [2, 3, 4, 5], + "interpolation": "simplex", + "num_training_records": 100, + "num_training_epoch": 200, + "optimizer": tf.keras.optimizers.Adagrad, + "learning_rate": 30.0, + "x_generator": self._ScatterXUniformly, + "y_function": self._Max, + } # pyformat: disable + loss = self._TrainModel(config) + self.assertAlmostEqual(loss, 0.000122, delta=self.loss_eps) + self._TestEnsemble(config) + + config = { + "lattice_sizes": [2, 3, 4, 5], + "interpolation": "simplex", + "num_training_records": 100, + "num_training_epoch": 200, + "optimizer": tf.keras.optimizers.Adagrad, + "learning_rate": 30.0, + "x_generator": self._ScatterXUniformly, + "y_function": self._WeightedSum, + } # pyformat: disable + loss = self._TrainModel(config) + self.assertAlmostEqual(loss, 0.003793, delta=self.loss_eps) + self._TestEnsemble(config) + @parameterized.parameters( ([2, 3, 4], 6.429155), ([2, 3, 3], 13.390955), @@ -1775,10 +1890,10 @@ def testEqaulySizedDimsOptimization(self, lattice_sizes, expected_loss): self.assertAlmostEqual(loss, expected_loss, delta=self.loss_eps) @parameterized.parameters( - ([2, 2, 2, 2, 2, 2], 92), + ([2, 2, 2, 2, 2, 2], 81), ([2, 2, 3, 2, 3, 2], 117), ([2, 2, 2, 2, 3, 3], 102), - ([2, 2, 2, 2, 2, 2, 2, 2, 2], 125), + ([2, 2, 2, 2, 2, 2, 2, 2, 2], 114), ([2, 2, 2, 2, 2, 2, 3, 3, 3], 135), ) def testGraphSize(self, lattice_sizes, expected_graph_size): @@ -1798,6 +1913,139 @@ def testGraphSize(self, lattice_sizes, expected_graph_size): self.assertLessEqual(graph_size, expected_graph_size) + @parameterized.parameters( + ("random_uniform_or_linear_initializer", [3, 3, 3], + [([0, 1, 2], "peak")], + tf.keras.initializers.RandomUniform), + ("random_uniform_or_linear_initializer", [3, 3, 3], + [([0, 1, 2], "valley")], + tf.keras.initializers.RandomUniform), + ("random_uniform_or_linear_initializer", [3, 3, 3], + [([0, 1], "valley")], + ll.LinearInitializer), + ("random_uniform_or_linear_initializer", [3, 3, 3], + [([0, 1], "valley"), ([2], "peak")], + ll.LinearInitializer), + ("random_uniform_or_linear_initializer", [3, 3, 3], + None, + ll.LinearInitializer), + ("linear_initializer", [3, 3, 3], + [([0, 1], "valley")], + ll.LinearInitializer), + ("random_monotonic_initializer", [3, 3, 3], + [([0, 1], "valley")], + ll.RandomMonotonicInitializer)) + def testCreateKernelInitializer( + self, + kernel_initializer_id, lattice_sizes, joint_unimodalities, expected_type): + self.assertEqual( + expected_type, + type(ll.create_kernel_initializer( + kernel_initializer_id, + lattice_sizes, + monotonicities=None, + output_min=0.0, + output_max=1.0, + unimodalities=None, + joint_unimodalities=joint_unimodalities))) + + @parameterized.parameters( + # Single Unit + ( + [2, 2], + [[0.], [1.], [2.], [3.]], + [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]], + [[0.], [1.], [2.], [3.]], + ), + ( + [3, 2], + [[-0.4], [0.9], [0.4], [-0.6], [-0.8], [0.6]], + [[0.8, 0.3], [0.3, 0.8], [2.0, 0.0], [2.0, 0.5], [2.0, 1.0]], + [[-0.06], [0.19], [-0.8], [-0.1], [0.6]], + ), + ( + [2, 2, 2, 2, 2], + [[-0.2], [-0.7], [-0.8], [0.8], [-0.3], [-0.6], [0.4], [0.5], [-0.3], + [0.3], [0.9], [0.4], [0.3], [-0.7], [0.1], [0.8], [-0.7], [-0.6], + [0.9], [-0.2], [0.3], [0.2], [0.9], [-0.1], [-0.6], [0.8], [0.4], + [1], [0.5], [0.2], [0.8], [-0.8]], + [[0.1, 0.2, 0.3, 0.4, 0.5], [0.5, 0.4, 0.3, 0.2, 0.1]], + [[-0.04], [-0.18]], + ), + ( + [3, 2, 2], + [[0], [1], [0.5], [0.1], [-0.5], [-0.9], [0.6], [-0.7], [-0.4], [0.2], + [0], [0.8]], + [[0.1, 0.2, 0.3], [0.3, 0.2, 0.1], [1.1, 0.2, 0.3], [1.7, 0.2, 0.1]], + [[0.04], [-0.06], [-0.43], [-0.27]], + ), + # Multi Unit + ( + [2, 2], + [ + [1., 11., 111.], + [2., 22., 222.], + [3., 33., 333.], + [4., 44., 444.], + ], + [ + [[0.0, 0.0], [0.0, 0.0], [1.0, 1.0]], + [[0.0, 1.0], [0.0, 1.0], [1.0, 0.0]], + [[1.0, 0.0], [1.0, 0.0], [0.0, 1.0]], + [[1.0, 1.0], [1.0, 1.0], [0.0, 0.0]], + ], + [ + [1., 11., 444.], + [2., 22., 333.], + [3., 33., 222.], + [4., 44., 111.], + ], + ), + ( + [3, 2], + [ + [-0.4, -4, -40, -400], + [0.9, 9, 90, 900], + [0.4, 4, 40, 400], + [-0.6, -6, -60, -600], + [-0.8, -8, -80, -800], + [0.6, 6, 60, 600], + ], + [ + [[0.8, 0.3], [2.0, 1.0], [0.8, 0.3], [2.0, 1.0]], + [[0.3, 0.8], [2.0, 0.5], [0.3, 0.8], [2.0, 0.5]], + [[2.0, 0.0], [2.0, 0.0], [2.0, 0.0], [2.0, 0.0]], + [[2.0, 0.5], [0.3, 0.8], [2.0, 0.5], [0.3, 0.8]], + [[2.0, 1.0], [0.8, 0.3], [2.0, 1.0], [0.8, 0.3]], + ], + [ + [-0.06, 6., -6., 600.], + [0.19, -1., 19., -100.], + [-0.8, -8., -80., -800.], + [-0.1, 1.9, -10., 190.], + [0.6, -0.6, 60., -60.], + ], + ), + ) + def testSimplexInterpolation(self, lattice_sizes, kernel, inputs, + expected_outputs): + if self.disable_all: + return + + kernel = tf.constant(kernel, dtype=tf.float32) + inputs = tf.constant(inputs, dtype=tf.float32) + units = int(kernel.shape[1]) + model = tf.keras.models.Sequential([ + ll.Lattice( + lattice_sizes, + units=units, + interpolation="simplex", + kernel_initializer=tf.keras.initializers.Constant(kernel), + ), + ]) + outputs = model.predict(inputs) + self.assertAllClose(outputs, expected_outputs) + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow_lattice/python/linear_layer.py b/tensorflow_lattice/python/linear_layer.py index c464dbd..6fba8a2 100644 --- a/tensorflow_lattice/python/linear_layer.py +++ b/tensorflow_lattice/python/linear_layer.py @@ -305,13 +305,13 @@ def assert_constraints(self, eps=1e-4): """Asserts that weights satisfy all constraints. In graph mode builds and returns list of assertion ops. - In eager mode directly executes assetions. + In eager mode directly executes assertions. Args: eps: Allowed constraints violation. Returns: - List of assertion ops in graph mode or immideately asserts in eager mode. + List of assertion ops in graph mode or immediately asserts in eager mode. """ return linear_lib.assert_constraints( weights=self.kernel, diff --git a/tensorflow_lattice/python/premade.py b/tensorflow_lattice/python/premade.py index 49e350a..1dd28ce 100644 --- a/tensorflow_lattice/python/premade.py +++ b/tensorflow_lattice/python/premade.py @@ -134,7 +134,13 @@ def __init__(self, model_config=None, dtype=tf.float32, **kwargs): dtype=dtype)) if len(lattice_outputs) > 1: - averaged_lattice_output = tf.keras.layers.Average()(lattice_outputs) + if model_config.use_linear_combination: + averaged_lattice_output = premade_lib.build_linear_combination_layer( + ensemble_outputs=lattice_outputs, + model_config=model_config, + dtype=dtype) + else: + averaged_lattice_output = tf.keras.layers.Average()(lattice_outputs) else: averaged_lattice_output = lattice_outputs[0] if model_config.output_calibration: @@ -163,7 +169,6 @@ def get_config(self): @classmethod def from_config(cls, config, custom_objects=None): - custom_objects = _extend_custom_objects(custom_objects) model = super(CalibratedLatticeEnsemble, cls).from_config( config, custom_objects=custom_objects) try: @@ -279,7 +284,6 @@ def get_config(self): @classmethod def from_config(cls, config, custom_objects=None): - custom_objects = _extend_custom_objects(custom_objects) model = super(CalibratedLattice, cls).from_config( config, custom_objects=custom_objects) try: @@ -363,8 +367,7 @@ def __init__(self, model_config=None, dtype=tf.float32, **kwargs): weighted_average = ( model_config.output_min is not None or - model_config.output_max is not None or - model_config.output_calibration) + model_config.output_max is not None or model_config.output_calibration) linear_output = premade_lib.build_linear_layer( linear_input=submodels_inputs[0], feature_configs=model_config.feature_configs, @@ -399,7 +402,6 @@ def get_config(self): @classmethod def from_config(cls, config, custom_objects=None): - custom_objects = _extend_custom_objects(custom_objects) model = super(CalibratedLinear, cls).from_config( config, custom_objects=custom_objects) try: @@ -463,9 +465,7 @@ def __init__(self, model_config=None, dtype=tf.float32, **kwargs): premade_lib.verify_config(model_config) # Get feature configs and construct model. input_layer = premade_lib.build_input_layer( - feature_configs=model_config.feature_configs, - dtype=dtype, - ragged=True) + feature_configs=model_config.feature_configs, dtype=dtype, ragged=True) # We need to construct middle_dimension calibrated_lattices for the # aggregation layer. Note that we cannot do this in premade_lib because @@ -475,6 +475,7 @@ def __init__(self, model_config=None, dtype=tf.float32, **kwargs): # aggregation. calibrated_lattice_config = configs.CalibratedLatticeConfig( feature_configs=model_config.feature_configs, + interpolation=model_config.aggregation_lattice_interpolation, regularizer_configs=model_config.regularizer_configs, output_min=-1.0, output_max=1.0, @@ -487,8 +488,14 @@ def __init__(self, model_config=None, dtype=tf.float32, **kwargs): premade_lib.LayerOutputRange.INPUT_TO_FINAL_CALIBRATION if model_config.output_calibration else premade_lib.LayerOutputRange.MODEL_OUTPUT) + # Change input layer into a list based on model_config.feature_configs. + # This is the order of inputs expected by calibrated_lattice_models. + inputs = [ + input_layer[feature_config.name] + for feature_config in model_config.feature_configs + ] aggregation_output = premade_lib.build_aggregation_layer( - aggregation_input_layer=input_layer, + aggregation_input_layer=inputs, model_config=model_config, calibrated_lattice_models=calibrated_lattice_models, layer_output_range=aggregation_layer_output_range, @@ -504,12 +511,9 @@ def __init__(self, model_config=None, dtype=tf.float32, **kwargs): model_output = aggregation_output # Define inputs and initialize model. - inputs = [ - input_layer[feature_config.name] - for feature_config in model_config.feature_configs - ] - super(AggregateFunction, self).__init__( - inputs=inputs, outputs=model_output) + kwargs['inputs'] = inputs + kwargs['outputs'] = model_output + super(AggregateFunction, self).__init__(**kwargs) def get_config(self): """Returns a configuration dictionary.""" @@ -520,7 +524,6 @@ def get_config(self): @classmethod def from_config(cls, config, custom_objects=None): - custom_objects = _extend_custom_objects(custom_objects) model = super(AggregateFunction, cls).from_config( config, custom_objects=custom_objects) try: @@ -535,8 +538,17 @@ def from_config(cls, config, custom_objects=None): return model -def _extend_custom_objects(custom_objects): - """Extends the given custom_objects mapping with TFL objects.""" +def get_custom_objects(custom_objects=None): + """Creates and returns a dictionary mapping names to custom objects. + + Args: + custom_objects: Optional dictionary mapping names (strings) to custom + classes or functions to be considered during deserialization. If provided, + the returned mapping will be extended to contain this one. + + Returns: + A dictionary mapping names (strings) to tensorflow lattice custom objects. + """ tfl_custom_objects = { 'AggregateFunction': AggregateFunction, diff --git a/tensorflow_lattice/python/premade_lib.py b/tensorflow_lattice/python/premade_lib.py index fc3be54..101d70d 100644 --- a/tensorflow_lattice/python/premade_lib.py +++ b/tensorflow_lattice/python/premade_lib.py @@ -19,6 +19,7 @@ import collections import copy +import enum import itertools from . import aggregation_layer @@ -31,7 +32,6 @@ from . import pwl_calibration_lib from absl import logging -import enum import numpy as np import six @@ -43,6 +43,7 @@ INPUT_LAYER_NAME = 'tfl_input' LATTICE_LAYER_NAME = 'tfl_lattice' LINEAR_LAYER_NAME = 'tfl_linear' +OUTPUT_LINEAR_COMBINATION_LAYER_NAME = 'tfl_output_linear_combination' OUTPUT_CALIB_LAYER_NAME = 'tfl_output_calib' # Prefix for passthrough (identity) nodes for shared calibration. @@ -296,7 +297,9 @@ def build_aggregation_layer(aggregation_input_layer, model_config, """Creates an aggregation layer using the given calibrated lattice models. Args: - aggregation_input_layer: A mapping from feature name to `tf.keras.Input`. + aggregation_input_layer: A list or a mapping from feature name to + `tf.keras.Input`, in the order or format expected by + `calibrated_lattice_models`. model_config: Model configuration object describing model architecture. Should be one of the model configs in `tfl.configs`. calibrated_lattice_models: A list of calibrated lattice models of size @@ -359,6 +362,7 @@ def build_aggregation_layer(aggregation_input_layer, model_config, output_min=output_min, output_max=output_max, clip_inputs=False, + interpolation=model_config.middle_lattice_interpolation, kernel_initializer=kernel_initializer, dtype=dtype, name=lattice_layer_name, @@ -527,6 +531,7 @@ def build_lattice_layer(lattice_input, feature_configs, model_config, output_min=output_min, output_max=output_max, clip_inputs=False, + interpolation=model_config.interpolation, kernel_regularizer=lattice_regularizers, kernel_initializer=kernel_initializer, dtype=dtype, @@ -534,6 +539,50 @@ def build_lattice_layer(lattice_input, feature_configs, model_config, lattice_input) +def build_linear_combination_layer(ensemble_outputs, model_config, dtype): + """Creates a `tfl.layers.Linear` layer initialized to be an average. + + Args: + ensemble_outputs: Ensemble outputs to be linearly combined. + model_config: Model configuration object describing model architecture. + Should be one of the model configs in `tfl.configs`. + dtype: dtype + + Returns: + A `tfl.layers.Linear` instance. + """ + if isinstance(ensemble_outputs, list): + num_input_dims = len(ensemble_outputs) + linear_input = tf.keras.layers.Concatenate(axis=1)(ensemble_outputs) + else: + num_input_dims = int(ensemble_outputs.shape[1]) + linear_input = ensemble_outputs + kernel_initializer = tf.keras.initializers.Constant(1.0 / num_input_dims) + bias_initializer = tf.keras.initializers.Constant(0) + + if (not model_config.output_calibration and + model_config.output_min is None and model_config.output_max is None): + normalization_order = None + else: + # We need to use weighted average to keep the output range. + normalization_order = 1 + # Bias term cannot be used when this layer should have bounded output. + if model_config.use_bias: + raise ValueError('Cannot use a bias term in linear combination with ' + 'output bounds or output calibration') + + return linear_layer.Linear( + num_input_dims=num_input_dims, + monotonicities=['increasing'] * num_input_dims, + normalization_order=normalization_order, + use_bias=model_config.use_bias, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + dtype=dtype, + name=OUTPUT_LINEAR_COMBINATION_LAYER_NAME)( + linear_input) + + def build_output_calibration_layer(output_calibration_input, model_config, dtype): """Creates a monotonic output calibration layer with inputs range [0, 1]. diff --git a/tensorflow_lattice/python/premade_test.py b/tensorflow_lattice/python/premade_test.py index 47e7921..91a55ef 100644 --- a/tensorflow_lattice/python/premade_test.py +++ b/tensorflow_lattice/python/premade_test.py @@ -20,6 +20,7 @@ import copy import json +import tempfile from absl import logging import numpy as np import pandas as pd @@ -28,6 +29,13 @@ from tensorflow_lattice.python import premade from tensorflow_lattice.python import premade_lib + +fake_data = { + 'train_xs': [np.array([1]), np.array([3]), np.array([0])], + 'train_ys': np.array([1]), + 'eval_xs': [np.array([2]), np.array([30]), np.array([-3])] +} + unspecified_feature_configs = [ configs.FeatureConfig( name='numerical_1', @@ -403,7 +411,7 @@ def testLatticeEnsembleFromConfig(self): output_initialization=[-1.0, 1.0]) model = premade.CalibratedLatticeEnsemble(model_config) loaded_model = premade.CalibratedLatticeEnsemble.from_config( - model.get_config()) + model.get_config(), custom_objects=premade.get_custom_objects()) self.assertEqual( json.dumps(model.get_config(), sort_keys=True, cls=self.Encoder), json.dumps(loaded_model.get_config(), sort_keys=True, cls=self.Encoder)) @@ -421,7 +429,28 @@ def testLatticeFromConfig(self): output_calibration_num_keypoints=6, output_initialization=[0.0, 1.0]) model = premade.CalibratedLattice(model_config) - loaded_model = premade.CalibratedLattice.from_config(model.get_config()) + loaded_model = premade.CalibratedLattice.from_config( + model.get_config(), custom_objects=premade.get_custom_objects()) + self.assertEqual( + json.dumps(model.get_config(), sort_keys=True, cls=self.Encoder), + json.dumps(loaded_model.get_config(), sort_keys=True, cls=self.Encoder)) + + def testLatticeSimplexFromConfig(self): + model_config = configs.CalibratedLatticeConfig( + feature_configs=copy.deepcopy(feature_configs), + regularizer_configs=[ + configs.RegularizerConfig('calib_wrinkle', l2=1e-3), + configs.RegularizerConfig('torsion', l2=1e-3), + ], + output_min=0.0, + output_max=1.0, + interpolation='simplex', + output_calibration=True, + output_calibration_num_keypoints=6, + output_initialization=[0.0, 1.0]) + model = premade.CalibratedLattice(model_config) + loaded_model = premade.CalibratedLattice.from_config( + model.get_config(), custom_objects=premade.get_custom_objects()) self.assertEqual( json.dumps(model.get_config(), sort_keys=True, cls=self.Encoder), json.dumps(loaded_model.get_config(), sort_keys=True, cls=self.Encoder)) @@ -440,7 +469,8 @@ def testLinearFromConfig(self): output_calibration_num_keypoints=6, output_initialization=[0.0, 1.0]) model = premade.CalibratedLinear(model_config) - loaded_model = premade.CalibratedLinear.from_config(model.get_config()) + loaded_model = premade.CalibratedLinear.from_config( + model.get_config(), custom_objects=premade.get_custom_objects()) self.assertEqual( json.dumps(model.get_config(), sort_keys=True, cls=self.Encoder), json.dumps(loaded_model.get_config(), sort_keys=True, cls=self.Encoder)) @@ -460,7 +490,8 @@ def testAggregateFromConfig(self): output_calibration_num_keypoints=8, output_initialization=[0.0, 1.0]) model = premade.AggregateFunction(model_config) - loaded_model = premade.AggregateFunction.from_config(model.get_config()) + loaded_model = premade.AggregateFunction.from_config( + model.get_config(), custom_objects=premade.get_custom_objects()) self.assertEqual( json.dumps(model.get_config(), sort_keys=True, cls=self.Encoder), json.dumps(loaded_model.get_config(), sort_keys=True, cls=self.Encoder)) @@ -518,6 +549,118 @@ def testCalibratedLatticeEnsembleCrystals(self): logging.info(results) self.assertGreater(results[1], 0.85) + def testLatticeEnsembleH5FormatSaveLoad(self): + model_config = configs.CalibratedLatticeEnsembleConfig( + feature_configs=copy.deepcopy(feature_configs), + lattices=[['numerical_1', 'categorical'], + ['numerical_2', 'categorical']], + num_lattices=2, + lattice_rank=2, + separate_calibrators=True, + regularizer_configs=[ + configs.RegularizerConfig('calib_hessian', l2=1e-3), + configs.RegularizerConfig('torsion', l2=1e-4), + ], + output_min=-1.0, + output_max=1.0, + output_calibration=True, + output_calibration_num_keypoints=5, + output_initialization=[-1.0, 1.0]) + model = premade.CalibratedLatticeEnsemble(model_config) + # Compile and fit model. + model.compile(loss='mse', optimizer=tf.keras.optimizers.Adam(0.1)) + model.fit(fake_data['train_xs'], fake_data['train_ys']) + # Save model using H5 format. + with tempfile.NamedTemporaryFile(suffix='.h5') as f: + tf.keras.models.save_model(model, f.name) + loaded_model = tf.keras.models.load_model( + f.name, custom_objects=premade.get_custom_objects()) + self.assertAllClose( + model.predict(fake_data['eval_xs']), + loaded_model.predict(fake_data['eval_xs'])) + + def testLatticeH5FormatSaveLoad(self): + model_config = configs.CalibratedLatticeConfig( + feature_configs=copy.deepcopy(feature_configs), + regularizer_configs=[ + configs.RegularizerConfig('calib_wrinkle', l2=1e-3), + configs.RegularizerConfig('torsion', l2=1e-3), + ], + output_min=0.0, + output_max=1.0, + output_calibration=True, + output_calibration_num_keypoints=6, + output_initialization=[0.0, 1.0]) + model = premade.CalibratedLattice(model_config) + # Compile and fit model. + model.compile(loss='mse', optimizer=tf.keras.optimizers.Adam(0.1)) + model.fit(fake_data['train_xs'], fake_data['train_ys']) + # Save model using H5 format. + with tempfile.NamedTemporaryFile(suffix='.h5') as f: + tf.keras.models.save_model(model, f.name) + loaded_model = tf.keras.models.load_model( + f.name, custom_objects=premade.get_custom_objects()) + self.assertAllClose( + model.predict(fake_data['eval_xs']), + loaded_model.predict(fake_data['eval_xs'])) + + def testLinearH5FormatSaveLoad(self): + model_config = configs.CalibratedLinearConfig( + feature_configs=copy.deepcopy(feature_configs), + regularizer_configs=[ + configs.RegularizerConfig('calib_hessian', l2=1e-4), + configs.RegularizerConfig('torsion', l2=1e-3), + ], + use_bias=True, + output_min=0.0, + output_max=1.0, + output_calibration=True, + output_calibration_num_keypoints=6, + output_initialization=[0.0, 1.0]) + model = premade.CalibratedLinear(model_config) + # Compile and fit model. + model.compile(loss='mse', optimizer=tf.keras.optimizers.Adam(0.1)) + model.fit(fake_data['train_xs'], fake_data['train_ys']) + # Save model using H5 format. + with tempfile.NamedTemporaryFile(suffix='.h5') as f: + tf.keras.models.save_model(model, f.name) + loaded_model = tf.keras.models.load_model( + f.name, custom_objects=premade.get_custom_objects()) + self.assertAllClose( + model.predict(fake_data['eval_xs']), + loaded_model.predict(fake_data['eval_xs'])) + + def testAggregateH5FormatSaveLoad(self): + model_config = configs.AggregateFunctionConfig( + feature_configs=feature_configs, + regularizer_configs=[ + configs.RegularizerConfig('calib_hessian', l2=1e-4), + configs.RegularizerConfig('torsion', l2=1e-3), + ], + middle_calibration=True, + middle_monotonicity='increasing', + output_min=0.0, + output_max=1.0, + output_calibration=True, + output_calibration_num_keypoints=8, + output_initialization=[0.0, 1.0]) + model = premade.AggregateFunction(model_config) + # Compile and fit model. + model.compile(loss='mse', optimizer=tf.keras.optimizers.Adam(0.1)) + model.fit(fake_data['train_xs'], fake_data['train_ys']) + # Save model using H5 format. + with tempfile.NamedTemporaryFile(suffix='.h5') as f: + # Note: because of naming clashes in the optimizer, we cannot include it + # when saving in HDF5. The keras team has informed us that we should not + # push to support this since SavedModel format is the new default and no + # new HDF5 functionality is desired. + tf.keras.models.save_model(model, f.name, include_optimizer=False) + loaded_model = tf.keras.models.load_model( + f.name, custom_objects=premade.get_custom_objects()) + self.assertAllClose( + model.predict(fake_data['eval_xs']), + loaded_model.predict(fake_data['eval_xs'])) + if __name__ == '__main__': tf.test.main() diff --git a/tensorflow_lattice/python/pwl_calibration_layer.py b/tensorflow_lattice/python/pwl_calibration_layer.py index 236ce31..626b94e 100644 --- a/tensorflow_lattice/python/pwl_calibration_layer.py +++ b/tensorflow_lattice/python/pwl_calibration_layer.py @@ -396,7 +396,7 @@ def call(self, inputs): if inputs.dtype != self._interpolation_keypoints.dtype: raise ValueError("dtype(%s) of input to PWLCalibration layer does not " "correspond to dtype(%s) of keypoints. You can enforce " - "dtype of keypoints be explicitly providing 'dtype' " + "dtype of keypoints by explicitly providing 'dtype' " "parameter to layer constructor or by passing keypoints " "in such format which by default will be converted into " "desired one." % @@ -475,13 +475,13 @@ def assert_constraints(self, eps=1e-6): In graph mode builds and returns list of assertion ops. Note that ops will be created at the moment when this function is being called. - In eager mode directly executes assetions. + In eager mode directly executes assertions. Args: eps: Allowed constraints violation. Returns: - List of assertion ops in graph mode or immideately asserts in eager mode. + List of assertion ops in graph mode or immediately asserts in eager mode. """ # Assert by computing outputs for keypoints and testing them against # constraints. @@ -569,7 +569,7 @@ def __call__(self, shape, dtype=None, partition_info=None): """Returns weights of PWL calibration layer. Args: - shape: Must be rank-2 tensor with of shape `(k, units)` where `k >= 2`. + shape: Must be a collection of the form `(k, units)` where `k >= 2`. dtype: Standard Keras initializer param. partition_info: Standard Keras initializer param. diff --git a/tensorflow_lattice/python/pwl_calibration_lib.py b/tensorflow_lattice/python/pwl_calibration_lib.py index a3eff70..e859bbc 100644 --- a/tensorflow_lattice/python/pwl_calibration_lib.py +++ b/tensorflow_lattice/python/pwl_calibration_lib.py @@ -121,7 +121,7 @@ def linear_initializer(shape, dtype=None): """Initializes PWL calibration layer to represent linear function. - PWL calibration layer weights have shape `(knum_keypoints, units)`. First row + PWL calibration layer weights have shape `(num_keypoints, units)`. First row represents bias. All remaining represent delta in y-value compare to previous point. Aka heights of segments. @@ -469,7 +469,7 @@ def project_all_constraints(weights, For all combinations of constraints except the case where bounds constraints are specified without monotonicity constraints we properly project into - nearest point with respect to L2 norm. For later case we use a heuristic to + nearest point with respect to L2 norm. For latter case we use a heuristic to map input point into some feasible point with no guarantees on how close this point is to the true projection. diff --git a/tensorflow_lattice/python/pwl_calibration_sonnet_module.py b/tensorflow_lattice/python/pwl_calibration_sonnet_module.py new file mode 100644 index 0000000..fbc044a --- /dev/null +++ b/tensorflow_lattice/python/pwl_calibration_sonnet_module.py @@ -0,0 +1,543 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Piecewise linear calibration layer. + +Sonnet (v2) implementation of tensorflow lattice pwl calibration module. Module +takes single or multi-dimensional input and transforms it using piecewise linear +functions following monotonicity, convexity/concavity and bounds constraints if +specified. +""" + +# TODO: Add built-in regularizers like laplacian, hessian, etc. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from . import pwl_calibration_lib + +from absl import logging +import numpy as np +import sonnet as snt +import tensorflow as tf + +INTERPOLATION_KEYPOINTS_NAME = "interpolation_keypoints" +LENGTHS_NAME = "lengths" +MISSING_INPUT_VALUE_NAME = "missing_input_value" +PWL_CALIBRATION_KERNEL_NAME = "pwl_calibration_kernel" +PWL_CALIBRATION_MISSING_OUTPUT_NAME = "pwl_calibration_missing_output" + + +class PWLCalibration(snt.Module): + # pyformat: disable + """Piecewise linear calibration layer. + + Module takes input of shape `(batch_size, units)` or `(batch_size, 1)` and + transforms it using `units` number of piecewise linear functions following + monotonicity, convexity and bounds constraints if specified. If multi + dimensional input is provides, each output will be for the corresponding + input, otherwise all PWL functions will act on the same input. All units share + the same configuration, but each has their separate set of trained + parameters. + + Input shape: + Single input should be a rank-2 tensor with shape: `(batch_size, units)` or + `(batch_size, 1)`. The input can also be a list of two tensors of the same + shape where the first tensor is the regular input tensor and the second is the + `is_missing` tensor. In the `is_missing` tensor, 1.0 represents missing input + and 0.0 represents available input. + + Output shape: + Rank-2 tensor with shape: `(batch_size, units)`. + + Attributes: + - All `__init__` arguments. + kernel: TF variable which stores weights of piecewise linear function. + missing_output: TF variable which stores output learned for missing input. + Or TF Constant which stores `missing_output_value` if one is provided. + Available only if `impute_missing` is True. + + Example: + + ```python + calibrator = tfl.sonnet_modules.PWLCalibration( + # Key-points of piecewise-linear function. + input_keypoints=np.linspace(1., 4., num=4), + # Output can be bounded, e.g. when this layer feeds into a lattice. + output_min=0.0, + output_max=2.0, + # You can specify monotonicity and other shape constraints for the layer. + monotonicity='increasing', + ) + ``` + """ + # pyformat: enable + + def __init__(self, + input_keypoints, + units=1, + output_min=None, + output_max=None, + clamp_min=False, + clamp_max=False, + monotonicity="none", + convexity="none", + is_cyclic=False, + kernel_init="equal_heights", + impute_missing=False, + missing_input_value=None, + missing_output_value=None, + num_projection_iterations=8, + **kwargs): + # pyformat: disable + """Initializes an instance of `PWLCalibration`. + + Args: + input_keypoints: Ordered list of keypoints of piecewise linear function. + Can be anything accepted by tf.convert_to_tensor(). + units: Output dimension of the layer. See class comments for details. + output_min: Minimum output of calibrator. + output_max: Maximum output of calibrator. + clamp_min: For monotonic calibrators ensures that output_min is reached. + clamp_max: For monotonic calibrators ensures that output_max is reached. + monotonicity: Constraints piecewise linear function to be monotonic using + 'increasing' or 1 to indicate increasing monotonicity, 'decreasing' or + -1 to indicate decreasing monotonicity and 'none' or 0 to indicate no + monotonicity constraints. + convexity: Constraints piecewise linear function to be convex or concave. + Convexity is indicated by 'convex' or 1, concavity is indicated by + 'concave' or -1, 'none' or 0 indicates no convexity/concavity + constraints. + Concavity together with increasing monotonicity as well as convexity + together with decreasing monotonicity results in diminishing return + constraints. + Consider increasing the value of `num_projection_iterations` if + convexity is specified, especially with larger number of keypoints. + is_cyclic: Whether the output for last keypoint should be identical to + output for first keypoint. This is useful for features such as + "time of day" or "degree of turn". If inputs are discrete and exactly + match keypoints then is_cyclic will have an effect only if TFL + regularizers are being used. + kernel_init: None or one of: + - String `"equal_heights"`: For pieces of pwl function to have equal + heights. + - String `"equal_slopes"`: For pieces of pwl function to have equal + slopes. + - Any Sonnet initializer object. If you are passing such object make + sure that you know how this module uses the variables. + impute_missing: Whether to learn an output for cases where input data is + missing. If set to True, either `missing_input_value` should be + initialized, or the `call()` method should get pair of tensors. See + class input shape description for more details. + missing_input_value: If set, all inputs which are equal to this value will + be considered as missing. Can not be set if `impute_missing` is False. + missing_output_value: If set, instead of learning output for missing + inputs, simply maps them into this value. Can not be set if + `impute_missing` is False. + num_projection_iterations: Number of iterations of the Dykstra's + projection algorithm. Constraints are strictly satisfied at the end of + each update, but the update will be closer to a true L2 projection with + higher number of iterations. See + `tfl.pwl_calibration_lib.project_all_constraints` for more details. + **kwargs: Other args passed to `snt.Module` initializer. + + Raises: + ValueError: If layer hyperparameters are invalid. + """ + # pyformat: enable + super(PWLCalibration, self).__init__(**kwargs) + + pwl_calibration_lib.verify_hyperparameters( + input_keypoints=input_keypoints, + output_min=output_min, + output_max=output_max, + monotonicity=monotonicity, + convexity=convexity, + is_cyclic=is_cyclic) + if missing_input_value is not None and not impute_missing: + raise ValueError("'missing_input_value' is specified, but " + "'impute_missing' is set to False. " + "'missing_input_value': " + str(missing_input_value)) + if missing_output_value is not None and not impute_missing: + raise ValueError("'missing_output_value' is specified, but " + "'impute_missing' is set to False. " + "'missing_output_value': " + str(missing_output_value)) + if input_keypoints is None: + raise ValueError("'input_keypoints' can't be None") + if monotonicity is None: + raise ValueError("'monotonicity' can't be None. Did you mean '0'?") + + self.input_keypoints = input_keypoints + self.units = units + self.output_min = output_min + self.output_max = output_max + self.clamp_min = clamp_min + self.clamp_max = clamp_max + (self._output_init_min, self._output_init_max, self._output_min_constraints, + self._output_max_constraints + ) = pwl_calibration_lib.convert_all_constraints(self.output_min, + self.output_max, + self.clamp_min, + self.clamp_max) + + self.monotonicity = monotonicity + self.convexity = convexity + self.is_cyclic = is_cyclic + + if kernel_init == "equal_heights": + self.kernel_init = _UniformOutputInitializer( + output_min=self._output_init_min, + output_max=self._output_init_max, + monotonicity=self.monotonicity) + elif kernel_init == "equal_slopes": + self.kernel_init = _UniformOutputInitializer( + output_min=self._output_init_min, + output_max=self._output_init_max, + monotonicity=self.monotonicity, + keypoints=self.input_keypoints) + + self.impute_missing = impute_missing + self.missing_input_value = missing_input_value + self.missing_output_value = missing_output_value + self.num_projection_iterations = num_projection_iterations + + @snt.once + def _create_parameters_once(self, inputs): + """Creates the variables that will be reused on each call of the module.""" + self.dtype = tf.convert_to_tensor(self.input_keypoints).dtype + input_keypoints = np.array(self.input_keypoints) + # Don't need last keypoint for interpolation because we need only beginnings + # of intervals. + self._interpolation_keypoints = tf.constant( + input_keypoints[:-1], + dtype=self.dtype, + name=INTERPOLATION_KEYPOINTS_NAME) + self._lengths = tf.constant( + input_keypoints[1:] - input_keypoints[:-1], + dtype=self.dtype, + name=LENGTHS_NAME) + + constraints = _PWLCalibrationConstraints( + monotonicity=self.monotonicity, + convexity=self.convexity, + lengths=self._lengths, + output_min=self.output_min, + output_max=self.output_max, + output_min_constraints=self._output_min_constraints, + output_max_constraints=self._output_max_constraints, + num_projection_iterations=self.num_projection_iterations) + + # If 'is_cyclic' is specified - last weight will be computed from previous + # weights in order to connect last keypoint with first. + num_weights = input_keypoints.size - self.is_cyclic + + # PWL calibration layer kernel is units-column matrix. First row of matrix + # represents bias. All remaining represent delta in y-value compare to + # previous point. Aka heights of segments. + self.kernel = tf.Variable( + initial_value=self.kernel_init([num_weights, self.units], + dtype=self.dtype), + name=PWL_CALIBRATION_KERNEL_NAME, + constraint=constraints) + + if self.impute_missing: + if self.missing_input_value is not None: + self._missing_input_value_tensor = tf.constant( + self.missing_input_value, + dtype=self.dtype, + name=MISSING_INPUT_VALUE_NAME) + else: + self._missing_input_value_tensor = None + + if self.missing_output_value is not None: + self.missing_output = tf.constant( + self.missing_output_value, shape=[1, self.units], dtype=self.dtype) + else: + missing_init = (self._output_init_min + self._output_init_max) / 2.0 + missing_constraints = _NaiveBoundsConstraints( + lower_bound=self.output_min, upper_bound=self.output_max) + initializer = snt.initializers.Constant(missing_init) + self.missing_output = tf.Variable( + initial_value=initializer([1, self.units], self.dtype), + name=PWL_CALIBRATION_MISSING_OUTPUT_NAME, + constraint=missing_constraints) + + def __call__(self, inputs): + """Standard Sonnet __call__() method.. + + Args: + inputs: Either input tensor or list of 2 elements: input tensor and + `is_missing` tensor. + + Returns: + Calibrated input tensor. + + Raises: + ValueError: If `is_missing` tensor specified incorrectly. + """ + self._create_parameters_once(inputs) + is_missing = None + if isinstance(inputs, list): + # Only 2 element lists are allowed. When such list is given - second + # element represents 'is_missing' tensor encoded as float value. + if not self.impute_missing: + raise ValueError("Multiple inputs for PWLCalibration module assume " + "regular input tensor and 'is_missing' tensor, but " + "this instance of a layer is not configured to handle " + "missing value. See 'impute_missing' parameter.") + if len(inputs) > 2: + raise ValueError("Multiple inputs for PWLCalibration module assume " + "normal input tensor and 'is_missing' tensor, but more" + " than 2 tensors given. 'inputs': " + str(inputs)) + if len(inputs) == 2: + inputs, is_missing = inputs + if is_missing.shape.as_list() != inputs.shape.as_list(): + raise ValueError( + "is_missing shape %s does not match inputs shape %s for " + "PWLCalibration module" % + (str(is_missing.shape), str(inputs.shape))) + else: + [inputs] = inputs + if len(inputs.shape) != 2 or (inputs.shape[1] != self.units and + inputs.shape[1] != 1): + raise ValueError("Shape of input tensor for PWLCalibration module must" + " be [-1, units] or [-1, 1]. It is: " + + str(inputs.shape)) + + if self._interpolation_keypoints.dtype != inputs.dtype: + raise ValueError("dtype(%s) of input to PWLCalibration module does not " + "correspond to dtype(%s) of keypoints. You can enforce " + "dtype of keypoints by passing keypoints " + "in such format which by default will be converted into " + "the desired one." % + (inputs.dtype, self._interpolation_keypoints.dtype)) + # Here is calibration. Everything else is handling of missing. + if inputs.shape[1] > 1: + # Add dimension to multi dim input to get shape [batch_size, units, 1]. + # Interpolation will have shape [batch_size, units, weights]. + inputs_to_calibration = tf.expand_dims(inputs, -1) + else: + inputs_to_calibration = inputs + interpolation_weights = pwl_calibration_lib.compute_interpolation_weights( + inputs_to_calibration, self._interpolation_keypoints, self._lengths) + if self.is_cyclic: + # Need to add such last height to make all heights to sum up to 0.0 in + # order to make calibrator cyclic. + bias_and_heights = tf.concat( + [self.kernel, -tf.reduce_sum(self.kernel[1:], + axis=0, keepdims=True)], + axis=0) + else: + bias_and_heights = self.kernel + + # bias_and_heights has shape [weight, units]. + if inputs.shape[1] > 1: + # Multi dim input has interpolation shape [batch_size, units, weights]. + result = tf.reduce_sum( + interpolation_weights * tf.transpose(bias_and_heights), axis=-1) + else: + # Single dim input has interpolation shape [batch_size, weights]. + result = tf.matmul(interpolation_weights, bias_and_heights) + + if self.impute_missing: + if is_missing is None: + if self.missing_input_value is None: + raise ValueError("PWLCalibration layer is configured to impute " + "missing but no 'missing_input_value' specified and " + "'is_missing' tensor is not given.") + assert self._missing_input_value_tensor is not None + is_missing = tf.cast( + tf.equal(inputs, self._missing_input_value_tensor), + dtype=self.dtype) + result = is_missing * self.missing_output + (1.0 - is_missing) * result + return result + + +class _UniformOutputInitializer(snt.initializers.Initializer): + # pyformat: disable + """Initializes PWL calibration layer to represent linear function. + + PWL calibration layer weights are one-d tensor. First element of tensor + represents bias. All remaining represent delta in y-value compare to previous + point. Aka heights of segments. + + Attributes: + - All `__init__` arguments. + """ + # pyformat: enable + + def __init__(self, output_min, output_max, monotonicity, keypoints=None): + # pyformat: disable + """Initializes an instance of `_UniformOutputInitializer`. + + Args: + output_min: Minimum value of PWL calibration output after initialization. + output_max: Maximum value of PWL calibration output after initialization. + monotonicity: + - if 'none' or 'increasing', the returned function will go from + `(input_min, output_min)` to `(input_max, output_max)`. + - if 'decreasing', the returned function will go from + `(input_min, output_max)` to `(input_max, output_min)`. + keypoints: + - if not provided (None or []), all pieces of returned function + will have equal heights (i.e. `y[i+1] - y[i]` is constant). + - if provided, all pieces of returned function will have equal slopes + (i.e. `(y[i+1] - y[i]) / (x[i+1] - x[i])` is constant). + """ + # pyformat: enable + pwl_calibration_lib.verify_hyperparameters( + input_keypoints=keypoints, + output_min=output_min, + output_max=output_max, + monotonicity=monotonicity) + self.output_min = output_min + self.output_max = output_max + self.monotonicity = monotonicity + self.keypoints = keypoints + + def __call__(self, shape, dtype): + """Returns weights of PWL calibration layer. + + Args: + shape: Must be a collection of the form `(k, units)` where `k >= 2`. + dtype: Standard Sonnet initializer param. + + Returns: + Weights of PWL calibration layer. + + Raises: + ValueError: If requested shape is invalid for PWL calibration layer + weights. + """ + return pwl_calibration_lib.linear_initializer( + shape=shape, + output_min=self.output_min, + output_max=self.output_max, + monotonicity=pwl_calibration_lib.canonicalize_monotonicity( + self.monotonicity), + keypoints=self.keypoints, + dtype=dtype) + + +class _PWLCalibrationConstraints(object): + # pyformat: disable + """Monotonicity and bounds constraints for PWL calibration layer. + + Applies an approximate L2 projection to the weights of a PWLCalibration layer + such that the result satisfies the specified constraints. + + Attributes: + - All `__init__` arguments. + """ + # pyformat: enable + + def __init__( + self, + monotonicity="none", + convexity="none", + lengths=None, + output_min=None, + output_max=None, + output_min_constraints=pwl_calibration_lib.BoundConstraintsType.NONE, + output_max_constraints=pwl_calibration_lib.BoundConstraintsType.NONE, + num_projection_iterations=8): + """Initializes an instance of `PWLCalibration`. + + Args: + monotonicity: Same meaning as corresponding parameter of `PWLCalibration`. + convexity: Same meaning as corresponding parameter of `PWLCalibration`. + lengths: Lengths of pieces of piecewise linear function. Needed only if + convexity is specified. + output_min: Minimum possible output of pwl function. + output_max: Maximum possible output of pwl function. + output_min_constraints: A `tfl.pwl_calibration_lib.BoundConstraintsType` + describing the constraints on the layer's minimum value. + output_max_constraints: A `tfl.pwl_calibration_lib.BoundConstraintsType` + describing the constraints on the layer's maximum value. + num_projection_iterations: Same meaning as corresponding parameter of + `PWLCalibration`. + """ + pwl_calibration_lib.verify_hyperparameters( + output_min=output_min, + output_max=output_max, + monotonicity=monotonicity, + convexity=convexity, + lengths=lengths) + self.monotonicity = monotonicity + self.convexity = convexity + self.lengths = lengths + self.output_min = output_min + self.output_max = output_max + self.output_min_constraints = output_min_constraints + self.output_max_constraints = output_max_constraints + self.num_projection_iterations = num_projection_iterations + + canonical_convexity = pwl_calibration_lib.canonicalize_convexity( + self.convexity) + canonical_monotonicity = pwl_calibration_lib.canonicalize_monotonicity( + self.monotonicity) + if (canonical_convexity != 0 and canonical_monotonicity == 0 and + (output_min_constraints != pwl_calibration_lib.BoundConstraintsType.NONE + or output_max_constraints != + pwl_calibration_lib.BoundConstraintsType.NONE)): + logging.warning("Convexity constraints are specified with bounds " + "constraints, but without monotonicity. Such combination " + "might lead to convexity being slightly violated. " + "Consider increasing num_projection_iterations to " + "reduce violation.") + + def __call__(self, w): + """Applies constraints to w.""" + return pwl_calibration_lib.project_all_constraints( + weights=w, + monotonicity=pwl_calibration_lib.canonicalize_monotonicity( + self.monotonicity), + output_min=self.output_min, + output_max=self.output_max, + output_min_constraints=self.output_min_constraints, + output_max_constraints=self.output_max_constraints, + convexity=pwl_calibration_lib.canonicalize_convexity( + self.convexity), + lengths=self.lengths, + num_projection_iterations=self.num_projection_iterations) + + +class _NaiveBoundsConstraints(object): + # pyformat: disable + """Naively clips all elements of tensor to be within bounds. + + This constraint is used only for the weight tensor for missing output value. + + Attributes: + - All `__init__` arguments. + """ + # pyformat: enable + + def __init__(self, lower_bound=None, upper_bound=None): + """Initializes an instance of `_NaiveBoundsConstraints`. + + Args: + lower_bound: Lower bound to clip variable values to. + upper_bound: Upper bound to clip variable values to. + """ + self.lower_bound = lower_bound + self.upper_bound = upper_bound + + def __call__(self, w): + """Applies constraints to w.""" + if self.lower_bound is not None: + w = tf.maximum(w, self.lower_bound) + if self.upper_bound is not None: + w = tf.minimum(w, self.upper_bound) + return w diff --git a/tensorflow_lattice/python/pwl_calibration_test.py b/tensorflow_lattice/python/pwl_calibration_test.py index 3fb7939..8e6d70c 100644 --- a/tensorflow_lattice/python/pwl_calibration_test.py +++ b/tensorflow_lattice/python/pwl_calibration_test.py @@ -14,6 +14,10 @@ """Tests for PWL calibration layer. This test should be run with "-c opt" since otherwise it's slow. +Also, to only run a subset of the tests (useful when developing a new test or +set of tests), change the initialization of the _disable_all boolean to 'True' +in the SetUp method, and comment out the check for this boolean in those tests +that you want to run. """ from __future__ import absolute_import @@ -28,7 +32,8 @@ import tensorflow as tf from tensorflow import keras from tensorflow_lattice.python import parallel_combination_layer as parallel_combination -from tensorflow_lattice.python import pwl_calibration_layer as pwl_calibraion +from tensorflow_lattice.python import pwl_calibration_layer as keras_layer +from tensorflow_lattice.python import pwl_calibration_sonnet_module as sonnet_module from tensorflow_lattice.python import pwl_calibration_lib as pwl_lib from tensorflow_lattice.python import test_utils @@ -145,10 +150,12 @@ def _SetDefaults(self, config): config.setdefault("num_projection_iterations", 8) config.setdefault("constraint_assertion_eps", 1e-6) config.setdefault("model_dir", "/tmp/test_pwl_model_dir/") + config.setdefault("dtype", tf.float32) if "input_keypoints" not in config: # If "input_keypoints" are provided - other params referred by code below - # might be not available. + # might be not available, so we make sure it exists before executing + # this code. config.setdefault("input_keypoints", np.linspace(start=config["input_min"], stop=config["input_max"], @@ -203,7 +210,7 @@ def _TrainModel(self, config, plot_path=None): calibration_layers = [] for _ in range(num_calibration_layers): calibration_layers.append( - pwl_calibraion.PWLCalibration( + keras_layer.PWLCalibration( units=pwl_calibration_units, dtype=tf.float32, input_keypoints=config["input_keypoints"], @@ -278,6 +285,225 @@ def _InverseAndTrain(self, config): inversed_loss = self._TrainModel(inversed_config) return inversed_loss + def _CreateTrainingData(self, config): + training_inputs = config["x_generator"]( + units=config["units"], + num_points=config["num_training_records"], + input_min=config["input_keypoints"][0], + input_max=config["input_keypoints"][-1], + missing_probability=config["missing_probability"], + missing_input_value=config["missing_input_value"]) + training_labels = [config["y_function"](x) for x in training_inputs] + training_inputs = tf.convert_to_tensor(training_inputs, dtype=tf.float32) + training_labels = tf.convert_to_tensor(training_labels, dtype=tf.float32) + return (training_inputs, training_labels) + + def _CreateKerasLayer(self, config): + missing_input_value = config["missing_input_value"] + if config["use_separate_missing"]: + # We use 'config["missing_input_value"]' to create the is_missing tensor, + # and we want the model to use the is_missing tensor so we don't pass + # a missing_input_value to the model. + missing_input_value=None + return keras_layer.PWLCalibration( + input_keypoints=config["input_keypoints"], + units=config["units"], + output_min=config["output_min"], + output_max=config["output_max"], + clamp_min=config["clamp_min"], + clamp_max=config["clamp_max"], + monotonicity=config["monotonicity"], + convexity=config["convexity"], + is_cyclic=config["is_cyclic"], + kernel_initializer=config["initializer"], + kernel_regularizer=config["kernel_regularizer"], + impute_missing=config["impute_missing"], + missing_output_value=config["missing_output_value"], + missing_input_value=missing_input_value, + num_projection_iterations=config["num_projection_iterations"], + dtype=config["dtype"]) + + def _CreateSonnetModule(self, config): + missing_input_value = config["missing_input_value"] + if config["use_separate_missing"]: + # We use 'config["missing_input_value"]' to create the is_missing tensor, + # and we want the model to use the is_missing tensor so we don't pass + # a missing_input_value to the model. + missing_input_value=None + return sonnet_module.PWLCalibration( + input_keypoints=config["input_keypoints"], + units=config["units"], + output_min=config["output_min"], + output_max=config["output_max"], + clamp_min=config["clamp_min"], + clamp_max=config["clamp_max"], + monotonicity=config["monotonicity"], + convexity=config["convexity"], + is_cyclic=config["is_cyclic"], + kernel_init=config["initializer"], + impute_missing=config["impute_missing"], + missing_input_value=missing_input_value, + missing_output_value=config["missing_output_value"], + num_projection_iterations=config["num_projection_iterations"]) + + def _AssertSonnetEquivalentToKeras(self, config): + training_inputs, training_labels = self._CreateTrainingData(config) + keras_layer_ctor = lambda: self._CreateKerasLayer(config) + sonnet_module_ctor = lambda: self._CreateSonnetModule(config) + test_utils.assert_sonnet_equivalent_to_keras( + test=self, + sonnet_module_ctor=sonnet_module_ctor, + keras_layer_ctor=keras_layer_ctor, + training_inputs=training_inputs, + training_labels=training_labels, + ) + + def _createConfig(self, **kwargs): + config = dict(kwargs) + return self._SetDefaults(config) + + def testSonnetDefaultValues(self): + """Compares the sonnet module with default values to the keras layer.""" + if self._disable_all: + return + config = self._createConfig( + input_keypoints=[0, 0.25, 0.5, 1.0], + units=10, + x_generator=self._ScatterXUniformly, + y_function=self._SmallWaves, + num_training_records=100, + ) + self._AssertSonnetEquivalentToKeras(config) + + def testSonnetOutputMinOutputMax(self): + if self._disable_all: + return + config = self._createConfig( + input_keypoints=[0, 0.25, 0.5, 1.0], + units=10, + x_generator=self._ScatterXUniformly, + y_function=self._SmallWaves, + num_training_records=100, + output_min=1.0, + output_max=10.0, + ) + self._AssertSonnetEquivalentToKeras(config) + + def testSonnetClampMinClampMax(self): + if self._disable_all: + return + config = self._createConfig( + input_keypoints=[0, 0.25, 0.5, 1.0], + units=10, + x_generator=self._ScatterXUniformly, + y_function=self._SmallWaves, + num_training_records=100, + clamp_min=1.0, + output_max=10.0, + ) + self._AssertSonnetEquivalentToKeras(config) + + def testSonnetMonotonicity(self): + if self._disable_all: + return + for monotonicity in ["increasing", 1, "decreasing", -1]: + config = self._createConfig( + input_keypoints=[0, 0.25, 0.5, 1.0], + units=10, + x_generator=self._ScatterXUniformly, + y_function=self._SmallWaves, + num_training_records=100, + monotonicity=monotonicity, + ) + self._AssertSonnetEquivalentToKeras(config) + + def testSonnetConvexity(self): + if self._disable_all: + return + for convexity in ["convex", 1, "concave", -1]: + config = self._createConfig( + input_keypoints=[0, 0.25, 0.5, 1.0], + units=10, + x_generator=self._ScatterXUniformly, + y_function=self._SmallWaves, + num_training_records=100, + convexity=convexity, + ) + self._AssertSonnetEquivalentToKeras(config) + + def testSonnetIsCyclic(self): + if self._disable_all: + return + config = self._createConfig( + input_keypoints=[0, 0.25, 0.5, 1.0], + units=10, + x_generator=self._ScatterXUniformly, + y_function=self._SmallWaves, + num_training_records=100, + is_cyclic=True, + ) + self._AssertSonnetEquivalentToKeras(config) + + def testSonnetKernelInit(self): + if self._disable_all: + return + # kernel_init="equal_heights" is the default and is tested in + # testSonnetDefaultValues, so we don't test it here. + for kernel_init in [None, "equal_slopes"]: + config = self._createConfig( + input_keypoints=[0, 0.25, 0.5, 1.0], + units=10, + x_generator=self._ScatterXUniformly, + y_function=self._SmallWaves, + num_training_records=100, + kernel_init=kernel_init, + ) + self._AssertSonnetEquivalentToKeras(config) + + def testSonnetMissingInputValue(self): + if self._disable_all: + return + config = self._createConfig( + input_keypoints=[0, 0.25, 0.5, 1.0], + units=10, + x_generator=self._ScatterXUniformly, + y_function=self._SmallWaves, + num_training_records=100, + impute_missing=True, + missing_input_value=3, + missing_probability=0.5, + ) + self._AssertSonnetEquivalentToKeras(config) + + def testSonnetMissingOutputValue(self): + if self._disable_all: + return + config = self._createConfig( + input_keypoints=[0, 0.25, 0.5, 1.0], + units=10, + x_generator=self._ScatterXUniformly, + y_function=self._SmallWaves, + num_training_records=100, + impute_missing=True, + missing_input_value=3, + missing_probability=0.5, + missing_output_value=10, + ) + self._AssertSonnetEquivalentToKeras(config) + + def testSonnetNumProjectionIterations(self): + if self._disable_all: + return + config = self._createConfig( + input_keypoints=[0, 0.25, 0.5, 1.0], + units=10, + x_generator=self._ScatterXUniformly, + y_function=self._SmallWaves, + num_training_records=100, + num_projection_iterations=2, + ) + self._AssertSonnetEquivalentToKeras(config) + @parameterized.parameters( (1, False, 0.001022), (3, False, 0.000543), @@ -1075,7 +1301,7 @@ def testRegularizers(self, units, regularizer, pure_reg_loss, training_loss): def testAssertMonotonicity(self): if self._disable_all: return - decreasing_initializer = pwl_calibraion.UniformOutputInitializer( + decreasing_initializer = keras_layer.UniformOutputInitializer( output_min=0.0, output_max=1.0, monotonicity=-1) # Specify decreasing initializer and do 0 training iterations so no # projections are being executed. diff --git a/tensorflow_lattice/python/rtl_layer.py b/tensorflow_lattice/python/rtl_layer.py index 97e9cf7..1ffc5fe 100644 --- a/tensorflow_lattice/python/rtl_layer.py +++ b/tensorflow_lattice/python/rtl_layer.py @@ -48,17 +48,22 @@ class RTL(keras.layers.Layer): takes in a collection of monotonic and unconstrained features and randomly arranges them into lattices of a given rank. The input is taken as "groups", and inputs from the same group will not be used in the same lattice. E.g. the - input can be the ouput of a calibration layer with multiple units applied to + input can be the output of a calibration layer with multiple units applied to the same input feature. If there are more slots in the RTL than the number of inputs, inputs will be repeatedly used. Repeats will be approximately uniform - accross all inputs. + across all inputs. Input shape: - A dict with keys in `['unconstrained', 'increasing']`, and the values either - a list of tensors of shape (batch_size, D_i), or a single tensor of shape - (batch_size, D) that will be split into a list of D tensors of size - (batch_size, 1). Each tensor in the list is considered a "group" of features - that the RTL layer should try not to use in the same lattice. + One of: + - A dict with keys in `['unconstrained', 'increasing']`, and the values + either a list of tensors of shape (batch_size, D_i), or a single tensor + of shape (batch_size, D) that will be conceptually split into a list of D + tensors of size (batch_size, 1). Each tensor in the list is considered a + "group" of features that the RTL layer should try not to use in the same + lattice. + - A single tensor of shape (batch_size, D), which is considered to be + unconstrained and will be conceptually split into a list of D tensors of + size (batch_size, 1). Output shape: If `separate_outputs == True`, the output will be in the same format as the @@ -119,6 +124,8 @@ def __init__(self, num_projection_iterations=10, monotonic_at_every_step=True, clip_inputs=True, + interpolation='hypercube', + avoid_intragroup_interaction=True, kernel_initializer='random_monotonic_initializer', kernel_regularizer=None, **kwargs): @@ -151,6 +158,13 @@ def __init__(self, num_projection_iterations parameter is likely to hurt convergence. clip_inputs: If inputs should be clipped to the input range of the lattice. + interpolation: One of 'hypercube' or 'simplex' interpolation. For a + d-dimensional lattice, 'hypercube' interpolates 2^d parameters, whereas + 'simplex' uses d+1 parameters and thus scales better. For details see + `tfl.lattice_lib.evaluate_with_simplex_interpolation` and + `tfl.lattice_lib.evaluate_with_hypercube_interpolation`. + avoid_intragroup_interaction: If set to true, the RTL algorithm will try + to avoid having inputs from the same group in the same lattice. kernel_initializer: One of: - `'linear_initializer'`: initialize parameters to form a linear function with positive and equal coefficients for monotonic dimensions @@ -190,6 +204,8 @@ def __init__(self, self.num_projection_iterations = num_projection_iterations self.monotonic_at_every_step = monotonic_at_every_step self.clip_inputs = clip_inputs + self.interpolation = interpolation + self.avoid_intragroup_interaction = avoid_intragroup_interaction self.kernel_initializer = kernel_initializer self.kernel_regularizer = kernel_regularizer @@ -200,7 +216,7 @@ def build(self, input_shape): self._lattice_layers = {} for monotonicities, inputs_for_units in self._rtl_structure: units = len(inputs_for_units) - self._lattice_layers[monotonicities] = lattice_layer.Lattice( + self._lattice_layers[str(monotonicities)] = lattice_layer.Lattice( lattice_sizes=[self.lattice_size] * self.lattice_rank, units=units, monotonicities=monotonicities, @@ -209,6 +225,7 @@ def build(self, input_shape): num_projection_iterations=self.num_projection_iterations, monotonic_at_every_step=self.monotonic_at_every_step, clip_inputs=self.clip_inputs, + interpolation=self.interpolation, kernel_initializer=self.kernel_initializer, kernel_regularizer=self.kernel_regularizer, ) @@ -217,40 +234,33 @@ def build(self, input_shape): def call(self, x, **kwargs): """Standard Keras call() method.""" if not isinstance(x, dict): - raise ValueError('Input to the RTL layer must be dict') + x = {'unconstrained': x} + # Flatten the input. # The order for flattening should match the order in _get_rtl_structure. input_tensors = [] for input_key in sorted(x.keys()): items = x[input_key] - if not isinstance(items, list): - items = [items] - for tensor in items: - dim = tensor.shape.as_list()[1] - if dim == 1: - input_tensors.append(tensor) - else: - input_tensors.extend(tf.split(tensor, dim, axis=1)) + if isinstance(items, list): + input_tensors.extend(items) + else: + input_tensors.append(items) + if len(input_tensors) == 1: + flattened_input = input_tensors[0] + else: + flattened_input = tf.concat(input_tensors, axis=1) # outputs_for_monotonicity[0] are non-monotonic outputs # outputs_for_monotonicity[1] are monotonic outputs outputs_for_monotonicity = [[], []] for monotonicities, inputs_for_units in self._rtl_structure: - # Create inputs to lattice layer by concatenating all the inputs. - lattice_inputs = [] - for inputs_for_unit in inputs_for_units: - # Concat into (-1, lattice_rank) for a single lattice - lattice_inputs.append( - tf.concat([input_tensors[i] for i in inputs_for_unit], axis=1)) - if len(lattice_inputs) > 1: - # Stack into (-1, units, lattice_rank) for multi-unit lattice layer - lattice_inputs = tf.stack(lattice_inputs, axis=1) - else: - lattice_inputs = lattice_inputs[0] + if len(inputs_for_units) == 1: + inputs_for_units = inputs_for_units[0] + lattice_inputs = tf.gather(flattened_input, inputs_for_units, axis=1) output_monotonicity = max(monotonicities) # Call each lattice layer and store based on output monotonicy. outputs_for_monotonicity[output_monotonicity].append( - self._lattice_layers[monotonicities](lattice_inputs)) + self._lattice_layers[str(monotonicities)](lattice_inputs)) if self.separate_outputs: separate_outputs = {} @@ -302,6 +312,8 @@ def get_config(self): 'num_projection_iterations': self.num_projection_iterations, 'monotonic_at_every_step': self.monotonic_at_every_step, 'clip_inputs': self.clip_inputs, + 'interpolation': self.interpolation, + 'avoid_intragroup_interaction': self.avoid_intragroup_interaction, 'kernel_initializer': self.kernel_initializer, 'kernel_regularizer': self.kernel_regularizer, }) @@ -325,13 +337,13 @@ def assert_constraints(self, eps=1e-6): """Asserts that weights satisfy all constraints. In graph mode builds and returns a list of assertion ops. - In eager mode directly executes assetions. + In eager mode directly executes assertions. Args: eps: allowed constraints violation. Returns: - List of assertion ops in graph mode or immideately asserts in eager mode. + List of assertion ops in graph mode or immediately asserts in eager mode. """ assertions = [] for layer in self._lattice_layers.values(): @@ -354,7 +366,7 @@ def _get_rtl_structure(self, input_shape): indices into the flattened input to the layer. """ if not isinstance(input_shape, dict): - raise ValueError('Input to the RTL layer must be dict') + input_shape = {'unconstrained': input_shape} # Calculate the flattened input to the RTL layer. rtl_inputs will be a list # of _RTLInput items, each including information about the monotonicity, @@ -411,7 +423,7 @@ def _get_rtl_structure(self, input_shape): # group is used in each lattice. changed = True iteration = 0 - while changed: + while changed and self.avoid_intragroup_interaction: if iteration > _MAX_RTL_SWAPS: logging.info('Some lattices in the RTL layer might use features from ' 'the same input group') diff --git a/tensorflow_lattice/python/test_utils.py b/tensorflow_lattice/python/test_utils.py index c18a566..229e534 100644 --- a/tensorflow_lattice/python/test_utils.py +++ b/tensorflow_lattice/python/test_utils.py @@ -22,6 +22,8 @@ from . import visualization from absl import logging import numpy as np +import sonnet as snt +import tensorflow as tf class TimeTracker(object): @@ -125,6 +127,98 @@ def run_training_loop(config, return loss +def assert_sonnet_equivalent_to_keras( + test, sonnet_module_ctor, keras_layer_ctor, + training_inputs, training_labels, + epsilon=1e-4): + """Asserts that a Sonnet module is "equivalent" to a Keras layer. + + Creates a Sonnet module and a Keras layer using the given constructors. It + then uses both models to evaluate the given 'training_inputs' tensor and + asserts that the results are equal. It then trains both models and asserts + that the final loss (w.r.t the given 'training_labels') and the post-training + predictions of both models on 'training_inputs' are also equal. + + Args: + test: a tf.test.TestCase whose 'assert...' methods to use for assertion. + sonnet_module_ctor: A callable that takes no arguments and returns the + Sonnet module to use. + keras_layer_ctor: A callable that takes no arguments and returns the + Keras layer to use. + training_inputs: Tensor of shape (batch_size, ....) tensor containing the + training inputs. + training_labels: Tensor of shape (batch_size, ....). tensor containing the + training labels. + epsilon: float. Sensitivity of comparison. Comparison of model predictions + and losses are done using test.assertNear and test.assertNDArrayNear. + This is the value to pass as the 'err' parameter to these assertion + methods. + """ + # This function assumes we're executing eagerly. + test.assertTrue(tf.executing_eagerly()) + + num_training_epochs = 10 + num_training_inputs = training_inputs.shape[0] + + # Create Keras model. + keras_model = tf.keras.models.Sequential(layers=[keras_layer_ctor()]) + keras_model.compile( + loss=tf.keras.losses.mean_squared_error, + optimizer=tf.keras.optimizers.SGD(learning_rate=0.1)) + + keras_preds_pre_training = keras_model.predict( + x=training_inputs, batch_size=num_training_inputs) + + # Train the Keras model. + keras_model.fit(x=training_inputs, y=training_labels, + batch_size=num_training_inputs, + epochs=num_training_epochs, + verbose=0) + keras_loss = keras_model.evaluate(x=training_inputs, y=training_labels, + batch_size=num_training_inputs, + verbose=0) + keras_preds_post_training = keras_model.predict( + x=training_inputs, batch_size=num_training_inputs) + + # Create the Sonnet model + sonnet_module = sonnet_module_ctor() + + sonnet_preds_pre_training = sonnet_module(training_inputs).numpy() + + # Train the Sonnet model + sonnet_optimizer = snt.optimizers.SGD(learning_rate=0.1) + mse_loss = tf.keras.losses.MeanSquaredError() + for _ in range(num_training_epochs): + with tf.GradientTape() as tape: + preds = sonnet_module(training_inputs) + sonnet_loss = mse_loss(training_labels, preds) + + params = sonnet_module.trainable_variables + grads = tape.gradient(sonnet_loss, params) + sonnet_optimizer.apply(grads, params) + # We need to apply constraints explicitly in Sonnet + for var in params: + if var.constraint: + var.assign(var.constraint(var)) + + sonnet_preds_post_training = sonnet_module(training_inputs).numpy() + sonnet_loss = mse_loss(training_labels, sonnet_preds_post_training).numpy() + # Note that corresponding initializers between Sonnet and Keras may have + # different default values for their construction arguments. E.g. + # https://sonnet.readthedocs.io/en/latest/api.html#randomuniform + # https://www.tensorflow.org/api_docs/python/tf/keras/initializers/RandomUniform + # which may cause these assertions to fail. When comparing models make + # sure th initializers behave the same. + test.assertNDArrayNear(sonnet_preds_pre_training, + keras_preds_pre_training, + err=epsilon) + test.assertNear(sonnet_loss, keras_loss, err=epsilon) + test.assertNDArrayNear( + sonnet_preds_post_training, + keras_preds_post_training, + err=epsilon) + + def two_dim_mesh_grid(num_points, x_min, y_min, x_max, y_max): """Generates uniform 2-d mesh grid for 3-d surfaces visualisation via pyplot. diff --git a/tensorflow_lattice/sonnet_modules/__init__.py b/tensorflow_lattice/sonnet_modules/__init__.py new file mode 100644 index 0000000..0db2b21 --- /dev/null +++ b/tensorflow_lattice/sonnet_modules/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""'sonnet_modules' namespace for TFL layers.""" + +from tensorflow_lattice.python.pwl_calibration_sonnet_module import PWLCalibration