diff --git a/examples/BUILD b/examples/BUILD index 29a9084..6d6ca86 100644 --- a/examples/BUILD +++ b/examples/BUILD @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================== +load("//third_party/bazel_rules/rules_python/python:py_binary.bzl", "py_binary") + licenses(["notice"]) package( diff --git a/setup.py b/setup.py index 60788e0..983a01e 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ # This version number should always be that of the *next* (unreleased) version. # Immediately after uploading a package to PyPI, you should increment the # version number and push to gitHub. -__version__ = "2.0.11" +__version__ = "2.0.12" if "--release" in sys.argv: sys.argv.remove("--release") @@ -45,7 +45,6 @@ "scikit-learn", "matplotlib", "graphviz", - "dm-sonnet", ] # Part of the visualization code uses colabtools and IPython libraries. These diff --git a/tensorflow_lattice/BUILD b/tensorflow_lattice/BUILD index 303d6e2..8e13739 100644 --- a/tensorflow_lattice/BUILD +++ b/tensorflow_lattice/BUILD @@ -28,7 +28,6 @@ py_library( srcs = [ "__init__.py", "layers/__init__.py", - "sonnet_modules/__init__.py", ], srcs_version = "PY2AND3", deps = [ @@ -36,6 +35,8 @@ py_library( "//tensorflow_lattice/python:categorical_calibration_layer", "//tensorflow_lattice/python:categorical_calibration_lib", "//tensorflow_lattice/python:cdf_layer", + "//tensorflow_lattice/python:conditional_cdf", + "//tensorflow_lattice/python:conditional_pwl_calibration", "//tensorflow_lattice/python:configs", "//tensorflow_lattice/python:estimators", "//tensorflow_lattice/python:kronecker_factored_lattice_layer", @@ -50,7 +51,6 @@ py_library( "//tensorflow_lattice/python:premade_lib", "//tensorflow_lattice/python:pwl_calibration_layer", "//tensorflow_lattice/python:pwl_calibration_lib", - "//tensorflow_lattice/python:pwl_calibration_sonnet_module", "//tensorflow_lattice/python:rtl_layer", "//tensorflow_lattice/python:test_utils", "//tensorflow_lattice/python:utils", diff --git a/tensorflow_lattice/__init__.py b/tensorflow_lattice/__init__.py index b87e5c3..f1259aa 100644 --- a/tensorflow_lattice/__init__.py +++ b/tensorflow_lattice/__init__.py @@ -20,11 +20,12 @@ from __future__ import absolute_import import tensorflow_lattice.layers - from tensorflow_lattice.python import aggregation_layer from tensorflow_lattice.python import categorical_calibration_layer from tensorflow_lattice.python import categorical_calibration_lib from tensorflow_lattice.python import cdf_layer +from tensorflow_lattice.python import conditional_cdf +from tensorflow_lattice.python import conditional_pwl_calibration from tensorflow_lattice.python import configs from tensorflow_lattice.python import estimators from tensorflow_lattice.python import kronecker_factored_lattice_layer @@ -39,9 +40,6 @@ from tensorflow_lattice.python import premade_lib from tensorflow_lattice.python import pwl_calibration_layer from tensorflow_lattice.python import pwl_calibration_lib -from tensorflow_lattice.python import pwl_calibration_sonnet_module from tensorflow_lattice.python import test_utils from tensorflow_lattice.python import utils from tensorflow_lattice.python import visualization - -import tensorflow_lattice.sonnet_modules diff --git a/tensorflow_lattice/python/BUILD b/tensorflow_lattice/python/BUILD index 68aefdc..c0f95fe 100644 --- a/tensorflow_lattice/python/BUILD +++ b/tensorflow_lattice/python/BUILD @@ -13,6 +13,9 @@ # limitations under the License. # ============================================================================== +load("//third_party/bazel_rules/rules_python/python:py_library.bzl", "py_library") +load("//third_party/bazel_rules/rules_python/python:py_test.bzl", "py_test") + package( default_visibility = [ "//tensorflow_lattice:__subpackages__", @@ -434,19 +437,6 @@ py_library( ], ) -py_library( - name = "pwl_calibration_sonnet_module", - srcs = ["pwl_calibration_sonnet_module.py"], - srcs_version = "PY2AND3", - deps = [ - ":pwl_calibration_lib", - ":utils", - # absl/logging dep, - # sonnet dep, - # tensorflow:tensorflow_no_contrib dep, - ], -) - py_test( name = "pwl_calibration_test", size = "large", @@ -457,14 +447,13 @@ py_test( deps = [ ":parallel_combination_layer", ":pwl_calibration_layer", - ":pwl_calibration_sonnet_module", ":test_utils", ":utils", # absl/logging dep, # absl/testing:parameterized dep, # numpy dep, - # sonnet dep, # tensorflow dep, + # tensorflow:tensorflow_no_contrib dep, ], ) @@ -514,8 +503,6 @@ py_library( ":visualization", # absl/logging dep, # numpy dep, - # sonnet dep, - # tensorflow dep, ], ) @@ -528,6 +515,42 @@ py_library( ], ) +py_library( + name = "conditional_pwl_calibration", + srcs = ["conditional_pwl_calibration.py"], + deps = [ + # numpy dep, + # tensorflow:tensorflow_no_contrib dep, + ], +) + +py_library( + name = "conditional_cdf", + srcs = ["conditional_cdf.py"], + deps = [ + # tensorflow:tensorflow_no_contrib dep, + ], +) + +py_test( + name = "conditional_cdf_test", + srcs = ["conditional_cdf_test.py"], + deps = [ + ":conditional_cdf", + # absl/testing:parameterized dep, + # tensorflow:tensorflow_no_contrib dep, + ], +) + +py_test( + name = "conditional_pwl_calibration_test", + srcs = ["conditional_pwl_calibration_test.py"], + deps = [ + ":conditional_pwl_calibration", + # tensorflow:tensorflow_no_contrib dep, + ], +) + py_test( name = "utils_test", srcs = ["utils_test.py"], diff --git a/tensorflow_lattice/python/aggregation_layer.py b/tensorflow_lattice/python/aggregation_layer.py index 9aadadf..d59f17e 100644 --- a/tensorflow_lattice/python/aggregation_layer.py +++ b/tensorflow_lattice/python/aggregation_layer.py @@ -76,11 +76,14 @@ def call(self, x): def get_config(self): """Standard Keras get_config() method.""" config = super(Aggregation, self).get_config().copy() - config.update({'model': tf.keras.utils.serialize_keras_object(self.model)}) + config.update( + {'model': tf.keras.utils.legacy.serialize_keras_object(self.model)} + ) return config @classmethod def from_config(cls, config, custom_objects=None): - model = tf.keras.utils.deserialize_keras_object( - config.pop('model'), custom_objects=custom_objects) + model = tf.keras.utils.legacy.deserialize_keras_object( + config.pop('model'), custom_objects=custom_objects + ) return cls(model, **config) diff --git a/tensorflow_lattice/python/categorical_calibration_layer.py b/tensorflow_lattice/python/categorical_calibration_layer.py index 2384f1b..996695e 100644 --- a/tensorflow_lattice/python/categorical_calibration_layer.py +++ b/tensorflow_lattice/python/categorical_calibration_layer.py @@ -249,9 +249,11 @@ def get_config(self): "output_max": self.output_max, "monotonicities": self.monotonicities, "kernel_initializer": - keras.initializers.serialize(self.kernel_initializer), + keras.initializers.serialize( + self.kernel_initializer, use_legacy_format=True), "kernel_regularizer": - [keras.regularizers.serialize(r) for r in self.kernel_regularizer], + [keras.regularizers.serialize(r, use_legacy_format=True) + for r in self.kernel_regularizer], "default_input_value": self.default_input_value, "split_outputs": self.split_outputs, } # pyformat: disable diff --git a/tensorflow_lattice/python/cdf_layer.py b/tensorflow_lattice/python/cdf_layer.py index 2f192e8..a82dc15 100644 --- a/tensorflow_lattice/python/cdf_layer.py +++ b/tensorflow_lattice/python/cdf_layer.py @@ -255,7 +255,8 @@ def get_config(self): "sparsity_factor": self.sparsity_factor, "kernel_initializer": - tf.keras.initializers.serialize(self.kernel_initializer), + tf.keras.initializers.serialize( + self.kernel_initializer, use_legacy_format=True), } config.update(super(CDF, self).get_config()) return config diff --git a/tensorflow_lattice/python/conditional_cdf.py b/tensorflow_lattice/python/conditional_cdf.py new file mode 100644 index 0000000..2b36547 --- /dev/null +++ b/tensorflow_lattice/python/conditional_cdf.py @@ -0,0 +1,275 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implements CDF transformation with derived parameters (kernels). + +`cdf_fn` is similar to `tfl.layers.CDF`, which is an additive / multiplicative +average of a few shifted and scaled `sigmoid` or `relu6` basis functions, +with the difference that the functions are parametrized by the provided +parameters instead of learnable weights belonging to a `tfl.layers.CDF` layer. + +These parameters can be one of: + + - constants, + - trainable variables, + - outputs from other TF modules. + +For inputs of shape `(batch_size, input_dim)`, two sets of free-form +parameters are used to configure the CDF function: + +- `location_parameters` for where to place the sigmoid / relu6 transformation +basis, +- `scaling_parameters` (optional) for the horizontal scaling before applying +the transformation basis. +""" + +from typing import Optional, Union, Tuple +import tensorflow as tf + + +def _verify_cdf_params( + inputs: tf.Tensor, + location_parameters: tf.Tensor, + scaling_parameters: Optional[tf.Tensor], + units: int, + activation: str, + reduction: str, + sparsity_factor: int, +) -> None: + """Verifies the arguments of cdf_fn call. + + Args: + inputs: inputs to the CDF function. + location_parameters: parameters for deciding the locations of the + transformations. + scaling_parameters: parameters for deciding the horizontal scaling of the + transformations. + units: output dimension. + activation: either `sigmoid` or `relu6` for selecting the transformation. + reduction: either `mean`, `geometric_mean`, or `none` to specify whether to + perform averaging and which average to perform. + sparsity_factor: deciding the level of sparsity during reduction. + `input_dim` and `units` should both be divisible by `sparsity_factor`. + """ + if activation not in ("sigmoid", "relu6"): + raise ValueError( + f"activation = {activation} is not supported. Use 'sigmoid' or 'relu6'." + ) + if reduction not in ("mean", "geometric_mean", "none"): + raise ValueError( + f"reduction = {reduction} is not supported. Use 'mean'," + " 'geometric_mean' or 'none'." + ) + + if len(inputs.shape) != 2: + raise ValueError( + f"inputs shape {inputs.shape} is not (batch_size, input_dim)." + ) + + input_dim = inputs.shape[1] + if units % sparsity_factor != 0: + raise ValueError( + f"units = {units} is not divisible by sparsity_factor =" + f" {sparsity_factor}." + ) + if input_dim % sparsity_factor != 0: + raise ValueError( + f"input_dim = {input_dim} is not divisible by sparsity_factor =" + f" {sparsity_factor}." + ) + + if ( + len(location_parameters.shape) != 4 + or location_parameters.shape[1] != input_dim + or location_parameters.shape[3] != units // sparsity_factor + ): + raise ValueError( + "location_parameters shape" + f" {location_parameters.shape} is not (batch, input_dim, " + f"num_functions, units / sparsity_factor = {units // sparsity_factor})." + ) + + if scaling_parameters is not None: + try: + _ = tf.broadcast_to( + scaling_parameters, + location_parameters.shape, + name="cdf_fn_try_broadcasting", + ) + except Exception as err: + raise ValueError( + "scaling_parameters and location_parameters likely" + " are not broadcastable. Shapes of scaling_parameters:" + f" {scaling_parameters.shape}, location_parameters:" + f" {location_parameters.shape}." + ) from err + + +@tf.function +def cdf_fn( + inputs: tf.Tensor, + location_parameters: tf.Tensor, + scaling_parameters: Optional[tf.Tensor] = None, + units: int = 1, + activation: str = "relu6", + reduction: str = "mean", + sparsity_factor: int = 1, + scaling_exp_transform_multiplier: Optional[float] = None, + return_derived_parameters: bool = False, +) -> Union[tf.Tensor, Tuple[tf.Tensor, tf.Tensor, tf.Tensor]]: + r"""Maps `inputs` through a CDF function specified by keypoint parameters. + + `cdf_fn` is similar to `tfl.layers.CDF`, which is an additive / multiplicative + average of a few shifted and scaled `sigmoid` or `relu6` basis functions, + with the difference that the functions are parametrized by the provided + parameters instead of learnable weights belonging to a `tfl.layers.CDF` layer. + + These parameters can be one of: + + - constants, + - trainable variables, + - outputs from other TF modules. + + For inputs of shape `(batch_size, input_dim)`, two sets of free-form + parameters are used to configure the CDF function: + + - `location_parameters` for where to place the sigmoid / relu6 transformation + basis, + - `scaling_parameters` (optional) for the horizontal scaling before applying + the transformation basis. + + The transformation per dimension is `x -> activation(scale * (x - location))`, + where: + + - `scale` (specified via `scaling_parameter`) is the input scaling for each + dimension and needs to be strictly positive for the CDF function to become + monotonic. If needed, you can set `scaling_exp_transform_multiplier` to get + `scale = exp(scaling_parameter * scaling_exp_transform_multiplier)` and + guarantees strict positivity. + - `location` (specified via `location_parameter`) is the input shift. Notice + for `relu6` this is where the transformation starts to be nonzero, whereas for + `sigmoid` this is where the transformation hits 0.5. + - `activation` is either `sigmoid` or `relu6` (for `relu6 / 6`). + + An optional `reduction` operation will compute the additive / multiplicative + average for the input dims after their individual CDF transformation. `mean` + and `geometric_mean` are supported if sepcified. + + `sparsity_factor` decides the level of sparsity during reduction. For + instance, default of `sparsity = 1` calculates the average of *all* input + dims, whereas `sparsity = 2` calculates the average of *every other* input + dim, and so on. + + Input shape: + We denote `num_functions` as the number of `sigmoid` or `relu6 / 6` basis + functions used for each CDF transformation. + + `inputs` should be: + + - `(batch_size, input_dim)`. + + `location_parameters` should be: + + - `(batch_size, input_dim, num_functions, units // sparsity_factor)`. + + `scaling_parameters` when provided should be broadcast friendly + with `location_parameters`, e.g. one of + + - `(batch_size, input_dim, 1, 1)`, + - `(batch_size, input_dim, num_functions, 1)`, + - `(batch_size, input_dim, 1, units // sparsity_factor)`, + - `(batch_size, input_dim, num_functions, units // sparsity_factor)`. + + Args: + inputs: inputs to the CDF function. + location_parameters: parameters for deciding the locations of the + transformations. + scaling_parameters: parameters for deciding the horizontal scaling of the + transformations. + units: output dimension. + activation: either `sigmoid` or `relu6` for selecting the transformation. + reduction: either `mean`, `geometric_mean`, or `none` to specify whether to + perform averaging and which average to perform. + sparsity_factor: deciding the level of sparsity during reduction. + `input_dim` and `units` should both be divisible by `sparsity_factor`. + scaling_exp_transform_multiplier: if provided, will be used inside an + exponential transformation for `scaling_parameters`. This can be useful if + `scaling_parameters` is free-form. + return_derived_parameters: Whether `location_parameters` and + `scaling_parameters` should be output along with the model output (e.g. + for loss function computation purpoeses). + + Returns: + If `return_derived_parameters = False`: + + - The CDF transformed outputs as a tensor with shape either + `(batch_size, units)` if `reduction = 'mean' / 'geometric_mean'`, or + `(batch_size, input_dim // sparsity_factor, units)` if + `reduction = 'none'`. + + If `return_derived_parameters = True`: + + - A tuple of three elements: + + 1. The CDF transformed outputs. + 2. `location_parameters`. + 3. `scaling_parameters`, with `exp` transformation applied if specified. + """ + + _verify_cdf_params( + inputs, + location_parameters, + scaling_parameters, + units, + activation, + reduction, + sparsity_factor, + ) + input_dim = inputs.shape[1] + x = inputs[..., tf.newaxis, tf.newaxis] - location_parameters + if scaling_parameters is not None: + if scaling_exp_transform_multiplier is not None: + scaling_parameters = tf.math.exp( + scaling_parameters * scaling_exp_transform_multiplier + ) + x *= scaling_parameters + else: + # For use when return_derived_parameters = True. + scaling_parameters = tf.ones_like(location_parameters, dtype=tf.float32) + + # Shape: (batch, input_dim, 1, 1) + # --> (batch, input_dim, num_functions, units / factor) + # --> (batch, input_dim, units / factor). + if activation == "relu6": + result = tf.reduce_mean(tf.nn.relu6(x), axis=2) / 6 + else: # activation == "sigmoid": + result = tf.reduce_mean(tf.nn.sigmoid(x), axis=2) + + if sparsity_factor != 1: + # Shape: (batch, input_dim, units / factor) + # --> (batch, input_dim / factor, units). + result = tf.reshape(result, (-1, input_dim // sparsity_factor, units)) + + # Shape: (batch, input_dim / factor, units) --> (batch, units). + if reduction == "mean": + result = tf.reduce_mean(result, axis=1) + elif reduction == "geometric_mean": + # We use the log form so that we can add the epsilon term + # tf.pow(tf.reduce_prod(cdfs, axis=1), 1. / num_terms). + result = tf.math.exp(tf.reduce_mean(tf.math.log(result + 1e-8), axis=1)) + # Otherwise reduction == "none". + + if return_derived_parameters: + return (result, location_parameters, scaling_parameters) + else: + return result diff --git a/tensorflow_lattice/python/conditional_cdf_test.py b/tensorflow_lattice/python/conditional_cdf_test.py new file mode 100644 index 0000000..c9d3919 --- /dev/null +++ b/tensorflow_lattice/python/conditional_cdf_test.py @@ -0,0 +1,731 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF tests for conditional_cdf.py.""" + +from absl.testing import parameterized +import tensorflow as tf +from tensorflow_lattice.python.conditional_cdf import cdf_fn + +_EPSILON = 1e-4 + + +class CdfFnTest(parameterized.TestCase, tf.test.TestCase): + + def assertAllClose(self, x, y): + super().assertAllClose(x, y, atol=1e-4) + + @parameterized.named_parameters( + dict( + testcase_name="trivial", + inputs=[[-1.0], [0.0], [1.0]], + location_parameters=[ + [[[-1.0], [0.0], [1.0]]], + [[[-1.0], [0.0], [1.0]]], + [[[-1.0], [0.0], [1.0]]], + ], + scaling_parameters=[[[[1.0]]], [[[1.0]]], [[[1.0]]]], + reduction="none", + expected=[[[0.29604811]], [[0.5]], [[0.70395189]]], + ), + dict( + testcase_name="trivial_mean", + inputs=[[-1.0], [0.0], [1.0]], + location_parameters=[ + [[[-1.0], [0.0], [1.0]]], + [[[-1.0], [0.0], [1.0]]], + [[[-1.0], [0.0], [1.0]]], + ], + scaling_parameters=[[[[1.0]]], [[[1.0]]], [[[1.0]]]], + reduction="mean", + expected=[[0.29604811], [0.5], [0.70395189]], + ), + dict( + testcase_name="moderate", + inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]], + location_parameters=[ + [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]], + [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]], + [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]], + ], + scaling_parameters=[ + [[[1.0]], [[1.0]]], + [[[1.0]], [[1.0]]], + [[[1.0]], [[1.0]]], + ], + reduction="none", + expected=[ + [[0.29604811], [0.5]], + [[0.5], [0.61075843]], + [[0.70395189], [0.66584245]], + ], + ), + dict( + testcase_name="moderate_scaling", + inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]], + location_parameters=[ + [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]], + [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]], + [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]], + ], + scaling_parameters=[ + [[[1.0]], [[1.0]]], + [[[2.0]], [[2.0]]], + [[[5.0]], [[7.0]]], + ], + reduction="none", + expected=[ + [[0.29604811], [0.5]], + [[0.5], [0.632815979]], + [[0.8310872504], [0.6666666666]], + ], + ), + dict( + testcase_name="moderate_mean", + inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]], + location_parameters=[ + [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]], + [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]], + [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]], + ], + scaling_parameters=None, + reduction="mean", + expected=[[0.398024055], [0.555379215], [0.684897170]], + ), + dict( + testcase_name="moderate_geometric_mean", + inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]], + location_parameters=[ + [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]], + [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]], + [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]], + ], + scaling_parameters=[ + [[[1.0]], [[1.0]]], + [[[1.0]], [[1.0]]], + [[[1.0]], [[1.0]]], + ], + reduction="geometric_mean", + expected=[[0.38473894], [0.55261127], [0.68463206]], + ), + ) + def test_compute_sigmoid( + self, + inputs, + location_parameters, + scaling_parameters, + reduction, + expected, + ): + result = cdf_fn( + inputs=tf.constant(inputs, dtype=tf.float32), + location_parameters=tf.constant(location_parameters, dtype=tf.float32), + scaling_parameters=( + tf.constant(scaling_parameters, dtype=tf.float32) + if scaling_parameters is not None + else None + ), + units=1, + activation="sigmoid", + reduction=reduction, + ) + self.assertAllClose(result, expected) + + @parameterized.named_parameters( + dict( + testcase_name="trivial", + inputs=[[-1.0], [0.0], [1.0]], + location_parameters=[ + [[[-1.0], [0.0], [1.0]]], + [[[-1.0], [0.0], [1.0]]], + [[[-1.0], [0.0], [1.0]]], + ], + scaling_parameters=[[[[1.0]]], [[[1.0]]], [[[1.0]]]], + reduction="none", + expected=[[[0.0]], [[1.0 / 18]], [[3.0 / 18]]], + ), + dict( + testcase_name="trivial_none_scaling", + inputs=[[-1.0], [0.0], [1.0]], + location_parameters=[ + [[[-1.0], [0.0], [1.0]]], + [[[-1.0], [0.0], [1.0]]], + [[[-1.0], [0.0], [1.0]]], + ], + scaling_parameters=None, + reduction="none", + expected=[[[0.0]], [[1.0 / 18]], [[3.0 / 18]]], + ), + dict( + testcase_name="trivial_mean", + inputs=[[-1.0], [0.0], [1.0]], + location_parameters=[ + [[[-1.0], [0.0], [1.0]]], + [[[-1.0], [0.0], [1.0]]], + [[[-1.0], [0.0], [1.0]]], + ], + scaling_parameters=[[[[1.0]]], [[[1.0]]], [[[1.0]]]], + reduction="mean", + expected=[[0.0], [1.0 / 18], [3.0 / 18]], + ), + dict( + testcase_name="moderate", + inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]], + location_parameters=[ + [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]], + [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]], + [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]], + ], + scaling_parameters=[ + [[[1.0]], [[1.0]]], + [[[1.0]], [[1.0]]], + [[[1.0]], [[1.0]]], + ], + reduction="none", + expected=[ + [[0.0], [2.0 / 18]], + [[1.0 / 18], [5.0 / 18]], + [[3.0 / 18], [8.0 / 18]], + ], + ), + dict( + testcase_name="moderate_none_scaling", + inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]], + location_parameters=[ + [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]], + [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]], + [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]], + ], + scaling_parameters=None, + reduction="none", + expected=[ + [[0.0], [2.0 / 18]], + [[1.0 / 18], [5.0 / 18]], + [[3.0 / 18], [8.0 / 18]], + ], + ), + dict( + testcase_name="moderate_scaling", + inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]], + location_parameters=[ + [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]], + [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]], + [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]], + ], + scaling_parameters=[ + [[[1.0]], [[1.0]]], + [[[2.0]], [[2.0]]], + [[[5.0]], [[0.5]]], + ], + reduction="none", + expected=[ + [[0.0], [2.0 / 18]], + [[2.0 / 18], [8.0 / 18]], + [[11.0 / 18], [4.0 / 18]], + ], + ), + dict( + testcase_name="moderate_mean", + inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]], + location_parameters=[ + [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]], + [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]], + [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]], + ], + scaling_parameters=[ + [[[1.0]], [[1.0]]], + [[[1.0]], [[1.0]]], + [[[1.0]], [[1.0]]], + ], + reduction="mean", + expected=[[1.0 / 18], [3.0 / 18], [5.5 / 18]], + ), + dict( + testcase_name="moderate_geometric_mean", + inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]], + location_parameters=[ + [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]], + [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]], + [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]], + ], + scaling_parameters=[ + [[[1.0]], [[1.0]]], + [[[1.0]], [[1.0]]], + [[[1.0]], [[1.0]]], + ], + reduction="geometric_mean", + expected=[[0.0], [2.23606797 / 18], [4.898979485 / 18]], + ), + ) + def test_compute_relu6( + self, + inputs, + location_parameters, + scaling_parameters, + reduction, + expected, + ): + result = cdf_fn( + inputs=tf.constant(inputs, dtype=tf.float32), + location_parameters=tf.constant(location_parameters, dtype=tf.float32), + scaling_parameters=( + tf.constant(scaling_parameters, dtype=tf.float32) + if scaling_parameters is not None + else None + ), + units=1, + activation="relu6", + reduction=reduction, + ) + self.assertAllClose(result, expected) + + @parameterized.named_parameters( + dict( + testcase_name="0.0", + scaling_exp_transform_multiplier=0.0, + expected=[[0.398024055], [0.555379215], [0.684897170]], + ), + dict( + testcase_name="1.0", + scaling_exp_transform_multiplier=1.0, + expected=[[0.344373118], [0.58323046], [0.6278357037]], + ), + dict( + testcase_name="-1.0", + scaling_exp_transform_multiplier=-1.0, + expected=[[0.4554976295], [0.51644151635], [0.66798191003]], + ), + ) + def test_scaling_exp_transformation( + self, scaling_exp_transform_multiplier, expected + ): + result = cdf_fn( + inputs=tf.constant([[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]]), + location_parameters=tf.constant([ + [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]], + [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]], + [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]], + ]), + scaling_parameters=tf.constant([ + [[[1.0]], [[1.0]]], + [[[0.0]], [[2.0]]], + [[[-1.0]], [[3.0]]], + ]), + reduction="mean", + activation="sigmoid", + scaling_exp_transform_multiplier=scaling_exp_transform_multiplier, + ) + self.assertAllClose(result, expected) + + @parameterized.named_parameters( + dict( + testcase_name="sigmoid_repeat", + inputs=[[0.0], [0.0], [0.0]], + location_parameters=[ + [[[-1.0], [0.0], [1.0]]], + [[[-1.0], [0.0], [1.0]]], + [[[-1.0], [0.0], [1.0]]], + ], + scaling_parameters=[[[[1.0]]], [[[1.0]]], [[[1.0]]]], + units=1, + activation="sigmoid", + sparsity_factor=1, + scaling_exp_transform_multiplier=None, + expected=[ + [ + [[[-0.06553731], [-0.08333334], [-0.06553732]]], + [[[-0.06553731], [-0.08333334], [-0.06553732]]], + [[[-0.06553731], [-0.08333334], [-0.06553732]]], + ], + [ + [[[-7.4505806e-09]]], + [[[-7.4505806e-09]]], + [[[-7.4505806e-09]]], + ], + ], + ), + dict( + testcase_name="sigmoid_trivial", + inputs=[[-1.0], [0.0], [1.0]], + location_parameters=[ + [[[-1.0], [0.0], [1.0]]], + [[[-1.0], [0.0], [1.0]]], + [[[-1.0], [0.0], [1.0]]], + ], + scaling_parameters=[[[[1.0]]], [[[1.0]]], [[[1.0]]]], + units=1, + activation="sigmoid", + sparsity_factor=1, + scaling_exp_transform_multiplier=None, + expected=[ + [ + [[[-0.04934135], [-0.03880439], [-0.0207221]]], + [[[-0.06553731], [-0.08333334], [-0.06553732]]], + [[[-0.04927362], [-0.09227023], [-0.11732531]]], + ], + [[[[-8.0248594e-02]]], [[[-7.4505806e-09]]], [[[1.9081746e-01]]]], + ], + ), + dict( + testcase_name="relu6", + inputs=[[-1.0], [0.0], [1.0]], + location_parameters=[ + [[[-1.0], [0.0], [1.0]]], + [[[-1.0], [0.0], [1.0]]], + [[[-1.0], [0.0], [1.0]]], + ], + scaling_parameters=[[[[1.0]]], [[[1.0]]], [[[1.0]]]], + units=1, + activation="relu6", + sparsity_factor=1, + scaling_exp_transform_multiplier=None, + expected=[ + [ + [[[-0.0], [-0.0], [-0.0]]], + [[[-0.00617284], [-0.0], [-0.0]]], + [[[-0.01851852], [-0.01851852], [-0.0]]], + ], + [[[[0.0]]], [[[0.00617284]]], [[[0.05555556]]]], + ], + ), + dict( + testcase_name="units_multiplier_sigmoid", + inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]], + location_parameters=[ + [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]], + [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]], + [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]], + ], + scaling_parameters=[ + [[[1.0]], [[1.0]]], + [[[2.0]], [[2.0]]], + [[[5.0]], [[7.0]]], + ], + units=2, + activation="sigmoid", + sparsity_factor=2, + scaling_exp_transform_multiplier=0.0, + expected=[ + [ + [ + [[-0.04934135], [-0.03880439], [-0.0207221]], + [[-0.03499786], [-0.08333334], [-0.03499787]], + ], + [ + [[-0.06553731], [-0.08333334], [-0.06553732]], + [[-0.00719178], [-0.08005493], [-0.04275048]], + ], + [ + [[-0.04927362], [-0.09227023], [-0.11732531]], + [[-0.00109488], [-0.04660612], [-0.04660612]], + ], + ], + [[[[-0.0]], [[-0.0]]], [[[-0.0]], [[0.0]]], [[[0.0]], [[0.0]]]], + ], + ), + dict( + testcase_name="units_multiplier_relu6", + inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]], + location_parameters=[ + [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]], + [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]], + [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]], + ], + scaling_parameters=[ + [[[1.0]], [[1.0]]], + [[[2.0]], [[2.0]]], + [[[5.0]], [[7.0]]], + ], + units=2, + activation="relu6", + sparsity_factor=2, + scaling_exp_transform_multiplier=0.01, + expected=[ + [ + [ + [[-0.000], [-0.000], [-0.000]], + [[-0.01259508], [-0.0], [-0.0]], + ], + [ + [[-0.00642476], [-0.0], [-0.0]], + [[-0.03212379], [-0.03212379], [-0.0]], + ], + [ + [[-0.02046613], [-0.02046613], [-0.0]], + [[-0.0], [-0.05392344], [-0.0]], + ], + ], + [ + [[[0.0000000e00]], [[2.5190154e-04]]], + [[[6.4247579e-05]], [[1.6061892e-03]]], + [[[6.1398384e-04]], [[1.0784689e-03]]], + ], + ], + ), + ) + def test_gradient( + self, + inputs, + location_parameters, + scaling_parameters, + units, + activation, + sparsity_factor, + scaling_exp_transform_multiplier, + expected, + ): + location_parameters = tf.Variable( + location_parameters, + trainable=True, + dtype=tf.float32, + name="location_parameters", + ) + scaling_parameters = tf.Variable( + scaling_parameters, + trainable=True, + dtype=tf.float32, + name="scaling_parameters", + ) + + with tf.GradientTape() as tape: + y = cdf_fn( + inputs=tf.constant(inputs, dtype=tf.float32), + location_parameters=location_parameters, + scaling_parameters=scaling_parameters, + reduction="mean", + units=units, + activation=activation, + sparsity_factor=sparsity_factor, + scaling_exp_transform_multiplier=scaling_exp_transform_multiplier, + ) + loss = tf.reduce_sum(y * y) + grads = tape.gradient(loss, [location_parameters, scaling_parameters]) + self.assertAllClose(grads, expected) + + @parameterized.named_parameters( + dict( + testcase_name="activation", + inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]], + location_parameters=[ + [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]], + [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]], + [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]], + ], + scaling_parameters=[ + [[[1.0]], [[1.0]]], + [[[2.0]], [[2.0]]], + [[[5.0]], [[7.0]]], + ], + units=2, + activation="relu", + reduction="none", + sparsity_factor=2, + expected="activation = .* is not supported.*", + ), + dict( + testcase_name="reduction", + inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]], + location_parameters=[ + [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]], + [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]], + [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]], + ], + scaling_parameters=None, + units=2, + activation="sigmoid", + reduction="some_reduction", + sparsity_factor=2, + expected="reduction = .* is not supported.*", + ), + dict( + testcase_name="input_shape", + inputs=[-1.0, 0.0], + location_parameters=[ + [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]], + [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]], + [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]], + ], + scaling_parameters=[ + [[[1.0]], [[1.0]]], + [[[2.0]], [[2.0]]], + [[[5.0]], [[7.0]]], + ], + units=2, + activation="sigmoid", + reduction="none", + sparsity_factor=2, + expected="inputs shape.*is not.*", + ), + dict( + testcase_name="units_and_sparsity_factor", + inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]], + location_parameters=[ + [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]], + [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]], + [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]], + ], + scaling_parameters=None, + units=2, + activation="sigmoid", + reduction="mean", + sparsity_factor=3, + expected="units.*is not divisible by sparsity_factor.*", + ), + dict( + testcase_name="input_dim_and_sparsity_factor", + inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]], + location_parameters=[ + [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]], + [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]], + [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]], + ], + scaling_parameters=[ + [[[1.0]], [[1.0]]], + [[[2.0]], [[2.0]]], + [[[5.0]], [[7.0]]], + ], + units=3, + activation="sigmoid", + reduction="mean", + sparsity_factor=3, + expected="input_dim.*is not divisible by sparsity_factor.*", + ), + dict( + testcase_name="location_parameters_shape_1", + inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]], + location_parameters=[ + [[-1.0], [0.0], [1.0]], + [[-2.0], [0.0], [2.0]], + [[-1.0], [0.0], [1.0]], + [[-3.0], [0.0], [3.0]], + [[-1.0], [0.0], [1.0]], + [[-4.0], [0.0], [4.0]], + ], + scaling_parameters=[ + [[[1.0]], [[1.0]]], + [[[2.0]], [[2.0]]], + [[[5.0]], [[7.0]]], + ], + units=2, + activation="sigmoid", + reduction="mean", + sparsity_factor=2, + expected="location_parameters shape.*is not.*", + ), + dict( + testcase_name="location_parameters_shape_2", + inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]], + location_parameters=[ + [[[-1.0], [0.0], [1.0]]], + [[[-1.0], [0.0], [1.0]]], + [[[-1.0], [0.0], [1.0]]], + ], + scaling_parameters=None, + units=2, + activation="sigmoid", + reduction="mean", + sparsity_factor=2, + expected="location_parameters shape.*is not.*", + ), + dict( + testcase_name="location_parameters_shape_3", + inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]], + location_parameters=[ + [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]], + [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]], + [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]], + ], + scaling_parameters=None, + units=2, + activation="sigmoid", + reduction="mean", + sparsity_factor=1, + expected="location_parameters shape.*is not.*", + ), + dict( + testcase_name="location_and_scaling_shape_1", + inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]], + location_parameters=[ + [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]], + [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]], + [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]], + ], + scaling_parameters=[ + [[[1.0]], [[1.0]]], + [[[2.0]], [[2.0]]], + [[[5.0]], [[7.0]]], + [[[5.0]], [[7.0]]], + ], + units=2, + activation="sigmoid", + reduction="mean", + sparsity_factor=2, + expected=( + "scaling_parameters and location_parameters" + " likely are not broadcastable.*" + ), + ), + dict( + testcase_name="location_and_scaling_shape_2", + inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]], + location_parameters=[ + [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]], + [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]], + [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]], + ], + scaling_parameters=[ + [[[1.0, 1.0]], [[1.0, 1.0]]], + [[[2.0, 1.0]], [[2.0, 1.0]]], + [[[5.0, 1.0]], [[7.0, 1.0]]], + ], + units=2, + activation="sigmoid", + reduction="mean", + sparsity_factor=2, + expected=( + "scaling_parameters and location_parameters" + " likely are not broadcastable.*" + ), + ), + ) + def test_raise( + self, + inputs, + location_parameters, + scaling_parameters, + units, + activation, + reduction, + sparsity_factor, + expected, + ): + with self.assertRaisesRegex(ValueError, expected): + _ = cdf_fn( + inputs=tf.constant(inputs, dtype=tf.float32), + location_parameters=tf.constant( + location_parameters, dtype=tf.float32 + ), + scaling_parameters=( + tf.constant(scaling_parameters, dtype=tf.float32) + if scaling_parameters is not None + else None + ), + units=units, + reduction=reduction, + activation=activation, + sparsity_factor=sparsity_factor, + ) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_lattice/python/conditional_pwl_calibration.py b/tensorflow_lattice/python/conditional_pwl_calibration.py new file mode 100644 index 0000000..3dae723 --- /dev/null +++ b/tensorflow_lattice/python/conditional_pwl_calibration.py @@ -0,0 +1,495 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implements PWLCalibration with derived parameters (kernels). + +`pwl_calibration_fn` is similar to `tfl.layers.PWLCalibration` with the key +difference that the keypoints are decided by the given parameters instead +of learnable weights belonging to a layer. These parameters can be one of: + + - constants, + - trainable variables, + - outputs from other TF modules. + +For inputs of shape `(batch_size, units)`, two sets of parameters are required +to configure the piece-wise linear calibrator in terms of its x and y values: + + - `keypoint_input_parameters` for configuring the x values, + - `keypoint_output_parameters` for configuring the y values. + +This function is a general form of conditional calibration, that one input +variable is calibrated based on free form parameters coming from other variables +and their transformations. + +Shapes: +The last dimension sizes of `keypoint_input_parameters` (input_param_size) and +`keypoint_output_parameters` (output_param_size) depend on the number of +keypoints used by the calibrator. We follow the relationships that + + - input_param_size = # keypoints - 2, as the leftmost and rightmost keypoints + are given. + - output_param_size = # keypoints initially, and we then modify it by + + 1. if cyclic calibrator: output_param_size -= 1, + 2. if clamp_min: output_param_size -= 1, + 3. if clamp_max: output_param_size -= 1, + 4. if need to learn how to impute missing: output_param_size += 1. + +The final shapes need to be broadcast friendly with `(batch_size, units, 1)`: + + - `keypoint_input_parameters`: + `(1 or batch_size, 1 or units, input_param_size)`. + - `keypoint_output_parameters`: + `(1 or batch_size, 1 or units, output_param_size)`. +""" + +from typing import Optional, Sequence, Tuple, Union + +import numpy as np +import tensorflow as tf + + +def _front_pad(x: tf.Tensor, constant_values: float) -> tf.Tensor: + return tf.pad(x, [[0, 0], [0, 0], [1, 0]], constant_values=constant_values) + + +def default_keypoint_output_parameters( + num_keypoints: int, + units: int = 1, + monotonicity: str = "none", + is_cyclic: bool = False, + clamp_min: bool = False, + clamp_max: bool = False, + derived_missing_output: bool = False, +) -> Optional[tf.Tensor]: + """Helper creating default `keypoint_output_parameters`. + + Primarily used for testing. + + Args: + num_keypoints: number of keypoints for inputs. + units: number of parallel calibrations on one input. + monotonicity: `none` or `increasing`, monotonicity of the calibration. + is_cyclic: whether the calibration is cyclic. Only works if `monotonicity == + none`. + clamp_min: whether the leftmost keypoint should be clamped. Only works if + `monotonicity == increasing`. + clamp_max: whether the rightmost keypoint should be clamped. Only works if + `monotonicity == increasing`. + derived_missing_output: whether to reserve a placeholder for the missing + output value. + + Returns: + A tensor with a shape of `(1, units, output_param_size)`. + + Raises: + `ValueError` if parsing failed. + """ + if monotonicity == "none": + output_param_size = num_keypoints - is_cyclic + derived_missing_output + # default output = midpoint between + # keypoint_output_min and keypoint_output_max, flat. + return tf.zeros((1, units, output_param_size), dtype=tf.float32) + elif monotonicity == "increasing": + output_param_size = ( + num_keypoints - clamp_min - clamp_max + derived_missing_output + ) + # default output = equal increments between + # keypoint_output_min and keypoint_output_max. + return tf.zeros((1, units, output_param_size), dtype=tf.float32) + else: + raise ValueError(f"Unknown monotonicity: {monotonicity}") + + +def default_keypoint_input_parameters( + num_keypoints: Optional[int] = None, + keypoints: Optional[Sequence[float]] = None, + units: int = 1, +) -> Optional[tf.Tensor]: + """Helper creating default `keypoint_input_parameters`. + + Primarily used for testing. + + Args: + num_keypoints: number of keypoints. If provided, keypoints will be equally + spaced. + keypoints: sequence of increasing keypoints. + units: number of parallel calibrations on one input. + + Returns: + A tensor with a shape of `(1, units, input_param_size)` or + `(1, units, input_param_size)`. + + Raises: + `ValueError` if parsing failed. + """ + if num_keypoints is not None and num_keypoints > 2: + return tf.zeros((1, units, num_keypoints - 2), dtype=tf.float32) + elif keypoints is not None: + keypoints = np.array(keypoints) + deltas = keypoints[1:] - keypoints[:-1] + if np.all(deltas > 0): + deltas = deltas / np.sum(deltas) + deltas = np.log(deltas / deltas[0])[1:] + deltas = tf.reshape(tf.constant(deltas, dtype=tf.float32), (1, 1, -1)) + return tf.tile(deltas, [1, units, 1]) + else: + raise ValueError("Neither num_keypoints nor keypoints is specified.") + + +def _verify_pwl_calibration( + inputs, + keypoint_input_parameters, + keypoint_output_parameters, + units, + keypoint_input_min, + keypoint_input_max, + keypoint_output_min, + keypoint_output_max, + clamp_min, + clamp_max, + monotonicity, + is_cyclic, + missing_input_value, + missing_output_value, +): + """Validates calibration arguments.""" + # Validate keypoint input_min and input_max. + if keypoint_input_min > keypoint_input_max: + raise ValueError( + f"keypoint_input_min = {keypoint_input_min} > keypoint_input_max =" + f" {keypoint_input_max}." + ) + + # Validate pwl shape arguments. + if monotonicity not in ("none", "increasing"): + raise ValueError( + "Monotonicity should be 'none' or 'increasing'. " + f"Given '{monotonicity}'." + ) + + if monotonicity == "none" and (clamp_min or clamp_max): + raise ValueError("Cannot clamp to min or max when monotonicity is 'none'.") + + if keypoint_output_min > keypoint_output_max: + raise ValueError( + f"keypoint_output_min = {keypoint_output_min} > keypoint_output_max =" + f" {keypoint_output_max}." + ) + + if monotonicity == "increasing" and is_cyclic: + raise ValueError("Monotonicity should be 'none' when is_cyclic=True.") + + # Validate missingness indicators. + if missing_output_value is not None and missing_input_value is None: + raise ValueError( + "missing_output_value is set, but missing_input_value is None" + ) + + # Validate parameter shapes. See module level doc string for details. + num_keypoints = ( + keypoint_input_parameters.shape[-1] + 2 + if keypoint_input_parameters is not None + else 0 + ) + output_param_size = ( + num_keypoints + - clamp_max + - clamp_min + - is_cyclic + + (missing_input_value is not None) + - (missing_output_value is not None) + ) + + if output_param_size <= 0: + raise ValueError( + f"Required keypoint_output_parameters per example = {output_param_size}" + " <= 0: Creating a trivial function, e.g. identity or constant." + ) + + if units > 1 and len(keypoint_output_parameters.shape) != 3: + raise ValueError( + "keypoint_output_parameters should be 3 dimensional when units > 1. " + f"Given {keypoint_output_parameters.shape}." + ) + if ( + len(keypoint_output_parameters.shape) == 3 + and keypoint_output_parameters.shape[1] != units + ): + raise ValueError( + "2nd dimension of keypoint_output_parameters does not match units, " + f"units = {units} vs keypoint_output_parameters = " + f"{keypoint_output_parameters.shape[1]}." + ) + if keypoint_output_parameters.shape[-1] != output_param_size: + raise ValueError( + "keypoint_output_parameters shape is " + f"{keypoint_output_parameters.shape} whose last dimension needs to be " + f"{output_param_size}." + ) + + # Validate input shape. + if inputs.shape[1] > 1 and inputs.shape[1] != units: + raise ValueError( + "2nd dimension of input shape does not match units > 1, " + f"Require (batch_size, 1) or (batch_size, units = {units})." + ) + + +def _compute_interpolation_weights(inputs, keypoints, lengths): + """Computes weights for PWL calibration. + + Args: + inputs: Tensor of shape: `(batch_size, units, 1)`. For multi-unit + calibration, broadcasting will be used if needed. + keypoints: Tensor of shape `(num_keypoints-1)` which represents left + keypoint of pieces of piecewise linear function along X axis. + lengths: Tensor of shape `(num_keypoints-1)` which represents lengths of + pieces of piecewise linear function along X axis. + + Returns: + Interpolation weights tensor of shape: `(batch_size, units, num_keypoints)`. + """ + # weights always matches the shape of inputs. + weights = (inputs - keypoints) / lengths + weights = tf.clip_by_value(weights, 0.0, 1.0) + return _front_pad(weights, 1.0) + + +@tf.function +def pwl_calibration_fn( + inputs: tf.Tensor, + keypoint_input_parameters: Optional[tf.Tensor], + keypoint_output_parameters: tf.Tensor, + keypoint_input_min: float = 0.0, + keypoint_input_max: float = 1.0, + keypoint_output_min: float = 0.0, + keypoint_output_max: float = 1.0, + units: int = 1, + monotonicity: str = "none", + clamp_min: bool = False, + clamp_max: bool = False, + is_cyclic: bool = False, + missing_input_value: Optional[float] = None, + missing_output_value: Optional[float] = None, + return_derived_parameters: bool = False, +) -> Union[tf.Tensor, Tuple[tf.Tensor, tf.Tensor, tf.Tensor]]: + """Calibrates `inputs` using derived parameters (kernels). + + `pwl_calibration_fn` is similar to `tfl.layers.PWLCalibration` with the key + difference that the keypoints are decided by the given parameters instead + of learnable weights belonging to a layer. These parameters can be one of: + + - constants, + - trainable variables, + - outputs from other TF modules. + + Shapes: + The last dimension of `keypoint_input_parameters` (`input_param_size`) and + `keypoint_output_parameters` (`output_param_size`) depend on the number of + keypoints used by the calibrator. We follow the relationships that + + - `input_param_size = # keypoints - 2`, as the leftmost and rightmost + keypoints are given. + - `output_param_size = # keypoints` initially, and we then modify it by + + 1. if cyclic calibrator: `output_param_size -= 1`, + 2. if clamp_min: `output_param_size -= 1`, + 3. if clamp_max: `output_param_size -= 1`, + 4. if need to learn how to impute missing: `output_param_size += 1`. + + The final shapes need to be broadcast friendly with `(batch_size, units, 1)`: + + - `keypoint_input_parameters`: + `(1 or batch_size, 1 or units, input_param_size)`. + - `keypoint_output_parameters`: + `(1 or batch_size, 1 or units, output_param_size)`. + + Input shape: + `inputs` should be one of: + + - `(batch_size, 1)` if `units == 1`. + - `(batch_size, 1)` or `(batch_size, units)` if `units > 1`. + The former will be broadcast to match units. + + `keypoint_input_parameters` should be one of: + + - `None` if only the leftmost and the rightmost keypoints are required. + - `(1, input_param_size)`. + - `(batch_size, input_param_size)`. + - `(1, 1, input_param_size)`. + - `(batch_size, 1, input_param_size)`. + - `(1, units, input_param_size)`. + - `(batch_size, units, input_param_size)`. + + `keypoint_output_parameters` should be one of: + + - `(1, output_param_size)`. + - `(batch_size, output_param_size)`. + - `(1, 1, output_param_size)`. + - `(batch_size, 1, output_param_size)`. + - `(1, units, output_param_size)`. + - `(batch_size, units, output_param_size)`. + + Args: + inputs: inputs to the calibration fn. + keypoint_input_parameters: parameters for keypoint x's of calibration fn. + keypoint_output_parameters: parameters for keypoint y's of calibration fn. + keypoint_input_min: the leftmost keypoint. + keypoint_input_max: the rightmost keypoint. + keypoint_output_min: lower bound of the fn output. + keypoint_output_max: upper bound of the fn output. + units: number of parallel calibrations on one input. + monotonicity: `none` or `increasing`. Whether the calibration is monotonic. + clamp_min: only applies when monotonicity == `increasing`. Whether to clamp + the LHS keypoint to the calibration `keypoint_output_min`. + clamp_max: only applies when monotonicity == `increasing`. Whether to clamp + the RHS keypoint to the calibration `keypoint_output_max`. + is_cyclic: only applies when monotonicity == `none`. Whether the LHS and the + RHS keypoints have the same calibration output. + missing_input_value: if set, use as the value indicating a missing input. + missing_output_value: if set, use as the output for `missing_input_value`. + return_derived_parameters: if True, return the derived kernel parameters + used for interpolation. + + Returns: + If `return_derived_parameters = False`: + + - The calibrated output as a tensor with shape `(batch_size, units)`. + + If `return_derived_parameters == True`: + + - A tuple of three elements: + + 1. The calibrated output as a tensor with shape `(batch_size, units)`. + 2. The deltas between the keypoints x's with shape + `(batch_size, units, # keypoints - 1)`. + 3. The initial value and the deltas between the keypoints y's, with + shape shape `(batch_size, units, # keypoints)`. Apply `cumsum` will + reconstruct the y values. + """ + _verify_pwl_calibration( + inputs=inputs, + keypoint_input_parameters=keypoint_input_parameters, + keypoint_output_parameters=keypoint_output_parameters, + units=units, + keypoint_input_min=keypoint_input_min, + keypoint_input_max=keypoint_input_max, + keypoint_output_min=keypoint_output_min, + keypoint_output_max=keypoint_output_max, + clamp_min=clamp_min, + clamp_max=clamp_max, + monotonicity=monotonicity, + is_cyclic=is_cyclic, + missing_input_value=missing_input_value, + missing_output_value=missing_output_value, + ) + + if keypoint_input_parameters is None: + keypoint_input_parameters = tf.zeros((1, units, 1), dtype=tf.float32) + else: + if len(keypoint_input_parameters.shape) == 2: + keypoint_input_parameters = keypoint_input_parameters[:, tf.newaxis, :] + if keypoint_input_parameters.shape[1] == 1 and units > 1: + keypoint_input_parameters = tf.tile( + keypoint_input_parameters, [1, units, 1] + ) + # Front-pad 0 to normalize softmax. + keypoint_input_parameters = _front_pad(keypoint_input_parameters, 0.0) + + keypoint_deltas = tf.nn.softmax(keypoint_input_parameters, axis=-1) * ( + keypoint_input_max - keypoint_input_min + ) + # Front-pad `input_min` as the leftmost keypoint. + # Trim the rightmost keypoint not required for interpolation. + keypoints = ( + tf.cumsum(keypoint_deltas, exclusive=True, axis=-1) + keypoint_input_min + ) + + # Rename since its value will be modified as part of the output. + kernel_outputs = keypoint_output_parameters + if len(kernel_outputs.shape) == 2: + kernel_outputs = kernel_outputs[:, tf.newaxis, :] + if kernel_outputs.shape[1] == 1 and units > 1: + kernel_outputs = tf.tile(kernel_outputs, [1, units, 1]) + + missing_output = None + if missing_input_value is not None: + if missing_output_value is None: + # The last parameter is used to derive the imputed output value after + # sigmoid and rescale. + missing_output = keypoint_output_min + tf.sigmoid( + kernel_outputs[:, :, -1] + ) * (keypoint_output_max - keypoint_output_min) + kernel_outputs = kernel_outputs[:, :, :-1] + else: + missing_output = tf.fill( + kernel_outputs[:, :, -1].shape, missing_output_value + ) + + if monotonicity == "none": + kernel_outputs = ( + tf.sigmoid(kernel_outputs) * (keypoint_output_max - keypoint_output_min) + + keypoint_output_min + ) + if is_cyclic: + kernel_outputs = tf.concat( + [kernel_outputs, kernel_outputs[:, :, :1]], axis=-1 + ) + # Transform to [initial value, delta_0, delta_1,...]. + kernel_outputs = tf.concat( + [ + kernel_outputs[:, :, :1], + kernel_outputs[:, :, 1:] - kernel_outputs[:, :, :-1], + ], + axis=-1, + ) + else: # monotonicity == "increasing" + # Front-pad zero to normalize softmax. + kernel_outputs = _front_pad(kernel_outputs, 0.0) + kernel_outputs = tf.nn.softmax(kernel_outputs, axis=-1) * ( + keypoint_output_max - keypoint_output_min + ) + if clamp_min: + # Front-pad keypoint_output_min to the kernel_outputs. + kernel_outputs = _front_pad(kernel_outputs, keypoint_output_min) + else: + # Add keypoint_output_min to the LHS element in the kernel_outputs. + # TODO: test tf.tensor_scatter_nd_add. + kernel_outputs = tf.concat( + [ + kernel_outputs[:, :, :1] + keypoint_output_min, + kernel_outputs[:, :, 1:], + ], + axis=-1, + ) + if not clamp_max: + # Drop the RHS element in the kernel_outputs which made cumsum = 1. + kernel_outputs = kernel_outputs[:, :, :-1] + + if units > 1 and inputs.shape[-1] == 1: + inputs = tf.tile(inputs, [1, units]) + weights = _compute_interpolation_weights( + tf.reshape(inputs, (-1, units, 1)), keypoints, keypoint_deltas + ) + outputs = tf.reduce_sum(weights * kernel_outputs, axis=-1, keepdims=False) + + if missing_input_value is not None: + outputs = tf.where( + tf.equal(inputs, missing_input_value), missing_output, outputs + ) + + if return_derived_parameters: + return outputs, keypoint_deltas, kernel_outputs + else: + return outputs diff --git a/tensorflow_lattice/python/conditional_pwl_calibration_test.py b/tensorflow_lattice/python/conditional_pwl_calibration_test.py new file mode 100644 index 0000000..48bf66d --- /dev/null +++ b/tensorflow_lattice/python/conditional_pwl_calibration_test.py @@ -0,0 +1,544 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF tests for pwl_calibration_fn.py.""" + +import tensorflow as tf + +from tensorflow_lattice.python.conditional_pwl_calibration import default_keypoint_input_parameters +from tensorflow_lattice.python.conditional_pwl_calibration import pwl_calibration_fn + +_EPSILON = 1e-4 + + +class PwlCalibrationFnTest(tf.test.TestCase): + + def assertAllClose(self, x, y): + super().assertAllClose(x, y, rtol=_EPSILON, atol=_EPSILON) + + def assertAllGreaterEqual(self, a, comparison_target): + super().assertAllGreaterEqual(a, comparison_target - _EPSILON) + + def assertAllLessEqual(self, a, comparison_target): + super().assertAllLessEqual(a, comparison_target + _EPSILON) + + def assertAllEqual(self, a, comparison_target): + super().assertAllInRange( + a, comparison_target - _EPSILON, comparison_target + _EPSILON + ) + + def setUp(self): + super().setUp() + self.kernel_4 = tf.constant( + [ + [-0.38, -0.41, -0.34, -0.29], + [0.17, -0.32, 0.33, -0.1], + ], + dtype=tf.float32, + ) + self.kernel_5 = tf.constant( + [ + [-0.38, -0.41, -0.34, -0.29, 0.42], + [0.17, -0.32, 0.33, -0.1, -0.36], + ], + dtype=tf.float32, + ) + self.multi_unit_kernel_4 = tf.constant( + [ + [ + [-0.26, 0.43, 0.49, 0.26], + [0.39, 0.42, -0.33, 0.41], + [0.28, 0.04, 0.46, 0.09], + ], + [ + [-0.27, -0.23, 0.29, -0.12], + [-0.4, -0.24, -0.31, 0.01], + [0.03, 0.01, -0.42, -0.42], + ], + ], + dtype=tf.float32, + ) + self.multi_unit_kernel_5 = tf.constant( + [ + [ + [-0.26, 0.43, 0.49, 0.26, -0.32], + [0.39, 0.42, -0.33, 0.41, 0.11], + [0.28, 0.04, 0.46, 0.09, -0.33], + ], + [ + [-0.27, -0.23, 0.29, -0.12, 0.46], + [-0.4, -0.24, -0.31, 0.01, 0.21], + [0.03, 0.01, -0.42, -0.42, 0.37], + ], + ], + dtype=tf.float32, + ) + + def test_suite_none_monotonic(self): + """Tests non-monotonic calibration.""" + # basic call + y = pwl_calibration_fn( + inputs=tf.constant([[0.5], [0.8]]), + keypoint_output_parameters=self.kernel_4, + keypoint_input_parameters=default_keypoint_input_parameters( + keypoints=[0.0, 0.1, 0.4, 1.0] + ), + keypoint_input_min=0.0, + keypoint_input_max=1.0, + ) + print(default_keypoint_input_parameters(keypoints=[0.0, 0.1, 0.4, 1.0])) + self.assertAllClose(y, tf.constant([[0.41784188], [0.51060027]])) + + # if is_cyclic, starting and ending keypoints give the same prediction + y1 = pwl_calibration_fn( + inputs=tf.constant([[0.5], [0.5]]), + keypoint_output_parameters=self.kernel_4, + is_cyclic=True, + keypoint_input_parameters=default_keypoint_input_parameters( + keypoints=[0.5, 0.6, 0.65, 0.7, 0.8] + ), + keypoint_input_min=0.5, + keypoint_input_max=0.8, + ) + y2 = pwl_calibration_fn( + inputs=tf.constant([[0.8], [0.8]]), + keypoint_output_parameters=self.kernel_4, + is_cyclic=True, + keypoint_input_parameters=default_keypoint_input_parameters( + keypoints=[0.5, 0.6, 0.65, 0.7, 0.8] + ), + keypoint_input_min=0.5, + keypoint_input_max=0.8, + ) + self.assertAllClose(y1, y2) + + # basic multi-unit call, input needs broadcast + y = pwl_calibration_fn( + inputs=tf.constant([[0.5], [0.8]]), + units=3, + keypoint_output_parameters=self.multi_unit_kernel_4, + keypoint_input_parameters=default_keypoint_input_parameters( + keypoints=[0.0, 0.1, 0.4, 1.0] + ), + keypoint_input_min=0.0, + keypoint_input_max=1.0, + ) + self.assertAllClose( + y, + tf.constant([ + [0.6108614, 0.44871515, 0.5979259], + [0.50402266, 0.47603822, 0.39651677], + ]), + ) + + # basic multi-unit call + y = pwl_calibration_fn( + inputs=tf.constant([[0.5, 0.5, 0.5], [0.8, 0.8, 0.8]]), + units=3, + keypoint_output_parameters=self.multi_unit_kernel_4, + keypoint_input_parameters=default_keypoint_input_parameters( + keypoints=[0.0, 0.1, 0.4, 1.0] + ), + keypoint_input_min=0.0, + keypoint_input_max=1.0, + ) + self.assertAllClose( + y, + tf.constant([ + [0.6108614, 0.44871515, 0.5979259], + [0.50402266, 0.47603822, 0.39651677], + ]), + ) + + # keypoint_output_min and keypoint_output_max scales correctly + y1 = pwl_calibration_fn( + inputs=tf.constant([[0.5, 0.5, 0.5], [0.8, 0.8, 0.8]]), + units=3, + keypoint_output_parameters=self.multi_unit_kernel_4, + keypoint_input_parameters=default_keypoint_input_parameters( + keypoints=[0.0, 0.1, 0.4, 1.0] + ), + keypoint_input_min=0.0, + keypoint_input_max=1.0, + ) + y2 = pwl_calibration_fn( + inputs=tf.constant([[0.5, 0.5, 0.5], [0.8, 0.8, 0.8]]), + units=3, + keypoint_output_parameters=self.multi_unit_kernel_4, + keypoint_input_parameters=default_keypoint_input_parameters( + keypoints=[0.0, 0.1, 0.4, 1.0] + ), + keypoint_output_min=-1.0, + keypoint_output_max=10.0, + keypoint_input_min=0.0, + keypoint_input_max=1.0, + ) + self.assertAllClose(y1 * 11.0 - 1.0, y2) + + # multi-unit is_cyclic gives cyclic predictions + y1 = pwl_calibration_fn( + inputs=tf.constant([[-0.1], [1.1]]), + units=3, + is_cyclic=True, + keypoint_output_parameters=self.multi_unit_kernel_4, + keypoint_input_parameters=default_keypoint_input_parameters( + keypoints=[0.0, 0.1, 0.2, 0.5, 1.0] + ), + keypoint_input_min=0.0, + keypoint_input_max=1.0, + ) + y2 = pwl_calibration_fn( + inputs=tf.constant([[-1.0, 0.0, 1.0], [-1.0, 0.0, 1.0]]), + units=3, + is_cyclic=True, + keypoint_output_parameters=self.multi_unit_kernel_4, + keypoint_input_parameters=default_keypoint_input_parameters( + keypoints=[0.0, 0.1, 0.2, 0.5, 1.0] + ), + keypoint_input_min=0.0, + keypoint_input_max=1.0, + ) + self.assertAllClose(y1, y2) + + # missing input with given missing output imputed correctly + y = pwl_calibration_fn( + inputs=tf.constant([[0.5], [-1.0]]), + keypoint_output_parameters=self.kernel_4, + keypoint_input_parameters=default_keypoint_input_parameters( + keypoints=[0.0, 0.1, 0.4, 1.0] + ), + keypoint_input_min=0.0, + keypoint_input_max=1.0, + missing_input_value=-1.0, + missing_output_value=3.0, + ) + self.assertAllClose(y, tf.constant([[0.41784188], [3.0]])) + + # missing input imputed correctly with derived missing output + y = pwl_calibration_fn( + inputs=tf.constant([[0.5], [-1.0]]), + keypoint_output_parameters=self.kernel_5, + keypoint_input_parameters=default_keypoint_input_parameters( + keypoints=[0.0, 0.1, 0.4, 1.0] + ), + keypoint_input_min=0.0, + keypoint_input_max=1.0, + missing_input_value=-1.0, + ) + self.assertAllClose(y, tf.constant([[0.41784188], [0.41095957]])) + + def test_suite_increasing_monotonic(self): + """Tests monotonic calibration.""" + # basic call + y = pwl_calibration_fn( + inputs=tf.constant([[0.5], [0.8]]), + keypoint_output_parameters=self.kernel_4, + monotonicity='increasing', + keypoint_input_parameters=default_keypoint_input_parameters( + keypoints=[0.0, 0.1, 0.4, 1.0] + ), + keypoint_input_min=0.0, + keypoint_input_max=1.0, + ) + self.assertAllClose(y, tf.constant([[0.64769804], [0.7371951]])) + + # outputs are monotonic + y1 = pwl_calibration_fn( + inputs=tf.constant([[-0.5], [0.3]]), + keypoint_output_parameters=self.kernel_4, + monotonicity='increasing', + keypoint_input_parameters=default_keypoint_input_parameters( + keypoints=[0.0, 0.1, 0.4, 1.0] + ), + keypoint_input_min=0.0, + keypoint_input_max=1.0, + ) + y2 = pwl_calibration_fn( + inputs=tf.constant([[0.5], [0.8]]), + keypoint_output_parameters=self.kernel_4, + monotonicity='increasing', + keypoint_input_parameters=default_keypoint_input_parameters( + keypoints=[0.0, 0.1, 0.4, 1.0] + ), + keypoint_input_min=0.0, + keypoint_input_max=1.0, + ) + y3 = pwl_calibration_fn( + inputs=tf.constant([[0.6], [1.2]]), + keypoint_output_parameters=self.kernel_4, + monotonicity='increasing', + keypoint_input_parameters=default_keypoint_input_parameters( + keypoints=[0.0, 0.1, 0.4, 1.0] + ), + keypoint_input_min=0.0, + keypoint_input_max=1.0, + ) + self.assertAllGreaterEqual(y2 - y1, 0.0) + self.assertAllGreaterEqual(y3 - y2, 0.0) + + # clamp_min works as expected + y = pwl_calibration_fn( + inputs=tf.constant([[0.0], [-0.2]]), + keypoint_output_parameters=self.multi_unit_kernel_4, + keypoint_input_parameters=default_keypoint_input_parameters( + num_keypoints=5 + ), + keypoint_input_min=0.0, + keypoint_input_max=2.0, + monotonicity='increasing', + keypoint_output_min=-10.0, + clamp_min=True, + units=3, + ) + self.assertAllEqual(y, -10.0) + + # clamp_out works as expected + y = pwl_calibration_fn( + inputs=tf.constant([[2.0], [2.5]]), + keypoint_output_parameters=self.multi_unit_kernel_4, + keypoint_input_parameters=default_keypoint_input_parameters( + num_keypoints=5 + ), + keypoint_input_min=0.0, + keypoint_input_max=2.0, + monotonicity='increasing', + keypoint_output_max=10.0, + clamp_max=True, + units=3, + ) + self.assertAllEqual(y, 10.0) + + # clamp_min and clamp_out work as expected together, min + y = pwl_calibration_fn( + inputs=tf.constant([[0.0, 0.0, -10.0], [-0.2, 0.0, -100.0]]), + keypoint_output_parameters=self.multi_unit_kernel_4, + keypoint_input_parameters=default_keypoint_input_parameters( + keypoints=[0.0, 0.1, 0.4, 1.0, 1.5, 2.0] + ), + keypoint_input_min=0.0, + keypoint_input_max=2.0, + monotonicity='increasing', + keypoint_output_min=-10.0, + clamp_min=True, + keypoint_output_max=5.0, + clamp_max=True, + units=3, + ) + self.assertAllEqual(y, -10.0) + + # clamp_min and clamp_out work as expected together, max + y = pwl_calibration_fn( + inputs=tf.constant([[2.0, 3.0, 4.0], [2.5, 2.5, 2.5]]), + keypoint_output_parameters=self.multi_unit_kernel_4, + units=3, + keypoint_input_parameters=default_keypoint_input_parameters( + keypoints=[0.0, 0.1, 0.4, 1.0, 1.5, 2.0] + ), + keypoint_input_min=0.0, + keypoint_input_max=2.0, + monotonicity='increasing', + keypoint_output_min=-10.0, + clamp_min=True, + keypoint_output_max=5.0, + clamp_max=True, + ) + self.assertAllEqual(y, 5.0) + + # clamp_min, clamp_out, missing_input_value and derived missing_output_value + # work as expected together + y = pwl_calibration_fn( + inputs=tf.constant([[0.0, 1.0, 2.0], [-0.5, 1.5, 2.5]]), + keypoint_output_parameters=self.multi_unit_kernel_5, + units=3, + keypoint_input_parameters=default_keypoint_input_parameters( + keypoints=[0.0, 0.1, 0.4, 1.0, 1.5, 2.0] + ), + keypoint_input_min=0.0, + keypoint_input_max=2.0, + monotonicity='increasing', + keypoint_output_min=-10.0, + clamp_min=True, + keypoint_output_max=5.0, + clamp_max=True, + missing_input_value=-1.0, + ) + self.assertAllClose( + y, tf.constant([[-10.0, -0.3635044, 5.0], [-10.0, 1.3930602, 5.0]]) + ) + + # clamp_min, clamp_out and missing_input_value work as expected together + y = pwl_calibration_fn( + inputs=tf.constant([[0.0, -1.0, 2.0], [-0.5, -1.0, 2.5]]), + keypoint_output_parameters=self.multi_unit_kernel_4, + units=3, + keypoint_input_parameters=default_keypoint_input_parameters( + keypoints=[0.0, 0.1, 0.4, 1.0, 1.5, 2.0] + ), + keypoint_input_min=0.0, + keypoint_input_max=2.0, + monotonicity='increasing', + keypoint_output_min=-10.0, + clamp_min=True, + keypoint_output_max=5.0, + clamp_max=True, + missing_input_value=-1.0, + missing_output_value=3.0, + ) + self.assertAllClose(y, tf.constant([[-10.0, 3.0, 5.0], [-10.0, 3.0, 5.0]])) + + def test_gradient_step(self): + """Tests gradient computation.""" + trainable = tf.Variable( + tf.zeros_like(self.multi_unit_kernel_5, dtype=tf.float32), + trainable=True, + name='trainable', + ) + + with tf.GradientTape() as tape: + y = pwl_calibration_fn( + inputs=tf.constant([[-1.0, 0.0, 1.0], [0.8, 2.0, 3.0]]), + keypoint_output_parameters=trainable, + units=3, + keypoint_input_parameters=default_keypoint_input_parameters( + keypoints=[0.0, 0.1, 0.4, 1.0, 2.0] + ), + keypoint_input_min=0.0, + keypoint_input_max=2.0, + monotonicity='increasing', + keypoint_output_max=10.0, + clamp_max=True, + missing_input_value=-1.0, + ) + loss = tf.reduce_mean(y * y) + grads = tape.gradient(loss, trainable) + self.assertAllClose( + grads, + tf.constant([ + [ + [0.0, 0.0, 0.0, 0.0, 4.166667], + [-0.26666668, -0.26666668, -0.26666668, -0.26666668, 0.0], + [1.0666668, 1.0666668, 1.0666668, -4.266667, 0.0], + ], + [ + [1.3037037, 1.3037037, -0.3259262, -3.5851853, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + ]), + ) + + def test_suite_raises(self): + """Tests verifiable ValueErrors.""" + + with self.assertRaises(ValueError): + _ = pwl_calibration_fn( + inputs=tf.constant([[0.5], [0.8]]), + keypoint_output_parameters=self.kernel_4, + keypoint_input_parameters=default_keypoint_input_parameters( + keypoints=[0.0, 0.1, 0.1, 1.0] + ), + ) + with self.assertRaises(ValueError): + _ = pwl_calibration_fn( + inputs=tf.constant([[0.5], [0.8]]), + keypoint_output_parameters=self.kernel_4, + keypoint_input_parameters=default_keypoint_input_parameters( + keypoints=[0.0, 0.3, 0.1, 1.0] + ), + ) + with self.assertRaises(ValueError): + _ = pwl_calibration_fn( + inputs=tf.constant([[0.5], [0.8]]), + keypoint_output_parameters=self.kernel_4, + keypoint_input_parameters=default_keypoint_input_parameters( + keypoints=[0.0, 0.1, 0.2, 0.3, 1.0] + ), + ) + with self.assertRaises(ValueError): + _ = pwl_calibration_fn( + inputs=tf.constant([[0.5], [0.8]]), + keypoint_output_parameters=self.kernel_4, + is_cyclic=True, + keypoint_input_parameters=default_keypoint_input_parameters( + keypoints=[0.0, 0.1, 0.3, 1.0] + ), + ) + with self.assertRaises(ValueError): + _ = pwl_calibration_fn( + inputs=tf.constant([[0.5], [0.8]]), + keypoint_output_parameters=self.kernel_4, + units=3, + keypoint_input_parameters=default_keypoint_input_parameters( + keypoints=[0.0, 0.1, 0.5, 1.0] + ), + ) + with self.assertRaises(ValueError): + _ = pwl_calibration_fn( + inputs=tf.constant([[0.5], [0.8]]), + keypoint_output_parameters=self.kernel_4, + keypoint_input_parameters=default_keypoint_input_parameters( + keypoints=[0.0, 0.1, 0.4, 1.0] + ), + keypoint_output_min=1.0, + keypoint_output_max=0.0, + ) + with self.assertRaises(ValueError): + _ = pwl_calibration_fn( + inputs=tf.constant([[0.5], [0.8]]), + keypoint_output_parameters=self.kernel_4, + keypoint_input_parameters=default_keypoint_input_parameters( + keypoints=[0.0, 0.1, 0.4, 1.0] + ), + missing_output_value=1.0, + ) + with self.assertRaises(ValueError): + _ = pwl_calibration_fn( + inputs=tf.constant([[0.5], [0.8]]), + keypoint_output_parameters=self.kernel_4, + keypoint_input_parameters=default_keypoint_input_parameters( + keypoints=[0.0, 0.1, 0.4, 1.0] + ), + keypoint_input_min=1.0, + keypoint_input_max=0.0, + ) + with self.assertRaises(ValueError): + _ = pwl_calibration_fn( + inputs=tf.constant([[0.5], [0.8]]), + keypoint_output_parameters=self.multi_unit_kernel_4, + keypoint_input_parameters=default_keypoint_input_parameters( + keypoints=[0.0, 0.1, 0.4, 1.0] + ), + ) + with self.assertRaises(ValueError): + _ = pwl_calibration_fn( + inputs=tf.constant([[0.5], [0.8]]), + keypoint_output_parameters=self.kernel_5, + monotonicity='increasing', + keypoint_input_parameters=default_keypoint_input_parameters( + keypoints=[0.0, 0.1, 0.4, 1.0] + ), + ) + with self.assertRaises(ValueError): + _ = pwl_calibration_fn( + inputs=tf.constant([[0.5, 0.6, 0.7, 0.8], [0.0, 0.1, 0.2, 0.8]]), + units=3, + keypoint_output_parameters=self.multi_unit_kernel_5, + monotonicity='increasing', + keypoint_input_parameters=default_keypoint_input_parameters( + keypoints=[0.0, 0.1, 0.4, 0.7, 1.0] + ), + ) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_lattice/python/configs.py b/tensorflow_lattice/python/configs.py index c5d0e73..f049efb 100644 --- a/tensorflow_lattice/python/configs.py +++ b/tensorflow_lattice/python/configs.py @@ -92,24 +92,24 @@ def get_config(self): config.pop('__class__') if 'feature_configs' in config and config['feature_configs'] is not None: config['feature_configs'] = [ - tf.keras.utils.serialize_keras_object(feature_config) + tf.keras.utils.legacy.serialize_keras_object(feature_config) for feature_config in config['feature_configs'] ] if 'regularizer_configs' in config and config[ 'regularizer_configs'] is not None: config['regularizer_configs'] = [ - tf.keras.utils.serialize_keras_object(regularizer_config) + tf.keras.utils.legacy.serialize_keras_object(regularizer_config) for regularizer_config in config['regularizer_configs'] ] if ('reflects_trust_in' in config and config['reflects_trust_in'] is not None): config['reflects_trust_in'] = [ - tf.keras.utils.serialize_keras_object(trust_config) + tf.keras.utils.legacy.serialize_keras_object(trust_config) for trust_config in config['reflects_trust_in'] ] if 'dominates' in config and config['dominates'] is not None: config['dominates'] = [ - tf.keras.utils.serialize_keras_object(dominance_config) + tf.keras.utils.legacy.serialize_keras_object(dominance_config) for dominance_config in config['dominates'] ] return config @@ -120,28 +120,32 @@ def deserialize_nested_configs(cls, config, custom_objects=None): config = copy.deepcopy(config) if 'feature_configs' in config and config['feature_configs'] is not None: config['feature_configs'] = [ - tf.keras.utils.deserialize_keras_object( - feature_config, custom_objects=custom_objects) + tf.keras.utils.legacy.deserialize_keras_object( + feature_config, custom_objects=custom_objects + ) for feature_config in config['feature_configs'] ] if 'regularizer_configs' in config and config[ 'regularizer_configs'] is not None: config['regularizer_configs'] = [ - tf.keras.utils.deserialize_keras_object( - regularizer_config, custom_objects=custom_objects) + tf.keras.utils.legacy.deserialize_keras_object( + regularizer_config, custom_objects=custom_objects + ) for regularizer_config in config['regularizer_configs'] ] if ('reflects_trust_in' in config and config['reflects_trust_in'] is not None): config['reflects_trust_in'] = [ - tf.keras.utils.deserialize_keras_object( - trust_config, custom_objects=custom_objects) + tf.keras.utils.legacy.deserialize_keras_object( + trust_config, custom_objects=custom_objects + ) for trust_config in config['reflects_trust_in'] ] if 'dominates' in config and config['dominates'] is not None: config['dominates'] = [ - tf.keras.utils.deserialize_keras_object( - dominance_config, custom_objects=custom_objects) + tf.keras.utils.legacy.deserialize_keras_object( + dominance_config, custom_objects=custom_objects + ) for dominance_config in config['dominates'] ] return config diff --git a/tensorflow_lattice/python/estimators.py b/tensorflow_lattice/python/estimators.py index 8e8e896..01a3ca3 100644 --- a/tensorflow_lattice/python/estimators.py +++ b/tensorflow_lattice/python/estimators.py @@ -90,7 +90,6 @@ from tensorflow.python.feature_column import feature_column as fc # pylint: disable=g-direct-tensorflow-import from tensorflow.python.feature_column import feature_column_v2 as fc2 # pylint: disable=g-direct-tensorflow-import -from tensorflow.python.keras.utils import losses_utils # pylint: disable=g-direct-tensorflow-import from tensorflow.python.training import training_util # pylint: disable=g-direct-tensorflow-import from tensorflow_estimator.python.estimator import estimator as estimator_lib from tensorflow_estimator.python.estimator.canned import optimizers @@ -915,7 +914,7 @@ def __init__(self, prefitting_steps=None, config=None, warm_start_from=None, - loss_reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, + loss_reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE, loss_fn=None, dtype=tf.float32): """Initializes a `CannedClassifier` instance. @@ -1085,7 +1084,7 @@ def __init__(self, prefitting_steps=None, config=None, warm_start_from=None, - loss_reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, + loss_reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE, loss_fn=None, dtype=tf.float32): """Initializes a `CannedRegressor` instance. diff --git a/tensorflow_lattice/python/kronecker_factored_lattice_layer.py b/tensorflow_lattice/python/kronecker_factored_lattice_layer.py index 454c9b5..89ad6f1 100644 --- a/tensorflow_lattice/python/kronecker_factored_lattice_layer.py +++ b/tensorflow_lattice/python/kronecker_factored_lattice_layer.py @@ -231,8 +231,15 @@ def build(self, input_shape): # of the initializer using partial functions if it accepts scale. parameters = inspect.signature(self.kernel_initializer).parameters.keys() if "scale" in parameters: + # initial_value needs the lambda because it is a class property and the + # second and third arguments to tf.cond should be functions, + # but read_value is already a function, so the lambda is not needed. kernel_initializer = functools.partial( - self.kernel_initializer, scale=self.scale.initialized_value()) + self.kernel_initializer, + scale=tf.cond( + tf.compat.v1.is_variable_initialized(self.scale), + self.scale.read_value, + lambda: self.scale.initial_value)) else: kernel_initializer = self.kernel_initializer self.kernel = self.add_weight( @@ -297,9 +304,11 @@ def get_config(self): "output_max": self.output_max, "clip_inputs": self.clip_inputs, "kernel_initializer": - keras.initializers.serialize(self.kernel_initializer), + keras.initializers.serialize( + self.kernel_initializer, use_legacy_format=True), "scale_initializer": - keras.initializers.serialize(self.scale_initializer), + keras.initializers.serialize( + self.scale_initializer, use_legacy_format=True), } # pyformat: disable config.update(super(KroneckerFactoredLattice, self).get_config()) return config diff --git a/tensorflow_lattice/python/lattice_layer.py b/tensorflow_lattice/python/lattice_layer.py index a06c78d..3bb914d 100644 --- a/tensorflow_lattice/python/lattice_layer.py +++ b/tensorflow_lattice/python/lattice_layer.py @@ -487,9 +487,11 @@ def get_config(self): "clip_inputs": self.clip_inputs, "interpolation": self.interpolation, "kernel_initializer": - keras.initializers.serialize(self.kernel_initializer), + keras.initializers.serialize( + self.kernel_initializer, use_legacy_format=True), "kernel_regularizer": - [keras.regularizers.serialize(r) for r in self.kernel_regularizer], + [keras.regularizers.serialize(r, use_legacy_format=True) + for r in self.kernel_regularizer], } # pyformat: disable config.update(super(Lattice, self).get_config()) return config diff --git a/tensorflow_lattice/python/linear_layer.py b/tensorflow_lattice/python/linear_layer.py index 0a2314b..aaac196 100644 --- a/tensorflow_lattice/python/linear_layer.py +++ b/tensorflow_lattice/python/linear_layer.py @@ -297,16 +297,20 @@ def get_config(self): "input_min": self.input_min, "input_max": self.input_max, "kernel_initializer": - keras.initializers.serialize(self.kernel_initializer), + keras.initializers.serialize( + self.kernel_initializer, use_legacy_format=True), "kernel_regularizer": [ - keras.regularizers.serialize(r) for r in self.kernel_regularizer + keras.regularizers.serialize(r, use_legacy_format=True) + for r in self.kernel_regularizer ], } # pyformat: disable if self.use_bias: config["bias_initializer"] = keras.initializers.serialize( - self.bias_initializer) + self.bias_initializer, use_legacy_format=True + ) config["bias_regularizer"] = [ - keras.regularizers.serialize(r) for r in self.bias_regularizer + keras.regularizers.serialize(r, use_legacy_format=True) + for r in self.bias_regularizer ] config.update(super(Linear, self).get_config()) diff --git a/tensorflow_lattice/python/parallel_combination_layer.py b/tensorflow_lattice/python/parallel_combination_layer.py index 97e368e..ae0c711 100644 --- a/tensorflow_lattice/python/parallel_combination_layer.py +++ b/tensorflow_lattice/python/parallel_combination_layer.py @@ -101,7 +101,10 @@ def __init__(self, calibration_layers=None, single_output=True, **kwargs): categorical_calibration_layer.CategoricalCalibration, }): self.calibration_layers.append( - keras.layers.deserialize(calibration_layer)) + keras.layers.deserialize( + calibration_layer, use_legacy_format=True + ) + ) self.single_output = single_output def append(self, calibration_layer): @@ -152,8 +155,10 @@ def compute_output_shape(self, input_shape): def get_config(self): """Standard Keras config for serialization.""" config = { - "calibration_layers": [keras.layers.serialize(layer) - for layer in self.calibration_layers], + "calibration_layers": [ + keras.layers.serialize(layer, use_legacy_format=True) + for layer in self.calibration_layers + ], "single_output": self.single_output, } # pyformat: disable config.update(super(ParallelCombination, self).get_config()) diff --git a/tensorflow_lattice/python/premade.py b/tensorflow_lattice/python/premade.py index cadb92e..d671a67 100644 --- a/tensorflow_lattice/python/premade.py +++ b/tensorflow_lattice/python/premade.py @@ -137,25 +137,21 @@ def __init__(self, model_config=None, dtype=tf.float32, **kwargs): def get_config(self): """Returns a configuration dictionary.""" - config = super(CalibratedLatticeEnsemble, self).get_config() - config['model_config'] = tf.keras.utils.serialize_keras_object( - self.model_config) + config = {'name': self.name, 'trainable': self.trainable} + config['model_config'] = tf.keras.utils.legacy.serialize_keras_object( + self.model_config + ) return config @classmethod def from_config(cls, config, custom_objects=None): - model = super(CalibratedLatticeEnsemble, cls).from_config( - config, custom_objects=custom_objects) - try: - model_config = tf.keras.utils.deserialize_keras_object( - config.get('model_config'), custom_objects=custom_objects) - premade_lib.verify_config(model_config) - model.model_config = model_config - except ValueError: - logging.warning( - 'Could not load model_config. Constructing model without it: %s', - str(config.get('model_config'))) - return model + model_config = tf.keras.utils.legacy.deserialize_keras_object( + config.get('model_config'), custom_objects=custom_objects + ) + premade_lib.verify_config(model_config) + return cls(model_config, + name=config.get('name', None), + trainable=config.get('trainable', True)) class CalibratedLattice(tf.keras.Model): @@ -251,25 +247,21 @@ def __init__(self, model_config=None, dtype=tf.float32, **kwargs): def get_config(self): """Returns a configuration dictionary.""" - config = super(CalibratedLattice, self).get_config() - config['model_config'] = tf.keras.utils.serialize_keras_object( - self.model_config) + config = {'name': self.name, 'trainable': self.trainable} + config['model_config'] = tf.keras.utils.legacy.serialize_keras_object( + self.model_config + ) return config @classmethod def from_config(cls, config, custom_objects=None): - model = super(CalibratedLattice, cls).from_config( - config, custom_objects=custom_objects) - try: - model_config = tf.keras.utils.deserialize_keras_object( - config.get('model_config'), custom_objects=custom_objects) - premade_lib.verify_config(model_config) - model.model_config = model_config - except ValueError: - logging.warning( - 'Could not load model_config. Constructing model without it: %s', - str(config.get('model_config'))) - return model + model_config = tf.keras.utils.legacy.deserialize_keras_object( + config.get('model_config'), custom_objects=custom_objects + ) + premade_lib.verify_config(model_config) + return cls(model_config, + name=config.get('name', None), + trainable=config.get('trainable', True)) class CalibratedLinear(tf.keras.Model): @@ -368,25 +360,21 @@ def __init__(self, model_config=None, dtype=tf.float32, **kwargs): def get_config(self): """Returns a configuration dictionary.""" - config = super(CalibratedLinear, self).get_config() - config['model_config'] = tf.keras.utils.serialize_keras_object( - self.model_config) + config = {'name': self.name, 'trainable': self.trainable} + config['model_config'] = tf.keras.utils.legacy.serialize_keras_object( + self.model_config + ) return config @classmethod def from_config(cls, config, custom_objects=None): - model = super(CalibratedLinear, cls).from_config( - config, custom_objects=custom_objects) - try: - model_config = tf.keras.utils.deserialize_keras_object( - config.get('model_config'), custom_objects=custom_objects) - premade_lib.verify_config(model_config) - model.model_config = model_config - except ValueError: - logging.warning( - 'Could not load model_config. Constructing model without it: %s', - str(config.get('model_config'))) - return model + model_config = tf.keras.utils.legacy.deserialize_keras_object( + config.get('model_config'), custom_objects=custom_objects + ) + premade_lib.verify_config(model_config) + return cls(model_config, + name=config.get('name', None), + trainable=config.get('trainable', True)) # TODO: add support for tf.map_fn and inputs of shape (B, ?, input_dim) @@ -490,25 +478,21 @@ def __init__(self, model_config=None, dtype=tf.float32, **kwargs): def get_config(self): """Returns a configuration dictionary.""" - config = super(AggregateFunction, self).get_config() - config['model_config'] = tf.keras.utils.serialize_keras_object( - self.model_config) + config = {'name': self.name, 'trainable': self.trainable} + config['model_config'] = tf.keras.utils.legacy.serialize_keras_object( + self.model_config + ) return config @classmethod def from_config(cls, config, custom_objects=None): - model = super(AggregateFunction, cls).from_config( - config, custom_objects=custom_objects) - try: - model_config = tf.keras.utils.deserialize_keras_object( - config.get('model_config'), custom_objects=custom_objects) - premade_lib.verify_config(model_config) - model.model_config = model_config - except ValueError: - logging.warning( - 'Could not load model_config. Constructing model without it: %s', - str(config.get('model_config'))) - return model + model_config = tf.keras.utils.legacy.deserialize_keras_object( + config.get('model_config'), custom_objects=custom_objects + ) + premade_lib.verify_config(model_config) + return cls(model_config, + name=config.get('name', None), + trainable=config.get('trainable', True)) def get_custom_objects(custom_objects=None): diff --git a/tensorflow_lattice/python/premade_lib.py b/tensorflow_lattice/python/premade_lib.py index cd84403..79d787a 100644 --- a/tensorflow_lattice/python/premade_lib.py +++ b/tensorflow_lattice/python/premade_lib.py @@ -1514,6 +1514,7 @@ def compute_feature_keypoints(feature_configs, keypoints=feature_config.pwl_calibration_input_keypoints, clip_min=feature_config.pwl_calibration_clip_min, clip_max=feature_config.pwl_calibration_clip_max, + default_value=feature_config.default_value, weights=weights, weight_reduction=weight_reduction, feature_name=feature_name, diff --git a/tensorflow_lattice/python/pwl_calibration_layer.py b/tensorflow_lattice/python/pwl_calibration_layer.py index a4c3190..30e96e3 100644 --- a/tensorflow_lattice/python/pwl_calibration_layer.py +++ b/tensorflow_lattice/python/pwl_calibration_layer.py @@ -515,9 +515,11 @@ def get_config(self): "convexity": self.convexity, "is_cyclic": self.is_cyclic, "kernel_initializer": - keras.initializers.serialize(self.kernel_initializer), + keras.initializers.serialize( + self.kernel_initializer, use_legacy_format=True), "kernel_regularizer": - [keras.regularizers.serialize(r) for r in self.kernel_regularizer], + [keras.regularizers.serialize(r, use_legacy_format=True) + for r in self.kernel_regularizer], "impute_missing": self.impute_missing, "missing_input_value": self.missing_input_value, "num_projection_iterations": self.num_projection_iterations, @@ -633,8 +635,7 @@ def __init__(self, output_min, output_max, monotonicity, keypoints=None): will have equal heights (i.e. `y[i+1] - y[i]` is constant). - if provided, all pieces of returned function will have equal slopes (i.e. `(y[i+1] - y[i]) / (x[i+1] - x[i])` is constant). - """ - # pyformat: enable + """ # pyformat: enable pwl_calibration_lib.verify_hyperparameters( input_keypoints=keypoints, output_min=output_min, @@ -687,8 +688,7 @@ class PWLCalibrationConstraints(keras.constraints.Constraint): Attributes: - All `__init__` arguments. - """ - # pyformat: enable + """ # pyformat: enable def __init__( self, @@ -778,8 +778,7 @@ class NaiveBoundsConstraints(keras.constraints.Constraint): Attributes: - All `__init__` arguments. - """ - # pyformat: enable + """ # pyformat: enable def __init__(self, lower_bound=None, upper_bound=None): """Initializes an instance of `NaiveBoundsConstraints`. @@ -822,8 +821,7 @@ class LaplacianRegularizer(keras.regularizers.Regularizer): Attributes: - All `__init__` arguments. - """ - # pyformat: enable + """ # pyformat: enable def __init__(self, l1=0.0, l2=0.0, is_cyclic=False): """Initializes an instance of `LaplacianRegularizer`. @@ -895,8 +893,7 @@ class HessianRegularizer(keras.regularizers.Regularizer): Attributes: - All `__init__` arguments. - """ - # pyformat: enable + """ # pyformat: enable def __init__(self, l1=0.0, l2=0.0, is_cyclic=False): """Initializes an instance of `HessianRegularizer`. @@ -976,8 +973,7 @@ class WrinkleRegularizer(keras.regularizers.Regularizer): Attributes: - All `__init__` arguments. - """ - # pyformat: enable + """ # pyformat: enable def __init__(self, l1=0.0, l2=0.0, is_cyclic=False): """Initializes an instance of `WrinkleRegularizer`. diff --git a/tensorflow_lattice/python/pwl_calibration_sonnet_module.py b/tensorflow_lattice/python/pwl_calibration_sonnet_module.py deleted file mode 100644 index c3ca32f..0000000 --- a/tensorflow_lattice/python/pwl_calibration_sonnet_module.py +++ /dev/null @@ -1,538 +0,0 @@ -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Piecewise linear calibration layer. - -Sonnet (v2) implementation of tensorflow lattice pwl calibration module. Module -takes single or multi-dimensional input and transforms it using piecewise linear -functions following monotonicity, convexity/concavity and bounds constraints if -specified. -""" - -# TODO: Add built-in regularizers like laplacian, hessian, etc. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from . import pwl_calibration_lib -from . import utils - -from absl import logging -import numpy as np -import sonnet as snt -import tensorflow as tf - -INTERPOLATION_KEYPOINTS_NAME = "interpolation_keypoints" -LENGTHS_NAME = "lengths" -MISSING_INPUT_VALUE_NAME = "missing_input_value" -PWL_CALIBRATION_KERNEL_NAME = "pwl_calibration_kernel" -PWL_CALIBRATION_MISSING_OUTPUT_NAME = "pwl_calibration_missing_output" - - -class PWLCalibration(snt.Module): - # pyformat: disable - """Piecewise linear calibration layer. - - Module takes input of shape `(batch_size, units)` or `(batch_size, 1)` and - transforms it using `units` number of piecewise linear functions following - monotonicity, convexity and bounds constraints if specified. If multi - dimensional input is provides, each output will be for the corresponding - input, otherwise all PWL functions will act on the same input. All units share - the same configuration, but each has their separate set of trained - parameters. - - Input shape: - Single input should be a rank-2 tensor with shape: `(batch_size, units)` or - `(batch_size, 1)`. The input can also be a list of two tensors of the same - shape where the first tensor is the regular input tensor and the second is the - `is_missing` tensor. In the `is_missing` tensor, 1.0 represents missing input - and 0.0 represents available input. - - Output shape: - Rank-2 tensor with shape: `(batch_size, units)`. - - Attributes: - - All `__init__` arguments. - kernel: TF variable which stores weights of piecewise linear function. - missing_output: TF variable which stores output learned for missing input. - Or TF Constant which stores `missing_output_value` if one is provided. - Available only if `impute_missing` is True. - - Example: - - ```python - calibrator = tfl.sonnet_modules.PWLCalibration( - # Key-points of piecewise-linear function. - input_keypoints=np.linspace(1., 4., num=4), - # Output can be bounded, e.g. when this layer feeds into a lattice. - output_min=0.0, - output_max=2.0, - # You can specify monotonicity and other shape constraints for the layer. - monotonicity='increasing', - ) - ``` - """ - # pyformat: enable - - def __init__(self, - input_keypoints, - units=1, - output_min=None, - output_max=None, - clamp_min=False, - clamp_max=False, - monotonicity="none", - convexity="none", - is_cyclic=False, - kernel_init="equal_heights", - impute_missing=False, - missing_input_value=None, - missing_output_value=None, - num_projection_iterations=8, - **kwargs): - # pyformat: disable - """Initializes an instance of `PWLCalibration`. - - Args: - input_keypoints: Ordered list of keypoints of piecewise linear function. - Can be anything accepted by tf.convert_to_tensor(). - units: Output dimension of the layer. See class comments for details. - output_min: Minimum output of calibrator. - output_max: Maximum output of calibrator. - clamp_min: For monotonic calibrators ensures that output_min is reached. - clamp_max: For monotonic calibrators ensures that output_max is reached. - monotonicity: Constraints piecewise linear function to be monotonic using - 'increasing' or 1 to indicate increasing monotonicity, 'decreasing' or - -1 to indicate decreasing monotonicity and 'none' or 0 to indicate no - monotonicity constraints. - convexity: Constraints piecewise linear function to be convex or concave. - Convexity is indicated by 'convex' or 1, concavity is indicated by - 'concave' or -1, 'none' or 0 indicates no convexity/concavity - constraints. - Concavity together with increasing monotonicity as well as convexity - together with decreasing monotonicity results in diminishing return - constraints. - Consider increasing the value of `num_projection_iterations` if - convexity is specified, especially with larger number of keypoints. - is_cyclic: Whether the output for last keypoint should be identical to - output for first keypoint. This is useful for features such as - "time of day" or "degree of turn". If inputs are discrete and exactly - match keypoints then is_cyclic will have an effect only if TFL - regularizers are being used. - kernel_init: None or one of: - - String `"equal_heights"`: For pieces of pwl function to have equal - heights. - - String `"equal_slopes"`: For pieces of pwl function to have equal - slopes. - - Any Sonnet initializer object. If you are passing such object make - sure that you know how this module uses the variables. - impute_missing: Whether to learn an output for cases where input data is - missing. If set to True, either `missing_input_value` should be - initialized, or the `call()` method should get pair of tensors. See - class input shape description for more details. - missing_input_value: If set, all inputs which are equal to this value will - be considered as missing. Can not be set if `impute_missing` is False. - missing_output_value: If set, instead of learning output for missing - inputs, simply maps them into this value. Can not be set if - `impute_missing` is False. - num_projection_iterations: Number of iterations of the Dykstra's - projection algorithm. Constraints are strictly satisfied at the end of - each update, but the update will be closer to a true L2 projection with - higher number of iterations. See - `tfl.pwl_calibration_lib.project_all_constraints` for more details. - **kwargs: Other args passed to `snt.Module` initializer. - - Raises: - ValueError: If layer hyperparameters are invalid. - """ - # pyformat: enable - super(PWLCalibration, self).__init__(**kwargs) - - pwl_calibration_lib.verify_hyperparameters( - input_keypoints=input_keypoints, - output_min=output_min, - output_max=output_max, - monotonicity=monotonicity, - convexity=convexity, - is_cyclic=is_cyclic) - if missing_input_value is not None and not impute_missing: - raise ValueError("'missing_input_value' is specified, but " - "'impute_missing' is set to False. " - "'missing_input_value': " + str(missing_input_value)) - if missing_output_value is not None and not impute_missing: - raise ValueError("'missing_output_value' is specified, but " - "'impute_missing' is set to False. " - "'missing_output_value': " + str(missing_output_value)) - if input_keypoints is None: - raise ValueError("'input_keypoints' can't be None") - if monotonicity is None: - raise ValueError("'monotonicity' can't be None. Did you mean '0'?") - - self.input_keypoints = input_keypoints - self.units = units - self.output_min = output_min - self.output_max = output_max - self.clamp_min = clamp_min - self.clamp_max = clamp_max - (self._output_init_min, self._output_init_max, self._output_min_constraints, - self._output_max_constraints - ) = pwl_calibration_lib.convert_all_constraints(self.output_min, - self.output_max, - self.clamp_min, - self.clamp_max) - - self.monotonicity = monotonicity - self.convexity = convexity - self.is_cyclic = is_cyclic - - if kernel_init == "equal_heights": - self.kernel_init = _UniformOutputInitializer( - output_min=self._output_init_min, - output_max=self._output_init_max, - monotonicity=self.monotonicity) - elif kernel_init == "equal_slopes": - self.kernel_init = _UniformOutputInitializer( - output_min=self._output_init_min, - output_max=self._output_init_max, - monotonicity=self.monotonicity, - keypoints=self.input_keypoints) - - self.impute_missing = impute_missing - self.missing_input_value = missing_input_value - self.missing_output_value = missing_output_value - self.num_projection_iterations = num_projection_iterations - - @snt.once - def _create_parameters_once(self, inputs): - """Creates the variables that will be reused on each call of the module.""" - self.dtype = tf.convert_to_tensor(self.input_keypoints).dtype - input_keypoints = np.array(self.input_keypoints) - # Don't need last keypoint for interpolation because we need only beginnings - # of intervals. - self._interpolation_keypoints = tf.constant( - input_keypoints[:-1], - dtype=self.dtype, - name=INTERPOLATION_KEYPOINTS_NAME) - self._lengths = tf.constant( - input_keypoints[1:] - input_keypoints[:-1], - dtype=self.dtype, - name=LENGTHS_NAME) - - constraints = _PWLCalibrationConstraints( - monotonicity=self.monotonicity, - convexity=self.convexity, - lengths=self._lengths, - output_min=self.output_min, - output_max=self.output_max, - output_min_constraints=self._output_min_constraints, - output_max_constraints=self._output_max_constraints, - num_projection_iterations=self.num_projection_iterations) - - # If 'is_cyclic' is specified - last weight will be computed from previous - # weights in order to connect last keypoint with first. - num_weights = input_keypoints.size - self.is_cyclic - - # PWL calibration layer kernel is units-column matrix. First row of matrix - # represents bias. All remaining represent delta in y-value compare to - # previous point. Aka heights of segments. - self.kernel = tf.Variable( - initial_value=self.kernel_init([num_weights, self.units], - dtype=self.dtype), - name=PWL_CALIBRATION_KERNEL_NAME, - constraint=constraints) - - if self.impute_missing: - if self.missing_input_value is not None: - self._missing_input_value_tensor = tf.constant( - self.missing_input_value, - dtype=self.dtype, - name=MISSING_INPUT_VALUE_NAME) - else: - self._missing_input_value_tensor = None - - if self.missing_output_value is not None: - self.missing_output = tf.constant( - self.missing_output_value, shape=[1, self.units], dtype=self.dtype) - else: - missing_init = (self._output_init_min + self._output_init_max) / 2.0 - missing_constraints = _NaiveBoundsConstraints( - lower_bound=self.output_min, upper_bound=self.output_max) - initializer = snt.initializers.Constant(missing_init) - self.missing_output = tf.Variable( - initial_value=initializer([1, self.units], self.dtype), - name=PWL_CALIBRATION_MISSING_OUTPUT_NAME, - constraint=missing_constraints) - - def __call__(self, inputs): - """Standard Sonnet __call__() method.. - - Args: - inputs: Either input tensor or list of 2 elements: input tensor and - `is_missing` tensor. - - Returns: - Calibrated input tensor. - - Raises: - ValueError: If `is_missing` tensor specified incorrectly. - """ - self._create_parameters_once(inputs) - is_missing = None - if isinstance(inputs, list): - # Only 2 element lists are allowed. When such list is given - second - # element represents 'is_missing' tensor encoded as float value. - if not self.impute_missing: - raise ValueError("Multiple inputs for PWLCalibration module assume " - "regular input tensor and 'is_missing' tensor, but " - "this instance of a layer is not configured to handle " - "missing value. See 'impute_missing' parameter.") - if len(inputs) > 2: - raise ValueError("Multiple inputs for PWLCalibration module assume " - "normal input tensor and 'is_missing' tensor, but more" - " than 2 tensors given. 'inputs': " + str(inputs)) - if len(inputs) == 2: - inputs, is_missing = inputs - if is_missing.shape.as_list() != inputs.shape.as_list(): - raise ValueError( - "is_missing shape %s does not match inputs shape %s for " - "PWLCalibration module" % - (str(is_missing.shape), str(inputs.shape))) - else: - [inputs] = inputs - if len(inputs.shape) != 2 or (inputs.shape[1] != self.units and - inputs.shape[1] != 1): - raise ValueError("Shape of input tensor for PWLCalibration module must" - " be [-1, units] or [-1, 1]. It is: " + - str(inputs.shape)) - - if self._interpolation_keypoints.dtype != inputs.dtype: - raise ValueError("dtype(%s) of input to PWLCalibration module does not " - "correspond to dtype(%s) of keypoints. You can enforce " - "dtype of keypoints by passing keypoints " - "in such format which by default will be converted into " - "the desired one." % - (inputs.dtype, self._interpolation_keypoints.dtype)) - # Here is calibration. Everything else is handling of missing. - if inputs.shape[1] > 1: - # Add dimension to multi dim input to get shape [batch_size, units, 1]. - # Interpolation will have shape [batch_size, units, weights]. - inputs_to_calibration = tf.expand_dims(inputs, -1) - else: - inputs_to_calibration = inputs - interpolation_weights = pwl_calibration_lib.compute_interpolation_weights( - inputs_to_calibration, self._interpolation_keypoints, self._lengths) - if self.is_cyclic: - # Need to add such last height to make all heights to sum up to 0.0 in - # order to make calibrator cyclic. - bias_and_heights = tf.concat( - [self.kernel, -tf.reduce_sum(self.kernel[1:], axis=0, keepdims=True)], - axis=0) - else: - bias_and_heights = self.kernel - - # bias_and_heights has shape [weight, units]. - if inputs.shape[1] > 1: - # Multi dim input has interpolation shape [batch_size, units, weights]. - result = tf.reduce_sum( - interpolation_weights * tf.transpose(bias_and_heights), axis=-1) - else: - # Single dim input has interpolation shape [batch_size, weights]. - result = tf.matmul(interpolation_weights, bias_and_heights) - - if self.impute_missing: - if is_missing is None: - if self.missing_input_value is None: - raise ValueError("PWLCalibration layer is configured to impute " - "missing but no 'missing_input_value' specified and " - "'is_missing' tensor is not given.") - assert self._missing_input_value_tensor is not None - is_missing = tf.cast( - tf.equal(inputs, self._missing_input_value_tensor), - dtype=self.dtype) - result = is_missing * self.missing_output + (1.0 - is_missing) * result - return result - - -class _UniformOutputInitializer(snt.initializers.Initializer): - # pyformat: disable - """Initializes PWL calibration layer to represent linear function. - - PWL calibration layer weights are one-d tensor. First element of tensor - represents bias. All remaining represent delta in y-value compare to previous - point. Aka heights of segments. - - Attributes: - - All `__init__` arguments. - """ - # pyformat: enable - - def __init__(self, output_min, output_max, monotonicity, keypoints=None): - # pyformat: disable - """Initializes an instance of `_UniformOutputInitializer`. - - Args: - output_min: Minimum value of PWL calibration output after initialization. - output_max: Maximum value of PWL calibration output after initialization. - monotonicity: - - if 'none' or 'increasing', the returned function will go from - `(input_min, output_min)` to `(input_max, output_max)`. - - if 'decreasing', the returned function will go from - `(input_min, output_max)` to `(input_max, output_min)`. - keypoints: - - if not provided (None or []), all pieces of returned function - will have equal heights (i.e. `y[i+1] - y[i]` is constant). - - if provided, all pieces of returned function will have equal slopes - (i.e. `(y[i+1] - y[i]) / (x[i+1] - x[i])` is constant). - """ - # pyformat: enable - pwl_calibration_lib.verify_hyperparameters( - input_keypoints=keypoints, - output_min=output_min, - output_max=output_max, - monotonicity=monotonicity) - self.output_min = output_min - self.output_max = output_max - self.monotonicity = monotonicity - self.keypoints = keypoints - - def __call__(self, shape, dtype): - """Returns weights of PWL calibration layer. - - Args: - shape: Must be a collection of the form `(k, units)` where `k >= 2`. - dtype: Standard Sonnet initializer param. - - Returns: - Weights of PWL calibration layer. - - Raises: - ValueError: If requested shape is invalid for PWL calibration layer - weights. - """ - return pwl_calibration_lib.linear_initializer( - shape=shape, - output_min=self.output_min, - output_max=self.output_max, - monotonicity=utils.canonicalize_monotonicity(self.monotonicity), - keypoints=self.keypoints, - dtype=dtype) - - -class _PWLCalibrationConstraints(object): - # pyformat: disable - """Monotonicity and bounds constraints for PWL calibration layer. - - Applies an approximate L2 projection to the weights of a PWLCalibration layer - such that the result satisfies the specified constraints. - - Attributes: - - All `__init__` arguments. - """ - # pyformat: enable - - def __init__( - self, - monotonicity="none", - convexity="none", - lengths=None, - output_min=None, - output_max=None, - output_min_constraints=pwl_calibration_lib.BoundConstraintsType.NONE, - output_max_constraints=pwl_calibration_lib.BoundConstraintsType.NONE, - num_projection_iterations=8): - """Initializes an instance of `PWLCalibration`. - - Args: - monotonicity: Same meaning as corresponding parameter of `PWLCalibration`. - convexity: Same meaning as corresponding parameter of `PWLCalibration`. - lengths: Lengths of pieces of piecewise linear function. Needed only if - convexity is specified. - output_min: Minimum possible output of pwl function. - output_max: Maximum possible output of pwl function. - output_min_constraints: A `tfl.pwl_calibration_lib.BoundConstraintsType` - describing the constraints on the layer's minimum value. - output_max_constraints: A `tfl.pwl_calibration_lib.BoundConstraintsType` - describing the constraints on the layer's maximum value. - num_projection_iterations: Same meaning as corresponding parameter of - `PWLCalibration`. - """ - pwl_calibration_lib.verify_hyperparameters( - output_min=output_min, - output_max=output_max, - monotonicity=monotonicity, - convexity=convexity, - lengths=lengths) - self.monotonicity = monotonicity - self.convexity = convexity - self.lengths = lengths - self.output_min = output_min - self.output_max = output_max - self.output_min_constraints = output_min_constraints - self.output_max_constraints = output_max_constraints - self.num_projection_iterations = num_projection_iterations - - canonical_convexity = utils.canonicalize_convexity(self.convexity) - canonical_monotonicity = utils.canonicalize_monotonicity(self.monotonicity) - if (canonical_convexity != 0 and canonical_monotonicity == 0 and - (output_min_constraints != pwl_calibration_lib.BoundConstraintsType.NONE - or output_max_constraints != - pwl_calibration_lib.BoundConstraintsType.NONE)): - logging.warning("Convexity constraints are specified with bounds " - "constraints, but without monotonicity. Such combination " - "might lead to convexity being slightly violated. " - "Consider increasing num_projection_iterations to " - "reduce violation.") - - def __call__(self, w): - """Applies constraints to w.""" - return pwl_calibration_lib.project_all_constraints( - weights=w, - monotonicity=utils.canonicalize_monotonicity(self.monotonicity), - output_min=self.output_min, - output_max=self.output_max, - output_min_constraints=self.output_min_constraints, - output_max_constraints=self.output_max_constraints, - convexity=utils.canonicalize_convexity(self.convexity), - lengths=self.lengths, - num_projection_iterations=self.num_projection_iterations) - - -class _NaiveBoundsConstraints(object): - # pyformat: disable - """Naively clips all elements of tensor to be within bounds. - - This constraint is used only for the weight tensor for missing output value. - - Attributes: - - All `__init__` arguments. - """ - # pyformat: enable - - def __init__(self, lower_bound=None, upper_bound=None): - """Initializes an instance of `_NaiveBoundsConstraints`. - - Args: - lower_bound: Lower bound to clip variable values to. - upper_bound: Upper bound to clip variable values to. - """ - self.lower_bound = lower_bound - self.upper_bound = upper_bound - - def __call__(self, w): - """Applies constraints to w.""" - if self.lower_bound is not None: - w = tf.maximum(w, self.lower_bound) - if self.upper_bound is not None: - w = tf.minimum(w, self.upper_bound) - return w diff --git a/tensorflow_lattice/python/pwl_calibration_test.py b/tensorflow_lattice/python/pwl_calibration_test.py index 51029f7..f4c491d 100644 --- a/tensorflow_lattice/python/pwl_calibration_test.py +++ b/tensorflow_lattice/python/pwl_calibration_test.py @@ -33,7 +33,6 @@ from tensorflow import keras from tensorflow_lattice.python import parallel_combination_layer as parallel_combination from tensorflow_lattice.python import pwl_calibration_layer as keras_layer -from tensorflow_lattice.python import pwl_calibration_sonnet_module as sonnet_module from tensorflow_lattice.python import test_utils from tensorflow_lattice.python import utils @@ -330,187 +329,6 @@ def _CreateKerasLayer(self, config): num_projection_iterations=config["num_projection_iterations"], dtype=config["dtype"]) - def _CreateSonnetModule(self, config): - missing_input_value = config["missing_input_value"] - if config["use_separate_missing"]: - # We use 'config["missing_input_value"]' to create the is_missing tensor, - # and we want the model to use the is_missing tensor so we don't pass - # a missing_input_value to the model. - missing_input_value = None - return sonnet_module.PWLCalibration( - input_keypoints=config["input_keypoints"], - units=config["units"], - output_min=config["output_min"], - output_max=config["output_max"], - clamp_min=config["clamp_min"], - clamp_max=config["clamp_max"], - monotonicity=config["monotonicity"], - convexity=config["convexity"], - is_cyclic=config["is_cyclic"], - kernel_init=config["initializer"], - impute_missing=config["impute_missing"], - missing_input_value=missing_input_value, - missing_output_value=config["missing_output_value"], - num_projection_iterations=config["num_projection_iterations"]) - - def _AssertSonnetEquivalentToKeras(self, config): - training_inputs, training_labels = self._CreateTrainingData(config) - keras_layer_ctor = lambda: self._CreateKerasLayer(config) - sonnet_module_ctor = lambda: self._CreateSonnetModule(config) - test_utils.assert_sonnet_equivalent_to_keras( - test=self, - sonnet_module_ctor=sonnet_module_ctor, - keras_layer_ctor=keras_layer_ctor, - training_inputs=training_inputs, - training_labels=training_labels, - ) - - def _createConfig(self, **kwargs): - config = dict(kwargs) - return self._SetDefaults(config) - - def testSonnetDefaultValues(self): - """Compares the sonnet module with default values to the keras layer.""" - if self._disable_all: - return - config = self._createConfig( - input_keypoints=[0, 0.25, 0.5, 1.0], - units=10, - x_generator=self._ScatterXUniformly, - y_function=self._SmallWaves, - num_training_records=100, - ) - self._AssertSonnetEquivalentToKeras(config) - - def testSonnetOutputMinOutputMax(self): - if self._disable_all: - return - config = self._createConfig( - input_keypoints=[0, 0.25, 0.5, 1.0], - units=10, - x_generator=self._ScatterXUniformly, - y_function=self._SmallWaves, - num_training_records=100, - output_min=1.0, - output_max=10.0, - ) - self._AssertSonnetEquivalentToKeras(config) - - def testSonnetClampMinClampMax(self): - if self._disable_all: - return - config = self._createConfig( - input_keypoints=[0, 0.25, 0.5, 1.0], - units=10, - x_generator=self._ScatterXUniformly, - y_function=self._SmallWaves, - num_training_records=100, - clamp_min=1.0, - output_max=10.0, - ) - self._AssertSonnetEquivalentToKeras(config) - - def testSonnetMonotonicity(self): - if self._disable_all: - return - for monotonicity in ["increasing", 1, "decreasing", -1]: - config = self._createConfig( - input_keypoints=[0, 0.25, 0.5, 1.0], - units=10, - x_generator=self._ScatterXUniformly, - y_function=self._SmallWaves, - num_training_records=100, - monotonicity=monotonicity, - ) - self._AssertSonnetEquivalentToKeras(config) - - def testSonnetConvexity(self): - if self._disable_all: - return - for convexity in ["convex", 1, "concave", -1]: - config = self._createConfig( - input_keypoints=[0, 0.25, 0.5, 1.0], - units=10, - x_generator=self._ScatterXUniformly, - y_function=self._SmallWaves, - num_training_records=100, - convexity=convexity, - ) - self._AssertSonnetEquivalentToKeras(config) - - def testSonnetIsCyclic(self): - if self._disable_all: - return - config = self._createConfig( - input_keypoints=[0, 0.25, 0.5, 1.0], - units=10, - x_generator=self._ScatterXUniformly, - y_function=self._SmallWaves, - num_training_records=100, - is_cyclic=True, - ) - self._AssertSonnetEquivalentToKeras(config) - - def testSonnetKernelInit(self): - if self._disable_all: - return - # kernel_init="equal_heights" is the default and is tested in - # testSonnetDefaultValues, so we don't test it here. - for kernel_init in [None, "equal_slopes"]: - config = self._createConfig( - input_keypoints=[0, 0.25, 0.5, 1.0], - units=10, - x_generator=self._ScatterXUniformly, - y_function=self._SmallWaves, - num_training_records=100, - kernel_init=kernel_init, - ) - self._AssertSonnetEquivalentToKeras(config) - - def testSonnetMissingInputValue(self): - if self._disable_all: - return - config = self._createConfig( - input_keypoints=[0, 0.25, 0.5, 1.0], - units=10, - x_generator=self._ScatterXUniformly, - y_function=self._SmallWaves, - num_training_records=100, - impute_missing=True, - missing_input_value=3, - missing_probability=0.5, - ) - self._AssertSonnetEquivalentToKeras(config) - - def testSonnetMissingOutputValue(self): - if self._disable_all: - return - config = self._createConfig( - input_keypoints=[0, 0.25, 0.5, 1.0], - units=10, - x_generator=self._ScatterXUniformly, - y_function=self._SmallWaves, - num_training_records=100, - impute_missing=True, - missing_input_value=3, - missing_probability=0.5, - missing_output_value=10, - ) - self._AssertSonnetEquivalentToKeras(config) - - def testSonnetNumProjectionIterations(self): - if self._disable_all: - return - config = self._createConfig( - input_keypoints=[0, 0.25, 0.5, 1.0], - units=10, - x_generator=self._ScatterXUniformly, - y_function=self._SmallWaves, - num_training_records=100, - num_projection_iterations=2, - ) - self._AssertSonnetEquivalentToKeras(config) - @parameterized.parameters( (1, False, 0.001022, "fixed"), (3, False, 0.000543, "fixed"), diff --git a/tensorflow_lattice/python/rtl_test.py b/tensorflow_lattice/python/rtl_test.py index 1ed981d..6a0aa51 100644 --- a/tensorflow_lattice/python/rtl_test.py +++ b/tensorflow_lattice/python/rtl_test.py @@ -204,6 +204,7 @@ def testRTLSaveLoad(self): model = tf.keras.Model( inputs=[input_c, input_d, input_e, input_f], outputs=outputs) model.compile(loss="mse") + model.use_legacy_config = True with tempfile.NamedTemporaryFile(suffix=".h5") as f: model.save(f.name) diff --git a/tensorflow_lattice/python/test_utils.py b/tensorflow_lattice/python/test_utils.py index 229e534..41182c1 100644 --- a/tensorflow_lattice/python/test_utils.py +++ b/tensorflow_lattice/python/test_utils.py @@ -19,11 +19,11 @@ from __future__ import print_function import time -from . import visualization + from absl import logging import numpy as np -import sonnet as snt -import tensorflow as tf + +from . import visualization class TimeTracker(object): @@ -127,98 +127,6 @@ def run_training_loop(config, return loss -def assert_sonnet_equivalent_to_keras( - test, sonnet_module_ctor, keras_layer_ctor, - training_inputs, training_labels, - epsilon=1e-4): - """Asserts that a Sonnet module is "equivalent" to a Keras layer. - - Creates a Sonnet module and a Keras layer using the given constructors. It - then uses both models to evaluate the given 'training_inputs' tensor and - asserts that the results are equal. It then trains both models and asserts - that the final loss (w.r.t the given 'training_labels') and the post-training - predictions of both models on 'training_inputs' are also equal. - - Args: - test: a tf.test.TestCase whose 'assert...' methods to use for assertion. - sonnet_module_ctor: A callable that takes no arguments and returns the - Sonnet module to use. - keras_layer_ctor: A callable that takes no arguments and returns the - Keras layer to use. - training_inputs: Tensor of shape (batch_size, ....) tensor containing the - training inputs. - training_labels: Tensor of shape (batch_size, ....). tensor containing the - training labels. - epsilon: float. Sensitivity of comparison. Comparison of model predictions - and losses are done using test.assertNear and test.assertNDArrayNear. - This is the value to pass as the 'err' parameter to these assertion - methods. - """ - # This function assumes we're executing eagerly. - test.assertTrue(tf.executing_eagerly()) - - num_training_epochs = 10 - num_training_inputs = training_inputs.shape[0] - - # Create Keras model. - keras_model = tf.keras.models.Sequential(layers=[keras_layer_ctor()]) - keras_model.compile( - loss=tf.keras.losses.mean_squared_error, - optimizer=tf.keras.optimizers.SGD(learning_rate=0.1)) - - keras_preds_pre_training = keras_model.predict( - x=training_inputs, batch_size=num_training_inputs) - - # Train the Keras model. - keras_model.fit(x=training_inputs, y=training_labels, - batch_size=num_training_inputs, - epochs=num_training_epochs, - verbose=0) - keras_loss = keras_model.evaluate(x=training_inputs, y=training_labels, - batch_size=num_training_inputs, - verbose=0) - keras_preds_post_training = keras_model.predict( - x=training_inputs, batch_size=num_training_inputs) - - # Create the Sonnet model - sonnet_module = sonnet_module_ctor() - - sonnet_preds_pre_training = sonnet_module(training_inputs).numpy() - - # Train the Sonnet model - sonnet_optimizer = snt.optimizers.SGD(learning_rate=0.1) - mse_loss = tf.keras.losses.MeanSquaredError() - for _ in range(num_training_epochs): - with tf.GradientTape() as tape: - preds = sonnet_module(training_inputs) - sonnet_loss = mse_loss(training_labels, preds) - - params = sonnet_module.trainable_variables - grads = tape.gradient(sonnet_loss, params) - sonnet_optimizer.apply(grads, params) - # We need to apply constraints explicitly in Sonnet - for var in params: - if var.constraint: - var.assign(var.constraint(var)) - - sonnet_preds_post_training = sonnet_module(training_inputs).numpy() - sonnet_loss = mse_loss(training_labels, sonnet_preds_post_training).numpy() - # Note that corresponding initializers between Sonnet and Keras may have - # different default values for their construction arguments. E.g. - # https://sonnet.readthedocs.io/en/latest/api.html#randomuniform - # https://www.tensorflow.org/api_docs/python/tf/keras/initializers/RandomUniform - # which may cause these assertions to fail. When comparing models make - # sure th initializers behave the same. - test.assertNDArrayNear(sonnet_preds_pre_training, - keras_preds_pre_training, - err=epsilon) - test.assertNear(sonnet_loss, keras_loss, err=epsilon) - test.assertNDArrayNear( - sonnet_preds_post_training, - keras_preds_post_training, - err=epsilon) - - def two_dim_mesh_grid(num_points, x_min, y_min, x_max, y_max): """Generates uniform 2-d mesh grid for 3-d surfaces visualisation via pyplot. diff --git a/tensorflow_lattice/sonnet_modules/__init__.py b/tensorflow_lattice/sonnet_modules/__init__.py deleted file mode 100644 index 0db2b21..0000000 --- a/tensorflow_lattice/sonnet_modules/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""'sonnet_modules' namespace for TFL layers.""" - -from tensorflow_lattice.python.pwl_calibration_sonnet_module import PWLCalibration