diff --git a/server/src/main/java/org/elasticsearch/index/similarity/SimilarityService.java b/server/src/main/java/org/elasticsearch/index/similarity/SimilarityService.java index eaed2169f11c0..3fd908b3d1478 100644 --- a/server/src/main/java/org/elasticsearch/index/similarity/SimilarityService.java +++ b/server/src/main/java/org/elasticsearch/index/similarity/SimilarityService.java @@ -19,15 +19,22 @@ package org.elasticsearch.index.similarity; +import org.apache.logging.log4j.LogManager; +import org.apache.lucene.index.FieldInvertState; +import org.apache.lucene.index.IndexOptions; +import org.apache.lucene.search.CollectionStatistics; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.TermStatistics; import org.apache.lucene.search.similarities.BM25Similarity; import org.apache.lucene.search.similarities.BooleanSimilarity; import org.apache.lucene.search.similarities.ClassicSimilarity; import org.apache.lucene.search.similarities.PerFieldSimilarityWrapper; import org.apache.lucene.search.similarities.Similarity; +import org.apache.lucene.search.similarities.Similarity.SimScorer; +import org.apache.lucene.util.BytesRef; import org.elasticsearch.Version; import org.elasticsearch.common.TriFunction; import org.elasticsearch.common.logging.DeprecationLogger; -import org.elasticsearch.common.logging.Loggers; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.AbstractIndexComponent; import org.elasticsearch.index.IndexModule; @@ -44,7 +51,7 @@ public final class SimilarityService extends AbstractIndexComponent { - private static final DeprecationLogger DEPRECATION_LOGGER = new DeprecationLogger(Loggers.getLogger(SimilarityService.class)); + private static final DeprecationLogger DEPRECATION_LOGGER = new DeprecationLogger(LogManager.getLogger(SimilarityService.class)); public static final String DEFAULT_SIMILARITY = "BM25"; private static final String CLASSIC_SIMILARITY = "classic"; private static final Map>> DEFAULTS; @@ -132,6 +139,7 @@ public SimilarityService(IndexSettings indexSettings, ScriptService scriptServic TriFunction defaultFactory = BUILT_IN.get(typeName); TriFunction factory = similarities.getOrDefault(typeName, defaultFactory); final Similarity similarity = factory.apply(providerSettings, indexSettings.getIndexVersionCreated(), scriptService); + validateSimilarity(indexSettings.getIndexVersionCreated(), similarity); providers.put(name, () -> similarity); } for (Map.Entry>> entry : DEFAULTS.entrySet()) { @@ -182,4 +190,79 @@ public Similarity get(String name) { return (fieldType != null && fieldType.similarity() != null) ? fieldType.similarity().get() : defaultSimilarity; } } + + static void validateSimilarity(Version indexCreatedVersion, Similarity similarity) { + validateScoresArePositive(indexCreatedVersion, similarity); + validateScoresDoNotDecreaseWithFreq(indexCreatedVersion, similarity); + validateScoresDoNotIncreaseWithNorm(indexCreatedVersion, similarity); + } + + private static void validateScoresArePositive(Version indexCreatedVersion, Similarity similarity) { + CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000); + TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130); + SimScorer scorer = similarity.scorer(2f, collectionStats, termStats); + FieldInvertState state = new FieldInvertState(indexCreatedVersion.major, "some_field", + IndexOptions.DOCS_AND_FREQS, 20, 20, 0, 50, 10, 3); // length = 20, no overlap + final long norm = similarity.computeNorm(state); + for (int freq = 1; freq <= 10; ++freq) { + float score = scorer.score(freq, norm); + if (score < 0) { + fail(indexCreatedVersion, "Similarities should not return negative scores:\n" + + scorer.explain(Explanation.match(freq, "term freq"), norm)); + } + } + } + + private static void validateScoresDoNotDecreaseWithFreq(Version indexCreatedVersion, Similarity similarity) { + CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000); + TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130); + SimScorer scorer = similarity.scorer(2f, collectionStats, termStats); + FieldInvertState state = new FieldInvertState(indexCreatedVersion.major, "some_field", + IndexOptions.DOCS_AND_FREQS, 20, 20, 0, 50, 10, 3); // length = 20, no overlap + final long norm = similarity.computeNorm(state); + float previousScore = 0; + for (int freq = 1; freq <= 10; ++freq) { + float score = scorer.score(freq, norm); + if (score < previousScore) { + fail(indexCreatedVersion, "Similarity scores should not decrease when term frequency increases:\n" + + scorer.explain(Explanation.match(freq - 1, "term freq"), norm) + "\n" + + scorer.explain(Explanation.match(freq, "term freq"), norm)); + } + previousScore = score; + } + } + + private static void validateScoresDoNotIncreaseWithNorm(Version indexCreatedVersion, Similarity similarity) { + CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000); + TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130); + SimScorer scorer = similarity.scorer(2f, collectionStats, termStats); + + long previousNorm = 0; + float previousScore = Float.MAX_VALUE; + for (int length = 1; length <= 10; ++length) { + FieldInvertState state = new FieldInvertState(indexCreatedVersion.major, "some_field", + IndexOptions.DOCS_AND_FREQS, length, length, 0, 50, 10, 3); // length = 20, no overlap + final long norm = similarity.computeNorm(state); + if (Long.compareUnsigned(previousNorm, norm) > 0) { + // esoteric similarity, skip this check + break; + } + float score = scorer.score(1, norm); + if (score > previousScore) { + fail(indexCreatedVersion, "Similarity scores should not increase when norm increases:\n" + + scorer.explain(Explanation.match(1, "term freq"), norm - 1) + "\n" + + scorer.explain(Explanation.match(1, "term freq"), norm)); + } + previousScore = score; + previousNorm = norm; + } + } + + private static void fail(Version indexCreatedVersion, String message) { + if (indexCreatedVersion.onOrAfter(Version.V_7_0_0_alpha1)) { + throw new IllegalArgumentException(message); + } else if (indexCreatedVersion.onOrAfter(Version.V_6_5_0)) { + DEPRECATION_LOGGER.deprecated(message); + } + } } diff --git a/server/src/test/java/org/elasticsearch/index/similarity/SimilarityServiceTests.java b/server/src/test/java/org/elasticsearch/index/similarity/SimilarityServiceTests.java index 5d18a595e9687..eb769f3e77029 100644 --- a/server/src/test/java/org/elasticsearch/index/similarity/SimilarityServiceTests.java +++ b/server/src/test/java/org/elasticsearch/index/similarity/SimilarityServiceTests.java @@ -18,12 +18,18 @@ */ package org.elasticsearch.index.similarity; +import org.apache.lucene.index.FieldInvertState; +import org.apache.lucene.search.CollectionStatistics; +import org.apache.lucene.search.TermStatistics; import org.apache.lucene.search.similarities.BM25Similarity; import org.apache.lucene.search.similarities.BooleanSimilarity; +import org.apache.lucene.search.similarities.Similarity; +import org.elasticsearch.Version; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.IndexSettingsModule; +import org.hamcrest.Matchers; import java.util.Collections; @@ -56,4 +62,75 @@ public void testOverrideDefaultSimilarity() { SimilarityService service = new SimilarityService(indexSettings, null, Collections.emptyMap()); assertTrue(service.getDefaultSimilarity() instanceof BooleanSimilarity); } + + public void testSimilarityValidation() { + Similarity negativeScoresSim = new Similarity() { + + @Override + public long computeNorm(FieldInvertState state) { + return state.getLength(); + } + + @Override + public SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) { + return new SimScorer() { + + @Override + public float score(float freq, long norm) { + return -1; + } + + }; + } + }; + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, + () -> SimilarityService.validateSimilarity(Version.V_7_0_0_alpha1, negativeScoresSim)); + assertThat(e.getMessage(), Matchers.containsString("Similarities should not return negative scores")); + + Similarity decreasingScoresWithFreqSim = new Similarity() { + + @Override + public long computeNorm(FieldInvertState state) { + return state.getLength(); + } + + @Override + public SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) { + return new SimScorer() { + + @Override + public float score(float freq, long norm) { + return 1 / (freq + norm); + } + + }; + } + }; + e = expectThrows(IllegalArgumentException.class, + () -> SimilarityService.validateSimilarity(Version.V_7_0_0_alpha1, decreasingScoresWithFreqSim)); + assertThat(e.getMessage(), Matchers.containsString("Similarity scores should not decrease when term frequency increases")); + + Similarity increasingScoresWithNormSim = new Similarity() { + + @Override + public long computeNorm(FieldInvertState state) { + return state.getLength(); + } + + @Override + public SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) { + return new SimScorer() { + + @Override + public float score(float freq, long norm) { + return freq + norm; + } + + }; + } + }; + e = expectThrows(IllegalArgumentException.class, + () -> SimilarityService.validateSimilarity(Version.V_7_0_0_alpha1, increasingScoresWithNormSim)); + assertThat(e.getMessage(), Matchers.containsString("Similarity scores should not increase when norm increases")); + } }