Skip to content

Commit

Permalink
better batching to fdb backend
Browse files Browse the repository at this point in the history
  • Loading branch information
mathleur committed Jan 24, 2024
1 parent 73e183f commit b03be25
Showing 1 changed file with 149 additions and 7 deletions.
156 changes: 149 additions & 7 deletions polytope/datacube/backends/fdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,14 @@ def __init__(self, config={}, axis_options={}):
val = self._axes[name].type
self._check_and_add_axes(options, name, val)

def get(self, requests: IndexTree, leaf_path={}):
def get(self, requests: IndexTree):
fdb_requests = []
fdb_requests_decoding_info = []
self.get_fdb_requests(requests, fdb_requests, fdb_requests_decoding_info)
output_values = self.fdb.extract(fdb_requests)
self.assign_fdb_output_to_nodes(output_values, fdb_requests_decoding_info)

def old_get(self, requests: IndexTree, leaf_path={}):
# First when request node is root, go to its children
if requests.axis.name == "root":
for c in requests.children:
Expand All @@ -61,13 +68,115 @@ def get(self, requests: IndexTree, leaf_path={}):
for c in requests.children:
self.get(c, leaf_path)

def get_fdb_requests(self, requests: IndexTree, fdb_requests=[], fdb_requests_decoding_info=[], leaf_path={}):
# First when request node is root, go to its children
if requests.axis.name == "root":
for c in requests.children:
self.get_fdb_requests(c, fdb_requests, fdb_requests_decoding_info)
# If request node has no children, we have a leaf so need to assign fdb values to it
else:
key_value_path = {requests.axis.name: requests.value}
ax = requests.axis
(key_value_path, leaf_path, self.unwanted_path) = ax.unmap_path_key(
key_value_path, leaf_path, self.unwanted_path
)
leaf_path.update(key_value_path)
if len(requests.children[0].children[0].children) == 0:
# remap this last key
# TODO: here, find the fdb_requests and associated nodes to which to add results

(path, range_lengths, current_start_idxs, fdb_node_ranges, lat_length) = self.get_2nd_last_values(
requests, leaf_path
)
(original_indices, sorted_request_ranges) = self.sort_fdb_request_ranges(
range_lengths, current_start_idxs, lat_length
)
fdb_requests.append(tuple((path, sorted_request_ranges)))
fdb_requests_decoding_info.append(
tuple((original_indices, fdb_node_ranges, lat_length, range_lengths, current_start_idxs))
)

# Otherwise remap the path for this key and iterate again over children
else:
for c in requests.children:
self.get_fdb_requests(c, fdb_requests, fdb_requests_decoding_info, leaf_path)

def get_2nd_last_values(self, requests, leaf_path={}):
# In this function, we recursively loop over the last two layers of the tree and store the indices of the
# request ranges in those layers
# TODO: here find nearest point first before retrieving etc
if len(self.nearest_search) != 0:
first_ax_name = requests.children[0].axis.name
second_ax_name = requests.children[0].children[0].axis.name
# TODO: throw error if first_ax_name or second_ax_name not in self.nearest_search.keys()
nearest_pts = [
[lat_val, lon_val]
for (lat_val, lon_val) in zip(
self.nearest_search[first_ax_name][0], self.nearest_search[second_ax_name][0]
)
]
# first collect the lat lon points found
found_latlon_pts = []
for lat_child in requests.children:
for lon_child in lat_child.children:
found_latlon_pts.append([lat_child.value, lon_child.value])
# now find the nearest lat lon to the points requested
nearest_latlons = []
for pt in nearest_pts:
nearest_latlon = nearest_pt(found_latlon_pts, pt)
nearest_latlons.append(nearest_latlon)
# TODO: now combine with the rest of the function....
# TODO: need to remove the branches that do not fit
lat_children_values = [child.value for child in requests.children]
for i in range(len(lat_children_values)):
lat_child_val = lat_children_values[i]
lat_child = [child for child in requests.children if child.value == lat_child_val][0]
if lat_child.value not in [latlon[0] for latlon in nearest_latlons]:
lat_child.remove_branch()
else:
possible_lons = [latlon[1] for latlon in nearest_latlons if latlon[0] == lat_child.value]
lon_children_values = [child.value for child in lat_child.children]
for j in range(len(lon_children_values)):
lon_child_val = lon_children_values[j]
lon_child = [child for child in lat_child.children if child.value == lon_child_val][0]
if lon_child.value not in possible_lons:
lon_child.remove_branch()

lat_length = len(requests.children)
range_lengths = [False] * lat_length
current_start_idxs = [False] * lat_length
fdb_node_ranges = [False] * lat_length
for i in range(len(requests.children)):
lat_child = requests.children[i]
lon_length = len(lat_child.children)
range_lengths[i] = [1] * lon_length
current_start_idxs[i] = [None] * lon_length
fdb_node_ranges[i] = [[IndexTree.root] * lon_length] * lon_length
range_length = deepcopy(range_lengths[i])
current_start_idx = deepcopy(current_start_idxs[i])
fdb_range_nodes = deepcopy(fdb_node_ranges[i])
key_value_path = {lat_child.axis.name: lat_child.value}
ax = lat_child.axis
(key_value_path, leaf_path, self.unwanted_path) = ax.unmap_path_key(
key_value_path, leaf_path, self.unwanted_path
)
leaf_path.update(key_value_path)
(range_lengths[i], current_start_idxs[i], fdb_node_ranges[i]) = self.get_last_layer_before_leaf(
lat_child, leaf_path, range_length, current_start_idx, fdb_range_nodes
)
# TODO: do we need to return all of this?
leaf_path_copy = deepcopy(leaf_path)
leaf_path_copy.pop("values")
return (leaf_path_copy, range_lengths, current_start_idxs, fdb_node_ranges, lat_length)

def old_get_2nd_last_values(self, requests, leaf_path={}):
# In this function, we recursively loop over the last two layers of the tree and store the indices of the
# request ranges in those layers
# TODO: here find nearest point first before retrieving etc
if len(self.nearest_search) != 0:
first_ax_name = requests.children[0].axis.name
second_ax_name = requests.children[0].children[0].axis.name
# TODO: throw error if first_ax_name or second_ax_name not in self.nearest_search.keys()
nearest_pts = [
[lat_val, lon_val]
for (lat_val, lon_val) in zip(
Expand Down Expand Up @@ -155,6 +264,30 @@ def get_last_layer_before_leaf(self, requests, leaf_path, range_l, current_idx,
current_idx[i] = current_start_idx
return (range_l, current_idx, fdb_range_n)

def assign_fdb_output_to_nodes(self, output_values, fdb_requests_decoding_info):
for k in range(len(output_values)):
request_output_values = output_values[k]
(
original_indices,
fdb_node_ranges,
lat_length,
range_lengths,
current_start_idxs,
) = fdb_requests_decoding_info[k]
new_fdb_range_nodes = []
new_range_lengths = []
for j in range(lat_length):
for i in range(len(range_lengths[j])):
if current_start_idxs[j][i] is not None:
new_fdb_range_nodes.append(fdb_node_ranges[j][i])
new_range_lengths.append(range_lengths[j][i])
sorted_fdb_range_nodes = [new_fdb_range_nodes[i] for i in original_indices]
sorted_range_lengths = [new_range_lengths[i] for i in original_indices]
for i in range(len(sorted_fdb_range_nodes)):
for j in range(sorted_range_lengths[i]):
n = sorted_fdb_range_nodes[i][j]
n.result = request_output_values[0][i][0][j]

def give_fdb_val_to_node(self, leaf_path, range_lengths, current_start_idx, fdb_range_nodes, lat_length):
(output_values, original_indices) = self.find_fdb_values(
leaf_path, range_lengths, current_start_idx, lat_length
Expand All @@ -169,9 +302,21 @@ def give_fdb_val_to_node(self, leaf_path, range_lengths, current_start_idx, fdb_
sorted_fdb_range_nodes = [new_fdb_range_nodes[i] for i in original_indices]
sorted_range_lengths = [new_range_lengths[i] for i in original_indices]
for i in range(len(sorted_fdb_range_nodes)):
for k in range(sorted_range_lengths[i]):
n = sorted_fdb_range_nodes[i][k]
n.result = output_values[0][0][i][0][k]
for j in range(sorted_range_lengths[i]):
n = sorted_fdb_range_nodes[i][j]
n.result = output_values[0][0][i][0][j]

def sort_fdb_request_ranges(self, range_lengths, current_start_idx, lat_length):
interm_request_ranges = []
for i in range(lat_length):
for j in range(len(range_lengths[i])):
if current_start_idx[i][j] is not None:
current_request_ranges = (current_start_idx[i][j], current_start_idx[i][j] + range_lengths[i][j])
interm_request_ranges.append(current_request_ranges)
request_ranges_with_idx = list(enumerate(interm_request_ranges))
sorted_list = sorted(request_ranges_with_idx, key=lambda x: x[1][0])
original_indices, sorted_request_ranges = zip(*sorted_list)
return (original_indices, sorted_request_ranges)

def find_fdb_values(self, path, range_lengths, current_start_idx, lat_length):
path.pop("values")
Expand All @@ -186,10 +331,7 @@ def find_fdb_values(self, path, range_lengths, current_start_idx, lat_length):
sorted_list = sorted(request_ranges_with_idx, key=lambda x: x[1][0])
original_indices, sorted_request_ranges = zip(*sorted_list)
fdb_requests.append(tuple((path, sorted_request_ranges)))
print("REQUEST TO FDB")
print(fdb_requests)
output_values = self.fdb.extract(fdb_requests)
print(output_values)
return (output_values, original_indices)

def datacube_natural_indexes(self, axis, subarray):
Expand Down

0 comments on commit b03be25

Please sign in to comment.