Skip to content

Commit

Permalink
Embedded new CI to survivors
Browse files Browse the repository at this point in the history
  • Loading branch information
iuliivasilev committed Aug 5, 2024
1 parent 00fbd32 commit 0ebebcb
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 16 deletions.
32 changes: 17 additions & 15 deletions demonstration/Articles/Refactoring.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 33,
"id": "8ad4da3a",
"metadata": {},
"outputs": [],
Expand All @@ -222,7 +222,7 @@
"# concordant_pairs += 0.5\n",
" return concordant_pairs, total_pairs\n",
"\n",
"def concordance_index_self(T, P, E):\n",
"def concordance_index_self(event_times, predicted_scores, event_observed=None):\n",
" \"\"\"\n",
" Calculates the concordance index (C-index) for survival analysis.\n",
"\n",
Expand All @@ -234,11 +234,13 @@
" Returns:\n",
" The concordance index.\n",
" \"\"\"\n",
" order = np.argsort(T)\n",
" P = np.asarray(P)[order]\n",
" T = np.asarray(T)[order]\n",
" E = np.asarray(E)[order]\n",
" concordant_pairs, total_pairs = count_pairs(T, P, E)\n",
" if event_observed is None:\n",
" event_observed = np.ones(len(event_times))\n",
" order = np.argsort(event_times)\n",
" predicted_scores = np.asarray(predicted_scores)[order]\n",
" event_times = np.asarray(event_times)[order]\n",
" event_observed = np.asarray(event_observed)[order]\n",
" concordant_pairs, total_pairs = count_pairs(event_times, predicted_scores, event_observed)\n",
" \n",
" if total_pairs == 0:\n",
" return 0\n",
Expand All @@ -247,7 +249,7 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 39,
"id": "051083de",
"metadata": {},
"outputs": [],
Expand All @@ -264,7 +266,7 @@
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 40,
"id": "7f5b0643",
"metadata": {},
"outputs": [],
Expand All @@ -274,12 +276,12 @@
" b = np.random.rand(100)*100\n",
" e = np.round(np.random.rand(100))\n",
" \n",
" assert concordance_index(a, b, e) == concordance_index_self(a, b, e), f\"{a}, {b}, {e}\""
" assert concordance_index(a, b) == concordance_index_self(a, b), f\"{a}, {b}, {e}\""
]
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": 36,
"id": "d6fdf195",
"metadata": {},
"outputs": [],
Expand All @@ -296,15 +298,15 @@
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": 37,
"id": "d7ba9fdc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"652 ms ± 7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
"684 ms ± 4.06 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
Expand All @@ -314,15 +316,15 @@
},
{
"cell_type": "code",
"execution_count": 32,
"execution_count": 38,
"id": "3efaefab",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"21.9 ms ± 231 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
"22.7 ms ± 128 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
Expand Down
41 changes: 40 additions & 1 deletion survivors/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
from numba import njit, jit
from lifelines import KaplanMeierFitter, NelsonAalenFitter
from lifelines.utils import concordance_index
# from lifelines.utils import concordance_index

from .constants import TIME_NAME, CENS_NAME

Expand Down Expand Up @@ -498,6 +498,45 @@ def bic(k, n, time, cens, sf, cumhf, bins):
return k*np.log(n) - 2*loglikelihood(time, cens, sf, cumhf, bins)


@njit
def count_pairs(T, P, E):
n = len(T)
concordant_pairs = 0
total_pairs = 0
for i in range(n):
for j in range(i + 1, n):
if E[i] == 1 and T[i] <= T[j]:
total_pairs += 1
concordant_pairs += P[i] < P[j]
concordant_pairs += 0.5 * (P[i] == P[j])
return concordant_pairs, total_pairs


def concordance_index_self(event_times, predicted_scores, event_observed=None):
"""
Calculates the concordance index (C-index) for survival analysis.
Args:
event_times: Array of true event times.
predicted_scores: Array of predicted event times.
event_observed: Array of event indicators (1 if event occurred, 0 if censored).
Returns:
The concordance index.
"""
if event_observed is None:
event_observed = np.ones(len(event_times))
order = np.argsort(event_times)
predicted_scores = np.asarray(predicted_scores)[order]
event_times = np.asarray(event_times)[order]
event_observed = np.asarray(event_observed)[order]
concordant_pairs, total_pairs = count_pairs(event_times, predicted_scores, event_observed)

if total_pairs == 0:
return 0
return concordant_pairs / total_pairs


""" ESTIMATE FUNCTION """


Expand Down

0 comments on commit 0ebebcb

Please sign in to comment.