Skip to content

Commit

Permalink
Star stree request/response changes
Browse files Browse the repository at this point in the history
Signed-off-by: Sandesh Kumar <sandeshkr419@gmail.com>
  • Loading branch information
sandeshkr419 committed Aug 13, 2024
1 parent 781a2d4 commit da0056b
Show file tree
Hide file tree
Showing 14 changed files with 712 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ public class FeatureFlags {
* aggregations.
*/
public static final String STAR_TREE_INDEX = "opensearch.experimental.feature.composite_index.star_tree.enabled";
public static final Setting<Boolean> STAR_TREE_INDEX_SETTING = Setting.boolSetting(STAR_TREE_INDEX, false, Property.NodeScope);
public static final Setting<Boolean> STAR_TREE_INDEX_SETTING = Setting.boolSetting(STAR_TREE_INDEX, true, Property.NodeScope);

private static final List<Setting<Boolean>> ALL_FEATURE_FLAG_SETTINGS = List.of(
REMOTE_STORE_MIGRATION_EXPERIMENTAL_SETTING,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@
import org.opensearch.index.IndexSortConfig;
import org.opensearch.index.analysis.IndexAnalyzers;
import org.opensearch.index.cache.bitset.BitsetFilterCache;
import org.opensearch.index.codec.composite.CompositeIndexFieldInfo;
import org.opensearch.index.compositeindex.datacube.Metric;
import org.opensearch.index.compositeindex.datacube.MetricStat;
import org.opensearch.index.fielddata.IndexFieldData;
import org.opensearch.index.mapper.CompositeDataCubeFieldType;
import org.opensearch.index.mapper.ContentPath;
import org.opensearch.index.mapper.DerivedFieldResolver;
import org.opensearch.index.mapper.DerivedFieldResolverFactory;
Expand All @@ -73,12 +77,17 @@
import org.opensearch.script.ScriptContext;
import org.opensearch.script.ScriptFactory;
import org.opensearch.script.ScriptService;
import org.opensearch.search.aggregations.AggregatorFactory;
import org.opensearch.search.aggregations.metrics.SumAggregatorFactory;
import org.opensearch.search.aggregations.support.AggregationUsageService;
import org.opensearch.search.aggregations.support.ValuesSourceRegistry;
import org.opensearch.search.lookup.SearchLookup;
import org.opensearch.search.startree.OriginalOrStarTreeQuery;
import org.opensearch.search.startree.StarTreeQuery;
import org.opensearch.transport.RemoteClusterAware;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
Expand All @@ -89,6 +98,7 @@
import java.util.function.LongSupplier;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import static java.util.Collections.emptyList;
import static java.util.Collections.emptyMap;
Expand Down Expand Up @@ -522,6 +532,66 @@ private ParsedQuery toQuery(QueryBuilder queryBuilder, CheckedFunction<QueryBuil
}
}

public ParsedQuery toStarTreeQuery(CompositeIndexFieldInfo starTree, QueryBuilder queryBuilder, Query query) {
Map<String, List<Predicate<Long>>> predicateMap = getStarTreePredicates(queryBuilder);
StarTreeQuery starTreeQuery = new StarTreeQuery(starTree, predicateMap, null);
OriginalOrStarTreeQuery originalOrStarTreeQuery = new OriginalOrStarTreeQuery(starTreeQuery, query);
return new ParsedQuery(originalOrStarTreeQuery);
}

/**
* Parse query body to star-tree predicates
* @param queryBuilder
* @return
*/
private Map<String, List<Predicate<Long>>> getStarTreePredicates(QueryBuilder queryBuilder) {
// Assuming the following variables have been initialized:
Map<String, List<Predicate<Long>>> predicateMap = new HashMap<>();

// Check if the query builder is an instance of TermQueryBuilder
if (queryBuilder instanceof TermQueryBuilder) {
TermQueryBuilder tq = (TermQueryBuilder) queryBuilder;
String field = tq.fieldName();
long inputQueryVal = Long.parseLong(tq.value().toString());

// Get or create the list of predicates for the given field
List<Predicate<Long>> predicates = predicateMap.getOrDefault(field, new ArrayList<>());

// Create a predicate to match the input query value
Predicate<Long> predicate = dimVal -> dimVal == inputQueryVal;
predicates.add(predicate);

// Put the predicates list back into the map
predicateMap.put(field, predicates);
} else {
throw new IllegalArgumentException("The query is not a term query");
}
return predicateMap;

}

public boolean validateStarTreeMetricSuport(CompositeDataCubeFieldType compositeIndexFieldInfo, AggregatorFactory aggregatorFactory) {
String field = null;
Map<String, List<MetricStat>> supportedMetrics = compositeIndexFieldInfo.getMetrics()
.stream()
.collect(Collectors.toMap(Metric::getField, Metric::getMetrics));

// Existing support only for MetricAggregators without sub-aggregations
if (aggregatorFactory.getSubFactories().getFactories().length != 0) {
return false;
}

// TODO: increment supported aggregation type
if (aggregatorFactory instanceof SumAggregatorFactory) {
field = ((SumAggregatorFactory) aggregatorFactory).getField();
if (supportedMetrics.containsKey(field) && supportedMetrics.get(field).contains(MetricStat.SUM)) {
return true;
}
}

return false;
}

public Index index() {
return indexSettings.getIndex();
}
Expand Down
55 changes: 52 additions & 3 deletions server/src/main/java/org/opensearch/search/SearchService.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,16 @@
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.IndexService;
import org.opensearch.index.IndexSettings;
import org.opensearch.index.codec.composite.CompositeIndexFieldInfo;
import org.opensearch.index.engine.Engine;
import org.opensearch.index.mapper.CompositeDataCubeFieldType;
import org.opensearch.index.mapper.DerivedFieldResolver;
import org.opensearch.index.mapper.DerivedFieldResolverFactory;
import org.opensearch.index.mapper.StarTreeMapper;
import org.opensearch.index.query.InnerHitContextBuilder;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.MatchNoneQueryBuilder;
import org.opensearch.index.query.ParsedQuery;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
Expand All @@ -97,11 +101,13 @@
import org.opensearch.script.ScriptService;
import org.opensearch.search.aggregations.AggregationInitializationException;
import org.opensearch.search.aggregations.AggregatorFactories;
import org.opensearch.search.aggregations.AggregatorFactory;
import org.opensearch.search.aggregations.InternalAggregation;
import org.opensearch.search.aggregations.InternalAggregation.ReduceContext;
import org.opensearch.search.aggregations.MultiBucketConsumerService;
import org.opensearch.search.aggregations.SearchContextAggregations;
import org.opensearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree;
import org.opensearch.search.aggregations.support.ValuesSourceAggregatorFactory;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.collapse.CollapseContext;
import org.opensearch.search.dfs.DfsPhase;
Expand Down Expand Up @@ -1314,6 +1320,11 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
context.evaluateRequestShouldUseConcurrentSearch();
return;
}
// Can be marked false for majority cases for which star-tree cannot be used
// Will save checking the criteria later and we can have a limit on what search requests are supported
// As we increment the cases where star-tree can be used, this can be set back to true
boolean canUseStarTree = context.mapperService().isCompositeIndexPresent();

