diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6c955b6..593bf7b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,25 +3,25 @@ repos: - repo: https://github.com/ambv/black - rev: 23.1.0 + rev: 23.12.1 hooks: - id: black args: ["stemflow", "--line-length=120", "--target-version=py37"] - repo: https://github.com/pycqa/flake8 - rev: 6.0.0 + rev: 7.0.0 hooks: - id: flake8 args: ["--select=C,E,F,W,B,B950", "--max-line-length=120", "--ignore=E203,E501,W503,F401,F403"] - repo: https://github.com/timothycrosley/isort - rev: 5.12.0 + rev: 5.13.2 hooks: - id: isort args: ["-l 120", "--profile", "black", "."] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: check-yaml exclude: recipe/meta.yaml diff --git a/stemflow/gridding/QTree.py b/stemflow/gridding/QTree.py index 41c37da..c7ac814 100644 --- a/stemflow/gridding/QTree.py +++ b/stemflow/gridding/QTree.py @@ -12,8 +12,9 @@ import pandas as pd from ..utils.generate_soft_colors import generate_soft_color +from ..utils.jitterrotation.jitterrotator import JitterRotator from ..utils.validation import check_random_state -from .Q_blocks import Node, Point +from .Q_blocks import QNode, QPoint os.environ["MKL_NUM_THREADS"] = "1" os.environ["NUMEXPR_NUM_THREADS"] = "1" @@ -23,7 +24,7 @@ def recursive_subdivide( - node: Node, + node: QNode, grid_len_lon_upper_threshold: Union[float, int], grid_len_lon_lower_threshold: Union[float, int], grid_len_lat_upper_threshold: Union[float, int], @@ -53,7 +54,7 @@ def recursive_subdivide( h_ = float(node.height / 2) p = contains(node.x0, node.y0, w_, h_, node.points) - x1 = Node(node.x0, node.y0, w_, h_, p) + x1 = QNode(node.x0, node.y0, w_, h_, p) recursive_subdivide( x1, grid_len_lon_upper_threshold, @@ -64,7 +65,7 @@ def recursive_subdivide( ) p = contains(node.x0, node.y0 + h_, w_, h_, node.points) - x2 = Node(node.x0, node.y0 + h_, w_, h_, p) + x2 = QNode(node.x0, node.y0 + h_, w_, h_, p) recursive_subdivide( x2, grid_len_lon_upper_threshold, @@ -75,7 +76,7 @@ def recursive_subdivide( ) p = contains(node.x0 + w_, node.y0, w_, h_, node.points) - x3 = Node(node.x0 + w_, node.y0, w_, h_, p) + x3 = QNode(node.x0 + w_, node.y0, w_, h_, p) recursive_subdivide( x3, grid_len_lon_upper_threshold, @@ -86,7 +87,7 @@ def recursive_subdivide( ) p = contains(node.x0 + w_, node.y0 + h_, w_, h_, node.points) - x4 = Node(node.x0 + w_, node.y0 + h_, w_, h_, p) + x4 = QNode(node.x0 + w_, node.y0 + h_, w_, h_, p) recursive_subdivide( x4, grid_len_lon_upper_threshold, @@ -204,17 +205,12 @@ def add_lon_lat_data(self, indexes: Sequence, x_array: Sequence, y_array: Sequen if not len(x_array) == len(y_array) or not len(x_array) == len(indexes): raise ValueError("input longitude and latitude and indexes not in same length!") - data = np.array([x_array, y_array]).T - angle = self.rotation_angle - r = angle / 360 - theta = r * np.pi * 2 - rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) - data = data @ rotation_matrix - lon_new = (data[:, 0] + self.calibration_point_x_jitter).tolist() - lat_new = (data[:, 1] + self.calibration_point_y_jitter).tolist() + lon_new, lat_new = JitterRotator.rotate_jitter( + x_array, y_array, self.rotation_angle, self.calibration_point_x_jitter, self.calibration_point_y_jitter + ) for index, lon, lat in zip(indexes, lon_new, lat_new): - self.points.append(Point(index, lon, lat)) + self.points.append(QPoint(index, lon, lat)) def generate_gridding_params(self): """generate the gridding params after data are added @@ -233,7 +229,7 @@ def generate_gridding_params(self): self.left_bottom_point = (left_bottom_point_x, left_bottom_point_y) if self.lon_lat_equal_grid is True: - self.root = Node( + self.root = QNode( left_bottom_point_x, left_bottom_point_y, max(self.grid_length_x, self.grid_length_y), @@ -241,7 +237,7 @@ def generate_gridding_params(self): self.points, ) elif self.lon_lat_equal_grid is False: - self.root = Node( + self.root = QNode( left_bottom_point_x, left_bottom_point_y, self.grid_length_x, self.grid_length_y, self.points ) else: @@ -272,58 +268,37 @@ def graph(self, scatter: bool = True, ax=None): c = find_children(self.root) - # areas = set() - # width_set = set() - # height_set = set() - # for el in c: - # areas.add(el.width * el.height) - # width_set.add(el.width) - # height_set.add(el.height) - - theta = -(self.rotation_angle / 360) * np.pi * 2 - rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) - for n in c: - xy0_trans = np.array([[n.x0, n.y0]]) - if self.calibration_point_x_jitter: - new_x = xy0_trans[:, 0] - self.calibration_point_x_jitter - else: - new_x = xy0_trans[:, 0] - - if self.calibration_point_y_jitter: - new_y = xy0_trans[:, 1] - self.calibration_point_y_jitter - else: - new_y = xy0_trans[:, 1] - new_xy = np.array([[new_x[0], new_y[0]]]) @ rotation_matrix - new_x = new_xy[:, 0] - new_y = new_xy[:, 1] + old_x, old_y = JitterRotator.inverse_jitter_rotate( + [n.x0], [n.y0], self.rotation_angle, self.calibration_point_x_jitter, self.calibration_point_y_jitter + ) if ax is None: plt.gcf().gca().add_patch( patches.Rectangle( - (new_x, new_y), n.width, n.height, fill=False, angle=self.rotation_angle, color=the_color + (old_x, old_y), n.width, n.height, fill=False, angle=self.rotation_angle, color=the_color ) ) else: ax.add_patch( patches.Rectangle( - (new_x, new_y), n.width, n.height, fill=False, angle=self.rotation_angle, color=the_color + (old_x, old_y), n.width, n.height, fill=False, angle=self.rotation_angle, color=the_color ) ) - x = np.array([point.x for point in self.points]) - self.calibration_point_x_jitter - y = np.array([point.y for point in self.points]) - self.calibration_point_y_jitter - - data = np.array([x, y]).T @ rotation_matrix if scatter: + old_x, old_y = JitterRotator.inverse_jitter_rotate( + [point.x for point in self.points], + [point.y for point in self.points], + self.rotation_angle, + self.calibration_point_x_jitter, + self.calibration_point_y_jitter, + ) + if ax is None: - plt.scatter( - data[:, 0].tolist(), data[:, 1].tolist(), s=0.2, c="tab:blue", alpha=0.7 - ) # plots the points as red dots + plt.scatter(old_x, old_y, s=0.2, c="tab:blue", alpha=0.7) # plots the points as red dots else: - ax.scatter( - data[:, 0].tolist(), data[:, 1].tolist(), s=0.2, c="tab:blue", alpha=0.7 - ) # plots the points as red dots + ax.scatter(old_x, old_y, s=0.2, c="tab:blue", alpha=0.7) # plots the points as red dots return def get_final_result(self) -> pandas.core.frame.DataFrame: @@ -333,13 +308,13 @@ def get_final_result(self) -> pandas.core.frame.DataFrame: results (DataFrame): A pandas dataframe containing the gridding information """ all_grids = find_children(self.root) - point_indexes_list = [] + # point_indexes_list = [] point_grid_width_list = [] point_grid_height_list = [] point_grid_points_number_list = [] calibration_point_list = [] for grid in all_grids: - point_indexes_list.append([point.index for point in grid.points]) + # point_indexes_list.append([point.index for point in grid.points]) point_grid_width_list.append(grid.width) point_grid_height_list.append(grid.height) point_grid_points_number_list.append(len(grid.points)) @@ -347,7 +322,7 @@ def get_final_result(self) -> pandas.core.frame.DataFrame: result = pd.DataFrame( { - "checklist_indexes": point_indexes_list, + # "checklist_indexes": point_indexes_list, "stixel_indexes": list(range(len(point_grid_width_list))), "stixel_width": point_grid_width_list, "stixel_height": point_grid_height_list, diff --git a/stemflow/gridding/Q_blocks.py b/stemflow/gridding/Q_blocks.py index 3e0b918..0229ac6 100644 --- a/stemflow/gridding/Q_blocks.py +++ b/stemflow/gridding/Q_blocks.py @@ -1,12 +1,13 @@ """I call this Q_blocks because they are essential blocks for QTree methods""" -from typing import Tuple, Union +from collections.abc import Sequence +from typing import List, Tuple, Union from ..utils.sphere.coordinate_transform import lonlat_spherical_transformer from ..utils.sphere.distance import spherical_distance_from_coordinates -class Point: +class QPoint: """A Point class for recording data points""" def __init__(self, index, x, y): @@ -15,7 +16,7 @@ def __init__(self, index, x, y): self.index = index -class Node: +class QNode: """A tree-like division node class""" def __init__( @@ -24,7 +25,7 @@ def __init__( y0: Union[float, int], w: Union[float, int], h: Union[float, int], - points: list[Point], + points: Sequence, ): self.x0 = x0 self.y0 = y0 @@ -43,7 +44,7 @@ def get_points(self): return self.points -class Grid: +class QGrid: """Grid class for STEM (fixed gird size)""" def __init__(self, x_index, y_index, x_range, y_range): @@ -76,7 +77,7 @@ def __init__( inclination2: Union[float, int], azimuth3: Union[float, int], inclination3: Union[float, int], - points: list[Sphere_Point], + points: Sequence, ): self.x0 = x0 self.y0 = y0 diff --git a/stemflow/gridding/QuadGrid.py b/stemflow/gridding/QuadGrid.py index 593537c..dfc24aa 100644 --- a/stemflow/gridding/QuadGrid.py +++ b/stemflow/gridding/QuadGrid.py @@ -12,8 +12,9 @@ import pandas as pd from ..utils.generate_soft_colors import generate_soft_color +from ..utils.jitterrotation.jitterrotator import JitterRotator from ..utils.validation import check_random_state -from .Q_blocks import Grid, Node, Point +from .Q_blocks import QGrid, QNode, QPoint os.environ["MKL_NUM_THREADS"] = "1" os.environ["NUMEXPR_NUM_THREADS"] = "1" @@ -64,7 +65,7 @@ def add_lon_lat_data(self, indexes: Sequence, x_array: Sequence, y_array: Sequen lat_new = (data[:, 1] + self.calibration_point_y_jitter).tolist() for index, lon, lat in zip(indexes, lon_new, lat_new): - self.points.append(Point(index, lon, lat)) + self.points.append(QPoint(index, lon, lat)) def generate_gridding_params(self): """For completeness""" @@ -85,10 +86,10 @@ def subdivide(self): ymin = np.min(y_list) ymax = np.max(y_list) - self.x_start = xmin - self.grid_len + self.calibration_point_x_jitter - self.x_end = xmax + self.grid_len + self.calibration_point_x_jitter - self.y_start = ymin - self.grid_len + self.calibration_point_y_jitter - self.y_end = ymax + self.grid_len + self.calibration_point_y_jitter + self.x_start = xmin - self.grid_len + self.x_end = xmax + self.grid_len + self.y_start = ymin - self.grid_len + self.y_end = ymax + self.grid_len x_grids = np.arange(self.x_start, self.x_end, self.grid_len) y_grids = np.arange(self.y_start, self.y_end, self.grid_len) @@ -99,7 +100,7 @@ def subdivide(self): self.grids = [] for i in range(len(x_grids) - 1): for j in range(len(y_grids) - 1): - gird = Grid(i, j, (x_grids[i], x_grids[i + 1]), (y_grids[j], y_grids[j + 1])) + gird = QGrid(i, j, (x_grids[i], x_grids[i + 1]), (y_grids[j], y_grids[j + 1])) self.grids.append(gird) # Use numpy.digitize to bin points into grids @@ -112,30 +113,27 @@ def subdivide(self): grid.points = [self.points[i] for i in indices] def graph(self, scatter: bool = True, ax=None): - the_color = generate_soft_color() + """plot gridding - theta = -(self.rotation_angle / 360) * np.pi * 2 - rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) + Args: + scatter: Whether add scatterplot of data points + """ - for grid in self.grids: - xy0_trans = np.array([[grid.x_range[0], grid.y_range[0]]]) - if self.calibration_point_x_jitter: - new_x = xy0_trans[:, 0] - self.calibration_point_x_jitter - else: - new_x = xy0_trans[:, 0] + the_color = generate_soft_color() - if self.calibration_point_y_jitter: - new_y = xy0_trans[:, 1] - self.calibration_point_y_jitter - else: - new_y = xy0_trans[:, 1] - new_xy = np.array([[new_x[0], new_y[0]]]) @ rotation_matrix - new_x = new_xy[:, 0] - new_y = new_xy[:, 1] + for grid in self.grids: + old_x, old_y = JitterRotator.inverse_jitter_rotate( + [grid.x_range[0]], + [grid.y_range[0]], + self.rotation_angle, + self.calibration_point_x_jitter, + self.calibration_point_y_jitter, + ) if ax is None: plt.gcf().gca().add_patch( patches.Rectangle( - (new_x, new_y), + (old_x, old_y), self.grid_len, self.grid_len, fill=False, @@ -146,7 +144,7 @@ def graph(self, scatter: bool = True, ax=None): else: ax.add_patch( patches.Rectangle( - (new_x, new_y), + (old_x, old_y), self.grid_len, self.grid_len, fill=False, @@ -155,19 +153,19 @@ def graph(self, scatter: bool = True, ax=None): ) ) - x = np.array([point.x for point in self.points]) - self.calibration_point_x_jitter - y = np.array([point.y for point in self.points]) - self.calibration_point_y_jitter - - data = np.array([x, y]).T @ rotation_matrix if scatter: + old_x, old_y = JitterRotator.inverse_jitter_rotate( + [point.x for point in self.points], + [point.y for point in self.points], + self.rotation_angle, + self.calibration_point_x_jitter, + self.calibration_point_y_jitter, + ) + if ax is None: - plt.scatter( - data[:, 0].tolist(), data[:, 1].tolist(), s=0.2, c="tab:blue", alpha=0.7 - ) # plots the points as red dots + plt.scatter(old_x, old_y, s=0.2, c="tab:blue", alpha=0.7) # plots the points as red dots else: - ax.scatter( - data[:, 0].tolist(), data[:, 1].tolist(), s=0.2, c="tab:blue", alpha=0.7 - ) # plots the points as red dots + ax.scatter(old_x, old_y, s=0.2, c="tab:blue", alpha=0.7) # plots the points as red dots return @@ -178,21 +176,21 @@ def get_final_result(self) -> pandas.core.frame.DataFrame: results (DataFrame): A pandas dataframe containing the gridding information """ - point_indexes_list = [] + # point_indexes_list = [] point_grid_width_list = [] point_grid_height_list = [] point_grid_points_number_list = [] calibration_point_list = [] for grid in self.grids: - point_indexes_list.append([point.index for point in grid.points]) + # point_indexes_list.append([point.index for point in grid.points]) point_grid_width_list.append(self.grid_len) point_grid_height_list.append(self.grid_len) point_grid_points_number_list.append(len(grid.points)) - calibration_point_list.append((round(grid.x_range[0], 6), round(grid.x_range[0], 6))) + calibration_point_list.append((round(grid.x_range[0], 6), round(grid.y_range[0], 6))) result = pd.DataFrame( { - "checklist_indexes": point_indexes_list, + # "checklist_indexes": point_indexes_list, "stixel_indexes": list(range(len(point_grid_width_list))), "stixel_width": point_grid_width_list, "stixel_height": point_grid_height_list, diff --git a/stemflow/gridding/__init__.py b/stemflow/gridding/__init__.py index a714bae..903b937 100644 --- a/stemflow/gridding/__init__.py +++ b/stemflow/gridding/__init__.py @@ -1 +1 @@ -from .Q_blocks import Grid, Node, Point +from .Q_blocks import QGrid, QNode, QPoint diff --git a/stemflow/manually_testing.py b/stemflow/manually_testing.py index 6edafe8..ecfe376 100644 --- a/stemflow/manually_testing.py +++ b/stemflow/manually_testing.py @@ -131,6 +131,8 @@ def make_AdaSTEM_model1(fold_, min_req, ensemble_models_disk_saver, ensemble_mod Spatio1="longitude", Spatio2="latitude", Temporal1="DOY", + temporal_bin_start_jitter="adaptive", + spatio_bin_jitter_magnitude="adaptive", use_temporal_to_train=True, ensemble_models_disk_saver=ensemble_models_disk_saver, ensemble_models_disk_saving_dir=ensemble_models_disk_saving_dir, @@ -165,6 +167,8 @@ def make_AdaSTEM_model2(fold_, min_req, ensemble_models_disk_saver, ensemble_mod Spatio1="longitude", Spatio2="latitude", Temporal1="DOY", + temporal_bin_start_jitter="adaptive", + spatio_bin_jitter_magnitude="adaptive", use_temporal_to_train=True, ensemble_models_disk_saver=ensemble_models_disk_saver, ensemble_models_disk_saving_dir=ensemble_models_disk_saving_dir, @@ -213,6 +217,7 @@ def run_mini_test( """ # print("Start Running Mini-test...") + start_time = time.time() from xgboost import XGBClassifier, XGBRegressor from stemflow.model.AdaSTEM import AdaSTEM, AdaSTEMClassifier, AdaSTEMRegressor @@ -300,7 +305,8 @@ def run_mini_test( [ i for i in importances_by_points.columns - if i not in ["DOY", "longitude", "latitude", "longitude_new", "latitude_new"] + if i + not in [model.Temporal1, model.Spatio1, model.Spatio2, f"{model.Spatio1}_new", f"{model.Spatio2}_new"] ] ] .mean() @@ -374,7 +380,6 @@ def run_mini_test( assert os.path.exists(os.path.join(tmp_dir, "error_plot.pdf")) # 11.Evaluation - # %% print("Predicting on test set...") pred = model.predict(X_test) @@ -383,7 +388,7 @@ def run_mini_test( # %% perc = np.sum(np.isnan(pred.flatten())) / len(pred.flatten()) print(f"Percentage not predictable {round(perc*100, 2)}%") - assert perc < 0.05 + assert perc < 0.5 # %% pred_df = pd.DataFrame( @@ -417,6 +422,9 @@ def run_mini_test( assert not os.path.exists(tmp_dir) print("Finish!") + end_time = time.time() + time_use = end_time - start_time + print(f"Total time use: {time_use}") return model diff --git a/stemflow/model/AdaSTEM.py b/stemflow/model/AdaSTEM.py index b804e56..1085a05 100644 --- a/stemflow/model/AdaSTEM.py +++ b/stemflow/model/AdaSTEM.py @@ -28,14 +28,28 @@ roc_auc_score, ) from tqdm import tqdm +from tqdm.auto import tqdm as tqdm_auto # from ..utils.quadtree import get_ensemble_quadtree +from ..utils.validation import ( + check_base_model, + check_njobs, + check_prediciton_aggregation, + check_prediction_return, + check_random_state, + check_spatio_bin_jitter_magnitude, + check_task, + check_temporal_bin_start_jitter, + check_verbosity, + check_X_test, + check_X_train, + check_y_train, +) +from ..utils.wrapper import model_wrapper from .dummy_model import dummy_model1 - -# from ..utils.validation import check_random_state +from .Hurdle import Hurdle from .static_func_AdaSTEM import ( # predict_one_ensemble - _monkey_patched_predict_proba, assign_points_to_one_ensemble, get_model_and_stixel_specific_x_names, predict_one_stixel, @@ -66,8 +80,8 @@ def __init__( temporal_end: Union[float, int] = 366, temporal_step: Union[float, int] = 20, temporal_bin_interval: Union[float, int] = 50, - temporal_bin_start_jitter: Union[float, int, str] = "random", - spatio_bin_jitter_magnitude: Union[float, int] = 100, + temporal_bin_start_jitter: Union[float, int, str] = "adaptive", + spatio_bin_jitter_magnitude: Union[float, int] = "adaptive", save_gridding_plot: bool = True, save_tmp: bool = False, save_dir: str = "./", @@ -117,10 +131,10 @@ def __init__( size of the sliding window. Defaults to 50. temporal_bin_start_jitter: jitter of the start of the sliding window. - If 'random', a random jitter of range (-bin_interval, 0) will be generated - for the start. Defaults to 'random'. + If 'adaptive', a random jitter of range (-bin_interval, 0) will be generated + for the start. Defaults to 'adaptive'. spatio_bin_jitter_magnitude: - jitter of the spatial gridding. Defaults to 10. + jitter of the spatial gridding. Defaults to 'adaptive'. save_gridding_plot: Whether ot save gridding plots. Defaults to True. save_tmp: @@ -179,89 +193,63 @@ def __init__( feature importance dataframe for each stixel. """ - # save base model + # 1. Base model + check_base_model(base_model) + base_model = model_wrapper(base_model) self.base_model = base_model - self.Spatio1 = Spatio1 - self.Spatio2 = Spatio2 - self.Temporal1 = Temporal1 - self.use_temporal_to_train = use_temporal_to_train - self.subset_x_names = subset_x_names - self.ensemble_models_disk_saver = ensemble_models_disk_saver - self.ensemble_models_disk_saving_dir = ensemble_models_disk_saving_dir - if self.ensemble_models_disk_saver: - self.saving_code = np.random.randint(1, 1e8, 1) - - for func in ["fit", "predict"]: - if func not in dir(self.base_model): - raise AttributeError(f"input base model must have method '{func}'!") - - self.base_model = self.model_wrapper(self.base_model) + # 2. Model params + check_task(task) self.task = task - if self.task not in ["regression", "classification", "hurdle"]: - raise AttributeError( - f"task type must be one of 'regression', 'classification', or 'hurdle'! Now it is {self.task}" - ) - if self.task == "hurdle": - warnings.warn( - "You have chosen HURDLE task. The goal is to first conduct classification, and then apply regression on points with *positive values*" - ) + self.Temporal1 = Temporal1 + self.Spatio1 = Spatio1 + self.Spatio2 = Spatio2 + # 3. Gridding params self.ensemble_fold = ensemble_fold self.min_ensemble_required = min_ensemble_required - self.grid_len_upper_threshold = ( self.grid_len_lon_upper_threshold ) = self.grid_len_lat_upper_threshold = grid_len_upper_threshold self.grid_len_lower_threshold = ( self.grid_len_lon_lower_threshold ) = self.grid_len_lat_lower_threshold = grid_len_lower_threshold - self.points_lower_threshold = points_lower_threshold self.temporal_start = temporal_start self.temporal_end = temporal_end self.temporal_step = temporal_step self.temporal_bin_interval = temporal_bin_interval - self.spatio_bin_jitter_magnitude = spatio_bin_jitter_magnitude - self.plot_xlims = plot_xlims - self.plot_ylims = plot_ylims - # validate temporal_bin_start_jitter - if type(temporal_bin_start_jitter) not in [str, float, int]: - raise AttributeError( - f"Input temporal_bin_start_jitter should be 'random', float or int, got {type(temporal_bin_start_jitter)}" - ) - if type(temporal_bin_start_jitter) == str: - if not temporal_bin_start_jitter == "random": - raise AttributeError( - f"The input temporal_bin_start_jitter as string should only be 'random'. Other options include float or int. Got {temporal_bin_start_jitter}" - ) + check_spatio_bin_jitter_magnitude(spatio_bin_jitter_magnitude) + self.spatio_bin_jitter_magnitude = spatio_bin_jitter_magnitude + check_temporal_bin_start_jitter(temporal_bin_start_jitter) self.temporal_bin_start_jitter = temporal_bin_start_jitter - # + # 4. Training params if stixel_training_size_threshold is None: self.stixel_training_size_threshold = points_lower_threshold else: self.stixel_training_size_threshold = stixel_training_size_threshold + self.use_temporal_to_train = use_temporal_to_train + self.subset_x_names = subset_x_names + self.sample_weights_for_classifier = sample_weights_for_classifier - self.save_gridding_plot = save_gridding_plot + # 5. Multi-threading params (not implemented yet) + check_njobs(njobs) + self.njobs = njobs + + # 6. Plotting params + self.plot_xlims = plot_xlims + self.plot_ylims = plot_ylims self.save_tmp = save_tmp self.save_dir = save_dir + self.save_gridding_plot = save_gridding_plot - # validate njobs setting - if not isinstance(njobs, int): - raise TypeError(f"njobs is not a integer. Got {njobs}.") - elif njobs > 1: - raise NotImplementedError("Multi-thread processing is not implemented yet.") - - my_cpu_count = cpu_count() - if njobs > my_cpu_count: - raise ValueError(f"Setting of njobs ({njobs}) exceed the maximum ({my_cpu_count}).") - else: - self.njobs = njobs - - # - self.sample_weights_for_classifier = sample_weights_for_classifier + # X. miscellaneous + self.ensemble_models_disk_saver = ensemble_models_disk_saver + self.ensemble_models_disk_saving_dir = ensemble_models_disk_saving_dir + if self.ensemble_models_disk_saver: + self.saving_code = np.random.randint(1, 1e8, 1) if not verbosity == 0: self.verbosity = 1 @@ -311,7 +299,7 @@ def split(self, X_train: pd.core.frame.DataFrame, verbosity: Union[None, int] = temporal_bin_start_jitter=self.temporal_bin_start_jitter, spatio_bin_jitter_magnitude=self.spatio_bin_jitter_magnitude, save_gridding_plot=self.save_gridding_plot, - njobs=1, # self.njobs, + njobs=self.njobs, verbosity=verbosity, plot_xlims=self.plot_xlims, plot_ylims=self.plot_ylims, @@ -319,43 +307,113 @@ def split(self, X_train: pd.core.frame.DataFrame, verbosity: Union[None, int] = ax=ax, ) - self.grid_dict = {} - for ensemble_index in self.ensemble_df.ensemble_index.unique(): - this_ensemble = self.ensemble_df[self.ensemble_df.ensemble_index == ensemble_index] + def store_x_names(self, X_train): + # store x_names + self.x_names = list(X_train.columns) + if not self.use_temporal_to_train: + if self.Temporal1 in list(self.x_names): + del self.x_names[self.x_names.index(self.Temporal1)] - this_ensemble_gird_info = {} - this_ensemble_gird_info["checklist_index"] = [] - this_ensemble_gird_info["stixel"] = [] - for index, line in this_ensemble.iterrows(): - this_ensemble_gird_info["checklist_index"].extend(line["checklist_indexes"]) - this_ensemble_gird_info["stixel"].extend([line["unique_stixel_id"]] * len(line["checklist_indexes"])) + # if 'geometry' in self.x_names: + # del self.x_names[self.x_names.index('geometry')] - cores = pd.DataFrame(this_ensemble_gird_info) - cores2 = pd.DataFrame(list(X_train.index), columns=["data_point_index"]) - cores = pd.merge(cores, cores2, left_on="checklist_index", right_on="data_point_index", how="right") + for i in [self.Spatio1, self.Spatio2]: + if i in self.x_names: + del self.x_names[self.x_names.index(i)] - self.grid_dict[ensemble_index] = cores.stixel.values + def stixel_fitting(self, stixel): + """Fit one stixel - return self.grid_dict + Args: + stixel (gpd.geodataframe.GeoDataFrame): data sjoined with ensemble_df. + For a single stixel. + """ - def model_wrapper(self, model: BaseEstimator) -> BaseEstimator: - """wrap a predict_proba function for those models who don't have + ensemble_index = int(stixel["ensemble_index"].iloc[0]) + unique_stixel_id = stixel["unique_stixel_id"].iloc[0] + name = f"{ensemble_index}_{unique_stixel_id}" + + model, stixel_specific_x_names, status = train_one_stixel( + stixel_training_size_threshold=self.stixel_training_size_threshold, + x_names=self.x_names, + task=self.task, + base_model=self.base_model, + sample_weights_for_classifier=self.sample_weights_for_classifier, + subset_x_names=self.subset_x_names, + stixel_X_train=stixel, + ) - Args: - model: - Input model + if not status == "Success": + # print(f'Fitting: {ensemble_index}. Not pass: {status}') + pass - Returns: - Wrapped model that has a `predict_proba` method + else: + self.model_dict[f"{name}_model"] = model + self.stixel_specific_x_names[name] = stixel_specific_x_names + + def SAC_ensemble_training(self, index_df, data): + # Calculate the start indices for the sliding window + unique_start_indices = np.sort(index_df[f"{self.Temporal1}_start"].unique()) + # training, window by window + for start in unique_start_indices: + window_data_df = data[ + (data[self.Temporal1] >= start) & (data[self.Temporal1] < start + self.temporal_bin_interval) + ] + window_data_df = transform_pred_set_to_STEM_quad(self.Spatio1, self.Spatio2, window_data_df, index_df) + window_index_df = index_df[index_df[f"{self.Temporal1}_start"] == start] + + # Merge + def find_belonged_points(df, df_a): + return df_a[ + (df_a[f"{self.Spatio1}_new"] >= df["stixel_calibration_point_transformed_left_bound"].iloc[0]) + & (df_a[f"{self.Spatio1}_new"] < df["stixel_calibration_point_transformed_right_bound"].iloc[0]) + & (df_a[f"{self.Spatio2}_new"] >= df["stixel_calibration_point_transformed_lower_bound"].iloc[0]) + & (df_a[f"{self.Spatio2}_new"] < df["stixel_calibration_point_transformed_upper_bound"].iloc[0]) + ] + + query_results = ( + window_index_df[ + [ + "ensemble_index", + "unique_stixel_id", + "stixel_calibration_point_transformed_left_bound", + "stixel_calibration_point_transformed_right_bound", + "stixel_calibration_point_transformed_lower_bound", + "stixel_calibration_point_transformed_upper_bound", + ] + ] + .groupby(["ensemble_index", "unique_stixel_id"]) + .apply(find_belonged_points, df_a=window_data_df) + ) + + if len(query_results) == 0: + """All points fall out of the grids""" + continue + + # train + ( + query_results.reset_index(drop=False, level=[0, 1]) + .dropna(subset="unique_stixel_id") + .groupby("unique_stixel_id") + .apply(lambda stixel: self.stixel_fitting(stixel)) + ) + + def SAC_training(self, ensemble_df, data, verbosity): + """Training with the whole input data + + split (S), apply(A), combine (C). Ensemble level """ - if "predict_proba" in dir(model): - return model - else: - warnings.warn("predict_proba function not in base_model. Monkey patching one.") - model.predict_proba = _monkey_patched_predict_proba - return model + if verbosity > 0: + tqdm_auto.pandas(desc="training", postfix=None) + ensemble_df.groupby("ensemble_index").progress_apply( + lambda ensemble: self.SAC_ensemble_training(index_df=ensemble, data=data) + ) + else: + ensemble_df.groupby("ensemble_index").apply( + lambda ensemble: self.SAC_ensemble_training(index_df=ensemble, data=data) + ) def fit( self, @@ -376,146 +434,139 @@ def fit( TypeError: y_train is not a type of np.ndarray or pd.core.frame.DataFrame """ # - if verbosity is None: - verbosity = self.verbosity - elif verbosity == 0: - verbosity = 0 - else: - verbosity = 1 - - # check type - type_X_train = type(X_train) - - if not type_X_train == pd.core.frame.DataFrame: - raise TypeError(f"Input X_train should be type 'pd.core.frame.DataFrame'. Got {str(type_X_train)}") - - type_y_train = type(y_train) - if not (isinstance(y_train, np.ndarray) or isinstance(y_train, pd.core.frame.DataFrame)): - raise TypeError( - f"Input y_train should be type 'pd.core.frame.DataFrame' or 'np.ndarray'. Got {str(type_y_train)}" - ) - - # store x_names - self.x_names = list(X_train.columns) - for i in [self.Spatio1, self.Spatio2]: - if i in list(self.x_names): - del self.x_names[self.x_names.index(i)] - if not self.use_temporal_to_train: - if self.Temporal1 in list(self.x_names): - del self.x_names[self.x_names.index(self.Temporal1)] + verbosity = check_verbosity(self, verbosity) + check_X_train(X_train) + check_y_train(y_train) + self.store_x_names(X_train) # quadtree X_train = X_train.reset_index(drop=True) # I reset index here!! caution! X_train["true_y"] = np.array(y_train).flatten() - _ = self.split(X_train, verbosity=verbosity, ax=ax) + self.split(X_train, verbosity=verbosity, ax=ax) # define model dict self.model_dict = {} # stixel specific x_names list self.stixel_specific_x_names = {} - # Training function for each stixel - if not self.njobs > 1: + if self.njobs > 1: + raise NotImplementedError("Multi-threading not implemented yet.") + else: # single processing - func_ = ( - tqdm(self.ensemble_df.iterrows(), total=len(self.ensemble_df), desc="training: ") - if verbosity > 0 - else self.ensemble_df.iterrows() - ) + self.SAC_training(self.ensemble_df, X_train, verbosity) - current_ensemble_index = None - tmp_model_dict = {} - for index, line in func_: - ensemble_index = line["ensemble_index"] - if current_ensemble_index is None: - current_ensemble_index = ensemble_index - else: - if current_ensemble_index != ensemble_index: - # All models of the previous ensembles are trained. It's time to save them. - if self.ensemble_models_disk_saver: - with open( - os.path.join( - f"{self.ensemble_models_disk_saving_dir}", - f"trained_model_ensemble_{current_ensemble_index}_{self.saving_code}_.pkl", - ), - "wb", - ) as f: - pickle.dump(tmp_model_dict, f) - - tmp_model_dict = {} - current_ensemble_index = ensemble_index - - unique_stixel_id = line["unique_stixel_id"] - name = f"{ensemble_index}_{unique_stixel_id}" - checklist_indexes = line["checklist_indexes"] - model, stixel_specific_x_names = train_one_stixel( - stixel_training_size_threshold=self.stixel_training_size_threshold, - x_names=self.x_names, - task=self.task, - base_model=self.base_model, - sample_weights_for_classifier=self.sample_weights_for_classifier, - subset_x_names=self.subset_x_names, - X_train_copy=X_train, - checklist_indexes=checklist_indexes, - ) + return self - if model is None: - continue - else: - if self.ensemble_models_disk_saver: - tmp_model_dict[f"{name}_model"] = model - else: - self.model_dict[f"{name}_model"] = model + def stixel_predict(self, stixel): + """Predict one stixel - if len(stixel_specific_x_names) == 0: - continue - else: - self.stixel_specific_x_names[name] = stixel_specific_x_names + Args: + stixel (pd.core.frame.DataFrame): data joined with ensemble_df. + For a single stixel. + """ + ensemble_index = stixel["ensemble_index"].iloc[0] + unique_stixel_id = stixel["unique_stixel_id"].iloc[0] + + model_x_names_tuple = get_model_and_stixel_specific_x_names( + self.model_dict, + ensemble_index, + unique_stixel_id, + self.stixel_specific_x_names, + self.x_names, + ) + + if model_x_names_tuple[0] is None: + return None + pred = predict_one_stixel(stixel, self.task, model_x_names_tuple) + + if pred is None: + return None else: - # multi-processing - ensemble_index_list = self.ensemble_df["ensemble_index"].values - unique_stixel_id_list = self.ensemble_df["unique_stixel_id"].values - name_list = [ - f"{ensemble_index}_{unique_stixel_id}" - for ensemble_index, unique_stixel_id in zip(ensemble_index_list, unique_stixel_id_list) + return pred + + def SAC_ensemble_predict(self, index_df, data): + """Predict one ensemble + + Args: + index_df (pd.core.frame.DataFrame): ensemble data (model.ensemble_df) + data (pd.core.frame.DataFrame): input covariates to predict + """ + + temp_start = index_df[f"{self.Temporal1}_start"].min() + temp_end = index_df[f"{self.Temporal1}_end"].max() + + # Calculate the start indices for the sliding window + start_indices = np.arange(temp_start, temp_end, self.temporal_step) + + # prediction, window by window + window_prediction_list = [] + for start in start_indices: + window_data_df = data[ + (data[self.Temporal1] >= start) & (data[self.Temporal1] < start + self.temporal_bin_interval) ] - checklist_indexes = self.ensemble_df["checklist_indexes"] + window_data_df = transform_pred_set_to_STEM_quad(self.Spatio1, self.Spatio2, window_data_df, index_df) + window_index_df = index_df[index_df[f"{self.Temporal1}_start"] == start] + + def find_belonged_points(df, df_a): + return df_a[ + (df_a[f"{self.Spatio1}_new"] >= df["stixel_calibration_point_transformed_left_bound"].iloc[0]) + & (df_a[f"{self.Spatio1}_new"] < df["stixel_calibration_point_transformed_right_bound"].iloc[0]) + & (df_a[f"{self.Spatio2}_new"] >= df["stixel_calibration_point_transformed_lower_bound"].iloc[0]) + & (df_a[f"{self.Spatio2}_new"] < df["stixel_calibration_point_transformed_upper_bound"].iloc[0]) + ] + + query_results = ( + window_index_df[ + [ + "ensemble_index", + "unique_stixel_id", + "stixel_calibration_point_transformed_left_bound", + "stixel_calibration_point_transformed_right_bound", + "stixel_calibration_point_transformed_lower_bound", + "stixel_calibration_point_transformed_upper_bound", + ] + ] + .groupby(["ensemble_index", "unique_stixel_id"]) + .apply(find_belonged_points, df_a=window_data_df) + ) - with Pool(self.njobs) as p: - plain_args_iterator = zip( - repeat(self.stixel_training_size_threshold), - repeat(self.x_names), - repeat(self.task), - repeat(self.base_model), - repeat(self.sample_weights_for_classifier), - repeat(self.subset_x_names), - repeat(X_train), - checklist_indexes, - ) - if verbosity > 0: - args_iterator = tqdm(plain_args_iterator, total=len(checklist_indexes)) - else: - args_iterator = plain_args_iterator + if len(query_results) == 0: + """All points fall out of the grids""" + continue - tmp_res = p.starmap(train_one_stixel, args_iterator) + # predict + window_prediction = ( + query_results.reset_index(drop=False, level=[0, 1]) + .dropna(subset="unique_stixel_id") + .groupby("unique_stixel_id") + .apply(lambda stixel: self.stixel_predict(stixel)) + ) - # Store model and stixel specific x_names - for res, name in zip(tmp_res, name_list): - model_ = res[0] - stixel_specific_x_names_ = res[1] + window_prediction_list.append(window_prediction) - if model_ is None: - continue - else: - self.model_dict[f"{name}_model"] = model_ + ensemble_prediction = pd.concat(window_prediction_list, axis=0) + ensemble_prediction = ensemble_prediction.droplevel(0, axis=0) + ensemble_prediction = ensemble_prediction.groupby("index").mean().reset_index(drop=False) + return ensemble_prediction - if len(stixel_specific_x_names_) == 0: - continue - else: - self.stixel_specific_x_names[name] = stixel_specific_x_names_ + def SAC_predict(self, ensemble_df, data, verbosity): + """split (S), apply(A), combine (C). Ensemble level""" - return self + if verbosity > 0: + tqdm_auto.pandas(desc="predicting", postfix=None) + pred = ensemble_df.groupby("ensemble_index").progress_apply( + lambda ensemble: self.SAC_ensemble_predict(index_df=ensemble, data=data) + ) + else: + pred = ensemble_df.groupby("ensemble_index").apply( + lambda ensemble: self.SAC_ensemble_predict(index_df=ensemble, data=data) + ) + + # pred = pred.reset_index(drop=False) + pred = pred.droplevel(1, axis=0).reset_index(drop=False) + pred = pred.pivot_table(index="index", columns="ensemble_index", values="pred") + return pred def predict_proba( self, @@ -559,162 +610,13 @@ def predict_proba( Return numpy.ndarray of shape (n_samples, n_ensembles) """ - type_X_test = type(X_test) - if not type_X_test == pd.core.frame.DataFrame: - raise TypeError(f"Input X_test should be type 'pd.core.frame.DataFrame'. Got {type_X_test}") - # - if aggregation not in ["mean", "median"]: - raise ValueError(f"aggregation must be one of 'mean' and 'median'. Got {aggregation}") - - if not isinstance(return_by_separate_ensembles, bool): - type_return_by_separate_ensembles = str(type(return_by_separate_ensembles)) - raise TypeError(f"return_by_separate_ensembles must be bool. Got {type_return_by_separate_ensembles}") - else: - if return_by_separate_ensembles and return_std: - warnings("return_by_separate_ensembles == True. Automatically setting return_std=False") - return_std = False - - if verbosity is None: - verbosity = self.verbosity - - # predict - X_test_copy = X_test.copy() - - round_res_list = [] - - for ensemble in list(self.ensemble_df.ensemble_index.unique()): - this_ensemble = self.ensemble_df[self.ensemble_df.ensemble_index == ensemble] - this_ensemble.loc[:, "stixel_calibration_point_transformed_left_bound"] = [ - i[0] for i in this_ensemble["stixel_calibration_point(transformed)"] - ] - - this_ensemble.loc[:, "stixel_calibration_point_transformed_lower_bound"] = [ - i[1] for i in this_ensemble["stixel_calibration_point(transformed)"] - ] - - this_ensemble.loc[:, "stixel_calibration_point_transformed_right_bound"] = ( - this_ensemble["stixel_calibration_point_transformed_left_bound"] + this_ensemble["stixel_width"] - ) - - this_ensemble.loc[:, "stixel_calibration_point_transformed_upper_bound"] = ( - this_ensemble["stixel_calibration_point_transformed_lower_bound"] + this_ensemble["stixel_height"] - ) - - X_test_copy = transform_pred_set_to_STEM_quad(self.Spatio1, self.Spatio2, X_test_copy, this_ensemble) - this_ensemble_index = list(this_ensemble["ensemble_index"].values)[0] - - # pred each stixel - if not njobs > 1: - # single process - res_list = [] - - temp_bin_start_list = np.unique(this_ensemble[f"{self.Temporal1}_start"]) - iter_func = ( - temp_bin_start_list - if verbosity == 0 - else tqdm( - temp_bin_start_list, total=len(temp_bin_start_list), desc=f"predicting ensemble {ensemble} " - ) - ) - this_ensemble_model_dict = None - if self.ensemble_models_disk_saver: - with open( - os.path.join( - f"{self.ensemble_models_disk_saving_dir}", - f"trained_model_ensemble_{this_ensemble_index}_{self.saving_code}.pkl", - ), - "rb", - ) as f: - this_ensemble_model_dict = pickle.load(f) - else: - this_ensemble_model_dict = self.model_dict - - for temp_bin_start in iter_func: - # query the ensemble and sub_X_test for this temporal bin - sub_temp_ensemble = this_ensemble[this_ensemble[f"{self.Temporal1}_start"] == temp_bin_start] - sub_temp_X_test_copy = X_test_copy[ - (X_test_copy[self.Temporal1] >= sub_temp_ensemble[f"{self.Temporal1}_start"].values[0]) - & (X_test_copy[self.Temporal1] < sub_temp_ensemble[f"{self.Temporal1}_end"].values[0]) - ] - - for index, stixel in sub_temp_ensemble.iterrows(): - model_x_names_tuple = get_model_and_stixel_specific_x_names( - this_ensemble_model_dict, - ensemble, - stixel["unique_stixel_id"], - self.stixel_specific_x_names, - self.x_names, - ) - - if model_x_names_tuple[0] is None: - continue - - res = predict_one_stixel( - sub_temp_X_test_copy, - self.Temporal1, - self.Spatio1, - self.Spatio2, - stixel[f"{self.Temporal1}_start"], - stixel[f"{self.Temporal1}_end"], - stixel["stixel_calibration_point_transformed_left_bound"], - stixel["stixel_calibration_point_transformed_right_bound"], - stixel["stixel_calibration_point_transformed_lower_bound"], - stixel["stixel_calibration_point_transformed_upper_bound"], - self.x_names, - self.task, - model_x_names_tuple, - ) - - if res is None: - continue - - res_list.append(res) - else: - # # multi-processing - # with Pool(njobs) as p: - # plain_args_iterator = zip( - # repeat(X_test_copy), - # repeat(self.Temporal1), - # repeat(self.Spatio1), - # repeat(self.Spatio2), - # this_ensemble[f"{self.Temporal1}_start"], - # this_ensemble[f"{self.Temporal1}_end"], - # this_ensemble["stixel_calibration_point_transformed_left_bound"], - # this_ensemble["stixel_calibration_point_transformed_right_bound"], - # this_ensemble["stixel_calibration_point_transformed_lower_bound"], - # this_ensemble["stixel_calibration_point_transformed_upper_bound"], - # repeat(self.x_names), - # repeat(self.task), - # [ - # get_model_and_stixel_specific_x_names( - # self.model_dict, ensemble, grid_index, self.stixel_specific_x_names, self.x_names - # ) - # for grid_index in this_ensemble["unique_stixel_id"] - # ], - # ) - # if verbosity > 0: - # args_iterator = tqdm( - # plain_args_iterator, total=len(this_ensemble), desc=f"predicting ensemble {ensemble} " - # ) - # else: - # args_iterator = plain_args_iterator - - # res_list = p.starmap(predict_one_stixel, args_iterator) - raise NotImplementedError("Multi-threading for prediction is not implemented yet") - - try: - res_list = pd.concat(res_list, axis=0) - except Exception as e: - print(e) - res_list = pd.DataFrame({"index": list(X_test.index), "pred": [np.nan] * len(X_test.index)}).set_index( - "index" - ) + check_X_test(X_test) + check_prediciton_aggregation(aggregation) + return_by_separate_ensembles, return_std = check_prediction_return(return_by_separate_ensembles, return_std) + verbosity = check_verbosity(self, verbosity) + check_njobs(njobs) - res_list = res_list.reset_index(drop=False).groupby("index").mean() - round_res_list.append(res_list) - - # only sites that meet the minimum ensemble requirement are kept - res = pd.concat([df["pred"] for df in round_res_list], axis=1) + res = self.SAC_predict(self.ensemble_df, X_test, verbosity=verbosity) # Experimental Function if return_by_separate_ensembles: @@ -722,25 +624,25 @@ def predict_proba( new_res = new_res.merge(res, left_on="index", right_on="index", how="left") return new_res.values + # Aggregate if aggregation == "mean": res_mean = res.mean(axis=1, skipna=True) # mean of all grid model that predicts this stixel elif aggregation == "median": res_mean = res.median(axis=1, skipna=True) - res_std = res.std(axis=1, skipna=True) + # Nan count res_nan_count = res.isnull().sum(axis=1) - res_not_nan_count = len(round_res_list) - res_nan_count - pred_mean = np.where(res_not_nan_count.values < self.min_ensemble_required, np.nan, res_mean.values) - pred_std = np.where(res_not_nan_count.values < self.min_ensemble_required, np.nan, res_std.values) + pred_mean = np.where(res_nan_count.values >= self.min_ensemble_required, np.nan, res_mean.values) + pred_std = np.where(res_nan_count.values >= self.min_ensemble_required, np.nan, res_std.values) res = pd.DataFrame({"index": list(res_mean.index), "pred_mean": pred_mean, "pred_std": pred_std}).set_index( "index" ) + # Preparing output (formatting) new_res = pd.DataFrame({"index": list(X_test.index)}).set_index("index") - new_res = new_res.merge(res, left_on="index", right_on="index", how="left") nan_count = np.sum(np.isnan(new_res["pred_mean"].values)) @@ -941,9 +843,12 @@ def calculate_feature_importances(self): # generate feature importance dict feature_importance_list = [] - for index, ensemble_row in self.ensemble_df.drop("checklist_indexes", axis=1).iterrows(): + for index, ensemble_row in self.ensemble_df[ + self.ensemble_df["stixel_checklist_count"] >= self.stixel_training_size_threshold + ].iterrows(): if ensemble_row["stixel_checklist_count"] < self.stixel_training_size_threshold: continue + try: ensemble_index = ensemble_row["ensemble_index"] stixel_index = ensemble_row["unique_stixel_id"] @@ -952,15 +857,23 @@ def calculate_feature_importances(self): if isinstance(the_model, dummy_model1): importance_dict = dict(zip(self.x_names, [1 / len(self.x_names)] * len(self.x_names))) + elif isinstance(the_model, Hurdle): + if "feature_importances_" in the_model.__dir__(): + importance_dict = dict(zip(x_names, the_model.feature_importances_)) + else: + if isinstance(the_model.classifier, dummy_model1): + importance_dict = dict(zip(self.x_names, [1 / len(self.x_names)] * len(self.x_names))) + else: + importance_dict = dict(zip(x_names, the_model.classifier.feature_importances_)) else: - feature_imp = the_model.feature_importances_ - importance_dict = dict(zip(x_names, feature_imp)) + importance_dict = dict(zip(x_names, the_model.feature_importances_)) importance_dict["stixel_index"] = stixel_index feature_importance_list.append(importance_dict) except Exception as e: warnings.warn(f"{e}") + # print(e) continue self.feature_importances_ = ( @@ -1146,8 +1059,8 @@ def __init__( temporal_end=366, temporal_step=20, temporal_bin_interval=50, - temporal_bin_start_jitter="random", - spatio_bin_jitter_magnitude=100, + temporal_bin_start_jitter="adaptive", + spatio_bin_jitter_magnitude="adaptive", save_gridding_plot=False, save_tmp=False, save_dir="./", @@ -1305,8 +1218,8 @@ def __init__( temporal_end=366, temporal_step=20, temporal_bin_interval=50, - temporal_bin_start_jitter="random", - spatio_bin_jitter_magnitude=10, + temporal_bin_start_jitter="adaptive", + spatio_bin_jitter_magnitude="adaptive", save_gridding_plot=False, save_tmp=False, save_dir="./", diff --git a/stemflow/model/STEM.py b/stemflow/model/STEM.py index 820bf89..eb6d8a6 100644 --- a/stemflow/model/STEM.py +++ b/stemflow/model/STEM.py @@ -32,8 +32,8 @@ def __init__( temporal_end: Union[float, int] = 366, temporal_step: Union[float, int] = 20, temporal_bin_interval: Union[float, int] = 50, - temporal_bin_start_jitter: Union[float, int, str] = "random", - spatio_bin_jitter_magnitude: Union[float, int] = 100, + temporal_bin_start_jitter: Union[float, int, str] = "adaptive", + spatio_bin_jitter_magnitude: Union[float, int] = "adaptive", save_gridding_plot: bool = True, save_tmp: bool = False, save_dir: str = "./", @@ -81,10 +81,10 @@ def __init__( size of the sliding window. Defaults to 50. temporal_bin_start_jitter: jitter of the start of the sliding window. - If 'random', a random jitter of range (-bin_interval, 0) will be generated - for the start. Defaults to 'random'. + If 'adaptive', a random jitter of range (-bin_interval, 0) will be generated + for the start. Defaults to 'adaptive'. spatio_bin_jitter_magnitude: - jitter of the spatial gridding. Defaults to 10. + jitter of the spatial gridding. Defaults to 'adaptive. save_gridding_plot: Whether ot save gridding plots. Defaults to True. save_tmp: @@ -215,8 +215,8 @@ def __init__( temporal_end=366, temporal_step=20, temporal_bin_interval=50, - temporal_bin_start_jitter="random", - spatio_bin_jitter_magnitude=100, + temporal_bin_start_jitter="adaptive", + spatio_bin_jitter_magnitude="adaptive", save_gridding_plot=False, save_tmp=False, save_dir="./", @@ -304,8 +304,8 @@ def __init__( temporal_end=366, temporal_step=20, temporal_bin_interval=50, - temporal_bin_start_jitter="random", - spatio_bin_jitter_magnitude=10, + temporal_bin_start_jitter="adaptive", + spatio_bin_jitter_magnitude="adaptive", save_gridding_plot=False, save_tmp=False, save_dir="./", diff --git a/stemflow/model/__init__.py b/stemflow/model/__init__.py index 36c8c76..6c0763f 100644 --- a/stemflow/model/__init__.py +++ b/stemflow/model/__init__.py @@ -1,6 +1 @@ -from .static_func_AdaSTEM import ( - _monkey_patched_predict_proba, - assign_points_to_one_ensemble, - train_one_stixel, - transform_pred_set_to_STEM_quad, -) +from .static_func_AdaSTEM import assign_points_to_one_ensemble, train_one_stixel, transform_pred_set_to_STEM_quad diff --git a/stemflow/model/static_func_AdaSTEM.py b/stemflow/model/static_func_AdaSTEM.py index 52db1ab..5bf7661 100644 --- a/stemflow/model/static_func_AdaSTEM.py +++ b/stemflow/model/static_func_AdaSTEM.py @@ -17,28 +17,12 @@ from sklearn.base import BaseEstimator from sklearn.utils import class_weight +from ..utils.jitterrotation.jitterrotator import JitterRotator from .dummy_model import dummy_model1 # warnings.filterwarnings("ignore") -def _monkey_patched_predict_proba( - model: BaseEstimator, X_train: Union[pd.core.frame.DataFrame, np.ndarray] -) -> np.ndarray: - """the monkey patching predict_proba method - - Args: - model: the input model - X_train: input training data - - Returns: - predicted proba - """ - pred = model.predict(X_train) - pred = np.array(pred).reshape(-1, 1) - return np.concatenate([np.zeros(shape=pred.shape), pred], axis=1) - - def train_one_stixel( stixel_training_size_threshold: int, x_names: Union[list, np.ndarray], @@ -46,8 +30,7 @@ def train_one_stixel( base_model: BaseEstimator, sample_weights_for_classifier: bool, subset_x_names: bool, - X_train_copy: pd.core.frame.DataFrame, - checklist_indexes: list, + stixel_X_train: pd.core.frame.DataFrame, ) -> Tuple[Union[None, BaseEstimator], list]: """Train one stixel @@ -58,30 +41,28 @@ def train_one_stixel( base_model (BaseEstimator): Base model estimator. sample_weights_for_classifier (bool): Whether to balance the sample weights in classifier for imbalanced samples. subset_x_names (bool): Whether to only store variables with std > 0 for each stixel. - X_train_copy (pd.core.frame.DataFrame): Input training dataframe. - checklist_indexes (list): Each element is a list that contain all checklist indexes for this stixel. + sub_X_train (pd.core.frame.DataFrame): Input training dataframe for THE stixel. Returns: tuple[Union[None, BaseEstimator], list]: trained_model, stixel_specific_x_names """ - sub_X_train = X_train_copy.iloc[checklist_indexes, :] - if len(sub_X_train) < stixel_training_size_threshold: # threshold - return (None, []) + if len(stixel_X_train) < stixel_training_size_threshold: # threshold + return (None, [], "Not_Enough_Data") - sub_y_train = sub_X_train.iloc[:, -1] - sub_X_train = sub_X_train[x_names] + sub_y_train = stixel_X_train["true_y"] + sub_X_train = stixel_X_train[x_names] unique_sub_y_train_binary = np.unique(np.where(sub_y_train > 0, 1, 0)) # nan check - nan_count = np.sum(np.isnan(sub_y_train)) + np.sum(np.isnan(sub_y_train)) + nan_count = np.sum(np.isnan(np.array(sub_X_train))) + np.sum(np.isnan(sub_y_train)) if nan_count > 0: - return (None, []) + return (None, [], "Contain_Nan") # fit if (not task == "regression") and (len(unique_sub_y_train_binary) == 1): trained_model = dummy_model1(float(unique_sub_y_train_binary[0])) - return (trained_model, []) + return (trained_model, [], "Success") else: # Remove the variables that have no variation stixel_specific_x_names = x_names.copy() @@ -93,7 +74,7 @@ def train_one_stixel( # continue, if no variable left if len(stixel_specific_x_names) == 0: - return (None, []) + return (None, [], "x_names_length_zero") # now we are sure to fit a model trained_model = copy.deepcopy(base_model) @@ -108,16 +89,18 @@ def train_one_stixel( except Exception as e: print(e) - return (None, []) + # raise + return (None, [], "Base_model_fitting_error(non-regression, balanced weight)") else: try: trained_model.fit(sub_X_train[stixel_specific_x_names], sub_y_train) except Exception as e: print(e) - return (None, []) + # raise + return (None, [], "Base_model_fitting_error(regression)") - return (trained_model, stixel_specific_x_names) + return (trained_model, stixel_specific_x_names, "Success") def assign_points_to_one_ensemble( @@ -160,19 +143,21 @@ def assign_points_to_one_ensemble( this_ensemble["stixel_calibration_point_transformed_lower_bound"] + this_ensemble["stixel_height"] ) - Sample_ST_df = transform_pred_set_to_STEM_quad(Spatio1, Spatio2, Sample_ST_df.reset_index(drop=True), this_ensemble) + Sample_ST_df_ = transform_pred_set_to_STEM_quad( + Spatio1, Spatio2, Sample_ST_df.reset_index(drop=True), this_ensemble + ) # pred each stixel res_list = [] for index, line in this_ensemble.iterrows(): stixel_index = line["unique_stixel_id"] - sub_Sample_ST_df = Sample_ST_df[ - (Sample_ST_df[Temporal1] >= line[f"{Temporal1}_start"]) - & (Sample_ST_df[Temporal1] < line[f"{Temporal1}_end"]) - & (Sample_ST_df[f"{Spatio1}_new"] >= line["stixel_calibration_point_transformed_left_bound"]) - & (Sample_ST_df[f"{Spatio1}_new"] <= line["stixel_calibration_point_transformed_right_bound"]) - & (Sample_ST_df[f"{Spatio2}_new"] >= line["stixel_calibration_point_transformed_lower_bound"]) - & (Sample_ST_df[f"{Spatio2}_new"] <= line["stixel_calibration_point_transformed_upper_bound"]) + sub_Sample_ST_df = Sample_ST_df_[ + (Sample_ST_df_[Temporal1] >= line[f"{Temporal1}_start"]) + & (Sample_ST_df_[Temporal1] < line[f"{Temporal1}_end"]) + & (Sample_ST_df_[f"{Spatio1}_new"] >= line["stixel_calibration_point_transformed_left_bound"]) + & (Sample_ST_df_[f"{Spatio1}_new"] <= line["stixel_calibration_point_transformed_right_bound"]) + & (Sample_ST_df_[f"{Spatio2}_new"] >= line["stixel_calibration_point_transformed_lower_bound"]) + & (Sample_ST_df_[f"{Spatio2}_new"] <= line["stixel_calibration_point_transformed_upper_bound"]) ] if len(sub_Sample_ST_df) == 0: @@ -223,25 +208,18 @@ def transform_pred_set_to_STEM_quad( """ - x_array = X_train[Spatio1] - y_array = X_train[Spatio2] - coord = np.array([x_array, y_array]).T - angle = float(ensemble_info.iloc[0, :]["rotation"]) - r = angle / 360 - theta = r * np.pi * 2 - rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) - - coord = coord @ rotation_matrix - calibration_point_x_jitter = float(ensemble_info.iloc[0, :]["space_jitter(first rotate by zero then add this)"][0]) - calibration_point_y_jitter = float(ensemble_info.iloc[0, :]["space_jitter(first rotate by zero then add this)"][1]) - - long_new = (coord[:, 0] + calibration_point_x_jitter).tolist() - lat_new = (coord[:, 1] + calibration_point_y_jitter).tolist() + angle = float(ensemble_info["rotation"].iloc[0]) + calibration_point_x_jitter = float(ensemble_info["space_jitter(first rotate by zero then add this)"].iloc[0][0]) + calibration_point_y_jitter = float(ensemble_info["space_jitter(first rotate by zero then add this)"].iloc[0][1]) - X_train[f"{Spatio1}_new"] = long_new - X_train[f"{Spatio2}_new"] = lat_new + X_train_ = X_train.copy() + a, b = JitterRotator.rotate_jitter( + X_train[Spatio1], X_train[Spatio2], angle, calibration_point_x_jitter, calibration_point_y_jitter + ) + X_train_[f"{Spatio1}_new"] = a + X_train_[f"{Spatio2}_new"] = b - return X_train + return X_train_ def get_model_by_name(model_dict: dict, ensemble: str, grid_index: str) -> Union[None, BaseEstimator]: @@ -312,17 +290,7 @@ def get_model_and_stixel_specific_x_names( def predict_one_stixel( - X_test_copy: pd.core.frame.DataFrame, - Temporal1: str, - Spatio1: str, - Spatio2: str, - Temporal1_start: Union[float, int], - Temporal1_end: Union[float, int], - stixel_calibration_point_transformed_left_bound: Union[float, int], - stixel_calibration_point_transformed_right_bound: Union[float, int], - stixel_calibration_point_transformed_lower_bound: Union[float, int], - stixel_calibration_point_transformed_upper_bound: Union[float, int], - x_names: list, + X_test_stixel: pd.core.frame.DataFrame, task: str, model_x_names_tuple: Tuple[Union[None, BaseEstimator], list], ) -> pd.core.frame.DataFrame: @@ -330,16 +298,6 @@ def predict_one_stixel( Args: X_test_copy (pd.core.frame.DataFrame): Input testing variables - Temporal1 (str): Temporal variable name 1 - Spatio1 (str): Spatio variable name 1 - Spatio2 (str): Spatio variable name 2 - Temporal1_start (Union[float, int]): Starting point of Temporal variable for the sliding window. - Temporal1_end (Union[float, int]): Ending point of Temporal variable for the sliding window. - stixel_calibration_point_transformed_left_bound (Union[float, int]): The left bound of the stixel after transformation. - stixel_calibration_point_transformed_right_bound (Union[float, int]): The right bound of the stixel after transformation. - stixel_calibration_point_transformed_lower_bound (Union[float, int]): The lower bound of the stixel after transformation. - stixel_calibration_point_transformed_upper_bound (Union[float, int]): The upper bound of the stixel after transformation. - x_names (list): Total x_names. All variables. task (str): One of 'regression', 'classification' and 'hurdle' model_x_names_tuple (tuple[Union[None, BaseEstimator], list]): A tuple of (model, stixel_specific_x_names) @@ -349,24 +307,15 @@ def predict_one_stixel( if model_x_names_tuple[0] is None: return None - X_test_copy = X_test_copy[ - (X_test_copy[Temporal1] >= Temporal1_start) - & (X_test_copy[Temporal1] <= Temporal1_end) - & (X_test_copy[f"{Spatio1}_new"] >= stixel_calibration_point_transformed_left_bound) - & (X_test_copy[f"{Spatio1}_new"] <= stixel_calibration_point_transformed_right_bound) - & (X_test_copy[f"{Spatio2}_new"] >= stixel_calibration_point_transformed_lower_bound) - & (X_test_copy[f"{Spatio2}_new"] <= stixel_calibration_point_transformed_upper_bound) - ] - - if len(X_test_copy) == 0: + if len(X_test_stixel) == 0: return None # get test data if task == "regression": - pred = model_x_names_tuple[0].predict(np.array(X_test_copy[model_x_names_tuple[1]])) + pred = model_x_names_tuple[0].predict(np.array(X_test_stixel[model_x_names_tuple[1]])) else: - pred = model_x_names_tuple[0].predict_proba(np.array(X_test_copy[model_x_names_tuple[1]]))[:, 1] + pred = model_x_names_tuple[0].predict_proba(np.array(X_test_stixel[model_x_names_tuple[1]]))[:, 1] - res = pd.DataFrame({"index": list(X_test_copy.index), "pred": np.array(pred).flatten()}).set_index("index") + res = pd.DataFrame({"index": list(X_test_stixel.index), "pred": np.array(pred).flatten()}).set_index("index") return res diff --git a/stemflow/utils/jitterrotation/jitterrotator.py b/stemflow/utils/jitterrotation/jitterrotator.py new file mode 100644 index 0000000..eb1bae9 --- /dev/null +++ b/stemflow/utils/jitterrotation/jitterrotator.py @@ -0,0 +1,123 @@ +from typing import Tuple, Union + +import numpy as np + +# import geopandas as gpd + + +class JitterRotator: + def __init__(): + pass + + # @classmethod + # def rotate_jitter_gpd(cls, + # df: gpd.geodataframe.GeoDataFrame, + # rotation_angle: Union[int, float], + # calibration_point_x_jitter: Union[int, float], + # calibration_point_y_jitter: Union[int, float] + # ) -> gpd.geodataframe.GeoDataFrame: + # """Rotate Normal lng, lat to jittered, rotated space + + # Args: + # x_array (np.ndarray): input lng/x + # y_array (np.ndarray): input lat/y + # rotation_angle (Union[int, float]): rotation angle + # calibration_point_x_jitter (Union[int, float]): calibration_point_x_jitter + # calibration_point_y_jitter (Union[int, float]): calibration_point_y_jitter + + # Returns: + # tuple(np.ndarray, np.ndarray): newx, newy + # """ + # transformed_series = df.rotate( + # rotation_angle, origin=(0,0) + # ).affine_transform( + # [1,0,0,1,calibration_point_x_jitter,calibration_point_y_jitter] + # ) + + # df1 = gpd.GeoDataFrame(df, geometry=transformed_series) + + # return df1 + + # @classmethod + # def inverse_jitter_rotate_gpd(cls, + # df_rotated: gpd.geodataframe.GeoDataFrame, + # rotation_angle: Union[int, float], + # calibration_point_x_jitter: Union[int, float], + # calibration_point_y_jitter: Union[int, float] + # ) -> gpd.geodataframe.GeoDataFrame: + # """reverse jitter and rotation + + # Args: + # x_array_rotated (np.ndarray): input lng/x + # y_array_rotated (np.ndarray): input lng/x + # rotation_angle (Union[int, float]): rotation angle + # calibration_point_x_jitter (Union[int, float]): calibration_point_x_jitter + # calibration_point_y_jitter (Union[int, float]): calibration_point_y_jitter + # """ + + # return df_rotated.affine_transform( + # [1,0,0,1,-calibration_point_x_jitter,-calibration_point_y_jitter] + # ).rotate( + # -rotation_angle, origin=(0,0) + # ) + + @classmethod + def rotate_jitter( + cls, + x_array: np.ndarray, + y_array: np.ndarray, + rotation_angle: Union[int, float], + calibration_point_x_jitter: Union[int, float], + calibration_point_y_jitter: Union[int, float], + ): + """Rotate Normal lng, lat to jittered, rotated space + + Args: + x_array (np.ndarray): input lng/x + y_array (np.ndarray): input lat/y + rotation_angle (Union[int, float]): rotation angle + calibration_point_x_jitter (Union[int, float]): calibration_point_x_jitter + calibration_point_y_jitter (Union[int, float]): calibration_point_y_jitter + + Returns: + tuple(np.ndarray, np.ndarray): newx, newy + """ + data = np.array([x_array, y_array]).T + angle = rotation_angle + r = angle / 360 + theta = r * np.pi * 2 + rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) + data = data @ rotation_matrix + lon_new = (data[:, 0] + calibration_point_x_jitter).tolist() + lat_new = (data[:, 1] + calibration_point_y_jitter).tolist() + return lon_new, lat_new + + @classmethod + def inverse_jitter_rotate( + cls, + x_array_rotated: np.ndarray, + y_array_rotated: np.ndarray, + rotation_angle: Union[int, float], + calibration_point_x_jitter: Union[int, float], + calibration_point_y_jitter: Union[int, float], + ): + """reverse jitter and rotation + + Args: + x_array_rotated (np.ndarray): input lng/x + y_array_rotated (np.ndarray): input lng/x + rotation_angle (Union[int, float]): rotation angle + calibration_point_x_jitter (Union[int, float]): calibration_point_x_jitter + calibration_point_y_jitter (Union[int, float]): calibration_point_y_jitter + """ + theta = -(rotation_angle / 360) * np.pi * 2 + rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) + + back_jitter_data = np.array( + [ + np.array(x_array_rotated) - calibration_point_x_jitter, + np.array(y_array_rotated) - calibration_point_y_jitter, + ] + ).T + back_rotated = back_jitter_data @ rotation_matrix + return back_rotated[:, 0].flatten(), back_rotated[:, 1].flatten() diff --git a/stemflow/utils/quadtree.py b/stemflow/utils/quadtree.py index 69b00c8..7ac3823 100644 --- a/stemflow/utils/quadtree.py +++ b/stemflow/utils/quadtree.py @@ -19,6 +19,7 @@ from ..gridding.QTree import QTree from ..gridding.QuadGrid import QuadGrid +from .validation import check_transform_spatio_bin_jitter_magnitude, check_transform_temporal_bin_start_jitter # from tqdm.contrib.concurrent import process_map # from .generate_soft_colors import generate_soft_color @@ -37,7 +38,7 @@ def generate_temporal_bins( end: Union[float, int], step: Union[float, int], bin_interval: Union[float, int], - temporal_bin_start_jitter: Union[float, int, str] = "random", + temporal_bin_start_jitter: Union[float, int, str] = "adaptive", ) -> list: """Generate random temporal bins that splits the data @@ -52,7 +53,7 @@ def generate_temporal_bins( size of the sliding window temporal_bin_start_jitter: jitter of the start of the sliding window. - If 'random', a random jitter of range (-bin_interval, 0) will be generated + If 'adaptive', a random jitter of range (-bin_interval, 0) will be generated for the start. Returns: @@ -62,10 +63,7 @@ def generate_temporal_bins( bin_interval = bin_interval # 50 step = step # 20 - if type(temporal_bin_start_jitter) == str and temporal_bin_start_jitter == "random": - jit = np.random.uniform(low=0, high=bin_interval) - elif type(temporal_bin_start_jitter) in [int, float]: - jit = temporal_bin_start_jitter + jit = check_transform_temporal_bin_start_jitter(temporal_bin_start_jitter, bin_interval) start = start - jit bin_list = [] @@ -82,89 +80,6 @@ def generate_temporal_bins( return bin_list -# def generate_one_ensemble( -# ensemble_count, -# spatio_bin_jitter_magnitude, -# temporal_start, -# temporal_end, -# temporal_step, -# temporal_bin_interval, -# temporal_bin_start_jitter, -# data, -# Temporal1, -# grid_len_lon_upper_threshold, -# grid_len_lon_lower_threshold, -# grid_len_lat_upper_threshold, -# grid_len_lat_lower_threshold, -# points_lower_threshold, -# Spatio1, -# Spatio2, -# save_gridding_plot, -# ax, -# ): -# this_ensemble = [] -# rotation_angle = np.random.uniform(0, 360) -# calibration_point_x_jitter = np.random.uniform(-spatio_bin_jitter_magnitude, spatio_bin_jitter_magnitude) -# calibration_point_y_jitter = np.random.uniform(-spatio_bin_jitter_magnitude, spatio_bin_jitter_magnitude) - -# # print(f'ensemble_count: {ensemble_count}') - -# temporal_bins = generate_temporal_bins( -# start=temporal_start, -# end=temporal_end, -# step=temporal_step, -# bin_interval=temporal_bin_interval, -# temporal_bin_start_jitter=temporal_bin_start_jitter, -# ) - -# for time_block_index, bin_ in enumerate(temporal_bins): -# time_start = bin_[0] -# time_end = bin_[1] -# sub_data = data[(data[Temporal1] >= time_start) & (data[Temporal1] < time_end)] - -# if len(sub_data) == 0: -# continue - -# QT_obj = QTree( -# grid_len_lon_upper_threshold=grid_len_lon_upper_threshold, -# grid_len_lon_lower_threshold=grid_len_lon_lower_threshold, -# grid_len_lat_upper_threshold=grid_len_lat_upper_threshold, -# grid_len_lat_lower_threshold=grid_len_lat_lower_threshold, -# points_lower_threshold=points_lower_threshold, -# lon_lat_equal_grid=True, -# rotation_angle=rotation_angle, -# calibration_point_x_jitter=calibration_point_x_jitter, -# calibration_point_y_jitter=calibration_point_y_jitter, -# ) - -# # Give the data and indexes. The indexes should be used to assign points data so that base model can run on those points, -# # You need to generate the splitting parameters once giving the data. Like the calibration point and min,max. - -# QT_obj.add_lon_lat_data(sub_data.index, sub_data[Spatio1].values, sub_data[Spatio2].values) -# QT_obj.generate_gridding_params() - -# # Call subdivide to precess -# QT_obj.subdivide() -# this_slice = QT_obj.get_final_result() - -# if save_gridding_plot: -# if time_block_index == int(len(temporal_bins) / 2): -# QT_obj.graph(scatter=False, ax=ax) - -# this_slice["ensemble_index"] = ensemble_count -# this_slice[f"{Temporal1}_start"] = time_start -# this_slice[f"{Temporal1}_end"] = time_end -# this_slice[f"{Temporal1}_start"] = round(this_slice[f"{Temporal1}_start"], 1) -# this_slice[f"{Temporal1}_end"] = round(this_slice[f"{Temporal1}_end"], 1) -# this_slice["unique_stixel_id"] = [ -# str(time_block_index) + "_" + str(i) + "_" + str(k) -# for i, k in zip(this_slice["ensemble_index"].values, this_slice["stixel_indexes"].values) -# ] -# this_ensemble.append(this_slice) - -# return pd.concat(this_ensemble, axis=0) - - def get_ensemble_quadtree( data: pandas.core.frame.DataFrame, Spatio1: str = "longitude", @@ -181,8 +96,8 @@ def get_ensemble_quadtree( temporal_end: Union[float, int] = 366, temporal_step: Union[float, int] = 20, temporal_bin_interval: Union[float, int] = 50, - temporal_bin_start_jitter: Union[float, int, str] = "random", - spatio_bin_jitter_magnitude: Union[float, int] = 10, + temporal_bin_start_jitter: Union[float, int, str] = "adaptive", + spatio_bin_jitter_magnitude: Union[float, int] = "adaptive", save_gridding_plot: bool = True, njobs: int = 1, verbosity: int = 1, @@ -228,7 +143,7 @@ def get_ensemble_quadtree( size of the sliding window temporal_bin_start_jitter: jitter of the start of the sliding window. - If 'random', a random jitter of range (-bin_interval, 0) will be generated + If 'adaptive', a adaptive jitter of range (-bin_interval, 0) will be generated for the start. spatio_bin_jitter_magnitude: jitter of the spatial gridding. @@ -251,6 +166,9 @@ def get_ensemble_quadtree( 2. grid plot. np.nan if save_gridding_plot=False
""" + spatio_bin_jitter_magnitude = check_transform_spatio_bin_jitter_magnitude( + data, Spatio1, Spatio2, spatio_bin_jitter_magnitude + ) ensemble_all_df_list = [] @@ -265,27 +183,6 @@ def get_ensemble_quadtree( pass if njobs > 1 and isinstance(njobs, int): - # partial_generate_one_ensemble = partial( - # generate_one_ensemble, - # spatio_bin_jitter_magnitude=spatio_bin_jitter_magnitude, - # temporal_start=temporal_start, - # temporal_end=temporal_end, - # temporal_step=temporal_step, - # temporal_bin_interval=temporal_bin_interval, - # temporal_bin_start_jitter=temporal_bin_start_jitter, - # data=data, - # Temporal1=Temporal1, - # grid_len_lon_upper_threshold=grid_len_lon_upper_threshold, - # grid_len_lon_lower_threshold=grid_len_lon_lower_threshold, - # grid_len_lat_upper_threshold=grid_len_lat_upper_threshold, - # grid_len_lat_lower_threshold=grid_len_lat_lower_threshold, - # points_lower_threshold=points_lower_threshold, - # Spatio1=Spatio1, - # Spatio2=Spatio2, - # save_gridding_plot=save_gridding_plot, - # ) - - # ensemble_all_df_list = process_map(partial_generate_one_ensemble, list(range(size)), max_workers=njobs) raise NotImplementedError("Multi-threading for ensemble generation is not implemented yet.") else: @@ -298,7 +195,6 @@ def get_ensemble_quadtree( calibration_point_y_jitter = np.random.uniform(-spatio_bin_jitter_magnitude, spatio_bin_jitter_magnitude) # print(f'ensemble_count: {ensemble_count}') - temporal_bins = generate_temporal_bins( start=temporal_start, end=temporal_end, @@ -365,6 +261,20 @@ def get_ensemble_quadtree( ensemble_all_df_list.append(this_slice) ensemble_df = pd.concat(ensemble_all_df_list).reset_index(drop=True) + ensemble_df.loc[:, "stixel_calibration_point_transformed_left_bound"] = [ + i[0] for i in ensemble_df["stixel_calibration_point(transformed)"] + ] + ensemble_df.loc[:, "stixel_calibration_point_transformed_lower_bound"] = [ + i[1] for i in ensemble_df["stixel_calibration_point(transformed)"] + ] + ensemble_df.loc[:, "stixel_calibration_point_transformed_right_bound"] = ( + ensemble_df["stixel_calibration_point_transformed_left_bound"] + ensemble_df["stixel_width"] + ) + ensemble_df.loc[:, "stixel_calibration_point_transformed_upper_bound"] = ( + ensemble_df["stixel_calibration_point_transformed_lower_bound"] + ensemble_df["stixel_height"] + ) + ensemble_df = ensemble_df.reset_index(drop=True) + del ensemble_all_df_list if not save_path == "": diff --git a/stemflow/utils/validation.py b/stemflow/utils/validation.py index ef5618a..0ad84f2 100644 --- a/stemflow/utils/validation.py +++ b/stemflow/utils/validation.py @@ -1,6 +1,8 @@ +import warnings from typing import Union import numpy as np +import pandas as pd def check_random_state(seed: Union[None, int, np.random.RandomState]) -> np.random.RandomState: @@ -23,3 +25,131 @@ def check_random_state(seed: Union[None, int, np.random.RandomState]) -> np.rand if isinstance(seed, np.random.RandomState): return seed raise ValueError("%r cannot be used to seed a numpy.random.RandomState instance" % seed) + + +def check_task(task): + if task not in ["regression", "classification", "hurdle"]: + raise AttributeError(f"task type must be one of 'regression', 'classification', or 'hurdle'! Now it is {task}") + if task == "hurdle": + warnings.warn( + "You have chosen HURDLE task. The goal is to first conduct classification, and then apply regression on points with *positive values*" + ) + + +def check_base_model(base_model): + for func in ["fit", "predict"]: + if func not in dir(base_model): + raise AttributeError(f"input base model must have method '{func}'!") + + +def check_njobs(njobs): + # validate njobs setting + if not isinstance(njobs, int): + raise TypeError(f"njobs is not a integer. Got {njobs}.") + elif njobs > 1: + raise NotImplementedError("Multi-thread processing is not implemented yet.") + + # my_cpu_count = cpu_count() + # if njobs > my_cpu_count: + # raise ValueError(f"Setting of njobs ({njobs}) exceed the maximum ({my_cpu_count}).") + else: + pass + + +def check_verbosity(self, verbosity): + if verbosity is None: + verbosity = self.verbosity + elif verbosity == 0: + verbosity = 0 + else: + verbosity = 1 + return verbosity + + +def check_spatio_bin_jitter_magnitude(spatio_bin_jitter_magnitude): + if isinstance(spatio_bin_jitter_magnitude, (int, float)): + pass + elif isinstance(spatio_bin_jitter_magnitude, str): + if spatio_bin_jitter_magnitude == "adaptive": + pass + else: + raise ValueError("spatio_bin_jitter_magnitude string must be adaptive!") + else: + raise ValueError("spatio_bin_jitter_magnitude string must be one of [int, float, 'adaptive']!") + + +def check_transform_spatio_bin_jitter_magnitude(data, Spatio1, Spatio2, spatio_bin_jitter_magnitude): + check_spatio_bin_jitter_magnitude(spatio_bin_jitter_magnitude) + if isinstance(spatio_bin_jitter_magnitude, str): + if spatio_bin_jitter_magnitude == "adaptive": + jit = max(data[Spatio1].max() - data[Spatio1].min(), data[Spatio2].max() - data[Spatio2].min()) + return jit + return spatio_bin_jitter_magnitude + + +def check_temporal_bin_start_jitter(temporal_bin_start_jitter): + # validate temporal_bin_start_jitter + if not isinstance(temporal_bin_start_jitter, (str, float, int)): + raise AttributeError( + f"Input temporal_bin_start_jitter should be 'adaptive', float or int, got {type(temporal_bin_start_jitter)}" + ) + if isinstance(temporal_bin_start_jitter, str): + if not temporal_bin_start_jitter == "adaptive": + raise AttributeError( + f"The input temporal_bin_start_jitter as string should only be 'adaptive'. Other options include float or int. Got {temporal_bin_start_jitter}" + ) + + +def check_transform_temporal_bin_start_jitter(temporal_bin_start_jitter, bin_interval): + check_temporal_bin_start_jitter(temporal_bin_start_jitter) + if isinstance(temporal_bin_start_jitter, str): + if temporal_bin_start_jitter == "adaptive": + jit = np.random.uniform(low=0, high=bin_interval) + elif type(temporal_bin_start_jitter) in [int, float]: + jit = temporal_bin_start_jitter + + return jit + + +def check_X_train(X_train): + # check type + + type_X_train = type(X_train) + if not isinstance(X_train, pd.core.frame.DataFrame): + raise TypeError(f"Input X should be type 'pd.core.frame.DataFrame'. Got {str(type_X_train)}") + + +def check_y_train(y_train): + type_y_train = str(type(y_train)) + if not isinstance(y_train, (pd.core.frame.DataFrame, pd.core.frame.Series, np.ndarray)): + raise TypeError( + f"Input y_train should be type 'pd.core.frame.DataFrame' or 'pd.core.frame.Series', or 'np.ndarray'. Got {str(type_y_train)}" + ) + + +def check_X_test(X_test): + check_X_train(X_test) + + +def check_prediciton_aggregation(aggregation): + if aggregation not in ["mean", "median"]: + raise ValueError(f"aggregation must be one of 'mean' and 'median'. Got {aggregation}") + + +def check_prediction_return(return_by_separate_ensembles, return_std): + if not isinstance(return_by_separate_ensembles, bool): + type_return_by_separate_ensembles = str(type(return_by_separate_ensembles)) + raise TypeError(f"return_by_separate_ensembles must be bool. Got {type_return_by_separate_ensembles}") + else: + if return_by_separate_ensembles and return_std: + warnings("return_by_separate_ensembles == True. Automatically setting return_std=False") + return_std = False + return return_by_separate_ensembles, return_std + + +def check_X_y_shape_match(X, y): + # check shape match + y_size = np.array(y).flatten().shape[0] + X_size = X.shape[0] + if not y_size == X_size: + raise ValueError(f"The shape of X and y should match. Got X: {X_size}, y: {y_size}") diff --git a/stemflow/utils/wrapper.py b/stemflow/utils/wrapper.py new file mode 100644 index 0000000..e4b0b60 --- /dev/null +++ b/stemflow/utils/wrapper.py @@ -0,0 +1,43 @@ +import warnings +from typing import Union + +import numpy as np +import pandas as pd +from sklearn.base import BaseEstimator + + +def _monkey_patched_predict_proba( + model: BaseEstimator, X_train: Union[pd.core.frame.DataFrame, np.ndarray] +) -> np.ndarray: + """the monkey patching predict_proba method + + Args: + model: the input model + X_train: input training data + + Returns: + predicted proba + """ + pred = model.predict(X_train) + pred = np.array(pred).reshape(-1, 1) + return np.concatenate([np.zeros(shape=pred.shape), pred], axis=1) + + +def model_wrapper(model: BaseEstimator) -> BaseEstimator: + """wrap a predict_proba function for those models who don't have + + Args: + model: + Input model + + Returns: + Wrapped model that has a `predict_proba` method + + """ + if "predict_proba" in dir(model): + return model + else: + warnings.warn("predict_proba function not in base_model. Monkey patching one.") + + model.predict_proba = _monkey_patched_predict_proba + return model