Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR & Inference] Add fused_weight_only_linear_pass #59366

Merged
merged 16 commits into from
Dec 4, 2023

Conversation

Wanglongzhi2001
Copy link
Contributor

PR types

New features

PR changes

APIs

Description

添加将 matmul 算子转换成 weight_only_linear 算子的 PIR 的 Pass

Copy link

paddle-bot bot commented Nov 25, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Nov 25, 2023
@Wanglongzhi2001 Wanglongzhi2001 changed the title [Inference]Add matmul_to_weight_only_linear_pass [Inference] Add matmul_to_weight_only_linear_pass Nov 25, 2023
@@ -0,0 +1,135 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pass实现移到当前的fusion目录下

}
};

class MatmulToWeightOnlyLinearPass : public pir::Pass {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

继承pir::PatternRewritePass来实现

Comment on lines 111 to 120
void Run(pir::Operation *op) override {
pir::GreedyRewriteConfig cfg;
cfg.use_top_down_traversal = true;
cfg.max_iterations = 10;
pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg);
}

bool CanApplyOn(pir::Operation *op) const override {
return op->isa<::pir::ModuleOp>() && op->num_regions() > 0;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

继承pir::PatternRewritePass来实现后,这两个接口就不需要了

@@ -0,0 +1,86 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单测验证参考test/ir/pir/fused_pass/test_conv2d_fuse_pass.py这个写

@@ -1138,3 +1139,75 @@ TEST(constant_folding, ConstantFolding_Combine) {
CHECK_EQ(pm.Run(&program), true);
// EXPECT_EQ(program.block()->size(), 6u);
}

void BuildWeightOnlyLinearProgram(pir::Program *program,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

写py单测就可以了,这个cpp单测给删掉吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我当前代码的情况下不写cpp单侧的话,ci coverage 的 cpp coverage 过不了

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

采用新的comment来实现,ci-converage到时候跑完没过的话,我看一下,理论上不应该存在不被覆盖的代码

src.Tensor("add_out") = add(src.Tensor("matmul_out"), src.Tensor("bias"));

//
// Constraints.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Constraints添加一下对sm的限制,python/paddle/nn/quant/quantized_linear.py
image

});

const auto &weight_only_linear_arch_attr = res.Attr(
[](const pir::drr::MatchContext &match_ctx) -> int { return 80; });
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前只有80架构能支持吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不是的,只是因为我看这个算子的这个参数的默认值是80,我就也写80了

- op : weight_only_linear
args : (Tensor x, Tensor weight, Tensor bias, Tensor weight_scale, str weight_dtype, int arch = 80)
output : Tensor(out)
infer_meta :
func : WeightOnlyLinearInferMeta
kernel :
func : weight_only_linear
data_type : x
optional: bias
backward: weight_only_linear_grad

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这的逻辑确实不大对,已经修改成识别到的当前架构了,和 python api 的 arch 为默认值 none 时的情况保持一致

