Skip to content

Commit

Permalink
Add support for configuring HNSW parameters
Browse files Browse the repository at this point in the history
This PR extends the dense_vector type to allow configure HNSW params in
`index_options`:
`m` – max number of connections for each  node,
`ef_construction` – number  of candidate neighbors to track while searching
the graph for each newly inserted node.

```
"mappings": {
  "properties": {
    "my_vector": {
      "type": "dense_vector",
      "dims": 128,
      "index": true,
      "similarity": "l2_norm",
      "index_options": {
        "type" : "hnsw",
        "m" : 15,
        "ef_construction" : 50
      }
    }
  }
}
```

index_options as an object, and all parameters underneath are optional.
If  `m` or `ef_contruction` are not provided, the default values from the
current codec will be used.

Relates to #78473
  • Loading branch information
mayya-sharipova committed Oct 14, 2021
1 parent 0089bd0 commit 5004872
Show file tree
Hide file tree
Showing 10 changed files with 252 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ public CodecService(@Nullable MapperService mapperService) {
codecs.put(BEST_COMPRESSION_CODEC, new Lucene90Codec(Lucene90Codec.Mode.BEST_COMPRESSION));
} else {
codecs.put(DEFAULT_CODEC,
new PerFieldMappingPostingFormatCodec(Lucene90Codec.Mode.BEST_SPEED, mapperService));
new PerFieldMappingCodec(Lucene90Codec.Mode.BEST_SPEED, mapperService));
codecs.put(BEST_COMPRESSION_CODEC,
new PerFieldMappingPostingFormatCodec(Lucene90Codec.Mode.BEST_COMPRESSION, mapperService));
new PerFieldMappingCodec(Lucene90Codec.Mode.BEST_COMPRESSION, mapperService));
}
codecs.put(LUCENE_DEFAULT_CODEC, Codec.getDefault());
for (String codec : Codec.availableCodecs()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,32 @@

import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.DocValuesFormat;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.PostingsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90Codec;
import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.index.mapper.MapperService;

/**
* {@link PerFieldMappingPostingFormatCodec This postings format} is the default
* {@link PostingsFormat} for Elasticsearch. It utilizes the
* {@link MapperService} to lookup a {@link PostingsFormat} per field. This
* allows users to change the low level postings format for individual fields
* per index in real time via the mapping API. If no specific postings format is
* configured for a specific field the default postings format is used.
* {@link PerFieldMappingCodec This postings format} is the default
* {@link PostingsFormat} and {@link KnnVectorsFormat} for Elasticsearch. It utilizes the
* {@link MapperService} to lookup a {@link PostingsFormat} and {@link KnnVectorsFormat} per field. This
* allows users to change the low level postings format and vectors format for individual fields
* per index in real time via the mapping API. If no specific postings format or vector format is
* configured for a specific field the default postings or vector format is used.
*/
public class PerFieldMappingPostingFormatCodec extends Lucene90Codec {
public class PerFieldMappingCodec extends Lucene90Codec {
private final MapperService mapperService;

private final DocValuesFormat docValuesFormat = new Lucene90DocValuesFormat();

static {
assert Codec.forName(Lucene.LATEST_CODEC).getClass().isAssignableFrom(PerFieldMappingPostingFormatCodec.class) :
"PerFieldMappingPostingFormatCodec must subclass the latest " + "lucene codec: " + Lucene.LATEST_CODEC;
assert Codec.forName(Lucene.LATEST_CODEC).getClass().isAssignableFrom(PerFieldMappingCodec.class) :
"PerFieldMappingCodec must subclass the latest " + "lucene codec: " + Lucene.LATEST_CODEC;
}

public PerFieldMappingPostingFormatCodec(Mode compressionMode, MapperService mapperService) {
public PerFieldMappingCodec(Mode compressionMode, MapperService mapperService) {
super(compressionMode);
this.mapperService = mapperService;
}
Expand All @@ -48,6 +49,15 @@ public PostingsFormat getPostingsFormatForField(String field) {
return format;
}

@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
KnnVectorsFormat format = mapperService.mappingLookup().getKnnVectorsFormatForField(field);
if (format == null) {
return super.getKnnVectorsFormatForField(field);
}
return format;
}

@Override
public DocValuesFormat getDocValuesFormatForField(String field) {
return docValuesFormat;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

package org.elasticsearch.index.mapper;

import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.PostingsFormat;
import org.elasticsearch.cluster.metadata.DataStream;
import org.elasticsearch.index.IndexSettings;
Expand Down Expand Up @@ -228,6 +229,20 @@ public PostingsFormat getPostingsFormat(String field) {
return completionFields.contains(field) ? CompletionFieldMapper.postingsFormat() : null;
}

/**
* Returns the knn vectors format for a particular field
* @param field the field to retrieve a knn vectors format for
* @return the knn vectors format for the field, or {@code null} if the default format should be used
*/
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
Mapper fieldMapper = fieldMappers.get(field);
if (fieldMapper instanceof VectorFieldMapper) {
return ((VectorFieldMapper) fieldMapper).getKnnVectorsFormatForField();
} else {
return null;
}
}

void checkLimits(IndexSettings settings) {
checkFieldLimit(settings.getMappingTotalFieldsLimit());
checkObjectDepthLimit(settings.getMappingDepthLimit());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.index.mapper;

import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90HnswVectorsFormat;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;

import static org.apache.lucene.codecs.lucene90.Lucene90HnswVectorsFormat.DEFAULT_MAX_CONN;
import static org.apache.lucene.codecs.lucene90.Lucene90HnswVectorsFormat.DEFAULT_BEAM_WIDTH;

/**
* Field mapper for a vector field for ann search.
*/

public abstract class VectorFieldMapper extends FieldMapper {
public static final IndexOptions DEFAULT_INDEX_OPTIONS = new HNSWIndexOptions(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH);
protected final IndexOptions indexOptions;

protected VectorFieldMapper(String simpleName, MappedFieldType mappedFieldType, MultiFields multiFields, CopyTo copyTo,
IndexOptions indexOptions) {
super(simpleName, mappedFieldType, multiFields, copyTo);
this.indexOptions = indexOptions;
}

/**
* Returns the knn vectors format that is customly set up for this field or {@code null} if
* the format is not set up or if the set up format matches the default format.
* @return the knn vectors format for the field, or {@code null} if the default format should be used
*/
public KnnVectorsFormat getKnnVectorsFormatForField() {
if (indexOptions == null && indexOptions == DEFAULT_INDEX_OPTIONS) {
return null;
} else {
HNSWIndexOptions hnswIndexOptions = (HNSWIndexOptions) indexOptions;
return new Lucene90HnswVectorsFormat(hnswIndexOptions.m, hnswIndexOptions.efConstruction);
}
}

public static IndexOptions parseVectorIndexOptions(String fieldName, Object propNode) {
if (propNode == null) {
return null;
}
Map<?, ?> indexOptionsMap = (Map<?, ?>) propNode;
String type = XContentMapValues.nodeStringValue(indexOptionsMap.remove("type"), "hnsw");
if (type.equals("hnsw")) {
return HNSWIndexOptions.parseIndexOptions(fieldName, indexOptionsMap);
} else {
throw new MapperParsingException("Unknown vector index options type [" + type + "] for field [" + fieldName + "]");
}
}

public abstract static class IndexOptions implements ToXContent {
protected final String type;
public IndexOptions(String type) {
this.type = type;
}
}

public static class HNSWIndexOptions extends IndexOptions {
private final int m;
private final int efConstruction;

public HNSWIndexOptions(int m, int efConstruction) {
super("hnsw");
this.m = m;
this.efConstruction = efConstruction;
}

public int m() {
return m;
}

public int efConstruction() {
return efConstruction;
}

public static IndexOptions parseIndexOptions(String fieldName, Map<?, ?> indexOptionsMap) {
int m = XContentMapValues.nodeIntegerValue(indexOptionsMap.remove("m"), DEFAULT_MAX_CONN);
int efConstruction = XContentMapValues.nodeIntegerValue(indexOptionsMap.remove("ef_construction"), DEFAULT_BEAM_WIDTH);
MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap);
if (m == DEFAULT_MAX_CONN && efConstruction == DEFAULT_BEAM_WIDTH) {
return VectorFieldMapper.DEFAULT_INDEX_OPTIONS;
} else {
return new HNSWIndexOptions(m, efConstruction);
}
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field("type", type);
builder.field("m", m);
builder.field("ef_construction", efConstruction);
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
HNSWIndexOptions that = (HNSWIndexOptions) o;
return m == that.m && efConstruction == that.efConstruction;
}

@Override
public int hashCode() {
return Objects.hash(type, m, efConstruction);
}

@Override
public String toString() {
return "{type=" + type + ", m=" + m + ", ef_construction=" + efConstruction + " }";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public class CodecTests extends ESTestCase {

public void testResolveDefaultCodecs() throws Exception {
CodecService codecService = createCodecService();
assertThat(codecService.codec("default"), instanceOf(PerFieldMappingPostingFormatCodec.class));
assertThat(codecService.codec("default"), instanceOf(PerFieldMappingCodec.class));
assertThat(codecService.codec("default"), instanceOf(Lucene90Codec.class));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.common.unit.Fuzziness;
import org.elasticsearch.index.codec.PerFieldMappingCodec;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
Expand All @@ -38,7 +39,6 @@
import org.elasticsearch.index.analysis.IndexAnalyzers;
import org.elasticsearch.index.analysis.NamedAnalyzer;
import org.elasticsearch.index.codec.CodecService;
import org.elasticsearch.index.codec.PerFieldMappingPostingFormatCodec;
import org.hamcrest.FeatureMatcher;
import org.hamcrest.Matcher;
import org.hamcrest.Matchers;
Expand Down Expand Up @@ -122,8 +122,8 @@ public void testPostingsFormat() throws IOException {
MapperService mapperService = createMapperService(fieldMapping(this::minimalMapping));
CodecService codecService = new CodecService(mapperService);
Codec codec = codecService.codec("default");
assertThat(codec, instanceOf(PerFieldMappingPostingFormatCodec.class));
PerFieldMappingPostingFormatCodec perFieldCodec = (PerFieldMappingPostingFormatCodec) codec;
assertThat(codec, instanceOf(PerFieldMappingCodec.class));
PerFieldMappingCodec perFieldCodec = (PerFieldMappingCodec) codec;
assertThat(perFieldCodec.getPostingsFormatForField("field"), instanceOf(Completion90PostingsFormat.class));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ setup:
dims: 5
index: true
similarity: dot_product
index_options:
type: hnsw
m: 15
ef_construction: 80
- do:
index:
index: test-index
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ setup:
dims: 3
index: true
similarity: l2_norm
index_options:
type: hnsw
m: 15

---
"Indexing of Dense vectors should error when dims don't match defined in the mapping":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Version;
import org.elasticsearch.index.mapper.VectorFieldMapper;
import org.elasticsearch.xcontent.XContentParser.Token;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.index.fielddata.IndexFieldData;
Expand Down Expand Up @@ -47,7 +48,7 @@
/**
* A {@link FieldMapper} for indexing a dense vector of floats.
*/
public class DenseVectorFieldMapper extends FieldMapper {
public class DenseVectorFieldMapper extends VectorFieldMapper {

public static final String CONTENT_TYPE = "dense_vector";
public static short MAX_DIMS_COUNT = 2048; //maximum allowed number of dimensions
Expand All @@ -73,6 +74,8 @@ public static class Builder extends FieldMapper.Builder {
private final Parameter<Boolean> indexed = Parameter.indexParam(m -> toType(m).indexed, false);
private final Parameter<VectorSimilarity> similarity = Parameter.enumParam(
"similarity", false, m -> toType(m).similarity, null, VectorSimilarity.class);
private final Parameter<IndexOptions> indexOptions = new Parameter<>("index_options", false, () -> null,
(n, c, o) -> VectorFieldMapper.parseVectorIndexOptions(n, o), m -> toType(m).indexOptions);
private final Parameter<Map<String, String>> meta = Parameter.metaParam();

final Version indexVersionCreated;
Expand All @@ -84,11 +87,13 @@ public Builder(String name, Version indexVersionCreated) {
this.indexed.requiresParameters(similarity);
this.similarity.setSerializerCheck((id, ic, v) -> v != null);
this.similarity.requiresParameters(indexed);
this.indexOptions.requiresParameters(indexed);
this.indexOptions.setSerializerCheck((id, ic, v) -> v != null);
}

@Override
protected List<Parameter<?>> getParameters() {
return List.of(dims, indexed, similarity, meta);
return List.of(dims, indexed, similarity, indexOptions, meta);
}

@Override
Expand All @@ -102,7 +107,8 @@ public DenseVectorFieldMapper build(MapperBuilderContext context) {
similarity.getValue(),
indexVersionCreated,
multiFieldsBuilder.build(this, context),
copyTo.build());
copyTo.build(),
indexOptions.getValue());
}
}

Expand Down Expand Up @@ -187,10 +193,10 @@ public Query termQuery(Object value, SearchExecutionContext context) {
private final VectorSimilarity similarity;
private final Version indexCreatedVersion;

private DenseVectorFieldMapper(String simpleName, MappedFieldType mappedFieldType, int dims,
boolean indexed, VectorSimilarity similarity,
Version indexCreatedVersion, MultiFields multiFields, CopyTo copyTo) {
super(simpleName, mappedFieldType, multiFields, copyTo);
private DenseVectorFieldMapper(String simpleName, MappedFieldType mappedFieldType, int dims, boolean indexed,
VectorSimilarity similarity, Version indexCreatedVersion, MultiFields multiFields,
CopyTo copyTo, VectorFieldMapper.IndexOptions indexOptions) {
super(simpleName, mappedFieldType, multiFields, copyTo, indexOptions);
this.dims = dims;
this.indexed = indexed;
this.similarity = similarity;
Expand Down
Loading

0 comments on commit 5004872

Please sign in to comment.