Skip to content

Commit

Permalink
Star Tree Search request/response changes
Browse files Browse the repository at this point in the history
Signed-off-by: Sandesh Kumar <sandeshkr419@gmail.com>
  • Loading branch information
sandeshkr419 committed Aug 25, 2024
1 parent b99c73a commit d812134
Show file tree
Hide file tree
Showing 26 changed files with 1,195 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ public List<SequentialDocValuesIterator> getMetricReaders(SegmentWriteState stat
continue;
}
FieldInfo metricFieldInfo = state.fieldInfos.fieldInfo(metric.getField());
// if (metricStat != MetricStat.COUNT) {
// if (metricStat != MetricStat.VALUE_COUNT) {
if (metricFieldInfo == null) {
metricFieldInfo = StarTreeUtils.getFieldInfo(metric.getField(), 1);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@
import org.opensearch.index.IndexSortConfig;
import org.opensearch.index.analysis.IndexAnalyzers;
import org.opensearch.index.cache.bitset.BitsetFilterCache;
import org.opensearch.index.codec.composite.CompositeIndexFieldInfo;
import org.opensearch.index.compositeindex.datacube.Dimension;
import org.opensearch.index.compositeindex.datacube.Metric;
import org.opensearch.index.compositeindex.datacube.MetricStat;
import org.opensearch.index.fielddata.IndexFieldData;
import org.opensearch.index.mapper.CompositeDataCubeFieldType;
import org.opensearch.index.mapper.ContentPath;
import org.opensearch.index.mapper.DerivedFieldResolver;
import org.opensearch.index.mapper.DerivedFieldResolverFactory;
Expand All @@ -73,12 +78,22 @@
import org.opensearch.script.ScriptContext;
import org.opensearch.script.ScriptFactory;
import org.opensearch.script.ScriptService;
import org.opensearch.search.aggregations.AggregatorFactory;
import org.opensearch.search.aggregations.metrics.AvgAggregatorFactory;
import org.opensearch.search.aggregations.metrics.MaxAggregatorFactory;
import org.opensearch.search.aggregations.metrics.MinAggregatorFactory;
import org.opensearch.search.aggregations.metrics.SumAggregatorFactory;
import org.opensearch.search.aggregations.metrics.ValueCountAggregatorFactory;
import org.opensearch.search.aggregations.support.AggregationUsageService;
import org.opensearch.search.aggregations.support.ValuesSourceAggregatorFactory;
import org.opensearch.search.aggregations.support.ValuesSourceRegistry;
import org.opensearch.search.lookup.SearchLookup;
import org.opensearch.search.startree.OriginalOrStarTreeQuery;
import org.opensearch.search.startree.StarTreeQuery;
import org.opensearch.transport.RemoteClusterAware;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
Expand All @@ -89,6 +104,7 @@
import java.util.function.LongSupplier;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import static java.util.Collections.emptyList;
import static java.util.Collections.emptyMap;
Expand Down Expand Up @@ -522,6 +538,80 @@ private ParsedQuery toQuery(QueryBuilder queryBuilder, CheckedFunction<QueryBuil
}
}

public ParsedQuery toStarTreeQuery(
CompositeIndexFieldInfo starTree,
CompositeDataCubeFieldType compositeIndexFieldInfo,
QueryBuilder queryBuilder,
Query query
) {
Map<String, List<Predicate<Long>>> predicateMap;

if (queryBuilder == null) {
predicateMap = null;
} else if (queryBuilder instanceof TermQueryBuilder) {
List<String> supportedDimensions = compositeIndexFieldInfo.getDimensions()
.stream()
.map(Dimension::getField)
.collect(Collectors.toList());
predicateMap = getStarTreePredicates(queryBuilder, supportedDimensions);
} else {
return null;
}

StarTreeQuery starTreeQuery = new StarTreeQuery(starTree, predicateMap);
OriginalOrStarTreeQuery originalOrStarTreeQuery = new OriginalOrStarTreeQuery(starTreeQuery, query);
return new ParsedQuery(originalOrStarTreeQuery);
}

/**
* Parse query body to star-tree predicates
* @param queryBuilder
* @return predicates to match
*/
private Map<String, List<Predicate<Long>>> getStarTreePredicates(QueryBuilder queryBuilder, List<String> supportedDimensions) {
TermQueryBuilder tq = (TermQueryBuilder) queryBuilder;
String field = tq.fieldName();
if (supportedDimensions.contains(field) == false) {
throw new IllegalArgumentException("unsupported field in star-tree");
}
long inputQueryVal = Long.parseLong(tq.value().toString());

// Get or create the list of predicates for the given field
Map<String, List<Predicate<Long>>> predicateMap = new HashMap<>();
List<Predicate<Long>> predicates = predicateMap.getOrDefault(field, new ArrayList<>());

// Create a predicate to match the input query value
Predicate<Long> predicate = dimVal -> dimVal == inputQueryVal;
predicates.add(predicate);

// Put the predicates list back into the map
predicateMap.put(field, predicates);
return predicateMap;
}

public boolean validateStarTreeMetricSuport(CompositeDataCubeFieldType compositeIndexFieldInfo, AggregatorFactory aggregatorFactory) {
String field;
Map<String, List<MetricStat>> supportedMetrics = compositeIndexFieldInfo.getMetrics()
.stream()
.collect(Collectors.toMap(Metric::getField, Metric::getMetrics));

// Map to associate supported AggregatorFactory classes with their corresponding MetricStat
Map<Class<? extends ValuesSourceAggregatorFactory>, MetricStat> aggregatorStatMap = Map.of(
SumAggregatorFactory.class, MetricStat.SUM,
MaxAggregatorFactory.class, MetricStat.MAX,
MinAggregatorFactory.class, MetricStat.MIN,
ValueCountAggregatorFactory.class, MetricStat.VALUE_COUNT,
AvgAggregatorFactory.class, MetricStat.AVG
);

MetricStat metricStat = aggregatorStatMap.get(aggregatorFactory.getClass());
if (metricStat != null) {
field = ((ValuesSourceAggregatorFactory)aggregatorFactory).getField();
return supportedMetrics.containsKey(field) && supportedMetrics.get(field).contains(metricStat);
}
return false;
}

public Index index() {
return indexSettings.getIndex();
}
Expand Down
68 changes: 63 additions & 5 deletions server/src/main/java/org/opensearch/search/SearchService.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,16 @@
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.IndexService;
import org.opensearch.index.IndexSettings;
import org.opensearch.index.codec.composite.CompositeIndexFieldInfo;
import org.opensearch.index.engine.Engine;
import org.opensearch.index.mapper.CompositeDataCubeFieldType;
import org.opensearch.index.mapper.DerivedFieldResolver;
import org.opensearch.index.mapper.DerivedFieldResolverFactory;
import org.opensearch.index.mapper.StarTreeMapper;
import org.opensearch.index.query.InnerHitContextBuilder;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.MatchNoneQueryBuilder;
import org.opensearch.index.query.ParsedQuery;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
Expand All @@ -97,11 +101,13 @@
import org.opensearch.script.ScriptService;
import org.opensearch.search.aggregations.AggregationInitializationException;
import org.opensearch.search.aggregations.AggregatorFactories;
import org.opensearch.search.aggregations.AggregatorFactory;
import org.opensearch.search.aggregations.InternalAggregation;
import org.opensearch.search.aggregations.InternalAggregation.ReduceContext;
import org.opensearch.search.aggregations.MultiBucketConsumerService;
import org.opensearch.search.aggregations.SearchContextAggregations;
import org.opensearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree;
import org.opensearch.search.aggregations.support.ValuesSourceAggregatorFactory;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.collapse.CollapseContext;
import org.opensearch.search.dfs.DfsPhase;
Expand Down Expand Up @@ -162,6 +168,7 @@
import static org.opensearch.common.unit.TimeValue.timeValueHours;
import static org.opensearch.common.unit.TimeValue.timeValueMillis;
import static org.opensearch.common.unit.TimeValue.timeValueMinutes;
import static org.opensearch.search.internal.SearchContext.TRACK_TOTAL_HITS_DISABLED;

