diff --git a/polytope/datacube/backends/fdb.py b/polytope/datacube/backends/fdb.py index d0f4264f..7b2d1929 100644 --- a/polytope/datacube/backends/fdb.py +++ b/polytope/datacube/backends/fdb.py @@ -39,11 +39,18 @@ def __init__(self, config={}, axis_options={}): val = self._axes[name].type self._check_and_add_axes(options, name, val) - def get(self, requests: IndexTree, leaf_path={}): + def get(self, requests: IndexTree): + fdb_requests = [] + fdb_requests_decoding_info = [] + self.get_fdb_requests(requests, fdb_requests, fdb_requests_decoding_info) + output_values = self.fdb.extract(fdb_requests) + self.assign_fdb_output_to_nodes(output_values, fdb_requests_decoding_info) + + def get_fdb_requests(self, requests: IndexTree, fdb_requests=[], fdb_requests_decoding_info=[], leaf_path={}): # First when request node is root, go to its children if requests.axis.name == "root": for c in requests.children: - self.get(c) + self.get_fdb_requests(c, fdb_requests, fdb_requests_decoding_info) # If request node has no children, we have a leaf so need to assign fdb values to it else: key_value_path = {requests.axis.name: requests.value} @@ -54,12 +61,23 @@ def get(self, requests: IndexTree, leaf_path={}): leaf_path.update(key_value_path) if len(requests.children[0].children[0].children) == 0: # remap this last key - self.get_2nd_last_values(requests, leaf_path) + # TODO: here, find the fdb_requests and associated nodes to which to add results + + (path, range_lengths, current_start_idxs, fdb_node_ranges, lat_length) = self.get_2nd_last_values( + requests, leaf_path + ) + (original_indices, sorted_request_ranges) = self.sort_fdb_request_ranges( + range_lengths, current_start_idxs, lat_length + ) + fdb_requests.append(tuple((path, sorted_request_ranges))) + fdb_requests_decoding_info.append( + tuple((original_indices, fdb_node_ranges, lat_length, range_lengths, current_start_idxs)) + ) # Otherwise remap the path for this key and iterate again over children else: for c in requests.children: - self.get(c, leaf_path) + self.get_fdb_requests(c, fdb_requests, fdb_requests_decoding_info, leaf_path) def get_2nd_last_values(self, requests, leaf_path={}): # In this function, we recursively loop over the last two layers of the tree and store the indices of the @@ -68,6 +86,7 @@ def get_2nd_last_values(self, requests, leaf_path={}): if len(self.nearest_search) != 0: first_ax_name = requests.children[0].axis.name second_ax_name = requests.children[0].children[0].axis.name + # TODO: throw error if first_ax_name or second_ax_name not in self.nearest_search.keys() nearest_pts = [ [lat_val, lon_val] for (lat_val, lon_val) in zip( @@ -123,7 +142,10 @@ def get_2nd_last_values(self, requests, leaf_path={}): (range_lengths[i], current_start_idxs[i], fdb_node_ranges[i]) = self.get_last_layer_before_leaf( lat_child, leaf_path, range_length, current_start_idx, fdb_range_nodes ) - self.give_fdb_val_to_node(leaf_path, range_lengths, current_start_idxs, fdb_node_ranges, lat_length) + # TODO: do we need to return all of this? + leaf_path_copy = deepcopy(leaf_path) + leaf_path_copy.pop("values") + return (leaf_path_copy, range_lengths, current_start_idxs, fdb_node_ranges, lat_length) def get_last_layer_before_leaf(self, requests, leaf_path, range_l, current_idx, fdb_range_n): i = 0 @@ -155,27 +177,31 @@ def get_last_layer_before_leaf(self, requests, leaf_path, range_l, current_idx, current_idx[i] = current_start_idx return (range_l, current_idx, fdb_range_n) - def give_fdb_val_to_node(self, leaf_path, range_lengths, current_start_idx, fdb_range_nodes, lat_length): - (output_values, original_indices) = self.find_fdb_values( - leaf_path, range_lengths, current_start_idx, lat_length - ) - new_fdb_range_nodes = [] - new_range_lengths = [] - for j in range(lat_length): - for i in range(len(range_lengths[j])): - if current_start_idx[j][i] is not None: - new_fdb_range_nodes.append(fdb_range_nodes[j][i]) - new_range_lengths.append(range_lengths[j][i]) - sorted_fdb_range_nodes = [new_fdb_range_nodes[i] for i in original_indices] - sorted_range_lengths = [new_range_lengths[i] for i in original_indices] - for i in range(len(sorted_fdb_range_nodes)): - for k in range(sorted_range_lengths[i]): - n = sorted_fdb_range_nodes[i][k] - n.result = output_values[0][0][i][0][k] - - def find_fdb_values(self, path, range_lengths, current_start_idx, lat_length): - path.pop("values") - fdb_requests = [] + def assign_fdb_output_to_nodes(self, output_values, fdb_requests_decoding_info): + for k in range(len(output_values)): + request_output_values = output_values[k] + ( + original_indices, + fdb_node_ranges, + lat_length, + range_lengths, + current_start_idxs, + ) = fdb_requests_decoding_info[k] + new_fdb_range_nodes = [] + new_range_lengths = [] + for j in range(lat_length): + for i in range(len(range_lengths[j])): + if current_start_idxs[j][i] is not None: + new_fdb_range_nodes.append(fdb_node_ranges[j][i]) + new_range_lengths.append(range_lengths[j][i]) + sorted_fdb_range_nodes = [new_fdb_range_nodes[i] for i in original_indices] + sorted_range_lengths = [new_range_lengths[i] for i in original_indices] + for i in range(len(sorted_fdb_range_nodes)): + for j in range(sorted_range_lengths[i]): + n = sorted_fdb_range_nodes[i][j] + n.result = request_output_values[0][i][0][j] + + def sort_fdb_request_ranges(self, range_lengths, current_start_idx, lat_length): interm_request_ranges = [] for i in range(lat_length): for j in range(len(range_lengths[i])): @@ -185,12 +211,7 @@ def find_fdb_values(self, path, range_lengths, current_start_idx, lat_length): request_ranges_with_idx = list(enumerate(interm_request_ranges)) sorted_list = sorted(request_ranges_with_idx, key=lambda x: x[1][0]) original_indices, sorted_request_ranges = zip(*sorted_list) - fdb_requests.append(tuple((path, sorted_request_ranges))) - print("REQUEST TO FDB") - print(fdb_requests) - output_values = self.fdb.extract(fdb_requests) - print(output_values) - return (output_values, original_indices) + return (original_indices, sorted_request_ranges) def datacube_natural_indexes(self, axis, subarray): indexes = subarray[axis.name]