From b03be25656d3009fa907d985c64c937026b2fb17 Mon Sep 17 00:00:00 2001 From: Mathilde Leuridan Date: Wed, 24 Jan 2024 10:21:45 +0100 Subject: [PATCH] better batching to fdb backend --- polytope/datacube/backends/fdb.py | 156 ++++++++++++++++++++++++++++-- 1 file changed, 149 insertions(+), 7 deletions(-) diff --git a/polytope/datacube/backends/fdb.py b/polytope/datacube/backends/fdb.py index d0f4264f..8e930012 100644 --- a/polytope/datacube/backends/fdb.py +++ b/polytope/datacube/backends/fdb.py @@ -39,7 +39,14 @@ def __init__(self, config={}, axis_options={}): val = self._axes[name].type self._check_and_add_axes(options, name, val) - def get(self, requests: IndexTree, leaf_path={}): + def get(self, requests: IndexTree): + fdb_requests = [] + fdb_requests_decoding_info = [] + self.get_fdb_requests(requests, fdb_requests, fdb_requests_decoding_info) + output_values = self.fdb.extract(fdb_requests) + self.assign_fdb_output_to_nodes(output_values, fdb_requests_decoding_info) + + def old_get(self, requests: IndexTree, leaf_path={}): # First when request node is root, go to its children if requests.axis.name == "root": for c in requests.children: @@ -61,6 +68,39 @@ def get(self, requests: IndexTree, leaf_path={}): for c in requests.children: self.get(c, leaf_path) + def get_fdb_requests(self, requests: IndexTree, fdb_requests=[], fdb_requests_decoding_info=[], leaf_path={}): + # First when request node is root, go to its children + if requests.axis.name == "root": + for c in requests.children: + self.get_fdb_requests(c, fdb_requests, fdb_requests_decoding_info) + # If request node has no children, we have a leaf so need to assign fdb values to it + else: + key_value_path = {requests.axis.name: requests.value} + ax = requests.axis + (key_value_path, leaf_path, self.unwanted_path) = ax.unmap_path_key( + key_value_path, leaf_path, self.unwanted_path + ) + leaf_path.update(key_value_path) + if len(requests.children[0].children[0].children) == 0: + # remap this last key + # TODO: here, find the fdb_requests and associated nodes to which to add results + + (path, range_lengths, current_start_idxs, fdb_node_ranges, lat_length) = self.get_2nd_last_values( + requests, leaf_path + ) + (original_indices, sorted_request_ranges) = self.sort_fdb_request_ranges( + range_lengths, current_start_idxs, lat_length + ) + fdb_requests.append(tuple((path, sorted_request_ranges))) + fdb_requests_decoding_info.append( + tuple((original_indices, fdb_node_ranges, lat_length, range_lengths, current_start_idxs)) + ) + + # Otherwise remap the path for this key and iterate again over children + else: + for c in requests.children: + self.get_fdb_requests(c, fdb_requests, fdb_requests_decoding_info, leaf_path) + def get_2nd_last_values(self, requests, leaf_path={}): # In this function, we recursively loop over the last two layers of the tree and store the indices of the # request ranges in those layers @@ -68,6 +108,75 @@ def get_2nd_last_values(self, requests, leaf_path={}): if len(self.nearest_search) != 0: first_ax_name = requests.children[0].axis.name second_ax_name = requests.children[0].children[0].axis.name + # TODO: throw error if first_ax_name or second_ax_name not in self.nearest_search.keys() + nearest_pts = [ + [lat_val, lon_val] + for (lat_val, lon_val) in zip( + self.nearest_search[first_ax_name][0], self.nearest_search[second_ax_name][0] + ) + ] + # first collect the lat lon points found + found_latlon_pts = [] + for lat_child in requests.children: + for lon_child in lat_child.children: + found_latlon_pts.append([lat_child.value, lon_child.value]) + # now find the nearest lat lon to the points requested + nearest_latlons = [] + for pt in nearest_pts: + nearest_latlon = nearest_pt(found_latlon_pts, pt) + nearest_latlons.append(nearest_latlon) + # TODO: now combine with the rest of the function.... + # TODO: need to remove the branches that do not fit + lat_children_values = [child.value for child in requests.children] + for i in range(len(lat_children_values)): + lat_child_val = lat_children_values[i] + lat_child = [child for child in requests.children if child.value == lat_child_val][0] + if lat_child.value not in [latlon[0] for latlon in nearest_latlons]: + lat_child.remove_branch() + else: + possible_lons = [latlon[1] for latlon in nearest_latlons if latlon[0] == lat_child.value] + lon_children_values = [child.value for child in lat_child.children] + for j in range(len(lon_children_values)): + lon_child_val = lon_children_values[j] + lon_child = [child for child in lat_child.children if child.value == lon_child_val][0] + if lon_child.value not in possible_lons: + lon_child.remove_branch() + + lat_length = len(requests.children) + range_lengths = [False] * lat_length + current_start_idxs = [False] * lat_length + fdb_node_ranges = [False] * lat_length + for i in range(len(requests.children)): + lat_child = requests.children[i] + lon_length = len(lat_child.children) + range_lengths[i] = [1] * lon_length + current_start_idxs[i] = [None] * lon_length + fdb_node_ranges[i] = [[IndexTree.root] * lon_length] * lon_length + range_length = deepcopy(range_lengths[i]) + current_start_idx = deepcopy(current_start_idxs[i]) + fdb_range_nodes = deepcopy(fdb_node_ranges[i]) + key_value_path = {lat_child.axis.name: lat_child.value} + ax = lat_child.axis + (key_value_path, leaf_path, self.unwanted_path) = ax.unmap_path_key( + key_value_path, leaf_path, self.unwanted_path + ) + leaf_path.update(key_value_path) + (range_lengths[i], current_start_idxs[i], fdb_node_ranges[i]) = self.get_last_layer_before_leaf( + lat_child, leaf_path, range_length, current_start_idx, fdb_range_nodes + ) + # TODO: do we need to return all of this? + leaf_path_copy = deepcopy(leaf_path) + leaf_path_copy.pop("values") + return (leaf_path_copy, range_lengths, current_start_idxs, fdb_node_ranges, lat_length) + + def old_get_2nd_last_values(self, requests, leaf_path={}): + # In this function, we recursively loop over the last two layers of the tree and store the indices of the + # request ranges in those layers + # TODO: here find nearest point first before retrieving etc + if len(self.nearest_search) != 0: + first_ax_name = requests.children[0].axis.name + second_ax_name = requests.children[0].children[0].axis.name + # TODO: throw error if first_ax_name or second_ax_name not in self.nearest_search.keys() nearest_pts = [ [lat_val, lon_val] for (lat_val, lon_val) in zip( @@ -155,6 +264,30 @@ def get_last_layer_before_leaf(self, requests, leaf_path, range_l, current_idx, current_idx[i] = current_start_idx return (range_l, current_idx, fdb_range_n) + def assign_fdb_output_to_nodes(self, output_values, fdb_requests_decoding_info): + for k in range(len(output_values)): + request_output_values = output_values[k] + ( + original_indices, + fdb_node_ranges, + lat_length, + range_lengths, + current_start_idxs, + ) = fdb_requests_decoding_info[k] + new_fdb_range_nodes = [] + new_range_lengths = [] + for j in range(lat_length): + for i in range(len(range_lengths[j])): + if current_start_idxs[j][i] is not None: + new_fdb_range_nodes.append(fdb_node_ranges[j][i]) + new_range_lengths.append(range_lengths[j][i]) + sorted_fdb_range_nodes = [new_fdb_range_nodes[i] for i in original_indices] + sorted_range_lengths = [new_range_lengths[i] for i in original_indices] + for i in range(len(sorted_fdb_range_nodes)): + for j in range(sorted_range_lengths[i]): + n = sorted_fdb_range_nodes[i][j] + n.result = request_output_values[0][i][0][j] + def give_fdb_val_to_node(self, leaf_path, range_lengths, current_start_idx, fdb_range_nodes, lat_length): (output_values, original_indices) = self.find_fdb_values( leaf_path, range_lengths, current_start_idx, lat_length @@ -169,9 +302,21 @@ def give_fdb_val_to_node(self, leaf_path, range_lengths, current_start_idx, fdb_ sorted_fdb_range_nodes = [new_fdb_range_nodes[i] for i in original_indices] sorted_range_lengths = [new_range_lengths[i] for i in original_indices] for i in range(len(sorted_fdb_range_nodes)): - for k in range(sorted_range_lengths[i]): - n = sorted_fdb_range_nodes[i][k] - n.result = output_values[0][0][i][0][k] + for j in range(sorted_range_lengths[i]): + n = sorted_fdb_range_nodes[i][j] + n.result = output_values[0][0][i][0][j] + + def sort_fdb_request_ranges(self, range_lengths, current_start_idx, lat_length): + interm_request_ranges = [] + for i in range(lat_length): + for j in range(len(range_lengths[i])): + if current_start_idx[i][j] is not None: + current_request_ranges = (current_start_idx[i][j], current_start_idx[i][j] + range_lengths[i][j]) + interm_request_ranges.append(current_request_ranges) + request_ranges_with_idx = list(enumerate(interm_request_ranges)) + sorted_list = sorted(request_ranges_with_idx, key=lambda x: x[1][0]) + original_indices, sorted_request_ranges = zip(*sorted_list) + return (original_indices, sorted_request_ranges) def find_fdb_values(self, path, range_lengths, current_start_idx, lat_length): path.pop("values") @@ -186,10 +331,7 @@ def find_fdb_values(self, path, range_lengths, current_start_idx, lat_length): sorted_list = sorted(request_ranges_with_idx, key=lambda x: x[1][0]) original_indices, sorted_request_ranges = zip(*sorted_list) fdb_requests.append(tuple((path, sorted_request_ranges))) - print("REQUEST TO FDB") - print(fdb_requests) output_values = self.fdb.extract(fdb_requests) - print(output_values) return (output_values, original_indices) def datacube_natural_indexes(self, axis, subarray):