Skip to content

Commit

Permalink
Merge pull request #38 from derrynknife/random-survival-forest
Browse files Browse the repository at this point in the history
Random survival forest
  • Loading branch information
derrynknife authored Aug 1, 2023
2 parents 0e484ac + aaca459 commit f3ed2be
Show file tree
Hide file tree
Showing 55 changed files with 31,083 additions and 722 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,20 @@ repos:

# Organise imports
- repo: https://github.com/PyCQA/isort
rev: '5.11.4'
rev: '5.12.0'
hooks:
- id: isort
args: ["-l=79"]

# Update all old python syntax
- repo: https://github.com/asottile/pyupgrade
rev: 'v3.1.0'
rev: 'v3.3.1'
hooks:
- id: pyupgrade

# Format with Black
- repo: https://github.com/psf/black
rev: '22.10.0'
rev: '23.1.0'
hooks:
- id: black

Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ module = [
'lifelines',
'reliability',
'sphinx_rtd_theme',
'setuptools'
'setuptools',
'sklearn.*',
'sksurv.*',
'joblib.*'
]
ignore_missing_imports = true

Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
lifelines==0.27.4
numba==0.56.4
numpy-indexed==0.3.5
reliability==0.8.6
reliability==0.8.6
matplotlib==3.6
3 changes: 2 additions & 1 deletion requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ flake8-pyproject
mypy
pytest
coverage
pre-commit
pre-commit
scikit-survival
10 changes: 9 additions & 1 deletion surpyval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,25 @@
Gamma,
Gauss,
Gumbel,
InstantlyOccurs,
Logistic,
LogLogistic,
LogNormal,
MixtureModel,
NeverOccurs,
Normal,
Parametric,
Rayleigh,
Uniform,
Weibull,
)
from surpyval.regression import CoxPH, ExponentialPH, WeibullPH
from surpyval.regression import (
CoxPH,
ExponentialPH,
RandomSurvivalForest,
SurvivalTree,
WeibullPH,
)
from surpyval.utils import (
fs_to_xcn,
fs_to_xrd,
Expand Down
1 change: 0 additions & 1 deletion surpyval/competing_risks/competing_risks.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def cif(self, x, event):
def fit_from_df(
cls, df, x_col, e_col, c_col=None, n_col=None, method="Nelson-Aalen"
):

x, c, n, e = validate_cr_df_inputs(df, x_col, e_col, c_col, n_col)
model = cls.fit(x, e, c, n, method)
model.df = df
Expand Down
1 change: 0 additions & 1 deletion surpyval/competing_risks/fine_gray.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ class FineGray_:
# return log_like, jac, hess

def fit(self, x, Z, e, c=None, n=None):

x, Z, e, c, n = validate_fine_gray_inputs(x, Z, e, c, n)

unique_e = list(set(e))
Expand Down
135 changes: 107 additions & 28 deletions surpyval/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import importlib.resources

import numpy as np
import pandas as pd
from pkg_resources import resource_filename # type: ignore

from surpyval.utils.recurrent_utils import handle_xicn

data_module = importlib.import_module("surpyval.datasets")


class BoforsSteel_:
Expand Down Expand Up @@ -50,10 +56,10 @@ class BoforsSteel_:
"""

def __init__(self):
self.data = pd.read_csv(
resource_filename("surpyval", "datasets/bofors_steel.csv"),
engine="python",
)
with importlib.resources.path(
data_module, "bofors_steel.csv"
) as data_path:
self.data = pd.read_csv(data_path, engine="python")

def __repr__(self):
return """
Expand All @@ -67,10 +73,8 @@ def __repr__(self):

class Boston_:
def __init__(self):
self.data = pd.read_csv(
resource_filename("surpyval", "datasets/boston.csv"),
engine="python",
)
with importlib.resources.path(data_module, "boston.csv") as data_path:
self.data = pd.read_csv(data_path, engine="python")


class Bearing_:
Expand Down Expand Up @@ -114,10 +118,8 @@ def __repr__(self):

class Heart_:
def __init__(self):
self.data = pd.read_csv(
resource_filename("surpyval", "datasets/heart.csv"),
engine="python",
)
with importlib.resources.path(data_module, "heart.csv") as data_path:
self.data = pd.read_csv(data_path, engine="python")

def __repr__(self):
return """
Expand All @@ -128,9 +130,8 @@ def __repr__(self):

class Lung_:
def __init__(self):
self.data = pd.read_csv(
resource_filename("surpyval", "datasets/lung.csv"), engine="python"
)
with importlib.resources.path(data_module, "lung.csv") as data_path:
self.data = pd.read_csv(data_path, engine="python")

def __repr__(self):
return """
Expand All @@ -141,14 +142,12 @@ def __repr__(self):

class Rossi_:
def __init__(self):
self.data = pd.read_csv(
resource_filename("surpyval", "datasets/rossi.csv"),
engine="python",
)
self.time_varying_data = pd.read_csv(
resource_filename("surpyval", "datasets/rossi_tv.csv"),
engine="python",
)
with importlib.resources.path(data_module, "rossi.csv") as data_path:
self.data = pd.read_csv(data_path, engine="python")
with importlib.resources.path(
data_module, "rossi_tv.csv"
) as data_path:
self.time_varying_data = pd.read_csv(data_path, engine="python")

def __repr__(self):
return """
Expand All @@ -159,10 +158,8 @@ def __repr__(self):

class Tires_:
def __init__(self):
self.data = pd.read_csv(
resource_filename("surpyval", "datasets/tires.csv"),
engine="python",
)
with importlib.resources.path(data_module, "tires.csv") as data_path:
self.data = pd.read_csv(data_path, engine="python")

def __repr__(self):
return """
Expand All @@ -176,10 +173,92 @@ def __repr__(self):
"""


class RecurrentDataExample1_:
def __init__(self):
x = np.array(
[
2227.08,
2733.229,
3524.214,
5568.634,
5886.165,
5946.301,
6018.219,
7202.724,
8760,
772.9542,
1034.458,
3011.114,
3121.458,
3624.158,
3758.296,
5000,
900.9855,
1289.95,
2689.878,
3928.824,
4328.317,
4704.24,
5052.586,
5473.171,
6200,
411.407,
1122.74,
1300,
688.897,
915.101,
2650,
105.824,
500,
]
)

c = (
[0] * 8
+ [1]
+ [0] * 6
+ [1]
+ [0] * 8
+ [1]
+ [0] * 2
+ [1]
+ [0] * 2
+ [1]
+ [0] * 1
+ [1]
)
i = [1] * 9 + [2] * 7 + [3] * 9 + [4] * 3 + [5] * 3 + [6] * 2
n = np.ones_like(x)
self.data = handle_xicn(x, i, c, n, as_recurrent_data=True)

def __repr__(self):
return """
Data from:
"Modeling and Analysis of Repairable Systems with General Repair."
Mettas and Zhao (2005).
""".strip()


class RecurrentDataExample2_:
def __init__(self):
x = np.array([3, 6, 11, 5, 16, 9, 19, 22, 37, 23, 31, 45]).cumsum()

self.data = handle_xicn(x, as_recurrent_data=True)

def __repr__(self):
return """
Data from:
"G1-Renewal Process as Repairable System Model."
Kaminskiy and Krivtsov (2010).
""".strip()


BoforsSteel = BoforsSteel_()
Bearing = Bearing_()
Boston = Boston_()
Heart = Heart_()
Lung = Lung_()
Rossi = Rossi_()
Tires = Tires_()
RecurrentDataExample1 = RecurrentDataExample1_()
RecurrentDataExample2 = RecurrentDataExample2_()
Loading

0 comments on commit f3ed2be

Please sign in to comment.