Skip to content

Commit

Permalink
Merge pull request #307 from NVIDIA/neural_type_system2
Browse files Browse the repository at this point in the history
new version of neural type system
  • Loading branch information
okuchaiev authored Feb 12, 2020
2 parents 33d14dc + 196a248 commit f7b534a
Show file tree
Hide file tree
Showing 75 changed files with 1,969 additions and 3,272 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,16 @@ To release a new version, please update the changelog as followed:
## [Unreleased]

### Added
- New Neural Type System and its tests.
([PR #307](https://github.com/NVIDIA/NeMo/pull/307)) - @okuchaiev
- Named tensors tuple module's output for graph construction.
([PR #268](https://github.com/NVIDIA/NeMo/pull/268)) - @stasbel
- Introduced the `deprecated` decorator.
([PR #298](https://github.com/NVIDIA/NeMo/pull/298)) - @tkornuta-nvidia

### Changed
- All collections changed to use New Neural Type System.
([PR #307](https://github.com/NVIDIA/NeMo/pull/307)) - @okuchaiev
- Additional Collections Repositories merged into core `nemo_toolkit` package.
([PR #289](https://github.com/NVIDIA/NeMo/pull/289)) - @DEKHTIARJonathan
- Refactor manifest files parsing and processing for re-using.
Expand Down
4 changes: 3 additions & 1 deletion examples/start_here/chatbot_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,12 @@ def outputs2words(tensors, vocab):
tensors=[loss, src, outputs_inf, tgt], print_func=lambda x: outputs2words(x, dl.voc.index2word),
)

num_epochs = 1
logging.info(f"Training only for {num_epochs}. Train longer (~10-20) for convergence.")
# Start training
nf.train(
tensors_to_optimize=[loss],
callbacks=[callback],
optimizer="adam",
optimization_params={"num_epochs": 3, "lr": 0.001},
optimization_params={"num_epochs": num_epochs, "lr": 0.001},
)
8 changes: 4 additions & 4 deletions nemo/backends/pytorch/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,10 +919,10 @@ def __module_export(
dynamic_axes = defaultdict(list)

def __extract_dynamic_axes(port_name: str, ntype: NeuralType, dynamic_axes: defaultdict):
if ntype.axis2type:
for axis_id, axistype in ntype.axis2type.items():
if issubclass(axistype.semantics, BatchTag) or issubclass(axistype.semantics, TimeTag):
dynamic_axes[port_name].append(axis_id)
if ntype.axes:
for ind, axis in enumerate(ntype.axes):
if axis.kind == AxisKind.Batch or axis.kind == AxisKind.Time:
dynamic_axes[port_name].append(ind)

# This is a hack for Jasper to Jarvis export -- need re-design for this
inputs_to_drop = set()
Expand Down
43 changes: 9 additions & 34 deletions nemo/backends/pytorch/common/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch import nn

from nemo.backends.pytorch.nm import LossNM
from nemo.core.neural_types import AxisType, BatchTag, ChannelTag, NeuralType, RegressionTag, TimeTag
from nemo.core.neural_types import LabelsType, LogitsType, LossType, NeuralType, RegressionValuesType

__all__ = ['SequenceLoss', 'CrossEntropyLoss', 'MSELoss']

Expand Down Expand Up @@ -34,24 +34,8 @@ class SequenceLoss(LossNM):
@property
def input_ports(self):
"""Returns definitions of module input ports.
log_probs:
0: AxisType(BatchTag)
1: AxisType(TimeTag)
2: AxisType(ChannelTag)
targets:
0: AxisType(BatchTag)
1: AxisType(TimeTag)
"""
return {
'log_probs': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag), 2: AxisType(ChannelTag),}),
'targets': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}),
}
return {'log_probs': NeuralType(axes=('B', 'T', 'D')), 'targets': NeuralType(axes=('B', 'T'))}

@property
def output_ports(self):
Expand All @@ -61,7 +45,7 @@ def output_ports(self):
NeuralType(None)
"""
return {"loss": NeuralType(None)}
return {"loss": NeuralType(elements_type=LossType())}

def __init__(
self, pad_id=0, smoothing_coef=0.0, sample_wise=False, aux_ctc=False, ctc_initial_coef=0.1, ctc_blank_id=None
Expand Down Expand Up @@ -121,19 +105,10 @@ class CrossEntropyLoss(LossNM):
@property
def input_ports(self):
"""Returns definitions of module input ports.
logits:
0: AxisType(BatchTag)
1: AxisType(ChannelTag)
labels:
0: AxisType(BatchTag)
"""
return {
"logits": NeuralType({0: AxisType(BatchTag), 1: AxisType(ChannelTag)}),
"labels": NeuralType({0: AxisType(BatchTag),}),
"logits": NeuralType(axes=('B', 'D'), elements_type=LogitsType()),
"labels": NeuralType(axes=tuple('B'), elements_type=LabelsType()),
}

@property
Expand All @@ -143,7 +118,7 @@ def output_ports(self):
loss:
NeuralType(None)
"""
return {"loss": NeuralType(None)}
return {"loss": NeuralType(elements_type=LossType())}

def __init__(self, weight=None):
super().__init__()
Expand All @@ -168,8 +143,8 @@ def input_ports(self):
0: AxisType(RegressionTag)
"""
return {
"preds": NeuralType({0: AxisType(RegressionTag)}),
"labels": NeuralType({0: AxisType(RegressionTag)}),
"preds": NeuralType(tuple('B'), RegressionValuesType()),
"labels": NeuralType(tuple('B'), LabelsType()),
}

@property
Expand All @@ -179,7 +154,7 @@ def output_ports(self):
loss:
NeuralType(None)
"""
return {"loss": NeuralType(None)}
return {"loss": NeuralType(elements_type=LossType())}

def __init__(self):
super().__init__()
Expand Down
Loading

0 comments on commit f7b534a

Please sign in to comment.