-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable pickle of model with TensorFlow 2.11 #1040
Enable pickle of model with TensorFlow 2.11 #1040
Conversation
This is required for model reloading to work correctly. Otherwise there is a mismatch between the reloaded model and the variables it expects.
Documentation preview |
super(BaseModel, self).__init__(**kwargs) | ||
|
||
# Initializing model control flags controlled by MetricsComputeCallback() | ||
self._should_compute_train_metrics_for_batch = tf.Variable( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs to be moved to the __init__
method (from the compile) method so that this variable is created when reloading the model.
@@ -1343,6 +1347,9 @@ def fit( | |||
x = _maybe_convert_merlin_dataset(x, batch_size, **kwargs) | |||
self._maybe_set_schema(x) | |||
|
|||
if hasattr(x, "batch_size"): | |||
self._batch_size = x.batch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to save the batch size used during fit so that when we reload the mode we pass it the same shape inputs. There were some tests of the two tower model that required this for example. This might indicate something to further investigate. I'm not sure it should matter what batch size we use when re-loading the model.
merlin/models/tf/models/base.py
Outdated
|
||
inputs = model.get_sample_inputs(batch_size=batch_size) | ||
if inputs: | ||
model(inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the important part of this PR. We're calling the model with some sample data that matches the input schema of the model. This has the side-effect of building all the layers (creating all the relevant variables). So that the variables can be reloaded correctly with the new v3 keras saving_lib.
@@ -63,6 +63,8 @@ def compute_output_shape(self, input_shape): | |||
col_schema_shape = self.schema[name].shape | |||
if col_schema_shape.is_list: | |||
max_seq_length = col_schema_shape.dims[1].max | |||
if max_seq_length is not None: | |||
max_seq_length = int(max_seq_length) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When serializing/deserializing the schema, we get back a float value of the shape. If/when we fix that in core, this line can be removed
output = mm.ContrastiveOutput( | ||
DotProduct(), | ||
post=ContrastiveSampleWeight( | ||
pos_class_weight=tf.random.uniform(shape=(1000,)), | ||
neg_class_weight=tf.random.uniform(shape=(1000,)), | ||
pos_class_weight=tf.random.uniform(shape=(item_id_cardinality,)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test was randomly failing depending if you get unlucky and get an item id of the max value 1000.
@@ -251,7 +251,7 @@ def test_with_model(self, run_eagerly, music_streaming_data): | |||
layer, | |||
tf.keras.layers.Dense(1), | |||
mm.BinaryClassificationTask("click"), | |||
schema=music_streaming_data.schema, | |||
schema=music_streaming_data.schema.select_by_name("item_recency"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test started failing after the change to Model.from_config
because we were passing in an inconsistent schema to the model compared with what it expected.
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
I've got the tests passing locally now. The GitHub Actions jobs are now stuck in a queued state though |
I think it's just taking time to catch up with all the actions that got queued by previous pushes. You can go into the Actions tab and cancel those to speed it up |
I said I'd got this working. However, actually this test in |
if self.input_schema is not None: | ||
inputs = {} | ||
for column in self.input_schema: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe the issue is schema filtering/propagation? If for example I change these lines to go through self.schema
instead of input_schema
, i.e.,
if self.schema is not None:
inputs = {}
for column in self.schema:
and change the notebook to use sub_schema
in the model by passing sub_schema
:
model = mm.Model(deep_dlrm_interaction, binary_task, schema=sub_schema)
I can get all tests to succeed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test passing is not equivalent to it running correctly unfortunately. Using self.schema instead of self.inout_schema results in an exception being thrown inside the model from_config method where we call model(inputs)
. This is because we end up passing the target rating_binary
as an inout which the model does not accept. That exception is caught somewhere in TensorFlow/Keras, which seems to result in essentially skipping this model call, and reverting to the previous mechanism for building the model during load.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It turns out that this new saving v3 lib in keras seems like it's only used during pickle and not a regular saved model that we do with model.save(). I've added a condition to apply this model build in from config based on this threadlocal that is set in there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This means that most models will be pickleable. Although we only explictly test this in one test test_pickle
. However, there is at least one kind of model that isn't pickleable (the one in define your own architecture notebook).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This same saving_lib that is being used for pickle. is now the main saving lib in 2.12. So I think we'll have some more work to get that working reliably for all models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a good solution. I believe it's the same with the upstream keras models, i.e., most but not all keras models are pickable, but the recommended way is to use the save()
method not pickle.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like GPU CI was testing against tensorflow/keras 2.12 where the saving_lib
moved from keras.saving.experimental.saving_lib
to keras.saving.saving_lib
. I've added an extra import for this and seems to work. This feels kind of brittle in the sense that this could break easily in future versions when the temporary _SAVING_V3_ENABLED
variable is removed. Unless we have some way to detect if the model is being saved with this new keras format (used for pickle) vs the saved model format.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps we can figure out if we'd like to try supporting the keras native format properly which does seem to have uncovered some bugs in various model layers, so from that perspective seems worrthwhile.
Supports #1016
Goals ⚽
Enable pickle of model with TensorFlow 2.11
Implementation Details 🚧
from_config
method builds the model, so that every layer is instantiated with variables needed.should_compute_train_metrics_for_batch
to__init__
method of Model so that it creates the variable correctly when reloading model.Testing Details 🔍
test_pickle
was failing with TensorFlow 2.11