Skip to content

Commit

Permalink
Add feature importance utils
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyangkang committed Sep 11, 2023
1 parent 2b6663d commit c09fefb
Show file tree
Hide file tree
Showing 9 changed files with 888 additions and 84 deletions.
687 changes: 642 additions & 45 deletions docs/Examples/01.AdaSTEM_demo.ipynb

Large diffs are not rendered by default.

18 changes: 15 additions & 3 deletions docs/Examples/04.Prediction_visualization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@
},
{
"cell_type": "code",
"execution_count": 38,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -198,9 +198,21 @@
},
{
"cell_type": "code",
"execution_count": 39,
"execution_count": 2,
"metadata": {},
"outputs": [],
"outputs": [
{
"ename": "NameError",
"evalue": "name 'model' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[2], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m pred \u001b[39m=\u001b[39m model\u001b[39m.\u001b[39mpredict(X_test)\n",
"\u001b[0;31mNameError\u001b[0m: name 'model' is not defined"
]
}
],
"source": [
"pred = model.predict(X_test)\n"
]
Expand Down
Binary file added docs/FTR_IPT_slope_mean.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
10 changes: 6 additions & 4 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ theme:
- content.code.copy
- search.share
- navigation.indexes
- navigation.expand


language: en
Expand Down Expand Up @@ -53,10 +54,11 @@ plugins:
handlers:
python:
options:
# inherited_members: true
docstring_style: google
docstring_section_style: list
show_object_full_path: true
show_root_full_path: true
# show_object_full_path: true
# show_root_full_path: true

extra:
social:
Expand All @@ -78,16 +80,16 @@ nav:
- API:
- stemflow.model:
- 'AdaSTEM': API/stemflow.model.AdaSTEM.md
- 'dummy_model': API/stemflow.model.dummy_model.md
- 'Hurdle': API/stemflow.model.Hurdle.md
- 'dummy_model': API/stemflow.model.dummy_model.md
- stemflow.model_selection: API/stemflow.model_selection.md
- stemflow.utils:
- 'quadtree': API/stemflow.utils.quadtree.md
- 'plot_gif': API/stemflow.utils.plot_gif.md
- 'generate_soft_colors': API/stemflow.utils.generate_soft_colors.md


markdown_extensions:
- tables
- toc:
toc_depth : "1-1"
- pymdownx.highlight:
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
with codecs.open(os.path.join(here, "README.md"), encoding="utf-8") as fh:
long_description = "\n" + fh.read()

VERSION = '0.0.3'
VERSION = '0.0.4'
DESCRIPTION = 'A package for Adaptive Spatio-Temporal Model (AdaSTEM) in python'
LONG_DESCRIPTION = 'stemflow is a toolkit for Adaptive Spatio-Temporal Model (AdaSTEM) in python. A typical usage is daily abundance estimation using eBird citizen science data. It leverages the "adjacency" information of surrounding target values in space and time, to predict the classes/continues values of target spatial-temporal point. In the demo, we use a two-step hurdle model as "base model", with XGBoostClassifier for occurence modeling and XGBoostRegressor for abundance modeling.'

