Skip to content

Commit

Permalink
Merge pull request #130 from ecmwf/feature/different_tree_compression…
Browse files Browse the repository at this point in the history
…_merged_develop

Feature/different tree compression merged develop
  • Loading branch information
mathleur authored May 7, 2024
2 parents 9606d51 + 6a28c43 commit ace3975
Show file tree
Hide file tree
Showing 43 changed files with 490 additions and 201 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ site
example_eo
example_mri
.mypy_cache
*.req
*.req
polytope_python.egg-info
3 changes: 0 additions & 3 deletions performance/performance_many_num_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
# Create a dataarray with 3 labelled axes using different index types
options = {
"values": {"mapper": {"type": "octahedral", "resolution": 1280, "axes": ["latitude", "longitude"]}},
# "date": {"merge": {"with": "time", "linkers": ["T", "00"]}},
# "step": {"type_change": "int"},
# "number": {"type_change": "int"},
"longitude": {"cyclic": [0, 360]},
"latitude": {"reverse": {True}},
}
Expand Down
65 changes: 58 additions & 7 deletions polytope/datacube/backends/datacube.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import argparse
import logging
from abc import ABC, abstractmethod
from typing import Any
from typing import Any, List, Literal, Optional, Union

import xarray as xr
from conflator import ConfigModel, Conflator
from pydantic import ConfigDict

from ...utility.combinatorics import validate_axes
from ..datacube_axis import DatacubeAxis
Expand Down Expand Up @@ -35,6 +38,7 @@ def __init__(self, axis_options=None, datacube_options=None):
self.compressed_grid_axes = []
self.compressed_axes = []
self.merged_axes = []
self.unwanted_path = {}

@abstractmethod
def get(self, requests: IndexTree) -> Any:
Expand All @@ -50,15 +54,16 @@ def validate(self, axes):

def _create_axes(self, name, values, transformation_type_key, transformation_options):
# first check what the final axes are for this axis name given transformations
transformation_options = transformation_type_key
final_axis_names = DatacubeAxisTransformation.get_final_axes(
name, transformation_type_key, transformation_options
name, transformation_type_key.name, transformation_options
)
transformation = DatacubeAxisTransformation.create_transform(
name, transformation_type_key, transformation_options
name, transformation_type_key.name, transformation_options
)

# do not compress merged axes
if transformation_type_key == "merge":
if transformation_type_key.name == "merge":
self.merged_axes.append(name)
self.merged_axes.append(final_axis_names)

