From 2831a27f40be21ac54e8a74a456a5ca1ea351e67 Mon Sep 17 00:00:00 2001 From: Matt Gardner Date: Thu, 27 Feb 2020 08:06:49 -0800 Subject: [PATCH] initial test, passing --- allennlp/commands/train.py | 92 +++++++++++++++++++++++++++ allennlp/tests/commands/train_test.py | 21 ++++++ 2 files changed, 113 insertions(+) diff --git a/allennlp/commands/train.py b/allennlp/commands/train.py index bd6a772065c..e62c22dc9aa 100644 --- a/allennlp/commands/train.py +++ b/allennlp/commands/train.py @@ -698,3 +698,95 @@ def from_partial_objects( TrainModel.register("default", constructor="from_partial_objects")(TrainModel) + + +def train_with_lightning( + params: Params +): + from copy import deepcopy + dataset_reader = DatasetReader.from_params(params['dataset_reader']) + datasets = training_util.read_all_datasets( + train_data_path=params['train_data_path'], + dataset_reader=dataset_reader, + validation_dataset_reader=dataset_reader, + validation_data_path=params['validation_data_path'], + test_data_path=params.get('test_data_path'), + ) + + datasets_for_vocab_creation = params.get('datasets_for_vocab_creation') + if datasets_for_vocab_creation: + for key in datasets_for_vocab_creation: + if key not in datasets: + raise ConfigurationError(f"invalid 'dataset_for_vocab_creation' {key}") + + instance_generator = ( + instance + for key, dataset in datasets.items() + if not datasets_for_vocab_creation or key in datasets_for_vocab_creation + for instance in dataset + ) + + vocabulary = Vocabulary.from_params( + params=params.get('vocabulary', Params({})), + instances=instance_generator + ) + model = Model.from_params(params=params['model'], vocab=vocabulary) + + for dataset in datasets.values(): + dataset.index_with(model.vocab) + + data_loader = DataLoader.from_params( + params=deepcopy(params['data_loader']), + dataset=datasets["train"] + ) + validation_data = datasets.get("validation") + if validation_data is not None: + validation_data_loader = DataLoader.from_params( + params=deepcopy(params.get('validation_data_loader', params['data_loader'])), + dataset=validation_data + ) + else: + validation_data_loader = None + + test_data = datasets.get("test") + if test_data is not None: + test_data_loader = DataLoader.from_params( + params=deepcopy(params.get('validation_data_loader', params['data_loader'])), + dataset=test_data + ) + else: + test_data_loader = None + + + import pytorch_lightning + class LightningModule(pytorch_lightning.LightningModule): + def forward(self, **kwargs): + return model(**kwargs) + + def training_step(self, batch, batch_idx): + print(f"\n\nBATCH: {batch}\n\n") + # log needs to be separated out here, but presumably your model code can do that. + return {'loss': model(**batch)['loss']} + + def validation_step(self, batch, batch_idx): + return {'val_loss': modoel(**batch)['loss']} + + def configure_optimizers(self): + from allennlp.training.optimizers import Optimizer + parameters = [[n, p] for n, p in model.named_parameters() if p.requires_grad] + return Optimizer.from_params( + params=params['trainer']['optimizer'], + model_parameters=parameters + ) + + @pytorch_lightning.data_loader + def train_dataloader(self): + return data_loader + + @pytorch_lightning.data_loader + def validation_dataloader(self): + return validation_data_loader + + module = LightningModule() + trainer = pytorch_lightning.Trainer() + trainer.fit(module) diff --git a/allennlp/tests/commands/train_test.py b/allennlp/tests/commands/train_test.py index 6f4dea1ea94..cb6bf3878b3 100644 --- a/allennlp/tests/commands/train_test.py +++ b/allennlp/tests/commands/train_test.py @@ -9,6 +9,7 @@ import pytest import torch +from allennlp.commands import train from allennlp.commands.train import Train, train_model, train_model_from_args, TrainModel from allennlp.common import Params from allennlp.common.checks import ConfigurationError @@ -83,6 +84,26 @@ def test_train_model(self): recover=True, ) + def test_train_model_with_lightning(self): + params = lambda: Params( + { + "model": { + "type": "simple_tagger", + "text_field_embedder": { + "token_embedders": {"tokens": {"type": "embedding", "embedding_dim": 5}} + }, + "encoder": {"type": "lstm", "input_size": 5, "hidden_size": 7, "num_layers": 2}, + }, + "dataset_reader": {"type": "sequence_tagging"}, + "train_data_path": SEQUENCE_TAGGING_DATA_PATH, + "validation_data_path": SEQUENCE_TAGGING_DATA_PATH, + "data_loader": {"batch_size": 2}, + "trainer": {"num_epochs": 2, "optimizer": "adam"}, + } + ) + + train.train_with_lightning(params()) + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need multiple GPUs.") def test_train_model_distributed(self): params = lambda: Params(