Skip to content

Commit

Permalink
Created notebook for refactoring news
Browse files Browse the repository at this point in the history
  • Loading branch information
iuliivasilev committed Aug 3, 2024
1 parent e883428 commit aafccaf
Showing 1 changed file with 310 additions and 0 deletions.
310 changes: 310 additions & 0 deletions demonstration/Articles/Refactoring.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,310 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 27,
"id": "7ca4db35",
"metadata": {},
"outputs": [],
"source": [
"from lifelines.utils import concordance_index"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "fc02d133",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from lifelines.utils.btree import _BTree\n",
"\n",
"def concordance_index(event_times, predicted_scores, event_observed=None) -> float:\n",
" if event_observed is None:\n",
" event_observed = np.ones(event_times.shape[0], dtype=float)\n",
" \n",
" num_correct, num_tied, num_pairs = _concordance_summary_statistics(event_times, predicted_scores, event_observed)\n",
"\n",
" if num_pairs == 0:\n",
" raise ZeroDivisionError(\"No admissable pairs in the dataset.\")\n",
" return (num_correct + num_tied / 2) / num_pairs\n",
"\n",
"\n",
"def _concordance_summary_statistics(event_times, predicted_event_times, event_observed):\n",
" if np.logical_not(event_observed).all():\n",
" return (0, 0, 0)\n",
"\n",
" died_mask = event_observed.astype(bool)\n",
" died_truth = event_times[died_mask]\n",
" ix = np.argsort(died_truth)\n",
" died_truth = died_truth[ix]\n",
" died_pred = predicted_event_times[died_mask][ix]\n",
"\n",
" censored_truth = event_times[~died_mask]\n",
" ix = np.argsort(censored_truth)\n",
" censored_truth = censored_truth[ix]\n",
" censored_pred = predicted_event_times[~died_mask][ix]\n",
"\n",
" censored_ix = 0\n",
" died_ix = 0\n",
" times_to_compare = _BTree(np.unique(died_pred))\n",
" print(np.unique(died_pred), times_to_compare)\n",
" num_pairs = np.int64(0)\n",
" num_correct = np.int64(0)\n",
" num_tied = np.int64(0)\n",
"\n",
" # we iterate through cases sorted by exit time:\n",
" # - First, all cases that died at time t0. We add these to the sortedlist of died times.\n",
" # - Then, all cases that were censored at time t0. We DON'T add these since they are NOT\n",
" # comparable to subsequent elements.\n",
" while True:\n",
" has_more_censored = censored_ix < len(censored_truth)\n",
" has_more_died = died_ix < len(died_truth)\n",
" # Should we look at some censored indices next, or died indices?\n",
" if has_more_censored and (not has_more_died or died_truth[died_ix] > censored_truth[censored_ix]):\n",
" pairs, correct, tied, next_ix = _handle_pairs(censored_truth, censored_pred, censored_ix, times_to_compare)\n",
" censored_ix = next_ix\n",
" elif has_more_died and (not has_more_censored or died_truth[died_ix] <= censored_truth[censored_ix]):\n",
" pairs, correct, tied, next_ix = _handle_pairs(died_truth, died_pred, died_ix, times_to_compare)\n",
" for pred in died_pred[died_ix:next_ix]:\n",
" times_to_compare.insert(pred)\n",
" died_ix = next_ix\n",
" else:\n",
" assert not (has_more_died or has_more_censored)\n",
" break\n",
"\n",
" num_pairs += pairs\n",
" num_correct += correct\n",
" num_tied += tied\n",
"\n",
" return (num_correct, num_tied, num_pairs)\n",
"\n",
"\n",
"def _handle_pairs(truth, pred, first_ix, times_to_compare):\n",
" \"\"\"\n",
" Handle all pairs that exited at the same time as truth[first_ix].\n",
"\n",
" Returns\n",
" -------\n",
" (pairs, correct, tied, next_ix)\n",
" new_pairs: The number of new comparisons performed\n",
" new_correct: The number of comparisons correctly predicted\n",
" next_ix: The next index that needs to be handled\n",
" \"\"\"\n",
" next_ix = first_ix\n",
" while next_ix < len(truth) and truth[next_ix] == truth[first_ix]:\n",
" next_ix += 1\n",
" pairs = len(times_to_compare) * (next_ix - first_ix)\n",
" correct = np.int64(0)\n",
" tied = np.int64(0)\n",
" for i in range(first_ix, next_ix):\n",
" rank, count = times_to_compare.rank(pred[i])\n",
" correct += rank\n",
" tied += count\n",
"\n",
" return (pairs, correct, tied, next_ix)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "fb905b58",
"metadata": {},
"outputs": [],
"source": [
"def _concordance_index(risk, T, E, include_ties=True):\n",
" N = len(risk)\n",
" censored_survival = []\n",
" C = 0\n",
" w = 0\n",
" weightedPairs = 0\n",
" weightedConcPairs = 0\n",
"\n",
" print(T, E, risk)\n",
" for i in range(N):\n",
" if E[i] == 1:\n",
" for j in range(i + 1, N):\n",
" if T[i] < T[j] or (T[i] == T[j] and E[j] == 0):\n",
" weightedPairs += 1\n",
" if risk[i] > risk[j]:\n",
" weightedConcPairs += 1\n",
" elif include_ties:\n",
" weightedConcPairs += 1 / 2\n",
" C = weightedConcPairs / weightedPairs\n",
" C = max(C, 1 - C)\n",
"\n",
" return {\n",
" 'C': C,\n",
" 'nb_pairs': 2 * weightedPairs,\n",
" 'nb_concordant_pairs': 2 * weightedConcPards\n",
" }\n",
"\n",
"\n",
"def concordance_index(true_time, pred_time, event, include_ties = True, additional_results=False, **kwargs):\n",
" order = np.argsort(-true_time)\n",
" pred_time = pred_time[order]\n",
" true_time = true_time[order]\n",
" event = event[order]\n",
"\n",
" # Calculating th c-index\n",
" results = _concordance_index(pred_time, true_time, event, include_ties)\n",
"\n",
" if not additional_results:\n",
" return results[0]\n",
" return results"
]
},
{
"cell_type": "code",
"execution_count": 66,
"id": "47b7fb0d",
"metadata": {},
"outputs": [],
"source": [
"def concordance_index_self(T, P, E):\n",
" \"\"\"\n",
" Calculates the concordance index (C-index) for survival analysis.\n",
"\n",
" Args:\n",
" T: Array of true event times.\n",
" P: Array of predicted event times.\n",
" E: Array of event indicators (1 if event occurred, 0 if censored).\n",
"\n",
" Returns:\n",
" The concordance index.\n",
" \"\"\"\n",
" order = np.argsort(T)\n",
" P = P[order]\n",
" T = T[order]\n",
" E = E[order]\n",
" \n",
" n = len(T)\n",
" concordant_pairs = 0\n",
" total_pairs = 0\n",
" for i in range(n):\n",
" for j in range(i + 1, n):\n",
" if E[i] == 1 and T[i] <= T[j]:\n",
" total_pairs += 1\n",
" if P[i] < P[j]:\n",
" concordant_pairs += 1\n",
" elif P[i] == P[j]:\n",
" concordant_pairs += 0.5\n",
" if total_pairs == 0:\n",
" return 0\n",
" return concordant_pairs / total_pairs\n"
]
},
{
"cell_type": "code",
"execution_count": 67,
"id": "0547a83f",
"metadata": {},
"outputs": [],
"source": [
"a = (np.array([10, 20, 30, 40]), \n",
" np.array([20, 19, 29, 39]), \n",
" np.array([1, 0, 1, 0]))\n",
"assert concordance_index(*a) == concordance_index_self(*a)"
]
},
{
"cell_type": "code",
"execution_count": 71,
"id": "b529d42e",
"metadata": {},
"outputs": [],
"source": [
"for i in range(1000):\n",
" a = np.random.rand(100)*100\n",
" b = np.random.rand(100)*100\n",
" e = np.round(np.random.rand(100))\n",
" \n",
" assert concordance_index(a, b, e) == concordance_index_self(a, b, e), f\"{a}, {b}, {e}\""
]
},
{
"cell_type": "code",
"execution_count": 72,
"id": "314a5f66",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0.5041072200605274, 0.5041072200605274)"
]
},
"execution_count": 72,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"concordance_index(a, b, e)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2e6df15a",
"metadata": {},
"outputs": [],
"source": [
"concordance_index_self(a, b, e)"
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "7af657dc",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([43.54496772, 76.8872837 , 62.34868269, 76.74680946]),\n",
" array([44.25585109, 1.58026784, 94.81912392, 9.99180181]),\n",
" array([1., 1., 1., 1.]))"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a, b, e"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a4d58b6d",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit aafccaf

Please sign in to comment.