From 69b4927e587ef88e6e290ac27507e18256f9e8d9 Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Tue, 29 Jun 2021 13:01:53 -0400 Subject: [PATCH] Upgrade version to 0.5.1 for bug fix release (#312) * Update sparsifying_bert_using_recipes.md (#299) Fix wrong link to tutorial images * BERT pruning tutorial clean up (#300) * Disable save ckpt for BERT tutorial command (#301) * Add output for eval in tutorial (#302) * Rewrite readme for hugging face transformers integration (#303) * Rewrite readme for hugging face transformers integration * Update integrations/huggingface-transformers/README.md Co-authored-by: Jeannie Finks <74554921+jeanniefinks@users.noreply.github.com> * Update integrations/huggingface-transformers/README.md Co-authored-by: Jeannie Finks <74554921+jeanniefinks@users.noreply.github.com> * Update integrations/huggingface-transformers/README.md Co-authored-by: Jeannie Finks <74554921+jeanniefinks@users.noreply.github.com> * Update integrations/huggingface-transformers/README.md Co-authored-by: Jeannie Finks <74554921+jeanniefinks@users.noreply.github.com> * Update integrations/huggingface-transformers/README.md Co-authored-by: Jeannie Finks <74554921+jeanniefinks@users.noreply.github.com> * Update integrations/huggingface-transformers/README.md Co-authored-by: Jeannie Finks <74554921+jeanniefinks@users.noreply.github.com> * Update integrations/huggingface-transformers/README.md Co-authored-by: Jeannie Finks <74554921+jeanniefinks@users.noreply.github.com> * update from review Co-authored-by: Jeannie Finks <74554921+jeanniefinks@users.noreply.github.com> * Passage retrieval compression (#297) * adding IR elastic stuff * adding data download and modified es dense ranking * adding Doc2query * adding DPR code * updating doc2quyery code * adding msmarco eval scri[t * making dataset HF compatible * making dataset HF compatible * running doc2query t5 * model running * working on integrating * done with yaml recipe for all prunable layers * fixing config spacing for pruning yaml * work on dataset making * updaed thedownload data script and model training * running doc2query but missing the work for pruning * fixing issues in pruning * moving around DPR * added optimal lobotomizing project * adding to readme for baseline * new structures * cleaning up structure and pushing baseline numbers * moving sparse_ml_utils.py to src Co-authored-by: Mark Kurtz * Update example commands for hugging face integration (#306) * fix: correct minor typo (#307) * Phased pruning (#311) * Update example commands for hugging face integration * Phased pruning implementation * Update for quality * Upgrade version to 0.5.1 for bug fix release Co-authored-by: Tuan Nguyen Co-authored-by: Jeannie Finks <74554921+jeanniefinks@users.noreply.github.com> Co-authored-by: spacemanidol Co-authored-by: Rahul Tuli --- CONTRIBUTING.md | 2 +- .../huggingface-transformers/README.md | 475 +--- .../sparsifying_bert_using_recipes.md | 16 +- research/information_retrieval/DPR/README.md | 48 + .../information_retrieval/DPR/conf/README.md | 65 + .../DPR/conf/biencoder_train_cfg.yaml | 47 + .../DPR/conf/ctx_sources/default_sources.yaml | 6 + .../conf/datasets/encoder_train_default.yaml | 46 + .../DPR/conf/datasets/retriever_default.yaml | 33 + .../DPR/conf/dense_retriever.yaml | 71 + .../DPR/conf/encoder/hf_bert.yaml | 24 + .../DPR/conf/extractive_reader_train_cfg.yaml | 73 + .../DPR/conf/gen_embs.yaml | 52 + .../DPR/conf/train/biencoder_default.yaml | 27 + .../DPR/conf/train/biencoder_local.yaml | 27 + .../DPR/conf/train/biencoder_nq.yaml | 27 + .../conf/train/extractive_reader_default.yaml | 21 + .../DPR/dense_retriever.py | 460 ++++ .../DPR/download_data.py | 550 ++++ .../information_retrieval/DPR/dpr/__init__.py | 0 .../DPR/dpr/data/__init__.py | 0 .../DPR/dpr/data/biencoder_data.py | 613 +++++ .../DPR/dpr/data/download_data.py | 550 ++++ .../DPR/dpr/data/qa_validation.py | 216 ++ .../DPR/dpr/data/reader_data.py | 646 +++++ .../DPR/dpr/data/retriever_data.py | 337 +++ .../DPR/dpr/data/tables.py | 674 +++++ .../DPR/dpr/indexer/faiss_indexers.py | 254 ++ .../DPR/dpr/models/__init__.py | 91 + .../DPR/dpr/models/biencoder.py | 452 ++++ .../DPR/dpr/models/fairseq_models.py | 60 + .../DPR/dpr/models/hf_models.py | 336 +++ .../DPR/dpr/models/pytext_models.py | 130 + .../DPR/dpr/models/reader.py | 236 ++ .../information_retrieval/DPR/dpr/options.py | 110 + .../DPR/dpr/utils/__init__.py | 0 .../DPR/dpr/utils/conf_utils.py | 28 + .../DPR/dpr/utils/data_utils.py | 333 +++ .../DPR/dpr/utils/dist_utils.py | 96 + .../DPR/dpr/utils/model_utils.py | 174 ++ .../DPR/dpr/utils/tokenizers.py | 241 ++ .../DPR/generate_dense_embeddings.py | 172 ++ .../DPR/model_config.yml | 24 + .../DPR/ms_marco_eval.py | 177 ++ .../DPR/requirements.txt | 14 + .../DPR/train_config.yml | 27 + .../DPR/train_dense_encoder.py | 836 ++++++ research/information_retrieval/README.md | 18 + .../information_retrieval/doc2query/README.md | 47 + .../doc2query/indexes/init.txt | 0 .../doc2query/outputs/bm25_baseline.txt | 4 + .../doc2query/outputs/init.txt | 0 .../doc2query/recipes/90sparse.yaml | 2310 +++++++++++++++++ .../doc2query/recipes/noprune.yaml | 6 + .../doc2query/requirements.txt | 136 + .../doc2query/sparseml_utils.py | 121 + .../doc2query/src/augment_collection.py | 103 + .../src/convert_doc_collection_to_jsonl.py | 20 + .../doc2query/src/distill_doc2query.py | 64 + .../doc2query/src/make_doc2query_data.py | 66 + .../doc2query/src/msmarco_passage_eval.py | 185 ++ .../doc2query/src/run_doc2query.py | 804 ++++++ .../doc2query/src/sparseml_utils.py | 121 + .../elastic_integration/README.md | 37 + .../elastic_integration/chunker.py | 36 + .../elastic_integration/dense_document.py | 36 + .../elastic_integration/dense_ranking.py | 122 + .../elastic_integration/requirements.txt | 8 + .../elastic_integration/run_ranker.py | 40 + research/optimal_lobotomizing/README.md | 17 + research/optimal_lobotomizing/data/init.txt | 0 research/optimal_lobotomizing/scripts/init.sh | 0 research/optimal_lobotomizing/src/init.py | 0 .../pytorch/optim/modifier_pruning.py | 53 + src/sparseml/version.py | 2 +- .../pytorch/optim/test_modifier_pruning.py | 28 +- 76 files changed, 12875 insertions(+), 406 deletions(-) create mode 100644 research/information_retrieval/DPR/README.md create mode 100644 research/information_retrieval/DPR/conf/README.md create mode 100644 research/information_retrieval/DPR/conf/biencoder_train_cfg.yaml create mode 100644 research/information_retrieval/DPR/conf/ctx_sources/default_sources.yaml create mode 100644 research/information_retrieval/DPR/conf/datasets/encoder_train_default.yaml create mode 100644 research/information_retrieval/DPR/conf/datasets/retriever_default.yaml create mode 100644 research/information_retrieval/DPR/conf/dense_retriever.yaml create mode 100644 research/information_retrieval/DPR/conf/encoder/hf_bert.yaml create mode 100644 research/information_retrieval/DPR/conf/extractive_reader_train_cfg.yaml create mode 100644 research/information_retrieval/DPR/conf/gen_embs.yaml create mode 100644 research/information_retrieval/DPR/conf/train/biencoder_default.yaml create mode 100644 research/information_retrieval/DPR/conf/train/biencoder_local.yaml create mode 100644 research/information_retrieval/DPR/conf/train/biencoder_nq.yaml create mode 100644 research/information_retrieval/DPR/conf/train/extractive_reader_default.yaml create mode 100644 research/information_retrieval/DPR/dense_retriever.py create mode 100644 research/information_retrieval/DPR/download_data.py create mode 100644 research/information_retrieval/DPR/dpr/__init__.py create mode 100644 research/information_retrieval/DPR/dpr/data/__init__.py create mode 100644 research/information_retrieval/DPR/dpr/data/biencoder_data.py create mode 100644 research/information_retrieval/DPR/dpr/data/download_data.py create mode 100644 research/information_retrieval/DPR/dpr/data/qa_validation.py create mode 100644 research/information_retrieval/DPR/dpr/data/reader_data.py create mode 100644 research/information_retrieval/DPR/dpr/data/retriever_data.py create mode 100644 research/information_retrieval/DPR/dpr/data/tables.py create mode 100644 research/information_retrieval/DPR/dpr/indexer/faiss_indexers.py create mode 100644 research/information_retrieval/DPR/dpr/models/__init__.py create mode 100644 research/information_retrieval/DPR/dpr/models/biencoder.py create mode 100644 research/information_retrieval/DPR/dpr/models/fairseq_models.py create mode 100644 research/information_retrieval/DPR/dpr/models/hf_models.py create mode 100644 research/information_retrieval/DPR/dpr/models/pytext_models.py create mode 100644 research/information_retrieval/DPR/dpr/models/reader.py create mode 100644 research/information_retrieval/DPR/dpr/options.py create mode 100644 research/information_retrieval/DPR/dpr/utils/__init__.py create mode 100644 research/information_retrieval/DPR/dpr/utils/conf_utils.py create mode 100644 research/information_retrieval/DPR/dpr/utils/data_utils.py create mode 100644 research/information_retrieval/DPR/dpr/utils/dist_utils.py create mode 100644 research/information_retrieval/DPR/dpr/utils/model_utils.py create mode 100644 research/information_retrieval/DPR/dpr/utils/tokenizers.py create mode 100644 research/information_retrieval/DPR/generate_dense_embeddings.py create mode 100644 research/information_retrieval/DPR/model_config.yml create mode 100644 research/information_retrieval/DPR/ms_marco_eval.py create mode 100644 research/information_retrieval/DPR/requirements.txt create mode 100644 research/information_retrieval/DPR/train_config.yml create mode 100644 research/information_retrieval/DPR/train_dense_encoder.py create mode 100644 research/information_retrieval/README.md create mode 100644 research/information_retrieval/doc2query/README.md create mode 100644 research/information_retrieval/doc2query/indexes/init.txt create mode 100644 research/information_retrieval/doc2query/outputs/bm25_baseline.txt create mode 100644 research/information_retrieval/doc2query/outputs/init.txt create mode 100644 research/information_retrieval/doc2query/recipes/90sparse.yaml create mode 100644 research/information_retrieval/doc2query/recipes/noprune.yaml create mode 100644 research/information_retrieval/doc2query/requirements.txt create mode 100644 research/information_retrieval/doc2query/sparseml_utils.py create mode 100644 research/information_retrieval/doc2query/src/augment_collection.py create mode 100644 research/information_retrieval/doc2query/src/convert_doc_collection_to_jsonl.py create mode 100644 research/information_retrieval/doc2query/src/distill_doc2query.py create mode 100644 research/information_retrieval/doc2query/src/make_doc2query_data.py create mode 100644 research/information_retrieval/doc2query/src/msmarco_passage_eval.py create mode 100644 research/information_retrieval/doc2query/src/run_doc2query.py create mode 100644 research/information_retrieval/doc2query/src/sparseml_utils.py create mode 100644 research/information_retrieval/elastic_integration/README.md create mode 100644 research/information_retrieval/elastic_integration/chunker.py create mode 100644 research/information_retrieval/elastic_integration/dense_document.py create mode 100644 research/information_retrieval/elastic_integration/dense_ranking.py create mode 100644 research/information_retrieval/elastic_integration/requirements.txt create mode 100644 research/information_retrieval/elastic_integration/run_ranker.py create mode 100644 research/optimal_lobotomizing/README.md create mode 100644 research/optimal_lobotomizing/data/init.txt create mode 100644 research/optimal_lobotomizing/scripts/init.sh create mode 100644 research/optimal_lobotomizing/src/init.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index bc86b75180d..f8df89154a3 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -77,7 +77,7 @@ For documentation edits, include: ## Question or Problem -Sign up or log in: **Deep Sparse Community** [Discourse Forum](https://https://discuss.neuralmagic.com/) and/or [Slack](https://join.slack.com/t/discuss-neuralmagic/shared_invite/zt-q1a1cnvo-YBoICSIw3L1dmQpjBeDurQ). We are growing the community member by member and happy to see you there. Post all other questions including support or how to contribute. Don’t forget to search through existing discussions to avoid duplication! Thanks! +Sign up or log in: **Deep Sparse Community** [Discourse Forum](https://discuss.neuralmagic.com/) and/or [Slack](https://join.slack.com/t/discuss-neuralmagic/shared_invite/zt-q1a1cnvo-YBoICSIw3L1dmQpjBeDurQ). We are growing the community member by member and happy to see you there. Post all other questions including support or how to contribute. Don’t forget to search through existing discussions to avoid duplication! Thanks! ## Developing SparseML diff --git a/integrations/huggingface-transformers/README.md b/integrations/huggingface-transformers/README.md index 29eedafe73f..8936599e38f 100644 --- a/integrations/huggingface-transformers/README.md +++ b/integrations/huggingface-transformers/README.md @@ -13,423 +13,106 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --> -# Transformers-SparseML Integration -This folder contains an example on how to use SparseML with transformers. -We focus on Question Answering and use a modified implementation from the BERT Stanford Question Answering Dataset (SQuAD) in transformers. -Using various pruning configuration files we demonstrate the effect unstructured pruning can have on SQuAD. The example code is based on the transformers SQuAD implementation focused on BERT on the SQuAD1.0 dataset. It runs in 120 min (with BERT-base) on a single Tesla V100 16GB. -## Installation and Requirements -These example scripts require SparseML, transformers, torch, datasets and associated libraries. To install run the following command +# SparseML Hugging Face Transformers Integration -```bash -pip install sparseml[torch] torch transformers datasets -``` +This directory combines the SparseML recipe-driven approach with the +[huggingface/transformers](https://github.com/huggingface/transformers) repository. +By integrating the robust training flows in the `transformers` repository with the SparseML code base, +we enable model sparsification techniques on popular NLP models such as [BERT](https://arxiv.org/abs/1810.04805) +creating smaller and faster deployable versions. +The techniques include, but are not limted to: +- Pruning +- Quantization +- Pruning and Quantization +- Sparse Transfer Learning -## Usage -To custom-prune a model first go to the `prune-config.yaml` file and modify the parameters to your needs. We have provided a range of pruning configurations in the `prune_config_files` folder. -`!EpochRangeModifier` controls how long the model trains for and each `!GMPruningModifier` modifies controls how each portion is pruned. You can modify `end_epoch` to control how long the pruning regime lasts and `final_sparsity` and `init_sparsity` define the speed at which the module is pruned and the final sparsity. -### Training -```bash -python run_qa.py \ - --model_name_or_path bert-base-uncased \ - --dataset_name squad \ - --do_train \ - --per_device_train_batch_size 16 \ - --learning_rate 3e-5 \ - --max_seq_length 384 \ - --doc_stride 128 \ - --output_dir bert-base-uncased-90-1shot/ \ - --overwrite_output_dir \ - --cache_dir cache \ - --preprocessing_num_workers 4 \ - --seed 42 \ - --num_train_epochs 2 \ - --nm_prune_config recipes/90sparsity1shot.yaml - --fp16 -``` +## Highlights -#### Evaluation -```bash -python run_qa.py \ - --model_name_or_path bert-base-uncased-99sparsity-10total8gmp/ \ - --dataset_name squad \ - --do_eval \ - --per_device_eval_batch_size 16 \ - --output_dir bert-base-uncased-99sparsity-10total8gmp/ \ - --overwrite_output_dir \ - --cache_dir cache \ - --preprocessing_num_workers 4 \ -``` -#### ONNX Export -```bash -python run_qa.py \ - --model_name_or_path bert-base-uncased-99sparsity-10total8gmp/ - --do_eval \ - --dataset_name squad \ - --do_onnx_export \ - --onnx_export_path bert-base-uncased-99sparsity-10total8gmp/ \ - --cache_dir cache \ - --preprocessing_num_workers 4 \ -``` - -## SQUAD Performance -To demonstrate the effect that various pruning regimes and techniques can have, we prune the same bert-base-uncased model to five different sparsities (0,80,90,95,99) using three pruning methodologies: -- one shot (prune to desired weights before fine tune then fine tune for 1 epoch), -- GMP 1 epoch (prune to desired sparsity over an epoch then stabilize over another epoch), and -- GMP 8 epochs (prune to desired sparsity over 8 epochs then stabilize over another 2 epochs). - -It is worth noting that we are pruning all layers uniformly and we believe further gains can be achieved by targeted pruning of individual layers. +Coming soon! -train/exact_match 76.04541 -wandb: train/f1 84.5742 +## Tutorials +- [Sparsifying BERT Models Using Recipes](https://github.com/neuralmagic/sparseml/blob/main/integrations/huggingface-transformers/tutorials/sparsifying_bert_using_recipes.md) -| base model name | sparsity | total train epochs | prunned | one shot |pruning epochs| F1 Score | EM Score | -|-----------------------|---------- |-----------------------|---------|----------|--------------|---------- |-----------| -| bert-base-uncased |0 |1 |no |no |0 |84.574 |76.045 | -| bert-base-uncased |0 |2 |no |no |0 |88.002 |80.634 | -| bert-base-uncased |0 |10 |no |no |0 |87.603 |79.130 | -| bert-base-uncased |80 |1 |yes |yes |0 |25.141 |15.998 | -| bert-base-uncased |80 |2 |yes |no |0 |66.964 |53.879 | -| bert-base-uncased |80 |10 |yes |no |8 |83.951 |74.409 | -| bert-base-uncased |90 |1 |yes |yes |0 |16.064 |07.786 | -| bert-base-uncased |90 |2 |yes |no |0 |64.185 |50.946 | -| bert-base-uncased |90 |10 |yes |no |8 |79.091 |68.184 | -| bert-base-uncased |95 |1 |yes |yes |0 |10.501 |04.929 | -| bert-base-uncased |95 |2 |yes |no |0 |24.445 |14.437 | -| bert-base-uncased |95 |10 |yes |no |8 |72.761 |60.407 | -| bert-base-uncased |97 |10 |yes |no |6 |70.260 |57.021 | -| bert-base-uncased |99 |1 |yes |yes |0 |09.685 |03.614 | -| bert-base-uncased |99 |2 |yes |no |0 |17.433 |07.871 | -| bert-base-uncased |99 |10 |yes |no |8 |47.306 |32.564 | +## Installation -## Training with Distillation -In addition to a simple QA model we provide an implementation which can leverage teacher-student distillation. The usage of the distillation code is virually identical to the non-distilled model but the commands are as follows: - -#### Training +To begin, run the following command in the root directory of this integration (`cd integrations/huggingface-transformers`): ```bash -python run_distill_qa.py \ - --teacher_model_name_or_path spacemanidol/neuralmagic-bert-squad-12layer-0sparse\ - --student_model_name_or_path bert-base-uncased \ - --dataset_name squad \ - --do_train \ - --per_device_train_batch_size 16 \ - --learning_rate 3e-5 \ - --max_seq_length 384 \ - --doc_stride 128 \ - --output_dir distill_2epoch/ \ - --overwrite_output_dir \ - --cache_dir cache \ - --preprocessing_num_workers 4 \ - --seed 42 \ - --num_train_epochs 2 \ - --nm_prune_config recipes/noprune2epoch.yaml - --fp16 +bash setup_integration.sh ``` -#### Evaluation -```bash -python run_qa.py \ - --model_name_or_path bert-base-uncased-99sparsity-10total8gmp/ \ - --dataset_name squad \ - --do_eval \ - --per_device_eval_batch_size 16 \ - --output_dir bert-base-uncased-99sparsity-10total8gmp/ \ - --overwrite_output_dir \ - --cache_dir cache \ - --preprocessing_num_workers 4 \ -``` -#### ONNX Export -```bash -python run_qa.py \ - --model_name_or_path bert-base-uncased-99sparsity-10total8gmp/ - --do_eval \ - --dataset_name squad \ - --do_onnx_export \ - --onnx_export_path bert-base-uncased-99sparsity-10total8gmp/ \ - --cache_dir cache \ - --preprocessing_num_workers 4 \ -``` -### Distillation Results -Sparsity 80, 90, 97 -| base model name | sparsity |Distilled| prunned |train epochs|pruning epochs| F1 Score | EM Score | -|-----------------------|---------- |---------|---------|------------|--------------|----------|----------| -| bert-base-uncased |0 |no |no |2 |0 |88.32442 |81.10690 | -| bert-base-uncased |80 |no |no |30 |18 |84.06276 |74.63576 | -| bert-base-uncased |90 |no |no |30 |18 |79.64549 |68.50520 | -| bert-base-uncased |97 |no |no |30 |18 |70.42570 |57.29423 | -| bert-base-uncased |0 |yes |no |2 |0 |89.02277 |82.03406 | -| bert-base-uncased |80 |yes |yes |30 |18 |88.03192 |80.81362 | -| bert-base-uncased |90 |yes |yes |30 |18 |85.63751 |77.41721 | -| bert-base-uncased |97 |yes |yes |30 |18 |75.01276 |63.94513 | +The `setup_integration.sh` file will clone the transformers repository with the SparseML integration as a subfolder. +After the repo has successfully cloned, transformers and datasets will be installed along with any necessary dependencies. -### Distillation, Pruning, Layer Dropping -To explore the effect of model pruning compared to layer dropping, we train models to sparsity to match the amount of parameters in models with layers dropped. Results feature both with and without distillation. For distillation we use hard distillation and a a trained teacher model which is trained on SQuAD for 2 epochs and achieves an 88.32442/81.10690 F1/EM. A 9-layer model is roughly equivalent to 20% sparsity, 6-layer to 40%, 3-layer to 60%, 1-layer to 72%. +It is recommended to run Python 3.8 as some of the scripts within the transformers repository require it. -| base model name | sparsity | params |Distilled| prunned | layers |pruning epochs| F1 Score | EM Score | -|-----------------------|---------- |-----------------------|---------|---------|----------|--------------|----------|-----------| -| bert-base-uncased |0 |108,893,186 |no |no |12 |0 |88.32442 |81.10690 | -| bert-base-uncased |0 |87,629,570 |no |no |9 |0 |86.70732 |78.81740 | -| bert-base-uncased |0 |66,365,954 |no |no |6 |0 |81.63629 |72.66793 | -| bert-base-uncased |0 |45,102,338 |no |no |3 |0 |51.75267 |39.11069 | -| bert-base-uncased |0 |30,926,594 |no |no |1 |0 |26.22600 |17.32261 | -| bert-base-uncased |20 |108,893,186 |no |yes |12 |18 |87.19622 |79.16746 | -| bert-base-uncased |40 |108,893,186 |no |yes |12 |18 |86.27294 |78.07947 | -| bert-base-uncased |60 |108,893,186 |no |yes |12 |18 |86.44120 |77.94702 | -| bert-base-uncased |72 |108,893,186 |no |yes |12 |18 |85.49873 |76.43330 | -| bert-base-uncased |80 |66,365,954 |no |yes |6 |18 |77.86777 |67.07663 | -| bert-base-uncased |90 |66,365,954 |no |yes |6 |18 |73.51963 |61.22044 | -| bert-base-uncased |97 |66,365,954 |no |yes |6 |18 |67.27468 |53.85998 | -| bert-base-uncased |0 |108,893,186 |yes |no |12 |0 |89.02277 |82.03406 | -| bert-base-uncased |0 |87,629,570 |yes |no |9 |0 |87.94176 |80.46358 | -| bert-base-uncased |0 |66,365,954 |yes |no |6 |0 |83.45530 |75.03311 | -| bert-base-uncased |0 |45,102,338 |yes |no |3 |0 |43.82823 |33.05581 | -| bert-base-uncased |0 |30,926,594 |yes |no |1 |0 |28.10105 |18.50520 | -| bert-base-uncased |20 |108,893,186 |yes |yes |12 |18 |89.55543 |82.74361 | -| bert-base-uncased |40 |108,893,186 |yes |yes |12 |18 |89.76856 |83.05581 | -| bert-base-uncased |60 |108,893,186 |yes |yes |12 |18 |89.38194 |82.28950 | -| bert-base-uncased |72 |108,893,186 |yes |yes |12 |18 |89.10581 |83.03690 | -| bert-base-uncased |80 |66,365,954 |yes |yes |6 |18 |84.69427 |76.56575 | -| bert-base-uncased |90 |66,365,954 |yes |yes |6 |18 |80.53862 |71.00284 | -| bert-base-uncased |97 |66,365,954 |yes |yes |6 |18 |72.36219 |60.82308 | +## Quick Tour -## QQP, MNLI, GLUE Tasks -Similar to our modifications to SQUAD, we can prune models for GLUE tasks with minimal changes. Building on the [run_glue.py](https://github.com/huggingface/transformers/blob/master/examples/pytorch/text-classification/run_glue.py) transformers implementation we update the scripts to prune with sparseml and add a distillation trainer. -To replicate our experiments you can use the following commonds: +Recipes encode the instructions and hyperparameters for sparsifying a model using modifiers to the training process. +The modifiers can range from pruning and quantization to learning rate and weight decay. +When appropriately combined, it becomes possible to create highly sparse and accurate models. -Training without distillation for MNLI where model is pruned to 80% sparsity. To add distillation just include the command ```sh --teacher_model_name_or_path ``` -```sh -python run_glue.py \ - --student_model_name_or_path bert-base-cased \ - --task_name MNLI \ +This integration adds a `--recipe` argument to the [`run_qa.py` script](https://github.com/neuralmagic/transformers/blob/master/examples/pytorch/question-answering/run_qa.py) among others. +The argument loads an appropriate recipe while preserving the rest of the training pipeline. +Popular recipes used with this argument are found in the [`recipes` folder](./recipes). +Distillation arguments to support student-teacher distillation are additionally added to the scripts as they help improve the recovery while sparsifying. +Otherwise, all other arguments and functionality remain the same as the original repository. + +For example, pruning and quantizing a model on the SQuAD dataset can be done by running the following command from within the root of this integration's folder: +```bash +python transformers/examples/pytorch/question-answering/run_qa.py \ + --model_name_or_path bert-base-uncased \ + --dataset_name squad \ --do_train \ - --max_seq_length 128 \ - --per_device_train_batch_size 32 \ - --learning_rate 2e-5 \ - --nm_prune_config recipes/80sparselong.yaml - --output_dir /80sparseMNLI + --do_eval \ + --evaluation_strategy epoch \ + --per_device_train_batch_size 16 \ + --learning_rate 5e-5 \ + --max_seq_length 384 \ + --doc_stride 128 \ + --output_dir MODELS_DIR/bert-base-12layers_prune80 \ + --cache_dir cache \ + --preprocessing_num_workers 6 \ + --fp16 \ + --num_train_epochs 30 \ + --recipe recipes/bert-base-12layers_prune80.md \ + --onnx_export_path MODELS_DIR/bert-base-12layers_prune80/onnx \ + --save_strategy epoch \ + --save_total_limit 2 ``` -eval/accuracy 0.90762 -wandb: eval/f1 0.8745 - -### Results -We see similair results for QQP and MNLI as in SQUAD but the effect of model distillation is more muted. -| base model name | sparsity |Distilled| prunned |train epochs|pruning epochs| QQP Accuracy | QQP F1 | MNLI Accuracy | -|-----------------------|---------- |---------|---------|------------|--------------|--------------|----------|---------------| -| bert-base-uncased |0 |no |no |3 |0 |91.47 |88.49 |84.42 | -| bert-base-uncased |80 |no |no |30 |18 |90.76 |87.45 |81.34 | -| bert-base-uncased |90 |no |no |30 |18 |89.38 |85.30 |78.76 | -| bert-base-uncased |97 |no |no |30 |18 |87.57 |83.34 |73.42 | -| bert-base-uncased |0 |yes |no |3 |0 |63.18 |00.00 |32.95 | -| bert-base-uncased |80 |yes |yes |30 |18 |72.12 |43.26 |32.95 | -| bert-base-uncased |90 |yes |yes |30 |18 |69.70 |34.76 |32.95 | -| bert-base-uncased |97 |yes |yes |30 |18 |63.83 |33.99 |32.95 | - -## How to Integrate SparseML with Other Transformers Projects -For any other projects using Hugging Face's transformers there are essentially four components to modify: imports and needed function, loading SparseML, modifying training script, and ONNX export. - -First, take your existing project and add the following imports and functions: -```python -from transformers.optimization import ( - Adafactor, - AdamW, - get_constant_schedule, - get_constant_schedule_with_warmup, - get_cosine_schedule_with_warmup, - get_cosine_with_hard_restarts_schedule_with_warmup, - get_linear_schedule_with_warmup, - get_polynomial_decay_schedule_with_warmup, -) +### Structure -from sparseml.pytorch.optim.manager import ScheduledModifierManager -from sparseml.pytorch.optim.optimizer import ScheduledOptimizer -from sparseml.pytorch.utils import ModuleExporter +The following table lays out the root-level files and folders along with a description for each. -def load_optimizer(model, args): - no_decay = ["bias", "LayerNorm.weight"] - optimizer_grouped_parameters = [ - { - "params": [ - p - for n, p in model.named_parameters() - if not any(nd in n for nd in no_decay) - ], - "weight_decay": args.weight_decay, - }, - { - "params": [ - p - for n, p in model.named_parameters() - if any(nd in n for nd in no_decay) - ], - "weight_decay": 0.0, - }, - ] - optimizer_cls = AdamW - optimizer_kwargs = { - "betas": (args.adam_beta1, args.adam_beta2), - "eps": args.adam_epsilon, - } - optimizer_kwargs["lr"] = args.learning_rate - return optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) +| Folder/File Name | Description | +|----------------------|-----------------------------------------------------------------------------------------------------------------------| +| recipes | Typical recipes for sparsifying NLP models along with any downloaded recipes from the SparseZoo. | +| tutorials | Tutorial walkthroughs for how to sparsify NLP models using recipes. | +| transformers | Integration repository folder used to train and sparsify NLP models (`setup_integration.sh` must run first). | +| README.md | Readme file. | +| setup_integration.sh | Setup file for the integration run from the command line. | -def convert_example_to_features(example, tokenizer, max_seq_length, doc_stride, max_query_length): - Feature = collections.namedtuple( - "Feature", - [ - "unique_id", - "tokens", - "example_index", - "token_to_orig_map", - "token_is_max_context", - ], - ) - extra = [] - unique_id = 0 - query_tokens = tokenizer.tokenize(example["question"])[0:max_query_length] - tok_to_orig_index = [] - orig_to_tok_index = [] - all_doc_tokens = [] - for (i, token) in enumerate(example["context"]): - orig_to_tok_index.append(len(all_doc_tokens)) - sub_tokens = tokenizer.tokenize(token) - for sub_token in sub_tokens: - tok_to_orig_index.append(i) - all_doc_tokens.append(sub_token) - max_tokens_for_doc = max_seq_length - len(query_tokens) - 3 - _DocSpan = collections.namedtuple("DocSpan", ["start", "length"]) - doc_spans = [] - start_offset = 0 - while start_offset < len(all_doc_tokens): - length = len(all_doc_tokens) - start_offset - if length > max_tokens_for_doc: - length = max_tokens_for_doc - doc_spans.append(_DocSpan(start=start_offset, length=length)) - if start_offset + length == len(all_doc_tokens): - break - start_offset += min(length, doc_stride) - for (doc_span_index, doc_span) in enumerate(doc_spans): - tokens = [] - token_to_orig_map = {} - token_is_max_context = {} - segment_ids = [] - tokens.append("[CLS]") - segment_ids.append(0) - for token in query_tokens: - tokens.append(token) - segment_ids.append(0) - tokens.append("[SEP]") - segment_ids.append(0) - for i in range(doc_span.length): - split_token_index = doc_span.start + i - token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index] - is_max_context = _check_is_max_context( - doc_spans, doc_span_index, split_token_index - ) - token_is_max_context[len(tokens)] = is_max_context - tokens.append(all_doc_tokens[split_token_index]) - segment_ids.append(1) - tokens.append("[SEP]") - segment_ids.append(1) - input_ids = tokenizer.convert_tokens_to_ids(tokens) - input_mask = [1] * len(input_ids) - while len(input_ids) < max_seq_length: - input_ids.append(0) - input_mask.append(0) - segment_ids.append(0) - feature = Feature( - unique_id=unique_id, - tokens=tokens, - example_index=0, - token_to_orig_map=token_to_orig_map, - token_is_max_context=token_is_max_context, - ) - extra.append(feature) - unique_id += 1 - # extra is used as additional data but sparseml doesn't support it - return ( - torch.from_numpy(np.array([np.array(input_ids, dtype=np.int64)])), - torch.from_numpy(np.array([np.array(input_mask, dtype=np.int64)])), - torch.from_numpy(np.array([np.array(segment_ids, dtype=np.int64)])), - ) +### Exporting for Inference +After sparsifying a model, the `run_qa.py` script can be run with the `--onnx_export_path` argument to convert the model into an [ONNX](https://onnx.ai/) deployment format. +The export process is modified such that the quantized and pruned models are corrected and folded properly. -def _check_is_max_context(doc_spans, cur_span_index, position): - best_score = None - best_span_index = None - for (span_index, doc_span) in enumerate(doc_spans): - end = doc_span.start + doc_span.length - 1 - if position < doc_span.start: - continue - if position > end: - continue - num_left_context = position - doc_span.start - num_right_context = end - position - score = min(num_left_context, num_right_context) + 0.01 * doc_span.length - if best_score is None or score > best_score: - best_score = score - best_span_index = span_index - return cur_span_index == best_span_index -``` -We add some SparseML arguments: -```python -@dataclass -class DataTrainingArguments: - """ - Arguments pertaining to what data we are going to input our model for training and eval. - """ - #################################################################################### - # Start SparseML Integration - #################################################################################### - nm_prune_config: Optional[str] = field( - default='recipes/noprune1epoch.yaml', metadata={"help": "The input file name for the Neural Magic pruning config"} - ) - do_onnx_export: bool = field( - default=False, metadata={"help": "Export model to onnx"} - ) - onnx_export_path: Optional[str] = field( - default='onnx-export', metadata={"help": "The filename and path which will be where onnx model is outputed"} - ) - #################################################################################### - # End SparseML Integration - #################################################################################### -``` -Use the code below to load SparseML optimizers: -```python -## Neural Magic Integration here. -optim = load_optimizer(model, training_args) -steps_per_epoch = math.ceil(len(train_dataset) / (training_args.per_device_train_batch_size*training_args._n_gpu)) -manager = ScheduledModifierManager.from_yaml(data_args.nm_prune_config) -optim = ScheduledOptimizer(optim, model, manager, steps_per_epoch=steps_per_epoch, loggers=None) -``` -Modify the Hugging Face trainer to take the SparseML optimzier as shown below: -```python -# Initialize our Trainer and continue to use your regular transformers trainer -trainer = QuestionAnsweringTrainer( - model=model, - args=training_args, - train_dataset=train_dataset if training_args.do_train else None, - eval_dataset=validation_dataset if training_args.do_eval else None, - eval_examples=datasets["validation"] if training_args.do_eval else None, - tokenizer=tokenizer, - data_collator=data_collator, - post_process_function=post_processing_function, - compute_metrics=compute_metrics, - optimizers=(optim, None), # This is what is new. -) -``` -Finally, export the model. It is worth noting that you will have to create a sample batch which will be task-dependent. The code shown below is specific for SQuAD-style Question Answering: -```python -exporter = ModuleExporter( - model, output_dir='onnx-export' -) -sample_batch = convert_example_to_features( - datasets["validation"][0], - tokenizer, - data_args.max_seq_length, - data_args.doc_stride, - data_args.max_query_length, -) -exporter.export_onnx(sample_batch=sample_batch) +For example, the following command can be run from within the integration's folder to export a trained/sparsified model's checkpoint: +```bash +python transformers/examples/pytorch/question-answering/run_qa.py \ + --model_name_or_path MODELS_DIR/bert-base-12layers_prune80 \ + --dataset_name squad \ + --do_eval \ + --per_device_eval_batch_size 64 \ + --max_seq_length 384 \ + --doc_stride 128 \ + --output_dir MODELS_DIR/bert-base-12layers_prune80/eval \ + --cache_dir cache \ + --preprocessing_num_workers 6 \ + --onnx_export_path MODELS_DIR/bert-base-12layers_prune80/onnx ``` + +The DeepSparse Engine [accepts ONNX formats](https://docs.neuralmagic.com/sparseml/source/onnx_export.html) and is engineered to significantly speed up inference on CPUs for the sparsified models from this integration. +Examples for loading, benchmarking, and deploying can be found in the [DeepSparse repository here](https://github.com/neuralmagic/deepsparse). diff --git a/integrations/huggingface-transformers/tutorials/sparsifying_bert_using_recipes.md b/integrations/huggingface-transformers/tutorials/sparsifying_bert_using_recipes.md index 1f07f4ab322..75fca317925 100644 --- a/integrations/huggingface-transformers/tutorials/sparsifying_bert_using_recipes.md +++ b/integrations/huggingface-transformers/tutorials/sparsifying_bert_using_recipes.md @@ -32,8 +32,8 @@ Working through this tutorial, you will experience how Neural Magic recipes simp All the results listed in this tutorials are available publically through a [Weights and Biases project](https://wandb.ai/neuralmagic/sparse-bert-squad?workspace=user-neuralmagic).

- - + +

## Need Help? @@ -60,7 +60,7 @@ python transformers/examples/pytorch/question-answering/run_qa.py \ --fp16 \ --num_train_epochs 2 \ --warmup_steps 5400 \ - --report_to wandb + --save_strategy epoch ``` If the command runs successfully, you should have a model folder called `bert-base-12layers` in the provided model directory `MODELS_DIR`. @@ -76,7 +76,7 @@ Using the teacher model `bert-base-12layers` above, you can now train and prune Additionally, you will use the argument `--onnx_export_path` to specify the destination folder for the exported ONNX model. The resulting exported model could then be used for inference with the `DeepSparse Engine`. -The following command prunes the model in 30 epochs to 80% sparsity of the encoder layers: +The following command prunes the model in 30 epochs to 80% sparsity of the encoder layers, saving two checkpoints during training: ```bash python transformers/examples/pytorch/question-answering/run_qa.py \ @@ -97,9 +97,10 @@ python transformers/examples/pytorch/question-answering/run_qa.py \ --preprocessing_num_workers 6 \ --fp16 \ --num_train_epochs 30 \ - --recipe ../recipes/bert-base-12layers_prune80.md \ + --recipe recipes/bert-base-12layers_prune80.md \ --onnx_export_path MODELS_DIR/bert-base-12layers_prune80/onnx \ - --report_to wandb + --save_strategy epoch \ + --save_total_limit 2 ``` The directory `recipes` contains information about recipes and training commands used to produce our BERT pruned models on the SQuAD dataset. @@ -141,9 +142,10 @@ python transformers/examples/pytorch/question-answering/run_qa.py \ --per_device_eval_batch_size 64 \ --max_seq_length 384 \ --doc_stride 128 \ + --output_dir MODELS_DIR/bert-base-12layers_prune80/eval \ --cache_dir cache \ --preprocessing_num_workers 6 \ - --onnx_export_path MODELS_DIR/bert-base-12layers_prune80/onnx \ + --onnx_export_path MODELS_DIR/bert-base-12layers_prune80/onnx ``` If it runs successfully, you will have the converted `model.onnx` in `MODELS_DIR/bert-base-12layers_prune80/onnx`. You can now run it in ONNX-compatible inference engines such as [DeepSparse](https://github.com/neuralmagic/deepsparse). The `DeepSparse Engine` is explicitly coded to support running sparsified models for significant improvements in inference performance. diff --git a/research/information_retrieval/DPR/README.md b/research/information_retrieval/DPR/README.md new file mode 100644 index 00000000000..a4988f8d716 --- /dev/null +++ b/research/information_retrieval/DPR/README.md @@ -0,0 +1,48 @@ +# Compressing DPR +Author: @spacemanidol + +Methods +1. Varying models +2. Sturctured Pruning +3. Unstructured Pruning +4. Dimensionality Reduction +## Usage +batch_size: 4 +dev_batch_size: 16 +adam_eps: 1e-8 +adam_betas: (0.9, 0.999) +max_grad_norm: 2.0 +log_batch_step: 1 +train_rolling_loss_step: 100 +weight_decay: 0.0 +learning_rate: 2e-5 +# Linear warmup over warmup_steps. +warmup_steps: 1237 + +# Number of updates steps to accumulate before performing a backward/update pass. +gradient_accumulation_steps: 1 + +# Total number of training epochs to perform. +num_train_epochs: 40 +eval_per_epoch: 1 +hard_negatives: 1 +other_negatives: 0 +val_av_rank_hard_neg: 30 +val_av_rank_other_neg: 30 +val_av_rank_bsz: 128 +val_av_rank_max_qs: 10000 + +https://www.dropbox.com/s/lvvpsx0cjk4vemv/collection.tar.gz?dl=1 +https://www.dropbox.com/s/hq6xjhswiz60siu/queries.dev.small.tsv?dl=1 +https://www.dropbox.com/s/khsplt2fhqwjs0v/qrels.dev.small.tsv?dl=1 +https://www.dropbox.com/s/uzkvv4gpj3a596a/predicted_queries_topk_sampling.zip?dl=1 +https://www.dropbox.com/s/nc1drdkjpxxsngg/run.dev.small.tsv?dl=1 +## Results + +| Top-k passages | Original DPR NQ model | New DPR model | +| ------------- |:-------------:| -----:| +| 1 | 45.87 | 52.47 | +| 5 | 68.14 | 72.24 | +| 20 | 79.97 | 81.33 | +| 100 | 85.87 | 87.29 | +### requirements.txt diff --git a/research/information_retrieval/DPR/conf/README.md b/research/information_retrieval/DPR/conf/README.md new file mode 100644 index 00000000000..ad0044c5d85 --- /dev/null +++ b/research/information_retrieval/DPR/conf/README.md @@ -0,0 +1,65 @@ +## Hydra + +[Hydra](https://github.com/facebookresearch/hydra) is an open-source Python +framework that simplifies the development of research and other complex +applications. The key feature is the ability to dynamically create a +hierarchical configuration by composition and override it through config files +and the command line. + +## DPR configuration +All DPR tools configuration parameters are now split between different config groups and you can either modify them in the config files or override from command line. + +Each tools's (train_dense_encoder.py, generate_dense_embeddings.py, dense_retriever.py and train_reader.py) main method has now a hydra @hydra.main decorator with the name of the configuration file in the conf/ dir. +For example, dense_retriever.py takes all its parameters from conf/dense_retriever.yaml file. +Every tool's configuration files refers to other configuration files via "defaults:" parameter. +It is called a [configuration group](https://hydra.cc/docs/tutorials/structured_config/config_groups) in Hydra. + +Let's take a look at dense_retriever.py's configuration: + + +```yaml + +defaults: + - encoder: hf_bert + - datasets: retriever_default + - ctx_sources: default_sources + +indexers: + flat: + _target_: dpr.indexer.faiss_indexers.DenseFlatIndexer + + hnsw: + _target_: dpr.indexer.faiss_indexers.DenseHNSWFlatIndexer + + hnsw_sq: + _target_: dpr.indexer.faiss_indexers.DenseHNSWSQIndexer + +... +qa_dataset: +... +ctx_datatsets: +... +indexer: flat +... + +``` + +" - encoder: " - a configuration group that contains all parameters to instantiate the encoder. The actual parameters are located in conf/encoder/hf_bert.yaml file. +If you want to override some of them, you can either +- Modify that config file +- Create a new config group file under conf/encoder/ folder and enable to use it by providing encoder={your file name} command line argument +- Override specific parameter from command line. For example: encoder.sequence_length=300 + +" - datasets:" - a configuration group that contains a list of all possible sources of queries for evaluation. One can find them in conf/datasets/retriever_default.yaml file. +One should specify the dataset to use by providing qa_dataset parameter in order to use one of them during evaluation. For example, if you want to run the retriever on NQ test set, set qa_dataset=nq_test as a command line parameter. + +It is much easier now to use custom datasets, without the need to convert them to DPR format. Just define your own class that provides relevant __getitem__(), __len__() and load_data() methods (inherit from QASrc). + +" - ctx_sources: " - a configuration group that contains a list of all possible passage sources. One can find them in conf/ctx_sources/default_sources.yaml file. +One should specify a list of names of the passages datasets as ctx_datatsets parameter. For example, if you want to use dpr's old wikipedia passages, set ctx_datatsets=[dpr_wiki]. +Please note that this parameter is a list and you can effectively concatenate different passage source into one. In order to use multiple sources at once, one also needs to provide relevant embeddings files in encoded_ctx_files parameter, which is also a list. + + +"indexers:" - a parameters map that defines various indexes. The actual index is selected by indexer parameter which is 'flat' by default but you can use loss index types by setting indexer=hnsw or indexer=hnsw_sq in the command line. + +Please refer to the configuration files comments for every parameter. diff --git a/research/information_retrieval/DPR/conf/biencoder_train_cfg.yaml b/research/information_retrieval/DPR/conf/biencoder_train_cfg.yaml new file mode 100644 index 00000000000..82a35d400b1 --- /dev/null +++ b/research/information_retrieval/DPR/conf/biencoder_train_cfg.yaml @@ -0,0 +1,47 @@ + +# configuration groups +defaults: + - encoder: hf_bert + - train: biencoder_default + - datasets: encoder_train_default + +train_datasets: +dev_datasets: +output_dir: +train_sampling_rates: +loss_scale_factors: + +# Whether to lower case the input text. Set True for uncased models, False for the cased ones. +do_lower_case: True + +fix_ctx_encoder: False +val_av_rank_start_epoch: 30 +seed: 12345 +checkpoint_file_name: dpr_biencoder + +# A trained bi-encoder checkpoint file to initialize the model +model_file: + +# TODO: move to a conf group +# local_rank for distributed training on gpus +local_rank: -1 +global_loss_buf_sz: 592000 +device: +distributed_world_size: +distributed_port: +no_cuda: False +n_gpu: +fp16: True + +# For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." +# "See details at https://nvidia.github.io/apex/amp.html +fp16_opt_level: O1 + +# tokens which won't be slit by tokenizer +special_tokens: + +ignore_checkpoint_offset: False +ignore_checkpoint_optimizer: False + +# set to >1 to enable multiple query encoders +multi_q_encoder: False diff --git a/research/information_retrieval/DPR/conf/ctx_sources/default_sources.yaml b/research/information_retrieval/DPR/conf/ctx_sources/default_sources.yaml new file mode 100644 index 00000000000..75659a16929 --- /dev/null +++ b/research/information_retrieval/DPR/conf/ctx_sources/default_sources.yaml @@ -0,0 +1,6 @@ +# @package _group_ + +dpr_wiki: + _target_: dpr.data.retriever_data.CsvCtxSrc + file: data.wikipedia_split.psgs_w100 + id_prefix: 'wiki:' diff --git a/research/information_retrieval/DPR/conf/datasets/encoder_train_default.yaml b/research/information_retrieval/DPR/conf/datasets/encoder_train_default.yaml new file mode 100644 index 00000000000..916c516d66d --- /dev/null +++ b/research/information_retrieval/DPR/conf/datasets/encoder_train_default.yaml @@ -0,0 +1,46 @@ +# @package _group_ + +nq_train: + _target_: dpr.data.biencoder_data.JsonQADataset + file: data.retriever.nq-train + +nq_train_hn1: + _target_: dpr.data.biencoder_data.JsonQADataset + file: data.retriever.nq-adv-hn-train + +nq_dev: + _target_: dpr.data.biencoder_data.JsonQADataset + file: data.retriever.nq-dev + +trivia_train: + _target_: dpr.data.biencoder_data.JsonQADataset + file: data.retriever.trivia-train + +trivia_dev: + _target_: dpr.data.biencoder_data.JsonQADataset + file: data.retriever.trivia-dev + +squad1_train: + _target_: dpr.data.biencoder_data.JsonQADataset + file: data.retriever.squad1-train + +squad1_dev: + _target_: dpr.data.biencoder_data.JsonQADataset + file: data.retriever.squad1-dev + +webq_train: + _target_: dpr.data.biencoder_data.JsonQADataset + file: data.retriever.webq-train + +webq_dev: + _target_: dpr.data.biencoder_data.JsonQADataset + file: data.retriever.webq-dev + +curatedtrec_train: + _target_: dpr.data.biencoder_data.JsonQADataset + file: data.retriever.curatedtrec-train + +curatedtrec_dev: + _target_: dpr.data.biencoder_data.JsonQADataset + file: data.retriever.curatedtrec-dev + diff --git a/research/information_retrieval/DPR/conf/datasets/retriever_default.yaml b/research/information_retrieval/DPR/conf/datasets/retriever_default.yaml new file mode 100644 index 00000000000..cbe194ea5f5 --- /dev/null +++ b/research/information_retrieval/DPR/conf/datasets/retriever_default.yaml @@ -0,0 +1,33 @@ +# @package _group_ + +nq_test: + _target_: dpr.data.retriever_data.CsvQASrc + file: data.retriever.qas.nq-test + +nq_train: + _target_: dpr.data.retriever_data.CsvQASrc + file: data.retriever.qas.nq-train + +nq_dev: + _target_: dpr.data.retriever_data.CsvQASrc + file: data.retriever.qas.nq-dev + +trivia_test: + _target_: dpr.data.retriever_data.CsvQASrc + file: data.retriever.qas.trivia-test + +trivia_train: + _target_: dpr.data.retriever_data.CsvQASrc + file: data.retriever.qas.trivia-train + +trivia_dev: + _target_: dpr.data.retriever_data.CsvQASrc + file: data.retriever.qas.trivia-dev + +webq_test: + _target_: dpr.data.retriever_data.CsvQASrc + file: data.retriever.qas.webq-test + +curatedtrec_test: + _target_: dpr.data.retriever_data.CsvQASrc + file: data.retriever.qas.curatedtrec-test diff --git a/research/information_retrieval/DPR/conf/dense_retriever.yaml b/research/information_retrieval/DPR/conf/dense_retriever.yaml new file mode 100644 index 00000000000..9fb74ba5c56 --- /dev/null +++ b/research/information_retrieval/DPR/conf/dense_retriever.yaml @@ -0,0 +1,71 @@ +defaults: + - encoder: hf_bert # defines encoder initialization parameters + - datasets: retriever_default # contains a list of all possible sources of queries for evaluation. Specific set is selected by qa_dataset parameter + - ctx_sources: default_sources # contains a list of all possible passage sources. Specific passages sources selected by ctx_datatsets parameter + +indexers: + flat: + _target_: dpr.indexer.faiss_indexers.DenseFlatIndexer + + hnsw: + _target_: dpr.indexer.faiss_indexers.DenseHNSWFlatIndexer + + hnsw_sq: + _target_: dpr.indexer.faiss_indexers.DenseHNSWSQIndexer + +# the name of the queries dataset from the 'datasets' config group +qa_dataset: + +# a list of names of the passages datasets from the 'ctx_sources' config group +ctx_datatsets: + +#Glob paths to encoded passages (from generate_dense_embeddings tool) +encoded_ctx_files: [] + +out_file: +# "regex" or "string" +match: string +n_docs: 100 +validation_workers: 16 + +# Batch size to generate query embeddings +batch_size: 128 + +# Whether to lower case the input text. Set True for uncased models, False for the cased ones. +do_lower_case: True + +# The attribute name of encoder to use for queries. Options for the BiEncoder model: question_model, ctx_model +# question_model is used if this param is empty +encoder_path: + +# path to the FAISS index location - it is only needed if you want to serialize faiss index to files or read from them +# (instead of using encoded_ctx_files) +# it should point to either directory or a common index files prefix name +# if there is no index at the specific location, the index will be created from encoded_ctx_files +index_path: + +kilt_out_file: + +# A trained bi-encoder checkpoint file to initialize the model +model_file: + +validate_as_tables: False +rpc_retriever_cfg_file: +indexer: flat + +# tokens which won't be slit by tokenizer +special_tokens: + +# TODO: move to a conf group +# local_rank for distributed training on gpus +local_rank: -1 +global_loss_buf_sz: 150000 +device: +distributed_world_size: +no_cuda: False +n_gpu: +fp16: False + +# For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." +# "See details at https://nvidia.github.io/apex/amp.html +fp16_opt_level: O1 diff --git a/research/information_retrieval/DPR/conf/encoder/hf_bert.yaml b/research/information_retrieval/DPR/conf/encoder/hf_bert.yaml new file mode 100644 index 00000000000..93ee638ac84 --- /dev/null +++ b/research/information_retrieval/DPR/conf/encoder/hf_bert.yaml @@ -0,0 +1,24 @@ +# @package _group_ + +# model type. One of [hf_bert, pytext_bert, fairseq_roberta] +encoder_model_type: hf_bert + +# HuggingFace's config name for model initialization +pretrained_model_cfg: bert-base-uncased + +# Some encoders need to be initialized from a file +pretrained_file: + +# Extra linear layer on top of standard bert/roberta encoder +projection_dim: 0 + +# Max length of the encoder input sequence +sequence_length: 256 + +dropout: 0.1 + +# whether to fix (don't update) context encoder during training or not +fix_ctx_encoder: False + +# if False, the model won't load pre-trained BERT weights +pretrained: True \ No newline at end of file diff --git a/research/information_retrieval/DPR/conf/extractive_reader_train_cfg.yaml b/research/information_retrieval/DPR/conf/extractive_reader_train_cfg.yaml new file mode 100644 index 00000000000..a947136d431 --- /dev/null +++ b/research/information_retrieval/DPR/conf/extractive_reader_train_cfg.yaml @@ -0,0 +1,73 @@ +# extractive reader configuration + +defaults: + - encoder: hf_bert + - train: extractive_reader_default + +# A trained reader checkpoint file to initialize the model +model_file: + +# Whether to lower case the input text. Set True for uncased models, False for the cased ones. +do_lower_case: True + +seed: 42 + +# glob expression for train data files +train_files: + +# glob expression for dev data files +dev_files: + +# Total amount of positive and negative passages per question +passages_per_question: 24 + +# Total amount of positive and negative passages per question for evaluation +passages_per_question_predict: 50 + +# The output directory where the model checkpoints will be written to +output_dir: + +# Max amount of answer spans to marginalize per singe passage +max_n_answers: 10 + +# The maximum length of an answer that can be generated. This is needed because the start +# and end predictions are not conditioned on one another +max_answer_length: 10 + +# Top retrieval passages thresholds to analyze prediction results for +eval_top_docs: + - 50 + +checkpoint_file_name: dpr_extractive_reader + +# Path to a file to write prediction results to +prediction_results_file: + +# Enables fully resumable mode +fully_resumable: False + +# File with the original train dataset passages (json format) +gold_passages_src: + +# File with the original dataset passages (json format) +gold_passages_src_dev: + +# num of threads to pre-process data. +num_workers: 16 + +# TODO: move to a conf group +# local_rank for distributed training on gpus +local_rank: -1 +global_loss_buf_sz: 150000 +device: +distributed_world_size: +no_cuda: False +n_gpu: +fp16: False + +# For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." +# "See details at https://nvidia.github.io/apex/amp.html +fp16_opt_level: O1 + +# a list of tokens to avoid tokenization +special_tokens: \ No newline at end of file diff --git a/research/information_retrieval/DPR/conf/gen_embs.yaml b/research/information_retrieval/DPR/conf/gen_embs.yaml new file mode 100644 index 00000000000..bc30881ba00 --- /dev/null +++ b/research/information_retrieval/DPR/conf/gen_embs.yaml @@ -0,0 +1,52 @@ +defaults: + - encoder: hf_bert + - ctx_sources: default_sources + +# A trained bi-encoder checkpoint file to initialize the model +model_file: + +# Name of the all-passages resource +ctx_src: + +# which (ctx or query) encoder to be used for embedding generation +encoder_type: ctx + +# output .tsv file path to write results to +out_file: + +# Whether to lower case the input text. Set True for uncased models, False for the cased ones. +do_lower_case: True + +# Number(0-based) of data shard to process +shard_id: 0 + +# Total amount of data shards +num_shards: 1 + +# Batch size for the passage encoder forward pass (works in DataParallel mode) +batch_size: 32 + +tables_as_passages: False + +# tokens which won't be slit by tokenizer +special_tokens: + +tables_chunk_sz: 100 + +# TODO +tables_split_type: type1 + + +# TODO: move to a conf group +# local_rank for distributed training on gpus +local_rank: -1 +device: +distributed_world_size: +distributed_port: +no_cuda: False +n_gpu: +fp16: False + +# For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." +# "See details at https://nvidia.github.io/apex/amp.html +fp16_opt_level: O1 \ No newline at end of file diff --git a/research/information_retrieval/DPR/conf/train/biencoder_default.yaml b/research/information_retrieval/DPR/conf/train/biencoder_default.yaml new file mode 100644 index 00000000000..601e0f53fda --- /dev/null +++ b/research/information_retrieval/DPR/conf/train/biencoder_default.yaml @@ -0,0 +1,27 @@ +# @package _group_ + +batch_size: 4 +dev_batch_size: 4 +adam_eps: 1e-8 +adam_betas: (0.9, 0.999) +max_grad_norm: 1.0 +log_batch_step: 100 +train_rolling_loss_step: 100 +weight_decay: 0.0 +learning_rate: 1e-5 + +# Linear warmup over warmup_steps. +warmup_steps: 100 + +# Number of updates steps to accumulate before performing a backward/update pass. +gradient_accumulation_steps: 1 + +# Total number of training epochs to perform. +num_train_epochs: 40 +eval_per_epoch: 1 +hard_negatives: 1 +other_negatives: 0 +val_av_rank_hard_neg: 30 +val_av_rank_other_neg: 30 +val_av_rank_bsz: 128 +val_av_rank_max_qs: 10000 diff --git a/research/information_retrieval/DPR/conf/train/biencoder_local.yaml b/research/information_retrieval/DPR/conf/train/biencoder_local.yaml new file mode 100644 index 00000000000..69696df1660 --- /dev/null +++ b/research/information_retrieval/DPR/conf/train/biencoder_local.yaml @@ -0,0 +1,27 @@ +# @package _group_ + +batch_size: 4 +dev_batch_size: 16 +adam_eps: 1e-8 +adam_betas: (0.9, 0.999) +max_grad_norm: 2.0 +log_batch_step: 1 +train_rolling_loss_step: 100 +weight_decay: 0.0 +learning_rate: 2e-5 + +# Linear warmup over warmup_steps. +warmup_steps: 1237 + +# Number of updates steps to accumulate before performing a backward/update pass. +gradient_accumulation_steps: 1 + +# Total number of training epochs to perform. +num_train_epochs: 40 +eval_per_epoch: 1 +hard_negatives: 1 +other_negatives: 0 +val_av_rank_hard_neg: 30 +val_av_rank_other_neg: 30 +val_av_rank_bsz: 128 +val_av_rank_max_qs: 10000 diff --git a/research/information_retrieval/DPR/conf/train/biencoder_nq.yaml b/research/information_retrieval/DPR/conf/train/biencoder_nq.yaml new file mode 100644 index 00000000000..e26f533c44a --- /dev/null +++ b/research/information_retrieval/DPR/conf/train/biencoder_nq.yaml @@ -0,0 +1,27 @@ +# @package _group_ + +batch_size: 4 +dev_batch_size: 64 +adam_eps: 1e-8 +adam_betas: (0.9, 0.999) +max_grad_norm: 2.0 +log_batch_step: 100 +train_rolling_loss_step: 100 +weight_decay: 0.0 +learning_rate: 2e-5 + +# Linear warmup over warmup_steps. +warmup_steps: 1237 + +# Number of updates steps to accumulate before performing a backward/update pass. +gradient_accumulation_steps: 1 + +# Total number of training epochs to perform. +num_train_epochs: 40 +eval_per_epoch: 1 +hard_negatives: 1 +other_negatives: 0 +val_av_rank_hard_neg: 30 +val_av_rank_other_neg: 30 +val_av_rank_bsz: 128 +val_av_rank_max_qs: 10000 diff --git a/research/information_retrieval/DPR/conf/train/extractive_reader_default.yaml b/research/information_retrieval/DPR/conf/train/extractive_reader_default.yaml new file mode 100644 index 00000000000..62778c49808 --- /dev/null +++ b/research/information_retrieval/DPR/conf/train/extractive_reader_default.yaml @@ -0,0 +1,21 @@ +# @package _group_ + +eval_step: 2000 +batch_size: 16 +dev_batch_size: 72 +adam_eps: 1e-8 +adam_betas: (0.9, 0.999) +max_grad_norm: 1.0 +log_batch_step: 100 +train_rolling_loss_step: 100 +weight_decay: 0.0 +learning_rate: 1e-5 + +# Linear warmup over warmup_steps. +warmup_steps: 0 + +# Number of updates steps to accumulate before performing a backward/update pass. +gradient_accumulation_steps: 1 + +# Total number of training epochs to perform. +num_train_epochs: 100000 diff --git a/research/information_retrieval/DPR/dense_retriever.py b/research/information_retrieval/DPR/dense_retriever.py new file mode 100644 index 00000000000..7ee4eb0346d --- /dev/null +++ b/research/information_retrieval/DPR/dense_retriever.py @@ -0,0 +1,460 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" + Command line tool to get dense results and validate them +""" + +import glob +import json +import logging +import pickle +import time +from typing import List, Tuple, Dict, Iterator + +import hydra +import numpy as np +import torch +from omegaconf import DictConfig, OmegaConf +from torch import Tensor as T +from torch import nn + +from dpr.data.biencoder_data import RepTokenSelector +from dpr.data.qa_validation import calculate_matches, calculate_chunked_matches +from dpr.data.retriever_data import KiltCsvCtxSrc, TableChunk +from dpr.indexer.faiss_indexers import ( + DenseIndexer, +) +from dpr.models import init_biencoder_components +from dpr.models.biencoder import BiEncoder, _select_span_with_token +from dpr.options import setup_logger, setup_cfg_gpu, set_cfg_params_from_state +from dpr.utils.data_utils import Tensorizer +from dpr.utils.model_utils import ( + setup_for_distributed_mode, + get_model_obj, + load_states_from_checkpoint, +) + +logger = logging.getLogger() +setup_logger(logger) + + +def generate_question_vectors( + question_encoder: torch.nn.Module, + tensorizer: Tensorizer, + questions: List[str], + bsz: int, + query_token: str = None, + selector: RepTokenSelector = None, +) -> T: + n = len(questions) + query_vectors = [] + + with torch.no_grad(): + for j, batch_start in enumerate(range(0, n, bsz)): + batch_questions = questions[batch_start : batch_start + bsz] + + if query_token: + # TODO: tmp workaround for EL, remove or revise + if query_token == "[START_ENT]": + batch_token_tensors = [ + _select_span_with_token(q, tensorizer, token_str=query_token) + for q in batch_questions + ] + else: + batch_token_tensors = [ + tensorizer.text_to_tensor(" ".join([query_token, q])) + for q in batch_questions + ] + else: + batch_token_tensors = [ + tensorizer.text_to_tensor(q) for q in batch_questions + ] + + q_ids_batch = torch.stack(batch_token_tensors, dim=0).cuda() + q_seg_batch = torch.zeros_like(q_ids_batch).cuda() + q_attn_mask = tensorizer.get_attn_mask(q_ids_batch) + + if selector: + rep_positions = selector.get_positions(q_ids_batch, tensorizer) + + _, out, _ = BiEncoder.get_representation( + question_encoder, + q_ids_batch, + q_seg_batch, + q_attn_mask, + representation_token_pos=rep_positions, + ) + else: + _, out, _ = question_encoder(q_ids_batch, q_seg_batch, q_attn_mask) + + query_vectors.extend(out.cpu().split(1, dim=0)) + + if len(query_vectors) % 100 == 0: + logger.info("Encoded queries %d", len(query_vectors)) + + query_tensor = torch.cat(query_vectors, dim=0) + logger.info("Total encoded queries tensor %s", query_tensor.size()) + assert query_tensor.size(0) == len(questions) + return query_tensor + + +class DenseRetriever(object): + def __init__( + self, question_encoder: nn.Module, batch_size: int, tensorizer: Tensorizer + ): + self.question_encoder = question_encoder + self.batch_size = batch_size + self.tensorizer = tensorizer + self.selector = None + + def generate_question_vectors( + self, questions: List[str], query_token: str = None + ) -> T: + + bsz = self.batch_size + self.question_encoder.eval() + return generate_question_vectors( + self.question_encoder, + self.tensorizer, + questions, + bsz, + query_token=query_token, + selector=self.selector, + ) + + +class LocalFaissRetriever(DenseRetriever): + """ + Does passage retrieving over the provided index and question encoder + """ + + def __init__( + self, + question_encoder: nn.Module, + batch_size: int, + tensorizer: Tensorizer, + index: DenseIndexer, + ): + super().__init__(question_encoder, batch_size, tensorizer) + self.index = index + + def index_encoded_data( + self, + vector_files: List[str], + buffer_size: int, + path_id_prefixes: List = None, + ): + """ + Indexes encoded passages takes form a list of files + :param vector_files: file names to get passages vectors from + :param buffer_size: size of a buffer (amount of passages) to send for the indexing at once + :return: + """ + buffer = [] + for i, item in enumerate( + iterate_encoded_files(vector_files, path_id_prefixes=path_id_prefixes) + ): + buffer.append(item) + if 0 < buffer_size == len(buffer): + self.index.index_data(buffer) + buffer = [] + self.index.index_data(buffer) + logger.info("Data indexing completed.") + + def get_top_docs( + self, query_vectors: np.array, top_docs: int = 100 + ) -> List[Tuple[List[object], List[float]]]: + """ + Does the retrieval of the best matching passages given the query vectors batch + :param query_vectors: + :param top_docs: + :return: + """ + time0 = time.time() + results = self.index.search_knn(query_vectors, top_docs) + logger.info("index search time: %f sec.", time.time() - time0) + self.index = None + return results + + +def validate( + passages: Dict[object, Tuple[str, str]], + answers: List[List[str]], + result_ctx_ids: List[Tuple[List[object], List[float]]], + workers_num: int, + match_type: str, +) -> List[List[bool]]: + match_stats = calculate_matches( + passages, answers, result_ctx_ids, workers_num, match_type + ) + top_k_hits = match_stats.top_k_hits + + logger.info("Validation results: top k documents hits %s", top_k_hits) + top_k_hits = [v / len(result_ctx_ids) for v in top_k_hits] + logger.info("Validation results: top k documents hits accuracy %s", top_k_hits) + return match_stats.questions_doc_hits + + +def save_results( + passages: Dict[object, Tuple[str, str]], + questions: List[str], + answers: List[List[str]], + top_passages_and_scores: List[Tuple[List[object], List[float]]], + per_question_hits: List[List[bool]], + out_file: str, +): + # join passages text with the result ids, their questions and assigning has|no answer labels + merged_data = [] + # assert len(per_question_hits) == len(questions) == len(answers) + for i, q in enumerate(questions): + q_answers = answers[i] + results_and_scores = top_passages_and_scores[i] + hits = per_question_hits[i] + docs = [passages[doc_id] for doc_id in results_and_scores[0]] + scores = [str(score) for score in results_and_scores[1]] + ctxs_num = len(hits) + + merged_data.append( + { + "question": q, + "answers": q_answers, + "ctxs": [ + { + "id": results_and_scores[0][c], + "title": docs[c][1], + "text": docs[c][0], + "score": scores[c], + "has_answer": hits[c], + } + for c in range(ctxs_num) + ], + } + ) + + with open(out_file, "w") as writer: + writer.write(json.dumps(merged_data, indent=4) + "\n") + logger.info("Saved results * scores to %s", out_file) + + +def iterate_encoded_files( + vector_files: list, path_id_prefixes: List = None +) -> Iterator[Tuple]: + for i, file in enumerate(vector_files): + logger.info("Reading file %s", file) + id_prefix = None + if path_id_prefixes: + id_prefix = path_id_prefixes[i] + with open(file, "rb") as reader: + doc_vectors = pickle.load(reader) + for doc in doc_vectors: + doc = list(doc) + if id_prefix and not str(doc[0]).startswith(id_prefix): + doc[0] = id_prefix + str(doc[0]) + yield doc + + +def validate_tables( + passages: Dict[object, TableChunk], + answers: List[List[str]], + result_ctx_ids: List[Tuple[List[object], List[float]]], + workers_num: int, + match_type: str, +) -> List[List[bool]]: + match_stats = calculate_chunked_matches( + passages, answers, result_ctx_ids, workers_num, match_type + ) + top_k_chunk_hits = match_stats.top_k_chunk_hits + top_k_table_hits = match_stats.top_k_table_hits + + logger.info("Validation results: top k documents hits %s", top_k_chunk_hits) + top_k_hits = [v / len(result_ctx_ids) for v in top_k_chunk_hits] + logger.info("Validation results: top k table chunk hits accuracy %s", top_k_hits) + + logger.info("Validation results: top k tables hits %s", top_k_table_hits) + top_k_table_hits = [v / len(result_ctx_ids) for v in top_k_table_hits] + logger.info("Validation results: top k tables accuracy %s", top_k_table_hits) + + return match_stats.top_k_chunk_hits + + +@hydra.main(config_path="conf", config_name="dense_retriever") +def main(cfg: DictConfig): + cfg = setup_cfg_gpu(cfg) + logger.info("CFG (after gpu configuration):") + logger.info("%s", OmegaConf.to_yaml(cfg)) + + saved_state = load_states_from_checkpoint(cfg.model_file) + set_cfg_params_from_state(saved_state.encoder_params, cfg) + + tensorizer, encoder, _ = init_biencoder_components( + cfg.encoder.encoder_model_type, cfg, inference_only=True + ) + + encoder_path = cfg.encoder_path + if encoder_path: + logger.info("Selecting encoder: %s", encoder_path) + encoder = getattr(encoder, encoder_path) + else: + logger.info("Selecting standard question encoder") + encoder = encoder.question_model + + encoder, _ = setup_for_distributed_mode( + encoder, None, cfg.device, cfg.n_gpu, cfg.local_rank, cfg.fp16 + ) + encoder.eval() + + # load weights from the model file + model_to_load = get_model_obj(encoder) + logger.info("Loading saved model state ...") + + encoder_prefix = (encoder_path if encoder_path else "question_model") + "." + prefix_len = len(encoder_prefix) + + logger.info("Encoder state prefix %s", encoder_prefix) + question_encoder_state = { + key[prefix_len:]: value + for (key, value) in saved_state.model_dict.items() + if key.startswith(encoder_prefix) + } + # TODO: long term HF state compatibility fix + model_to_load.load_state_dict(question_encoder_state, strict=False) + vector_size = model_to_load.get_out_size() + logger.info("Encoder vector_size=%d", vector_size) + + # get questions & answers + questions = [] + question_answers = [] + + if not cfg.qa_dataset: + logger.warning("Please specify qa_dataset to use") + return + + ds_key = cfg.qa_dataset + logger.info("qa_dataset: %s", ds_key) + + qa_src = hydra.utils.instantiate(cfg.datasets[ds_key]) + qa_src.load_data() + + for ds_item in qa_src.data: + question, answers = ds_item.query, ds_item.answers + questions.append(question) + question_answers.append(answers) + + index = hydra.utils.instantiate(cfg.indexers[cfg.indexer]) + logger.info("Index class %s ", type(index)) + index_buffer_sz = index.buffer_size + index.init_index(vector_size) + retriever = LocalFaissRetriever(encoder, cfg.batch_size, tensorizer, index) + + logger.info("Using special token %s", qa_src.special_query_token) + questions_tensor = retriever.generate_question_vectors( + questions, query_token=qa_src.special_query_token + ) + + if qa_src.selector: + logger.info("Using custom representation token selector") + retriever.selector = qa_src.selector + + id_prefixes = [] + ctx_sources = [] + for ctx_src in cfg.ctx_datatsets: + ctx_src = hydra.utils.instantiate(cfg.ctx_sources[ctx_src]) + id_prefixes.append(ctx_src.id_prefix) + ctx_sources.append(ctx_src) + + logger.info("id_prefixes per dataset: %s", id_prefixes) + + # index all passages + ctx_files_patterns = cfg.encoded_ctx_files + index_path = cfg.index_path + + logger.info("ctx_files_patterns: %s", ctx_files_patterns) + if ctx_files_patterns: + assert len(ctx_files_patterns) == len( + id_prefixes + ), "ctx len={} pref leb={}".format(len(ctx_files_patterns), len(id_prefixes)) + else: + assert ( + index_path + ), "Either encoded_ctx_files or index_path parameter should be set." + + input_paths = [] + path_id_prefixes = [] + for i, pattern in enumerate(ctx_files_patterns): + pattern_files = glob.glob(pattern) + pattern_id_prefix = id_prefixes[i] + input_paths.extend(pattern_files) + path_id_prefixes.extend([pattern_id_prefix] * len(pattern_files)) + + logger.info("Embeddings files id prefixes: %s", path_id_prefixes) + + if index_path and index.index_exists(index_path): + logger.info("Index path: %s", index_path) + retriever.index.deserialize(index_path) + else: + logger.info("Reading all passages data from files: %s", input_paths) + retriever.index_encoded_data( + input_paths, index_buffer_sz, path_id_prefixes=path_id_prefixes + ) + if index_path: + retriever.index.serialize(index_path) + + # get top k results + top_ids_and_scores = retriever.get_top_docs(questions_tensor.numpy(), cfg.n_docs) + + # we no longer need the index + retriever = None + + all_passages = {} + for ctx_src in ctx_sources: + ctx_src.load_data_to(all_passages) + + if len(all_passages) == 0: + raise RuntimeError( + "No passages data found. Please specify ctx_file param properly." + ) + + if cfg.validate_as_tables: + questions_doc_hits = validate_tables( + all_passages, + question_answers, + top_ids_and_scores, + cfg.validation_workers, + cfg.match, + ) + else: + questions_doc_hits = validate( + all_passages, + question_answers, + top_ids_and_scores, + cfg.validation_workers, + cfg.match, + ) + + if cfg.out_file: + save_results( + all_passages, + questions, + question_answers, + top_ids_and_scores, + questions_doc_hits, + cfg.out_file, + ) + + if cfg.kilt_out_file: + kilt_ctx = next( + iter([ctx for ctx in ctx_sources if isinstance(ctx, KiltCsvCtxSrc)]), None + ) + if not kilt_ctx: + raise RuntimeError("No Kilt compatible context file provided") + assert hasattr(cfg, "kilt_out_file") + kilt_ctx.convert_to_kilt(qa_src.kilt_gold_file, cfg.out_file, cfg.kilt_out_file) + + +if __name__ == "__main__": + main() diff --git a/research/information_retrieval/DPR/download_data.py b/research/information_retrieval/DPR/download_data.py new file mode 100644 index 00000000000..d362e222efb --- /dev/null +++ b/research/information_retrieval/DPR/download_data.py @@ -0,0 +1,550 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" + Command line tool to download various preprocessed data sources & checkpoints for DPR +""" + +import argparse +import gzip +import logging +import os +import pathlib +import wget + +from typing import Tuple + +logger = logging.getLogger(__name__) + +# TODO: move to hydra config group + +NQ_LICENSE_FILES = [ + "https://dl.fbaipublicfiles.com/dpr/nq_license/LICENSE", + "https://dl.fbaipublicfiles.com/dpr/nq_license/README", +] + +RESOURCES_MAP = { + "data.wikipedia_split.psgs_w100": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz", + "original_ext": ".tsv", + "compressed": True, + "desc": "Entire wikipedia passages set obtain by splitting all pages into 100-word segments (no overlap)", + }, + "data.retriever.nq-dev": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-dev.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "NQ dev subset with passages pools for the Retriever train time validation", + "license_files": NQ_LICENSE_FILES, + }, + "data.retriever.nq-train": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-train.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "NQ train subset with passages pools for the Retriever training", + "license_files": NQ_LICENSE_FILES, + }, + "data.retriever.nq-adv-hn-train": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-adv-hn-train.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "NQ train subset with hard negative passages mined using the baseline DPR NQ encoders & wikipedia index", + "license_files": NQ_LICENSE_FILES, + }, + "data.retriever.trivia-dev": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-trivia-dev.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "TriviaQA dev subset with passages pools for the Retriever train time validation", + }, + "data.retriever.trivia-train": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-trivia-train.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "TriviaQA train subset with passages pools for the Retriever training", + }, + "data.retriever.squad1-train": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-squad1-train.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "SQUAD 1.1 train subset with passages pools for the Retriever training", + }, + "data.retriever.squad1-dev": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-squad1-dev.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "SQUAD 1.1 dev subset with passages pools for the Retriever train time validation", + }, + "data.retriever.webq-train": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-webquestions-train.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "WebQuestions dev subset with passages pools for the Retriever train time validation", + }, + "data.retriever.webq-dev": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-webquestions-dev.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "WebQuestions dev subset with passages pools for the Retriever train time validation", + }, + "data.retriever.curatedtrec-train": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-curatedtrec-train.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "CuratedTrec dev subset with passages pools for the Retriever train time validation", + }, + "data.retriever.curatedtrec-dev": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-curatedtrec-dev.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "CuratedTrec dev subset with passages pools for the Retriever train time validation", + }, + "data.retriever.qas.nq-dev": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/nq-dev.qa.csv", + "original_ext": ".csv", + "compressed": False, + "desc": "NQ dev subset for Retriever validation and IR results generation", + "license_files": NQ_LICENSE_FILES, + }, + "data.retriever.qas.nq-test": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/nq-test.qa.csv", + "original_ext": ".csv", + "compressed": False, + "desc": "NQ test subset for Retriever validation and IR results generation", + "license_files": NQ_LICENSE_FILES, + }, + "data.retriever.qas.nq-train": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/nq-train.qa.csv", + "original_ext": ".csv", + "compressed": False, + "desc": "NQ train subset for Retriever validation and IR results generation", + "license_files": NQ_LICENSE_FILES, + }, + # + "data.retriever.qas.trivia-dev": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/trivia-dev.qa.csv.gz", + "original_ext": ".csv", + "compressed": True, + "desc": "Trivia dev subset for Retriever validation and IR results generation", + }, + "data.retriever.qas.trivia-test": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/trivia-test.qa.csv.gz", + "original_ext": ".csv", + "compressed": True, + "desc": "Trivia test subset for Retriever validation and IR results generation", + }, + "data.retriever.qas.trivia-train": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/trivia-train.qa.csv.gz", + "original_ext": ".csv", + "compressed": True, + "desc": "Trivia train subset for Retriever validation and IR results generation", + }, + "data.retriever.qas.squad1-test": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/squad1-test.qa.csv", + "original_ext": ".csv", + "compressed": False, + "desc": "Trivia test subset for Retriever validation and IR results generation", + }, + "data.retriever.qas.webq-test": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/webquestions-test.qa.csv", + "original_ext": ".csv", + "compressed": False, + "desc": "WebQuestions test subset for Retriever validation and IR results generation", + }, + "data.retriever.qas.curatedtrec-test": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/curatedtrec-test.qa.csv", + "original_ext": ".csv", + "compressed": False, + "desc": "CuratedTrec test subset for Retriever validation and IR results generation", + }, + "data.gold_passages_info.nq_train": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/nq_gold_info/nq-train_gold_info.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "Original NQ (our train subset) gold positive passages and alternative question tokenization", + "license_files": NQ_LICENSE_FILES, + }, + "data.gold_passages_info.nq_dev": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/nq_gold_info/nq-dev_gold_info.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "Original NQ (our dev subset) gold positive passages and alternative question tokenization", + "license_files": NQ_LICENSE_FILES, + }, + "data.gold_passages_info.nq_test": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/nq_gold_info/nq-test_gold_info.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "Original NQ (our test, original dev subset) gold positive passages and alternative question " + "tokenization", + "license_files": NQ_LICENSE_FILES, + }, + "pretrained.fairseq.roberta-base.dict": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/pretrained/fairseq/roberta/dict.txt", + "original_ext": ".txt", + "compressed": False, + "desc": "Dictionary for pretrained fairseq roberta model", + }, + "pretrained.fairseq.roberta-base.model": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/pretrained/fairseq/roberta/model.pt", + "original_ext": ".pt", + "compressed": False, + "desc": "Weights for pretrained fairseq roberta base model", + }, + "pretrained.pytext.bert-base.model": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/pretrained/pytext/bert/bert-base-uncased.pt", + "original_ext": ".pt", + "compressed": False, + "desc": "Weights for pretrained pytext bert base model", + }, + "data.retriever_results.nq.single.wikipedia_passages": { + "s3_url": [ + "https://dl.fbaipublicfiles.com/dpr/data/wiki_encoded/single/nq/wiki_passages_{}".format( + i + ) + for i in range(50) + ], + "original_ext": ".pkl", + "compressed": False, + "desc": "Encoded wikipedia files using a biencoder checkpoint(" + "checkpoint.retriever.single.nq.bert-base-encoder) trained on NQ dataset ", + }, + "data.retriever_results.nq.single-adv-hn.wikipedia_passages": { + "s3_url": [ + "https://dl.fbaipublicfiles.com/dpr/data/wiki_encoded/single-adv-hn/nq/wiki_passages_{}".format( + i + ) + for i in range(50) + ], + "original_ext": ".pkl", + "compressed": False, + "desc": "Encoded wikipedia files using a single-adv-hn checkpoint(" + "checkpoint.retriever.single.nq.bert-base-encoder) trained on NQ dataset ", + }, + "data.retriever_results.nq.single-adv-hn.wikipedia_passages": { + "s3_url": [ + "https://dl.fbaipublicfiles.com/dpr/data/wiki_encoded/single-adv-hn/nq/wiki_passages_{}".format( + i + ) + for i in range(50) + ], + "original_ext": ".pkl", + "compressed": False, + "desc": "Encoded wikipedia files using a biencoder checkpoint(" + "checkpoint.retriever.single-adv-hn.nq.bert-base-encoder) trained on NQ dataset + adversarial hard negatives", + }, + "data.retriever_results.nq.single.test": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever_results/single/nq-test.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "Retrieval results of NQ test dataset for the encoder trained on NQ", + "license_files": NQ_LICENSE_FILES, + }, + "data.retriever_results.nq.single.dev": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever_results/single/nq-dev.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "Retrieval results of NQ dev dataset for the encoder trained on NQ", + "license_files": NQ_LICENSE_FILES, + }, + "data.retriever_results.nq.single.train": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever_results/single/nq-train.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "Retrieval results of NQ train dataset for the encoder trained on NQ", + "license_files": NQ_LICENSE_FILES, + }, + "data.retriever_results.nq.single-adv-hn.test": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever_results/single-adv-hn/nq-test.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "Retrieval results of NQ test dataset for the encoder trained on NQ + adversarial hard negatives", + "license_files": NQ_LICENSE_FILES, + }, + "checkpoint.retriever.single.nq.bert-base-encoder": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/retriever/single/nq/hf_bert_base.cp", + "original_ext": ".cp", + "compressed": False, + "desc": "Biencoder weights trained on NQ data and HF bert-base-uncased model", + }, + "checkpoint.retriever.multiset.bert-base-encoder": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/retriver/multiset/hf_bert_base.cp", + "original_ext": ".cp", + "compressed": False, + "desc": "Biencoder weights trained on multi set data and HF bert-base-uncased model", + }, + "checkpoint.retriever.single-adv-hn.nq.bert-base-encoder": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/retriver/single-adv-hn/nq/hf_bert_base.cp", + "original_ext": ".cp", + "compressed": False, + "desc": "Biencoder weights trained on the original DPR NQ data combined with adversarial hard negatives (See data.retriever.nq-adv-hn-train resource). " + "The model is HF bert-base-uncased", + }, + "data.reader.nq.single.train": { + "s3_url": [ + "https://dl.fbaipublicfiles.com/dpr/data/reader/nq/single/train.{}.pkl".format( + i + ) + for i in range(8) + ], + "original_ext": ".pkl", + "compressed": False, + "desc": "Reader model NQ train dataset input data preprocessed from retriever results (also trained on NQ)", + "license_files": NQ_LICENSE_FILES, + }, + "data.reader.nq.single.dev": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/reader/nq/single/dev.0.pkl", + "original_ext": ".pkl", + "compressed": False, + "desc": "Reader model NQ dev dataset input data preprocessed from retriever results (also trained on NQ)", + "license_files": NQ_LICENSE_FILES, + }, + "data.reader.nq.single.test": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/reader/nq/single/test.0.pkl", + "original_ext": ".pkl", + "compressed": False, + "desc": "Reader model NQ test dataset input data preprocessed from retriever results (also trained on NQ)", + "license_files": NQ_LICENSE_FILES, + }, + "data.reader.trivia.multi-hybrid.train": { + "s3_url": [ + "https://dl.fbaipublicfiles.com/dpr/data/reader/trivia/multi-hybrid/train.{}.pkl".format( + i + ) + for i in range(8) + ], + "original_ext": ".pkl", + "compressed": False, + "desc": "Reader model Trivia train dataset input data preprocessed from hybrid retriever results " + "(where dense part is trained on multiset)", + }, + "data.reader.trivia.multi-hybrid.dev": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/reader/trivia/multi-hybrid/dev.0.pkl", + "original_ext": ".pkl", + "compressed": False, + "desc": "Reader model Trivia dev dataset input data preprocessed from hybrid retriever results " + "(where dense part is trained on multiset)", + }, + "data.reader.trivia.multi-hybrid.test": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/reader/trivia/multi-hybrid/test.0.pkl", + "original_ext": ".pkl", + "compressed": False, + "desc": "Reader model Trivia test dataset input data preprocessed from hybrid retriever results " + "(where dense part is trained on multiset)", + }, + "checkpoint.reader.nq-single.hf-bert-base": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/reader/nq-single/hf_bert_base.cp", + "original_ext": ".cp", + "compressed": False, + "desc": "Reader weights trained on NQ-single retriever results and HF bert-base-uncased model", + }, + "checkpoint.reader.nq-trivia-hybrid.hf-bert-base": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/reader/nq-trivia-hybrid/hf_bert_base.cp", + "original_ext": ".cp", + "compressed": False, + "desc": "Reader weights trained on Trivia multi hybrid retriever results and HF bert-base-uncased model", + }, + # extra checkpoints for EfficientQA competition + "checkpoint.reader.nq-single-subset.hf-bert-base": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/reader/nq-single-seen_only/hf_bert_base.cp", + "original_ext": ".cp", + "compressed": False, + "desc": "Reader weights trained on NQ-single retriever results and HF bert-base-uncased model, when only Wikipedia pages seen during training are considered", + }, + "checkpoint.reader.nq-tfidf.hf-bert-base": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/reader/nq-drqa/hf_bert_base.cp", + "original_ext": ".cp", + "compressed": False, + "desc": "Reader weights trained on TFIDF results and HF bert-base-uncased model", + }, + "checkpoint.reader.nq-tfidf-subset.hf-bert-base": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/reader/nq-drqa-seen_only/hf_bert_base.cp", + "original_ext": ".cp", + "compressed": False, + "desc": "Reader weights trained on TFIDF results and HF bert-base-uncased model, when only Wikipedia pages seen during training are considered", + }, + # retrieval indexes + "indexes.single.nq.full.index": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/indexes/single/nq/full.index.dpr", + "original_ext": ".dpr", + "compressed": False, + "desc": "DPR index on NQ-single retriever", + }, + "indexes.single.nq.full.index_meta": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/indexes/single/nq/full.index_meta.dpr", + "original_ext": ".dpr", + "compressed": False, + "desc": "DPR index on NQ-single retriever (metadata)", + }, + "indexes.single.nq.subset.index": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/indexes/single/nq/seen_only.index.dpr", + "original_ext": ".dpr", + "compressed": False, + "desc": "DPR index on NQ-single retriever when only Wikipedia pages seen during training are considered", + }, + "indexes.single.nq.subset.index_meta": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/indexes/single/nq/seen_only.index_meta.dpr", + "original_ext": ".dpr", + "compressed": False, + "desc": "DPR index on NQ-single retriever when only Wikipedia pages seen during training are considered (metadata)", + }, + "indexes.tfidf.nq.full": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/indexes/drqa/nq/full-tfidf.npz", + "original_ext": ".npz", + "compressed": False, + "desc": "TFIDF index", + }, + "indexes.tfidf.nq.subset": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/indexes/drqa/nq/seen_only-tfidf.npz", + "original_ext": ".npz", + "compressed": False, + "desc": "TFIDF index when only Wikipedia pages seen during training are considered", + }, +} + + +def unpack(gzip_file: str, out_file: str): + logger.info("Uncompressing %s", gzip_file) + input = gzip.GzipFile(gzip_file, "rb") + s = input.read() + input.close() + output = open(out_file, "wb") + output.write(s) + output.close() + logger.info(" Saved to %s", out_file) + + +def download_resource( + s3_url: str, original_ext: str, compressed: bool, resource_key: str, out_dir: str +) -> Tuple[str, str]: + logger.info("Requested resource from %s", s3_url) + path_names = resource_key.split(".") + + if out_dir: + root_dir = out_dir + else: + # since hydra overrides the location for the 'current dir' for every run and we don't want to duplicate + # resources multiple times, remove the current folder's volatile part + root_dir = os.path.abspath("./") + if "/outputs/" in root_dir: + root_dir = root_dir[: root_dir.index("/outputs/")] + + logger.info("Download root_dir %s", root_dir) + + save_root = os.path.join( + root_dir, "downloads", *path_names[:-1] + ) # last segment is for file name + + pathlib.Path(save_root).mkdir(parents=True, exist_ok=True) + + local_file_uncompressed = os.path.abspath( + os.path.join(save_root, path_names[-1] + original_ext) + ) + logger.info("File to be downloaded as %s", local_file_uncompressed) + + if os.path.exists(local_file_uncompressed): + logger.info("File already exist %s", local_file_uncompressed) + return save_root, local_file_uncompressed + + local_file = os.path.abspath( + os.path.join( + save_root, path_names[-1] + (".tmp" if compressed else original_ext) + ) + ) + + wget.download(s3_url, out=local_file) + + logger.info("Downloaded to %s", local_file) + + if compressed: + uncompressed_file = os.path.join(save_root, path_names[-1] + original_ext) + unpack(local_file, uncompressed_file) + os.remove(local_file) + local_file = uncompressed_file + return save_root, local_file + + +def download_file(s3_url: str, out_dir: str, file_name: str): + logger.info("Loading from %s", s3_url) + local_file = os.path.join(out_dir, file_name) + + if os.path.exists(local_file): + logger.info("File already exist %s", local_file) + return + + wget.download(s3_url, out=local_file) + logger.info("Downloaded to %s", local_file) + + +def download(resource_key: str, out_dir: str = None): + if resource_key not in RESOURCES_MAP: + # match by prefix + resources = [k for k in RESOURCES_MAP.keys() if k.startswith(resource_key)] + if resources: + for key in resources: + download(key, out_dir) + else: + logger.info("no resources found for specified key") + return [] + download_info = RESOURCES_MAP[resource_key] + + s3_url = download_info["s3_url"] + + save_root_dir = None + data_files = [] + if isinstance(s3_url, list): + for i, url in enumerate(s3_url): + save_root_dir, local_file = download_resource( + url, + download_info["original_ext"], + download_info["compressed"], + "{}_{}".format(resource_key, i), + out_dir, + ) + data_files.append(local_file) + else: + save_root_dir, local_file = download_resource( + s3_url, + download_info["original_ext"], + download_info["compressed"], + resource_key, + out_dir, + ) + data_files.append(local_file) + + license_files = download_info.get("license_files", None) + if license_files: + download_file(license_files[0], save_root_dir, "LICENSE") + download_file(license_files[1], save_root_dir, "README") + return data_files + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--output_dir", + default="./", + type=str, + help="The output directory to download file", + ) + parser.add_argument( + "--resource", + type=str, + help="Resource name. See RESOURCES_MAP for all possible values", + ) + args = parser.parse_args() + if args.resource: + download(args.resource, args.output_dir) + else: + print("Please specify resource value. Possible options are:") + for k, v in RESOURCES_MAP.items(): + print("Resource key=%s : %s", k, v["desc"]) + + +if __name__ == "__main__": + main() diff --git a/research/information_retrieval/DPR/dpr/__init__.py b/research/information_retrieval/DPR/dpr/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/research/information_retrieval/DPR/dpr/data/__init__.py b/research/information_retrieval/DPR/dpr/data/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/research/information_retrieval/DPR/dpr/data/biencoder_data.py b/research/information_retrieval/DPR/dpr/data/biencoder_data.py new file mode 100644 index 00000000000..9193d47a776 --- /dev/null +++ b/research/information_retrieval/DPR/dpr/data/biencoder_data.py @@ -0,0 +1,613 @@ +import collections +import csv +import glob +import logging +import os +import random +from typing import Dict, List, Tuple + +import hydra +import jsonlines +import numpy as np +import torch +from omegaconf import DictConfig +from torch import Tensor as T + +from dpr.data.tables import Table +from dpr.utils.data_utils import read_data_from_json_files, Tensorizer + +logger = logging.getLogger(__name__) +BiEncoderPassage = collections.namedtuple("BiEncoderPassage", ["text", "title"]) + + +class BiEncoderSample(object): + query: str + positive_passages: List[BiEncoderPassage] + negative_passages: List[BiEncoderPassage] + hard_negative_passages: List[BiEncoderPassage] + + +class RepTokenSelector(object): + def get_positions(self, input_ids: T, tenzorizer: Tensorizer): + raise NotImplementedError + + +class RepStaticPosTokenSelector(RepTokenSelector): + def __init__(self, static_position: int = 0): + self.static_position = static_position + + def get_positions(self, input_ids: T, tenzorizer: Tensorizer): + return self.static_position + + +class RepSpecificTokenSelector(RepTokenSelector): + def __init__(self, token: str = "[CLS]"): + self.token = token + self.token_id = None + + def get_positions(self, input_ids: T, tenzorizer: Tensorizer): + if not self.token_id: + self.token_id = tenzorizer.get_token_id(self.token) + token_indexes = (input_ids == self.token_id).nonzero() + # check if all samples in input_ids has index presence and out a default value otherwise + bsz = input_ids.size(0) + if bsz == token_indexes.size(0): + return token_indexes + + token_indexes_result = [] + found_idx_cnt = 0 + for i in range(bsz): + if ( + found_idx_cnt < token_indexes.size(0) + and token_indexes[found_idx_cnt][0] == i + ): + # this samples has the special token + token_indexes_result.append(token_indexes[found_idx_cnt]) + found_idx_cnt += 1 + else: + logger.warning("missing special token %s", input_ids[i]) + + token_indexes_result.append( + torch.tensor([i, 0]).to(input_ids.device) + ) # setting 0-th token, i.e. CLS for BERT as the special one + token_indexes_result = torch.stack(token_indexes_result, dim=0) + return token_indexes_result + + +DEFAULT_SELECTOR = RepStaticPosTokenSelector() + + +class Dataset(torch.utils.data.Dataset): + def __init__( + self, + selector: DictConfig = None, + special_token: str = None, + shuffle_positives: bool = False, + query_special_suffix: str = None, + encoder_type: str = None, + ): + if selector: + self.selector = hydra.utils.instantiate(selector) + else: + self.selector = DEFAULT_SELECTOR + self.special_token = special_token + self.encoder_type = encoder_type + self.shuffle_positives = shuffle_positives + self.query_special_suffix = query_special_suffix + + def load_data(self): + raise NotImplementedError + + def __getitem__(self, index) -> BiEncoderSample: + raise NotImplementedError + + def _process_query(self, query: str): + # as of now, always normalize query + query = normalize_question(query) + if self.query_special_suffix and not query.endswith(self.query_special_suffix): + query += self.query_special_suffix + + return query + + +def get_dpr_files(source_name) -> List[str]: + if os.path.exists(source_name) or glob.glob(source_name): + return glob.glob(source_name) + else: + # try to use data downloader + from dpr.data.download_data import download + + return download(source_name) + + +class JsonQADataset(Dataset): + def __init__( + self, + file: str, + selector: DictConfig = None, + special_token: str = None, + encoder_type: str = None, + shuffle_positives: bool = False, + normalize: bool = False, + query_special_suffix: str = None, + ): + super().__init__( + selector, + special_token=special_token, + encoder_type=encoder_type, + shuffle_positives=shuffle_positives, + query_special_suffix=query_special_suffix, + ) + self.file = file + self.data_files = [] + self.data = [] + self.normalize = normalize + logger.info("Data files: %s", self.data_files) + + def load_data(self): + self.data_files = get_dpr_files(self.file) + data = read_data_from_json_files(self.data_files) + # filter those without positive ctx + self.data = [r for r in data if len(r["positive_ctxs"]) > 0] + logger.info("Total cleaned data size: {}".format(len(self.data))) + + def __getitem__(self, index) -> BiEncoderSample: + json_sample = self.data[index] + r = BiEncoderSample() + r.query = self._process_query(json_sample["question"]) + + positive_ctxs = json_sample["positive_ctxs"] + negative_ctxs = ( + json_sample["negative_ctxs"] if "negative_ctxs" in json_sample else [] + ) + hard_negative_ctxs = ( + json_sample["hard_negative_ctxs"] + if "hard_negative_ctxs" in json_sample + else [] + ) + + for ctx in positive_ctxs + negative_ctxs + hard_negative_ctxs: + if "title" not in ctx: + ctx["title"] = None + + def create_passage(ctx: dict): + return BiEncoderPassage( + normalize_passage(ctx["text"]) if self.normalize else ctx["text"], + ctx["title"], + ) + + r.positive_passages = [create_passage(ctx) for ctx in positive_ctxs] + r.negative_passages = [create_passage(ctx) for ctx in negative_ctxs] + r.hard_negative_passages = [create_passage(ctx) for ctx in hard_negative_ctxs] + return r + + def __len__(self): + return len(self.data) + + def get_qas(self) -> Tuple[List[str], List[str]]: + return [s["question"] for s in self.data], [s["answers"] for s in self.data] + + def get_qas_range( + self, start_idx: int, end_idx: int + ) -> Tuple[List[str], List[str]]: + return ( + [s["question"] for s in self.data[start_idx:end_idx]], + [s["answers"] for s in self.data[start_idx:end_idx]], + ) + + +def normalize_passage(ctx_text: str): + ctx_text = ctx_text.replace("\n", " ").replace("’", "'") + return ctx_text + + +def normalize_question(question: str) -> str: + question = question.replace("’", "'") + return question + + +class Cell: + def __init__(self): + self.value_tokens: List[str] = [] + self.type: str = "" + self.nested_tables: List[Table] = [] + + def __str__(self): + return " ".join(self.value_tokens) + + def to_dpr_json(self, cell_idx: int): + r = {"col": cell_idx} + r["value"] = str(self) + return r + + +class Row: + def __init__(self): + self.cells: List[Cell] = [] + + def __str__(self): + return "| ".join([str(c) for c in self.cells]) + + def visit(self, tokens_function, row_idx: int): + for i, c in enumerate(self.cells): + if c.value_tokens: + tokens_function(c.value_tokens, row_idx, i) + + def to_dpr_json(self, row_idx: int): + r = {"row": row_idx} + r["columns"] = [c.to_dpr_json(i) for i, c in enumerate(self.cells)] + return r + + +class Table(object): + def __init__(self, caption=""): + self.caption = caption + self.body: List[Row] = [] + self.key = None + self.gold_match = False + + def __str__(self): + table_str = ": {}\n".format(self.caption) + table_str += " rows:\n" + for i, r in enumerate(self.body): + table_str += " row #{}: {}\n".format(i, str(r)) + + return table_str + + def get_key(self) -> str: + if not self.key: + self.key = str(self) + return self.key + + def visit(self, tokens_function, include_caption: bool = False) -> bool: + if include_caption: + tokens_function(self.caption, -1, -1) + for i, r in enumerate(self.body): + r.visit(tokens_function, i) + + def to_dpr_json(self): + r = { + "caption": self.caption, + "rows": [r.to_dpr_json(i) for i, r in enumerate(self.body)], + } + if self.gold_match: + r["gold_match"] = 1 + return r + + +class NQTableParser(object): + def __init__(self, tokens, is_html_mask, title): + self.tokens = tokens + self.is_html_mask = is_html_mask + self.max_idx = len(self.tokens) + self.all_tables = [] + + self.current_table: Table = None + self.tables_stack = collections.deque() + self.title = title + + def parse(self) -> List[Table]: + self.all_tables = [] + self.tables_stack = collections.deque() + + for i in range(self.max_idx): + + t = self.tokens[i] + + if not self.is_html_mask[i]: + # cell content + self._on_content(t) + continue + + if "": + self._on_table_end() + elif "": + self._onRowEnd() + elif "", ""]: + self._on_cell_end() + + return self.all_tables + + def _on_table_start(self): + caption = self.title + parent_table = self.current_table + if parent_table: + self.tables_stack.append(parent_table) + + caption = parent_table.caption + if parent_table.body and parent_table.body[-1].cells: + current_cell = self.current_table.body[-1].cells[-1] + caption += " | " + " ".join(current_cell.value_tokens) + + t = Table() + t.caption = caption + self.current_table = t + self.all_tables.append(t) + + def _on_table_end(self): + t = self.current_table + if t: + if self.tables_stack: # t is a nested table + self.current_table = self.tables_stack.pop() + if self.current_table.body: + current_cell = self.current_table.body[-1].cells[-1] + current_cell.nested_tables.append(t) + else: + logger.error("table end without table object") + + def _onRowStart(self): + self.current_table.body.append(Row()) + + def _onRowEnd(self): + pass + + def _onCellStart(self): + current_row = self.current_table.body[-1] + current_row.cells.append(Cell()) + + def _on_cell_end(self): + pass + + def _on_content(self, token): + if self.current_table.body: + current_row = self.current_table.body[-1] + current_cell = current_row.cells[-1] + current_cell.value_tokens.append(token) + else: # tokens outside of row/cells. Just append to the table caption. + self.current_table.caption += " " + token + + +def read_nq_tables_jsonl(path: str) -> Dict[str, Table]: + tables_with_issues = 0 + single_row_tables = 0 + nested_tables = 0 + regular_tables = 0 + total_tables = 0 + total_rows = 0 + tables_dict = {} + + with jsonlines.open(path, mode="r") as jsonl_reader: + for jline in jsonl_reader: + tokens = jline["tokens"] + + if "( hide ) This section has multiple issues" in " ".join(tokens): + tables_with_issues += 1 + continue + + mask = jline["html_mask"] + # page_url = jline["doc_url"] + title = jline["title"] + p = NQTableParser(tokens, mask, title) + tables = p.parse() + + # table = parse_table(tokens, mask) + + nested_tables += len(tables[1:]) + + for t in tables: + total_tables += 1 + + # calc amount of non empty rows + non_empty_rows = sum( + [ + 1 + for r in t.body + if r.cells and any([True for c in r.cells if c.value_tokens]) + ] + ) + + if non_empty_rows <= 1: + single_row_tables += 1 + else: + regular_tables += 1 + total_rows += len(t.body) + + if t.get_key() not in tables_dict: + tables_dict[t.get_key()] = t + + if len(tables_dict) % 1000 == 0: + logger.info("tables_dict %d", len(tables_dict)) + + logger.info("regular tables %d", regular_tables) + logger.info("tables_with_issues %d", tables_with_issues) + logger.info("single_row_tables %d", single_row_tables) + logger.info("nested_tables %d", nested_tables) + return tables_dict + + +def get_table_string_for_answer_check(table: Table): # this doesn't use caption + table_text = "" + for r in table.body: + table_text += " . ".join([" ".join(c.value_tokens) for c in r.cells]) + table_text += " . " + return table_text + + +class JsonLTablesQADataset(Dataset): + def __init__( + self, + file: str, + is_train_set: bool, + selector: DictConfig = None, + shuffle_positives: bool = False, + max_negatives: int = 1, + seed: int = 0, + max_len=100, + split_type: str = "type1", + ): + super().__init__(selector, shuffle_positives=shuffle_positives) + self.data_files = glob.glob(file) + self.data = [] + self.is_train_set = is_train_set + self.max_negatives = max_negatives + self.rnd = random.Random(seed) + self.max_len = max_len + self.linearize_func = JsonLTablesQADataset.get_lin_func(split_type) + + def load_data(self): + data = [] + for path in self.data_files: + with jsonlines.open(path, mode="r") as jsonl_reader: + data += [jline for jline in jsonl_reader] + + # filter those without positive ctx + self.data = [r for r in data if len(r["positive_ctxs"]) > 0] + logger.info("Total cleaned data size: {}".format(len(self.data))) + + def __getitem__(self, index) -> BiEncoderSample: + json_sample = self.data[index] + r = BiEncoderSample() + r.query = json_sample["question"] + positive_ctxs = json_sample["positive_ctxs"] + hard_negative_ctxs = json_sample["hard_negative_ctxs"] + + if self.shuffle_positives: + self.rnd.shuffle(positive_ctxs) + + if self.is_train_set: + self.rnd.shuffle(hard_negative_ctxs) + positive_ctxs = positive_ctxs[0:1] + hard_negative_ctxs = hard_negative_ctxs[0 : self.max_negatives] + + r.positive_passages = [ + BiEncoderPassage(self.linearize_func(self, ctx, True), ctx["caption"]) + for ctx in positive_ctxs + ] + r.negative_passages = [] + r.hard_negative_passages = [ + BiEncoderPassage(self.linearize_func(self, ctx, False), ctx["caption"]) + for ctx in hard_negative_ctxs + ] + return r + + def __len__(self): + return len(self.data) + + @classmethod + def get_lin_func(cls, split_type: str): + f = { + "type1": JsonLTablesQADataset._linearize_table, + } + return f[split_type] + + @classmethod + def split_table(cls, t: dict, max_length: int): + rows = t["rows"] + header = None + header_len = 0 + start_row = 0 + + # get the first non empty row as the "header" + for i, r in enumerate(rows): + row_lin, row_len = JsonLTablesQADataset._linearize_row(r) + if len(row_lin) > 1: # TODO: change to checking cell value tokens + header = row_lin + header_len += row_len + start_row = i + break + + chunks = [] + current_rows = [header] + current_len = header_len + + for i in range(start_row + 1, len(rows)): + row_lin, row_len = JsonLTablesQADataset._linearize_row(rows[i]) + if len(row_lin) > 1: # TODO: change to checking cell value tokens + current_rows.append(row_lin) + current_len += row_len + if current_len >= max_length: + # linearize chunk + linearized_str = "\n".join(current_rows) + "\n" + chunks.append(linearized_str) + current_rows = [header] + current_len = header_len + + if len(current_rows) > 1: + linearized_str = "\n".join(current_rows) + "\n" + chunks.append(linearized_str) + return chunks + + def _linearize_table(self, t: dict, is_positive: bool) -> str: + rows = t["rows"] + selected_rows = set() + rows_linearized = [] + total_words_len = 0 + + # get the first non empty row as the "header" + for i, r in enumerate(rows): + row_lin, row_len = JsonLTablesQADataset._linearize_row(r) + if len(row_lin) > 1: # TODO: change to checking cell value tokens + selected_rows.add(i) + rows_linearized.append(row_lin) + total_words_len += row_len + break + + # split to chunks + if is_positive: + row_idx_with_answers = [ap[0] for ap in t["answer_pos"]] + + if self.shuffle_positives: + self.rnd.shuffle(row_idx_with_answers) + for i in row_idx_with_answers: + if i not in selected_rows: + row_lin, row_len = JsonLTablesQADataset._linearize_row(rows[i]) + selected_rows.add(i) + rows_linearized.append(row_lin) + total_words_len += row_len + if total_words_len >= self.max_len: + break + + if total_words_len < self.max_len: # append random rows + + if self.is_train_set: + rows_indexes = np.random.permutation(range(len(rows))) + else: + rows_indexes = [*range(len(rows))] + + for i in rows_indexes: + if i not in selected_rows: + row_lin, row_len = JsonLTablesQADataset._linearize_row(rows[i]) + if len(row_lin) > 1: # TODO: change to checking cell value tokens + selected_rows.add(i) + rows_linearized.append(row_lin) + total_words_len += row_len + if total_words_len >= self.max_len: + break + + linearized_str = "" + for r in rows_linearized: + linearized_str += r + "\n" + + return linearized_str + + @classmethod + def _linearize_row(cls, row: dict) -> Tuple[str, int]: + cell_values = [c["value"] for c in row["columns"]] + total_words = sum(len(c.split(" ")) for c in cell_values) + return ", ".join([c["value"] for c in row["columns"]]), total_words + + +def split_tables_to_chunks( + tables_dict: Dict[str, Table], max_table_len: int, split_type: str = "type1" +) -> List[Tuple[int, str, str, int]]: + tables_as_dicts = [t.to_dpr_json() for k, t in tables_dict.items()] + chunks = [] + chunk_id = 0 + for i, t in enumerate(tables_as_dicts): + # TODO: support other types + assert split_type == "type1" + table_chunks = JsonLTablesQADataset.split_table(t, max_table_len) + title = t["caption"] + for c in table_chunks: + # chunk id , text, title, external_id + chunks.append((chunk_id, c, title, i)) + chunk_id += 1 + if i % 1000 == 0: + logger.info("Splitted %d tables to %d chunks", i, len(chunks)) + return chunks diff --git a/research/information_retrieval/DPR/dpr/data/download_data.py b/research/information_retrieval/DPR/dpr/data/download_data.py new file mode 100644 index 00000000000..d362e222efb --- /dev/null +++ b/research/information_retrieval/DPR/dpr/data/download_data.py @@ -0,0 +1,550 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" + Command line tool to download various preprocessed data sources & checkpoints for DPR +""" + +import argparse +import gzip +import logging +import os +import pathlib +import wget + +from typing import Tuple + +logger = logging.getLogger(__name__) + +# TODO: move to hydra config group + +NQ_LICENSE_FILES = [ + "https://dl.fbaipublicfiles.com/dpr/nq_license/LICENSE", + "https://dl.fbaipublicfiles.com/dpr/nq_license/README", +] + +RESOURCES_MAP = { + "data.wikipedia_split.psgs_w100": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz", + "original_ext": ".tsv", + "compressed": True, + "desc": "Entire wikipedia passages set obtain by splitting all pages into 100-word segments (no overlap)", + }, + "data.retriever.nq-dev": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-dev.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "NQ dev subset with passages pools for the Retriever train time validation", + "license_files": NQ_LICENSE_FILES, + }, + "data.retriever.nq-train": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-train.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "NQ train subset with passages pools for the Retriever training", + "license_files": NQ_LICENSE_FILES, + }, + "data.retriever.nq-adv-hn-train": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-adv-hn-train.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "NQ train subset with hard negative passages mined using the baseline DPR NQ encoders & wikipedia index", + "license_files": NQ_LICENSE_FILES, + }, + "data.retriever.trivia-dev": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-trivia-dev.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "TriviaQA dev subset with passages pools for the Retriever train time validation", + }, + "data.retriever.trivia-train": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-trivia-train.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "TriviaQA train subset with passages pools for the Retriever training", + }, + "data.retriever.squad1-train": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-squad1-train.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "SQUAD 1.1 train subset with passages pools for the Retriever training", + }, + "data.retriever.squad1-dev": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-squad1-dev.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "SQUAD 1.1 dev subset with passages pools for the Retriever train time validation", + }, + "data.retriever.webq-train": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-webquestions-train.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "WebQuestions dev subset with passages pools for the Retriever train time validation", + }, + "data.retriever.webq-dev": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-webquestions-dev.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "WebQuestions dev subset with passages pools for the Retriever train time validation", + }, + "data.retriever.curatedtrec-train": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-curatedtrec-train.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "CuratedTrec dev subset with passages pools for the Retriever train time validation", + }, + "data.retriever.curatedtrec-dev": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-curatedtrec-dev.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "CuratedTrec dev subset with passages pools for the Retriever train time validation", + }, + "data.retriever.qas.nq-dev": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/nq-dev.qa.csv", + "original_ext": ".csv", + "compressed": False, + "desc": "NQ dev subset for Retriever validation and IR results generation", + "license_files": NQ_LICENSE_FILES, + }, + "data.retriever.qas.nq-test": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/nq-test.qa.csv", + "original_ext": ".csv", + "compressed": False, + "desc": "NQ test subset for Retriever validation and IR results generation", + "license_files": NQ_LICENSE_FILES, + }, + "data.retriever.qas.nq-train": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/nq-train.qa.csv", + "original_ext": ".csv", + "compressed": False, + "desc": "NQ train subset for Retriever validation and IR results generation", + "license_files": NQ_LICENSE_FILES, + }, + # + "data.retriever.qas.trivia-dev": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/trivia-dev.qa.csv.gz", + "original_ext": ".csv", + "compressed": True, + "desc": "Trivia dev subset for Retriever validation and IR results generation", + }, + "data.retriever.qas.trivia-test": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/trivia-test.qa.csv.gz", + "original_ext": ".csv", + "compressed": True, + "desc": "Trivia test subset for Retriever validation and IR results generation", + }, + "data.retriever.qas.trivia-train": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/trivia-train.qa.csv.gz", + "original_ext": ".csv", + "compressed": True, + "desc": "Trivia train subset for Retriever validation and IR results generation", + }, + "data.retriever.qas.squad1-test": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/squad1-test.qa.csv", + "original_ext": ".csv", + "compressed": False, + "desc": "Trivia test subset for Retriever validation and IR results generation", + }, + "data.retriever.qas.webq-test": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/webquestions-test.qa.csv", + "original_ext": ".csv", + "compressed": False, + "desc": "WebQuestions test subset for Retriever validation and IR results generation", + }, + "data.retriever.qas.curatedtrec-test": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/curatedtrec-test.qa.csv", + "original_ext": ".csv", + "compressed": False, + "desc": "CuratedTrec test subset for Retriever validation and IR results generation", + }, + "data.gold_passages_info.nq_train": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/nq_gold_info/nq-train_gold_info.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "Original NQ (our train subset) gold positive passages and alternative question tokenization", + "license_files": NQ_LICENSE_FILES, + }, + "data.gold_passages_info.nq_dev": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/nq_gold_info/nq-dev_gold_info.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "Original NQ (our dev subset) gold positive passages and alternative question tokenization", + "license_files": NQ_LICENSE_FILES, + }, + "data.gold_passages_info.nq_test": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/nq_gold_info/nq-test_gold_info.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "Original NQ (our test, original dev subset) gold positive passages and alternative question " + "tokenization", + "license_files": NQ_LICENSE_FILES, + }, + "pretrained.fairseq.roberta-base.dict": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/pretrained/fairseq/roberta/dict.txt", + "original_ext": ".txt", + "compressed": False, + "desc": "Dictionary for pretrained fairseq roberta model", + }, + "pretrained.fairseq.roberta-base.model": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/pretrained/fairseq/roberta/model.pt", + "original_ext": ".pt", + "compressed": False, + "desc": "Weights for pretrained fairseq roberta base model", + }, + "pretrained.pytext.bert-base.model": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/pretrained/pytext/bert/bert-base-uncased.pt", + "original_ext": ".pt", + "compressed": False, + "desc": "Weights for pretrained pytext bert base model", + }, + "data.retriever_results.nq.single.wikipedia_passages": { + "s3_url": [ + "https://dl.fbaipublicfiles.com/dpr/data/wiki_encoded/single/nq/wiki_passages_{}".format( + i + ) + for i in range(50) + ], + "original_ext": ".pkl", + "compressed": False, + "desc": "Encoded wikipedia files using a biencoder checkpoint(" + "checkpoint.retriever.single.nq.bert-base-encoder) trained on NQ dataset ", + }, + "data.retriever_results.nq.single-adv-hn.wikipedia_passages": { + "s3_url": [ + "https://dl.fbaipublicfiles.com/dpr/data/wiki_encoded/single-adv-hn/nq/wiki_passages_{}".format( + i + ) + for i in range(50) + ], + "original_ext": ".pkl", + "compressed": False, + "desc": "Encoded wikipedia files using a single-adv-hn checkpoint(" + "checkpoint.retriever.single.nq.bert-base-encoder) trained on NQ dataset ", + }, + "data.retriever_results.nq.single-adv-hn.wikipedia_passages": { + "s3_url": [ + "https://dl.fbaipublicfiles.com/dpr/data/wiki_encoded/single-adv-hn/nq/wiki_passages_{}".format( + i + ) + for i in range(50) + ], + "original_ext": ".pkl", + "compressed": False, + "desc": "Encoded wikipedia files using a biencoder checkpoint(" + "checkpoint.retriever.single-adv-hn.nq.bert-base-encoder) trained on NQ dataset + adversarial hard negatives", + }, + "data.retriever_results.nq.single.test": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever_results/single/nq-test.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "Retrieval results of NQ test dataset for the encoder trained on NQ", + "license_files": NQ_LICENSE_FILES, + }, + "data.retriever_results.nq.single.dev": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever_results/single/nq-dev.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "Retrieval results of NQ dev dataset for the encoder trained on NQ", + "license_files": NQ_LICENSE_FILES, + }, + "data.retriever_results.nq.single.train": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever_results/single/nq-train.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "Retrieval results of NQ train dataset for the encoder trained on NQ", + "license_files": NQ_LICENSE_FILES, + }, + "data.retriever_results.nq.single-adv-hn.test": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever_results/single-adv-hn/nq-test.json.gz", + "original_ext": ".json", + "compressed": True, + "desc": "Retrieval results of NQ test dataset for the encoder trained on NQ + adversarial hard negatives", + "license_files": NQ_LICENSE_FILES, + }, + "checkpoint.retriever.single.nq.bert-base-encoder": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/retriever/single/nq/hf_bert_base.cp", + "original_ext": ".cp", + "compressed": False, + "desc": "Biencoder weights trained on NQ data and HF bert-base-uncased model", + }, + "checkpoint.retriever.multiset.bert-base-encoder": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/retriver/multiset/hf_bert_base.cp", + "original_ext": ".cp", + "compressed": False, + "desc": "Biencoder weights trained on multi set data and HF bert-base-uncased model", + }, + "checkpoint.retriever.single-adv-hn.nq.bert-base-encoder": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/retriver/single-adv-hn/nq/hf_bert_base.cp", + "original_ext": ".cp", + "compressed": False, + "desc": "Biencoder weights trained on the original DPR NQ data combined with adversarial hard negatives (See data.retriever.nq-adv-hn-train resource). " + "The model is HF bert-base-uncased", + }, + "data.reader.nq.single.train": { + "s3_url": [ + "https://dl.fbaipublicfiles.com/dpr/data/reader/nq/single/train.{}.pkl".format( + i + ) + for i in range(8) + ], + "original_ext": ".pkl", + "compressed": False, + "desc": "Reader model NQ train dataset input data preprocessed from retriever results (also trained on NQ)", + "license_files": NQ_LICENSE_FILES, + }, + "data.reader.nq.single.dev": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/reader/nq/single/dev.0.pkl", + "original_ext": ".pkl", + "compressed": False, + "desc": "Reader model NQ dev dataset input data preprocessed from retriever results (also trained on NQ)", + "license_files": NQ_LICENSE_FILES, + }, + "data.reader.nq.single.test": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/reader/nq/single/test.0.pkl", + "original_ext": ".pkl", + "compressed": False, + "desc": "Reader model NQ test dataset input data preprocessed from retriever results (also trained on NQ)", + "license_files": NQ_LICENSE_FILES, + }, + "data.reader.trivia.multi-hybrid.train": { + "s3_url": [ + "https://dl.fbaipublicfiles.com/dpr/data/reader/trivia/multi-hybrid/train.{}.pkl".format( + i + ) + for i in range(8) + ], + "original_ext": ".pkl", + "compressed": False, + "desc": "Reader model Trivia train dataset input data preprocessed from hybrid retriever results " + "(where dense part is trained on multiset)", + }, + "data.reader.trivia.multi-hybrid.dev": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/reader/trivia/multi-hybrid/dev.0.pkl", + "original_ext": ".pkl", + "compressed": False, + "desc": "Reader model Trivia dev dataset input data preprocessed from hybrid retriever results " + "(where dense part is trained on multiset)", + }, + "data.reader.trivia.multi-hybrid.test": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/reader/trivia/multi-hybrid/test.0.pkl", + "original_ext": ".pkl", + "compressed": False, + "desc": "Reader model Trivia test dataset input data preprocessed from hybrid retriever results " + "(where dense part is trained on multiset)", + }, + "checkpoint.reader.nq-single.hf-bert-base": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/reader/nq-single/hf_bert_base.cp", + "original_ext": ".cp", + "compressed": False, + "desc": "Reader weights trained on NQ-single retriever results and HF bert-base-uncased model", + }, + "checkpoint.reader.nq-trivia-hybrid.hf-bert-base": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/reader/nq-trivia-hybrid/hf_bert_base.cp", + "original_ext": ".cp", + "compressed": False, + "desc": "Reader weights trained on Trivia multi hybrid retriever results and HF bert-base-uncased model", + }, + # extra checkpoints for EfficientQA competition + "checkpoint.reader.nq-single-subset.hf-bert-base": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/reader/nq-single-seen_only/hf_bert_base.cp", + "original_ext": ".cp", + "compressed": False, + "desc": "Reader weights trained on NQ-single retriever results and HF bert-base-uncased model, when only Wikipedia pages seen during training are considered", + }, + "checkpoint.reader.nq-tfidf.hf-bert-base": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/reader/nq-drqa/hf_bert_base.cp", + "original_ext": ".cp", + "compressed": False, + "desc": "Reader weights trained on TFIDF results and HF bert-base-uncased model", + }, + "checkpoint.reader.nq-tfidf-subset.hf-bert-base": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/reader/nq-drqa-seen_only/hf_bert_base.cp", + "original_ext": ".cp", + "compressed": False, + "desc": "Reader weights trained on TFIDF results and HF bert-base-uncased model, when only Wikipedia pages seen during training are considered", + }, + # retrieval indexes + "indexes.single.nq.full.index": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/indexes/single/nq/full.index.dpr", + "original_ext": ".dpr", + "compressed": False, + "desc": "DPR index on NQ-single retriever", + }, + "indexes.single.nq.full.index_meta": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/indexes/single/nq/full.index_meta.dpr", + "original_ext": ".dpr", + "compressed": False, + "desc": "DPR index on NQ-single retriever (metadata)", + }, + "indexes.single.nq.subset.index": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/indexes/single/nq/seen_only.index.dpr", + "original_ext": ".dpr", + "compressed": False, + "desc": "DPR index on NQ-single retriever when only Wikipedia pages seen during training are considered", + }, + "indexes.single.nq.subset.index_meta": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/indexes/single/nq/seen_only.index_meta.dpr", + "original_ext": ".dpr", + "compressed": False, + "desc": "DPR index on NQ-single retriever when only Wikipedia pages seen during training are considered (metadata)", + }, + "indexes.tfidf.nq.full": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/indexes/drqa/nq/full-tfidf.npz", + "original_ext": ".npz", + "compressed": False, + "desc": "TFIDF index", + }, + "indexes.tfidf.nq.subset": { + "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/indexes/drqa/nq/seen_only-tfidf.npz", + "original_ext": ".npz", + "compressed": False, + "desc": "TFIDF index when only Wikipedia pages seen during training are considered", + }, +} + + +def unpack(gzip_file: str, out_file: str): + logger.info("Uncompressing %s", gzip_file) + input = gzip.GzipFile(gzip_file, "rb") + s = input.read() + input.close() + output = open(out_file, "wb") + output.write(s) + output.close() + logger.info(" Saved to %s", out_file) + + +def download_resource( + s3_url: str, original_ext: str, compressed: bool, resource_key: str, out_dir: str +) -> Tuple[str, str]: + logger.info("Requested resource from %s", s3_url) + path_names = resource_key.split(".") + + if out_dir: + root_dir = out_dir + else: + # since hydra overrides the location for the 'current dir' for every run and we don't want to duplicate + # resources multiple times, remove the current folder's volatile part + root_dir = os.path.abspath("./") + if "/outputs/" in root_dir: + root_dir = root_dir[: root_dir.index("/outputs/")] + + logger.info("Download root_dir %s", root_dir) + + save_root = os.path.join( + root_dir, "downloads", *path_names[:-1] + ) # last segment is for file name + + pathlib.Path(save_root).mkdir(parents=True, exist_ok=True) + + local_file_uncompressed = os.path.abspath( + os.path.join(save_root, path_names[-1] + original_ext) + ) + logger.info("File to be downloaded as %s", local_file_uncompressed) + + if os.path.exists(local_file_uncompressed): + logger.info("File already exist %s", local_file_uncompressed) + return save_root, local_file_uncompressed + + local_file = os.path.abspath( + os.path.join( + save_root, path_names[-1] + (".tmp" if compressed else original_ext) + ) + ) + + wget.download(s3_url, out=local_file) + + logger.info("Downloaded to %s", local_file) + + if compressed: + uncompressed_file = os.path.join(save_root, path_names[-1] + original_ext) + unpack(local_file, uncompressed_file) + os.remove(local_file) + local_file = uncompressed_file + return save_root, local_file + + +def download_file(s3_url: str, out_dir: str, file_name: str): + logger.info("Loading from %s", s3_url) + local_file = os.path.join(out_dir, file_name) + + if os.path.exists(local_file): + logger.info("File already exist %s", local_file) + return + + wget.download(s3_url, out=local_file) + logger.info("Downloaded to %s", local_file) + + +def download(resource_key: str, out_dir: str = None): + if resource_key not in RESOURCES_MAP: + # match by prefix + resources = [k for k in RESOURCES_MAP.keys() if k.startswith(resource_key)] + if resources: + for key in resources: + download(key, out_dir) + else: + logger.info("no resources found for specified key") + return [] + download_info = RESOURCES_MAP[resource_key] + + s3_url = download_info["s3_url"] + + save_root_dir = None + data_files = [] + if isinstance(s3_url, list): + for i, url in enumerate(s3_url): + save_root_dir, local_file = download_resource( + url, + download_info["original_ext"], + download_info["compressed"], + "{}_{}".format(resource_key, i), + out_dir, + ) + data_files.append(local_file) + else: + save_root_dir, local_file = download_resource( + s3_url, + download_info["original_ext"], + download_info["compressed"], + resource_key, + out_dir, + ) + data_files.append(local_file) + + license_files = download_info.get("license_files", None) + if license_files: + download_file(license_files[0], save_root_dir, "LICENSE") + download_file(license_files[1], save_root_dir, "README") + return data_files + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--output_dir", + default="./", + type=str, + help="The output directory to download file", + ) + parser.add_argument( + "--resource", + type=str, + help="Resource name. See RESOURCES_MAP for all possible values", + ) + args = parser.parse_args() + if args.resource: + download(args.resource, args.output_dir) + else: + print("Please specify resource value. Possible options are:") + for k, v in RESOURCES_MAP.items(): + print("Resource key=%s : %s", k, v["desc"]) + + +if __name__ == "__main__": + main() diff --git a/research/information_retrieval/DPR/dpr/data/qa_validation.py b/research/information_retrieval/DPR/dpr/data/qa_validation.py new file mode 100644 index 00000000000..d05bae6ae57 --- /dev/null +++ b/research/information_retrieval/DPR/dpr/data/qa_validation.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" + Set of utilities for Q&A results validation tasks - Retriver passage validation and Reader predicted answer validation +""" + +import collections +import logging +import string +import unicodedata +from multiprocessing import Pool as ProcessPool + +import regex as re +from functools import partial +from typing import Tuple, List, Dict + +from dpr.data.retriever_data import TableChunk +from dpr.utils.tokenizers import SimpleTokenizer + +logger = logging.getLogger(__name__) + +QAMatchStats = collections.namedtuple( + "QAMatchStats", ["top_k_hits", "questions_doc_hits"] +) + +QATableMatchStats = collections.namedtuple( + "QAMatchStats", ["top_k_chunk_hits", "top_k_table_hits", "questions_doc_hits"] +) + + +def calculate_matches( + all_docs: Dict[object, Tuple[str, str]], + answers: List[List[str]], + closest_docs: List[Tuple[List[object], List[float]]], + workers_num: int, + match_type: str, +) -> QAMatchStats: + """ + Evaluates answers presence in the set of documents. This function is supposed to be used with a large collection of + documents and results. It internally forks multiple sub-processes for evaluation and then merges results + :param all_docs: dictionary of the entire documents database. doc_id -> (doc_text, title) + :param answers: list of answers's list. One list per question + :param closest_docs: document ids of the top results along with their scores + :param workers_num: amount of parallel threads to process data + :param match_type: type of answer matching. Refer to has_answer code for available options + :return: matching information tuple. + top_k_hits - a list where the index is the amount of top documents retrieved and the value is the total amount of + valid matches across an entire dataset. + questions_doc_hits - more detailed info with answer matches for every question and every retrieved document + """ + global dpr_all_documents + dpr_all_documents = all_docs + logger.info("dpr_all_documents size %d", len(dpr_all_documents)) + + tok_opts = {} + tokenizer = SimpleTokenizer(**tok_opts) + + processes = ProcessPool(processes=workers_num) + logger.info("Matching answers in top docs...") + get_score_partial = partial( + check_answer, match_type=match_type, tokenizer=tokenizer + ) + + questions_answers_docs = zip(answers, closest_docs) + scores = processes.map(get_score_partial, questions_answers_docs) + + logger.info("Per question validation results len=%d", len(scores)) + + n_docs = len(closest_docs[0][0]) + top_k_hits = [0] * n_docs + for question_hits in scores: + best_hit = next((i for i, x in enumerate(question_hits) if x), None) + if best_hit is not None: + top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]] + + return QAMatchStats(top_k_hits, scores) + + +def check_answer(questions_answers_docs, tokenizer, match_type) -> List[bool]: + """Search through all the top docs to see if they have any of the answers.""" + answers, (doc_ids, doc_scores) = questions_answers_docs + + global dpr_all_documents + hits = [] + + for i, doc_id in enumerate(doc_ids): + doc = dpr_all_documents[doc_id] + text = doc[0] + + answer_found = False + if text is None: # cannot find the document for some reason + logger.warning("no doc in db") + hits.append(False) + continue + + if has_answer(answers, text, tokenizer, match_type): + answer_found = True + hits.append(answer_found) + return hits + + +def has_answer(answers, text, tokenizer, match_type) -> bool: + """Check if a document contains an answer string. + If `match_type` is string, token matching is done between the text and answer. + If `match_type` is regex, we search the whole text with the regex. + """ + text = _normalize(text) + + if match_type == "string": + # Answer is a list of possible strings + text = tokenizer.tokenize(text).words(uncased=True) + + for single_answer in answers: + single_answer = _normalize(single_answer) + single_answer = tokenizer.tokenize(single_answer) + single_answer = single_answer.words(uncased=True) + + for i in range(0, len(text) - len(single_answer) + 1): + if single_answer == text[i : i + len(single_answer)]: + return True + + elif match_type == "regex": + # Answer is a regex + for single_answer in answers: + single_answer = _normalize(single_answer) + if regex_match(text, single_answer): + return True + return False + + +def regex_match(text, pattern): + """Test if a regex pattern is contained within a text.""" + try: + pattern = re.compile(pattern, flags=re.IGNORECASE + re.UNICODE + re.MULTILINE) + except BaseException: + return False + return pattern.search(text) is not None + + +# function for the reader model answer validation +def exact_match_score(prediction, ground_truth): + return _normalize_answer(prediction) == _normalize_answer(ground_truth) + + +def _normalize_answer(s): + def remove_articles(text): + return re.sub(r"\b(a|an|the)\b", " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def _normalize(text): + return unicodedata.normalize("NFD", text) + + +def calculate_chunked_matches( + all_docs: Dict[object, TableChunk], + answers: List[List[str]], + closest_docs: List[Tuple[List[object], List[float]]], + workers_num: int, + match_type: str, +) -> QATableMatchStats: + global dpr_all_documents + dpr_all_documents = all_docs + + global dpr_all_tables + dpr_all_tables = {} + + for key, table_chunk in all_docs.items(): + table_str, title, table_id = table_chunk + table_chunks = dpr_all_tables.get(table_id, []) + table_chunks.append((table_str, title)) + dpr_all_tables[table_id] = table_chunks + + tok_opts = {} + tokenizer = SimpleTokenizer(**tok_opts) + + processes = ProcessPool(processes=workers_num) + + logger.info("Matching answers in top docs...") + get_score_partial = partial( + check_chunked_docs_answer, match_type=match_type, tokenizer=tokenizer + ) + questions_answers_docs = zip(answers, closest_docs) + scores = processes.map(get_score_partial, questions_answers_docs) + logger.info("Per question validation results len=%d", len(scores)) + + n_docs = len(closest_docs[0][0]) + top_k_hits = [0] * n_docs + top_k_orig_hits = [0] * n_docs + for s in scores: + question_hits, question_orig_doc_hits = s + best_hit = next((i for i, x in enumerate(question_hits) if x), None) + if best_hit is not None: + top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]] + + best_hit = next((i for i, x in enumerate(question_orig_doc_hits) if x), None) + if best_hit is not None: + top_k_orig_hits[best_hit:] = [v + 1 for v in top_k_orig_hits[best_hit:]] + + return QATableMatchStats(top_k_hits, top_k_orig_hits, scores) diff --git a/research/information_retrieval/DPR/dpr/data/reader_data.py b/research/information_retrieval/DPR/dpr/data/reader_data.py new file mode 100644 index 00000000000..5395b521445 --- /dev/null +++ b/research/information_retrieval/DPR/dpr/data/reader_data.py @@ -0,0 +1,646 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" + Set of utilities for the Reader model related data processing tasks +""" + +import collections +import glob +import json +import logging +import math +import multiprocessing +import os +import pickle +import torch + +from functools import partial +from typing import Tuple, List, Dict, Iterable, Optional +from torch import Tensor as T +from tqdm import tqdm + +from dpr.utils.data_utils import Tensorizer, read_serialized_data_from_files + +logger = logging.getLogger() + + +class ReaderPassage(object): + """ + Container to collect and cache all Q&A passages related attributes before generating the reader input + """ + + def __init__( + self, + id=None, + text: str = None, + title: str = None, + score=None, + has_answer: bool = None, + ): + self.id = id + # string passage representations + self.passage_text = text + self.title = title + self.score = score + self.has_answer = has_answer + self.passage_token_ids = None + # offset of the actual passage (i.e. not a question or may be title) in the sequence_ids + self.passage_offset = None + self.answers_spans = None + # passage token ids + self.sequence_ids = None + + def on_serialize(self): + # store only final sequence_ids and the ctx offset + self.sequence_ids = self.sequence_ids.numpy() + self.passage_text = None + self.title = None + self.passage_token_ids = None + + def on_deserialize(self): + self.sequence_ids = torch.tensor(self.sequence_ids) + + +class ReaderSample(object): + """ + Container to collect all Q&A passages data per singe question + """ + + def __init__( + self, + question: str, + answers: List, + positive_passages: List[ReaderPassage] = [], + negative_passages: List[ReaderPassage] = [], + passages: List[ReaderPassage] = [], + ): + self.question = question + self.answers = answers + self.positive_passages = positive_passages + self.negative_passages = negative_passages + self.passages = passages + + def on_serialize(self): + for passage in self.passages + self.positive_passages + self.negative_passages: + passage.on_serialize() + + def on_deserialize(self): + for passage in self.passages + self.positive_passages + self.negative_passages: + passage.on_deserialize() + + +class ExtractiveReaderDataset(torch.utils.data.Dataset): + def __init__( + self, + files: str, + is_train: bool, + gold_passages_src: str, + tensorizer: Tensorizer, + run_preprocessing: bool, + num_workers: int, + ): + self.files = files + self.data = [] + self.is_train = is_train + self.gold_passages_src = gold_passages_src + self.tensorizer = tensorizer + self.run_preprocessing = run_preprocessing + self.num_workers = num_workers + + def __getitem__(self, index): + return self.data[index] + + def __len__(self): + return len(self.data) + + def load_data( + self, + ): + data_files = glob.glob(self.files) + logger.info("Data files: %s", data_files) + if not data_files: + raise RuntimeError("No Data files found") + preprocessed_data_files = self._get_preprocessed_files(data_files) + self.data = read_serialized_data_from_files(preprocessed_data_files) + + def _get_preprocessed_files( + self, + data_files: List, + ): + + serialized_files = [file for file in data_files if file.endswith(".pkl")] + if serialized_files: + return serialized_files + assert len(data_files) == 1, "Only 1 source file pre-processing is supported." + + # data may have been serialized and cached before, try to find ones from same dir + def _find_cached_files(path: str): + dir_path, base_name = os.path.split(path) + base_name = base_name.replace(".json", "") + out_file_prefix = os.path.join(dir_path, base_name) + out_file_pattern = out_file_prefix + "*.pkl" + return glob.glob(out_file_pattern), out_file_prefix + + serialized_files, out_file_prefix = _find_cached_files(data_files[0]) + if serialized_files: + logger.info("Found preprocessed files. %s", serialized_files) + return serialized_files + + logger.info( + "Data are not preprocessed for reader training. Start pre-processing ..." + ) + + # start pre-processing and save results + def _run_preprocessing(tensorizer: Tensorizer): + # temporarily disable auto-padding to save disk space usage of serialized files + tensorizer.set_pad_to_max(False) + serialized_files = convert_retriever_results( + self.is_train, + data_files[0], + out_file_prefix, + self.gold_passages_src, + self.tensorizer, + num_workers=self.num_workers, + ) + tensorizer.set_pad_to_max(True) + return serialized_files + + if self.run_preprocessing: + serialized_files = _run_preprocessing(self.tensorizer) + # TODO: check if pytorch process group is initialized + # torch.distributed.barrier() + else: + # torch.distributed.barrier() + serialized_files = _find_cached_files(data_files[0]) + return serialized_files + + +SpanPrediction = collections.namedtuple( + "SpanPrediction", + [ + "prediction_text", + "span_score", + "relevance_score", + "passage_index", + "passage_token_ids", + ], +) + +# configuration for reader model passage selection +ReaderPreprocessingCfg = collections.namedtuple( + "ReaderPreprocessingCfg", + [ + "use_tailing_sep", + "skip_no_positves", + "include_gold_passage", + "gold_page_only_positives", + "max_positives", + "max_negatives", + "min_negatives", + "max_retriever_passages", + ], +) + +DEFAULT_PREPROCESSING_CFG_TRAIN = ReaderPreprocessingCfg( + use_tailing_sep=False, + skip_no_positves=True, + include_gold_passage=False, + gold_page_only_positives=True, + max_positives=20, + max_negatives=50, + min_negatives=150, + max_retriever_passages=200, +) + +DEFAULT_EVAL_PASSAGES = 100 + + +def preprocess_retriever_data( + samples: List[Dict], + gold_info_file: Optional[str], + tensorizer: Tensorizer, + cfg: ReaderPreprocessingCfg = DEFAULT_PREPROCESSING_CFG_TRAIN, + is_train_set: bool = True, +) -> Iterable[ReaderSample]: + """ + Converts retriever results into reader training data. + :param samples: samples from the retriever's json file results + :param gold_info_file: optional path for the 'gold passages & questions' file. Required to get best results for NQ + :param tensorizer: Tensorizer object for text to model input tensors conversions + :param cfg: ReaderPreprocessingCfg object with positive and negative passage selection parameters + :param is_train_set: if the data should be processed as a train set + :return: iterable of ReaderSample objects which can be consumed by the reader model + """ + sep_tensor = tensorizer.get_pair_separator_ids() # separator can be a multi token + + gold_passage_map, canonical_questions = ( + _get_gold_ctx_dict(gold_info_file) if gold_info_file else ({}, {}) + ) + + no_positive_passages = 0 + positives_from_gold = 0 + + def create_reader_sample_ids(sample: ReaderPassage, question: str): + question_and_title = tensorizer.text_to_tensor( + sample.title, title=question, add_special_tokens=True + ) + if sample.passage_token_ids is None: + sample.passage_token_ids = tensorizer.text_to_tensor( + sample.passage_text, add_special_tokens=False + ) + + all_concatenated, shift = _concat_pair( + question_and_title, + sample.passage_token_ids, + tailing_sep=sep_tensor if cfg.use_tailing_sep else None, + ) + + sample.sequence_ids = all_concatenated + sample.passage_offset = shift + assert shift > 1 + if sample.has_answer and is_train_set: + sample.answers_spans = [ + (span[0] + shift, span[1] + shift) for span in sample.answers_spans + ] + return sample + + for sample in samples: + question = sample["question"] + + if question in canonical_questions: + question = canonical_questions[question] + + positive_passages, negative_passages = _select_reader_passages( + sample, + question, + tensorizer, + gold_passage_map, + cfg.gold_page_only_positives, + cfg.max_positives, + cfg.max_negatives, + cfg.min_negatives, + cfg.max_retriever_passages, + cfg.include_gold_passage, + is_train_set, + ) + # create concatenated sequence ids for each passage and adjust answer spans + positive_passages = [ + create_reader_sample_ids(s, question) for s in positive_passages + ] + negative_passages = [ + create_reader_sample_ids(s, question) for s in negative_passages + ] + + if is_train_set and len(positive_passages) == 0: + no_positive_passages += 1 + if cfg.skip_no_positves: + continue + + if next(iter(ctx for ctx in positive_passages if ctx.score == -1), None): + positives_from_gold += 1 + + if is_train_set: + yield ReaderSample( + question, + sample["answers"], + positive_passages=positive_passages, + negative_passages=negative_passages, + ) + else: + yield ReaderSample(question, sample["answers"], passages=negative_passages) + + logger.info("no positive passages samples: %d", no_positive_passages) + logger.info("positive passages from gold samples: %d", positives_from_gold) + + +def convert_retriever_results( + is_train_set: bool, + input_file: str, + out_file_prefix: str, + gold_passages_file: str, + tensorizer: Tensorizer, + num_workers: int = 8, +) -> List[str]: + """ + Converts the file with dense retriever(or any compatible file format) results into the reader input data and + serializes them into a set of files. + Conversion splits the input data into multiple chunks and processes them in parallel. Each chunk results are stored + in a separate file with name out_file_prefix.{number}.pkl + :param is_train_set: if the data should be processed for a train set (i.e. with answer span detection) + :param input_file: path to a json file with data to convert + :param out_file_prefix: output path prefix. + :param gold_passages_file: optional path for the 'gold passages & questions' file. Required to get best results for NQ + :param tensorizer: Tensorizer object for text to model input tensors conversions + :param num_workers: the number of parallel processes for conversion + :return: names of files with serialized results + """ + with open(input_file, "r", encoding="utf-8") as f: + samples = json.loads("".join(f.readlines())) + logger.info( + "Loaded %d questions + retrieval results from %s", len(samples), input_file + ) + workers = multiprocessing.Pool(num_workers) + ds_size = len(samples) + step = max(math.ceil(ds_size / num_workers), 1) + chunks = [samples[i : i + step] for i in range(0, ds_size, step)] + chunks = [(i, chunks[i]) for i in range(len(chunks))] + + logger.info("Split data into %d chunks", len(chunks)) + + processed = 0 + _parse_batch = partial( + _preprocess_reader_samples_chunk, + out_file_prefix=out_file_prefix, + gold_passages_file=gold_passages_file, + tensorizer=tensorizer, + is_train_set=is_train_set, + ) + serialized_files = [] + for file_name in workers.map(_parse_batch, chunks): + processed += 1 + serialized_files.append(file_name) + logger.info("Chunks processed %d", processed) + logger.info("Data saved to %s", file_name) + logger.info("Preprocessed data stored in %s", serialized_files) + return serialized_files + + +def get_best_spans( + tensorizer: Tensorizer, + start_logits: List, + end_logits: List, + ctx_ids: List, + max_answer_length: int, + passage_idx: int, + relevance_score: float, + top_spans: int = 1, +) -> List[SpanPrediction]: + """ + Finds the best answer span for the extractive Q&A model + """ + scores = [] + for (i, s) in enumerate(start_logits): + for (j, e) in enumerate(end_logits[i : i + max_answer_length]): + scores.append(((i, i + j), s + e)) + + scores = sorted(scores, key=lambda x: x[1], reverse=True) + + chosen_span_intervals = [] + best_spans = [] + + for (start_index, end_index), score in scores: + assert start_index <= end_index + length = end_index - start_index + 1 + assert length <= max_answer_length + + if any( + [ + start_index <= prev_start_index <= prev_end_index <= end_index + or prev_start_index <= start_index <= end_index <= prev_end_index + for (prev_start_index, prev_end_index) in chosen_span_intervals + ] + ): + continue + + # extend bpe subtokens to full tokens + start_index, end_index = _extend_span_to_full_words( + tensorizer, ctx_ids, (start_index, end_index) + ) + + predicted_answer = tensorizer.to_string(ctx_ids[start_index : end_index + 1]) + best_spans.append( + SpanPrediction( + predicted_answer, score, relevance_score, passage_idx, ctx_ids + ) + ) + chosen_span_intervals.append((start_index, end_index)) + + if len(chosen_span_intervals) == top_spans: + break + return best_spans + + +def _select_reader_passages( + sample: Dict, + question: str, + tensorizer: Tensorizer, + gold_passage_map: Dict[str, ReaderPassage], + gold_page_only_positives: bool, + max_positives: int, + max1_negatives: int, + max2_negatives: int, + max_retriever_passages: int, + include_gold_passage: bool, + is_train_set: bool, +) -> Tuple[List[ReaderPassage], List[ReaderPassage]]: + answers = sample["answers"] + + ctxs = [ReaderPassage(**ctx) for ctx in sample["ctxs"]][0:max_retriever_passages] + answers_token_ids = [ + tensorizer.text_to_tensor(a, add_special_tokens=False) for a in answers + ] + + if is_train_set: + positive_samples = list(filter(lambda ctx: ctx.has_answer, ctxs)) + negative_samples = list(filter(lambda ctx: not ctx.has_answer, ctxs)) + else: + positive_samples = [] + negative_samples = ctxs + + positive_ctxs_from_gold_page = ( + list( + filter( + lambda ctx: _is_from_gold_wiki_page( + gold_passage_map, ctx.title, question + ), + positive_samples, + ) + ) + if gold_page_only_positives + else [] + ) + + def find_answer_spans(ctx: ReaderPassage): + if ctx.has_answer: + if ctx.passage_token_ids is None: + ctx.passage_token_ids = tensorizer.text_to_tensor( + ctx.passage_text, add_special_tokens=False + ) + + answer_spans = [ + _find_answer_positions(ctx.passage_token_ids, answers_token_ids[i]) + for i in range(len(answers)) + ] + + # flatten spans list + answer_spans = [item for sublist in answer_spans for item in sublist] + answers_spans = list(filter(None, answer_spans)) + ctx.answers_spans = answers_spans + + if not answers_spans: + logger.warning( + "No answer found in passage id=%s text=%s, answers=%s, question=%s", + ctx.id, + ctx.passage_text, + answers, + question, + ) + + ctx.has_answer = bool(answers_spans) + + return ctx + + # check if any of the selected ctx+ has answer spans + selected_positive_ctxs = list( + filter( + lambda ctx: ctx.has_answer, + [find_answer_spans(ctx) for ctx in positive_ctxs_from_gold_page], + ) + ) + + if not selected_positive_ctxs: # fallback to positive ctx not from gold pages + selected_positive_ctxs = list( + filter( + lambda ctx: ctx.has_answer, + [find_answer_spans(ctx) for ctx in positive_samples], + ) + )[0:max_positives] + + # optionally include gold passage itself if it is still not in the positives list + if include_gold_passage and question in gold_passage_map: + gold_passage = gold_passage_map[question] + included_gold_passage = next( + iter(ctx for ctx in selected_positive_ctxs if ctx.id == gold_passage.id), + None, + ) + if not included_gold_passage: + gold_passage = find_answer_spans(gold_passage) + if not gold_passage.has_answer: + logger.warning("No answer found in gold passage %s", gold_passage) + else: + selected_positive_ctxs.append(gold_passage) + + max_negatives = ( + min(max(10 * len(selected_positive_ctxs), max1_negatives), max2_negatives) + if is_train_set + else DEFAULT_EVAL_PASSAGES + ) + negative_samples = negative_samples[0:max_negatives] + return selected_positive_ctxs, negative_samples + + +def _find_answer_positions(ctx_ids: T, answer: T) -> List[Tuple[int, int]]: + c_len = ctx_ids.size(0) + a_len = answer.size(0) + answer_occurences = [] + for i in range(0, c_len - a_len + 1): + if (answer == ctx_ids[i : i + a_len]).all(): + answer_occurences.append((i, i + a_len - 1)) + return answer_occurences + + +def _concat_pair(t1: T, t2: T, middle_sep: T = None, tailing_sep: T = None): + middle = [middle_sep] if middle_sep else [] + r = [t1] + middle + [t2] + ([tailing_sep] if tailing_sep else []) + return torch.cat(r, dim=0), t1.size(0) + len(middle) + + +def _get_gold_ctx_dict(file: str) -> Tuple[Dict[str, ReaderPassage], Dict[str, str]]: + gold_passage_infos = ( + {} + ) # question|question_tokens -> ReaderPassage (with title and gold ctx) + + # original NQ dataset has 2 forms of same question - original, and tokenized. + # Tokenized form is not fully consisted with the original question if tokenized by some encoder tokenizers + # Specifically, this is the case for the BERT tokenizer. + # Depending of which form was used for retriever training and results generation, it may be useful to convert + # all questions to the canonical original representation. + original_questions = {} # question from tokens -> original question (NQ only) + + with open(file, "r", encoding="utf-8") as f: + logger.info("Reading file %s" % file) + data = json.load(f)["data"] + + for sample in data: + question = sample["question"] + question_from_tokens = ( + sample["question_tokens"] if "question_tokens" in sample else question + ) + original_questions[question_from_tokens] = question + title = sample["title"].lower() + context = sample["context"] # Note: This one is cased + rp = ReaderPassage(sample["example_id"], text=context, title=title) + if question in gold_passage_infos: + logger.info("Duplicate question %s", question) + rp_exist = gold_passage_infos[question] + logger.info( + "Duplicate question gold info: title new =%s | old title=%s", + title, + rp_exist.title, + ) + logger.info("Duplicate question gold info: new ctx =%s ", context) + logger.info( + "Duplicate question gold info: old ctx =%s ", rp_exist.passage_text + ) + + gold_passage_infos[question] = rp + gold_passage_infos[question_from_tokens] = rp + return gold_passage_infos, original_questions + + +def _is_from_gold_wiki_page( + gold_passage_map: Dict[str, ReaderPassage], passage_title: str, question: str +): + gold_info = gold_passage_map.get(question, None) + if gold_info: + return passage_title.lower() == gold_info.title.lower() + return False + + +def _extend_span_to_full_words( + tensorizer: Tensorizer, tokens: List[int], span: Tuple[int, int] +) -> Tuple[int, int]: + start_index, end_index = span + max_len = len(tokens) + while start_index > 0 and tensorizer.is_sub_word_id(tokens[start_index]): + start_index -= 1 + + while end_index < max_len - 1 and tensorizer.is_sub_word_id(tokens[end_index + 1]): + end_index += 1 + + return start_index, end_index + + +def _preprocess_reader_samples_chunk( + samples: List, + out_file_prefix: str, + gold_passages_file: str, + tensorizer: Tensorizer, + is_train_set: bool, +) -> str: + chunk_id, samples = samples + logger.info("Start batch %d", len(samples)) + iterator = preprocess_retriever_data( + samples, + gold_passages_file, + tensorizer, + is_train_set=is_train_set, + ) + + results = [] + + iterator = tqdm(iterator) + for i, r in enumerate(iterator): + r.on_serialize() + results.append(r) + + out_file = out_file_prefix + "." + str(chunk_id) + ".pkl" + with open(out_file, mode="wb") as f: + logger.info("Serialize %d results to %s", len(results), out_file) + pickle.dump(results, f) + return out_file diff --git a/research/information_retrieval/DPR/dpr/data/retriever_data.py b/research/information_retrieval/DPR/dpr/data/retriever_data.py new file mode 100644 index 00000000000..628e6c6ce75 --- /dev/null +++ b/research/information_retrieval/DPR/dpr/data/retriever_data.py @@ -0,0 +1,337 @@ +import collections +import csv +import json +import logging +import pickle +from typing import Dict + +import hydra +import jsonlines +import torch +from omegaconf import DictConfig + +from dpr.data.biencoder_data import ( + BiEncoderPassage, + normalize_passage, + normalize_question, + get_dpr_files, + read_nq_tables_jsonl, + split_tables_to_chunks, +) + +logger = logging.getLogger(__name__) +QASample = collections.namedtuple("QuerySample", ["query", "id", "answers"]) +TableChunk = collections.namedtuple("TableChunk", ["text", "title", "table_id"]) + + +class RetrieverData(torch.utils.data.Dataset): + def __init__(self, file: str): + """ + :param file: - real file name or the resource name as they are defined in download_data.py + """ + self.file = file + self.data_files = [] + + def load_data(self): + self.data_files = get_dpr_files(self.file) + assert ( + len(self.data_files) == 1 + ), "RetrieverData source currently works with single files only. Files specified: {}".format( + self.data_files + ) + self.file = self.data_files[0] + + +class QASrc(RetrieverData): + def __init__( + self, + file: str, + selector: DictConfig = None, + special_query_token: str = None, + query_special_suffix: str = None, + ): + super().__init__(file) + self.data = None + self.selector = hydra.utils.instantiate(selector) if selector else None + self.special_query_token = special_query_token + self.query_special_suffix = query_special_suffix + + def __getitem__(self, index) -> QASample: + return self.data[index] + + def __len__(self): + return len(self.data) + + def _process_question(self, question: str): + # as of now, always normalize query + question = normalize_question(question) + if self.query_special_suffix and not question.endswith( + self.query_special_suffix + ): + question += self.query_special_suffix + return question + + +class CsvQASrc(QASrc): + def __init__( + self, + file: str, + question_col: int = 0, + answers_col: int = 1, + id_col: int = -1, + selector: DictConfig = None, + special_query_token: str = None, + query_special_suffix: str = None, + ): + super().__init__(file, selector, special_query_token, query_special_suffix) + self.question_col = question_col + self.answers_col = answers_col + self.id_col = id_col + + def load_data(self): + super().load_data() + data = [] + with open(self.file) as ifile: + reader = csv.reader(ifile, delimiter="\t") + for row in reader: + question = row[self.question_col] + answers = eval(row[self.answers_col]) + id = None + if self.id_col >= 0: + id = row[self.id_col] + data.append(QASample(self._process_question(question), id, answers)) + self.data = data + + +class JsonlQASrc(QASrc): + def __init__( + self, + file: str, + selector: DictConfig = None, + question_attr: str = "question", + answers_attr: str = "answers", + id_attr: str = "id", + special_query_token: str = None, + query_special_suffix: str = None, + ): + super().__init__(file, selector, special_query_token, query_special_suffix) + self.question_attr = question_attr + self.answers_attr = answers_attr + self.id_attr = id_attr + + def load_data(self): + super().load_data() + data = [] + with jsonlines.open(self.file, mode="r") as jsonl_reader: + for jline in jsonl_reader: + question = jline[self.question_attr] + answers = jline[self.answers_attr] if self.answers_attr in jline else [] + id = None + if self.id_attr in jline: + id = jline[self.id_attr] + data.append(QASample(self._process_question(question), id, answers)) + self.data = data + + +class KiltCsvQASrc(CsvQASrc): + def __init__( + self, + file: str, + kilt_gold_file: str, + question_col: int = 0, + answers_col: int = 1, + id_col: int = -1, + selector: DictConfig = None, + special_query_token: str = None, + query_special_suffix: str = None, + ): + super().__init__( + file, + question_col, + answers_col, + id_col, + selector, + special_query_token, + query_special_suffix, + ) + self.kilt_gold_file = kilt_gold_file + + +class KiltJsonlQASrc(JsonlQASrc): + def __init__( + self, + file: str, + kilt_gold_file: str, + question_attr: str = "input", + answers_attr: str = "answer", + id_attr: str = "id", + selector: DictConfig = None, + special_query_token: str = None, + query_special_suffix: str = None, + ): + super().__init__( + file, + selector, + question_attr, + answers_attr, + id_attr, + special_query_token, + query_special_suffix, + ) + self.kilt_gold_file = kilt_gold_file + + def load_data(self): + super().load_data() + data = [] + with jsonlines.open(self.file, mode="r") as jsonl_reader: + for jline in jsonl_reader: + question = jline[self.question_attr] + out = jline["output"] + answers = [o["answer"] for o in out if "answer" in o] + id = None + if self.id_attr in jline: + id = jline[self.id_attr] + data.append(QASample(self._process_question(question), id, answers)) + self.data = data + + +class TTS_ASR_QASrc(QASrc): + def __init__(self, file: str, trans_file: str): + super().__init__(file) + self.trans_file = trans_file + + def load_data(self): + super().load_data() + orig_data_dict = {} + with open(self.file, "r") as ifile: + reader = csv.reader(ifile, delimiter="\t") + id = 0 + for row in reader: + question = row[0] + answers = eval(row[1]) + orig_data_dict[id] = (question, answers) + id += 1 + data = [] + with open(self.trans_file, "r") as tfile: + reader = csv.reader(tfile, delimiter="\t") + for r in reader: + row_str = r[0] + idx = row_str.index("(None-") + q_id = int(row_str[idx + len("(None-") : -1]) + orig_data = orig_data_dict[q_id] + answers = orig_data[1] + q = row_str[:idx].strip().lower() + data.append(QASample(q, idx, answers)) + self.data = data + + +class CsvCtxSrc(RetrieverData): + def __init__( + self, + file: str, + id_col: int = 0, + text_col: int = 1, + title_col: int = 2, + id_prefix: str = None, + normalize: bool = False, + ): + super().__init__(file) + self.text_col = text_col + self.title_col = title_col + self.id_col = id_col + self.id_prefix = id_prefix + self.normalize = normalize + + def load_data_to(self, ctxs: Dict[object, BiEncoderPassage]): + super().load_data() + with open(self.file) as ifile: + reader = csv.reader(ifile, delimiter="\t") + for row in reader: + if row[self.id_col] == "id": + continue + if self.id_prefix: + sample_id = self.id_prefix + str(row[self.id_col]) + else: + sample_id = row[self.id_col] + passage = row[self.text_col] + if self.normalize: + passage = normalize_passage(passage) + ctxs[sample_id] = BiEncoderPassage(passage, row[self.title_col]) + + +class KiltCsvCtxSrc(CsvCtxSrc): + def __init__( + self, + file: str, + mapping_file: str, + id_col: int = 0, + text_col: int = 1, + title_col: int = 2, + id_prefix: str = None, + normalize: bool = False, + ): + super().__init__( + file, id_col, text_col, title_col, id_prefix, normalize=normalize + ) + self.mapping_file = mapping_file + + def convert_to_kilt(self, kilt_gold_file, dpr_output, kilt_out_file): + logger.info("Converting to KILT format file: %s", dpr_output) + + with open(dpr_output, "rt") as fin: + dpr_output = json.load(fin) + + with jsonlines.open(kilt_gold_file, "r") as reader: + kilt_gold_file = list(reader) + assert len(kilt_gold_file) == len(dpr_output) + map_path = self.mapping_file + with open(map_path, "rb") as fin: + mapping = pickle.load(fin) + + with jsonlines.open(kilt_out_file, mode="w") as writer: + for dpr_entry, kilt_gold_entry in zip(dpr_output, kilt_gold_file): + assert dpr_entry["question"] == kilt_gold_entry["input"] + provenance = [] + for ctx in dpr_entry["ctxs"]: + wikipedia_id, end_paragraph_id = mapping[int(ctx["id"])] + provenance.append( + { + "wikipedia_id": wikipedia_id, + "end_paragraph_id": end_paragraph_id, + } + ) + kilt_entry = { + "id": kilt_gold_entry["id"], + "input": dpr_entry["question"], + "output": [{"provenance": provenance}], + } + writer.write(kilt_entry) + + logger.info("Saved KILT formatted results to: %s", kilt_out_file) + + +class JsonlTablesCtxSrc(object): + def __init__( + self, + file: str, + tables_chunk_sz: int = 100, + split_type: str = "type1", + id_prefix: str = None, + ): + self.tables_chunk_sz = tables_chunk_sz + self.split_type = split_type + self.file = file + self.id_prefix = id_prefix + + def load_data_to(self, ctxs: Dict): + docs = {} + logger.info("Parsing Tables data from: %s", self.file) + tables_dict = read_nq_tables_jsonl(self.file) + table_chunks = split_tables_to_chunks( + tables_dict, self.tables_chunk_sz, split_type=self.split_type + ) + for chunk in table_chunks: + sample_id = self.id_prefix + str(chunk[0]) + docs[sample_id] = TableChunk(chunk[1], chunk[2], chunk[3]) + logger.info("Loaded %d tables chunks", len(docs)) + ctxs.update(docs) diff --git a/research/information_retrieval/DPR/dpr/data/tables.py b/research/information_retrieval/DPR/dpr/data/tables.py new file mode 100644 index 00000000000..f9622ab1916 --- /dev/null +++ b/research/information_retrieval/DPR/dpr/data/tables.py @@ -0,0 +1,674 @@ +import collections +import csv +import json +import logging +import re +import unicodedata + +import jsonlines +import spacy as spacy +from typing import List, Dict + + +logger = logging.getLogger() +logger.setLevel(logging.INFO) +if logger.hasHandlers(): + logger.handlers.clear() + +log_formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s") +console = logging.StreamHandler() +console.setFormatter(log_formatter) + +logger.addHandler(console) + + +class Cell: + def __init__(self): + self.value_tokens: List[str] = [] + self.type: str = "" + self.nested_tables: List[Table] = [] + + def __str__(self): + return " ".join(self.value_tokens) + + def to_dpr_json(self, cell_idx: int): + r = {"col": cell_idx} + r["value"] = str(self) + return r + + +class Row: + def __init__(self): + self.cells: List[Cell] = [] + + def __str__(self): + return "| ".join([str(c) for c in self.cells]) + + def visit(self, tokens_function, row_idx: int): + for i, c in enumerate(self.cells): + if c.value_tokens: + tokens_function(c.value_tokens, row_idx, i) + + def to_dpr_json(self, row_idx: int): + r = {"row": row_idx} + r["columns"] = [c.to_dpr_json(i) for i, c in enumerate(self.cells)] + return r + + +class Table(object): + def __init__(self, caption=""): + self.caption = caption + self.body: List[Row] = [] + self.key = None + self.gold_match = False + + def __str__(self): + table_str = ": {}\n".format(self.caption) + table_str += " rows:\n" + for i, r in enumerate(self.body): + table_str += " row #{}: {}\n".format(i, str(r)) + + return table_str + + def get_key(self) -> str: + if not self.key: + self.key = str(self) + return self.key + + def visit(self, tokens_function, include_caption: bool = False) -> bool: + if include_caption: + tokens_function(self.caption, -1, -1) + for i, r in enumerate(self.body): + r.visit(tokens_function, i) + + def to_dpr_json(self): + r = { + "caption": self.caption, + "rows": [r.to_dpr_json(i) for i, r in enumerate(self.body)], + } + if self.gold_match: + r["gold_match"] = 1 + return r + + +class NQTableParser(object): + def __init__(self, tokens, is_html_mask, title): + self.tokens = tokens + self.is_html_mask = is_html_mask + self.max_idx = len(self.tokens) + self.all_tables = [] + + self.current_table: Table = None + self.tables_stack = collections.deque() + self.title = title + + def parse(self) -> List[Table]: + self.all_tables = [] + self.tables_stack = collections.deque() + + for i in range(self.max_idx): + + t = self.tokens[i] + + if not self.is_html_mask[i]: + # cell content + self._on_content(t) + continue + + if "": + self._on_table_end() + elif "": + self._onRowEnd() + elif "", ""]: + self._on_cell_end() + + return self.all_tables + + def _on_table_start(self): + caption = self.title + parent_table = self.current_table + if parent_table: + self.tables_stack.append(parent_table) + + caption = parent_table.caption + if parent_table.body and parent_table.body[-1].cells: + current_cell = self.current_table.body[-1].cells[-1] + caption += " | " + " ".join(current_cell.value_tokens) + + t = Table() + t.caption = caption + self.current_table = t + self.all_tables.append(t) + + def _on_table_end(self): + t = self.current_table + if t: + if self.tables_stack: # t is a nested table + self.current_table = self.tables_stack.pop() + if self.current_table.body: + current_cell = self.current_table.body[-1].cells[-1] + current_cell.nested_tables.append(t) + else: + logger.error("table end without table object") + + def _onRowStart(self): + self.current_table.body.append(Row()) + + def _onRowEnd(self): + pass + + def _onCellStart(self): + current_row = self.current_table.body[-1] + current_row.cells.append(Cell()) + + def _on_cell_end(self): + pass + + def _on_content(self, token): + if self.current_table.body: + current_row = self.current_table.body[-1] + current_cell = current_row.cells[-1] + current_cell.value_tokens.append(token) + else: # tokens outside of row/cells. Just append to the table caption. + self.current_table.caption += " " + token + + +def read_nq_tables_jsonl(path: str, out_file: str = None) -> Dict[str, Table]: + tables_with_issues = 0 + single_row_tables = 0 + nested_tables = 0 + regular_tables = 0 + total_tables = 0 + total_rows = 0 + tables_dict = {} + + with jsonlines.open(path, mode="r") as jsonl_reader: + for jline in jsonl_reader: + tokens = jline["tokens"] + + if "( hide ) This section has multiple issues" in " ".join(tokens): + tables_with_issues += 1 + continue + # if '' in tokens[1:]: + # nested_tables += 1 + + mask = jline["html_mask"] + page_url = jline["doc_url"] + title = jline["title"] + # logger.info('Table from page %s', title) + # logger.info('tokens len %s', len(tokens)) + # logger.info('tokens %s', tokens) + # logger.info('page_url %s', page_url) + p = NQTableParser(tokens, mask, title) + tables = p.parse() + + # logger.info('parsed tables %d', len(tables)) + + # table = parse_table(tokens, mask) + nested_tables += len(tables[1:]) + + for t in tables: + # logger.info('Table: %s', t) + total_tables += 1 + + # calc amount of non empty rows + non_empty_rows = sum( + [ + 1 + for r in t.body + if r.cells and any([True for c in r.cells if c.value_tokens]) + ] + ) + + if non_empty_rows <= 1: + single_row_tables += 1 + else: + regular_tables += 1 + total_rows += len(t.body) + + if t.get_key() not in tables_dict: + tables_dict[t.get_key()] = t + + if len(tables_dict) % 1000 == 0: + logger.info("tables_dict %d", len(tables_dict)) + + print("regular tables", regular_tables) + print("tables_with_issues", tables_with_issues) + print("single_row_tables", single_row_tables) + print("nested_tables", nested_tables) + if out_file: + convert_to_csv_for_lucene(tables_dict, out_file) + return tables_dict + + +def get_table_string_for_answer_check(table: Table): # this doesn't use caption + table_text = "" + for r in table.body: + table_text += " . ".join([" ".join(c.value_tokens) for c in r.cells]) + table_text += " . " + return table_text + + +def convert_to_csv_for_lucene(tables_dict, out_file: str): + id = 0 + with open(out_file, "w", newline="") as csvfile: + writer = csv.writer(csvfile, delimiter="\t") + for _, v in tables_dict.items(): + id += 1 + # strip all + table_text = get_table_string_for_answer_check(v) + writer.writerow([id, table_text, v.caption]) + logger.info("Saved to %s", out_file) + + +def convert_jsonl_to_qas_tsv(path, out): + results = [] + with jsonlines.open(path, mode="r") as jsonl_reader: + for jline in jsonl_reader: + q = jline["question"] + answers = [] + if "short_answers" in jline: + answers = jline["short_answers"] + + results.append((q, answers)) + + with open(out, "w", newline="") as csvfile: + writer = csv.writer(csvfile, delimiter="\t") + for r in results: + writer.writerow([r[0], r[1]]) + + logger.info("Saved to %s", out) + + +nlp = spacy.load("en_core_web_sm", disable=["parser", "tagger", "ner", "entity_ruler"]) + + +def tokenize(text): + doc = nlp(text) + return [token.text.lower() for token in doc] + + +def normalize(text): + """Resolve different type of unicode encodings.""" + return unicodedata.normalize("NFD", text) + + +def prepare_answers(answers) -> List[List[str]]: + r = [] + for single_answer in answers: + single_answer = normalize(single_answer) + single_answer = single_answer.lower().split(" ") # tokenize(single_answer) + r.append(single_answer) + return r + + +def has_prepared_answer(prep_answers: List[List[str]], text): + """Check if a document contains an answer string.""" + text = normalize(text) + # Answer is a list of possible strings + text = tokenize(text) + for single_answer in prep_answers: + for i in range(0, len(text) - len(single_answer) + 1): + if single_answer == text[i : i + len(single_answer)]: + return True + return False + + +def has_prepared_answer2(prep_answers: List[List[str]], text: List[str]): + text = [normalize(token).lower() for token in text] + + # text = [item for sublist in text for item in sublist] + + # text = ' '.join(text) + # text = normalize(text) + # text = tokenize(text) + + for single_answer in prep_answers: + for i in range(0, len(text) - len(single_answer) + 1): + if single_answer == text[i : i + len(single_answer)]: + return True + return False + + +def has_answer(answers, text, regMatxh=False): + """Check if a document contains an answer string.""" + + text = normalize(text) + + if regMatxh: + single_answer = normalize(answers[0]) + if regex_match(text, single_answer): + return True + else: + # Answer is a list of possible strings + text = tokenize(text) + + for single_answer in answers: + single_answer = normalize(single_answer) + single_answer = tokenize(single_answer) + + for i in range(0, len(text) - len(single_answer) + 1): + if single_answer == text[i : i + len(single_answer)]: + return True + return False + + +def convert_search_res_to_dpr_and_eval( + res_file, all_tables_file_jsonl, nq_table_file, out_file, gold_res_file: str = None +): + db = {} + id = 0 + tables_dict = read_nq_tables_jsonl(all_tables_file_jsonl) + for _, v in tables_dict.items(): + id += 1 + db[id] = v + + logger.info("db size %s", len(db)) + total = 0 + dpr_results = {} + import torch + + bm25_per_topk_hits = torch.tensor([0] * 100) + qas = [] + with open(res_file) as tsvfile: + reader = csv.reader(tsvfile, delimiter="\t") + # file format: id, text + for row in reader: + total += 1 + q = row[0] + answers = eval(row[1]) + + prep_answers = prepare_answers(answers) + qas.append((q, prep_answers)) + # logger.info('question %s', q) + + question_hns = [] + question_positives = [] + answers_table_links = [] + + for k, bm25result in enumerate(row[2:]): + score, id = bm25result.split(",") + table = db[int(id)] + + answer_locations = [] + + def check_answer(tokens, row_idx: int, cell_idx: int): + if has_prepared_answer2(prep_answers, tokens): + answer_locations.append((row_idx, cell_idx)) + + # logger.info('table %s', table) + + # get string representation to find answer + if (len(question_positives) >= 10 and len(question_hns) >= 10) or ( + len(question_hns) >= 30 + ): + break + + # table_str = get_table_string_for_answer_check(table) + table.visit(check_answer) + has_answer = len(answer_locations) > 0 + + if has_answer: + # has_answer(answers, table.key) + # has_answer(answers, get_table_string_for_answer_check(table)) + # bm25_per_topk_hits[k:] += 1 + + question_positives.append(table) + answers_table_links.append(answer_locations) + # break + else: + question_hns.append(table) + + dpr_results[q] = (question_positives, question_hns, answers_table_links) + if len(dpr_results) % 100 == 0: + logger.info("dpr_results %s", len(dpr_results)) + + logger.info("dpr_results size %s", len(dpr_results)) + logger.info("total %s", total) + logger.info("bm25_per_topk_hits %s", bm25_per_topk_hits) + + if gold_res_file: + logger.info("Processing gold_res_file") + with open(gold_res_file) as cFile: + csvReader = csv.reader(cFile, delimiter=",") + for row in csvReader: + q_id = int(row[0]) + qas_tuple = qas[q_id] + prep_answers = qas_tuple[1] + question_gold_positive_match = None + q = qas_tuple[0] + + # logger.info("q=%s q_id=%s", q, q_id) + answers_links = None + for field in row[1:]: + psg_id = int(field.split()[0]) + # logger.info("psg_id=%s", psg_id) + # if psg_id >= len(db): + # continue + table = db[psg_id] + answer_locations = [] + + def check_answer(tokens, row_idx: int, cell_idx: int): + if has_prepared_answer2(prep_answers, tokens): + answer_locations.append((row_idx, cell_idx)) + + table.visit(check_answer) + has_answer = len(answer_locations) > 0 + if has_answer and question_gold_positive_match is None: + question_gold_positive_match = table + question_gold_positive_match.gold_match = True + answers_links = answer_locations + + if question_gold_positive_match is None: + logger.info("No gold match for q=%s, q_id=%s", q, q_id) + else: # inject into ctx+ at the first position + question_positives, hns, ans_links = dpr_results[q] + question_positives.insert(0, question_gold_positive_match) + ans_links.insert(0, answers_links) + + # return + out_results = [] + with jsonlines.open(nq_table_file, mode="r") as jsonl_reader: + for jline in jsonl_reader: + q = jline["question"] + gold_positive_table = jline["contexts"][0] + mask = gold_positive_table["html_mask"] + # page_url = jline['doc_url'] + title = jline["title"] + p = NQTableParser(gold_positive_table["tokens"], mask, title) + tables = p.parse() + # select the one with the answer(s) + prep_answers = prepare_answers(jline["short_answers"]) + + tables_with_answers = [] + tables_answer_locations = [] + + for t in tables: + answer_locations = [] + + def check_answer(tokens, row_idx: int, cell_idx: int): + if has_prepared_answer2(prep_answers, tokens): + answer_locations.append((row_idx, cell_idx)) + + t.visit(check_answer) + has_answer = len(answer_locations) > 0 + if has_answer: + tables_with_answers.append(t) + tables_answer_locations.append(answer_locations) + + if not tables_with_answers: + logger.info("No answer in gold table(s) for q=%s", q) + # tables_with_answers.append(tables[0]) + + positive_ctxs, hard_neg_ctxs, answers_table_links = dpr_results[q] + positive_ctxs = positive_ctxs + tables_with_answers + tables_answer_locations = answers_table_links + tables_answer_locations + assert len(positive_ctxs) == len(tables_answer_locations) + positive_ctxs = [t.to_dpr_json() for t in positive_ctxs] + + # set has_answer attributes + for i, ctx_json in enumerate(positive_ctxs): + answer_links = tables_answer_locations[i] + ctx_json["answer_pos"] = answer_links + hard_neg_ctxs = [t.to_dpr_json() for t in hard_neg_ctxs] + out_results.append( + { + "question": q, + "id": jline["example_id"], + "answers": jline["short_answers"], + "positive_ctxs": positive_ctxs, + "hard_negative_ctxs": hard_neg_ctxs, + } + ) + + logger.info("out_results size %s", len(out_results)) + + with jsonlines.open( + out_file, mode="w" + ) as writer: # encoding="utf-8", .encode('utf-8') + for r in out_results: + writer.write(r) + + # with open(out_file, "w") as writer: + # writer.write(json.dumps(out_results, indent=4) + "\n") # indent=4 + + logger.info("Saved to %s", out_file) + + +def convert_long_ans_to_dpr(nq_table_file, out_file): + out_results = [] + with jsonlines.open(nq_table_file, mode="r") as jsonl_reader: + for jline in jsonl_reader: + q = jline["question"] + + gold_positive_table = jline["contexts"] + + mask = gold_positive_table["la_ans_tokens_html_mask"] + # page_url = jline['doc_url'] + title = jline["title"] + + p = NQTableParser(gold_positive_table["la_ans_tokens"], mask, title) + tables = p.parse() + # select the one with the answer(s) + + positive_ctxs = [tables[0].to_dpr_json()] + + out_results.append( + { + "question": q, + "id": jline["example_id"], + "answers": [], + "positive_ctxs": positive_ctxs, + "hard_negative_ctxs": [], + } + ) + + logger.info("out_results size %s", len(out_results)) + + with jsonlines.open( + out_file, mode="w" + ) as writer: # encoding="utf-8", .encode('utf-8') + for r in out_results: + writer.write(r) + + logger.info("Saved to %s", out_file) + + +def parse_qa_csv_file(location): + res = [] + with open(location) as ifile: + reader = csv.reader(ifile, delimiter="\t") + for row in reader: + question = row[0] + answers = eval(row[1]) + res.append((question, answers)) + return res + + +def calc_questions_overlap(tables_file, regular_file, dev_file): + tab_questions = set() + + with jsonlines.open(tables_file, mode="r") as jsonl_reader: + logger.info("Reading file %s" % tables_file) + for jline in jsonl_reader: + q = jline["question"] + tab_questions.add(q) + + reg_questions = set() + + if regular_file[-4:] == ".csv": + qas = parse_qa_csv_file(regular_file) + for qa in qas: + reg_questions.add(qa[0]) + else: + with open(regular_file, "r", encoding="utf-8") as f: + logger.info("Reading file %s" % regular_file) + data = json.load(f) + for item in data: + q = item["question"] + reg_questions.add(q) + if dev_file: + if dev_file[-4:] == ".csv": + qas = parse_qa_csv_file(dev_file) + for qa in qas: + reg_questions.add(qa[0]) + else: + with open(dev_file, "r", encoding="utf-8") as f: + logger.info("Reading file %s" % dev_file) + data = json.load(f) + for item in data: + q = item["question"] + reg_questions.add(q) + + logger.info("tab_questions %d", len(tab_questions)) + logger.info("reg_questions %d", len(reg_questions)) + logger.info("overlap %d", len(tab_questions.intersection(reg_questions))) + + +def convert_train_jsonl_to_ctxmatch(path: str, out_file: str): + def get_table_string_for_ctx_match(table: dict): # this doesn't use caption + table_text = table["caption"] + " . " + for r in table["rows"]: + table_text += " . ".join([c["value"] for c in r["columns"]]) + table_text += " . " + return table_text + + results = [] + with jsonlines.open(path, mode="r") as jsonl_reader: + for jline in jsonl_reader: + if len(jline["positive_ctxs"]) == 0: + continue + ctx_pos = jline["positive_ctxs"][0] + table_str = get_table_string_for_ctx_match(ctx_pos) + q = jline["question"] + results.append((q, table_str)) + + if len(results) % 1000 == 0: + logger.info("results %d", len(results)) + + shards_sz = 3000 + shard = 0 + + for s in range(0, len(results), shards_sz): + chunk = results[s : s + shards_sz] + shard_file = out_file + ".shard_{}".format(shard) + with jsonlines.open(shard_file, mode="w") as writer: + logger.info("Saving to %s", shard_file) + for i, item in enumerate(chunk): + writer.write({"id": s + i, "question": item[0], "context": item[1]}) + shard += 1 + + +def regex_match(text, pattern): + """Test if a regex pattern is contained within a text.""" + try: + pattern = re.compile(pattern, flags=re.IGNORECASE + re.UNICODE + re.MULTILINE) + except BaseException: + return False + return pattern.search(text) is not None diff --git a/research/information_retrieval/DPR/dpr/indexer/faiss_indexers.py b/research/information_retrieval/DPR/dpr/indexer/faiss_indexers.py new file mode 100644 index 00000000000..edb793709e7 --- /dev/null +++ b/research/information_retrieval/DPR/dpr/indexer/faiss_indexers.py @@ -0,0 +1,254 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" + FAISS-based index components for dense retriever +""" + +import faiss +import logging +import numpy as np +import os +import pickle + +from typing import List, Tuple + +logger = logging.getLogger() + + +class DenseIndexer(object): + def __init__(self, buffer_size: int = 50000): + self.buffer_size = buffer_size + self.index_id_to_db_id = [] + self.index = None + + def init_index(self, vector_sz: int): + raise NotImplementedError + + def index_data(self, data: List[Tuple[object, np.array]]): + raise NotImplementedError + + def get_index_name(self): + raise NotImplementedError + + def search_knn( + self, query_vectors: np.array, top_docs: int + ) -> List[Tuple[List[object], List[float]]]: + raise NotImplementedError + + def serialize(self, file: str): + logger.info("Serializing index to %s", file) + + if os.path.isdir(file): + index_file = os.path.join(file, "index.dpr") + meta_file = os.path.join(file, "index_meta.dpr") + else: + index_file = file + ".index.dpr" + meta_file = file + ".index_meta.dpr" + + faiss.write_index(self.index, index_file) + with open(meta_file, mode="wb") as f: + pickle.dump(self.index_id_to_db_id, f) + + def get_files(self, path: str): + if os.path.isdir(path): + index_file = os.path.join(path, "index.dpr") + meta_file = os.path.join(path, "index_meta.dpr") + else: + index_file = path + ".{}.dpr".format(self.get_index_name()) + meta_file = path + ".{}_meta.dpr".format(self.get_index_name()) + return index_file, meta_file + + def index_exists(self, path: str): + index_file, meta_file = self.get_files(path) + return os.path.isfile(index_file) and os.path.isfile(meta_file) + + def deserialize(self, path: str): + logger.info("Loading index from %s", path) + index_file, meta_file = self.get_files(path) + + self.index = faiss.read_index(index_file) + logger.info( + "Loaded index of type %s and size %d", type(self.index), self.index.ntotal + ) + + with open(meta_file, "rb") as reader: + self.index_id_to_db_id = pickle.load(reader) + assert ( + len(self.index_id_to_db_id) == self.index.ntotal + ), "Deserialized index_id_to_db_id should match faiss index size" + + def _update_id_mapping(self, db_ids: List) -> int: + self.index_id_to_db_id.extend(db_ids) + return len(self.index_id_to_db_id) + + +class DenseFlatIndexer(DenseIndexer): + def __init__(self, buffer_size: int = 50000): + super(DenseFlatIndexer, self).__init__(buffer_size=buffer_size) + + def init_index(self, vector_sz: int): + self.index = faiss.IndexFlatIP(vector_sz) + + def index_data(self, data: List[Tuple[object, np.array]]): + n = len(data) + # indexing in batches is beneficial for many faiss index types + for i in range(0, n, self.buffer_size): + db_ids = [t[0] for t in data[i : i + self.buffer_size]] + vectors = [ + np.reshape(t[1], (1, -1)) for t in data[i : i + self.buffer_size] + ] + vectors = np.concatenate(vectors, axis=0) + total_data = self._update_id_mapping(db_ids) + self.index.add(vectors) + logger.info("data indexed %d", total_data) + + indexed_cnt = len(self.index_id_to_db_id) + logger.info("Total data indexed %d", indexed_cnt) + + def search_knn( + self, query_vectors: np.array, top_docs: int + ) -> List[Tuple[List[object], List[float]]]: + scores, indexes = self.index.search(query_vectors, top_docs) + # convert to external ids + db_ids = [ + [self.index_id_to_db_id[i] for i in query_top_idxs] + for query_top_idxs in indexes + ] + result = [(db_ids[i], scores[i]) for i in range(len(db_ids))] + return result + + def get_index_name(self): + return "flat_index" + + +class DenseHNSWFlatIndexer(DenseIndexer): + """ + Efficient index for retrieval. Note: default settings are for hugh accuracy but also high RAM usage + """ + + def __init__( + self, + buffer_size: int = 1e9, + store_n: int = 512, + ef_search: int = 128, + ef_construction: int = 200, + ): + super(DenseHNSWFlatIndexer, self).__init__(buffer_size=buffer_size) + self.store_n = store_n + self.ef_search = ef_search + self.ef_construction = ef_construction + self.phi = 0 + + def init_index(self, vector_sz: int): + # IndexHNSWFlat supports L2 similarity only + # so we have to apply DOT -> L2 similairy space conversion with the help of an extra dimension + index = faiss.IndexHNSWFlat(vector_sz + 1, self.store_n) + index.hnsw.efSearch = self.ef_search + index.hnsw.efConstruction = self.ef_construction + self.index = index + + def index_data(self, data: List[Tuple[object, np.array]]): + n = len(data) + + # max norm is required before putting all vectors in the index to convert inner product similarity to L2 + if self.phi > 0: + raise RuntimeError( + "DPR HNSWF index needs to index all data at once," + "results will be unpredictable otherwise." + ) + phi = 0 + for i, item in enumerate(data): + id, doc_vector = item[0:2] + norms = (doc_vector ** 2).sum() + phi = max(phi, norms) + logger.info("HNSWF DotProduct -> L2 space phi={}".format(phi)) + self.phi = phi + + # indexing in batches is beneficial for many faiss index types + bs = int(self.buffer_size) + for i in range(0, n, bs): + db_ids = [t[0] for t in data[i : i + bs]] + vectors = [np.reshape(t[1], (1, -1)) for t in data[i : i + bs]] + + norms = [(doc_vector ** 2).sum() for doc_vector in vectors] + aux_dims = [np.sqrt(phi - norm) for norm in norms] + hnsw_vectors = [ + np.hstack((doc_vector, aux_dims[i].reshape(-1, 1))) + for i, doc_vector in enumerate(vectors) + ] + hnsw_vectors = np.concatenate(hnsw_vectors, axis=0) + self.train(hnsw_vectors) + + self._update_id_mapping(db_ids) + self.index.add(hnsw_vectors) + logger.info("data indexed %d", len(self.index_id_to_db_id)) + indexed_cnt = len(self.index_id_to_db_id) + logger.info("Total data indexed %d", indexed_cnt) + + def train(self, vectors: np.array): + pass + + def search_knn( + self, query_vectors: np.array, top_docs: int + ) -> List[Tuple[List[object], List[float]]]: + + aux_dim = np.zeros(len(query_vectors), dtype="float32") + query_nhsw_vectors = np.hstack((query_vectors, aux_dim.reshape(-1, 1))) + logger.info("query_hnsw_vectors %s", query_nhsw_vectors.shape) + scores, indexes = self.index.search(query_nhsw_vectors, top_docs) + # convert to external ids + db_ids = [ + [self.index_id_to_db_id[i] for i in query_top_idxs] + for query_top_idxs in indexes + ] + result = [(db_ids[i], scores[i]) for i in range(len(db_ids))] + return result + + def deserialize(self, file: str): + super(DenseHNSWFlatIndexer, self).deserialize(file) + # to trigger exception on subsequent indexing + self.phi = 1 + + def get_index_name(self): + return "hnsw_index" + + +class DenseHNSWSQIndexer(DenseHNSWFlatIndexer): + """ + Efficient index for retrieval. Note: default settings are for hugh accuracy but also high RAM usage + """ + + def __init__( + self, + buffer_size: int = 1e10, + store_n: int = 128, + ef_search: int = 128, + ef_construction: int = 200, + ): + super(DenseHNSWSQIndexer, self).__init__( + buffer_size=buffer_size, + store_n=store_n, + ef_search=ef_search, + ef_construction=ef_construction, + ) + + def init_index(self, vector_sz: int): + # IndexHNSWFlat supports L2 similarity only + # so we have to apply DOT -> L2 similairy space conversion with the help of an extra dimension + index = faiss.IndexHNSWSQ( + vector_sz + 1, faiss.ScalarQuantizer.QT_8bit, self.store_n + ) + index.hnsw.efSearch = self.ef_search + index.hnsw.efConstruction = self.ef_construction + self.index = index + + def train(self, vectors: np.array): + self.index.train(vectors) + + def get_index_name(self): + return "hnswsq_index" diff --git a/research/information_retrieval/DPR/dpr/models/__init__.py b/research/information_retrieval/DPR/dpr/models/__init__.py new file mode 100644 index 00000000000..cebca5c1021 --- /dev/null +++ b/research/information_retrieval/DPR/dpr/models/__init__.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import importlib + +""" + 'Router'-like set of methods for component initialization with lazy imports +""" + + +def init_hf_bert_biencoder(args, **kwargs): + if importlib.util.find_spec("transformers") is None: + raise RuntimeError('Please install transformers lib') + from .hf_models import get_bert_biencoder_components + return get_bert_biencoder_components(args, **kwargs) + + +def init_hf_bert_reader(args, **kwargs): + if importlib.util.find_spec("transformers") is None: + raise RuntimeError('Please install transformers lib') + from .hf_models import get_bert_reader_components + return get_bert_reader_components(args, **kwargs) + + +def init_pytext_bert_biencoder(args, **kwargs): + if importlib.util.find_spec("pytext") is None: + raise RuntimeError('Please install pytext lib') + from .pytext_models import get_bert_biencoder_components + return get_bert_biencoder_components(args, **kwargs) + + +def init_fairseq_roberta_biencoder(args, **kwargs): + if importlib.util.find_spec("fairseq") is None: + raise RuntimeError('Please install fairseq lib') + from .fairseq_models import get_roberta_biencoder_components + return get_roberta_biencoder_components(args, **kwargs) + + +def init_hf_bert_tenzorizer(args, **kwargs): + if importlib.util.find_spec("transformers") is None: + raise RuntimeError('Please install transformers lib') + from .hf_models import get_bert_tensorizer + return get_bert_tensorizer(args) + + +def init_hf_roberta_tenzorizer(args, **kwargs): + if importlib.util.find_spec("transformers") is None: + raise RuntimeError('Please install transformers lib') + from .hf_models import get_roberta_tensorizer + return get_roberta_tensorizer(args) + + +BIENCODER_INITIALIZERS = { + 'hf_bert': init_hf_bert_biencoder, + 'pytext_bert': init_pytext_bert_biencoder, + 'fairseq_roberta': init_fairseq_roberta_biencoder, +} + +READER_INITIALIZERS = { + 'hf_bert': init_hf_bert_reader, +} + +TENSORIZER_INITIALIZERS = { + 'hf_bert': init_hf_bert_tenzorizer, + 'hf_roberta': init_hf_roberta_tenzorizer, + 'pytext_bert': init_hf_bert_tenzorizer, # using HF's code as of now + 'fairseq_roberta': init_hf_roberta_tenzorizer, # using HF's code as of now +} + + +def init_comp(initializers_dict, type, args, **kwargs): + if type in initializers_dict: + return initializers_dict[type](args, **kwargs) + else: + raise RuntimeError('unsupported model type: {}'.format(type)) + + +def init_biencoder_components(encoder_type: str, args, **kwargs): + return init_comp(BIENCODER_INITIALIZERS, encoder_type, args, **kwargs) + + +def init_reader_components(encoder_type: str, args, **kwargs): + return init_comp(READER_INITIALIZERS, encoder_type, args, **kwargs) + + +def init_tenzorizer(encoder_type: str, args, **kwargs): + return init_comp(TENSORIZER_INITIALIZERS, encoder_type, args, **kwargs) diff --git a/research/information_retrieval/DPR/dpr/models/biencoder.py b/research/information_retrieval/DPR/dpr/models/biencoder.py new file mode 100644 index 00000000000..f8183a4fa9d --- /dev/null +++ b/research/information_retrieval/DPR/dpr/models/biencoder.py @@ -0,0 +1,452 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +BiEncoder component + loss function for 'all-in-batch' training +""" + +import collections +import logging +import random +from typing import Tuple, List + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor as T +from torch import nn + +from dpr.data.biencoder_data import BiEncoderSample +from dpr.utils.data_utils import Tensorizer +from dpr.utils.model_utils import CheckpointState + +logger = logging.getLogger(__name__) + +BiEncoderBatch = collections.namedtuple( + "BiENcoderInput", + [ + "question_ids", + "question_segments", + "context_ids", + "ctx_segments", + "is_positive", + "hard_negatives", + "encoder_type", + ], +) +# TODO: it is only used by _select_span_with_token. Move them to utils +rnd = random.Random(0) + + +def dot_product_scores(q_vectors: T, ctx_vectors: T) -> T: + """ + calculates q->ctx scores for every row in ctx_vector + :param q_vector: + :param ctx_vector: + :return: + """ + # q_vector: n1 x D, ctx_vectors: n2 x D, result n1 x n2 + r = torch.matmul(q_vectors, torch.transpose(ctx_vectors, 0, 1)) + return r + + +def cosine_scores(q_vector: T, ctx_vectors: T): + # q_vector: n1 x D, ctx_vectors: n2 x D, result n1 x n2 + return F.cosine_similarity(q_vector, ctx_vectors, dim=1) + + +class BiEncoder(nn.Module): + """Bi-Encoder model component. Encapsulates query/question and context/passage encoders.""" + + def __init__( + self, + question_model: nn.Module, + ctx_model: nn.Module, + fix_q_encoder: bool = False, + fix_ctx_encoder: bool = False, + ): + super(BiEncoder, self).__init__() + self.question_model = question_model + self.ctx_model = ctx_model + self.fix_q_encoder = fix_q_encoder + self.fix_ctx_encoder = fix_ctx_encoder + + @staticmethod + def get_representation( + sub_model: nn.Module, + ids: T, + segments: T, + attn_mask: T, + fix_encoder: bool = False, + representation_token_pos=0, + ) -> (T, T, T): + sequence_output = None + pooled_output = None + hidden_states = None + if ids is not None: + if fix_encoder: + with torch.no_grad(): + sequence_output, pooled_output, hidden_states = sub_model( + ids, + segments, + attn_mask, + representation_token_pos=representation_token_pos, + ) + + if sub_model.training: + sequence_output.requires_grad_(requires_grad=True) + pooled_output.requires_grad_(requires_grad=True) + else: + sequence_output, pooled_output, hidden_states = sub_model( + ids, + segments, + attn_mask, + representation_token_pos=representation_token_pos, + ) + + return sequence_output, pooled_output, hidden_states + + def forward( + self, + question_ids: T, + question_segments: T, + question_attn_mask: T, + context_ids: T, + ctx_segments: T, + ctx_attn_mask: T, + encoder_type: str = None, + representation_token_pos=0, + ) -> Tuple[T, T]: + q_encoder = ( + self.question_model + if encoder_type is None or encoder_type == "question" + else self.ctx_model + ) + _q_seq, q_pooled_out, _q_hidden = self.get_representation( + q_encoder, + question_ids, + question_segments, + question_attn_mask, + self.fix_q_encoder, + representation_token_pos=representation_token_pos, + ) + + ctx_encoder = ( + self.ctx_model + if encoder_type is None or encoder_type == "ctx" + else self.question_model + ) + _ctx_seq, ctx_pooled_out, _ctx_hidden = self.get_representation( + ctx_encoder, context_ids, ctx_segments, ctx_attn_mask, self.fix_ctx_encoder + ) + + return q_pooled_out, ctx_pooled_out + + # TODO delete once moved to the new method + @classmethod + def create_biencoder_input( + cls, + samples: List, + tensorizer: Tensorizer, + insert_title: bool, + num_hard_negatives: int = 0, + num_other_negatives: int = 0, + shuffle: bool = True, + shuffle_positives: bool = False, + hard_neg_fallback: bool = True, + ) -> BiEncoderBatch: + """ + Creates a batch of the biencoder training tuple. + :param samples: list of data items (from json) to create the batch for + :param tensorizer: components to create model input tensors from a text sequence + :param insert_title: enables title insertion at the beginning of the context sequences + :param num_hard_negatives: amount of hard negatives per question (taken from samples' pools) + :param num_other_negatives: amount of other negatives per question (taken from samples' pools) + :param shuffle: shuffles negative passages pools + :param shuffle_positives: shuffles positive passages pools + :return: BiEncoderBatch tuple + """ + question_tensors = [] + ctx_tensors = [] + positive_ctx_indices = [] + hard_neg_ctx_indices = [] + + for sample in samples: + # ctx+ & [ctx-] composition + # as of now, take the first(gold) ctx+ only + if shuffle and shuffle_positives: + positive_ctxs = sample["positive_ctxs"] + positive_ctx = positive_ctxs[np.random.choice(len(positive_ctxs))] + else: + positive_ctx = sample["positive_ctxs"][0] + + neg_ctxs = sample["negative_ctxs"] + hard_neg_ctxs = sample["hard_negative_ctxs"] + + if shuffle: + random.shuffle(neg_ctxs) + random.shuffle(hard_neg_ctxs) + + if hard_neg_fallback and len(hard_neg_ctxs) == 0: + hard_neg_ctxs = neg_ctxs[0:num_hard_negatives] + + neg_ctxs = neg_ctxs[0:num_other_negatives] + hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives] + + all_ctxs = [positive_ctx] + neg_ctxs + hard_neg_ctxs + hard_negatives_start_idx = 1 + hard_negatives_end_idx = 1 + len(hard_neg_ctxs) + + current_ctxs_len = len(ctx_tensors) + + sample_ctxs_tensors = [ + tensorizer.text_to_tensor( + ctx["text"], + title=ctx["title"] if (insert_title and "title" in ctx) else None, + ) + for ctx in all_ctxs + ] + + ctx_tensors.extend(sample_ctxs_tensors) + positive_ctx_indices.append(current_ctxs_len) + hard_neg_ctx_indices.append( + [ + i + for i in range( + current_ctxs_len + hard_negatives_start_idx, + current_ctxs_len + hard_negatives_end_idx, + ) + ] + ) + + question_tensors.append(tensorizer.text_to_tensor(question)) + + ctxs_tensor = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors], dim=0) + questions_tensor = torch.cat([q.view(1, -1) for q in question_tensors], dim=0) + + ctx_segments = torch.zeros_like(ctxs_tensor) + question_segments = torch.zeros_like(questions_tensor) + + return BiEncoderBatch( + questions_tensor, + question_segments, + ctxs_tensor, + ctx_segments, + positive_ctx_indices, + hard_neg_ctx_indices, + "question", + ) + + @classmethod + def create_biencoder_input2( + cls, + samples: List[BiEncoderSample], + tensorizer: Tensorizer, + insert_title: bool, + num_hard_negatives: int = 0, + num_other_negatives: int = 0, + shuffle: bool = True, + shuffle_positives: bool = False, + hard_neg_fallback: bool = True, + query_token: str = None, + ) -> BiEncoderBatch: + """ + Creates a batch of the biencoder training tuple. + :param samples: list of BiEncoderSample-s to create the batch for + :param tensorizer: components to create model input tensors from a text sequence + :param insert_title: enables title insertion at the beginning of the context sequences + :param num_hard_negatives: amount of hard negatives per question (taken from samples' pools) + :param num_other_negatives: amount of other negatives per question (taken from samples' pools) + :param shuffle: shuffles negative passages pools + :param shuffle_positives: shuffles positive passages pools + :return: BiEncoderBatch tuple + """ + question_tensors = [] + ctx_tensors = [] + positive_ctx_indices = [] + hard_neg_ctx_indices = [] + + for sample in samples: + # ctx+ & [ctx-] composition + # as of now, take the first(gold) ctx+ only + + if shuffle and shuffle_positives: + positive_ctxs = sample.positive_passages + positive_ctx = positive_ctxs[np.random.choice(len(positive_ctxs))] + else: + positive_ctx = sample.positive_passages[0] + + neg_ctxs = sample.negative_passages + hard_neg_ctxs = sample.hard_negative_passages + question = sample.query + # question = normalize_question(sample.query) + + if shuffle: + random.shuffle(neg_ctxs) + random.shuffle(hard_neg_ctxs) + + if hard_neg_fallback and len(hard_neg_ctxs) == 0: + hard_neg_ctxs = neg_ctxs[0:num_hard_negatives] + + neg_ctxs = neg_ctxs[0:num_other_negatives] + hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives] + + all_ctxs = [positive_ctx] + neg_ctxs + hard_neg_ctxs + hard_negatives_start_idx = 1 + hard_negatives_end_idx = 1 + len(hard_neg_ctxs) + + current_ctxs_len = len(ctx_tensors) + + sample_ctxs_tensors = [ + tensorizer.text_to_tensor( + ctx.text, title=ctx.title if (insert_title and ctx.title) else None + ) + for ctx in all_ctxs + ] + + ctx_tensors.extend(sample_ctxs_tensors) + positive_ctx_indices.append(current_ctxs_len) + hard_neg_ctx_indices.append( + [ + i + for i in range( + current_ctxs_len + hard_negatives_start_idx, + current_ctxs_len + hard_negatives_end_idx, + ) + ] + ) + + if query_token: + # TODO: tmp workaround for EL, remove or revise + if query_token == "[START_ENT]": + query_span = _select_span_with_token( + question, tensorizer, token_str=query_token + ) + question_tensors.append(query_span) + else: + question_tensors.append( + tensorizer.text_to_tensor(" ".join([query_token, question])) + ) + else: + question_tensors.append(tensorizer.text_to_tensor(question)) + + ctxs_tensor = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors], dim=0) + questions_tensor = torch.cat([q.view(1, -1) for q in question_tensors], dim=0) + + ctx_segments = torch.zeros_like(ctxs_tensor) + question_segments = torch.zeros_like(questions_tensor) + + return BiEncoderBatch( + questions_tensor, + question_segments, + ctxs_tensor, + ctx_segments, + positive_ctx_indices, + hard_neg_ctx_indices, + "question", + ) + + def load_state(self, saved_state: CheckpointState): + # TODO: make a long term HF compatibility fix + if "question_model.embeddings.position_ids" in saved_state.model_dict: + del saved_state.model_dict["question_model.embeddings.position_ids"] + del saved_state.model_dict["ctx_model.embeddings.position_ids"] + self.load_state_dict(saved_state.model_dict) + + def get_state_dict(self): + return self.state_dict() + + +class BiEncoderNllLoss(object): + def calc( + self, + q_vectors: T, + ctx_vectors: T, + positive_idx_per_question: list, + hard_negative_idx_per_question: list = None, + loss_scale: float = None, + ) -> Tuple[T, int]: + """ + Computes nll loss for the given lists of question and ctx vectors. + Note that although hard_negative_idx_per_question in not currently in use, one can use it for the + loss modifications. For example - weighted NLL with different factors for hard vs regular negatives. + :return: a tuple of loss value and amount of correct predictions per batch + """ + scores = self.get_scores(q_vectors, ctx_vectors) + + if len(q_vectors.size()) > 1: + q_num = q_vectors.size(0) + scores = scores.view(q_num, -1) + + softmax_scores = F.log_softmax(scores, dim=1) + + loss = F.nll_loss( + softmax_scores, + torch.tensor(positive_idx_per_question).to(softmax_scores.device), + reduction="mean", + ) + + max_score, max_idxs = torch.max(softmax_scores, 1) + correct_predictions_count = ( + max_idxs == torch.tensor(positive_idx_per_question).to(max_idxs.device) + ).sum() + + if loss_scale: + loss.mul_(loss_scale) + + return loss, correct_predictions_count + + @staticmethod + def get_scores(q_vector: T, ctx_vectors: T) -> T: + f = BiEncoderNllLoss.get_similarity_function() + return f(q_vector, ctx_vectors) + + @staticmethod + def get_similarity_function(): + return dot_product_scores + + +def _select_span_with_token( + text: str, tensorizer: Tensorizer, token_str: str = "[START_ENT]" +) -> T: + id = tensorizer.get_token_id(token_str) + query_tensor = tensorizer.text_to_tensor(text) + + if id not in query_tensor: + query_tensor_full = tensorizer.text_to_tensor(text, apply_max_len=False) + token_indexes = (query_tensor_full == id).nonzero() + if token_indexes.size(0) > 0: + start_pos = token_indexes[0, 0].item() + # add some randomization to avoid overfitting to a specific token position + + left_shit = int(tensorizer.max_length / 2) + rnd_shift = int((rnd.random() - 0.5) * left_shit / 2) + left_shit += rnd_shift + + query_tensor = query_tensor_full[start_pos - left_shit :] + cls_id = tensorizer.tokenizer.cls_token_id + if query_tensor[0] != cls_id: + query_tensor = torch.cat([torch.tensor([cls_id]), query_tensor], dim=0) + + from dpr.models.reader import _pad_to_len + + query_tensor = _pad_to_len( + query_tensor, tensorizer.get_pad_id(), tensorizer.max_length + ) + query_tensor[-1] = tensorizer.tokenizer.sep_token_id + # logger.info('aligned query_tensor %s', query_tensor) + + assert id in query_tensor, "query_tensor={}".format(query_tensor) + return query_tensor + else: + raise RuntimeError( + "[START_ENT] toke not found for Entity Linking sample query={}".format( + text + ) + ) + else: + return query_tensor diff --git a/research/information_retrieval/DPR/dpr/models/fairseq_models.py b/research/information_retrieval/DPR/dpr/models/fairseq_models.py new file mode 100644 index 00000000000..dd8a6513c13 --- /dev/null +++ b/research/information_retrieval/DPR/dpr/models/fairseq_models.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Encoder model wrappers based on Fairseq code +""" + +import logging +from typing import Tuple + +from fairseq.models.roberta.hub_interface import RobertaHubInterface +from fairseq.models.roberta.model import RobertaModel as FaiseqRobertaModel +from fairseq.optim.adam import FairseqAdam +from torch import Tensor as T +from torch import nn + +from dpr.models.hf_models import get_roberta_tensorizer +from .biencoder import BiEncoder + +logger = logging.getLogger(__name__) + + +def get_roberta_biencoder_components(args, inference_only: bool = False, **kwargs): + question_encoder = RobertaEncoder.from_pretrained(args.pretrained_file) + ctx_encoder = RobertaEncoder.from_pretrained(args.pretrained_file) + biencoder = BiEncoder(question_encoder, ctx_encoder) + optimizer = get_fairseq_adamw_optimizer(biencoder, args) if not inference_only else None + + tensorizer = get_roberta_tensorizer(args) + + return tensorizer, biencoder, optimizer + + +def get_fairseq_adamw_optimizer(model: nn.Module, args): + setattr(args, 'lr', [args.learning_rate]) + return FairseqAdam(args, model.parameters()).optimizer + + +class RobertaEncoder(nn.Module): + + def __init__(self, fairseq_roberta_hub: RobertaHubInterface): + super(RobertaEncoder, self).__init__() + self.fairseq_roberta = fairseq_roberta_hub + + @classmethod + def from_pretrained(cls, pretrained_dir_path: str): + model = FaiseqRobertaModel.from_pretrained(pretrained_dir_path) + return cls(model) + + def forward(self, input_ids: T, token_type_ids: T, attention_mask: T) -> Tuple[T, ...]: + roberta_out = self.fairseq_roberta.extract_features(input_ids) + cls_out = roberta_out[:, 0, :] + return roberta_out, cls_out, None + + def get_out_size(self): + raise NotImplementedError diff --git a/research/information_retrieval/DPR/dpr/models/hf_models.py b/research/information_retrieval/DPR/dpr/models/hf_models.py new file mode 100644 index 00000000000..45571aee380 --- /dev/null +++ b/research/information_retrieval/DPR/dpr/models/hf_models.py @@ -0,0 +1,336 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Encoder model wrappers based on HuggingFace code +""" + +import logging +from typing import Tuple + +import torch +from torch import Tensor as T +from torch import nn +from transformers.modeling_bert import BertConfig, BertModel +from transformers.optimization import AdamW +from transformers.tokenization_bert import BertTokenizer +from transformers.tokenization_roberta import RobertaTokenizer + +from dpr.models.biencoder import BiEncoder +from dpr.utils.data_utils import Tensorizer +from .reader import Reader + +logger = logging.getLogger(__name__) + + +def get_bert_biencoder_components(cfg, inference_only: bool = False, **kwargs): + dropout = cfg.encoder.dropout if hasattr(cfg.encoder, "dropout") else 0.0 + question_encoder = HFBertEncoder.init_encoder( + cfg.encoder.pretrained_model_cfg, + projection_dim=cfg.encoder.projection_dim, + dropout=dropout, + pretrained=cfg.encoder.pretrained, + **kwargs + ) + ctx_encoder = HFBertEncoder.init_encoder( + cfg.encoder.pretrained_model_cfg, + projection_dim=cfg.encoder.projection_dim, + dropout=dropout, + pretrained=cfg.encoder.pretrained, + **kwargs + ) + + fix_ctx_encoder = cfg.fix_ctx_encoder if hasattr(cfg, "fix_ctx_encoder") else False + + biencoder = BiEncoder( + question_encoder, ctx_encoder, fix_ctx_encoder=fix_ctx_encoder + ) + + optimizer = ( + get_optimizer( + biencoder, + learning_rate=cfg.train.learning_rate, + adam_eps=cfg.train.adam_eps, + weight_decay=cfg.train.weight_decay, + ) + if not inference_only + else None + ) + + tensorizer = get_bert_tensorizer(cfg) + return tensorizer, biencoder, optimizer + + +def get_bert_reader_components(cfg, inference_only: bool = False, **kwargs): + dropout = cfg.encoder.dropout if hasattr(cfg.encoder, "dropout") else 0.0 + encoder = HFBertEncoder.init_encoder( + cfg.encoder.pretrained_model_cfg, + projection_dim=cfg.encoder.projection_dim, + dropout=dropout, + pretrained=cfg.encoder.pretrained, + **kwargs + ) + + hidden_size = encoder.config.hidden_size + reader = Reader(encoder, hidden_size) + + optimizer = ( + get_optimizer( + reader, + learning_rate=cfg.train.learning_rate, + adam_eps=cfg.train.adam_eps, + weight_decay=cfg.train.weight_decay, + ) + if not inference_only + else None + ) + + tensorizer = get_bert_tensorizer(cfg) + return tensorizer, reader, optimizer + + +def get_bert_tensorizer(cfg, tokenizer=None): + sequence_length = cfg.encoder.sequence_length + pretrained_model_cfg = cfg.encoder.pretrained_model_cfg + + if not tokenizer: + tokenizer = get_bert_tokenizer( + pretrained_model_cfg, do_lower_case=cfg.do_lower_case + ) + if cfg.special_tokens: + _add_special_tokens(tokenizer, cfg.special_tokens) + + return BertTensorizer(tokenizer, sequence_length) + + +def _add_special_tokens(tokenizer, special_tokens): + logger.info("Adding special tokens %s", special_tokens) + special_tokens_num = len(special_tokens) + # TODO: this is a hack-y logic that uses some private tokenizer structure which can be changed in HF code + assert special_tokens_num < 50 + unused_ids = [ + tokenizer.vocab["[unused{}]".format(i)] for i in range(special_tokens_num) + ] + logger.info("Utilizing the following unused token ids %s", unused_ids) + + for idx, id in enumerate(unused_ids): + del tokenizer.vocab["[unused{}]".format(idx)] + tokenizer.vocab[special_tokens[idx]] = id + tokenizer.ids_to_tokens[id] = special_tokens[idx] + + tokenizer._additional_special_tokens = list(special_tokens) + logger.info( + "Added special tokenizer.additional_special_tokens %s", + tokenizer.additional_special_tokens, + ) + logger.info("Tokenizer's all_special_tokens %s", tokenizer.all_special_tokens) + + +def get_roberta_tensorizer(args, tokenizer=None): + if not tokenizer: + tokenizer = get_roberta_tokenizer( + args.pretrained_model_cfg, do_lower_case=args.do_lower_case + ) + return RobertaTensorizer(tokenizer, args.sequence_length) + + +def get_optimizer( + model: nn.Module, + learning_rate: float = 1e-5, + adam_eps: float = 1e-8, + weight_decay: float = 0.0, +) -> torch.optim.Optimizer: + no_decay = ["bias", "LayerNorm.weight"] + + optimizer_grouped_parameters = [ + { + "params": [ + p + for n, p in model.named_parameters() + if not any(nd in n for nd in no_decay) + ], + "weight_decay": weight_decay, + }, + { + "params": [ + p + for n, p in model.named_parameters() + if any(nd in n for nd in no_decay) + ], + "weight_decay": 0.0, + }, + ] + optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=adam_eps) + return optimizer + + +def get_bert_tokenizer(pretrained_cfg_name: str, do_lower_case: bool = True): + return BertTokenizer.from_pretrained( + pretrained_cfg_name, do_lower_case=do_lower_case + ) + + +def get_roberta_tokenizer(pretrained_cfg_name: str, do_lower_case: bool = True): + # still uses HF code for tokenizer since they are the same + return RobertaTokenizer.from_pretrained( + pretrained_cfg_name, do_lower_case=do_lower_case + ) + + +class HFBertEncoder(BertModel): + def __init__(self, config, project_dim: int = 0): + BertModel.__init__(self, config) + assert config.hidden_size > 0, "Encoder hidden_size can't be zero" + self.encode_proj = ( + nn.Linear(config.hidden_size, project_dim) if project_dim != 0 else None + ) + self.init_weights() + + @classmethod + def init_encoder( + cls, + cfg_name: str, + projection_dim: int = 0, + dropout: float = 0.1, + pretrained: bool = True, + **kwargs + ) -> BertModel: + cfg = BertConfig.from_pretrained(cfg_name if cfg_name else "bert-base-uncased") + if dropout != 0: + cfg.attention_probs_dropout_prob = dropout + cfg.hidden_dropout_prob = dropout + + if pretrained: + return cls.from_pretrained( + cfg_name, config=cfg, project_dim=projection_dim, **kwargs + ) + else: + return HFBertEncoder(cfg, project_dim=projection_dim) + + def forward( + self, + input_ids: T, + token_type_ids: T, + attention_mask: T, + representation_token_pos=0, + ) -> Tuple[T, ...]: + if self.config.output_hidden_states: + sequence_output, pooled_output, hidden_states = super().forward( + input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + ) + else: + hidden_states = None + sequence_output, pooled_output = super().forward( + input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + ) + + if isinstance(representation_token_pos, int): + pooled_output = sequence_output[:, representation_token_pos, :] + else: # treat as a tensor + bsz = sequence_output.size(0) + assert ( + representation_token_pos.size(0) == bsz + ), "query bsz={} while representation_token_pos bsz={}".format( + bsz, representation_token_pos.size(0) + ) + pooled_output = torch.stack( + [ + sequence_output[i, representation_token_pos[i, 1], :] + for i in range(bsz) + ] + ) + + if self.encode_proj: + pooled_output = self.encode_proj(pooled_output) + return sequence_output, pooled_output, hidden_states + + def get_out_size(self): + if self.encode_proj: + return self.encode_proj.out_features + return self.config.hidden_size + + +class BertTensorizer(Tensorizer): + def __init__( + self, tokenizer: BertTokenizer, max_length: int, pad_to_max: bool = True + ): + self.tokenizer = tokenizer + self.max_length = max_length + self.pad_to_max = pad_to_max + + def text_to_tensor( + self, + text: str, + title: str = None, + add_special_tokens: bool = True, + apply_max_len: bool = True, + ): + text = text.strip() + # tokenizer automatic padding is explicitly disabled since its inconsistent behavior + # TODO: move max len to methods params? + + if title: + token_ids = self.tokenizer.encode( + title, + text_pair=text, + add_special_tokens=add_special_tokens, + max_length=self.max_length if apply_max_len else 10000, + pad_to_max_length=False, + truncation=True, + ) + else: + token_ids = self.tokenizer.encode( + text, + add_special_tokens=add_special_tokens, + max_length=self.max_length if apply_max_len else 10000, + pad_to_max_length=False, + truncation=True, + ) + + seq_len = self.max_length + if self.pad_to_max and len(token_ids) < seq_len: + token_ids = token_ids + [self.tokenizer.pad_token_id] * ( + seq_len - len(token_ids) + ) + if len(token_ids) >= seq_len: + token_ids = token_ids[0:seq_len] if apply_max_len else token_ids + token_ids[-1] = self.tokenizer.sep_token_id + + return torch.tensor(token_ids) + + def get_pair_separator_ids(self) -> T: + return torch.tensor([self.tokenizer.sep_token_id]) + + def get_pad_id(self) -> int: + return self.tokenizer.pad_token_id + + def get_attn_mask(self, tokens_tensor: T) -> T: + return tokens_tensor != self.get_pad_id() + + def is_sub_word_id(self, token_id: int): + token = self.tokenizer.convert_ids_to_tokens([token_id])[0] + return token.startswith("##") or token.startswith(" ##") + + def to_string(self, token_ids, skip_special_tokens=True): + return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) + + def set_pad_to_max(self, do_pad: bool): + self.pad_to_max = do_pad + + def get_token_id(self, token: str) -> int: + return self.tokenizer.vocab[token] + + +class RobertaTensorizer(BertTensorizer): + def __init__(self, tokenizer, max_length: int, pad_to_max: bool = True): + super(RobertaTensorizer, self).__init__( + tokenizer, max_length, pad_to_max=pad_to_max + ) diff --git a/research/information_retrieval/DPR/dpr/models/pytext_models.py b/research/information_retrieval/DPR/dpr/models/pytext_models.py new file mode 100644 index 00000000000..97ccc920e97 --- /dev/null +++ b/research/information_retrieval/DPR/dpr/models/pytext_models.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Encoder model wrappers based on HuggingFace code +""" + +import logging +from typing import Tuple + +import torch +from pytext.models.representations.transformer_sentence_encoder import TransformerSentenceEncoder +from pytext.optimizer.optimizers import AdamW +from torch import Tensor as T +from torch import nn + +from .biencoder import BiEncoder + +logger = logging.getLogger(__name__) + + +def get_bert_biencoder_components(args, inference_only: bool = False): + # since bert tokenizer is the same in HF and pytext/fairseq, just use HF's implementation here for now + from .hf_models import get_tokenizer, BertTensorizer + + tokenizer = get_tokenizer(args.pretrained_model_cfg, do_lower_case=args.do_lower_case) + + question_encoder = PytextBertEncoder.init_encoder(args.pretrained_file, + projection_dim=args.projection_dim, dropout=args.dropout, + vocab_size=tokenizer.vocab_size, + padding_idx=tokenizer.pad_token_type_id + ) + + ctx_encoder = PytextBertEncoder.init_encoder(args.pretrained_file, + projection_dim=args.projection_dim, dropout=args.dropout, + vocab_size=tokenizer.vocab_size, + padding_idx=tokenizer.pad_token_type_id + ) + + biencoder = BiEncoder(question_encoder, ctx_encoder) + + optimizer = get_optimizer(biencoder, + learning_rate=args.learning_rate, + adam_eps=args.adam_eps, weight_decay=args.weight_decay, + ) if not inference_only else None + + tensorizer = BertTensorizer(tokenizer, args.sequence_length) + return tensorizer, biencoder, optimizer + + +def get_optimizer(model: nn.Module, learning_rate: float = 1e-5, adam_eps: float = 1e-8, + weight_decay: float = 0.0) -> torch.optim.Optimizer: + cfg = AdamW.Config() + cfg.lr = learning_rate + cfg.weight_decay = weight_decay + cfg.eps = adam_eps + optimizer = AdamW.from_config(cfg, model) + return optimizer + + +def get_pytext_bert_base_cfg(): + cfg = TransformerSentenceEncoder.Config() + cfg.embedding_dim = 768 + cfg.ffn_embedding_dim = 3072 + cfg.num_encoder_layers = 12 + cfg.num_attention_heads = 12 + cfg.num_segments = 2 + cfg.use_position_embeddings = True + cfg.offset_positions_by_padding = True + cfg.apply_bert_init = True + cfg.encoder_normalize_before = True + cfg.activation_fn = "gelu" + cfg.projection_dim = 0 + cfg.max_seq_len = 512 + cfg.multilingual = False + cfg.freeze_embeddings = False + cfg.n_trans_layers_to_freeze = 0 + cfg.use_torchscript = False + return cfg + + +class PytextBertEncoder(TransformerSentenceEncoder): + + def __init__(self, config: TransformerSentenceEncoder.Config, + padding_idx: int, + vocab_size: int, + projection_dim: int = 0, + *args, + **kwarg + ): + + TransformerSentenceEncoder.__init__(self, config, False, padding_idx, vocab_size, *args, **kwarg) + + assert config.embedding_dim > 0, 'Encoder hidden_size can\'t be zero' + self.encode_proj = nn.Linear(config.embedding_dim, projection_dim) if projection_dim != 0 else None + + @classmethod + def init_encoder(cls, pretrained_file: str = None, projection_dim: int = 0, dropout: float = 0.1, + vocab_size: int = 0, + padding_idx: int = 0, **kwargs): + cfg = get_pytext_bert_base_cfg() + + if dropout != 0: + cfg.dropout = dropout + cfg.attention_dropout = dropout + cfg.activation_dropout = dropout + + encoder = cls(cfg, padding_idx, vocab_size, projection_dim, **kwargs) + + if pretrained_file: + logger.info('Loading pre-trained pytext encoder state from %s', pretrained_file) + state = torch.load(pretrained_file) + encoder.load_state_dict(state) + return encoder + + def forward(self, input_ids: T, token_type_ids: T, attention_mask: T) -> Tuple[T, ...]: + pooled_output = super().forward((input_ids, attention_mask, token_type_ids, None))[0] + if self.encode_proj: + pooled_output = self.encode_proj(pooled_output) + + return None, pooled_output, None + + def get_out_size(self): + if self.encode_proj: + return self.encode_proj.out_features + return self.representation_dim diff --git a/research/information_retrieval/DPR/dpr/models/reader.py b/research/information_retrieval/DPR/dpr/models/reader.py new file mode 100644 index 00000000000..761efc0b26c --- /dev/null +++ b/research/information_retrieval/DPR/dpr/models/reader.py @@ -0,0 +1,236 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +The reader model code + its utilities (loss computation and input batch tensor generator) +""" + +import collections +import logging +from typing import List + +import numpy as np +import torch +import torch.nn as nn +from torch import Tensor as T +from torch.nn import CrossEntropyLoss + +from dpr.data.reader_data import ReaderSample, ReaderPassage +from dpr.utils.model_utils import init_weights + +logger = logging.getLogger() + +ReaderBatch = collections.namedtuple('ReaderBatch', ['input_ids', 'start_positions', 'end_positions', 'answers_mask']) + + +class Reader(nn.Module): + + def __init__(self, encoder: nn.Module, hidden_size): + super(Reader, self).__init__() + self.encoder = encoder + self.qa_outputs = nn.Linear(hidden_size, 2) + self.qa_classifier = nn.Linear(hidden_size, 1) + init_weights([self.qa_outputs, self.qa_classifier]) + + def forward(self, input_ids: T, attention_mask: T, start_positions=None, end_positions=None, answer_mask=None): + # notations: N - number of questions in a batch, M - number of passages per questions, L - sequence length + N, M, L = input_ids.size() + start_logits, end_logits, relevance_logits = self._forward(input_ids.view(N * M, L), + attention_mask.view(N * M, L)) + if self.training: + return compute_loss(start_positions, end_positions, answer_mask, start_logits, end_logits, relevance_logits, + N, M) + + return start_logits.view(N, M, L), end_logits.view(N, M, L), relevance_logits.view(N, M) + + def _forward(self, input_ids, attention_mask): + # TODO: provide segment values + sequence_output, _pooled_output, _hidden_states = self.encoder(input_ids, None, attention_mask) + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + rank_logits = self.qa_classifier(sequence_output[:, 0, :]) + return start_logits, end_logits, rank_logits + + +def compute_loss(start_positions, end_positions, answer_mask, start_logits, end_logits, relevance_logits, N, M): + start_positions = start_positions.view(N * M, -1) + end_positions = end_positions.view(N * M, -1) + answer_mask = answer_mask.view(N * M, -1) + + start_logits = start_logits.view(N * M, -1) + end_logits = end_logits.view(N * M, -1) + relevance_logits = relevance_logits.view(N * M) + + answer_mask = answer_mask.type(torch.FloatTensor).cuda() + + ignored_index = start_logits.size(1) + start_positions.clamp_(0, ignored_index) + end_positions.clamp_(0, ignored_index) + loss_fct = CrossEntropyLoss(reduce=False, ignore_index=ignored_index) + + # compute switch loss + relevance_logits = relevance_logits.view(N, M) + switch_labels = torch.zeros(N, dtype=torch.long).cuda() + switch_loss = torch.sum(loss_fct(relevance_logits, switch_labels)) + + # compute span loss + start_losses = [(loss_fct(start_logits, _start_positions) * _span_mask) + for (_start_positions, _span_mask) + in zip(torch.unbind(start_positions, dim=1), torch.unbind(answer_mask, dim=1))] + + end_losses = [(loss_fct(end_logits, _end_positions) * _span_mask) + for (_end_positions, _span_mask) + in zip(torch.unbind(end_positions, dim=1), torch.unbind(answer_mask, dim=1))] + loss_tensor = torch.cat([t.unsqueeze(1) for t in start_losses], dim=1) + \ + torch.cat([t.unsqueeze(1) for t in end_losses], dim=1) + + loss_tensor = loss_tensor.view(N, M, -1).max(dim=1)[0] + span_loss = _calc_mml(loss_tensor) + return span_loss + switch_loss + + +def create_reader_input(pad_token_id: int, + samples: List[ReaderSample], + passages_per_question: int, + max_length: int, + max_n_answers: int, + is_train: bool, + shuffle: bool, + ) -> ReaderBatch: + """ + Creates a reader batch instance out of a list of ReaderSample-s + :param pad_token_id: id of the padding token + :param samples: list of samples to create the batch for + :param passages_per_question: amount of passages for every question in a batch + :param max_length: max model input sequence length + :param max_n_answers: max num of answers per single question + :param is_train: if the samples are for a train set + :param shuffle: should passages selection be randomized + :return: ReaderBatch instance + """ + input_ids = [] + start_positions = [] + end_positions = [] + answers_masks = [] + empty_sequence = torch.Tensor().new_full((max_length,), pad_token_id, dtype=torch.long) + + for sample in samples: + positive_ctxs = sample.positive_passages + negative_ctxs = sample.negative_passages if is_train else sample.passages + + sample_tensors = _create_question_passages_tensors(positive_ctxs, + negative_ctxs, + passages_per_question, + empty_sequence, + max_n_answers, + pad_token_id, + is_train, + is_random=shuffle) + if not sample_tensors: + logger.warning('No valid passages combination for question=%s ', sample.question) + continue + sample_input_ids, starts_tensor, ends_tensor, answer_mask = sample_tensors + input_ids.append(sample_input_ids) + if is_train: + start_positions.append(starts_tensor) + end_positions.append(ends_tensor) + answers_masks.append(answer_mask) + input_ids = torch.cat([ids.unsqueeze(0) for ids in input_ids], dim=0) + + if is_train: + start_positions = torch.stack(start_positions, dim=0) + end_positions = torch.stack(end_positions, dim=0) + answers_masks = torch.stack(answers_masks, dim=0) + + return ReaderBatch(input_ids, start_positions, end_positions, answers_masks) + + +def _calc_mml(loss_tensor): + marginal_likelihood = torch.sum(torch.exp( + - loss_tensor - 1e10 * (loss_tensor == 0).float()), 1) + return -torch.sum(torch.log(marginal_likelihood + + torch.ones(loss_tensor.size(0)).cuda() * (marginal_likelihood == 0).float())) + + +def _pad_to_len(seq: T, pad_id: int, max_len: int): + s_len = seq.size(0) + if s_len > max_len: + return seq[0: max_len] + return torch.cat([seq, torch.Tensor().new_full((max_len - s_len,), pad_id, dtype=torch.long)], dim=0) + + +def _get_answer_spans(idx, positives: List[ReaderPassage], max_len: int): + positive_a_spans = positives[idx].answers_spans + return [span for span in positive_a_spans if (span[0] < max_len and span[1] < max_len)] + + +def _get_positive_idx(positives: List[ReaderPassage], max_len: int, is_random: bool): + # select just one positive + positive_idx = np.random.choice(len(positives)) if is_random else 0 + + if not _get_answer_spans(positive_idx, positives, max_len): + # question may be too long, find the first positive with at least one valid span + positive_idx = next((i for i in range(len(positives)) if _get_answer_spans(i, positives, max_len)), + None) + return positive_idx + + +def _create_question_passages_tensors(positives: List[ReaderPassage], negatives: List[ReaderPassage], total_size: int, + empty_ids: T, + max_n_answers: int, + pad_token_id: int, + is_train: bool, + is_random: bool = True): + max_len = empty_ids.size(0) + if is_train: + # select just one positive + positive_idx = _get_positive_idx(positives, max_len, is_random) + if positive_idx is None: + return None + + positive_a_spans = _get_answer_spans(positive_idx, positives, max_len)[0: max_n_answers] + + answer_starts = [span[0] for span in positive_a_spans] + answer_ends = [span[1] for span in positive_a_spans] + + assert all(s < max_len for s in answer_starts) + assert all(e < max_len for e in answer_ends) + + positive_input_ids = _pad_to_len(positives[positive_idx].sequence_ids, pad_token_id, max_len) + + answer_starts_tensor = torch.zeros((total_size, max_n_answers)).long() + answer_starts_tensor[0, 0:len(answer_starts)] = torch.tensor(answer_starts) + + answer_ends_tensor = torch.zeros((total_size, max_n_answers)).long() + answer_ends_tensor[0, 0:len(answer_ends)] = torch.tensor(answer_ends) + + answer_mask = torch.zeros((total_size, max_n_answers), dtype=torch.long) + answer_mask[0, 0:len(answer_starts)] = torch.tensor([1 for _ in range(len(answer_starts))]) + + positives_selected = [positive_input_ids] + + else: + positives_selected = [] + answer_starts_tensor = None + answer_ends_tensor = None + answer_mask = None + + positives_num = len(positives_selected) + negative_idxs = np.random.permutation(range(len(negatives))) if is_random else range( + len(negatives) - positives_num) + + negative_idxs = negative_idxs[:total_size - positives_num] + + negatives_selected = [_pad_to_len(negatives[i].sequence_ids, pad_token_id, max_len) for i in negative_idxs] + + while len(negatives_selected) < total_size - positives_num: + negatives_selected.append(empty_ids.clone()) + + input_ids = torch.stack([t for t in positives_selected + negatives_selected], dim=0) + return input_ids, answer_starts_tensor, answer_ends_tensor, answer_mask diff --git a/research/information_retrieval/DPR/dpr/options.py b/research/information_retrieval/DPR/dpr/options.py new file mode 100644 index 00000000000..e55ee7a7f3d --- /dev/null +++ b/research/information_retrieval/DPR/dpr/options.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Command line arguments utils +""" + + +import logging +import numpy as np +import os +import random +import socket +import torch + +from omegaconf import DictConfig + +logger = logging.getLogger() + +# TODO: to be merged with conf_utils.py + + +def set_cfg_params_from_state(state: dict, cfg: DictConfig): + """ + Overrides some of the encoder config parameters from a give state object + """ + if not state: + return + cfg.do_lower_case = state["do_lower_case"] + cfg.encoder.pretrained_model_cfg = state["pretrained_model_cfg"] + cfg.encoder.encoder_model_type = state["encoder_model_type"] + cfg.encoder.pretrained_file = state["pretrained_file"] + cfg.encoder.projection_dim = state["projection_dim"] + cfg.encoder.sequence_length = state["sequence_length"] + + +def get_encoder_params_state_from_cfg(cfg: DictConfig): + """ + Selects the param values to be saved in a checkpoint, so that a trained model can be used for downstream + tasks without the need to specify these parameter again + :return: Dict of params to memorize in a checkpoint + """ + return { + "do_lower_case": cfg.do_lower_case, + "pretrained_model_cfg": cfg.encoder.pretrained_model_cfg, + "encoder_model_type": cfg.encoder.encoder_model_type, + "pretrained_file": cfg.encoder.pretrained_file, + "projection_dim": cfg.encoder.projection_dim, + "sequence_length": cfg.encoder.sequence_length, + } + + +def set_seed(args): + seed = args.seed + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if args.n_gpu > 0: + torch.cuda.manual_seed_all(seed) + + +def setup_cfg_gpu(cfg): + """ + Setup params for CUDA, GPU & distributed training + """ + logger.info("args.local_rank %s", cfg.local_rank) + ws = os.environ.get("WORLD_SIZE") + cfg.distributed_world_size = int(ws) if ws else 1 + logger.info("WORLD_SIZE %s", ws) + if cfg.local_rank == -1 or cfg.no_cuda: # single-node multi-gpu (or cpu) mode + device = str( + torch.device( + "cuda" if torch.cuda.is_available() and not cfg.no_cuda else "cpu" + ) + ) + cfg.n_gpu = torch.cuda.device_count() + else: # distributed mode + torch.cuda.set_device(cfg.local_rank) + device = str(torch.device("cuda", cfg.local_rank)) + torch.distributed.init_process_group(backend="nccl") + cfg.n_gpu = 1 + + cfg.device = device + + logger.info( + "Initialized host %s as d.rank %d on device=%s, n_gpu=%d, world size=%d", + socket.gethostname(), + cfg.local_rank, + cfg.device, + cfg.n_gpu, + cfg.distributed_world_size, + ) + logger.info("16-bits training: %s ", cfg.fp16) + return cfg + + +def setup_logger(logger): + logger.setLevel(logging.INFO) + if logger.hasHandlers(): + logger.handlers.clear() + log_formatter = logging.Formatter( + "[%(thread)s] %(asctime)s [%(levelname)s] %(name)s: %(message)s" + ) + console = logging.StreamHandler() + console.setFormatter(log_formatter) + logger.addHandler(console) diff --git a/research/information_retrieval/DPR/dpr/utils/__init__.py b/research/information_retrieval/DPR/dpr/utils/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/research/information_retrieval/DPR/dpr/utils/conf_utils.py b/research/information_retrieval/DPR/dpr/utils/conf_utils.py new file mode 100644 index 00000000000..3b5d544c7c9 --- /dev/null +++ b/research/information_retrieval/DPR/dpr/utils/conf_utils.py @@ -0,0 +1,28 @@ +import logging + +import hydra +from omegaconf import DictConfig + +logger = logging.getLogger(__name__) + + +class BiencoderDatasetsCfg(object): + def __init__(self, cfg: DictConfig): + datasets = cfg.datasets + self.train_datasets_names = cfg.train_datasets + logger.info("train_datasets: %s", self.train_datasets_names) + if self.train_datasets_names: + self.train_datasets = [ + hydra.utils.instantiate(datasets[ds_name]) + for ds_name in self.train_datasets_names + ] + else: + self.train_datasets = [] + if cfg.dev_datasets: + self.dev_datasets_names = cfg.dev_datasets + logger.info("dev_datasets: %s", self.dev_datasets_names) + self.dev_datasets = [ + hydra.utils.instantiate(datasets[ds_name]) + for ds_name in self.dev_datasets_names + ] + self.sampling_rates = cfg.train_sampling_rates diff --git a/research/information_retrieval/DPR/dpr/utils/data_utils.py b/research/information_retrieval/DPR/dpr/utils/data_utils.py new file mode 100644 index 00000000000..53c23695479 --- /dev/null +++ b/research/information_retrieval/DPR/dpr/utils/data_utils.py @@ -0,0 +1,333 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Utilities for general purpose data processing +""" +import json +import logging +import pickle +import random + +import itertools +import math +import torch +from torch import Tensor as T +from typing import List, Iterator, Callable, Tuple + +logger = logging.getLogger() + + +def read_serialized_data_from_files(paths: List[str]) -> List: + results = [] + for i, path in enumerate(paths): + with open(path, "rb") as reader: + logger.info("Reading file %s", path) + data = pickle.load(reader) + results.extend(data) + logger.info("Aggregated data size: {}".format(len(results))) + logger.info("Total data size: {}".format(len(results))) + return results + + +def read_data_from_json_files(paths: List[str]) -> List: + results = [] + for i, path in enumerate(paths): + with open(path, "r", encoding="utf-8") as f: + logger.info("Reading file %s" % path) + data = json.load(f) + results = data + logger.info("Aggregated data size: {}".format(len(results))) + return results + + +class ShardedDataIterator(object): + """ + General purpose data iterator to be used for Pytorch's DDP mode where every node should handle its own part of + the data. + Instead of cutting data shards by their min size, it sets the amount of iterations by the maximum shard size. + It fills the extra sample by just taking first samples in a shard. + It can also optionally enforce identical batch size for all iterations (might be useful for DP mode). + """ + + def __init__( + self, + data: torch.utils.data.Dataset, + shard_id: int = 0, + num_shards: int = 1, + batch_size: int = 1, + shuffle=True, + shuffle_seed: int = 0, + offset: int = 0, + strict_batch_size: bool = False, + ): + + self.data = data + total_size = len(data) + + self.shards_num = max(num_shards, 1) + self.shard_id = max(shard_id, 0) + + samples_per_shard = math.ceil(total_size / self.shards_num) + + self.shard_start_idx = self.shard_id * samples_per_shard + + self.shard_end_idx = min(self.shard_start_idx + samples_per_shard, total_size) + + if strict_batch_size: + self.max_iterations = math.ceil(samples_per_shard / batch_size) + else: + self.max_iterations = int(samples_per_shard / batch_size) + + logger.info( + "samples_per_shard=%d, shard_start_idx=%d, shard_end_idx=%d, max_iterations=%d", + samples_per_shard, + self.shard_start_idx, + self.shard_end_idx, + self.max_iterations, + ) + + self.iteration = offset # to track in-shard iteration status + self.shuffle = shuffle + self.batch_size = batch_size + self.shuffle_seed = shuffle_seed + self.strict_batch_size = strict_batch_size + + def total_data_len(self) -> int: + return len(self.data) + + def iterations_num(self) -> int: + return self.max_iterations - self.iteration + + def max_iterations_num(self) -> int: + return self.max_iterations + + def get_iteration(self) -> int: + return self.iteration + + def apply(self, visitor_func: Callable): + for sample in self.data: + visitor_func(sample) + + def get_shard_indices(self, epoch: int): + indices = list(range(len(self.data))) + if self.shuffle: + # to be able to resume, same shuffling should be used when starting from a failed/stopped iteration + epoch_rnd = random.Random(self.shuffle_seed + epoch) + epoch_rnd.shuffle(indices) + shard_indices = indices[self.shard_start_idx : self.shard_end_idx] + return shard_indices + + # TODO: merge with iterate_ds_sampled_data + def iterate_ds_data(self, epoch: int = 0) -> Iterator[List]: + # if resuming iteration somewhere in the middle of epoch, one needs to adjust max_iterations + max_iterations = self.max_iterations - self.iteration + shard_indices = self.get_shard_indices(epoch) + + for i in range( + self.iteration * self.batch_size, len(shard_indices), self.batch_size + ): + items_idxs = shard_indices[i : i + self.batch_size] + if self.strict_batch_size and len(items_idxs) < self.batch_size: + logger.debug("Extending batch to max size") + items_idxs.extend(shard_indices[0 : self.batch_size - len(items)]) + self.iteration += 1 + items = [self.data[idx] for idx in items_idxs] + yield items + + # some shards may done iterating while the others are at the last batch. Just return the first batch + while self.iteration < max_iterations: + logger.debug("Fulfilling non complete shard=".format(self.shard_id)) + self.iteration += 1 + items_idxs = shard_indices[0 : self.batch_size] + items = [self.data[idx] for idx in items_idxs] + yield items + + logger.info( + "Finished iterating, iteration={}, shard={}".format( + self.iteration, self.shard_id + ) + ) + # reset the iteration status + self.iteration = 0 + + def iterate_ds_sampled_data( + self, num_iterations: int, epoch: int = 0 + ) -> Iterator[List]: + self.iteration = 0 + shard_indices = self.get_shard_indices(epoch) + cycle_it = itertools.cycle(shard_indices) + for i in range(num_iterations): + items_idxs = [next(cycle_it) for _ in range(self.batch_size)] + self.iteration += 1 + items = [self.data[idx] for idx in items_idxs] + yield items + + logger.info( + "Finished iterating, iteration={}, shard={}".format( + self.iteration, self.shard_id + ) + ) + # TODO: reset the iteration status? + self.iteration = 0 + + def get_dataset(self) -> torch.utils.data.Dataset: + return self.data + + +class MultiSetDataIterator(object): + """ + Iterator over multiple data sources. Useful when all samples form a single batch should be from the same dataset. + """ + + def __init__( + self, + datasets: List[ShardedDataIterator], + shuffle_seed: int = 0, + shuffle=True, + sampling_rates: List = [], + rank: int = 0, + ): + self.iterables = datasets + data_lengths = [it.total_data_len() for it in datasets] + self.total_data = sum(data_lengths) + logger.info("rank=%d; Multi set data sizes %s", rank, data_lengths) + logger.info("rank=%d; Multi set total data %s", rank, self.total_data) + logger.info("rank=%d; Multi set sampling_rates %s", rank, sampling_rates) + self.shuffle_seed = shuffle_seed + self.shuffle = shuffle + self.iteration = 0 + self.rank = rank + + if sampling_rates: + self.max_its_pr_ds = [ + int(ds.max_iterations_num() * sampling_rates[i]) + for i, ds in enumerate(datasets) + ] + else: + self.max_its_pr_ds = [ds.max_iterations_num() for ds in datasets] + + self.max_iterations = sum(self.max_its_pr_ds) + logger.info( + "rank=%d; Multi set max_iterations per dataset %s", rank, self.max_its_pr_ds + ) + logger.info("rank=%d; Multi set max_iterations %d", rank, self.max_iterations) + + def total_data_len(self) -> int: + return self.total_data + + def get_max_iterations(self): + return self.max_iterations + + def iterate_ds_data(self, epoch: int = 0) -> Iterator[Tuple[List, int]]: + + logger.info("rank=%d; Iteration start", self.rank) + logger.info( + "rank=%d; Multi set iteration: iteration ptr per set: %s", + self.rank, + [it.get_iteration() for it in self.iterables], + ) + + data_src_indices = [] + iterators = [] + for source, src_its in enumerate(self.max_its_pr_ds): + logger.info( + "rank=%d; Multi set iteration: source %d, batches to be taken: %s", + self.rank, + source, + src_its, + ) + data_src_indices.extend([source] * src_its) + + iterators.append( + self.iterables[source].iterate_ds_sampled_data(src_its, epoch=epoch) + ) + + if self.shuffle: + # to be able to resume, same shuffling should be used when starting from a failed/stopped iteration + epoch_rnd = random.Random(self.shuffle_seed + epoch) + epoch_rnd.shuffle(data_src_indices) + + logger.info( + "rank=%d; data_src_indices len=%d", self.rank, len(data_src_indices) + ) + for i, source_idx in enumerate(data_src_indices): + it = iterators[source_idx] + next_item = next(it, None) + if next_item is not None: + self.iteration += 1 + yield (next_item, source_idx) + else: + logger.warning( + "rank=%d; Next item in the source %s is None", self.rank, source_idx + ) + + logger.info("rank=%d; last iteration %d", self.rank, self.iteration) + + logger.info( + "rank=%d; Multi set iteration finished: iteration per set: %s", + self.rank, + [it.iteration for it in self.iterables], + ) + [next(it, None) for it in iterators] + + # TODO: clear iterators in some non-hacky way + for it in self.iterables: + it.iteration = 0 + logger.info( + "rank=%d; Multi set iteration finished after next: iteration per set: %s", + self.rank, + [it.iteration for it in self.iterables], + ) + # reset the iteration status + self.iteration = 0 + + def get_iteration(self) -> int: + return self.iteration + + def get_dataset(self, ds_id: int) -> torch.utils.data.Dataset: + return self.iterables[ds_id].get_dataset() + + def get_datasets(self) -> List[torch.utils.data.Dataset]: + return [it.get_dataset() for it in self.iterables] + + +class Tensorizer(object): + """ + Component for all text to model input data conversions and related utility methods + """ + + # Note: title, if present, is supposed to be put before text (i.e. optional title + document body) + def text_to_tensor( + self, + text: str, + title: str = None, + add_special_tokens: bool = True, + apply_max_len: bool = True, + ): + raise NotImplementedError + + def get_pair_separator_ids(self) -> T: + raise NotImplementedError + + def get_pad_id(self) -> int: + raise NotImplementedError + + def get_attn_mask(self, tokens_tensor: T): + raise NotImplementedError + + def is_sub_word_id(self, token_id: int): + raise NotImplementedError + + def to_string(self, token_ids, skip_special_tokens=True): + raise NotImplementedError + + def set_pad_to_max(self, pad: bool): + raise NotImplementedError + + def get_token_id(self, token: str) -> int: + raise NotImplementedError diff --git a/research/information_retrieval/DPR/dpr/utils/dist_utils.py b/research/information_retrieval/DPR/dpr/utils/dist_utils.py new file mode 100644 index 00000000000..3b0bf85c3a5 --- /dev/null +++ b/research/information_retrieval/DPR/dpr/utils/dist_utils.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Utilities for distributed model training +""" + +import pickle + +import torch +import torch.distributed as dist + + +def get_rank(): + return dist.get_rank() + + +def get_world_size(): + return dist.get_world_size() + + +def get_default_group(): + return dist.group.WORLD + + +def all_reduce(tensor, group=None): + if group is None: + group = get_default_group() + return dist.all_reduce(tensor, group=group) + + +def all_gather_list(data, group=None, max_size=16384): + """Gathers arbitrary data from all nodes into a list. + Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python + data. Note that *data* must be picklable. + Args: + data (Any): data from the local worker to be gathered on other workers + group (optional): group of the collective + """ + SIZE_STORAGE_BYTES = 4 # int32 to encode the payload size + + enc = pickle.dumps(data) + enc_size = len(enc) + + if enc_size + SIZE_STORAGE_BYTES > max_size: + raise ValueError( + 'encoded data exceeds max_size, this can be fixed by increasing buffer size: {}'.format(enc_size)) + + rank = get_rank() + world_size = get_world_size() + buffer_size = max_size * world_size + + if not hasattr(all_gather_list, '_buffer') or \ + all_gather_list._buffer.numel() < buffer_size: + all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size) + all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory() + + buffer = all_gather_list._buffer + buffer.zero_() + cpu_buffer = all_gather_list._cpu_buffer + + assert enc_size < 256 ** SIZE_STORAGE_BYTES, 'Encoded object size should be less than {} bytes'.format( + 256 ** SIZE_STORAGE_BYTES) + + size_bytes = enc_size.to_bytes(SIZE_STORAGE_BYTES, byteorder='big') + + cpu_buffer[0:SIZE_STORAGE_BYTES] = torch.ByteTensor(list(size_bytes)) + cpu_buffer[SIZE_STORAGE_BYTES: enc_size + SIZE_STORAGE_BYTES] = torch.ByteTensor(list(enc)) + + start = rank * max_size + size = enc_size + SIZE_STORAGE_BYTES + buffer[start: start + size].copy_(cpu_buffer[:size]) + + all_reduce(buffer, group=group) + + try: + result = [] + for i in range(world_size): + out_buffer = buffer[i * max_size: (i + 1) * max_size] + size = int.from_bytes(out_buffer[0:SIZE_STORAGE_BYTES], byteorder='big') + if size > 0: + result.append(pickle.loads(bytes(out_buffer[SIZE_STORAGE_BYTES: size + SIZE_STORAGE_BYTES].tolist()))) + return result + except pickle.UnpicklingError: + raise Exception( + 'Unable to unpickle data from other workers. all_gather_list requires all ' + 'workers to enter the function together, so this error usually indicates ' + 'that the workers have fallen out of sync somehow. Workers can fall out of ' + 'sync if one of them runs out of memory, or if there are other conditions ' + 'in your training script that can cause one worker to finish an epoch ' + 'while other workers are still iterating over their portions of the data.' + ) diff --git a/research/information_retrieval/DPR/dpr/utils/model_utils.py b/research/information_retrieval/DPR/dpr/utils/model_utils.py new file mode 100644 index 00000000000..ec8455f5620 --- /dev/null +++ b/research/information_retrieval/DPR/dpr/utils/model_utils.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import collections +import glob +import logging +import os +from typing import List + +import torch +from torch import nn +from torch.optim.lr_scheduler import LambdaLR +from torch.serialization import default_restore_location + +logger = logging.getLogger() + +CheckpointState = collections.namedtuple( + "CheckpointState", + [ + "model_dict", + "optimizer_dict", + "scheduler_dict", + "offset", + "epoch", + "encoder_params", + ], +) + + +def setup_for_distributed_mode( + model: nn.Module, + optimizer: torch.optim.Optimizer, + device: object, + n_gpu: int = 1, + local_rank: int = -1, + fp16: bool = False, + fp16_opt_level: str = "O1", +) -> (nn.Module, torch.optim.Optimizer): + model.to(device) + if fp16: + try: + import apex + from apex import amp + + apex.amp.register_half_function(torch, "einsum") + except ImportError: + raise ImportError( + "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." + ) + + model, optimizer = amp.initialize(model, optimizer, opt_level=fp16_opt_level) + + if n_gpu > 1: + model = torch.nn.DataParallel(model) + + if local_rank != -1: + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[device if device else local_rank], + output_device=local_rank, + find_unused_parameters=True, + ) + return model, optimizer + + +def move_to_cuda(sample): + if len(sample) == 0: + return {} + + def _move_to_cuda(maybe_tensor): + if torch.is_tensor(maybe_tensor): + return maybe_tensor.cuda() + elif isinstance(maybe_tensor, dict): + return {key: _move_to_cuda(value) for key, value in maybe_tensor.items()} + elif isinstance(maybe_tensor, list): + return [_move_to_cuda(x) for x in maybe_tensor] + elif isinstance(maybe_tensor, tuple): + return [_move_to_cuda(x) for x in maybe_tensor] + else: + return maybe_tensor + + return _move_to_cuda(sample) + + +def move_to_device(sample, device): + if len(sample) == 0: + return {} + + def _move_to_device(maybe_tensor, device): + if torch.is_tensor(maybe_tensor): + return maybe_tensor.to(device) + elif isinstance(maybe_tensor, dict): + return { + key: _move_to_device(value, device) + for key, value in maybe_tensor.items() + } + elif isinstance(maybe_tensor, list): + return [_move_to_device(x, device) for x in maybe_tensor] + elif isinstance(maybe_tensor, tuple): + return [_move_to_device(x, device) for x in maybe_tensor] + else: + return maybe_tensor + + return _move_to_device(sample, device) + + +def get_schedule_linear( + optimizer, + warmup_steps, + total_training_steps, + steps_shift=0, + last_epoch=-1, +): + + """Create a schedule with a learning rate that decreases linearly after + linearly increasing during a warmup period. + """ + + def lr_lambda(current_step): + current_step += steps_shift + if current_step < warmup_steps: + return float(current_step) / float(max(1, warmup_steps)) + return max( + 1e-7, + float(total_training_steps - current_step) + / float(max(1, total_training_steps - warmup_steps)), + ) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def init_weights(modules: List): + for module in modules: + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +def get_model_obj(model: nn.Module): + return model.module if hasattr(model, "module") else model + + +def get_model_file(args, file_prefix) -> str: + if args.model_file and os.path.exists(args.model_file): + return args.model_file + + out_cp_files = ( + glob.glob(os.path.join(args.output_dir, file_prefix + "*")) + if args.output_dir + else [] + ) + logger.info("Checkpoint files %s", out_cp_files) + model_file = None + + if len(out_cp_files) > 0: + model_file = max(out_cp_files, key=os.path.getctime) + return model_file + + +def load_states_from_checkpoint(model_file: str) -> CheckpointState: + logger.info("Reading saved model from %s", model_file) + state_dict = torch.load( + model_file, map_location=lambda s, l: default_restore_location(s, "cpu") + ) + logger.info("model_state_dict keys %s", state_dict.keys()) + return CheckpointState(**state_dict) diff --git a/research/information_retrieval/DPR/dpr/utils/tokenizers.py b/research/information_retrieval/DPR/dpr/utils/tokenizers.py new file mode 100644 index 00000000000..c7fa30d2fb1 --- /dev/null +++ b/research/information_retrieval/DPR/dpr/utils/tokenizers.py @@ -0,0 +1,241 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +""" +Most of the tokenizers code here is copied from DrQA codebase to avoid adding extra dependency +""" + +import copy +import logging + +import regex +import spacy + +logger = logging.getLogger(__name__) + + +class Tokens(object): + """A class to represent a list of tokenized text.""" + TEXT = 0 + TEXT_WS = 1 + SPAN = 2 + POS = 3 + LEMMA = 4 + NER = 5 + + def __init__(self, data, annotators, opts=None): + self.data = data + self.annotators = annotators + self.opts = opts or {} + + def __len__(self): + """The number of tokens.""" + return len(self.data) + + def slice(self, i=None, j=None): + """Return a view of the list of tokens from [i, j).""" + new_tokens = copy.copy(self) + new_tokens.data = self.data[i: j] + return new_tokens + + def untokenize(self): + """Returns the original text (with whitespace reinserted).""" + return ''.join([t[self.TEXT_WS] for t in self.data]).strip() + + def words(self, uncased=False): + """Returns a list of the text of each token + + Args: + uncased: lower cases text + """ + if uncased: + return [t[self.TEXT].lower() for t in self.data] + else: + return [t[self.TEXT] for t in self.data] + + def offsets(self): + """Returns a list of [start, end) character offsets of each token.""" + return [t[self.SPAN] for t in self.data] + + def pos(self): + """Returns a list of part-of-speech tags of each token. + Returns None if this annotation was not included. + """ + if 'pos' not in self.annotators: + return None + return [t[self.POS] for t in self.data] + + def lemmas(self): + """Returns a list of the lemmatized text of each token. + Returns None if this annotation was not included. + """ + if 'lemma' not in self.annotators: + return None + return [t[self.LEMMA] for t in self.data] + + def entities(self): + """Returns a list of named-entity-recognition tags of each token. + Returns None if this annotation was not included. + """ + if 'ner' not in self.annotators: + return None + return [t[self.NER] for t in self.data] + + def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True): + """Returns a list of all ngrams from length 1 to n. + + Args: + n: upper limit of ngram length + uncased: lower cases text + filter_fn: user function that takes in an ngram list and returns + True or False to keep or not keep the ngram + as_string: return the ngram as a string vs list + """ + + def _skip(gram): + if not filter_fn: + return False + return filter_fn(gram) + + words = self.words(uncased) + ngrams = [(s, e + 1) + for s in range(len(words)) + for e in range(s, min(s + n, len(words))) + if not _skip(words[s:e + 1])] + + # Concatenate into strings + if as_strings: + ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams] + + return ngrams + + def entity_groups(self): + """Group consecutive entity tokens with the same NER tag.""" + entities = self.entities() + if not entities: + return None + non_ent = self.opts.get('non_ent', 'O') + groups = [] + idx = 0 + while idx < len(entities): + ner_tag = entities[idx] + # Check for entity tag + if ner_tag != non_ent: + # Chomp the sequence + start = idx + while (idx < len(entities) and entities[idx] == ner_tag): + idx += 1 + groups.append((self.slice(start, idx).untokenize(), ner_tag)) + else: + idx += 1 + return groups + + +class Tokenizer(object): + """Base tokenizer class. + Tokenizers implement tokenize, which should return a Tokens class. + """ + + def tokenize(self, text): + raise NotImplementedError + + def shutdown(self): + pass + + def __del__(self): + self.shutdown() + + +class SimpleTokenizer(Tokenizer): + ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' + NON_WS = r'[^\p{Z}\p{C}]' + + def __init__(self, **kwargs): + """ + Args: + annotators: None or empty set (only tokenizes). + """ + self._regexp = regex.compile( + '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), + flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE + ) + if len(kwargs.get('annotators', {})) > 0: + logger.warning('%s only tokenizes! Skipping annotators: %s' % + (type(self).__name__, kwargs.get('annotators'))) + self.annotators = set() + + def tokenize(self, text): + data = [] + matches = [m for m in self._regexp.finditer(text)] + for i in range(len(matches)): + # Get text + token = matches[i].group() + + # Get whitespace + span = matches[i].span() + start_ws = span[0] + if i + 1 < len(matches): + end_ws = matches[i + 1].span()[0] + else: + end_ws = span[1] + + # Format data + data.append(( + token, + text[start_ws: end_ws], + span, + )) + return Tokens(data, self.annotators) + + +class SpacyTokenizer(Tokenizer): + + def __init__(self, **kwargs): + """ + Args: + annotators: set that can include pos, lemma, and ner. + model: spaCy model to use (either path, or keyword like 'en'). + """ + model = kwargs.get("model", "en_core_web_sm") + self.annotators = copy.deepcopy(kwargs.get('annotators', set())) + nlp_kwargs = {'parser': False} + if not any([p in self.annotators for p in ['lemma', 'pos', 'ner']]): + nlp_kwargs['tagger'] = False + if 'ner' not in self.annotators: + nlp_kwargs['entity'] = False + self.nlp = spacy.load(model, **nlp_kwargs) + + def tokenize(self, text): + # We don't treat new lines as tokens. + clean_text = text.replace('\n', ' ') + tokens = self.nlp.tokenizer(clean_text) + if any([p in self.annotators for p in ['lemma', 'pos', 'ner']]): + self.nlp.tagger(tokens) + if 'ner' in self.annotators: + self.nlp.entity(tokens) + + data = [] + for i in range(len(tokens)): + # Get whitespace + start_ws = tokens[i].idx + if i + 1 < len(tokens): + end_ws = tokens[i + 1].idx + else: + end_ws = tokens[i].idx + len(tokens[i].text) + + data.append(( + tokens[i].text, + text[start_ws: end_ws], + (tokens[i].idx, tokens[i].idx + len(tokens[i].text)), + tokens[i].tag_, + tokens[i].lemma_, + tokens[i].ent_type_, + )) + + # Set special option for non-entity tag: '' vs 'O' in spaCy + return Tokens(data, self.annotators, opts={'non_ent': ''}) diff --git a/research/information_retrieval/DPR/generate_dense_embeddings.py b/research/information_retrieval/DPR/generate_dense_embeddings.py new file mode 100644 index 00000000000..fe7b9b6f91a --- /dev/null +++ b/research/information_retrieval/DPR/generate_dense_embeddings.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" + Command line tool that produces embeddings for a large documents base based on the pretrained ctx & question encoders + Supposed to be used in a 'sharded' way to speed up the process. +""" +import logging +import math +import os +import pathlib +import pickle +from typing import List, Tuple + +import hydra +import numpy as np +import torch +from omegaconf import DictConfig, OmegaConf +from torch import nn + +from dpr.data.biencoder_data import BiEncoderPassage +from dpr.models import init_biencoder_components +from dpr.options import set_cfg_params_from_state, setup_cfg_gpu, setup_logger + +from dpr.utils.data_utils import Tensorizer +from dpr.utils.model_utils import ( + setup_for_distributed_mode, + get_model_obj, + load_states_from_checkpoint, + move_to_device, +) + +logger = logging.getLogger() +setup_logger(logger) + + +def gen_ctx_vectors( + cfg: DictConfig, + ctx_rows: List[Tuple[object, BiEncoderPassage]], + model: nn.Module, + tensorizer: Tensorizer, + insert_title: bool = True, +) -> List[Tuple[object, np.array]]: + n = len(ctx_rows) + bsz = cfg.batch_size + total = 0 + results = [] + for j, batch_start in enumerate(range(0, n, bsz)): + batch = ctx_rows[batch_start : batch_start + bsz] + batch_token_tensors = [ + tensorizer.text_to_tensor( + ctx[1].text, title=ctx[1].title if insert_title else None + ) + for ctx in batch + ] + + ctx_ids_batch = move_to_device( + torch.stack(batch_token_tensors, dim=0), cfg.device + ) + ctx_seg_batch = move_to_device(torch.zeros_like(ctx_ids_batch), cfg.device) + ctx_attn_mask = move_to_device( + tensorizer.get_attn_mask(ctx_ids_batch), cfg.device + ) + with torch.no_grad(): + _, out, _ = model(ctx_ids_batch, ctx_seg_batch, ctx_attn_mask) + out = out.cpu() + + ctx_ids = [r[0] for r in batch] + extra_info = [] + if len(batch[0]) > 3: + extra_info = [r[3:] for r in batch] + + assert len(ctx_ids) == out.size(0) + total += len(ctx_ids) + + # TODO: refactor to avoid 'if' + if extra_info: + results.extend( + [ + (ctx_ids[i], out[i].view(-1).numpy(), *extra_info[i]) + for i in range(out.size(0)) + ] + ) + else: + results.extend( + [(ctx_ids[i], out[i].view(-1).numpy()) for i in range(out.size(0))] + ) + + if total % 10 == 0: + logger.info("Encoded passages %d", total) + return results + +@hydra.main(config_path="conf", config_name="gen_embs") +def main(cfg: DictConfig): + + assert cfg.model_file, "Please specify encoder checkpoint as model_file param" + assert cfg.ctx_src, "Please specify passages source as ctx_src param" + print(os.getcwd()) + cfg = setup_cfg_gpu(cfg) + + saved_state = load_states_from_checkpoint(cfg.model_file) + set_cfg_params_from_state(saved_state.encoder_params, cfg) + + logger.info("CFG:") + logger.info("%s", OmegaConf.to_yaml(cfg)) + + tensorizer, encoder, _ = init_biencoder_components( + cfg.encoder.encoder_model_type, cfg, inference_only=True + ) + + encoder = encoder.ctx_model if cfg.encoder_type == "ctx" else encoder.question_model + + encoder, _ = setup_for_distributed_mode( + encoder, + None, + cfg.device, + cfg.n_gpu, + cfg.local_rank, + cfg.fp16, + cfg.fp16_opt_level, + ) + encoder.eval() + + # load weights from the model file + model_to_load = get_model_obj(encoder) + logger.info("Loading saved model state ...") + logger.debug("saved model keys =%s", saved_state.model_dict.keys()) + + prefix_len = len("ctx_model.") + ctx_state = { + key[prefix_len:]: value + for (key, value) in saved_state.model_dict.items() + if key.startswith("ctx_model.") + } + model_to_load.load_state_dict(ctx_state) + + logger.info("reading data source: %s", cfg.ctx_src) + + ctx_src = hydra.utils.instantiate(cfg.ctx_sources[cfg.ctx_src]) + all_passages_dict = {} + ctx_src.load_data_to(all_passages_dict) + all_passages = [(k, v) for k, v in all_passages_dict.items()] + + shard_size = math.ceil(len(all_passages) / cfg.num_shards) + start_idx = cfg.shard_id * shard_size + end_idx = start_idx + shard_size + + logger.info( + "Producing encodings for passages range: %d to %d (out of total %d)", + start_idx, + end_idx, + len(all_passages), + ) + shard_passages = all_passages[start_idx:end_idx] + + data = gen_ctx_vectors(cfg, shard_passages, encoder, tensorizer, True) + + file = cfg.out_file + "_" + str(cfg.shard_id) + pathlib.Path(os.path.dirname(file)).mkdir(parents=True, exist_ok=True) + logger.info("Writing results to %s" % file) + with open(file, mode="wb") as f: + pickle.dump(data, f) + + logger.info("Total passages processed %d. Written to %s", len(data), file) + + +if __name__ == "__main__": + main() diff --git a/research/information_retrieval/DPR/model_config.yml b/research/information_retrieval/DPR/model_config.yml new file mode 100644 index 00000000000..93ee638ac84 --- /dev/null +++ b/research/information_retrieval/DPR/model_config.yml @@ -0,0 +1,24 @@ +# @package _group_ + +# model type. One of [hf_bert, pytext_bert, fairseq_roberta] +encoder_model_type: hf_bert + +# HuggingFace's config name for model initialization +pretrained_model_cfg: bert-base-uncased + +# Some encoders need to be initialized from a file +pretrained_file: + +# Extra linear layer on top of standard bert/roberta encoder +projection_dim: 0 + +# Max length of the encoder input sequence +sequence_length: 256 + +dropout: 0.1 + +# whether to fix (don't update) context encoder during training or not +fix_ctx_encoder: False + +# if False, the model won't load pre-trained BERT weights +pretrained: True \ No newline at end of file diff --git a/research/information_retrieval/DPR/ms_marco_eval.py b/research/information_retrieval/DPR/ms_marco_eval.py new file mode 100644 index 00000000000..2ca08902c54 --- /dev/null +++ b/research/information_retrieval/DPR/ms_marco_eval.py @@ -0,0 +1,177 @@ +""" +This module computes evaluation metrics for MSMARCO dataset on the ranking task. Intenral hard coded eval files version. DO NOT PUBLISH! +Command line: +python msmarco_eval_ranking.py + +Creation Date : 06/12/2018 +Last Modified : 4/09/2019 +Authors : Daniel Campos , Rutger van Haasteren +""" +import sys +import statistics + +from collections import Counter + +MaxMRRRank = 10 + +def load_reference_from_stream(f): + """Load Reference reference relevant passages + Args:f (stream): stream to load. + Returns:qids_to_relevant_passageids (dict): dictionary mapping from query_id (int) to relevant passages (list of ints). + """ + qids_to_relevant_passageids = {} + for l in f: + try: + l = l.strip().split('\t') + qid = int(l[0]) + if qid in qids_to_relevant_passageids: + pass + else: + qids_to_relevant_passageids[qid] = [] + qids_to_relevant_passageids[qid].append(int(l[2])) + except: + raise IOError('\"%s\" is not valid format' % l) + return qids_to_relevant_passageids + +def load_reference(path_to_reference): + """Load Reference reference relevant passages + Args:path_to_reference (str): path to a file to load. + Returns:qids_to_relevant_passageids (dict): dictionary mapping from query_id (int) to relevant passages (list of ints). + """ + with open(path_to_reference,'r') as f: + qids_to_relevant_passageids = load_reference_from_stream(f) + return qids_to_relevant_passageids + +def load_candidate_from_stream(f): + """Load candidate data from a stream. + Args:f (stream): stream to load. + Returns:qid_to_ranked_candidate_passages (dict): dictionary mapping from query_id (int) to a list of 1000 passage ids(int) ranked by relevance and importance + """ + qid_to_ranked_candidate_passages = {} + for l in f: + try: + l = l.strip().split('\t') + qid = int(l[0]) + pid = int(l[1]) + rank = int(l[2]) + if qid in qid_to_ranked_candidate_passages: + pass + else: + # By default, all PIDs in the list of 1000 are 0. Only override those that are given + tmp = [0] * 1000 + qid_to_ranked_candidate_passages[qid] = tmp + qid_to_ranked_candidate_passages[qid][rank-1]=pid + except: + raise IOError('\"%s\" is not valid format' % l) + return qid_to_ranked_candidate_passages + +def load_candidate(path_to_candidate): + """Load candidate data from a file. + Args:path_to_candidate (str): path to file to load. + Returns:qid_to_ranked_candidate_passages (dict): dictionary mapping from query_id (int) to a list of 1000 passage ids(int) ranked by relevance and importance + """ + + with open(path_to_candidate,'r') as f: + qid_to_ranked_candidate_passages = load_candidate_from_stream(f) + return qid_to_ranked_candidate_passages + +def quality_checks_qids(qids_to_relevant_passageids, qids_to_ranked_candidate_passages): + """Perform quality checks on the dictionaries + + Args: + p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping + Dict as read in with load_reference or load_reference_from_stream + p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates + Returns: + bool,str: Boolean whether allowed, message to be shown in case of a problem + """ + message = '' + allowed = True + + # Create sets of the QIDs for the submitted and reference queries + candidate_set = set(qids_to_ranked_candidate_passages.keys()) + ref_set = set(qids_to_relevant_passageids.keys()) + + # Check that we do not have multiple passages per query + for qid in qids_to_ranked_candidate_passages: + # Remove all zeros from the candidates + duplicate_pids = set([item for item, count in Counter(qids_to_ranked_candidate_passages[qid]).items() if count > 1]) + + if len(duplicate_pids-set([0])) > 0: + message = "Cannot rank a passage multiple times for a single query. QID={qid}, PID={pid}".format( + qid=qid, pid=list(duplicate_pids)[0]) + allowed = False + + return allowed, message + +def compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages): + """Compute MRR metric + Args: + p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping + Dict as read in with load_reference or load_reference_from_stream + p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates + Returns: + dict: dictionary of metrics {'MRR': } + """ + all_scores = {} + MRR = 0 + qids_with_relevant_passages = 0 + ranking = [] + j = 0 + for qid in qids_to_ranked_candidate_passages: + if qid in qids_to_relevant_passageids: + ranking.append(0) + target_pid = qids_to_relevant_passageids[qid] + candidate_pid = qids_to_ranked_candidate_passages[qid] + for i in range(0,MaxMRRRank): + if candidate_pid[i] in target_pid: + MRR += 1/(i + 1) + j += 1 + ranking.pop() + ranking.append(i+1) + break + if len(ranking) == 0: + raise IOError("No matching QIDs found. Are you sure you are scoring the evaluation set?") + denominator = len(qids_to_ranked_candidate_passages) #qids_to_relevant_passageids + MRR = MRR/denominator + all_scores['MRR @10'] = MRR + all_scores['Recal @10'] = j/denominator + all_scores['QueriesRanked'] = denominator + return all_scores + +def compute_metrics_from_files(path_to_reference, path_to_candidate, perform_checks=True): + """Compute MRR metric + Args: + p_path_to_reference_file (str): path to reference file. + Reference file should contain lines in the following format: + QUERYID\tPASSAGEID + Where PASSAGEID is a relevant passage for a query. Note QUERYID can repeat on different lines with different PASSAGEIDs + p_path_to_candidate_file (str): path to candidate file. + Candidate file sould contain lines in the following format: + QUERYID\tPASSAGEID1\tRank + If a user wishes to use the TREC format please run the script with a -t flag at the end. If this flag is used the expected format is + QUERYID\tITER\tDOCNO\tRANK\tSIM\tRUNID + Where the values are separated by tabs and ranked in order of relevance + Returns: + dict: dictionary of metrics {'MRR': } + """ + qids_to_relevant_passageids = load_reference(path_to_reference) + qids_to_ranked_candidate_passages = load_candidate(path_to_candidate) + if perform_checks: + allowed, message = quality_checks_qids(qids_to_relevant_passageids, qids_to_ranked_candidate_passages) + if message != '': print(message) + return compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages) + +def main(): + """Command line: + python msmarco_eval_ranking.py + """ + path_to_candidate = sys.argv[2] + path_to_reference = sys.argv[1] + metrics = compute_metrics_from_files(path_to_reference, path_to_candidate) + print('#####################') + for metric in sorted(metrics): + print('{}: {}'.format(metric, metrics[metric])) + print('#####################') +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/research/information_retrieval/DPR/requirements.txt b/research/information_retrieval/DPR/requirements.txt new file mode 100644 index 00000000000..247ca4dfdef --- /dev/null +++ b/research/information_retrieval/DPR/requirements.txt @@ -0,0 +1,14 @@ +transformers +torch +faiss +tqdm +elasticsearch +streamlit +requests +sparseml +faiss-cpu +filelock +numpy +regex +spacy +sparseml \ No newline at end of file diff --git a/research/information_retrieval/DPR/train_config.yml b/research/information_retrieval/DPR/train_config.yml new file mode 100644 index 00000000000..69696df1660 --- /dev/null +++ b/research/information_retrieval/DPR/train_config.yml @@ -0,0 +1,27 @@ +# @package _group_ + +batch_size: 4 +dev_batch_size: 16 +adam_eps: 1e-8 +adam_betas: (0.9, 0.999) +max_grad_norm: 2.0 +log_batch_step: 1 +train_rolling_loss_step: 100 +weight_decay: 0.0 +learning_rate: 2e-5 + +# Linear warmup over warmup_steps. +warmup_steps: 1237 + +# Number of updates steps to accumulate before performing a backward/update pass. +gradient_accumulation_steps: 1 + +# Total number of training epochs to perform. +num_train_epochs: 40 +eval_per_epoch: 1 +hard_negatives: 1 +other_negatives: 0 +val_av_rank_hard_neg: 30 +val_av_rank_other_neg: 30 +val_av_rank_bsz: 128 +val_av_rank_max_qs: 10000 diff --git a/research/information_retrieval/DPR/train_dense_encoder.py b/research/information_retrieval/DPR/train_dense_encoder.py new file mode 100644 index 00000000000..73818e45c3e --- /dev/null +++ b/research/information_retrieval/DPR/train_dense_encoder.py @@ -0,0 +1,836 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +""" + Pipeline to train DPR Biencoder +""" + +import logging +import math +import os +import random +import sys +import time +from typing import Tuple + +import hydra +import torch +from omegaconf import DictConfig, OmegaConf +from torch import Tensor as T +from torch import nn + +from dpr.models import init_biencoder_components +from dpr.models.biencoder import BiEncoder, BiEncoderNllLoss, BiEncoderBatch +from dpr.options import ( + setup_cfg_gpu, + set_seed, + get_encoder_params_state_from_cfg, + set_cfg_params_from_state, + setup_logger, +) +from dpr.utils.conf_utils import BiencoderDatasetsCfg +from dpr.utils.data_utils import ( + ShardedDataIterator, + Tensorizer, + MultiSetDataIterator, +) +from dpr.utils.dist_utils import all_gather_list +from dpr.utils.model_utils import ( + setup_for_distributed_mode, + move_to_device, + get_schedule_linear, + CheckpointState, + get_model_file, + get_model_obj, + load_states_from_checkpoint, +) + +logger = logging.getLogger() +setup_logger(logger) + + +class BiEncoderTrainer(object): + """ + BiEncoder training pipeline component. Can be used to initiate or resume training and validate the trained model + using either binary classification's NLL loss or average rank of the question's gold passages across dataset + provided pools of negative passages. For full IR accuracy evaluation, please see generate_dense_embeddings.py + and dense_retriever.py CLI tools. + """ + + def __init__(self, cfg: DictConfig): + self.shard_id = cfg.local_rank if cfg.local_rank != -1 else 0 + self.distributed_factor = cfg.distributed_world_size or 1 + + logger.info("***** Initializing components for training *****") + + # if model file is specified, encoder parameters from saved state should be used for initialization + model_file = get_model_file(cfg, cfg.checkpoint_file_name) + saved_state = None + if model_file: + saved_state = load_states_from_checkpoint(model_file) + set_cfg_params_from_state(saved_state.encoder_params, cfg) + + tensorizer, model, optimizer = init_biencoder_components( + cfg.encoder.encoder_model_type, cfg + ) + + model, optimizer = setup_for_distributed_mode( + model, + optimizer, + cfg.device, + cfg.n_gpu, + cfg.local_rank, + cfg.fp16, + cfg.fp16_opt_level, + ) + self.biencoder = model + self.optimizer = optimizer + self.tensorizer = tensorizer + self.start_epoch = 0 + self.start_batch = 0 + self.scheduler_state = None + self.best_validation_result = None + self.best_cp_name = None + self.cfg = cfg + self.ds_cfg = BiencoderDatasetsCfg(cfg) + + if saved_state: + self._load_saved_state(saved_state) + + self.dev_iterator = None + + def get_data_iterator( + self, + batch_size: int, + is_train_set: bool, + shuffle=True, + shuffle_seed: int = 0, + offset: int = 0, + rank: int = 0, + ): + + hydra_datasets = ( + self.ds_cfg.train_datasets if is_train_set else self.ds_cfg.dev_datasets + ) + sampling_rates = self.ds_cfg.sampling_rates + + logger.info( + "Initializing task/set data %s", + self.ds_cfg.train_datasets_names + if is_train_set + else self.ds_cfg.dev_datasets_names, + ) + + # randomized data loading to avoid file system congestion + datasets_list = [ds for ds in hydra_datasets] + rnd = random.Random(rank) + rnd.shuffle(datasets_list) + [ds.load_data() for ds in datasets_list] + + sharded_iterators = [ + ShardedDataIterator( + ds, + shard_id=self.shard_id, + num_shards=self.distributed_factor, + batch_size=batch_size, + shuffle=shuffle, + shuffle_seed=shuffle_seed, + offset=offset, + ) + for ds in hydra_datasets + ] + + return MultiSetDataIterator( + sharded_iterators, + shuffle_seed, + shuffle, + sampling_rates=sampling_rates if is_train_set else [1], + rank=rank, + ) + + def run_train(self): + cfg = self.cfg + + train_iterator = self.get_data_iterator( + cfg.train.batch_size, + True, + shuffle=True, + shuffle_seed=cfg.seed, + offset=self.start_batch, + rank=cfg.local_rank, + ) + max_iterations = train_iterator.get_max_iterations() + logger.info(" Total iterations per epoch=%d", max_iterations) + if max_iterations == 0: + logger.warning("No data found for training.") + return + + updates_per_epoch = ( + train_iterator.max_iterations // cfg.train.gradient_accumulation_steps + ) + + total_updates = updates_per_epoch * cfg.train.num_train_epochs + logger.info(" Total updates=%d", total_updates) + warmup_steps = cfg.train.warmup_steps + + if self.scheduler_state: + # TODO: ideally we'd want to just call + # scheduler.load_state_dict(self.scheduler_state) + # but it doesn't work properly as of now + + logger.info("Loading scheduler state %s", self.scheduler_state) + shift = int(self.scheduler_state["last_epoch"]) + logger.info("Steps shift %d", shift) + scheduler = get_schedule_linear( + self.optimizer, + warmup_steps, + total_updates, + steps_shift=shift, + ) + else: + scheduler = get_schedule_linear( + self.optimizer, warmup_steps, total_updates + ) + + eval_step = math.ceil(updates_per_epoch / cfg.train.eval_per_epoch) + logger.info(" Eval step = %d", eval_step) + logger.info("***** Training *****") + + for epoch in range(self.start_epoch, int(cfg.train.num_train_epochs)): + logger.info("***** Epoch %d *****", epoch) + self._train_epoch(scheduler, epoch, eval_step, train_iterator) + + if cfg.local_rank in [-1, 0]: + logger.info( + "Training finished. Best validation checkpoint %s", self.best_cp_name + ) + + def validate_and_save(self, epoch: int, iteration: int, scheduler): + cfg = self.cfg + # for distributed mode, save checkpoint for only one process + save_cp = cfg.local_rank in [-1, 0] + + if epoch == cfg.val_av_rank_start_epoch: + self.best_validation_result = None + + if not cfg.dev_datasets: + validation_loss = 0 + else: + if epoch >= cfg.val_av_rank_start_epoch: + validation_loss = self.validate_average_rank() + else: + validation_loss = self.validate_nll() + + if save_cp: + cp_name = self._save_checkpoint(scheduler, epoch, iteration) + logger.info("Saved checkpoint to %s", cp_name) + + if validation_loss < (self.best_validation_result or validation_loss + 1): + self.best_validation_result = validation_loss + self.best_cp_name = cp_name + logger.info("New Best validation checkpoint %s", cp_name) + + def validate_nll(self) -> float: + logger.info("NLL validation ...") + cfg = self.cfg + self.biencoder.eval() + + if not self.dev_iterator: + self.dev_iterator = self.get_data_iterator( + cfg.train.dev_batch_size, False, shuffle=False, rank=cfg.local_rank + ) + data_iterator = self.dev_iterator + + total_loss = 0.0 + start_time = time.time() + total_correct_predictions = 0 + num_hard_negatives = cfg.train.hard_negatives + num_other_negatives = cfg.train.other_negatives + log_result_step = cfg.train.log_batch_step + batches = 0 + dataset = 0 + + for i, samples_batch in enumerate(data_iterator.iterate_ds_data()): + if isinstance(samples_batch, Tuple): + samples_batch, dataset = samples_batch + logger.info("Eval step: %d ,rnk=%s", i, cfg.local_rank) + biencoder_input = BiEncoder.create_biencoder_input2( + samples_batch, + self.tensorizer, + True, + num_hard_negatives, + num_other_negatives, + shuffle=False, + ) + + # get the token to be used for representation selection + ds_cfg = self.ds_cfg.dev_datasets[dataset] + rep_positions = ds_cfg.selector.get_positions( + biencoder_input.question_ids, self.tensorizer + ) + encoder_type = ds_cfg.encoder_type + + loss, correct_cnt = _do_biencoder_fwd_pass( + self.biencoder, + biencoder_input, + self.tensorizer, + cfg, + encoder_type=encoder_type, + rep_positions=rep_positions, + ) + total_loss += loss.item() + total_correct_predictions += correct_cnt + batches += 1 + if (i + 1) % log_result_step == 0: + logger.info( + "Eval step: %d , used_time=%f sec., loss=%f ", + i, + time.time() - start_time, + loss.item(), + ) + + total_loss = total_loss / batches + total_samples = batches * cfg.train.dev_batch_size * self.distributed_factor + correct_ratio = float(total_correct_predictions / total_samples) + logger.info( + "NLL Validation: loss = %f. correct prediction ratio %d/%d ~ %f", + total_loss, + total_correct_predictions, + total_samples, + correct_ratio, + ) + return total_loss + + def validate_average_rank(self) -> float: + """ + Validates biencoder model using each question's gold passage's rank across the set of passages from the dataset. + It generates vectors for specified amount of negative passages from each question (see --val_av_rank_xxx params) + and stores them in RAM as well as question vectors. + Then the similarity scores are calculted for the entire + num_questions x (num_questions x num_passages_per_question) matrix and sorted per quesrtion. + Each question's gold passage rank in that sorted list of scores is averaged across all the questions. + :return: averaged rank number + """ + logger.info("Average rank validation ...") + + cfg = self.cfg + self.biencoder.eval() + distributed_factor = self.distributed_factor + + if not self.dev_iterator: + self.dev_iterator = self.get_data_iterator( + cfg.train.dev_batch_size, False, shuffle=False, rank=cfg.local_rank + ) + data_iterator = self.dev_iterator + + sub_batch_size = cfg.train.val_av_rank_bsz + sim_score_f = BiEncoderNllLoss.get_similarity_function() + q_represenations = [] + ctx_represenations = [] + positive_idx_per_question = [] + + num_hard_negatives = cfg.train.val_av_rank_hard_neg + num_other_negatives = cfg.train.val_av_rank_other_neg + + log_result_step = cfg.train.log_batch_step + dataset = 0 + for i, samples_batch in enumerate(data_iterator.iterate_ds_data()): + # samples += 1 + if ( + len(q_represenations) + > cfg.train.val_av_rank_max_qs / distributed_factor + ): + break + + if isinstance(samples_batch, Tuple): + samples_batch, dataset = samples_batch + + biencoder_input = BiEncoder.create_biencoder_input2( + samples_batch, + self.tensorizer, + True, + num_hard_negatives, + num_other_negatives, + shuffle=False, + ) + total_ctxs = len(ctx_represenations) + ctxs_ids = biencoder_input.context_ids + ctxs_segments = biencoder_input.ctx_segments + bsz = ctxs_ids.size(0) + + # get the token to be used for representation selection + ds_cfg = self.ds_cfg.dev_datasets[dataset] + encoder_type = ds_cfg.encoder_type + rep_positions = ds_cfg.selector.get_positions( + biencoder_input.question_ids, self.tensorizer + ) + + # split contexts batch into sub batches since it is supposed to be too large to be processed in one batch + for j, batch_start in enumerate(range(0, bsz, sub_batch_size)): + + q_ids, q_segments = ( + (biencoder_input.question_ids, biencoder_input.question_segments) + if j == 0 + else (None, None) + ) + + if j == 0 and cfg.n_gpu > 1 and q_ids.size(0) == 1: + # if we are in DP (but not in DDP) mode, all model input tensors should have batch size >1 or 0, + # otherwise the other input tensors will be split but only the first split will be called + continue + + ctx_ids_batch = ctxs_ids[batch_start : batch_start + sub_batch_size] + ctx_seg_batch = ctxs_segments[ + batch_start : batch_start + sub_batch_size + ] + + q_attn_mask = self.tensorizer.get_attn_mask(q_ids) + ctx_attn_mask = self.tensorizer.get_attn_mask(ctx_ids_batch) + with torch.no_grad(): + q_dense, ctx_dense = self.biencoder( + q_ids, + q_segments, + q_attn_mask, + ctx_ids_batch, + ctx_seg_batch, + ctx_attn_mask, + encoder_type=encoder_type, + representation_token_pos=rep_positions, + ) + + if q_dense is not None: + q_represenations.extend(q_dense.cpu().split(1, dim=0)) + + ctx_represenations.extend(ctx_dense.cpu().split(1, dim=0)) + + batch_positive_idxs = biencoder_input.is_positive + positive_idx_per_question.extend( + [total_ctxs + v for v in batch_positive_idxs] + ) + + if (i + 1) % log_result_step == 0: + logger.info( + "Av.rank validation: step %d, computed ctx_vectors %d, q_vectors %d", + i, + len(ctx_represenations), + len(q_represenations), + ) + + ctx_represenations = torch.cat(ctx_represenations, dim=0) + q_represenations = torch.cat(q_represenations, dim=0) + + logger.info( + "Av.rank validation: total q_vectors size=%s", q_represenations.size() + ) + logger.info( + "Av.rank validation: total ctx_vectors size=%s", ctx_represenations.size() + ) + + q_num = q_represenations.size(0) + assert q_num == len(positive_idx_per_question) + + scores = sim_score_f(q_represenations, ctx_represenations) + values, indices = torch.sort(scores, dim=1, descending=True) + + rank = 0 + for i, idx in enumerate(positive_idx_per_question): + # aggregate the rank of the known gold passage in the sorted results for each question + gold_idx = (indices[i] == idx).nonzero() + rank += gold_idx.item() + + if distributed_factor > 1: + # each node calcuated its own rank, exchange the information between node and calculate the "global" average rank + # NOTE: the set of passages is still unique for every node + eval_stats = all_gather_list([rank, q_num], max_size=100) + for i, item in enumerate(eval_stats): + remote_rank, remote_q_num = item + if i != cfg.local_rank: + rank += remote_rank + q_num += remote_q_num + + av_rank = float(rank / q_num) + logger.info( + "Av.rank validation: average rank %s, total questions=%d", av_rank, q_num + ) + return av_rank + + def _train_epoch( + self, + scheduler, + epoch: int, + eval_step: int, + train_data_iterator: MultiSetDataIterator, + ): + + cfg = self.cfg + rolling_train_loss = 0.0 + epoch_loss = 0 + epoch_correct_predictions = 0 + + log_result_step = cfg.train.log_batch_step + rolling_loss_step = cfg.train.train_rolling_loss_step + num_hard_negatives = cfg.train.hard_negatives + num_other_negatives = cfg.train.other_negatives + seed = cfg.seed + self.biencoder.train() + epoch_batches = train_data_iterator.max_iterations + data_iteration = 0 + + dataset = 0 + for i, samples_batch in enumerate( + train_data_iterator.iterate_ds_data(epoch=epoch) + ): + if isinstance(samples_batch, Tuple): + samples_batch, dataset = samples_batch + + ds_cfg = self.ds_cfg.train_datasets[dataset] + special_token = ds_cfg.special_token + encoder_type = ds_cfg.encoder_type + shuffle_positives = ds_cfg.shuffle_positives + + # to be able to resume shuffled ctx- pools + data_iteration = train_data_iterator.get_iteration() + random.seed(seed + epoch + data_iteration) + + biencoder_batch = BiEncoder.create_biencoder_input2( + samples_batch, + self.tensorizer, + True, + num_hard_negatives, + num_other_negatives, + shuffle=True, + shuffle_positives=shuffle_positives, + query_token=special_token, + ) + + # get the token to be used for representation selection + from dpr.data.biencoder_data import DEFAULT_SELECTOR + + selector = ds_cfg.selector if ds_cfg else DEFAULT_SELECTOR + + rep_positions = selector.get_positions( + biencoder_batch.question_ids, self.tensorizer + ) + + loss_scale = ( + cfg.loss_scale_factors[dataset] if cfg.loss_scale_factors else None + ) + loss, correct_cnt = _do_biencoder_fwd_pass( + self.biencoder, + biencoder_batch, + self.tensorizer, + cfg, + encoder_type=encoder_type, + rep_positions=rep_positions, + loss_scale=loss_scale, + ) + + epoch_correct_predictions += correct_cnt + epoch_loss += loss.item() + rolling_train_loss += loss.item() + + if cfg.fp16: + from apex import amp + + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + if cfg.train.max_grad_norm > 0: + torch.nn.utils.clip_grad_norm_( + amp.master_params(self.optimizer), cfg.train.max_grad_norm + ) + else: + loss.backward() + if cfg.train.max_grad_norm > 0: + torch.nn.utils.clip_grad_norm_( + self.biencoder.parameters(), cfg.train.max_grad_norm + ) + + if (i + 1) % cfg.train.gradient_accumulation_steps == 0: + self.optimizer.step() + scheduler.step() + self.biencoder.zero_grad() + + if i % log_result_step == 0: + lr = self.optimizer.param_groups[0]["lr"] + logger.info( + "Epoch: %d: Step: %d/%d, loss=%f, lr=%f", + epoch, + data_iteration, + epoch_batches, + loss.item(), + lr, + ) + + if (i + 1) % rolling_loss_step == 0: + logger.info("Train batch %d", data_iteration) + latest_rolling_train_av_loss = rolling_train_loss / rolling_loss_step + logger.info( + "Avg. loss per last %d batches: %f", + rolling_loss_step, + latest_rolling_train_av_loss, + ) + rolling_train_loss = 0.0 + + if data_iteration % eval_step == 0: + logger.info( + "rank=%d, Validation: Epoch: %d Step: %d/%d", + cfg.local_rank, + epoch, + data_iteration, + epoch_batches, + ) + self.validate_and_save( + epoch, train_data_iterator.get_iteration(), scheduler + ) + self.biencoder.train() + + logger.info("Epoch finished on %d", cfg.local_rank) + self.validate_and_save(epoch, data_iteration, scheduler) + + epoch_loss = (epoch_loss / epoch_batches) if epoch_batches > 0 else 0 + logger.info("Av Loss per epoch=%f", epoch_loss) + logger.info("epoch total correct predictions=%d", epoch_correct_predictions) + + def _save_checkpoint(self, scheduler, epoch: int, offset: int) -> str: + cfg = self.cfg + model_to_save = get_model_obj(self.biencoder) + cp = os.path.join(cfg.output_dir, cfg.checkpoint_file_name + "." + str(epoch)) + meta_params = get_encoder_params_state_from_cfg(cfg) + state = CheckpointState( + model_to_save.get_state_dict(), + self.optimizer.state_dict(), + scheduler.state_dict(), + offset, + epoch, + meta_params, + ) + torch.save(state._asdict(), cp) + logger.info("Saved checkpoint at %s", cp) + return cp + + def _load_saved_state(self, saved_state: CheckpointState): + epoch = saved_state.epoch + # offset is currently ignored since all checkpoints are made after full epochs + offset = saved_state.offset + if offset == 0: # epoch has been completed + epoch += 1 + logger.info("Loading checkpoint @ batch=%s and epoch=%s", offset, epoch) + + if self.cfg.ignore_checkpoint_offset: + self.start_epoch = 0 + self.start_batch = 0 + else: + self.start_epoch = epoch + # TODO: offset doesn't work for multiset currently + self.start_batch = 0 # offset + + model_to_load = get_model_obj(self.biencoder) + logger.info("Loading saved model state ...") + + model_to_load.load_state(saved_state) + + if not self.cfg.ignore_checkpoint_optimizer: + if saved_state.optimizer_dict: + logger.info("Loading saved optimizer state ...") + self.optimizer.load_state_dict(saved_state.optimizer_dict) + + if saved_state.scheduler_dict: + self.scheduler_state = saved_state.scheduler_dict + + +def _calc_loss( + cfg, + loss_function, + local_q_vector, + local_ctx_vectors, + local_positive_idxs, + local_hard_negatives_idxs: list = None, + loss_scale: float = None, +) -> Tuple[T, bool]: + """ + Calculates In-batch negatives schema loss and supports to run it in DDP mode by exchanging the representations + across all the nodes. + """ + distributed_world_size = cfg.distributed_world_size or 1 + if distributed_world_size > 1: + q_vector_to_send = ( + torch.empty_like(local_q_vector).cpu().copy_(local_q_vector).detach_() + ) + ctx_vector_to_send = ( + torch.empty_like(local_ctx_vectors).cpu().copy_(local_ctx_vectors).detach_() + ) + + global_question_ctx_vectors = all_gather_list( + [ + q_vector_to_send, + ctx_vector_to_send, + local_positive_idxs, + local_hard_negatives_idxs, + ], + max_size=cfg.global_loss_buf_sz, + ) + + global_q_vector = [] + global_ctxs_vector = [] + + # ctxs_per_question = local_ctx_vectors.size(0) + positive_idx_per_question = [] + hard_negatives_per_question = [] + + total_ctxs = 0 + + for i, item in enumerate(global_question_ctx_vectors): + q_vector, ctx_vectors, positive_idx, hard_negatives_idxs = item + + if i != cfg.local_rank: + global_q_vector.append(q_vector.to(local_q_vector.device)) + global_ctxs_vector.append(ctx_vectors.to(local_q_vector.device)) + positive_idx_per_question.extend([v + total_ctxs for v in positive_idx]) + hard_negatives_per_question.extend( + [[v + total_ctxs for v in l] for l in hard_negatives_idxs] + ) + else: + global_q_vector.append(local_q_vector) + global_ctxs_vector.append(local_ctx_vectors) + positive_idx_per_question.extend( + [v + total_ctxs for v in local_positive_idxs] + ) + hard_negatives_per_question.extend( + [[v + total_ctxs for v in l] for l in local_hard_negatives_idxs] + ) + total_ctxs += ctx_vectors.size(0) + global_q_vector = torch.cat(global_q_vector, dim=0) + global_ctxs_vector = torch.cat(global_ctxs_vector, dim=0) + + else: + global_q_vector = local_q_vector + global_ctxs_vector = local_ctx_vectors + positive_idx_per_question = local_positive_idxs + hard_negatives_per_question = local_hard_negatives_idxs + + loss, is_correct = loss_function.calc( + global_q_vector, + global_ctxs_vector, + positive_idx_per_question, + hard_negatives_per_question, + loss_scale=loss_scale, + ) + + return loss, is_correct + + +def _do_biencoder_fwd_pass( + model: nn.Module, + input: BiEncoderBatch, + tensorizer: Tensorizer, + cfg, + encoder_type: str, + rep_positions=0, + loss_scale: float = None, +) -> Tuple[torch.Tensor, int]: + + input = BiEncoderBatch(**move_to_device(input._asdict(), cfg.device)) + + q_attn_mask = tensorizer.get_attn_mask(input.question_ids) + ctx_attn_mask = tensorizer.get_attn_mask(input.context_ids) + + if model.training: + model_out = model( + input.question_ids, + input.question_segments, + q_attn_mask, + input.context_ids, + input.ctx_segments, + ctx_attn_mask, + encoder_type=encoder_type, + representation_token_pos=rep_positions, + ) + else: + with torch.no_grad(): + model_out = model( + input.question_ids, + input.question_segments, + q_attn_mask, + input.context_ids, + input.ctx_segments, + ctx_attn_mask, + encoder_type=encoder_type, + representation_token_pos=rep_positions, + ) + + local_q_vector, local_ctx_vectors = model_out + + loss_function = BiEncoderNllLoss() + + loss, is_correct = _calc_loss( + cfg, + loss_function, + local_q_vector, + local_ctx_vectors, + input.is_positive, + input.hard_negatives, + loss_scale=loss_scale, + ) + + is_correct = is_correct.sum().item() + + if cfg.n_gpu > 1: + loss = loss.mean() + if cfg.train.gradient_accumulation_steps > 1: + loss = loss / cfg.gradient_accumulation_steps + return loss, is_correct + + +@hydra.main(config_path="conf", config_name="biencoder_train_cfg") +def main(cfg: DictConfig): + if cfg.train.gradient_accumulation_steps < 1: + raise ValueError( + "Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( + cfg.train.gradient_accumulation_steps + ) + ) + + if cfg.output_dir is not None: + os.makedirs(cfg.output_dir, exist_ok=True) + + cfg = setup_cfg_gpu(cfg) + set_seed(cfg) + + if cfg.local_rank in [-1, 0]: + logger.info("CFG (after gpu configuration):") + logger.info("%s", OmegaConf.to_yaml(cfg)) + + trainer = BiEncoderTrainer(cfg) + + if cfg.train_datasets and len(cfg.train_datasets) > 0: + trainer.run_train() + elif cfg.model_file and cfg.dev_datasets: + logger.info( + "No train files are specified. Run 2 types of validation for specified model file" + ) + trainer.validate_nll() + trainer.validate_average_rank() + else: + logger.warning( + "Neither train_file or (model_file & dev_file) parameters are specified. Nothing to do." + ) + + +if __name__ == "__main__": + logger.info("Sys.argv: %s", sys.argv) + hydra_formatted_args = [] + # convert the cli params added by torch.distributed.launch into Hydra format + for arg in sys.argv: + if arg.startswith("--"): + hydra_formatted_args.append(arg[len("--") :]) + else: + hydra_formatted_args.append(arg) + logger.info("Hydra formatted Sys.argv: %s", hydra_formatted_args) + sys.argv = hydra_formatted_args + + main() diff --git a/research/information_retrieval/README.md b/research/information_retrieval/README.md new file mode 100644 index 00000000000..5fcfef65691 --- /dev/null +++ b/research/information_retrieval/README.md @@ -0,0 +1,18 @@ +# Compressing Neural Methods for Information Retrieval +Author: @spacemanidol + +Neural Methods for information retrieval have shown tremendous promise. Leveraging language models like BERT and T5 single stage and multi stage systems have exploded and in some cases are able to outperform traditional sparse search(BM25) by 2-3x. +Despite the improvement in quality neural methods prove trippy to use in production. Large model size makes index generation difficult and expensive and requires large GPU clusters. In this folder we explore how compression methods like structured pruning, unstructured pruning and distilliation can be used with the Neural Magic framework to bridge the gap from research to production. + +We experiment with compressing and optimizing sparse search by doing query prediction and expansion with a T5 model and dense retireval using BERT based bi-encoders. The goal of these experiments is to determine if neural models can be made deployable for any type of workload without using GPUs + +### Doc2Query +Fill out when project is done + +### DPR + +### Elastic Search Implementation + +## Results + + diff --git a/research/information_retrieval/doc2query/README.md b/research/information_retrieval/doc2query/README.md new file mode 100644 index 00000000000..942361a0c8a --- /dev/null +++ b/research/information_retrieval/doc2query/README.md @@ -0,0 +1,47 @@ +# Doc2Query Compressed + +Author: @spacemanidol + +Doc2query introduced a simple and direct method to integrate neural information retrieval in context of tradition keyword search. Instead of introducing a neural ranking engine at query time neural methods are moved to index generation time. +A sequence to sequence is trained with the input being passages(short context windows) and the target being the relevant query. Since the MSMARCO coprus features over 500,000 relevant passages methods like T5 can be leveraged. Unfortunatley, without compression existing T5 takes the index generation from 10 minutes(16 threads on a 14 core Intel(R) Xeon(R) Gold 5120 CPU @ 2.20GHz) to > using 4 16 GB V100 +## Results + +| Method | Sparsity | MRR @10 MSMARCO Dev | Latency(s) per 1000 queries | Index Generation (S)|Citation | +|--------------|----------|---------------------|-----------------------------|---------------------|----------------| +|BM25(Anserini)|0 |0.1874 |79.85 |00:10:16 + + +### Baseline +Download the data +```sh +cd data +wget https://msmarco.blob.core.windows.net/msmarcoranking/collectionandqueries.tar.gz +tar -xzvf collectionandqueries.tar.gz +rm collectionandqueries.tar.gz +cat queries.dev.tsv queries.train.tsv queries.eval.tsv > queries.tsv +``` + +To format the collections file, build simple index, run on msmarco dev set and evaluate which should produce outpu +``` +mkdir data/base_collection +python src/convert_doc_collection_to_jsonl.py --collection_path data/collection.tsv --output_path data/base_collection/collection +python -m pyserini.index -collection JsonCollection -generator DefaultLuceneDocumentGenerator \ + -threads 16 -input data/base_collection \ + -index indexes/msmarco-passage-baseline -storePositions -storeDocvectors -storeRaw +python -m pyserini.search --topics data/queries.dev.small.tsv \ + --index indexes/msmarco-passage-baseline \ + --output runs/run.msmarco-passage.bm25baseline.tsv \ + --bm25 --output-format msmarco --hits 1000 --k1 0.82 --b 0.68 +python src/msmarco_passage_eval.py data/qrels.dev.small.tsv runs/run.msmarco-passage.bm25baseline.tsv +##################### +MRR @10: 0.18741227770955543 +QueriesRanked: 6980 +##################### +``` + +### Doc2Query + +Format the data for training + +```sh + python src/make_doc2query_data.py --collection_file data/collection.tsv --query_file data.queries.tsv --train_qrel_file data/qrels.train.tsv --dev_qrel_file data/qrels.dev.tsv --output_file_prefix data/doc_query_ diff --git a/research/information_retrieval/doc2query/indexes/init.txt b/research/information_retrieval/doc2query/indexes/init.txt new file mode 100644 index 00000000000..e69de29bb2d diff --git a/research/information_retrieval/doc2query/outputs/bm25_baseline.txt b/research/information_retrieval/doc2query/outputs/bm25_baseline.txt new file mode 100644 index 00000000000..7c25e0a0890 --- /dev/null +++ b/research/information_retrieval/doc2query/outputs/bm25_baseline.txt @@ -0,0 +1,4 @@ +##################### +MRR @10: 0.18741227770955543 +QueriesRanked: 6980 +##################### diff --git a/research/information_retrieval/doc2query/outputs/init.txt b/research/information_retrieval/doc2query/outputs/init.txt new file mode 100644 index 00000000000..e69de29bb2d diff --git a/research/information_retrieval/doc2query/recipes/90sparse.yaml b/research/information_retrieval/doc2query/recipes/90sparse.yaml new file mode 100644 index 00000000000..dfc38d244a3 --- /dev/null +++ b/research/information_retrieval/doc2query/recipes/90sparse.yaml @@ -0,0 +1,2310 @@ +version: 1.1.0 + +modifiers: + - !EpochRangeModifier + end_epoch: 5 + start_epoch: 0.0 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.0.layer.0.SelfAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.0.layer.0.SelfAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.0.layer.0.SelfAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.0.layer.0.SelfAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.0.layer.1.DenseReluDense.wi.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.0.layer.1.DenseReluDense.wo.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.0.layer.0.SelfAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.0.layer.0.SelfAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.0.layer.0.SelfAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.0.layer.0.SelfAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.0.layer.1.EncDecAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.0.layer.1.EncDecAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.0.layer.1.EncDecAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.0.layer.1.EncDecAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.0.layer.2.DenseReluDense.wi.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.0.layer.2.DenseReluDense.wo.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.1.layer.0.SelfAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.1.layer.0.SelfAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.1.layer.0.SelfAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.1.layer.0.SelfAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.1.layer.1.DenseReluDense.wi.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.1.layer.1.DenseReluDense.wo.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.1.layer.0.SelfAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.1.layer.0.SelfAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.1.layer.0.SelfAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.1.layer.0.SelfAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.1.layer.1.EncDecAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.1.layer.1.EncDecAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.1.layer.1.EncDecAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.1.layer.1.EncDecAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.1.layer.2.DenseReluDense.wi.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.1.layer.2.DenseReluDense.wo.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.2.layer.0.SelfAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.2.layer.0.SelfAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.2.layer.0.SelfAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.2.layer.0.SelfAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.2.layer.1.DenseReluDense.wi.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.2.layer.1.DenseReluDense.wo.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.2.layer.0.SelfAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.2.layer.0.SelfAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.2.layer.0.SelfAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.2.layer.0.SelfAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.2.layer.1.EncDecAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.2.layer.1.EncDecAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.2.layer.1.EncDecAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.2.layer.1.EncDecAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.2.layer.2.DenseReluDense.wi.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.2.layer.2.DenseReluDense.wo.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.3.layer.0.SelfAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.3.layer.0.SelfAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.3.layer.0.SelfAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.3.layer.0.SelfAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.3.layer.1.DenseReluDense.wi.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.3.layer.1.DenseReluDense.wo.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.3.layer.0.SelfAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.3.layer.0.SelfAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.3.layer.0.SelfAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.3.layer.0.SelfAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.3.layer.1.EncDecAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.3.layer.1.EncDecAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.3.layer.1.EncDecAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.3.layer.1.EncDecAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.3.layer.2.DenseReluDense.wi.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.3.layer.2.DenseReluDense.wo.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.4.layer.0.SelfAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.4.layer.0.SelfAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.4.layer.0.SelfAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.4.layer.0.SelfAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.4.layer.1.DenseReluDense.wi.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.4.layer.1.DenseReluDense.wo.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.4.layer.0.SelfAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.4.layer.0.SelfAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.4.layer.0.SelfAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.4.layer.0.SelfAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.4.layer.1.EncDecAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.4.layer.1.EncDecAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.4.layer.1.EncDecAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.4.layer.1.EncDecAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.4.layer.2.DenseReluDense.wi.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.4.layer.2.DenseReluDense.wo.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.5.layer.0.SelfAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.5.layer.0.SelfAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.5.layer.0.SelfAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.5.layer.0.SelfAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.5.layer.1.DenseReluDense.wi.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.5.layer.1.DenseReluDense.wo.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.5.layer.0.SelfAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.5.layer.0.SelfAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.5.layer.0.SelfAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.5.layer.0.SelfAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.5.layer.1.EncDecAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.5.layer.1.EncDecAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.5.layer.1.EncDecAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.5.layer.1.EncDecAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.5.layer.2.DenseReluDense.wi.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.5.layer.2.DenseReluDense.wo.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.6.layer.0.SelfAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.6.layer.0.SelfAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.6.layer.0.SelfAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.6.layer.0.SelfAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.6.layer.1.DenseReluDense.wi.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.6.layer.1.DenseReluDense.wo.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.6.layer.0.SelfAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.6.layer.0.SelfAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.6.layer.0.SelfAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.6.layer.0.SelfAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.6.layer.1.EncDecAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.6.layer.1.EncDecAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.6.layer.1.EncDecAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.6.layer.1.EncDecAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.6.layer.2.DenseReluDense.wi.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.6.layer.2.DenseReluDense.wo.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.7.layer.0.SelfAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.7.layer.0.SelfAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.7.layer.0.SelfAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.7.layer.0.SelfAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.7.layer.1.DenseReluDense.wi.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.7.layer.1.DenseReluDense.wo.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.7.layer.0.SelfAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.7.layer.0.SelfAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.7.layer.0.SelfAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.7.layer.0.SelfAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.7.layer.1.EncDecAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.7.layer.1.EncDecAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.7.layer.1.EncDecAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.7.layer.1.EncDecAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.7.layer.2.DenseReluDense.wi.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.7.layer.2.DenseReluDense.wo.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.8.layer.0.SelfAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.8.layer.0.SelfAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.8.layer.0.SelfAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.8.layer.0.SelfAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.8.layer.1.DenseReluDense.wi.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.8.layer.1.DenseReluDense.wo.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.8.layer.0.SelfAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.8.layer.0.SelfAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.8.layer.0.SelfAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.8.layer.0.SelfAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.8.layer.1.EncDecAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.8.layer.1.EncDecAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.8.layer.1.EncDecAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.8.layer.1.EncDecAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.8.layer.2.DenseReluDense.wi.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.8.layer.2.DenseReluDense.wo.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.9.layer.0.SelfAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.9.layer.0.SelfAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.9.layer.0.SelfAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.9.layer.0.SelfAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.9.layer.1.DenseReluDense.wi.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.9.layer.1.DenseReluDense.wo.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.9.layer.0.SelfAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.9.layer.0.SelfAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.9.layer.0.SelfAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.9.layer.0.SelfAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.9.layer.1.EncDecAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.9.layer.1.EncDecAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.9.layer.1.EncDecAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.9.layer.1.EncDecAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.9.layer.2.DenseReluDense.wi.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.9.layer.2.DenseReluDense.wo.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.10.layer.0.SelfAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.10.layer.0.SelfAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.10.layer.0.SelfAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.10.layer.0.SelfAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.10.layer.1.DenseReluDense.wi.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.10.layer.1.DenseReluDense.wo.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.10.layer.0.SelfAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.10.layer.0.SelfAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.10.layer.0.SelfAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.10.layer.0.SelfAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.10.layer.1.EncDecAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.10.layer.1.EncDecAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.10.layer.1.EncDecAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.10.layer.1.EncDecAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.10.layer.2.DenseReluDense.wi.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.10.layer.2.DenseReluDense.wo.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.11.layer.0.SelfAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.11.layer.0.SelfAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.11.layer.0.SelfAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.11.layer.0.SelfAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.11.layer.1.DenseReluDense.wi.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['encoder.block.11.layer.1.DenseReluDense.wo.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.11.layer.0.SelfAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.11.layer.0.SelfAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.11.layer.0.SelfAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.11.layer.0.SelfAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.11.layer.1.EncDecAttention.q.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.11.layer.1.EncDecAttention.k.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.11.layer.1.EncDecAttention.v.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.11.layer.1.EncDecAttention.o.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.11.layer.2.DenseReluDense.wi.weight'] + start_epoch: 1 + update_frequency: 0.01 + + - !GMPruningModifier + end_epoch: 4 + final_sparsity: 0.90 + init_sparsity: 0.00 + inter_func: cubic + leave_enabled: True + log_types: __ALL__ + mask_type: unstructured + params: ['decoder.block.11.layer.2.DenseReluDense.wo.weight'] + start_epoch: 1 + update_frequency: 0.01 diff --git a/research/information_retrieval/doc2query/recipes/noprune.yaml b/research/information_retrieval/doc2query/recipes/noprune.yaml new file mode 100644 index 00000000000..b4a205e1c08 --- /dev/null +++ b/research/information_retrieval/doc2query/recipes/noprune.yaml @@ -0,0 +1,6 @@ +version: 1.1.0 + +modifiers: + - !EpochRangeModifier + end_epoch: 5 + start_epoch: 0.0 diff --git a/research/information_retrieval/doc2query/requirements.txt b/research/information_retrieval/doc2query/requirements.txt new file mode 100644 index 00000000000..14565701cf8 --- /dev/null +++ b/research/information_retrieval/doc2query/requirements.txt @@ -0,0 +1,136 @@ +absl-py==0.13.0 +argon2-cffi==20.1.0 +async-generator==1.10 +attrs==21.2.0 +backcall==0.2.0 +bleach==3.3.0 +blis==0.7.4 +catalogue==2.0.4 +certifi==2021.5.30 +cffi==1.14.5 +chardet==4.0.0 +click==7.1.2 +configparser==5.0.2 +cycler==0.10.0 +cymem==2.0.5 +Cython==0.29.23 +datasets==1.8.0 +decorator==4.4.2 +defusedxml==0.7.1 +dill==0.3.4 +docker-pycreds==0.4.0 +entrypoints==0.3 +filelock==3.0.12 +flatbuffers==2.0 +fsspec==2021.6.0 +gitdb==4.0.7 +GitPython==3.1.14 +huggingface-hub==0.0.8 +idna==2.10 +imageio==2.9.0 +ipykernel==5.5.5 +ipython==7.24.1 +ipython-genutils==0.2.0 +ipywidgets==7.6.3 +jedi==0.18.0 +Jinja2==3.0.1 +joblib==1.0.1 +jsonschema==3.2.0 +jupyter==1.0.0 +jupyter-client==6.1.12 +jupyter-console==6.4.0 +jupyter-core==4.7.1 +jupyterlab-pygments==0.1.2 +jupyterlab-widgets==1.0.0 +kiwisolver==1.3.1 +MarkupSafe==2.0.1 +matplotlib==3.4.2 +matplotlib-inline==0.1.2 +merge-args==0.1.4 +mistune==0.8.4 +multiprocess==0.70.12.2 +murmurhash==1.0.5 +nbclient==0.5.3 +nbconvert==6.0.7 +nbformat==5.1.3 +nest-asyncio==1.5.1 +networkx==2.5.1 +nltk==3.6.2 +notebook==6.4.0 +numpy==1.20.3 +onnx==1.7.0 +onnxruntime==1.8.0 +packaging==20.9 +pandas==1.2.4 +pandocfilters==1.4.3 +parso==0.8.2 +pathtools==0.1.2 +pathy==0.5.2 +pexpect==4.8.0 +pickleshare==0.7.5 +Pillow==8.2.0 +preshed==3.0.5 +progressbar2==3.53.1 +prometheus-client==0.11.0 +promise==2.3 +prompt-toolkit==3.0.18 +protobuf==3.17.3 +psutil==5.8.0 +ptyprocess==0.7.0 +pyarrow==3.0.0 +pycparser==2.20 +pydantic==1.7.4 +Pygments==2.9.0 +pyjnius==1.3.0 +pyparsing==2.4.7 +pyrsistent==0.17.3 +pyserini==0.12.0 +python-dateutil==2.8.1 +python-utils==2.5.6 +pytz==2021.1 +PyWavelets==1.1.1 +PyYAML==5.4.1 +pyzmq==22.1.0 +qtconsole==5.1.0 +QtPy==1.9.0 +regex==2021.4.4 +requests==2.25.1 +rouge-score==0.0.4 +sacremoses==0.0.45 +scikit-image==0.18.1 +scikit-learn==0.24.2 +scipy==1.6.3 +Send2Trash==1.5.0 +sentencepiece==0.1.96 +sentry-sdk==1.1.0 +shortuuid==1.0.1 +six==1.16.0 +smart-open==3.0.0 +smmap==4.0.0 +spacy==3.0.6 +spacy-legacy==3.0.6 +sparseml==0.4.0 +sparsezoo==0.4.0 +srsly==2.4.1 +subprocess32==3.5.4 +terminado==0.10.1 +testpath==0.5.0 +thinc==8.0.5 +threadpoolctl==2.1.0 +tifffile==2021.6.14 +tokenizers==0.10.3 +toposort==1.6 +torch==1.8.0 +tornado==6.1 +tqdm==4.61.1 +traitlets==5.0.5 +transformers==4.7.0 +typer==0.3.2 +typing-extensions==3.10.0.0 +urllib3==1.26.5 +wandb==0.10.32 +wasabi==0.8.2 +wcwidth==0.2.5 +webencodings==0.5.1 +widgetsnbextension==3.5.1 +xxhash==2.0.2 diff --git a/research/information_retrieval/doc2query/sparseml_utils.py b/research/information_retrieval/doc2query/sparseml_utils.py new file mode 100644 index 00000000000..b6f8b6dbb27 --- /dev/null +++ b/research/information_retrieval/doc2query/sparseml_utils.py @@ -0,0 +1,121 @@ +import math + +import torch +import torch.nn.functional as F + +from sparseml.pytorch.optim.manager import ScheduledModifierManager +from sparseml.pytorch.optim.optimizer import ScheduledOptimizer +from sparseml.pytorch.utils import ModuleExporter, logger +from trainer_qa import QuestionAnsweringTrainer + + +class SparseMLQATrainer(QuestionAnsweringTrainer): + """ + Question Answering trainer with SparseML integration + + :param recipe: recipe for model sparsification + :param teacher: teacher model for distillation + :param distill_hardness: ratio of loss by teacher targets (between 0 and 1) + :param distill_temperature: temperature for distillation + :param args, kwargs: arguments passed into parent class + """ + + def __init__(self, recipe, teacher=None, distill_hardness=0.5, distill_temperature=2.0, *args, **kwargs): + super().__init__(*args, **kwargs) + self.recipe = recipe + self.teacher = teacher + self.distill_hardness = distill_hardness + self.distill_temperature = distill_temperature + self.criterion = torch.nn.CrossEntropyLoss() + + self.manager = None + self.loggers = None + if self.recipe is not None: + loggers = [] + if "wandb" in self.args.report_to: + loggers.append(logger.WANDBLogger()) + self.loggers = loggers + + def create_optimizer(self): + """ + Create optimizer customized using SparseML + """ + super().create_optimizer() + if self.recipe is None: + return + steps_per_epoch = math.ceil( + len(self.train_dataset) / (self.args.per_device_train_batch_size * self.args._n_gpu) + ) + self.manager = ScheduledModifierManager.from_yaml(self.recipe) + self.args.num_train_epochs = float(self.manager.max_epochs) + if hasattr(self, "scaler"): + self.manager.initialize(self.model, epoch=0.0, loggers=self.loggers) + self.scaler = self.manager.modify( + self.model, self.optimizer, steps_per_epoch=steps_per_epoch, wrap_optim=self.scaler + ) + else: + self.optimizer = ScheduledOptimizer( + self.optimizer, self.model, self.manager, steps_per_epoch=steps_per_epoch, loggers=self.loggers + ) + + def compute_loss(self, model, inputs, return_outputs=False): + """ + Computing loss using teacher/student distillation + """ + if self.recipe is None or self.teacher is None: + return super().compute_loss(model, inputs, return_outputs=return_outputs) + + outputs = model(**inputs) + if self.teacher is None: + loss = outputs["loss"] + else: + input_device = inputs["input_ids"].device + self.teacher = self.teacher.to(input_device) + start_logits_student = outputs["start_logits"] + end_logits_student = outputs["end_logits"] + start_logits_label = inputs["start_positions"] + end_logits_label = inputs["end_positions"] + with torch.no_grad(): + teacher_output = self.teacher( + input_ids=inputs["input_ids"], + token_type_ids=inputs["token_type_ids"], + attention_mask=inputs["attention_mask"], + ) + start_logits_teacher = teacher_output["start_logits"] + end_logits_teacher = teacher_output["end_logits"] + loss_start = ( + F.kl_div( + input=F.log_softmax(start_logits_student / self.distill_temperature, dim=-1), + target=F.softmax(start_logits_teacher / self.distill_temperature, dim=-1), + reduction="batchmean", + ) + * (self.distill_temperature ** 2) + ) + loss_end = ( + F.kl_div( + input=F.log_softmax(end_logits_student / self.distill_temperature, dim=-1), + target=F.softmax(end_logits_teacher / self.distill_temperature, dim=-1), + reduction="batchmean", + ) + * (self.distill_temperature ** 2) + ) + teacher_loss = (loss_start + loss_end) / 2.0 + loss_start = self.criterion(start_logits_student, start_logits_label) + loss_end = self.criterion(end_logits_student, end_logits_label) + label_loss = (loss_start + loss_end) / 2.0 + loss = ((1 - self.distill_hardness) * label_loss) + (self.distill_hardness * teacher_loss) + return (loss, outputs) if return_outputs else loss + + +def export_model(model, dataloader, output_dir): + """ + Export a trained model to ONNX + :param model: trained model + :param dataloader: dataloader to get sample batch + :param output_dir: output directory for ONNX model + """ + exporter = ModuleExporter(model, output_dir=output_dir) + for _, sample_batch in enumerate(dataloader): + sample_input = (sample_batch["input_ids"], sample_batch["attention_mask"], sample_batch["token_type_ids"]) + exporter.export_onnx(sample_batch=sample_input, convert_qat=True) + break \ No newline at end of file diff --git a/research/information_retrieval/doc2query/src/augment_collection.py b/research/information_retrieval/doc2query/src/augment_collection.py new file mode 100644 index 00000000000..754b79edbf1 --- /dev/null +++ b/research/information_retrieval/doc2query/src/augment_collection.py @@ -0,0 +1,103 @@ +import argparse +import os +import json + +import transformers +from filelock import FileLock +from transformers import ( + AutoConfig, + AutoModelForSeq2SeqLM, + AutoTokenizer, + DataCollatorForSeq2Seq, + HfArgumentParser, + Seq2SeqTrainer, + Seq2SeqTrainingArguments, + set_seed, +) + +def load_qid2query(filename): + qid2query = {} + with open(filename, 'r') as f: + for l in f: + l = l.strip().split('\t') + qid2query[int(l[0])] = l[1] + return qid2query + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--collection_file", + default="data/collection.tsv", + type=str, + help="The msmarco passage collection file", + ) + parser.add_argument( + "--model_name_or_path", + type=str, + default=None, + help="Doc2Query predictions", + ) + parser.add_argument( + "--augmented_collection_file", + type=str, + default="data/augmented_collection.jsonl", + help="The output_file for augmented doc 2 query index", + ) + parser.add_argument( + "--beam_size", + type=int, + default=3, + help="number of queries to generate per passage", + ) + parser.add_argument( + "--max_length", + type=int, + default=32, + help="length of document queries", + ) + parser.add_argument( + '--no_cuda', + action="store_true", + help="Use this to not use cuda") + args = parser.parse_args() + print("Loading collection") + collection = load_qid2query(args.collection_file) + print("Collection loaded") + device='cuda' + if args.no_cuda: + device='cpu' + + print("Loading model") + config = AutoConfig.from_pretrained(args.model_name_or_path,) + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path,) + model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path) + model.to(device) + model.resize_token_embeddings(len(tokenizer)) + print("Model Loaded") + print("Augmenting passages") + augmentations = 0 + #TODO Introduce batching at inference time as right now runs 1 by 1 + with open(args.augmented_collection_file, 'w') as w: + for doc_id in collection: + if augmentations % 5000 == 0: + print("{} passages augmented".format(augmentations)) + document_text = collection[doc_id] + input_ids = tokenizer.encode(document_text, return_tensors='pt').to(device) + outputs = model.generate( + input_ids=input_ids, + max_length=args.max_length, + do_sample=True, + top_k=10, + num_return_sequences=args.beam_size) + query_augment = '' + for i in range(args.beam_size): + query_augment += ' ' + query_augment += tokenizer.decode(outputs[i], skip_special_tokens=True) + output_dict = {'id': doc_id, 'contents': document_text + query_augment} + w.write(json.dumps(output_dict) + '\n') + augmentations += 1 + +if __name__ == "__main__": + main() + diff --git a/research/information_retrieval/doc2query/src/convert_doc_collection_to_jsonl.py b/research/information_retrieval/doc2query/src/convert_doc_collection_to_jsonl.py new file mode 100644 index 00000000000..98128816f4a --- /dev/null +++ b/research/information_retrieval/doc2query/src/convert_doc_collection_to_jsonl.py @@ -0,0 +1,20 @@ +import os +import json +import argparse + +def convert_collection(args): + with open(args.output_path, 'w', encoding='utf-8') as w: + with open(args.collection_path, encoding='utf-8') as f: + for i, line in enumerate(f): + id, body = line.split('\t') + output_dict = {'id': id, 'contents': body} + w.write(json.dumps(output_dict) + '\n') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Convert MSMARCO tsv passage collection into jsonl files for Anserini.') + parser.add_argument('--collection_path', required=True, help='Path to MS MARCO tsv collection.') + parser.add_argument('--output_path', required=True, help='Output filename.') + args = parser.parse_args() + convert_collection(args) + print('Done!') diff --git a/research/information_retrieval/doc2query/src/distill_doc2query.py b/research/information_retrieval/doc2query/src/distill_doc2query.py new file mode 100644 index 00000000000..365ff6a84a6 --- /dev/null +++ b/research/information_retrieval/doc2query/src/distill_doc2query.py @@ -0,0 +1,64 @@ +# neuralmagic: no copyright +# flake8: noqa +# fmt: off +# isort: skip_file +#!/usr/bin/env python +# coding=utf-8 +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Union + +import torch +from torch import nn +import torch.nn.functional as F +from torch import Tensor + +from transformers import Trainer, is_datasets_available, is_torch_tpu_available +from transformers.trainer_utils import PredictionOutput + +class DistillGlueTrainer(Trainer): + def __init__(self, *args, eval_examples=None, post_process_function=None, teacher=None, loss=None, batch_size=8, max_sequence_length=384,distill_hardness=1.0, temperature=2.0, **kwargs): + super().__init__(*args, **kwargs) + self.eval_examples = eval_examples + self.post_process_function = post_process_function + self.loss = loss + self.teacher = teacher + self.batch_size = batch_size + self.temperature = temperature + self.distill_hardness = distill_hardness + self.criterion = nn.CrossEntropyLoss() + self.max_sequence_length = max_sequence_length + if self.teacher is None: + self.distill_hardness = 0 + + def compute_loss(self, model, inputs, return_outputs=False): + """ + How the loss is computed by Trainer. Modified for Distilation using student teacher framework modified for distilation. + """ + outputs = model(**inputs) + loss = outputs["loss"] + logits_student = outputs["logits"] + if self.teacher is not None: + input_device = inputs["input_ids"].device + self.teacher = self.teacher.to(input_device) + with torch.no_grad(): + teacher_outputs = self.teacher( + input_ids=inputs["input_ids"], + token_type_ids=inputs["token_type_ids"], + attention_mask=inputs["attention_mask"], + ) + logits_teacher = teacher_outputs["logits"] + loss_distill = F.kl_div( input=logits_student, target=logits_teacher, reduction="batchmean",) * (self.temperature ** 2) + loss = ((1-self.distill_hardness) * loss) + torch.abs((self.distill_hardness * loss_distill)) + return (loss, outputs) if return_outputs else loss diff --git a/research/information_retrieval/doc2query/src/make_doc2query_data.py b/research/information_retrieval/doc2query/src/make_doc2query_data.py new file mode 100644 index 00000000000..57862c38538 --- /dev/null +++ b/research/information_retrieval/doc2query/src/make_doc2query_data.py @@ -0,0 +1,66 @@ +import argparse +import os +import json +def load_qid2query(filename): + qid2query = {} + with open(filename, 'r') as f: + for l in f: + l = l.strip().split('\t') + qid2query[int(l[0])] = l[1] + return qid2query + +def load_qrels(filename, collection, qid2query): + qrels = {} + with open(filename, 'r') as f: + for l in f: + l = l.strip().split('\t') + qrels[qid2query[int(l[0])]] = collection[int(l[2])] + return qrels + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--collection_file", + default="collection.tsv", + type=str, + help="The msmarco passage collection file", + ) + parser.add_argument( + "--query_file", + type=str, + default="queries.tsv", + help="Qid to query for all msmarco queries", + ) + parser.add_argument( + "--train_qrel_file", + type=str, + default="qrels.train.tsv", + help="The input file in TSV form of doc2query", + ) + parser.add_argument( + "--dev_qrel_file", + type=str, + default="qrels.dev.tsv", + help="The input file in TSV form of doc2query", + ) + parser.add_argument( + "--output_file_prefix", + type=str, + default="doc_query_", + help="The input file in TSV form of doc2query", + ) + args = parser.parse_args() + collection = load_qid2query(args.collection_file) + qid2query = load_qid2query(args.query_file) + train_qrels = load_qrels(args.train_qrel_file, collection, qid2query) + dev_qrels = load_qrels(args.dev_qrel_file, collection, qid2query) + with open(args.output_file_prefix+"train.json",'w') as w: + for qrel in train_qrels: + w.write("{}\n".format(json.dumps({"input":train_qrels[qrel], "target":qrel}))) + with open(args.output_file_prefix+"dev.json",'w') as w: + for qrel in dev_qrels: + w.write("{}\n".format(json.dumps({"input":dev_qrels[qrel], "target":qrel}))) + + +if __name__ == "__main__": + main() diff --git a/research/information_retrieval/doc2query/src/msmarco_passage_eval.py b/research/information_retrieval/doc2query/src/msmarco_passage_eval.py new file mode 100644 index 00000000000..4ed0eb64581 --- /dev/null +++ b/research/information_retrieval/doc2query/src/msmarco_passage_eval.py @@ -0,0 +1,185 @@ +""" +This module computes evaluation metrics for MSMARCO dataset on the ranking task. +Command line: +python msmarco_eval_ranking.py + +Creation Date : 06/12/2018 +Last Modified : 1/21/2019 +Authors : Daniel Campos , Rutger van Haasteren +""" +import re +import sys +import statistics + +from collections import Counter + +MaxMRRRank = 10 + +def load_reference_from_stream(f): + """Load Reference reference relevant passages + Args:f (stream): stream to load. + Returns:qids_to_relevant_passageids (dict): dictionary mapping from query_id (int) to relevant passages (list of ints). + """ + qids_to_relevant_passageids = {} + for l in f: + try: + l = re.split('[\t\s]', l.strip()) + qid = int(l[0]) + if qid in qids_to_relevant_passageids: + pass + else: + qids_to_relevant_passageids[qid] = [] + qids_to_relevant_passageids[qid].append(int(l[2])) + except: + raise IOError('\"%s\" is not valid format' % l) + return qids_to_relevant_passageids + +def load_reference(path_to_reference): + """Load Reference reference relevant passages + Args:path_to_reference (str): path to a file to load. + Returns:qids_to_relevant_passageids (dict): dictionary mapping from query_id (int) to relevant passages (list of ints). + """ + with open(path_to_reference,'r') as f: + qids_to_relevant_passageids = load_reference_from_stream(f) + return qids_to_relevant_passageids + +def load_candidate_from_stream(f): + """Load candidate data from a stream. + Args:f (stream): stream to load. + Returns:qid_to_ranked_candidate_passages (dict): dictionary mapping from query_id (int) to a list of 1000 passage ids(int) ranked by relevance and importance + """ + qid_to_ranked_candidate_passages = {} + for l in f: + try: + l = l.strip().split('\t') + qid = int(l[0]) + pid = int(l[1]) + rank = int(l[2]) + if qid in qid_to_ranked_candidate_passages: + pass + else: + # By default, all PIDs in the list of 1000 are 0. Only override those that are given + tmp = [0] * 1000 + qid_to_ranked_candidate_passages[qid] = tmp + qid_to_ranked_candidate_passages[qid][rank-1]=pid + except: + raise IOError('\"%s\" is not valid format' % l) + return qid_to_ranked_candidate_passages + +def load_candidate(path_to_candidate): + """Load candidate data from a file. + Args:path_to_candidate (str): path to file to load. + Returns:qid_to_ranked_candidate_passages (dict): dictionary mapping from query_id (int) to a list of 1000 passage ids(int) ranked by relevance and importance + """ + + with open(path_to_candidate,'r') as f: + qid_to_ranked_candidate_passages = load_candidate_from_stream(f) + return qid_to_ranked_candidate_passages + +def quality_checks_qids(qids_to_relevant_passageids, qids_to_ranked_candidate_passages): + """Perform quality checks on the dictionaries + + Args: + p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping + Dict as read in with load_reference or load_reference_from_stream + p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates + Returns: + bool,str: Boolean whether allowed, message to be shown in case of a problem + """ + message = '' + allowed = True + + # Create sets of the QIDs for the submitted and reference queries + candidate_set = set(qids_to_ranked_candidate_passages.keys()) + ref_set = set(qids_to_relevant_passageids.keys()) + + # Check that we do not have multiple passages per query + for qid in qids_to_ranked_candidate_passages: + # Remove all zeros from the candidates + duplicate_pids = set([item for item, count in Counter(qids_to_ranked_candidate_passages[qid]).items() if count > 1]) + + if len(duplicate_pids-set([0])) > 0: + message = "Cannot rank a passage multiple times for a single query. QID={qid}, PID={pid}".format( + qid=qid, pid=list(duplicate_pids)[0]) + allowed = False + + return allowed, message + +def compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages): + """Compute MRR metric + Args: + p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping + Dict as read in with load_reference or load_reference_from_stream + p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates + Returns: + dict: dictionary of metrics {'MRR': } + """ + all_scores = {} + MRR = 0 + qids_with_relevant_passages = 0 + ranking = [] + for qid in qids_to_ranked_candidate_passages: + if qid in qids_to_relevant_passageids: + ranking.append(0) + target_pid = qids_to_relevant_passageids[qid] + candidate_pid = qids_to_ranked_candidate_passages[qid] + for i in range(0,MaxMRRRank): + if candidate_pid[i] in target_pid: + MRR += 1/(i + 1) + ranking.pop() + ranking.append(i+1) + break + if len(ranking) == 0: + raise IOError("No matching QIDs found. Are you sure you are scoring the evaluation set?") + + MRR = MRR/len(qids_to_relevant_passageids) + all_scores['MRR @10'] = MRR + all_scores['QueriesRanked'] = len(qids_to_ranked_candidate_passages) + return all_scores + +def compute_metrics_from_files(path_to_reference, path_to_candidate, perform_checks=True): + """Compute MRR metric + Args: + p_path_to_reference_file (str): path to reference file. + Reference file should contain lines in the following format: + QUERYID\tPASSAGEID + Where PASSAGEID is a relevant passage for a query. Note QUERYID can repeat on different lines with different PASSAGEIDs + p_path_to_candidate_file (str): path to candidate file. + Candidate file sould contain lines in the following format: + QUERYID\tPASSAGEID1\tRank + If a user wishes to use the TREC format please run the script with a -t flag at the end. If this flag is used the expected format is + QUERYID\tITER\tDOCNO\tRANK\tSIM\tRUNID + Where the values are separated by tabs and ranked in order of relevance + Returns: + dict: dictionary of metrics {'MRR': } + """ + + qids_to_relevant_passageids = load_reference(path_to_reference) + qids_to_ranked_candidate_passages = load_candidate(path_to_candidate) + if perform_checks: + allowed, message = quality_checks_qids(qids_to_relevant_passageids, qids_to_ranked_candidate_passages) + if message != '': print(message) + + return compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages) + +def main(): + """Command line: + python msmarco_eval_ranking.py + """ + + if len(sys.argv) == 3: + path_to_reference = sys.argv[1] + path_to_candidate = sys.argv[2] + metrics = compute_metrics_from_files(path_to_reference, path_to_candidate) + print('#####################') + for metric in sorted(metrics): + print('{}: {}'.format(metric, metrics[metric])) + print('#####################') + + else: + print('Usage: msmarco_eval_ranking.py ') + exit() + +if __name__ == '__main__': + main() + diff --git a/research/information_retrieval/doc2query/src/run_doc2query.py b/research/information_retrieval/doc2query/src/run_doc2query.py new file mode 100644 index 00000000000..dde59532183 --- /dev/null +++ b/research/information_retrieval/doc2query/src/run_doc2query.py @@ -0,0 +1,804 @@ +# neuralmagic: no copyright +# flake8: noqa +# fmt: off +# isort: skip_file +#!/usr/bin/env python +# coding=utf-8 +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example script for integrating spaseml with the transformers library to perform model distillation and pruning on GLUE tasks. +This script is addopted from hugging face's implementation for the GLUEDataset. +Hugging Face's original implementation is regularly updated and can be found at https://github.com/huggingface/transformers/blob/master/examples/pytorch/text-classification/run_glue.py +This script will: +- Load transformer based models +- Load a sparseml training and pruning optimizer +- Train on Target GLUE Task +- Evaluate on GLUE +- Export model to onnx. +########## +Command help: +usage: run_glue.py [-h] \ + [--teacher_model_name_or_path] \ + [--student_model_name_or_path] \ + [--task_name] \ + [--temperature] \ + [--distill_hardness] \ + [--dataset_name] \ + [--num_train_epochs] \ + [--do_train] \ + [--do_eval] \ + [--per_device_train_batch_size] \ + [--per_device_eval_batch_size] \ + [--learning_rate]\ + [--output_dir] \ + [--overwrite_output_dir] \ + [--cache_dir]\ + [--preprocessing_num_workers] \ + [--seed] \ + [--nm_prune_config] \ + [--do_onnx_export] \ + [--onnx_export_path] \ + [--layers_to_keep] \ + +Train, prune, and evaluate a transformer base question answering model on squad. + -h, --help show this help message and exit + --teacher_model_name_or_path The name or path of model which will be used for distilation. + Note, this model needs to be trained for QA task already. + --student_model_name_or_path The path to the transformers model you wish to train + or the name of the pretrained language model you wish + to use. ex: bert-base-uncased. + --task_name The name of the GLUE task which the model with train and evalute on. + --temperature Hyperparameter which controls model distilation + --distill_hardness Hyperparameter which controls how much of the loss comes from teacher vs training labels + --dataset_name The name of which dataset you want to use to train or + your model. ex: squad for using SQuAD. + --num_train_epochs Paramater to control how many training epochs you wish + your model to train. + --do_train Boolean denoting if the model should be trained + or not. Default is false. + --do_eval Boolean denoting if the model should be evaluated + or not. Default is false. + --per_device_train_batch_size Size of each training batch based on samples per GPU. + 24 will fit in a 11gb GPU, 32 in a 16gb. + --per_device_eval_batch_size Size of each training batch based on samples per GPU. + 24 will fit in a 11gb GPU, 32 in a 16gb. + --learning_rate Learning rate initial float value. ex: 3e-5. + --output_dir Path which model checkpoints and paths should be saved. + --overwrite_output_dir Boolean to define if the + --cache_dir Directiory which cached transformer files(datasets, models + , tokenizers) are saved for fast loading. + --preprocessing_num_workers The amount of cpu workers which are used to process datasets + --seed Int which determines what random seed is for training/shuffling + --nm_prune_config Path to the neural magic prune configuration file. examples can + be found in prune_config_files but are customized for bert-base-uncased. + --do_onnx_export Boolean denoting if the model should be exported to onnx + --onnx_export_path Path where onnx model path will be exported. ex: onnx-export + --layers_to_keep Number of layers to keep from original model. Layers are dropped before training + --max_seq_length Int for the max sequence length to be parsed for glue tasks ex: 128 tokens. + +########## +Example command for training a 95% sparse BERT SQUAD model for 1 epoch without distilation on the Quora Duplicate Question Task: +python examples/transformers/run_glue.py \ + --teacher_model_name_or_path NONE + --student_model_name_or_path bert-base-uncased \ + --task_name QQP + --dataset_name squad \ + --num_train_epochs 1 \ + --do_train \ + --do_eval \ + --per_device_train_batch_size 12 \ + --per_device_eval_batch_size 12 \ + --learning_rate 3e-5 \ + --max_seq_length 128 \ + --doc_stride 128 \ + --output_dir 95sparsity1epoch/ \ + --overwrite_output_dir \ + --cache_dir cache \ + --preprocessing_num_workers 8 \ + --seed 42 \ + --nm_prune_config prune_config_files/95sparsity1epoch.yaml \ + --do_onnx_export \ + --onnx_export_path 95sparsity1epoch/ \ + --distill_hardness 1.0 \ + --temperature 2.0 \ + --layers_to_keep 12 \ +""" +import logging +import os +import sys +from dataclasses import dataclass, field +from typing import Optional +import random +import math + +import nltk +import wandb +import numpy as np +from datasets import load_dataset, load_metric + +import transformers +from filelock import FileLock +from transformers import ( + AutoConfig, + AutoModelForSeq2SeqLM, + AutoTokenizer, + DataCollatorForSeq2Seq, + HfArgumentParser, + Seq2SeqTrainer, + Seq2SeqTrainingArguments, + set_seed, +) + + +from transformers.optimization import ( + Adafactor, + AdamW, + get_constant_schedule, + get_constant_schedule_with_warmup, + get_cosine_schedule_with_warmup, + get_cosine_with_hard_restarts_schedule_with_warmup, + get_linear_schedule_with_warmup, + get_polynomial_decay_schedule_with_warmup, +) + +from transformers.file_utils import is_offline_mode +from transformers.trainer_utils import get_last_checkpoint +from transformers.utils import check_min_version + +from sparseml.pytorch.optim.manager import ScheduledModifierManager +from sparseml.pytorch.optim.optimizer import ScheduledOptimizer +from sparseml.pytorch.utils import ModuleExporter, logger + +logger = logging.getLogger(__name__) + +try: + nltk.data.find("tokenizers/punkt") +except (LookupError, OSError): + if is_offline_mode(): + raise LookupError( + "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files" + ) + with FileLock(".lock") as lock: + nltk.download("punkt", quiet=True) + +logger = logging.getLogger(__name__) + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + default='t5-base', metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default='t5-base', metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default='t5-base', metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + #################################################################################### + # Start SparseML Integration + #################################################################################### + nm_prune_config: Optional[str] = field( + default='90sparse.yaml', metadata={"help": "The input file name for the Neural Magic pruning config"} + ) + do_onnx_export: bool = field( + default=False, metadata={"help": "Export model to onnx"} + ) + onnx_export_path: Optional[str] = field( + default='onnx-export', metadata={"help": "The filename and path which will be where onnx model is outputed"} + ) + layers_to_keep: int = field( + default=12, metadata={"help":"How many layers to keep for the model"} + ) + #################################################################################### + # End SparseML Integration + #################################################################################### + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + text_column: Optional[str] = field( + default=None, + metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, + ) + summary_column: Optional[str] = field( + default=None, + metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, + ) + train_file: Optional[str] = field( + default="data/doc_query_train.json", metadata={"help": "The input training data file (a jsonlines or csv file)."} + ) + validation_file: Optional[str] = field( + default='data/doc_query_dev.json', + metadata={ + "help": "An optional input evaluation data file to evaluate the metrics (rouge) on " + "(a jsonlines or csv file)." + }, + ) + test_file: Optional[str] = field( + default=None, + metadata={ + "help": "An optional input test data file to evaluate the metrics (rouge) on " "(a jsonlines or csv file)." + }, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_source_length: Optional[int] = field( + default=256, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + max_target_length: Optional[int] = field( + default=32, + metadata={ + "help": "The maximum total sequence length for target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + val_max_target_length: Optional[int] = field( + default=32, + metadata={ + "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." + "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " + "during ``evaluate`` and ``predict``." + }, + ) + pad_to_max_length: bool = field( + default=False, + metadata={ + "help": "Whether to pad all samples to model maximum sentence length. " + "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " + "efficient on GPU but very bad for TPU." + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + max_predict_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " + "value if set." + }, + ) + num_beams: Optional[int] = field( + default=None, + metadata={ + "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " + "which is used during ``evaluate`` and ``predict``." + }, + ) + ignore_pad_token_for_loss: bool = field( + default=True, + metadata={ + "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." + }, + ) + source_prefix: Optional[str] = field( + default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."} + ) + + def __post_init__(self): + if self.dataset_name is None and self.train_file is None and self.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + else: + if self.train_file is not None: + extension = self.train_file.split(".")[-1] + assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." + if self.validation_file is not None: + extension = self.validation_file.split(".")[-1] + assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." + if self.val_max_target_length is None: + self.val_max_target_length = self.max_target_length + +#################################################################################### +# Start SparseML Integration +#################################################################################### +def load_optimizer(model, args): + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [ + p + for n, p in model.named_parameters() + if not any(nd in n for nd in no_decay) + ], + "weight_decay": args.weight_decay, + }, + { + "params": [ + p + for n, p in model.named_parameters() + if any(nd in n for nd in no_decay) + ], + "weight_decay": 0.0, + }, + ] + optimizer_cls = AdamW + optimizer_kwargs = { + "betas": (args.adam_beta1, args.adam_beta2), + "eps": args.adam_epsilon, + } + optimizer_kwargs["lr"] = args.learning_rate + return optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) + +def convert_example_to_features(example, tokenizer, max_seq_length, sentence1_key, sentence2_key): + tokens = [] + segment_ids = [] + tokens.append("[CLS]") + segment_ids.append(0) + for t in tokenizer.tokenize(example[sentence1_key])[:int(max_seq_length/2)]: + tokens.append(t) + segment_ids.append(0) + tokens.append("[SEP]") + segment_ids.append(0) + if sentence1_key != None: + for t in tokenizer.tokenize(example[sentence2_key])[:int(max_seq_length/2)]: + tokens.append(t) + segment_ids.append(0) + tokens.append("[SEP]") + segment_ids.append(1) + input_ids = tokenizer.convert_tokens_to_ids(tokens) + input_mask = [1] * len(input_ids) + while len(input_ids) < max_seq_length: + input_ids.append(0) + input_mask.append(0) + segment_ids.append(0) + return ( + torch.from_numpy(np.array([np.array(input_ids, dtype=np.int64)])), + torch.from_numpy(np.array([np.array(input_mask, dtype=np.int64)])), + torch.from_numpy(np.array([np.array(segment_ids, dtype=np.int64)])), + ) + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + if data_args.source_prefix is None and model_args.model_name_or_path in [ + "t5-small", + "t5-base", + "t5-large", + "t5-3b", + "t5-11b", + ]: + logger.warning( + "You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with " + "`--source_prefix 'summarize: ' `" + ) + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + logger.setLevel(logging.INFO) + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + logger.info(f"Training/evaluation parameters {training_args}") + + # Set seed before initializing model. + set_seed(training_args.seed) + + # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files this script will use the first column for the full texts and the second column for the + # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments). + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if data_args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir) + else: + data_files = {} + if data_args.train_file is not None: + data_files["train"] = data_args.train_file + extension = data_args.train_file.split(".")[-1] + if data_args.validation_file is not None: + data_files["validation"] = data_args.validation_file + extension = data_args.validation_file.split(".")[-1] + if data_args.test_file is not None: + data_files["test"] = data_args.test_file + extension = data_args.test_file.split(".")[-1] + datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Load pretrained model and tokenizer + # + # Distributed training: + # The .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=model_args.use_fast_tokenizer, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + model = AutoModelForSeq2SeqLM.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + model.resize_token_embeddings(len(tokenizer)) + print(model) + if model.config.decoder_start_token_id is None: + raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") + + prefix = data_args.source_prefix if data_args.source_prefix is not None else "" + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + if training_args.do_train: + column_names = datasets["train"].column_names + elif training_args.do_eval: + column_names = datasets["validation"].column_names + elif training_args.do_predict: + column_names = datasets["test"].column_names + else: + logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") + return + + # Get the column names for input/target. + if data_args.text_column is None: + text_column = dataset_columns[0] if dataset_columns is not None else column_names[0] + else: + text_column = data_args.text_column + if text_column not in column_names: + raise ValueError( + f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}" + ) + if data_args.summary_column is None: + summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1] + else: + summary_column = data_args.summary_column + if summary_column not in column_names: + raise ValueError( + f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}" + ) + + # Temporarily set max_target_length for training. + max_target_length = data_args.max_target_length + padding = "max_length" if data_args.pad_to_max_length else False + + if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"): + logger.warning( + "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for" + f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory" + ) + + def preprocess_function(examples): + inputs = examples['input'] + targets = examples['target'] + inputs = [prefix + inp for inp in inputs] + model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) + + # Setup the tokenizer for targets + with tokenizer.as_target_tokenizer(): + labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True) + + # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore + # padding in the loss. + if padding == "max_length" and data_args.ignore_pad_token_for_loss: + labels["input_ids"] = [ + [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] + ] + + model_inputs["labels"] = labels["input_ids"] + return model_inputs + if training_args.do_train: + if "train" not in datasets: + raise ValueError("--do_train requires a train dataset") + train_dataset = datasets["train"] + if data_args.max_train_samples is not None: + train_dataset = train_dataset.select(range(data_args.max_train_samples)) + train_dataset = train_dataset.map( + preprocess_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + ) + if training_args.do_eval: + max_target_length = data_args.val_max_target_length + if "validation" not in datasets: + raise ValueError("--do_eval requires a validation dataset") + eval_dataset = datasets["validation"] + if data_args.max_eval_samples is not None: + eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) + eval_dataset = eval_dataset.map( + preprocess_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + ) + + if training_args.do_predict: + max_target_length = data_args.val_max_target_length + if "test" not in datasets: + raise ValueError("--do_predict requires a test dataset") + predict_dataset = datasets["test"] + if data_args.max_predict_samples is not None: + predict_dataset = predict_dataset.select(range(data_args.max_predict_samples)) + predict_dataset = predict_dataset.map( + preprocess_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + ) + # Data collator + label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id + data_collator = DataCollatorForSeq2Seq( + tokenizer, + model=model, + label_pad_token_id=label_pad_token_id, + pad_to_multiple_of=8 if training_args.fp16 else None, + ) + + # Metric + metric = load_metric("rouge") + + def postprocess_text(preds, labels): + preds = [pred.strip() for pred in preds] + labels = [label.strip() for label in labels] + + # rougeLSum expects newline after each sentence + preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] + labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] + + return preds, labels + + def compute_metrics(eval_preds): + preds, labels = eval_preds + if isinstance(preds, tuple): + preds = preds[0] + decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) + if data_args.ignore_pad_token_for_loss: + # Replace -100 in the labels as we can't decode them. + labels = np.where(labels != -100, labels, tokenizer.pad_token_id) + decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) + + # Some simple post-processing + decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) + + result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) + # Extract a few results from ROUGE + result = {key: value.mid.fmeasure * 100 for key, value in result.items()} + + prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] + result["gen_len"] = np.mean(prediction_lens) + result = {k: round(v, 4) for k, v in result.items()} + return result + + #################################################################################### + # Start SparseML Integration + #################################################################################### + import pdb +# pdb.set_trace() + if training_args.do_train: + optim = load_optimizer(model, training_args) + steps_per_epoch = math.ceil(len(train_dataset) / (training_args.per_device_train_batch_size*training_args._n_gpu)) + manager = ScheduledModifierManager.from_yaml(data_args.nm_prune_config) + training_args.num_train_epochs = float(manager.max_epochs) + optim = ScheduledOptimizer(optim, model, manager, steps_per_epoch=steps_per_epoch, loggers=None) + + # Initialize our Trainer + trainer = Seq2SeqTrainer( + model=model, + args=training_args, + train_dataset=train_dataset if training_args.do_train else None, + eval_dataset=eval_dataset if training_args.do_eval else None, + tokenizer=tokenizer, + data_collator=data_collator, + compute_metrics=compute_metrics if training_args.predict_with_generate else None, + optimizers=(optim, None) if training_args.do_train else (None, None), + ) + + # Training + if training_args.do_train: + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + train_result = trainer.train(resume_from_checkpoint=checkpoint) + trainer.save_model() # Saves the tokenizer too for easy upload + + metrics = train_result.metrics + max_train_samples = ( + data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) + ) + metrics["train_samples"] = min(max_train_samples, len(train_dataset)) + + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + # Evaluation + results = {} + if training_args.do_eval: + logger.info("*** Evaluate ***") + + metrics = trainer.evaluate( + max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval" + ) + max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) + metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) + + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + if training_args.do_predict: + logger.info("*** Predict ***") + + predict_results = trainer.predict( + predict_dataset, + metric_key_prefix="predict", + max_length=data_args.val_max_target_length, + num_beams=data_args.num_beams, + ) + metrics = predict_results.metrics + max_predict_samples = ( + data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) + ) + metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) + + trainer.log_metrics("predict", metrics) + trainer.save_metrics("predict", metrics) + + if trainer.is_world_process_zero(): + if training_args.predict_with_generate: + predictions = tokenizer.batch_decode( + predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True + ) + predictions = [pred.strip() for pred in predictions] + output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt") + with open(output_prediction_file, "w") as writer: + writer.write("\n".join(predictions)) + + if training_args.push_to_hub: + kwargs = {"finetuned_from": model_args.model_name_or_path, "tags": "summarization"} + if data_args.dataset_name is not None: + kwargs["dataset_tags"] = data_args.dataset_name + if data_args.dataset_config_name is not None: + kwargs["dataset_args"] = data_args.dataset_config_name + kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" + else: + kwargs["dataset"] = data_args.dataset_name + + trainer.push_to_hub(**kwargs) + + #################################################################################### + # Start SparseML Integration + #################################################################################### + if data_args.do_onnx_export: + logger.info("*** Export to ONNX ***") + os.environ["TOKENIZERS_PARALLELISM"] = "false" + exporter = ModuleExporter( + student_model, output_dir=data_args.onnx_export_path + ) + sample_batch = convert_example_to_features( + datasets["train"][0], + tokenizer, + data_args.max_seq_length, + sentence1_key, + sentence2_key, + ) + exporter.export_onnx(sample_batch=sample_batch) + #################################################################################### + # End SparseML Integration + #################################################################################### + return results + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + +if __name__ == "__main__": + main() diff --git a/research/information_retrieval/doc2query/src/sparseml_utils.py b/research/information_retrieval/doc2query/src/sparseml_utils.py new file mode 100644 index 00000000000..b6f8b6dbb27 --- /dev/null +++ b/research/information_retrieval/doc2query/src/sparseml_utils.py @@ -0,0 +1,121 @@ +import math + +import torch +import torch.nn.functional as F + +from sparseml.pytorch.optim.manager import ScheduledModifierManager +from sparseml.pytorch.optim.optimizer import ScheduledOptimizer +from sparseml.pytorch.utils import ModuleExporter, logger +from trainer_qa import QuestionAnsweringTrainer + + +class SparseMLQATrainer(QuestionAnsweringTrainer): + """ + Question Answering trainer with SparseML integration + + :param recipe: recipe for model sparsification + :param teacher: teacher model for distillation + :param distill_hardness: ratio of loss by teacher targets (between 0 and 1) + :param distill_temperature: temperature for distillation + :param args, kwargs: arguments passed into parent class + """ + + def __init__(self, recipe, teacher=None, distill_hardness=0.5, distill_temperature=2.0, *args, **kwargs): + super().__init__(*args, **kwargs) + self.recipe = recipe + self.teacher = teacher + self.distill_hardness = distill_hardness + self.distill_temperature = distill_temperature + self.criterion = torch.nn.CrossEntropyLoss() + + self.manager = None + self.loggers = None + if self.recipe is not None: + loggers = [] + if "wandb" in self.args.report_to: + loggers.append(logger.WANDBLogger()) + self.loggers = loggers + + def create_optimizer(self): + """ + Create optimizer customized using SparseML + """ + super().create_optimizer() + if self.recipe is None: + return + steps_per_epoch = math.ceil( + len(self.train_dataset) / (self.args.per_device_train_batch_size * self.args._n_gpu) + ) + self.manager = ScheduledModifierManager.from_yaml(self.recipe) + self.args.num_train_epochs = float(self.manager.max_epochs) + if hasattr(self, "scaler"): + self.manager.initialize(self.model, epoch=0.0, loggers=self.loggers) + self.scaler = self.manager.modify( + self.model, self.optimizer, steps_per_epoch=steps_per_epoch, wrap_optim=self.scaler + ) + else: + self.optimizer = ScheduledOptimizer( + self.optimizer, self.model, self.manager, steps_per_epoch=steps_per_epoch, loggers=self.loggers + ) + + def compute_loss(self, model, inputs, return_outputs=False): + """ + Computing loss using teacher/student distillation + """ + if self.recipe is None or self.teacher is None: + return super().compute_loss(model, inputs, return_outputs=return_outputs) + + outputs = model(**inputs) + if self.teacher is None: + loss = outputs["loss"] + else: + input_device = inputs["input_ids"].device + self.teacher = self.teacher.to(input_device) + start_logits_student = outputs["start_logits"] + end_logits_student = outputs["end_logits"] + start_logits_label = inputs["start_positions"] + end_logits_label = inputs["end_positions"] + with torch.no_grad(): + teacher_output = self.teacher( + input_ids=inputs["input_ids"], + token_type_ids=inputs["token_type_ids"], + attention_mask=inputs["attention_mask"], + ) + start_logits_teacher = teacher_output["start_logits"] + end_logits_teacher = teacher_output["end_logits"] + loss_start = ( + F.kl_div( + input=F.log_softmax(start_logits_student / self.distill_temperature, dim=-1), + target=F.softmax(start_logits_teacher / self.distill_temperature, dim=-1), + reduction="batchmean", + ) + * (self.distill_temperature ** 2) + ) + loss_end = ( + F.kl_div( + input=F.log_softmax(end_logits_student / self.distill_temperature, dim=-1), + target=F.softmax(end_logits_teacher / self.distill_temperature, dim=-1), + reduction="batchmean", + ) + * (self.distill_temperature ** 2) + ) + teacher_loss = (loss_start + loss_end) / 2.0 + loss_start = self.criterion(start_logits_student, start_logits_label) + loss_end = self.criterion(end_logits_student, end_logits_label) + label_loss = (loss_start + loss_end) / 2.0 + loss = ((1 - self.distill_hardness) * label_loss) + (self.distill_hardness * teacher_loss) + return (loss, outputs) if return_outputs else loss + + +def export_model(model, dataloader, output_dir): + """ + Export a trained model to ONNX + :param model: trained model + :param dataloader: dataloader to get sample batch + :param output_dir: output directory for ONNX model + """ + exporter = ModuleExporter(model, output_dir=output_dir) + for _, sample_batch in enumerate(dataloader): + sample_input = (sample_batch["input_ids"], sample_batch["attention_mask"], sample_batch["token_type_ids"]) + exporter.export_onnx(sample_batch=sample_input, convert_qat=True) + break \ No newline at end of file diff --git a/research/information_retrieval/elastic_integration/README.md b/research/information_retrieval/elastic_integration/README.md new file mode 100644 index 00000000000..26b4528696b --- /dev/null +++ b/research/information_retrieval/elastic_integration/README.md @@ -0,0 +1,37 @@ +# Dense Information Retrieval Integration with Elastic Search +Author: @spacemanidol +This folder contains information on how to leverage sparse Dense Information Retrieval methods with sparseml + +To run any research projects, cd into the desired project's directory and install from the requirements.txt file using the following: +```bash +pip install -r requirements.txt +``` + +If any issues are encountered, first try starting from a new virtual environment and install the requirements: +```bash +virtualenv -p python3 venv +``` + +If there are continued issues, contact the author(s) indicated at the top of the README of each project's directory. + +## Setup +### Elastic Search +First you will need to set up an active elastic search cluster. Instructions for how to do this can be found [here](https://www.elastic.co/guide/en/elasticsearch/reference/current/targz.html) and once installed can be run as shown below. +```bash +./bin/elasticsearch +``` +## Usage +### Index Generation +#### Passage Collection +#### Document Collection + +## Results +### Index Generation +Speed breakdowns + +### Retrieval comparison + +MSMARCO Passage Ranking +MSMARCO Document Ranking +NQ +TriviaQA diff --git a/research/information_retrieval/elastic_integration/chunker.py b/research/information_retrieval/elastic_integration/chunker.py new file mode 100644 index 00000000000..134d10b7e44 --- /dev/null +++ b/research/information_retrieval/elastic_integration/chunker.py @@ -0,0 +1,36 @@ +# neuralmagic: no copyright +# flake8: noqa +# fmt: off +# isort: skip_file +#!/usr/bin/env python +# coding=utf-8 +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re +import logging +from typing import List +from transformers import DPRReaderTokenizer + +class DocumentChunker: + def __init__(self, tokenizer, max_tokens = 512, max_query_tokens=30, document_chunks=5): + self.tokenizer = tokenizer #DPRReaderTokenizer.from_pretrained('facebook/dpr-reader-single-nq-base') + self.max_tokens = max_tokens + self.document_chunks = document_chunks + self.max_query_tokens = max_query_tokens + self.max_doc_tokens = self.max_tokens - self.max_query_tokens -2 # [SEP] and [CLS] + + def chunk_doc(self, document ): + #This is where logic to split documents goes. We keep the first chunk, last chunk, and a classifed best chunk + chunks = [] + return chunks diff --git a/research/information_retrieval/elastic_integration/dense_document.py b/research/information_retrieval/elastic_integration/dense_document.py new file mode 100644 index 00000000000..4d021c8f91d --- /dev/null +++ b/research/information_retrieval/elastic_integration/dense_document.py @@ -0,0 +1,36 @@ +# neuralmagic: no copyright +# flake8: noqa +# fmt: off +# isort: skip_file +#!/usr/bin/env python +# coding=utf-8 +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json + +class DenseDocument: + title: str + body: str + chunks: list + def __init__(self, title: str, body:str, chunks:list): + self.title = title + self.body = body + self.chunks = chunks + + def to_dict(self): + return {'title': self.title, 'body': self.body, 'chunks': self.chunks} + + def __repr__(self): + pretty = json.dumps(self.to_dict()) + return pretty diff --git a/research/information_retrieval/elastic_integration/dense_ranking.py b/research/information_retrieval/elastic_integration/dense_ranking.py new file mode 100644 index 00000000000..10a3ee67d5c --- /dev/null +++ b/research/information_retrieval/elastic_integration/dense_ranking.py @@ -0,0 +1,122 @@ +# neuralmagic: no copyright +# flake8: noqa +# fmt: off +# isort: skip_file +#!/usr/bin/env python +# coding=utf-8 +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re +import logging +import torch +import faiss +from tqdm import tqdm +from typing import List +from elasticsearch import Elasticsearch +from transformers import (DPRContextEncoder, DPRContextEncoderTokenizer, + DPRQuestionEncoder, DPRQuestionEncoderTokenizer) +from chunker import DocumentChunker +from dense_document import DenseDocument + + +class DenseIndex(): + def __init__(self, documents, context_tokenizer, context_model, query_tokenizer, query_model, index_name='dense-index', dimension=768): + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.index_name = index_name + self.faiss_index = faiss.IndexFlatIP(self.dimension) + self.prep_index() + self.generate_index(documents) + + def prep_index(self): + self.es = Elasticsearch() + if self.es.indices.exists(self.index_name): + logging.warning(f'Deleting old index for {self.index_name}.') + self.es.indices.delete(self.index_name) + self.es.indices.create(index=self.index_name) + + def generate_index(self, documents, max_passages: int = 5): #Works for passages because passages dataset will only have 1 chunk + self.documents = documents + self.doc_bodies = [doc.body for doc in self.documents] + self.passage_to_doc = {} #pasage_id to doc_id + self.passages = [] + doc_id = 0 + passage_id = 0 + for doc_counter, doc_body in tqdm(enumerate(self.doc_bodies),total=len(self.doc_bodies)): + self.es.create(self.index_name, id=doc_id, body={'document': doc_body}) + passages = self.chunk_document(doc_body) + for i in range(min(len(passages),max_passages)): #NEED to add a chunking strategy first P, Last P, Best P + self.passages.append(passages[i]) + input_ids = self.context_tokenizer(passages[i], return_tensors='pt')['input_ids'] + self.faiss_index.add(self.context_model(input_ids).pooler_output.detach().numpy()) + self.passage_to_doc[passage_id] = doc_id + passage_id += 1 + doc_id += 1 + + def dense_search(self, query: str, k: int = 10): + input_ids = self.query_tokenizer(query, return_tensors='pt')['input_ids'] + vec_dists, vec_ids = self.faiss_index.search(self.query_model(input_ids).pooler_output.detach().numpy(), k=k) + vec_dists, vec_ids = list(vec_dists[0]), list(vec_ids[0]) + vec_dists= list(map(float, vec_dists)) + results = [] + for dist, passage_id in zip(vec_dists, vec_ids): + document_id = self.passage_to_doc[passage_id] + result = { + 'document': self.documents[document_id], + 'document_id': document_id, + 'passage': self.passages[passage_id], + 'passage_id': int(passage_id), + 'faiss_dist': dist + } + results.append(result) + return results + + def sparse_search(self, query: str, k: int = 10): + body = { + 'size': k, + 'query': { + 'match': { + 'chunk': query + } + } + } + results = self.es.search(index=self.index_name, body=body) + hits = results['hits']['hits'] + return hits + + def hybrid_search(self, query: str, k: int = 10, dense_weight: float = 0.5): + results_index = {} + for sparse_result in self.sparse_search(query): + id, score = sparse_result['_id'], sparse_result['_score'] + id = int(id) + results_index[id] = {'elastic_score': score} + for dense_result in self.dense_search(query): + id, score = dense_result['passage_id'], dense_result['faiss_dist'] + if id in results_index: + results_index[id]['faiss_dist'] = score + else: + results_index[id] = {'faiss_dist': score, 'elastic_score': 0} + results = [] + for passage_id, scores in results_index.items(): + document_id = self.passage_to_doc[passage_id] + document = self.documents[document_id] + doc_profile = document.to_dict() + result = { + 'document': doc_profile, + 'document_id': document_id, + 'passage': self.passages[passage_id], + 'passage_id': int(passage_id), + 'scores': scores + } + results.append(result) + return results diff --git a/research/information_retrieval/elastic_integration/requirements.txt b/research/information_retrieval/elastic_integration/requirements.txt new file mode 100644 index 00000000000..e95495fa482 --- /dev/null +++ b/research/information_retrieval/elastic_integration/requirements.txt @@ -0,0 +1,8 @@ +transformers +torch +faiss +tqdm +elasticsearch +streamlit +requests +sparseml \ No newline at end of file diff --git a/research/information_retrieval/elastic_integration/run_ranker.py b/research/information_retrieval/elastic_integration/run_ranker.py new file mode 100644 index 00000000000..89af103cf99 --- /dev/null +++ b/research/information_retrieval/elastic_integration/run_ranker.py @@ -0,0 +1,40 @@ +# neuralmagic: no copyright +# flake8: noqa +# fmt: off +# isort: skip_file +#!/usr/bin/env python +# coding=utf-8 +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example script for integrating spare ranking models with Elasticsearch. +########## +Command help: +usage: run_ranker.py [-h] \ + +########## +Example command for generating a Dense Elastic Search compatible index using sparse Bi-encoder Model: +python research/information_retrieval/run_ranker.py \ + +""" + + context_tokenizer = DPRContextEncoderTokenizer.from_pretrained( + 'facebook/dpr-ctx_encoder-single-nq-base') + context_model = DPRContextEncoder.from_pretrained( + 'facebook/dpr-ctx_encoder-single-nq-base', return_dict=True) + question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( + 'facebook/dpr-question_encoder-single-nq-base') + question_model = DPRQuestionEncoder.from_pretrained( + 'facebook/dpr-question_encoder-single-nq-base', return_dict=True) \ No newline at end of file diff --git a/research/optimal_lobotomizing/README.md b/research/optimal_lobotomizing/README.md new file mode 100644 index 00000000000..d7fd400ae77 --- /dev/null +++ b/research/optimal_lobotomizing/README.md @@ -0,0 +1,17 @@ +# Optimal Lobotomizing: Exploring the effects of model compression on memorization in language models +Author: @spacemanidol + +Language models have proven to be incredibly effective methods for language understanding and generation. As they are trained on massive textual datasets they memorize + +### Method + +### Prep and Data Gen +1. Find Datasets that focuses on memorization for decoder and encoder models(GPT-NEO) +### Experiments +1. Train models +2. Prune attention heads +3. Prune layers +4. Unstructured pruning + + + diff --git a/research/optimal_lobotomizing/data/init.txt b/research/optimal_lobotomizing/data/init.txt new file mode 100644 index 00000000000..e69de29bb2d diff --git a/research/optimal_lobotomizing/scripts/init.sh b/research/optimal_lobotomizing/scripts/init.sh new file mode 100644 index 00000000000..e69de29bb2d diff --git a/research/optimal_lobotomizing/src/init.py b/research/optimal_lobotomizing/src/init.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/sparseml/pytorch/optim/modifier_pruning.py b/src/sparseml/pytorch/optim/modifier_pruning.py index 0b1bd58684e..6fb55a6467c 100644 --- a/src/sparseml/pytorch/optim/modifier_pruning.py +++ b/src/sparseml/pytorch/optim/modifier_pruning.py @@ -493,6 +493,9 @@ class GMPruningModifier(_PruningParamsModifier): immediately after or doing some other prune :param inter_func: the type of interpolation function to use: [linear, cubic, inverse_cubic] + :param phased: True to enable a phased approach where pruning will + turn on and off with the update_frequency. Starts with pruning on + at start_epoch, off at start_epoch + update_frequency, and so on. :param log_types: The loggers to allow the learning rate to be logged to, default is __ALL__ :param mask_type: String to define type of sparsity (options: ['unstructured', @@ -514,6 +517,7 @@ def __init__( params: Union[str, List[str]], leave_enabled: bool = True, inter_func: str = "cubic", + phased: bool = False, log_types: Union[str, List[str]] = ALL_TOKEN, mask_type: Union[str, List[int], PruningMaskCreator] = "unstructured", global_sparsity: bool = False, @@ -531,6 +535,7 @@ def __init__( self._final_sparsity = final_sparsity self._leave_enabled = convert_to_bool(leave_enabled) self._inter_func = inter_func + self._phased = phased self._mask_type = mask_type self._mask_creator = ( mask_type @@ -612,6 +617,24 @@ def inter_func(self, value: str): self._inter_func = value self.validate() + @ModifierProp() + def phased(self) -> bool: + """ + :return: True to enable a phased approach where pruning will + turn on and off with the update_frequency. Starts with pruning on + at start_epoch, off at start_epoch + update_frequency, and so on. + """ + return self._phased + + @phased.setter + def phased(self, value: bool): + """ + :param value: the type of interpolation function to use: + [linear, cubic, inverse_cubic] + """ + self._phased = value + self.validate() + @ModifierProp() def mask_type(self) -> Union[str, List[int], PruningMaskCreator]: """ @@ -763,6 +786,16 @@ def _check_mask_update(self, module: Module, epoch: float, steps_per_epoch: int) self._final_sparsity, self._inter_func, ) + + # make sure if phased that the phases end at the final sparsity + # if it doesn't divide evenly + if self.phased and not self.end_pending(epoch, steps_per_epoch): + # adjust for phased pruning: start=on, start+update=off + phase = math.floor((epoch - self.start_epoch) / self.update_frequency) + if phase % 2 != 0: + # odd update phase, turn sparsity off + self._applied_sparsity = 0.0 + self._module_masks.set_param_masks_from_sparsity(self._applied_sparsity) if self.end_pending(epoch, steps_per_epoch): @@ -843,6 +876,9 @@ class MagnitudePruningModifier(GMPruningModifier): immediately after or doing some other prune :param inter_func: the type of interpolation function to use: [linear, cubic, inverse_cubic] + :param phased: True to enable a phased approach where pruning will + turn on and off with the update_frequency. Starts with pruning on + at start_epoch, off at start_epoch + update_frequency, and so on. :param log_types: The loggers to allow the learning rate to be logged to, default is __ALL__ :param mask_type: String to define type of sparsity (options: ['unstructured', @@ -860,6 +896,7 @@ def __init__( params: Union[str, List[str]], leave_enabled: bool = True, inter_func: str = "cubic", + phased: bool = False, log_types: Union[str, List[str]] = ALL_TOKEN, mask_type: Union[str, List[int], PruningMaskCreator] = "unstructured", ): @@ -872,6 +909,7 @@ def __init__( params=params, leave_enabled=leave_enabled, inter_func=inter_func, + phased=phased, log_types=log_types, mask_type=mask_type, global_sparsity=False, @@ -933,6 +971,9 @@ class MovementPruningModifier(GMPruningModifier): immediately after or doing some other prune :param inter_func: the type of interpolation function to use: [linear, cubic, inverse_cubic] + :param phased: True to enable a phased approach where pruning will + turn on and off with the update_frequency. Starts with pruning on + at start_epoch, off at start_epoch + update_frequency, and so on. :param log_types: The loggers to allow the learning rate to be logged to, default is __ALL__ :param mask_type: String to define type of sparsity (options: ['unstructured', @@ -950,6 +991,7 @@ def __init__( params: Union[str, List[str]], leave_enabled: bool = True, inter_func: str = "cubic", + phased: bool = False, log_types: Union[str, List[str]] = ALL_TOKEN, mask_type: Union[str, List[int], PruningMaskCreator] = "unstructured", ): @@ -962,6 +1004,7 @@ def __init__( params=params, leave_enabled=leave_enabled, inter_func=inter_func, + phased=phased, log_types=log_types, mask_type=mask_type, global_sparsity=False, @@ -1024,6 +1067,9 @@ class GlobalMagnitudePruningModifier(GMPruningModifier): immediately after or doing some other prune :param inter_func: the type of interpolation function to use: [linear, cubic, inverse_cubic] + :param phased: True to enable a phased approach where pruning will + turn on and off with the update_frequency. Starts with pruning on + at start_epoch, off at start_epoch + update_frequency, and so on. :param log_types: The loggers to allow the learning rate to be logged to, default is __ALL__ :param mask_type: String to define type of sparsity (options: ['unstructured', @@ -1043,6 +1089,7 @@ def __init__( params: Union[str, List[str]] = ALL_PRUNABLE_TOKEN, leave_enabled: bool = True, inter_func: str = "cubic", + phased: bool = False, log_types: Union[str, List[str]] = ALL_TOKEN, mask_type: Union[str, List[int], PruningMaskCreator] = "unstructured", score_type: Union[str, MFACOptions] = "magnitude", @@ -1056,6 +1103,7 @@ def __init__( params=params, leave_enabled=leave_enabled, inter_func=inter_func, + phased=phased, log_types=log_types, mask_type=mask_type, global_sparsity=True, @@ -1115,6 +1163,9 @@ class MFACPruningModifier(GMPruningModifier): immediately after or doing some other prune :param inter_func: the type of interpolation function to use: [linear, cubic, inverse_cubic] + :param phased: True to enable a phased approach where pruning will + turn on and off with the update_frequency. Starts with pruning on + at start_epoch, off at start_epoch + update_frequency, and so on. :param log_types: The loggers to allow the learning rate to be logged to, default is __ALL__ :param mask_type: String to define type of sparsity (options: ['unstructured', @@ -1139,6 +1190,7 @@ def __init__( params: Union[str, List[str]], leave_enabled: bool = True, inter_func: str = "cubic", + phased: bool = False, log_types: Union[str, List[str]] = ALL_TOKEN, mask_type: Union[str, List[int], PruningMaskCreator] = "unstructured", mfac_options: Dict[str, Any] = None, @@ -1152,6 +1204,7 @@ def __init__( params=params, leave_enabled=leave_enabled, inter_func=inter_func, + phased=phased, log_types=log_types, mask_type=mask_type, global_sparsity=True, diff --git a/src/sparseml/version.py b/src/sparseml/version.py index 1bf8a9d0f39..4d71deafab5 100644 --- a/src/sparseml/version.py +++ b/src/sparseml/version.py @@ -19,7 +19,7 @@ from datetime import date -version_base = "0.5.0" +version_base = "0.5.1" is_release = False # change to True to set the generated version as a release version diff --git a/tests/sparseml/pytorch/optim/test_modifier_pruning.py b/tests/sparseml/pytorch/optim/test_modifier_pruning.py index 9ef2c98614a..9b02d1c651e 100644 --- a/tests/sparseml/pytorch/optim/test_modifier_pruning.py +++ b/tests/sparseml/pytorch/optim/test_modifier_pruning.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math import os import pytest @@ -246,6 +247,16 @@ def test_constant_pruning_yaml(): inter_func="cubic", mask_type=[1, 4], ), + lambda: GMPruningModifier( + params=["__ALL_PRUNABLE__"], + init_sparsity=0.9, + final_sparsity=0.9, + start_epoch=10.0, + end_epoch=25.0, + update_frequency=2.0, + inter_func="cubic", + phased=True, + ), ], scope="function", ) @@ -294,7 +305,22 @@ def test_lifecycle( epoch += modifier.update_frequency assert modifier.update_ready(epoch, test_steps_per_epoch) modifier.scheduled_update(model, optimizer, epoch, test_steps_per_epoch) - assert modifier.applied_sparsity > last_sparsity + + if not modifier.phased: + assert modifier.applied_sparsity > last_sparsity + else: + pruned_on = ( + math.floor( + (epoch - modifier.start_epoch) / modifier.update_frequency + ) + % 2 + == 0 + ) + if pruned_on: + assert modifier.applied_sparsity >= last_sparsity + else: + assert modifier.applied_sparsity == 0 + last_sparsity = modifier.applied_sparsity _ = model(test_batch) # check forward pass