Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add recall test with small dataset #2080

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 25 additions & 15 deletions src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@

import java.io.IOException;
import java.net.URL;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;

Expand All @@ -41,9 +44,11 @@ public static void setUpClass() throws IOException {
}
URL testIndexVectors = BinaryIndexIT.class.getClassLoader().getResource("data/test_vectors_binary_1000x128.json");
URL testQueries = BinaryIndexIT.class.getClassLoader().getResource("data/test_queries_binary_100x128.csv");
URL groundTruthValues = BinaryIndexIT.class.getClassLoader().getResource("data/test_ground_truth_binary_100.csv");
assert testIndexVectors != null;
assert testQueries != null;
testData = new TestUtils.TestData(testIndexVectors.getPath(), testQueries.getPath());
assert groundTruthValues != null;
testData = new TestUtils.TestData(testIndexVectors.getPath(), testQueries.getPath(), groundTruthValues.getPath());
}

@After
Expand Down Expand Up @@ -83,18 +88,19 @@ public void testFaissHnswBinary_whenSmallDataSet_thenCreateIngestQueryWorks() {
}

@SneakyThrows
public void testFaissHnswBinary_when1000Data_thenCreateIngestQueryWorks() {
public void testFaissHnswBinary_when1000Data_thenRecallIsAboveNinePointZero() {
// Create Index
createKnnHnswBinaryIndex(KNNEngine.FAISS, INDEX_NAME, FIELD_NAME, 128);
ingestTestData(INDEX_NAME, FIELD_NAME);

int k = 10;
int k = 100;
for (int i = 0; i < testData.queries.length; i++) {
// Query
List<KNNResult> knnResults = runKnnQuery(INDEX_NAME, FIELD_NAME, testData.queries[i], k);

// Validate
assertEquals(k, knnResults.size());
float recall = getRecall(
Set.of(Arrays.copyOf(testData.groundTruthValues[i], k)),
knnResults.stream().map(KNNResult::getDocId).collect(Collectors.toSet())
);
assertTrue("Recall: " + recall, recall > 0.1);
}
}

Expand All @@ -109,6 +115,18 @@ public void testFaissHnswBinary_whenRadialSearch_thenThrowException() {
assertTrue(e.getMessage(), e.getMessage().contains("Binary data type does not support radial search"));
}

private float getRecall(final Set<String> truth, final Set<String> result) {
// Count the number of relevant documents retrieved
result.retainAll(truth);
int relevantRetrieved = result.size();

// Total number of relevant documents
int totalRelevant = truth.size();

// Calculate recall
return (float) relevantRetrieved / totalRelevant;
}

private List<KNNResult> runRnnQuery(
final String indexName,
final String fieldName,
Expand Down Expand Up @@ -171,12 +189,4 @@ private void createKnnHnswBinaryIndex(final KNNEngine knnEngine, final String in

createKnnIndex(indexName, knnIndexMapping);
}

private byte[] toByte(final float[] vector) {
byte[] bytes = new byte[vector.length];
for (int i = 0; i < vector.length; i++) {
bytes[i] = (byte) vector[i];
}
return bytes;
}
}
139 changes: 139 additions & 0 deletions src/test/java/org/opensearch/knn/integ/IndexIT.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.integ;

import com.google.common.primitives.Floats;
import lombok.SneakyThrows;
import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang.ArrayUtils;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.junit.After;
import org.junit.BeforeClass;
import org.opensearch.client.Response;
import org.opensearch.knn.KNNJsonIndexMappingsBuilder;
import org.opensearch.knn.KNNJsonQueryBuilder;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.KNNResult;
import org.opensearch.knn.TestUtils;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;

import java.io.IOException;
import java.net.URL;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;

/**
* This class contains integration tests for index
*/
@Log4j2
public class IndexIT extends KNNRestTestCase {
private static TestUtils.TestData testData;

@BeforeClass
public static void setUpClass() throws IOException {
if (IndexIT.class.getClassLoader() == null) {
throw new IllegalStateException("ClassLoader of IndexIT Class is null");
}
URL testIndexVectors = IndexIT.class.getClassLoader().getResource("data/test_vectors_1000x128.json");
URL testQueries = IndexIT.class.getClassLoader().getResource("data/test_queries_100x128.csv");
URL groundTruthValues = IndexIT.class.getClassLoader().getResource("data/test_ground_truth_l2_100.csv");
assert testIndexVectors != null;
assert testQueries != null;
assert groundTruthValues != null;
testData = new TestUtils.TestData(testIndexVectors.getPath(), testQueries.getPath(), groundTruthValues.getPath());
}

@After
public void cleanUp() {
try {
deleteKNNIndex(INDEX_NAME);
} catch (Exception e) {
log.error(e);
}
}

@SneakyThrows
public void testFaissHnsw_when1000Data_thenRecallIsAboveNinePointZero() {
// Create Index
createKnnHnswIndex(KNNEngine.FAISS, INDEX_NAME, FIELD_NAME, 128);
ingestTestData(INDEX_NAME, FIELD_NAME);

int k = 100;
for (int i = 0; i < testData.queries.length; i++) {
List<KNNResult> knnResults = runKnnQuery(INDEX_NAME, FIELD_NAME, testData.queries[i], k);
float recall = getRecall(
Set.of(Arrays.copyOf(testData.groundTruthValues[i], k)),
knnResults.stream().map(KNNResult::getDocId).collect(Collectors.toSet())
);
assertTrue("Recall: " + recall, recall > 0.9);
}
}

private float getRecall(final Set<String> truth, final Set<String> result) {
// Count the number of relevant documents retrieved
result.retainAll(truth);
int relevantRetrieved = result.size();

// Total number of relevant documents
int totalRelevant = truth.size();

// Calculate recall
return (float) relevantRetrieved / totalRelevant;
}

private List<KNNResult> runKnnQuery(final String indexName, final String fieldName, final float[] queryVector, final int k)
throws Exception {
String query = KNNJsonQueryBuilder.builder()
.fieldName(fieldName)
.vector(ArrayUtils.toObject(queryVector))
.k(k)
.build()
.getQueryString();
Response response = searchKNNIndex(indexName, query, k);
return parseSearchResponse(EntityUtils.toString(response.getEntity()), fieldName);
}

private void ingestTestData(final String indexName, final String fieldName) throws Exception {
// Index the test data
for (int i = 0; i < testData.indexData.docs.length; i++) {
addKnnDoc(
indexName,
Integer.toString(testData.indexData.docs[i]),
fieldName,
Floats.asList(testData.indexData.vectors[i]).toArray()
);
}

// Assert we have the right number of documents in the index
refreshAllIndices();
assertEquals(testData.indexData.docs.length, getDocCount(indexName));
}

private void createKnnHnswIndex(final KNNEngine knnEngine, final String indexName, final String fieldName, final int dimension)
throws IOException {
KNNJsonIndexMappingsBuilder.Method method = KNNJsonIndexMappingsBuilder.Method.builder()
.methodName(METHOD_HNSW)
.spaceType(SpaceType.L2.getValue())
.engine(knnEngine.getName())
.build();

String knnIndexMapping = KNNJsonIndexMappingsBuilder.builder()
.fieldName(fieldName)
.dimension(dimension)
.vectorDataType(VectorDataType.FLOAT.getValue())
.method(method)
.build()
.getIndexMapping();

createKnnIndex(indexName, knnIndexMapping);
}
}
8 changes: 8 additions & 0 deletions src/test/resources/data/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ test_queries_100x128.csv and packing 8 bits to 1 byte with ends up with 16 lengt
For quantization technique, we calculated the median(49935.95941056451) of all values in test_vectors_1000x128.json
and converted it as 0 if it is less than the median and 1 if it is equal to or larger than the median.

# test_ground_truth_binary_100.csv
The file contains the ground truth for the query test_queries_binary_100x128.csv against the data
test_vectors_binary_1000x128.json using hamming distance.

# test_ground_truth_l2_100.csv
The file contains the ground truth for the query test_queries_100x128.csv against the data test_vectors_1000x128.json
using l2 distance

# test_vectors_nested_1000x128.json
The file contains a simulated data to represent nested field.
Consecutive ids are assigned for data from same parent document.
Expand Down
Loading
Loading