Skip to content

Commit

Permalink
star tree parsing approach
Browse files Browse the repository at this point in the history
  • Loading branch information
sandeshkr419 committed Jul 17, 2024
1 parent 781a2d4 commit a138757
Show file tree
Hide file tree
Showing 15 changed files with 1,204 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ public class FeatureFlags {
* aggregations.
*/
public static final String STAR_TREE_INDEX = "opensearch.experimental.feature.composite_index.star_tree.enabled";
public static final Setting<Boolean> STAR_TREE_INDEX_SETTING = Setting.boolSetting(STAR_TREE_INDEX, false, Property.NodeScope);
public static final Setting<Boolean> STAR_TREE_INDEX_SETTING = Setting.boolSetting(STAR_TREE_INDEX, true, Property.NodeScope);

private static final List<Setting<Boolean>> ALL_FEATURE_FLAG_SETTINGS = List.of(
REMOTE_STORE_MIGRATION_EXPERIMENTAL_SETTING,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
import org.opensearch.search.aggregations.support.AggregationUsageService;
import org.opensearch.search.aggregations.support.ValuesSourceRegistry;
import org.opensearch.search.lookup.SearchLookup;
import org.opensearch.search.query.startree.StarTreeQuery;
import org.opensearch.transport.RemoteClusterAware;

import java.io.IOException;
Expand Down Expand Up @@ -498,6 +499,12 @@ public boolean indexSortedOnField(String field) {
return indexSortConfig.hasPrimarySortOnField(field);
}

public ParsedQuery toStarTreeQuery(Map<String, List<Predicate<Long>>> compositePredicateMap,
Set<String> groupByColumns) {
StarTreeQuery starTreeQuery = new StarTreeQuery(compositePredicateMap, groupByColumns);
return new ParsedQuery(starTreeQuery);
}

public ParsedQuery toQuery(QueryBuilder queryBuilder) {
return toQuery(queryBuilder, q -> {
Query query = q.toQuery(this);
Expand Down
43 changes: 43 additions & 0 deletions server/src/main/java/org/opensearch/search/SearchService.java
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@
import org.opensearch.search.aggregations.MultiBucketConsumerService;
import org.opensearch.search.aggregations.SearchContextAggregations;
import org.opensearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree;
import org.opensearch.search.aggregations.startree.StarTreeAggregator;
import org.opensearch.search.aggregations.startree.StarTreeAggregatorFactory;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.collapse.CollapseContext;
import org.opensearch.search.dfs.DfsPhase;
Expand Down Expand Up @@ -148,6 +150,7 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -1314,12 +1317,19 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
context.evaluateRequestShouldUseConcurrentSearch();
return;
}

// Can be marked false for majority cases for which star-tree cannot be used
// Will save checking the criteria later and we can have a limit on what search requests are supported
// As we increment the cases where star-tree can be used, this can be set back to true
boolean canUseStarTree = context.mapperService().isCompositeIndexPresent();

SearchShardTarget shardTarget = context.shardTarget();
QueryShardContext queryShardContext = context.getQueryShardContext();
context.from(source.from());
context.size(source.size());
Map<String, InnerHitContextBuilder> innerHitBuilders = new HashMap<>();
if (source.query() != null) {
canUseStarTree = false;
InnerHitContextBuilder.extractInnerHits(source.query(), innerHitBuilders);
context.parsedQuery(queryShardContext.toQuery(source.query()));
}
Expand Down Expand Up @@ -1496,6 +1506,39 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
if (source.profile()) {
context.setProfilers(new Profilers(context.searcher(), context.shouldUseConcurrentSearch()));
}

if (canUseStarTree) {
try {
if (setStarTreeQuery(context, queryShardContext, source)) {
logger.info("Star Tree will be used in execution");
};
} catch (IOException e) {
logger.info("Cannot use star-tree");
}

}
}

private boolean setStarTreeQuery(SearchContext context, QueryShardContext queryShardContext, SearchSourceBuilder source) throws IOException {

// TODO: (finish)
// 1. Check criteria for star-tree query / aggregation formation
// 2: Set StarTree Query & Star Tree Aggregator here

if (source.aggregations() == null) {
return false;
}

context.parsedQuery(queryShardContext.toStarTreeQuery(null, Set.of("sum_status")));

StarTreeAggregatorFactory factory = new StarTreeAggregatorFactory("sum_status", queryShardContext, null, null, null, List.of("status"), List.of("sum"));
StarTreeAggregatorFactory[] factories = {factory};
AggregatorFactories aggregatorFactories = new AggregatorFactories(factories);

context.aggregations(new SearchContextAggregations(aggregatorFactories, multiBucketConsumerService.create()));

// StarTreeAggregatorFactory factory = new StarTreeAggregatorFactory()
return true;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
import org.opensearch.core.common.breaker.CircuitBreaker;
import org.opensearch.core.common.breaker.CircuitBreakingException;
import org.opensearch.core.indices.breaker.CircuitBreakerService;
import org.opensearch.index.compositeindex.datacube.MetricStat;
import org.opensearch.index.compositeindex.datacube.startree.aggregators.ValueAggregator;
import org.opensearch.index.compositeindex.datacube.startree.aggregators.ValueAggregatorFactory;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.aggregations.support.ValuesSourceConfig;
import org.opensearch.search.internal.SearchContext;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,13 @@ public boolean test(AggregatorFactory o) {
}
};

private AggregatorFactory[] factories;
protected AggregatorFactory[] factories;

public static Builder builder() {
return new Builder();
}

private AggregatorFactories(AggregatorFactory[] factories) {
public AggregatorFactories(AggregatorFactory[] factories) {
this.factories = factories;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.search.aggregations.startree;

import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.search.aggregations.Aggregation;
import org.opensearch.search.aggregations.InternalAggregation;
import org.opensearch.search.aggregations.InternalAggregations;
import org.opensearch.search.aggregations.InternalMultiBucketAggregation;
import org.opensearch.search.aggregations.support.CoreValuesSourceType;
import org.opensearch.search.aggregations.support.ValueType;
import org.opensearch.search.aggregations.support.ValuesSourceType;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

public class InternalStarTree<B extends InternalStarTree.Bucket, R extends InternalStarTree<B, R>> extends InternalMultiBucketAggregation<
R,
B> {
static final InternalStarTree.Factory FACTORY = new InternalStarTree.Factory();

public static class Bucket extends InternalMultiBucketAggregation.InternalBucket {
public long sum;
public InternalAggregations aggregations;
private final String key;

public Bucket(String key, long sum, InternalAggregations aggregations) {
this.key = key;
this.sum = sum;
this.aggregations = aggregations;
}

@Override
public String getKey() {
return getKeyAsString();
}

@Override
public String getKeyAsString() {
return key;
}

@Override
public long getDocCount() {
return sum;
}

@Override
public InternalAggregations getAggregations() {
return aggregations;
}

protected InternalStarTree.Factory<? extends InternalStarTree.Bucket, ?> getFactory() {
return FACTORY;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(Aggregation.CommonFields.KEY.getPreferredName(), key);
// TODO : this is hack ( we are mapping bucket.noofdocs to sum )
builder.field("SUM", sum);
aggregations.toXContentInternal(builder, params);
builder.endObject();
return builder;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(key);
out.writeVLong(sum);
aggregations.writeTo(out);
}

@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
if (other == null || getClass() != other.getClass()) {
return false;
}
InternalStarTree.Bucket that = (InternalStarTree.Bucket) other;
return Objects.equals(sum, that.sum) && Objects.equals(aggregations, that.aggregations) && Objects.equals(key, that.key);
}

@Override
public int hashCode() {
return Objects.hash(getClass(), sum, aggregations, key);
}
}

public static class Factory<B extends Bucket, R extends InternalStarTree<B, R>> {
public ValuesSourceType getValueSourceType() {
return CoreValuesSourceType.NUMERIC;
}

public ValueType getValueType() {
return ValueType.NUMERIC;
}

@SuppressWarnings("unchecked")
public R create(String name, List<B> ranges, Map<String, Object> metadata) {
return (R) new InternalStarTree<B, R>(name, ranges, metadata);
}

@SuppressWarnings("unchecked")
public B createBucket(String key, long docCount, InternalAggregations aggregations) {
return (B) new InternalStarTree.Bucket(key, docCount, aggregations);
}

@SuppressWarnings("unchecked")
public R create(List<B> ranges, R prototype) {
return (R) new InternalStarTree<B, R>(prototype.name, ranges, prototype.metadata);
}

@SuppressWarnings("unchecked")
public B createBucket(InternalAggregations aggregations, B prototype) {
// TODO : prototype.getDocCount() -- is mapped to sum - change this
return (B) new InternalStarTree.Bucket(prototype.getKey(), prototype.getDocCount(), aggregations);
}
}

public InternalStarTree.Factory<B, R> getFactory() {
return FACTORY;
}

private final List<B> ranges;

public InternalStarTree(String name, List<B> ranges, Map<String, Object> metadata) {
super(name, metadata);
this.ranges = ranges;
}

/**
* Read from a stream.
*/
public InternalStarTree(StreamInput in) throws IOException {
super(in);
int size = in.readVInt();
List<B> ranges = new ArrayList<>(size);
for (int i = 0; i < size; i++) {
String key = in.readString();
ranges.add(getFactory().createBucket(key, in.readVLong(), InternalAggregations.readFrom(in)));
}
this.ranges = ranges;
}

@Override
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeVInt(ranges.size());
for (B bucket : ranges) {
bucket.writeTo(out);
}
}

@Override
public String getWriteableName() {
return "startree";
}

@Override
public List<B> getBuckets() {
return ranges;
}

public R create(List<B> buckets) {
return getFactory().create(buckets, (R) this);
}

@Override
public B createBucket(InternalAggregations aggregations, B prototype) {
return getFactory().createBucket(aggregations, prototype);
}

@Override
public InternalAggregation reduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
Map<String, List<B>> bucketsMap = new HashMap<>();

for (InternalAggregation aggregation : aggregations) {
InternalStarTree<B, R> filters = (InternalStarTree<B, R>) aggregation;
int i = 0;
for (B bucket : filters.ranges) {
String key = bucket.getKey();
List<B> sameRangeList = bucketsMap.get(key);
if (sameRangeList == null) {
sameRangeList = new ArrayList<>(aggregations.size());
bucketsMap.put(key, sameRangeList);
}
sameRangeList.add(bucket);
}
}

ArrayList<B> reducedBuckets = new ArrayList<>(bucketsMap.size());

for (List<B> sameRangeList : bucketsMap.values()) {
B reducedBucket = reduceBucket(sameRangeList, reduceContext);
if (reducedBucket.getDocCount() >= 1) {
reducedBuckets.add(reducedBucket);
}
}
reduceContext.consumeBucketsAndMaybeBreak(reducedBuckets.size());
Collections.sort(reducedBuckets, Comparator.comparing(InternalStarTree.Bucket::getKey));

return getFactory().create(name, reducedBuckets, getMetadata());
}

@Override
protected B reduceBucket(List<B> buckets, ReduceContext context) {
assert buckets.size() > 0;

B reduced = null;
List<InternalAggregations> aggregationsList = new ArrayList<>(buckets.size());
for (B bucket : buckets) {
if (reduced == null) {
reduced = (B) new Bucket(bucket.getKey(), bucket.getDocCount(), bucket.getAggregations());
} else {
reduced.sum += bucket.sum;
}
aggregationsList.add(bucket.getAggregations());
}
reduced.aggregations = InternalAggregations.reduce(aggregationsList, context);
return reduced;
}

@Override
public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
builder.startArray(CommonFields.BUCKETS.getPreferredName());

for (B range : ranges) {
range.toXContent(builder, params);
}
builder.endArray();
return builder;
}

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), ranges);
}

@Override
public boolean equals(Object obj) {
if (this == obj) return true;
if (obj == null || getClass() != obj.getClass()) return false;
if (super.equals(obj) == false) return false;

InternalStarTree<?, ?> that = (InternalStarTree<?, ?>) obj;
return Objects.equals(ranges, that.ranges);
}

}
Loading

0 comments on commit a138757

Please sign in to comment.