Skip to content
This repository has been archived by the owner on May 1, 2023. It is now read-only.

Adding Early Exit branch

Neta Zmora edited this page Apr 12, 2020 · 2 revisions

Using the Early-Exit method requires attaching exit branches to an exiting models. This can be done by making static changes to the PyTorch model code, or by dynamically adding the branches at runtime.

In Distiller, distiller.EarlyExitMgr is a utility class that helps us make dynamic changes to the model, in order to attach exit branches. The EE implementation for the CIFAR10 ResNet architecture resnet_cifar_earlyexit.py is a good example of how this is done in Distiller:

  1. Define the exit branches.
  2. Attach the exit branches.
  3. Define the forward() method

Defining an exit branch

An early-exit branch is simply a PyTorch sub-model. It can perform any processing you like, as long as its inputs can be attached to the output of the model. The output of the branch must be the same as that of the original model's output. For example, in a CIFAR10 image classification model the output is a vector of 10 class-probabilities, and this must also be the output of each branch.

def get_exits_def():
    exits_def = [('layer1.2.relu2', nn.Sequential(nn.AvgPool2d(3),
                                                  nn.Flatten(),
                                                  nn.Linear(1600, NUM_CLASSES)))]
    return exits_def

Attaching the exit branches

Attaching exit branches is straight-forward: instantiate an distiller.EarlyExitMgr, and invokes its attach_exits method, passing the model we are attaching to, and a dictionary exit branches. The branches dictionary is keyed by the fully-qualified name of the layer to which we are attaching to. In the example above, we are attaching to layer layer1.2.relu2.

ee_mgr = distiller.EarlyExitMgr()
ee_mgr.attach_exits(my_model, get_exits_def())

Define the forward() method of our model

The forward method of our new model (the original model, now with the attached exits), should return the output of the original model output, plus the outputs of all the newly attached exits. Exits cache their outputs so before computing new outputs, we first clear the caches by invoking ee_mgr.delete_exits_outputs, then run the forward method of our model, and finally collect and return the newly cached outputs using ee_mgr.get_exits_outputs.

class ResNetCifarEarlyExit(ResNetCifar):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.ee_mgr = distiller.EarlyExitMgr()
        self.ee_mgr.attach_exits(self, get_exits_def())

    def forward(self, x):
        self.ee_mgr.delete_exits_outputs(self)
        # Run the input through the network (including exits)
        x = super().forward(x)
        outputs = self.ee_mgr.get_exits_outputs(self) + [x]
        return outputs