Skip to content

Commit

Permalink
fix test bug and add docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyangkang committed Jan 24, 2024
1 parent 000013a commit 0974836
Show file tree
Hide file tree
Showing 15 changed files with 185 additions and 178 deletions.
6 changes: 3 additions & 3 deletions stemflow/model/SphereAdaSTEM.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(
ensemble_fold: int = 10,
min_ensemble_required: int = 7,
grid_len_upper_threshold: Union[float, int] = 8000,
grid_len_lower_threshold: Union[float, int] = 100,
grid_len_lower_threshold: Union[float, int] = 500,
points_lower_threshold: int = 50,
stixel_training_size_threshold: int = None,
temporal_start: Union[float, int] = 1,
Expand Down Expand Up @@ -600,7 +600,7 @@ def __init__(
ensemble_fold=10,
min_ensemble_required=7,
grid_len_upper_threshold=8000,
grid_len_lower_threshold=100,
grid_len_lower_threshold=500,
points_lower_threshold=50,
stixel_training_size_threshold=None,
temporal_start=1,
Expand Down Expand Up @@ -759,7 +759,7 @@ def __init__(
ensemble_fold=10,
min_ensemble_required=7,
grid_len_upper_threshold=8000,
grid_len_lower_threshold=100,
grid_len_lower_threshold=500,
points_lower_threshold=50,
stixel_training_size_threshold=None,
temporal_start=1,
Expand Down
1 change: 0 additions & 1 deletion stemflow/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
from .validation import check_random_state
62 changes: 6 additions & 56 deletions stemflow/utils/jitterrotation/jitterrotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,65 +2,13 @@

import numpy as np

# import geopandas as gpd


class JitterRotator:
"""2D jitter rotator."""

def __init__():
pass

# @classmethod
# def rotate_jitter_gpd(cls,
# df: gpd.geodataframe.GeoDataFrame,
# rotation_angle: Union[int, float],
# calibration_point_x_jitter: Union[int, float],
# calibration_point_y_jitter: Union[int, float]
# ) -> gpd.geodataframe.GeoDataFrame:
# """Rotate Normal lng, lat to jittered, rotated space

# Args:
# x_array (np.ndarray): input lng/x
# y_array (np.ndarray): input lat/y
# rotation_angle (Union[int, float]): rotation angle
# calibration_point_x_jitter (Union[int, float]): calibration_point_x_jitter
# calibration_point_y_jitter (Union[int, float]): calibration_point_y_jitter

# Returns:
# tuple(np.ndarray, np.ndarray): newx, newy
# """
# transformed_series = df.rotate(
# rotation_angle, origin=(0,0)
# ).affine_transform(
# [1,0,0,1,calibration_point_x_jitter,calibration_point_y_jitter]
# )

# df1 = gpd.GeoDataFrame(df, geometry=transformed_series)

# return df1

# @classmethod
# def inverse_jitter_rotate_gpd(cls,
# df_rotated: gpd.geodataframe.GeoDataFrame,
# rotation_angle: Union[int, float],
# calibration_point_x_jitter: Union[int, float],
# calibration_point_y_jitter: Union[int, float]
# ) -> gpd.geodataframe.GeoDataFrame:
# """reverse jitter and rotation

# Args:
# x_array_rotated (np.ndarray): input lng/x
# y_array_rotated (np.ndarray): input lng/x
# rotation_angle (Union[int, float]): rotation angle
# calibration_point_x_jitter (Union[int, float]): calibration_point_x_jitter
# calibration_point_y_jitter (Union[int, float]): calibration_point_y_jitter
# """

# return df_rotated.affine_transform(
# [1,0,0,1,-calibration_point_x_jitter,-calibration_point_y_jitter]
# ).rotate(
# -rotation_angle, origin=(0,0)
# )

@classmethod
def rotate_jitter(
cls,
Expand Down Expand Up @@ -124,10 +72,12 @@ def inverse_jitter_rotate(


class Sphere_Jitterrotator:
"""3D jitter rotator"""

def __init__(self) -> None:
pass

def rotate_jitter(point: np.ndarray, axis: np.ndarray, angle: Union[float, int]):
def rotate_jitter(point: np.ndarray, axis: np.ndarray, angle: Union[float, int]) -> np.ndarray:
"""_summary_
Args:
Expand All @@ -136,7 +86,7 @@ def rotate_jitter(point: np.ndarray, axis: np.ndarray, angle: Union[float, int])
angle (Union[float, int]): angle in degree
Returns:
_type_: _description_
np.ndarray: _description_
"""
u = np.array(axis)
u = u / np.linalg.norm(u)
Expand Down
15 changes: 2 additions & 13 deletions stemflow/utils/quadtree.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
# import libraries
"A function module to get quadtree results for 2D indexing system. Returns ensemble_df and plotting axes."

import os
import warnings

# from collections.abc import Sequence
# from functools import partial
# from itertools import repeat
# from multiprocessing import Pool
from typing import Tuple, Union

import matplotlib

# import matplotlib.patches as patches
import matplotlib.pyplot as plt # plotting libraries
import numpy as np
import pandas
Expand All @@ -21,11 +15,6 @@
from ..gridding.QuadGrid import QuadGrid
from .validation import check_transform_spatio_bin_jitter_magnitude, check_transform_temporal_bin_start_jitter

# from tqdm.contrib.concurrent import process_map
# from .generate_soft_colors import generate_soft_color
# from .validation import check_random_state


os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
Expand Down
46 changes: 41 additions & 5 deletions stemflow/utils/sphere/Icosahedron.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
"Functions for the initial icosahedron in spherical indexing system"

import numpy as np

from .coordinate_transform import lonlat_cartesian_3D_transformer


def get_Icosahedron_vertices():
def get_Icosahedron_vertices() -> np.ndarray:
"""Return the 12 vertices of icosahedron
Returns:
np.ndarray: (n_vertices, 3D_coordinates)
"""
phi = (1 + np.sqrt(5)) / 2
vertices = np.array(
[
Expand All @@ -24,7 +31,17 @@ def get_Icosahedron_vertices():
return vertices


def calc_and_judge_distance(v1, v2, v3):
def calc_and_judge_distance(v1: np.ndarray, v2: np.ndarray, v3: np.ndarray) -> bool:
"""Determine if the three points have same distance with each other
Args:
v1 (np.ndarray): point 1
v2 (np.ndarray): point 1
v3 (np.ndarray): point 1
Returns:
bool: Whether have same pair-wise distance
"""
d1 = np.sum((np.array(v1) - np.array(v2)) ** 2) ** (1 / 2)
d2 = np.sum((np.array(v1) - np.array(v3)) ** 2) ** (1 / 2)
d3 = np.sum((np.array(v2) - np.array(v3)) ** 2) ** (1 / 2)
Expand All @@ -34,7 +51,12 @@ def calc_and_judge_distance(v1, v2, v3):
return False


def get_Icosahedron_faces():
def get_Icosahedron_faces() -> np.ndarray:
"""Get icosahedron faces
Returns:
np.ndarray: shape (20,3,3). (faces, point, 3d_dimension)
"""
vertices = get_Icosahedron_vertices()

face_list = []
Expand All @@ -51,7 +73,12 @@ def get_Icosahedron_faces():
return face_list


def get_earth_Icosahedron_vertices_and_faces_lonlat():
def get_earth_Icosahedron_vertices_and_faces_lonlat() -> [np.ndarray, np.ndarray]:
"""Get vertices and faces in lon, lat
Returns:
[np.ndarray, np.ndarray]: vertices, faces
"""
# earth_radius_km=6371.0
# get Icosahedron vertices and faces
vertices = get_Icosahedron_vertices()
Expand All @@ -68,7 +95,16 @@ def get_earth_Icosahedron_vertices_and_faces_lonlat():
return np.stack([vertices_lng, vertices_lat], axis=-1), np.stack([faces_lng, faces_lat], axis=-1)


def get_earth_Icosahedron_vertices_and_faces_3D(radius=1):
def get_earth_Icosahedron_vertices_and_faces_3D(radius=1) -> [np.ndarray, np.ndarray]:
"""Get vertices and faces in lon, lat
Args:
radius (Union[int, float]): radius of earth in km.
Returns:
[np.ndarray, np.ndarray]: vertices, faces
"""

# earth_radius_km=6371.0
# get Icosahedron vertices and faces
vertices = get_Icosahedron_vertices()
Expand Down
57 changes: 53 additions & 4 deletions stemflow/utils/sphere/coordinate_transform.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,29 @@
from collections.abc import Sequence
from typing import Tuple, Union

import numpy as np

from ...gridding.Q_blocks import QPoint_3D


class lonlat_cartesian_3D_transformer:
"""Transformer between longitude,latitude and 3d dimension (x,y,z)."""

def __init__(self) -> None:
pass

def transform(lng, lat, radius=6371):
def transform(lng: np.ndarray, lat: np.ndarray, radius: float = 6371.0) -> Tuple[np.ndarray, np.ndarray]:
"""Transform lng, lat to x,y,z
Args:
lng (np.ndarray): lng
lat (np.ndarray): lat
radius (float, optional): radius of earth in km. Defaults to 6371.
Returns:
Tuple[np.ndarray, np.ndarray]: x,y,z
"""

# Convert latitude and longitude from degrees to radians
lat_rad = np.radians(lat)
lng_rad = np.radians(lng)
Expand All @@ -21,15 +35,38 @@ def transform(lng, lat, radius=6371):

return x, y, z

def inverse_transform(x, y, z, r=None):
def inverse_transform(
x: np.ndarray, y: np.ndarray, z: np.ndarray, r: float = None
) -> Tuple[np.ndarray, np.ndarray]:
"""transform x,y,z to lon, lat
Args:
x (np.ndarray): x
y (np.ndarray): y
z (np.ndarray): z
r (float, optional): Radius of your spherical coordinate. If not given, calculate from x,y,z. Defaults to None.
Returns:
Tuple[np.ndarray, np.ndarray]: longitude, latitude
"""
if r is None:
r = np.sqrt(x**2 + y**2 + z**2)
latitude = np.degrees(np.arcsin(z / r))
longitude = np.degrees(np.arctan2(y, x))
return longitude, latitude


def get_midpoint_3D(p1, p2, radius=6371):
def get_midpoint_3D(p1: QPoint_3D, p2: QPoint_3D, radius: float = 6371.0) -> QPoint_3D:
"""Get the mid-point of three QPoint_3D objet (vector)
Args:
p1 (QPoint_3D): p1
p2 (QPoint_3D): p2
radius (float, optional): radius of earth in km. Defaults to 6371.0.
Returns:
QPoint_3D: mid-point.
"""
v1 = np.array([p1.x, p1.y, p1.z])
v2 = np.array([p2.x, p2.y, p2.z])

Expand All @@ -41,7 +78,19 @@ def get_midpoint_3D(p1, p2, radius=6371):
return p3


def continuous_interpolation_3D_plotting(p1, p2, radius=6371):
def continuous_interpolation_3D_plotting(
p1: np.ndarray, p2: np.ndarray, radius: float = 6371.0
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""interpolate 10 points on earth surface between the given two points. For plotting.
Args:
p1 (np.ndarray): p1
p2 (np.ndarray): p2
radius (float, optional): radius of earth in km. Defaults to 6371.0.
Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray]: 10 x, 10 y, 10 z
"""
v1 = np.array([p1[0], p1[1], p1[2]])
v2 = np.array([p2[0], p2[1], p2[2]])

Expand Down
39 changes: 25 additions & 14 deletions stemflow/utils/sphere/discriminant_formula.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,40 @@
import numpy as np

# def sign(target, p2, p3):
# return np.sign((target[:,0] - p3[0]) * (p2[1] - p3[1]) - (p2[0] - p3[0]) * (target[:,1] - p3[1]))
from typing import Union

# def point_in_triangle(targets, p1, p2, p3):
import numpy as np

# d1 = sign(targets, p1, p2)
# d2 = sign(targets, p2, p3)
# d3 = sign(targets, p3, p1)

# signs = np.column_stack([d1<0,d2<0,d3<0])
# has_neg = signs.sum(axis=1)
# has_pos = -signs.sum(axis=1)
# return np.logical_not(np.logical_and(has_neg, has_pos))
def is_point_inside_triangle(point: np.ndarray, A: np.ndarray, B: np.ndarray, C: np.ndarray) -> np.ndarray:
"""Check if a point is inside a triangle
Args:
point (np.ndarray): point in vector. Shape (X, dimension).
A (np.ndarray): point A of triangle. Shape (dimension).
B (np.ndarray): point B of triangle. Shape (dimension).
C (np.ndarray): point C of triangle. Shape (dimension).
def is_point_inside_triangle(point, A, B, C):
Returns:
np.ndarray: inside or not
"""
u = np.cross(C - B, point - B) @ np.cross(C - B, A - B)
v = np.cross(A - C, point - C) @ np.cross(A - C, B - C)
w = np.cross(B - A, point - A) @ np.cross(B - A, C - A)

return (u >= 0) & (v >= 0) & (w >= 0)


def intersect_triangle_plane(P0, V, A, B, C):
def intersect_triangle_plane(P0: np.ndarray, V: np.ndarray, A: np.ndarray, B: np.ndarray, C: np.ndarray) -> np.ndarray:
"""Get if the ray go through the triangle of A,B,C
Args:
P0 (np.ndarray): start point of ray
V (np.ndarray): A point that the ray go through
A (np.ndarray): point A of triangle. Shape (dimension).
B (np.ndarray): point A of triangle. Shape (dimension).
C (np.ndarray): point A of triangle. Shape (dimension).
Returns:
np.ndarray: Whether the point go through triangle ABC
"""
# Calculate the normal vector of the plane
N = np.cross(B - A, C - A)

Expand Down
Loading

0 comments on commit 0974836

Please sign in to comment.