Skip to content

Commit

Permalink
Merge branch 'develop' into fix_test_match_matrix_tensor_op
Browse files Browse the repository at this point in the history
  • Loading branch information
xingmingyyj authored Dec 26, 2023
2 parents 315238a + 96b9068 commit 14ac734
Show file tree
Hide file tree
Showing 356 changed files with 6,564 additions and 3,018 deletions.
2 changes: 1 addition & 1 deletion cmake/external/xpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ endif()
if(NOT DEFINED XPU_XHPC_BASE_DATE)
set(XPU_XHPC_BASE_DATE "20231215")
endif()
set(XPU_XCCL_BASE_VERSION "1.1.7.1")
set(XPU_XCCL_BASE_VERSION "1.1.8.1")
if(NOT DEFINED XPU_XFT_BASE_VERSION)
set(XPU_XFT_BASE_VERSION "20230602")
endif()
Expand Down
57 changes: 42 additions & 15 deletions paddle/cinn/common/integer_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ std::optional<bool> SymbolicExprAnalyzer::ProveEQ(const ir::Expr& lhs,
if (diff.is_constant()) {
return diff.get_constant() == 0;
}
ir::Expr diff_lower_bound = LowerBound(diff);
VLOG(6) << "lower bound of " << diff << " = " << diff_lower_bound;
ir::Expr diff_upper_bound = UpperBound(diff);
VLOG(6) << "upper bound of " << diff << " = " << diff_upper_bound;
if (diff_lower_bound.is_constant() && diff_upper_bound.is_constant() &&
diff_lower_bound.get_constant() == diff_upper_bound.get_constant()) {
return diff_lower_bound.get_constant() == 0;
}
std::optional<bool> prove_gt = ProveGT(lhs, rhs);
if (prove_gt.has_value() && prove_gt.value()) {
return false;
Expand All @@ -71,22 +79,11 @@ std::optional<bool> SymbolicExprAnalyzer::ProveEQ(const ir::Expr& lhs,

std::optional<bool> SymbolicExprAnalyzer::ProveNE(const ir::Expr& lhs,
const ir::Expr& rhs) const {
if (lhs == rhs) {
return false;
}
ir::Expr diff = AutoSimplify(ir::Sub::Make(lhs, rhs), var_intervals_);
if (diff.is_constant()) {
return diff.get_constant() != 0;
}
std::optional<bool> prove_gt = ProveGT(lhs, rhs);
if (prove_gt.has_value() && prove_gt.value()) {
return true;
}
std::optional<bool> prove_lt = ProveLT(lhs, rhs);
if (prove_lt.has_value() && prove_lt.value()) {
return true;
std::optional<bool> prove_eq = ProveEQ(lhs, rhs);
if (!prove_eq.has_value()) {
return std::nullopt;
}
return std::nullopt;
return !prove_eq.value();
}

std::optional<bool> SymbolicExprAnalyzer::ProveGE(const ir::Expr& lhs,
Expand Down Expand Up @@ -456,5 +453,35 @@ std::optional<bool> SingleIntervalIntSet::ProveSuperSet(
return std::nullopt;
}

ir::Expr EnhancedSimplifyModExpr(
ir::Expr e,
const absl::flat_hash_map<std::string, CasInterval>& var_intervals) {
struct Mutator : public ir::IRMutator<ir::Expr*> {
explicit Mutator(
const absl::flat_hash_map<std::string, CasInterval>& var_intervals)
: var_intervals_(var_intervals), analyzer_(var_intervals_) {}

void operator()(ir::Expr* expr) { Visit(expr); }
void Visit(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }

private:
void Visit(const ir::Mod* op, ir::Expr* expr) override {
std::optional<bool> prove_lt = analyzer_.ProveLT(op->a(), op->b());
if (prove_lt.has_value() && prove_lt.value()) {
*expr = op->a();
}
}

private:
const absl::flat_hash_map<std::string, CasInterval>& var_intervals_;
SymbolicExprAnalyzer analyzer_;
};

Mutator mutator(var_intervals);
ir::Expr copied = ir::ir_utils::IRCopy(e);
mutator(&copied);
return copied;
}

} // namespace common
} // namespace cinn
20 changes: 17 additions & 3 deletions paddle/cinn/common/integer_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,10 @@ struct SymbolicExprLimit {
// The set consisting of all integers in the interval from min to max
class SingleIntervalIntSet {
public:
explicit SingleIntervalIntSet(const ir::Expr& min,
const ir::Expr& max,
cas_intervals_t var_intervals = {});
explicit SingleIntervalIntSet(
const ir::Expr& min = SymbolicExprLimit::positive_inf,
const ir::Expr& max = SymbolicExprLimit::negative_inf,
cas_intervals_t var_intervals = {});
SingleIntervalIntSet(const SingleIntervalIntSet& set) = default;
SingleIntervalIntSet(SingleIntervalIntSet&& set) = default;
SingleIntervalIntSet& operator=(const SingleIntervalIntSet& set) = default;
Expand Down Expand Up @@ -92,5 +93,18 @@ class SingleIntervalIntSet {
cas_intervals_t var_intervals_;
};

std::optional<bool> ProveEQ(const SingleIntervalIntSet& lhs,
const SingleIntervalIntSet& rhs);
std::optional<SingleIntervalIntSet> ProvedUnion(const SingleIntervalIntSet& a,
const SingleIntervalIntSet& b);
std::optional<SingleIntervalIntSet> ProvedIntersect(
const SingleIntervalIntSet& a, const SingleIntervalIntSet& b);
cas_intervals_t MergeVarIntervals(const SingleIntervalIntSet& a,
const SingleIntervalIntSet& b);

ir::Expr EnhancedSimplifyModExpr(
ir::Expr e,
const absl::flat_hash_map<std::string, CasInterval>& var_intervals);

} // namespace common
} // namespace cinn
13 changes: 13 additions & 0 deletions paddle/cinn/common/integer_set_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -278,5 +278,18 @@ TEST(SingleIntervalIntSet, case_1) {
ProvedIntersect(set_0, single_point).value().ProveEmpty().value());
}

TEST(SingleIntervalIntSet, case_2) {
ir::Var S = ir::Var(ir::Expr(0), ir::Expr(0), "S");

SingleIntervalIntSet set_0{S, S + Expr(1)};
SingleIntervalIntSet set_1{Expr(0), Expr(1)};
SingleIntervalIntSet set_2{Expr(0), Expr(2)};

EXPECT_TRUE(ProveEQ(set_0, set_1).value());
EXPECT_FALSE(ProveEQ(set_0, set_2).value());
EXPECT_TRUE(set_0.ProveSubSet(set_2).value());
EXPECT_TRUE(set_2.ProveSuperSet(set_0).value());
}

} // namespace common
} // namespace cinn
2 changes: 2 additions & 0 deletions paddle/cinn/ir/group_schedule/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ core_gather_headers()
gather_srcs(cinnapi_src SRCS base_group_scheduler.cc)
gather_srcs(cinnapi_src SRCS st_shape_group_scheduler.cc)
gather_srcs(cinnapi_src SRCS dy_shape_group_scheduler.cc)

