Skip to content

Commit

Permalink
Update recurrent_event_data.py
Browse files Browse the repository at this point in the history
updated RecurrentEventData to accommodate for 2D x
  • Loading branch information
derrynknife committed Jul 28, 2023
1 parent a5a7266 commit 1e2d131
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions surpyval/utils/recurrent_event_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,27 @@ def __init__(self, x, i, c, n):
self.n = np.atleast_1d(n)
self.items = list(set(self.i))

self.interarrival_times = self.find_interarrival_times(x, i)
if self.x.ndim == 1:
self.interarrival_times = self.find_interarrival_times(x, i)
else:
self.midpoints = self.x.mean(axis=1)
self._index = 0

def to_xrd(self, estimator="Nelson-Aalen"):
if not hasattr(self, "xrd"):
x_unique = np.unique(self.x)
# find the total number of times an event occurs at each x
if self.x.ndim == 2:
x_out = self.midpoints
else:
x_out = self.x

x_unique = np.unique(x_out)

d = np.array(
[
self.n[(self.x == xi) & (self.c == 0)].sum()
self.n[
(x_out == xi) & ((self.c == 0) | (self.c == 2))
].sum()
for xi in x_unique
]
)
Expand Down

0 comments on commit 1e2d131

Please sign in to comment.