Skip to content

Commit

Permalink
Merge pull request #101 from ecmwf/feature/local_latlon_grid
Browse files Browse the repository at this point in the history
Feature/local latlon grid
  • Loading branch information
mathleur authored Feb 1, 2024
2 parents b4bb7ea + c666d0f commit f4573ac
Show file tree
Hide file tree
Showing 26 changed files with 2,703 additions and 2,298 deletions.
534 changes: 5 additions & 529 deletions polytope/datacube/datacube_axis.py

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions polytope/datacube/index_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ def leaves_with_ancestors(self):
self._collect_leaf_nodes(leaves)
return leaves

def pprint_2(self, level=0):
if self.axis.name == "root":
print("\n")
print("\t" * level + "\u21b3" + str(self))
for child in self.children:
child.pprint_2(level + 1)

def _collect_leaf_nodes_old(self, leaves):
if len(self.children) == 0:
leaves.append(self)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import bisect
import math
from copy import deepcopy
from typing import List

from .datacube_cyclic import DatacubeAxisCyclic


def cyclic(cls):
if cls.is_cyclic:

def update_range():
for transform in cls.transformations:
if isinstance(transform, DatacubeAxisCyclic):
transformation = transform
cls.range = transformation.range

def to_intervals(range):
update_range()
if range[0] == -math.inf:
range[0] = cls.range[0]
if range[1] == math.inf:
range[1] = cls.range[1]
axis_lower = cls.range[0]
axis_upper = cls.range[1]
axis_range = axis_upper - axis_lower
lower = range[0]
upper = range[1]
intervals = []
if lower < axis_upper:
# In this case, we want to go from lower to the first remapped cyclic axis upper
# or the asked upper range value.
# For example, if we have cyclic range [0,360] and we want to break [-270,180] into intervals,
# we first want to obtain [-270, 0] as the first range, where 0 is the remapped cyclic axis upper
# but if we wanted to break [-270, -180] into intervals, we would want to get [-270,-180],
# where -180 is the asked upper range value.
loops = int((axis_upper - lower) / axis_range)
remapped_up = axis_upper - (loops) * axis_range
new_upper = min(upper, remapped_up)
else:
# In this case, since lower >= axis_upper, we need to either go to the asked upper range
# or we need to go to the first remapped cyclic axis upper which is higher than lower
new_upper = min(axis_upper + axis_range, upper)
while new_upper < lower:
new_upper = min(new_upper + axis_range, upper)
intervals.append([lower, new_upper])
# Now that we have established what the first interval should be, we should just jump from cyclic range
# to cyclic range until we hit the asked upper range value.
new_up = deepcopy(new_upper)
while new_up < upper:
new_upper = new_up
new_up = min(upper, new_upper + axis_range)
intervals.append([new_upper, new_up])
# Once we have added all the in-between ranges, we need to add the last interval
intervals.append([new_up, upper])
return intervals

def _remap_range_to_axis_range(range):
update_range()
axis_lower = cls.range[0]
axis_upper = cls.range[1]
axis_range = axis_upper - axis_lower
lower = range[0]
upper = range[1]
if lower < axis_lower:
# In this case we need to calculate the number of loops between the axis lower
# and the lower to recenter the lower
loops = int((axis_lower - lower - cls.tol) / axis_range)
return_lower = lower + (loops + 1) * axis_range
return_upper = upper + (loops + 1) * axis_range
elif lower >= axis_upper:
# In this case we need to calculate the number of loops between the axis upper
# and the lower to recenter the lower
loops = int((lower - axis_upper) / axis_range)
return_lower = lower - (loops + 1) * axis_range
return_upper = upper - (loops + 1) * axis_range
else:
# In this case, the lower value is already in the right range
return_lower = lower
return_upper = upper
return [return_lower, return_upper]

def _remap_val_to_axis_range(value):
return_range = _remap_range_to_axis_range([value, value])
return return_range[0]

def remap(range: List):
update_range()
if cls.range[0] - cls.tol <= range[0] <= cls.range[1] + cls.tol:
if cls.range[0] - cls.tol <= range[1] <= cls.range[1] + cls.tol:
# If we are already in the cyclic range, return it
return [range]
elif abs(range[0] - range[1]) <= 2 * cls.tol:
# If we have a range that is just one point, then it should still be counted
# and so we should take a small interval around it to find values inbetween
range = [
_remap_val_to_axis_range(range[0]) - cls.tol,
_remap_val_to_axis_range(range[0]) + cls.tol,
]
return [range]
range_intervals = cls.to_intervals(range)
ranges = []
for interval in range_intervals:
if abs(interval[0] - interval[1]) > 0:
# If the interval is not just a single point, we remap it to the axis range
range = _remap_range_to_axis_range([interval[0], interval[1]])
up = range[1]
low = range[0]
if up < low:
# Make sure we remap in the right order
ranges.append([up - cls.tol, low + cls.tol])
else:
ranges.append([low - cls.tol, up + cls.tol])
return ranges

old_find_indexes = cls.find_indexes

def find_indexes(path, datacube):
return old_find_indexes(path, datacube)

old_unmap_path_key = cls.unmap_path_key

