Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyangkang committed Jan 24, 2024
1 parent 0974836 commit a4b5861
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
2 changes: 1 addition & 1 deletion stemflow/model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from numpy import ndarray
from pandas.core.frame import DataFrame

from .utils import check_random_state
from .utils.validation import check_random_state


def ST_train_test_split(
Expand Down
20 changes: 12 additions & 8 deletions stemflow/utils/jitterrotation/jitterrotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def rotate_jitter(
rotation_angle: Union[int, float],
calibration_point_x_jitter: Union[int, float],
calibration_point_y_jitter: Union[int, float],
):
) -> Tuple[np.ndarray, np.ndarray]:
"""Rotate Normal lng, lat to jittered, rotated space
Args:
Expand All @@ -28,7 +28,7 @@ def rotate_jitter(
calibration_point_y_jitter (Union[int, float]): calibration_point_y_jitter
Returns:
tuple(np.ndarray, np.ndarray): newx, newy
Tuple[np.ndarray, np.ndarray]: newx, newy
"""
data = np.array([x_array, y_array]).T
angle = rotation_angle
Expand All @@ -48,7 +48,7 @@ def inverse_jitter_rotate(
rotation_angle: Union[int, float],
calibration_point_x_jitter: Union[int, float],
calibration_point_y_jitter: Union[int, float],
):
) -> Tuple[np.ndarray, np.ndarray]:
"""reverse jitter and rotation
Args:
Expand All @@ -57,6 +57,10 @@ def inverse_jitter_rotate(
rotation_angle (Union[int, float]): rotation angle
calibration_point_x_jitter (Union[int, float]): calibration_point_x_jitter
calibration_point_y_jitter (Union[int, float]): calibration_point_y_jitter
Returns:
Tuple[np.ndarray, np.ndarray]: newx, newy
"""
theta = -(rotation_angle / 360) * np.pi * 2
rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
Expand All @@ -78,15 +82,15 @@ def __init__(self) -> None:
pass

def rotate_jitter(point: np.ndarray, axis: np.ndarray, angle: Union[float, int]) -> np.ndarray:
"""_summary_
"""rotate_jitter 3d points
Args:
point (np.ndarray): shape of (X, 3)
axis (np.ndarray): shape of (3,)
angle (Union[float, int]): angle in degree
Returns:
np.ndarray: _description_
np.ndarray: rotated_point
"""
u = np.array(axis)
u = u / np.linalg.norm(u)
Expand Down Expand Up @@ -118,16 +122,16 @@ def rotate_jitter(point: np.ndarray, axis: np.ndarray, angle: Union[float, int])
rotated_point = np.dot(point, rotation_matrix)
return rotated_point

def inverse_rotate_jitter(point: np.ndarray, axis: np.ndarray, angle: Union[float, int]):
"""_summary_
def inverse_rotate_jitter(point: np.ndarray, axis: np.ndarray, angle: Union[float, int]) -> np.ndarray:
"""inverse rotate_jitter 3d points
Args:
point (np.ndarray): shape of (X, 3)
axis (np.ndarray): shape of (3,)
angle (Union[float, int]): angle in degree
Returns:
_type_: _description_
np.ndarray: inverse rotated point
"""
u = np.array(axis)
u = u / np.linalg.norm(u)
Expand Down

0 comments on commit a4b5861

Please sign in to comment.