Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR & Inference] Add fused_weight_only_linear_pass #59366

Merged
merged 16 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.h"
#include "paddle/fluid/pir/drr/api/drr_pattern_base.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pass/pass_registry.h"
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h"

namespace {

inline int getSMVersion() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

调用这个platform::GetGPUComputeCapability(platform::GetCurrentDeviceId())接口

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

int sm_version = 80;
#if defined(PADDLE_WITH_CUDA)
sm_version = paddle::platform::GetGPUComputeCapability(
paddle::platform::GetCurrentDeviceId());
#endif
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#else
PADDLE_THROW 抛出当前Paddle没有带上CUDA编译

这样是不是会 友好点

return sm_version;
}

class FusedWeightOnlyLinearPattern
: public pir::drr::DrrPatternBase<FusedWeightOnlyLinearPattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
//
// Source Pattern.
//
pir::drr::SourcePattern src = ctx->SourcePattern();
const auto &matmul =
src.Op("pd_op.matmul",
{{"transpose_x", src.Attr("matmul_transpose_x")},
{"transpose_y", src.Attr("matmul_transpose_y")}});
src.Tensor("matmul_out") = matmul(src.Tensor("x"), src.Tensor("w"));

const auto &add = src.Op("pd_op.add");
src.Tensor("add_out") = add(src.Tensor("matmul_out"), src.Tensor("bias"));

//
// Constraints.
//
src.RequireNativeCall([](const pir::drr::MatchContext &match_ctx) -> bool {
bool matmul_trans_x = match_ctx.Attr<bool>("matmul_transpose_x");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sm不在支持的那几个里面,约束需要返回false,你的pass不能生效

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

bool matmul_trans_y = match_ctx.Attr<bool>("matmul_transpose_y");
if (matmul_trans_x || matmul_trans_y) return false;

if (!(match_ctx.Tensor("w").Shape().size() == 2 &&
match_ctx.Tensor("x").Shape().size() >= 2 &&
match_ctx.Tensor("bias").Shape().size() == 1)) {
return false;
}

return true;
});
//
// Result Pattern.
//
pir::drr::ResultPattern res = src.ResultPattern();

// quantize weight
const auto &weight_only_int8_attr =
res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any {
return "weight_only_int8";
});
// int arch = getSMVersion();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

实在不好意思,这里有一个 typo,应该把这里取消注释,然后下面的 80 改成这个 arch,我立马修改一下

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下个PR改~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,谢谢~

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

const auto &weight_quantize_arch_attr =
res.Attr([&](const pir::drr::MatchContext &match_ctx) -> std::any {
return 80;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

开源版本目前是只有80架构的weightonly linear

但是后面其实是分 70有一个特殊的weightonly,75 80 86 89后用一个weightonly

如果这里hardcode了,我觉得需要加一个注释TODO

});

const auto &weight_quantize = res.Op(
"pd_op.weight_quantize",
{{"algo", weight_only_int8_attr}, {"arch", weight_quantize_arch_attr}});
weight_quantize({&res.Tensor("w")},
{&res.Tensor("quanted_weight_tensor"),
&res.Tensor("weight_scale_tensor")});

const auto &weight_dtype_attr =
res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any {
return "int8";
});

const auto &weight_only_linear_arch_attr = res.Attr(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的attr就不要重复写,复用前面的是不是更好保持一致?不然后续可能出现漏改的情况

[&](const pir::drr::MatchContext &match_ctx) -> int { return 80; });
const auto &weight_only_linear =
res.Op("pd_op.weight_only_linear",
{{"weight_dtype", weight_dtype_attr},
{"arch", weight_only_linear_arch_attr}});
weight_only_linear({&res.Tensor("x"),
&res.Tensor("quanted_weight_tensor"),
&res.Tensor("bias"),
&res.Tensor("weight_scale_tensor")},
{&res.Tensor("add_out")});
}
};

class FusedWeightOnlyLinearPass : public pir::PatternRewritePass {
public:
FusedWeightOnlyLinearPass()
: pir::PatternRewritePass("fused_weight_only_linear_pass", 4) {}

pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override {
pir::RewritePatternSet ps(context);
ps.Add(FusedWeightOnlyLinearPattern().Build(context));
return ps;
}

bool CanApplyOn(pir::Operation *op) const override {
int sm_vesion = getSMVersion();
if (sm_vesion != 70 && sm_vesion != 80 && sm_vesion != 86 &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里先只允许80 canapplyon,然后加上一些注释

sm_vesion != 75) {
return false;
}
return op->num_regions() > 0;
}

private:
pir::FrozenRewritePatternSet patterns_;
};

} // namespace

namespace pir {
std::unique_ptr<Pass> CreateFusedWeightOnlyLinearPass() {
return std::make_unique<FusedWeightOnlyLinearPass>();
}
} // namespace pir

