diff --git a/pom.xml b/pom.xml index ecb15d4b84..2dead29b50 100644 --- a/pom.xml +++ b/pom.xml @@ -105,10 +105,18 @@ io.anserini.index.IndexCollection IndexCollection + + io.anserini.index.IndexVectorCollection + IndexVectorCollection + io.anserini.search.SearchCollection SearchCollection + + io.anserini.search.SearchVectorCollection + SearchVectorCollection + io.anserini.search.SearchMsmarco SearchMsmarco diff --git a/src/main/java/io/anserini/collection/VectorCollection.java b/src/main/java/io/anserini/collection/VectorCollection.java new file mode 100644 index 0000000000..30ebbb74cf --- /dev/null +++ b/src/main/java/io/anserini/collection/VectorCollection.java @@ -0,0 +1,93 @@ +/* + * Anserini: A Lucene toolkit for reproducible information retrieval research + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.anserini.collection; + +import com.fasterxml.jackson.databind.JsonNode; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.HashMap; +import java.util.Map; + + +/** + * A document collection for encoded dense vectors for ANN (HNSW) search. + * The "vector" field are concatenated into the "contents" field for indexing. + */ +public class VectorCollection extends DocumentCollection { + public VectorCollection(Path path) { + this.path = path; + } + + @Override + public FileSegment createFileSegment(Path p) throws IOException { + return new VectorCollection.Segment<>(p); + } + + public static class Segment extends JsonCollection.Segment { + public Segment(Path path) throws IOException { + super(path); + } + + @Override + protected Document createNewDocument(JsonNode json) { + return new Document(json); + } + } + + public static class Document extends JsonCollection.Document { + private final String id; + private final String contents; + private final String raw; + private Map fields; + + public Document(JsonNode json) { + super(); + this.raw = json.toPrettyString(); + this.id = json.get("docid").asText(); + this.contents = json.get("vector").toString(); + // We're not going to index any other fields, so just initialize an empty map. + this.fields = new HashMap<>(); + } + + @Override + public String id() { + if (id == null) { + throw new RuntimeException("JSON document has no \"_id\" field!"); + } + return id; + } + + @Override + public String contents() { + if (contents == null) { + throw new RuntimeException("JSON document has no contents that could be parsed!"); + } + return contents; + } + + @Override + public String raw() { + return raw; + } + + @Override + public Map fields() { + return fields; + } + } +} diff --git a/src/main/java/io/anserini/index/IndexArgs.java b/src/main/java/io/anserini/index/IndexArgs.java index 70957978da..4ae249c967 100644 --- a/src/main/java/io/anserini/index/IndexArgs.java +++ b/src/main/java/io/anserini/index/IndexArgs.java @@ -36,9 +36,12 @@ public class IndexArgs { // This is the name of the field in the Lucene document where the entity document is stored. public static final String ENTITY = "entity"; + // This is the name of the field in the Lucene document where the vector document is stored. + public static final String VECTOR = "vector"; private static final int TIMEOUT = 600 * 1000; + // required arguments @Option(name = "-input", metaVar = "[path]", required = true, diff --git a/src/main/java/io/anserini/index/IndexVectorArgs.java b/src/main/java/io/anserini/index/IndexVectorArgs.java new file mode 100644 index 0000000000..96a247e617 --- /dev/null +++ b/src/main/java/io/anserini/index/IndexVectorArgs.java @@ -0,0 +1,120 @@ +/* + * Anserini: A Lucene toolkit for reproducible information retrieval research + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.anserini.index; + +import org.kohsuke.args4j.Option; +import org.kohsuke.args4j.spi.StringArrayOptionHandler; + + +public class IndexVectorArgs { + + // This is the name of the field in the Lucene document where the docid is stored. + public static final String ID = "id"; + + // This is the name of the field in the Lucene document that should be searched by default. + public static final String CONTENTS = "contents"; + + // This is the name of the field in the Lucene document where the raw document is stored. + public static final String RAW = "raw"; + + // This is the name of the field in the Lucene document where the vector document is stored. + public static final String VECTOR = "vector"; + + private static final int TIMEOUT = 600 * 1000; + + + // required arguments + + @Option(name = "-input", metaVar = "[path]", required = true, + usage = "Location of input collection.") + public String input; + + @Option(name = "-threads", metaVar = "[num]", required = true, + usage = "Number of indexing threads.") + public int threads; + + @Option(name = "-collection", metaVar = "[class]", required = true, + usage = "Collection class in package 'io.anserini.collection'.") + public String collectionClass; + + @Option(name = "-generator", metaVar = "[class]", + usage = "Document generator class in package 'io.anserini.index.generator'.") + public String generatorClass = "DefaultLuceneDocumentGenerator"; + + // optional general arguments + + @Option(name = "-verbose", forbids = {"-quiet"}, + usage = "Enables verbose logging for each indexing thread; can be noisy if collection has many small file segments.") + public boolean verbose = false; + + @Option(name = "-quiet", forbids = {"-verbose"}, + usage = "Turns off all logging.") + public boolean quiet = false; + + // optional arguments + + @Option(name = "-index", metaVar = "[path]", usage = "Index path.") + public String index; + + @Option(name = "-fields", handler = StringArrayOptionHandler.class, + usage = "List of fields to index (space separated), in addition to the default 'contents' field.") + public String[] fields = new String[]{}; + + @Option(name = "-storePositions", + usage = "Boolean switch to index store term positions; needed for phrase queries.") + public boolean storePositions = false; + + @Option(name = "-storeDocvectors", + usage = "Boolean switch to store document vectors; needed for (pseudo) relevance feedback.") + public boolean storeDocvectors = false; + + @Option(name = "-storeContents", + usage = "Boolean switch to store document contents.") + public boolean storeContents = false; + + @Option(name = "-storeRaw", + usage = "Boolean switch to store raw source documents.") + public boolean storeRaw = false; + + @Option(name = "-optimize", + usage = "Boolean switch to optimize index (i.e., force merge) into a single segment; costly for large collections.") + public boolean optimize = false; + + @Option(name = "-uniqueDocid", + usage = "Removes duplicate documents with the same docid during indexing. This significantly slows indexing throughput " + + "but may be needed for tweet collections since the streaming API might deliver a tweet multiple times.") + public boolean uniqueDocid = false; + + @Option(name = "-memorybuffer", metaVar = "[mb]", + usage = "Memory buffer size (in MB).") + public int memorybufferSize = 2048; + + @Option(name = "-whitelist", metaVar = "[file]", + usage = "File containing list of docids, one per line; only these docids will be indexed.") + public String whitelist = null; + + + // Sharding options + + @Option(name = "-shard.count", metaVar = "[n]", + usage = "Number of shards to partition the document collection into.") + public int shardCount = -1; + + @Option(name = "-shard.current", metaVar = "[n]", + usage = "The current shard number to generate (indexed from 0).") + public int shardCurrent = -1; +} diff --git a/src/main/java/io/anserini/index/IndexVectorCollection.java b/src/main/java/io/anserini/index/IndexVectorCollection.java new file mode 100644 index 0000000000..30c8bc4862 --- /dev/null +++ b/src/main/java/io/anserini/index/IndexVectorCollection.java @@ -0,0 +1,382 @@ +/* + * Anserini: A Lucene toolkit for reproducible information retrieval research + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.anserini.index; + +import io.anserini.collection.DocumentCollection; +import io.anserini.collection.FileSegment; +import io.anserini.collection.SourceDocument; +import io.anserini.index.generator.EmptyDocumentException; +import io.anserini.index.generator.InvalidDocumentException; +import io.anserini.index.generator.LuceneDocumentGenerator; +import io.anserini.index.generator.SkippedDocumentException; +import org.apache.commons.io.FileUtils; +import org.apache.commons.lang3.time.DurationFormatUtils; +import org.apache.logging.log4j.Level; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.core.config.Configurator; +import org.apache.lucene.document.Document; +import org.apache.lucene.index.ConcurrentMergeScheduler; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.Term; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.FSDirectory; +import org.kohsuke.args4j.CmdLineException; +import org.kohsuke.args4j.CmdLineParser; +import org.kohsuke.args4j.OptionHandlerFilter; +import org.kohsuke.args4j.ParserProperties; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; + +public final class IndexVectorCollection { + private static final Logger LOG = LogManager.getLogger(IndexVectorCollection.class); + + // This is the default analyzer used, unless another stemming algorithm or language is specified. + public final class Counters { + /** + * Counter for successfully indexed documents. + */ + public AtomicLong indexed = new AtomicLong(); + + /** + * Counter for empty documents that are not indexed. Empty documents are not necessary errors; + * it could be the case, for example, that a document is comprised solely of stopwords. + */ + public AtomicLong empty = new AtomicLong(); + + /** + * Counter for unindexable documents. These are cases where {@link SourceDocument#indexable()} + * returns false. + */ + public AtomicLong unindexable = new AtomicLong(); + + /** + * Counter for skipped documents. These are cases documents are skipped as part of normal + * processing logic, e.g., using a whitelist, not indexing retweets or deleted tweets. + */ + public AtomicLong skipped = new AtomicLong(); + + /** + * Counter for unexpected errors. + */ + public AtomicLong errors = new AtomicLong(); + } + + private final class LocalIndexerThread extends Thread { + final private Path inputFile; + final private IndexWriter writer; + final private DocumentCollection collection; + private FileSegment fileSegment; + + private LocalIndexerThread(IndexWriter writer, DocumentCollection collection, Path inputFile) { + this.writer = writer; + this.collection = collection; + this.inputFile = inputFile; + setName(inputFile.getFileName().toString()); + } + + @Override + @SuppressWarnings("unchecked") + public void run() { + try { + LuceneDocumentGenerator generator = (LuceneDocumentGenerator) + generatorClass.getDeclaredConstructor(IndexVectorArgs.class).newInstance(args); + + // We keep track of two separate counts: the total count of documents in this file segment (cnt), + // and the number of documents in this current "batch" (batch). We update the global counter every + // 10k documents: this is so that we get intermediate updates, which is informative if a collection + // has only one file segment; see https://github.com/castorini/anserini/issues/683 + int cnt = 0; + int batch = 0; + + FileSegment segment = collection.createFileSegment(inputFile); + // in order to call close() and clean up resources in case of exception + this.fileSegment = segment; + + for (SourceDocument d : segment) { + if (!d.indexable()) { + counters.unindexable.incrementAndGet(); + continue; + } + + Document doc; + try { + doc = generator.createDocument(d); + } catch (EmptyDocumentException e1) { + counters.empty.incrementAndGet(); + continue; + } catch (SkippedDocumentException e2) { + counters.skipped.incrementAndGet(); + continue; + } catch (InvalidDocumentException e3) { + counters.errors.incrementAndGet(); + continue; + } + + if (whitelistDocids != null && !whitelistDocids.contains(d.id())) { + counters.skipped.incrementAndGet(); + continue; + } + + if (args.uniqueDocid) { + writer.updateDocument(new Term("id", d.id()), doc); + } else { + writer.addDocument(doc); + } + cnt++; + batch++; + + // And the counts from this batch, reset batch counter. + if (batch % 10000 == 0) { + counters.indexed.addAndGet(batch); + batch = 0; + } + } + + // Add the remaining documents. + counters.indexed.addAndGet(batch); + + int skipped = segment.getSkippedCount(); + if (skipped > 0) { + // When indexing tweets, this is normal, because there are delete messages that are skipped over. + counters.skipped.addAndGet(skipped); + LOG.warn(inputFile.getParent().getFileName().toString() + File.separator + + inputFile.getFileName().toString() + ": " + skipped + " docs skipped."); + } + + if (segment.getErrorStatus()) { + counters.errors.incrementAndGet(); + LOG.error(inputFile.getParent().getFileName().toString() + File.separator + + inputFile.getFileName().toString() + ": error iterating through segment."); + } + + // Log at the debug level because this can be quite noisy if there are lots of file segments. + LOG.debug(inputFile.getParent().getFileName().toString() + File.separator + + inputFile.getFileName().toString() + ": " + cnt + " docs added."); + } catch (Exception e) { + LOG.error(Thread.currentThread().getName() + ": Unexpected Exception:", e); + } finally { + if (fileSegment != null) { + fileSegment.close(); + } + } + } + } + + private final IndexVectorArgs args; + private final Path collectionPath; + private final Set whitelistDocids; + private final Class collectionClass; + private final Class generatorClass; + private final DocumentCollection collection; + private final Counters counters; + private Path indexPath; + + @SuppressWarnings("unchecked") + public IndexVectorCollection(IndexVectorArgs args) throws Exception { + this.args = args; + + if (args.verbose) { + // If verbose logging enabled, changed default log level to DEBUG so we get per-thread logging messages. + Configurator.setRootLevel(Level.DEBUG); + LOG.info("Setting log level to " + Level.DEBUG); + } else if (args.quiet) { + // If quiet mode enabled, only report warnings and above. + Configurator.setRootLevel(Level.WARN); + } else { + // Otherwise, we get the standard set of log messages. + Configurator.setRootLevel(Level.INFO); + LOG.info("Setting log level to " + Level.INFO); + } + + LOG.info("Starting indexer..."); + LOG.info("============ Loading Parameters ============"); + LOG.info("DocumentCollection path: " + args.input); + LOG.info("CollectionClass: " + args.collectionClass); + LOG.info("Generator: " + args.generatorClass); + LOG.info("Threads: " + args.threads); + LOG.info("Store document \"contents\" field? " + args.storeContents); + LOG.info("Store document \"raw\" field? " + args.storeRaw); + LOG.info("Optimize (merge segments)? " + args.optimize); + LOG.info("Whitelist: " + args.whitelist); + LOG.info("Index path: " + args.index); + + if (args.index != null) { + this.indexPath = Paths.get(args.index); + if (!Files.exists(this.indexPath)) { + Files.createDirectories(this.indexPath); + } + } + + // Our documentation uses /path/to/foo as a convention: to make copy and paste of the commands work, we assume + // collections/ as the path location. + String pathStr = args.input; + if (pathStr.startsWith("/path/to")) { + pathStr = pathStr.replace("/path/to", "collections"); + } + collectionPath = Paths.get(pathStr); + if (!Files.exists(collectionPath) || !Files.isReadable(collectionPath) || !Files.isDirectory(collectionPath)) { + throw new RuntimeException("Document directory " + collectionPath.toString() + " does not exist or is not readable, please check the path"); + } + + this.generatorClass = Class.forName("io.anserini.index.generator." + args.generatorClass); + this.collectionClass = Class.forName("io.anserini.collection." + args.collectionClass); + + // Initialize the collection. + collection = (DocumentCollection) this.collectionClass.getConstructor(Path.class).newInstance(collectionPath); + + if (args.whitelist != null) { + List lines = FileUtils.readLines(new File(args.whitelist), "utf-8"); + this.whitelistDocids = new HashSet<>(lines); + } else { + this.whitelistDocids = null; + } + + this.counters = new Counters(); + } + + public Counters run() throws IOException { + final long start = System.nanoTime(); + LOG.info("============ Indexing Collection ============"); + + int numThreads = args.threads; + IndexWriter writer = null; + + // Used for LocalIndexThread + if (indexPath != null) { + final Directory dir = FSDirectory.open(indexPath); + final IndexWriterConfig config = new IndexWriterConfig(); + config.setOpenMode(IndexWriterConfig.OpenMode.CREATE); + config.setRAMBufferSizeMB(args.memorybufferSize); + config.setUseCompoundFile(false); + config.setMergeScheduler(new ConcurrentMergeScheduler()); + writer = new IndexWriter(dir, config); + } + + final ThreadPoolExecutor executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(numThreads); + LOG.info("Thread pool with " + numThreads + " threads initialized."); + + LOG.info("Initializing collection in " + collectionPath.toString()); + + List segmentPaths = collection.getSegmentPaths(); + // when we want sharding to be done + if (args.shardCount > 1) { + segmentPaths = collection.getSegmentPaths(args.shardCount, args.shardCurrent); + } + final int segmentCnt = segmentPaths.size(); + + LOG.info(String.format("%,d %s found", segmentCnt, (segmentCnt == 1 ? "file" : "files" ))); + LOG.info("Starting to index..."); + + for (int i = 0; i < segmentCnt; i++) { + executor.execute(new LocalIndexerThread(writer, collection, (Path) segmentPaths.get(i))); + } + + executor.shutdown(); + + try { + // Wait for existing tasks to terminate + while (!executor.awaitTermination(1, TimeUnit.MINUTES)) { + if (segmentCnt == 1) { + LOG.info(String.format("%,d documents indexed", counters.indexed.get())); + } else { + LOG.info(String.format("%.2f%% of files completed, %,d documents indexed", + (double) executor.getCompletedTaskCount() / segmentCnt * 100.0d, counters.indexed.get())); + } + } + } catch (InterruptedException ie) { + // (Re-)Cancel if current thread also interrupted + executor.shutdownNow(); + // Preserve interrupt status + Thread.currentThread().interrupt(); + } + + if (segmentCnt != executor.getCompletedTaskCount()) { + throw new RuntimeException("totalFiles = " + segmentCnt + + " is not equal to completedTaskCount = " + executor.getCompletedTaskCount()); + } + + long numIndexed = writer.getDocStats().maxDoc; + + // Do a final commit + try { + if (writer != null) { + writer.commit(); + if (args.optimize) { + writer.forceMerge(1); + } + } + } finally { + try { + if (writer != null) { + writer.close(); + } + } catch (IOException e) { + // It is possible that this happens... but nothing much we can do at this point, + // so just log the error and move on. + LOG.error(e); + } + } + + if (numIndexed != counters.indexed.get()) { + LOG.warn("Unexpected difference between number of indexed documents and index maxDoc."); + } + + LOG.info(String.format("Indexing Complete! %,d documents indexed", numIndexed)); + LOG.info("============ Final Counter Values ============"); + LOG.info(String.format("indexed: %,12d", counters.indexed.get())); + LOG.info(String.format("unindexable: %,12d", counters.unindexable.get())); + LOG.info(String.format("empty: %,12d", counters.empty.get())); + LOG.info(String.format("skipped: %,12d", counters.skipped.get())); + LOG.info(String.format("errors: %,12d", counters.errors.get())); + + final long durationMillis = TimeUnit.MILLISECONDS.convert(System.nanoTime() - start, TimeUnit.NANOSECONDS); + LOG.info(String.format("Total %,d documents indexed in %s", numIndexed, + DurationFormatUtils.formatDuration(durationMillis, "HH:mm:ss"))); + + return counters; + } + + public static void main(String[] args) throws Exception { + IndexVectorArgs indexCollectionArgs = new IndexVectorArgs(); + CmdLineParser parser = new CmdLineParser(indexCollectionArgs, ParserProperties.defaults().withUsageWidth(100)); + + try { + parser.parseArgument(args); + } catch (CmdLineException e) { + System.err.println(e.getMessage()); + parser.printUsage(System.err); + System.err.println("Example: " + IndexVectorCollection.class.getSimpleName() + + parser.printExample(OptionHandlerFilter.REQUIRED)); + return; + } + + new IndexVectorCollection(indexCollectionArgs).run(); + } +} diff --git a/src/main/java/io/anserini/index/generator/LuceneVectorDocumentGenerator.java b/src/main/java/io/anserini/index/generator/LuceneVectorDocumentGenerator.java new file mode 100644 index 0000000000..73437c9d60 --- /dev/null +++ b/src/main/java/io/anserini/index/generator/LuceneVectorDocumentGenerator.java @@ -0,0 +1,93 @@ +/* + * Anserini: A Lucene toolkit for reproducible information retrieval research + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.anserini.index.generator; + +import io.anserini.collection.SourceDocument; +import io.anserini.index.IndexVectorArgs; + +import java.util.ArrayList; +import java.util.Arrays; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.JsonMappingException; +import com.fasterxml.jackson.databind.ObjectMapper; + +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.KnnVectorField; +import org.apache.lucene.document.StoredField; +import org.apache.lucene.document.StringField; +import org.apache.lucene.index.VectorSimilarityFunction; + + +/** + * Converts a {@link SourceDocument} into a Lucene {@link Document}, ready to be indexed. + * + * @param type of the source document + */ +public class LuceneVectorDocumentGenerator implements LuceneDocumentGenerator { + protected IndexVectorArgs args; + + protected LuceneVectorDocumentGenerator() { + } + + /** + * Constructor with config and counters + * + * @param args configuration arguments + */ + public LuceneVectorDocumentGenerator(IndexVectorArgs args) { + this.args = args; + } + + private float[] convertJsonArray(String vectorString) throws JsonMappingException, JsonProcessingException { + ObjectMapper mapper = new ObjectMapper(); + ArrayList denseVector = mapper.readValue(vectorString, new TypeReference>(){}); + int length = denseVector.size(); + float[] vector = new float[length]; + int i = 0; + for (Float f : denseVector) { + vector[i++] = f; + } + return vector; + } + + @Override + public Document createDocument(T src) throws InvalidDocumentException { + String id = src.id(); + float[] contents; + + try { + contents = convertJsonArray(src.contents()); + } catch (Exception e) { + throw new InvalidDocumentException(); + } + + // Make a new, empty document. + final Document document = new Document(); + + // Store the collection docid. + document.add(new StringField(IndexVectorArgs.ID, id, Field.Store.YES)); + // This is needed to break score ties by docid. + document.add(new KnnVectorField(IndexVectorArgs.VECTOR, contents, VectorSimilarityFunction.DOT_PRODUCT)); + if (args.storeRaw) { + document.add(new StoredField(IndexVectorArgs.RAW, src.raw())); + } + return document; + } +} diff --git a/src/main/java/io/anserini/search/SearchVectorArgs.java b/src/main/java/io/anserini/search/SearchVectorArgs.java new file mode 100644 index 0000000000..04e0e92a8d --- /dev/null +++ b/src/main/java/io/anserini/search/SearchVectorArgs.java @@ -0,0 +1,103 @@ +/* + * Anserini: A Lucene toolkit for reproducible information retrieval research + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.anserini.search; + +import org.kohsuke.args4j.Option; +import org.kohsuke.args4j.spi.StringArrayOptionHandler; + + +public class SearchVectorArgs { + // required arguments + @Option(name = "-index", metaVar = "[path]", required = true, usage = "Path to Lucene index") + public String index; + + @Option(name = "-topics", metaVar = "[file]", handler = StringArrayOptionHandler.class, required = true, usage = "topics file") + public String[] topics; + + @Option(name = "-output", metaVar = "[file]", required = true, usage = "output file") + public String output; + + @Option(name = "-topicreader", required = true, usage = "TopicReader to use.") + public String topicReader; + + // optional arguments + @Option(name = "-querygenerator", usage = "QueryGenerator to use.") + public String queryGenerator = "BagOfWordsQueryGenerator"; + + @Option(name = "-threads", metaVar = "[int]", usage = "Number of threads to use for running different parameter configurations.") + public int threads = 1; + + @Option(name = "-parallelism", metaVar = "[int]", usage = "Number of threads to use for each individual parameter configuration.") + public int parallelism = 8; + + @Option(name = "-removeQuery", usage = "Remove docids that have the query id when writing final run output.") + public Boolean removeQuery = false; + + // Note that this option is set to false by default because duplicate documents usually indicate some underlying + // indexing issues, and we don't want to just eat errors silently. + @Option(name = "-removedups", usage = "Remove duplicate docids when writing final run output.") + public Boolean removedups = false; + + @Option(name = "-skipexists", usage = "When enabled, will skip if the run file exists") + public Boolean skipexists = false; + + @Option(name = "-hits", metaVar = "[number]", required = false, usage = "max number of hits to return") + public int hits = 1000; + + @Option(name = "-efSearch", metaVar = "[number]", required = false, usage = "efSearch parameter for HNSW search") + public int efSearch = 100; + + @Option(name = "-inmem", usage = "Boolean switch to read index in memory") + public Boolean inmem = false; + + @Option(name = "-topicfield", usage = "Which field of the query should be used, default \"title\"." + + " For TREC ad hoc topics, description or narrative can be used.") + public String topicfield = "title"; + + @Option(name = "-runtag", metaVar = "[tag]", usage = "runtag") + public String runtag = null; + + @Option(name = "-format", metaVar = "[output format]", usage = "Output format, default \"trec\", alternative \"msmarco\".") + public String format = "trec"; + + // --------------------------------------------- + // Simple built-in support for passage retrieval + // --------------------------------------------- + + // A simple approach to passage retrieval is to pre-segment documents in the corpus into passages and index those + // passages. At retrieval time, we retain only the max scoring passage from each document; this is often called MaxP, + // from Dai and Callan (SIGIR 2019) in the context of BERT, although the general approach dates back to Callan + // (SIGIR 1994), Hearst and Plaunt (SIGIR 1993), and lots of other papers from the 1990s and even earlier. + // + // One common convention is to label the passages of a docid as "docid.00000", "docid.00001", "docid.00002", ... + // We use this convention in CORD-19. Alternatively, in document expansion for the MS MARCO document corpus, we use + // '#' as the delimiter. + // + // The options below control various aspects of this behavior. + + @Option(name = "-selectMaxPassage", usage = "Select and retain only the max scoring segment from each document.") + public Boolean selectMaxPassage = false; + + @Option(name = "-selectMaxPassage.delimiter", metaVar = "[regexp]", + usage = "The delimiter (as a regular regression) for splitting the segment id from the doc id.") + public String selectMaxPassage_delimiter = "\\."; + + @Option(name = "-selectMaxPassage.hits", metaVar = "[int]", + usage = "Maximum number of hits to return per topic after segment id removal. " + + "Note that this is different from '-hits', which specifies the number of hits including the segment id. ") + public int selectMaxPassage_hits = Integer.MAX_VALUE; +} diff --git a/src/main/java/io/anserini/search/SearchVectorCollection.java b/src/main/java/io/anserini/search/SearchVectorCollection.java new file mode 100644 index 0000000000..b92d42a381 --- /dev/null +++ b/src/main/java/io/anserini/search/SearchVectorCollection.java @@ -0,0 +1,335 @@ +/* + * Anserini: A Lucene toolkit for reproducible information retrieval research + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.anserini.search; + +import io.anserini.index.IndexVectorArgs; +import io.anserini.rerank.ScoredDocuments; +import io.anserini.search.query.VectorQueryGenerator; +import io.anserini.search.topicreader.TopicReader; +import org.apache.commons.lang3.time.DurationFormatUtils; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnVectorQuery; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.SortField; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.store.FSDirectory; +import org.apache.lucene.store.MMapDirectory; +import org.kohsuke.args4j.CmdLineException; +import org.kohsuke.args4j.CmdLineParser; +import org.kohsuke.args4j.OptionHandlerFilter; +import org.kohsuke.args4j.ParserProperties; + +import java.io.Closeable; +import java.io.File; +import java.io.IOException; +import java.io.PrintWriter; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.HashSet; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.SortedMap; +import java.util.TreeMap; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ConcurrentSkipListMap; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * Main entry point for search. + */ +public final class SearchVectorCollection implements Closeable { + // These are the default tie-breaking rules for documents that end up with the same score with respect to a query. + // For most collections, docids are strings, and we break ties by lexicographic sort order. For tweets, docids are + // longs, and we break ties by reverse numerical sort order (i.e., most recent tweet first). This means that searching + // tweets requires a slightly different code path, which is enabled by the -searchtweets option in SearchVectorArgs. + public static final Sort BREAK_SCORE_TIES_BY_DOCID = + new Sort(SortField.FIELD_SCORE, new SortField(IndexVectorArgs.ID, SortField.Type.STRING_VAL)); + + private static final Logger LOG = LogManager.getLogger(SearchVectorCollection.class); + + private final SearchVectorArgs args; + private final IndexReader reader; + + private final class SearcherThread extends Thread { + final private IndexReader reader; + final private IndexSearcher searcher; + final private SortedMap> topics; + final private String outputPath; + final private String runTag; + + private SearcherThread(IndexReader reader, SortedMap> topics, String outputPath, String runTag) { + this.reader = reader; + this.topics = topics; + this.runTag = runTag; + this.outputPath = outputPath; + this.searcher = new IndexSearcher(this.reader); + setName(outputPath); + } + + @Override + public void run() { + try { + // A short descriptor of the ranking setup. + final String desc = String.format("ranker: kNN"); + // ThreadPool for parallelizing the execution of individual queries: + ThreadPoolExecutor executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(args.parallelism); + // Data structure for holding the per-query results, with the qid as the key and the results (the lines that + // will go into the final run file) as the value. + ConcurrentSkipListMap results = new ConcurrentSkipListMap<>(); + AtomicInteger cnt = new AtomicInteger(); + + final long start = System.nanoTime(); + for (Map.Entry> entry : topics.entrySet()) { + K qid = entry.getKey(); + + // This is the per-query execution, in parallel. + executor.execute(() -> { + // This is for holding the results. + StringBuilder out = new StringBuilder(); + String queryString = entry.getValue().get(args.topicfield); + ScoredDocuments docs; + try { + docs = search(this.searcher, queryString); + } catch (IOException e) { + throw new CompletionException(e); + } + + // For removing duplicate docids. + Set docids = new HashSet<>(); + + int rank = 1; + for (int i = 0; i < docs.documents.length; i++) { + String docid = docs.documents[i].get(IndexVectorArgs.ID); + + if (args.selectMaxPassage) { + docid = docid.split(args.selectMaxPassage_delimiter)[0]; + } + + if (docids.contains(docid)) + continue; + + // Remove docids that are identical to the query id if flag is set. + if (args.removeQuery && docid.equals(qid)) + continue; + + if ("msmarco".equals(args.format)) { + // MS MARCO output format: + out.append(String.format(Locale.US, "%s\t%s\t%d\n", qid, docid, rank)); + } else { + // Standard TREC format: + // + the first column is the topic number. + // + the second column is currently unused and should always be "Q0". + // + the third column is the official document identifier of the retrieved document. + // + the fourth column is the rank the document is retrieved. + // + the fifth column shows the score (integer or floating point) that generated the ranking. + // + the sixth column is called the "run tag" and should be a unique identifier for your + out.append(String.format(Locale.US, "%s Q0 %s %d %f %s\n", + qid, docid, rank, docs.scores[i], runTag)); + } + + // Note that this option is set to false by default because duplicate documents usually indicate some + // underlying indexing issues, and we don't want to just eat errors silently. + // + // However, we we're performing passage retrieval, i.e., with "selectMaxSegment", we *do* want to remove + // duplicates. + if (args.removedups || args.selectMaxPassage) { + docids.add(docid); + } + + rank++; + + if (args.selectMaxPassage && rank > args.selectMaxPassage_hits) { + break; + } + } + + results.put(qid, out.toString()); + int n = cnt.incrementAndGet(); + if (n % 100 == 0) { + LOG.info(String.format("%s: %d queries processed", desc, n)); + } + }); + } + + executor.shutdown(); + + try { + // Wait for existing tasks to terminate. + while (!executor.awaitTermination(1, TimeUnit.MINUTES)); + } catch (InterruptedException ie) { + // (Re-)Cancel if current thread also interrupted. + executor.shutdownNow(); + // Preserve interrupt status. + Thread.currentThread().interrupt(); + } + final long durationMillis = TimeUnit.MILLISECONDS.convert(System.nanoTime() - start, TimeUnit.NANOSECONDS); + + LOG.info(desc + ": " + topics.size() + " queries processed in " + + DurationFormatUtils.formatDuration(durationMillis, "HH:mm:ss") + + String.format(" = ~%.2f q/s", topics.size()/(durationMillis/1000.0))); + + // Now we write the results to a run file. + PrintWriter out = new PrintWriter(Files.newBufferedWriter(Paths.get(outputPath), StandardCharsets.UTF_8)); + + // This is the default case: just dump out the qids by their natural order. + for (K qid : results.keySet()) { + out.print(results.get(qid)); + } + out.flush(); + out.close(); + + } catch (Exception e) { + LOG.error(Thread.currentThread().getName() + ": Unexpected Exception: ", e); + } + } + } + + public SearchVectorCollection(SearchVectorArgs args) throws IOException { + this.args = args; + Path indexPath = Paths.get(args.index); + + if (!Files.exists(indexPath) || !Files.isDirectory(indexPath) || !Files.isReadable(indexPath)) { + throw new IllegalArgumentException(String.format("Index path '%s' does not exist or is not a directory.", args.index)); + } + + LOG.info("============ Initializing Searcher ============"); + LOG.info("Index: " + indexPath); + this.reader = args.inmem ? DirectoryReader.open(MMapDirectory.open(indexPath)) : + DirectoryReader.open(FSDirectory.open(indexPath)); + LOG.info("Vector Search:"); + LOG.info("Number of threads for running different parameter configurations: " + args.threads); + } + + @Override + public void close() throws IOException { + reader.close(); + } + + @SuppressWarnings("unchecked") + public void runTopics() throws IOException { + TopicReader tr; + SortedMap> topics = new TreeMap<>(); + for (String singleTopicsFile : args.topics) { + Path topicsFilePath = Paths.get(singleTopicsFile); + if (!Files.exists(topicsFilePath) || !Files.isRegularFile(topicsFilePath) || !Files.isReadable(topicsFilePath)) { + throw new IllegalArgumentException("Topics file : " + topicsFilePath + " does not exist or is not a (readable) file."); + } + try { + tr = (TopicReader) Class.forName("io.anserini.search.topicreader." + args.topicReader + "TopicReader") + .getConstructor(Path.class).newInstance(topicsFilePath); + topics.putAll(tr.read()); + } catch (Exception e) { + e.printStackTrace(); + throw new IllegalArgumentException("Unable to load topic reader: " + args.topicReader); + } + } + + final String runTag = args.runtag == null ? "Anserini" : args.runtag; + LOG.info("runtag: " + runTag); + + final ThreadPoolExecutor executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(args.threads); + + + LOG.info("============ Launching Search Threads ============"); + + String outputPath = args.output; + if (args.skipexists && new File(outputPath).exists()) { + LOG.info("Run already exists, skipping: " + outputPath); + } else { + executor.execute(new SearcherThread<>(reader, topics, outputPath, runTag)); + executor.shutdown(); + } + + try { + // Wait for existing tasks to terminate + while (!executor.awaitTermination(1, TimeUnit.MINUTES)) { + } + } catch (InterruptedException ie) { + // (Re-)Cancel if current thread also interrupted + executor.shutdownNow(); + // Preserve interrupt status + Thread.currentThread().interrupt(); + } + } + + public ScoredDocuments search(IndexSearcher searcher, String queryString) throws IOException { + KnnVectorQuery query; + VectorQueryGenerator generator; + try { + generator = (VectorQueryGenerator) Class.forName("io.anserini.search.query." + args.queryGenerator) + .getConstructor().newInstance(); + } catch (Exception e) { + e.printStackTrace(); + throw new IllegalArgumentException("Unable to load QueryGenerator: " + args.topicReader); + } + + // If fieldsMap isn't null, then it means that the -fields option is specified. In this case, we search across + // multiple fields with the associated boosts. + query = generator.buildQuery(IndexVectorArgs.VECTOR, queryString, args.efSearch); + + TopDocs rs = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[]{}); + rs = searcher.search(query, args.hits); + ScoredDocuments scoredDocs; + scoredDocs = ScoredDocuments.fromTopDocs(rs, searcher); + + return scoredDocs; + } + + + public static void main(String[] args) throws Exception { + SearchVectorArgs searchArgs = new SearchVectorArgs(); + CmdLineParser parser = new CmdLineParser(searchArgs, ParserProperties.defaults().withUsageWidth(100)); + + try { + parser.parseArgument(args); + } catch (CmdLineException e) { + System.err.println(e.getMessage()); + parser.printUsage(System.err); + System.err.println("Example: SearchCollection" + parser.printExample(OptionHandlerFilter.REQUIRED)); + return; + } + + final long start = System.nanoTime(); + SearchVectorCollection searcher; + + // We're at top-level already inside a main; makes no sense to propagate exceptions further, so reformat the + // exception messages and display on console. + try { + searcher = new SearchVectorCollection(searchArgs); + } catch (IllegalArgumentException e) { + System.err.println(e.getMessage()); + return; + } + + searcher.runTopics(); + searcher.close(); + final long durationMillis = TimeUnit.MILLISECONDS.convert(System.nanoTime() - start, TimeUnit.NANOSECONDS); + LOG.info("Total run time: " + DurationFormatUtils.formatDuration(durationMillis, "HH:mm:ss")); + } +} diff --git a/src/main/java/io/anserini/search/query/VectorQueryGenerator.java b/src/main/java/io/anserini/search/query/VectorQueryGenerator.java new file mode 100644 index 0000000000..71ed76d019 --- /dev/null +++ b/src/main/java/io/anserini/search/query/VectorQueryGenerator.java @@ -0,0 +1,48 @@ +/* + * Anserini: A Lucene toolkit for reproducible information retrieval research + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.anserini.search.query; + +import java.util.ArrayList; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.JsonMappingException; +import com.fasterxml.jackson.databind.ObjectMapper; + +import org.apache.lucene.search.KnnVectorQuery; + +public class VectorQueryGenerator { + + private float[] convertJsonArray(String vectorString) throws JsonMappingException, JsonProcessingException { + ObjectMapper mapper = new ObjectMapper(); + ArrayList denseVector = mapper.readValue(vectorString, new TypeReference>(){}); + int length = denseVector.size(); + float[] vector = new float[length]; + int i = 0; + for (Float f : denseVector) { + vector[i++] = f; + } + return vector; + } + + public KnnVectorQuery buildQuery(String field, String queryString, Integer topK) throws JsonMappingException, JsonProcessingException{ + float[] queryVector; + queryVector = convertJsonArray(queryString); + KnnVectorQuery knnQuery = new KnnVectorQuery(field, queryVector, topK); + return knnQuery; + } +} diff --git a/src/main/java/io/anserini/search/topicreader/JsonIntVectorTopicReader.java b/src/main/java/io/anserini/search/topicreader/JsonIntVectorTopicReader.java new file mode 100644 index 0000000000..830cfc1357 --- /dev/null +++ b/src/main/java/io/anserini/search/topicreader/JsonIntVectorTopicReader.java @@ -0,0 +1,51 @@ +/* + * Anserini: A Lucene toolkit for reproducible information retrieval research + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.anserini.search.topicreader; + +import java.io.BufferedReader; +import java.io.IOException; +import java.nio.file.Path; +import java.util.HashMap; +import java.util.Map; +import java.util.SortedMap; +import java.util.TreeMap; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + +public class JsonIntVectorTopicReader extends TopicReader { + + public JsonIntVectorTopicReader(Path topicFile) { + super(topicFile); + } + + @Override + public SortedMap> read(BufferedReader reader) throws IOException { + SortedMap> map = new TreeMap<>(); + String line; + ObjectMapper mapper = new ObjectMapper(); + while ((line = reader.readLine()) != null) { + line = line.trim(); + JsonNode lineNode = mapper.readerFor(JsonNode.class).readTree(line); + Integer topicID = lineNode.get("qid").asInt(); + Map fields = new HashMap<>(); + fields.put("vector", lineNode.get("vector").toString()); + map.put(topicID, fields); + } + return map; + } +} diff --git a/src/main/java/io/anserini/search/topicreader/JsonStringVectorTopicReader.java b/src/main/java/io/anserini/search/topicreader/JsonStringVectorTopicReader.java new file mode 100644 index 0000000000..b5b4695503 --- /dev/null +++ b/src/main/java/io/anserini/search/topicreader/JsonStringVectorTopicReader.java @@ -0,0 +1,52 @@ +/* + * Anserini: A Lucene toolkit for reproducible information retrieval research + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package io.anserini.search.topicreader; + +import java.io.BufferedReader; +import java.io.IOException; +import java.nio.file.Path; +import java.util.HashMap; +import java.util.Map; +import java.util.SortedMap; +import java.util.TreeMap; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + +public class JsonStringVectorTopicReader extends TopicReader { + + public JsonStringVectorTopicReader(Path topicFile) { + super(topicFile); + } + + @Override + public SortedMap> read(BufferedReader reader) throws IOException { + SortedMap> map = new TreeMap<>(); + String line; + ObjectMapper mapper = new ObjectMapper(); + while ((line = reader.readLine()) != null) { + line = line.trim(); + JsonNode lineNode = mapper.readerFor(JsonNode.class).readTree(line); + String topicID = lineNode.get("qid").asText(); + Map fields = new HashMap<>(); + fields.put("vector", lineNode.get("vector").toString()); + map.put(topicID, fields); + } + return map; + } +}