Skip to content

Commit

Permalink
using bst as filter
Browse files Browse the repository at this point in the history
Signed-off-by: Sandesh Kumar <sandeshkr419@gmail.com>
  • Loading branch information
sandeshkr419 committed Aug 27, 2024
1 parent 89e69c2 commit c775c14
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 113 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ public class FeatureFlags {
* aggregations.
*/
public static final String STAR_TREE_INDEX = "opensearch.experimental.feature.composite_index.star_tree.enabled";
public static final Setting<Boolean> STAR_TREE_INDEX_SETTING = Setting.boolSetting(STAR_TREE_INDEX, false, Property.NodeScope);
public static final Setting<Boolean> STAR_TREE_INDEX_SETTING = Setting.boolSetting(STAR_TREE_INDEX, true, Property.NodeScope);

/**
* Gates the functionality of application based configuration templates.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -540,21 +540,21 @@ public ParsedQuery toStarTreeQuery(
QueryBuilder queryBuilder,
Query query
) {
Map<String, List<Predicate<Long>>> predicateMap;
Map<String, Long> queryMap;

if (queryBuilder == null) {
predicateMap = null;
queryMap = null;
} else if (queryBuilder instanceof TermQueryBuilder) {
List<String> supportedDimensions = compositeIndexFieldInfo.getDimensions()
.stream()
.map(Dimension::getField)
.collect(Collectors.toList());
predicateMap = getStarTreePredicates(queryBuilder, supportedDimensions);
queryMap = getStarTreePredicates(queryBuilder, supportedDimensions);
} else {
return null;
}

StarTreeQuery starTreeQuery = new StarTreeQuery(starTree, predicateMap);
StarTreeQuery starTreeQuery = new StarTreeQuery(starTree, queryMap);
OriginalOrStarTreeQuery originalOrStarTreeQuery = new OriginalOrStarTreeQuery(starTreeQuery, query);
return new ParsedQuery(originalOrStarTreeQuery);
}
Expand All @@ -564,24 +564,17 @@ public ParsedQuery toStarTreeQuery(
* @param queryBuilder
* @return predicates to match
*/
private Map<String, List<Predicate<Long>>> getStarTreePredicates(QueryBuilder queryBuilder, List<String> supportedDimensions) {
private Map<String, Long> getStarTreePredicates(QueryBuilder queryBuilder, List<String> supportedDimensions) {
TermQueryBuilder tq = (TermQueryBuilder) queryBuilder;
String field = tq.fieldName();
if (supportedDimensions.contains(field) == false) {
if (!supportedDimensions.contains(field)) {
throw new IllegalArgumentException("unsupported field in star-tree");
}
long inputQueryVal = Long.parseLong(tq.value().toString());

// Get or create the list of predicates for the given field
Map<String, List<Predicate<Long>>> predicateMap = new HashMap<>();
List<Predicate<Long>> predicates = predicateMap.getOrDefault(field, new ArrayList<>());

// Create a predicate to match the input query value
Predicate<Long> predicate = dimVal -> dimVal == inputQueryVal;
predicates.add(predicate);

// Put the predicates list back into the map
predicateMap.put(field, predicates);
// Create a map with the field and the value
Map<String, Long> predicateMap = new HashMap<>();
predicateMap.put(field, inputQueryVal);
return predicateMap;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,14 @@ static class StarTreeResult {

private final StarTreeNode starTreeRoot;

Map<String, List<Predicate<Long>>> _predicateEvaluators;
Map<String, Long> _predicateEvaluators;

DocIdSetBuilder docsWithField;

DocIdSetBuilder.BulkAdder adder;
Map<String, DocIdSetIterator> dimValueMap;

public StarTreeFilter(StarTreeValues starTreeAggrStructure, Map<String, List<Predicate<Long>>> predicateEvaluators) {
public StarTreeFilter(StarTreeValues starTreeAggrStructure, Map<String, Long> predicateEvaluators) {
// This filter operator does not support AND/OR/NOT operations.
starTreeRoot = starTreeAggrStructure.getRoot();
dimValueMap = starTreeAggrStructure.getDimensionDocValuesIteratorMap();
Expand All @@ -87,33 +87,27 @@ public DocIdSetIterator getStarTreeResult() throws IOException {
List<DocIdSetIterator> andIterators = new ArrayList<>();
andIterators.add(starTreeResult._matchedDocIds.build().iterator());
DocIdSetIterator docIdSetIterator = andIterators.get(0);

// No matches, return
if (starTreeResult.maxMatchedDoc == -1) {
return docIdSetIterator;
}
int docCount = 0;
for (String remainingPredicateColumn : starTreeResult._remainingPredicateColumns) {
// TODO : set to max value of doc values
logger.debug("remainingPredicateColumn : {}, maxMatchedDoc : {} ", remainingPredicateColumn, starTreeResult.maxMatchedDoc);
DocIdSetBuilder builder = new DocIdSetBuilder(starTreeResult.maxMatchedDoc + 1);
List<Predicate<Long>> compositePredicateEvaluators = _predicateEvaluators.get(remainingPredicateColumn);
SortedNumericDocValues ndv = (SortedNumericDocValues) this.dimValueMap.get(remainingPredicateColumn);
List<Integer> docIds = new ArrayList<>();
long queryValue = _predicateEvaluators.get(remainingPredicateColumn); // Get the query value directly

while (docIdSetIterator.nextDoc() != NO_MORE_DOCS) {
docCount++;
int docID = docIdSetIterator.docID();
if (ndv.advanceExact(docID)) {
final int valuesCount = ndv.docValueCount();
long value = ndv.nextValue();
for (Predicate<Long> compositePredicateEvaluator : compositePredicateEvaluators) {
// TODO : this might be expensive as its done against all doc values docs
if (compositePredicateEvaluator.test(value)) {
for (int i = 0; i < valuesCount; i++) {
long value = ndv.nextValue();
// Directly compare value with queryValue
if (value == queryValue) {
docIds.add(docID);
for (int i = 0; i < valuesCount - 1; i++) {
while (docIdSetIterator.nextDoc() != NO_MORE_DOCS) {
docIds.add(docIdSetIterator.docID());
}
}
break;
}
}
Expand All @@ -134,36 +128,24 @@ public DocIdSetIterator getStarTreeResult() throws IOException {
*/
private StarTreeResult traverseStarTree() throws IOException {
Set<String> globalRemainingPredicateColumns = null;

StarTreeNode starTree = starTreeRoot;

List<String> dimensionNames = new ArrayList<>(dimValueMap.keySet());

// Track whether we have found a leaf node added to the queue. If we have found a leaf node, and
// traversed to the
// level of the leave node, we can set globalRemainingPredicateColumns if not already set
// because we know the leaf
// node won't split further on other predicate columns.
boolean foundLeafNode = starTree.isLeaf();

// Use BFS to traverse the star tree
Queue<StarTreeNode> queue = new ArrayDeque<>();
queue.add(starTree);
int currentDimensionId = -1;
Set<String> remainingPredicateColumns = new HashSet<>(_predicateEvaluators.keySet());
if (foundLeafNode) {
globalRemainingPredicateColumns = new HashSet<>(remainingPredicateColumns);
}

int matchedDocsCountInStarTree = 0;
int maxDocNum = -1;

StarTreeNode starTreeNode;
List<Integer> docIds = new ArrayList<>();

while ((starTreeNode = queue.poll()) != null) {
int dimensionId = starTreeNode.getDimensionId();
if (dimensionId > currentDimensionId) {
// Previous level finished
String dimension = dimensionNames.get(dimensionId);
remainingPredicateColumns.remove(dimension);
if (foundLeafNode && globalRemainingPredicateColumns == null) {
Expand All @@ -172,7 +154,6 @@ private StarTreeResult traverseStarTree() throws IOException {
currentDimensionId = dimensionId;
}

// If all predicate columns columns are matched, we can use aggregated document
if (remainingPredicateColumns.isEmpty()) {
int docId = starTreeNode.getAggregatedDocId();
docIds.add(docId);
Expand All @@ -181,10 +162,6 @@ private StarTreeResult traverseStarTree() throws IOException {
continue;
}

// For leaf node, because we haven't exhausted all predicate columns and group-by columns,
// we cannot use the aggregated document.
// Add the range of documents for this node to the bitmap, and keep track of the
// remaining predicate columns for this node
if (starTreeNode.isLeaf()) {
for (long i = starTreeNode.getStartDocId(); i < starTreeNode.getEndDocId(); i++) {
docIds.add((int) i);
Expand All @@ -194,75 +171,26 @@ private StarTreeResult traverseStarTree() throws IOException {
continue;
}

// For non-leaf node, proceed to next level
String childDimension = dimensionNames.get(dimensionId + 1);

// Only read star-node when the dimension is not in the global remaining predicate columns
// because we cannot use star-node in such cases
StarTreeNode starNode = null;
if ((globalRemainingPredicateColumns == null || !globalRemainingPredicateColumns.contains(childDimension))) {
if (globalRemainingPredicateColumns == null || !globalRemainingPredicateColumns.contains(childDimension)) {
starNode = starTreeNode.getChildForDimensionValue(StarTreeUtils.ALL, true);
}

if (remainingPredicateColumns.contains(childDimension)) {
// Have predicates on the next level, add matching nodes to the queue

// Calculate the matching dictionary ids for the child dimension
int numChildren = starTreeNode.getNumChildren();

// If number of matching dictionary ids is large, use scan instead of binary search
long queryValue = _predicateEvaluators.get(childDimension); // Get the query value directly from the map
int matchingChildId = findFirstMatchingChild(starTreeNode, queryValue);

Iterator<? extends StarTreeNode> childrenIterator = starTreeNode.getChildrenIterator();

// When the star-node exists, and the number of matching doc ids is more than or equal to
// the number of non-star child nodes, check if all the child nodes match the predicate,
// and use the star-node if so
if (starNode != null) {
List<StarTreeNode> matchingChildNodes = new ArrayList<>();
boolean findLeafChildNode = false;
while (childrenIterator.hasNext()) {
StarTreeNode childNode = childrenIterator.next();
List<Predicate<Long>> predicates = _predicateEvaluators.get(childDimension);
for (Predicate<Long> predicate : predicates) {
long val = childNode.getDimensionValue();
if (predicate.test(val)) {
matchingChildNodes.add(childNode);
findLeafChildNode |= childNode.isLeaf();
break;
}
}
}
if (matchingChildNodes.size() == numChildren - 1) {
// All the child nodes (except for the star-node) match the predicate, use the star-node
queue.add(starNode);
foundLeafNode |= starNode.isLeaf();
} else {
// Some child nodes do not match the predicate, use the matching child nodes
queue.addAll(matchingChildNodes);
foundLeafNode |= findLeafChildNode;
}
} else {
// Cannot use the star-node, use the matching child nodes
while (childrenIterator.hasNext()) {
StarTreeNode childNode = childrenIterator.next();
List<Predicate<Long>> predicates = _predicateEvaluators.get(childDimension);
for (Predicate<Long> predicate : predicates) {
if (predicate.test(childNode.getDimensionValue())) {
queue.add(childNode);
foundLeafNode |= childNode.isLeaf();
break;
}
}
}
if (matchingChildId != -1) {
StarTreeNode matchingChild = starTreeNode.getChildForDimensionValue(matchingChildId, false);
queue.add(matchingChild);
foundLeafNode |= matchingChild.isLeaf();
}
} else {
// No predicate on the next level
if (starNode != null) {
// Star-node exists, use it
queue.add(starNode);
foundLeafNode |= starNode.isLeaf();
} else {
// Star-node does not exist or cannot be used, add all non-star nodes to the queue
Iterator<? extends StarTreeNode> childrenIterator = starTreeNode.getChildrenIterator();
while (childrenIterator.hasNext()) {
StarTreeNode childNode = childrenIterator.next();
Expand All @@ -286,4 +214,23 @@ private StarTreeResult traverseStarTree() throws IOException {
maxDocNum
);
}


private int findFirstMatchingChild(StarTreeNode parentNode, long value) throws IOException {
int left = 0;
int right = parentNode.getNumChildren() - 1;
while (left <= right) {
int mid = left + (right - left) / 2;
StarTreeNode midNode = parentNode.getChildForDimensionValue(mid, false);
long midValue = midNode.getDimensionValue();
if (midValue == value) {
return mid; // Found the matching child
} else if (midValue < value) {
left = mid + 1;
} else {
right = mid - 1;
}
}
return -1; // No matching child found
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;

/**
* Query class for querying star tree data structure.
Expand All @@ -44,14 +43,14 @@ public class StarTreeQuery extends Query implements Accountable {
CompositeIndexFieldInfo starTree;

/**
* Map of field name to a list of predicates to be applied on that field
* This is used to filter the data based on the predicates
* Map of field name to a value to be queried for that field
* This is used to filter the data based on the query
*/
Map<String, List<Predicate<Long>>> compositePredicateMap;
Map<String, Long> queryMap;

public StarTreeQuery(CompositeIndexFieldInfo starTree, Map<String, List<Predicate<Long>>> compositePredicateMap) {
public StarTreeQuery(CompositeIndexFieldInfo starTree, Map<String, Long> queryMap) {
this.starTree = starTree;
this.compositePredicateMap = compositePredicateMap;
this.queryMap = queryMap;
}

@Override
Expand Down Expand Up @@ -98,7 +97,7 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
return null;
}

StarTreeFilter filter = new StarTreeFilter(starTreeValues, compositePredicateMap);
StarTreeFilter filter = new StarTreeFilter(starTreeValues, queryMap);
DocIdSetIterator result = filter.getStarTreeResult();
return new ConstantScoreScorer(this, score(), scoreMode, result);
}
Expand Down

0 comments on commit c775c14

Please sign in to comment.