Expand All @@ -78,16 +83,16 @@ def _create_axes(self, name, values, transformation_type_key, transformation_opt
if self._axes is None or axis_name not in self._axes.keys():
DatacubeAxis.create_standard(axis_name, values, self)
# add transformation tag to axis, as well as transformation options for later
setattr(self._axes[axis_name], has_transform[transformation_type_key], True) # where has_transform is a
# factory inside datacube_transformations to set the has_transform, is_cyclic etc axis properties
setattr(self._axes[axis_name], has_transform[transformation_type_key.name], True) # where has_transform is
# a factory inside datacube_transformations to set the has_transform, is_cyclic etc axis properties
# add the specific transformation handled here to the relevant axes
# Modify the axis to update with the tag

if transformation not in self._axes[axis_name].transformations: # Avoids duplicates being stored
self._axes[axis_name].transformations.append(transformation)

def _add_all_transformation_axes(self, options, name, values):
for transformation_type_key in options.keys():
for transformation_type_key in options.transformations:
if transformation_type_key != "cyclic":
self.transformed_axes.append(name)
self._create_axes(name, values, transformation_type_key, options)
Expand Down Expand Up @@ -143,6 +148,52 @@ def remap_path(self, path: DatacubePath):
path[key] = self._axes[key].remap([value, value])[0][0]
return path

@staticmethod
def create_axes_config(axis_options):
class TransformationConfig(ConfigModel):
model_config = ConfigDict(extra="forbid")
name: str = ""

class CyclicConfig(TransformationConfig):
name: Literal["cyclic"]
range: List[float] = [0]

class MapperConfig(TransformationConfig):
name: Literal["mapper"]
type: str = ""
resolution: Union[int, List[int]] = 0
axes: List[str] = [""]
local: Optional[List[float]] = None

class ReverseConfig(TransformationConfig):
name: Literal["reverse"]
is_reverse: bool = False

class TypeChangeConfig(TransformationConfig):
name: Literal["type_change"]
type: str = "int"

class MergeConfig(TransformationConfig):
name: Literal["merge"]
other_axis: str = ""
linkers: List[str] = [""]

action_subclasses_union = Union[CyclicConfig, MapperConfig, ReverseConfig, TypeChangeConfig, MergeConfig]

class AxisConfig(ConfigModel):
axis_name: str = ""
transformations: list[action_subclasses_union]

class Config(ConfigModel):
config: list[AxisConfig] = []

parser = argparse.ArgumentParser(allow_abbrev=False)
axis_config = Conflator(app_name="polytope", model=Config, cli=False, argparser=parser).load()
if axis_options.get("config"):
axis_config = Config(config=axis_options.get("config"))

return axis_config

@staticmethod
def create(datacube, axis_options: dict, datacube_options={}):
if isinstance(datacube, (xr.core.dataarray.DataArray, xr.core.dataset.Dataset)):
Expand Down
19 changes: 13 additions & 6 deletions polytope/datacube/backends/fdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from copy import deepcopy
from itertools import product

import numpy as np
import pygribjump as pygj

from ...utility.geometry import nearest_pt
Expand All @@ -19,6 +18,7 @@ def __init__(self, config=None, axis_options=None, datacube_options=None):
logging.info("Created an FDB datacube with options: " + str(axis_options))

self.unwanted_path = {}
self.axis_options = Datacube.create_axes_config(axis_options).config

partial_request = config
# Find values in the level 3 FDB datacube
Expand All @@ -31,15 +31,23 @@ def __init__(self, config=None, axis_options=None, datacube_options=None):
self.fdb_coordinates["values"] = []
for name, values in self.fdb_coordinates.items():
values.sort()
options = axis_options.get(name, None)
options = None
for opt in self.axis_options:
if opt.axis_name == name:
options = opt
# options = axis_options.get(name, None)
self._check_and_add_axes(options, name, values)
self.treated_axes.append(name)
self.complete_axes.append(name)

# add other options to axis which were just created above like "lat" for the mapper transformations for eg
for name in self._axes:
if name not in self.treated_axes:
options = axis_options.get(name, None)
options = None
for opt in self.axis_options:
if opt.axis_name == name:
options = opt
# options = axis_options.get(name, None)
val = self._axes[name].type
self._check_and_add_axes(options, name, val)

Expand All @@ -65,8 +73,6 @@ def get(self, requests: IndexTree):
interm_branch_tuple_values = []
for key in compressed_request[0].keys():
# remove the tuple of the request when we ask the fdb

# TODO: here, would need to take care of axes that are merged and unmerged, which need to be carefully decompressed
interm_branch_tuple_values.append(compressed_request[0][key])
request_combis = product(*interm_branch_tuple_values)

Expand All @@ -92,7 +98,8 @@ def get(self, requests: IndexTree):
# request[0][key] = request[0][key][0]
# branch_tuple_combi = product(*interm_branch_tuple_values)
# # TODO: now build the relevant requests from this and ask gj for them
# # TODO: then group the output values together to fit back with the original compressed request and continue
# # TODO: then group the output values together to fit back with the original compressed request
# # and continue
# new_requests = []
# for combi in branch_tuple_combi:
# new_request = {}
Expand Down
87 changes: 55 additions & 32 deletions polytope/datacube/backends/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,68 +3,91 @@
import numpy as np
import xarray as xr

from .datacube import Datacube, IndexTree
from .datacube import Datacube


class XArrayDatacube(Datacube):
"""Xarray arrays are labelled, axes can be defined as strings or integers (e.g. "time" or 0)."""

def __init__(self, dataarray: xr.DataArray, axis_options=None, datacube_options=None):
super().__init__(axis_options, datacube_options)
if axis_options is None:
axis_options = {}
if datacube_options is None:
datacube_options = {}
self.axis_options = Datacube.create_axes_config(axis_options).config
self.axis_counter = 0
self._axes = None
self.dataarray = dataarray

for name, values in dataarray.coords.variables.items():
options = None
for opt in self.axis_options:
if opt.axis_name == name:
options = opt
if name in dataarray.dims:
options = self.axis_options.get(name, None)
self._check_and_add_axes(options, name, values)
self.treated_axes.append(name)
self.complete_axes.append(name)
else:
if self.dataarray[name].dims == ():
options = self.axis_options.get(name, None)
self._check_and_add_axes(options, name, values)
self.treated_axes.append(name)
for name in dataarray.dims:
if name not in self.treated_axes:
options = self.axis_options.get(name, None)
options = None
for opt in self.axis_options:
if opt.axis_name == name:
options = opt
val = dataarray[name].values[0]
self._check_and_add_axes(options, name, val)
self.treated_axes.append(name)
# add other options to axis which were just created above like "lat" for the mapper transformations for eg
for name in self._axes:
if name not in self.treated_axes:
options = self.axis_options.get(name, None)
options = None
for opt in self.axis_options:
if opt.axis_name == name:
options = opt
val = self._axes[name].type
self._check_and_add_axes(options, name, val)

def get(self, requests: IndexTree):
for r in requests.leaves:
path = r.flatten()
if len(path.items()) == self.axis_counter:
# TODO: need to undo the tuples in the path into actual paths with a single value that xarray can read
unmapped_path = {}
path_copy = deepcopy(path)
for key in path_copy:
axis = self._axes[key]
key_value_path = {key: path_copy[key]}
(key_value_path, path, unmapped_path) = axis.unmap_path_key(key_value_path, path, unmapped_path)
path.update(key_value_path)

unmapped_path = {}
self.refit_path(path, unmapped_path, path)
for key in path:
path[key] = list(path[key])
for key in unmapped_path:
if isinstance(unmapped_path[key], tuple):
unmapped_path[key] = list(unmapped_path[key])

subxarray = self.dataarray.sel(path, method="nearest")
subxarray = subxarray.sel(unmapped_path)
value = subxarray.values
key = subxarray.name
r.result = (key, value)
def get(self, requests, leaf_path=None, axis_counter=0):
if leaf_path is None:
leaf_path = {}
if requests.axis.name == "root":
for c in requests.children:
self.get(c, leaf_path, axis_counter + 1)
else:
key_value_path = {requests.axis.name: requests.values}
ax = requests.axis
(key_value_path, leaf_path, self.unwanted_path) = ax.unmap_path_key(
key_value_path, leaf_path, self.unwanted_path
)
leaf_path.update(key_value_path)
if len(requests.children) != 0:
# We are not a leaf and we loop over
for c in requests.children:
self.get(c, leaf_path, axis_counter + 1)
else:
r.remove_branch()
if self.axis_counter != axis_counter:
requests.remove_branch()
else:
# We are at a leaf and need to assign value to it
leaf_path_copy = deepcopy(leaf_path)
unmapped_path = {}
self.refit_path(leaf_path_copy, unmapped_path, leaf_path)
for key in leaf_path_copy:
leaf_path_copy[key] = list(leaf_path_copy[key])
for key in unmapped_path:
if isinstance(unmapped_path[key], tuple):
unmapped_path[key] = list(unmapped_path[key])
subxarray = self.dataarray.sel(leaf_path_copy, method="nearest")
subxarray = subxarray.sel(unmapped_path)
# value = subxarray.item()
value = subxarray.values
key = subxarray.name
requests.result = (key, value)

def datacube_natural_indexes(self, axis, subarray):
if axis.name in self.complete_axes:
Expand Down
2 changes: 2 additions & 0 deletions polytope/datacube/datacube_axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,8 @@ def serialize(self, value):
np.datetime64: PandasTimestampDatacubeAxis(),
np.timedelta64: PandasTimedeltaDatacubeAxis(),
np.float64: FloatDatacubeAxis(),
np.float32: FloatDatacubeAxis(),
np.int32: IntDatacubeAxis(),
np.str_: UnsliceableDatacubeAxis(),
str: UnsliceableDatacubeAxis(),
np.object_: UnsliceableDatacubeAxis(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class DatacubeAxisCyclic(DatacubeAxisTransformation):
def __init__(self, name, cyclic_options):
self.name = name
self.transformation_options = cyclic_options
self.range = cyclic_options
self.range = cyclic_options.range

def generate_final_transformation(self):
return self
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ class DatacubeMapper(DatacubeAxisTransformation):

def __init__(self, name, mapper_options):
self.transformation_options = mapper_options
self.grid_type = mapper_options["type"]
self.grid_resolution = mapper_options["resolution"]
self.grid_axes = mapper_options["axes"]
self.grid_type = mapper_options.type
self.grid_resolution = mapper_options.resolution
self.grid_axes = mapper_options.axes
self.local_area = []
if "local" in mapper_options.keys():
self.local_area = mapper_options["local"]
if mapper_options.local is not None:
self.local_area = mapper_options.local
self.old_axis = name
self._final_transformation = self.generate_final_transformation()
self._final_mapped_axes = self._final_transformation._mapped_axes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ def __init__(self, name, merge_options):
self.transformation_options = merge_options
self.name = name
self._first_axis = name
self._second_axis = merge_options["with"]
self._linkers = merge_options["linkers"]
self._second_axis = merge_options.other_axis
self._linkers = merge_options.linkers

def blocked_axes(self):
return [self._second_axis]
Expand All @@ -33,7 +33,6 @@ def merged_values(self, datacube):
first_val = first_ax_vals[i]
for j in range(len(second_ax_vals)):
second_val = second_ax_vals[j]
# TODO: check that the first and second val are strings
val_to_add = pd.to_datetime("".join([first_val, linkers[0], second_val, linkers[1]]))
val_to_add = val_to_add.to_numpy()
val_to_add = val_to_add.astype("datetime64[s]")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ def create_transform(name, transformation_type_key, transformation_options):
file_name = ".datacube_" + transformation_file_name
module = import_module("polytope.datacube.transformations" + file_name + file_name)
constructor = getattr(module, transformation_type)
transformation_type_option = transformation_options[transformation_type_key]
# transformation_type_option = transformation_options[transformation_type_key]
transformation_type_option = transformation_options
new_transformation = deepcopy(constructor(name, transformation_type_option))

new_transformation.name = name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class DatacubeAxisTypeChange(DatacubeAxisTransformation):
def __init__(self, name, type_options):
self.name = name
self.transformation_options = type_options
self.new_type = type_options
self.new_type = type_options.type
self._final_transformation = self.generate_final_transformation()

def generate_final_transformation(self):
Expand Down
Loading

0 comments on commit ace3975

Please sign in to comment.