Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Increase modularization & Removed multi-processing (feature in the future) #32

Merged
merged 6 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,25 @@

repos:
- repo: https://github.com/ambv/black
rev: 23.1.0
rev: 23.12.1
hooks:
- id: black
args: ["stemflow", "--line-length=120", "--target-version=py37"]

- repo: https://github.com/pycqa/flake8
rev: 6.0.0
rev: 7.0.0
hooks:
- id: flake8
args: ["--select=C,E,F,W,B,B950", "--max-line-length=120", "--ignore=E203,E501,W503,F401,F403"]

- repo: https://github.com/timothycrosley/isort
rev: 5.12.0
rev: 5.13.2
hooks:
- id: isort
args: ["-l 120", "--profile", "black", "."]

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: check-yaml
exclude: recipe/meta.yaml
Expand Down
87 changes: 31 additions & 56 deletions stemflow/gridding/QTree.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
import pandas as pd

from ..utils.generate_soft_colors import generate_soft_color
from ..utils.jitterrotation.jitterrotator import JitterRotator
from ..utils.validation import check_random_state
from .Q_blocks import Node, Point
from .Q_blocks import QNode, QPoint

os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
Expand All @@ -23,7 +24,7 @@


def recursive_subdivide(
node: Node,
node: QNode,
grid_len_lon_upper_threshold: Union[float, int],
grid_len_lon_lower_threshold: Union[float, int],
grid_len_lat_upper_threshold: Union[float, int],
Expand Down Expand Up @@ -53,7 +54,7 @@ def recursive_subdivide(
h_ = float(node.height / 2)

p = contains(node.x0, node.y0, w_, h_, node.points)
x1 = Node(node.x0, node.y0, w_, h_, p)
x1 = QNode(node.x0, node.y0, w_, h_, p)
recursive_subdivide(
x1,
grid_len_lon_upper_threshold,
Expand All @@ -64,7 +65,7 @@ def recursive_subdivide(
)

p = contains(node.x0, node.y0 + h_, w_, h_, node.points)
x2 = Node(node.x0, node.y0 + h_, w_, h_, p)
x2 = QNode(node.x0, node.y0 + h_, w_, h_, p)
recursive_subdivide(
x2,
grid_len_lon_upper_threshold,
Expand All @@ -75,7 +76,7 @@ def recursive_subdivide(
)

p = contains(node.x0 + w_, node.y0, w_, h_, node.points)
x3 = Node(node.x0 + w_, node.y0, w_, h_, p)
x3 = QNode(node.x0 + w_, node.y0, w_, h_, p)
recursive_subdivide(
x3,
grid_len_lon_upper_threshold,
Expand All @@ -86,7 +87,7 @@ def recursive_subdivide(
)

p = contains(node.x0 + w_, node.y0 + h_, w_, h_, node.points)
x4 = Node(node.x0 + w_, node.y0 + h_, w_, h_, p)
x4 = QNode(node.x0 + w_, node.y0 + h_, w_, h_, p)
recursive_subdivide(
x4,
grid_len_lon_upper_threshold,
Expand Down Expand Up @@ -204,17 +205,12 @@ def add_lon_lat_data(self, indexes: Sequence, x_array: Sequence, y_array: Sequen
if not len(x_array) == len(y_array) or not len(x_array) == len(indexes):
raise ValueError("input longitude and latitude and indexes not in same length!")

data = np.array([x_array, y_array]).T
angle = self.rotation_angle
r = angle / 360
theta = r * np.pi * 2
rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
data = data @ rotation_matrix
lon_new = (data[:, 0] + self.calibration_point_x_jitter).tolist()
lat_new = (data[:, 1] + self.calibration_point_y_jitter).tolist()
lon_new, lat_new = JitterRotator.rotate_jitter(
x_array, y_array, self.rotation_angle, self.calibration_point_x_jitter, self.calibration_point_y_jitter
)

for index, lon, lat in zip(indexes, lon_new, lat_new):
self.points.append(Point(index, lon, lat))
self.points.append(QPoint(index, lon, lat))

def generate_gridding_params(self):
"""generate the gridding params after data are added
Expand All @@ -233,15 +229,15 @@ def generate_gridding_params(self):

self.left_bottom_point = (left_bottom_point_x, left_bottom_point_y)
if self.lon_lat_equal_grid is True:
self.root = Node(
self.root = QNode(
left_bottom_point_x,
left_bottom_point_y,
max(self.grid_length_x, self.grid_length_y),
max(self.grid_length_x, self.grid_length_y),
self.points,
)
elif self.lon_lat_equal_grid is False:
self.root = Node(
self.root = QNode(
left_bottom_point_x, left_bottom_point_y, self.grid_length_x, self.grid_length_y, self.points
)
else:
Expand Down Expand Up @@ -272,58 +268,37 @@ def graph(self, scatter: bool = True, ax=None):

c = find_children(self.root)

# areas = set()
# width_set = set()
# height_set = set()
# for el in c:
# areas.add(el.width * el.height)
# width_set.add(el.width)
# height_set.add(el.height)

theta = -(self.rotation_angle / 360) * np.pi * 2
rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])

for n in c:
xy0_trans = np.array([[n.x0, n.y0]])
if self.calibration_point_x_jitter:
new_x = xy0_trans[:, 0] - self.calibration_point_x_jitter
else:
new_x = xy0_trans[:, 0]

if self.calibration_point_y_jitter:
new_y = xy0_trans[:, 1] - self.calibration_point_y_jitter
else:
new_y = xy0_trans[:, 1]
new_xy = np.array([[new_x[0], new_y[0]]]) @ rotation_matrix
new_x = new_xy[:, 0]
new_y = new_xy[:, 1]
old_x, old_y = JitterRotator.inverse_jitter_rotate(
[n.x0], [n.y0], self.rotation_angle, self.calibration_point_x_jitter, self.calibration_point_y_jitter
)

if ax is None:
plt.gcf().gca().add_patch(
patches.Rectangle(
(new_x, new_y), n.width, n.height, fill=False, angle=self.rotation_angle, color=the_color
(old_x, old_y), n.width, n.height, fill=False, angle=self.rotation_angle, color=the_color
)
)
else:
ax.add_patch(
patches.Rectangle(
(new_x, new_y), n.width, n.height, fill=False, angle=self.rotation_angle, color=the_color
(old_x, old_y), n.width, n.height, fill=False, angle=self.rotation_angle, color=the_color
)
)

x = np.array([point.x for point in self.points]) - self.calibration_point_x_jitter
y = np.array([point.y for point in self.points]) - self.calibration_point_y_jitter

data = np.array([x, y]).T @ rotation_matrix
if scatter:
old_x, old_y = JitterRotator.inverse_jitter_rotate(
[point.x for point in self.points],
[point.y for point in self.points],
self.rotation_angle,
self.calibration_point_x_jitter,
self.calibration_point_y_jitter,
)

if ax is None:
plt.scatter(
data[:, 0].tolist(), data[:, 1].tolist(), s=0.2, c="tab:blue", alpha=0.7
) # plots the points as red dots
plt.scatter(old_x, old_y, s=0.2, c="tab:blue", alpha=0.7) # plots the points as red dots
else:
ax.scatter(
data[:, 0].tolist(), data[:, 1].tolist(), s=0.2, c="tab:blue", alpha=0.7
) # plots the points as red dots
ax.scatter(old_x, old_y, s=0.2, c="tab:blue", alpha=0.7) # plots the points as red dots
return

def get_final_result(self) -> pandas.core.frame.DataFrame:
Expand All @@ -333,21 +308,21 @@ def get_final_result(self) -> pandas.core.frame.DataFrame:
results (DataFrame): A pandas dataframe containing the gridding information
"""
all_grids = find_children(self.root)
point_indexes_list = []
# point_indexes_list = []
point_grid_width_list = []
point_grid_height_list = []
point_grid_points_number_list = []
calibration_point_list = []
for grid in all_grids:
point_indexes_list.append([point.index for point in grid.points])
# point_indexes_list.append([point.index for point in grid.points])
point_grid_width_list.append(grid.width)
point_grid_height_list.append(grid.height)
point_grid_points_number_list.append(len(grid.points))
calibration_point_list.append((round(grid.x0, 6), round(grid.y0, 6)))

result = pd.DataFrame(
{
"checklist_indexes": point_indexes_list,
# "checklist_indexes": point_indexes_list,
"stixel_indexes": list(range(len(point_grid_width_list))),
"stixel_width": point_grid_width_list,
"stixel_height": point_grid_height_list,
Expand Down
13 changes: 7 additions & 6 deletions stemflow/gridding/Q_blocks.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""I call this Q_blocks because they are essential blocks for QTree methods"""

from typing import Tuple, Union
from collections.abc import Sequence
from typing import List, Tuple, Union

from ..utils.sphere.coordinate_transform import lonlat_spherical_transformer
from ..utils.sphere.distance import spherical_distance_from_coordinates


class Point:
class QPoint:
"""A Point class for recording data points"""

def __init__(self, index, x, y):
Expand All @@ -15,7 +16,7 @@ def __init__(self, index, x, y):
self.index = index


class Node:
class QNode:
"""A tree-like division node class"""

def __init__(
Expand All @@ -24,7 +25,7 @@ def __init__(
y0: Union[float, int],
w: Union[float, int],
h: Union[float, int],
points: list[Point],
points: Sequence,
):
self.x0 = x0
self.y0 = y0
Expand All @@ -43,7 +44,7 @@ def get_points(self):
return self.points


class Grid:
class QGrid:
"""Grid class for STEM (fixed gird size)"""

def __init__(self, x_index, y_index, x_range, y_range):
Expand Down Expand Up @@ -76,7 +77,7 @@ def __init__(
inclination2: Union[float, int],
azimuth3: Union[float, int],
inclination3: Union[float, int],
points: list[Sphere_Point],
points: Sequence,
):
self.x0 = x0
self.y0 = y0
Expand Down
Loading