Skip to content

Commit

Permalink
Merge branch 'dev' into set-epoch-on-batch-sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
mvpatel2000 authored Apr 2, 2024
2 parents 14eb4f9 + b0d33c4 commit e821428
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 24 deletions.
7 changes: 5 additions & 2 deletions composer/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,11 @@ def sigterm_handler(signal, frame):
sys.exit(128 + signal)


signal.signal(signal.SIGTERM, sigterm_handler)
signal.signal(signal.SIGINT, sigterm_handler)
try:
signal.signal(signal.SIGTERM, sigterm_handler)
signal.signal(signal.SIGINT, sigterm_handler)
except ValueError:
log.warning('Failed to set signal handler. Checkpoints may not be flushed if the process is killed.')


def _get_default_passes():
Expand Down
4 changes: 2 additions & 2 deletions composer/trainer/mosaic_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def patch_pytorch():
from torch.distributed.fsdp import _runtime_utils
_runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None

elif version.parse(torch.__version__) < version.parse('2.3.0'):
# Monkey patch for torch < 2.2.2 ie torch == 2.2.1
elif version.parse(torch.__version__) < version.parse('2.2.9'):
# Monkey patch for torch < 2.3.0 ie torch == 2.2.1/2.2.2 currently
pass

elif version.parse(torch.__version__) < version.parse('2.3.1'):
Expand Down
22 changes: 2 additions & 20 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,16 +227,7 @@ def __init__(
metadata_destination = os.path.join(self.destination_path, '.metadata')
if dist.get_local_rank() == 0:
metadata_path = str(Path(source_path) / Path('.metadata'))
if isinstance(object_store, ObjectStore):
object_store.download_object(
object_name=metadata_path,
filename=metadata_destination,
)
else:
object_store.download_file(
remote_file_name=metadata_path,
destination=metadata_destination,
)
download_object_or_file(metadata_path, metadata_destination, object_store)
dist.barrier()

# FileSystemReader takes in a root directory in its constructor, which is the dir where
Expand Down Expand Up @@ -385,16 +376,7 @@ def is_checkpoint_legacy_sharded(object_store: Optional[Union[LoggerDestination,
_, _, metadata_path = parse_uri(metadata_path)
with tempfile.TemporaryDirectory() as temp_dir:
metadata_destination = os.path.join(str(temp_dir), '.metadata')
if isinstance(object_store, ObjectStore):
object_store.download_object(
object_name=metadata_path,
filename=metadata_destination,
)
else:
object_store.download_file(
remote_file_name=metadata_path,
destination=metadata_destination,
)
download_object_or_file(metadata_path, metadata_destination, object_store)
return False
except FileNotFoundError:
return True
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def package_files(prefix: str, directory: str, extension: str):
'cryptography==41.0.5',
'pytest-httpserver>=1.0.4,<1.1',
'setuptools<=59.5.0',
'pillow==9.3.0', # Matches the Pillow version listed in the Dockerfile
]

extra_deps['system_metrics_monitor'] = {
Expand Down
14 changes: 14 additions & 0 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

import importlib
import logging
import os
import subprocess
import sys
import textwrap
import threading
from pathlib import Path
from typing import List
from unittest.mock import Mock
Expand Down Expand Up @@ -323,3 +325,15 @@ def test_logging(
('composer.core.engine', 10, 'Post-closing callback EventCounterCallback'),
('composer.core.engine', 10, 'Engine closed.'),
]


def _worker():
import composer.core.engine
importlib.reload(composer.core.engine)


def test_graceful_fallback_when_signal_handler_cannot_be_set():
# https://github.com/mosaicml/composer/issues/3151#issue-2205981731
t = threading.Thread(target=_worker)
t.start()
t.join()

0 comments on commit e821428

Please sign in to comment.