Skip to content

Commit

Permalink
Resolve some codefactor issues (#756)
Browse files Browse the repository at this point in the history
* remove unnecessary pass statements

* use isinstance for type checks

* remove unnecessary else/elif after return

* remove unnecessary return statements

* move doc string to top

* merge isinstance calls

* remove unnecessary else/elif after raise

* use list comprehension

* do not use len without comparison

* add missing shebang

* revert isinstance check back to type

broke tests, because bool is actually subclass of int

* add missing period to doc string

* remove unnecessary pass statements

* use isinstance for type checks

* remove unnecessary else/elif after return

* remove unnecessary return statements

* move doc string to top

* merge isinstance calls

* remove unnecessary else/elif after raise

* use list comprehension

* do not use len without comparison

* add missing shebang

* revert isinstance check back to type

broke tests, because bool is actually subclass of int

* add missing period to doc string

* Fix default ckpt path when logger exists (#771)

* rename logging -> loggers (#767)

* move logging >> loggers

* add warning

* fix tests

* logging alias

* formatting

* formatting

* use isinstance for type checks

* revert isinstance check back to type

broke tests, because bool is actually subclass of int

* add more detail to tbptt example (#755)

* add more detail to tbptt example

* warn user about new arg in training_step

Co-authored-by: Vadim Bereznyuk <kuynzereb@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com>
  • Loading branch information
4 people committed Feb 1, 2020
1 parent 5e97e66 commit 472f394
Show file tree
Hide file tree
Showing 15 changed files with 46 additions and 81 deletions.
2 changes: 2 additions & 0 deletions .run_local_tests.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/usr/bin/env bash

# use this to run tests
rm -rf _ckpt_*
rm -rf tests/save_dir*
Expand Down
8 changes: 3 additions & 5 deletions pytorch_lightning/callbacks/pt_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def set_params(self, params):
self.params = params

def set_model(self, model):
if type(model) is LightningDistributedDataParallel:
if isinstance(model, LightningDistributedDataParallel):
model = model.module
self.model = model

Expand All @@ -43,7 +43,6 @@ def on_epoch_begin(self, epoch, logs=None):
on_epoch_begin(epoch=2, logs={'val_loss': 0.2})
"""
pass

def on_epoch_end(self, epoch, logs=None):
pass
Expand All @@ -56,7 +55,6 @@ def on_batch_begin(self, batch, logs=None):
batch (Tensor): current batch tensor
logs (dict): key-value pairs of quantities to monitor
"""
pass

def on_batch_end(self, batch, logs=None):
pass
Expand Down Expand Up @@ -143,7 +141,7 @@ def check_metrics(self, logs):
if monitor_val is None:
if self.strict:
raise RuntimeError(error_msg)
elif self.verbose > 0:
if self.verbose > 0:
warnings.warn(error_msg, RuntimeWarning)

return False
Expand Down Expand Up @@ -399,7 +397,7 @@ def __init__(self, scheduling: dict):
if minimal_epoch < 1:
msg = f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct"
raise IndexError(msg)
elif minimal_epoch != 1: # if user didnt define first epoch accumulation factor
if minimal_epoch != 1: # if user didnt define first epoch accumulation factor
scheduling.update({1: 1})

self.scheduling = scheduling
Expand Down
11 changes: 0 additions & 11 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,19 @@ def on_sanity_check_start(self):
.. warning:: will be deprecated.
:return:
"""
pass

def on_train_start(self):
"""Called at the beginning of training before sanity check
:return:
"""
# do something at the start of training
pass

def on_train_end(self):
"""
Called at the end of training before logger experiment is closed
:return:
"""
# do something at the end of training
pass

def on_batch_start(self, batch):
"""Called in the training loop before anything happens for that batch.
Expand All @@ -56,32 +53,26 @@ def on_batch_start(self, batch):
:return:
"""
# do something when the batch starts
pass

def on_batch_end(self):
"""Called in the training loop after the batch."""
# do something when the batch ends
pass

def on_epoch_start(self):
"""Called in the training loop at the very beginning of the epoch."""
# do something when the epoch starts
pass

def on_epoch_end(self):
"""Called in the training loop at the very end of the epoch."""
# do something when the epoch ends
pass

def on_pre_performance_check(self):
"""Called at the very beginning of the validation loop."""
# do something before validation starts
pass

def on_post_performance_check(self):
"""Called at the very end of the validation loop."""
# do something before validation end
pass

def on_before_zero_grad(self, optimizer):
"""Called after optimizer.step() and before optimizer.zero_grad()
Expand All @@ -99,7 +90,6 @@ def on_before_zero_grad(self, optimizer):
:return:
"""
# do something with the optimizer or inspect it.
pass

def on_after_backward(self):
"""Called after loss.backward() and before optimizers do anything.
Expand All @@ -122,7 +112,6 @@ def on_after_backward(self):
global_step=self.trainer.global_step)
"""
pass

def backward(self, use_amp, loss, optimizer, optimizer_idx):
"""Override backward with your own implementation if you need to
Expand Down
12 changes: 2 additions & 10 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ def training_step(self, batch, batch_idx, hiddens):
You can also return a -1 instead of a dict to stop the current loop. This is useful if you want to
break out of the current training epoch early.
"""
pass

def validation_step(self, *args, **kwargs):
r"""
Expand Down Expand Up @@ -326,7 +325,6 @@ def validation_step(self, batch, batch_idx, dataset_idx):
.. note:: When the validation_step is called, the model has been put in eval mode and PyTorch gradients
have been disabled. At the end of validation, model goes back to training mode and gradients are enabled.
"""
pass

def test_step(self, *args, **kwargs):
"""return whatever outputs will need to be aggregated in test_end
Expand Down Expand Up @@ -395,7 +393,6 @@ def test_step(self, batch, batch_idx, dataset_idx):
The `dataset_idx` corresponds to the order of datasets returned in `test_dataloader`.
"""
pass

def validation_end(self, outputs):
"""Outputs has the appended output after each validation step.
Expand Down Expand Up @@ -467,7 +464,6 @@ def validation_end(self, outputs):
return results
"""
pass

def test_end(self, outputs):
"""Outputs has the appended output after each test step.
Expand Down Expand Up @@ -532,7 +528,6 @@ def test_end(self, outputs):
return results
"""
pass

def configure_ddp(self, model, device_ids):
r"""
Expand Down Expand Up @@ -842,8 +837,7 @@ def tbptt_split_batch(self, batch, split_size):
Each returned batch split is passed separately to training_step(...).
"""
time_dims = [len(x[0]) for x in batch if isinstance(
x, torch.Tensor) or isinstance(x, collections.Sequence)]
time_dims = [len(x[0]) for x in batch if isinstance(x, (torch.Tensor, collections.Sequence))]
assert len(time_dims) >= 1, "Unable to determine batch time dimension"
assert all(x == time_dims[0] for x in time_dims), "Batch time dimension length is ambiguous"

Expand Down Expand Up @@ -1192,7 +1186,6 @@ def on_load_checkpoint(self, checkpoint):
.. note:: Lighting auto-restores global step, epoch, and all training state including amp scaling.
No need for you to restore anything regarding training.
"""
pass

def on_save_checkpoint(self, checkpoint):
r"""
Expand All @@ -1216,7 +1209,6 @@ def on_save_checkpoint(self, checkpoint):
for you to store anything about training.
"""
pass


def load_hparams_from_tags_csv(tags_csv):
Expand All @@ -1236,7 +1228,7 @@ def load_hparams_from_tags_csv(tags_csv):
def convert(val):
constructors = [int, float, str]

if type(val) is str:
if isinstance(val, str):
if val.lower() == 'true':
return True
if val.lower() == 'false':
Expand Down
10 changes: 2 additions & 8 deletions pytorch_lightning/core/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def get_variable_sizes(self):
if isinstance(input_, (list, tuple)): # pragma: no cover
in_size = []
for x in input_:
if type(x) is list:
if isinstance(x, list):
in_size.append(len(x))
else:
in_size.append(x.size())
Expand All @@ -97,7 +97,6 @@ def get_variable_sizes(self):
self.in_sizes = in_sizes
self.out_sizes = out_sizes
assert len(in_sizes) == len(out_sizes)
return

def get_layer_names(self):
'''Collect Layer Names'''
Expand All @@ -112,21 +111,17 @@ def get_layer_names(self):

self.layer_names = names
self.layer_types = layer_types
return

def get_parameter_sizes(self):
'''Get sizes of all parameters in `model`'''
mods = self.named_modules()
sizes = []
for _, m in mods:
p = list(m.parameters())
modsz = []
for j in range(len(p)):
modsz.append(np.array(p[j].size()))
modsz = [np.array(param.size()) for param in p]
sizes.append(modsz)

self.param_sizes = sizes
return

def get_parameter_nums(self):
'''Get number of parameters in each layer'''
Expand All @@ -137,7 +132,6 @@ def get_parameter_nums(self):
all_params += np.prod(p)
param_nums.append(all_params)
self.param_nums = param_nums
return

def make_summary(self):
'''
Expand Down
4 changes: 0 additions & 4 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@ def on_load_checkpoint(self, checkpoint):
:param checkpoint:
:return:
"""
pass

def on_save_checkpoint(self, checkpoint):
"""
Give the model a chance to add something to the checkpoint.
state_dict is already there
"""
pass

# -------------------------
# OPTIONAL HOOKS
Expand All @@ -24,11 +22,9 @@ def on_hpc_save(self, checkpoint):
Hook to do whatever you need right before Slurm manager saves the model
:return:
"""
pass

def on_hpc_load(self, checkpoint):
"""
Hook to do whatever you need right before Slurm manager loads the model
:return:
"""
pass
3 changes: 0 additions & 3 deletions pytorch_lightning/loggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,15 @@ def log_hyperparams(self, params):

def save(self):
"""Save log data."""
pass

def finalize(self, status):
"""Do any processing that is necessary to finalize an experiment.
:param status: Status that the experiment finished with (e.g. success, failed, aborted)
"""
pass

def close(self):
"""Do any cleanup that is necessary to close an experiment."""
pass

@property
def rank(self):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,5 +143,5 @@ def _get_next_version(self):

if len(existing_versions) == 0:
return 0
else:
return max(existing_versions) + 1

return max(existing_versions) + 1
8 changes: 4 additions & 4 deletions pytorch_lightning/overrides/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_a_var(obj): # pragma: no cover
if isinstance(obj, torch.Tensor):
return obj

if isinstance(obj, list) or isinstance(obj, tuple):
if isinstance(obj, (list, tuple)):
for result in map(get_a_var, obj):
if isinstance(result, torch.Tensor):
return result
Expand Down Expand Up @@ -56,10 +56,10 @@ def forward(self, *inputs, **kwargs):
# lightning
if self.module.training:
return self.module.training_step(*inputs[0], **kwargs[0])
elif self.module.testing:
if self.module.testing:
return self.module.test_step(*inputs[0], **kwargs[0])
else:
return self.module.validation_step(*inputs[0], **kwargs[0])

return self.module.validation_step(*inputs[0], **kwargs[0])

replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
outputs = self.parallel_apply(replicas, inputs, kwargs)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids):

# when slurm is managing the task it sets the visible devices
if not is_slurm_managing_tasks:
if type(data_parallel_device_ids) is int:
if isinstance(data_parallel_device_ids, int):
id_str = ','.join(str(x) for x in list(range(data_parallel_device_ids)))
os.environ["CUDA_VISIBLE_DEVICES"] = id_str
else:
Expand Down
Loading

0 comments on commit 472f394

Please sign in to comment.