Expand All @@ -32,7 +32,7 @@
keywords=['python', 'spatial-temporal model', 'ebird', 'citizen science', 'spatial temporal exploratory model',
'STEM','AdaSTEM','abundance','phenology'],
classifiers=[
"Development Status :: 1 - Planning",
"Development Status :: 5 - Production/Stable",
"Programming Language :: Python :: 3",
"Operating System :: Unix",
"Operating System :: MacOS :: MacOS X",
Expand Down
218 changes: 202 additions & 16 deletions stemflow/model/AdaSTEM.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ def __init__(self,
Ensemble plot.
model_dict (dict):
Dictionary of {stixel_index: trained_model}.
grid_dict (dict):
An array of stixels assigned to each emsemble.
feature_importances_ (pd.core.frame.DataFrame):
Feature importance dataframe for each stixel.
"""
# save base model
Expand Down Expand Up @@ -361,7 +365,7 @@ def fit(self,
self.stixel_specific_x_names[name] = [i for i in self.stixel_specific_x_names[name] if not i in \
list(sub_X_train.columns[sub_X_train.std(axis=0)==0])]

# continue, if not variable left
# continue, if no variable left
if len(self.stixel_specific_x_names[name])==0:
continue

Expand All @@ -386,6 +390,9 @@ def fit(self,
except Exception as e:
warnings.warn(e)
continue

# Finally, calculate feature importance
self.calculate_feature_importances()


def predict_proba(self,
Expand Down Expand Up @@ -426,9 +433,8 @@ def predict_proba(self,
X_test_copy = X_test.copy()

round_res_list = []
ensemble_df = self.ensemble_df
for ensemble in list(ensemble_df.ensemble_index.unique()):
this_ensemble = ensemble_df[ensemble_df.ensemble_index==ensemble]
for ensemble in list(self.ensemble_df.ensemble_index.unique()):
this_ensemble = self.ensemble_df[self.ensemble_df.ensemble_index==ensemble]
this_ensemble['stixel_calibration_point_transformed_left_bound'] = \
[i[0] for i in this_ensemble['stixel_calibration_point(transformed)']]

Expand All @@ -452,11 +458,11 @@ def predict_proba(self,
for index,line in iter_func:
grid_index = line['unique_stixel_id']
sub_X_test = X_test_copy[
(X_test_copy.DOY>=line['DOY_start']) & (X_test_copy.DOY<=line['DOY_end']) & \
(X_test_copy.lon_new>=line['stixel_calibration_point_transformed_left_bound']) &\
(X_test_copy.lon_new<=line['stixel_calibration_point_transformed_right_bound']) &\
(X_test_copy.lat_new>=line['stixel_calibration_point_transformed_lower_bound']) &\
(X_test_copy.lat_new<=line['stixel_calibration_point_transformed_upper_bound'])
(X_test_copy[self.Temporal1]>=line[f'{self.Temporal1}_start']) & (X_test_copy[self.Temporal1]<=line[f'{self.Temporal1}_end']) & \
(X_test_copy[f'{self.Spatio1}_new']>=line['stixel_calibration_point_transformed_left_bound']) &\
(X_test_copy[f'{self.Spatio1}_new']<=line['stixel_calibration_point_transformed_right_bound']) &\
(X_test_copy[f'{self.Spatio2}_new']>=line['stixel_calibration_point_transformed_lower_bound']) &\
(X_test_copy[f'{self.Spatio2}_new']<=line['stixel_calibration_point_transformed_upper_bound'])
]

if len(sub_X_test)==0:
Expand All @@ -471,7 +477,10 @@ def predict_proba(self,

try:
model = self.model_dict[f'{ensemble}_{grid_index}_model']
stixel_specific_x_names = self.stixel_specific_x_names[f'{ensemble}_{grid_index}']
if isinstance(model, dummy_model1):
stixel_specific_x_names = self.x_names
else:
stixel_specific_x_names = self.stixel_specific_x_names[f'{ensemble}_{grid_index}']

if self.task=='regression':
pred = model.predict(np.array(sub_X_test[stixel_specific_x_names]))
Expand Down Expand Up @@ -583,8 +592,8 @@ def transform_pred_set_to_STEM_quad(self,
"""

x_array = X_train['longitude']
y_array = X_train['latitude']
x_array = X_train[self.Spatio1]
y_array = X_train[self.Spatio2]
coord = np.array([x_array, y_array]).T
angle = float(ensemble_info.iloc[0,:]['rotation'])
r = angle/360
Expand All @@ -603,8 +612,8 @@ def transform_pred_set_to_STEM_quad(self,
long_new = (coord[:,0] + calibration_point_x_jitter).tolist()
lat_new = (coord[:,1] + calibration_point_y_jitter).tolist()

X_train['lon_new'] = long_new
X_train['lat_new'] = lat_new
X_train[f'{self.Spatio1}_new'] = long_new
X_train[f'{self.Spatio2}_new'] = lat_new

return X_train

Expand Down Expand Up @@ -734,9 +743,186 @@ def score(self,
self.score_dict = score_dict
return self.score_dict

def calclate_feature_importance():
def calculate_feature_importances(self):
"""A method to generate feature importance values for each stixel.
Feature importances are saved in self.feature_importances_.
Attribute dependence:
1. self.ensemble_df
2. self.model_dict
3. self.stixel_specific_x_names
4. The input base model should have attribute `feature_importances_`
"""
# generate feature importance dict
feature_importance_list = []

for index,ensemble_row in self.ensemble_df.drop('checklist_indexes', axis=1).iterrows():
if ensemble_row['stixel_checklist_count']<self.stixel_training_size_threshold:
continue
try:
ensemble_index = ensemble_row['ensemble_index']
stixel_index = ensemble_row['unique_stixel_id']
the_model = self.model_dict[f'{ensemble_index}_{stixel_index}_model']
x_names = self.stixel_specific_x_names[f'{ensemble_index}_{stixel_index}']

if isinstance(the_model, dummy_model1):
importance_dict = dict(zip(self.x_names, [1/len(self.x_names)] * len(self.x_names)))
else:
feature_imp = the_model.feature_importances_
importance_dict = dict(zip(x_names, feature_imp))

importance_dict['stixel_index'] = stixel_index
feature_importance_list.append(importance_dict)

except Exception as e:
print(e)
continue

self.feature_importances_ = pd.DataFrame(feature_importance_list).set_index('stixel_index').reset_index(drop=False).fillna(0)

def assign_feature_importances_by_points(self,
Sample_ST_df: Union[pd.core.frame.DataFrame, None] = None,
verbosity: int=0,
aggregation: str='mean'
) -> pd.core.frame.DataFrame:
"""Assign feature importance to the input spatio-temporal points
Args:
Sample_ST_df (Union[pd.core.frame.DataFrame, None], optional):
Dataframe that indicate the spatio-temporal points of interest.
Must contain `self.Spatio1`, `self.Spatio2`, and `self.Temporal1` in columns.
If None, the resolution will be:
| varibale|values|
|---------|--------|
|Spatio_var1|np.arange(-180,180,1)|
|Spatio_var2|np.arange(-90,90,1)|
|Temporal_var1|np.arange(1,366,7)|
Defaults to None.
verbosity (int, optional):
Whether to show progressbar during assigning. 0 for No, otherwise Yes. Defaults to 0.
aggregation (str, optional):
One of 'mean' and 'median' to aggregate feature importance across ensembles.
Raises:
NameError:
feature_importances_ attribute is not calculated. Try model.calculate_feature_importances() first.
ValueError:
f'aggregation not one of [\'mean\',\'median\'].'
KeyError:
One of [`self.Spatio1`, `self.Spatio2`, `self.Temporal1`] not found in `Sample_ST_df.columns`
Returns:
DataFrame with feature importance assigned.
"""
#
if not 'feature_importances_' in dir(self):
raise NameError(f'feature_importances_ attribute is not calculated. Try model.calculate_feature_importances() first.')
#
if not aggregation in ['mean','median']:
raise ValueError(f'aggregation not one of [\'mean\',\'median\'].')

#
if not (Sample_ST_df is None):
for var_name in [self.Spatio1, self.Spatio2, self.Temporal1]:
if not var_name in Sample_ST_df.columns:
raise KeyError(f'{var_name} not found in Sample_ST_df.columns')
else:
Spatio_var1 = np.arange(-180,180,1)
Spatio_var2 = np.arange(-90,90,1)
Temporal_var1 = np.arange(1,366,7)
new_Spatio_var1, new_Spatio_var2, new_Temporal_var1 = np.meshgrid(
Spatio_var1,Spatio_var2,Temporal_var1
)

Sample_ST_df = pd.DataFrame({
self.Temporal1: new_Temporal_var1.flatten(),
self.Spatio1: new_Spatio_var1.flatten(),
self.Spatio2: new_Spatio_var2.flatten()
})

# assign input spatio-temporal points to stixels
round_res_list = []
for ensemble in list(self.ensemble_df.ensemble_index.unique()):
this_ensemble = self.ensemble_df[self.ensemble_df.ensemble_index==ensemble]
this_ensemble['stixel_calibration_point_transformed_left_bound'] = \
[i[0] for i in this_ensemble['stixel_calibration_point(transformed)']]

this_ensemble['stixel_calibration_point_transformed_lower_bound'] = \
[i[1] for i in this_ensemble['stixel_calibration_point(transformed)']]

this_ensemble['stixel_calibration_point_transformed_right_bound'] = \
this_ensemble['stixel_calibration_point_transformed_left_bound'] + this_ensemble['stixel_width']

this_ensemble['stixel_calibration_point_transformed_upper_bound'] = \
this_ensemble['stixel_calibration_point_transformed_lower_bound'] + this_ensemble['stixel_height']

Sample_ST_df = self.transform_pred_set_to_STEM_quad(Sample_ST_df.reset_index(drop=True),
this_ensemble)

##### pred each stixel
res_list = []
iter_func = this_ensemble.iterrows() if verbosity==0 else tqdm(this_ensemble.iterrows(),
total=len(this_ensemble),
desc=f'Processing {ensemble} ')

for index,line in iter_func:
stixel_index = line['unique_stixel_id']
sub_Sample_ST_df = Sample_ST_df[
(Sample_ST_df[self.Temporal1]>=line[f'{self.Temporal1}_start']) & (Sample_ST_df[self.Temporal1]<=line[f'{self.Temporal1}_end']) & \
(Sample_ST_df[f'{self.Spatio1}_new']>=line['stixel_calibration_point_transformed_left_bound']) &\
(Sample_ST_df[f'{self.Spatio1}_new']<=line['stixel_calibration_point_transformed_right_bound']) &\
(Sample_ST_df[f'{self.Spatio2}_new']>=line['stixel_calibration_point_transformed_lower_bound']) &\
(Sample_ST_df[f'{self.Spatio2}_new']<=line['stixel_calibration_point_transformed_upper_bound'])
]

if len(sub_Sample_ST_df)==0:
continue

# load feature_importances
try:
this_feature_importance = self.feature_importances_[self.feature_importances_['stixel_index']==stixel_index]
if len(this_feature_importance)==0:
continue
this_feature_importance = dict(this_feature_importance.iloc[0,:])
res_list.append({
'sample_index':list(sub_Sample_ST_df.index),
**{a:[b]*len(sub_Sample_ST_df) for a,b in zip(this_feature_importance.keys(),
this_feature_importance.values())}
})

except Exception as e:
print(e)
continue

res_list = pd.concat([pd.DataFrame(i) for i in res_list], axis=0).drop('stixel_index', axis=1)
res_list = res_list.groupby('sample_index').mean().reset_index(drop=False)
round_res_list.append(res_list)

round_res_df = pd.concat(round_res_list, axis=0)

ensemble_available_count = round_res_df.groupby('sample_index').count().iloc[:,0]

# Only points with more than self.min_ensemble_required ensembles available are used
usable_sample = ensemble_available_count[ensemble_available_count>=self.min_ensemble_required] #
round_res_df = round_res_df[round_res_df['sample_index'].isin(list(usable_sample.index))]

# aggregate across ensembles
if aggregation=='mean':
mean_feature_importances_across_ensembles = round_res_df.groupby('sample_index').mean()
elif aggregation=='median':
mean_feature_importances_across_ensembles = round_res_df.groupby('sample_index').median()

if self.use_temporal_to_train:
mean_feature_importances_across_ensembles = mean_feature_importances_across_ensembles.rename(columns={self.Temporal1:f'{self.Temporal1}_predictor'})
out_ = pd.concat([Sample_ST_df, mean_feature_importances_across_ensembles], axis=1).dropna()
return out_



raise NotImplementedError()



Expand Down
Binary file modified stemflow/model/__pycache__/AdaSTEM.cpython-39.pyc
Binary file not shown.
Binary file modified stemflow/utils/__pycache__/plot_gif.cpython-39.pyc
Binary file not shown.
Loading

0 comments on commit c09fefb

Please sign in to comment.