Skip to content

Commit

Permalink
Add ST CV
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyangkang committed Sep 5, 2023
1 parent f21e3b2 commit 9c277ba
Show file tree
Hide file tree
Showing 13 changed files with 837 additions and 638 deletions.
430 changes: 430 additions & 0 deletions 01.AdaSTEM_demo.ipynb

Large diffs are not rendered by default.

5 changes: 0 additions & 5 deletions BirdSTEM/dataset/load_test_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1 @@

####
# input: speices name
# output: training set, tesing set, prediction set
####

43 changes: 42 additions & 1 deletion BirdSTEM/model/Hurdle.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def fit(self, X_train, y_train, sample_weight=None):
self.classifier.fit(new_dat[:,:-1], np.where(new_dat[:,-1]>0, 1, 0), sample_weight=sample_weight)
else:
self.classifier.fit(new_dat[:,:-1], np.where(new_dat[:,-1]>0, 1, 0))
self.regressor.fit(new_dat[new_dat[:,-1]>0,:][:,:-1], new_dat[new_dat[:,-1]>0,:][:,-1])
self.regressor.fit(new_dat[new_dat[:,-1]>0,:][:,:-1], np.array(new_dat[new_dat[:,-1]>0,:][:,-1]))

def predict(self, X_test):
cls_res = self.classifier.predict(X_test)
Expand All @@ -60,3 +60,44 @@ def predict_proba(self, X_test):
# res = cls_res * reg_res
# return res


class Hurdle_for_AdaSTEM(BaseEstimator):
def __init__(self, classifier, regressor):
'''
The input classifier should have function:
1. predict
and the regressor should have
1. predict
'''
self.classifier = classifier
self.regressor = regressor


def fit(self, X_train, y_train):
'''
y_train should be a continued feature
'''
binary_ =np.unique(np.where(y_train>0, 1, 0))
if len(binary_)==1:
warnings.warn('Warning: only one class presented. Replace with dummy classifier & regressor.')
self.classifier = dummy_model1(binary_[0])
self.regressor = dummy_model1(binary_[0])
return

X_train['y_train'] = y_train

self.classifier.fit(X_train.iloc[:,:-1], np.where(X_train.iloc[:,-1].values>0, 1, 0))
self.regressor.fit(X_train[X_train['y_train']>0].iloc[:,:-1], np.array(X_train[X_train['y_train']>0].iloc[:,-1]))

def predict(self, X_test):
cls_res = self.classifier.predict(X_test)
reg_res = self.regressor.predict(X_test)
# reg_res = np.where(reg_res>=0, reg_res, 0) ### we constrain the reg value to be positive
res = np.where(cls_res<0.5, 0, cls_res)
res = np.where(cls_res>0.5, reg_res, cls_res)
return res.reshape(-1,1)

def predict_proba(self, X_test):
return self.predict(self, X_test)
207 changes: 207 additions & 0 deletions BirdSTEM/model_selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
from pandas.core.frame import DataFrame
from numpy import ndarray
import numpy as np
import pandas as pd
from .utils import check_random_state

def ST_train_test_split(X: DataFrame, y,
Spatio1: str = 'longitude', Spatio2: str = 'latitude', Temporal1: str = 'DOY',
Spatio_blocks_count = 10,
Temporal_blocks_count = 10,
test_size = 0.3,
random_state = None,
) -> (DataFrame, ndarray):
"""Spatial Temporal train-test split
Parameters
----------
X: DataFrame
y: DataFrame or numpy array
Spatio1: str
column name of spatial indicator 1
Spatio2: str
column name of spatial indicator 2
Temporal1: str
column name of temporal indicator 1
Spatio_blocks_count: int
How many block to split for spatio indicators
Temporal_blocks_count: int
How many block to split for temporal indicators
test_size: float
Fraction of test set in terms of blocks count
random_state: int
random state for choosing testing blocks
Returns
---------
X_train: DataFrame
X_test: DataFrame
y_train: np.darray
y_test: np.darray
"""
# random seed
rng = check_random_state(random_state)

# validate
if not isinstance(X, DataFrame):
type_x = str(type(X))
raise TypeError(f'X input should be pandas.core.frame.DataFrame, Got {type_x}')
if not (isinstance(y, DataFrame) or isinstance(y, ndarray)):
type_y = str(type(y))
raise TypeError(f'y input should be pandas.core.frame.DataFrame or numpy.ndarray, Got {type_y}')

# check shape match
y_size = np.array(y).flatten().shape[0]
X_size = X.shape[0]
if not y_size == X_size:
raise ValueError(f'The shape of X and y should match. Got X: {X_size}, y: {y_size}')

