From 1ce776ae54ec99a0a601bae8a8ba392170938b92 Mon Sep 17 00:00:00 2001 From: Chaitanya Gohel Date: Tue, 31 Oct 2023 17:28:24 +0530 Subject: [PATCH] Enabling sort optimizatin back for half_float with custom comparators Signed-off-by: Chaitanya Gohel --- CHANGELOG.md | 3 +- .../test/search/260_sort_mixed.yml | 29 +++-- .../test/search/90_search_after.yml | 68 +++++++++++ .../opensearch/search/sort/FieldSortIT.java | 60 ++++++++++ .../fielddata/IndexNumericFieldData.java | 5 +- .../FloatValuesComparatorSource.java | 4 +- .../HalfFloatValuesComparatorSource.java | 55 +++++++++ .../comparators/HalfFloatComparator.java | 111 ++++++++++++++++++ 8 files changed, 324 insertions(+), 11 deletions(-) create mode 100644 server/src/main/java/org/opensearch/index/fielddata/fieldcomparator/HalfFloatValuesComparatorSource.java create mode 100644 server/src/main/java/org/opensearch/index/search/comparators/HalfFloatComparator.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 046693421d07b..f18c807db0fb9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - [AdmissionControl] Added changes for AdmissionControl Interceptor and AdmissionControlService for RateLimiting ([#9286](https://github.com/opensearch-project/OpenSearch/pull/9286)) - GHA to verify checklist items completion in PR descriptions ([#10800](https://github.com/opensearch-project/OpenSearch/pull/10800)) - [Remote cluster state] Restore cluster state version during remote state auto restore ([#10853](https://github.com/opensearch-project/OpenSearch/pull/10853)) +- Add back half_float BKD based sort query optimization ([#11024](https://github.com/opensearch-project/OpenSearch/pull/11024)) ### Dependencies - Bump `log4j-core` from 2.18.0 to 2.19.0 @@ -141,4 +142,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Security [Unreleased 3.0]: https://github.com/opensearch-project/OpenSearch/compare/2.x...HEAD -[Unreleased 2.x]: https://github.com/opensearch-project/OpenSearch/compare/2.12...2.x \ No newline at end of file +[Unreleased 2.x]: https://github.com/opensearch-project/OpenSearch/compare/2.12...2.x diff --git a/rest-api-spec/src/main/resources/rest-api-spec/test/search/260_sort_mixed.yml b/rest-api-spec/src/main/resources/rest-api-spec/test/search/260_sort_mixed.yml index ba2b18eb3b6d0..ac14f9bf7a241 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/test/search/260_sort_mixed.yml +++ b/rest-api-spec/src/main/resources/rest-api-spec/test/search/260_sort_mixed.yml @@ -20,6 +20,16 @@ properties: counter: type: double + + - do: + indices.create: + index: test_3 + body: + mappings: + properties: + counter: + type: half_float + - do: bulk: refresh: true @@ -33,6 +43,12 @@ - index: _index: test_2 - counter: 184.4 + - index: + _index: test_3 + - counter: 187.4 + - index: + _index: test_3 + - counter: 194.4 - do: search: @@ -40,8 +56,8 @@ rest_total_hits_as_int: true body: sort: [{ counter: desc }] - - match: { hits.total: 3 } - - length: { hits.hits: 3 } + - match: { hits.total: 5 } + - length: { hits.hits: 5 } - match: { hits.hits.0._index: test_2 } - match: { hits.hits.0._source.counter: 1223372036854775800.23 } - match: { hits.hits.0.sort.0: 1223372036854775800.23 } @@ -55,14 +71,13 @@ rest_total_hits_as_int: true body: sort: [{ counter: asc }] - - match: { hits.total: 3 } - - length: { hits.hits: 3 } + - match: { hits.total: 5 } + - length: { hits.hits: 5 } - match: { hits.hits.0._index: test_2 } - match: { hits.hits.0._source.counter: 184.4 } - match: { hits.hits.0.sort.0: 184.4 } - - match: { hits.hits.1._index: test_1 } - - match: { hits.hits.1._source.counter: 223372036854775800 } - - match: { hits.hits.1.sort.0: 223372036854775800 } + - match: { hits.hits.1._index: test_3 } + - match: { hits.hits.1._source.counter: 187.4 } --- "search across indices with mixed long, double and unsigned_long numeric types": diff --git a/rest-api-spec/src/main/resources/rest-api-spec/test/search/90_search_after.yml b/rest-api-spec/src/main/resources/rest-api-spec/test/search/90_search_after.yml index 55e1566656faf..18ef0e49b1288 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/test/search/90_search_after.yml +++ b/rest-api-spec/src/main/resources/rest-api-spec/test/search/90_search_after.yml @@ -320,3 +320,71 @@ - length: { hits.hits: 1 } - match: { hits.hits.0._index: test } - match: { hits.hits.0._source.population: null } + +--- +"half float": + - do: + indices.create: + index: test + body: + mappings: + properties: + population: + type: half_float + - do: + bulk: + refresh: true + index: test + body: | + {"index":{}} + {"population": 184.4} + {"index":{}} + {"population": 194.4} + + - do: + search: + index: test + rest_total_hits_as_int: true + body: + size: 1 + sort: [ { population: asc } ] + - match: { hits.total: 2 } + - length: { hits.hits: 1 } + - match: { hits.hits.0._index: test } + - match: { hits.hits.0._source.population: 184.4 } + + # search_after with the sort + - do: + search: + index: test + rest_total_hits_as_int: true + body: + size: 1 + sort: [ { population: asc } ] + search_after: [ 184.4 ] + - match: { hits.total: 2 } + - length: { hits.hits: 1 } + - match: { hits.hits.0._index: test } + - match: { hits.hits.0._source.population: 194.4 } + + # search_after with the sort with missing + - do: + bulk: + refresh: true + index: test + body: | + {"index":{}} + {"population": null} + - do: + search: + index: test + rest_total_hits_as_int: true + body: + "size": 5 + "sort": [ { "population": { "order": "asc", "missing": "_last" } } ] + search_after: [ 200 ] # making it out of min/max so only missing value hit is qualified + + - match: { hits.total: 3 } + - length: { hits.hits: 1 } + - match: { hits.hits.0._index: test } + - match: { hits.hits.0._source.population: null } diff --git a/server/src/internalClusterTest/java/org/opensearch/search/sort/FieldSortIT.java b/server/src/internalClusterTest/java/org/opensearch/search/sort/FieldSortIT.java index bee242b933dfd..d4980a64a3977 100644 --- a/server/src/internalClusterTest/java/org/opensearch/search/sort/FieldSortIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/search/sort/FieldSortIT.java @@ -605,6 +605,9 @@ public void testSimpleSorts() throws Exception { .startObject("float_value") .field("type", "float") .endObject() + .startObject("half_float_value") + .field("type", "half_float") + .endObject() .startObject("double_value") .field("type", "double") .endObject() @@ -628,6 +631,7 @@ public void testSimpleSorts() throws Exception { .field("long_value", i) .field("unsigned_long_value", UNSIGNED_LONG_BASE.add(BigInteger.valueOf(10000 * i))) .field("float_value", 0.1 * i) + .field("half_float_value", 0.1 * i) .field("double_value", 0.1 * i) .endObject() ); @@ -794,6 +798,28 @@ public void testSimpleSorts() throws Exception { assertThat(searchResponse.toString(), not(containsString("error"))); + // HALF_FLOAT + size = 1 + random.nextInt(10); + searchResponse = client().prepareSearch().setQuery(matchAllQuery()).setSize(size).addSort("half_float_value", SortOrder.ASC).get(); + + assertHitCount(searchResponse, 10L); + assertThat(searchResponse.getHits().getHits().length, equalTo(size)); + for (int i = 0; i < size; i++) { + assertThat(searchResponse.getHits().getAt(i).getId(), equalTo(Integer.toString(i))); + } + + assertThat(searchResponse.toString(), not(containsString("error"))); + size = 1 + random.nextInt(10); + searchResponse = client().prepareSearch().setQuery(matchAllQuery()).setSize(size).addSort("half_float_value", SortOrder.DESC).get(); + + assertHitCount(searchResponse, 10); + assertThat(searchResponse.getHits().getHits().length, equalTo(size)); + for (int i = 0; i < size; i++) { + assertThat(searchResponse.getHits().getAt(i).getId(), equalTo(Integer.toString(9 - i))); + } + + assertThat(searchResponse.toString(), not(containsString("error"))); + // DOUBLE size = 1 + random.nextInt(10); searchResponse = client().prepareSearch().setQuery(matchAllQuery()).setSize(size).addSort("double_value", SortOrder.ASC).get(); @@ -1330,6 +1356,9 @@ public void testSortMVField() throws Exception { .startObject("float_values") .field("type", "float") .endObject() + .startObject("half_float_values") + .field("type", "float") + .endObject() .startObject("double_values") .field("type", "double") .endObject() @@ -1351,6 +1380,7 @@ public void testSortMVField() throws Exception { .array("short_values", 1, 5, 10, 8) .array("byte_values", 1, 5, 10, 8) .array("float_values", 1f, 5f, 10f, 8f) + .array("half_float_values", 1f, 5f, 10f, 8f) .array("double_values", 1d, 5d, 10d, 8d) .array("string_values", "01", "05", "10", "08") .endObject() @@ -1365,6 +1395,7 @@ public void testSortMVField() throws Exception { .array("short_values", 11, 15, 20, 7) .array("byte_values", 11, 15, 20, 7) .array("float_values", 11f, 15f, 20f, 7f) + .array("half_float_values", 11f, 15f, 20f, 7f) .array("double_values", 11d, 15d, 20d, 7d) .array("string_values", "11", "15", "20", "07") .endObject() @@ -1379,6 +1410,7 @@ public void testSortMVField() throws Exception { .array("short_values", 2, 1, 3, -4) .array("byte_values", 2, 1, 3, -4) .array("float_values", 2f, 1f, 3f, -4f) + .array("half_float_values", 2f, 1f, 3f, -4f) .array("double_values", 2d, 1d, 3d, -4d) .array("string_values", "02", "01", "03", "!4") .endObject() @@ -1585,6 +1617,34 @@ public void testSortMVField() throws Exception { assertThat(searchResponse.getHits().getAt(2).getId(), equalTo(Integer.toString(3))); assertThat(((Number) searchResponse.getHits().getAt(2).getSortValues()[0]).floatValue(), equalTo(3f)); + searchResponse = client().prepareSearch().setQuery(matchAllQuery()).setSize(10).addSort("half_float_values", SortOrder.ASC).get(); + + assertThat(searchResponse.getHits().getTotalHits().value, equalTo(3L)); + assertThat(searchResponse.getHits().getHits().length, equalTo(3)); + + assertThat(searchResponse.getHits().getAt(0).getId(), equalTo(Integer.toString(3))); + assertThat(((Number) searchResponse.getHits().getAt(0).getSortValues()[0]).floatValue(), equalTo(-4f)); + + assertThat(searchResponse.getHits().getAt(1).getId(), equalTo(Integer.toString(1))); + assertThat(((Number) searchResponse.getHits().getAt(1).getSortValues()[0]).floatValue(), equalTo(1f)); + + assertThat(searchResponse.getHits().getAt(2).getId(), equalTo(Integer.toString(2))); + assertThat(((Number) searchResponse.getHits().getAt(2).getSortValues()[0]).floatValue(), equalTo(7f)); + + searchResponse = client().prepareSearch().setQuery(matchAllQuery()).setSize(10).addSort("half_float_values", SortOrder.DESC).get(); + + assertThat(searchResponse.getHits().getTotalHits().value, equalTo(3L)); + assertThat(searchResponse.getHits().getHits().length, equalTo(3)); + + assertThat(searchResponse.getHits().getAt(0).getId(), equalTo(Integer.toString(2))); + assertThat(((Number) searchResponse.getHits().getAt(0).getSortValues()[0]).floatValue(), equalTo(20f)); + + assertThat(searchResponse.getHits().getAt(1).getId(), equalTo(Integer.toString(1))); + assertThat(((Number) searchResponse.getHits().getAt(1).getSortValues()[0]).floatValue(), equalTo(10f)); + + assertThat(searchResponse.getHits().getAt(2).getId(), equalTo(Integer.toString(3))); + assertThat(((Number) searchResponse.getHits().getAt(2).getSortValues()[0]).floatValue(), equalTo(3f)); + searchResponse = client().prepareSearch().setQuery(matchAllQuery()).setSize(10).addSort("double_values", SortOrder.ASC).get(); assertThat(searchResponse.getHits().getTotalHits().value, equalTo(3L)); diff --git a/server/src/main/java/org/opensearch/index/fielddata/IndexNumericFieldData.java b/server/src/main/java/org/opensearch/index/fielddata/IndexNumericFieldData.java index 6fc074fe0de95..b0ff944d014de 100644 --- a/server/src/main/java/org/opensearch/index/fielddata/IndexNumericFieldData.java +++ b/server/src/main/java/org/opensearch/index/fielddata/IndexNumericFieldData.java @@ -42,6 +42,7 @@ import org.opensearch.index.fielddata.IndexFieldData.XFieldComparatorSource.Nested; import org.opensearch.index.fielddata.fieldcomparator.DoubleValuesComparatorSource; import org.opensearch.index.fielddata.fieldcomparator.FloatValuesComparatorSource; +import org.opensearch.index.fielddata.fieldcomparator.HalfFloatValuesComparatorSource; import org.opensearch.index.fielddata.fieldcomparator.IntValuesComparatorSource; import org.opensearch.index.fielddata.fieldcomparator.LongValuesComparatorSource; import org.opensearch.index.fielddata.fieldcomparator.UnsignedLongValuesComparatorSource; @@ -220,6 +221,8 @@ private XFieldComparatorSource comparatorSource( final XFieldComparatorSource source; switch (targetNumericType) { case HALF_FLOAT: + source = new HalfFloatValuesComparatorSource(this, missingValue, sortMode, nested); + break; case FLOAT: source = new FloatValuesComparatorSource(this, missingValue, sortMode, nested); break; @@ -242,7 +245,7 @@ private XFieldComparatorSource comparatorSource( assert !targetNumericType.isFloatingPoint(); source = new IntValuesComparatorSource(this, missingValue, sortMode, nested); } - if (targetNumericType != getNumericType() || getNumericType() == NumericType.HALF_FLOAT) { + if (targetNumericType != getNumericType()) { source.disableSkipping(); // disable skipping logic for cast of sort field } return source; diff --git a/server/src/main/java/org/opensearch/index/fielddata/fieldcomparator/FloatValuesComparatorSource.java b/server/src/main/java/org/opensearch/index/fielddata/fieldcomparator/FloatValuesComparatorSource.java index 04a34cd418520..abf872dd235e5 100644 --- a/server/src/main/java/org/opensearch/index/fielddata/fieldcomparator/FloatValuesComparatorSource.java +++ b/server/src/main/java/org/opensearch/index/fielddata/fieldcomparator/FloatValuesComparatorSource.java @@ -61,7 +61,7 @@ */ public class FloatValuesComparatorSource extends IndexFieldData.XFieldComparatorSource { - private final IndexNumericFieldData indexFieldData; + protected final IndexNumericFieldData indexFieldData; public FloatValuesComparatorSource( IndexNumericFieldData indexFieldData, @@ -78,7 +78,7 @@ public SortField.Type reducedType() { return SortField.Type.FLOAT; } - private NumericDoubleValues getNumericDocValues(LeafReaderContext context, float missingValue) throws IOException { + protected NumericDoubleValues getNumericDocValues(LeafReaderContext context, float missingValue) throws IOException { final SortedNumericDoubleValues values = indexFieldData.load(context).getDoubleValues(); if (nested == null) { return FieldData.replaceMissing(sortMode.select(values), missingValue); diff --git a/server/src/main/java/org/opensearch/index/fielddata/fieldcomparator/HalfFloatValuesComparatorSource.java b/server/src/main/java/org/opensearch/index/fielddata/fieldcomparator/HalfFloatValuesComparatorSource.java new file mode 100644 index 0000000000000..f63c4110bbdda --- /dev/null +++ b/server/src/main/java/org/opensearch/index/fielddata/fieldcomparator/HalfFloatValuesComparatorSource.java @@ -0,0 +1,55 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.index.fielddata.fieldcomparator; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.NumericDocValues; +import org.apache.lucene.search.FieldComparator; +import org.apache.lucene.search.LeafFieldComparator; +import org.opensearch.index.fielddata.IndexNumericFieldData; +import org.opensearch.index.search.comparators.HalfFloatComparator; +import org.opensearch.search.MultiValueMode; + +import java.io.IOException; + +/** + * Comparator source for half_float values. + * + * @opensearch.internal + */ +public class HalfFloatValuesComparatorSource extends FloatValuesComparatorSource { + public HalfFloatValuesComparatorSource( + IndexNumericFieldData indexFieldData, + Object missingValue, + MultiValueMode sortMode, + Nested nested + ) { + super(indexFieldData, missingValue, sortMode, nested); + } + + @Override + public FieldComparator newComparator(String fieldname, int numHits, boolean enableSkipping, boolean reversed) { + assert indexFieldData == null || fieldname.equals(indexFieldData.getFieldName()); + + final float fMissingValue = (Float) missingObject(missingValue, reversed); + // NOTE: it's important to pass null as a missing value in the constructor so that + // the comparator doesn't check docsWithField since we replace missing values in select() + return new HalfFloatComparator(numHits, fieldname, null, reversed, enableSkipping && this.enableSkipping) { + @Override + public LeafFieldComparator getLeafComparator(LeafReaderContext context) throws IOException { + return new HalfFloatLeafComparator(context) { + @Override + protected NumericDocValues getNumericDocValues(LeafReaderContext context, String field) throws IOException { + return HalfFloatValuesComparatorSource.this.getNumericDocValues(context, fMissingValue).getRawFloatValues(); + } + }; + } + }; + } +} diff --git a/server/src/main/java/org/opensearch/index/search/comparators/HalfFloatComparator.java b/server/src/main/java/org/opensearch/index/search/comparators/HalfFloatComparator.java new file mode 100644 index 0000000000000..685c0bb9bbf3d --- /dev/null +++ b/server/src/main/java/org/opensearch/index/search/comparators/HalfFloatComparator.java @@ -0,0 +1,111 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.index.search.comparators; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.sandbox.document.HalfFloatPoint; +import org.apache.lucene.search.LeafFieldComparator; +import org.apache.lucene.search.comparators.NumericComparator; + +import java.io.IOException; + +/** + * The comparator for unsigned long numeric type. + * Comparator based on {@link Float#compare} for {@code numHits}. This comparator provides a + * skipping functionality – an iterator that can skip over non-competitive documents. + */ +public class HalfFloatComparator extends NumericComparator { + private final float[] values; + protected float topValue; + protected float bottom; + + public HalfFloatComparator(int numHits, String field, Float missingValue, boolean reverse, boolean enableSkipping) { + super(field, missingValue != null ? missingValue : 0.0f, reverse, enableSkipping, HalfFloatPoint.BYTES); + values = new float[numHits]; + } + + @Override + public int compare(int slot1, int slot2) { + return Float.compare(values[slot1], values[slot2]); + } + + @Override + public void setTopValue(Float value) { + super.setTopValue(value); + topValue = value; + } + + @Override + public Float value(int slot) { + return Float.valueOf(values[slot]); + } + + @Override + public LeafFieldComparator getLeafComparator(LeafReaderContext context) throws IOException { + return new HalfFloatLeafComparator(context); + } + + /** Leaf comparator for {@link HalfFloatComparator} that provides skipping functionality */ + public class HalfFloatLeafComparator extends NumericLeafComparator { + + public HalfFloatLeafComparator(LeafReaderContext context) throws IOException { + super(context); + } + + private float getValueForDoc(int doc) throws IOException { + if (docValues.advanceExact(doc)) { + return Float.intBitsToFloat((int) docValues.longValue()); + } else { + return missingValue; + } + } + + @Override + public void setBottom(int slot) throws IOException { + bottom = values[slot]; + super.setBottom(slot); + } + + @Override + public int compareBottom(int doc) throws IOException { + return Float.compare(bottom, getValueForDoc(doc)); + } + + @Override + public int compareTop(int doc) throws IOException { + return Float.compare(topValue, getValueForDoc(doc)); + } + + @Override + public void copy(int slot, int doc) throws IOException { + values[slot] = getValueForDoc(doc); + super.copy(slot, doc); + } + + @Override + protected int compareMissingValueWithBottomValue() { + return Float.compare(missingValue, bottom); + } + + @Override + protected int compareMissingValueWithTopValue() { + return Float.compare(missingValue, topValue); + } + + @Override + protected void encodeBottom(byte[] packedValue) { + HalfFloatPoint.encodeDimension(bottom, packedValue, 0); + } + + @Override + protected void encodeTop(byte[] packedValue) { + HalfFloatPoint.encodeDimension(topValue, packedValue, 0); + } + } +}