diff --git a/composer/distributed/dist_strategy.py b/composer/distributed/dist_strategy.py index ad08e172b8..1cc1044a02 100644 --- a/composer/distributed/dist_strategy.py +++ b/composer/distributed/dist_strategy.py @@ -209,7 +209,7 @@ def prepare_fsdp_module( Args: model (torch.nn.Module): The model to wrap. optimizers (torch.optim.Optimizer | Sequence[torch.optim.Optimizer], optional): The optimizer for `model`, assumed to have a single param group := model.parameters(). - fsdp_config (dict[str, Any]): The FSDP config. + fsdp_config (FSDPConfig): The FSDP config. precision: (Precision): The precision being used by the Trainer, used to fill in defaults for FSDP `mixed_precision` settings. device (Device): The device being used by the Trainer. auto_microbatching (bool, optional): Whether or not auto microbatching is enabled.