/**
* The main search service
Expand Down Expand Up @@ -1314,6 +1321,10 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
context.evaluateRequestShouldUseConcurrentSearch();
return;
}
// Can be marked false for majority cases for which star-tree cannot be used
// As we increment the cases where star-tree can be used, this can be set back to true
boolean canUseStarTree = context.mapperService().isCompositeIndexPresent();

SearchShardTarget shardTarget = context.shardTarget();
QueryShardContext queryShardContext = context.getQueryShardContext();
context.from(source.from());
Expand All @@ -1324,10 +1335,12 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
context.parsedQuery(queryShardContext.toQuery(source.query()));
}
if (source.postFilter() != null) {
canUseStarTree = false;
InnerHitContextBuilder.extractInnerHits(source.postFilter(), innerHitBuilders);
context.parsedPostFilter(queryShardContext.toQuery(source.postFilter()));
}
if (innerHitBuilders.size() > 0) {
if (!innerHitBuilders.isEmpty()) {
canUseStarTree = false;
for (Map.Entry<String, InnerHitContextBuilder> entry : innerHitBuilders.entrySet()) {
try {
entry.getValue().build(context, context.innerHits());
Expand All @@ -1337,11 +1350,10 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
}
}
if (source.sorts() != null) {
canUseStarTree = false;
try {
Optional<SortAndFormats> optionalSort = SortBuilder.buildSort(source.sorts(), context.getQueryShardContext());
if (optionalSort.isPresent()) {
context.sort(optionalSort.get());
}
optionalSort.ifPresent(context::sort);
} catch (IOException e) {
throw new SearchException(shardTarget, "failed to create sort elements", e);
}
Expand All @@ -1355,8 +1367,10 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
}
if (source.trackTotalHitsUpTo() != null) {
context.trackTotalHitsUpTo(source.trackTotalHitsUpTo());
canUseStarTree = canUseStarTree && (source.trackTotalHitsUpTo() == TRACK_TOTAL_HITS_DISABLED);
}
if (source.minScore() != null) {
canUseStarTree = false;
context.minimumScore(source.minScore());
}
if (source.timeout() != null) {
Expand Down Expand Up @@ -1496,6 +1510,50 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
if (source.profile()) {
context.setProfilers(new Profilers(context.searcher(), context.shouldUseConcurrentSearch()));
}

if (canUseStarTree) {
try {
setStarTreeQuery(context, queryShardContext, source);
logger.debug("can use star tree");
} catch (IOException e) {
logger.debug("not using star tree");
}
}
}

private boolean setStarTreeQuery(SearchContext context, QueryShardContext queryShardContext, SearchSourceBuilder source)
throws IOException {

if (source.aggregations() == null) {
return false;
}

// Current implementation assumes only single star-tree is supported
CompositeDataCubeFieldType compositeMappedFieldType = (StarTreeMapper.StarTreeFieldType) context.mapperService()
.getCompositeFieldTypes()
.iterator()
.next();
CompositeIndexFieldInfo starTree = new CompositeIndexFieldInfo(
compositeMappedFieldType.name(),
compositeMappedFieldType.getCompositeIndexType()
);

ParsedQuery newParsedQuery = queryShardContext.toStarTreeQuery(starTree, compositeMappedFieldType, source.query(), context.query());
if (newParsedQuery == null) {
return false;
}

for (AggregatorFactory aggregatorFactory : context.aggregations().factories().getFactories()) {
if (!(aggregatorFactory instanceof ValuesSourceAggregatorFactory
&& aggregatorFactory.getSubFactories().getFactories().length == 0)) {
return false;
}
if (queryShardContext.validateStarTreeMetricSuport(compositeMappedFieldType, aggregatorFactory) == false) {
return false;
}
}
context.parsedQuery(newParsedQuery);
return true;
}

/**
Expand Down Expand Up @@ -1655,7 +1713,7 @@ public static boolean canMatchSearchAfter(
&& minMax != null
&& primarySortField != null
&& primarySortField.missing() == null
&& Objects.equals(trackTotalHitsUpto, SearchContext.TRACK_TOTAL_HITS_DISABLED)) {
&& Objects.equals(trackTotalHitsUpto, TRACK_TOTAL_HITS_DISABLED)) {
final Object searchAfterPrimary = searchAfter.fields[0];
if (primarySortField.order() == SortOrder.DESC) {
if (minMax.compareMin(searchAfterPrimary) > 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ public static Builder builder() {
return new Builder();
}

private AggregatorFactories(AggregatorFactory[] factories) {
public AggregatorFactories(AggregatorFactory[] factories) {
this.factories = factories;
}

Expand Down Expand Up @@ -661,4 +661,8 @@ public PipelineTree buildPipelineTree() {
return new PipelineTree(subTrees, aggregators);
}
}

public AggregatorFactory[] getFactories() {
return factories;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,8 @@ protected boolean supportsConcurrentSegmentSearch() {
public boolean evaluateChildFactories() {
return factories.allFactoriesSupportConcurrentSearch();
}

public AggregatorFactories getSubFactories() {
return factories;
}
}
Loading

0 comments on commit d812134

Please sign in to comment.