Skip to content

Commit

Permalink
Update recurrent_utils.py
Browse files Browse the repository at this point in the history
Improved runtime efficiency of handle_xicn function.
  • Loading branch information
derrynknife committed Aug 17, 2023
1 parent 64b45c4 commit 6d05d3b
Showing 1 changed file with 30 additions and 19 deletions.
49 changes: 30 additions & 19 deletions surpyval/utils/recurrent_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,6 @@ def handle_xicn(x, i=None, c=None, n=None, Z=None, as_recurrent_data=True):
"Counts greater than 1 must be intervally or left censored"
)

# Check that if censored, it is the highest value
for ii in set(i):
ci = c[i == ii]
if len(ci[ci == 1]) > 1:
raise ValueError(f"Item {ii} is right censored more than once.")
if len(ci[ci == -1]) > 1:
raise ValueError(f"Item {ii} is left censored more than once.")

# sort by item and x
if x.ndim == 2:
# Order 2D by the midpoint
Expand All @@ -74,17 +66,36 @@ def handle_xicn(x, i=None, c=None, n=None, Z=None, as_recurrent_data=True):
if Z is not None:
Z = Z[sort_order]

# Check that the x values for each item are monotonically increasing
for ii in set(i):
xi = x[i == ii]
ci = c[i == ii]
if xi.ndim == 2:
for first, second in zip(xi[:-1], xi[1:]):
if first[1] > second[0]:
raise ValueError(f"Item {ii} has overlapping intervals")
else:
if np.any(np.diff(xi) < 0):
raise ValueError(f"Item {ii} has non-monotonic x values")
unique_i, idx = np.unique(i, return_index=True)
censoring_by_i = np.split(c, idx)[1:]

for ii, arr in zip(unique_i, censoring_by_i):
if 1 in arr:
if (arr == 1).sum() > 1:
raise ValueError(
f"Item {ii} has more than one right censored time"
)
if arr[-1] != 1:
raise ValueError(
f"Item {ii} has right censored event which is not the last"
)
if -1 in arr:
if (arr == -1).sum() > 1:
raise ValueError(
f"Item {ii} has more than one left censored event"
)
if arr[0] != -1:
raise ValueError(
f"Item {ii} has left censored event that is not the first"
)

if x.ndim == 2:
times_by_i = np.split(x, idx)[1:]
for ii, arr in zip(unique_i, times_by_i):
starts = arr[1:][:, 0]
ends = arr[:-1][:, 1]
if (ends > starts).any():
raise ValueError(f"Item {ii} has overlapping intervals")

if as_recurrent_data:
data = RecurrentEventData(x, i, c, n)
Expand Down

0 comments on commit 6d05d3b

Please sign in to comment.