Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/different tree compression merged develop #130

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
609f6fa
add conflator config for transformation options
mathleur Feb 19, 2024
29f2b51
clean up
mathleur Feb 19, 2024
355d296
try to add nested conflator config
mathleur Feb 22, 2024
d6e58be
make conflator config work with fdb backend
mathleur Feb 27, 2024
27b3c57
make conflator config work with xarray backend
mathleur Feb 27, 2024
fafa4cd
black
mathleur Feb 27, 2024
0eca856
add conflator to requirements
mathleur Feb 27, 2024
056fdf7
upgrade pytest version in test requirements
mathleur Feb 27, 2024
317158b
fix small xarray bug
mathleur Mar 28, 2024
e7e47f3
make xarray backend get recursive
mathleur Apr 2, 2024
77225c6
clean up
mathleur Apr 2, 2024
70b6361
add performance example
mathleur Apr 3, 2024
1f13af7
try to make conflator config not a yaml
mathleur Apr 3, 2024
09248bc
add np.float32 and np.int32 as axis lookups
mathleur Apr 17, 2024
4c822e6
merged conflator_config branch with develop
mathleur Apr 17, 2024
f262052
merge develop
mathleur Apr 17, 2024
43edfcd
fix test_axis_mappers
mathleur Apr 17, 2024
d4e34f1
fix error in test
mathleur Apr 17, 2024
07e1637
fix formatting
mathleur Apr 18, 2024
1430c7c
fix flake8
mathleur Apr 18, 2024
af371ae
small fixes
mathleur Apr 18, 2024
ba94de6
fix numpy to_list problem
mathleur Apr 18, 2024
b6db873
fix almost all tests
mathleur Apr 18, 2024
5cc03ae
add conflator transformation options as dictionaries
mathleur Apr 19, 2024
6e78f82
black
mathleur Apr 19, 2024
37ccbea
fix conflator branch in requirements
mathleur Apr 19, 2024
0148249
fix conflator branch in requirements
mathleur Apr 19, 2024
3032fe7
Merge pull request #122 from ecmwf/feature/small-fix
mathleur Apr 30, 2024
13cb534
Merge pull request #127 from ecmwf/feature/improved_xarray_backend
mathleur Apr 30, 2024
c853a02
merge develop
mathleur Apr 30, 2024
c97f278
Merge pull request #124 from ecmwf/feature/merged_conflator_config_de…
mathleur Apr 30, 2024
b0dc5ff
merge develop
mathleur May 7, 2024
0d617a8
fix problems
mathleur May 7, 2024
3a46769
fix compression in test
mathleur May 7, 2024
1d1c43f
black
mathleur May 7, 2024
6a28c43
flake8
mathleur May 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ site
example_eo
example_mri
.mypy_cache
*.req
*.req
polytope_python.egg-info
3 changes: 0 additions & 3 deletions performance/performance_many_num_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
# Create a dataarray with 3 labelled axes using different index types
options = {
"values": {"mapper": {"type": "octahedral", "resolution": 1280, "axes": ["latitude", "longitude"]}},
# "date": {"merge": {"with": "time", "linkers": ["T", "00"]}},
# "step": {"type_change": "int"},
# "number": {"type_change": "int"},
"longitude": {"cyclic": [0, 360]},
"latitude": {"reverse": {True}},
}
Expand Down
65 changes: 58 additions & 7 deletions polytope/datacube/backends/datacube.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import argparse
import logging
from abc import ABC, abstractmethod
from typing import Any
from typing import Any, List, Literal, Optional, Union

import xarray as xr
from conflator import ConfigModel, Conflator
from pydantic import ConfigDict

from ...utility.combinatorics import validate_axes
from ..datacube_axis import DatacubeAxis
Expand Down Expand Up @@ -35,6 +38,7 @@ def __init__(self, axis_options=None, datacube_options=None):
self.compressed_grid_axes = []
self.compressed_axes = []
self.merged_axes = []
self.unwanted_path = {}

@abstractmethod
def get(self, requests: IndexTree) -> Any:
Expand All @@ -50,15 +54,16 @@ def validate(self, axes):

