Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate add_guard #195

Merged
merged 14 commits into from
May 23, 2022
30 changes: 4 additions & 26 deletions src/exo/API.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,30 +613,6 @@ def add_assertion(self, assertion):
p.instr, p.eff, p.srcinfo)
return Procedure(p, _provenance_eq_Procedure=None)

def add_guard(self, stmt_pat, iter_pat, value):
if not isinstance(stmt_pat, str):
raise TypeError("expected first arg to be a string")
if not isinstance(iter_pat, str):
raise TypeError("expected second arg to be a string")
if not isinstance(value, int):
raise TypeError("expected third arg to be an int")
# TODO: refine this analysis or re-think the directive...
# this is making sure that the condition will guarantee that the
# guarded statement runs on the first iteration
if value != 0:
raise TypeError("expected third arg to be 0")

iter_pat = iter_name_to_pattern(iter_pat)
iter_pat = self._find_stmt(iter_pat)
if not isinstance(iter_pat, LoopIR.Seq):
raise TypeError("expected the loop to be sequential")
stmts = self._find_stmt(stmt_pat, default_match_no=None)
loopir = self._loopir_proc
for s in stmts:
loopir = Schedules.DoAddGuard(loopir, s, iter_pat, value).result()

return Procedure(loopir, _provenance_eq_Procedure=self)

def bound_and_guard(self, loop):
"""
Replace
Expand Down Expand Up @@ -693,17 +669,19 @@ def fuse_if(self, if1, if2):

return Procedure(loopir, _provenance_eq_Procedure=self)

def add_loop(self, stmt, var, hi):
def add_loop(self, stmt, var, hi, *, guard=False):
if not isinstance(stmt, str):
raise TypeError("expected first arg to be a string")
if not isinstance(var, str):
raise TypeError("expected second arg to be a string")
if not isinstance(hi, int):
raise TypeError("currently, only constant bound is supported")
if not isinstance(guard, bool):
raise TypeError("guard needs to be True or False")
alexreinking marked this conversation as resolved.
Show resolved Hide resolved

stmt = self._find_stmt(stmt)
loopir = self._loopir_proc
loopir = Schedules.DoAddLoop(loopir, stmt, var, hi).result()
loopir = Schedules.DoAddLoop(loopir, stmt, var, hi, guard).result()

return Procedure(loopir, _provenance_eq_Procedure=self)

Expand Down
14 changes: 11 additions & 3 deletions src/exo/LoopIR_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2412,10 +2412,11 @@ def map_stmts(self, stmts):


class _DoAddLoop(LoopIR_Rewrite):
def __init__(self, proc, stmt, var, hi):
def __init__(self, proc, stmt, var, hi, guard):
self.stmt = stmt
self.var = var
self.hi = hi
self.guard = guard

super().__init__(proc)

Expand All @@ -2427,8 +2428,15 @@ def map_s(self, s):
raise SchedulingError("expected stmt to be idempotent!")

sym = Sym(self.var)
hi = LoopIR.Const(self.hi, T.int, s.srcinfo)
ir = LoopIR.ForAll(sym, hi, [s], None, s.srcinfo)

new_s = s
if self.guard:
cond = LoopIR.BinOp('==', LoopIR.Read(sym, [], T.index, s.srcinfo),
LoopIR.Const(0, T.int, s.srcinfo), T.bool, s.srcinfo)
new_s = LoopIR.If(cond, [s], [], None, s.srcinfo)

hi = LoopIR.Const(self.hi, T.int, new_s.srcinfo)
ir = LoopIR.ForAll(sym, hi, [new_s], None, new_s.srcinfo)
return [ir]

return super().map_s(s)
Expand Down
Loading