Skip to content

Commit

Permalink
Add minimal sanity checks to custom/scripted similarities.
Browse files Browse the repository at this point in the history
Lucene 8 introduced more constraints on similarities, in particular:
 - scores must not be negative,
 - scores must not decrease when term freq increases,
 - scores must not increase when norm (interpreted as an unsigned long)
   increases.

We can't check every single case, but could at least run some sanity checks.

Relates elastic#33309
  • Loading branch information
jpountz committed Sep 10, 2018
1 parent 77aeeda commit d161e9a
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,22 @@

package org.elasticsearch.index.similarity;

import org.apache.logging.log4j.LogManager;
import org.apache.lucene.index.FieldInvertState;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.search.similarities.BM25Similarity;
import org.apache.lucene.search.similarities.BooleanSimilarity;
import org.apache.lucene.search.similarities.ClassicSimilarity;
import org.apache.lucene.search.similarities.PerFieldSimilarityWrapper;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.search.similarities.Similarity.SimScorer;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Version;
import org.elasticsearch.common.TriFunction;
import org.elasticsearch.common.logging.DeprecationLogger;
import org.elasticsearch.common.logging.Loggers;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.AbstractIndexComponent;
import org.elasticsearch.index.IndexModule;
Expand All @@ -44,7 +51,7 @@

public final class SimilarityService extends AbstractIndexComponent {

private static final DeprecationLogger DEPRECATION_LOGGER = new DeprecationLogger(Loggers.getLogger(SimilarityService.class));
private static final DeprecationLogger DEPRECATION_LOGGER = new DeprecationLogger(LogManager.getLogger(SimilarityService.class));
public static final String DEFAULT_SIMILARITY = "BM25";
private static final String CLASSIC_SIMILARITY = "classic";
private static final Map<String, Function<Version, Supplier<Similarity>>> DEFAULTS;
Expand Down Expand Up @@ -132,6 +139,7 @@ public SimilarityService(IndexSettings indexSettings, ScriptService scriptServic
TriFunction<Settings, Version, ScriptService, Similarity> defaultFactory = BUILT_IN.get(typeName);
TriFunction<Settings, Version, ScriptService, Similarity> factory = similarities.getOrDefault(typeName, defaultFactory);
final Similarity similarity = factory.apply(providerSettings, indexSettings.getIndexVersionCreated(), scriptService);
validateSimilarity(indexSettings.getIndexVersionCreated(), similarity);
providers.put(name, () -> similarity);
}
for (Map.Entry<String, Function<Version, Supplier<Similarity>>> entry : DEFAULTS.entrySet()) {
Expand Down Expand Up @@ -182,4 +190,79 @@ public Similarity get(String name) {
return (fieldType != null && fieldType.similarity() != null) ? fieldType.similarity().get() : defaultSimilarity;
}
}

static void validateSimilarity(Version indexCreatedVersion, Similarity similarity) {
validateScoresArePositive(indexCreatedVersion, similarity);
validateScoresDoNotDecreaseWithFreq(indexCreatedVersion, similarity);
validateScoresDoNotIncreaseWithNorm(indexCreatedVersion, similarity);
}

private static void validateScoresArePositive(Version indexCreatedVersion, Similarity similarity) {
CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000);
TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130);
SimScorer scorer = similarity.scorer(2f, collectionStats, termStats);
FieldInvertState state = new FieldInvertState(indexCreatedVersion.major, "some_field",
IndexOptions.DOCS_AND_FREQS, 20, 20, 0, 50, 10, 3); // length = 20, no overlap
final long norm = similarity.computeNorm(state);
for (int freq = 1; freq <= 10; ++freq) {
float score = scorer.score(freq, norm);
if (score < 0) {
fail(indexCreatedVersion, "Similarities should not return negative scores:\n" +
scorer.explain(Explanation.match(freq, "term freq"), norm));
}
}
}

private static void validateScoresDoNotDecreaseWithFreq(Version indexCreatedVersion, Similarity similarity) {
CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000);
TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130);
SimScorer scorer = similarity.scorer(2f, collectionStats, termStats);
FieldInvertState state = new FieldInvertState(indexCreatedVersion.major, "some_field",
IndexOptions.DOCS_AND_FREQS, 20, 20, 0, 50, 10, 3); // length = 20, no overlap
final long norm = similarity.computeNorm(state);
float previousScore = 0;
for (int freq = 1; freq <= 10; ++freq) {
float score = scorer.score(freq, norm);
if (score < previousScore) {
fail(indexCreatedVersion, "Similarity scores should not decrease when term frequency increases:\n" +
scorer.explain(Explanation.match(freq - 1, "term freq"), norm) + "\n" +
scorer.explain(Explanation.match(freq, "term freq"), norm));
}
previousScore = score;
}
}

