Skip to content

Commit

Permalink
Remove deprecated vector functions. (#48725)
Browse files Browse the repository at this point in the history
Follow up to #48604. This PR removes the deprecated vector function signatures
of the form `cosineSimilarity(query, doc['field'])`.
  • Loading branch information
jtibshirani committed Oct 31, 2019
1 parent 23a4e4a commit 939c242
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 154 deletions.
10 changes: 9 additions & 1 deletion docs/reference/migration/migrate_8_0/search.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,12 @@ The `nested_filter` and `nested_path` options, deprecated in 6.x, have been remo

{es} will no longer prefer using shards in the same location (with the same awareness attribute values) to process
`_search` and `_get` requests. Adaptive replica selection (activated by default in this version) will route requests
more efficiently using the service time of prior inter-node communications.
more efficiently using the service time of prior inter-node communications.

[float]
==== Update to vector function signatures
The vector functions of the form `function(query, doc['field'])` were
deprecated in 7.6, and are now removed in 8.x. The form
`function(query, 'field')` should be used instead. For example,
`cosineSimilarity(query, doc['field'])` is replaced by
`cosineSimilarity(query, 'field')`.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
setup:
- skip:
features: [headers, warnings]
features: headers
version: " - 7.2.99"
reason: "dense_vector functions were added from 7.3"

Expand Down Expand Up @@ -99,26 +99,3 @@ setup:
- match: {hits.hits.2._id: "1"}
- gte: {hits.hits.2._score: 0.78}
- lte: {hits.hits.2._score: 0.791}

---
"Deprecated function signature":
- do:
headers:
Content-Type: application/json
warnings:
- The vector functions of the form function(query, doc['field']) are deprecated, and the form function(query, 'field') should be used instead. For example, cosineSimilarity(query, doc['field']) is replaced by cosineSimilarity(query, 'field').
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])"
params:
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]

- match: {hits.total: 3}
- match: {hits.hits.0._id: "3"}
- match: {hits.hits.1._id: "2"}
- match: {hits.hits.2._id: "1"}
Original file line number Diff line number Diff line change
Expand Up @@ -104,27 +104,3 @@ setup:
- match: {hits.hits.2._id: "1"}
- gte: {hits.hits.2._score: 0.78}
- lte: {hits.hits.2._score: 0.791}

---
"Deprecated function signature":
- do:
headers:
Content-Type: application/json
warnings:
- The [sparse_vector] field type is deprecated and will be removed in 8.0.
- The vector functions of the form function(query, doc['field']) are deprecated, and the form function(query, 'field') should be used instead. For example, cosineSimilarity(query, doc['field']) is replaced by cosineSimilarity(query, 'field').
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])"
params:
query_vector: {"2": -0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0}

- match: {hits.total: 3}
- match: {hits.hits.0._id: "3"}
- match: {hits.hits.1._id: "2"}
- match: {hits.hits.2._id: "1"}
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@

public class ScoreScriptUtils {
private static final DeprecationLogger deprecationLogger = new DeprecationLogger(LogManager.getLogger(ScoreScriptUtils.class));
static final String DEPRECATION_MESSAGE = "The vector functions of the form function(query, doc['field']) are deprecated, and " +
"the form function(query, 'field') should be used instead. For example, cosineSimilarity(query, doc['field']) is replaced by " +
"cosineSimilarity(query, 'field').";

//**************FUNCTIONS FOR DENSE VECTORS
// Functions are implemented as classes to accept a hidden parameter scoreScript that contains some index settings.
Expand All @@ -43,7 +40,7 @@ public static class DenseVectorFunction {

public DenseVectorFunction(ScoreScript scoreScript,
List<Number> queryVector,
Object field) {
String field) {
this(scoreScript, queryVector, field, false);
}

Expand All @@ -56,9 +53,10 @@ public DenseVectorFunction(ScoreScript scoreScript,
*/
public DenseVectorFunction(ScoreScript scoreScript,
List<Number> queryVector,
Object field,
String field,
boolean normalizeQuery) {
this.scoreScript = scoreScript;
this.docValues = (DenseVectorScriptDocValues) scoreScript.getDoc().get(field);

this.queryVector = new float[queryVector.size()];
double queryMagnitude = 0.0;
Expand All @@ -74,17 +72,6 @@ public DenseVectorFunction(ScoreScript scoreScript,
this.queryVector[dim] /= queryMagnitude;
}
}

if (field instanceof String) {
String fieldName = (String) field;
docValues = (DenseVectorScriptDocValues) scoreScript.getDoc().get(fieldName);
} else if (field instanceof DenseVectorScriptDocValues) {
docValues = (DenseVectorScriptDocValues) field;
deprecationLogger.deprecatedAndMaybeLog("vector_function_signature", DEPRECATION_MESSAGE);
} else {
throw new IllegalArgumentException("For vector functions, the 'field' argument must be of type String or " +
"VectorScriptDocValues");
}
}

BytesRef getEncodedVector() {
Expand Down Expand Up @@ -112,7 +99,7 @@ BytesRef getEncodedVector() {
// Calculate l1 norm (Manhattan distance) between a query's dense vector and documents' dense vectors
public static final class L1Norm extends DenseVectorFunction {

public L1Norm(ScoreScript scoreScript, List<Number> queryVector, Object field) {
public L1Norm(ScoreScript scoreScript, List<Number> queryVector, String field) {
super(scoreScript, queryVector, field);
}

Expand All @@ -132,7 +119,7 @@ public double l1norm() {
// Calculate l2 norm (Euclidean distance) between a query's dense vector and documents' dense vectors
public static final class L2Norm extends DenseVectorFunction {

public L2Norm(ScoreScript scoreScript, List<Number> queryVector, Object field) {
public L2Norm(ScoreScript scoreScript, List<Number> queryVector, String field) {
super(scoreScript, queryVector, field);
}

Expand All @@ -152,7 +139,7 @@ public double l2norm() {
// Calculate a dot product between a query's dense vector and documents' dense vectors
public static final class DotProduct extends DenseVectorFunction {

public DotProduct(ScoreScript scoreScript, List<Number> queryVector, Object field) {
public DotProduct(ScoreScript scoreScript, List<Number> queryVector, String field) {
super(scoreScript, queryVector, field);
}

Expand All @@ -171,7 +158,7 @@ public double dotProduct() {
// Calculate cosine similarity between a query's dense vector and documents' dense vectors
public static final class CosineSimilarity extends DenseVectorFunction {

public CosineSimilarity(ScoreScript scoreScript, List<Number> queryVector, Object field) {
public CosineSimilarity(ScoreScript scoreScript, List<Number> queryVector, String field) {
super(scoreScript, queryVector, field, true);
}

Expand Down Expand Up @@ -214,8 +201,10 @@ public static class SparseVectorFunction {
// queryVector represents a map of dimensions to values
public SparseVectorFunction(ScoreScript scoreScript,
Map<String, Number> queryVector,
Object field) {
String field) {
this.scoreScript = scoreScript;
this.docValues = (SparseVectorScriptDocValues) scoreScript.getDoc().get(field);

//break vector into two arrays dims and values
int n = queryVector.size();
queryValues = new float[n];
Expand All @@ -232,18 +221,6 @@ public SparseVectorFunction(ScoreScript scoreScript,
}
// Sort dimensions in the ascending order and sort values in the same order as their corresponding dimensions
sortSparseDimsFloatValues(queryDims, queryValues, n);

if (field instanceof String) {
String fieldName = (String) field;
docValues = (SparseVectorScriptDocValues) scoreScript.getDoc().get(fieldName);
} else if (field instanceof SparseVectorScriptDocValues) {
docValues = (SparseVectorScriptDocValues) field;
deprecationLogger.deprecatedAndMaybeLog("vector_function_signature", DEPRECATION_MESSAGE);
} else {
throw new IllegalArgumentException("For vector functions, the 'field' argument must be of type String or " +
"VectorScriptDocValues");
}

deprecationLogger.deprecatedAndMaybeLog("sparse_vector_function", SparseVectorFieldMapper.DEPRECATION_MESSAGE);
}

Expand All @@ -264,8 +241,8 @@ BytesRef getEncodedVector() {

// Calculate l1 norm (Manhattan distance) between a query's sparse vector and documents' sparse vectors
public static final class L1NormSparse extends SparseVectorFunction {
public L1NormSparse(ScoreScript scoreScript,Map<String, Number> queryVector, Object docVector) {
super(scoreScript, queryVector, docVector);
public L1NormSparse(ScoreScript scoreScript,Map<String, Number> queryVector, String field) {
super(scoreScript, queryVector, field);
}

public double l1normSparse() {
Expand Down Expand Up @@ -303,8 +280,8 @@ public double l1normSparse() {

// Calculate l2 norm (Euclidean distance) between a query's sparse vector and documents' sparse vectors
public static final class L2NormSparse extends SparseVectorFunction {
public L2NormSparse(ScoreScript scoreScript, Map<String, Number> queryVector, Object docVector) {
super(scoreScript, queryVector, docVector);
public L2NormSparse(ScoreScript scoreScript, Map<String, Number> queryVector, String field) {
super(scoreScript, queryVector, field);
}

public double l2normSparse() {
Expand Down Expand Up @@ -345,8 +322,8 @@ public double l2normSparse() {

// Calculate a dot product between a query's sparse vector and documents' sparse vectors
public static final class DotProductSparse extends SparseVectorFunction {
public DotProductSparse(ScoreScript scoreScript, Map<String, Number> queryVector, Object docVector) {
super(scoreScript, queryVector, docVector);
public DotProductSparse(ScoreScript scoreScript, Map<String, Number> queryVector, String field) {
super(scoreScript, queryVector, field);
}

public double dotProductSparse() {
Expand All @@ -362,8 +339,8 @@ public double dotProductSparse() {
public static final class CosineSimilaritySparse extends SparseVectorFunction {
final double queryVectorMagnitude;

public CosineSimilaritySparse(ScoreScript scoreScript, Map<String, Number> queryVector, Object docVector) {
super(scoreScript, queryVector, docVector);
public CosineSimilaritySparse(ScoreScript scoreScript, Map<String, Number> queryVector, String field) {
super(scoreScript, queryVector, field);
double dotProduct = 0;
for (int i = 0; i< queryDims.length; i++) {
dotProduct += queryValues[i] * queryValues[i];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ class org.elasticsearch.script.ScoreScript @no_import {
}

static_import {
double l1norm(org.elasticsearch.script.ScoreScript, List, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L1Norm
double l2norm(org.elasticsearch.script.ScoreScript, List, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L2Norm
double cosineSimilarity(org.elasticsearch.script.ScoreScript, List, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$CosineSimilarity
double dotProduct(org.elasticsearch.script.ScoreScript, List, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$DotProduct
double l1normSparse(org.elasticsearch.script.ScoreScript, Map, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L1NormSparse
double l2normSparse(org.elasticsearch.script.ScoreScript, Map, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L2NormSparse
double dotProductSparse(org.elasticsearch.script.ScoreScript, Map, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$DotProductSparse
double cosineSimilaritySparse(org.elasticsearch.script.ScoreScript, Map, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$CosineSimilaritySparse
double l1norm(org.elasticsearch.script.ScoreScript, List, String) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L1Norm
double l2norm(org.elasticsearch.script.ScoreScript, List, String) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L2Norm
double cosineSimilarity(org.elasticsearch.script.ScoreScript, List, String) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$CosineSimilarity
double dotProduct(org.elasticsearch.script.ScoreScript, List, String) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$DotProduct
double l1normSparse(org.elasticsearch.script.ScoreScript, Map, String) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L1NormSparse
double l2normSparse(org.elasticsearch.script.ScoreScript, Map, String) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L2NormSparse
double dotProductSparse(org.elasticsearch.script.ScoreScript, Map, String) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$DotProductSparse
double cosineSimilaritySparse(org.elasticsearch.script.ScoreScript, Map, String) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$CosineSimilaritySparse
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,68 +50,48 @@ public void testDenseVectorFunctions() {
when(scoreScript._getIndexVersion()).thenReturn(indexVersion);
when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, docValues));

testDotProduct(docValues, scoreScript);
testCosineSimilarity(docValues, scoreScript);
testL1Norm(docValues, scoreScript);
testL2Norm(docValues, scoreScript);
testDotProduct(scoreScript);
testCosineSimilarity(scoreScript);
testL1Norm(scoreScript);
testL2Norm(scoreScript);
}
}

private void testDotProduct(DenseVectorScriptDocValues docValues, ScoreScript scoreScript) {
private void testDotProduct(ScoreScript scoreScript) {
DotProduct function = new DotProduct(scoreScript, queryVector, field);
double result = function.dotProduct();
assertEquals("dotProduct result is not equal to the expected value!", 65425.624, result, 0.001);

DotProduct deprecatedFunction = new DotProduct(scoreScript, queryVector, docValues);
double deprecatedResult = deprecatedFunction.dotProduct();
assertEquals("dotProduct result is not equal to the expected value!", 65425.624, deprecatedResult, 0.001);
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);

DotProduct invalidFunction = new DotProduct(scoreScript, invalidQueryVector, field);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, invalidFunction::dotProduct);
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
}

private void testCosineSimilarity(DenseVectorScriptDocValues docValues, ScoreScript scoreScript) {
private void testCosineSimilarity(ScoreScript scoreScript) {
CosineSimilarity function = new CosineSimilarity(scoreScript, queryVector, field);
double result = function.cosineSimilarity();
assertEquals("cosineSimilarity result is not equal to the expected value!", 0.790, result, 0.001);

CosineSimilarity deprecatedFunction = new CosineSimilarity(scoreScript, queryVector, docValues);
double deprecatedResult = deprecatedFunction.cosineSimilarity();
assertEquals("cosineSimilarity result is not equal to the expected value!", 0.790, deprecatedResult, 0.001);
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);

CosineSimilarity invalidFunction = new CosineSimilarity(scoreScript, invalidQueryVector, field);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, invalidFunction::cosineSimilarity);
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
}

private void testL1Norm(DenseVectorScriptDocValues docValues, ScoreScript scoreScript) {
private void testL1Norm(ScoreScript scoreScript) {
L1Norm function = new L1Norm(scoreScript, queryVector, field);
double result = function.l1norm();
assertEquals("l1norm result is not equal to the expected value!", 485.184, result, 0.001);

L1Norm deprecatedFunction = new L1Norm(scoreScript, queryVector, docValues);
double deprecatedResult = deprecatedFunction.l1norm();
assertEquals("l1norm result is not equal to the expected value!", 485.184, deprecatedResult, 0.001);
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);

L1Norm invalidFunction = new L1Norm(scoreScript, invalidQueryVector, field);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, invalidFunction::l1norm);
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
}

private void testL2Norm(DenseVectorScriptDocValues docValues, ScoreScript scoreScript) {
private void testL2Norm(ScoreScript scoreScript) {
L2Norm function = new L2Norm(scoreScript, queryVector, field);
double result = function.l2norm();
assertEquals("l2norm result is not equal to the expected value!", 301.361, result, 0.001);

L2Norm deprecatedFunction = new L2Norm(scoreScript, queryVector, docValues);
double deprecatedResult = deprecatedFunction.l2norm();
assertEquals("l2norm result is not equal to the expected value!", 301.361, deprecatedResult, 0.001);
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);

L2Norm invalidFunction = new L2Norm(scoreScript, invalidQueryVector, field);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, invalidFunction::l2norm);
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
Expand Down
Loading

0 comments on commit 939c242

Please sign in to comment.