-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Created notebook for refactoring news
- Loading branch information
1 parent
e883428
commit aafccaf
Showing
1 changed file
with
310 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,310 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 27, | ||
"id": "7ca4db35", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from lifelines.utils import concordance_index" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "fc02d133", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"from lifelines.utils.btree import _BTree\n", | ||
"\n", | ||
"def concordance_index(event_times, predicted_scores, event_observed=None) -> float:\n", | ||
" if event_observed is None:\n", | ||
" event_observed = np.ones(event_times.shape[0], dtype=float)\n", | ||
" \n", | ||
" num_correct, num_tied, num_pairs = _concordance_summary_statistics(event_times, predicted_scores, event_observed)\n", | ||
"\n", | ||
" if num_pairs == 0:\n", | ||
" raise ZeroDivisionError(\"No admissable pairs in the dataset.\")\n", | ||
" return (num_correct + num_tied / 2) / num_pairs\n", | ||
"\n", | ||
"\n", | ||
"def _concordance_summary_statistics(event_times, predicted_event_times, event_observed):\n", | ||
" if np.logical_not(event_observed).all():\n", | ||
" return (0, 0, 0)\n", | ||
"\n", | ||
" died_mask = event_observed.astype(bool)\n", | ||
" died_truth = event_times[died_mask]\n", | ||
" ix = np.argsort(died_truth)\n", | ||
" died_truth = died_truth[ix]\n", | ||
" died_pred = predicted_event_times[died_mask][ix]\n", | ||
"\n", | ||
" censored_truth = event_times[~died_mask]\n", | ||
" ix = np.argsort(censored_truth)\n", | ||
" censored_truth = censored_truth[ix]\n", | ||
" censored_pred = predicted_event_times[~died_mask][ix]\n", | ||
"\n", | ||
" censored_ix = 0\n", | ||
" died_ix = 0\n", | ||
" times_to_compare = _BTree(np.unique(died_pred))\n", | ||
" print(np.unique(died_pred), times_to_compare)\n", | ||
" num_pairs = np.int64(0)\n", | ||
" num_correct = np.int64(0)\n", | ||
" num_tied = np.int64(0)\n", | ||
"\n", | ||
" # we iterate through cases sorted by exit time:\n", | ||
" # - First, all cases that died at time t0. We add these to the sortedlist of died times.\n", | ||
" # - Then, all cases that were censored at time t0. We DON'T add these since they are NOT\n", | ||
" # comparable to subsequent elements.\n", | ||
" while True:\n", | ||
" has_more_censored = censored_ix < len(censored_truth)\n", | ||
" has_more_died = died_ix < len(died_truth)\n", | ||
" # Should we look at some censored indices next, or died indices?\n", | ||
" if has_more_censored and (not has_more_died or died_truth[died_ix] > censored_truth[censored_ix]):\n", | ||
" pairs, correct, tied, next_ix = _handle_pairs(censored_truth, censored_pred, censored_ix, times_to_compare)\n", | ||
" censored_ix = next_ix\n", | ||
" elif has_more_died and (not has_more_censored or died_truth[died_ix] <= censored_truth[censored_ix]):\n", | ||
" pairs, correct, tied, next_ix = _handle_pairs(died_truth, died_pred, died_ix, times_to_compare)\n", | ||
" for pred in died_pred[died_ix:next_ix]:\n", | ||
" times_to_compare.insert(pred)\n", | ||
" died_ix = next_ix\n", | ||
" else:\n", | ||
" assert not (has_more_died or has_more_censored)\n", | ||
" break\n", | ||
"\n", | ||
" num_pairs += pairs\n", | ||
" num_correct += correct\n", | ||
" num_tied += tied\n", | ||
"\n", | ||
" return (num_correct, num_tied, num_pairs)\n", | ||
"\n", | ||
"\n", | ||
"def _handle_pairs(truth, pred, first_ix, times_to_compare):\n", | ||
" \"\"\"\n", | ||
" Handle all pairs that exited at the same time as truth[first_ix].\n", | ||
"\n", | ||
" Returns\n", | ||
" -------\n", | ||
" (pairs, correct, tied, next_ix)\n", | ||
" new_pairs: The number of new comparisons performed\n", | ||
" new_correct: The number of comparisons correctly predicted\n", | ||
" next_ix: The next index that needs to be handled\n", | ||
" \"\"\"\n", | ||
" next_ix = first_ix\n", | ||
" while next_ix < len(truth) and truth[next_ix] == truth[first_ix]:\n", | ||
" next_ix += 1\n", | ||
" pairs = len(times_to_compare) * (next_ix - first_ix)\n", | ||
" correct = np.int64(0)\n", | ||
" tied = np.int64(0)\n", | ||
" for i in range(first_ix, next_ix):\n", | ||
" rank, count = times_to_compare.rank(pred[i])\n", | ||
" correct += rank\n", | ||
" tied += count\n", | ||
"\n", | ||
" return (pairs, correct, tied, next_ix)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 22, | ||
"id": "fb905b58", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def _concordance_index(risk, T, E, include_ties=True):\n", | ||
" N = len(risk)\n", | ||
" censored_survival = []\n", | ||
" C = 0\n", | ||
" w = 0\n", | ||
" weightedPairs = 0\n", | ||
" weightedConcPairs = 0\n", | ||
"\n", | ||
" print(T, E, risk)\n", | ||
" for i in range(N):\n", | ||
" if E[i] == 1:\n", | ||
" for j in range(i + 1, N):\n", | ||
" if T[i] < T[j] or (T[i] == T[j] and E[j] == 0):\n", | ||
" weightedPairs += 1\n", | ||
" if risk[i] > risk[j]:\n", | ||
" weightedConcPairs += 1\n", | ||
" elif include_ties:\n", | ||
" weightedConcPairs += 1 / 2\n", | ||
" C = weightedConcPairs / weightedPairs\n", | ||
" C = max(C, 1 - C)\n", | ||
"\n", | ||
" return {\n", | ||
" 'C': C,\n", | ||
" 'nb_pairs': 2 * weightedPairs,\n", | ||
" 'nb_concordant_pairs': 2 * weightedConcPards\n", | ||
" }\n", | ||
"\n", | ||
"\n", | ||
"def concordance_index(true_time, pred_time, event, include_ties = True, additional_results=False, **kwargs):\n", | ||
" order = np.argsort(-true_time)\n", | ||
" pred_time = pred_time[order]\n", | ||
" true_time = true_time[order]\n", | ||
" event = event[order]\n", | ||
"\n", | ||
" # Calculating th c-index\n", | ||
" results = _concordance_index(pred_time, true_time, event, include_ties)\n", | ||
"\n", | ||
" if not additional_results:\n", | ||
" return results[0]\n", | ||
" return results" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 66, | ||
"id": "47b7fb0d", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def concordance_index_self(T, P, E):\n", | ||
" \"\"\"\n", | ||
" Calculates the concordance index (C-index) for survival analysis.\n", | ||
"\n", | ||
" Args:\n", | ||
" T: Array of true event times.\n", | ||
" P: Array of predicted event times.\n", | ||
" E: Array of event indicators (1 if event occurred, 0 if censored).\n", | ||
"\n", | ||
" Returns:\n", | ||
" The concordance index.\n", | ||
" \"\"\"\n", | ||
" order = np.argsort(T)\n", | ||
" P = P[order]\n", | ||
" T = T[order]\n", | ||
" E = E[order]\n", | ||
" \n", | ||
" n = len(T)\n", | ||
" concordant_pairs = 0\n", | ||
" total_pairs = 0\n", | ||
" for i in range(n):\n", | ||
" for j in range(i + 1, n):\n", | ||
" if E[i] == 1 and T[i] <= T[j]:\n", | ||
" total_pairs += 1\n", | ||
" if P[i] < P[j]:\n", | ||
" concordant_pairs += 1\n", | ||
" elif P[i] == P[j]:\n", | ||
" concordant_pairs += 0.5\n", | ||
" if total_pairs == 0:\n", | ||
" return 0\n", | ||
" return concordant_pairs / total_pairs\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 67, | ||
"id": "0547a83f", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"a = (np.array([10, 20, 30, 40]), \n", | ||
" np.array([20, 19, 29, 39]), \n", | ||
" np.array([1, 0, 1, 0]))\n", | ||
"assert concordance_index(*a) == concordance_index_self(*a)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 71, | ||
"id": "b529d42e", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"for i in range(1000):\n", | ||
" a = np.random.rand(100)*100\n", | ||
" b = np.random.rand(100)*100\n", | ||
" e = np.round(np.random.rand(100))\n", | ||
" \n", | ||
" assert concordance_index(a, b, e) == concordance_index_self(a, b, e), f\"{a}, {b}, {e}\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 72, | ||
"id": "314a5f66", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"(0.5041072200605274, 0.5041072200605274)" | ||
] | ||
}, | ||
"execution_count": 72, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"concordance_index(a, b, e)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "2e6df15a", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"concordance_index_self(a, b, e)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 52, | ||
"id": "7af657dc", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"(array([43.54496772, 76.8872837 , 62.34868269, 76.74680946]),\n", | ||
" array([44.25585109, 1.58026784, 94.81912392, 9.99180181]),\n", | ||
" array([1., 1., 1., 1.]))" | ||
] | ||
}, | ||
"execution_count": 52, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"a, b, e" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "a4d58b6d", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.4" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |