Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for configuring HNSW parameters #79193

Merged
merged 4 commits into from
Oct 18, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ public CodecService(@Nullable MapperService mapperService) {
codecs.put(BEST_COMPRESSION_CODEC, new Lucene90Codec(Lucene90Codec.Mode.BEST_COMPRESSION));
} else {
codecs.put(DEFAULT_CODEC,
new PerFieldMappingPostingFormatCodec(Lucene90Codec.Mode.BEST_SPEED, mapperService));
new PerFieldMappingCodec(Lucene90Codec.Mode.BEST_SPEED, mapperService));
codecs.put(BEST_COMPRESSION_CODEC,
new PerFieldMappingPostingFormatCodec(Lucene90Codec.Mode.BEST_COMPRESSION, mapperService));
new PerFieldMappingCodec(Lucene90Codec.Mode.BEST_COMPRESSION, mapperService));
}
codecs.put(LUCENE_DEFAULT_CODEC, Codec.getDefault());
for (String codec : Codec.availableCodecs()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,32 @@

import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.DocValuesFormat;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.PostingsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90Codec;
import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.index.mapper.MapperService;

/**
* {@link PerFieldMappingPostingFormatCodec This postings format} is the default
* {@link PostingsFormat} for Elasticsearch. It utilizes the
* {@link MapperService} to lookup a {@link PostingsFormat} per field. This
* allows users to change the low level postings format for individual fields
* per index in real time via the mapping API. If no specific postings format is
* configured for a specific field the default postings format is used.
* {@link PerFieldMappingCodec This postings format} is the default
mayya-sharipova marked this conversation as resolved.
Show resolved Hide resolved
* {@link PostingsFormat} and {@link KnnVectorsFormat} for Elasticsearch. It utilizes the
* {@link MapperService} to lookup a {@link PostingsFormat} and {@link KnnVectorsFormat} per field. This
* allows users to change the low level postings format and vectors format for individual fields
* per index in real time via the mapping API. If no specific postings format or vector format is
* configured for a specific field the default postings or vector format is used.
*/
public class PerFieldMappingPostingFormatCodec extends Lucene90Codec {
public class PerFieldMappingCodec extends Lucene90Codec {
mayya-sharipova marked this conversation as resolved.
Show resolved Hide resolved
private final MapperService mapperService;

private final DocValuesFormat docValuesFormat = new Lucene90DocValuesFormat();

static {
assert Codec.forName(Lucene.LATEST_CODEC).getClass().isAssignableFrom(PerFieldMappingPostingFormatCodec.class) :
"PerFieldMappingPostingFormatCodec must subclass the latest " + "lucene codec: " + Lucene.LATEST_CODEC;
assert Codec.forName(Lucene.LATEST_CODEC).getClass().isAssignableFrom(PerFieldMappingCodec.class) :
"PerFieldMappingCodec must subclass the latest " + "lucene codec: " + Lucene.LATEST_CODEC;
}

public PerFieldMappingPostingFormatCodec(Mode compressionMode, MapperService mapperService) {
public PerFieldMappingCodec(Mode compressionMode, MapperService mapperService) {
super(compressionMode);
this.mapperService = mapperService;
}
Expand All @@ -48,6 +49,15 @@ public PostingsFormat getPostingsFormatForField(String field) {
return format;
}

@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
KnnVectorsFormat format = mapperService.mappingLookup().getKnnVectorsFormatForField(field);
if (format == null) {
return super.getKnnVectorsFormatForField(field);
}
return format;
}

@Override
public DocValuesFormat getDocValuesFormatForField(String field) {
return docValuesFormat;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

package org.elasticsearch.index.mapper;

import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.PostingsFormat;
import org.elasticsearch.cluster.metadata.DataStream;
import org.elasticsearch.index.IndexSettings;
Expand Down Expand Up @@ -228,6 +229,20 @@ public PostingsFormat getPostingsFormat(String field) {
return completionFields.contains(field) ? CompletionFieldMapper.postingsFormat() : null;
}

/**
* Returns the knn vectors format for a particular field
* @param field the field to retrieve a knn vectors format for
* @return the knn vectors format for the field, or {@code null} if the default format should be used
*/
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
Mapper fieldMapper = fieldMappers.get(field);
if (fieldMapper instanceof VectorFieldMapper) {
return ((VectorFieldMapper) fieldMapper).getKnnVectorsFormatForField();
} else {
return null;
}
}

void checkLimits(IndexSettings settings) {
checkFieldLimit(settings.getMappingTotalFieldsLimit());
checkObjectDepthLimit(settings.getMappingDepthLimit());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.index.mapper;

import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90HnswVectorsFormat;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;

import static org.apache.lucene.codecs.lucene90.Lucene90HnswVectorsFormat.DEFAULT_MAX_CONN;
import static org.apache.lucene.codecs.lucene90.Lucene90HnswVectorsFormat.DEFAULT_BEAM_WIDTH;

/**
* Field mapper for a vector field for ann search.
*/

public abstract class VectorFieldMapper extends FieldMapper {
public static final IndexOptions DEFAULT_INDEX_OPTIONS = new HNSWIndexOptions(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH);
protected final IndexOptions indexOptions;

protected VectorFieldMapper(String simpleName, MappedFieldType mappedFieldType, MultiFields multiFields, CopyTo copyTo,
IndexOptions indexOptions) {
super(simpleName, mappedFieldType, multiFields, copyTo);
this.indexOptions = indexOptions;
}

/**
* Returns the knn vectors format that is customly set up for this field or {@code null} if
* the format is not set up or if the set up format matches the default format.
* @return the knn vectors format for the field, or {@code null} if the default format should be used
*/
public KnnVectorsFormat getKnnVectorsFormatForField() {
if (indexOptions == null && indexOptions == DEFAULT_INDEX_OPTIONS) {
return null;
} else {
HNSWIndexOptions hnswIndexOptions = (HNSWIndexOptions) indexOptions;
return new Lucene90HnswVectorsFormat(hnswIndexOptions.m, hnswIndexOptions.efConstruction);
}
}

public static IndexOptions parseVectorIndexOptions(String fieldName, Object propNode) {
if (propNode == null) {
return null;
}
Map<?, ?> indexOptionsMap = (Map<?, ?>) propNode;
String type = XContentMapValues.nodeStringValue(indexOptionsMap.remove("type"), "hnsw");
mayya-sharipova marked this conversation as resolved.
Show resolved Hide resolved
if (type.equals("hnsw")) {
return HNSWIndexOptions.parseIndexOptions(fieldName, indexOptionsMap);
} else {
throw new MapperParsingException("Unknown vector index options type [" + type + "] for field [" + fieldName + "]");
}
}

public abstract static class IndexOptions implements ToXContent {
protected final String type;
public IndexOptions(String type) {
this.type = type;
}
}

public static class HNSWIndexOptions extends IndexOptions {
mayya-sharipova marked this conversation as resolved.
Show resolved Hide resolved
private final int m;
private final int efConstruction;

public HNSWIndexOptions(int m, int efConstruction) {
super("hnsw");
this.m = m;
this.efConstruction = efConstruction;
}

public int m() {
return m;
}

public int efConstruction() {
return efConstruction;
}

public static IndexOptions parseIndexOptions(String fieldName, Map<?, ?> indexOptionsMap) {
int m = XContentMapValues.nodeIntegerValue(indexOptionsMap.remove("m"), DEFAULT_MAX_CONN);
int efConstruction = XContentMapValues.nodeIntegerValue(indexOptionsMap.remove("ef_construction"), DEFAULT_BEAM_WIDTH);
MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap);
if (m == DEFAULT_MAX_CONN && efConstruction == DEFAULT_BEAM_WIDTH) {
return VectorFieldMapper.DEFAULT_INDEX_OPTIONS;
} else {
return new HNSWIndexOptions(m, efConstruction);
}
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field("type", type);
builder.field("m", m);
builder.field("ef_construction", efConstruction);
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
HNSWIndexOptions that = (HNSWIndexOptions) o;
return m == that.m && efConstruction == that.efConstruction;
}

@Override
public int hashCode() {
return Objects.hash(type, m, efConstruction);
}

@Override
public String toString() {
return "{type=" + type + ", m=" + m + ", ef_construction=" + efConstruction + " }";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public class CodecTests extends ESTestCase {

public void testResolveDefaultCodecs() throws Exception {
CodecService codecService = createCodecService();
assertThat(codecService.codec("default"), instanceOf(PerFieldMappingPostingFormatCodec.class));
assertThat(codecService.codec("default"), instanceOf(PerFieldMappingCodec.class));
assertThat(codecService.codec("default"), instanceOf(Lucene90Codec.class));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.common.unit.Fuzziness;
import org.elasticsearch.index.codec.PerFieldMappingCodec;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
Expand All @@ -38,7 +39,6 @@
import org.elasticsearch.index.analysis.IndexAnalyzers;
import org.elasticsearch.index.analysis.NamedAnalyzer;
import org.elasticsearch.index.codec.CodecService;
import org.elasticsearch.index.codec.PerFieldMappingPostingFormatCodec;
import org.hamcrest.FeatureMatcher;
import org.hamcrest.Matcher;
import org.hamcrest.Matchers;
Expand Down Expand Up @@ -122,8 +122,8 @@ public void testPostingsFormat() throws IOException {
MapperService mapperService = createMapperService(fieldMapping(this::minimalMapping));
CodecService codecService = new CodecService(mapperService);
Codec codec = codecService.codec("default");
assertThat(codec, instanceOf(PerFieldMappingPostingFormatCodec.class));
PerFieldMappingPostingFormatCodec perFieldCodec = (PerFieldMappingPostingFormatCodec) codec;
assertThat(codec, instanceOf(PerFieldMappingCodec.class));
PerFieldMappingCodec perFieldCodec = (PerFieldMappingCodec) codec;
assertThat(perFieldCodec.getPostingsFormatForField("field"), instanceOf(Completion90PostingsFormat.class));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ setup:
dims: 5
index: true
similarity: dot_product
index_options:
mayya-sharipova marked this conversation as resolved.
Show resolved Hide resolved
type: hnsw
m: 15
ef_construction: 80
- do:
index:
index: test-index
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ setup:
dims: 3
index: true
similarity: l2_norm
index_options:
type: hnsw
m: 15

---
"Indexing of Dense vectors should error when dims don't match defined in the mapping":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Version;
import org.elasticsearch.index.mapper.VectorFieldMapper;
import org.elasticsearch.xcontent.XContentParser.Token;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.index.fielddata.IndexFieldData;
Expand Down Expand Up @@ -47,7 +48,7 @@
/**
* A {@link FieldMapper} for indexing a dense vector of floats.
*/
public class DenseVectorFieldMapper extends FieldMapper {
public class DenseVectorFieldMapper extends VectorFieldMapper {

public static final String CONTENT_TYPE = "dense_vector";
public static short MAX_DIMS_COUNT = 2048; //maximum allowed number of dimensions
Expand All @@ -73,6 +74,8 @@ public static class Builder extends FieldMapper.Builder {
private final Parameter<Boolean> indexed = Parameter.indexParam(m -> toType(m).indexed, false);
private final Parameter<VectorSimilarity> similarity = Parameter.enumParam(
"similarity", false, m -> toType(m).similarity, null, VectorSimilarity.class);
private final Parameter<IndexOptions> indexOptions = new Parameter<>("index_options", false, () -> null,
(n, c, o) -> VectorFieldMapper.parseVectorIndexOptions(n, o), m -> toType(m).indexOptions);
private final Parameter<Map<String, String>> meta = Parameter.metaParam();

final Version indexVersionCreated;
Expand All @@ -84,11 +87,13 @@ public Builder(String name, Version indexVersionCreated) {
this.indexed.requiresParameters(similarity);
this.similarity.setSerializerCheck((id, ic, v) -> v != null);
this.similarity.requiresParameters(indexed);
this.indexOptions.requiresParameters(indexed);
this.indexOptions.setSerializerCheck((id, ic, v) -> v != null);
}

@Override
protected List<Parameter<?>> getParameters() {
return List.of(dims, indexed, similarity, meta);
return List.of(dims, indexed, similarity, indexOptions, meta);
}

@Override
Expand All @@ -102,7 +107,8 @@ public DenseVectorFieldMapper build(MapperBuilderContext context) {
similarity.getValue(),
indexVersionCreated,
multiFieldsBuilder.build(this, context),
copyTo.build());
copyTo.build(),
indexOptions.getValue());
}
}

Expand Down Expand Up @@ -187,10 +193,10 @@ public Query termQuery(Object value, SearchExecutionContext context) {
private final VectorSimilarity similarity;
private final Version indexCreatedVersion;

private DenseVectorFieldMapper(String simpleName, MappedFieldType mappedFieldType, int dims,
boolean indexed, VectorSimilarity similarity,
Version indexCreatedVersion, MultiFields multiFields, CopyTo copyTo) {
super(simpleName, mappedFieldType, multiFields, copyTo);
private DenseVectorFieldMapper(String simpleName, MappedFieldType mappedFieldType, int dims, boolean indexed,
VectorSimilarity similarity, Version indexCreatedVersion, MultiFields multiFields,
CopyTo copyTo, VectorFieldMapper.IndexOptions indexOptions) {
super(simpleName, mappedFieldType, multiFields, copyTo, indexOptions);
this.dims = dims;
this.indexed = indexed;
this.similarity = similarity;
Expand Down
Loading