def unmap_path_key(key_value_path, leaf_path, unwanted_path):
value = key_value_path[cls.name]
for transform in cls.transformations:
if isinstance(transform, DatacubeAxisCyclic):
if cls.name == transform.name:
new_val = _remap_val_to_axis_range(value)
key_value_path[cls.name] = new_val
key_value_path, leaf_path, unwanted_path = old_unmap_path_key(key_value_path, leaf_path, unwanted_path)
return (key_value_path, leaf_path, unwanted_path)

old_unmap_to_datacube = cls.unmap_to_datacube

def unmap_to_datacube(path, unmapped_path):
(path, unmapped_path) = old_unmap_to_datacube(path, unmapped_path)
return (path, unmapped_path)

old_find_indices_between = cls.find_indices_between

def find_indices_between(index_ranges, low, up, datacube, method=None):
update_range()
indexes_between_ranges = []

if method != "surrounding" or method != "nearest":
return old_find_indices_between(index_ranges, low, up, datacube, method)
else:
for indexes in index_ranges:
if cls.name in datacube.complete_axes:
start = indexes.searchsorted(low, "left")
end = indexes.searchsorted(up, "right")
else:
start = bisect.bisect_left(indexes, low)
end = bisect.bisect_right(indexes, up)

if start - 1 < 0:
index_val_found = indexes[-1:][0]
indexes_between_ranges.append([index_val_found])
if end + 1 > len(indexes):
index_val_found = indexes[:2][0]
indexes_between_ranges.append([index_val_found])
start = max(start - 1, 0)
end = min(end + 1, len(indexes))
if cls.name in datacube.complete_axes:
indexes_between = indexes[start:end].to_list()
else:
indexes_between = indexes[start:end]
indexes_between_ranges.append(indexes_between)
return indexes_between_ranges

def offset(range):
# We first unpad the range by the axis tolerance to make sure that
# we find the wanted range of the cyclic axis since we padded by the axis tolerance before.
# Also, it's safer that we find the offset of a value inside the range instead of on the border
unpadded_range = [range[0] + 1.5 * cls.tol, range[1] - 1.5 * cls.tol]
cyclic_range = _remap_range_to_axis_range(unpadded_range)
offset = unpadded_range[0] - cyclic_range[0]
return offset

cls.to_intervals = to_intervals
cls.remap = remap
cls.offset = offset
cls.find_indexes = find_indexes
cls.unmap_to_datacube = unmap_to_datacube
cls.find_indices_between = find_indices_between
cls.unmap_path_key = unmap_path_key

return cls
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .datacube_transformations import DatacubeAxisTransformation
from ..datacube_transformations import DatacubeAxisTransformation


class DatacubeAxisCyclic(DatacubeAxisTransformation):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from copy import deepcopy
from importlib import import_module

from ..datacube_transformations import DatacubeAxisTransformation


class DatacubeMapper(DatacubeAxisTransformation):
# Needs to implements DatacubeAxisTransformation methods

def __init__(self, name, mapper_options):
self.transformation_options = mapper_options
self.grid_type = mapper_options["type"]
self.grid_resolution = mapper_options["resolution"]
self.grid_axes = mapper_options["axes"]
self.local_area = []
if "local" in mapper_options.keys():
self.local_area = mapper_options["local"]
self.old_axis = name
self._final_transformation = self.generate_final_transformation()
self._final_mapped_axes = self._final_transformation._mapped_axes
self._axis_reversed = self._final_transformation._axis_reversed

def generate_final_transformation(self):
map_type = _type_to_datacube_mapper_lookup[self.grid_type]
module = import_module("polytope.datacube.transformations.datacube_mappers.mapper_types." + self.grid_type)
constructor = getattr(module, map_type)
transformation = deepcopy(constructor(self.old_axis, self.grid_axes, self.grid_resolution, self.local_area))
return transformation

def blocked_axes(self):
return []

def unwanted_axes(self):
return [self._final_mapped_axes[0]]

def transformation_axes_final(self):
final_axes = self._final_mapped_axes
return final_axes

# Needs to also implement its own methods

def change_val_type(self, axis_name, values):
# the new axis_vals created will be floats
return [0.0]

def _mapped_axes(self):
# NOTE: Each of the mapper method needs to call it's sub mapper method
final_axes = self._final_mapped_axes
return final_axes

def _base_axis(self):
pass

def _resolution(self):
pass

def first_axis_vals(self):
return self._final_transformation.first_axis_vals()

def second_axis_vals(self, first_val):
return self._final_transformation.second_axis_vals(first_val)

def map_first_axis(self, lower, upper):
return self._final_transformation.map_first_axis(lower, upper)

def map_second_axis(self, first_val, lower, upper):
return self._final_transformation.map_second_axis(first_val, lower, upper)

def find_second_idx(self, first_val, second_val):
return self._final_transformation.find_second_idx(first_val, second_val)

def unmap_first_val_to_start_line_idx(self, first_val):
return self._final_transformation.unmap_first_val_to_start_line_idx(first_val)

def unmap(self, first_val, second_val):
return self._final_transformation.unmap(first_val, second_val)


_type_to_datacube_mapper_lookup = {
"octahedral": "OctahedralGridMapper",
"healpix": "HealpixGridMapper",
"regular": "RegularGridMapper",
"reduced_ll": "ReducedLatLonMapper",
"local_regular": "LocalRegularGridMapper",
}
Loading

0 comments on commit f4573ac

Please sign in to comment.