Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable pickle of model with TensorFlow 2.11 #1040

Merged

Conversation

oliverholworthy
Copy link
Member

@oliverholworthy oliverholworthy commented Mar 28, 2023

Supports #1016

Goals ⚽

Enable pickle of model with TensorFlow 2.11

Implementation Details 🚧

  • TensorFlow/Keras 2.11 enabled v3 saving format and it appears to now be expected that the from_config method builds the model, so that every layer is instantiated with variables needed.
  • Moved the creation of variable should_compute_train_metrics_for_batch to __init__ method of Model so that it creates the variable correctly when reloading model.

Testing Details 🔍

  • Existing test test_pickle was failing with TensorFlow 2.11

This is required for model reloading to work correctly. Otherwise there
is a mismatch between the reloaded model and the variables it expects.
@oliverholworthy oliverholworthy added the chore Maintenance for the repository label Mar 28, 2023
@oliverholworthy oliverholworthy added this to the Merlin 23.03 milestone Mar 28, 2023
@oliverholworthy oliverholworthy self-assigned this Mar 28, 2023
@github-actions
Copy link

Documentation preview

https://nvidia-merlin.github.io/models/review/pr-1040

@edknv edknv mentioned this pull request Mar 28, 2023
super(BaseModel, self).__init__(**kwargs)

# Initializing model control flags controlled by MetricsComputeCallback()
self._should_compute_train_metrics_for_batch = tf.Variable(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be moved to the __init__ method (from the compile) method so that this variable is created when reloading the model.

@@ -1343,6 +1347,9 @@ def fit(
x = _maybe_convert_merlin_dataset(x, batch_size, **kwargs)
self._maybe_set_schema(x)

if hasattr(x, "batch_size"):
self._batch_size = x.batch_size
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to save the batch size used during fit so that when we reload the mode we pass it the same shape inputs. There were some tests of the two tower model that required this for example. This might indicate something to further investigate. I'm not sure it should matter what batch size we use when re-loading the model.


inputs = model.get_sample_inputs(batch_size=batch_size)
if inputs:
model(inputs)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the important part of this PR. We're calling the model with some sample data that matches the input schema of the model. This has the side-effect of building all the layers (creating all the relevant variables). So that the variables can be reloaded correctly with the new v3 keras saving_lib.

@@ -63,6 +63,8 @@ def compute_output_shape(self, input_shape):
col_schema_shape = self.schema[name].shape
if col_schema_shape.is_list:
max_seq_length = col_schema_shape.dims[1].max
if max_seq_length is not None:
max_seq_length = int(max_seq_length)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When serializing/deserializing the schema, we get back a float value of the shape. If/when we fix that in core, this line can be removed

output = mm.ContrastiveOutput(
DotProduct(),
post=ContrastiveSampleWeight(
pos_class_weight=tf.random.uniform(shape=(1000,)),
neg_class_weight=tf.random.uniform(shape=(1000,)),
pos_class_weight=tf.random.uniform(shape=(item_id_cardinality,)),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test was randomly failing depending if you get unlucky and get an item id of the max value 1000.

@@ -251,7 +251,7 @@ def test_with_model(self, run_eagerly, music_streaming_data):
layer,
tf.keras.layers.Dense(1),
mm.BinaryClassificationTask("click"),
schema=music_streaming_data.schema,
schema=music_streaming_data.schema.select_by_name("item_recency"),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test started failing after the change to Model.from_config because we were passing in an inconsistent schema to the model compared with what it expected.

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@oliverholworthy
Copy link
Member Author

I've got the tests passing locally now. The GitHub Actions jobs are now stuck in a queued state though

@karlhigley
Copy link
Contributor

I think it's just taking time to catch up with all the actions that got queued by previous pushes. You can go into the Actions tab and cancel those to speed it up

@oliverholworthy
Copy link
Member Author

I've got the tests passing locally now.

I said I'd got this working. However, actually this test in tests/unit/tf/examples/test_06_advanced_own_architecture.py is now failing again. I realised that the notebook test was fixed accidentally, as a result of model.save somehow swalliowing errors raised in Model.from_config. After the last 2 commits to correclty create the correct inputs the call to model(inputs) no longer raises, which causse the reload to fail later when it finds variables with different shapes

Comment on lines +1879 to +1881
if self.input_schema is not None:
inputs = {}
for column in self.input_schema:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the issue is schema filtering/propagation? If for example I change these lines to go through self.schema instead of input_schema, i.e.,

        if self.schema is not None:
            inputs = {}
            for column in self.schema:

and change the notebook to use sub_schema in the model by passing sub_schema:

model = mm.Model(deep_dlrm_interaction, binary_task, schema=sub_schema)

I can get all tests to succeed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test passing is not equivalent to it running correctly unfortunately. Using self.schema instead of self.inout_schema results in an exception being thrown inside the model from_config method where we call model(inputs). This is because we end up passing the target rating_binary as an inout which the model does not accept. That exception is caught somewhere in TensorFlow/Keras, which seems to result in essentially skipping this model call, and reverting to the previous mechanism for building the model during load.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It turns out that this new saving v3 lib in keras seems like it's only used during pickle and not a regular saved model that we do with model.save(). I've added a condition to apply this model build in from config based on this threadlocal that is set in there.

8e55be1

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means that most models will be pickleable. Although we only explictly test this in one test test_pickle. However, there is at least one kind of model that isn't pickleable (the one in define your own architecture notebook).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This same saving_lib that is being used for pickle. is now the main saving lib in 2.12. So I think we'll have some more work to get that working reliably for all models.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a good solution. I believe it's the same with the upstream keras models, i.e., most but not all keras models are pickable, but the recommended way is to use the save() method not pickle.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like test_pickle is failing after 8e55be1 and a916914 :(

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like GPU CI was testing against tensorflow/keras 2.12 where the saving_lib moved from keras.saving.experimental.saving_lib to keras.saving.saving_lib. I've added an extra import for this and seems to work. This feels kind of brittle in the sense that this could break easily in future versions when the temporary _SAVING_V3_ENABLED variable is removed. Unless we have some way to detect if the model is being saved with this new keras format (used for pickle) vs the saved model format.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we can figure out if we'd like to try supporting the keras native format properly which does seem to have uncovered some bugs in various model layers, so from that perspective seems worrthwhile.

@oliverholworthy oliverholworthy merged commit 210aade into NVIDIA-Merlin:main Mar 31, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
chore Maintenance for the repository
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants