Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Evaluator support to update multiple accumulators #2894

Merged
merged 3 commits into from
Dec 20, 2023

Conversation

petebankhead
Copy link
Contributor

Description

This PR proposes adding a method to the Evaluator abstract class

   public void updateAccumulators(String[] keys, NDList labels, NDList predictions) {
        for (String key : keys) {
            updateAccumulator(key, labels, predictions);
        }
    }

and then overriding this in subclasses to more efficiently update accumulators.

The reason is that the use of EvaluatorTrainingListener can dominate training time - at least when using Apple Silicon + MPS with the recent PR #2873

Part of the issue seems to be that updateAccumulator needs to be called multiple times for different evaluators after each batch, e.g. accuracy and loss. This results in the same values being recalculated multiple times and transferred to the CPU. The recalculation itself is quite fast, but transfer to the CPU is slow.

Example with MNIST

I see the following improvements when training with MNIST using MPS.

With the PR

08:21:32.873 [main] [INFO ] a.d.t.l.LoggingTrainingListener - Epoch 50 finished.
08:21:32.873 [main] [INFO ] a.d.t.l.LoggingTrainingListener - Train: Accuracy: 1.00, SoftmaxCrossEntropyLoss: 3.24E-03
08:21:32.873 [main] [INFO ] a.d.t.l.LoggingTrainingListener - Validate: Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.12
Time: 13.673

Without the PR

08:22:41.559 [main] [INFO ] a.d.t.l.LoggingTrainingListener - Epoch 50 finished.
08:22:41.559 [main] [INFO ] a.d.t.l.LoggingTrainingListener - Train: Accuracy: 1.00, SoftmaxCrossEntropyLoss: 3.24E-03
08:22:41.559 [main] [INFO ] a.d.t.l.LoggingTrainingListener - Validate: Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.12
Time: 17.899

Without logging

Removing TrainingListener.Defaults.logging(), I see

Time: 10.006

indicating that there is still a considerable overhead in the use of training listeners, but it is roughly halved with the changes here.

On the CPU

The MNIST example admittedly isn't the best, because it's much faster to use the CPU than MPS anyway. There are still modest improvements though

  • Previous behavior, with logging: Time: 6.059
  • With the PR, with logging: Time: 5.601

I see more substantial improvements with custom training for semantic segmentation, e.g. with U-Net, when MPS is much faster than using the CPU.

Without logging, the PR should have little or no effect:

  • Previous behavior, no logging: Time: 4.893
  • With the PR, no logging: Time: 4.974

Code

Adapted from the MNIST tutorial

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.basicdataset.cv.classification.Mnist;
import ai.djl.basicmodelzoo.basic.Mlp;
import ai.djl.engine.Engine;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;

import java.util.Arrays;

public class MNISTMinimal {

