Skip to content

Commit

Permalink
add joblib as parallel backend
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyangkang committed Jan 31, 2024
1 parent 5155e71 commit c42b0fd
Show file tree
Hide file tree
Showing 8 changed files with 508 additions and 548 deletions.
304 changes: 157 additions & 147 deletions stemflow/model/AdaSTEM.py

Large diffs are not rendered by default.

88 changes: 65 additions & 23 deletions stemflow/model/SphereAdaSTEM.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from types import MethodType
from typing import Callable, Tuple, Union

import joblib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from numpy import ndarray
Expand All @@ -16,14 +18,14 @@
from ..utils.sphere.discriminant_formula import intersect_triangle_plane

#
from ..utils.sphere_quadtree import get_ensemble_sphere_quadtree
from ..utils.sphere_quadtree import get_one_ensemble_sphere_quadtree
from ..utils.validation import (
check_base_model,
check_njobs,
check_prediciton_aggregation,
check_spatio_bin_jitter_magnitude,
check_task,
check_temporal_bin_start_jitter,
check_transform_njobs,
check_verbosity,
)
from ..utils.wrapper import model_wrapper
Expand Down Expand Up @@ -81,6 +83,7 @@ def __init__(
plot_ylims: Tuple[Union[float, int], Union[float, int]] = (-90, 90),
verbosity: int = 0,
plot_empty: bool = False,
radius: float = 6371.0,
):
"""Make a Spherical AdaSTEM object
Expand Down Expand Up @@ -152,6 +155,8 @@ def __init__(
0 to output nothing and everything otherwise.
plot_empty:
Whether to plot the empty grid
radius:
radius of earth in km.
Raises:
AttributeError: Base model do not have method 'fit' or 'predict'
Expand Down Expand Up @@ -219,7 +224,11 @@ def __init__(
warnings.warn('the input Spatio1 is not "latitude"! Set to "latitude"')
self.Spatio2 = "latitude"

def split(self, X_train: pd.core.frame.DataFrame, verbosity: Union[None, int] = None, ax=None) -> dict:
self.radius = radius

def split(
self, X_train: pd.core.frame.DataFrame, verbosity: Union[None, int] = None, ax=None, njobs: int = 1
) -> dict:
"""QuadTree indexing the input data
Args:
Expand All @@ -230,10 +239,8 @@ def split(self, X_train: pd.core.frame.DataFrame, verbosity: Union[None, int] =
Returns:
self.grid_dict, a dictionary of one DataFrame for each grid, containing the gridding information
"""
if verbosity is None:
verbosity = self.verbosity

fold = self.ensemble_fold
verbosity = check_verbosity(self, verbosity)
njobs = check_transform_njobs(self, njobs)
save_path = os.path.join(self.save_dir, "ensemble_quadtree_df.csv") if self.save_tmp else ""

if "grid_len" not in self.__dir__():
Expand All @@ -243,30 +250,61 @@ def split(self, X_train: pd.core.frame.DataFrame, verbosity: Union[None, int] =
# We har using STEM
pass

self.ensemble_df, self.gridding_plot = get_ensemble_sphere_quadtree(
X_train[[self.Spatio1, self.Spatio2, self.Temporal1]],
Temporal1=self.Temporal1,
size=fold,
grid_len_upper_threshold=self.grid_len_upper_threshold,
grid_len_lower_threshold=self.grid_len_lower_threshold,
points_lower_threshold=self.points_lower_threshold,
partial_get_one_ensemble_sphere_quadtree = partial(
get_one_ensemble_sphere_quadtree,
data=X_train,
spatio_bin_jitter_magnitude=self.spatio_bin_jitter_magnitude,
temporal_start=self.temporal_start,
temporal_end=self.temporal_end,
temporal_step=self.temporal_step,
temporal_bin_interval=self.temporal_bin_interval,
temporal_bin_start_jitter=self.temporal_bin_start_jitter,
spatio_bin_jitter_magnitude=self.spatio_bin_jitter_magnitude,
save_gridding_plotly=self.save_gridding_plot, # currently only allow output plotly
Temporal1=self.Temporal1,
radius=self.radius,
grid_len_upper_threshold=self.grid_len_upper_threshold,
grid_len_lower_threshold=self.grid_len_lower_threshold,
points_lower_threshold=self.points_lower_threshold,
plot_empty=self.plot_empty,
save_gridding_plot=False,
njobs=self.njobs,
verbosity=verbosity,
plot_xlims=self.plot_xlims,
plot_ylims=self.plot_ylims,
save_path=save_path,
save_gridding_plotly=self.save_gridding_plot,
ax=ax,
plot_empty=self.plot_empty,
)

if njobs > 1 and isinstance(njobs, int):
parallel = joblib.Parallel(n_jobs=njobs, return_as="generator")
output_generator = parallel(
joblib.delayed(partial_get_one_ensemble_sphere_quadtree)(i) for i in list(range(self.ensemble_fold))
)
if verbosity > 0:
output_generator = tqdm(output_generator, total=self.ensemble_fold, desc="Generating Ensemble: ")

ensemble_all_df_list = [i for i in output_generator]

else:
iter_func_ = (
tqdm(range(self.ensemble_fold), total=self.ensemble_fold, desc="Generating Ensemble: ")
if verbosity > 0
else range(self.ensemble_fold)
)
ensemble_all_df_list = [
partial_get_one_ensemble_sphere_quadtree(ensemble_count) for ensemble_count in iter_func_
]

ensemble_df = pd.concat(ensemble_all_df_list).reset_index(drop=True)
ensemble_df = ensemble_df.reset_index(drop=True)

del ensemble_all_df_list

if not save_path == "":
ensemble_df.to_csv(save_path, index=False)
print(f"Saved! {save_path}")

if self.save_gridding_plot:
self.ensemble_df, self.gridding_plot = ensemble_df, ax

else:
self.ensemble_df, self.gridding_plot = ensemble_df, None

def SAC_ensemble_training(self, index_df: pd.core.frame.DataFrame, data: pd.core.frame.DataFrame):
"""A sub-module of SAC training function.
Train only one ensemble.
Expand All @@ -279,6 +317,7 @@ def SAC_ensemble_training(self, index_df: pd.core.frame.DataFrame, data: pd.core
# Calculate the start indices for the sliding window
unique_start_indices = np.sort(index_df[f"{self.Temporal1}_start"].unique())
# training, window by window
res_list = []
for start in unique_start_indices:
window_data_df = data[
(data[self.Temporal1] >= start) & (data[self.Temporal1] < start + self.temporal_bin_interval)
Expand Down Expand Up @@ -326,12 +365,15 @@ def find_belonged_points(df, df_a):
continue

# train
(
res = (
query_results.reset_index(drop=False, level=[0, 1])
.dropna(subset="unique_stixel_id")
.groupby("unique_stixel_id")
.apply(lambda stixel: self.stixel_fitting(stixel))
)
res_list.append(list(res))

return res_list

def SAC_ensemble_predict(
self, index_df: pd.core.frame.DataFrame, data: pd.core.frame.DataFrame
Expand Down
4 changes: 2 additions & 2 deletions stemflow/model/static_func_AdaSTEM.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def train_one_stixel(


def assign_points_to_one_ensemble(
ensemble_df: pd.core.frame.DataFrame,
ensemble: str,
ensemble_df: pd.core.frame.DataFrame,
Sample_ST_df: pd.core.frame.DataFrame,
Temporal1: str,
Spatio1: str,
Expand Down Expand Up @@ -191,8 +191,8 @@ def assign_points_to_one_ensemble(


def assign_points_to_one_ensemble_sphere(
ensemble_df: pd.core.frame.DataFrame,
ensemble: str,
ensemble_df: pd.core.frame.DataFrame,
Sample_ST_df: pd.core.frame.DataFrame,
Temporal1: str,
Spatio1: str,
Expand Down
Loading

0 comments on commit c42b0fd

Please sign in to comment.