SearchShardTarget shardTarget = context.shardTarget();
QueryShardContext queryShardContext = context.getQueryShardContext();
context.from(source.from());
Expand All @@ -1339,9 +1350,7 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
if (source.sorts() != null) {
try {
Optional<SortAndFormats> optionalSort = SortBuilder.buildSort(source.sorts(), context.getQueryShardContext());
if (optionalSort.isPresent()) {
context.sort(optionalSort.get());
}
optionalSort.ifPresent(context::sort);
} catch (IOException e) {
throw new SearchException(shardTarget, "failed to create sort elements", e);
}
Expand Down Expand Up @@ -1496,6 +1505,46 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
if (source.profile()) {
context.setProfilers(new Profilers(context.searcher(), context.shouldUseConcurrentSearch()));
}

if (canUseStarTree) {
try {
setStarTreeQuery(context, queryShardContext, source);
logger.info("using star tree");
} catch (IOException e) {
logger.info("not using star tree");
}
}
}

private boolean setStarTreeQuery(SearchContext context, QueryShardContext queryShardContext, SearchSourceBuilder source)
throws IOException {

if (source.aggregations() == null) {
return false;
}

// TODO: Support for multiple startrees
CompositeDataCubeFieldType compositeMappedFieldType = (StarTreeMapper.StarTreeFieldType) context.mapperService()
.getCompositeFieldTypes()
.iterator()
.next();
CompositeIndexFieldInfo starTree = new CompositeIndexFieldInfo(
compositeMappedFieldType.name(),
compositeMappedFieldType.getCompositeIndexType()
);

ParsedQuery parsedQuery = queryShardContext.toStarTreeQuery(starTree, source.query(), context.query());
AggregatorFactory aggregatorFactory = context.aggregations().factories().getFactories()[0];
if (!(aggregatorFactory instanceof ValuesSourceAggregatorFactory
&& aggregatorFactory.getSubFactories().getFactories().length == 0)) {
return false;
}

if (queryShardContext.validateStarTreeMetricSuport(compositeMappedFieldType, aggregatorFactory)) {
context.parsedQuery(parsedQuery);
}

return true;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ public static Builder builder() {
return new Builder();
}

private AggregatorFactories(AggregatorFactory[] factories) {
public AggregatorFactories(AggregatorFactory[] factories) {
this.factories = factories;
}

Expand Down Expand Up @@ -661,4 +661,8 @@ public PipelineTree buildPipelineTree() {
return new PipelineTree(subTrees, aggregators);
}
}

public AggregatorFactory[] getFactories() {
return factories;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,8 @@ protected boolean supportsConcurrentSegmentSearch() {
public boolean evaluateChildFactories() {
return factories.allFactoriesSupportConcurrentSearch();
}

public AggregatorFactories getSubFactories() {
return factories;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,20 @@

package org.opensearch.search.aggregations.metrics;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SegmentReader;
import org.opensearch.common.lucene.Lucene;
import org.opensearch.common.util.Comparators;
import org.opensearch.index.codec.composite.CompositeIndexFieldInfo;
import org.opensearch.index.codec.composite.CompositeIndexReader;
import org.opensearch.index.codec.composite.datacube.startree.StarTreeValues;
import org.opensearch.search.aggregations.Aggregator;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.sort.SortOrder;

import java.io.IOException;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;

/**
* Base class to aggregate all docs into a single numeric metric value.
Expand Down Expand Up @@ -107,4 +114,14 @@ public BucketComparator bucketComparator(String key, SortOrder order) {
return (lhs, rhs) -> Comparators.compareDiscardNaN(metric(key, lhs), metric(key, rhs), order == SortOrder.ASC);
}
}

protected StarTreeValues getStarTreeValues(LeafReaderContext ctx, CompositeIndexFieldInfo starTree) throws IOException {
SegmentReader reader = Lucene.segmentReader(ctx.reader());
if (!(reader.getDocValuesReader() instanceof CompositeIndexReader)) return null;
CompositeIndexReader starTreeDocValuesReader = (CompositeIndexReader) reader.getDocValuesReader();
StarTreeValues values = (StarTreeValues) starTreeDocValuesReader.getCompositeIndexValues(starTree);
final AtomicReference<StarTreeValues> aggrVal = new AtomicReference<>(null);

return values;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
import org.opensearch.common.lease.Releasables;
import org.opensearch.common.util.BigArrays;
import org.opensearch.common.util.DoubleArray;
import org.opensearch.index.codec.composite.CompositeIndexFieldInfo;
import org.opensearch.index.codec.composite.datacube.startree.StarTreeValues;
import org.opensearch.index.fielddata.SortedNumericDoubleValues;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.aggregations.Aggregator;
Expand All @@ -45,6 +47,8 @@
import org.opensearch.search.aggregations.support.ValuesSource;
import org.opensearch.search.aggregations.support.ValuesSourceConfig;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.startree.OriginalOrStarTreeQuery;
import org.opensearch.search.startree.StarTreeQuery;

import java.io.IOException;
import java.util.Map;
Expand All @@ -56,13 +60,13 @@
*/
public class SumAggregator extends NumericMetricsAggregator.SingleValue {

private final ValuesSource.Numeric valuesSource;
private final DocValueFormat format;
protected final ValuesSource.Numeric valuesSource;
protected final DocValueFormat format;

private DoubleArray sums;
private DoubleArray compensations;
protected DoubleArray sums;
protected DoubleArray compensations;

SumAggregator(
public SumAggregator(
String name,
ValuesSourceConfig valuesSourceConfig,
SearchContext context,
Expand All @@ -86,6 +90,14 @@ public ScoreMode scoreMode() {

@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException {
if (context.query() instanceof OriginalOrStarTreeQuery && ((OriginalOrStarTreeQuery) context.query()).isStarTreeUsed()) {
StarTreeQuery starTreeQuery = ((OriginalOrStarTreeQuery) context.query()).getStarTreeQuery();
return getStarTreeLeafCollector(ctx, sub, starTreeQuery.getStarTree());
}
return getDefaultLeafCollector(ctx, sub);
}

private LeafBucketCollector getDefaultLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
if (valuesSource == null) {
return LeafBucketCollector.NO_OP_COLLECTOR;
}
Expand Down Expand Up @@ -118,6 +130,28 @@ public void collect(int doc, long bucket) throws IOException {
};
}

private LeafBucketCollector getStarTreeLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub, CompositeIndexFieldInfo starTree)
throws IOException {
final BigArrays bigArrays = context.bigArrays();
final CompensatedSum kahanSummation = new CompensatedSum(0, 0);

StarTreeValues starTreeValues = getStarTreeValues(ctx, starTree);

//
String fieldName = ((ValuesSource.Numeric.FieldData) valuesSource).getIndexFieldName();

return new LeafBucketCollectorBase(sub, starTreeValues) {
@Override
public void collect(int doc, long bucket) throws IOException {
// TODO: Fix the response for collecting star tree sum
sums = bigArrays.grow(sums, bucket + 1);
compensations = bigArrays.grow(compensations, bucket + 1);
compensations.set(bucket, kahanSummation.delta());
sums.set(bucket, kahanSummation.value());
}
};
}

@Override
public double metric(long owningBucketOrd) {
if (valuesSource == null || owningBucketOrd >= sums.size()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
*
* @opensearch.internal
*/
class SumAggregatorFactory extends ValuesSourceAggregatorFactory {
public class SumAggregatorFactory extends ValuesSourceAggregatorFactory {

SumAggregatorFactory(
String name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,10 @@ public SortedNumericDocValues longValues(LeafReaderContext context) {
public SortedNumericDoubleValues doubleValues(LeafReaderContext context) {
return indexFieldData.load(context).getDoubleValues();
}

public String getIndexFieldName() {
return indexFieldData.getFieldName();
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,16 @@ protected abstract Aggregator doCreateInternal(
public String getStatsSubtype() {
return config.valueSourceType().typeName();
}

public String getField() {
return config.fieldContext().field();
}

public String getAggregationName() {
return name;
}

public ValuesSourceConfig getConfig() {
return config;
}
}
Loading

0 comments on commit da0056b

Please sign in to comment.