Skip to content

Commit

Permalink
Switches Keras object serialization to new logic and changes public A…
Browse files Browse the repository at this point in the history
…PI for deserialize_keras_object/serialize_keras_object to the new functions.

PiperOrigin-RevId: 512685517
  • Loading branch information
nkovela1 authored and tensorflower-gardener committed Feb 27, 2023
1 parent 9d8baa9 commit 7eb8ef2
Show file tree
Hide file tree
Showing 48 changed files with 692 additions and 194 deletions.
49 changes: 43 additions & 6 deletions keras/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@
"""Built-in activation functions."""

import sys
import types

import tensorflow.compat.v2 as tf

import keras.layers.activation as activation_layers
from keras import backend
from keras.saving import object_registration
from keras.saving import serialization_lib
from keras.saving.legacy import serialization as legacy_serialization
from keras.saving.legacy.saved_model import utils as saved_model_utils
from keras.utils import generic_utils

# isort: off
Expand Down Expand Up @@ -544,7 +548,7 @@ def serialize(activation, use_legacy_format=False):
>>> tf.keras.activations.serialize('abcd')
Traceback (most recent call last):
...
ValueError: ('Cannot serialize', 'abcd')
ValueError: Unknown activation function 'abcd' cannot be serialized.
Raises:
ValueError: The input function is not a valid one.
Expand All @@ -558,8 +562,35 @@ def serialize(activation, use_legacy_format=False):
if use_legacy_format:
return legacy_serialization.serialize_keras_object(activation)

# To be replaced by new serialization_lib
return legacy_serialization.serialize_keras_object(activation)
fn_config = serialization_lib.serialize_keras_object(activation)
if (
not tf.__internal__.tf2.enabled()
or saved_model_utils.in_tf_saved_model_scope()
):
return fn_config
if "config" not in fn_config:
raise ValueError(
f"Unknown activation function '{activation}' cannot be "
"serialized due to invalid function name. Make sure to use "
"an activation name that matches the references defined in "
"activations.py or use `@keras.utils.register_keras_serializable` "
"for any custom activations. "
f"config={fn_config}"
)
if not isinstance(activation, types.FunctionType):
# Case for additional custom activations represented by objects
return fn_config
if (
isinstance(fn_config["config"], str)
and fn_config["config"] not in globals()
):
# Case for custom activation functions from external activations modules
fn_config["config"] = object_registration.get_registered_name(
activation
)
return fn_config
return fn_config["config"]
# Case for keras.activations builtins (simply return name)


# Add additional globals so that deserialize() can find these common activation
Expand Down Expand Up @@ -592,7 +623,7 @@ def deserialize(name, custom_objects=None, use_legacy_format=False):
>>> tf.keras.activations.deserialize('abcd')
Traceback (most recent call last):
...
ValueError: Unknown activation function:abcd
ValueError: Unknown activation function 'abcd' cannot be deserialized.
Raises:
ValueError: `Unknown activation function` if the input string does not
Expand All @@ -617,14 +648,20 @@ def deserialize(name, custom_objects=None, use_legacy_format=False):
printable_module_name="activation function",
)

