Skip to content

Commit

Permalink
Makes sure KNNVectorValues aren't recreated unnecessarily when
Browse files Browse the repository at this point in the history
quantization isn't needed

Signed-off-by: Tejas Shah <shatejas@amazon.com>
  • Loading branch information
shatejas committed Sep 20, 2024
1 parent 30adbe4 commit e348524
Showing 1 changed file with 38 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Supplier;

import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataType;
import static org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory.getVectorValues;
Expand Down Expand Up @@ -82,19 +83,19 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException {
for (final NativeEngineFieldVectorsWriter<?> field : fields) {
final FieldInfo fieldInfo = field.getFieldInfo();
final VectorDataType vectorDataType = extractVectorDataType(fieldInfo);
int totalLiveDocs = getLiveDocs(getVectorValues(vectorDataType, field.getDocsWithField(), field.getVectors()));
int totalLiveDocs = field.getVectors().size();
if (totalLiveDocs > 0) {
KNNVectorValues<?> knnVectorValues = getVectorValues(vectorDataType, field.getDocsWithField(), field.getVectors());

final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValues, totalLiveDocs);
final Supplier<KNNVectorValues<?>> knnVectorValuesSupplier = () -> getVectorValues(
vectorDataType,
field.getDocsWithField(),
field.getVectors()
);
final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs);
final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState);

knnVectorValues = getVectorValues(vectorDataType, field.getDocsWithField(), field.getVectors());
final KNNVectorValues<?> knnVectorValues = knnVectorValuesSupplier.get();

StopWatch stopWatch = new StopWatch().start();

writer.flushIndex(knnVectorValues, totalLiveDocs);

long time_in_millis = stopWatch.stop().totalTime().millis();
KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.incrementBy(time_in_millis);
log.debug("Flush took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName());
Expand All @@ -110,17 +111,20 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState
flatVectorsWriter.mergeOneField(fieldInfo, mergeState);

final VectorDataType vectorDataType = extractVectorDataType(fieldInfo);
int totalLiveDocs = getLiveDocs(getKNNVectorValuesForMerge(vectorDataType, fieldInfo, mergeState));
final Supplier<KNNVectorValues<?>> knnVectorValuesSupplier = () -> getKNNVectorValuesForMerge(
vectorDataType,
fieldInfo,
mergeState
);
int totalLiveDocs = getLiveDocs(knnVectorValuesSupplier.get());
if (totalLiveDocs == 0) {
log.debug("[Merge] No live docs for field {}", fieldInfo.getName());
return;
}

KNNVectorValues<?> knnVectorValues = getKNNVectorValuesForMerge(vectorDataType, fieldInfo, mergeState);
final QuantizationState quantizationState = train(fieldInfo, knnVectorValues, totalLiveDocs);
final QuantizationState quantizationState = train(fieldInfo, knnVectorValuesSupplier, totalLiveDocs);
final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState);

knnVectorValues = getKNNVectorValuesForMerge(vectorDataType, fieldInfo, mergeState);
final KNNVectorValues<?> knnVectorValues = knnVectorValuesSupplier.get();

StopWatch stopWatch = new StopWatch().start();

Expand Down Expand Up @@ -191,27 +195,36 @@ private <T> KNNVectorValues<T> getKNNVectorValuesForMerge(
final VectorDataType vectorDataType,
final FieldInfo fieldInfo,
final MergeState mergeState
) throws IOException {
switch (fieldInfo.getVectorEncoding()) {
case FLOAT32:
FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
return getVectorValues(vectorDataType, mergedFloats);
case BYTE:
ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
return getVectorValues(vectorDataType, mergedBytes);
default:
throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]");
) {
try {
switch (fieldInfo.getVectorEncoding()) {
case FLOAT32:
FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
return getVectorValues(vectorDataType, mergedFloats);
case BYTE:
ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
return getVectorValues(vectorDataType, mergedBytes);
default:
throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]");
}
} catch (final IOException e) {
log.error("Unable to merge vectors for field [{}]", fieldInfo.getName(), e);
throw new IllegalStateException("Unable to merge vectors for field [" + fieldInfo.getName() + "]", e);
}
}

private QuantizationState train(final FieldInfo fieldInfo, final KNNVectorValues<?> knnVectorValues, final int totalLiveDocs)
throws IOException {
private QuantizationState train(
final FieldInfo fieldInfo,
final Supplier<KNNVectorValues<?>> knnVectorValuesSupplier,
final int totalLiveDocs
) throws IOException {

final QuantizationService quantizationService = QuantizationService.getInstance();
final QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo);
QuantizationState quantizationState = null;
if (quantizationParams != null && totalLiveDocs > 0) {
initQuantizationStateWriterIfNecessary();
KNNVectorValues<?> knnVectorValues = knnVectorValuesSupplier.get();
quantizationState = quantizationService.train(quantizationParams, knnVectorValues, totalLiveDocs);
quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), quantizationState);
}
Expand Down

0 comments on commit e348524

Please sign in to comment.