From f4d58d5814ba082be9821b64dde38bc317072730 Mon Sep 17 00:00:00 2001 From: "Todd A. Anderson" Date: Mon, 5 Feb 2024 14:33:56 -0800 Subject: [PATCH] Fix ordering of calc_map_internal overload. --- ramba/shardview_array.py | 44 ++++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/ramba/shardview_array.py b/ramba/shardview_array.py index 11a2ac7..0205af9 100644 --- a/ramba/shardview_array.py +++ b/ramba/shardview_array.py @@ -389,6 +389,28 @@ def has_index(sv, index): # return (_index_start(sv)<=index).all() and (index<_stop(sv)).all() +def calc_map_internal(sl_i, sv_s, sv_e, sv_st): + if sl_i.step is None: + s = min(max(sl_i.start, sv_s), sv_e) + e = min(max(sl_i.stop, sv_s), sv_e) + sz = e-s + si = s-sl_i.start + st = sv_st + else: + if sl_i.step>0: + s = min(max(sl_i.start, sv_s + (sl_i.start-sv_s)%sl_i.step), sv_e) + e = min(max(sl_i.stop - (sl_i.stop-1-sl_i.start)%sl_i.step, sv_s), sv_e) + si = max(0,(s-sl_i.start)//sl_i.step) + else: + s = min(max(sl_i.stop + 1 + (sl_i.start-sl_i.stop-1)%abs(sl_i.step), sv_s+(sl_i.start-sv_s)%abs(sl_i.step)), sv_e) + e = min(max(sl_i.start+1, sv_s), sv_e) + si = max(0,int(np.ceil((e-1-sl_i.start)/sl_i.step))) + sz = int(np.ceil((e-s)/abs(sl_i.step))) + st = sv_st*sl_i.step + e = s + (sz-1)*abs(sl_i.step)+1 + return s, max(s,e-1), sz, si, st + + @overload(calc_map_internal, nopython=True, cache=True) def calc_map_internal(sl_i, sv_s, sv_e, sv_st): if isinstance(sl_i,numba.types.SliceType): @@ -426,28 +448,6 @@ def impl(sl_i, sv_s, sv_e, sv_st): raise numba.core.errors.TypingError("ERR: slice contains something unexpected!", type(sl_i)) -def calc_map_internal(sl_i, sv_s, sv_e, sv_st): - if sl_i.step is None: - s = min(max(sl_i.start, sv_s), sv_e) - e = min(max(sl_i.stop, sv_s), sv_e) - sz = e-s - si = s-sl_i.start - st = sv_st - else: - if sl_i.step>0: - s = min(max(sl_i.start, sv_s + (sl_i.start-sv_s)%sl_i.step), sv_e) - e = min(max(sl_i.stop - (sl_i.stop-1-sl_i.start)%sl_i.step, sv_s), sv_e) - si = max(0,(s-sl_i.start)//sl_i.step) - else: - s = min(max(sl_i.stop + 1 + (sl_i.start-sl_i.stop-1)%abs(sl_i.step), sv_s+(sl_i.start-sv_s)%abs(sl_i.step)), sv_e) - e = min(max(sl_i.start+1, sv_s), sv_e) - si = max(0,int(np.ceil((e-1-sl_i.start)/sl_i.step))) - sz = int(np.ceil((e-s)/abs(sl_i.step))) - st = sv_st*sl_i.step - e = s + (sz-1)*abs(sl_i.step)+1 - return s, max(s,e-1), sz, si, st - - @numba.njit(fastmath=fastmath, cache=True) def mapslice(sv, sl): # assert len(sl) == len_size(sv) # This assert causes compilation failure at i+=1; likely due to Numba bug; need to revisit