# To be replaced by new serialization_lib
return legacy_serialization.deserialize_keras_object(
returned_fn = serialization_lib.deserialize_keras_object(
name,
module_objects=activation_functions,
custom_objects=custom_objects,
printable_module_name="activation function",
)

if isinstance(returned_fn, str):
raise ValueError(
f"Unknown activation function '{name}' cannot be deserialized."
)

return returned_fn


@keras_export("keras.activations.get")
@tf.__internal__.dispatch.add_dispatch_support
Expand Down
4 changes: 2 additions & 2 deletions keras/api/golden/v1/tensorflow.keras.utils.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ tf_module {
}
member_method {
name: "deserialize_keras_object"
argspec: "args=[\'identifier\', \'module_objects\', \'custom_objects\', \'printable_module_name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'object\'], "
argspec: "args=[\'config\', \'custom_objects\', \'safe_mode\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'True\'], "
}
member_method {
name: "disable_interactive_logging"
Expand Down Expand Up @@ -114,7 +114,7 @@ tf_module {
}
member_method {
name: "serialize_keras_object"
argspec: "args=[\'instance\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'obj\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "to_categorical"
Expand Down
4 changes: 2 additions & 2 deletions keras/api/golden/v2/tensorflow.keras.utils.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ tf_module {
}
member_method {
name: "deserialize_keras_object"
argspec: "args=[\'identifier\', \'module_objects\', \'custom_objects\', \'printable_module_name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'object\'], "
argspec: "args=[\'config\', \'custom_objects\', \'safe_mode\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'True\'], "
}
member_method {
name: "disable_interactive_logging"
Expand Down Expand Up @@ -130,7 +130,7 @@ tf_module {
}
member_method {
name: "serialize_keras_object"
argspec: "args=[\'instance\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'obj\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "set_random_seed"
Expand Down
6 changes: 3 additions & 3 deletions keras/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

from keras import backend
from keras.saving.legacy import serialization as legacy_serialization
from keras.saving.legacy.serialization import deserialize_keras_object
from keras.saving.legacy.serialization import serialize_keras_object
from keras.saving.serialization_lib import deserialize_keras_object
from keras.saving.serialization_lib import serialize_keras_object

# isort: off
from tensorflow.python.util.tf_export import keras_export
Expand Down Expand Up @@ -389,7 +389,7 @@ def get(identifier):
return deserialize(identifier, use_legacy_format=use_legacy_format)
elif isinstance(identifier, str):
config = {"class_name": str(identifier), "config": {}}
return deserialize(config)
return get(config)
elif callable(identifier):
return identifier
else:
Expand Down
6 changes: 5 additions & 1 deletion keras/engine/base_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2295,7 +2295,7 @@ def get_build_config(self):
if self._build_input_shape is not None:

def convert_tensorshapes(x):
if isinstance(x, tf.TensorShape):
if isinstance(x, tf.TensorShape) and x._dims:
return tuple(x.as_list())
return x

Expand Down Expand Up @@ -3608,6 +3608,10 @@ def _make_op(self, inputs):
# Recreate constant in graph to add distribution context.
value = tf.get_static_value(constant)
if value is not None:
if isinstance(value, dict):
value = serialization_lib.deserialize_keras_object(
value
)
constant = tf.constant(value, name=node_def.input[index])
inputs.insert(index, constant)
# TODO(b/183990973): We should drop or consolidate these private api
Expand Down
18 changes: 13 additions & 5 deletions keras/engine/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

"""A `Network` is way to compose layers: the topological form of a `Model`."""


import collections
import copy
import itertools
Expand All @@ -33,6 +32,7 @@
from keras.engine import node as node_module
from keras.engine import training as training_lib
from keras.engine import training_utils
from keras.saving import serialization_lib
from keras.saving.legacy import serialization
from keras.saving.legacy.saved_model import json_utils
from keras.saving.legacy.saved_model import network_serialization
Expand Down Expand Up @@ -1265,6 +1265,10 @@ def _should_skip_first_node(layer):
# Networks that are constructed with an Input layer/shape start with a
# pre-existing node linking their input to output. This node is excluded
# from the network config.
if not hasattr(layer, "_self_tracked_trackables"):
# Special case for serialization of Functional models without
# defined input shape argument.
return isinstance(layer, Functional)
if layer._self_tracked_trackables:
return (
isinstance(layer, Functional)
Expand Down Expand Up @@ -1428,7 +1432,10 @@ def process_node(layer, node_data):
# Call layer on its inputs, thus creating the node
# and building the layer if needed.
if input_tensors is not None:
if not layer._preserve_input_structure_in_config:
if (
not hasattr(layer, "_preserve_input_structure_in_config")
or not layer._preserve_input_structure_in_config
):
input_tensors = base_layer_utils.unnest_if_single_tensor(
input_tensors
)
Expand Down Expand Up @@ -1546,10 +1553,11 @@ def get_network_config(network, serialize_layer_fn=None, config=None):
Returns:
Config dictionary.
"""
serialize_layer_fn = (
serialize_layer_fn or serialization.serialize_keras_object
)
config = config or {}
serialize_obj_fn = serialization_lib.serialize_keras_object
if "module" not in config:
serialize_obj_fn = serialization.serialize_keras_object
serialize_layer_fn = serialize_layer_fn or serialize_obj_fn
config["name"] = network.name
node_conversion_map = {}
for layer in network.layers:
Expand Down
11 changes: 8 additions & 3 deletions keras/engine/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from keras.engine import input_layer
from keras.engine import training
from keras.engine import training_utils
from keras.saving.legacy import serialization
from keras.saving import serialization_lib
from keras.saving.legacy.saved_model import model_serialization
from keras.utils import generic_utils
from keras.utils import layer_utils
Expand Down Expand Up @@ -454,7 +454,9 @@ def get_config(self):
# filtered out of `self.layers`). Note that
# `self._self_tracked_trackables` is managed by the tracking
# infrastructure and should not be used.
layer_configs.append(serialization.serialize_keras_object(layer))
layer_configs.append(
serialization_lib.serialize_keras_object(layer)
)
config = training.Model.get_config(self)
config["name"] = self.name
config["layers"] = copy.deepcopy(layer_configs)
Expand All @@ -473,8 +475,11 @@ def from_config(cls, config, custom_objects=None):
layer_configs = config
model = cls(name=name)
for layer_config in layer_configs:
use_legacy_format = "module" not in layer_config
layer = layer_module.deserialize(
layer_config, custom_objects=custom_objects
layer_config,
custom_objects=custom_objects,
use_legacy_format=use_legacy_format,
)
model.add(layer)

Expand Down
6 changes: 3 additions & 3 deletions keras/feature_column/base_feature_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import tensorflow.compat.v2 as tf

from keras.engine.base_layer import Layer
from keras.saving.legacy import serialization
from keras.saving import serialization_lib


class _BaseFeaturesLayer(Layer):
Expand Down Expand Up @@ -130,7 +130,7 @@ def get_config(self):
for fc in self._feature_columns
]
config = {"feature_columns": column_configs}
config["partitioner"] = serialization.serialize_keras_object(
config["partitioner"] = serialization_lib.serialize_keras_object(
self._partitioner
)

Expand All @@ -147,7 +147,7 @@ def from_config(cls, config, custom_objects=None):
)
for c in config["feature_columns"]
]
config_cp["partitioner"] = serialization.deserialize_keras_object(
config_cp["partitioner"] = serialization_lib.deserialize_keras_object(
config["partitioner"], custom_objects
)

Expand Down
2 changes: 1 addition & 1 deletion keras/initializers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ py_library(
"//:expect_tensorflow_installed",
"//keras:backend",
"//keras/dtensor:utils",
"//keras/saving:serialization",
"//keras/saving:serialization_lib",
"//keras/utils:generic_utils",
"//keras/utils:tf_inspect",
],
Expand Down
11 changes: 5 additions & 6 deletions keras/initializers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from keras.initializers import initializers
from keras.initializers import initializers_v1
from keras.saving import serialization_lib
from keras.saving.legacy import serialization as legacy_serialization
from keras.utils import generic_utils
from keras.utils import tf_inspect as inspect
Expand Down Expand Up @@ -138,8 +139,7 @@ def serialize(initializer, use_legacy_format=False):
if use_legacy_format:
return legacy_serialization.serialize_keras_object(initializer)

# To be replaced by new serialization_lib
return legacy_serialization.serialize_keras_object(initializer)
return serialization_lib.serialize_keras_object(initializer)


@keras_export("keras.initializers.deserialize")
Expand All @@ -154,8 +154,7 @@ def deserialize(config, custom_objects=None, use_legacy_format=False):
printable_module_name="initializer",
)

# To be replaced by new serialization_lib
return legacy_serialization.deserialize_keras_object(
return serialization_lib.deserialize_keras_object(
config,
module_objects=LOCAL.ALL_OBJECTS,
custom_objects=custom_objects,
Expand Down Expand Up @@ -203,8 +202,8 @@ def get(identifier):
use_legacy_format = "module" not in identifier
return deserialize(identifier, use_legacy_format=use_legacy_format)
elif isinstance(identifier, str):
identifier = str(identifier)
return deserialize(identifier)
config = {"class_name": str(identifier), "config": {}}
return get(config)
elif callable(identifier):
if inspect.isclass(identifier):
identifier = identifier()
Expand Down
11 changes: 11 additions & 0 deletions keras/initializers/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from keras import backend
from keras.dtensor import utils
from keras.saving import serialization_lib

# isort: off
from tensorflow.python.util.tf_export import keras_export
Expand Down Expand Up @@ -267,6 +268,16 @@ def __call__(self, shape, dtype=None, **kwargs):
def get_config(self):
return {"value": self.value}

@classmethod
def from_config(cls, config):
config.pop("dtype", None)
if "value" in config:
if isinstance(config["value"], dict):
config["value"] = serialization_lib.deserialize_keras_object(
config["value"]
)
return cls(**config)


@keras_export(
"keras.initializers.RandomUniform",
Expand Down
27 changes: 14 additions & 13 deletions keras/layers/core/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,20 +136,21 @@ def test_dropout_with_savemodel(self):
@test_combinations.run_all_keras_modes
class LambdaLayerTest(test_combinations.TestCase):
def test_lambda(self):
test_utils.layer_test(
keras.layers.Lambda,
kwargs={"function": lambda x: x + 1},
input_shape=(3, 2),
)
with SafeModeScope(safe_mode=False):
test_utils.layer_test(
keras.layers.Lambda,
kwargs={"function": lambda x: x + 1},
input_shape=(3, 2),
)

test_utils.layer_test(
keras.layers.Lambda,
kwargs={
"function": lambda x, a, b: x * a + b,
"arguments": {"a": 0.6, "b": 0.4},
},
input_shape=(3, 2),
)
test_utils.layer_test(
keras.layers.Lambda,
kwargs={
"function": lambda x, a, b: x * a + b,
"arguments": {"a": 0.6, "b": 0.4},
},
input_shape=(3, 2),
)

# test serialization with function
def f(x):
Expand Down
Loading

0 comments on commit 7eb8ef2

Please sign in to comment.