From 04a318d89806e0aaa645c2e261dfd93d0ed200bc Mon Sep 17 00:00:00 2001 From: TensorFlow Lattice Authors Date: Tue, 16 Feb 2021 14:28:15 -0800 Subject: [PATCH] Internal change PiperOrigin-RevId: 357805569 Change-Id: I3a68210ed3aa5dedd248a8a38ffaf67e9a8dfdf1 --- docs/tutorials/shape_constraints.ipynb | 16 +- examples/custom_estimators_uci_heart.py | 19 +- setup.py | 2 +- tensorflow_lattice/python/BUILD | 7 + tensorflow_lattice/python/configs.py | 66 ++- tensorflow_lattice/python/estimators.py | 324 +++++++++++--- tensorflow_lattice/python/estimators_test.py | 34 +- .../kronecker_factored_lattice_layer.py | 353 +++++++++++++-- .../python/kronecker_factored_lattice_lib.py | 335 ++++++++++++-- .../python/kronecker_factored_lattice_test.py | 273 +++++++++-- tensorflow_lattice/python/lattice_layer.py | 63 +-- tensorflow_lattice/python/lattice_lib.py | 25 ++ tensorflow_lattice/python/model_info.py | 55 ++- tensorflow_lattice/python/premade.py | 13 + tensorflow_lattice/python/premade_lib.py | 422 +++++++++++++----- tensorflow_lattice/python/premade_test.py | 141 +++++- .../python/pwl_calibration_layer.py | 99 +++- .../python/pwl_calibration_lib.py | 31 +- .../python/pwl_calibration_test.py | 54 ++- tensorflow_lattice/python/rtl_layer.py | 138 +++++- tensorflow_lattice/python/rtl_lib.py | 39 +- tensorflow_lattice/python/visualization.py | 2 + 22 files changed, 2072 insertions(+), 439 deletions(-) diff --git a/docs/tutorials/shape_constraints.ipynb b/docs/tutorials/shape_constraints.ipynb index 63d2149..e224678 100644 --- a/docs/tutorials/shape_constraints.ipynb +++ b/docs/tutorials/shape_constraints.ipynb @@ -154,9 +154,9 @@ }, "outputs": [], "source": [ - "NUM_EPOCHS = 500\n", + "NUM_EPOCHS = 1000\n", "BATCH_SIZE = 64\n", - "LEARNING_RATE=0.001" + "LEARNING_RATE=0.01" ] }, { @@ -376,9 +376,9 @@ "\n", "# Generate datasets.\n", "np.random.seed(42)\n", - "data_train = sample_dataset(2000, testing_set=False)\n", - "data_val = sample_dataset(1000, testing_set=False)\n", - "data_test = sample_dataset(1000, testing_set=True)\n", + "data_train = sample_dataset(500, testing_set=False)\n", + "data_val = sample_dataset(500, testing_set=False)\n", + "data_test = sample_dataset(500, testing_set=True)\n", "\n", "# Plotting dataset densities.\n", "figsize(12, 5)\n", @@ -536,9 +536,9 @@ " feature_columns=feature_columns,\n", " # Hyper-params optimized on validation set.\n", " n_batches_per_layer=1,\n", - " max_depth=3,\n", - " n_trees=20,\n", - " min_node_weight=0.1,\n", + " max_depth=2,\n", + " n_trees=50,\n", + " learning_rate=0.05,\n", " config=tf.estimator.RunConfig(tf_random_seed=42),\n", ")\n", "gbt_estimator.train(input_fn=train_input_fn)\n", diff --git a/examples/custom_estimators_uci_heart.py b/examples/custom_estimators_uci_heart.py index 6fbab8c..9e45d31 100644 --- a/examples/custom_estimators_uci_heart.py +++ b/examples/custom_estimators_uci_heart.py @@ -18,7 +18,7 @@ This example trains a TFL custom estimators on the UCI heart dataset. Example usage: -custom_estimators_uci_heart --num_epochs=40 +custom_estimators_uci_heart --num_epochs=5000 """ from __future__ import absolute_import @@ -38,7 +38,7 @@ FLAGS = flags.FLAGS flags.DEFINE_float('learning_rate', 0.01, 'Learning rate.') flags.DEFINE_integer('batch_size', 100, 'Batch size.') -flags.DEFINE_integer('num_epochs', 200, 'Number of training epoch.') +flags.DEFINE_integer('num_epochs', 2000, 'Number of training epoch.') def main(_): @@ -121,7 +121,7 @@ def model_fn(features, labels, mode, config): tfl.layers.PWLCalibration( input_keypoints=[0.0, 1.0, 2.0, 3.0], output_min=0.0, - output_max=lattice_sizes[0] - 1.0, + output_max=lattice_sizes[2] - 1.0, # You can specify TFL regularizers as tuple # ('regularizer name', l1, l2). kernel_regularizer=('hessian', 0.0, 1e-4), @@ -130,7 +130,7 @@ def model_fn(features, labels, mode, config): tfl.layers.CategoricalCalibration( num_buckets=3, output_min=0.0, - output_max=lattice_sizes[1] - 1.0, + output_max=lattice_sizes[3] - 1.0, # Categorical monotonicity can be partial order. # (i, j) indicates that we must have output(i) <= output(i). # Make sure to set the lattice monotonicity to 1 for this dimension. @@ -138,8 +138,15 @@ def model_fn(features, labels, mode, config): )(inputs['thal']), ]) output = tfl.layers.Lattice( - lattice_sizes=lattice_sizes, monotonicities=lattice_monotonicities)( - lattice_input) + lattice_sizes=lattice_sizes, + monotonicities=lattice_monotonicities, + # Add a kernel_initializer so that the Lattice is not initialized as a + # flat plane. The output_min and output_max could be arbitrary, as long + # as output_min < output_max. + kernel_initializer=tfl.lattice_layer.RandomMonotonicInitializer( + lattice_sizes=lattice_sizes, output_min=-10, output_max=10), + )( + lattice_input) training = (mode == tf.estimator.ModeKeys.TRAIN) model = tf.keras.Model(inputs=inputs, outputs=output) diff --git a/setup.py b/setup.py index 49bdc24..2cef6fd 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ # This version number should always be that of the *next* (unreleased) version. # Immediately after uploading a package to PyPI, you should increment the # version number and push to gitHub. -__version__ = "2.0.7" +__version__ = "2.0.8" if "--release" in sys.argv: sys.argv.remove("--release") diff --git a/tensorflow_lattice/python/BUILD b/tensorflow_lattice/python/BUILD index aefa309..3c41bf3 100644 --- a/tensorflow_lattice/python/BUILD +++ b/tensorflow_lattice/python/BUILD @@ -117,6 +117,7 @@ py_library( deps = [ ":categorical_calibration_layer", ":configs", + ":kronecker_factored_lattice_layer", ":lattice_layer", ":linear_layer", ":model_info", @@ -197,6 +198,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":kronecker_factored_lattice_layer", + ":kronecker_factored_lattice_lib", ":test_utils", # absl/logging dep, # absl/testing:parameterized dep, @@ -329,6 +331,7 @@ py_library( ":aggregation_layer", ":categorical_calibration_layer", ":configs", + ":kronecker_factored_lattice_layer", ":lattice_layer", ":parallel_combination_layer", ":premade_lib", @@ -346,6 +349,8 @@ py_library( ":aggregation_layer", ":categorical_calibration_layer", ":configs", + ":kronecker_factored_lattice_layer", + ":kronecker_factored_lattice_lib", ":lattice_layer", ":lattice_lib", ":linear_layer", @@ -372,6 +377,7 @@ py_test( ":premade", ":premade_lib", # absl/logging dep, + # absl/testing:parameterized dep, # numpy dep, # tensorflow dep, ], @@ -439,6 +445,7 @@ py_library( srcs = ["rtl_layer.py"], srcs_version = "PY2AND3", deps = [ + ":kronecker_factored_lattice_layer", ":lattice_layer", ":rtl_lib", # tensorflow:tensorflow_no_contrib dep, diff --git a/tensorflow_lattice/python/configs.py b/tensorflow_lattice/python/configs.py index a561079..12c2d72 100644 --- a/tensorflow_lattice/python/configs.py +++ b/tensorflow_lattice/python/configs.py @@ -267,6 +267,8 @@ def __init__(self, num_lattices=None, lattice_rank=None, interpolation='hypercube', + parameterization='all_vertices', + num_terms=2, separate_calibrators=True, use_linear_combination=False, use_bias=False, @@ -305,6 +307,34 @@ def __init__(self, 'simplex' uses d+1 parameters and thus scales better. For details see `tfl.lattice_lib.evaluate_with_simplex_interpolation` and `tfl.lattice_lib.evaluate_with_hypercube_interpolation`. + parameterization: The parameterization of the lattice function class to + use. A lattice function is uniquely determined by specifying its value + on every lattice vertex. A parameterization scheme is a mapping from a + vector of parameters to a multidimensional array of lattice vertex + values. It can be one of: + - String `'all_vertices'`: This is the "traditional" parameterization + that keeps one scalar parameter per lattice vertex where the mapping + is essentially the identity map. With this scheme, the number of + parameters scales exponentially with the number of inputs to the + lattice. The underlying lattices used will be `tfl.layers.Lattice` + layers. + - String `'kronecker_factored'`: With this parameterization, for each + lattice input i we keep a collection of `num_terms` vectors each + having `feature_configs[0].lattice_size` entries (note that all + features must have the same lattice size). To obtain the tensor of + lattice vertex values, for `t=1,2,...,num_terms` we compute the + outer product of the `t'th` vector in each collection, multiply by a + per-term scale, and sum the resulting tensors. Finally, we add a + single shared bias parameter to each entry in the sum. With this + scheme, the number of parameters grows linearly with `lattice_rank` + (assuming lattice sizes and `num_terms` are held constant). + Currently, only monotonicity shape constraint and bound constraint + are supported for this scheme. Regularization is not currently + supported. The underlying lattices used will be + `tfl.layers.KroneckerFactoredLattice` layers. + num_terms: The number of terms in a lattice using `'kronecker_factored'` + parameterization. Ignored if parameterization is set to + `'all_vertices'`. separate_calibrators: If features should be separately calibrated for each lattice in the ensemble. use_linear_combination: If set to true, a linear combination layer will be @@ -375,12 +405,15 @@ class CalibratedLatticeConfig(_Config, _HasFeatureConfigs, def __init__(self, feature_configs=None, interpolation='hypercube', + parameterization='all_vertices', + num_terms=2, regularizer_configs=None, output_min=None, output_max=None, output_calibration=False, output_calibration_num_keypoints=10, - output_initialization='quantiles'): + output_initialization='quantiles', + random_seed=0): """Initializes a `CalibratedLatticeConfig` instance. Args: @@ -392,6 +425,34 @@ def __init__(self, 'simplex' uses d+1 parameters and thus scales better. For details see `tfl.lattice_lib.evaluate_with_simplex_interpolation` and `tfl.lattice_lib.evaluate_with_hypercube_interpolation`. + parameterization: The parameterization of the lattice function class to + use. A lattice function is uniquely determined by specifying its value + on every lattice vertex. A parameterization scheme is a mapping from a + vector of parameters to a multidimensional array of lattice vertex + values. It can be one of: + - String `'all_vertices'`: This is the "traditional" parameterization + that keeps one scalar parameter per lattice vertex where the mapping + is essentially the identity map. With this scheme, the number of + parameters scales exponentially with the number of inputs to the + lattice. The underlying lattice used will be a `tfl.layers.Lattice` + layer. + - String `'kronecker_factored'`: With this parameterization, for each + lattice input i we keep a collection of `num_terms` vectors each + having `feature_configs[0].lattice_size` entries (note that all + features must have the same lattice size). To obtain the tensor of + lattice vertex values, for `t=1,2,...,num_terms` we compute the + outer product of the `t'th` vector in each collection, multiply by a + per-term scale, and sum the resulting tensors. Finally, we add a + single shared bias parameter to each entry in the sum. With this + scheme, the number of parameters grows linearly with + `len(feature_configs)` (assuming lattice sizes and `num_terms` are + held constant). Currently, only monotonicity shape constraint and + bound constraint are supported for this scheme. Regularization is + not currently supported. The underlying lattice used will be a + `tfl.layers.KroneckerFactoredLattice` layer. + num_terms: The number of terms in a lattice using `'kronecker_factored'` + parameterization. Ignored if parameterization is set to + `'all_vertices'`. regularizer_configs: A list of `tfl.configs.RegularizerConfig` instances that apply global regularization. output_min: Lower bound constraint on the output of the model. @@ -410,6 +471,9 @@ def __init__(self, - String `'uniform'`: Output is initliazed uniformly in label range. - A list of numbers: To be used for initialization of the output lattice or output calibrator. + random_seed: Random seed to use for initialization of a lattice with + `'kronecker_factored'` parameterization. Ignored if parameterization is + set to `'all_vertices'`. """ super(CalibratedLatticeConfig, self).__init__(locals()) diff --git a/tensorflow_lattice/python/estimators.py b/tensorflow_lattice/python/estimators.py index 88f266a..b052898 100644 --- a/tensorflow_lattice/python/estimators.py +++ b/tensorflow_lattice/python/estimators.py @@ -73,6 +73,7 @@ from . import categorical_calibration_layer from . import configs +from . import kronecker_factored_lattice_layer as kfll from . import lattice_layer from . import linear_layer from . import model_info @@ -1433,8 +1434,251 @@ def _create_lattice_nodes(sess, ops, graph, submodel_input_nodes): return lattice_nodes -def _create_rtl_lattice_nodes(sess, ops, graph, calibration_nodes_map): - """Returns a map from lattice_submodel_index to LatticeNode.""" +def _create_kronecker_factored_lattice_nodes(sess, ops, graph, + submodel_input_nodes): + """Returns a map from submodel_idx to KroneckerFactoredLatticeNode.""" + kfl_nodes = {} + # KroneckerFactoredLattice kernel weights. + # {KFL_LAYER_NAME}_{submodel_idx}/{KFL_KERNEL_NAME} + kfl_kernel_op_re = '^{}_(.*)/{}/Read/ReadVariableOp$'.format( + premade_lib.KFL_LAYER_NAME, + kfll.KFL_KERNEL_NAME, + ) + for kfl_kernel_op, submodel_idx in _match_op(ops, kfl_kernel_op_re): + kfl_kernel = sess.run( + graph.get_operation_by_name(kfl_kernel_op).outputs[0]).flatten() + + # KroneckerFactoredLattice scale. + # {KFL_LAYER_NAME}_{submodel_idx}/{KFL_SCALE_NAME} + kfl_scale_op_name = '{}_{}/{}/Read/ReadVariableOp'.format( + premade_lib.KFL_LAYER_NAME, submodel_idx, kfll.KFL_SCALE_NAME) + kfl_scale = sess.run( + graph.get_operation_by_name(kfl_scale_op_name).outputs[0]).flatten() + + # KroneckerFactoredLattice bias. + # {KFL_LAYER_NAME}_{submodel_idx}/{KFL_BIAS_NAME} + kfl_bias_op_name = '{}_{}/{}/Read/ReadVariableOp'.format( + premade_lib.KFL_LAYER_NAME, submodel_idx, kfll.KFL_BIAS_NAME) + kfl_bias = sess.run( + graph.get_operation_by_name(kfl_bias_op_name).outputs[0]).flatten() + + # Lattice sizes. + # {KFL_LAYER_NAME}_{submodel_idx}/{LATTICE_SIZES_NAME} + lattice_sizes_op_name = '{}_{}/{}'.format(premade_lib.KFL_LAYER_NAME, + submodel_idx, + kfll.LATTICE_SIZES_NAME) + lattice_sizes = sess.run( + graph.get_operation_by_name( + lattice_sizes_op_name).outputs[0]) + + # Units. + # {KFL_LAYER_NAME}_{submodel_idx}/{UNITS_NAME} + units_op_name = '{}_{}/{}'.format(premade_lib.KFL_LAYER_NAME, + submodel_idx, kfll.UNITS_NAME) + units = sess.run( + graph.get_operation_by_name(units_op_name).outputs[0]) + + # Dims. + # {KFL_LAYER_NAME}_{submodel_idx}/{DIMS_NAME} + dims_op_name = '{}_{}/{}'.format(premade_lib.KFL_LAYER_NAME, + submodel_idx, kfll.DIMS_NAME) + dims = sess.run( + graph.get_operation_by_name(dims_op_name).outputs[0]) + + # Num terms. + # {KFL_LAYER_NAME}_{submodel_idx}/{NUM_TERMS_NAME} + num_terms_op_name = '{}_{}/{}'.format(premade_lib.KFL_LAYER_NAME, + submodel_idx, kfll.NUM_TERMS_NAME) + num_terms = sess.run( + graph.get_operation_by_name(num_terms_op_name).outputs[0]) + + # Shape the flat weights, scale, and bias parameters based on the calculated + # lattice_sizes, units, dims, and num_terms. + weights = np.reshape(kfl_kernel, + (1, lattice_sizes, units * dims, num_terms)) + scale = np.reshape(kfl_scale, (units, num_terms)) + bias = np.reshape(kfl_bias, (units)) + + # Sort input nodes by input index. + input_nodes = [ + node for _, node in sorted(submodel_input_nodes[submodel_idx]) + ] + + kfl_node = model_info.KroneckerFactoredLatticeNode( + input_nodes=input_nodes, weights=weights, scale=scale, bias=bias) + kfl_nodes[submodel_idx] = kfl_node + return kfl_nodes + + +def _create_rtl_submodel_kronecker_factored_lattice_nodes( + sess, ops, graph, flattened_calibration_nodes, submodel_idx, submodel_key): + """Returns next key and map from key+unit to KroneckerFactoredLatticeNode.""" + submodel_kfl_nodes = {} + # KroneckerFactoredLattice kernel weights + # {RTL_LAYER_NAME}_{submodel_idx}/ + # {RTL_KFL_NAME}_{monotonicities}/{KFL_KERNEL_NAME} + kfl_kernel_op_re = '^{}_{}/{}_(.*)/{}/Read/ReadVariableOp$'.format( + premade_lib.RTL_LAYER_NAME, + submodel_idx, + rtl_layer.RTL_KFL_NAME, + kfll.KFL_KERNEL_NAME, + ) + for kfl_kernel_op, monotonicities in _match_op(ops, kfl_kernel_op_re): + kfl_kernel = sess.run( + graph.get_operation_by_name(kfl_kernel_op).outputs[0]).flatten() + + # KroneckerFactoredLattice scale. + # {RTL_LAYER_NAME}_{submodel_idx}/ + # {RTL_KFL_NAME}_{monotonicities}/{KFL_SCALE_NAME} + kfl_scale_op_name = '{}_{}/{}_{}/{}/Read/ReadVariableOp'.format( + premade_lib.RTL_LAYER_NAME, + submodel_idx, + rtl_layer.RTL_KFL_NAME, + monotonicities, + kfll.KFL_SCALE_NAME, + ) + kfl_scale = sess.run( + graph.get_operation_by_name(kfl_scale_op_name).outputs[0]).flatten() + + # KroneckerFactoredLattice bias. + # {RTL_LAYER_NAME}_{submodel_idx}/ + # {RTL_KFL_NAME}_{monotonicities}/{KFL_BIAS_NAME} + kfl_bias_op_name = '{}_{}/{}_{}/{}/Read/ReadVariableOp'.format( + premade_lib.RTL_LAYER_NAME, + submodel_idx, + rtl_layer.RTL_KFL_NAME, + monotonicities, + kfll.KFL_BIAS_NAME, + ) + kfl_bias = sess.run( + graph.get_operation_by_name(kfl_bias_op_name).outputs[0]).flatten() + + # Lattice sizes. + # {RTL_LAYER_NAME}_{submodel_idx}/ + # {RTL_KFL_NAME}_{monotonicities}/{LATTICE_SIZES_NAME} + lattice_sizes_op_name = '{}_{}/{}_{}/{}'.format( + premade_lib.RTL_LAYER_NAME, + submodel_idx, + rtl_layer.RTL_KFL_NAME, + monotonicities, + kfll.LATTICE_SIZES_NAME, + ) + lattice_sizes = sess.run( + graph.get_operation_by_name(lattice_sizes_op_name).outputs[0]) + + # Dims. + # {RTL_LAYER_NAME}_{submodel_idx}/ + # {RTL_KFL_NAME}_{monotonicities}/{DIMS_NAME} + dims_op_name = '{}_{}/{}_{}/{}'.format( + premade_lib.RTL_LAYER_NAME, + submodel_idx, + rtl_layer.RTL_KFL_NAME, + monotonicities, + kfll.DIMS_NAME, + ) + dims = sess.run(graph.get_operation_by_name(dims_op_name).outputs[0]) + + # Num terms. + # {RTL_LAYER_NAME}_{submodel_idx}/ + # {RTL_KFL_NAME}_{monotonicities}/{NUM_TERMS_NAME} + num_terms_op_name = '{}_{}/{}_{}/{}'.format( + premade_lib.RTL_LAYER_NAME, + submodel_idx, + rtl_layer.RTL_KFL_NAME, + monotonicities, + kfll.NUM_TERMS_NAME, + ) + num_terms = sess.run( + graph.get_operation_by_name(num_terms_op_name).outputs[0]) + + # inputs_for_units + # {RTL_LAYER_NAME}_{submodel_index}/ + # {INPUTS_FOR_UNITS_PREFIX}_{monotonicities} + inputs_for_units_op_name = '{}_{}/{}_{}'.format( + premade_lib.RTL_LAYER_NAME, submodel_idx, + rtl_layer.INPUTS_FOR_UNITS_PREFIX, monotonicities) + inputs_for_units = sess.run( + graph.get_operation_by_name(inputs_for_units_op_name).outputs[0]) + + # Make a unique kfl for each unit. + units = inputs_for_units.shape[0] + for i in range(units): + # Shape the flat weights, scale, and bias parameters based on the + # calculated lattice_sizes, units, dims, and num_terms. + weights = np.reshape(kfl_kernel, + (1, lattice_sizes, units * dims, num_terms)) + scale = np.reshape(kfl_scale, (units, num_terms)) + bias = np.reshape(kfl_bias, (units)) + + # Gather input nodes for lattice node. + indices = inputs_for_units[i] + input_nodes = [flattened_calibration_nodes[index] for index in indices] + + kfl_node = model_info.KroneckerFactoredLatticeNode( + input_nodes=input_nodes, weights=weights, scale=scale, bias=bias) + submodel_kfl_nodes[submodel_key] = kfl_node + submodel_key += 1 + return submodel_key, submodel_kfl_nodes + + +def _create_rtl_submodel_lattice_nodes(sess, ops, graph, + flattened_calibration_nodes, + submodel_idx, submodel_key): + """Returns next key and map from key+unit to LatticeNode.""" + submodel_lattice_nodes = {} + # Lattice kernel weights. + # {RTL_LAYER_NAME}_{submodel_idx}/ + # {RTL_LATTICE_NAME}_{monotonicities}/{LATTICE_KERNEL_NAME} + lattice_kernel_op_re = '^{}_{}/{}_(.*)/{}/Read/ReadVariableOp$'.format( + premade_lib.RTL_LAYER_NAME, + submodel_idx, + rtl_layer.RTL_LATTICE_NAME, + lattice_layer.LATTICE_KERNEL_NAME, + ) + for lattice_kernel_op, monotonicities in _match_op(ops, lattice_kernel_op_re): + lattice_kernel = sess.run( + graph.get_operation_by_name(lattice_kernel_op).outputs[0]) + + # Lattice sizes. + # {RTL_LAYER_NAME}_{submodel_idx}/ + # {RTL_LATTICE_NAME}_{monotonicities}/{LATTICE_SIZES_NAME} + lattice_sizes_op_name = '{}_{}/{}_{}/{}'.format( + premade_lib.RTL_LAYER_NAME, submodel_idx, rtl_layer.RTL_LATTICE_NAME, + monotonicities, lattice_layer.LATTICE_SIZES_NAME) + lattice_sizes = sess.run( + graph.get_operation_by_name( + lattice_sizes_op_name).outputs[0]).flatten() + + # inputs_for_units + # {RTL_LAYER_NAME}_{submodel_index}/ + # {INPUTS_FOR_UNITS_PREFIX}_{monotonicities} + inputs_for_units_op_name = '{}_{}/{}_{}'.format( + premade_lib.RTL_LAYER_NAME, submodel_idx, + rtl_layer.INPUTS_FOR_UNITS_PREFIX, monotonicities) + inputs_for_units = sess.run( + graph.get_operation_by_name(inputs_for_units_op_name).outputs[0]) + + # Make a unique lattice for each unit. + units = inputs_for_units.shape[0] + for i in range(units): + # Shape the flat lattice parameters based on the calculated lattice + # sizes. + weights = np.reshape(lattice_kernel[:, i], lattice_sizes) + + # Gather input nodes for lattice node. + indices = inputs_for_units[i] + input_nodes = [flattened_calibration_nodes[index] for index in indices] + + lattice_node = model_info.LatticeNode( + input_nodes=input_nodes, weights=weights) + submodel_lattice_nodes[submodel_key] = lattice_node + submodel_key += 1 + return submodel_key, submodel_lattice_nodes + + +def _create_rtl_lattice_nodes(sess, ops, graph, calibration_nodes_map, + kronecker_factored): + """Returns a map from lattice_submodel_index to lattice type Node.""" lattice_nodes = {} lattice_submodel_index = 0 # Feature name in concat op. @@ -1458,57 +1702,14 @@ def _create_rtl_lattice_nodes(sess, ops, graph, calibration_nodes_map): for feature_name in names_in_flattened_order: flattened_calibration_nodes.extend(calibration_nodes_map[feature_name]) - # Lattice kernel weights. - # {RTL_LAYER_NAME}_{submodel_idx}/ - # {RTL_LATTICE_NAME}_{monotonicities}/{LATTICE_KERNEL_NAME} - lattice_kernel_op_re = '^{}_{}/{}_(.*)/{}/Read/ReadVariableOp$'.format( - premade_lib.RTL_LAYER_NAME, - submodel_idx, - rtl_layer.RTL_LATTICE_NAME, - lattice_layer.LATTICE_KERNEL_NAME, - ) - for lattice_kernel_op, monotonicities in _match_op(ops, - lattice_kernel_op_re): - # Lattice kernel weights. - lattice_kernel = sess.run( - graph.get_operation_by_name(lattice_kernel_op).outputs[0]) - - # Lattice sizes. - # {RTL_LAYER_NAME}_{submodel_idx}/ - # {RTL_LATTICE_NAME}_{monotonicities}/{LATTICE_SIZES_NAME} - lattice_sizes_op_name = '{}_{}/{}_{}/{}'.format( - premade_lib.RTL_LAYER_NAME, submodel_idx, rtl_layer.RTL_LATTICE_NAME, - monotonicities, lattice_layer.LATTICE_SIZES_NAME) - - lattice_sizes = sess.run( - graph.get_operation_by_name( - lattice_sizes_op_name).outputs[0]).flatten() - - # inputs_for_units - # {RTL_LAYER_NAME}_{submodel_index}/ - # {INPUTS_FOR_UNITS_PREFIX}_{monotonicities} - inputs_for_units_op_name = '{}_{}/{}_{}'.format( - premade_lib.RTL_LAYER_NAME, submodel_idx, - rtl_layer.INPUTS_FOR_UNITS_PREFIX, monotonicities) - - inputs_for_units = sess.run( - graph.get_operation_by_name(inputs_for_units_op_name).outputs[0]) - - # Make a unique lattice for each unit. - units = inputs_for_units.shape[0] - for i in range(units): - # Shape the flat lattice parameters based on the calculated lattice - # sizes. - weights = np.reshape(lattice_kernel[:, i], lattice_sizes) - - # Gather input nodes for lattice node. - indices = inputs_for_units[i] - input_nodes = [flattened_calibration_nodes[index] for index in indices] - - lattice_node = model_info.LatticeNode( - input_nodes=input_nodes, weights=weights) - lattice_nodes[lattice_submodel_index] = lattice_node - lattice_submodel_index += 1 + if kronecker_factored: + node_fn = _create_rtl_submodel_kronecker_factored_lattice_nodes + else: + node_fn = _create_rtl_submodel_lattice_nodes + lattice_submodel_index, submodel_lattice_nodes = node_fn( + sess, ops, graph, flattened_calibration_nodes, submodel_idx, + lattice_submodel_index) + lattice_nodes.update(submodel_lattice_nodes) return lattice_nodes @@ -1595,6 +1796,7 @@ def _create_output_calibration_node(sess, ops, graph, input_node): return output_calibration_node +# TODO: add support for KFL in RTL Layer def get_model_graph(saved_model_path, tag='serve'): """Returns all layers and parameters used in a saved model as a graph. @@ -1666,12 +1868,24 @@ def get_model_graph(saved_model_path, tag='serve'): submodel_output_nodes.update(lattice_nodes) nodes.extend(lattice_nodes.values()) + # Ensemble Kronecker Factored Lattice nodes. + kfl_nodes = _create_kronecker_factored_lattice_nodes( + sess, ops, graph, submodel_input_nodes) + submodel_output_nodes.update(kfl_nodes) + nodes.extend(kfl_nodes.values()) + # RTL Lattice nodes. - rtl_lattice_nodes = _create_rtl_lattice_nodes(sess, ops, graph, - calibration_nodes_map) + rtl_lattice_nodes = _create_rtl_lattice_nodes( + sess, ops, graph, calibration_nodes_map, kronecker_factored=False) submodel_output_nodes.update(rtl_lattice_nodes) nodes.extend(rtl_lattice_nodes.values()) + # RTL Kronecker Factored Lattice nodes. + kfl_rtl_nodes = _create_rtl_lattice_nodes( + sess, ops, graph, calibration_nodes_map, kronecker_factored=True) + submodel_output_nodes.update(kfl_rtl_nodes) + nodes.extend(kfl_rtl_nodes.values()) + # Output combination node. model_output_node = _create_output_combination_node(sess, ops, graph, submodel_output_nodes) diff --git a/tensorflow_lattice/python/estimators_test.py b/tensorflow_lattice/python/estimators_test.py index 3e4bd6b..40ffa63 100644 --- a/tensorflow_lattice/python/estimators_test.py +++ b/tensorflow_lattice/python/estimators_test.py @@ -595,17 +595,18 @@ def testCalibratedLinearEstimator(self, feature_names, output_calibration, self.assertLess(results['average_loss'], average_loss) @parameterized.parameters( - ('random', 5, 6, False, True), - ('random', 4, 5, True, False), - ('rtl_layer', 5, 6, False, True), - ('rtl_layer', 4, 5, True, False), + ('random', 5, 6, 'all_vertices', False, True), + ('random', 4, 5, 'kronecker_factored', True, False), + ('rtl_layer', 5, 6, 'kronecker_factored', False, True), + ('rtl_layer', 4, 5, 'all_vertices', True, False), ) def testCalibratedLatticeEnsembleModelInfo(self, lattices, num_lattices, - lattice_rank, separate_calibrators, + lattice_rank, parameterization, + separate_calibrators, output_calibration): self._ResetAllBackends() feature_configs = copy.deepcopy(self.heart_feature_configs) - if lattices == 'rtl_layer': + if lattices == 'rtl_layer' or parameterization == 'kronecker_factored': # RTL Layer only supports monotonicity and bound constraints. for feature_config in feature_configs: feature_config.lattice_size = 2 @@ -618,6 +619,7 @@ def testCalibratedLatticeEnsembleModelInfo(self, lattices, num_lattices, lattices=lattices, num_lattices=num_lattices, lattice_rank=lattice_rank, + parameterization=parameterization, separate_calibrators=separate_calibrators, output_calibration=output_calibration, ) @@ -707,10 +709,12 @@ def testCalibratedLatticeEnsembleFix2dConstraintViolations( self.assertCountEqual(lattice, expected_lattice) @parameterized.parameters( - ('linear', True), - ('lattice', False), + ('linear', None, True), + ('lattice', 'all_vertices', False), + ('lattice', 'kronecker_factored', False), ) - def testCalibratedModelInfo(self, model_type, output_calibration): + def testCalibratedModelInfo(self, model_type, parameterization, + output_calibration): self._ResetAllBackends() if model_type == 'linear': model_config = configs.CalibratedLinearConfig( @@ -718,8 +722,18 @@ def testCalibratedModelInfo(self, model_type, output_calibration): output_calibration=output_calibration, ) else: + feature_configs = copy.deepcopy(self.heart_feature_configs) + if parameterization == 'kronecker_factored': + # RTL Layer only supports monotonicity and bound constraints. + for feature_config in feature_configs: + feature_config.lattice_size = 2 + feature_config.unimodality = 'none' + feature_config.reflects_trust_in = None + feature_config.dominates = None + feature_config.regularizer_configs = None model_config = configs.CalibratedLatticeConfig( - feature_configs=self.heart_feature_configs, + feature_configs=feature_configs, + parameterization=parameterization, output_calibration=output_calibration, ) estimator = estimators.CannedClassifier( diff --git a/tensorflow_lattice/python/kronecker_factored_lattice_layer.py b/tensorflow_lattice/python/kronecker_factored_lattice_layer.py index 7162e31..09d4efa 100644 --- a/tensorflow_lattice/python/kronecker_factored_lattice_layer.py +++ b/tensorflow_lattice/python/kronecker_factored_lattice_layer.py @@ -23,13 +23,21 @@ from __future__ import division from __future__ import print_function +import functools +import inspect + from . import kronecker_factored_lattice_lib as kfl_lib from . import utils +import tensorflow as tf from tensorflow import keras +DIMS_NAME = "dims" KFL_SCALE_NAME = "kronecker_factored_lattice_scale" KFL_BIAS_NAME = "kronecker_factored_lattice_bias" KFL_KERNEL_NAME = "kronecker_factored_lattice_kernel" +LATTICE_SIZES_NAME = "lattice_sizes" +NUM_TERMS_NAME = "num_terms" +UNITS_NAME = "units" # TODO: add support for different lattice_sizes for each input @@ -62,6 +70,8 @@ class KroneckerFactoredLattice(keras.layers.Layer): * **Monotonicity:** constrains the function to be increasing in the corresponding dimension. + There are upper and lower bound constraints on the output. + Input shape: - if `units == 1`: tensor of shape: `(batch_size, ..., dims)` or list of `dims` tensors of same shape: `(batch_size, ..., 1)` @@ -105,9 +115,11 @@ def __init__(self, units=1, num_terms=2, monotonicities=None, + output_min=None, + output_max=None, clip_inputs=True, - satisfy_constraints_at_every_step=True, - kernel_initializer="random_monotonic_initializer", + kernel_initializer="kfl_random_monotonic_initializer", + scale_initializer="scale_initializer", **kwargs): # pyformat: disable """Initializes an instance of `KroneckerFactoredLattice`. @@ -122,14 +134,22 @@ def __init__(self, be monotonic in the corresponding feature, using 'increasing' or 1 to indicate increasing monotonicity and 'none' or 0 to indicate no monotonicity constraints. + output_min: None or lower bound of the output. + output_max: None or upper bound of the output. clip_inputs: If inputs should be clipped to the input range of the Kronecker-Factored Lattice. - satisfy_constraints_at_every_step: Whether to strictly enforce constraints - after every gradient update by applying an imprecise projection. kernel_initializer: None or one of: - - `'random_monotonic_initializer'`: initializes parameters as uniform + - `'kfl_random_monotonic_initializer'`: initializes parameters as uniform random functions that are monotonic in monotonic dimensions. - Any Keras initializer object. + scale_initializer: None or one of: + - `'scale_initializer'`: Initializes scale depending on output_min and + output_max. If both output_min and output_max are set, scale is + initialized to half their difference, alternating signs for each term. + If only output_min is set, scale is initialized to 1 for each term. If + only output_max is set, scale is initialized to -1 for each term. + Otherwise scale is initialized to alternate between 1 and -1 for each + term. **kwargs: Other args passed to `tf.keras.layers.Layer` initializer. Raises: @@ -137,19 +157,31 @@ def __init__(self, """ # pyformat: enable kfl_lib.verify_hyperparameters( - lattice_sizes=lattice_sizes, units=units, num_terms=num_terms) + lattice_sizes=lattice_sizes, + units=units, + num_terms=num_terms, + output_min=output_min, + output_max=output_max) super(KroneckerFactoredLattice, self).__init__(**kwargs) self.lattice_sizes = lattice_sizes self.units = units self.num_terms = num_terms self.monotonicities = monotonicities + self.output_min = output_min + self.output_max = output_max self.clip_inputs = clip_inputs - self.satisfy_constraints_at_every_step = satisfy_constraints_at_every_step self.kernel_initializer = create_kernel_initializer( kernel_initializer_id=kernel_initializer, - monotonicities=self.monotonicities) + monotonicities=self.monotonicities, + output_min=self.output_min, + output_max=self.output_max) + + self.scale_initializer = create_scale_initializer( + scale_initializer_id=scale_initializer, + output_min=self.output_min, + output_max=self.output_max) def build(self, input_shape): """Standard Keras build() method.""" @@ -163,41 +195,70 @@ def build(self, input_shape): else: dims = input_shape.as_list()[-1] + if self.output_min is not None or self.output_max is not None: + scale_constraints = ScaleConstraints( + output_min=self.output_min, output_max=self.output_max) + else: + scale_constraints = None self.scale = self.add_weight( KFL_SCALE_NAME, shape=[self.units, self.num_terms], - initializer=ScaleInitializer(), + initializer=self.scale_initializer, + constraint=scale_constraints, dtype=self.dtype) self.bias = self.add_weight( KFL_BIAS_NAME, shape=[self.units], - initializer="zeros", + initializer=BiasInitializer(self.output_min, self.output_max), + trainable=(self.output_min is None and self.output_max is None), dtype=self.dtype) - if self.monotonicities: + if (self.monotonicities or self.output_min is not None or + self.output_max is not None): constraints = KroneckerFactoredLatticeConstraints( units=self.units, scale=self.scale, monotonicities=self.monotonicities, - satisfy_constraints_at_every_step=self - .satisfy_constraints_at_every_step) + output_min=self.output_min, + output_max=self.output_max) else: constraints = None # Note that the first dimension of shape is 1 to work with - # tf.nn.depthwise_conv2d. + # tf.nn.depthwise_conv2d. We also provide scale to the __call__ method + # of the initializer using partial functions if it accepts scale. + parameters = inspect.signature(self.kernel_initializer).parameters.keys() + if "scale" in parameters: + kernel_initializer = functools.partial( + self.kernel_initializer, scale=self.scale.initialized_value()) + else: + kernel_initializer = self.kernel_initializer self.kernel = self.add_weight( KFL_KERNEL_NAME, shape=[1, self.lattice_sizes, self.units * dims, self.num_terms], - initializer=self.kernel_initializer, + initializer=kernel_initializer, constraint=constraints, dtype=self.dtype) - self._final_constraints = KroneckerFactoredLatticeConstraints( + self._final_kernel_constraints = KroneckerFactoredLatticeConstraints( units=self.units, scale=self.scale, monotonicities=self.monotonicities, - satisfy_constraints_at_every_step=True) + output_min=self.output_min, + output_max=self.output_max) + + self._final_scale_constraints = ScaleConstraints( + output_min=self.output_min, output_max=self.output_max) + + # These tensors are meant for book keeping. Note that this slightly + # increases the size of the graph. + self.lattice_sizes_tensor = tf.constant( + self.lattice_sizes, dtype=tf.int32, name=LATTICE_SIZES_NAME) + self.units_tensor = tf.constant( + self.units, dtype=tf.int32, name=UNITS_NAME) + self.dims_tensor = tf.constant(dims, dtype=tf.int32, name=DIMS_NAME) + self.num_terms_tensor = tf.constant( + self.num_terms, dtype=tf.int32, name=NUM_TERMS_NAME) super(KroneckerFactoredLattice, self).build(input_shape) @@ -230,28 +291,33 @@ def get_config(self): "units": self.units, "num_terms": self.num_terms, "monotonicities": self.monotonicities, + "output_min": self.output_min, + "output_max": self.output_max, "clip_inputs": self.clip_inputs, - "satisfy_constraints_at_every_step": - self.satisfy_constraints_at_every_step, "kernel_initializer": keras.initializers.serialize(self.kernel_initializer), + "scale_initializer": + keras.initializers.serialize(self.scale_initializer), } # pyformat: disable config.update(super(KroneckerFactoredLattice, self).get_config()) return config + # TODO: can we remove this now that we always project at every step? def finalize_constraints(self): """Ensures layers weights strictly satisfy constraints. Applies approximate projection to strictly satisfy specified constraints. - If `monotonic_at_every_step == True` there is no need to call this function. Returns: - In eager mode directly updates weights and returns variable which stores - them. In graph mode returns `assign_add` op which has to be executed to - updates weights. + In eager mode directly updates kernel and scale and returns the variables + which store them. In graph mode returns a `group` op containing the + `assign_add` ops which have to be executed to update the kernel and scale. """ - return self.kernel.assign_add( - self._final_constraints(self.kernel) - self.kernel) + finalize_kernel = self.kernel.assign_add( + self._final_kernel_constraints(self.kernel) - self.kernel) + finalize_scale = self.scale.assign_add( + self._final_scale_constraints(self.scale) - self.scale) + return tf.group([finalize_kernel, finalize_scale]) def assert_constraints(self, eps=1e-6): """Asserts that weights satisfy all constraints. @@ -271,10 +337,17 @@ def assert_constraints(self, eps=1e-6): scale=self.scale, monotonicities=utils.canonicalize_monotonicities( self.monotonicities, allow_decreasing=False), + output_min=self.output_min, + output_max=self.output_max, eps=eps) -def create_kernel_initializer(kernel_initializer_id, monotonicities): +def create_kernel_initializer(kernel_initializer_id, + monotonicities, + output_min, + output_max, + init_min=None, + init_max=None): """Returns a kernel Keras initializer object from its id. This function is used to convert the 'kernel_initializer' parameter in the @@ -286,32 +359,83 @@ def create_kernel_initializer(kernel_initializer_id, monotonicities): parameter in the constructor of `tfl.layers.KroneckerFactoredLattice`. monotonicities: See the documentation of the same parameter in the constructor of `tfl.layers.KroneckerFactoredLattice`. + output_min: See the documentation of the same parameter in the constructor + of `tfl.layers.KroneckerFactoredLattice`. + output_max: See the documentation of the same parameter in the constructor + of `tfl.layers.KroneckerFactoredLattice`. + init_min: None or lower bound of kernel initialization. If set, init_max + must also be set. Ignored if kernel_initializer_id is a Keras object. + init_max: None or upper bound of kernel initialization. If set, init_min + must also be set. Ignored if kernel_initializer_id is a Keras object. Returns: The Keras initializer object for the `tfl.layers.KroneckerFactoredLattice` kernel variable. + + Raises: + ValueError: If only one of init_{min/max} is set. """ + if init_min is None and init_max is None: + init_min, init_max = kfl_lib.default_init_params(output_min, output_max) + elif init_min is not None and init_max is not None: + # We have nothing to set here. + pass + else: + raise ValueError("Both or neither of init_{min/max} must be set") + # Construct initializer. if kernel_initializer_id in [ - "random_monotonic_initializer", "RandomMonotonicInitializer" + "kfl_random_monotonic_initializer", "KFLRandomMonotonicInitializer" ]: - return RandomMonotonicInitializer(monotonicities) + return KFLRandomMonotonicInitializer( + monotonicities=monotonicities, init_min=init_min, init_max=init_max) else: # This is needed for Keras deserialization logic to be aware of our custom # objects. with keras.utils.custom_object_scope({ - "RandomMonotonicInitializer": RandomMonotonicInitializer, + "KFLRandomMonotonicInitializer": KFLRandomMonotonicInitializer, }): return keras.initializers.get(kernel_initializer_id) -class RandomMonotonicInitializer(keras.initializers.Initializer): +def create_scale_initializer(scale_initializer_id, output_min, output_max): + """Returns a scale Keras initializer object from its id. + + This function is used to convert the 'scale_initializer' parameter in the + constructor of tfl.layers.KroneckerFactoredLattice into the corresponding + initializer object. + + Args: + scale_initializer_id: See the documentation of the 'scale_initializer' + parameter in the constructor of `tfl.layers.KroneckerFactoredLattice`. + output_min: See the documentation of the same parameter in the constructor + of `tfl.layers.KroneckerFactoredLattice`. + output_max: See the documentation of the same parameter in the constructor + of `tfl.layers.KroneckerFactoredLattice`. + + Returns: + The Keras initializer object for the `tfl.layers.KroneckerFactoredLattice` + scale variable. + """ + # Construct initializer. + if scale_initializer_id in ["scale_initializer", "ScaleInitializer"]: + return ScaleInitializer(output_min=output_min, output_max=output_max) + else: + # This is needed for Keras deserialization logic to be aware of our custom + # objects. + with keras.utils.custom_object_scope({ + "ScaleInitializer": ScaleInitializer, + }): + return keras.initializers.get(scale_initializer_id) + + +class KFLRandomMonotonicInitializer(keras.initializers.Initializer): # pyformat: disable """Initializes a `tfl.layers.KroneckerFactoredLattice` as random monotonic.""" # pyformat: enable def __init__(self, monotonicities, init_min=0.5, init_max=1.5, seed=None): - """Initializes an instance of `RandomMonotonicInitializer`. + """Initializes an instance of `KFLRandomMonotonicInitializer`. Args: monotonicities: Monotonic dimensions for initialization. Does not need to @@ -325,17 +449,19 @@ def __init__(self, monotonicities, init_min=0.5, init_max=1.5, seed=None): self.init_max = init_max self.seed = seed - def __call__(self, shape, dtype=None, partition_info=None): + def __call__(self, shape, scale, dtype=None, **kwargs): """Returns weights of `tfl.layers.KroneckerFactoredLattice` layer. Args: shape: Must be: `(1, lattice_sizes, units * dims, num_terms)`. + scale: Scale variable of shape: `(units, num_terms)`. dtype: Standard Keras initializer param. - partition_info: Standard Keras initializer param. Not used. + **kwargs: Other args passed to `tf.keras.initializers.Initializer` + __call__ method. """ - del partition_info - return kfl_lib.random_monotonic_initializer( + return kfl_lib.kfl_random_monotonic_initializer( shape=shape, + scale=scale, monotonicities=utils.canonicalize_monotonicities( self.monotonicities, allow_decreasing=False), init_min=self.init_min, @@ -356,20 +482,94 @@ def get_config(self): class ScaleInitializer(keras.initializers.Initializer): # pyformat: disable - """Initializes scale to alternate between 1 and -1 for each term.""" + """Initializes scale depending on output_min and output_max. + + If both output_min and output_max are set, scale is initialized to half their + difference, alternating signs for each term. If only output_min is set, scale + is initialized to 1 for each term. If only output_max is set, scale is + initialized to -1 for each term. Otherwise scale is initialized to alternate + between 1 and -1 for each term. + """ # pyformat: enable - def __call__(self, shape, dtype=None, partition_info=None): + def __init__(self, output_min, output_max): + """Initializes an instance of `ScaleInitializer`. + + Args: + output_min: None or minimum layer output. + output_max: None or maximum layer output. + """ + self.output_min = output_min + self.output_max = output_max + + def __call__(self, shape, dtype=None, **kwargs): """Returns weights of `tfl.layers.KroneckerFactoredLattice` scale. Args: shape: Must be: `(units, num_terms)`. dtype: Standard Keras initializer param. - partition_info: Standard Keras initializer param. Not used. + **kwargs: Other args passed to `tf.keras.initializers.Initializer` + __call__ method. """ - del partition_info units, num_terms = shape - return kfl_lib.scale_initializer(units=units, num_terms=num_terms) + return kfl_lib.scale_initializer( + units=units, + num_terms=num_terms, + output_min=self.output_min, + output_max=self.output_max) + + def get_config(self): + """Standard Keras config for serializaion.""" + config = { + "output_min": self.output_min, + "output_max": self.output_max, + } # pyformat: disable + return config + + +class BiasInitializer(keras.initializers.Initializer): + # pyformat: disable + """Initializes bias depending on output_min and output_max. + + If both output_min and output_max are set, bias is initialized to their + average. If only output_min is set, bias is initialized to output_min. If only + output_max is set, bias is initialized to output_max. Otherwise bias is + initialized to zeros. + """ + # pyformat: enable + + def __init__(self, output_min, output_max): + """Initializes an instance of `BiasInitializer`. + + Args: + output_min: None or minimum layer output. + output_max: None or maximum layer output. + """ + self.output_min = output_min + self.output_max = output_max + + def __call__(self, shape, dtype=None, **kwargs): + """Returns weights of `tfl.layers.KroneckerFactoredLattice` bias. + + Args: + shape: Must be: `(units, num_terms)`. + dtype: Standard Keras initializer param. + **kwargs: Other args passed to `tf.keras.initializers.Initializer` + __call__ method. + """ + return kfl_lib.bias_initializer( + units=shape[0], + output_min=self.output_min, + output_max=self.output_max, + dtype=dtype) + + def get_config(self): + """Standard Keras config for serializaion.""" + config = { + "output_min": self.output_min, + "output_max": self.output_max, + } # pyformat: disable + return config class KroneckerFactoredLatticeConstraints(keras.constraints.Constraint): @@ -388,7 +588,8 @@ def __init__(self, units, scale, monotonicities=None, - satisfy_constraints_at_every_step=True): + output_min=None, + output_max=None): """Initializes an instance of `KroneckerFactoredLatticeConstraints`. Args: @@ -397,15 +598,18 @@ def __init__(self, scale: Scale variable of shape: `(units, num_terms)`. monotonicities: Same meaning as corresponding parameter of `KroneckerFactoredLattice`. - satisfy_constraints_at_every_step: Whether to use approximate projection - to ensure that constratins are strictly satisfied. + output_min: Same meaning as corresponding parameter of + `KroneckerFactoredLattice`. + output_max: Same meaning as corresponding parameter of + `KroneckerFactoredLattice`. """ self.units = units self.scale = scale self.monotonicities = utils.canonicalize_monotonicities( monotonicities, allow_decreasing=False) self.num_constraint_dims = utils.count_non_zeros(self.monotonicities) - self.satisfy_constraints_at_every_step = satisfy_constraints_at_every_step + self.output_min = output_min + self.output_max = output_max def __call__(self, w): """Applies constraints to `w`. @@ -417,12 +621,14 @@ def __call__(self, w): Returns: Constrained and projected w. """ - if self.num_constraint_dims and self.satisfy_constraints_at_every_step: - w = kfl_lib.finalize_constraints( + if self.num_constraint_dims: + w = kfl_lib.finalize_weight_constraints( w, units=self.units, scale=self.scale, - monotonicities=self.monotonicities) + monotonicities=self.monotonicities, + output_min=self.output_min, + output_max=self.output_max) return w def get_config(self): @@ -431,6 +637,55 @@ def get_config(self): "units": self.units, "scale": self.scale, "monotonicities": self.monotonicities, - "satisfy_constraints_at_every_step": - self.satisfy_constraints_at_every_step, + "output_min": self.output_min, + "output_max": self.output_max, + } # pyformat: disable + + +class ScaleConstraints(keras.constraints.Constraint): + # pyformat: disable + """Constraints for `tfl.layers.KroneckerFactoredLattice` scale. + + Constraints the scale variable to be between + `[output_min-output_max, output_max-output_min]` such that the final output + of the layer is within the desired `[output_min, output_max]` range, assuming + bias is properly fixed to be `output_min`. + + Attributes: + - All `__init__` arguments. + """ + # pyformat: enable + + def __init__(self, output_min=None, output_max=None): + """Initializes an instance of `ScaleConstraints`. + + Args: + output_min: Same meaning as corresponding parameter of + `KroneckerFactoredLattice`. + output_max: Same meaning as corresponding parameter of + `KroneckerFactoredLattice`. + """ + self.output_min = output_min + self.output_max = output_max + + def __call__(self, scale): + """Applies constraints to `scale`. + + Args: + scale: Kronecker-Factored Lattice scale tensor of shape: `(units, + num_terms)`. + + Returns: + Constrained and clipped scale. + """ + if self.output_min is not None or self.output_max is not None: + scale = kfl_lib.finalize_scale_constraints( + scale, output_min=self.output_min, output_max=self.output_max) + return scale + + def get_config(self): + """Standard Keras config for serialization.""" + return { + "output_min": self.output_min, + "output_max": self.output_max, } # pyformat: disable diff --git a/tensorflow_lattice/python/kronecker_factored_lattice_lib.py b/tensorflow_lattice/python/kronecker_factored_lattice_lib.py index 22c56f3..435cc28 100644 --- a/tensorflow_lattice/python/kronecker_factored_lattice_lib.py +++ b/tensorflow_lattice/python/kronecker_factored_lattice_lib.py @@ -149,12 +149,26 @@ def evaluate_with_hypercube_interpolation(inputs, scale, bias, kernel, units, return results -def random_monotonic_initializer(shape, - monotonicities, - init_min=0.5, - init_max=1.5, - dtype=tf.float32, - seed=None): +def default_init_params(output_min, output_max): + """Returns default initialization bounds depending on layer output bounds. + + Args: + output_min: None or minimum layer output. + output_max: None or maximum layer output. + """ + if output_min is None and output_max is None: + return 0.5, 1.5 + else: + return 0.0, 1.0 + + +def kfl_random_monotonic_initializer(shape, + scale, + monotonicities, + init_min=0.5, + init_max=1.5, + dtype=tf.float32, + seed=None): """Returns a uniformly random sampled monotonic weight tensor. - The uniform random monotonic function will initilaize the lattice parameters @@ -165,6 +179,7 @@ def random_monotonic_initializer(shape, Args: shape: Shape of weights to initialize. Must be: `(1, lattice_sizes, units * dims, num_terms)`. + scale: Scale variable of shape: `(units, num_terms)`. monotonicities: None or list or tuple of length dims of elements of {0,1} which represents monotonicity constraints per dimension. 1 stands for increasing (non-decreasing in fact), 0 for no monotonicity constraints. @@ -175,7 +190,7 @@ def random_monotonic_initializer(shape, Returns: Kronecker-Factored Lattice weights tensor of shape: - `(1, lattice_sizes, units * dims, num_terms)`. + `(1, lattice_sizes, units * dims, num_terms)`. """ # Sample from the uniform distribution. weights = tf.random.uniform( @@ -186,37 +201,87 @@ def random_monotonic_initializer(shape, _, lattice_sizes, units_times_dims, num_terms = shape if units_times_dims % dims != 0: raise ValueError( - "len(monotonicities) is {}, which does not evenly divide shape[2]" + "len(monotonicities) is {}, which does not evenly divide shape[2]." "len(monotonicities) should be equal to `dims`, and shape[2] " "should be equal to units * dims.".format(dims)) units = units_times_dims // dims weights = tf.reshape(weights, [-1, lattice_sizes, units, dims, num_terms]) + # Make all dimensions monotonically increasing with respect to the sign of + # scale. + direction = tf.expand_dims(tf.sign(scale), axis=1) # Now we can unstack each dimension. - weights = tf.unstack(weights, axis=3) + weights = tf.unstack(direction * weights, axis=3) monotonic_weights = [ tf.sort(weight, axis=1) if monotonicity else weight for monotonicity, weight in zip(monotonicities, weights) ] # Restack, reshape, and return weights weights = tf.stack(monotonic_weights, axis=3) - weights = tf.reshape(weights, shape) + weights = tf.reshape(direction * weights, shape) return weights -def scale_initializer(units, num_terms): - """Initializes scale to alternate between 1 and -1 for each term. +def scale_initializer(units, num_terms, output_min, output_max): + """Initializes scale depending on output_min and output_max. + + If both output_min and output_max are set, scale is initialized to half their + difference, alternating signs for each term. If only output_min is set, scale + is initialized to 1 for each term. If only output_max is set, scale is + initialized to -1 for each term. Otherwise scale is initialized to alternate + between 1 and -1 for each term. Args: - units: Output dimension of the layer. Each of units scale will be - initialized identically. + units: Output dimension of the layer. Each unit's scale will be initialized + identically. num_terms: Number of independently trained submodels per unit, the outputs of which are averaged to get the final output. + output_min: None or minimum layer output. + output_max: None or maximum layer output. Returns: Kronecker-Factored Lattice scale of shape: `(units, num_terms)`. """ + if output_min is not None and output_max is None: + return np.ones([units, num_terms]) + if output_min is None and output_max is not None: + return -np.ones([units, num_terms]) + # Both or neither bounds are set, so we alternate sign. signs = (np.arange(num_terms) % -2) * 2 + 1 - return np.tile(signs, [units, 1]) + scale = np.tile(signs, [units, 1]) + if output_min is not None and output_max is not None: + scale = scale * ((output_max - output_min) / 2.0) + return scale + + +def bias_initializer(units, output_min, output_max, dtype=tf.float32): + """Initializes bias depending on output_min and output_max. + + If both output_min and output_max are set, bias is initialized to their + average. If only output_min is set, bias is initialized to output_min. If only + output_max is set, bias is initialized to output_max. Otherwise bias is + initialized to zeros. + + Args: + units: Output dimension of the layer. Each of units bias will be initialized + identically. + output_min: None or minimum layer output. + output_max: None or maximum layer output. + dtype: dtype + + Returns: + Kronecker-Factored Lattice bias of shape: `(units)`. + """ + if output_min is not None and output_max is not None: + return tf.constant( + (output_min + output_max) / 2.0, shape=[units], dtype=dtype) + elif output_min is not None: + return tf.constant(output_min, shape=[units], dtype=dtype) + elif output_max is not None: + # In this case, weights will be nonnegative and scale will be nonpositive so + # we add output_max to interpolation output to achieve proper bound. + return tf.constant(output_max, shape=[units], dtype=dtype) + else: + return tf.zeros(shape=[units], dtype=dtype) def _approximately_project_monotonicity(weights, units, scale, monotonicities): @@ -278,7 +343,56 @@ def _approximately_project_monotonicity(weights, units, scale, monotonicities): return weights -def finalize_constraints(weights, units, scale, monotonicities): +def _approximately_project_bounds(weights, units, output_min, output_max): + """Approximately projects to strictly meet bound constraints. + + For more details, see _approximately_project_bounds in lattice_lib.py. + + Args: + weights: Tensor with weights of shape `(1, lattice_sizes, units * dims, + num_terms)`. + units: Number of units per input dimension. + output_min: None or minimum layer output. + output_max: None or maximum layer output. + + Returns: + Tensor with projected weights matching shape of input weights. + """ + if output_min is None and output_max is None: + return weights + + # We project by the dims'th root projection factor of the weights, ultimately + # projecting each term into the range [-1,1], but only if both output_min and + # output_max are specified. Otherwise, we restrict the weights to be + # nonnegative and the interpolation will do a final shift to respect the + # one-sided bound. + if output_min is not None and output_max is not None: + # Recall that w.shape is (1, lattice_sizes, units * dims, num_terms). + weights_shape = weights.get_shape().as_list() + _, lattice_sizes, units_times_dims, num_terms = weights_shape + assert units_times_dims % units == 0 + dims = units_times_dims // units + weights = tf.reshape(weights, [-1, lattice_sizes, units, dims, num_terms]) + max_keypoint_values = tf.reduce_max(tf.abs(weights), axis=1, keepdims=True) + max_output_value = tf.reduce_prod( + max_keypoint_values, axis=3, keepdims=True) + full_projection_factor = tf.maximum(max_output_value, 1.0) + individual_projection_factor = tf.pow(full_projection_factor, 1.0 / dims) + weights = weights / individual_projection_factor + # We must reshape to get our final projected weights. + weights = tf.reshape(weights, weights_shape) + else: + weights = tf.maximum(weights, 0) + + return weights + + +# Note: this function must not depend on the result of projecting scale. +# Currently this function depends on the sign of scale, but the scale projection +# will not flip the sign of scale (only make it 0 in the worse case), which will +# not cause any issues. +def finalize_weight_constraints(weights, units, scale, monotonicities, + output_min, output_max): """Approximately projects weights to strictly satisfy all constraints. This projeciton guarantees that constraints are strictly met, but it is not @@ -296,30 +410,72 @@ def finalize_constraints(weights, units, scale, monotonicities): monotonicities: List or tuple of length dims of elements of {0,1} which represents monotonicity constraints per dimension. 1 stands for increasing (non-decreasing in fact), 0 for no monotonicity constraints. + output_min: None or minimum layer output. + output_max: None or maximum layer output. Returns: Projected weights tensor of same shape as `weights`. """ - if utils.count_non_zeros(monotonicities) == 0: - return weights - - # TODO: in the case of only one monotonic dimension, we only have to - # constrain the non-monotonic dimensions to be positive. - # There must be monotonicity constraints, so we need all positive weights. - weights = tf.maximum(weights, 0) + if utils.count_non_zeros(monotonicities) > 0: + # TODO: in the case of only one monotonic dimension, we only have to + # constrain the non-monotonic dimensions to be positive. + # There must be monotonicity constraints, so we need all nonnegative + # weights. + weights = tf.maximum(weights, 0) + weights = _approximately_project_monotonicity( + weights=weights, + units=units, + scale=scale, + monotonicities=monotonicities) - # Project monotonicity constraints. - weights = _approximately_project_monotonicity(weights, units, scale, - monotonicities) + if output_min is not None or output_max is not None: + weights = _approximately_project_bounds( + weights=weights, + units=units, + output_min=output_min, + output_max=output_max) return weights +# Note: we cannot rely on the weights projection occuring always before or +# always after the scale projection, so this function must not result in a +# projection that would ultimately change the results of the weights projection. +# Currently the weights projection depends on the sign of scale, so this +# function does not change the sign (only makes scale 0 in the worst case), +# which will not cause any issues. +def finalize_scale_constraints(scale, output_min, output_max): + """Clips scale to strictly satisfy all constraints. + + Args: + scale: Scale variable of shape: `(units, num_terms)`. + output_min: None or minimum layer output. + output_max: None or maximum layer output. + + Returns: + Clipped scale tensor of same shape as `scale`. + """ + if output_min is not None and output_max is not None: + bound = (output_max - output_min) / 2.0 + scale = tf.clip_by_value(scale, clip_value_min=-bound, clip_value_max=bound) + elif output_min is not None: + # In this case, we need scale to be nonnegative to properly shift by bias + # and satisfy the one-sided max bound. + scale = tf.maximum(scale, 0) + elif output_max is not None: + # In this case, we need scale to be nonpositive to properly mirror and shift + # by bias and satisfy the one-sided min bound. + scale = tf.minimum(scale, 0) + return scale + + def verify_hyperparameters(lattice_sizes=None, units=None, num_terms=None, input_shape=None, - monotonicities=None): + monotonicities=None, + output_min=None, + output_max=None): """Verifies that all given hyperparameters are consistent. This function does not inspect weights themselves. Only their shape. Use @@ -337,6 +493,8 @@ def verify_hyperparameters(lattice_sizes=None, `monotonicities` is set. monotonicities: Monotonicities hyperparameter of `KroneckerFactoredLattice` layer. Useful only if `input_shape` is set. + output_min: Minimum output of `KroneckerFactoredLattice` layer. + output_max: Maximum output of `KroneckerFactoredLattice` layer. Raises: ValueError: If lattice_sizes < 2. @@ -384,6 +542,12 @@ def verify_hyperparameters(lattice_sizes=None, "'units'. 'units': %s, 'input_shape: %s" % (units, input_shape)) + if output_min is not None and output_max is not None: + if output_min >= output_max: + raise ValueError("'output_min' must be strictly less than 'output_max'. " + "'output_min': %f, 'output_max': %f" % + (output_min, output_max)) + def _assert_monotonicity_constraints(weights, units, scale, monotonicities, eps): @@ -433,7 +597,110 @@ def _assert_monotonicity_constraints(weights, units, scale, monotonicities, return monotonicity_asserts -def assert_constraints(weights, units, scale, monotonicities, eps=1e-6): +def _assert_bound_constraints(weights, units, scale, output_min, output_max, + eps): + """Asserts that weights satisfy monotonicity constraints. + + Args: + weights: `KroneckerFactoredLattice` weights tensor of shape: `(1, + lattice_sizes, units * dims, num_terms)`. + units: Number of units per input dimension. + scale: Scale variable of shape: `(units, num_terms)`. + output_min: None or minimum layer output. + output_max: None or maximum layer output. + eps: Allowed constraints violation. + + Returns: + List of monotonicity assertion ops in graph mode or directly executes + assertions in eager mode and returns a list of NoneType elements. + """ + bound_asserts = [] + + # Recall that w.shape is (1, lattice_sizes, units * dims, num_terms). + weights_shape = weights.get_shape().as_list() + _, lattice_sizes, units_times_dims, num_terms = weights_shape + assert units_times_dims % units == 0 + dims = units_times_dims // units + weights = tf.reshape(weights, [-1, lattice_sizes, units, dims, num_terms]) + + # If both bounds are specified, we must also have that the maximum output be + # between -1 and 1. + if output_min is not None and output_max is not None: + for term, term_weights in enumerate(tf.unstack(weights, axis=4)): + max_keypoint_values = tf.reduce_max( + tf.abs(term_weights), axis=1, keepdims=True) + max_output_values = tf.reduce_prod( + max_keypoint_values, axis=3, keepdims=True) + for unit, unit_max_output_value in enumerate( + tf.unstack(max_output_values, axis=2)): + diff = tf.squeeze(1 - unit_max_output_value) + bound_asserts.append( + tf.Assert( + diff >= -eps, + data=[ + "Bound violation (max output greater than 1)", "Diff", diff, + "Epsilon", eps, "Maximum output value", + unit_max_output_value, "Term index", term, "Unit", unit, + "Weights", weights + ])) + else: + # If only one bound is specified, we must have that all of our weights are + # nonnegative at this point. There can be no allowed epsilon error here + # because of the effect of a negative value. + total_negative_weights = tf.reduce_sum(tf.cast(weights < 0, tf.int32)) + bound_asserts.append( + tf.Assert( + total_negative_weights <= 0, + data=[ + "Bound violation (negative weights)", + "Number of negative weights", total_negative_weights, "Weights", + weights + ])) + + # If both bounds are specified, scale must be between + # -(output_max-output_min)/2 and (output_max-output_min)/2. If only output_min + # is specified, then scale must be nonnegative. If only output_max is + # specified, then scale must be nonpositive. + if output_min is not None and output_max is not None: + bound = (output_max - output_min) / 2.0 + below_bound_scales = tf.reduce_sum(tf.cast(scale < -bound, tf.int32)) + above_bound_scale = tf.reduce_sum(tf.cast(scale > bound, tf.int32)) + bound_asserts.append( + tf.Assert( + below_bound_scales + above_bound_scale <= 0, + data=[ + "Bound violation (scale out of bounds)", "Bound", bound, + "Scale", scale + ])) + elif output_min is not None: + negative_scales = tf.reduce_sum(tf.cast(scale < 0, tf.int32)) + bound_asserts.append( + tf.Assert( + negative_scales <= 0, + data=[ + "Bound violation (only output_min specified with negative " + "scale values)", "Scale", scale + ])) + elif output_max is not None: + positive_scales = tf.reduce_sum(tf.cast(scale > 0, tf.int32)) + bound_asserts.append( + tf.Assert( + positive_scales <= 0, + data=[ + "Bound violation (only output_max specified with positive " + "scale values)", "Scale", scale + ])) + + return bound_asserts + + +def assert_constraints(weights, + units, + scale, + monotonicities, + output_min, + output_max, + eps=1e-6): """Asserts that weights satisfy constraints. Args: @@ -442,6 +709,8 @@ def assert_constraints(weights, units, scale, monotonicities, eps=1e-6): units: Number of units per input dimension. scale: Scale variable of shape: `(units, num_terms)`. monotonicities: Monotonicity constraints. + output_min: None or minimum layer output. + output_max: None or maximum layer output. eps: Allowed constraints violation. Returns: @@ -459,4 +728,14 @@ def assert_constraints(weights, units, scale, monotonicities, eps=1e-6): eps=eps) asserts.extend(monotonicity_asserts) + if output_min is not None or output_max is not None: + bound_asserts = _assert_bound_constraints( + weights=weights, + units=units, + scale=scale, + output_min=output_min, + output_max=output_max, + eps=eps) + asserts.extend(bound_asserts) + return asserts diff --git a/tensorflow_lattice/python/kronecker_factored_lattice_test.py b/tensorflow_lattice/python/kronecker_factored_lattice_test.py index 824be60..f31e34a 100644 --- a/tensorflow_lattice/python/kronecker_factored_lattice_test.py +++ b/tensorflow_lattice/python/kronecker_factored_lattice_test.py @@ -25,6 +25,7 @@ import tensorflow as tf from tensorflow import keras from tensorflow_lattice.python import kronecker_factored_lattice_layer as kfll +from tensorflow_lattice.python import kronecker_factored_lattice_lib as kfl_lib from tensorflow_lattice.python import test_utils @@ -162,9 +163,9 @@ def _GetTrainingInputsAndLabels(self, config): Returns: Tuple `(training_inputs, training_labels, raw_training_inputs)` where - `training_inputs` and `training_labels` are data for training and - `raw_training_inputs` are representation of training_inputs for - visualisation. + `training_inputs` and `training_labels` are data for training and + `raw_training_inputs` are representation of training_inputs for + visualisation. """ raw_training_inputs = config["x_generator"]( num_points=config["num_training_records"], @@ -185,10 +186,12 @@ def _SetDefaults(self, config): config.setdefault("units", 1) config.setdefault("num_terms", 2) config.setdefault("monotonicities", None) + config.setdefault("output_min", None) + config.setdefault("output_max", None) config.setdefault("signal_name", "TEST") - config.setdefault("satisfy_constraints_at_every_step", True) config.setdefault("target_monotonicity_diff", 0.0) config.setdefault("lattice_index", 0) + config.setdefault("scale_initializer", "scale_initializer") return config @@ -237,9 +240,10 @@ def _TrainModel(self, config, plot_path=None): units=units, num_terms=config["num_terms"], monotonicities=config["monotonicities"], - satisfy_constraints_at_every_step=config[ - "satisfy_constraints_at_every_step"], + output_min=config["output_min"], + output_max=config["output_max"], kernel_initializer=config["kernel_initializer"], + scale_initializer=config["scale_initializer"], input_shape=input_shape, dtype=tf.float32) model = keras.models.Sequential() @@ -264,6 +268,8 @@ def _TrainModel(self, config, plot_path=None): if tf.executing_eagerly(): tf.print("final weights: ", keras_layer.kernel) + tf.print("final scale: ", keras_layer.scale) + tf.print("final bias: ", keras_layer.bias) assetion_ops = keras_layer.assert_constraints( eps=-config["target_monotonicity_diff"]) if not tf.executing_eagerly() and assetion_ops: @@ -275,7 +281,7 @@ def testMonotonicityOneD(self): if self.disable_all: return monotonicities = [1] - kernel_initializer = kfll.RandomMonotonicInitializer( + kernel_initializer = kfll.KFLRandomMonotonicInitializer( monotonicities=monotonicities, seed=self.seed) config = { "lattice_sizes": 20, @@ -290,11 +296,11 @@ def testMonotonicityOneD(self): "kernel_initializer": kernel_initializer, } # pyformat: disable loss = self._TrainModel(config) - self.assertAlmostEqual(loss, 0.118006, delta=self.loss_eps) + self.assertAlmostEqual(loss, 0.114794, delta=self.loss_eps) self._TestEnsemble(config) monotonicities = ["increasing"] - kernel_initializer = kfll.RandomMonotonicInitializer( + kernel_initializer = kfll.KFLRandomMonotonicInitializer( monotonicities=monotonicities, seed=self.seed) config = { "lattice_sizes": 20, @@ -309,11 +315,11 @@ def testMonotonicityOneD(self): "kernel_initializer": kernel_initializer, } # pyformat: disable loss = self._TrainModel(config) - self.assertAlmostEqual(loss, 2.842038, delta=self.loss_eps) + self.assertAlmostEqual(loss, 2.841594, delta=self.loss_eps) self._TestEnsemble(config) monotonicities = [1] - kernel_initializer = kfll.RandomMonotonicInitializer( + kernel_initializer = kfll.KFLRandomMonotonicInitializer( monotonicities=monotonicities, seed=self.seed) config = { "lattice_sizes": 5, @@ -337,7 +343,7 @@ def testMonotonicityTwoD(self): if self.disable_all: return monotonicities = [1, 1] - kernel_initializer = kfll.RandomMonotonicInitializer( + kernel_initializer = kfll.KFLRandomMonotonicInitializer( monotonicities=monotonicities, seed=self.seed) config = { "lattice_sizes": 21, @@ -352,11 +358,11 @@ def testMonotonicityTwoD(self): "kernel_initializer": kernel_initializer, } # pyformat: disable loss = self._TrainModel(config) - self.assertAlmostEqual(loss, 0.530398, delta=self.loss_eps) + self.assertAlmostEqual(loss, 0.405031, delta=self.loss_eps) self._TestEnsemble(config) monotonicities = ["none", "increasing"] - kernel_initializer = kfll.RandomMonotonicInitializer( + kernel_initializer = kfll.KFLRandomMonotonicInitializer( monotonicities=monotonicities, seed=self.seed) config = { "lattice_sizes": 21, @@ -371,11 +377,11 @@ def testMonotonicityTwoD(self): "kernel_initializer": kernel_initializer, } # pyformat: disable loss = self._TrainModel(config) - self.assertAlmostEqual(loss, 0.595422, delta=self.loss_eps) + self.assertAlmostEqual(loss, 0.209862, delta=self.loss_eps) self._TestEnsemble(config) monotonicities = [1, 0] - kernel_initializer = kfll.RandomMonotonicInitializer( + kernel_initializer = kfll.KFLRandomMonotonicInitializer( monotonicities=monotonicities, seed=self.seed) config = { "lattice_sizes": 21, @@ -390,11 +396,11 @@ def testMonotonicityTwoD(self): "kernel_initializer": kernel_initializer, } # pyformat: disable loss = self._TrainModel(config) - self.assertAlmostEqual(loss, 0.362752, delta=self.loss_eps) + self.assertAlmostEqual(loss, 0.417009, delta=self.loss_eps) self._TestEnsemble(config) monotonicities = [1, 1] - kernel_initializer = kfll.RandomMonotonicInitializer( + kernel_initializer = kfll.KFLRandomMonotonicInitializer( monotonicities=monotonicities, seed=self.seed) config = { "lattice_sizes": 2, @@ -409,7 +415,7 @@ def testMonotonicityTwoD(self): "kernel_initializer": kernel_initializer, } # pyformat: disable loss = self._TrainModel(config) - self.assertAlmostEqual(loss, 0.051138, delta=self.loss_eps) + self.assertAlmostEqual(loss, 0.050983, delta=self.loss_eps) self._TestEnsemble(config) def testMonotonicity5d(self): @@ -434,7 +440,7 @@ def testMonotonicity5d(self): self.assertAlmostEqual(loss, 0.000524, delta=self.loss_eps) monotonicities = [1, 1, 1, 1, 1] - kernel_initializer = kfll.RandomMonotonicInitializer( + kernel_initializer = kfll.KFLRandomMonotonicInitializer( monotonicities=monotonicities, seed=self.seed) config = { "lattice_sizes": 2, @@ -449,11 +455,11 @@ def testMonotonicity5d(self): "kernel_initializer": kernel_initializer, } # pyformat: disable loss = self._TrainModel(config) - self.assertAlmostEqual(loss, 0.015825, delta=self.loss_eps) + self.assertAlmostEqual(loss, 0.014968, delta=self.loss_eps) self._TestEnsemble(config) monotonicities = [1, "increasing", 1, 1] - kernel_initializer = kfll.RandomMonotonicInitializer( + kernel_initializer = kfll.KFLRandomMonotonicInitializer( monotonicities=monotonicities, seed=self.seed) config = { "lattice_sizes": 3, @@ -468,7 +474,7 @@ def testMonotonicity5d(self): "kernel_initializer": kernel_initializer, } # pyformat: disable loss = self._TrainModel(config) - self.assertAlmostEqual(loss, 0.376523, delta=self.loss_eps) + self.assertAlmostEqual(loss, 0.377255, delta=self.loss_eps) self._TestEnsemble(config) @parameterized.parameters( @@ -530,7 +536,7 @@ def testMonotonicity10dSinOfSum(self): if self.disable_all: return monotonicities = [1] * 10 - kernel_initializer = kfll.RandomMonotonicInitializer( + kernel_initializer = kfll.KFLRandomMonotonicInitializer( monotonicities=monotonicities, seed=self.seed) config = { "lattice_sizes": 2, @@ -545,19 +551,19 @@ def testMonotonicity10dSinOfSum(self): "kernel_initializer": kernel_initializer, } # pyformat: disable loss = self._TrainModel(config) - self.assertAlmostEqual(loss, 0.183625, delta=self.loss_eps) + self.assertAlmostEqual(loss, 0.190541, delta=self.loss_eps) monotonicities = [0, 1, 0, 1, 1, 0, 1, 1, 1, 0] - kernel_initializer = kfll.RandomMonotonicInitializer( + kernel_initializer = kfll.KFLRandomMonotonicInitializer( monotonicities=monotonicities, seed=self.seed) config["monotonicities"] = monotonicities config["kernel_initializer"] = kernel_initializer loss = self._TrainModel(config) - self.assertAlmostEqual(loss, 0.190994, delta=self.loss_eps) + self.assertAlmostEqual(loss, 0.194174, delta=self.loss_eps) @parameterized.parameters( # Custom TFL initializer: - ("random_monotonic_initializer", 2.668374), + ("kfl_random_monotonic_initializer", 2.668374), # Standard Keras initializer: (keras.initializers.Constant(value=1.5), 2.140740), # Standard Keras initializer specified as string constant: @@ -566,8 +572,8 @@ def testMonotonicity10dSinOfSum(self): def testInitializerType(self, initializer, expected_loss): if self.disable_all: return - if initializer == "random_monotonic_initializer": - initializer = kfll.RandomMonotonicInitializer( + if initializer == "kfl_random_monotonic_initializer": + initializer = kfll.KFLRandomMonotonicInitializer( monotonicities=None, seed=self.seed) config = { "lattice_sizes": 3, @@ -615,6 +621,173 @@ def testAssertMonotonicity(self): with self.assertRaises(tf.errors.InvalidArgumentError): self._TrainModel(config) + @parameterized.parameters( + (-1, 1, + kfll.KFLRandomMonotonicInitializer( + monotonicities=None, init_min=-10, + init_max=10), "scale_initializer"), + (None, 1, + kfll.KFLRandomMonotonicInitializer( + monotonicities=None, init_min=-10, + init_max=10), "scale_initializer"), + (-1, None, + kfll.KFLRandomMonotonicInitializer( + monotonicities=None, init_min=-10, + init_max=10), "scale_initializer"), + (-1, 1, "kfl_random_monotonic_initializer", + tf.keras.initializers.Constant(value=-100)), + (None, 1, "kfl_random_monotonic_initializer", + tf.keras.initializers.Constant(value=100)), + (-1, None, "kfl_random_monotonic_initializer", + tf.keras.initializers.Constant(value=-100)), + ) + def testAssertBounds(self, output_min, output_max, kernel_initializer, + scale_initializer): + if self.disable_all: + return + # Specify random initializer that ensures initial output can be out of + # bounds and do 0 training iterations so no projections are executed. + config = { + "lattice_sizes": 2, + "input_dims": 2, + "num_training_records": 100, + "num_training_epoch": 0, + "optimizer": tf.keras.optimizers.Adagrad, + "learning_rate": 0.15, + "x_generator": self._TwoDMeshGrid, + "y_function": self._ScaledSum, + "monotonicities": [0, 0], + "output_min": output_min, + "output_max": output_max, + "kernel_initializer": kernel_initializer, + "scale_initializer": scale_initializer, + } + with self.assertRaises(tf.errors.InvalidArgumentError): + self._TrainModel(config) + + @parameterized.parameters( + (2, 1, -3, -1, 4.868363), + (2, 2, 0, 1, 0.169257), + (1, 2, -5, 5, 0.011738), + (1, 10, -1, 1, 0.680978), + (1, 3, None, None, 0.035185), + (1, 2, None, 5, 0.011590), + (3, 3, 0, None, 0.010172), + (4, 2, None, -2, 10.178278), + ) + def testOutputBounds(self, units, input_dims, output_min, output_max, + expected_loss): + if self.disable_all: + return + monotonicities = [1] * input_dims + kernel_initializer = kfll.KFLRandomMonotonicInitializer( + monotonicities=monotonicities, seed=self.seed) + config = { + "lattice_sizes": 4, + "units": units, + "input_dims": input_dims, + "num_training_records": 900, + "num_training_epoch": 100, + "optimizer": tf.keras.optimizers.Adagrad, + "learning_rate": 1.0, + "x_generator": self._ScatterXUniformly, + "y_function": self._SinPlusX, + "monotonicities": monotonicities, + "output_min": output_min, + "output_max": output_max, + "kernel_initializer": kernel_initializer, + # This is the epsilon error allowed when asserting constraints, + # including bounds. We include this to ensure that the bound constraint + # assertions do not fail due to numerical errors. + "target_monotonicity_diff": -1e-6, + } # pyformat: disable + loss = self._TrainModel(config) + self.assertAlmostEqual(loss, expected_loss, delta=self.loss_eps) + self._TestEnsemble(config) + + @parameterized.parameters( + (2, 1, 3, 2, -1, 1), + (2, 2, 2, 1, None, 1), + (2, 1, 3, 4, -1, None), + (3, 3, 4, 2, -50, 2), + (4, 4, 2, 4, -1.5, 2.3), + (2, 2, 2, 2, None, None), + ) + # Note: dims must be at least 1 + def testConstraints(self, lattice_sizes, units, dims, num_terms, output_min, + output_max): + if self.disable_all: + return + # Run our test for 100 iterations to minimize the chance we pass by chance. + for _ in range(100): + # Create 100 random inputs that are frozen in all but the increasing + # dimension, which increases uniformly from 0 to lattice_sizes-1. + batch_size = 100 + random_vals = [ + np.random.uniform(0, lattice_sizes - 1) for _ in range(dims - 1) + ] + increasing_dim = np.random.randint(0, dims) + step_size = (lattice_sizes - 1) / batch_size + values = [ + np.roll([0.0 + (i * step_size)] + random_vals, increasing_dim) + for i in range(batch_size) + ] + if units > 1: + values = [[value] * units for value in values] + shape = [batch_size, units, dims] + else: + shape = [batch_size, dims] + inputs = tf.constant(values, dtype=tf.float32, shape=shape) + + # Create our weights, constraint them, and evaluate our function on our + # constructed inputs. + init_min = -1.5 if output_min is None else output_min + init_max = 1.5 if output_max is None else output_max + + # Offset the initiailization bounds to increase likelihood of breaking + # constraints. + offset = 100 + kernel = tf.random.uniform([1, lattice_sizes, units * dims, num_terms], + minval=init_min - offset, + maxval=init_max + offset) + scale = tf.random.uniform([units, num_terms], + minval=init_min, + maxval=init_max) + bias = kfl_lib.bias_initializer( + units, output_min, output_max, dtype=tf.float32) + + scale_constraint = kfll.ScaleConstraints(output_min, output_max) + constrained_scale = scale_constraint(scale) + + monotonicities = [np.random.randint(0, 2) for _ in range(dims)] + monotonicities[increasing_dim] = 1 + kernel_constraint = kfll.KroneckerFactoredLatticeConstraints( + units, constrained_scale, monotonicities, output_min, output_max) + constrained_kernel = kernel_constraint(kernel) + + outputs = kfl_lib.evaluate_with_hypercube_interpolation( + inputs=inputs, + scale=constrained_scale, + bias=bias, + kernel=constrained_kernel, + units=units, + num_terms=num_terms, + lattice_sizes=lattice_sizes, + clip_inputs=True) + + # Check that outputs are inside our bounds + min_check = float("-inf") if output_min is None else output_min + self.assertEqual(tf.reduce_sum(tf.cast(outputs < min_check, tf.int32)), 0) + max_check = float("+inf") if output_max is None else output_max + self.assertEqual(tf.reduce_sum(tf.cast(outputs > max_check, tf.int32)), 0) + # Check that we satisfy monotonicity constraints. Note that by + # construction the outputs should already be in sorted order. + sorted_outputs = tf.sort(outputs, axis=0) + # We use close equality instead of strict equality because of numerical + # errors that result in nearly identical arrays failing a strict check + # after sorting. + self.assertAllClose(outputs, sorted_outputs, rtol=1e-6, atol=1e-6) + def testInputOutOfBounds(self): if self.disable_all: return @@ -633,7 +806,7 @@ def testInputOutOfBounds(self): self.assertAlmostEqual(loss, 0.018726, delta=self.loss_eps) self._TestEnsemble(config) - kernel_initializer = kfll.RandomMonotonicInitializer( + kernel_initializer = kfll.KFLRandomMonotonicInitializer( monotonicities=None, seed=self.seed) config = { "lattice_sizes": 2, @@ -655,7 +828,7 @@ def testHighDimensionsStressTest(self): return monotonicities = [0] * 16 monotonicities[3], monotonicities[4], monotonicities[10] = (1, 1, 1) - kernel_initializer = kfll.RandomMonotonicInitializer( + kernel_initializer = kfll.KFLRandomMonotonicInitializer( monotonicities=monotonicities, seed=self.seed) config = { "lattice_sizes": 2, @@ -673,14 +846,14 @@ def testHighDimensionsStressTest(self): "target_monotonicity_diff": -1e-5, } # pyformat: disable loss = self._TrainModel(config) - self.assertAlmostEqual(loss, 0.251715, delta=self.loss_eps) + self.assertAlmostEqual(loss, 0.257097, delta=self.loss_eps) @parameterized.parameters( - (2, 5, 2, 49), - (2, 6, 4, 49), - (2, 9, 2, 49), - (3, 5, 4, 56), - (3, 9, 2, 56), + (2, 5, 2, 57), + (2, 6, 4, 57), + (2, 9, 2, 57), + (3, 5, 4, 63), + (3, 9, 2, 63), ) def testGraphSize(self, lattice_sizes, input_dims, num_terms, expected_graph_size): @@ -703,13 +876,16 @@ def testGraphSize(self, lattice_sizes, input_dims, num_terms, @parameterized.parameters( ("random_uniform", tf.keras.initializers.RandomUniform), - ("random_monotonic_initializer", kfll.RandomMonotonicInitializer)) + ("kfl_random_monotonic_initializer", kfll.KFLRandomMonotonicInitializer)) def testCreateKernelInitializer(self, kernel_initializer_id, expected_type): self.assertEqual( expected_type, type( kfll.create_kernel_initializer( - kernel_initializer_id, monotonicities=None))) + kernel_initializer_id, + monotonicities=None, + output_min=None, + output_max=None))) # We test that the scale variable attribute of our KroneckerFactoredLattice # is the same object as the scale contained in the constraint on the kernel, @@ -718,8 +894,9 @@ def testCreateKernelInitializer(self, kernel_initializer_id, expected_type): # across all uses of the object. def testSavingLoadingScale(self): # Create simple x --> x^2 dataset. - train_data = [[float(x), float(x)**2] for x in range(100)] + train_data = [[[float(x)], float(x)**2] for x in range(100)] train_x, train_y = zip(*train_data) + train_x, train_y = np.array(train_x), np.array(train_y) # Construct simple single lattice model. Must have monotonicities specified # or constraint will be None. keras_layer = kfll.KroneckerFactoredLattice( @@ -743,7 +920,15 @@ def testSavingLoadingScale(self): "KroneckerFactoredLattice": kfll.KroneckerFactoredLattice, "KroneckerFactoredLatticeConstraint": - kfll.KroneckerFactoredLatticeConstraints + kfll.KroneckerFactoredLatticeConstraints, + "KFLRandomMonotonicInitializer": + kfll.KFLRandomMonotonicInitializer, + "ScaleInitializer": + kfll.ScaleInitializer, + "ScaleConstraints": + kfll.ScaleConstraints, + "BiasInitializer": + kfll.BiasInitializer, }) # Extract loaded layer. loaded_keras_layer = loaded_model.layers[0] @@ -751,10 +936,6 @@ def testSavingLoadingScale(self): loaded_layer_scale = loaded_keras_layer.scale loaded_constraint_scale = loaded_keras_layer.kernel.constraint.scale self.assertIs(loaded_layer_scale, loaded_constraint_scale) - # Train for another epoch and test equality of all updated elements just to - # be safe. - loaded_model.fit(train_x, train_y) - self.assertAllEqual(loaded_layer_scale, loaded_constraint_scale) @parameterized.parameters( (1, 3, 1), diff --git a/tensorflow_lattice/python/lattice_layer.py b/tensorflow_lattice/python/lattice_layer.py index 0bac43e..236b8e8 100644 --- a/tensorflow_lattice/python/lattice_layer.py +++ b/tensorflow_lattice/python/lattice_layer.py @@ -532,9 +532,15 @@ def assert_constraints(self, eps=1e-6): eps=eps) -def create_kernel_initializer(kernel_initializer_id, lattice_sizes, - monotonicities, output_min, output_max, - unimodalities, joint_unimodalities): +def create_kernel_initializer(kernel_initializer_id, + lattice_sizes, + monotonicities, + output_min, + output_max, + unimodalities, + joint_unimodalities, + init_min=None, + init_max=None): """Returns a kernel Keras initializer object from its id. This function is used to convert the 'kernel_initializer' parameter in the @@ -555,29 +561,20 @@ def create_kernel_initializer(kernel_initializer_id, lattice_sizes, constructor of tfl.Lattice. joint_unimodalities: See the documentation of the same parameter in the constructor of tfl.Lattice. + init_min: None or lower bound of kernel initialization. If set, init_max + must also be set. + init_max: None or upper bound of kernel initialization. If set, init_min + must also be set. Returns: The Keras initializer object for the tfl.Lattice kernel variable. - """ - - def default_params(output_min, output_max): - """Return reasonable default parameters if not defined explicitly.""" - if output_min is not None: - output_init_min = output_min - elif output_max is not None: - output_init_min = min(0.0, output_max) - else: - output_init_min = 0.0 - if output_max is not None: - output_init_max = output_max - elif output_min is not None: - output_init_max = max(1.0, output_min) - else: - output_init_max = 1.0 - - # Return our min and max. - return output_init_min, output_init_max + Raises: + ValueError: If only one of init_{min/max} is set. + """ + if ((init_min is not None and init_max is None) or + (init_min is None and init_max is not None)): + raise ValueError("Both or neither of init_{min/max} must be set") def do_joint_unimodalities_contain_all_features(joint_unimodalities): if (joint_unimodalities is None) or (len(joint_unimodalities) != 1): @@ -597,23 +594,27 @@ def do_joint_unimodalities_contain_all_features(joint_unimodalities): all_unimodalities[dim] = direction if kernel_initializer_id in ["linear_initializer", "LinearInitializer"]: - output_init_min, output_init_max = default_params(output_min, output_max) + if init_min is None and init_max is None: + init_min, init_max = lattice_lib.default_init_params( + output_min, output_max) return LinearInitializer( lattice_sizes=lattice_sizes, monotonicities=monotonicities, - output_min=output_init_min, - output_max=output_init_max, + output_min=init_min, + output_max=init_max, unimodalities=all_unimodalities) elif kernel_initializer_id in [ "random_monotonic_initializer", "RandomMonotonicInitializer" ]: - output_init_min, output_init_max = default_params(output_min, output_max) + if init_min is None and init_max is None: + init_min, init_max = lattice_lib.default_init_params( + output_min, output_max) return RandomMonotonicInitializer( lattice_sizes=lattice_sizes, - output_min=output_init_min, - output_max=output_init_max, + output_min=init_min, + output_max=init_max, unimodalities=all_unimodalities) elif kernel_initializer_id in [ "random_uniform_or_linear_initializer", "RandomUniformOrLinearInitializer" @@ -621,10 +622,12 @@ def do_joint_unimodalities_contain_all_features(joint_unimodalities): if do_joint_unimodalities_contain_all_features(joint_unimodalities): return create_kernel_initializer("random_uniform", lattice_sizes, monotonicities, output_min, output_max, - unimodalities, joint_unimodalities) + unimodalities, joint_unimodalities, + init_min, init_max) return create_kernel_initializer("linear_initializer", lattice_sizes, monotonicities, output_min, output_max, - unimodalities, joint_unimodalities) + unimodalities, joint_unimodalities, + init_min, init_max) else: # This is needed for Keras deserialization logic to be aware of our custom # objects. diff --git a/tensorflow_lattice/python/lattice_lib.py b/tensorflow_lattice/python/lattice_lib.py index 27fbe9e..23e6733 100644 --- a/tensorflow_lattice/python/lattice_lib.py +++ b/tensorflow_lattice/python/lattice_lib.py @@ -413,6 +413,31 @@ def _bucketize_consequtive_equal_dims(inputs, lattice_sizes): return zip(inputs, bucket_sizes, bucket_dim_sizes) +def default_init_params(output_min, output_max): + """Returns reasonable default parameters if not defined explicitly. + + Args: + output_min: None or minimum layer output. + output_max: None or maximum layer output. + """ + if output_min is not None: + init_min = output_min + elif output_max is not None: + init_min = min(0.0, output_max) + else: + init_min = 0.0 + + if output_max is not None: + init_max = output_max + elif output_min is not None: + init_max = max(1.0, output_min) + else: + init_max = 1.0 + + # Return our min and max. + return init_min, init_max + + def linear_initializer(lattice_sizes, output_min, output_max, diff --git a/tensorflow_lattice/python/model_info.py b/tensorflow_lattice/python/model_info.py index 465ea94..9a631f5 100644 --- a/tensorflow_lattice/python/model_info.py +++ b/tensorflow_lattice/python/model_info.py @@ -33,8 +33,8 @@ class ModelGraph( describe model structure and parameters. Attributes: - nodes: List of all the nodes in the model. - output_node: The output node of the model. + nodes: List of all the nodes in the model. + output_node: The output node of the model. """ @@ -44,9 +44,9 @@ class InputFeatureNode( """Input features to the model. Attributes: - name: Name of the input feature. - is_categorical: If the feature is categorical. - vocabulary_list: Category values for categorical features or None. + name: Name of the input feature. + is_categorical: If the feature is categorical. + vocabulary_list: Category values for categorical features or None. """ @@ -58,11 +58,11 @@ class PWLCalibrationNode( """Represetns a PWL calibration layer. Attributes: - input_node: Input node for the calibration. - input_keypoints: Input keypoints for PWL calibration. - output_keypoints: Output keypoints for PWL calibration. - default_input: Default/missing input value or None. - default_output: Default/missing output value or None. + input_node: Input node for the calibration. + input_keypoints: Input keypoints for PWL calibration. + output_keypoints: Output keypoints for PWL calibration. + default_input: Default/missing input value or None. + default_output: Default/missing output value or None. """ @@ -72,10 +72,10 @@ class CategoricalCalibrationNode( """Represetns a categorical calibration layer. Attributes: - input_node: Input node for the calibration. - output_values: Output calibration values. If the calibrated feature has - default/missing values, the last value will be for default/missing. - default_input: Default/missing input value or None. + input_node: Input node for the calibration. + output_values: Output calibration values. If the calibrated feature has + default/missing values, the last value will be for default/missing. + default_input: Default/missing input value or None. """ @@ -85,9 +85,9 @@ class LinearNode( """Represents a linear layer. Attributes: - input_nodes: List of input nodes to the linear layer. - coefficients: Linear weights. - bias: Bias term for the linear layer. + input_nodes: List of input nodes to the linear layer. + coefficients: Linear weights. + bias: Bias term for the linear layer. """ @@ -96,8 +96,23 @@ class LatticeNode( """Represetns a lattice layer. Attributes: - input_nodes: List of input nodes to the lattice layer. - weights: Lattice parameters. + input_nodes: List of input nodes to the lattice layer. + weights: Lattice parameters. + """ + + +class KroneckerFactoredLatticeNode( + collections.namedtuple('KroneckerFactoredLatticeNode', + ['input_nodes', 'weights', 'scale', 'bias'])): + """Represents a kronecker-factored lattice layer. + + Attributes: + input_nodes: List of input nodes to the kronecker-factored lattice layer. + weights: Kronecker-factored lattice kernel parameters of shape + `(1, lattice_sizes, units * dims, num_terms)`. + scale: Kronecker-factored lattice scale parameters of shape + `(units, num_terms)`. + bias: Kronecker-factored lattice bias parameters of shape `(units)`. """ @@ -105,5 +120,5 @@ class MeanNode(collections.namedtuple('MeanNode', ['input_nodes'])): """Represents an averaging layer. Attributes: - input_nodes: List of input nodes to the average layer. + input_nodes: List of input nodes to the average layer. """ diff --git a/tensorflow_lattice/python/premade.py b/tensorflow_lattice/python/premade.py index 7f538b8..cadb92e 100644 --- a/tensorflow_lattice/python/premade.py +++ b/tensorflow_lattice/python/premade.py @@ -39,6 +39,7 @@ from . import aggregation_layer from . import categorical_calibration_layer from . import configs +from . import kronecker_factored_lattice_layer as kfll from . import lattice_layer from . import linear_layer from . import parallel_combination_layer @@ -528,6 +529,8 @@ def get_custom_objects(custom_objects=None): configs.AggregateFunctionConfig, 'Aggregation': aggregation_layer.Aggregation, + 'BiasInitializer': + kfll.BiasInitializer, 'CalibratedLatticeEnsemble': CalibratedLatticeEnsemble, 'CalibratedLattice': @@ -548,6 +551,12 @@ def get_custom_objects(custom_objects=None): configs.DominanceConfig, 'FeatureConfig': configs.FeatureConfig, + 'KFLRandomMonotonicInitializer': + kfll.KFLRandomMonotonicInitializer, + 'KroneckerFactoredLattice': + kfll.KroneckerFactoredLattice, + 'KroneckerFactoredLatticeConstraints': + kfll.KroneckerFactoredLatticeConstraints, 'LaplacianRegularizer': lattice_layer.LaplacianRegularizer, 'Lattice': @@ -574,6 +583,10 @@ def get_custom_objects(custom_objects=None): configs.RegularizerConfig, 'RTL': rtl_layer.RTL, + 'ScaleConstraints': + kfll.ScaleConstraints, + 'ScaleInitializer': + kfll.ScaleInitializer, 'TorsionRegularizer': lattice_layer.TorsionRegularizer, 'TrustConfig': diff --git a/tensorflow_lattice/python/premade_lib.py b/tensorflow_lattice/python/premade_lib.py index 2583555..28097f2 100644 --- a/tensorflow_lattice/python/premade_lib.py +++ b/tensorflow_lattice/python/premade_lib.py @@ -25,6 +25,8 @@ from . import aggregation_layer from . import categorical_calibration_layer from . import configs +from . import kronecker_factored_lattice_layer as kfll +from . import kronecker_factored_lattice_lib as kfl_lib from . import lattice_layer from . import lattice_lib from . import linear_layer @@ -42,6 +44,7 @@ AGGREGATION_LAYER_NAME = 'tfl_aggregation' CALIB_LAYER_NAME = 'tfl_calib' INPUT_LAYER_NAME = 'tfl_input' +KFL_LAYER_NAME = 'tfl_kronecker_factored_lattice' LATTICE_LAYER_NAME = 'tfl_lattice' LINEAR_LAYER_NAME = 'tfl_linear' OUTPUT_LINEAR_COMBINATION_LAYER_NAME = 'tfl_output_linear_combination' @@ -141,8 +144,19 @@ def _output_range(layer_output_range, model_config, feature_config=None): elif layer_output_range == LayerOutputRange.MODEL_OUTPUT: output_min = model_config.output_min output_max = model_config.output_max - output_init_min = np.min(model_config.output_initialization) - output_init_max = np.max(model_config.output_initialization) + # Note: due to the multiplicative nature of KroneckerFactoredLattice layers, + # the initialization min/max do not correspond directly to the output + # min/max. Thus we follow the same scheme as the KroneckerFactoredLattice + # lattice layer to properly initialize the kernel and scale such that + # the output does in fact respect the requested bounds. + if ((isinstance(model_config, configs.CalibratedLatticeEnsembleConfig) or + isinstance(model_config, configs.CalibratedLatticeConfig)) and + model_config.parameterization == 'kronecker_factored'): + output_init_min, output_init_max = kfl_lib.default_init_params( + output_min, output_max) + else: + output_init_min = np.min(model_config.output_initialization) + output_init_max = np.max(model_config.output_initialization) elif layer_output_range == LayerOutputRange.INPUT_TO_FINAL_CALIBRATION: output_init_min = output_min = 0.0 output_init_max = output_max = 1.0 @@ -512,7 +526,13 @@ def build_lattice_layer(lattice_input, feature_configs, model_config, dtype: dtype Returns: - A `tfl.layers.Lattice` instance. + A `tfl.layers.Lattice` instance if `model_config.parameterization` is set to + `'all_vertices'` or a `tfl.layers.KroneckerFactoredLattice` instance if + set to `'kronecker_factored'`. + + Raises: + ValueError: If `model_config.parameterization` is not one of + `'all_vertices'` or `'kronecker_factored'`. """ layer_name = '{}_{}'.format(LATTICE_LAYER_NAME, submodel_index) @@ -558,28 +578,54 @@ def build_lattice_layer(lattice_input, feature_configs, model_config, monotonic_dominances = _dominance_constraints_from_feature_configs( feature_configs) - kernel_initializer = lattice_layer.LinearInitializer( - lattice_sizes=lattice_sizes, - monotonicities=lattice_monotonicities, - unimodalities=lattice_unimodalities, - output_min=output_init_min, - output_max=output_init_max) - return lattice_layer.Lattice( - lattice_sizes=lattice_sizes, - monotonicities=lattice_monotonicities, - unimodalities=lattice_unimodalities, - edgeworth_trusts=edgeworth_trusts, - trapezoid_trusts=trapezoid_trusts, - monotonic_dominances=monotonic_dominances, - output_min=output_min, - output_max=output_max, - clip_inputs=False, - interpolation=model_config.interpolation, - kernel_regularizer=lattice_regularizers, - kernel_initializer=kernel_initializer, - dtype=dtype, - name=layer_name)( - lattice_input) + if model_config.parameterization == 'all_vertices': + layer_name = '{}_{}'.format(LATTICE_LAYER_NAME, submodel_index) + kernel_initializer = lattice_layer.LinearInitializer( + lattice_sizes=lattice_sizes, + monotonicities=lattice_monotonicities, + unimodalities=lattice_unimodalities, + output_min=output_init_min, + output_max=output_init_max) + return lattice_layer.Lattice( + lattice_sizes=lattice_sizes, + monotonicities=lattice_monotonicities, + unimodalities=lattice_unimodalities, + edgeworth_trusts=edgeworth_trusts, + trapezoid_trusts=trapezoid_trusts, + monotonic_dominances=monotonic_dominances, + output_min=output_min, + output_max=output_max, + clip_inputs=False, + interpolation=model_config.interpolation, + kernel_regularizer=lattice_regularizers, + kernel_initializer=kernel_initializer, + dtype=dtype, + name=layer_name)( + lattice_input) + elif model_config.parameterization == 'kronecker_factored': + layer_name = '{}_{}'.format(KFL_LAYER_NAME, submodel_index) + kernel_initializer = kfll.KFLRandomMonotonicInitializer( + monotonicities=lattice_monotonicities, + init_min=output_init_min, + init_max=output_init_max, + seed=model_config.random_seed) + scale_initializer = kfll.ScaleInitializer( + output_min=output_min, output_max=output_max) + return kfll.KroneckerFactoredLattice( + lattice_sizes=lattice_sizes[0], + num_terms=model_config.num_terms, + monotonicities=lattice_monotonicities, + output_min=output_min, + output_max=output_max, + clip_inputs=False, + kernel_initializer=kernel_initializer, + scale_initializer=scale_initializer, + dtype=dtype, + name=layer_name)( + lattice_input) + else: + raise ValueError('Unknown type of parameterization: {}'.format( + model_config.parameterization)) def build_lattice_ensemble_layer(submodels_inputs, model_config, dtype): @@ -634,6 +680,10 @@ def build_rtl_layer(calibration_outputs, model_config, submodel_index, Returns: A `tfl.layers.RTL` instance. + + Raises: + ValueError: If `model_config.parameterization` is not one of + `'all_vertices'` or `'kronecker_factored'`. """ layer_name = '{}_{}'.format(RTL_LAYER_NAME, submodel_index) @@ -658,18 +708,26 @@ def build_rtl_layer(calibration_outputs, model_config, submodel_index, rtl_inputs['unconstrained'].append(calibration_output) lattice_size = model_config.feature_configs[0].lattice_size - kernel_initializer = lattice_layer.RandomMonotonicInitializer( - lattice_sizes=[lattice_size] * model_config.lattice_rank, - output_min=output_init_min, - output_max=output_init_max) + if model_config.parameterization == 'all_vertices': + kernel_initializer = 'random_monotonic_initializer' + elif model_config.parameterization == 'kronecker_factored': + kernel_initializer = 'kfl_random_monotonic_initializer' + else: + raise ValueError('Unknown type of parameterization: {}'.format( + model_config.parameterization)) return rtl_layer.RTL( num_lattices=model_config.num_lattices, lattice_rank=model_config.lattice_rank, lattice_size=lattice_size, output_min=output_min, output_max=output_max, + init_min=output_init_min, + init_max=output_init_max, + random_seed=model_config.random_seed, clip_inputs=False, interpolation=model_config.interpolation, + parameterization=model_config.parameterization, + num_terms=model_config.num_terms, kernel_regularizer=lattice_regularizers, kernel_initializer=kernel_initializer, average_outputs=average_outputs, @@ -981,6 +1039,10 @@ def construct_prefitting_model_config(model_config, feature_names=None): # Make a copy of the model config provided and set all pairs covered. prefitting_model_config = copy.deepcopy(model_config) + # Set parameterization of prefitting model to 'all_vertices' to extract + # crystals using normal lattice because we do not have laplacian/torsion + # regularizers for KFL. This should still extract could feature combinations. + prefitting_model_config.parameterization = 'all_vertices' _set_all_pairs_cover_lattices( prefitting_model_config=prefitting_model_config, feature_names=feature_names) @@ -1284,107 +1346,235 @@ def set_crystals_lattice_ensemble(model_config, ] for lattice in lattices] -def verify_config(model_config): - """Verifies that the model_config and feature_configs are fully specified. +def _verify_ensemble_config(model_config): + """Verifies that an ensemble model and feature configs are properly specified. Args: model_config: Model configuration object describing model architecture. Should be one of the model configs in `tfl.configs`. + + Raises: + ValueError: If `model_config.lattices` is set to 'rtl_layer' and + `model_config.num_lattices` is not specified. + ValueError: If `model_config.num_lattices < 2`. + ValueError: If `model_config.lattices` is set to 'rtl_layer' and + `lattice_size` is not the same for all features. + ValueError: If `model_config.lattices` is set to 'rtl_layer' and + there are features with unimodality constraints. + ValueError: If `model_config.lattices` is set to 'rtl_layer' and + there are features with trust constraints. + ValueError: If `model_config.lattices` is set to 'rtl_layer' and + there are features with dominance constraints. + ValueError: If `model_config.lattices` is set to 'rtl_layer' and + there are per-feature lattice regularizers. + ValueError: If `model_config.lattices` is not iterable or constaints + non-string values. + ValueError: If `model_config.lattices` is not set to 'rtl_layer' or a fully + specified list of lists of feature names. """ - if isinstance(model_config, configs.CalibratedLatticeEnsembleConfig): + if model_config.lattices == 'rtl_layer': + # RTL must have num_lattices specified and >= 2. + if model_config.num_lattices is None: + raise ValueError('model_config.num_lattices must be specified when ' + 'model_config.lattices is set to \'rtl_layer\'.') + if model_config.num_lattices < 2: + raise ValueError( + 'CalibratedLatticeEnsemble must have >= 2 lattices. For single ' + 'lattice models, use CalibratedLattice instead.') + # Check that all lattices sizes for all features are the same. + if any(feature_config.lattice_size != + model_config.feature_configs[0].lattice_size + for feature_config in model_config.feature_configs): + raise ValueError('RTL Layer must have the same lattice size for all ' + 'features.') + # Check that there are only monotonicity and bound constraints. + if any( + feature_config.unimodality != 'none' and feature_config.unimodality != 0 + for feature_config in model_config.feature_configs): + raise ValueError( + 'RTL Layer does not currently support unimodality constraints.') + if any(feature_config.reflects_trust_in is not None + for feature_config in model_config.feature_configs): + raise ValueError( + 'RTL Layer does not currently support trust constraints.') + if any(feature_config.dominates is not None + for feature_config in model_config.feature_configs): + raise ValueError( + 'RTL Layer does not currently support dominance constraints.') + # Check that there are no per-feature lattice regularizers. + for feature_config in model_config.feature_configs: + for regularizer_config in feature_config.regularizer_configs or []: + if not regularizer_config.name.startswith( + _INPUT_CALIB_REGULARIZER_PREFIX): + raise ValueError( + 'RTL Layer does not currently support per-feature lattice ' + 'regularizers.') + elif isinstance(model_config.lattices, list): # Make sure there are more than one lattice. If not, tell user to use # CalibratedLattice instead. - if model_config.lattices == 'rtl_layer': - # RTL must have num_lattices specified and >= 2. - if model_config.num_lattices is None: - raise ValueError('model_config.num_lattices must be specified when ' - 'model_config.lattices is set to \'rtl_layer\'.') - if model_config.num_lattices < 2: - raise ValueError( - 'CalibratedLatticeEnsemble must have >= 2 lattices. For single ' - 'lattice models, use CalibratedLattice instead.') - # Check that all lattices sizes for all features are the same. - if any(feature_config.lattice_size != - model_config.feature_configs[0].lattice_size - for feature_config in model_config.feature_configs): - raise ValueError('Lattice sizes must be the same for all features.') - # Check that there are only monotonicity and bound constraints. - if any(feature_config.unimodality != 'none' - for feature_config in model_config.feature_configs): - raise ValueError( - 'RTL Layer does not currently support unimodality constraints.') - if any(feature_config.reflects_trust_in is not None - for feature_config in model_config.feature_configs): + if len(model_config.lattices) < 2: + raise ValueError( + 'CalibratedLatticeEnsemble must have >= 2 lattices. For single ' + 'lattice models, use CalibratedLattice instead.') + for lattice in model_config.lattices: + if (not np.iterable(lattice) or + any(not isinstance(x, str) for x in lattice)): raise ValueError( - 'RTL Layer does not currently support trust constraints.') - if any(feature_config.dominates is not None - for feature_config in model_config.feature_configs): + 'Lattices are not fully specified for ensemble config.') + else: + raise ValueError( + 'Lattices are not fully specified for ensemble config. Lattices must ' + 'be set to \'rtl_layer\' or be fully specified as a list of lists of ' + 'feature names.') + + +def _verify_kronecker_factored_config(model_config): + """Verifies that a kronecker_factored model_config is properly specified. + + Args: + model_config: Model configuration object describing model architecture. + Should be one of the model configs in `tfl.configs`. + + Raises: + ValueError: If there are lattice regularizers. + ValueError: If there are per-feature lattice regularizers. + ValueError: If there are unimodality constraints. + ValueError: If there are trust constraints. + ValueError: If there are dominance constraints. + """ + for regularizer_config in model_config.regularizer_configs or []: + if not regularizer_config.name.startswith(_INPUT_CALIB_REGULARIZER_PREFIX): + raise ValueError( + 'KroneckerFactoredLattice layer does not currently support ' + 'lattice regularizers.') + for feature_config in model_config.feature_configs: + for regularizer_config in feature_config.regularizer_configs or []: + if not regularizer_config.name.startswith( + _INPUT_CALIB_REGULARIZER_PREFIX): raise ValueError( - 'RTL Layer does not currently support dominance constraints.') - # Check that there are no per-feature lattice regularizers. - for feature_config in model_config.feature_configs: - for regularizer_config in feature_config.regularizer_configs or []: - if not regularizer_config.name.startswith( - _INPUT_CALIB_REGULARIZER_PREFIX): - raise ValueError( - 'RTL Layer does not currently support per-feature lattice ' - 'regularizers.') - elif isinstance(model_config.lattices, list): - if len(model_config.lattices) < 2: + 'KroneckerFactoredLattice layer does not currently support ' + 'per-feature lattice regularizers.') + # Check that all lattices sizes for all features are the same. + if any(feature_config.lattice_size != + model_config.feature_configs[0].lattice_size + for feature_config in model_config.feature_configs): + raise ValueError('KroneckerFactoredLattice layer must have the same ' + 'lattice size for all features.') + # Check that there are only monotonicity and bound constraints. + if any( + feature_config.unimodality != 'none' and feature_config.unimodality != 0 + for feature_config in model_config.feature_configs): + raise ValueError( + 'KroneckerFactoredLattice layer does not currently support unimodality ' + 'constraints.') + if any(feature_config.reflects_trust_in is not None + for feature_config in model_config.feature_configs): + raise ValueError( + 'KroneckerFactoredLattice layer does not currently support trust ' + 'constraints.') + if any(feature_config.dominates is not None + for feature_config in model_config.feature_configs): + raise ValueError( + 'KroneckerFactoredLattice layer does not currently support dominance ' + 'constraints.') + + +def _verify_aggregate_function_config(model_config): + """Verifies that an aggregate function model_config is properly specified. + + Args: + model_config: Model configuration object describing model architecture. + Should be one of the model configs in `tfl.configs`. + + Raises: + ValueError: If `middle_dimension < 1`. + ValueError: If `model_config.middle_monotonicity` is not None and + `model_config.middle_calibration` is not True. + """ + if model_config.middle_dimension < 1: + raise ValueError('Middle dimension must be at least 1: {}'.format( + model_config.middle_dimension)) + if (model_config.middle_monotonicity is not None and + not model_config.middle_calibration): + raise ValueError( + 'middle_calibration must be true when middle_monotonicity is ' + 'specified.') + + +def _verify_feature_config(feature_config): + """Verifies that feature_config is properly specified. + + Args: + feature_config: Feature configuration object describing an input feature to + a model. Should be an instance of `tfl.configs.FeatureConfig`. + + Raises: + ValueError: If `feature_config.pwl_calibration_input_keypoints` is not + iterable or contains non-{int/float} values for a numerical feature. + ValueError: If `feature_config.monotonicity` is not an iterable for a + categorical feature. + ValueError: If any element in `feature_config.monotonicity` is not an + iterable for a categorical feature. + ValueError: If any value in any element in `feature_config.monotonicity` is + not an int for a categorical feature. + ValueError: If any value in any element in `feature_config.monotonicity` is + not in the range `[0, feature_config.num_buckets]` for a categorical + feature. + """ + if not feature_config.num_buckets: + # Validate PWL Calibration configuration. + if (not np.iterable(feature_config.pwl_calibration_input_keypoints) or + any(not isinstance(x, (int, float)) + for x in feature_config.pwl_calibration_input_keypoints)): + raise ValueError('Input keypoints are invalid for feature {}: {}'.format( + feature_config.name, feature_config.pwl_calibration_input_keypoints)) + elif feature_config.monotonicity and feature_config.monotonicity != 'none': + # Validate Categorical Calibration configuration. + if not np.iterable(feature_config.monotonicity): + raise ValueError('Monotonicity is not a list for feature {}: {}'.format( + feature_config.name, feature_config.monotonicity)) + for i, t in enumerate(feature_config.monotonicity): + if not np.iterable(t): raise ValueError( - 'CalibratedLatticeEnsemble must have >= 2 lattices. For single ' - 'lattice models, use CalibratedLattice instead.') - for lattice in model_config.lattices: - if (not np.iterable(lattice) or - any(not isinstance(x, str) for x in lattice)): + 'Element {} is not a list/tuple for feature {} monotonicty: {}' + .format(i, feature_config.name, t)) + for j, val in enumerate(t): + if not isinstance(val, int): raise ValueError( - 'Lattices are not fully specified for ensemble config.') - else: - raise ValueError( - 'Lattices are not fully specified for ensemble config. Lattices must ' - 'be set to \'rtl_layer\' or be fully specified as a list of lists of ' - 'feature names.') - if isinstance(model_config, configs.AggregateFunctionConfig): - if model_config.middle_dimension < 1: - raise ValueError('Middle dimension must be at least 1: {}'.format( - model_config.middle_dimension)) - if (model_config.middle_monotonicity is not None and - not model_config.middle_calibration): - raise ValueError( - 'middle_calibration must be true when middle_monotonicity is ' - 'specified.') + 'Element {} for list/tuple {} for feature {} monotonicity is ' + 'not an index: {}'.format(j, i, feature_config.name, val)) + if val < 0 or val >= feature_config.num_buckets: + raise ValueError( + 'Element {} for list/tuple {} for feature {} monotonicity is ' + 'an invalid index not in range [0, num_buckets - 1]: {}'.format( + j, i, feature_config.name, val)) + + +def verify_config(model_config): + """Verifies that the model_config and feature_configs are properly specified. + + Args: + model_config: Model configuration object describing model architecture. + Should be one of the model configs in `tfl.configs`. + + Raises: + ValueError: If `model_config.feature_configs` is None. + ValueError: If `model_config.output_initialization` is not iterable or + contains non-{int/float} values. + + """ if model_config.feature_configs is None: raise ValueError('Feature configs must be fully specified.') + if isinstance(model_config, configs.CalibratedLatticeEnsembleConfig): + _verify_ensemble_config(model_config) + if ((isinstance(model_config, configs.CalibratedLatticeEnsembleConfig) or + isinstance(model_config, configs.CalibratedLatticeConfig)) and + model_config.parameterization == 'kronecker_factored'): + _verify_kronecker_factored_config(model_config) + if isinstance(model_config, configs.AggregateFunctionConfig): + _verify_aggregate_function_config(model_config) for feature_config in model_config.feature_configs: - if not feature_config.num_buckets: - # Validate PWL Calibration configuration. - if (not np.iterable(feature_config.pwl_calibration_input_keypoints) or - any(not isinstance(x, (int, float)) - for x in feature_config.pwl_calibration_input_keypoints)): - raise ValueError( - 'Input keypoints are invalid for feature {}: {}'.format( - feature_config.name, - feature_config.pwl_calibration_input_keypoints)) - elif feature_config.monotonicity and feature_config.monotonicity != 'none': - # Validate Categorical Calibration configuration. - if not np.iterable(feature_config.monotonicity): - raise ValueError('Monotonicity is not a list for feature {}: {}'.format( - feature_config.name, feature_config.monotonicity)) - for i, t in enumerate(feature_config.monotonicity): - if not np.iterable(t): - raise ValueError( - 'Element {} is not a list/tuple for feature {} monotonicty: {}' - .format(i, feature_config.name, t)) - for j, val in enumerate(t): - if not isinstance(val, int): - raise ValueError( - 'Element {} for list/tuple {} for feature {} monotonicity is ' - 'not an index: {}'.format(j, i, feature_config.name, val)) - if val < 0 or val >= feature_config.num_buckets: - raise ValueError( - 'Element {} for list/tuple {} for feature {} monotonicity is ' - 'an invalid index not in range [0, num_buckets - 1]: {}'.format( - j, i, feature_config.name, val)) + _verify_feature_config(feature_config) if (not np.iterable(model_config.output_initialization) or any(not isinstance(x, (int, float)) for x in model_config.output_initialization)): diff --git a/tensorflow_lattice/python/premade_test.py b/tensorflow_lattice/python/premade_test.py index acf3fce..88d4630 100644 --- a/tensorflow_lattice/python/premade_test.py +++ b/tensorflow_lattice/python/premade_test.py @@ -22,6 +22,7 @@ import tempfile from absl import logging +from absl.testing import parameterized import numpy as np import pandas as pd import tensorflow as tf @@ -95,7 +96,7 @@ ] -class PremadeTest(tf.test.TestCase): +class PremadeTest(parameterized.TestCase, tf.test.TestCase): """Tests for TFL premade.""" def setUp(self): @@ -497,24 +498,44 @@ def testAggregateFromConfig(self): json.dumps(model.get_config(), sort_keys=True, cls=self.Encoder), json.dumps(loaded_model.get_config(), sort_keys=True, cls=self.Encoder)) - def testCalibratedLatticeEnsembleCrystals(self): + @parameterized.parameters( + ('hypercube', 'all_vertices', 0, 0.85), + ('simplex', 'all_vertices', 0, 0.89), + ('hypercube', 'kronecker_factored', 2, 0.82), + ('hypercube', 'kronecker_factored', 4, 0.82), + ) + def testCalibratedLatticeEnsembleCrystals(self, interpolation, + parameterization, num_terms, + expected_minimum_auc): # Construct model. self._ResetAllBackends() + crystals_feature_configs = copy.deepcopy(self.heart_feature_configs) model_config = configs.CalibratedLatticeEnsembleConfig( regularizer_configs=[ configs.RegularizerConfig(name='torsion', l2=1e-4), configs.RegularizerConfig(name='output_calib_hessian', l2=1e-4), ], - feature_configs=self.heart_feature_configs, + feature_configs=crystals_feature_configs, lattices='crystals', num_lattices=6, lattice_rank=5, + interpolation=interpolation, + parameterization=parameterization, + num_terms=num_terms, separate_calibrators=True, output_calibration=False, output_min=self.heart_min_label, output_max=self.heart_max_label - self.numerical_error_epsilon, output_initialization=[self.heart_min_label, self.heart_max_label], ) + if parameterization == 'kronecker_factored': + model_config.regularizer_configs = None + for feature_config in model_config.feature_configs: + feature_config.lattice_size = 2 + feature_config.unimodality = 'none' + feature_config.reflects_trust_in = None + feature_config.dominates = None + feature_config.regularizer_configs = None # Perform prefitting steps. prefitting_model_config = premade_lib.construct_prefitting_model_config( model_config) @@ -548,9 +569,16 @@ def testCalibratedLatticeEnsembleCrystals(self): self.heart_test_x, self.heart_test_y, verbose=False) logging.info('Calibrated lattice ensemble crystals classifier results:') logging.info(results) - self.assertGreater(results[1], 0.85) - - def testCalibratedLatticeEnsembleRTL(self): + self.assertGreater(results[1], expected_minimum_auc) + + @parameterized.parameters( + ('hypercube', 'all_vertices', 0, 0.85), + ('simplex', 'all_vertices', 0, 0.88), + ('hypercube', 'kronecker_factored', 2, 0.86), + ('hypercube', 'kronecker_factored', 4, 0.9), + ) + def testCalibratedLatticeEnsembleRTL(self, interpolation, parameterization, + num_terms, expected_minimum_auc): # Construct model. self._ResetAllBackends() rtl_feature_configs = copy.deepcopy(self.heart_feature_configs) @@ -569,12 +597,18 @@ def testCalibratedLatticeEnsembleRTL(self): lattices='rtl_layer', num_lattices=6, lattice_rank=5, + interpolation=interpolation, + parameterization=parameterization, + num_terms=num_terms, separate_calibrators=True, output_calibration=False, output_min=self.heart_min_label, output_max=self.heart_max_label - self.numerical_error_epsilon, output_initialization=[self.heart_min_label, self.heart_max_label], ) + # We must remove all regularization if using 'kronecker_factored'. + if parameterization == 'kronecker_factored': + model_config.regularizer_configs = None # Construct and train final model model = premade.CalibratedLatticeEnsemble(model_config) model.compile( @@ -591,15 +625,72 @@ def testCalibratedLatticeEnsembleRTL(self): self.heart_test_x, self.heart_test_y, verbose=False) logging.info('Calibrated lattice ensemble rtl classifier results:') logging.info(results) - self.assertGreater(results[1], 0.85) + self.assertGreater(results[1], expected_minimum_auc) + + @parameterized.parameters( + ('hypercube', 'all_vertices', 0, 0.81), + ('simplex', 'all_vertices', 0, 0.81), + ('hypercube', 'kronecker_factored', 2, 0.79), + ('hypercube', 'kronecker_factored', 4, 0.8), + ) + def testCalibratedLattice(self, interpolation, parameterization, num_terms, + expected_minimum_auc): + # Construct model configuration. + self._ResetAllBackends() + lattice_feature_configs = copy.deepcopy(self.heart_feature_configs[:5]) + model_config = configs.CalibratedLatticeConfig( + feature_configs=lattice_feature_configs, + interpolation=interpolation, + parameterization=parameterization, + num_terms=num_terms, + regularizer_configs=[ + configs.RegularizerConfig(name='torsion', l2=1e-4), + configs.RegularizerConfig(name='output_calib_hessian', l2=1e-4), + ], + output_min=self.heart_min_label, + output_max=self.heart_max_label, + output_calibration=False, + output_initialization=[self.heart_min_label, self.heart_max_label], + ) + if parameterization == 'kronecker_factored': + model_config.regularizer_configs = None + for feature_config in model_config.feature_configs: + feature_config.lattice_size = 2 + feature_config.unimodality = 'none' + feature_config.reflects_trust_in = None + feature_config.dominates = None + feature_config.regularizer_configs = None + # Construct and train final model + model = premade.CalibratedLattice(model_config) + model.compile( + loss=tf.keras.losses.BinaryCrossentropy(), + metrics=tf.keras.metrics.AUC(), + optimizer=tf.keras.optimizers.Adam(0.01)) + model.fit( + self.heart_train_x[:5], + self.heart_train_y, + batch_size=100, + epochs=200, + verbose=False) + results = model.evaluate( + self.heart_test_x[:5], self.heart_test_y, verbose=False) + logging.info('Calibrated lattice classifier results:') + logging.info(results) + self.assertGreater(results[1], expected_minimum_auc) - def testLatticeEnsembleH5FormatSaveLoad(self): + @parameterized.parameters( + ('all_vertices', 0), + ('kronecker_factored', 2), + ) + def testLatticeEnsembleH5FormatSaveLoad(self, parameterization, num_terms): model_config = configs.CalibratedLatticeEnsembleConfig( feature_configs=copy.deepcopy(feature_configs), lattices=[['numerical_1', 'categorical'], ['numerical_2', 'categorical']], num_lattices=2, lattice_rank=2, + parameterization=parameterization, + num_terms=num_terms, separate_calibrators=True, regularizer_configs=[ configs.RegularizerConfig('calib_hessian', l2=1e-3), @@ -610,6 +701,14 @@ def testLatticeEnsembleH5FormatSaveLoad(self): output_calibration=True, output_calibration_num_keypoints=5, output_initialization=[-1.0, 1.0]) + if parameterization == 'kronecker_factored': + model_config.regularizer_configs = None + for feature_config in model_config.feature_configs: + feature_config.lattice_size = 2 + feature_config.unimodality = 'none' + feature_config.reflects_trust_in = None + feature_config.dominates = None + feature_config.regularizer_configs = None model = premade.CalibratedLatticeEnsemble(model_config) # Compile and fit model. model.compile(loss='mse', optimizer=tf.keras.optimizers.Adam(0.1)) @@ -623,7 +722,11 @@ def testLatticeEnsembleH5FormatSaveLoad(self): model.predict(fake_data['eval_xs']), loaded_model.predict(fake_data['eval_xs'])) - def testLatticeEnsembleRTLH5FormatSaveLoad(self): + @parameterized.parameters( + ('all_vertices', 0), + ('kronecker_factored', 2), + ) + def testLatticeEnsembleRTLH5FormatSaveLoad(self, parameterization, num_terms): rtl_feature_configs = copy.deepcopy(feature_configs) for feature_config in rtl_feature_configs: feature_config.lattice_size = 2 @@ -636,6 +739,8 @@ def testLatticeEnsembleRTLH5FormatSaveLoad(self): lattices='rtl_layer', num_lattices=2, lattice_rank=2, + parameterization=parameterization, + num_terms=num_terms, separate_calibrators=True, regularizer_configs=[ configs.RegularizerConfig('calib_hessian', l2=1e-3), @@ -646,6 +751,8 @@ def testLatticeEnsembleRTLH5FormatSaveLoad(self): output_calibration=True, output_calibration_num_keypoints=5, output_initialization=[-1.0, 1.0]) + if parameterization == 'kronecker_factored': + model_config.regularizer_configs = None model = premade.CalibratedLatticeEnsemble(model_config) # Compile and fit model. model.compile(loss='mse', optimizer=tf.keras.optimizers.Adam(0.1)) @@ -659,9 +766,15 @@ def testLatticeEnsembleRTLH5FormatSaveLoad(self): model.predict(fake_data['eval_xs']), loaded_model.predict(fake_data['eval_xs'])) - def testLatticeH5FormatSaveLoad(self): + @parameterized.parameters( + ('all_vertices', 0), + ('kronecker_factored', 2), + ) + def testLatticeH5FormatSaveLoad(self, parameterization, num_terms): model_config = configs.CalibratedLatticeConfig( feature_configs=copy.deepcopy(feature_configs), + parameterization=parameterization, + num_terms=num_terms, regularizer_configs=[ configs.RegularizerConfig('calib_wrinkle', l2=1e-3), configs.RegularizerConfig('torsion', l2=1e-3), @@ -671,6 +784,14 @@ def testLatticeH5FormatSaveLoad(self): output_calibration=True, output_calibration_num_keypoints=6, output_initialization=[0.0, 1.0]) + if parameterization == 'kronecker_factored': + model_config.regularizer_configs = None + for feature_config in model_config.feature_configs: + feature_config.lattice_size = 2 + feature_config.unimodality = 'none' + feature_config.reflects_trust_in = None + feature_config.dominates = None + feature_config.regularizer_configs = None model = premade.CalibratedLattice(model_config) # Compile and fit model. model.compile(loss='mse', optimizer=tf.keras.optimizers.Adam(0.1)) diff --git a/tensorflow_lattice/python/pwl_calibration_layer.py b/tensorflow_lattice/python/pwl_calibration_layer.py index bfe7e3a..15bf687 100644 --- a/tensorflow_lattice/python/pwl_calibration_layer.py +++ b/tensorflow_lattice/python/pwl_calibration_layer.py @@ -37,6 +37,7 @@ MISSING_INPUT_VALUE_NAME = "missing_input_value" PWL_CALIBRATION_KERNEL_NAME = "pwl_calibration_kernel" PWL_CALIBRATION_MISSING_OUTPUT_NAME = "pwl_calibration_missing_output" +INTERPOLATION_LOGITS_NAME = "interpolation_logits" class PWLCalibration(keras.layers.Layer): @@ -109,6 +110,7 @@ def __init__(self, missing_output_value=None, num_projection_iterations=8, split_outputs=False, + input_keypoints_type="fixed", **kwargs): # pyformat: disable """Initializes an instance of `PWLCalibration`. @@ -177,6 +179,11 @@ class input shape description for more details. `tfl.pwl_calibration_lib.project_all_constraints` for more details. split_outputs: Whether to split the output tensor into a list of outputs for each unit. Ignored if units < 2. + input_keypoints_type: One of "fixed" or "learned_interior". If + "learned_interior", keypoints are initialized to the values in + `input_keypoints` but then allowed to vary during training, with the + exception of the first and last keypoint location which are fixed. + Convexity can only be imposed with "fixed". **kwargs: Other args passed to `tf.keras.layers.Layer` initializer. Raises: @@ -204,6 +211,10 @@ class input shape description for more details. raise ValueError("'input_keypoints' can't be None") if monotonicity is None: raise ValueError("'monotonicity' can't be None. Did you mean '0'?") + if convexity not in ("none", + 0) and input_keypoints_type == "learned_interior": + raise ValueError("Cannot set input_keypoints_type to 'learned_interior'" + " and impose convexity constraints.") self.input_keypoints = input_keypoints self.units = units @@ -277,25 +288,40 @@ class input shape description for more details. self.missing_output_value = missing_output_value self.num_projection_iterations = num_projection_iterations self.split_outputs = split_outputs + self.input_keypoints_type = input_keypoints_type def build(self, input_shape): """Standard Keras build() method.""" input_keypoints = np.array(self.input_keypoints) # Don't need last keypoint for interpolation because we need only beginnings # of intervals. - self._interpolation_keypoints = tf.constant( - input_keypoints[:-1], - dtype=self.dtype, - name=INTERPOLATION_KEYPOINTS_NAME) - self._lengths = tf.constant( - input_keypoints[1:] - input_keypoints[:-1], - dtype=self.dtype, - name=LENGTHS_NAME) + if self.input_keypoints_type == "fixed": + self._interpolation_keypoints = tf.constant( + input_keypoints[:-1], + dtype=self.dtype, + name=INTERPOLATION_KEYPOINTS_NAME) + self._lengths = tf.constant( + input_keypoints[1:] - input_keypoints[:-1], + dtype=self.dtype, + name=LENGTHS_NAME) + else: + self._keypoint_min = input_keypoints[0] + self._keypoint_range = input_keypoints[-1] - input_keypoints[0] + # Logits are initialized such that they will recover the scaled keypoint + # gaps in input_keypoints. + initial_logits = np.log( + (input_keypoints[1:] - input_keypoints[:-1]) / self._keypoint_range) + tiled_logits = np.tile(initial_logits, self.units) + self.interpolation_logits = self.add_weight( + INTERPOLATION_LOGITS_NAME, + shape=[self.units, len(input_keypoints) - 1], + initializer=tf.constant_initializer(tiled_logits), + dtype=self.dtype) constraints = PWLCalibrationConstraints( monotonicity=self.monotonicity, convexity=self.convexity, - lengths=self._lengths, + lengths=self._lengths if self.input_keypoints_type == "fixed" else None, output_min=self.output_min, output_max=self.output_max, output_min_constraints=self._output_min_constraints, @@ -400,22 +426,36 @@ def call(self, inputs): raise ValueError("Shape of input tensor for PWLCalibration layer must be " "[-1, units] or [-1, 1]. It is: " + str(inputs.shape)) - if inputs.dtype != self._interpolation_keypoints.dtype: + if self.input_keypoints_type == "fixed": + keypoints_dtype = self._interpolation_keypoints.dtype + else: + keypoints_dtype = self.interpolation_logits.dtype + if inputs.dtype != keypoints_dtype: raise ValueError("dtype(%s) of input to PWLCalibration layer does not " "correspond to dtype(%s) of keypoints. You can enforce " "dtype of keypoints by explicitly providing 'dtype' " "parameter to layer constructor or by passing keypoints " "in such format which by default will be converted into " - "desired one." % - (inputs.dtype, self._interpolation_keypoints.dtype)) + "desired one." % (inputs.dtype, keypoints_dtype)) # Here is calibration. Everything else is handling of missing. - if inputs.shape[1] > 1: - # Add dimension to multi dim input to get shape [batch_size, units, 1]. - # Interpolation will have shape [batch_size, units, weights]. + if inputs.shape[1] > 1 or (self.input_keypoints_type == "learned_interior" + and self.units > 1): + # Interpolation will have shape [batch_size, units, weights] in these + # cases. To prepare for that, we add a dimension to the input here to get + # shape [batch_size, units, 1] or [batch_size, 1, 1] if 1d input. inputs_to_calibration = tf.expand_dims(inputs, -1) else: inputs_to_calibration = inputs + if self.input_keypoints_type == "learned_interior": + self._lengths = tf.multiply( + tf.nn.softmax(self.interpolation_logits, axis=1), + self._keypoint_range, + name=LENGTHS_NAME) + self._interpolation_keypoints = tf.add( + tf.cumsum(self._lengths, axis=1, exclusive=True), + self._keypoint_min, + name=INTERPOLATION_KEYPOINTS_NAME) interpolation_weights = pwl_calibration_lib.compute_interpolation_weights( inputs_to_calibration, self._interpolation_keypoints, self._lengths) if self.is_cyclic: @@ -428,7 +468,7 @@ def call(self, inputs): bias_and_heights = self.kernel # bias_and_heights has shape [weight, units]. - if inputs.shape[1] > 1: + if len(interpolation_weights.shape) > 2: # Multi dim input has interpolation shape [batch_size, units, weights]. result = tf.reduce_sum( interpolation_weights * tf.transpose(bias_and_heights), axis=-1) @@ -481,6 +521,7 @@ def get_config(self): "missing_input_value": self.missing_input_value, "num_projection_iterations": self.num_projection_iterations, "split_outputs": self.split_outputs, + "input_keypoints_type": self.input_keypoints_type, } # pyformat: disable config.update(super(PWLCalibration, self).get_config()) return config @@ -530,12 +571,36 @@ def assert_constraints(self, eps=1e-6): return asserts def keypoints_outputs(self): - """Returns tensor which corresponds to outputs of layer for keypoints.""" + """Returns tensor of keypoint outputs of shape [num_weights, num_units].""" kp_outputs = tf.cumsum(self.kernel) if self.is_cyclic: kp_outputs = tf.concat([kp_outputs, kp_outputs[0:1]], axis=0) return kp_outputs + def keypoints_inputs(self): + """Returns tensor of keypoint inputs of shape [num_weights, num_units].""" + # We don't store the last keypoint in self._interpolation_keypoints since + # it is not needed for training or evaluation, but we re-add it here to + # align with the keypoints_outputs function. + if self.input_keypoints_type == "fixed": + all_keypoints = tf.concat([ + self._interpolation_keypoints, + self._interpolation_keypoints[-1:] + self._lengths[-1:] + ], + axis=0) + return tf.stack([all_keypoints] * self.units, axis=1) + else: + lengths = tf.nn.softmax( + self.interpolation_logits, axis=-1) * self._keypoint_range + interpolation_keypoints = tf.cumsum( + lengths, axis=-1, exclusive=True) + self._keypoint_min + all_keypoints = tf.concat([ + interpolation_keypoints, + interpolation_keypoints[:, -1:] + lengths[:, -1:] + ], + axis=1) + return tf.transpose(all_keypoints) + class UniformOutputInitializer(keras.initializers.Initializer): # pyformat: disable diff --git a/tensorflow_lattice/python/pwl_calibration_lib.py b/tensorflow_lattice/python/pwl_calibration_lib.py index 659bd22..b378366 100644 --- a/tensorflow_lattice/python/pwl_calibration_lib.py +++ b/tensorflow_lattice/python/pwl_calibration_lib.py @@ -97,21 +97,34 @@ def compute_interpolation_weights(inputs, keypoints, lengths): """Computes weights for PWL calibration. Args: - inputs: Tensor of shape: `(D0, D1, ..., DN, 1)` which represents inputs to - to the pwl function. A typical shape is: `(batch_size, 1)`. - keypoints: Rank-1 tensor of shape `(num_keypoints - 1)` which represents - left keypoint of pieces of piecewise linear function along X axis. - lengths: Rank-1 tensor of shape `(num_keypoints - 1)` which represents - lengths of pieces of piecewise linear function along X axis. + inputs: Tensor of shape: `(batch_size, 1)`, `(batch_size, units, 1)` or + `(batch_size, 1, 1)`. For multi-unit calibration, broadcasting will be used + if needed. + keypoints: Tensor of shape `(num_keypoints-1)` or `(units, num_keypoints-1)` + which represents left keypoint of pieces of piecewise linear function + along X axis. + lengths: Tensor of shape `(num_keypoints-1)` or `(units, num_keypoints-1)` + which represents lengths of pieces of piecewise linear function along X + axis. Returns: - Interpolation weights tensor of shape: `(D0, D1, ..., DN, num_keypoints)`. + Interpolation weights tensor of shape: `(batch_size, num_keypoints)` or + `(batch_size, units, num_keypoints)`. """ weights = (inputs - keypoints) / lengths weights = tf.minimum(weights, 1.0) weights = tf.maximum(weights, 0.0) - # Prepend 1.0 at the beginning to add bias unconditionally. - return tf.concat([tf.ones_like(inputs), weights], axis=-1) + # Prepend 1.0 at the beginning to add bias unconditionally. Worth testing + # different strategies, including those commented out, on different hardware. + if len(keypoints.shape) == 1: + return tf.concat([tf.ones_like(inputs), weights], axis=-1) + else: + shape = tf.concat([tf.shape(weights)[:-1], [1]], axis=0) + return tf.concat([tf.ones(shape), weights], axis=-1) + # return tf.concat([tf.ones_like(weights)[..., :1], weights], axis=-1) + # return tf.concat([tf.ones_like(weights[..., :1]), weights], axis=-1) + # paddings = [[0, 0]] * (len(weights.shape) - 1) + [[1, 0]] + # return tf.pad(weights, paddings, constant_values=1.) def linear_initializer(shape, diff --git a/tensorflow_lattice/python/pwl_calibration_test.py b/tensorflow_lattice/python/pwl_calibration_test.py index cc7496a..4d24d89 100644 --- a/tensorflow_lattice/python/pwl_calibration_test.py +++ b/tensorflow_lattice/python/pwl_calibration_test.py @@ -152,6 +152,7 @@ def _SetDefaults(self, config): config.setdefault("constraint_assertion_eps", 1e-6) config.setdefault("model_dir", "/tmp/test_pwl_model_dir/") config.setdefault("dtype", tf.float32) + config.setdefault("input_keypoints_type", "fixed") if "input_keypoints" not in config: # If "input_keypoints" are provided - other params referred by code below @@ -229,7 +230,8 @@ def _TrainModel(self, config, plot_path=None): impute_missing=config["impute_missing"], missing_output_value=config["missing_output_value"], missing_input_value=config["missing_input_value"], - num_projection_iterations=config["num_projection_iterations"])) + num_projection_iterations=config["num_projection_iterations"], + input_keypoints_type=config["input_keypoints_type"])) if len(calibration_layers) == 1: if config["use_separate_missing"]: model.add( @@ -510,11 +512,15 @@ def testSonnetNumProjectionIterations(self): self._AssertSonnetEquivalentToKeras(config) @parameterized.parameters( - (1, False, 0.001022), - (3, False, 0.000543), - (3, True, 0.000987), + (1, False, 0.001022, "fixed"), + (3, False, 0.000543, "fixed"), + (3, True, 0.000987, "fixed"), + (1, False, 0.000393, "learned_interior"), + (3, False, 0.000427, "learned_interior"), + (3, True, 0.000577, "learned_interior"), ) - def testUnconstrainedNoMissingValue(self, units, one_d_input, expected_loss): + def testUnconstrainedNoMissingValue(self, units, one_d_input, expected_loss, + input_keypoints_type): if self._disable_all: return config = { @@ -532,6 +538,7 @@ def testUnconstrainedNoMissingValue(self, units, one_d_input, expected_loss): "input_max": 1.0, "output_min": None, "output_max": None, + "input_keypoints_type": input_keypoints_type, } loss = self._TrainModel(config) self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps) @@ -1368,6 +1375,43 @@ def testOutputShape(self): self.assertAllEqual(output_shape, pwl_1.compute_output_shape(input_a.shape)) self.assertAllEqual(output_shape, [o.shape for o in output]) + @parameterized.parameters(("fixed", 1, 1), ("fixed", 1, 2), ("fixed", 2, 2), + ("learned_interior", 1, 1), + ("learned_interior", 1, 2), + ("learned_interior", 2, 2)) + def testKeypointsInputs(self, input_keypoints_type, input_dims, output_units): + if self._disable_all: + return + + input_keypoints = [0, 0.5, 1] + expected_function_output = np.array([[0.0] * output_units, + [0.5] * output_units, + [1.0] * output_units]) + + # Check after layer build + pwl = keras_layer.PWLCalibration( + input_keypoints=input_keypoints, + units=output_units, + input_keypoints_type=input_keypoints_type) + pwl.build(input_shape=[10, input_dims]) + self.assertAllEqual(expected_function_output, pwl.keypoints_inputs()) + + # Check after Keras model compile + model = keras.models.Sequential() + model.add(tf.keras.layers.Input(shape=[input_dims], dtype=tf.float32)) + model.add(pwl) + model.compile(loss=keras.losses.mean_squared_error) + self.assertAllEqual(expected_function_output, pwl.keypoints_inputs()) + + # Check after Keras model fit; look for change in learned case. + train_x = np.random.uniform(size=(10, input_dims)) + train_y = train_x[:, 0]**2 + model.fit(train_x, train_y, batch_size=len(train_x), epochs=5, verbose=0) + if input_keypoints_type == "fixed": + self.assertAllEqual(expected_function_output, pwl.keypoints_inputs()) + else: + self.assertNotAllEqual(expected_function_output, pwl.keypoints_inputs()) + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow_lattice/python/rtl_layer.py b/tensorflow_lattice/python/rtl_layer.py index 34983e5..95cb772 100644 --- a/tensorflow_lattice/python/rtl_layer.py +++ b/tensorflow_lattice/python/rtl_layer.py @@ -28,6 +28,7 @@ import collections import itertools +from . import kronecker_factored_lattice_layer as kfll from . import lattice_layer from . import rtl_lib @@ -40,6 +41,7 @@ _MAX_RTL_SWAPS = 10000 _RTLInput = collections.namedtuple('_RTLInput', ['monotonicity', 'group', 'input_index']) +RTL_KFL_NAME = 'rtl_kronecker_factored_lattice' RTL_LATTICE_NAME = 'rtl_lattice' INPUTS_FOR_UNITS_PREFIX = 'inputs_for_lattice' RTL_CONCAT_NAME = 'rtl_concat' @@ -125,12 +127,16 @@ def __init__(self, lattice_size=2, output_min=None, output_max=None, + init_min=None, + init_max=None, separate_outputs=False, random_seed=42, num_projection_iterations=10, monotonic_at_every_step=True, clip_inputs=True, interpolation='hypercube', + parameterization='all_vertices', + num_terms=2, avoid_intragroup_interaction=True, kernel_initializer='random_monotonic_initializer', kernel_regularizer=None, @@ -145,12 +151,15 @@ def __init__(self, lattice_size: Number of lattice vertices per dimension (minimum is 2). output_min: None or lower bound of the output. output_max: None or upper bound of the output. + init_min: None or lower bound of lattice kernel initialization. + init_max: None or upper bound of lattice kernel initialization. separate_outputs: If set to true, the output will be a dict in the same format as the input to the layer, ready to be passed to another RTL layer. If false, the output will be a single tensor of shape (batch_size, num_lattices). See output shape for details. random_seed: Random seed for the randomized feature arrangement in the - ensemble. + ensemble. Also used for initialization of lattices using + `'kronecker_factored'` parameterization. num_projection_iterations: Number of iterations of Dykstra projections algorithm. Projection updates will be closer to a true projection (with respect to the L2 norm) with higher number of iterations. Increasing @@ -170,6 +179,35 @@ def __init__(self, 'simplex' uses d+1 parameters and thus scales better. For details see `tfl.lattice_lib.evaluate_with_simplex_interpolation` and `tfl.lattice_lib.evaluate_with_hypercube_interpolation`. + parameterization: The parameterization of the lattice function class to + use. A lattice function is uniquely determined by specifying its value + on every lattice vertex. A parameterization scheme is a mapping from a + vector of parameters to a multidimensional array of lattice vertex + values. It can be one of: + - String `'all_vertices'`: This is the "traditional" parameterization + that keeps one scalar parameter per lattice vertex where the mapping + is essentially the identity map. With this scheme, the number of + parameters scales exponentially with the number of inputs to the + lattice. The underlying lattices used will be `tfl.layers.Lattice` + layers. + - String `'kronecker_factored'`: With this parameterization, for each + lattice input i we keep a collection of `num_terms` vectors each + having `feature_configs[0].lattice_size` entries (note that the + lattice size of the first feature will be used as the lattice size + for all other features as well). To obtain the tensor of lattice + vertex values, for `t=1,2,...,num_terms` we compute the outer + product of the `t'th` vector in each collection, multiply by a + per-term scale, and sum the resulting tensors. Finally, we add a + single shared bias parameter to each entry in the sum. With this + scheme, the number of parameters grows linearly with `lattice_rank` + (assuming lattice sizes and `num_terms` are held constant). + Currently, only monotonicity shape constraint and bound constraint + are supported for this scheme. Regularization is not currently + supported. The underlying lattices used will be + `tfl.layers.KroneckerFactoredLattice` layers. + num_terms: The number of terms in a lattice using `'kronecker_factored'` + parameterization. Ignored if parameterization is set to + `'all_vertices'`. avoid_intragroup_interaction: If set to true, the RTL algorithm will try to avoid having inputs from the same group in the same lattice. kernel_initializer: One of: @@ -179,13 +217,25 @@ def __init__(self, that minimum possible output is equal to output_min and maximum possible output is equal to output_max. See `tfl.lattice_layer.LinearInitializer` class docstring for more - details. + details. This initialization is not supported when using the + `'kronecker_factored'` parameterization. - `'random_monotonic_initializer'`: initialize parameters uniformly at random such that all parameters are monotonically increasing for each input. Parameters will be sampled uniformly at random from the range + `[init_min, init_max]` if specified, otherwise `[output_min, output_max]`. See `tfl.lattice_layer.RandomMonotonicInitializer` class docstring for - more details. + more details. This initialization is not supported when using the + `'kronecker_factored'` parameterization. + - `'kfl_random_monotonic_initializer'`: initialize parameters uniformly + at random such that all parameters are monotonically increasing for + each monotonic input. Parameters will be sampled uniformly at random + from the range `[init_min, init_max]` if specified. Otherwise, the + initialization range will be algorithmically determined depending on + output_{min/max}. See `tfl.layers.KroneckerFactoredLattice` and + `tfl.kronecker_factored_lattice.KFLRandomMonotonicInitializer` class + docstrings for more details. This initialization is not supported when + using `'all_vertices'` parameterization. kernel_regularizer: None or a single element or a list of following: - Tuple `('torsion', l1, l2)` or List `['torsion', l1, l2]` where l1 and l2 represent corresponding regularization amount for graph Torsion @@ -203,26 +253,34 @@ def __init__(self, Raises: ValueError: If layer hyperparameters are invalid. + ValueError: If `parameterization` is not one of `'all_vertices'` or + `'kronecker_factored'`. """ # pyformat: enable rtl_lib.verify_hyperparameters( lattice_size=lattice_size, output_min=output_min, output_max=output_max, - kernel_regularizer=kernel_regularizer, - interpolation=interpolation) + interpolation=interpolation, + parameterization=parameterization, + kernel_initializer=kernel_initializer, + kernel_regularizer=kernel_regularizer) super(RTL, self).__init__(**kwargs) self.num_lattices = num_lattices self.lattice_rank = lattice_rank self.lattice_size = lattice_size self.output_min = output_min self.output_max = output_max + self.init_min = init_min + self.init_max = init_max self.separate_outputs = separate_outputs self.random_seed = random_seed self.num_projection_iterations = num_projection_iterations self.monotonic_at_every_step = monotonic_at_every_step self.clip_inputs = clip_inputs self.interpolation = interpolation + self.parameterization = parameterization + self.num_terms = num_terms self.avoid_intragroup_interaction = avoid_intragroup_interaction self.kernel_initializer = kernel_initializer self.kernel_regularizer = kernel_regularizer @@ -257,21 +315,57 @@ def build(self, input_shape): inputs_for_units, dtype=tf.int32, name=inputs_for_units_name) ]): units = len(inputs_for_units) - layer_name = '{}_{}'.format(RTL_LATTICE_NAME, monotonicities_str) - self._lattice_layers[str(monotonicities)] = lattice_layer.Lattice( - lattice_sizes=[self.lattice_size] * self.lattice_rank, - units=units, - monotonicities=monotonicities, - output_min=self.output_min, - output_max=self.output_max, - num_projection_iterations=self.num_projection_iterations, - monotonic_at_every_step=self.monotonic_at_every_step, - clip_inputs=self.clip_inputs, - interpolation=self.interpolation, - kernel_initializer=self.kernel_initializer, - kernel_regularizer=kernel_regularizer, - name=layer_name, - ) + if self.parameterization == 'all_vertices': + layer_name = '{}_{}'.format(RTL_LATTICE_NAME, monotonicities_str) + lattice_sizes = [self.lattice_size] * self.lattice_rank + kernel_initializer = lattice_layer.create_kernel_initializer( + kernel_initializer_id=self.kernel_initializer, + lattice_sizes=lattice_sizes, + monotonicities=monotonicities, + output_min=self.output_min, + output_max=self.output_max, + unimodalities=None, + joint_unimodalities=None, + init_min=self.init_min, + init_max=self.init_max) + self._lattice_layers[str(monotonicities)] = lattice_layer.Lattice( + lattice_sizes=lattice_sizes, + units=units, + monotonicities=monotonicities, + output_min=self.output_min, + output_max=self.output_max, + num_projection_iterations=self.num_projection_iterations, + monotonic_at_every_step=self.monotonic_at_every_step, + clip_inputs=self.clip_inputs, + interpolation=self.interpolation, + kernel_initializer=kernel_initializer, + kernel_regularizer=kernel_regularizer, + name=layer_name, + ) + elif self.parameterization == 'kronecker_factored': + layer_name = '{}_{}'.format(RTL_KFL_NAME, monotonicities_str) + kernel_initializer = kfll.create_kernel_initializer( + kernel_initializer_id=self.kernel_initializer, + monotonicities=monotonicities, + output_min=self.output_min, + output_max=self.output_max, + init_min=self.init_min, + init_max=self.init_max) + self._lattice_layers[str( + monotonicities)] = kfll.KroneckerFactoredLattice( + lattice_sizes=self.lattice_size, + units=units, + num_terms=self.num_terms, + monotonicities=monotonicities, + output_min=self.output_min, + output_max=self.output_max, + clip_inputs=self.clip_inputs, + kernel_initializer=kernel_initializer, + scale_initializer='scale_initializer', + name=layer_name) + else: + raise ValueError('Unknown type of parameterization: {}'.format( + self.parameterization)) super(RTL, self).build(input_shape) def call(self, x, **kwargs): @@ -358,12 +452,16 @@ def get_config(self): 'lattice_size': self.lattice_size, 'output_min': self.output_min, 'output_max': self.output_max, + 'init_min': self.init_min, + 'init_max': self.init_max, 'separate_outputs': self.separate_outputs, 'random_seed': self.random_seed, 'num_projection_iterations': self.num_projection_iterations, 'monotonic_at_every_step': self.monotonic_at_every_step, 'clip_inputs': self.clip_inputs, 'interpolation': self.interpolation, + 'parameterization': self.parameterization, + 'num_terms': self.num_terms, 'avoid_intragroup_interaction': self.avoid_intragroup_interaction, 'kernel_initializer': self.kernel_initializer, 'kernel_regularizer': self.kernel_regularizer, diff --git a/tensorflow_lattice/python/rtl_lib.py b/tensorflow_lattice/python/rtl_lib.py index 4aa0842..b7b5113 100644 --- a/tensorflow_lattice/python/rtl_lib.py +++ b/tensorflow_lattice/python/rtl_lib.py @@ -24,8 +24,10 @@ def verify_hyperparameters(lattice_size, input_shape=None, output_min=None, output_max=None, - kernel_regularizer=None, - interpolation="hypercube"): + interpolation="hypercube", + parameterization="all_vertices", + kernel_initializer=None, + kernel_regularizer=None): """Verifies that all given hyperparameters are consistent. See `tfl.layers.RTL` class level comment for detailed description of @@ -36,17 +38,24 @@ def verify_hyperparameters(lattice_size, input_shape: Shape of layer input. output_min: Minimum output of `RTL` layer. output_max: Maximum output of `RTL` layer. - kernel_regularizer: Regularizers to check against. interpolation: One of 'simplex' or 'hypercube' interpolation. + parameterization: One of 'all_vertices' or 'kronecker_factored' + parameterizations. + kernel_initializer: Initizlier to check against. + kernel_regularizer: Regularizers to check against. Raises: ValueError: If lattice_size < 2. KeyError: If input_shape is a dict with incorrect keys. ValueError: If output_min >= output_max. + ValueError: If interpolation is not one of 'simplex' or 'hypercube'. + ValueError: If parameterization is 'kronecker_factored' and + kernel_initializer is 'linear_initializer'. + ValueError: If parameterization is 'kronecker_factored' and + kernel_regularizer is not None. ValueError: If kernel_regularizer contains a tuple with len != 3. ValueError: If kernel_regularizer contains a tuple with non-float l1 value. ValueError: If kernel_regularizer contains a tuple with non-flaot l2 value. - ValueError: If interpolation is not one of 'simplex' or 'hypercube'. """ if lattice_size < 2: @@ -66,6 +75,24 @@ def verify_hyperparameters(lattice_size, "'output_min': %f, 'output_max': %f" % (output_min, output_max)) + if interpolation not in ["hypercube", "simplex"]: + raise ValueError("RTL interpolation type should be either 'simplex' " + "or 'hypercube': %s" % interpolation) + + if (parameterization == "kronecker_factored" and + kernel_initializer == "linear_initializer"): + raise ValueError("'kronecker_factored' parameterization does not currently " + "support linear iniitalization. 'parameterization': %s, " + "'kernel_initializer': %s" % + (parameterization, kernel_initializer)) + + if (parameterization == "kronecker_factored" and + kernel_regularizer is not None): + raise ValueError("'kronecker_factored' parameterization does not currently " + "support regularization. 'parameterization': %s, " + "'kernel_regularizer': %s" % + (parameterization, kernel_regularizer)) + if kernel_regularizer: if isinstance(kernel_regularizer, list): regularizers = kernel_regularizer @@ -84,7 +111,3 @@ def verify_hyperparameters(lattice_size, raise ValueError( "Regularizer l2 must be a single float. Given: {}".format( type(l2))) - - if interpolation not in ["hypercube", "simplex"]: - raise ValueError("RTL interpolation type should be either 'simplex' " - "or 'hypercube': %s" % interpolation) diff --git a/tensorflow_lattice/python/visualization.py b/tensorflow_lattice/python/visualization.py index 969452e..7667302 100644 --- a/tensorflow_lattice/python/visualization.py +++ b/tensorflow_lattice/python/visualization.py @@ -283,6 +283,8 @@ def _node_name(node): return 'Linear' if isinstance(node, model_info.LatticeNode): return 'Lattice' + if isinstance(node, model_info.KroneckerFactoredLatticeNode): + return 'KroneckerFactoredLattice' if isinstance(node, model_info.MeanNode): return 'Average' return str(type(node))