Skip to content

Commit

Permalink
Inserted new CI to core code
Browse files Browse the repository at this point in the history
  • Loading branch information
iuliivasilev committed Aug 5, 2024
1 parent b76e5e1 commit a8cfcb6
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
5 changes: 2 additions & 3 deletions survivors/ensemble/base_ensemble.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pandas as pd
import numpy as np
from lifelines.utils import concordance_index
from sklearn.metrics import roc_auc_score

from .. import metrics as metr
Expand Down Expand Up @@ -161,7 +160,7 @@ def separate_score_oob(self):
X_v = self.oob[i]
if self.ens_metric_name == "CI":
pred = self.models[i].predict(X_v, target=cnt.TIME_NAME)
score = concordance_index(X_v[cnt.TIME_NAME], pred)
score = metr.concordance_index(X_v[cnt.TIME_NAME], pred)
else:
pred = self.models[i].predict_at_times(X_v, bins=self.bins, mode="surv")
y_true = cnt.get_y(X_v[cnt.CENS_NAME], X_v[cnt.TIME_NAME])
Expand All @@ -182,7 +181,7 @@ def aggregate_score_selfoob(self, bins=None):
if self.ens_metric_name in ["CI", "roc"]:
pred = pd.concat(self.list_pred_oob, axis=1).mean(axis=1)
if self.ens_metric_name == "CI":
return concordance_index(target_time, pred)
return metr.concordance_index(target_time, pred)
return roc_auc_score(target_cens, pred)

if is_ibs:
Expand Down
8 changes: 4 additions & 4 deletions survivors/ensemble/base_ensemble_iter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pandas as pd
import numpy as np
from lifelines.utils import concordance_index
# from lifelines.utils import concordance_index
from sklearn.metrics import roc_auc_score

from .. import metrics as metr
Expand Down Expand Up @@ -157,7 +157,7 @@ def separate_score_oob(self):
X_v = self.oob[i]
if self.ens_metric_name == "CI":
pred = self.models[i].predict(X_v, target=cnt.TIME_NAME)
score = concordance_index(X_v[cnt.TIME_NAME], pred)
score = metr.concordance_index(X_v[cnt.TIME_NAME], pred)
else:
pred = self.models[i].predict_at_times(X_v, bins=self.bins, mode="surv")
y_true = cnt.get_y(X_v[cnt.CENS_NAME], X_v[cnt.TIME_NAME])
Expand All @@ -178,7 +178,7 @@ def aggregate_score_selfoob(self, bins=None):
if self.ens_metric_name in ["CI", "roc"]:
pred = pd.concat(self.list_pred_oob, axis=1).mean(axis=1)
if self.ens_metric_name == "CI":
return concordance_index(target_time, pred)
return metr.concordance_index(target_time, pred)
return roc_auc_score(target_cens, pred)

if is_ibs:
Expand Down Expand Up @@ -246,7 +246,7 @@ def aggregate_score_selfoob(self):
target_cens = join_oob[cnt.CENS_NAME]

if self.ens_metric_name == "CI":
return concordance_index(target_time, pred)
return metr.concordance_index(target_time, pred)
elif self.ens_metric_name == "roc":
return roc_auc_score(target_cens, pred)
elif is_ibs:
Expand Down

0 comments on commit a8cfcb6

Please sign in to comment.