Skip to content

Commit

Permalink
Fix ordering of calc_map_internal overload.
Browse files Browse the repository at this point in the history
  • Loading branch information
DrTodd13 committed Feb 5, 2024
1 parent 35590ea commit f4d58d5
Showing 1 changed file with 22 additions and 22 deletions.
44 changes: 22 additions & 22 deletions ramba/shardview_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,28 @@ def has_index(sv, index):
# return (_index_start(sv)<=index).all() and (index<_stop(sv)).all()


def calc_map_internal(sl_i, sv_s, sv_e, sv_st):
if sl_i.step is None:
s = min(max(sl_i.start, sv_s), sv_e)
e = min(max(sl_i.stop, sv_s), sv_e)
sz = e-s
si = s-sl_i.start
st = sv_st
else:
if sl_i.step>0:
s = min(max(sl_i.start, sv_s + (sl_i.start-sv_s)%sl_i.step), sv_e)
e = min(max(sl_i.stop - (sl_i.stop-1-sl_i.start)%sl_i.step, sv_s), sv_e)
si = max(0,(s-sl_i.start)//sl_i.step)
else:
s = min(max(sl_i.stop + 1 + (sl_i.start-sl_i.stop-1)%abs(sl_i.step), sv_s+(sl_i.start-sv_s)%abs(sl_i.step)), sv_e)
e = min(max(sl_i.start+1, sv_s), sv_e)
si = max(0,int(np.ceil((e-1-sl_i.start)/sl_i.step)))
sz = int(np.ceil((e-s)/abs(sl_i.step)))
st = sv_st*sl_i.step
e = s + (sz-1)*abs(sl_i.step)+1
return s, max(s,e-1), sz, si, st


@overload(calc_map_internal, nopython=True, cache=True)
def calc_map_internal(sl_i, sv_s, sv_e, sv_st):
if isinstance(sl_i,numba.types.SliceType):
Expand Down Expand Up @@ -426,28 +448,6 @@ def impl(sl_i, sv_s, sv_e, sv_st):
raise numba.core.errors.TypingError("ERR: slice contains something unexpected!", type(sl_i))


def calc_map_internal(sl_i, sv_s, sv_e, sv_st):
if sl_i.step is None:
s = min(max(sl_i.start, sv_s), sv_e)
e = min(max(sl_i.stop, sv_s), sv_e)
sz = e-s
si = s-sl_i.start
st = sv_st
else:
if sl_i.step>0:
s = min(max(sl_i.start, sv_s + (sl_i.start-sv_s)%sl_i.step), sv_e)
e = min(max(sl_i.stop - (sl_i.stop-1-sl_i.start)%sl_i.step, sv_s), sv_e)
si = max(0,(s-sl_i.start)//sl_i.step)
else:
s = min(max(sl_i.stop + 1 + (sl_i.start-sl_i.stop-1)%abs(sl_i.step), sv_s+(sl_i.start-sv_s)%abs(sl_i.step)), sv_e)
e = min(max(sl_i.start+1, sv_s), sv_e)
si = max(0,int(np.ceil((e-1-sl_i.start)/sl_i.step)))
sz = int(np.ceil((e-s)/abs(sl_i.step)))
st = sv_st*sl_i.step
e = s + (sz-1)*abs(sl_i.step)+1
return s, max(s,e-1), sz, si, st


@numba.njit(fastmath=fastmath, cache=True)
def mapslice(sv, sl):
# assert len(sl) == len_size(sv) # This assert causes compilation failure at i+=1; likely due to Numba bug; need to revisit
Expand Down

0 comments on commit f4d58d5

Please sign in to comment.