REGISTER_IR_PASS(fused_weight_only_linear_pass, FusedWeightOnlyLinearPass);
26 changes: 26 additions & 0 deletions paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include "paddle/pir/core/dll_decl.h"

namespace pir {

class Pass;

IR_API std::unique_ptr<Pass> CreateFusedWeightOnlyLinearPass();

} // namespace pir
2 changes: 2 additions & 0 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h"
#include "paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.h"
#include "paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.h"
#include "paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.h"
#include "paddle/fluid/pir/transforms/infer_symbolic_shape_pass.h"
#include "paddle/fluid/pir/transforms/inplace_pass.h"
#include "paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.h"
Expand Down Expand Up @@ -96,6 +97,7 @@ USE_PIR_PASS(dead_code_elimination_pass);
USE_PIR_PASS(attention_fuse_pass);
USE_PIR_PASS(fused_gemm_epilogue_pass);
USE_PIR_PASS(fused_dropout_add_pass);
USE_PIR_PASS(fused_weight_only_linear_pass);
USE_PIR_PASS(fused_linear_param_grad_add_pass);
USE_PIR_PASS(inplace_pass);
USE_PIR_PASS(replace_fetch_with_shadow_output_pass);
Expand Down
1 change: 1 addition & 0 deletions paddle/pir/pass/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ struct PassInfo {
// opt_level=1: constant fold, cse, memory optimize, etc.
// opt_level=2: the fusion logical pass.
// opt_level=3: layout, etc.
// opt_level=4: the radical optimization.
uint8_t opt_level;

// The list which pass depends on.
Expand Down
2 changes: 1 addition & 1 deletion test/ir/pir/fused_pass/pass_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def run_pir_pass(self):
if not isinstance(self.pass_list, list):
self.pass_list = [self.pass_list]

pm = pir.PassManager()
pm = pir.PassManager(opt_level=4)
for pass_name in self.pass_list:
pm.add_pass(pass_name)

Expand Down
96 changes: 96 additions & 0 deletions test/ir/pir/fused_pass/test_fused_weight_only_linear_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from pass_test import PassTest

import paddle
from paddle.base import core

np.random.seed(2013)

import os
import re


def get_cuda_version():
result = os.popen("nvcc --version").read()
regex = r'release (\S+),'
match = re.search(regex, result)
if match:
num = str(match.group(1))
integer, decimal = num.split('.')
return int(integer) * 1000 + int(float(decimal) * 10)
else:
return -1


@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

覆盖率没到的原因应该是你跳过了,converage-ci的cuda version 是10.2,这个单测你本地能验证通过不?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我本地可以通过单测

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

覆盖率没到的原因应该是你跳过了,converage-ci的cuda version 是10.2,这个单测你本地能验证通过不?

所以要不我还是手动添加一下 CPP 的单测?这个 weight_only_linear 确实是需要 cuda version >=11.2, 改 ci-coverage 对应集群的 cuda version 感觉也不现实。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果能验证通过,converage-ci可以豁免

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果能验证通过,converage-ci可以豁免

好的

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看windows-inference ci的结果吧,它的cuda版本是11.2,不用增加cpp单测,如果windows-inference ci跑到了这个单测并通过了,但是coverage不够,coverage-ci可以豁免

or paddle.device.cuda.get_device_capability()[0] < 8,
"weight_only_linear requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
class TestFusedWeightOnlyLinearPass_Fp32(PassTest):
def build_ir_progam(self):
with paddle.pir_utils.IrGuard():
self.pir_program = paddle.static.Program()
with paddle.pir.core.program_guard(self.pir_program):
x = paddle.static.data(
name='x', shape=[3, 64, 64], dtype=self.dtype
)
w = paddle.static.data(
name="w", shape=[64, 64], dtype=self.dtype
)
bias_ = paddle.static.data(
name="bias", shape=[64], dtype=self.dtype
)
bias = paddle.assign(bias_)
res1 = paddle.matmul(x=x, y=w)
out = paddle.add(res1, bias)

self.pass_list = ['fused_weight_only_linear_pass']
self.feeds = {
"x": np.random.random((3, 64, 64)).astype(self.dtype),
"w": np.random.random((64, 64)).astype(self.dtype),
"bias": np.random.random(64).astype(self.dtype),
}
self.fetch_list = [out]
self.valid_op_map = {
"pd_op.weight_only_linear": 1,
"pd_op.weight_quantize": 1,
"pd_op.matmul": 0,
"pd_op.add": 0,
}

def setUp(self):
self.place_runtime = "gpu"
self.dtype = 'float32'
self.build_ir_progam()

def test_check_output(self):
self.check_pass_correct()


class TestFusedWeightOnlyLinearPass_Fp16(TestFusedWeightOnlyLinearPass_Fp32):
def setUp(self):
self.place_runtime = "gpu"
self.dtype = 'float16'
self.build_ir_progam()


if __name__ == "__main__":
unittest.main()