private static void validateScoresDoNotIncreaseWithNorm(Version indexCreatedVersion, Similarity similarity) {
CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000);
TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130);
SimScorer scorer = similarity.scorer(2f, collectionStats, termStats);

long previousNorm = 0;
float previousScore = Float.MAX_VALUE;
for (int length = 1; length <= 10; ++length) {
FieldInvertState state = new FieldInvertState(indexCreatedVersion.major, "some_field",
IndexOptions.DOCS_AND_FREQS, length, length, 0, 50, 10, 3); // length = 20, no overlap
final long norm = similarity.computeNorm(state);
if (Long.compareUnsigned(previousNorm, norm) > 0) {
// esoteric similarity, skip this check
break;
}
float score = scorer.score(1, norm);
if (score > previousScore) {
fail(indexCreatedVersion, "Similarity scores should not increase when norm increases:\n" +
scorer.explain(Explanation.match(1, "term freq"), norm - 1) + "\n" +
scorer.explain(Explanation.match(1, "term freq"), norm));
}
previousScore = score;
previousNorm = norm;
}
}

private static void fail(Version indexCreatedVersion, String message) {
if (indexCreatedVersion.onOrAfter(Version.V_7_0_0_alpha1)) {
throw new IllegalArgumentException(message);
} else if (indexCreatedVersion.onOrAfter(Version.V_6_5_0)) {
DEPRECATION_LOGGER.deprecated(message);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,18 @@
*/
package org.elasticsearch.index.similarity;

import org.apache.lucene.index.FieldInvertState;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.search.similarities.BM25Similarity;
import org.apache.lucene.search.similarities.BooleanSimilarity;
import org.apache.lucene.search.similarities.Similarity;
import org.elasticsearch.Version;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.IndexSettings;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.IndexSettingsModule;
import org.hamcrest.Matchers;

import java.util.Collections;

Expand Down Expand Up @@ -56,4 +62,75 @@ public void testOverrideDefaultSimilarity() {
SimilarityService service = new SimilarityService(indexSettings, null, Collections.emptyMap());
assertTrue(service.getDefaultSimilarity() instanceof BooleanSimilarity);
}

public void testSimilarityValidation() {
Similarity negativeScoresSim = new Similarity() {

@Override
public long computeNorm(FieldInvertState state) {
return state.getLength();
}

@Override
public SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) {
return new SimScorer() {

@Override
public float score(float freq, long norm) {
return -1;
}

};
}
};
IllegalArgumentException e = expectThrows(IllegalArgumentException.class,
() -> SimilarityService.validateSimilarity(Version.V_7_0_0_alpha1, negativeScoresSim));
assertThat(e.getMessage(), Matchers.containsString("Similarities should not return negative scores"));

Similarity decreasingScoresWithFreqSim = new Similarity() {

@Override
public long computeNorm(FieldInvertState state) {
return state.getLength();
}

@Override
public SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) {
return new SimScorer() {

@Override
public float score(float freq, long norm) {
return 1 / (freq + norm);
}

};
}
};
e = expectThrows(IllegalArgumentException.class,
() -> SimilarityService.validateSimilarity(Version.V_7_0_0_alpha1, decreasingScoresWithFreqSim));
assertThat(e.getMessage(), Matchers.containsString("Similarity scores should not decrease when term frequency increases"));

Similarity increasingScoresWithNormSim = new Similarity() {

@Override
public long computeNorm(FieldInvertState state) {
return state.getLength();
}

@Override
public SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) {
return new SimScorer() {

@Override
public float score(float freq, long norm) {
return freq + norm;
}

};
}
};
e = expectThrows(IllegalArgumentException.class,
() -> SimilarityService.validateSimilarity(Version.V_7_0_0_alpha1, increasingScoresWithNormSim));
assertThat(e.getMessage(), Matchers.containsString("Similarity scores should not increase when norm increases"));
}
}

0 comments on commit d161e9a

Please sign in to comment.