Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pin PTL, bump omegaconf #1049

Merged
merged 7 commits into from
Aug 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions nemo/collections/asr/models/classification_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Dict, List, Optional, Union

import torch
from omegaconf import DictConfig, ListConfig
from omegaconf import DictConfig, ListConfig, OmegaConf
from pytorch_lightning import Trainer

from nemo.collections.asr.data.audio_to_text import AudioLabelDataset
Expand Down Expand Up @@ -246,7 +246,10 @@ def change_labels(self, new_labels: List[str]):

# Update config
self._cfg.labels = new_labels
self._cfg.decoder.params = new_decoder_config

OmegaConf.set_struct(self._cfg.decoder, False)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use open_dict:

with open_dict(self._cfg.decoder):
  self._cfg.decoder = new_decoder_config

self._cfg.decoder = new_decoder_config
OmegaConf.set_struct(self._cfg.decoder, True)

if 'train_ds' in self._cfg and self._cfg.train_ds is not None:
self._cfg.train_ds.labels = new_labels
Expand Down
9 changes: 6 additions & 3 deletions nemo/collections/asr/models/ctc_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,17 +256,20 @@ def change_vocabulary(self, new_tokenizer_dir: str, new_tokenizer_type: str):
# Override number of classes if placeholder provided
logging.info(
"\nReplacing old number of classes ({}) with new number of classes - {}".format(
decoder_config.params['num_classes'], len(vocabulary)
decoder_config['params']['num_classes'], len(vocabulary)
)
)
decoder_config.params['num_classes'] = len(vocabulary)
decoder_config['params']['num_classes'] = len(vocabulary)

del self.decoder
self.decoder = EncDecCTCModelBPE.from_config_dict(decoder_config)
self._wer = WERBPE(tokenizer=self.tokenizer, batch_dim_index=0, use_cer=False, ctc_decode=True)

# Update config
self._cfg.decoder.params = decoder_config
OmegaConf.set_struct(self._cfg.decoder, False)
self._cfg.decoder = decoder_config
OmegaConf.set_struct(self._cfg.decoder, True)

logging.info(f"Changed tokenizer to {self.decoder.vocabulary} vocabulary.")


Expand Down
7 changes: 5 additions & 2 deletions nemo/collections/asr/models/ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Dict, List, Optional, Union

import torch
from omegaconf import DictConfig
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer

from nemo.collections.asr.data.audio_to_text import AudioToCharDataset, TarredAudioToCharDataset
Expand Down Expand Up @@ -154,7 +154,10 @@ def change_vocabulary(self, new_vocabulary: List[str]):
self._wer = WER(vocabulary=self.decoder.vocabulary, batch_dim_index=0, use_cer=False, ctc_decode=True)

# Update config
self._cfg.decoder.params = new_decoder_config
OmegaConf.set_struct(self._cfg.decoder, False)
self._cfg.decoder = new_decoder_config
OmegaConf.set_struct(self._cfg.decoder, True)

logging.info(f"Changed decoder to output to {self.decoder.vocabulary} vocabulary.")

def _setup_dataloader_from_config(self, config: Optional[Dict]):
Expand Down
3 changes: 2 additions & 1 deletion nemo/core/classes/modelPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):

self._cfg = config

self.save_hyperparameters(self._cfg)
self.save_hyperparameters(OmegaConf.to_container(self._cfg, resolve=True))
self._train_dl = None
self._validation_dl = None
self._test_dl = None
Expand Down Expand Up @@ -231,6 +231,7 @@ def load_from_checkpoint(
Loads ModelPT from checkpoint, with some maintenance of restoration.
For documentation, please refer to LightningModule.load_from_checkpoin() documentation.
"""
# TODO (@titu1994): When PTL 0.9+ is supported, add `strict=False` flag to constructor
checkpoint = None
try:
cls.__set_model_restore_state(is_being_restored=True)
Expand Down
6 changes: 3 additions & 3 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
numpy>=1.18.2
onnx>=1.7.0
pytorch-lightning>=0.8.5
pytorch-lightning==0.8.5
python-dateutil
torch
wget
wrapt
ruamel.yaml
scikit-learn
omegaconf==2.0.1rc11
hydra-core==1.0.0rc3
omegaconf==2.0.1rc12
hydra-core==1.0.0rc4
transformers>=2.11.0