Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minimal Build for On-Device Training #16326

Merged
merged 34 commits into from
Jun 22, 2023

Conversation

baijumeswani
Copy link
Contributor

@baijumeswani baijumeswani commented Jun 12, 2023

🛠️ Changes in this pull request:

This pull request introduces two significant changes to the project:

  • Changing on device training checkpoint format: The current implementation stores the on device training checkpoint as a sequence of tensors in multiple files inside a checkpoint folder, which can be inefficient in terms of storage and performance. In this PR, I have modified the checkpoint format to utilize the flatbuffer table to save the checkpoint to a single file, providing a more compact and efficient representation. The changes around this are twofold:

    • Add the checkpoint flatbuffer schema that will generate the necessary checkpoint source files.
    • Update the checkpoint saving and loading functionality to use the new format.
  • Adding support for onnxruntime minimal build: To support scenarios where binary size is a constraint, I made changes to ensure that the training build can work well with the minimal build.

🔍 Open Issues:

  • In order to extract the optimizer type, the existing implementation re-loaded the onnx optimizer model and parsed it. This is no longer possible, since the model format can either be onnx or ort. One idea is to do the same for ort format optimizer model. This needs some investigation.
  • Changes to the offline tooling to generate ort format training artifacts.
  • End-to-end training example showcasing the use of the minimal training build.
  • Add support for export model for inferencing in a minimal build.

@baijumeswani baijumeswani added the training issues related to ONNX Runtime training; typically submitted using template label Jun 12, 2023
@baijumeswani baijumeswani changed the title Baijumeswani/training minimal build Minimal Build for On-Device Training Jun 12, 2023
onnxruntime/core/flatbuffers/schema/compile_schema.py Outdated Show resolved Hide resolved
onnxruntime/core/graph/graph_flatbuffers_utils.cc Outdated Show resolved Hide resolved
onnxruntime/core/graph/graph_flatbuffers_utils.cc Outdated Show resolved Hide resolved
onnxruntime/core/graph/graph_flatbuffers_utils.cc Outdated Show resolved Hide resolved
onnxruntime/core/graph/graph_flatbuffers_utils.cc Outdated Show resolved Hide resolved
onnxruntime/core/graph/graph_flatbuffers_utils.cc Outdated Show resolved Hide resolved
onnxruntime/core/graph/graph_flatbuffers_utils.cc Outdated Show resolved Hide resolved
orttraining/orttraining/training_api/checkpoint.cc Outdated Show resolved Hide resolved
@baijumeswani baijumeswani requested a review from a team as a code owner June 13, 2023 22:32
onnxruntime/core/flatbuffers/schema/ort_training.fbs Outdated Show resolved Hide resolved
onnxruntime/core/flatbuffers/schema/README.md Outdated Show resolved Hide resolved
onnxruntime/core/flatbuffers/schema/README.md Outdated Show resolved Hide resolved
onnxruntime/core/flatbuffers/schema/README.md Outdated Show resolved Hide resolved
onnxruntime/core/flatbuffers/schema/README.md Outdated Show resolved Hide resolved
onnxruntime/core/graph/graph_flatbuffers_utils.cc Outdated Show resolved Hide resolved
onnxruntime/core/graph/graph_flatbuffers_utils.cc Outdated Show resolved Hide resolved
onnxruntime/core/graph/graph_flatbuffers_utils.cc Outdated Show resolved Hide resolved
onnxruntime/core/graph/graph_flatbuffers_utils.cc Outdated Show resolved Hide resolved
onnxruntime/core/graph/graph_flatbuffers_utils.cc Outdated Show resolved Hide resolved
Copy link
Contributor

@pengwa pengwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not very familar with flatbuffer format, just have few general comments.

onnxruntime/core/flatbuffers/schema/README.md Outdated Show resolved Hide resolved
onnxruntime/core/flatbuffers/schema/README.md Outdated Show resolved Hide resolved
onnxruntime/core/flatbuffers/checkpoint_version.h Outdated Show resolved Hide resolved
orttraining/orttraining/training_api/checkpoint.cc Outdated Show resolved Hide resolved
orttraining/orttraining/training_api/checkpoint.cc Outdated Show resolved Hide resolved
orttraining/orttraining/training_api/checkpoint.cc Outdated Show resolved Hide resolved
orttraining/orttraining/training_api/checkpoint.cc Outdated Show resolved Hide resolved
orttraining/orttraining/training_api/checkpoint.cc Outdated Show resolved Hide resolved
orttraining/orttraining/training_api/checkpoint.cc Outdated Show resolved Hide resolved
onnxruntime/core/flatbuffers/schema/README.md Outdated Show resolved Hide resolved
Copy link
Contributor

@pengwa pengwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minors.

orttraining/orttraining/training_api/checkpoint.h Outdated Show resolved Hide resolved
onnxruntime/core/flatbuffers/schema/README.md Outdated Show resolved Hide resolved
onnxruntime/core/flatbuffers/schema/compile_schema.py Outdated Show resolved Hide resolved
onnxruntime/core/graph/graph_flatbuffers_utils.cc Outdated Show resolved Hide resolved
onnxruntime/core/graph/graph_flatbuffers_utils.cc Outdated Show resolved Hide resolved
orttraining/orttraining/training_api/checkpoint.cc Outdated Show resolved Hide resolved
orttraining/orttraining/training_api/checkpoint.cc Outdated Show resolved Hide resolved
orttraining/orttraining/training_api/checkpoint.cc Outdated Show resolved Hide resolved
orttraining/orttraining/training_api/checkpoint.cc Outdated Show resolved Hide resolved
orttraining/orttraining/training_api/checkpoint.cc Outdated Show resolved Hide resolved
.lintrunner.toml Outdated Show resolved Hide resolved
Copy link
Contributor

@pengwa pengwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am suggesting to rename "requires_grad" to "requires_grad_params" to better represent itself.

Since it is part of schema, if we want to do that, maybe we should do it earlier (instead of bumping the versions next time). Any thought?

Besides that, LGTM.

orttraining/orttraining/training_api/checkpoint.cc Outdated Show resolved Hide resolved
edgchen1
edgchen1 previously approved these changes Jun 22, 2023
Copy link
Contributor

@edgchen1 edgchen1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a few comments, looks good overall

onnxruntime/core/flatbuffers/schema/README.md Outdated Show resolved Hide resolved
orttraining/orttraining/training_api/checkpoint.cc Outdated Show resolved Hide resolved
orttraining/orttraining/training_api/checkpoint.cc Outdated Show resolved Hide resolved
orttraining/orttraining/training_api/checkpoint.cc Outdated Show resolved Hide resolved
orttraining/orttraining/training_api/checkpoint.cc Outdated Show resolved Hide resolved
orttraining/orttraining/training_api/checkpoint.cc Outdated Show resolved Hide resolved
orttraining/orttraining/training_api/checkpoint.cc Outdated Show resolved Hide resolved
pengwa
pengwa previously approved these changes Jun 22, 2023
@baijumeswani baijumeswani merged commit 10ba1e2 into main Jun 22, 2023
87 of 91 checks passed
@baijumeswani baijumeswani deleted the baijumeswani/training-minimal-build branch June 22, 2023 19:27
@baijumeswani
Copy link
Contributor Author

Thank you for the valuable feedback @edgchen1 @pengwa @skottmckay @askhade 😄

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
training issues related to ONNX Runtime training; typically submitted using template
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants