diff --git a/CHANGELOG.md b/CHANGELOG.md index a6580dfbf64e..b0c3a818beb4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -72,6 +72,8 @@ To release a new version, please update the changelog as followed: ### Added - Named tensors tuple module's output for graph construction. ([PR #268](https://github.com/NVIDIA/NeMo/pull/268)) - @stasbel +- Introduced the `deprecated` decorator. +([PR #298](https://github.com/NVIDIA/NeMo/pull/298)) - @tkornuta-nvidia ### Changed - Additional Collections Repositories merged into core `nemo_toolkit` package. @@ -80,12 +82,18 @@ To release a new version, please update the changelog as followed: ([PR #284](https://github.com/NVIDIA/NeMo/pull/284)) - @stasbel - NeMo is not longer using pep8 code style rules. Code style rules are now enforced with `isort` and `black` incorporated into CI checks. ([PR #286](https://github.com/NVIDIA/NeMo/pull/286)) - @stasbel +- Major cleanup of Neural Module constructors (init), aiming at increasing the framework robustness: cleanup of NeuralModule initialization logic, refactor of trainer/actions (getting rid of local_params), fixes of several examples and unit tests, extraction and storing of intial parameters (init_params). +([PR #309](https://github.com/NVIDIA/NeMo/pull/309)) - @tkornuta-nvidia + ### Dependencies Update +- Added dependency on `wrapt` (the new version of the `deprecated` warning) - @tkornuta-nvidia, @DEKHTIARJonathan ### Deprecated ### Fixed +- Critical fix of the training action on CPU +([PR #308](https://github.com/NVIDIA/NeMo/pull/309)) - @tkornuta-nvidia ### Removed diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c4fcd2b1397b..3b3c46f5dcae 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -48,6 +48,6 @@ There are several tools to automatically format your code to be PEP 8 compliant, ## Nemo style 1. If you import a module from the same collection, use relative path instead of absolute path. For example, inside ``nemo_nlp``, use ``.utils`` instead of ``nemo_nelp.utils``. -1. Before accessing something, always make sure that it exists. E.g. right now, in ``actions.py``, there's this line of code ``batch_size=dl_nm.local_parameters["batch_size"]`` but nowhere in the codebase we check that ``batch_size`` is passed into datalayer. +1. Before accessing something, always make sure that it exists. 1. Right inheritance. For example, if a module doesn't have any trainable weights, don't inherit from TrainableNM. 1. Naming consistency, both within NeMo and between NeMo and external literature. E.g. use the name ``logits`` for ``log_probs``, ``hidden_size`` for ``d_model``. diff --git a/docs/docs_zh/sources/source/nlp/ner.rst b/docs/docs_zh/sources/source/nlp/ner.rst index 90cd306f62a2..ea9af287efc0 100644 --- a/docs/docs_zh/sources/source/nlp/ner.rst +++ b/docs/docs_zh/sources/source/nlp/ner.rst @@ -90,7 +90,7 @@ text.txt 每一行包含文本序列,其中词以空格来进行分隔。label label_ids = train_data_layer.dataset.label_ids num_classes = len(label_ids) - hidden_size = bert_model.local_parameters["hidden_size"] + hidden_size = bert_model.hidden_size ner_classifier = nemo_nlp.TokenClassifier(hidden_size=hidden_size, num_classes=num_classes, dropout=CLASSIFICATION_DROPOUT) @@ -217,8 +217,8 @@ text.txt 每一行包含文本序列,其中词以空格来进行分隔。label tokenizer = NemoBertTokenizer(pretrained_model="scibert_scivocab_cased") bert_model = nemo_nlp.huggingface.BERT( - pretrained_model_name="scibert_scivocab_cased", - factory=neural_factory) + pretrained_model_name="scibert_scivocab_cased" + ) 如果你想使用 TensorFlow 训练好的模型,例如 BioBERT ,你需要首先使用 Hugging Face 提供的 `model conversion script`_ 进行模型转换,再在 NeMo 中使用这个模型。 diff --git a/docs/docs_zh/sources/source/tutorials/custommodules.rst b/docs/docs_zh/sources/source/tutorials/custommodules.rst index 361c4bd7b95e..31b39181f8f0 100644 --- a/docs/docs_zh/sources/source/tutorials/custommodules.rst +++ b/docs/docs_zh/sources/source/tutorials/custommodules.rst @@ -48,8 +48,8 @@ .. code-block:: python - def __init__(self, *, module_params, ..., **kwargs) - super().__init__(**kwargs) + def __init__(self, module_params, ...) + super().__init__() (4) 实现 ``torch.nn.Module`` 模块里的 ``forward`` 方法 @@ -76,11 +76,11 @@ 0: AxisType(BatchTag), 1: AxisType(ChannelTag)})} - def __init__(self, **kwargs): + def __init__(self, dim): # (3) 调用基类构造函数 - TrainableNM.__init__(self, **kwargs) + super().__init__() # Neural Modules 的特定部分,剩下的是 PyTorch 代码 - self._dim = self.local_parameters["dim"] + self._dim = dim self.fc1 = nn.Linear(self._dim, 1) t.nn.init.xavier_uniform_(self.fc1.weight) self._device = t.device( @@ -115,8 +115,8 @@ def output_ports(self): return {...} - def __init__(self, *, module_params, .., **kwargs) - TrainableNM.__init__(self, **kwargs) + def __init__(self, module_params, ...) + super().__init__() (4) 修改 ``forward`` 方法,使得它的输入参数和你的输入端口名字匹配。 @@ -162,11 +162,11 @@ "label": NeuralType({0: AxisType(BatchTag)}), } - def __init__(self, **kwargs): - DataLayerNM.__init__(self, **kwargs) + def __init__(self, input_size, path): + super().__init__() - self._input_size = kwargs["input_size"] - self._path = kwargs["path"] + self._input_size = input_size + self._path = path self._transforms = transforms.Compose([ transforms.RandomResizedCrop(self._input_size), @@ -216,9 +216,9 @@ Example def output_ports(self): return {"loss": NeuralType(None)} - def __init__(self, **kwargs): + def __init__(self): # 神经模块 API - super().__init__(**kwargs) + super().__init__() # 结束神经模块 API self._criterion = torch.nn.CrossEntropyLoss() @@ -226,5 +226,3 @@ Example # 你需要实现这个方法 def _loss_function(self, **kwargs): return self._criterion(*(kwargs.values())) - - diff --git a/docs/sources/source/nlp/joint_intent_slot_filling.rst b/docs/sources/source/nlp/joint_intent_slot_filling.rst index 0b4f1284d912..830a110eff41 100644 --- a/docs/sources/source/nlp/joint_intent_slot_filling.rst +++ b/docs/sources/source/nlp/joint_intent_slot_filling.rst @@ -59,7 +59,7 @@ This will tokenize text following the mapping of the original BERT model. .. code-block:: python from transformers import BertTokenizer - hidden_size = pretrained_bert_model.local_parameters["hidden_size"] + hidden_size = pretrained_bert_model.hidden_size tokenizer = BertTokenizer.from_pretrained(args.pretrained_bert_model) Next, we define all Neural Modules participating in our joint intent slot filling classification pipeline. @@ -79,7 +79,8 @@ Next, we define all Neural Modules participating in our joint intent slot fillin .. code-block:: python pretrained_bert_model = nemo_nlp.huggingface.BERT( - pretrained_model_name=args.pretrained_bert_model, factory=nf) + pretrained_model_name=args.pretrained_bert_model + ) hidden_states = pretrained_bert_model(input_ids=ids, token_type_ids=type_ids, attention_mask=input_mask) @@ -256,4 +257,4 @@ References .. bibliography:: nlp_all.bib :style: plain :labelprefix: NLP-SLOT - :keyprefix: nlp-slot- \ No newline at end of file + :keyprefix: nlp-slot- diff --git a/docs/sources/source/nlp/ner.rst b/docs/sources/source/nlp/ner.rst index 497f53c82744..c139a09fa912 100644 --- a/docs/sources/source/nlp/ner.rst +++ b/docs/sources/source/nlp/ner.rst @@ -94,7 +94,7 @@ We need to create the classifier to sit on top of the pretrained model and defin .. code-block:: python - hidden_size = bert_model.local_parameters["hidden_size"] + hidden_size = bert_model.hidden_size ner_classifier = nemo_nlp.TokenClassifier(hidden_size=hidden_size, num_classes=num_classes, dropout=CLASSIFICATION_DROPOUT) diff --git a/docs/sources/source/nlp/punctuation.rst b/docs/sources/source/nlp/punctuation.rst index 6ded5c6e06d2..e958f58935d3 100644 --- a/docs/sources/source/nlp/punctuation.rst +++ b/docs/sources/source/nlp/punctuation.rst @@ -116,7 +116,7 @@ Now, create the train and evaluation data layers: punct_label_ids = train_data_layer.dataset.punct_label_ids capit_label_ids = train_data_layer.dataset.capit_label_ids - hidden_size = bert_model.local_parameters["hidden_size"] + hidden_size = bert_model.hidden_size # Note that you need to specify punct_label_ids and capit_label_ids - mapping form labels # to label_ids generated during creation of the train_data_layer to make sure that diff --git a/docs/sources/source/nlp/question_answering.rst b/docs/sources/source/nlp/question_answering.rst index 08264bf0020e..266e74799a7a 100644 --- a/docs/sources/source/nlp/question_answering.rst +++ b/docs/sources/source/nlp/question_answering.rst @@ -61,7 +61,7 @@ This will tokenize text following the mapping of the original BERT model. .. code-block:: python from nemo.collections.nlp import NemoBertTokenizer - hidden_size = pretrained_bert_model.local_parameters["hidden_size"] + hidden_size = pretrained_bert_model.hidden_size tokenizer = NemoBertTokenizer(args.pretrained_bert_model) Next, we define all Neural Modules participating in our question answering classification pipeline. diff --git a/docs/sources/source/tutorials/custommodules.rst b/docs/sources/source/tutorials/custommodules.rst index 28eac88ca127..9bf308fdbef3 100644 --- a/docs/sources/source/tutorials/custommodules.rst +++ b/docs/sources/source/tutorials/custommodules.rst @@ -48,8 +48,8 @@ Defining a module from scratch .. code-block:: python - def __init__(self, *, module_params, ..., **kwargs) - super().__init__(**kwargs) + def __init__(self, module_params, ...) + super().__init__() (4) Implement ``forward`` method from ``torch.nn.Module`` @@ -76,11 +76,11 @@ Example 1 0: AxisType(BatchTag), 1: AxisType(ChannelTag)})} - def __init__(self, **kwargs): + def __init__(self, dim): # (3) Call base constructor - TrainableNM.__init__(self, **kwargs) + super().__init__() # And of Neural Modules specific part. Rest is PyTorch code - self._dim = self.local_parameters["dim"] + self._dim = dim self.fc1 = nn.Linear(self._dim, 1) t.nn.init.xavier_uniform_(self.fc1.weight) self._device = t.device( @@ -116,8 +116,8 @@ Converting from PyTorch's nn.Module def output_ports(self): return {...} - def __init__(self, *, module_params, .., **kwargs) - TrainableNM.__init__(self, **kwargs) + def __init__(self, module_params, ...) + super().__init__() (4) Modify ``forward`` method so that its input arguments match your input port names exactly. @@ -167,11 +167,11 @@ This example wraps PyTorch's *ImageFolder* dataset into a neural module data lay "label": NeuralType({0: AxisType(BatchTag)}), } - def __init__(self, **kwargs): - DataLayerNM.__init__(self, **kwargs) + def __init__(self, input_size, path): + super().__init__() - self._input_size = kwargs["input_size"] - self._path = kwargs["path"] + self._input_size = input_size + self._path = path self._transforms = transforms.Compose([ transforms.RandomResizedCrop(self._input_size), @@ -223,9 +223,9 @@ Example def output_ports(self): return {"loss": NeuralType(None)} - def __init__(self, **kwargs): + def __init__(self): # Neural Module API specific - super().__init__(**kwargs) + super().__init__() # End of Neural Module API specific self._criterion = torch.nn.CrossEntropyLoss() @@ -233,5 +233,3 @@ Example # You need to implement this function def _loss_function(self, **kwargs): return self._criterion(*(kwargs.values())) - - diff --git a/examples/applications/asr_service/app/__init__.py b/examples/applications/asr_service/app/__init__.py index abcd6abf1015..a31e50d7ef94 100644 --- a/examples/applications/asr_service/app/__init__.py +++ b/examples/applications/asr_service/app/__init__.py @@ -32,7 +32,7 @@ # Instantiate necessary Neural Modules # Note that data layer is missing from here neural_factory = nemo.core.NeuralModuleFactory(placement=nemo.core.DeviceType.GPU, backend=nemo.core.Backend.PyTorch) -data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(factory=neural_factory) +data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor() jasper_encoder = nemo_asr.JasperEncoder( jasper=jasper_model_definition['JasperEncoder']['jasper'], activation=jasper_model_definition['JasperEncoder']['activation'], diff --git a/examples/asr/jasper.py b/examples/asr/jasper.py index c42bf9321a1f..bb00ffa304ef 100644 --- a/examples/asr/jasper.py +++ b/examples/asr/jasper.py @@ -136,9 +136,7 @@ def create_all_dags(args, neural_factory): ) jasper_decoder = nemo_asr.JasperDecoderForCTC( - feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"], - num_classes=len(vocab), - factory=neural_factory, + feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"], num_classes=len(vocab) ) ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab)) diff --git a/examples/asr/jasper_aishell.py b/examples/asr/jasper_aishell.py index a6115d6c8f77..67bafeafdf00 100644 --- a/examples/asr/jasper_aishell.py +++ b/examples/asr/jasper_aishell.py @@ -137,9 +137,7 @@ def create_all_dags(args, neural_factory): ) jasper_decoder = nemo_asr.JasperDecoderForCTC( - feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"], - num_classes=len(vocab), - factory=neural_factory, + feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"], num_classes=len(vocab) ) ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab)) diff --git a/examples/asr/notebooks/2_Online_ASR_Microphone_Demo.ipynb b/examples/asr/notebooks/2_Online_ASR_Microphone_Demo.ipynb index 38486280ec13..4a842b3a4365 100644 --- a/examples/asr/notebooks/2_Online_ASR_Microphone_Demo.ipynb +++ b/examples/asr/notebooks/2_Online_ASR_Microphone_Demo.ipynb @@ -132,8 +132,8 @@ " \"a_sig_length\": NeuralType({0: AxisType(BatchTag)}),\n", " }\n", "\n", - " def __init__(self, **kwargs):\n", - " DataLayerNM.__init__(self, **kwargs)\n", + " def __init__(self):\n", + " super().__init__()\n", " self.output = True\n", " \n", " def __iter__(self):\n", diff --git a/examples/image/gan.py b/examples/image/gan.py index 6a01e822830e..61aee8c252f7 100644 --- a/examples/image/gan.py +++ b/examples/image/gan.py @@ -18,7 +18,7 @@ parser.add_argument( "--train_dataset", # set default=os.getcwd() unless your are running test - default="/home/mrjenkins/TestData", + default="~/TestData/mnist", type=str, ) parser.add_argument("--amp_opt_level", choices=['O0', 'O1', 'O2', 'O3'], default='O0') @@ -44,7 +44,7 @@ batch_size=batch_size, shuffle=True, train=True, root=args.train_dataset ) -generator = nemo_simple_gan.SimpleGenerator(batch_size=batch_size) +generator = nemo_simple_gan.SimpleGenerator() discriminator = nemo_simple_gan.SimpleDiscriminator() neg_disc_loss = nemo_simple_gan.DiscriminatorLoss(neg=True) disc_loss = nemo_simple_gan.DiscriminatorLoss() diff --git a/examples/nlp/BERTPretrainingTutorial.ipynb b/examples/nlp/BERTPretrainingTutorial.ipynb index a9d82a21b5ee..6c62a495db50 100644 --- a/examples/nlp/BERTPretrainingTutorial.ipynb +++ b/examples/nlp/BERTPretrainingTutorial.ipynb @@ -133,8 +133,8 @@ " num_attention_heads=NUM_HEADS,\n", " intermediate_size=D_INNER,\n", " max_position_embeddings=MAX_SEQ_LENGTH,\n", - " hidden_act=HIDDEN_ACT,\n", - " factory=neural_factory)" + " hidden_act=HIDDEN_ACT\n", + ")" ] }, { @@ -167,22 +167,21 @@ "metadata": {}, "outputs": [], "source": [ - "import os\n", "train_data_layer = nemo_nlp.BertPretrainingDataLayer(\n", " tokenizer=tokenizer,\n", " dataset=os.path.join(\"data/lm/wikitext-2\", \"train.txt\"),\n", " max_seq_length=MAX_SEQ_LENGTH,\n", " mask_probability=MASK_PROBABILITY,\n", - " batch_size=BATCH_SIZE,\n", - " factory=neural_factory)\n", + " batch_size=BATCH_SIZE\n", + ")\n", "\n", "eval_data_layer = nemo_nlp.BertPretrainingDataLayer(\n", " tokenizer=tokenizer,\n", " dataset=os.path.join(\"data/lm/wikitext-2\", \"valid.txt\"),\n", " max_seq_length=MAX_SEQ_LENGTH,\n", " mask_probability=MASK_PROBABILITY,\n", - " batch_size=BATCH_SIZE_EVAL,\n", - " factory=neural_factory)" + " batch_size=BATCH_SIZE_EVAL\n", + ")" ] }, { @@ -301,7 +300,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.5" + "version": "3.7.4" } }, "nbformat": 4, diff --git a/examples/nlp/NERWithBERT.ipynb b/examples/nlp/NERWithBERT.ipynb index 83e0f9811fbf..19cf18f8389b 100644 --- a/examples/nlp/NERWithBERT.ipynb +++ b/examples/nlp/NERWithBERT.ipynb @@ -99,8 +99,7 @@ "label_ids = train_data_layer.dataset.label_ids\n", "num_classes = len(label_ids)\n", "\n", - "hidden_size = bert_model.local_parameters[\"hidden_size\"]\n", - "ner_classifier = nemo_nlp.TokenClassifier(hidden_size=hidden_size,\n", + "ner_classifier = nemo_nlp.TokenClassifier(hidden_size=bert_model.hidden_size,\n", " num_classes=num_classes,\n", " dropout=CLASSIFICATION_DROPOUT)\n", "\n", @@ -204,9 +203,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3.7.4 64-bit", "language": "python", - "name": "python3" + "name": "python37464bitc56e562f54084a24b5afed5459c99218" }, "language_info": { "codemirror_mode": { diff --git a/examples/nlp/PunctuationWithBERT.ipynb b/examples/nlp/PunctuationWithBERT.ipynb index 4d92cdba7bac..58d0f57f8edb 100644 --- a/examples/nlp/PunctuationWithBERT.ipynb +++ b/examples/nlp/PunctuationWithBERT.ipynb @@ -142,19 +142,17 @@ "punct_label_ids = train_data_layer.dataset.punct_label_ids\n", "capit_label_ids = train_data_layer.dataset.capit_label_ids\n", "\n", - "hidden_size = bert_model.local_parameters[\"hidden_size\"]\n", - "\n", "\n", "# Define classifier for Punctuation and Capitalization tasks\n", "punct_classifier = nemo_nlp.TokenClassifier(\n", - " hidden_size=hidden_size,\n", + " hidden_size=bert_model.hidden_size,\n", " num_classes=len(punct_label_ids),\n", " dropout=CLASSIFICATION_DROPOUT,\n", " num_layers=PUNCT_NUM_FC_LAYERS,\n", " name='Punctuation')\n", "\n", "capit_classifier = nemo_nlp.TokenClassifier(\n", - " hidden_size=hidden_size,\n", + " hidden_size=bert_model.hidden_size,\n", " num_classes=len(capit_label_ids),\n", " dropout=CLASSIFICATION_DROPOUT,\n", " name='Capitalization')\n", diff --git a/examples/nlp/asr_postprocessor.py b/examples/nlp/asr_postprocessor.py index e29969fe95f3..f65de6e8becc 100644 --- a/examples/nlp/asr_postprocessor.py +++ b/examples/nlp/asr_postprocessor.py @@ -66,7 +66,7 @@ tokens_to_add = vocab_size - tokenizer.vocab_size zeros_transform = nemo.backends.pytorch.common.ZerosLikeNM() -encoder = nemo_nlp.huggingface.BERT(pretrained_model_name=args.pretrained_model, local_rank=args.local_rank) +encoder = nemo_nlp.huggingface.BERT(pretrained_model_name=args.pretrained_model) device = encoder.bert.embeddings.word_embeddings.weight.get_device() zeros = torch.zeros((tokens_to_add, args.d_model)).to(device=device) encoder.bert.embeddings.word_embeddings.weight.data = torch.cat( @@ -92,7 +92,7 @@ t_log_softmax = nemo_nlp.TokenClassifier(args.d_model, num_classes=vocab_size, num_layers=1, log_softmax=True) -loss_fn = nemo_nlp.PaddedSmoothedCrossEntropyLossNM(pad_id=tokenizer.pad_id(), smoothing=0.1) +loss_fn = nemo_nlp.PaddedSmoothedCrossEntropyLossNM(pad_id=tokenizer.pad_id(), label_smoothing=0.1) beam_search = nemo_nlp.BeamSearchTranslatorNM( decoder=decoder, diff --git a/examples/nlp/glue_with_BERT.py b/examples/nlp/glue_with_BERT.py index a30bec8d0021..d7dcc8bc87b7 100644 --- a/examples/nlp/glue_with_BERT.py +++ b/examples/nlp/glue_with_BERT.py @@ -240,7 +240,7 @@ model.restore_from(args.bert_checkpoint) -hidden_size = model.local_parameters["hidden_size"] +hidden_size = model.hidden_size # uses [CLS] token for classification (the first token) if args.task_name == 'sts-b': @@ -268,8 +268,8 @@ def create_pipeline( processor=processor, evaluate=evaluate, batch_size=batch_size, - num_workers=0, - local_rank=local_rank, + # num_workers=0, + # local_rank=local_rank, tokenizer=tokenizer, data_dir=args.data_dir, max_seq_length=max_seq_length, diff --git a/examples/nlp/joint_intent_slot_infer.py b/examples/nlp/joint_intent_slot_infer.py index 6a005ab15759..d2f3efaf8c68 100644 --- a/examples/nlp/joint_intent_slot_infer.py +++ b/examples/nlp/joint_intent_slot_infer.py @@ -36,7 +36,7 @@ nemo_nlp.huggingface.BERT.list_pretrained_models() """ pretrained_bert_model = nemo_nlp.huggingface.BERT(pretrained_model_name=args.pretrained_bert_model) -hidden_size = pretrained_bert_model.local_parameters["hidden_size"] +hidden_size = pretrained_bert_model.hidden_size tokenizer = BertTokenizer.from_pretrained(args.pretrained_bert_model) data_desc = JointIntentSlotDataDesc(args.data_dir, args.do_lower_case, args.dataset_name) @@ -51,8 +51,8 @@ max_seq_length=args.max_seq_length, shuffle=False, batch_size=args.batch_size, - num_workers=0, - local_rank=args.local_rank, + # num_workers=0, + # local_rank=args.local_rank, ) classifier = nemo_nlp.JointIntentSlotClassifier( diff --git a/examples/nlp/joint_intent_slot_infer_b1.py b/examples/nlp/joint_intent_slot_infer_b1.py index 69e69c2f47ca..089a2c06820e 100644 --- a/examples/nlp/joint_intent_slot_infer_b1.py +++ b/examples/nlp/joint_intent_slot_infer_b1.py @@ -30,9 +30,9 @@ See the list of pretrained models, call: nemo_nlp.huggingface.BERT.list_pretrained_models() """ -pretrained_bert_model = nemo_nlp.huggingface.BERT(pretrained_model_name=args.pretrained_bert_model, factory=nf) +pretrained_bert_model = nemo_nlp.huggingface.BERT(pretrained_model_name=args.pretrained_bert_model) tokenizer = BertTokenizer.from_pretrained(args.pretrained_bert_model) -hidden_size = pretrained_bert_model.local_parameters["hidden_size"] +hidden_size = pretrained_bert_model.hidden_size data_desc = JointIntentSlotDataDesc(args.data_dir, args.do_lower_case, args.dataset_name) diff --git a/examples/nlp/joint_intent_slot_with_bert.py b/examples/nlp/joint_intent_slot_with_bert.py index 665f1701b62c..8e0d5874f226 100644 --- a/examples/nlp/joint_intent_slot_with_bert.py +++ b/examples/nlp/joint_intent_slot_with_bert.py @@ -76,7 +76,7 @@ else: pretrained_bert_model = nemo_nlp.huggingface.BERT(pretrained_model_name=args.pretrained_bert_model, factory=nf) -hidden_size = pretrained_bert_model.local_parameters["hidden_size"] +hidden_size = pretrained_bert_model.hidden_size data_desc = JointIntentSlotDataDesc( args.data_dir, args.do_lower_case, args.dataset_name, args.none_slot_label, args.pad_label, diff --git a/examples/nlp/punctuation_capitalization.py b/examples/nlp/punctuation_capitalization.py index 0ca47dde7acc..cf2a2d20cda6 100644 --- a/examples/nlp/punctuation_capitalization.py +++ b/examples/nlp/punctuation_capitalization.py @@ -140,7 +140,7 @@ model.restore_from(args.bert_checkpoint) nemo.logging.info(f"Model restored from {args.bert_checkpoint}") -hidden_size = model.local_parameters["hidden_size"] +hidden_size = model.hidden_size punct_classifier = "TokenClassifier" punct_loss = "TokenClassificationLoss" diff --git a/examples/nlp/punctuation_capitalization_infer.py b/examples/nlp/punctuation_capitalization_infer.py index 25d08e67ad7d..2456e64408f2 100644 --- a/examples/nlp/punctuation_capitalization_infer.py +++ b/examples/nlp/punctuation_capitalization_infer.py @@ -78,7 +78,7 @@ nemo_nlp.huggingface.BERT.list_pretrained_models() """ pretrained_bert_model = nemo_nlp.huggingface.BERT(pretrained_model_name=args.pretrained_bert_model) -hidden_size = pretrained_bert_model.local_parameters["hidden_size"] +hidden_size = pretrained_bert_model.hidden_size tokenizer = NemoBertTokenizer(args.pretrained_bert_model) data_layer = nemo_nlp.BertTokenClassificationInferDataLayer( diff --git a/examples/nlp/sentence_classification_with_bert.py b/examples/nlp/sentence_classification_with_bert.py index 62bce6491ee5..2cd622e65ac3 100644 --- a/examples/nlp/sentence_classification_with_bert.py +++ b/examples/nlp/sentence_classification_with_bert.py @@ -68,12 +68,12 @@ """ if args.bert_checkpoint and args.bert_config: - pretrained_bert_model = nemo_nlp.huggingface.BERT(config_filename=args.bert_config, factory=nf) + pretrained_bert_model = nemo_nlp.huggingface.BERT(config_filename=args.bert_config) pretrained_bert_model.restore_from(args.bert_checkpoint) else: - pretrained_bert_model = nemo_nlp.huggingface.BERT(pretrained_model_name=args.pretrained_bert_model, factory=nf) + pretrained_bert_model = nemo_nlp.huggingface.BERT(pretrained_model_name=args.pretrained_bert_model) -hidden_size = pretrained_bert_model.local_parameters["hidden_size"] +hidden_size = pretrained_bert_model.hidden_size tokenizer = BertTokenizer.from_pretrained(args.pretrained_bert_model) data_desc = SentenceClassificationDataDesc(args.dataset_name, args.data_dir, args.do_lower_case) @@ -102,8 +102,8 @@ def create_pipeline(num_samples=-1, batch_size=32, num_gpus=1, local_rank=0, mod num_samples=num_samples, shuffle=shuffle, batch_size=batch_size, - num_workers=0, - local_rank=local_rank, + # num_workers=0, + # local_rank=local_rank, ) ids, type_ids, input_mask, labels = data_layer() diff --git a/examples/nlp/squad.py b/examples/nlp/squad.py index da8e4dd7f9d5..627b8bd00300 100755 --- a/examples/nlp/squad.py +++ b/examples/nlp/squad.py @@ -311,7 +311,7 @@ def create_pipeline( """ model = nemo_nlp.huggingface.BERT(pretrained_model_name=args.pretrained_bert_model) - hidden_size = model.local_parameters["hidden_size"] + hidden_size = model.hidden_size qa_head = nemo_nlp.TokenClassifier(hidden_size=hidden_size, num_classes=2, num_layers=1, log_softmax=False) squad_loss = nemo_nlp.QuestionAnsweringLoss() diff --git a/examples/nlp/token_classification.py b/examples/nlp/token_classification.py index a6c782d1f214..43749c299e05 100644 --- a/examples/nlp/token_classification.py +++ b/examples/nlp/token_classification.py @@ -137,7 +137,7 @@ model.restore_from(args.bert_checkpoint) nemo.logging.info(f"Model restored from {args.bert_checkpoint}") -hidden_size = model.local_parameters["hidden_size"] +hidden_size = model.hidden_size classifier = "TokenClassifier" task_loss = "TokenClassificationLoss" diff --git a/examples/nlp/token_classification_infer.py b/examples/nlp/token_classification_infer.py index 4205909f41cc..ae272f86d210 100644 --- a/examples/nlp/token_classification_infer.py +++ b/examples/nlp/token_classification_infer.py @@ -56,7 +56,7 @@ nemo_nlp.huggingface.BERT.list_pretrained_models() """ pretrained_bert_model = nemo_nlp.huggingface.BERT(pretrained_model_name=args.pretrained_bert_model) -hidden_size = pretrained_bert_model.local_parameters["hidden_size"] +hidden_size = pretrained_bert_model.hidden_size tokenizer = NemoBertTokenizer(args.pretrained_bert_model) data_layer = nemo_nlp.BertTokenClassificationInferDataLayer( diff --git a/examples/start_here/chatbot_example.py b/examples/start_here/chatbot_example.py index 47d240318613..c5107411525d 100644 --- a/examples/start_here/chatbot_example.py +++ b/examples/start_here/chatbot_example.py @@ -6,59 +6,46 @@ logging = nemo.logging -# Get Data data_file = "movie_data.txt" + +# Download the data file. if not os.path.isfile(data_file): with gzip.open("../../tests/data/movie_lines.txt.gz", 'rb') as f_in: with open(data_file, 'wb') as f_out: shutil.copyfileobj(f_in, f_out) -# Configuration -config = { - "corpus_name": "cornell", - "datafile": data_file, - "attn_model": 'dot', - "hidden_size": 512, - "encoder_n_layers": 2, - "decoder_n_layers": 2, - "dropout": 0.1, - "voc_size": 6104 + 3, - "batch_size": 128, - # "num_epochs": 15, - # 3 is too small - used for test - "num_epochs": 3, - "optimizer_kind": "adam", - "learning_rate": 0.0003, - "tb_log_dir": "ChatBot", -} - -# instantiate neural factory + +# Instantiate the neural factory nf = nemo.core.NeuralModuleFactory() # To use CPU-only do: -# from nemo.core import DeviceType -# nf = nemo.core.NeuralModuleFactory(placement=DeviceType.CPU) +# nf = nemo.core.NeuralModuleFactory(placement=nemo.core.DeviceType.CPU) -# instantiate neural modules -dl = nemo.tutorials.DialogDataLayer(**config) -encoder = nemo.tutorials.EncoderRNN(**config) -decoder = nemo.tutorials.LuongAttnDecoderRNN(**config) +# Instantiate all required neural modules. +dl = nemo.tutorials.DialogDataLayer(batch_size=128, corpus_name="cornell", datafile=data_file) +encoder = nemo.tutorials.EncoderRNN(voc_size=(6104 + 3), encoder_n_layers=2, hidden_size=512, dropout=0.1) +decoder = nemo.tutorials.LuongAttnDecoderRNN( + attn_model="dot", hidden_size=512, voc_size=(6104 + 3), decoder_n_layers=2, dropout=0.1 +) L = nemo.tutorials.MaskedXEntropyLoss() -decoderInfer = nemo.tutorials.GreedyLuongAttnDecoderRNN(**config) -# PARAMETER SHARING: between training and auto-regressive inference decoders +decoderInfer = nemo.tutorials.GreedyLuongAttnDecoderRNN( + attn_model="dot", hidden_size=512, voc_size=(6104 + 3), decoder_n_layers=2, dropout=0.1, max_dec_steps=10 +) + +# PARAMETER SHARING: between training and auto-regressive inference decoders. decoderInfer.tie_weights_with(decoder, list(decoder.get_weights().keys())) -# express activations flow +# Connect the modules - express activations flow for training. src, src_lengths, tgt, mask, max_tgt_length = dl() encoder_outputs, encoder_hidden = encoder(input_seq=src, input_lengths=src_lengths) outputs, hidden = decoder(targets=tgt, encoder_outputs=encoder_outputs, max_target_len=max_tgt_length) loss = L(predictions=outputs, target=tgt, mask=mask) -# run inference decoder to generate predictions +# Run inference decoder to generate predictions. outputs_inf, _ = decoderInfer(encoder_outputs=encoder_outputs) -# define callback function which prints intermediate results to console +# Define the callback function which prints intermediate results to console. def outputs2words(tensors, vocab): source_ids = tensors[1][:, 0].cpu().numpy().tolist() response_ids = tensors[2][:, 0].cpu().numpy().tolist() @@ -73,14 +60,15 @@ def outputs2words(tensors, vocab): logging.info(f"SOURCE: {source} <---> PREDICTED RESPONSE: {response} " f"<---> TARGET: {target}") +# Create simple callback. callback = nemo.core.SimpleLossLoggerCallback( tensors=[loss, src, outputs_inf, tgt], print_func=lambda x: outputs2words(x, dl.voc.index2word), ) -# start training +# Start training nf.train( tensors_to_optimize=[loss], callbacks=[callback], optimizer="adam", - optimization_params={"num_epochs": config["num_epochs"], "lr": 0.001}, + optimization_params={"num_epochs": 3, "lr": 0.001}, ) diff --git a/examples/start_here/simplest_example.py b/examples/start_here/simplest_example.py index d9ad47198087..0bf3fb795dac 100644 --- a/examples/start_here/simplest_example.py +++ b/examples/start_here/simplest_example.py @@ -4,18 +4,16 @@ logging = nemo.logging nf = nemo.core.NeuralModuleFactory() - # To use CPU-only do: -# from nemo.core import DeviceType -# nf = nemo.core.NeuralModuleFactory(placement=DeviceType.CPU) +# nf = nemo.core.NeuralModuleFactory(placement=nemo.core.DeviceType.CPU) -# instantiate necessary neural modules -# RealFunctionDataLayer defaults to f=torch.sin, sampling from x=[-4, 4] +# Instantiate the necessary neural modules. +# RealFunctionDataLayer defaults to f_name="sin", sampling from x=[-4, 4] dl = nemo.tutorials.RealFunctionDataLayer(n=10000, batch_size=128) fx = nemo.tutorials.TaylorNet(dim=4) loss = nemo.tutorials.MSELoss() -# describe activation's flow +# Describe the activation flow. x, y = dl() p = fx(x=x) lss = loss(predictions=p, target=y) @@ -25,7 +23,5 @@ tensors=[lss], print_func=lambda x: logging.info(f'Train Loss: {str(x[0].item())}'), ) -# Invoke "train" action -nf.train( - [lss], callbacks=[callback], optimization_params={"num_epochs": 3, "lr": 0.0003}, optimizer="sgd", -) +# Invoke "train" action. +nf.train([lss], callbacks=[callback], optimization_params={"num_epochs": 3, "lr": 0.0003}, optimizer="sgd") diff --git a/examples/start_here/simplest_example_configuration_import.py b/examples/start_here/simplest_example_configuration_import.py new file mode 100644 index 000000000000..2233afbd58a1 --- /dev/null +++ b/examples/start_here/simplest_example_configuration_import.py @@ -0,0 +1,48 @@ +# TODO: actually fill this +# ! /usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright 2019 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +import nemo +from nemo.core import DeviceType + +# Run on CPU. +nf = nemo.core.NeuralModuleFactory(placement=DeviceType.CPU) + + +# instantiate necessary neural modules +# RealFunctionDataLayer defaults to f=torch.sin, sampling from x=[-4, 4] +# dl = nemo.tutorials.RealFunctionDataLayer(n=10000, f_name="cos", x=[-4, 4], batch_size=128) +dl = nemo.tutorials.RealFunctionDataLayer(n=100, f_name="cos", x_lo=-1, x_hi=1, batch_size=128) + + +fx = nemo.tutorials.TaylorNet(dim=4) +loss = nemo.tutorials.MSELoss() + +# describe activation's flow +x, y = dl() +p = fx(x=x) +lss = loss(predictions=p, target=y) + +# SimpleLossLoggerCallback will print loss values to console. +callback = nemo.core.SimpleLossLoggerCallback( + tensors=[lss], print_func=lambda x: nemo.logging.info(f'Train Loss: {str(x[0].item())}') +) + + +# Invoke "train" action +nf.train([lss], callbacks=[callback], optimization_params={"num_epochs": 3, "lr": 0.0003}, optimizer="sgd") diff --git a/examples/tts/configs/tacotron2.yaml b/examples/tts/configs/tacotron2.yaml index 8406dc97eb7c..a405d4b85b90 100644 --- a/examples/tts/configs/tacotron2.yaml +++ b/examples/tts/configs/tacotron2.yaml @@ -31,7 +31,7 @@ AudioToMelSpectrogramPreprocessor: n_fft: *n_fft frame_splicing: 1 dither: 0. - feat_type: "logfbank" + #feat_type: "logfbank" stft_conv: true sample_rate: *sr highfreq: *fmax diff --git a/examples/tts/tts_infer.py b/examples/tts/tts_infer.py index 94545eab5776..81e828d680e8 100644 --- a/examples/tts/tts_infer.py +++ b/examples/tts/tts_infer.py @@ -137,7 +137,7 @@ def create_infer_dags( labels=tacotron2_params['labels'], batch_size=infer_batch_size, num_workers=cpu_per_dl, - load_audio=False, + # load_audio=False, bos_id=len(tacotron2_params['labels']), eos_id=len(tacotron2_params['labels']) + 1, pad_id=len(tacotron2_params['labels']) + 2, diff --git a/nemo/backends/pytorch/actions.py b/nemo/backends/pytorch/actions.py index d1ce14bc3d47..f7061318305c 100644 --- a/nemo/backends/pytorch/actions.py +++ b/nemo/backends/pytorch/actions.py @@ -1,4 +1,5 @@ # Copyright (c) 2019 NVIDIA Corporation +import copy import importlib import itertools import json @@ -492,8 +493,8 @@ def _eval(self, tensors_2_evaluate, callback, step, verbose=False): """ with torch.no_grad(): # each call chain corresponds to a tensor in tensors_2_evaluate - dl_nm = None call_chain, _ = self.__get_top_sorted_modules_and_dataloader(hook=tensors_2_evaluate) + # "Retrieve" data layer from call chain. dl_nm = call_chain[0][0] # Prepare eval_dataloader @@ -511,13 +512,15 @@ def _eval(self, tensors_2_evaluate, callback, step, verbose=False): # ) # ) if dl_nm.dataset is not None: - sampler = torch.utils.data.distributed.DistributedSampler(dl_nm.dataset) + sampler = torch.utils.data.distributed.DistributedSampler( + dataset=dl_nm.dataset, shuffle=dl_nm.shuffle + ) eval_dataloader = torch.utils.data.DataLoader( dataset=dl_nm.dataset, sampler=sampler, - num_workers=dl_nm.local_parameters.get("num_workers", os.cpu_count()), - batch_size=dl_nm.local_parameters["batch_size"], - shuffle=(sampler is None), + num_workers=dl_nm.num_workers, + batch_size=dl_nm.batch_size, + shuffle=False, ) else: eval_dataloader = dl_nm.data_iterator @@ -529,9 +532,9 @@ def _eval(self, tensors_2_evaluate, callback, step, verbose=False): eval_dataloader = torch.utils.data.DataLoader( dataset=dl_nm.dataset, sampler=None, # not distributed sampler - num_workers=call_chain[0][0].local_parameters.get("num_workers", os.cpu_count()), - batch_size=call_chain[0][0].local_parameters["batch_size"], - shuffle=call_chain[0][0].local_parameters.get("shuffle", False), + num_workers=dl_nm.num_workers, + batch_size=dl_nm.batch_size, + shuffle=dl_nm.shuffle, ) else: eval_dataloader = dl_nm.data_iterator @@ -666,13 +669,15 @@ def _infer( # ) # ) if dl_nm.dataset is not None: - sampler = torch.utils.data.distributed.DistributedSampler(dl_nm.dataset) + sampler = torch.utils.data.distributed.DistributedSampler( + dataset=dl_nm.dataset, shuffle=dl_nm.shuffle + ) eval_dataloader = torch.utils.data.DataLoader( dataset=dl_nm.dataset, sampler=sampler, - num_workers=dl_nm.local_parameters.get("num_workers", os.cpu_count()), - batch_size=dl_nm.local_parameters["batch_size"], - shuffle=(sampler is None), + num_workers=dl_nm.num_workers, + batch_size=dl_nm.batch_size, + shuffle=False, ) else: eval_dataloader = dl_nm.data_iterator @@ -685,9 +690,9 @@ def _infer( eval_dataloader = torch.utils.data.DataLoader( dataset=dl_nm.dataset, sampler=None, # not distributed sampler - num_workers=call_chain[0][0].local_parameters.get("num_workers", os.cpu_count()), - batch_size=call_chain[0][0].local_parameters["batch_size"], - shuffle=call_chain[0][0].local_parameters.get("shuffle", False), + num_workers=dl_nm.num_workers, + batch_size=dl_nm.batch_size, + shuffle=dl_nm.shuffle, ) else: eval_dataloader = dl_nm.data_iterator @@ -945,17 +950,17 @@ def __extract_dynamic_axes(port_name: str, ntype: NeuralType, dynamic_axes: defa if len(dynamic_axes) == 0: dynamic_axes = None - local_parameters = {} - if module._local_parameters is not None: - for key, value in module._local_parameters.items(): - local_parameters[key] = value + # Make a deep copy of init parameters. + init_params_copy = copy.deepcopy(module._init_params) # Remove NeMo-related things from the module # We need to change __call__ method. Note that this will change the # whole class, not just this object! Which is why we need to repair it # in the finally block type(module).__call__ = torch.nn.Module.__call__ - module._local_parameters = None + + # Reset standard instance field - making the file (probably) lighter. + module._init_params = None module._placement = None module._factory = None module._device = None @@ -1006,12 +1011,12 @@ def __extract_dynamic_axes(port_name: str, ntype: NeuralType, dynamic_axes: defa elif d_format == DeploymentFormat.PYTORCH: torch.save(module.state_dict(), output) with open(output + ".json", 'w') as outfile: - json.dump(local_parameters, outfile) + json.dump(init_params_copy, outfile) else: raise NotImplementedError(f"Not supported deployment format: {d_format}") except Exception as e: # nopep8 - logging.error(f'ERROR: module export failed for {module} ' f'with exception {e}') + logging.error(f'module export failed for {module} ' f'with exception {e}') finally: def __old_call__(self, force_pt=False, *input, **kwargs): @@ -1179,13 +1184,15 @@ def train( # "optimizers") logging.info("Doing distributed training") if t_dataset is not None: - train_sampler = torch.utils.data.distributed.DistributedSampler(t_dataset) + train_sampler = torch.utils.data.distributed.DistributedSampler( + dataset=t_dataset, shuffle=dataNM.shuffle + ) train_dataloader = torch.utils.data.DataLoader( dataset=t_dataset, sampler=train_sampler, - num_workers=dataNM.local_parameters.get("num_workers", os.cpu_count()), - batch_size=dataNM.local_parameters["batch_size"], - shuffle=(train_sampler is None), + num_workers=dataNM.num_workers, + batch_size=dataNM.batch_size, + shuffle=False, ) else: train_dataloader = dataNM.data_iterator @@ -1229,9 +1236,9 @@ def train( train_dataloader = torch.utils.data.DataLoader( dataset=t_dataset, sampler=None, - num_workers=dataNM.local_parameters.get("num_workers", os.cpu_count()), - batch_size=dataNM.local_parameters["batch_size"], - shuffle=dataNM.local_parameters.get("shuffle", True), + num_workers=dataNM.num_workers, + batch_size=dataNM.batch_size, + shuffle=dataNM.shuffle, ) else: train_dataloader = dataNM.data_iterator @@ -1310,7 +1317,7 @@ def train( ): if stop_on_nan_loss: raise ValueError('Loss is NaN or inf - exiting') - logging.warning('WARNING: Loss is NaN or inf') + logging.warning('Loss is NaN or inf') curr_optimizer.zero_grad() nan = True break diff --git a/nemo/backends/pytorch/common/losses.py b/nemo/backends/pytorch/common/losses.py index 295c09ba1ce4..f79917720bec 100644 --- a/nemo/backends/pytorch/common/losses.py +++ b/nemo/backends/pytorch/common/losses.py @@ -64,18 +64,11 @@ def output_ports(self): return {"loss": NeuralType(None)} def __init__( - self, - pad_id=0, - smoothing_coef=0.0, - sample_wise=False, - aux_ctc=False, - ctc_initial_coef=0.1, - ctc_blank_id=None, - **kwargs + self, pad_id=0, smoothing_coef=0.0, sample_wise=False, aux_ctc=False, ctc_initial_coef=0.1, ctc_blank_id=None ): assert (not aux_ctc) or (ctc_blank_id is not None), "Should be a blank id if using CTC loss" - super().__init__(**kwargs) + super().__init__() self.pad_id = pad_id self.smoothing_coef = smoothing_coef @@ -152,8 +145,8 @@ def output_ports(self): """ return {"loss": NeuralType(None)} - def __init__(self, weight=None, **kwargs): - LossNM.__init__(self, **kwargs) + def __init__(self, weight=None): + super().__init__() if weight: weight = torch.FloatTensor(weight).to(self._device) self._criterion = nn.CrossEntropyLoss(weight=weight) @@ -188,8 +181,8 @@ def output_ports(self): """ return {"loss": NeuralType(None)} - def __init__(self, **kwargs): - LossNM.__init__(self, **kwargs) + def __init__(self): + super().__init__() self._criterion = nn.MSELoss() def _loss_function(self, preds, labels): diff --git a/nemo/backends/pytorch/common/other.py b/nemo/backends/pytorch/common/other.py index 982abd100446..80d43dadae15 100644 --- a/nemo/backends/pytorch/common/other.py +++ b/nemo/backends/pytorch/common/other.py @@ -51,8 +51,8 @@ def output_ports(self): """ return {"combined": None} - def __init__(self, mode="add", **kwargs): - TrainableNM.__init__(self, **kwargs) + def __init__(self, mode="add"): + super().__init__() self._mode = mode def forward(self, x1, x2): @@ -94,8 +94,8 @@ def output_ports(self): "indices": NeuralType({0: AxisType(BatchTag)}), } - def __init__(self, **kwargs): - TrainableNM.__init__(self, **kwargs) + def __init__(self): + super().__init__() # this method is key method you need to overwrite from PyTorch # nn.Module's API @@ -107,8 +107,8 @@ def forward(self, x): class TableLookUp(NeuralModule): """Performs a table lookup. For example, convert class ids to names""" - def __init__(self, ids2classes=None, **kwargs): - NeuralModule.__init__(self, **kwargs) + def __init__(self, ids2classes=None): + NeuralModule.__init__(self) if ids2classes is None: ids2classes = {} @@ -220,9 +220,8 @@ def output_ports(self): """ return {"classes": None} - def __init__(self, detokenizer=None, **kwargs): - NeuralModule.__init__(self, **kwargs) - # self._sp_decoder = self.local_parameters.get("sp_decoder", {}) + def __init__(self, detokenizer=None): + NeuralModule.__init__(self) self._detokenizer = detokenizer def __call__(self, force_pt=False, *input, **kwargs): @@ -276,8 +275,8 @@ def output_ports(self): """ return {"outputs": NeuralType({0: AxisType(TimeTag), 1: AxisType(BatchTag), 2: AxisType(ChannelTag),})} - def __init__(self, *, voc_size, hidden_size, dropout=0.0, **kwargs): - TrainableNM.__init__(self, **kwargs) + def __init__(self, voc_size, hidden_size, dropout=0.0): + super().__init__() self.voc_size = voc_size self.hidden_size = hidden_size @@ -312,8 +311,8 @@ def output_ports(self): """ return {"outputs": None} - def __init__(self, *, from_dim, to_dim, dropout=0.0, **kwargs): - TrainableNM.__init__(self, **kwargs) + def __init__(self, from_dim, to_dim, dropout=0.0): + super().__init__() self.from_dim = from_dim self.to_dim = to_dim @@ -352,8 +351,8 @@ def output_ports(self): """ return {"input_type_ids": NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag),})} - def __init__(self, **kwargs): - TrainableNM.__init__(self, **kwargs) + def __init__(self): + super().__init__() def forward(self, input_type_ids): return torch.zeros_like(input_type_ids).long() diff --git a/nemo/backends/pytorch/common/rnn.py b/nemo/backends/pytorch/common/rnn.py index c7f6fc66f5bc..4b8e994223eb 100644 --- a/nemo/backends/pytorch/common/rnn.py +++ b/nemo/backends/pytorch/common/rnn.py @@ -111,9 +111,8 @@ def __init__( rnn_type='gru', n_layers=2, tie_emb_out_weights=True, - **kwargs ): - super().__init__(**kwargs) + super().__init__() self.bos_id = bos_id self.attention_type = attention_type diff --git a/nemo/backends/pytorch/common/search.py b/nemo/backends/pytorch/common/search.py index 812c22ce2cfd..350fdb3dff5c 100644 --- a/nemo/backends/pytorch/common/search.py +++ b/nemo/backends/pytorch/common/search.py @@ -66,8 +66,8 @@ def output_ports(self): 'attention_weights': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag), 2: AxisType(TimeTag),}), } - def __init__(self, decoder, pad_id, bos_id, eos_id, max_len, batch_size=None, **kwargs): - super().__init__(**kwargs) + def __init__(self, decoder, pad_id, bos_id, eos_id, max_len, batch_size=None): + super().__init__() self.decoder = decoder self.pad_id = pad_id @@ -118,8 +118,8 @@ class BeamSearch(GreedySearch): """ - def __init__(self, decoder, pad_id, bos_id, eos_id, max_len, batch_size=None, beam_size=8, **kwargs): - super().__init__(decoder, pad_id, bos_id, eos_id, max_len, batch_size, **kwargs) + def __init__(self, decoder, pad_id, bos_id, eos_id, max_len, batch_size=None, beam_size=8): + super().__init__(decoder, pad_id, bos_id, eos_id, max_len, batch_size) self.beam_size = beam_size diff --git a/nemo/backends/pytorch/common/zero_data.py b/nemo/backends/pytorch/common/zero_data.py index 8b7b2c08ce6a..0c7b14fe1a11 100644 --- a/nemo/backends/pytorch/common/zero_data.py +++ b/nemo/backends/pytorch/common/zero_data.py @@ -67,10 +67,10 @@ class ZerosDataLayer(DataLayerNM): Defaults to None. """ - def __init__(self, *, size, output_ports, dtype, batch_size, shapes=None, **kwargs): - DataLayerNM.__init__(self, **kwargs) - self._size = size + def __init__(self, size, output_ports, dtype, batch_size, shapes=None): self._output_ports = output_ports + DataLayerNM.__init__(self) + self._size = size self._type = dtype self._batch_size = batch_size self._shapes = shapes @@ -97,7 +97,3 @@ def data_iterator(self): @property def dataset(self): return self._dataset - - @property - def batch_size(self): - return self._batch_size diff --git a/nemo/backends/pytorch/module_wrapper.py b/nemo/backends/pytorch/module_wrapper.py index c233f77ee01e..f439a847411d 100644 --- a/nemo/backends/pytorch/module_wrapper.py +++ b/nemo/backends/pytorch/module_wrapper.py @@ -10,8 +10,8 @@ class TrainableNeuralModuleWrapper(NeuralModule, nn.Module): """This class wraps an instance of Pytorch's nn.Module and returns NeuralModule's instance.""" - def __init__(self, pt_nn_module, input_ports_dict, output_ports_dict, **kwargs): - NeuralModule.__init__(self, **kwargs) + def __init__(self, pt_nn_module, input_ports_dict, output_ports_dict): + NeuralModule.__init__(self) nn.Module.__init__(self) self._input_ports = input_ports_dict self._output_ports = output_ports_dict diff --git a/nemo/backends/pytorch/nm.py b/nemo/backends/pytorch/nm.py index 2cd70c3695d5..0a92cfe5cdc9 100644 --- a/nemo/backends/pytorch/nm.py +++ b/nemo/backends/pytorch/nm.py @@ -1,4 +1,5 @@ # Copyright (c) 2019 NVIDIA Corporation +import os from abc import abstractmethod from typing import Dict, List, Optional, Set, Tuple @@ -20,19 +21,29 @@ class TrainableNM(NeuralModule, nn.Module): .. code-block:: python - def __init__(self, **kwargs): - TrainableNM.__init__(self, **kwargs) - .... # you code + def __init__(self): + super().__init__() + .... # your code Then make sure that your forward(..) method accepts arguments named like input ports. + + Args: + pretrained_model_name (str): name of pretrained model to use in order + to initialize this neural module + """ - def __init__(self, **kwargs): - NeuralModule.__init__(self, **kwargs) # For NeuralModule API + def __init__(self, pretrained_model_name=None): + + NeuralModule.__init__(self) # For NeuralModule API nn.Module.__init__(self) # For PyTorch API + self._device = get_cuda_device(self.placement) + # Store pretrained model name (to be removed/changed) + self._pretrained_model_name = pretrained_model_name + def __call__(self, *input, force_pt=False, **kwargs): pt_call = len(input) > 0 or force_pt if pt_call: @@ -119,8 +130,8 @@ def num_weights(self): class NonTrainableNM(NeuralModule): - def __init__(self, **kwargs): - NeuralModule.__init__(self, **kwargs) # For NeuralModule API + def __init__(self): + NeuralModule.__init__(self) # For NeuralModule API self._device = get_cuda_device(self.placement) def __call__(self, force_pt=False, *input, **kwargs): @@ -179,13 +190,22 @@ class DataLayerNM(NeuralModule): data_iterator property to return iterator over the dataset. """ - def __init__(self, **kwargs): + def __init__(self): + NeuralModule.__init__(self) # For NeuralModule API + self._device = get_cuda_device(self.placement) + # if 'batch_size' not in kwargs: # nemo.logging.warning("No batch_size specified in the data layer. " # "Setting batch_size to 1.") # kwargs['batch_size'] = 1 - NeuralModule.__init__(self, **kwargs) # For NeuralModule API - self._device = get_cuda_device(self.placement) + + # Set default values of variables used by trained/passed to DataLoader. + # NOTE: That also means that those are parameters of DataLoader/trainer, not DataLayer. + # Thus those fields will be removed from DataLayer and moved to trainer configuration + # (when the time for that will come;)) + self._batch_size = 1 + self._num_workers = os.cpu_count() # Use all CPUs by default. + self._shuffle = True # Shuffle by default. @property def input_ports(self): @@ -269,14 +289,44 @@ def data_iterator(self): If this is implemented, `dataset` property should return None. """ + @property + def batch_size(self): + """ Property returning the batch size. """ + return self._batch_size + + # @batch_size.setter + # def batch_size(self, bs): + # """ Property setting the batch size. """ + # self._batch_size = bs + + @property + def shuffle(self): + """ Property returning the shuffle flag. """ + return self._shuffle + + # @shuffle.setter + # def shuffle(self, sh): + # """ Property setting the shuffle flag. """ + # self._shuffle = sh + + @property + def num_workers(self): + """ Property returning the number of workers. """ + return self._num_workers + + # @num_workers.setter + # def num_workers(self, nw): + # """ Property setting the number of workers. """ + # self._num_workers = nw + class LossNM(NeuralModule): """A helper Base class for creating Pytorch-based loss function modules. You must implement _loss_function method. """ - def __init__(self, **kwargs): - NeuralModule.__init__(self, **kwargs) # For NeuralModule API + def __init__(self): + NeuralModule.__init__(self) # For NeuralModule API self._device = get_cuda_device(self.placement) def get_weights(self): diff --git a/nemo/backends/pytorch/torchvision/data/image_folder.py b/nemo/backends/pytorch/torchvision/data/image_folder.py index 8f762e6bfbbc..5c4946b5cdd5 100644 --- a/nemo/backends/pytorch/torchvision/data/image_folder.py +++ b/nemo/backends/pytorch/torchvision/data/image_folder.py @@ -38,8 +38,8 @@ def output_ports(self): "label": NeuralType({0: AxisType(BatchTag)}), } - def __init__(self, *, input_size=32, batch_size, path, shuffle=True, is_eval=False, **kwargs): - DataLayerNM.__init__(self, **kwargs) + def __init__(self, batch_size, path, input_size=32, shuffle=True, is_eval=False): + super().__init__() self._input_size = input_size self._batch_size = batch_size diff --git a/nemo/backends/pytorch/tutorials/chatbot/modules.py b/nemo/backends/pytorch/tutorials/chatbot/modules.py index de98c5799edb..33fec674ba02 100644 --- a/nemo/backends/pytorch/tutorials/chatbot/modules.py +++ b/nemo/backends/pytorch/tutorials/chatbot/modules.py @@ -50,8 +50,8 @@ def output_ports(self): "max_tgt_lengths": NeuralType(None), } - def __init__(self, *, batch_size, corpus_name, datafile, min_count=3, **kwargs): - DataLayerNM.__init__(self, **kwargs) + def __init__(self, batch_size, corpus_name, datafile, min_count=3): + super().__init__() self._batch_size = batch_size self._corpus_name = corpus_name @@ -129,8 +129,8 @@ def output_ports(self): "hidden": NeuralType({0: AxisType(BatchTag), 1: AxisType(ChannelTag)}), } - def __init__(self, *, voc_size, encoder_n_layers, hidden_size, dropout, bidirectional=True, **kwargs): - TrainableNM.__init__(self, **kwargs) + def __init__(self, voc_size, encoder_n_layers, hidden_size, dropout, bidirectional=True): + super().__init__() self.voc_size = voc_size self.n_layers = encoder_n_layers @@ -217,8 +217,8 @@ def output_ports(self): "hidden": NeuralType({0: AxisType(BatchTag), 1: AxisType(ChannelTag)}), } - def __init__(self, *, attn_model, hidden_size, voc_size, decoder_n_layers, dropout, **kwargs): - TrainableNM.__init__(self, **kwargs) + def __init__(self, attn_model, hidden_size, voc_size, decoder_n_layers, dropout): + super().__init__() self.attn_model = attn_model self.hidden_size = hidden_size @@ -360,8 +360,8 @@ def output_ports(self): """ return {"loss": NeuralType(None)} - def __init__(self, **kwargs): - LossNM.__init__(self, **kwargs) + def __init__(self): + super().__init__() self._device = t.device("cuda" if self.placement == DeviceType.GPU else "cpu") @@ -416,13 +416,12 @@ def output_ports(self): "hidden": NeuralType({0: AxisType(BatchTag), 1: AxisType(ChannelTag)}), } - def __init__(self, *, attn_model, hidden_size, voc_size, decoder_n_layers, dropout, max_dec_steps=10, **kwargs): - TrainableNM.__init__(self, **kwargs) + def __init__(self, attn_model, hidden_size, voc_size, decoder_n_layers, dropout, max_dec_steps=10): + super().__init__() self.attn_model = attn_model self.hidden_size = hidden_size self.voc_size = voc_size - # self.local_parameters["output_size"] self.output_size = voc_size self.n_layers = decoder_n_layers self.dropout = dropout diff --git a/nemo/backends/pytorch/tutorials/toys.py b/nemo/backends/pytorch/tutorials/toys.py index a6929f9d3b43..cf43c475543e 100644 --- a/nemo/backends/pytorch/tutorials/toys.py +++ b/nemo/backends/pytorch/tutorials/toys.py @@ -32,11 +32,11 @@ def output_ports(self): """ return {"y_pred": NeuralType({0: AxisType(BatchTag), 1: AxisType(ChannelTag)})} - def __init__(self, *, dim, **kwargs): + def __init__(self, dim): # Part specific for Neural Modules API: # (1) call base constructor # (2) define input and output ports - TrainableNM.__init__(self, **kwargs) + super().__init__() # And of Neural Modules specific part. Rest is Pytorch code self._dim = dim @@ -87,11 +87,11 @@ def output_ports(self): """ return {"y_pred": NeuralType({0: AxisType(BatchTag), 1: AxisType(ChannelTag)}, optional=True)} - def __init__(self, *, dim, **kwargs): + def __init__(self, dim): # Part specific for Neural Modules API: # (1) call base constructor # (2) define input and output ports - TrainableNM.__init__(self, **kwargs) + super().__init__() # And of Neural Modules specific part. Rest is Pytorch code self._dim = dim @@ -121,9 +121,10 @@ class RealFunctionDataLayer(DataLayerNM): Args: n: Total number of samples batch_size: Size of each batch per iteration - f: A lambda of the function to apply to each x value to get labels. + f_name: Name of the function that will be applied to each x value to get labels. Must take a torch tensor as input, and output a torch tensor of the same shape. Defaults to torch.sin(). + [Options: sin | cos] x_lo: Lower bound of domain to sample x_hi: Upper bound of domain to sample """ @@ -150,15 +151,31 @@ def output_ports(self): "y": NeuralType({0: AxisType(BatchTag), 1: AxisType(ChannelTag)}), } - def __init__(self, *, n, batch_size, f=t.sin, x_lo=-4, x_hi=4, **kwargs): - DataLayerNM.__init__(self, **kwargs) + def __init__(self, batch_size, f_name="sin", n=1000, x_lo=-4, x_hi=4): + """ + Creates a datalayer returning (x-y) pairs, with n points from a given range. + + Args: + batch_size: size of batch + f_name: name of function ["sin" | "cos"] + n: number of points + x_lo: lower boundary along x axis + x_hi: higher boundary along x axis + """ + super().__init__() + + # Dicionary with handled functions. + handled_funcs = {"sin": t.sin, "cos": t.cos} + + # Get function - raises an exception if function is not handled + func = handled_funcs[f_name] self._n = n self._batch_size = batch_size self._device = t.device("cuda" if self.placement == DeviceType.GPU else "cpu") x_data = t.tensor(np.random.uniform(low=x_lo, high=x_hi, size=self._n)).unsqueeze(-1).to(self._device) - y_data = f(x_data) + y_data = func(x_data) self._data_iterator = t_utils.DataLoader( t_utils.TensorDataset(x_data.float(), y_data.float()), batch_size=self._batch_size, @@ -202,8 +219,8 @@ def output_ports(self): """ return {"loss": NeuralType(None)} - def __init__(self, **kwargs): - LossNM.__init__(self, **kwargs) + def __init__(self): + super().__init__() self._criterion = nn.MSELoss() def _loss_function(self, **kwargs): @@ -239,8 +256,8 @@ def output_ports(self): """ return {"loss": NeuralType(None)} - def __init__(self, **kwargs): - LossNM.__init__(self, **kwargs) + def __init__(self): + super().__init__() self._criterion = nn.L1Loss() def _loss_function(self, **kwargs): @@ -274,9 +291,9 @@ def output_ports(self): """ return {"loss": NeuralType(None)} - def __init__(self, **kwargs): + def __init__(self): # Neural Module API specific - NeuralModule.__init__(self, **kwargs) + NeuralModule.__init__(self) # End of Neural Module API specific self._criterion = nn.CrossEntropyLoss() @@ -326,9 +343,9 @@ def output_ports(self): """ return {"loss": NeuralType(None)} - def __init__(self, **kwargs): + def __init__(self): # Neural Module API specific - NeuralModule.__init__(self, **kwargs) + NeuralModule.__init__(self) # You need to implement this function def _loss_function(self, **kwargs): diff --git a/nemo/collections/asr/audio_preprocessing.py b/nemo/collections/asr/audio_preprocessing.py index 119a031e5a3f..94476839a1f3 100644 --- a/nemo/collections/asr/audio_preprocessing.py +++ b/nemo/collections/asr/audio_preprocessing.py @@ -55,8 +55,8 @@ class AudioPreprocessor(NonTrainableNM): transforming the wav files to features. """ - def __init__(self, win_length, hop_length, **kwargs): - super().__init__(**kwargs) + def __init__(self, win_length, hop_length): + super().__init__() self.win_length = win_length self.hop_length = hop_length @@ -161,7 +161,6 @@ def output_ports(self): def __init__( self, - *, sample_rate=16000, window_size=0.02, window_stride=0.01, @@ -170,7 +169,6 @@ def __init__( n_fft=None, window="hann", normalized=True, - **kwargs, ): if not HAVE_TORCHAUDIO: raise ModuleNotFoundError( @@ -189,7 +187,7 @@ def __init__( if window_stride: n_window_stride = int(window_stride * sample_rate) - super().__init__(n_window_size, n_window_stride, **kwargs) + super().__init__(n_window_size, n_window_stride) self.win_length = n_window_size self.hop_length = n_window_stride @@ -326,7 +324,6 @@ def output_ports(self): def __init__( self, - *, sample_rate=16000, window_size=0.02, window_stride=0.01, @@ -348,7 +345,6 @@ def __init__( stft_conv=False, pad_value=0, mag_power=2.0, - **kwargs, ): if window_size and n_window_size: raise ValueError(f"{self} received both window_size and " f"n_window_size. Only one should be specified.") @@ -361,7 +357,7 @@ def __init__( if window_stride: n_window_stride = int(window_stride * sample_rate) - super().__init__(n_window_size, n_window_stride, **kwargs) + super().__init__(n_window_size, n_window_stride) self.featurizer = FilterbankFeatures( sample_rate=sample_rate, @@ -478,7 +474,6 @@ def output_ports(self): def __init__( self, - *, sample_rate=16000, window_size=0.02, window_stride=0.01, @@ -493,7 +488,6 @@ def __init__( dct_type=2, norm='ortho', log=True, - **kwargs, ): if not HAVE_TORCHAUDIO: raise ModuleNotFoundError( @@ -513,7 +507,7 @@ def __init__( if window_stride: n_window_stride = int(window_stride * sample_rate) - super().__init__(n_window_size, n_window_stride, **kwargs) + super().__init__(n_window_size, n_window_stride) mel_kwargs = {} @@ -615,7 +609,6 @@ def output_ports(self): def __init__( self, - *, freq_masks=0, time_masks=0, freq_width=10, @@ -624,9 +617,8 @@ def __init__( rect_time=5, rect_freq=20, rng=None, - **kwargs, ): - NonTrainableNM.__init__(self, **kwargs) + super().__init__() if rect_masks > 0: self.spec_cutout = SpecCutout(rect_masks=rect_masks, rect_time=rect_time, rect_freq=rect_freq, rng=rng,) @@ -717,8 +709,8 @@ def output_ports(self): "out_y_len": NeuralType({0: AxisType(BatchTag)}), } - def __init__(self, *, mult_batch=1, **kwargs): - NonTrainableNM.__init__(self, **kwargs) + def __init__(self, mult_batch=1): + super().__init__() self.mult = mult_batch @torch.no_grad() diff --git a/nemo/collections/asr/beam_search_decoder.py b/nemo/collections/asr/beam_search_decoder.py index 7c48eb61e88e..6bb985a98e5c 100644 --- a/nemo/collections/asr/beam_search_decoder.py +++ b/nemo/collections/asr/beam_search_decoder.py @@ -7,6 +7,7 @@ from nemo.backends.pytorch.nm import NonTrainableNM from nemo.core import DeviceType from nemo.core.neural_types import AxisType, BatchTag, ChannelTag, NeuralType, TimeTag +from nemo.utils.helpers import get_cuda_device class BeamSearchDecoderWithLM(NonTrainableNM): @@ -65,9 +66,7 @@ def output_ports(self): """ return {"predictions": NeuralType(None)} - def __init__( - self, *, vocab, beam_width, alpha, beta, lm_path, num_cpus, cutoff_prob=1.0, cutoff_top_n=40, **kwargs - ): + def __init__(self, vocab, beam_width, alpha, beta, lm_path, num_cpus, cutoff_prob=1.0, cutoff_top_n=40): try: from ctc_decoders import Scorer @@ -79,11 +78,10 @@ def __init__( "from nemo/scripts/install_decoders.py" ) - super().__init__( - # Override default placement from neural factory - placement=DeviceType.CPU, - **kwargs - ) + super().__init__() + # Override the default placement from neural factory and set placement/device to be CPU. + self._placement = DeviceType.CPU + self._device = get_cuda_device(self._placement) if self._factory.world_size > 1: raise ValueError("BeamSearchDecoderWithLM does not run in distributed mode") diff --git a/nemo/collections/asr/data_layer.py b/nemo/collections/asr/data_layer.py index 9aca09afd070..44b1cca9c9b6 100644 --- a/nemo/collections/asr/data_layer.py +++ b/nemo/collections/asr/data_layer.py @@ -108,7 +108,6 @@ def output_ports(self): def __init__( self, - *, manifest_filepath, labels, batch_size, @@ -125,10 +124,8 @@ def __init__( drop_last=False, shuffle=True, num_workers=0, - # perturb_config=None, - **kwargs, ): - super().__init__(**kwargs) + super().__init__() self._featurizer = WaveformFeaturizer(sample_rate=sample_rate, int_values=int_values, augmentor=None) @@ -245,7 +242,6 @@ def output_ports(self): def __init__( self, - *, kaldi_dir, labels, batch_size, @@ -255,9 +251,8 @@ def __init__( drop_last=False, shuffle=True, num_workers=0, - **kwargs, ): - super().__init__(**kwargs) + super().__init__() # Set up dataset dataset_params = { @@ -382,9 +377,8 @@ def __init__( drop_last=False, num_workers=0, shuffle=True, - **kwargs, ): - super().__init__(**kwargs) + super().__init__() # Set up dataset dataset_params = { diff --git a/nemo/collections/asr/greedy_ctc_decoder.py b/nemo/collections/asr/greedy_ctc_decoder.py index 03eb9862c47b..b9b416b8983a 100644 --- a/nemo/collections/asr/greedy_ctc_decoder.py +++ b/nemo/collections/asr/greedy_ctc_decoder.py @@ -34,8 +34,8 @@ def output_ports(self): """ return {"predictions": NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)})} - def __init__(self, **kwargs): - TrainableNM.__init__(self, **kwargs) + def __init__(self): + super().__init__() def forward(self, log_probs): with torch.no_grad(): diff --git a/nemo/collections/asr/jasper.py b/nemo/collections/asr/jasper.py index a363813c7952..db75e0793643 100644 --- a/nemo/collections/asr/jasper.py +++ b/nemo/collections/asr/jasper.py @@ -124,7 +124,6 @@ def output_ports(self): def __init__( self, - *, jasper, activation, feat_in, @@ -134,9 +133,8 @@ def __init__( conv_mask=True, frame_splicing=1, init_mode='xavier_uniform', - **kwargs ): - TrainableNM.__init__(self, **kwargs) + super().__init__() activation = jasper_activations[activation]() feat_in = feat_in * frame_splicing @@ -234,8 +232,8 @@ def output_ports(self): """ return {"output": NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag), 2: AxisType(ChannelTag),})} - def __init__(self, *, feat_in, num_classes, init_mode="xavier_uniform", **kwargs): - TrainableNM.__init__(self, **kwargs) + def __init__(self, feat_in, num_classes, init_mode="xavier_uniform"): + super().__init__() self._feat_in = feat_in # Add 1 for blank char diff --git a/nemo/collections/asr/las/misc.py b/nemo/collections/asr/las/misc.py index c1402f517a34..a1a1a855e419 100644 --- a/nemo/collections/asr/las/misc.py +++ b/nemo/collections/asr/las/misc.py @@ -43,8 +43,8 @@ def output_ports(self): """ return {'tensor': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag), 2: AxisType(ChannelTag),})} - def __init__(self, in_channels, out_channels, **kwargs): - super().__init__(**kwargs) + def __init__(self, in_channels, out_channels): + super().__init__() self.icnn = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=True) self.bn = nn.BatchNorm1d(out_channels) diff --git a/nemo/collections/asr/losses.py b/nemo/collections/asr/losses.py index 47dbaac2b6da..f43a30791079 100644 --- a/nemo/collections/asr/losses.py +++ b/nemo/collections/asr/losses.py @@ -53,10 +53,9 @@ def output_ports(self): """ return {"loss": NeuralType(None)} - def __init__(self, *, num_classes, **kwargs): - LossNM.__init__(self, **kwargs) + def __init__(self, num_classes): + super().__init__() - # self._blank = self.local_parameters.get('blank', 0) self._blank = num_classes self._criterion = nn.CTCLoss(blank=self._blank, reduction='none') diff --git a/nemo/collections/asr/parts/features.py b/nemo/collections/asr/parts/features.py index 792758f84575..1c2fa1b28d42 100644 --- a/nemo/collections/asr/parts/features.py +++ b/nemo/collections/asr/parts/features.py @@ -107,7 +107,6 @@ class FilterbankFeatures(nn.Module): def __init__( self, - *, sample_rate=16000, n_window_size=320, n_window_stride=160, diff --git a/nemo/collections/nlp/data/data_layers.py b/nemo/collections/nlp/data/data_layers.py index 05fb44c34590..36dac97ec98d 100644 --- a/nemo/collections/nlp/data/data_layers.py +++ b/nemo/collections/nlp/data/data_layers.py @@ -44,13 +44,15 @@ class TextDataLayer(DataLayerNM): Args: dataset_type: type of dataset used for this datalayer dataset_params (dict): all the params for the dataset + batch_size: size of batch """ - def __init__(self, dataset_type, dataset_params, **kwargs): - super().__init__(**kwargs) + def __init__(self, dataset_type, dataset_params, batch_size): + super().__init__() if isinstance(dataset_type, str): dataset_type = getattr(sys.modules[__name__], dataset_type) self._dataset = dataset_type(**dataset_params) + self._batch_size = batch_size def __len__(self): return len(self._dataset) @@ -115,9 +117,7 @@ def __init__( shuffle=False, batch_size=64, dataset_type=BertSentenceClassificationDataset, - **kwargs ): - kwargs['batch_size'] = batch_size dataset_params = { 'input_file': input_file, 'tokenizer': tokenizer, @@ -125,7 +125,7 @@ def __init__( 'num_samples': num_samples, 'shuffle': shuffle, } - super().__init__(dataset_type, dataset_params, **kwargs) + super().__init__(dataset_type, dataset_params, batch_size) class BertJointIntentSlotDataLayer(TextDataLayer): @@ -207,9 +207,7 @@ def __init__( ignore_extra_tokens=False, ignore_start_end=False, dataset_type=BertJointIntentSlotDataset, - **kwargs ): - kwargs['batch_size'] = batch_size dataset_params = { 'input_file': input_file, 'slot_file': slot_file, @@ -221,7 +219,7 @@ def __init__( 'ignore_extra_tokens': ignore_extra_tokens, 'ignore_start_end': ignore_start_end, } - super().__init__(dataset_type, dataset_params, **kwargs) + super().__init__(dataset_type, dataset_params, batch_size) class BertJointIntentSlotInferDataLayer(TextDataLayer): @@ -281,16 +279,13 @@ def output_ports(self): "subtokens_mask": NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}), } - def __init__( - self, queries, tokenizer, max_seq_length, batch_size=1, dataset_type=BertJointIntentSlotInferDataset, **kwargs - ): - kwargs['batch_size'] = batch_size + def __init__(self, queries, tokenizer, max_seq_length, batch_size=1, dataset_type=BertJointIntentSlotInferDataset): dataset_params = { 'queries': queries, 'tokenizer': tokenizer, 'max_seq_length': max_seq_length, } - super().__init__(dataset_type, dataset_params, **kwargs) + super().__init__(dataset_type, dataset_params, batch_size) class LanguageModelingDataLayer(TextDataLayer): @@ -333,7 +328,7 @@ def output_ports(self): } def __init__( - self, dataset, tokenizer, max_seq_length, batch_step=128, dataset_type=LanguageModelingDataset, **kwargs + self, dataset, tokenizer, max_seq_length, batch_size, batch_step=128, dataset_type=LanguageModelingDataset ): dataset_params = { 'dataset': dataset, @@ -341,7 +336,7 @@ def __init__( 'max_seq_length': max_seq_length, 'batch_step': batch_step, } - super().__init__(dataset_type, dataset_params, **kwargs) + super().__init__(dataset_type, dataset_params, batch_size) class BertTokenClassificationDataLayer(TextDataLayer): @@ -403,9 +398,7 @@ def __init__( ignore_start_end=False, use_cache=False, dataset_type=BertTokenClassificationDataset, - **kwargs ): - kwargs['batch_size'] = batch_size dataset_params = { 'text_file': text_file, 'label_file': label_file, @@ -419,7 +412,7 @@ def __init__( 'ignore_start_end': ignore_start_end, 'use_cache': use_cache, } - super().__init__(dataset_type, dataset_params, **kwargs) + super().__init__(dataset_type, dataset_params, batch_size) class BertTokenClassificationInferDataLayer(TextDataLayer): @@ -462,21 +455,14 @@ def output_ports(self): } def __init__( - self, - queries, - tokenizer, - max_seq_length, - batch_size=1, - dataset_type=BertTokenClassificationInferDataset, - **kwargs + self, queries, tokenizer, max_seq_length, batch_size=1, dataset_type=BertTokenClassificationInferDataset ): - kwargs['batch_size'] = batch_size dataset_params = { 'queries': queries, 'tokenizer': tokenizer, 'max_seq_length': max_seq_length, } - super().__init__(dataset_type, dataset_params, **kwargs) + super().__init__(dataset_type, dataset_params, batch_size) class BertPunctuationCapitalizationDataLayer(TextDataLayer): @@ -546,9 +532,7 @@ def __init__( ignore_start_end=False, use_cache=False, dataset_type=BertPunctuationCapitalizationDataset, - **kwargs ): - kwargs['batch_size'] = batch_size dataset_params = { 'text_file': text_file, 'label_file': label_file, @@ -563,7 +547,7 @@ def __init__( 'ignore_start_end': ignore_start_end, 'use_cache': use_cache, } - super().__init__(dataset_type, dataset_params, **kwargs) + super().__init__(dataset_type, dataset_params, batch_size) class BertPunctuationCapitalizationInferDataLayer(TextDataLayer): @@ -606,21 +590,14 @@ def output_ports(self): } def __init__( - self, - queries, - tokenizer, - max_seq_length, - batch_size=1, - dataset_type=BertTokenClassificationInferDataset, - **kwargs + self, queries, tokenizer, max_seq_length, batch_size=1, dataset_type=BertTokenClassificationInferDataset, ): - kwargs['batch_size'] = batch_size dataset_params = { 'queries': queries, 'tokenizer': tokenizer, 'max_seq_length': max_seq_length, } - super().__init__(dataset_type, dataset_params, **kwargs) + super().__init__(dataset_type, dataset_params, batch_size) class BertQuestionAnsweringDataLayer(TextDataLayer): @@ -696,9 +673,7 @@ def __init__( mode="train", batch_size=64, dataset_type=SquadDataset, - **kwargs ): - kwargs['batch_size'] = batch_size dataset_params = { 'data_dir': data_dir, 'mode': mode, @@ -709,7 +684,7 @@ def __init__( 'doc_stride': doc_stride, } - super().__init__(dataset_type, dataset_params, **kwargs) + super().__init__(dataset_type, dataset_params, batch_size) class BertPretrainingDataLayer(TextDataLayer): @@ -771,10 +746,7 @@ def output_ports(self): "labels": NeuralType({0: AxisType(BatchTag)}), } - def __init__( - self, tokenizer, dataset, max_seq_length, mask_probability, short_seq_prob=0.1, batch_size=64, **kwargs - ): - kwargs['batch_size'] = batch_size + def __init__(self, tokenizer, dataset, max_seq_length, mask_probability, short_seq_prob=0.1, batch_size=64): dataset_params = { 'tokenizer': tokenizer, 'dataset': dataset, @@ -782,7 +754,7 @@ def __init__( 'mask_probability': mask_probability, 'short_seq_prob': short_seq_prob, } - super().__init__(BertPretrainingDataset, dataset_params, **kwargs) + super().__init__(BertPretrainingDataset, dataset_params, batch_size) class BertPretrainingPreprocessedDataLayer(DataLayerNM): @@ -844,7 +816,9 @@ def output_ports(self): "labels": NeuralType({0: AxisType(BatchTag)}), } - def __init__(self, dataset, max_pred_length, batch_size=64, training=True, **kwargs): + def __init__(self, dataset, max_pred_length, batch_size=64, training=True): + super().__init__() + self._batch_size = batch_size if os.path.isdir(dataset): self.files = [ @@ -854,7 +828,6 @@ def __init__(self, dataset, max_pred_length, batch_size=64, training=True, **kwa self.files = [dataset] self.files.sort() self.num_files = len(self.files) - self.batch_size = batch_size self.max_pred_length = max_pred_length self.training = training total_length = 0 @@ -863,7 +836,6 @@ def __init__(self, dataset, max_pred_length, batch_size=64, training=True, **kwa total_length += len(fp['input_ids']) fp.close() self.total_length = total_length - super().__init__(**kwargs) def _collate_fn(self, x): num_components = len(x[0]) @@ -979,10 +951,10 @@ def __init__( tokenizer_tgt, dataset_src, dataset_tgt, + batch_size=64, tokens_in_batch=1024, clean=False, dataset_type=TranslationDataset, - **kwargs ): dataset_params = { 'tokenizer_src': tokenizer_src, @@ -992,7 +964,7 @@ def __init__( 'tokens_in_batch': tokens_in_batch, 'clean': clean, } - super().__init__(dataset_type, dataset_params, **kwargs) + super().__init__(dataset_type, dataset_params, batch_size) if self._placement == nemo.core.DeviceType.AllGpu: sampler = pt_data.distributed.DistributedSampler(self._dataset) @@ -1075,9 +1047,7 @@ def __init__( shuffle=False, batch_size=64, dataset_type=GLUEDataset, - **kwargs ): - kwargs['batch_size'] = batch_size dataset_params = { 'data_dir': data_dir, 'output_mode': 'classification', @@ -1088,7 +1058,7 @@ def __init__( 'max_seq_length': max_seq_length, } - super().__init__(dataset_type, dataset_params, **kwargs) + super().__init__(dataset_type, dataset_params, batch_size) class GlueDataLayerRegression(TextDataLayer): @@ -1144,9 +1114,7 @@ def __init__( shuffle=False, batch_size=64, dataset_type=GLUEDataset, - **kwargs ): - kwargs['batch_size'] = batch_size dataset_params = { 'data_dir': data_dir, 'output_mode': 'regression', @@ -1157,4 +1125,4 @@ def __init__( 'max_seq_length': max_seq_length, } - super().__init__(dataset_type, dataset_params, **kwargs) + super().__init__(dataset_type, dataset_params, batch_size) diff --git a/nemo/collections/nlp/data/tokenizers/gpt2_tokenizer.py b/nemo/collections/nlp/data/tokenizers/gpt2_tokenizer.py index 7c7417c9f0c7..60e6c3cf3cd5 100644 --- a/nemo/collections/nlp/data/tokenizers/gpt2_tokenizer.py +++ b/nemo/collections/nlp/data/tokenizers/gpt2_tokenizer.py @@ -12,7 +12,6 @@ def __init__( errors='replace', bos_token="<|endoftext|>", eos_token="<|endoftext|>", - **kwargs ): if pretrained_model: self.tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model) diff --git a/nemo/collections/nlp/huggingface/bert.py b/nemo/collections/nlp/huggingface/bert.py index 684b6d93048a..616c07f60ce0 100644 --- a/nemo/collections/nlp/huggingface/bert.py +++ b/nemo/collections/nlp/huggingface/bert.py @@ -68,7 +68,6 @@ def output_ports(self): def __init__( self, - *, pretrained_model_name=None, config_filename=None, vocab_size=None, @@ -78,9 +77,8 @@ def __init__( intermediate_size=3072, hidden_act="gelu", max_position_embeddings=512, - **kwargs ): - TrainableNM.__init__(self, **kwargs) + super().__init__() # Check that only one of pretrained_model_name, config_filename, and # vocab_size was passed in @@ -99,6 +97,7 @@ def __init__( + "BERT constructor." ) + # TK: The following code checks the same once again. if vocab_size is not None: config = BertConfig( vocab_size_or_config_json_file=vocab_size, @@ -125,8 +124,23 @@ def __init__( self.add_module("bert", model) self.config = model.config - for key, value in self.config.to_dict().items(): - self._local_parameters[key] = value + + # TK: storing config name in init_params instead. + # for key, value in self.config.to_dict().items(): + # self._local_parameters[key] = value + + # Store the only value that will be used externally - hidden_size. + self._hidden_size = hidden_size + + @property + def hidden_size(self): + """ + Property returning hidden size. + + Returns: + Hidden size. + """ + return self._hidden_size @staticmethod def list_pretrained_models() -> Optional[List[PretrainedModelInfo]]: diff --git a/nemo/collections/nlp/modules/classifiers.py b/nemo/collections/nlp/modules/classifiers.py index 4aa171e693f4..0d6259cdd31a 100644 --- a/nemo/collections/nlp/modules/classifiers.py +++ b/nemo/collections/nlp/modules/classifiers.py @@ -273,10 +273,8 @@ def output_ports(self): "slot_logits": NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag), 2: AxisType(ChannelTag),}), } - def __init__( - self, hidden_size, num_intents, num_slots, dropout=0.0, use_transformer_pretrained=True, **kwargs, - ): - super().__init__(**kwargs) + def __init__(self, hidden_size, num_intents, num_slots, dropout=0.0, use_transformer_pretrained=True): + super().__init__() self.dropout = nn.Dropout(dropout) self.slot_mlp = MultiLayerPerceptron( hidden_size, diff --git a/nemo/collections/nlp/modules/losses.py b/nemo/collections/nlp/modules/losses.py index 2c1584ddfd98..34912f609fa4 100644 --- a/nemo/collections/nlp/modules/losses.py +++ b/nemo/collections/nlp/modules/losses.py @@ -75,8 +75,8 @@ def output_ports(self): "end_logits": NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}), } - def __init__(self, **kwargs): - LossNM.__init__(self, **kwargs) + def __init__(self): + super().__init__() def _loss_function(self, **kwargs): logits = kwargs['logits'] @@ -145,8 +145,8 @@ def output_ports(self): """ return {"loss": NeuralType(None)} - def __init__(self, label_smoothing=0.0, **kwargs): - LossNM.__init__(self, **kwargs) + def __init__(self, label_smoothing=0.0): + super().__init__() self._criterion = SmoothedCrossEntropyLoss(label_smoothing) def _loss_function(self, logits, output_ids, output_mask): @@ -182,11 +182,10 @@ def output_ports(self): """ return {"loss": NeuralType(None)} - def __init__(self, *, num_inputs=2, **kwargs): + def __init__(self, num_inputs=2): + super().__init__() # Store number of inputs/losses. self.num_losses = num_inputs - # kwargs["create_port_args"] = {"num_losses": num_inputs} - LossNM.__init__(self, **kwargs) def _loss_function(self, **kwargs): values = [kwargs[x] for x in sorted(kwargs.keys())] @@ -244,8 +243,8 @@ def output_ports(self): """ return {"loss": NeuralType(None)} - def __init__(self, num_classes, class_weights=None, **kwargs): - LossNM.__init__(self, **kwargs) + def __init__(self, num_classes, class_weights=None): + super().__init__() if class_weights: class_weights = torch.FloatTensor(class_weights).to(self._device) @@ -330,14 +329,9 @@ def output_ports(self): return {"loss": NeuralType(None)} def __init__( - self, - num_slots, - slot_classes_loss_weights=None, - intent_classes_loss_weights=None, - intent_loss_weight=0.6, - **kwargs + self, num_slots, slot_classes_loss_weights=None, intent_classes_loss_weights=None, intent_loss_weight=0.6 ): - LossNM.__init__(self, **kwargs) + super().__init__() self.num_slots = num_slots self.intent_loss_weight = intent_loss_weight self.slot_classes_loss_weights = slot_classes_loss_weights @@ -413,15 +407,14 @@ def output_ports(self): """ return {"loss": NeuralType(None)} - def __init__(self, **kwargs): - LossNM.__init__(self, **kwargs) + def __init__(self, pad_id, label_smoothing=0, predict_last_k=0): + super().__init__() - loss_params = { - "label_smoothing": self.local_parameters.get("label_smoothing", 0), - "predict_last_k": self.local_parameters.get("predict_last_k", 0), - } + # Create the loss function object. + loss_params = {"label_smoothing": label_smoothing, "predict_last_k": predict_last_k} self._loss_fn = SmoothedCrossEntropyLoss(**loss_params) - self._pad_id = self.local_parameters['pad_id'] + # Store padding. + self._pad_id = pad_id def _loss_function(self, logits, target_ids): target_mask = mask_padded_tokens(target_ids, self._pad_id).to(logits.dtype) diff --git a/nemo/collections/nlp/modules/transformer_nm.py b/nemo/collections/nlp/modules/transformer_nm.py index 0fef4622c482..e8e9897a825b 100644 --- a/nemo/collections/nlp/modules/transformer_nm.py +++ b/nemo/collections/nlp/modules/transformer_nm.py @@ -95,9 +95,8 @@ def __init__( learn_positional_encodings=False, hidden_act='relu', mask_future=False, - **kwargs ): - TrainableNM.__init__(self, **kwargs) + super().__init__() self.embedding_layer = TransformerEmbedding( vocab_size=vocab_size, @@ -211,9 +210,8 @@ def __init__( attn_layer_dropout=0.0, learn_positional_encodings=False, hidden_act='relu', - **kwargs ): - TrainableNM.__init__(self, **kwargs) + super().__init__() self.embedding_layer = TransformerEmbedding( vocab_size=vocab_size, @@ -280,8 +278,8 @@ def output_ports(self): """ return {"output_ids": NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)})} - def __init__(self, decoder, log_softmax, max_seq_length, pad_token, bos_token, eos_token, batch_size=1, **kwargs): - TrainableNM.__init__(self, **kwargs) + def __init__(self, decoder, log_softmax, max_seq_length, pad_token, bos_token, eos_token, batch_size=1): + super().__init__() self.generator = GreedySequenceGenerator( decoder.embedding_layer, @@ -370,9 +368,8 @@ def __init__( beam_size=4, max_delta_length=50, length_penalty=0, - **kwargs ): - TrainableNM.__init__(self, **kwargs) + super().__init__() self.generator = BeamSearchSequenceGenerator( decoder.embedding_layer, diff --git a/nemo/collections/simple_gan/gan.py b/nemo/collections/simple_gan/gan.py index 47cf4f49121b..16e83bbdf5c5 100644 --- a/nemo/collections/simple_gan/gan.py +++ b/nemo/collections/simple_gan/gan.py @@ -48,8 +48,8 @@ def output_ports(self): """ return {"decision": NeuralType({0: AxisType(BatchTag), 1: AxisType(ChannelTag, 1)})} - def __init__(self, **kwargs): - super().__init__(**kwargs) + def __init__(self): + super().__init__() self.layers = torch.nn.Sequential( torch.nn.Conv2d(1, 64, 3, padding=1), torch.nn.ReLU(), @@ -123,8 +123,8 @@ def output_ports(self): ) } - def __init__(self, **kwargs): - super().__init__(**kwargs) + def __init__(self): + super().__init__() self.layers = torch.nn.Sequential( torch.nn.ConvTranspose2d(64, 128, 3, stride=2), torch.nn.ReLU(), @@ -174,8 +174,8 @@ def output_ports(self): """ return {"loss": NeuralType(None)} - def __init__(self, neg=False, **kwargs): - super().__init__(**kwargs) + def __init__(self, neg=False): + super().__init__() self.neg = neg def _loss(self, decision): @@ -233,8 +233,8 @@ def output_ports(self): """ return {"loss": NeuralType(None)} - def __init__(self, lambda_, **kwargs): - super().__init__(**kwargs) + def __init__(self, lambda_): + super().__init__() self.lambda_ = lambda_ def _loss(self, interpolated_image, interpolated_decision): @@ -328,8 +328,8 @@ def output_ports(self): ) } - def __init__(self, **kwargs): - super().__init__(**kwargs) + def __init__(self): + super().__init__() def forward(self, image1, image2): alpha = torch.rand(image1.shape[0], 1).unsqueeze(-1).unsqueeze(-1) @@ -372,8 +372,8 @@ def output_ports(self): ) } - def __init__(self, *, batch_size, **kwargs): - DataLayerNM.__init__(self, **kwargs) + def __init__(self, batch_size): + super().__init__() self._batch_size = batch_size class DummyDataset(torch.utils.data.Dataset): @@ -457,9 +457,9 @@ def output_ports(self): "label": NeuralType({0: AxisType(BatchTag)}), } - def __init__(self, *, batch_size, root, train=True, shuffle=True, **kwargs): + def __init__(self, batch_size, root, train=True, shuffle=True): + super().__init__() self._input_size = (28, 28) - DataLayerNM.__init__(self, **kwargs) self._batch_size = batch_size self._train = train @@ -467,7 +467,7 @@ def __init__(self, *, batch_size, root, train=True, shuffle=True, **kwargs): self._root = root self._transforms = transforms.Compose([transforms.ToTensor()]) - self._dataset = datasets.MNIST(root=self._root, train=self._train, download=True, transform=self._transforms,) + self._dataset = datasets.MNIST(root=self._root, train=self._train, download=True, transform=self._transforms) class DatasetWrapper(Dataset): def __init__(self, dataset): diff --git a/nemo/collections/tts/data_layers.py b/nemo/collections/tts/data_layers.py index 1a68956e6350..cad859fb10cb 100644 --- a/nemo/collections/tts/data_layers.py +++ b/nemo/collections/tts/data_layers.py @@ -64,7 +64,6 @@ def output_ports(self): def __init__( self, - *, manifest_filepath, batch_size, min_duration=0.1, @@ -74,9 +73,8 @@ def __init__( shuffle=True, num_workers=0, n_segments=0, - **kwargs ): - DataLayerNM.__init__(self, **kwargs) + super().__init__() self._dataset = AudioOnlyDataset( manifest_filepath=manifest_filepath, diff --git a/nemo/collections/tts/tacotron2_modules.py b/nemo/collections/tts/tacotron2_modules.py index 8c496aa3fe4e..0613311d3dc4 100644 --- a/nemo/collections/tts/tacotron2_modules.py +++ b/nemo/collections/tts/tacotron2_modules.py @@ -60,8 +60,8 @@ def output_ports(self): ) } - def __init__(self, n_symbols, symbols_embedding_dim: int = 512, **kwargs): - super().__init__(**kwargs) + def __init__(self, n_symbols, symbols_embedding_dim: int = 512): + super().__init__() self.embedding = nn.Embedding(n_symbols, symbols_embedding_dim) self.to(self._device) @@ -123,9 +123,9 @@ def output_ports(self): } def __init__( - self, encoder_n_convolutions: int = 5, encoder_embedding_dim: int = 512, encoder_kernel_size: int = 3, **kwargs + self, encoder_n_convolutions: int = 5, encoder_embedding_dim: int = 512, encoder_kernel_size: int = 3 ): - super().__init__(**kwargs) + super().__init__() self.encoder = Encoder( encoder_n_convolutions=encoder_n_convolutions, encoder_embedding_dim=encoder_embedding_dim, @@ -254,9 +254,8 @@ def __init__( attention_location_n_filters: int = 32, attention_location_kernel_size: int = 31, prenet_p_dropout: float = 0.5, - **kwargs ): - super().__init__(**kwargs) + super().__init__() self.decoder = Decoder( n_mel_channels=n_mel_channels, n_frames_per_step=n_frames_per_step, @@ -450,9 +449,8 @@ def __init__( postnet_kernel_size: int = 5, postnet_n_convolutions: int = 5, p_dropout: float = 0.5, - **kwargs ): - super().__init__(**kwargs) + super().__init__() self.postnet = Postnet( n_mel_channels=n_mel_channels, postnet_embedding_dim=postnet_embedding_dim, @@ -547,8 +545,8 @@ def output_ports(self): """ return {"loss": NeuralType(None)} - def __init__(self, pad_value: float = -11.52, **kwargs): - super().__init__(**kwargs) + def __init__(self, pad_value: float = -11.52): + super().__init__() self.pad_value = pad_value def _loss_function(self, **kwargs): diff --git a/nemo/collections/tts/waveglow_modules.py b/nemo/collections/tts/waveglow_modules.py index 954689b0bd3d..5e13ae73faf9 100644 --- a/nemo/collections/tts/waveglow_modules.py +++ b/nemo/collections/tts/waveglow_modules.py @@ -94,9 +94,8 @@ def __init__( n_wn_layers: int = 8, n_wn_channels: int = 512, wn_kernel_size: int = 3, - **kwargs ): - super().__init__(**kwargs) + super().__init__() wavenet_config = { "n_layers": n_wn_layers, "n_channels": n_wn_channels, @@ -197,7 +196,6 @@ def __init__( n_wn_channels: int = 512, wn_kernel_size: int = 3, sigma: float = 0.6, - **kwargs ): self._sigma = sigma super().__init__( @@ -209,7 +207,6 @@ def __init__( n_wn_layers=n_wn_layers, n_wn_channels=n_wn_channels, wn_kernel_size=wn_kernel_size, - **kwargs ) self._removed_weight_norm = False @@ -287,8 +284,8 @@ def output_ports(self): """ return {"loss": NeuralType(None)} - def __init__(self, sigma: float = 1.0, **kwargs): - super().__init__(**kwargs) + def __init__(self, sigma: float = 1.0): + super().__init__() self.sigma = sigma def _loss_function(self, **kwargs): diff --git a/nemo/core/neural_factory.py b/nemo/core/neural_factory.py index 086af2a04fbf..ede5195b3909 100644 --- a/nemo/core/neural_factory.py +++ b/nemo/core/neural_factory.py @@ -418,8 +418,10 @@ def __name_import(name): return mod @deprecated(version=0.11) - def __get_pytorch_module(self, name, params, collection, pretrained): - params["factory"] = self + def __get_pytorch_module(self, name, collection, params, pretrained): + # TK: "factory" is not passed as parameter anymore. + # params["factory"] = self + if collection == "toys" or collection == "tutorials" or collection == "other": constructor = NeuralModuleFactory.__name_import("nemo.backends.pytorch.tutorials." + name) elif collection == "nemo_nlp": @@ -461,7 +463,7 @@ def __get_pytorch_module(self, name, params, collection, pretrained): if num_classes is not None: pt_model.fc = nn.Linear(512, params["num_classes"]) return mw.TrainableNeuralModuleWrapper( - pt_nn_module=pt_model, input_ports_dict=input_ports, output_ports_dict=output_ports, **params, + pt_nn_module=pt_model, input_ports_dict=input_ports, output_ports_dict=output_ports, ) elif _nm_name == "resnet50": input_ports = { @@ -481,7 +483,7 @@ def __get_pytorch_module(self, name, params, collection, pretrained): if num_classes is not None: pt_model.fc = nn.Linear(2048, params["num_classes"]) return mw.TrainableNeuralModuleWrapper( - pt_nn_module=pt_model, input_ports_dict=input_ports, output_ports_dict=output_ports, **params, + pt_nn_module=pt_model, input_ports_dict=input_ports, output_ports_dict=output_ports, ) else: collection_path = "nemo.collections." + collection + "." + name @@ -489,13 +491,14 @@ def __get_pytorch_module(self, name, params, collection, pretrained): if name == "BERT" and pretrained is True: params["pretrained"] = True - if "placement" not in params: - params["placement"] = self._placement + # TK: "placement" is not passed as parameter anymore. + # if "placement" not in params: + # params["placement"] = self._placement instance = constructor(**params) return instance @deprecated(version=0.11) - def get_module(self, name, params, collection, pretrained=False): + def get_module(self, name, collection, params, pretrained=False): """ Creates NeuralModule instance @@ -511,22 +514,24 @@ def get_module(self, name, params, collection, pretrained=False): Returns: NeuralModule instance """ - if params is not None and "optimization_level" in params: - if params["optimization_level"] != self._optim_level: - nemo.logging.warning( - "Module's {0} requested optimization level {1} is" - "different from the one specified by factory - {2}." - "Using: {3} for this module".format( - name, params["optimization_level"], self._optim_level, params["optimization_level"], - ) - ) - else: - if params is None: - params = {} - params["optimization_level"] = self._optim_level + + # TK: "optimization_level" is not passed as parameter anymore. + # if params is not None and "optimization_level" in params: + # if params["optimization_level"] != self._optim_level: + # nemo.logging.warning( + # "Module's {0} requested optimization level {1} is" + # "different from the one specified by factory - {2}." + # "Using: {3} for this module".format( + # name, params["optimization_level"], self._optim_level, params["optimization_level"], + # ) + # ) + # else: + # if params is None: + # params = {} + # params["optimization_level"] = self._optim_level if self._backend == Backend.PyTorch: - return self.__get_pytorch_module(name=name, params=params, collection=collection, pretrained=pretrained,) + return self.__get_pytorch_module(name=name, collection=collection, params=params, pretrained=pretrained,) else: return None @@ -743,8 +748,8 @@ def placement(self): def optim_level(self): return self._optim_level - @deprecated(version=0.11, explanation="Please use ``nemo.logging instead``") @property + @deprecated(version=0.11, explanation="Please use ``nemo.logging instead``") def logger(self): return nemo.logging diff --git a/nemo/core/neural_modules.py b/nemo/core/neural_modules.py index 94f0c7e5635b..609d3c567a30 100644 --- a/nemo/core/neural_modules.py +++ b/nemo/core/neural_modules.py @@ -7,10 +7,10 @@ from abc import ABC, abstractmethod from collections import namedtuple from enum import Enum -from inspect import getargvalues, stack +from inspect import getargvalues, getfullargspec, stack from typing import Dict, List, Optional, Set, Tuple -from .neural_factory import DeviceType, Optimization +import nemo from .neural_types import ( CanNotInferResultNeuralType, NeuralPortNameMismatchError, @@ -19,7 +19,6 @@ NeuralTypeComparisonResult, NmTensor, ) -from nemo import logging from nemo.core import NeuralModuleFactory from nemo.utils.decorators.deprecated import deprecated @@ -38,47 +37,132 @@ class WeightShareTransform(Enum): class NeuralModule(ABC): """Abstract class that every Neural Module must inherit from. - - Args: - pretrained_model_name (str): name of pretrained model to use in order - to initialize this neural module - factory (NeuralModuleFactory): :class:`NeuralModuleFactory` which - created or which should mange this instance. Required for - multi-gpu training. - placement (DeviceType): (default:None) where this module should - be placed. If provided, this parameter takes precedence over - whatever is specified in factory. """ - def __init__( - self, *, pretrained_model_name=None, factory=None, placement=None, **kwargs, - ): - self._pretrained_model_name = pretrained_model_name - self._local_parameters = self.update_local_params() + def __init__(self): - default_factory = NeuralModuleFactory.get_default_factory() - if (factory is None) and (default_factory is not None): - factory = default_factory + # Get default factory. + self._factory = NeuralModuleFactory.get_default_factory() # Set module properties from factory else use defaults - self._placement = factory.placement if factory is not None else DeviceType.GPU - self._opt_level = factory.optim_level if factory is not None else Optimization.mxprO0 + self._placement = self._factory.placement + # If one needs to change that should override it manually. - # Update module properties using overrides if overrides exist - if placement is not None: - self._placement = placement + # Optimization level. + self._opt_level = self._factory.optim_level - self._factory = factory + # Get object UUID. self._uuid = str(uuid.uuid4()) - # if kwargs: - # logging.warning( - # "When constructing {}. The base " - # "NeuralModule class received the following unused " - # "arguments:".format(self.__class__.__name__)) - # logging.warning("{}".format(kwargs.keys())) + # Retrieve dictionary of parameters (keys, values) passed to init. + self._init_params = self.__extract_init_params() + + # Pint the types of the values. + # for key, value in self._init_params.items(): + # print("{}: {} ({})".format(key, value, type(value))) + + # Validate the parameters. + # self._validate_params(self._init_params) + + @property + def init_params(self) -> Optional[Dict]: + """ + Property returning parameters used to instantiate the module. + + Returns: + Dictionary containing parameters used to instantiate the module. + """ + return self._init_params + + def __extract_init_params(self): + """ + Retrieves the dictionary of of parameters (keys, values) passed to constructor of a class derived + (also indirectly) from the Neural Module class. + + Returns: + Dictionary containing parameters passed to init(). + """ + # Get names of arguments of the original module init method. + init_keys = getfullargspec(type(self).__init__).args + + # Remove self. + if "self" in init_keys: + init_keys.remove("self") - @deprecated() + # Create list of params. + init_params = {}.fromkeys(init_keys) + + # Retrieve values of those params from the call list. + for frame in stack()[1:]: + localvars = getargvalues(frame[0]).locals + # print("localvars: ", localvars) + for key in init_keys: + # Found the variable! + if key in localvars.keys(): + # Save the value. + init_params[key] = localvars[key] + + # Return parameters. + return init_params + + # TODO: IF part of API, should not start with _, it hidden should start with __ + def _validate_params(self, params): + """ + Checks whether dictionary contains parameters being primitive types (string, int, float etc.) + or (lists of)+ primitive types. + + Args: + params: dictionary of parameters. + + Returns: + True if all parameters were ok, False otherwise. + """ + ok = True + + # Iterate over parameters and check them one by one. + for key, variable in params.items(): + if not self.__is_of_allowed_type(variable): + nemo.logging.warning( + "{} contains variable {} is of type {} which is not of a allowed.".format( + key, variable, type(variable) + ) + ) + ok = False + + # Return the result. + return ok + + def __is_of_allowed_type(self, var): + """ + A recursive function that checks if a given variable is allowed (in) + + Args: + pretrained_model_name (str): name of pretrained model to use in order. + + Returns: + True if all parameters were ok, False otherwise. + """ + var_type = type(var) + + # If this is list - check its elements. + if var_type == list: + for list_var in var: + if not self.__is_of_allowed_type(list_var): + return False + + # If this is list - check its elements. + elif var_type == dict: + for _, dict_var in var.items(): + if not self.__is_of_allowed_type(dict_var): + return False + + elif var_type not in (str, int, float, bool): + return False + + # Well, seems that everything is ok. + return True + + @deprecated(version=0.11) @staticmethod def create_ports(**kwargs): """ Deprecated method, to be remoted in the next release.""" @@ -362,13 +446,15 @@ def placement(self): return self._placement @property + @deprecated(version=0.11) def local_parameters(self) -> Optional[Dict]: """Get module's parameters Returns: module's parameters """ - return self._local_parameters + return self._init_params + # return self._local_parameters @property def unique_instance_id(self): @@ -392,33 +478,3 @@ def num_weights(self): """Number of module's weights """ pass - - @staticmethod - def update_local_params(): - """ - Loops through the call chain of class initializations and stops at the - first class that is not an instance of Neural Module. At each step of - the loop, the class contructor arguments are added to a dictionary - containing the local parameters used to construct the Neural Module - - Returns: - A dictionary containing all parameters passed to the module's init - chain. - """ - local_parameters = {} - for frame in stack()[1:]: - posname, kwname, localvars = getargvalues(frame[0])[-3:] - # Check if caller is a Neural Module - if "self" in localvars and isinstance(localvars["self"], NeuralModule): - if posname is not None: - raise ValueError("NeuralModules cannot accept `*` " "positional arguments.") - # Get func arg dict - localvars.update(localvars.pop(kwname, [])) - del localvars["self"] - local_parameters.update(localvars) - # Else we have rearched the end of the init callchain and we are - # done - else: - break - - return local_parameters diff --git a/nemo/utils/decorators/deprecated.py b/nemo/utils/decorators/deprecated.py index e55d3733ffa7..862d99ac7ba1 100644 --- a/nemo/utils/decorators/deprecated.py +++ b/nemo/utils/decorators/deprecated.py @@ -16,61 +16,56 @@ 'deprecated', ] +import functools + +import wrapt + import nemo +# Remember which deprecation warnings have been printed already. +_PRINTED_WARNING = {} + -class deprecated(object): +def deprecated(wrapped=None, version=None, explanation=None): """ Decorator class used for indicating that a function is deprecated and going to be removed. Tracks down which functions printed the warning and will print it only once per function. """ - # Static variable - list of names of functions that we already printed - # the warning for. - warned_functions = [] + if wrapped is None: + return functools.partial(deprecated, version=version, explanation=explanation) - def __init__(self, version=None, explanation=None): + @wrapt.decorator + def wrapper(wrapped, instance, args, kwargs): """ - Constructor. Stores version and explanation into local variables. + Method prints the adequate warning (only once per function) when + required and calls the function func, passing the original arguments, + i.e. version and explanation. Args: version: Version in which the function will be removed (optional) explanation: Additional explanation (optional), e.g. use method ``blabla instead``. - - """ - self.version = version - self.explanation = explanation - - def __call__(self, func): - """ - Method prints the adequate warning (only once per function) when - required and calls the function func, passing the original arguments. """ - def wrapper(*args, **kwargs): - """ - Function prints the adequate warning and calls the function func, - passing the original arguments. - """ - # Check if we already warned about that function. - if func.__name__ not in deprecated.warned_functions: - # Add to list so we won't print it again. - deprecated.warned_functions.append(func.__name__) + # Check if we already warned about that function. + if wrapped.__name__ not in _PRINTED_WARNING.keys(): + # Add to list so we won't print it again. + _PRINTED_WARNING[wrapped.__name__] = True - # Prepare the warning message. - msg = "Function ``{}`` is deprecated.".format(func.__name__) + # Prepare the warning message. + msg = "Function ``{}`` is deprecated.".format(wrapped.__name__) - # Optionally, add version and alternative. - if self.version is not None: - msg = msg + " It is going to be removed in " - msg = msg + "the {} version.".format(self.version) + # Optionally, add version and alternative. + if version is not None: + msg = msg + " It is going to be removed in " + msg = msg + "the {} version.".format(version) - if self.explanation is not None: - msg = msg + " " + self.explanation + if explanation is not None: + msg = msg + " " + explanation - # Display the deprecated warning. - nemo.logging.warning(msg) + # Display the deprecated warning. + nemo.logging.warning(msg) - # Call the function. - return func(*args, **kwargs) + # Call the function. + return wrapped(*args, **kwargs) - return wrapper + return wrapper(wrapped) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 77f09161de68..9528b238b88d 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -3,4 +3,5 @@ pandas tensorboardX torch torchvision -wget \ No newline at end of file +wget +wrapt diff --git a/requirements/requirements_test.txt b/requirements/requirements_test.txt index 9017d9f7536b..192a3a3fddc7 100644 --- a/requirements/requirements_test.txt +++ b/requirements/requirements_test.txt @@ -1,4 +1,5 @@ parameterized pytest black -isort[requirements] \ No newline at end of file +isort[requirements] +wrapt diff --git a/tests/asr/test_asr.py b/tests/asr/test_asr.py index f3aeb12ee81c..b77b5cd582b5 100644 --- a/tests/asr/test_asr.py +++ b/tests/asr/test_asr.py @@ -18,6 +18,7 @@ import os import shutil import tarfile +import unittest from ruamel.yaml import YAML @@ -28,6 +29,8 @@ from tests.common_setup import NeMoUnitTest logging = nemo.logging + + freq = 16000 @@ -180,11 +183,11 @@ def test_pytorch_audio_dataset(self): def test_dataloader(self): batch_size = 4 dl = nemo_asr.AudioToTextDataLayer( - featurizer_config=self.featurizer_config, + # featurizer_config=self.featurizer_config, manifest_filepath=self.manifest_filepath, labels=self.labels, batch_size=batch_size, - placement=DeviceType.GPU, + # placement=DeviceType.GPU, drop_last=True, ) for ind, data in enumerate(dl.data_iterator): @@ -254,21 +257,21 @@ def test_kaldi_dataloader(self): def test_trim_silence(self): batch_size = 4 normal_dl = nemo_asr.AudioToTextDataLayer( - featurizer_config=self.featurizer_config, + # featurizer_config=self.featurizer_config, manifest_filepath=self.manifest_filepath, labels=self.labels, batch_size=batch_size, - placement=DeviceType.GPU, + # placement=DeviceType.GPU, drop_last=True, shuffle=False, ) trimmed_dl = nemo_asr.AudioToTextDataLayer( - featurizer_config=self.featurizer_config, + # featurizer_config=self.featurizer_config, manifest_filepath=self.manifest_filepath, trim_silence=True, labels=self.labels, batch_size=batch_size, - placement=DeviceType.GPU, + # placement=DeviceType.GPU, drop_last=True, shuffle=False, ) @@ -279,11 +282,11 @@ def test_trim_silence(self): def test_audio_preprocessors(self): batch_size = 5 dl = nemo_asr.AudioToTextDataLayer( - featurizer_config=self.featurizer_config, + # featurizer_config=self.featurizer_config, manifest_filepath=self.manifest_filepath, labels=self.labels, batch_size=batch_size, - placement=DeviceType.GPU, + # placement=DeviceType.GPU, drop_last=True, shuffle=False, ) @@ -322,17 +325,17 @@ def test_audio_preprocessors(self): self.assertTrue(spec[0].shape[1] == 201) # n_fft // 2 + 1 bins self.assertTrue(mfcc[0].shape[1] == 15) + # @unittest.skip("Init parameters of nemo_asr.AudioToMelSpectrogramPreprocessor are invalid") def test_jasper_training(self): with open(os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/jasper_smaller.yaml"))) as file: jasper_model_definition = self.yaml.load(file) dl = nemo_asr.AudioToTextDataLayer( - featurizer_config=self.featurizer_config, + # featurizer_config=self.featurizer_config, manifest_filepath=self.manifest_filepath, labels=self.labels, batch_size=4, ) pre_process_params = { - 'int_values': False, 'frame_splicing': 1, 'features': 64, 'window_size': 0.02, @@ -366,25 +369,22 @@ def test_jasper_training(self): tensors=[loss], print_func=lambda x: logging.info(f'Train Loss: {str(x[0].item())}'), ) # Instantiate an optimizer to perform `train` action - neural_factory = nemo.core.NeuralModuleFactory( - backend=nemo.core.Backend.PyTorch, local_rank=None, create_tb_writer=False, - ) - optimizer = neural_factory.get_trainer() + optimizer = self.nf.get_trainer() optimizer.train( [loss], callbacks=[callback], optimizer="sgd", optimization_params={"num_epochs": 10, "lr": 0.0003}, ) + # @unittest.skip("Init parameters of nemo_asr.AudioToMelSpectrogramPreprocessor are invalid") def test_double_jasper_training(self): with open(os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/jasper_smaller.yaml"))) as file: jasper_model_definition = self.yaml.load(file) dl = nemo_asr.AudioToTextDataLayer( - featurizer_config=self.featurizer_config, + # featurizer_config=self.featurizer_config, manifest_filepath=self.manifest_filepath, labels=self.labels, batch_size=4, ) pre_process_params = { - 'int_values': False, 'frame_splicing': 1, 'features': 64, 'window_size': 0.02, @@ -429,32 +429,23 @@ def test_double_jasper_training(self): tensors=[loss], print_func=lambda x: logging.info(str(x[0].item())) ) # Instantiate an optimizer to perform `train` action - neural_factory = nemo.core.NeuralModuleFactory( - backend=nemo.core.Backend.PyTorch, local_rank=None, create_tb_writer=False, - ) - optimizer = neural_factory.get_trainer() + optimizer = self.nf.get_trainer() optimizer.train( [loss], callbacks=[callback], optimizer="sgd", optimization_params={"num_epochs": 10, "lr": 0.0003}, ) + # @unittest.skip("Init parameters of nemo_asr.AudioToMelSpectrogramPreprocessor are invalid") def test_quartznet_training(self): with open(os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/quartznet_test.yaml"))) as f: quartz_model_definition = self.yaml.load(f) - dl = nemo_asr.AudioToTextDataLayer( - featurizer_config=self.featurizer_config, - manifest_filepath=self.manifest_filepath, - labels=self.labels, - batch_size=4, - ) + dl = nemo_asr.AudioToTextDataLayer(manifest_filepath=self.manifest_filepath, labels=self.labels, batch_size=4,) pre_process_params = { - 'int_values': False, 'frame_splicing': 1, 'features': 64, 'window_size': 0.02, 'n_fft': 512, 'dither': 1e-05, 'window': 'hann', - 'feat_type': 'logfbank', 'sample_rate': 16000, 'normalize': 'per_feature', 'window_stride': 0.01, @@ -481,10 +472,7 @@ def test_quartznet_training(self): tensors=[loss], print_func=lambda x: logging.info(f'Train Loss: {str(x[0].item())}'), ) # Instantiate an optimizer to perform `train` action - neural_factory = nemo.core.NeuralModuleFactory( - backend=nemo.core.Backend.PyTorch, local_rank=None, create_tb_writer=False, - ) - optimizer = neural_factory.get_trainer() + optimizer = self.nf.get_trainer() optimizer.train( [loss], callbacks=[callback], optimizer="sgd", optimization_params={"num_epochs": 10, "lr": 0.0003}, ) @@ -492,14 +480,8 @@ def test_quartznet_training(self): def test_stft_conv(self): with open(os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/jasper_smaller.yaml"))) as file: jasper_model_definition = self.yaml.load(file) - dl = nemo_asr.AudioToTextDataLayer( - featurizer_config=self.featurizer_config, - manifest_filepath=self.manifest_filepath, - labels=self.labels, - batch_size=4, - ) + dl = nemo_asr.AudioToTextDataLayer(manifest_filepath=self.manifest_filepath, labels=self.labels, batch_size=4,) pre_process_params = { - 'int_values': False, 'frame_splicing': 1, 'features': 64, 'window_size': 0.02, @@ -535,10 +517,7 @@ def test_stft_conv(self): tensors=[loss], print_func=lambda x: logging.info(str(x[0].item())) ) # Instantiate an optimizer to perform `train` action - neural_factory = nemo.core.NeuralModuleFactory( - backend=nemo.core.Backend.PyTorch, local_rank=None, create_tb_writer=False, - ) - optimizer = neural_factory.get_trainer() + optimizer = self.nf.get_trainer() optimizer.train( [loss], callbacks=[callback], optimizer="sgd", optimization_params={"num_epochs": 10, "lr": 0.0003}, ) @@ -546,14 +525,8 @@ def test_stft_conv(self): def test_clas(self): with open('examples/asr/experimental/configs/garnet_an4.yaml') as file: cfg = self.yaml.load(file) - dl = nemo_asr.AudioToTextDataLayer( - featurizer_config=self.featurizer_config, - manifest_filepath=self.manifest_filepath, - labels=self.labels, - batch_size=4, - ) + dl = nemo_asr.AudioToTextDataLayer(manifest_filepath=self.manifest_filepath, labels=self.labels, batch_size=4,) pre_process_params = { - 'int_values': False, 'frame_splicing': 1, 'features': 64, 'window_size': 0.02, @@ -575,7 +548,19 @@ def test_clas(self): in_channels=cfg['encoder']['jasper'][-1]['filters'], out_channels=cfg['decoder']['hidden_size'], ) decoder = nemo.backends.pytorch.common.DecoderRNN( - voc_size=len(self.labels), bos_id=0, **cfg['decoder'] # fictive + voc_size=len(self.labels), + bos_id=0, + hidden_size=cfg['decoder']['hidden_size'], + attention_method=cfg['decoder']['attention_method'], + attention_type=cfg['decoder']['attention_type'], + in_dropout=cfg['decoder']['in_dropout'], + gru_dropout=cfg['decoder']['gru_dropout'], + attn_dropout=cfg['decoder']['attn_dropout'], + teacher_forcing=cfg['decoder']['teacher_forcing'], + curriculum_learning=cfg['decoder']['curriculum_learning'], + rnn_type=cfg['decoder']['rnn_type'], + n_layers=cfg['decoder']['n_layers'], + tie_emb_out_weights=cfg['decoder']['tie_emb_out_weights'], ) loss = nemo.backends.pytorch.common.SequenceLoss() @@ -592,10 +577,7 @@ def test_clas(self): tensors=[loss], print_func=lambda x: logging.info(str(x[0].item())) ) # Instantiate an optimizer to perform `train` action - neural_factory = nemo.core.NeuralModuleFactory( - backend=nemo.core.Backend.PyTorch, local_rank=None, create_tb_writer=False, - ) - optimizer = neural_factory.get_trainer() + optimizer = self.nf.get_trainer() optimizer.train( [loss], callbacks=[callback], optimizer="sgd", optimization_params={"num_epochs": 10, "lr": 0.0003}, ) @@ -603,14 +585,8 @@ def test_clas(self): def test_jasper_eval(self): with open(os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/jasper_smaller.yaml"))) as file: jasper_model_definition = self.yaml.load(file) - dl = nemo_asr.AudioToTextDataLayer( - featurizer_config=self.featurizer_config, - manifest_filepath=self.manifest_filepath, - labels=self.labels, - batch_size=4, - ) + dl = nemo_asr.AudioToTextDataLayer(manifest_filepath=self.manifest_filepath, labels=self.labels, batch_size=4,) pre_process_params = { - 'int_values': False, 'frame_splicing': 1, 'features': 64, 'window_size': 0.02, @@ -652,11 +628,4 @@ def test_jasper_eval(self): user_epochs_done_callback=process_evaluation_epoch, ) # Instantiate an optimizer to perform `train` action - neural_factory = nemo.core.NeuralModuleFactory( - backend=nemo.core.Backend.PyTorch, local_rank=None, create_tb_writer=False, - ) - neural_factory.eval(callbacks=[eval_callback]) - - -# if __name__ == '__main__': -# unittest.main() + self.nf.eval(callbacks=[eval_callback]) diff --git a/tests/asr/test_weight_share.py b/tests/asr/test_weight_share.py index 4bc1d6125da9..e4e0ce8247f4 100644 --- a/tests/asr/test_weight_share.py +++ b/tests/asr/test_weight_share.py @@ -19,6 +19,7 @@ import os import shutil import tarfile +import unittest from typing import Dict import numpy as np @@ -172,13 +173,13 @@ def test_freeze_unfreeze_TrainableNM(self): with open(path) as file: jasper_model_definition = self.yaml.load(file) dl = nemo_asr.AudioToTextDataLayer( - featurizer_config=self.featurizer_config, + # featurizer_config=self.featurizer_config, manifest_filepath=self.manifest_filepath, labels=self.labels, batch_size=4, ) pre_process_params = { - 'int_values': False, + #'int_values': False, 'frame_splicing': 1, 'features': 64, 'window_size': 0.02, @@ -213,20 +214,16 @@ def test_freeze_unfreeze_TrainableNM(self): callback = nemo.core.SimpleLossLoggerCallback( tensors=[loss], print_func=lambda x: logging.info(f'Train Loss: {str(x[0].item())}'), ) - # Instantiate an optimizer to perform `train` action - neural_factory = nemo.core.NeuralModuleFactory( - backend=nemo.core.Backend.PyTorch, local_rank=None, create_tb_writer=False, - ) - optimizer = neural_factory.get_trainer() + optimizer = self.nf.get_trainer() optimizer.train( [loss], callbacks=[callback], optimizer="sgd", optimization_params={"num_epochs": 2, "lr": 0.0003}, ) + # @unittest.skip( + # "Tests fails at get_pytorch_module() that will be changed in next PR anyway. \ + # Besides, quite sure this test is not related with ASR :]" + # ) def test_freeze_unfreeze_Wrapper(self): - neural_factory = nemo.core.NeuralModuleFactory( - backend=nemo.core.Backend.PyTorch, placement=nemo.core.DeviceType.GPU, create_tb_writer=False, - ) - dl_train = nemo.backends.pytorch.ZerosDataLayer( size=40, dtype=[torch.FloatTensor, torch.LongTensor], @@ -244,12 +241,14 @@ def test_freeze_unfreeze_Wrapper(self): }, ) + # WHY THE HELL THIS TEST IS IN ASR!!!!??? + # NOTICE: pretrain=True argument - resnet = neural_factory.get_module( + resnet = self.nf.get_module( name="resnet18", params={"num_classes": 2}, collection="torchvision", pretrained=True, ) - L_train = neural_factory.get_module(name="CrossEntropyLoss", collection="toys", params={}) + L_train = self.nf.get_module(name="CrossEntropyLoss", collection="toys", params={}) # NOTICE: Freeze all Neural Module's weights resnet.freeze() @@ -264,10 +263,9 @@ def test_freeze_unfreeze_Wrapper(self): tensors=[train_loss], print_func=lambda x: logging.info(f'Train Loss: {str(x[0].item())}'), ) # Instantiate an optimizer to perform `train` action - neural_factory = nemo.core.NeuralModuleFactory( - backend=nemo.core.Backend.PyTorch, local_rank=None, create_tb_writer=False, - ) - optimizer = neural_factory.get_trainer() + optimizer = self.nf.get_trainer() optimizer.train( [train_loss], callbacks=[callback], optimizer="sgd", optimization_params={"num_epochs": 2, "lr": 0.0003}, ) + + # WHERE IS ACTUALLY THE TEST?? ARE WE CHECKING ANYTHING?? diff --git a/tests/asr/test_zeroDS.py b/tests/asr/test_zeroDS.py index 6197d74cc014..3b6b15dba4a6 100644 --- a/tests/asr/test_zeroDS.py +++ b/tests/asr/test_zeroDS.py @@ -88,9 +88,6 @@ def tearDownClass(cls) -> None: def test_simple_train(self): logging.info("Simplest train test with ZeroDL") - neural_factory = nemo.core.neural_factory.NeuralModuleFactory( - backend=nemo.core.Backend.PyTorch, create_tb_writer=False - ) trainable_module = nemo.backends.pytorch.tutorials.TaylorNet(dim=4) data_source = nemo.backends.pytorch.common.ZerosDataLayer( size=10000, @@ -109,7 +106,7 @@ def test_simple_train(self): callback = nemo.core.SimpleLossLoggerCallback( tensors=[loss_tensor], print_func=lambda x: logging.info(f'Train Loss: {str(x[0].item())}'), ) - neural_factory.train( + self.nf.train( [loss_tensor], callbacks=[callback], optimization_params={"num_epochs": 3, "lr": 0.0003}, optimizer="sgd", ) @@ -157,9 +154,6 @@ def test_asr_with_zero_ds(self): tensors=[loss], print_func=lambda x: logging.info(f'Train Loss: {str(x[0].item())}'), ) # Instantiate an optimizer to perform `train` action - neural_factory = nemo.core.NeuralModuleFactory( - backend=nemo.core.Backend.PyTorch, local_rank=None, create_tb_writer=False, - ) - neural_factory.train( + self.nf.train( [loss], callbacks=[callback], optimization_params={"num_epochs": 2, "lr": 0.0003}, optimizer="sgd", ) diff --git a/tests/common_setup.py b/tests/common_setup.py index 6d9a1e30f234..eb31d7c38625 100644 --- a/tests/common_setup.py +++ b/tests/common_setup.py @@ -25,7 +25,11 @@ class NeMoUnitTest(unittest.TestCase): def setUp(self) -> None: - nemo.core.neural_factory.NeuralModuleFactory.reset_default_factory() - logging.info("---------------------------------------------------------") - logging.info(self._testMethodName) - logging.info("---------------------------------------------------------") + """ Default setup - instantiates Neural Factory. """ + # Initialize the default Neural Factory - on GPU. + self.nf = nemo.core.NeuralModuleFactory(placement=nemo.core.DeviceType.GPU) + # Reset loggers. + self.nf._exp_manager.reset_loggers() + + # Print standard header. + logging.info("-" * 20 + " " + self._testMethodName + " " + "-" * 20) diff --git a/tests/nlp/test_squad.py b/tests/nlp/test_squad.py index 6da31b90ac7f..46a04c301dd4 100644 --- a/tests/nlp/test_squad.py +++ b/tests/nlp/test_squad.py @@ -92,7 +92,7 @@ def test_squad_v1(self): backend=nemo.core.Backend.PyTorch, local_rank=None, create_tb_writer=False, ) model = nemo_nlp.huggingface.BERT(pretrained_model_name=pretrained_bert_model) - hidden_size = model.local_parameters["hidden_size"] + hidden_size = model.hidden_size qa_head = nemo_nlp.TokenClassifier(hidden_size=hidden_size, num_classes=2, num_layers=1, log_softmax=False,) squad_loss = nemo_nlp.QuestionAnsweringLoss() @@ -199,7 +199,7 @@ def test_squad_v2(self): backend=nemo.core.Backend.PyTorch, local_rank=None, create_tb_writer=False, ) model = nemo_nlp.huggingface.BERT(pretrained_model_name=pretrained_bert_model) - hidden_size = model.local_parameters["hidden_size"] + hidden_size = model.hidden_size qa_head = nemo_nlp.TokenClassifier(hidden_size=hidden_size, num_classes=2, num_layers=1, log_softmax=False,) squad_loss = nemo_nlp.QuestionAnsweringLoss() diff --git a/tests/test_deploy_export.py b/tests/test_deploy_export.py index a2194807512a..8d51161d92e9 100644 --- a/tests/test_deploy_export.py +++ b/tests/test_deploy_export.py @@ -29,8 +29,12 @@ class TestDeployExport(NeMoUnitTest): - def setUp(self) -> None: - self.nf = nemo.core.NeuralModuleFactory(placement=nemo.core.DeviceType.GPU) + def setUp(self): + """ Setups neural factory so it will use GPU instead of CPU. """ + NeMoUnitTest.setUp(self) + + # Perform computations on GPU. + self.nf._placement = nemo.core.DeviceType.GPU def __test_export_route(self, module, out_name, mode, input_example=None): out = Path(out_name) @@ -46,7 +50,7 @@ def __test_export_route(self, module, out_name, mode, input_example=None): os.remove(out) def test_simple_module_export(self): - simplest_module = nemo.backends.pytorch.tutorials.TaylorNet(dim=4, factory=self.nf) + simplest_module = nemo.backends.pytorch.tutorials.TaylorNet(dim=4) self.__test_export_route( module=simplest_module, out_name="simple.pt", diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index 7dabb2f40d89..9f41fa21340a 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -1,16 +1,21 @@ -# Copyright (C) NVIDIA. All Rights Reserved. +# ! /usr/bin/python +# -*- coding: utf-8 -*- + +# ============================================================================= +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# ============================================================================= from io import StringIO @@ -20,12 +25,12 @@ from tests.common_setup import NeMoUnitTest -class DeprecatedTestCase(NeMoUnitTest): +class DeprecatedTest(NeMoUnitTest): def test_say_whee_deprecated(self): """ Tests whether both std and err streams return the right values when function is deprecated.""" - @deprecated() + @deprecated def say_whee(): print("Whee!") @@ -44,7 +49,7 @@ def test_say_wow_twice_deprecated(self): """ Tests whether both std and err streams return the right values when a deprecated is called twice.""" - @deprecated() + @deprecated def say_wow(): print("Woooow!") diff --git a/tests/test_infer.py b/tests/test_infer.py index 8e83ca1a0f2b..c6faeb8cdcec 100644 --- a/tests/test_infer.py +++ b/tests/test_infer.py @@ -25,8 +25,8 @@ class AddsTen(NonTrainableNM): - def __init__(self, **kwargs): - super().__init__(**kwargs) + def __init__(self): + super().__init__() @property def input_ports(self): @@ -41,8 +41,8 @@ def forward(self, mod_in): class SubtractsTen(NonTrainableNM): - def __init__(self, **kwargs): - super().__init__(**kwargs) + def __init__(self): + super().__init__() @property def input_ports(self): diff --git a/tests/test_neural_factory.py b/tests/test_neural_factory.py index 83db0f16e4c8..d9d0aa0baeb1 100644 --- a/tests/test_neural_factory.py +++ b/tests/test_neural_factory.py @@ -21,28 +21,17 @@ class TestNeuralFactory(NeMoUnitTest): - def test_creation(self): - neural_factory = nemo.core.NeuralModuleFactory( - backend=nemo.core.Backend.PyTorch, local_rank=None, create_tb_writer=False, - ) - instance = neural_factory.get_module(name="TaylorNet", collection="toys", params={"dim": 4}) + def test_create_single_module(self): + instance = self.nf.get_module(name="TaylorNet", collection="toys", params={"dim": 4}) self.assertTrue(isinstance(instance, nemo.backends.pytorch.tutorials.TaylorNet)) - def test_simple_example(self): - neural_factory = nemo.core.neural_factory.NeuralModuleFactory( - backend=nemo.core.Backend.PyTorch, local_rank=None, create_tb_writer=False, - ) - dl = neural_factory.get_module( + def test_create_simple_graph(self): + dl = self.nf.get_module( name="RealFunctionDataLayer", collection="toys", params={"n": 10000, "batch_size": 128}, ) - fx = neural_factory.get_module(name="TaylorNet", collection="toys", params={"dim": 4}) - loss = neural_factory.get_module(name="MSELoss", collection="toys", params={}) + fx = self.nf.get_module(name="TaylorNet", collection="toys", params={"dim": 4}) + loss = self.nf.get_module(name="MSELoss", collection="toys", params={}) x, y = dl() y_pred = fx(x=x) - loss_tensor = loss(predictions=y_pred, target=y) - - optimizer = neural_factory.get_trainer() - optimizer.train( - [loss_tensor], optimizer="sgd", optimization_params={"lr": 1e-3, "num_epochs": 1}, - ) + _ = loss(predictions=y_pred, target=y) diff --git a/tests/test_neural_modules_initialization.py b/tests/test_neural_modules_initialization.py new file mode 100644 index 000000000000..e6d5c29a4827 --- /dev/null +++ b/tests/test_neural_modules_initialization.py @@ -0,0 +1,83 @@ +# ! /usr/bin/python +# -*- coding: utf-8 -*- + +# ============================================================================= +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + + +import nemo +from tests.common_setup import NeMoUnitTest + + +class MockupModule(nemo.core.NeuralModule): + """ + Mockup component class. + """ + + def __init__(self): + nemo.core.NeuralModule.__init__(self) + + +class NeuralModuleConfigTest(NeMoUnitTest): + """ + Class testing methods related to Neural Module import/export. + """ + + def setUp(self) -> None: + super().setUp() + + # Mockup abstract methods. + MockupModule.__abstractmethods__ = set() + + # Create object. + self.module = MockupModule() + + def test_build_in_types(self): + """ Tests whether build-in types are handled.""" + + params = {"int": 123, "float": 12.4, "string": "ala ma kota", "bool": True} + + # Check error output. + self.assertEqual(self.module._validate_params(params), True) + + def test_nested_dict(self): + """ Tests whether (nested) dicts are handled.""" + + params = { + "dict_outer": { + "dict_inner_1": {"int": 123, "float": 12.4, "string": "ala ma kota", "bool": True}, + "dict_inner_2": {"int": 123, "float": 12.4, "string": "ala ma kota", "bool": True}, + } + } + + # Check error output. + self.assertEqual(self.module._validate_params(params), True) + + def test_nested_list(self): + """ Tests whether (nested) lists are handled.""" + + params = {"list_outer": [[1, 2, 3, 4]]} + + # Check error output. + self.assertEqual(self.module._validate_params(params), True) + + def test_nested_mix(self): + """ Tests whether (nested) lists are handled.""" + + params = {"list_outer": [{"int": 123, "float": 12.4, "string": "ala ma kota", "bool": True}]} + + # Check error output. + self.assertEqual(self.module._validate_params(params), True) diff --git a/tests/test_neural_modules.py b/tests/test_neural_modules_pytorch.py similarity index 59% rename from tests/test_neural_modules.py rename to tests/test_neural_modules_pytorch.py index da5a7b48ad27..13ff0226262b 100644 --- a/tests/test_neural_modules.py +++ b/tests/test_neural_modules_pytorch.py @@ -1,6 +1,7 @@ # ! /usr/bin/python # -*- coding: utf-8 -*- +# ============================================================================= # Copyright 2019 NVIDIA. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,87 +17,53 @@ # limitations under the License. # ============================================================================= +import unittest + import nemo from nemo.backends.pytorch.nm import TrainableNM from tests.common_setup import NeMoUnitTest class TestNM1(TrainableNM): - def __init__(self, var1, var2=2, var3=3, **kwargs): - super(TestNM1, self).__init__(**kwargs) - - @property - def input_ports(self): - """Returns definitions of module input ports.""" - return {} - - @property - def output_ports(self): - """Returns definitions of module output ports.""" - return {} - - def foward(self): - pass + def __init__(self, var1=1, var2=2, var3=3): + super(TestNM1, self).__init__() class TestNM2(TestNM1): - def __init__(self, var2, **kwargs): - super(TestNM2, self).__init__(**kwargs) - - @property - def input_ports(self): - """Returns definitions of module input ports.""" - return {} - - @property - def output_ports(self): - """Returns definitions of module output ports.""" - return {} + def __init__(self, var2): + super(TestNM2, self).__init__(var2=var2) - def foward(self): - pass +class TestNeuralModulesPT(NeMoUnitTest): + def setUp(self) -> None: + super().setUp() -class BrokenNM(TrainableNM): - def __init__(self, var2, *error, **kwargs): - super(BrokenNM, self).__init__(**kwargs) - - @property - def input_ports(self): - """Returns definitions of module input ports.""" - return {} - - @property - def output_ports(self): - """Returns definitions of module output ports.""" - return {} - - def foward(self): - pass + # Mockup abstract methods. + TestNM1.__abstractmethods__ = set() + TestNM2.__abstractmethods__ = set() + def test_default_init_params(self): + simple_nm = TestNM1(var1=1) + init_params = simple_nm.init_params + self.assertEqual(init_params["var1"], 1) + self.assertEqual(init_params["var2"], 2) + self.assertEqual(init_params["var3"], 3) -class TestNeuralModulesPT(NeMoUnitTest): - def test_simple_local_params(self): + def test_simple_init_params(self): simple_nm = TestNM1(var1=10, var3=30) - local_params = simple_nm.local_parameters - self.assertEqual(local_params["var1"], 10) - self.assertEqual(local_params["var2"], 2) - self.assertEqual(local_params["var3"], 30) - - def test_nested_local_params(self): - simple_nm = TestNM2(25, var1="hello") - local_params = simple_nm.local_parameters - self.assertEqual(local_params["var1"], "hello") - self.assertEqual(local_params["var2"], 25) - self.assertEqual(local_params["var3"], 3) - - def test_posarg_check(self): - with self.assertRaises(ValueError): - NM = BrokenNM(8) + init_params = simple_nm.init_params + self.assertEqual(init_params["var1"], 10) + self.assertEqual(init_params["var2"], 2) + self.assertEqual(init_params["var3"], 30) + + def test_nested_init_params(self): + simple_nm = TestNM2(var2="hello") + init_params = simple_nm.init_params + self.assertEqual(init_params["var2"], "hello") def test_constructor_TaylorNet(self): tn = nemo.backends.pytorch.tutorials.TaylorNet(dim=4) - self.assertEqual(tn.local_parameters["dim"], 4) + self.assertEqual(tn.init_params["dim"], 4) def test_call_TaylorNet(self): x_tg = nemo.core.neural_modules.NmTensor(