# indexing
Sindex1 = np.linspace(X[Spatio1].min(), X[Spatio1].max(), Spatio_blocks_count)
Sindex2 = np.linspace(X[Spatio2].min(), X[Spatio2].max(), Spatio_blocks_count)
Tindex1 = np.linspace(X[Temporal1].min(), X[Temporal1].max(), Temporal_blocks_count)

indexes = [str(a)+'_'+str(b)+'_'+str(c) for a,b,c in zip(
np.digitize(X[Spatio1],Sindex1),
np.digitize(X[Spatio2],Sindex2),
np.digitize(X[Temporal1],Tindex1)
)]

unique_indexes = list(np.unique(indexes))

# get test set record indexes
test_indexes = []
test_cell = list(rng.choice(unique_indexes, replace=False, size=int(len(unique_indexes)*test_size)))

for index, cell in enumerate(indexes):
if cell in test_cell:
test_indexes.append(index)

# get train set record indexes
train_indexes = list(set(range(len(indexes))) - set(test_indexes))

# get train test data
X_train = X.iloc[train_indexes, :]
y_train = np.array(y).flatten()[train_indexes].reshape(-1,1)
X_test = X.iloc[test_indexes, :]
y_test = np.array(y).flatten()[test_indexes].reshape(-1,1)

return X_train, X_test, y_train, y_test




def ST_CV(X: DataFrame, y,
Spatio1: str = 'longitude', Spatio2: str = 'latitude', Temporal1: str = 'DOY',
Spatio_blocks_count = 10,
Temporal_blocks_count = 10,
random_state = None,
CV=3,
):
"""Spatial Temporal train-test split
Parameters
----------
X: DataFrame
y: DataFrame or numpy array
Spatio1: str
column name of spatial indicator 1
Spatio2: str
column name of spatial indicator 2
Temporal1: str
column name of temporal indicator 1
Spatio_blocks_count: int
How many block to split for spatio indicators
Temporal_blocks_count: int
How many block to split for temporal indicators
test_size: float
Fraction of test set in terms of blocks count
random_state: int
random state for choosing testing blocks
CV: int
fold cross validation
Returns
---------
X_train: DataFrame
X_test: DataFrame
y_train: np.darray
y_test: np.darray
"""
# random seed
rng = check_random_state(random_state)

# validate
if not isinstance(X, DataFrame):
type_x = str(type(X))
raise TypeError(f'X input should be pandas.core.frame.DataFrame, Got {type_x}')
if not (isinstance(y, DataFrame) or isinstance(y, ndarray)):
type_y = str(type(y))
raise TypeError(f'y input should be pandas.core.frame.DataFrame or numpy.ndarray, Got {type_y}')
if not (isinstance(CV, int) and CV>0):
raise ValueError('CV should be a positive interger')

# check shape match
y_size = np.array(y).flatten().shape[0]
X_size = X.shape[0]
if not y_size == X_size:
raise ValueError(f'The shape of X and y should match. Got X: {X_size}, y: {y_size}')

# indexing
Sindex1 = np.linspace(X[Spatio1].min(), X[Spatio1].max(), Spatio_blocks_count)
Sindex2 = np.linspace(X[Spatio2].min(), X[Spatio2].max(), Spatio_blocks_count)
Tindex1 = np.linspace(X[Temporal1].min(), X[Temporal1].max(), Temporal_blocks_count)

indexes = [str(a)+'_'+str(b)+'_'+str(c) for a,b,c in zip(
np.digitize(X[Spatio1],Sindex1),
np.digitize(X[Spatio2],Sindex2),
np.digitize(X[Temporal1],Tindex1)
)]

unique_indexes = list(np.unique(indexes))
rng.shuffle(unique_indexes)
test_size = int(len(unique_indexes) * (1/CV))

for cv_count in range(CV):
# get test set record indexes
test_indexes = []
start = cv_count*test_size
end = np.min([(cv_count+1)*test_size, len(unique_indexes)+1])
test_cell = unique_indexes[start: end]

for index, cell in enumerate(indexes):
if cell in test_cell:
test_indexes.append(index)

# get train set record indexes
train_indexes = list(set(range(len(indexes))) - set(test_indexes))

# get train test data
X_train = X.iloc[train_indexes, :]
y_train = np.array(y).flatten()[train_indexes].reshape(-1,1)
X_test = X.iloc[test_indexes, :]
y_test = np.array(y).flatten()[test_indexes].reshape(-1,1)

yield X_train, X_test, y_train, y_test



2 changes: 0 additions & 2 deletions BirdSTEM/test.py

This file was deleted.

3 changes: 3 additions & 0 deletions BirdSTEM/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .validation import (
check_random_state
)
Loading

0 comments on commit 9c277ba

Please sign in to comment.