Skip to content

Commit

Permalink
Code cleanup: refactoring, type checking, and formatting (#1076)
Browse files Browse the repository at this point in the history
  • Loading branch information
mjdenkowski committed Dec 11, 2022
1 parent b4c8427 commit 4dba5a3
Show file tree
Hide file tree
Showing 18 changed files with 117 additions and 97 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [3.1.29]

### Changed

- Running `sockeye-evaluate` no longer applies text tokenization for TER (same behavior as other metrics).
- Turned on type checking for all `sockeye` modules except `test_utils` and addressed resulting type issues.
- Refactored code in various modules without changing user-level behavior.

## [3.1.28]

### Added
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '3.1.28'
__version__ = '3.1.29'
62 changes: 34 additions & 28 deletions sockeye/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ def check_regular_directory(value_to_check):

def int_greater_or_equal(threshold: int) -> Callable:
"""
Returns a method that can be used in argument parsing to check that the int argument is greater or equal to `threshold`.
Returns a method that can be used in argument parsing to check that the int argument is greater or equal to
`threshold`.
:param threshold: The threshold that we assume the cli argument value is greater or equal to.
:return: A method that can be used as a type in argparse.
Expand All @@ -155,7 +156,8 @@ def check_greater_equal(value: str):

def float_greater_or_equal(threshold: float) -> Callable:
"""
Returns a method that can be used in argument parsing to check that the float argument is greater or equal to `threshold`.
Returns a method that can be used in argument parsing to check that the float argument is greater or equal to
`threshold`.
:param threshold: The threshold that we assume the cli argument value is greater or equal to.
:return: A method that can be used as a type in argparse.
Expand Down Expand Up @@ -571,7 +573,8 @@ def add_device_args(params):
device_params.add_argument('--tf32',
type=bool_str(),
default=True,
help='Globally enable transparent tf32 acceleration of float32 at the cost of reducing precision to 10 bits')
help='Globally enable transparent tf32 acceleration of float32 at the cost of reducing '
'precision to 10 bits. Default: %(default)s.')


def add_vocab_args(params):
Expand Down Expand Up @@ -829,22 +832,20 @@ def add_batch_args(params, default_batch_size=4096, default_batch_type=C.BATCH_T


def add_nvs_train_parameters(params):
params.add_argument(
'--bow-task-weight',
type=float_greater_or_equal(0.0),
default=1.0,
help=
'The weight of the auxiliary Bag-of-word (BOW) loss when --neural-vocab-selection is enabled. Default %(default)s.'
)

params.add_argument(
'--bow-task-pos-weight',
type=float_greater_or_equal(0.0),
default=10,
help='The weight of the positive class (the set of words present on the target side) for the BOW loss '
'when --neural-vocab-selection is set as x * num_negative_class / num_positive_class where x is the '
'--bow-task-pos-weight. Higher values will bias more towards recall, resulting in larger vocabularies '
'at test time trading off larger vocabularies for higher translation quality. Default %(default)s.')
params.add_argument('--bow-task-weight',
type=float_greater_or_equal(0.0),
default=1.0,
help='The weight of the auxiliary Bag-of-word (BOW) loss when --neural-vocab-selection is '
'enabled. Default %(default)s.')

params.add_argument('--bow-task-pos-weight',
type=float_greater_or_equal(0.0),
default=10,
help='The weight of the positive class (the set of words present on the target side) for the '
'BOW loss when --neural-vocab-selection is set as x * num_negative_class / '
'num_positive_class where x is the --bow-task-pos-weight. Higher values will bias more '
'towards recall, resulting in larger vocabularies at test time trading off larger '
'vocabularies for higher translation quality. Default %(default)s.')


def add_training_args(params):
Expand All @@ -866,16 +867,18 @@ def add_training_args(params):
type=str,
default=None,
choices=[C.LENGTH_TASK_RATIO, C.LENGTH_TASK_LENGTH],
help='If specified, adds an auxiliary task during training to predict source/target length ratios '
'(mean squared error loss), or absolute lengths (Poisson) loss. Default %(default)s.')
help='If specified, adds an auxiliary task during training to predict source/target '
'length ratios (mean squared error loss), or absolute lengths (Poisson) loss. '
'Default %(default)s.')
train_params.add_argument('--length-task-weight',
type=float_greater_or_equal(0.0),
default=1.0,
help='The weight of the auxiliary --length-task loss. Default %(default)s.')
train_params.add_argument('--length-task-layers',
type=int_greater_or_equal(1),
default=1,
help='Number of fully-connected layers for predicting the length ratio. Default %(default)s.')
help='Number of fully-connected layers for predicting the length ratio. '
'Default %(default)s.')

add_nvs_train_parameters(train_params)

Expand Down Expand Up @@ -1088,7 +1091,8 @@ def add_training_args(params):

train_params.add_argument('--keep-initializations',
action="store_true",
help='In addition to keeping the last n params files, also keep params from checkpoint 0.')
help='In addition to keeping the last n params files, also keep params from checkpoint '
'0.')

train_params.add_argument('--cache-last-best-params',
required=False,
Expand Down Expand Up @@ -1349,7 +1353,8 @@ def add_inference_args(params):

decode_params.add_argument('--skip-nvs',
action='store_true',
help='Manually turn off Neural Vocabulary Selection (NVS) to do a softmax over the full target vocabulary.',
help='Manually turn off Neural Vocabulary Selection (NVS) to do a softmax over the full '
'target vocabulary.',
default=False)

decode_params.add_argument('--nvs-thresh',
Expand Down Expand Up @@ -1406,13 +1411,14 @@ def add_brevity_penalty_args(params):
params.add_argument('--brevity-penalty-weight',
default=1.0,
type=float_greater_or_equal(0.0),
help='Scaler for the brevity penalty in beam search: weight * log(BP) + score. Default: %(default)s')
help='Scaler for the brevity penalty in beam search: weight * log(BP) + score. '
'Default: %(default)s')
params.add_argument('--brevity-penalty-constant-length-ratio',
default=0.0,
type=float_greater_or_equal(0.0),
help='Has effect if --brevity-penalty-type is set to \'constant\'. If positive, overrides the length '
'ratio, used for brevity penalty calculation, for all inputs. If zero, uses the average of length '
'ratios from the training data over all models. Default: %(default)s.')
help='Has effect if --brevity-penalty-type is set to \'constant\'. If positive, overrides the '
'length ratio, used for brevity penalty calculation, for all inputs. If zero, uses the '
'average of length ratios from the training data over all models. Default: %(default)s.')


def add_clamp_to_dtype_arg(params):
Expand Down
1 change: 0 additions & 1 deletion sockeye/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,6 @@ def __init__(self,
self.output_vocab_size = inference.model_output_vocab_size
self.output_factor_vocab_size = inference.model_output_factor_vocab_size
self._inference = inference
self.global_avoid_trie = None
assert inference._skip_softmax, "skipping softmax must be enabled for GreedySearch"
self.work_block = GreedyTop1()

Expand Down
7 changes: 2 additions & 5 deletions sockeye/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@
JSON_RESTRICT_LEXICON_KEY = "restrict_lexicon"
JSON_CONSTRAINTS_KEY = "constraints"
JSON_AVOID_KEY = "avoid"
JSON_ENCODING = "utf-8"

VERSION_NAME = "version"
CONFIG_NAME = "config"
Expand Down Expand Up @@ -285,7 +284,6 @@
DTYPE_BF16 = 'bfloat16'
DTYPE_FP16 = 'float16'
DTYPE_FP32 = 'float32'
DTYPE_TF32 = 'tf32'
DTYPE_INT8 = 'int8'
DTYPE_INT16 = 'int16'
DTYPE_INT32 = 'int32'
Expand Down Expand Up @@ -364,7 +362,6 @@
# sequence length count types
SEQ_LEN_IN_CHARACTERS = "char"
SEQ_LEN_IN_TOKENS = "token"
SEQ_LEN_IN_WORDS = "word" # use case: merge sub-words to original word before counting

# scoring
SCORING_TYPE_NEGLOGPROB = 'neglogprob'
Expand All @@ -383,12 +380,12 @@
BREVITY_PENALTY_LEARNED = 'learned'
BREVITY_PENALTY_NONE = 'none'

# k-nn
# k-nn
KNN_STATE_DATA_STORE_NAME = "keys.npy"
KNN_WORD_DATA_STORE_NAME = "vals.npy"
KNN_WORD_DATA_STORE_DTYPE = DTYPE_INT32
KNN_CONFIG_NAME = "config.yaml"
KNN_INDEX_NAME = "key_index"
KNN_EPSILON = 1e-6
DEFAULT_DATA_STORE_BLOCK_SIZE = 1024 * 1024
DEFAULT_KNN_LAMBDA = 0.8
DEFAULT_KNN_LAMBDA = 0.8
2 changes: 1 addition & 1 deletion sockeye/convert_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def convert_checkpoint_to_params(model_config_fname: str, checkpoint_dirname: st
model_config = model.SockeyeModel.load_config(model_config_fname)
sockeye_model = model.SockeyeModel(model_config)
# Gather the float32 params on CPU
state_dict = get_fp32_state_dict_from_zero1_checkpoint(checkpoint_dirname)
state_dict = dict(get_fp32_state_dict_from_zero1_checkpoint(checkpoint_dirname))
# Strip the first prefix from each param name to match the SockeyeModel
# Ex: 'model.encoder.layers...' -> 'encoder.layers...'
state_dict = {name[name.find('.') + 1:]: param for (name, param) in state_dict.items()}
Expand Down
7 changes: 4 additions & 3 deletions sockeye/data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,8 +363,8 @@ def create_shards(source_fnames: List[str],
:param target_fnames: The path to the target text (and optional token-parallel factor files).
:param num_shards: The total number of shards.
:param output_prefix: The prefix under which the shard files will be created.
:return: List of tuples of source (and source factor) file names and target (and target factor) file names for each shard
and a flag of whether the returned file names are temporary and can be deleted.
:return: List of tuples of source (and source factor) file names and target (and target factor) file names for each
shard and a flag of whether the returned file names are temporary and can be deleted.
"""
if num_shards == 1:
return [(tuple(source_fnames), tuple(target_fnames))], True
Expand Down Expand Up @@ -595,7 +595,8 @@ def prepare_data(source_fnames: List[str],
pool: multiprocessing.pool.Pool = None,
shards: List[Tuple[Tuple[str, ...], Tuple[str, ...]]] = None):
"""
:param shards: List of num_shards shards of parallel source and target tuples which in turn contain tuples to shard data factor file paths.
:param shards: List of num_shards shards of parallel source and target tuples which in turn contain tuples to shard
data factor file paths.
"""
logger.info("Preparing data.")
# write vocabularies to data folder
Expand Down
25 changes: 0 additions & 25 deletions sockeye/device.py

This file was deleted.

3 changes: 2 additions & 1 deletion sockeye/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ def forward(self, data: pt.Tensor, valid_length: pt.Tensor) -> Tuple[pt.Tensor,

_, max_len, __ = data.size()
# length_mask for source attention masking. Shape: (batch_size, max_len)
single_head_att_mask = layers.prepare_source_length_mask(valid_length, self.config.attention_heads, max_length=max_len, expand=False)
single_head_att_mask = layers.prepare_source_length_mask(valid_length, self.config.attention_heads,
max_length=max_len, expand=False)
# Shape: (batch_size, max_len) -> (batch_size * heads, 1, max_len)
att_mask = single_head_att_mask.unsqueeze(1).expand(-1, self.config.attention_heads, -1).reshape((-1, max_len)).unsqueeze(1)
att_mask = att_mask.expand(-1, max_len, -1)
Expand Down
12 changes: 6 additions & 6 deletions sockeye/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def raw_corpus_bleu(hypotheses: Iterable[str], references: Iterable[str],
:param offset: Smoothing constant.
:return: BLEU score as float between 0 and 1.
"""
return sacrebleu.raw_corpus_bleu(hypotheses, [references], smooth_value=offset).score / 100.0
return sacrebleu.raw_corpus_bleu(hypotheses, [references], smooth_value=offset).score / 100.0 # type: ignore


def raw_corpus_chrf(hypotheses: Iterable[str], references: Iterable[str]) -> float:
Expand All @@ -58,7 +58,7 @@ def raw_corpus_chrf(hypotheses: Iterable[str], references: Iterable[str]) -> flo
:param references: Reference stream.
:return: chrF score as float between 0 and 1.
"""
return sacrebleu.corpus_chrf(hypotheses, [references]).score
return sacrebleu.corpus_chrf(hypotheses, [references]).score # type: ignore


def raw_corpus_ter(hypotheses: Iterable[str], references: Iterable[str]) -> float:
Expand All @@ -69,8 +69,8 @@ def raw_corpus_ter(hypotheses: Iterable[str], references: Iterable[str]) -> floa
:param references: Reference stream.
:return: TER score as float between 0 and 1.
"""
ter = sacrebleu.metrics.TER(argparse.Namespace())
return ter.corpus_score(hypotheses, [references]).score
ter = sacrebleu.metrics.TER()
return ter.corpus_score(hypotheses, [references]).score # type: ignore


def raw_corpus_rouge1(hypotheses: Iterable[str], references: Iterable[str]) -> float:
Expand Down Expand Up @@ -186,8 +186,8 @@ def _print_mean_std_score(metrics: List[Tuple[str, Callable]], scores: Dict[str,
scores_mean_std = [] # type: List[str]
for name, _ in metrics:
if len(scores[name]) > 1:
score_mean = np.item(np.mean(scores[name]))
score_std = np.item(np.std(scores[name], ddof=1))
score_mean = np.mean(scores[name]).item()
score_std = np.std(scores[name], ddof=1).item()
scores_mean_std.append("%.3f\t%.3f" % (score_mean, score_std))
else:
score = scores[name][0]
Expand Down
12 changes: 6 additions & 6 deletions sockeye/generate_decoder_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import argparse
import logging
import os
from typing import Dict, List
from typing import List, Optional

import numpy as np
import torch as pt
Expand Down Expand Up @@ -50,7 +50,7 @@ def __init__(self,
self.num_dim = num_dim # dimension of a single entry
self.dtype = dtype
self.block_size = -1
self.mmap = None
self.mmap = None # type: Optional[np.memmap]
self.tail_idx = 0 # where the next entry should be inserted
self.size = 0 # size of storage already assigned

Expand Down Expand Up @@ -120,12 +120,12 @@ def __init__(self,
self.max_seq_len_target = max_seq_len_target

self.output_dir = output_dir
self.state_store_file = None
self.words_store_file = None
self.state_store_file = None # type: Optional[NumpyMemmapStorage]
self.words_store_file = None # type: Optional[NumpyMemmapStorage]

# info for KNNConfig
self.num_states = 0
self.dimension = None
self.dimension = None # type: Optional[int]
self.state_data_type = utils.get_numpy_dtype(state_data_type)
self.word_data_type = utils.get_numpy_dtype(word_data_type)

Expand Down Expand Up @@ -186,7 +186,7 @@ def generate_states_and_store(self,
trace_inputs = {'get_decoder_states': model_inputs}
self.traced_model = pt.jit.trace_module(self.model, trace_inputs, strict=False)
# shape: (batch, seq_len, hidden_dim)
decoder_states = self.traced_model.get_decoder_states(*model_inputs)
decoder_states = self.traced_model.get_decoder_states(*model_inputs) # type: ignore

# flatten batch and seq_len dimensions, remove pads on the target
pad_mask = (batch.target != C.PAD_ID)[:, :, 0] # shape: (batch, seq_len)
Expand Down
1 change: 0 additions & 1 deletion sockeye/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def get_max_output_length(input_length: int):
return max_input_len, get_max_output_length


BeamHistory = Dict[str, List]
Tokens = List[str]
TokenIds = List[List[int]] # each token id may contain multiple factors
SentenceId = Union[int, str]
Expand Down
Loading

0 comments on commit 4dba5a3

Please sign in to comment.