Skip to content

Commit

Permalink
rename accelerator_backend -> accelerator (#6034)
Browse files Browse the repository at this point in the history
* rename accelerator backend

* rename new additions from master

* add proper deprecation

* pep8

* warning match

* add missing warning type
  • Loading branch information
awaelchli authored Feb 18, 2021
1 parent 02ac4b0 commit 6cc1a06
Show file tree
Hide file tree
Showing 24 changed files with 101 additions and 92 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `.get_model()` with explicit `.lightning_module` property ([#6035](https://github.com/PyTorchLightning/pytorch-lightning/pull/6035))


- Deprecated Trainer attribute `accelerator_backend` in favor of `accelerator` ([#6034](https://github.com/PyTorchLightning/pytorch-lightning/pull/6034))



### Removed

- Removed deprecated checkpoint argument `filepath` ([#5321](https://github.com/PyTorchLightning/pytorch-lightning/pull/5321))
Expand Down
2 changes: 1 addition & 1 deletion pl_examples/basic_examples/conv_sequential_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,6 @@ def instantiate_datamodule(args):
trainer.fit(model, cifar10_dm)
trainer.test(model, datamodule=cifar10_dm)

if trainer.accelerator_backend.rpc_enabled:
if trainer.accelerator.rpc_enabled:
# Called at the end of trainer to ensure all processes are killed
trainer.training_type_plugin.exit_rpc_process()
7 changes: 1 addition & 6 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args

Expand Down Expand Up @@ -448,11 +447,7 @@ def all_gather(
the output will also be a collection with tensors of this shape.
"""
group = group if group is not None else torch.distributed.group.WORLD
if self.trainer.accelerator_backend is not None:
all_gather = self.trainer.accelerator_backend.all_gather
else:
all_gather = all_gather_ddp_if_available

all_gather = self.trainer.accelerator.all_gather
data = convert_to_tensors(data, device=self.device)
all_gather = partial(all_gather, group=group, sync_grads=sync_grads)
return apply_to_collection(data, torch.Tensor, all_gather)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __optimizer_step(self, closure: Optional[Callable] = None, profiler_name: st
model = trainer.lightning_module

with trainer.profiler.profile(profiler_name):
trainer.accelerator_backend.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)
trainer.accelerator.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)

if self._trainer.train_loop.automatic_optimization:
trainer.train_loop.on_before_zero_grad(optimizer)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def init_deepspeed(self):
self._format_config()
self._config_initialized = True

precision = self.lightning_module.trainer.accelerator_backend.precision
precision = self.lightning_module.trainer.accelerator.precision
model = LightningDeepSpeedModule(pl_module=self.model, precision=precision)

if self.lightning_module.trainer.training:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
trainer.progress_bar_callback.disable()

self.model_to_device()
trainer.accelerator_backend.setup_optimizers(trainer)
trainer.accelerator.setup_optimizers(trainer)
trainer.precision_plugin.connect(self._model, None, None)

# replace trainer save_checkpoint to use `xm.save`
Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,7 @@ def hpc_save(self, folderpath: str, logger):

model.on_hpc_save(checkpoint)

if self.trainer.accelerator_backend:
checkpoint = self.trainer.accelerator_backend.on_save(checkpoint)
checkpoint = self.trainer.accelerator.on_save(checkpoint)

# do the actual save
# TODO: fix for anything with multiprocess DP, DDP, DDP2
Expand Down Expand Up @@ -286,7 +285,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
optimizer_states = []
for i, optimizer in enumerate(self.trainer.optimizers):
# Rely on accelerator to dump optimizer state
optimizer_state = self.trainer.accelerator_backend.optimizer_state(optimizer)
optimizer_state = self.trainer.accelerator.optimizer_state(optimizer)
optimizer_states.append(optimizer_state)

checkpoint['optimizer_states'] = optimizer_states
Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class TrainerDataLoadingMixin(ABC):
limit_val_batches: Union[int, float]
limit_test_batches: Union[int, float]
replace_sampler_ddp: bool
accelerator_backend: Accelerator
accelerator: Accelerator
num_nodes: int
num_processes: int
distributed_backend: Optional[str]
Expand Down Expand Up @@ -398,8 +398,7 @@ def request_dataloader(self, dataloader_fx: Callable) -> DataLoader:
dataloader = dataloader_fx()
dataloader = self._flatten_dl_only(dataloader)

if self.accelerator_backend is not None:
self.accelerator_backend.barrier('get_dataloaders')
self.accelerator.barrier('get_dataloaders')
return dataloader

def _flatten_dl_only(self, dataloaders):
Expand Down
12 changes: 11 additions & 1 deletion pytorch_lightning/trainer/deprecated_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.accelerators import Accelerator
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.states import RunningStage
Expand Down Expand Up @@ -133,10 +134,19 @@ def use_single_gpu(self, val: bool) -> None:
self.accelerator_connector._device_type = DeviceType.GPU


class DeprecatedModelAttributes:
class DeprecatedTrainerAttributes:

accelerator: Accelerator
lightning_module = LightningModule

@property
def accelerator_backend(self) -> Accelerator:
rank_zero_warn(
"The `Trainer.accelerator_backend` attribute is deprecated in favor of `Trainer.accelerator`"
" since 1.2 and will be removed in v1.4.", DeprecationWarning
)
return self.accelerator

def get_model(self) -> LightningModule:
rank_zero_warn(
"The use of `Trainer.get_model()` is deprecated in favor of `Trainer.lightning_module`"
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,11 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx):
if self.testing:
model_ref._current_fx_name = "test_step"
with self.trainer.profiler.profile("test_step"):
output = self.trainer.accelerator_backend.test_step(args)
output = self.trainer.accelerator.test_step(args)
else:
model_ref._current_fx_name = "validation_step"
with self.trainer.profiler.profile("validation_step"):
output = self.trainer.accelerator_backend.validation_step(args)
output = self.trainer.accelerator.validation_step(args)

# capture any logged information
self.trainer.logger_connector.cache_logged_metrics()
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/predict_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def predict(self, batch, batch_idx, dataloader_idx):
model_ref = self.trainer.lightning_module

model_ref._current_fx_name = "predict"
predictions = self.trainer.accelerator_backend.predict(args)
predictions = self.trainer.accelerator.predict(args)
self._predictions[dataloader_idx].append(predictions)
self.trainer._progress_bar_callback.on_predict_batch_end(
self.trainer, model_ref, predictions, batch, batch_idx, dataloader_idx
Expand Down
9 changes: 2 additions & 7 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,6 @@ class TrainerProperties(ABC):
def accelerator(self) -> Accelerator:
return self.accelerator_connector.accelerator

@property
def accelerator_backend(self) -> Accelerator:
# for backward compatibility
return self.accelerator

@property
def distributed_backend(self) -> Optional[str]:
# for backward compatibility
Expand Down Expand Up @@ -138,7 +133,7 @@ def log_dir(self) -> Optional[str]:
else:
dirpath = getattr(self.logger, 'log_dir' if isinstance(self.logger, TensorBoardLogger) else 'save_dir')

dirpath = self.accelerator_backend.broadcast(dirpath)
dirpath = self.accelerator.broadcast(dirpath)
return dirpath

@property
Expand Down Expand Up @@ -360,7 +355,7 @@ def lightning_optimizers(self) -> List[LightningOptimizer]:

@property
def lightning_module(self) -> LightningModule:
return self.accelerator_backend.lightning_module
return self.accelerator.lightning_module

@property
def optimizers(self) -> Optional[List[Optimizer]]:
Expand Down
26 changes: 13 additions & 13 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from pytorch_lightning.trainer.connectors.slurm_connector import SLURMConnector
from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.trainer.deprecated_api import DeprecatedDistDeviceAttributes, DeprecatedModelAttributes
from pytorch_lightning.trainer.deprecated_api import DeprecatedDistDeviceAttributes, DeprecatedTrainerAttributes
from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
Expand Down Expand Up @@ -80,7 +80,7 @@ class Trainer(
TrainerTrainingTricksMixin,
TrainerDataLoadingMixin,
DeprecatedDistDeviceAttributes,
DeprecatedModelAttributes,
DeprecatedTrainerAttributes,
):

@overwrite_by_env_vars
Expand Down Expand Up @@ -470,7 +470,7 @@ def fit(
# ----------------------------
self.call_setup_hook(model)
self.call_hook("on_before_accelerator_backend_setup", model)
self.accelerator_backend.setup(self, model)
self.accelerator.setup(self, model)
self.setup_trainer(model)

# ----------------------------
Expand Down Expand Up @@ -533,24 +533,24 @@ def fit(

self._set_running_stage(None, model)

return self.accelerator_backend.results or 1
return self.accelerator.results or 1

def pre_dispatch(self):
self.accelerator_backend.pre_dispatch()
self.accelerator.pre_dispatch()

def post_dispatch(self):
self.accelerator_backend.post_dispatch()
self.accelerator_backend.teardown()
self.accelerator.post_dispatch()
self.accelerator.teardown()

def dispatch(self):
if self.testing:
self.accelerator_backend.start_testing(self)
self.accelerator.start_testing(self)

elif self.predicting:
self.accelerator_backend.start_predicting(self)
self.accelerator.start_predicting(self)

else:
self.accelerator_backend.start_training(self)
self.accelerator.start_training(self)

def train_or_test_or_predict(self):
if self.testing:
Expand Down Expand Up @@ -949,7 +949,7 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders):
)
return {}
if not self._device_type == DeviceType.TPU:
self.accelerator_backend.barrier()
self.accelerator.barrier()

ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt['state_dict'])
Expand Down Expand Up @@ -1109,8 +1109,8 @@ def call_hook(self, hook_name, *args, **kwargs):

# if the PL module doesn't have the hook then call the accelerator
# used to auto-reduce things for the user with Results obj
elif hasattr(self.accelerator_backend, hook_name):
accelerator_hook = getattr(self.accelerator_backend, hook_name)
elif hasattr(self.accelerator, hook_name):
accelerator_hook = getattr(self.accelerator, hook_name)
output = accelerator_hook(*args, **kwargs)

if not skip:
Expand Down
12 changes: 6 additions & 6 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,8 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
model_ref._current_fx_name = 'training_step'
model_ref._results = Result()
with self.trainer.profiler.profile("training_step"):
training_step_output = self.trainer.accelerator_backend.training_step(args)
self.trainer.accelerator_backend.post_training_step()
training_step_output = self.trainer.accelerator.training_step(args)
self.trainer.accelerator.post_training_step()

self.trainer.logger_connector.cache_logged_metrics()

Expand Down Expand Up @@ -438,14 +438,14 @@ def on_before_zero_grad(self, optimizer):
self.trainer.call_hook('on_before_zero_grad', optimizer)

def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx):
self.trainer.accelerator_backend.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)

def track_and_norm_grad(self, optimizer):
# track gradient norms
grad_norm_dic = self._track_gradient_norm()

# clip gradients
self.trainer.accelerator_backend.clip_gradients(optimizer, self.trainer.gradient_clip_val)
self.trainer.accelerator.clip_gradients(optimizer, self.trainer.gradient_clip_val)
self._cur_grad_norm_dict = grad_norm_dic

def _track_gradient_norm(self):
Expand Down Expand Up @@ -769,9 +769,9 @@ def backward(self, result, optimizer, opt_idx, *args, **kwargs):

# backward can be called manually in the training loop
if isinstance(result, torch.Tensor):
self.trainer.accelerator_backend.backward(result, optimizer, opt_idx, should_accumulate, *args, **kwargs)
self.trainer.accelerator.backward(result, optimizer, opt_idx, should_accumulate, *args, **kwargs)
else:
result.closure_loss = self.trainer.accelerator_backend.backward(
result.closure_loss = self.trainer.accelerator.backward(
result.closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs
)

Expand Down
Loading

0 comments on commit 6cc1a06

Please sign in to comment.