Skip to content

Commit

Permalink
Pin PTL, bump omegaconf (#1049)
Browse files Browse the repository at this point in the history
* Pin PTL, bump omegaconf

Signed-off-by: smajumdar <titu1994@gmail.com>

* Patch config preservation

Signed-off-by: smajumdar <titu1994@gmail.com>

* Patch config preservation

Signed-off-by: smajumdar <titu1994@gmail.com>

* Patch config preservation

Signed-off-by: smajumdar <titu1994@gmail.com>

* Correct preservation of decoder config after update

Signed-off-by: smajumdar <titu1994@gmail.com>

* Bump hydra-core

Signed-off-by: smajumdar <titu1994@gmail.com>

* Hyper parameter saving patch

Signed-off-by: smajumdar <titu1994@gmail.com>
  • Loading branch information
titu1994 authored Aug 20, 2020
1 parent cf0d554 commit fcc1d99
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 11 deletions.
7 changes: 5 additions & 2 deletions nemo/collections/asr/models/classification_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Dict, List, Optional, Union

import torch
from omegaconf import DictConfig, ListConfig
from omegaconf import DictConfig, ListConfig, OmegaConf
from pytorch_lightning import Trainer

from nemo.collections.asr.data.audio_to_text import AudioLabelDataset
Expand Down Expand Up @@ -246,7 +246,10 @@ def change_labels(self, new_labels: List[str]):

# Update config
self._cfg.labels = new_labels
self._cfg.decoder.params = new_decoder_config

OmegaConf.set_struct(self._cfg.decoder, False)
self._cfg.decoder = new_decoder_config
OmegaConf.set_struct(self._cfg.decoder, True)

if 'train_ds' in self._cfg and self._cfg.train_ds is not None:
self._cfg.train_ds.labels = new_labels
Expand Down
9 changes: 6 additions & 3 deletions nemo/collections/asr/models/ctc_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,17 +256,20 @@ def change_vocabulary(self, new_tokenizer_dir: str, new_tokenizer_type: str):
# Override number of classes if placeholder provided
logging.info(
"\nReplacing old number of classes ({}) with new number of classes - {}".format(
decoder_config.params['num_classes'], len(vocabulary)
decoder_config['params']['num_classes'], len(vocabulary)
)
)
decoder_config.params['num_classes'] = len(vocabulary)
decoder_config['params']['num_classes'] = len(vocabulary)

del self.decoder
self.decoder = EncDecCTCModelBPE.from_config_dict(decoder_config)
self._wer = WERBPE(tokenizer=self.tokenizer, batch_dim_index=0, use_cer=False, ctc_decode=True)

# Update config
self._cfg.decoder.params = decoder_config
OmegaConf.set_struct(self._cfg.decoder, False)
self._cfg.decoder = decoder_config
OmegaConf.set_struct(self._cfg.decoder, True)

logging.info(f"Changed tokenizer to {self.decoder.vocabulary} vocabulary.")


Expand Down
7 changes: 5 additions & 2 deletions nemo/collections/asr/models/ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Dict, List, Optional, Union

import torch
from omegaconf import DictConfig
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer

from nemo.collections.asr.data.audio_to_text import AudioToCharDataset, TarredAudioToCharDataset
Expand Down Expand Up @@ -154,7 +154,10 @@ def change_vocabulary(self, new_vocabulary: List[str]):
self._wer = WER(vocabulary=self.decoder.vocabulary, batch_dim_index=0, use_cer=False, ctc_decode=True)

# Update config
self._cfg.decoder.params = new_decoder_config
OmegaConf.set_struct(self._cfg.decoder, False)
self._cfg.decoder = new_decoder_config
OmegaConf.set_struct(self._cfg.decoder, True)

logging.info(f"Changed decoder to output to {self.decoder.vocabulary} vocabulary.")

def _setup_dataloader_from_config(self, config: Optional[Dict]):
Expand Down
3 changes: 2 additions & 1 deletion nemo/core/classes/modelPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):

self._cfg = config

self.save_hyperparameters(self._cfg)
self.save_hyperparameters(OmegaConf.to_container(self._cfg, resolve=True))
self._train_dl = None
self._validation_dl = None
self._test_dl = None
Expand Down Expand Up @@ -231,6 +231,7 @@ def load_from_checkpoint(
Loads ModelPT from checkpoint, with some maintenance of restoration.
For documentation, please refer to LightningModule.load_from_checkpoin() documentation.
"""
# TODO (@titu1994): When PTL 0.9+ is supported, add `strict=False` flag to constructor
checkpoint = None
try:
cls.__set_model_restore_state(is_being_restored=True)
Expand Down
6 changes: 3 additions & 3 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
numpy>=1.18.2
onnx>=1.7.0
pytorch-lightning>=0.8.5
pytorch-lightning==0.8.5
python-dateutil
torch
wget
wrapt
ruamel.yaml
scikit-learn
omegaconf==2.0.1rc11
hydra-core==1.0.0rc3
omegaconf==2.0.1rc12
hydra-core==1.0.0rc4
transformers>=2.11.0

0 comments on commit fcc1d99

Please sign in to comment.