    public static void main(String[] args) {

        int batchSize = 1000;

        Device device = Device.of("mps", 0);
        try (NDManager manager = NDManager.newBaseManager(device, "PyTorch")) {

            Engine.getEngine("PyTorch").setRandomSeed(1243);

            Mnist mnist = Mnist.builder()
                    .optDevice(device)
                    .optManager(manager)
                    .setSampling(batchSize, true)
                    .build();

            mnist.prepare();

            Model model = Model.newInstance("mlp", device);
            model.setBlock(new Mlp(28 * 28, 10, new int[] {128, 64}));

            DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
                    .optDevices(new Device[]{device})
                    .addEvaluator(new Accuracy()) // Use accuracy so we humans can understand how accurate the model is
                    .addTrainingListeners(TrainingListener.Defaults.logging());

            // Now that we have our training configuration, we should create a new trainer for our model
            Trainer trainer = model.newTrainer(config);

            trainer.setMetrics(new Metrics());
            trainer.initialize(new Shape(1, 28*28));

            // Deep learning is typically trained in epochs where each epoch trains the model on each item in the dataset once.
            int epoch = 50;

            long start = System.currentTimeMillis();
            RandomAccessDataset split[] = mnist.randomSplit(6, 4);
            EasyTrain.fit(trainer, epoch, split[0], split[1]);
            long end = System.currentTimeMillis();

            System.out.println("Time: " + (end - start) / 1000.0);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}

Improve the performance of EvaluatorTrainingListener by enabling evaluators to update multiple accumulators from the same labels and predictions, rather than needing to recompute values.
Aims to fix failing test
@petebankhead
Copy link
Contributor Author

I see from the failing test that my PR is more problematic than I realised. It may cause most existing subclasses of Loss or AbstractAccuracy (for example) to fail if they also override updateAccumulator, unless they also override updateAccumulators.

I've tried to fix all affected cases within DJL (there aren't very many) but I guess such a change could affect external subclasses that rely on the current behavior.

I think subclasses of Evaluator should be ok because of the default implementation of updateAccumulators.

I haven't been able to test with CUDA yet, but on MPS I'm seeing training times reduced by 20-40%. Is there a better way to achieve the performance improvements?

@codecov-commenter
Copy link

codecov-commenter commented Dec 17, 2023

Codecov Report

Attention: 1363 lines in your changes are missing coverage. Please review.

Comparison is base (bb5073f) 72.08% compared to head (f5959db) 72.27%.
Report is 935 commits behind head on master.

Files Patch % Lines
...va/ai/djl/modality/nlp/generate/TextGenerator.java 2.81% 276 Missing ⚠️
.../java/ai/djl/modality/nlp/generate/SeqBatcher.java 0.75% 132 Missing ⚠️
...ity/nlp/generate/ContrastiveSeqBatchScheduler.java 2.97% 98 Missing ⚠️
...i/djl/modality/nlp/generate/SeqBatchScheduler.java 9.83% 55 Missing ⚠️
.../java/ai/djl/modality/cv/BufferedImageFactory.java 40.96% 47 Missing and 2 partials ⚠️
...a/ai/djl/modality/nlp/generate/StepGeneration.java 2.04% 48 Missing ⚠️
api/src/main/java/ai/djl/ndarray/NDArray.java 43.42% 39 Missing and 4 partials ⚠️
...n/java/ai/djl/modality/cv/output/CategoryMask.java 22.00% 39 Missing ⚠️
...i/src/main/java/ai/djl/ndarray/NDArrayAdapter.java 71.21% 31 Missing and 7 partials ⚠️
.../cv/translator/SemanticSegmentationTranslator.java 37.50% 35 Missing ⚠️
... and 76 more

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@             Coverage Diff              @@
##             master    #2894      +/-   ##
============================================
+ Coverage     72.08%   72.27%   +0.18%     
- Complexity     5126     7184    +2058     
============================================
  Files           473      708     +235     
  Lines         21970    32014   +10044     
  Branches       2351     3337     +986     
============================================
+ Hits          15838    23138    +7300     
- Misses         4925     7284    +2359     
- Partials       1207     1592     +385     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@petebankhead
Copy link
Contributor Author

Last one... I've tried on an old Windows computer with a GTX 1060 and PyTorch 2.0.1, and see a ~15-20% improvement in performance (although ~45% by removing logging altogether, consistent with it adding considerable overhead - as with MPS).

Previous

09:58:43.121 [main] [INFO ] a.d.t.l.LoggingTrainingListener - Epoch 50 finished.
09:58:43.121 [main] [INFO ] a.d.t.l.LoggingTrainingListener - Train: Accuracy: 1.00, SoftmaxCrossEntropyLoss: 3.09E-03
09:58:43.121 [main] [INFO ] a.d.t.l.LoggingTrainingListener - Validate: Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.12
Time: 10.782

With PR

09:59:48.296 [main] [INFO ] a.d.t.l.LoggingTrainingListener - Epoch 50 finished.
09:59:48.296 [main] [INFO ] a.d.t.l.LoggingTrainingListener - Train: Accuracy: 1.00, SoftmaxCrossEntropyLoss: 3.09E-03
09:59:48.296 [main] [INFO ] a.d.t.l.LoggingTrainingListener - Validate: Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.12
Time: 8.979

Without logging

Time: 5.894

Using the CPU of my tired old laptop, the PR reduces the training time slightly from 23.586 s to 21.54 s.

@lanking520
Copy link
Member

Thanks for your contribution! This is awesome. @zachgk do you mind take a look?

Copy link
Contributor

@zachgk zachgk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great! And I appreciate the thorough work

We will have to document in the release notes that users will have to modify classes extending AbstractAccuracy, but I don't think that should be too common. One alternative might be to remove updateAccumulator() entirely. This will mean that there won't be silent problems, but it will require all users to have to change their Evalaturo and Loss classes. Overall, I think this is slightly better

@zachgk zachgk merged commit 1eb54c0 into deepjavalibrary:master Dec 20, 2023
5 checks passed
@petebankhead petebankhead deleted the evaluators branch December 20, 2023 06:21
frankfliu pushed a commit that referenced this pull request Apr 26, 2024
* Evaluator support to update multiple accumulators

Improve the performance of EvaluatorTrainingListener by enabling evaluators to update multiple accumulators from the same labels and predictions, rather than needing to recompute values.

* Fix formatting

* Update AbstractCompositeLoss.java

Aims to fix failing test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants