Skip to content

Commit

Permalink
Merge branch 'main' into BWC_Tests_k-NN
Browse files Browse the repository at this point in the history
  • Loading branch information
naveentatikonda committed Nov 23, 2021
2 parents ca038d5 + c929559 commit 3c0e4b9
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 3 deletions.
7 changes: 7 additions & 0 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,13 @@ jbyteArray knn_jni::faiss_wrapper::TrainIndex(knn_jni::JNIUtilInterface * jniUti
std::unique_ptr<faiss::Index> indexWriter;
indexWriter.reset(faiss::index_factory((int) dimensionJ, indexDescriptionCpp.c_str(), metric));

// Related to https://github.com/facebookresearch/faiss/issues/1621. HNSWPQ defaults to l2 even when metric is
// passed in. This updates it to the correct metric.
indexWriter->metric_type = metric;
if (auto * indexHnswPq = dynamic_cast<faiss::IndexHNSWPQ*>(indexWriter.get())) {
indexHnswPq->storage->metric_type = metric;
}

// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) {
auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]);
Expand Down
8 changes: 7 additions & 1 deletion jni/tests/faiss_wrapper_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,16 @@ TEST(FaissCreateIndexFromTemplateTest, BasicAssertions) {
jniEnv, reinterpret_cast<jobjectArray>(&vectors)))
.WillRepeatedly(Return(vectors.size()));

std::string spaceType = knn_jni::L2;
std::unordered_map<std::string, jobject> parametersMap;
parametersMap[knn_jni::SPACE_TYPE] = (jobject) &spaceType;

knn_jni::faiss_wrapper::CreateIndexFromTemplate(
&mockJNIUtil, jniEnv, reinterpret_cast<jintArray>(&ids),
reinterpret_cast<jobjectArray>(&vectors), (jstring)&indexPath,
reinterpret_cast<jbyteArray>(&(vectorIoWriter.data)));
reinterpret_cast<jbyteArray>(&(vectorIoWriter.data)),
(jobject) &parametersMap
);

// Make sure index can be loaded
std::unique_ptr<faiss::Index> index(test_util::FaissLoadIndex(indexPath));
Expand Down
4 changes: 2 additions & 2 deletions src/test/java/org/opensearch/knn/index/FaissIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public void testEndToEnd_fromMethod() throws IOException, InterruptedException {
String fieldName = "test-field-1";

KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW);
SpaceType spaceType = SpaceType.INNER_PRODUCT;
SpaceType spaceType = SpaceType.L2;

List<Integer> mValues = ImmutableList.of(16, 32, 64, 128);
List<Integer> efConstructionValues = ImmutableList.of(16, 32, 64, 128);
Expand Down Expand Up @@ -115,7 +115,7 @@ public void testEndToEnd_fromMethod() throws IOException, InterruptedException {
List<Float> actualScores = parseSearchResponseScore(responseBody, fieldName);
for (int j = 0; j < k; j++) {
float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList()));
assertEquals(KNNEngine.FAISS.score(KNNScoringUtil.innerProduct(testData.queries[i], primitiveArray),
assertEquals(KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray),
spaceType), actualScores.get(j), 0.0001);
}
}
Expand Down

0 comments on commit 3c0e4b9

Please sign in to comment.