Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OPT with quantizable MatMuls #78

Merged
merged 1 commit into from
Jun 16, 2023
Merged

OPT with quantizable MatMuls #78

merged 1 commit into from
Jun 16, 2023

Conversation

natuan
Copy link

@natuan natuan commented May 26, 2023

This change enables quantization for torch.bmm in OPT models using SparseML recipe.

The wrapped MatMuls corresponding to (key, query) and (prob, value) pairs in an OPTDecoderLayer now look like:

        (11): OPTDecoderLayer(
          (self_attn): OPTAttention(
            <...>
            (attn_weights_bmm): QuantizableBatchMatMul(
              (left_input): BMMLeftInput_QK()
              (right_input): BMMRightInput_QK()
              (output): BMMOutput_QK()
            )
            (attn_output_bmm): QuantizableBatchMatMul(
              (left_input): BMMLeftInput_PV()
              (right_input): BMMRightInput_PV()
              (output): BMMOutput_PV()
            )
          )
          <...>

The quantization of these MatMuls can be performed individually on their inputs and outputs through recipes, as the following example:

    ignore: ["lm_head", "Embedding", "OPTLearnedPositionalEmbedding", "QuantizableBatchMatMul"]
    scheme_overrides:
      BMMLeftInput_QK:
        input_activations:
          num_bits: 8
          symmetric: True
        output_activations: null
      BMMRightInput_QK:
        input_activations:
          num_bits: 8
          symmetric: False
        output_activations: null
      BMMOutput_QK:
        input_activations: null
        output_activations: null
      BMMLeftInput_PV:
        input_activations:
          num_bits: 8
          symmetric: False
        output_activations: null
      BMMRightInput_PV:
        input_activations:
          num_bits: 8
          symmetric: True
        output_activations: null
      BMMOutput_PV:
        input_activations: null
        output_activations: null

The resulting layer portion after applying the above recipe looks like this:

            (attn_weights_bmm): QuantizableBatchMatMul(
              (left_input): QuantWrapper(
                (quant): QuantStub(
                  (activation_post_process): FakeQuantize(
                    fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32)
                    (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
                  )
                )
                (dequant): DeQuantStub()
                (module): BMMLeftInput_QK(
                  (activation_post_process): Identity()
                )
              )
              (right_input): QuantWrapper(
                (quant): QuantStub(
                  (activation_post_process): FakeQuantize(
                    fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_affine, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32)
                    (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
                  )
                )
                (dequant): DeQuantStub()
                (module): BMMRightInput_QK(
                  (activation_post_process): Identity()
                )
              )
              (output): BMMOutput_QK(
                (activation_post_process): Identity()
              )
            )
            (attn_output_bmm): QuantizableBatchMatMul(
              (left_input): QuantWrapper(
                (quant): QuantStub(
                  (activation_post_process): FakeQuantize(
                    fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_affine, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32)
                    (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
                  )
                )
                (dequant): DeQuantStub()
                (module): BMMLeftInput_PV(
                  (activation_post_process): Identity()
                )
              )
              (right_input): QuantWrapper(
                (quant): QuantStub(
                  (activation_post_process): FakeQuantize(
                    fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32)
                    (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
                  )
                )
                (dequant): DeQuantStub()
                (module): BMMRightInput_PV(
                  (activation_post_process): Identity()
                )
              )
              (output): BMMOutput_PV(
                (activation_post_process): Identity()
              )

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@natuan natuan requested review from bfineran, anmarques, dbogunowicz and a team May 26, 2023 16:34
Copy link
Member

@anmarques anmarques left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify that you tested this? I didn't know that simply inheriting from an identity module was enough for the new class to not be seen as identity 😄

@natuan
Copy link
Author

natuan commented May 30, 2023

Can you clarify that you tested this? I didn't know that simply inheriting from an identity module was enough for the new class to not be seen as identity 😄

I've updated the description on how it looks before and after. It's still confusing that the QuantizableMatMuls needs to be part of the ignore list for this to work.

@natuan natuan merged commit 517260f into main Jun 16, 2023
1 of 2 checks passed
@natuan natuan deleted the OPT_Quantizable_MatMuls branch June 16, 2023 17:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants