Skip to content

Commit

Permalink
Merge branch 'dev' into set-epoch-on-batch-sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
mvpatel2000 authored Mar 19, 2024
2 parents ebf08ba + ee13dea commit 660d7fd
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 13 deletions.
8 changes: 6 additions & 2 deletions composer/loggers/mosaicml_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,12 +271,16 @@ def dict_to_str(data: Dict[str, Any]):

def exception_to_json_serializable_dict(exc: Exception):
"""Converts exception into a JSON serializable dictionary for run metadata."""
default_exc_attrs = set(dir(Exception()))
exc_data = {'class': exc.__class__.__name__, 'message': str(exc), 'attributes': {}}

for attr in dir(exc):
if not attr.startswith('__') and attr not in ['args', 'with_traceback']:
# ignore the traceback and default args in exception object
# Exclude default attributes and special methods
if attr not in default_exc_attrs and not attr.startswith('__'):
try:
value = getattr(exc, attr)
if callable(value):
continue
if isinstance(value, (str, int, float, bool, list, dict, type(None))):
exc_data['attributes'][attr] = value
else:
Expand Down
8 changes: 4 additions & 4 deletions composer/models/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,19 @@ def get_initializer(self) -> Callable[[torch.nn.Module], None]:
"""

def kaiming_normal(w: nn.Module):
if isinstance(w, torch.nn.Linear) or isinstance(w, torch.nn.Conv2d):
if isinstance(w, (torch.nn.Linear, torch.nn.Conv2d)):
torch.nn.init.kaiming_normal_(w.weight)

def kaiming_uniform(w: nn.Module):
if isinstance(w, torch.nn.Linear) or isinstance(w, torch.nn.Conv2d):
if isinstance(w, (torch.nn.Linear, torch.nn.Conv2d)):
torch.nn.init.kaiming_uniform_(w.weight)

def xavier_uniform(w: nn.Module):
if isinstance(w, torch.nn.Linear) or isinstance(w, torch.nn.Conv2d):
if isinstance(w, (torch.nn.Linear, torch.nn.Conv2d)):
torch.nn.init.xavier_uniform_(w.weight)

def xavier_normal(w: nn.Module):
if isinstance(w, torch.nn.Linear) or isinstance(w, torch.nn.Conv2d):
if isinstance(w, (torch.nn.Linear, torch.nn.Conv2d)):
torch.nn.init.xavier_normal_(w.weight)

def bn_ones(w: nn.Module):
Expand Down
5 changes: 2 additions & 3 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,7 +1038,6 @@ def _save_checkpoint(
'integrations': state._get_integrations_state_dict(),
'metadata': state._get_state_metadata(),
},
'rng': reproducibility.get_rng_state(),
}
else:
state_dict = {
Expand All @@ -1055,7 +1054,7 @@ def _save_checkpoint(
# Ensure state exists
state_dict['state'] = state_dict.get('state', {})

if state.fsdp_sharded_state_dict_enabled:
if state.fsdp_sharded_state_dict_enabled and not weights_only:
# Only rank 0 saves RNG
if dist.get_global_rank() > 0:
state_dict.pop('rng')
Expand All @@ -1064,7 +1063,7 @@ def _save_checkpoint(
# requires a top level state dict key for the optimizer.
# See https://github.com/pytorch/pytorch/blob/v2.0.1/torch/distributed/checkpoint/optimizer.py#L271
# for more info.
if version.parse(torch.__version__) < version.parse('2.2.9') and not weights_only:
if version.parse(torch.__version__) < version.parse('2.2.9'):
state_dict['optimizers'] = state_dict['state'].pop('optimizers')

log.debug('State dict created.')
Expand Down
2 changes: 1 addition & 1 deletion composer/utils/object_store/uc_object_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

def _wrap_errors(uri: str, e: Exception):
from databricks.sdk.core import DatabricksError
from databricks.sdk.errors.mapping import NotFound
from databricks.sdk.errors.platform import NotFound
if isinstance(e, DatabricksError):
if isinstance(e, NotFound) or e.error_code == _NOT_FOUND_ERROR_CODE: # type: ignore
raise FileNotFoundError(f'Object {uri} not found') from e
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def package_files(prefix: str, directory: str, extension: str):
'coolname>=1.1.0,<3',
'tabulate==0.9.0', # for auto-generating tables
'py-cpuinfo>=8.0.0,<10',
'packaging>=21.3.0,<23.3',
'packaging>=21.3.0,<24.1',
'importlib-metadata>=5.0.0,<7',
'mosaicml-cli>=0.5.25,<0.7',
]
Expand All @@ -102,7 +102,7 @@ def package_files(prefix: str, directory: str, extension: str):
# Should manually update dependency versions occassionally.
'custom_inherit==2.4.1',
'junitparser==3.1.2',
'coverage[toml]==7.4.3',
'coverage[toml]==7.4.4',
'fasteners==0.18', # object store tests require fasteners
'pytest==7.4.4',
'ipython==8.11.0',
Expand Down Expand Up @@ -226,7 +226,7 @@ def package_files(prefix: str, directory: str, extension: str):

extra_deps['pandas'] = ['pandas>=2.0.0,<3.0']

extra_deps['databricks'] = ['databricks-sdk==0.18.0']
extra_deps['databricks'] = ['databricks-sdk==0.22.0']

extra_deps['all'] = {dep for deps in extra_deps.values() for dep in deps}

Expand Down

0 comments on commit 660d7fd

Please sign in to comment.