paddle::framework::Scope *scope) {
pir::Builder builder = pir::Builder(ctx, program->block());

pir::Type fp32_dtype = pir::Float32Type::get(ctx);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单测应该也补充float16的情况

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


namespace {

inline int getSMVersion() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

调用这个platform::GetGPUComputeCapability(platform::GetCurrentDeviceId())接口

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

// Constraints.
//
src.RequireNativeCall([](const pir::drr::MatchContext &match_ctx) -> bool {
bool matmul_trans_x = match_ctx.Attr<bool>("matmul_transpose_x");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sm不在支持的那几个里面,约束需要返回false,你的pass不能生效

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

class FusedWeightOnlyLinearPass : public pir::PatternRewritePass {
public:
FusedWeightOnlyLinearPass()
: pir::PatternRewritePass("fused_weight_only_linear_pass", 2) {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

opt_level设为4吧,表示比较激进的优化,并且在pass.h里补充下opt_level=4的注释说明
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor Author

@Wanglongzhi2001 Wanglongzhi2001 Nov 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yuanlehome 这里我把 opt_level 改成 4 之后单测就跑不过了,我把别的 pass 的 opt_level 增加也都跑不过单测,但是我在 PIR 的源码,代码中我并没有看到 opt_level 影响 pass 执行的逻辑,所以这里先暂时保持 opt_level 为 2 吧

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

还是设为4,改一下这里的逻辑
image

@Wanglongzhi2001
Copy link
Contributor Author

@yuanlehome 您好,修改过后 cpp coverage 没过
image
image

Comment on lines 65 to 69
int sm_vesion = getSMVersion();
if (sm_vesion != 70 || sm_vesion != 80 || sm_vesion != 86 ||
sm_vesion != 75) {
return false;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不好意思哈,我突然想起来,这个应该放在CanApplyOn接口里,这个接口是专门来限制PASS应用范围的

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

至于那个coverage ci,等ci跑完,你别着急提commit,我看下是哪些代码行没覆盖到~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yuanlehome 刘哥,ci 跑完了,你看看 ci coverage ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的看到了,覆盖率检测应该是有点问题,明天我看看修复一下~


@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

覆盖率没到的原因应该是你跳过了,converage-ci的cuda version 是10.2,这个单测你本地能验证通过不?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我本地可以通过单测

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

覆盖率没到的原因应该是你跳过了,converage-ci的cuda version 是10.2,这个单测你本地能验证通过不?

所以要不我还是手动添加一下 CPP 的单测?这个 weight_only_linear 确实是需要 cuda version >=11.2, 改 ci-coverage 对应集群的 cuda version 感觉也不现实。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果能验证通过,converage-ci可以豁免

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果能验证通过,converage-ci可以豁免

好的

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看windows-inference ci的结果吧,它的cuda版本是11.2,不用增加cpp单测,如果windows-inference ci跑到了这个单测并通过了,但是coverage不够,coverage-ci可以豁免

@Wanglongzhi2001
Copy link
Contributor Author

Wanglongzhi2001 commented Dec 4, 2023

@yuanlehome ci已经通过了,麻烦看一下。不过本地因为机器显存不够的原因,暂时还没有在 paddlenlp 的weight only的llama等大模型的基础上进行推理实测我这个pass

Copy link
Contributor

@yuanlehome yuanlehome left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yuanlehome
Copy link
Contributor

@yuanlehome ci已经通过了,麻烦看一下。不过本地因为机器显存不够的原因,暂时还没有在 paddlenlp 的weight only的llama等大模型的基础上进行推理实测我这个pass

我理解你在任何一个具有matmul op的模型上同样可以验证这个pass

@Wanglongzhi2001
Copy link
Contributor Author

@yuanlehome ci已经通过了,麻烦看一下。不过本地因为机器显存不够的原因,暂时还没有在 paddlenlp 的weight only的llama等大模型的基础上进行推理实测我这个pass

我理解你在任何一个具有matmul op的模型上同样可以验证这个pass

好的谢谢,我原本是想打算在大模型的基础上测试一下加速的效果

Copy link
Contributor

@wanghuancoder wanghuancoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yuanlehome
Copy link
Contributor

@yuanlehome ci已经通过了,麻烦看一下。不过本地因为机器显存不够的原因,暂时还没有在 paddlenlp 的weight only的llama等大模型的基础上进行推理实测我这个pass

我理解你在任何一个具有matmul op的模型上同样可以验证这个pass

好的谢谢,我原本是想打算在大模型的基础上测试一下加速的效果

如果由于机器原因,可以让导师来验证一下~

@yuanlehome yuanlehome changed the title [Inference] Add matmul_to_weight_only_linear_pass [PIR & Inference] Add fused_weight_only_linear_pass Dec 4, 2023
res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any {
return "weight_only_int8";
});
// int arch = getSMVersion();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

实在不好意思,这里有一个 typo,应该把这里取消注释,然后下面的 80 改成这个 arch,我立马修改一下

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下个PR改~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,谢谢~

#if defined(PADDLE_WITH_CUDA)
sm_version = paddle::platform::GetGPUComputeCapability(
paddle::platform::GetCurrentDeviceId());
#endif
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#else
PADDLE_THROW 抛出当前Paddle没有带上CUDA编译

这样是不是会 友好点

res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any {
return "weight_only_int8";
});
// int arch = getSMVersion();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

// int arch = getSMVersion();
const auto &weight_quantize_arch_attr =
res.Attr([&](const pir::drr::MatchContext &match_ctx) -> std::any {
return 80;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

开源版本目前是只有80架构的weightonly linear

但是后面其实是分 70有一个特殊的weightonly,75 80 86 89后用一个weightonly

如果这里hardcode了,我觉得需要加一个注释TODO


bool CanApplyOn(pir::Operation *op) const override {
int sm_vesion = getSMVersion();
if (sm_vesion != 70 && sm_vesion != 80 && sm_vesion != 86 &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里先只允许80 canapplyon,然后加上一些注释

return "int8";
});

const auto &weight_only_linear_arch_attr = res.Attr(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的attr就不要重复写,复用前面的是不是更好保持一致?不然后续可能出现漏改的情况

@yuanlehome
Copy link
Contributor

新增的comment下个PR统一修改下哈~,这个先合入了

@Wanglongzhi2001
Copy link
Contributor Author

新增的comment下个PR统一修改下哈~,这个先合入了

好的

@yuanlehome yuanlehome merged commit 1fe4974 into PaddlePaddle:develop Dec 4, 2023
29 checks passed
@yuanlehome
Copy link
Contributor

yuanlehome commented Dec 4, 2023

新增的comment下个PR统一修改下哈~,这个先合入了

好的

下一PR描述里,记得引用下这个PR,标明是补充实现~(最好12.10号前提PR并合入)

@Wanglongzhi2001
Copy link
Contributor Author

新增的comment下个PR统一修改下哈~,这个先合入了

好的

下一PR描述里,记得引用下这个PR,标明是补充实现~

好的

SigureMo pushed a commit to gouzil/Paddle that referenced this pull request Dec 5, 2023
* [Inference]Add matmul_to_weight_only_linear_pass

* fix test and rename pass

* fix the comment of test

* fix ci

* fix: fix test

* refactor: refactor pass and test

* refactor: refactor pass

* refactor: add fp16 test

* refactor: refactor pass

* refactor: refactor the opt_level

* fix: fix typo

* fix: fix ci compile error when without gpu

* refactor: refactor pass and test

* fix: fix conflict

* fix: fix conflict

* refactor: refactor opt_level in pass_test to 4
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants