forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[air/docs] checkpoints (ray-project#25901)
- Loading branch information
1 parent
1abe908
commit 92efc85
Showing
7 changed files
with
253 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
Checkpoints | ||
=========== | ||
|
||
Checkpoints are the common format for models that are used across different components of the Ray AI Runtime. | ||
|
||
.. image:: images/checkpoints.jpg | ||
|
||
What exactly is a checkpoint? | ||
----------------------------- | ||
|
||
The Checkpoint object is a serializable reference to a model. The model can represented in one of three ways: | ||
|
||
- a directory located on local (on-disk) storage | ||
- a directory located on external storage (e.g. cloud storage) | ||
- an in-memory dictionary | ||
|
||
The flexibility provided in the Checkpoint model representation is useful in distributed environments, | ||
where you may want to recreate the same model on multiple nodes in your Ray cluster for inference | ||
or across different Ray clusters. | ||
|
||
|
||
Creating a checkpoint | ||
--------------------- | ||
|
||
There are two ways of generating a checkpoint. | ||
|
||
The first way is to generate it from a pretrained model. Each framework that AIR supports has a ``to_air_checkpoint`` method that can be used to generate an AIR checkpoint: | ||
|
||
.. literalinclude:: doc_code/checkpoint_usage.py | ||
:language: python | ||
:start-after: __checkpoint_quick_start__ | ||
:end-before: __checkpoint_quick_end__ | ||
|
||
|
||
Another way is to retrieve it from the results of a Trainer or a Tuner. | ||
|
||
.. literalinclude:: doc_code/checkpoint_usage.py | ||
:language: python | ||
:start-after: __use_trainer_checkpoint_start__ | ||
:end-before: __use_trainer_checkpoint_end__ | ||
|
||
What can I do with a checkpoint? | ||
-------------------------------- | ||
|
||
Checkpoints can be used to instantiate a :class:`Predictor`, :class:`BatchPredictor`, or :class:`PredictorDeployment`. | ||
Upon usage, the model held by the Checkpoint will be instantiated in memory and used for inference. | ||
|
||
Below is an example using a checkpoint in the :class:`BatchPredictor` for scalable batch inference: | ||
|
||
.. literalinclude:: doc_code/checkpoint_usage.py | ||
:language: python | ||
:start-after: __batch_pred_start__ | ||
:end-before: __batch_pred_end__ | ||
|
||
Below is an example using a checkpoint in a service for online inference via :class:`PredictorDeployment`: | ||
|
||
.. literalinclude:: doc_code/checkpoint_usage.py | ||
:language: python | ||
:start-after: __online_inference_start__ | ||
:end-before: __online_inference_end__ | ||
|
||
The Checkpoint object has methods to translate between different checkpoint storage locations. | ||
With this flexibility, Checkpoint objects can be serialized and used in different contexts | ||
(e.g., on a different process or a different machine): | ||
|
||
.. literalinclude:: doc_code/checkpoint_usage.py | ||
:language: python | ||
:start-after: __basic_checkpoint_start__ | ||
:end-before: __basic_checkpoint_end__ | ||
|
||
|
||
Example: Using Checkpoints with MLflow | ||
-------------------------------------- | ||
|
||
MLflow has its own `checkpoint format <https://www.mlflow.org/docs/latest/models.html>`__ called the "MLflow Model". It is a standard format for packaging machine learning models that can be used in a variety of downstream tools. | ||
|
||
Below is an example of using MLflow models as a Ray AIR Checkpoint. | ||
|
||
.. literalinclude:: doc_code/checkpoint_mlflow.py | ||
:language: python | ||
:start-after: __mlflow_checkpoint_start__ | ||
:end-before: __mlflow_checkpoint_end__ | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# flake8: noqa | ||
# isort: skip_file | ||
|
||
# __mlflow_checkpoint_start__ | ||
from ray.air.checkpoint import Checkpoint | ||
from sklearn.ensemble import RandomForestClassifier | ||
import mlflow.sklearn | ||
|
||
# Create an sklearn classifier | ||
clf = RandomForestClassifier(max_depth=7, random_state=0) | ||
# ... e.g. train model with clf.fit() | ||
# Save model using MLflow | ||
mlflow.sklearn.save_model(clf, "model_directory") | ||
|
||
# Create checkpoint object from path | ||
checkpoint = Checkpoint.from_directory("model_directory") | ||
|
||
# Write it to some other directory | ||
checkpoint.to_directory("other_directory") | ||
# You can also use `checkpoint.to_uri/from_uri` to | ||
# read from/write to cloud storage | ||
|
||
# We can now use MLflow to re-load the model | ||
clf = mlflow.sklearn.load_model("other_directory") | ||
|
||
# It is guaranteed that the original data was recovered | ||
assert isinstance(clf, RandomForestClassifier) | ||
# __mlflow_checkpoint_end__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
# flake8: noqa | ||
# isort: skip_file | ||
|
||
# __checkpoint_quick_start__ | ||
from ray.train.tensorflow import to_air_checkpoint | ||
import tensorflow as tf | ||
|
||
# This can be a trained model. | ||
def build_model() -> tf.keras.Model: | ||
model = tf.keras.Sequential( | ||
[ | ||
tf.keras.layers.InputLayer(input_shape=(1,)), | ||
tf.keras.layers.Dense(1), | ||
] | ||
) | ||
return model | ||
|
||
|
||
model = build_model() | ||
|
||
checkpoint = to_air_checkpoint(model) | ||
# __checkpoint_quick_end__ | ||
|
||
|
||
# __use_trainer_checkpoint_start__ | ||
import pandas as pd | ||
import ray | ||
from ray.air import train_test_split | ||
from ray.train.xgboost import XGBoostTrainer | ||
|
||
|
||
bc_df = pd.read_csv( | ||
"https://air-example-data.s3.us-east-2.amazonaws.com/breast_cancer.csv" | ||
) | ||
dataset = ray.data.from_pandas(bc_df) | ||
# Optionally, read directly from s3 | ||
# dataset = ray.data.read_csv("s3://air-example-data/breast_cancer.csv") | ||
|
||
# Split data into train and validation. | ||
train_dataset, valid_dataset = train_test_split(dataset, test_size=0.3) | ||
|
||
trainer = XGBoostTrainer( | ||
scaling_config={"num_workers": 2}, | ||
label_column="target", | ||
params={ | ||
"objective": "binary:logistic", | ||
"eval_metric": ["logloss", "error"], | ||
}, | ||
datasets={"train": train_dataset}, | ||
num_boost_round=5, | ||
) | ||
|
||
result = trainer.fit() | ||
checkpoint = result.checkpoint | ||
# __use_trainer_checkpoint_end__ | ||
|
||
# __batch_pred_start__ | ||
from ray.train.batch_predictor import BatchPredictor | ||
from ray.train.xgboost import XGBoostPredictor | ||
|
||
# Create a test dataset by dropping the target column. | ||
test_dataset = valid_dataset.map_batches( | ||
lambda df: df.drop("target", axis=1), batch_format="pandas" | ||
) | ||
|
||
batch_predictor = BatchPredictor.from_checkpoint(checkpoint, XGBoostPredictor) | ||
|
||
# Bulk batch prediction. | ||
batch_predictor.predict(test_dataset) | ||
# __batch_pred_end__ | ||
|
||
|
||
# __online_inference_start__ | ||
import requests | ||
from fastapi import Request | ||
import pandas as pd | ||
|
||
from ray import serve | ||
from ray.serve import PredictorDeployment | ||
from ray.serve.http_adapters import json_request | ||
|
||
|
||
async def adapter(request: Request): | ||
content = await request.json() | ||
print(content) | ||
return pd.DataFrame.from_dict(content) | ||
|
||
|
||
serve.start(detached=True) | ||
deployment = PredictorDeployment.options(name="XGBoostService") | ||
|
||
deployment.deploy( | ||
XGBoostPredictor, checkpoint, batching_params=False, http_adapter=adapter | ||
) | ||
|
||
print(deployment.url) | ||
|
||
sample_input = test_dataset.take(1) | ||
sample_input = dict(sample_input[0]) | ||
|
||
output = requests.post(deployment.url, json=[sample_input]).json() | ||
print(output) | ||
# __online_inference_end__ | ||
|
||
# __basic_checkpoint_start__ | ||
from ray.air.checkpoint import Checkpoint | ||
|
||
# Create checkpoint data dict | ||
checkpoint_data = {"data": 123} | ||
|
||
# Create checkpoint object from data | ||
checkpoint = Checkpoint.from_dict(checkpoint_data) | ||
|
||
# Save checkpoint to a directory on the file system. | ||
path = checkpoint.to_directory() | ||
|
||
# This path can then be passed around, | ||
# # e.g. to a different function or a different script. | ||
# You can also use `checkpoint.to_uri/from_uri` to | ||
# read from/write to cloud storage | ||
|
||
# In another function or script, recover Checkpoint object from path | ||
checkpoint = Checkpoint.from_directory(path) | ||
|
||
# Convert into dictionary again | ||
recovered_data = checkpoint.to_dict() | ||
|
||
# It is guaranteed that the original data has been recovered | ||
assert recovered_data == checkpoint_data | ||
# __basic_checkpoint_end__ |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters