Skip to content

Latest commit

 

History

History
299 lines (241 loc) · 13.1 KB

experiments-monot5-tpu.md

File metadata and controls

299 lines (241 loc) · 13.1 KB

Neural Pointwise Ranking Baselines on MS MARCO Passage Retrieval - with TPU

This page contains instructions for running monoT5 on the MS MARCO passage ranking task.

We will focus on using monoT5-3B to rerank, since it is difficult to run such a large model without a TPU. We also mention the changes required to run monoT5-base for those with a more constrained compute budget.

Note that there are also separate documents to run MS MARCO ranking tasks on regular GPU. Please see MS MARCO document ranking task, MS MARCO passage ranking task - Subset and MS MARCO passage ranking task - Entire.

Prior to running this, we suggest looking at our first-stage BM25 ranking instructions. We rerank the BM25 run files that contain ~1000 passages per query using monoT5. monoT5 is a pointwise reranker. This means that each document is scored independently using T5.

Data Prep

Since we will use some scripts form PyGaggle to process data and evaluate results, we first install it from source.

git clone --recursive https://github.com/castorini/pygaggle.git
cd pygaggle
pip install .

We store all the files in the data/msmarco_passage directory.

export DATA_DIR=data/msmarco_passage
mkdir ${DATA_DIR}

We provide specific data prep instructions for the train and dev set.

Train Set

First, download the MS MARCO train triples:

cd ${DATA_DIR}
wget https://storage.googleapis.com/duobert_git/triples.train.small.tar.gz
tar -xvf triples.train.small.tar.gz
rm triples.train.small.tar.gz
cd ../../

Then convert the train triples file to the monoT5 input format:

python pygaggle/data/create_msmarco_monot5_train.py --triples_train ${DATA_DIR}/triples.train.small.tsv --output_to_t5 ${DATA_DIR}/query_doc_pairs.train.tsv

Next, copy the monoT5 input file to Google Storage. TPU training will read data directly from gs.

gsutil cp ${DATA_DIR}/query_doc_pairs.train.tsv ${GS_FOLDER}/

This file is made available in our bucket.

Dev Set

We download the query, qrels, run and corpus files corresponding to the MS MARCO passage dev set.

The run file is generated by following the Anserini's BM25 ranking instructions.

In short, the files are:

  • topics.msmarco-passage.dev-subset.txt: 6,980 queries from the MS MARCO dev set.
  • qrels.msmarco-passage.dev-subset.txt: 7,437 pairs of query relevant passage ids from the MS MARCO dev set.
  • run.dev.small.tsv: Approximately 6,980,000 pairs of dev set queries and retrieved passages using Anserini's BM25.
  • collection.tar.gz: All passages (8,841,823) in the MS MARCO passage corpus. In this tsv file, the first column is the passage id, and the second is the passage text.

A more detailed description of the data is available here.

Let's start.

cd ${DATA_DIR}
wget https://storage.googleapis.com/duobert_git/run.bm25.dev.small.tsv
wget https://github.com/raw/castorini/anserini/master/src/main/resources/topics-and-qrels/topics.msmarco-passage.dev-subset.txt
wget https://github.com/raw/castorini/anserini/master/src/main/resources/topics-and-qrels/qrels.msmarco-passage.dev-subset.txt
wget https://www.dropbox.com/s/m1n2wf80l1lb9j1/collection.tar.gz
tar -xvf collection.tar.gz
rm collection.tar.gz
mv run.bm25.dev.small.tsv run.dev.small.tsv
cd ../../

As a sanity check, we can evaluate the first-stage retrieved documents using the official MS MARCO evaluation script.

python tools/scripts/msmarco/msmarco_passage_eval.py ${DATA_DIR}/qrels.msmarco-passage.dev-subset.txt ${DATA_DIR}/run.dev.small.tsv

The output should be:

#####################
MRR @10: 0.18736452221767383
QueriesRanked: 6980
#####################

Then, we prepare the query-doc pairs in the monoT5 input format.

python pygaggle/data/create_msmarco_monot5_input.py --queries ${DATA_DIR}/topics.msmarco-passage.dev-subset.txt \
                                      --run ${DATA_DIR}/run.dev.small.tsv \
                                      --corpus ${DATA_DIR}/collection.tsv \
                                      --t5_input ${DATA_DIR}/query_doc_pairs.dev.small.txt \
                                      --t5_input_ids ${DATA_DIR}/query_doc_pair_ids.dev.small.tsv

We will get two output files here:

  • query_doc_pairs.dev.small.txt: The query-doc pairs for monoT5 input.
  • query_doc_pair_ids.dev.small.tsv: The query_ids and doc_ids that map to the query-doc pairs. We will use this to map query-doc pairs to their corresponding monoT5 output scores.

The files are made available in our bucket.

Note that there might be a memory issue if the monoT5 input file is too large for the memory in the instance. We thus split the input file into multiple files.

split --suffix-length 3 --numeric-suffixes --lines 1000000 ${DATA_DIR}/query_doc_pairs.dev.small.txt ${DATA_DIR}/query_doc_pairs.dev.small.txt

For query_doc_pairs.dev.small.txt, we will get 7 files after split. i.e. (query_doc_pairs.dev.small.txt000 to query_doc_pairs.dev.small.txt006). Note that it is possible that running reranking might still result in OOM issues in which case reduce the number of lines to smaller than 1000000.

We copy these input files to Google Storage. TPU inference will read data directly from gs.

export GS_FOLDER=<google storage folder to store input/output data>
gsutil cp ${DATA_DIR}/query_doc_pairs.dev.small.txt??? ${GS_FOLDER}

These files can also be found in our bucket.

Start a VM with TPU on Google Cloud

Define environment variables.

export PROJECT_NAME=<gcloud project name>
export PROJECT_ID=<gcloud project id>
export INSTANCE_NAME=<name of vm to create>
export TPU_NAME=<name of tpu to create>

Create the VM.

gcloud beta compute --project=${PROJECT_NAME} instances create ${INSTANCE_NAME} --zone=europe-west4-a --machine-type=n1-standard-4 --subnet=default --network-tier=PREMIUM --maintenance-policy=MIGRATE --service-account=${PROJECT_ID}-compute@developer.gserviceaccount.com  --scopes=https://www.googleapis.com/auth/cloud-platform --image=debian-10-buster-v20201112 --image-project=debian-cloud --boot-disk-size=25GB --boot-disk-type=pd-standard --boot-disk-device-name=${INSTANCE_NAME} --reservation-affinity=any

It is possible that the image and machine-type provided here are dated so feel free to update them to whichever fits your needs. After the VM created, we can ssh to the machine.
Make sure to initialize PROJECT_NAME and TPU_NAME from within the machine too. Then create a TPU.

curl -O https://dl.google.com/cloud_tpu/ctpu/latest/linux/ctpu && chmod a+x ctpu
./ctpu up --name=${TPU_NAME} --project=${PROJECT_NAME} --zone=europe-west4-a --tpu-size=v3-8 --tpu-only --noconf

Setup environment on VM

Install required tools including Miniconda.

sudo apt-get update
sudo apt-get install git gcc screen --yes
curl -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
bash ./Miniconda3-latest-Linux-x86_64.sh
source ~/.bashrc

Then create a Python virtual environment for the experiments and install dependencies.

conda init
conda create --y --name py36 python=3.6
conda activate py36
conda install -c conda-forge httptools jsonnet --yes
pip install tensorflow tensorflow-text t5[gcp]
git clone https://github.com/castorini/mesh.git
pip install --editable mesh

Rerank with monoT5

Let's first define the model type and checkpoint.

export MODEL_NAME=<base or 3B>
export MODEL_DIR=gs://castorini/monot5/experiments/${MODEL_NAME}

Then run following command to start the process in background and monitor the log

for ITER in {000..006}; do
  echo "Running iter: $ITER" >> out.log_eval_exp
  nohup t5_mesh_transformer \
    --tpu="${TPU_NAME}" \
    --gcp_project="${PROJECT_NAME}" \
    --tpu_zone="europe-west4-a" \
    --model_dir="${MODEL_DIR}" \
    --gin_file="gs://t5-data/pretrained_models/${MODEL_NAME}/operative_config.gin" \
    --gin_file="infer.gin" \
    --gin_file="beam_search.gin" \
    --gin_param="utils.tpu_mesh_shape.tpu_topology = '2x2'" \
    --gin_param="infer_checkpoint_step = 1100000" \
    --gin_param="utils.run.sequence_length = {'inputs': 512, 'targets': 2}" \
    --gin_param="Bitransformer.decode.max_decode_length = 2" \
    --gin_param="input_filename = '${GS_FOLDER}/query_doc_pairs.dev.small.txt${ITER}'" \
    --gin_param="output_filename = '${GS_FOLDER}/query_doc_pair_scores.dev.small.txt${ITER}'" \
    --gin_param="utils.run.batch_size=('tokens_per_batch', 65536)" \
    --gin_param="Bitransformer.decode.beam_size = 1" \
    --gin_param="Bitransformer.decode.temperature = 0.0" \
    --gin_param="Unitransformer.sample_autoregressive.sampling_keep_top_k = -1" \
    >> out.log_eval_exp 2>&1
done &

tail -100f out.log_eval_exp

Using a TPU v3-8, it takes approximately 5 hours and 35 hours to rerank with monoT5-base and monoT5-3B respectively.

Note that we strongly encourage you to run any of the long processes in screen to make sure they don't get interrupted.

Evaluate reranked results

After reranking is done, let's copy the results from GS to our working directory, where we concatenate all the score files back into one file.

gsutil cp ${GS_FOLDER}/query_doc_pair_scores.dev.small.txt???-1100000 ${DATA_DIR}/
cat ${DATA_DIR}/query_doc_pair_scores.dev.small.txt???-1100000 > ${DATA_DIR}/query_doc_pair_scores.dev.small.txt

Then we convert the monoT5 output to the required MSMARCO format.

python pygaggle/data/convert_monot5_output_to_msmarco_run.py --t5_output ${DATA_DIR}/query_doc_pair_scores.dev.small.txt \
                                                --t5_output_ids ${DATA_DIR}/query_doc_pair_ids.dev.small.tsv \
                                                --mono_run ${DATA_DIR}/run.monot5_${MODEL_NAME}.dev.tsv

Now we can evaluate the reranked results using the official MS MARCO evaluation script.

python tools/scripts/msmarco/msmarco_passage_eval.py ${DATA_DIR}/qrels.msmarco-passage.dev-subset.txt ${DATA_DIR}/run.monot5_${MODEL_NAME}.dev.tsv

In the case of monoT5-3B, the output should be:

#####################
MRR @10: 0.3983799517896949
QueriesRanked: 6980
#####################

In the case of monoT5-base, the output should be:

#####################
MRR @10: 0.38160657433938283
QueriesRanked: 6980
#####################

If you were able to replicate any of these results, please submit a PR adding to the replication log, along with the model(s) you replicated. Please mention in your PR if you note any differences.

Training a monoT5 reranker

We use the following environment variables:

export MODEL_NAME=<t5 pretrain model, e.g. base, large, 3B>
export GS_FOLDER=<gs folder to store checkpoints>
export PROJECT_NAME=<gcloud project name>
export TPU_NAME=<name of tpu to create>
export MODEL_INIT_CKPT=<initial model checkpoint, e.g. 999900>

Copy pre-trained checkpoint to our target model

echo "model_checkpoint_path: \"model.ckpt-${MODEL_INIT_CKPT}\"" > checkpoint
gsutil cp checkpoint ${GS_FOLDER}
gsutil cp gs://t5-data/pretrained_models/${MODEL_NAME}/model.ckpt-${MODEL_INIT_CKPT}* ${GS_FOLDER}

Finally, we can begin training.

nohup t5_mesh_transformer  \
  --tpu="${TPU_NAME}" \
  --gcp_project="${PROJECT_NAME}" \
  --tpu_zone="europe-west4-a" \
  --model_dir="${GS_FOLDER}" \
  --gin_param="init_checkpoint = 'gs://t5-data/pretrained_models/${MODEL_NAME}/model.ckpt-${MODEL_INIT_CKPT}'" \
  --gin_file="dataset.gin" \
  --gin_file="models/bi_v1.gin" \
  --gin_file="gs://t5-data/pretrained_models/${MODEL_NAME}/operative_config.gin" \
  --gin_param="utils.tpu_mesh_shape.model_parallelism = 1" \
  --gin_param="utils.tpu_mesh_shape.tpu_topology = '2x2'" \
  --gin_param="utils.run.train_dataset_fn = @t5.models.mesh_transformer.tsv_dataset_fn" \
  --gin_param="tsv_dataset_fn.filename = 'gs://castorini/monot5/data/query_doc_pairs.train.tsv'" \
  --gin_file="learning_rate_schedules/constant_0_001.gin" \
  --gin_param="run.train_steps = 1100000" \
  --gin_param="run.save_checkpoints_steps = 10000" \
  --gin_param="utils.run.batch_size=('tokens_per_batch', 65536)" \
  >> out.log_exp 2>&1 &

tail -100f out.log_exp

In the case of monoT5-3B, set utils.tpu_mesh_shape.model_parallelism to 8 instead of 1. Training monoT5 base, large, and 3B take approximately 12, 48, and 160 hours overall, respectively, on a single TPU v3-8.

Replication Log