Skip to content

Latest commit

 

History

History
27 lines (21 loc) · 1.45 KB

README.md

File metadata and controls

27 lines (21 loc) · 1.45 KB

infersent

This is a Tensorflow version of InferSent. For simplicity, this project only implement the BiLSTM with max pooling model, based on Supervised Learning of Universal Sentence Representations from Natural Language Inference Data.

How to train the model?

  1. Download SNLI dataset and GloVe vectors trained on Common Crawl 840B with 300 dimensions.

  2. Pre-process SNLI data, create dataset for training and validating, and store them in TFRecord files:

python preprocess_dataset.py --glove_file /path/to/your/glove --input_files /path/to/your/snli/snli_1.0_train.jsonl,/path/to/your/snli/snli_1.0_dev.jsonl,/path/to/your/snli/snli_1.0_test.jsonl --output_dir /path/to/save/tfrecords
  1. Train the InferModel:
python train.py --glove_file /path/to/your/glove --input_train_file_pattern "/path/to/save/tfrecords/train-?????-of-?????" --input_valid_file_pattern "/path/to/save/tfrecords/valid-?????-of-?????" --train_dir /path/to/save/checkpoints

Experiment result

With the default settings in configurations.py, I obtained a dev accuracy of 83.72% in epoch 5, 83.37% in epoch 10.

Tips for fine-tuning

Don't dropout too much. When the classifier and encoder dropout are both set to 0.5, both the train and the dev accuracy is decreased below 70%.