Skip to content

Commit

Permalink
Ignoring Gemm -> Q/Dq -> Softmax in GemmToQLinearMatMul transform (#1343
Browse files Browse the repository at this point in the history
)
  • Loading branch information
corey-nm committed Jan 26, 2023
1 parent c7cf6b6 commit ba84f56
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/sparseml/exporters/transforms/gemm_to_qlinearmatmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

from onnx import ModelProto, helper, numpy_helper

from sparseml.exporters.transforms.onnx_transform import OnnxTransform
Expand All @@ -31,15 +29,14 @@

__all__ = ["GemmToQLinearMatMul"]

_LOGGER = logging.getLogger(__name__)


class GemmToQLinearMatMul(OnnxTransform):
"""
Transforms Gemm nodes to QLinearMatMul.
NOTE: Does not match if the structure is
`Gemm -> QuantizeLinear -> DequantizeLinear -> Gemm`
1. `Gemm -> QuantizeLinear -> DequantizeLinear -> Gemm`
2. `Gemm -> QuantizeLinear -> DequantizeLinear -> Softmax`
Transforms
```
Expand Down Expand Up @@ -93,7 +90,10 @@ def transform(self, model: ModelProto) -> ModelProto:
output_dequant = match.children[0][1]
if output_dequant is not None:
output_dequant_child = graph.get_node_single_child(output_dequant)
if output_dequant_child and output_dequant_child.op_type == "Gemm":
if output_dequant_child and output_dequant_child.op_type in {
"Gemm",
"Softmax",
}:
# output quant is not a QDQ block for the current Gemm Node but,
# the input QDQ block for a new Gemm block this Gemm should be
# skipped and processed by _convert_quantizable_gemm_no_activations
Expand Down

0 comments on commit ba84f56

Please sign in to comment.