Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformer Head Pruner #3884

Merged
merged 67 commits into from
Jul 28, 2021
Merged

Transformer Head Pruner #3884

merged 67 commits into from
Jul 28, 2021

Conversation

xiaowu0162
Copy link
Contributor

@xiaowu0162 xiaowu0162 commented Jun 29, 2021

This pr adds a pruner for pruning attention heads in transformers.
To-do:

  • basic: module matching, name-based weight grouping, group-aware maskers using weight norm as criteria
  • graph-based weight grouping
  • maskers relying on activation
  • maskers relying on gradient
  • global sort in maskers
  • support iterative pruning from scratch / integration with pruning scheduler for iterative features
  • example
  • docs

@xiaowu0162

This comment has been minimized.

@xiaowu0162 xiaowu0162 marked this pull request as ready for review July 9, 2021 08:16
break
except:
continue
if layer_idx is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a better way to get the index of the attention head? the first integer may be not strong.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This layer_idx is the layer index of the BERT encoder. Here I include these lines of code only to show the user how they may take advantage of the pruned_heads dict inside pruner to get the pruned heads for each group, and then match each group to the original layer, and finally call the built-in transformers _prune_heads() function to do model speedup. This is meant to be a temporary workaround before we can properly handle speedup for transformers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If our speedup code after refactor can handle transformer, then I will replace these lines with our speedup methods (maybe in a separate pr)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the users are aware of the naming of their own model to prune, I think they can also use their own rules to match layers to groups

and include `model, optimizer, criterion, epoch` as function arguments.
criterion: function
Function used to calculate the loss between the target and the output.
For example, you can use ``torch.nn.CrossEntropyLoss()`` as input.
Copy link
Contributor

@zheng-ningxin zheng-ningxin Jul 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel like that the TransformerHeadPruner is too heavy. I prefer to locate this pruner as a one-shot pruner, which means we do not need handle with the num_iteration, optimizer, trainer, criterion, things. That's much clearer. All those finetuning related things we can offload to the outer search algorithms. We can discuss with Quanlu @QuanluZhang .

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can further discuss on that. One challenge is that this does not fit well in our current compression V1 framework (since the current iterative and dependency aware pruner are limited to convolutions), and compression V2 is not ready yet. My initial thought was to first integrate all these logic in one pruner (because of empirically good performance compared to one-shot pruning), and then factor out when compression V2 is ready.

@QuanluZhang QuanluZhang merged commit aa2cc92 into microsoft:master Jul 28, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants