diff --git a/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc new file mode 100644 index 0000000000000..823c7bdc8f81b --- /dev/null +++ b/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc @@ -0,0 +1,141 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.h" +#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/platform/device/gpu/gpu_info.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_registry.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" + +namespace { + +inline int getSMVersion() { + int sm_version = 80; +#if defined(PADDLE_WITH_CUDA) + sm_version = paddle::platform::GetGPUComputeCapability( + paddle::platform::GetCurrentDeviceId()); +#endif + return sm_version; +} + +class FusedWeightOnlyLinearPattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + // + // Source Pattern. + // + pir::drr::SourcePattern src = ctx->SourcePattern(); + const auto &matmul = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_transpose_x")}, + {"transpose_y", src.Attr("matmul_transpose_y")}}); + src.Tensor("matmul_out") = matmul(src.Tensor("x"), src.Tensor("w")); + + const auto &add = src.Op("pd_op.add"); + src.Tensor("add_out") = add(src.Tensor("matmul_out"), src.Tensor("bias")); + + // + // Constraints. + // + src.RequireNativeCall([](const pir::drr::MatchContext &match_ctx) -> bool { + bool matmul_trans_x = match_ctx.Attr("matmul_transpose_x"); + bool matmul_trans_y = match_ctx.Attr("matmul_transpose_y"); + if (matmul_trans_x || matmul_trans_y) return false; + + if (!(match_ctx.Tensor("w").Shape().size() == 2 && + match_ctx.Tensor("x").Shape().size() >= 2 && + match_ctx.Tensor("bias").Shape().size() == 1)) { + return false; + } + + return true; + }); + // + // Result Pattern. + // + pir::drr::ResultPattern res = src.ResultPattern(); + + // quantize weight + const auto &weight_only_int8_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + return "weight_only_int8"; + }); + // int arch = getSMVersion(); + const auto &weight_quantize_arch_attr = + res.Attr([&](const pir::drr::MatchContext &match_ctx) -> std::any { + return 80; + }); + + const auto &weight_quantize = res.Op( + "pd_op.weight_quantize", + {{"algo", weight_only_int8_attr}, {"arch", weight_quantize_arch_attr}}); + weight_quantize({&res.Tensor("w")}, + {&res.Tensor("quanted_weight_tensor"), + &res.Tensor("weight_scale_tensor")}); + + const auto &weight_dtype_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + return "int8"; + }); + + const auto &weight_only_linear_arch_attr = res.Attr( + [&](const pir::drr::MatchContext &match_ctx) -> int { return 80; }); + const auto &weight_only_linear = + res.Op("pd_op.weight_only_linear", + {{"weight_dtype", weight_dtype_attr}, + {"arch", weight_only_linear_arch_attr}}); + weight_only_linear({&res.Tensor("x"), + &res.Tensor("quanted_weight_tensor"), + &res.Tensor("bias"), + &res.Tensor("weight_scale_tensor")}, + {&res.Tensor("add_out")}); + } +}; + +class FusedWeightOnlyLinearPass : public pir::PatternRewritePass { + public: + FusedWeightOnlyLinearPass() + : pir::PatternRewritePass("fused_weight_only_linear_pass", 4) {} + + pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { + pir::RewritePatternSet ps(context); + ps.Add(FusedWeightOnlyLinearPattern().Build(context)); + return ps; + } + + bool CanApplyOn(pir::Operation *op) const override { + int sm_vesion = getSMVersion(); + if (sm_vesion != 70 && sm_vesion != 80 && sm_vesion != 86 && + sm_vesion != 75) { + return false; + } + return op->num_regions() > 0; + } + + private: + pir::FrozenRewritePatternSet patterns_; +}; + +} // namespace + +namespace pir { +std::unique_ptr CreateFusedWeightOnlyLinearPass() { + return std::make_unique(); +} +} // namespace pir + +REGISTER_IR_PASS(fused_weight_only_linear_pass, FusedWeightOnlyLinearPass); diff --git a/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.h b/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.h new file mode 100644 index 0000000000000..b616355d15f29 --- /dev/null +++ b/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.h @@ -0,0 +1,26 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/pir/core/dll_decl.h" + +namespace pir { + +class Pass; + +IR_API std::unique_ptr CreateFusedWeightOnlyLinearPass(); + +} // namespace pir diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index f8cd9b46a4cab..89d4cf5c025a7 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -41,6 +41,7 @@ #include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h" #include "paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.h" #include "paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.h" +#include "paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.h" #include "paddle/fluid/pir/transforms/infer_symbolic_shape_pass.h" #include "paddle/fluid/pir/transforms/inplace_pass.h" #include "paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.h" @@ -96,6 +97,7 @@ USE_PIR_PASS(dead_code_elimination_pass); USE_PIR_PASS(attention_fuse_pass); USE_PIR_PASS(fused_gemm_epilogue_pass); USE_PIR_PASS(fused_dropout_add_pass); +USE_PIR_PASS(fused_weight_only_linear_pass); USE_PIR_PASS(fused_linear_param_grad_add_pass); USE_PIR_PASS(inplace_pass); USE_PIR_PASS(replace_fetch_with_shadow_output_pass); diff --git a/paddle/pir/pass/pass.h b/paddle/pir/pass/pass.h index a9e9881ec9fd4..cc5e4a1dcbd83 100644 --- a/paddle/pir/pass/pass.h +++ b/paddle/pir/pass/pass.h @@ -59,6 +59,7 @@ struct PassInfo { // opt_level=1: constant fold, cse, memory optimize, etc. // opt_level=2: the fusion logical pass. // opt_level=3: layout, etc. + // opt_level=4: the radical optimization. uint8_t opt_level; // The list which pass depends on. diff --git a/test/ir/pir/fused_pass/pass_test.py b/test/ir/pir/fused_pass/pass_test.py index 5f7ca010d359c..1409111a3085a 100644 --- a/test/ir/pir/fused_pass/pass_test.py +++ b/test/ir/pir/fused_pass/pass_test.py @@ -33,7 +33,7 @@ def run_pir_pass(self): if not isinstance(self.pass_list, list): self.pass_list = [self.pass_list] - pm = pir.PassManager() + pm = pir.PassManager(opt_level=4) for pass_name in self.pass_list: pm.add_pass(pass_name) diff --git a/test/ir/pir/fused_pass/test_fused_weight_only_linear_pass.py b/test/ir/pir/fused_pass/test_fused_weight_only_linear_pass.py new file mode 100644 index 0000000000000..731d59d23aeb1 --- /dev/null +++ b/test/ir/pir/fused_pass/test_fused_weight_only_linear_pass.py @@ -0,0 +1,96 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from pass_test import PassTest + +import paddle +from paddle.base import core + +np.random.seed(2013) + +import os +import re + + +def get_cuda_version(): + result = os.popen("nvcc --version").read() + regex = r'release (\S+),' + match = re.search(regex, result) + if match: + num = str(match.group(1)) + integer, decimal = num.split('.') + return int(integer) * 1000 + int(float(decimal) * 10) + else: + return -1 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "weight_only_linear requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class TestFusedWeightOnlyLinearPass_Fp32(PassTest): + def build_ir_progam(self): + with paddle.pir_utils.IrGuard(): + self.pir_program = paddle.static.Program() + with paddle.pir.core.program_guard(self.pir_program): + x = paddle.static.data( + name='x', shape=[3, 64, 64], dtype=self.dtype + ) + w = paddle.static.data( + name="w", shape=[64, 64], dtype=self.dtype + ) + bias_ = paddle.static.data( + name="bias", shape=[64], dtype=self.dtype + ) + bias = paddle.assign(bias_) + res1 = paddle.matmul(x=x, y=w) + out = paddle.add(res1, bias) + + self.pass_list = ['fused_weight_only_linear_pass'] + self.feeds = { + "x": np.random.random((3, 64, 64)).astype(self.dtype), + "w": np.random.random((64, 64)).astype(self.dtype), + "bias": np.random.random(64).astype(self.dtype), + } + self.fetch_list = [out] + self.valid_op_map = { + "pd_op.weight_only_linear": 1, + "pd_op.weight_quantize": 1, + "pd_op.matmul": 0, + "pd_op.add": 0, + } + + def setUp(self): + self.place_runtime = "gpu" + self.dtype = 'float32' + self.build_ir_progam() + + def test_check_output(self): + self.check_pass_correct() + + +class TestFusedWeightOnlyLinearPass_Fp16(TestFusedWeightOnlyLinearPass_Fp32): + def setUp(self): + self.place_runtime = "gpu" + self.dtype = 'float16' + self.build_ir_progam() + + +if __name__ == "__main__": + unittest.main()