def _create_axes(self, name, values, transformation_type_key, transformation_options):
# first check what the final axes are for this axis name given transformations
transformation_options = transformation_type_key
final_axis_names = DatacubeAxisTransformation.get_final_axes(
name, transformation_type_key, transformation_options
name, transformation_type_key.name, transformation_options
)
transformation = DatacubeAxisTransformation.create_transform(
name, transformation_type_key, transformation_options
name, transformation_type_key.name, transformation_options
)

# do not compress merged axes
if transformation_type_key == "merge":
if transformation_type_key.name == "merge":
self.merged_axes.append(name)
self.merged_axes.append(final_axis_names)

Expand All @@ -78,16 +83,16 @@ def _create_axes(self, name, values, transformation_type_key, transformation_opt
if self._axes is None or axis_name not in self._axes.keys():
DatacubeAxis.create_standard(axis_name, values, self)
# add transformation tag to axis, as well as transformation options for later
setattr(self._axes[axis_name], has_transform[transformation_type_key], True) # where has_transform is a
# factory inside datacube_transformations to set the has_transform, is_cyclic etc axis properties
setattr(self._axes[axis_name], has_transform[transformation_type_key.name], True) # where has_transform is
# a factory inside datacube_transformations to set the has_transform, is_cyclic etc axis properties
# add the specific transformation handled here to the relevant axes
# Modify the axis to update with the tag

if transformation not in self._axes[axis_name].transformations: # Avoids duplicates being stored
self._axes[axis_name].transformations.append(transformation)

def _add_all_transformation_axes(self, options, name, values):
for transformation_type_key in options.keys():
for transformation_type_key in options.transformations:
if transformation_type_key != "cyclic":
self.transformed_axes.append(name)
self._create_axes(name, values, transformation_type_key, options)
Expand Down Expand Up @@ -143,6 +148,52 @@ def remap_path(self, path: DatacubePath):
path[key] = self._axes[key].remap([value, value])[0][0]
return path

@staticmethod
def create_axes_config(axis_options):
class TransformationConfig(ConfigModel):
model_config = ConfigDict(extra="forbid")
name: str = ""

class CyclicConfig(TransformationConfig):
name: Literal["cyclic"]
range: List[float] = [0]

class MapperConfig(TransformationConfig):
name: Literal["mapper"]
type: str = ""
resolution: Union[int, List[int]] = 0
axes: List[str] = [""]
local: Optional[List[float]] = None

class ReverseConfig(TransformationConfig):
name: Literal["reverse"]
is_reverse: bool = False

class TypeChangeConfig(TransformationConfig):
name: Literal["type_change"]
type: str = "int"

class MergeConfig(TransformationConfig):
name: Literal["merge"]
other_axis: str = ""
linkers: List[str] = [""]

action_subclasses_union = Union[CyclicConfig, MapperConfig, ReverseConfig, TypeChangeConfig, MergeConfig]

class AxisConfig(ConfigModel):
axis_name: str = ""
transformations: list[action_subclasses_union]

class Config(ConfigModel):
config: list[AxisConfig] = []

parser = argparse.ArgumentParser(allow_abbrev=False)
axis_config = Conflator(app_name="polytope", model=Config, cli=False, argparser=parser).load()
if axis_options.get("config"):
axis_config = Config(config=axis_options.get("config"))

return axis_config

@staticmethod
def create(datacube, axis_options: dict, datacube_options={}):
if isinstance(datacube, (xr.core.dataarray.DataArray, xr.core.dataset.Dataset)):
Expand Down
19 changes: 13 additions & 6 deletions polytope/datacube/backends/fdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from copy import deepcopy
from itertools import product

import numpy as np
import pygribjump as pygj

from ...utility.geometry import nearest_pt
Expand All @@ -19,6 +18,7 @@ def __init__(self, config=None, axis_options=None, datacube_options=None):
logging.info("Created an FDB datacube with options: " + str(axis_options))

self.unwanted_path = {}
self.axis_options = Datacube.create_axes_config(axis_options).config

partial_request = config
# Find values in the level 3 FDB datacube
Expand All @@ -31,15 +31,23 @@ def __init__(self, config=None, axis_options=None, datacube_options=None):
self.fdb_coordinates["values"] = []
for name, values in self.fdb_coordinates.items():
values.sort()
options = axis_options.get(name, None)
options = None
for opt in self.axis_options:
if opt.axis_name == name:
options = opt
# options = axis_options.get(name, None)
self._check_and_add_axes(options, name, values)
self.treated_axes.append(name)
self.complete_axes.append(name)

# add other options to axis which were just created above like "lat" for the mapper transformations for eg
for name in self._axes:
if name not in self.treated_axes:
options = axis_options.get(name, None)
options = None
for opt in self.axis_options:
if opt.axis_name == name:
options = opt
# options = axis_options.get(name, None)
val = self._axes[name].type
self._check_and_add_axes(options, name, val)

Expand All @@ -65,8 +73,6 @@ def get(self, requests: IndexTree):
interm_branch_tuple_values = []
for key in compressed_request[0].keys():
# remove the tuple of the request when we ask the fdb

# TODO: here, would need to take care of axes that are merged and unmerged, which need to be carefully decompressed
interm_branch_tuple_values.append(compressed_request[0][key])
request_combis = product(*interm_branch_tuple_values)

Expand All @@ -92,7 +98,8 @@ def get(self, requests: IndexTree):
# request[0][key] = request[0][key][0]
# branch_tuple_combi = product(*interm_branch_tuple_values)
# # TODO: now build the relevant requests from this and ask gj for them
# # TODO: then group the output values together to fit back with the original compressed request and continue
# # TODO: then group the output values together to fit back with the original compressed request
# # and continue
# new_requests = []
# for combi in branch_tuple_combi:
# new_request = {}
Expand Down
87 changes: 55 additions & 32 deletions polytope/datacube/backends/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,68 +3,91 @@
import numpy as np
import xarray as xr

from .datacube import Datacube, IndexTree
from .datacube import Datacube


class XArrayDatacube(Datacube):
"""Xarray arrays are labelled, axes can be defined as strings or integers (e.g. "time" or 0)."""

def __init__(self, dataarray: xr.DataArray, axis_options=None, datacube_options=None):
super().__init__(axis_options, datacube_options)
if axis_options is None:
axis_options = {}
if datacube_options is None:
datacube_options = {}
self.axis_options = Datacube.create_axes_config(axis_options).config
self.axis_counter = 0
self._axes = None
self.dataarray = dataarray

for name, values in dataarray.coords.variables.items():
options = None
for opt in self.axis_options:
if opt.axis_name == name:
options = opt
if name in dataarray.dims:
options = self.axis_options.get(name, None)
self._check_and_add_axes(options, name, values)
self.treated_axes.append(name)
self.complete_axes.append(name)
else:
if self.dataarray[name].dims == ():
options = self.axis_options.get(name, None)
self._check_and_add_axes(options, name, values)
self.treated_axes.append(name)
for name in dataarray.dims:
if name not in self.treated_axes:
options = self.axis_options.get(name, None)
options = None
for opt in self.axis_options:
if opt.axis_name == name:
options = opt
val = dataarray[name].values[0]
self._check_and_add_axes(options, name, val)
self.treated_axes.append(name)
# add other options to axis which were just created above like "lat" for the mapper transformations for eg
for name in self._axes:
if name not in self.treated_axes:
options = self.axis_options.get(name, None)
options = None
for opt in self.axis_options:
if opt.axis_name == name:
options = opt
val = self._axes[name].type
self._check_and_add_axes(options, name, val)

def get(self, requests: IndexTree):
for r in requests.leaves:
path = r.flatten()
if len(path.items()) == self.axis_counter:
# TODO: need to undo the tuples in the path into actual paths with a single value that xarray can read
unmapped_path = {}
path_copy = deepcopy(path)
for key in path_copy:
axis = self._axes[key]
key_value_path = {key: path_copy[key]}
(key_value_path, path, unmapped_path) = axis.unmap_path_key(key_value_path, path, unmapped_path)
path.update(key_value_path)

unmapped_path = {}
self.refit_path(path, unmapped_path, path)
for key in path:
path[key] = list(path[key])
for key in unmapped_path:
if isinstance(unmapped_path[key], tuple):
unmapped_path[key] = list(unmapped_path[key])

subxarray = self.dataarray.sel(path, method="nearest")
subxarray = subxarray.sel(unmapped_path)
value = subxarray.values
key = subxarray.name
r.result = (key, value)
def get(self, requests, leaf_path=None, axis_counter=0):
if leaf_path is None:
leaf_path = {}
if requests.axis.name == "root":
for c in requests.children:
self.get(c, leaf_path, axis_counter + 1)
else:
key_value_path = {requests.axis.name: requests.values}
ax = requests.axis
(key_value_path, leaf_path, self.unwanted_path) = ax.unmap_path_key(
key_value_path, leaf_path, self.unwanted_path
)
leaf_path.update(key_value_path)
if len(requests.children) != 0:
# We are not a leaf and we loop over
for c in requests.children:
self.get(c, leaf_path, axis_counter + 1)
else:
r.remove_branch()
if self.axis_counter != axis_counter:
requests.remove_branch()
else:
# We are at a leaf and need to assign value to it
leaf_path_copy = deepcopy(leaf_path)
unmapped_path = {}
self.refit_path(leaf_path_copy, unmapped_path, leaf_path)
for key in leaf_path_copy:
leaf_path_copy[key] = list(leaf_path_copy[key])
for key in unmapped_path:
if isinstance(unmapped_path[key], tuple):
unmapped_path[key] = list(unmapped_path[key])
subxarray = self.dataarray.sel(leaf_path_copy, method="nearest")
subxarray = subxarray.sel(unmapped_path)
# value = subxarray.item()
value = subxarray.values
key = subxarray.name
requests.result = (key, value)

def datacube_natural_indexes(self, axis, subarray):
if axis.name in self.complete_axes:
Expand Down
2 changes: 2 additions & 0 deletions polytope/datacube/datacube_axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,8 @@ def serialize(self, value):
np.datetime64: PandasTimestampDatacubeAxis(),
np.timedelta64: PandasTimedeltaDatacubeAxis(),
np.float64: FloatDatacubeAxis(),
np.float32: FloatDatacubeAxis(),
np.int32: IntDatacubeAxis(),
np.str_: UnsliceableDatacubeAxis(),
str: UnsliceableDatacubeAxis(),
np.object_: UnsliceableDatacubeAxis(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class DatacubeAxisCyclic(DatacubeAxisTransformation):
def __init__(self, name, cyclic_options):
self.name = name
self.transformation_options = cyclic_options
self.range = cyclic_options
self.range = cyclic_options.range

def generate_final_transformation(self):
return self
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ class DatacubeMapper(DatacubeAxisTransformation):

def __init__(self, name, mapper_options):
self.transformation_options = mapper_options
self.grid_type = mapper_options["type"]
self.grid_resolution = mapper_options["resolution"]
self.grid_axes = mapper_options["axes"]
self.grid_type = mapper_options.type
self.grid_resolution = mapper_options.resolution
self.grid_axes = mapper_options.axes
self.local_area = []
if "local" in mapper_options.keys():
self.local_area = mapper_options["local"]
if mapper_options.local is not None:
self.local_area = mapper_options.local
self.old_axis = name
self._final_transformation = self.generate_final_transformation()
self._final_mapped_axes = self._final_transformation._mapped_axes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ def __init__(self, name, merge_options):
self.transformation_options = merge_options
self.name = name
self._first_axis = name
self._second_axis = merge_options["with"]
self._linkers = merge_options["linkers"]
self._second_axis = merge_options.other_axis
self._linkers = merge_options.linkers

def blocked_axes(self):
return [self._second_axis]
Expand All @@ -33,7 +33,6 @@ def merged_values(self, datacube):
first_val = first_ax_vals[i]
for j in range(len(second_ax_vals)):
second_val = second_ax_vals[j]
# TODO: check that the first and second val are strings
val_to_add = pd.to_datetime("".join([first_val, linkers[0], second_val, linkers[1]]))
val_to_add = val_to_add.to_numpy()
val_to_add = val_to_add.astype("datetime64[s]")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ def create_transform(name, transformation_type_key, transformation_options):
file_name = ".datacube_" + transformation_file_name
module = import_module("polytope.datacube.transformations" + file_name + file_name)
constructor = getattr(module, transformation_type)
transformation_type_option = transformation_options[transformation_type_key]
# transformation_type_option = transformation_options[transformation_type_key]
transformation_type_option = transformation_options
new_transformation = deepcopy(constructor(name, transformation_type_option))

new_transformation.name = name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class DatacubeAxisTypeChange(DatacubeAxisTransformation):
def __init__(self, name, type_options):
self.name = name
self.transformation_options = type_options
self.new_type = type_options
self.new_type = type_options.type
self._final_transformation = self.generate_final_transformation()

def generate_final_transformation(self):
Expand Down
Loading
Loading