add_subdirectory(tactic)
8 changes: 8 additions & 0 deletions paddle/cinn/ir/group_schedule/base_group_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,13 @@ std::unique_ptr<GroupScheduler> GroupScheduler::Make(
}
}

std::unordered_set<std::string> GroupScheduler::OutputTensorNames() const {
std::unordered_set<std::string> output_tensor_names{output_tensor_names_};
for (ir::ScheduleBlockNode* node : schedule_block_graph_->EndPoints()) {
output_tensor_names.insert(node->id());
}
return output_tensor_names;
}

} // namespace ir
} // namespace cinn
2 changes: 2 additions & 0 deletions paddle/cinn/ir/group_schedule/base_group_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class GroupScheduler {

virtual std::vector<std::pair<SymbolicPredicate, ir::Expr>> GetIRs() = 0;

std::unordered_set<std::string> OutputTensorNames() const;

protected:
ir::IRSchedule* ir_sch_;
const std::unordered_set<std::string>& output_tensor_names_;
Expand Down
20 changes: 20 additions & 0 deletions paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,16 @@
// limitations under the License.

#include "paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.h"
#include "paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.h"

namespace cinn {
namespace ir {

void DynamicShapeGroupScheduler::Init() {
std::unordered_set<std::string> output_names = OutputTensorNames();
tactics_.emplace_back(new ArrangeStorageTactic(output_names));
}

void DynamicShapeGroupScheduler::Schedule() {
// Fake schedule for test
std::vector<Expr> all_blocks = ir_sch_->GetAllBlocks();
Expand All @@ -36,12 +42,26 @@ void DynamicShapeGroupScheduler::Schedule() {
auto splited_loops1 = ir_sch_->Split(block0_loops[0], {1024, -1});

ir_sch_->Bind(splited_loops1[0], "threadIdx.x");

ApplyTactics();

ir::Expr predicate1 = ir::LE::Make(Expr(1023), Expr(1024));
std::unique_ptr<ir::IRSchedule> new_ir_sch1 =
std::make_unique<ir::IRSchedule>(*ir_sch_);
ir_schs_.emplace_back(predicate1, std::move(new_ir_sch1));
}

void DynamicShapeGroupScheduler::ApplyTactics() {
schedule_block_graph_->Update(*ir_sch_);
for (const auto& tactic : tactics_) {
auto ApplyTacticFunc = [&](ir::ScheduleBlockNode* node) {
tactic->Apply(ir_sch_, node->id());
};
schedule_block_graph_->DFSTopoWalk(ApplyTacticFunc);
schedule_block_graph_->Update(*ir_sch_);
}
}

std::vector<std::pair<SymbolicPredicate, ir::Expr>>
DynamicShapeGroupScheduler::GetIRs() {
std::vector<std::pair<SymbolicPredicate, ir::Expr>> irs;
Expand Down
11 changes: 10 additions & 1 deletion paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#pragma once
#include "paddle/cinn/ir/group_schedule/base_group_scheduler.h"
#include "paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h"

namespace cinn {
namespace ir {
Expand All @@ -28,15 +29,23 @@ class DynamicShapeGroupScheduler : public GroupScheduler {
ir::IRSchedule* ir_sch,
const std::unordered_set<std::string>& output_tensor_names,
const cinn::common::Target& target)
: GroupScheduler(ir_sch, output_tensor_names, target) {}
: GroupScheduler(ir_sch, output_tensor_names, target) {
Init();
}

void Schedule() override;

std::vector<std::pair<SymbolicPredicate, ir::Expr>> GetIRs() override;

private:
void Init();

void ApplyTactics();

private:
std::vector<std::pair<SymbolicPredicate, std::unique_ptr<ir::IRSchedule>>>
ir_schs_;
std::vector<std::unique_ptr<ScheduleTactic>> tactics_;
};

} // namespace ir
Expand Down
9 changes: 0 additions & 9 deletions paddle/cinn/ir/group_schedule/st_shape_group_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,15 +205,6 @@ ir::ScheduleBlockNode* StaticShapeGroupScheduler::FindGlobalMasterNode() const {
return master;
}

std::unordered_set<std::string> StaticShapeGroupScheduler::OutputTensorNames()
const {
std::unordered_set<std::string> output_tensor_names{output_tensor_names_};
for (ir::ScheduleBlockNode* node : schedule_block_graph_->EndPoints()) {
output_tensor_names.insert(node->id());
}
return output_tensor_names;
}

void StaticShapeGroupScheduler::DoLoopAlignment() {
VLOG(5) << "[Start LoopAlignment] func body: "
<< ir_sch_->GetModule().GetExprs().front();
Expand Down
3 changes: 0 additions & 3 deletions paddle/cinn/ir/group_schedule/st_shape_group_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,6 @@ class StaticShapeGroupScheduler : public GroupScheduler {
// throughout the entire IR.
void UpdateBlockOrder();

// Get output tensor names of group.
std::unordered_set<std::string> OutputTensorNames() const;

/**
* @brief Determine whether the graph level dependency is still maintained
* after the schedule_block is placed in the insert position of target_loop.
Expand Down
3 changes: 3 additions & 0 deletions paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
core_gather_headers()

gather_srcs(cinnapi_src SRCS arrange_storage_tactic.cc)
Loading

0 comments on commit 14ac734

Please sign in to comment.