Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow overwrite on upload retry in remote uploader downloader #3310

Merged
merged 6 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions composer/loggers/remote_uploader_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,8 +673,8 @@ def _upload_worker(

# defining as a function-in-function to use decorator notation with num_attempts as an argument
@retry(ObjectStoreTransientError, num_attempts=num_attempts)
def upload_file():
if not overwrite:
def upload_file(retry_index: int = 0):
if retry_index == 0 and not overwrite:
try:
remote_backend.get_object_size(remote_file_name)
except FileNotFoundError:
Expand Down
14 changes: 8 additions & 6 deletions composer/utils/retrying.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import collections.abc
import functools
import inspect
import logging
import random
import time
Expand Down Expand Up @@ -46,18 +47,16 @@ def retry( # type: ignore

Attempts are spaced out with ``initial_backoff + 2**num_attempts + random.random() * max_jitter`` seconds.

Optionally, the decorated function can specify `retry_index` as an argument to receive the current attempt number.

Example:
.. testcode::

from composer.utils import retry

num_tries = 0

@retry(RuntimeError, num_attempts=3, initial_backoff=0.1)
def flaky_function():
global num_tries
if num_tries < 2:
num_tries += 1
def flaky_function(retry_index: int):
if retry_index < 2:
raise RuntimeError("Called too soon!")
return "Third time's a charm."

Expand All @@ -84,9 +83,12 @@ def wrapped_func(func: TCallable) -> TCallable:

@functools.wraps(func)
def new_func(*args: Any, **kwargs: Any):
retry_index_param = 'retry_index'
i = 0
while True:
try:
if retry_index_param in inspect.signature(func).parameters:
kwargs[retry_index_param] = i
return func(*args, **kwargs)
except exc_class as e:
log.debug(f'Attempt {i} failed. Exception type: {type(e)}, message: {str(e)}.')
Expand Down
60 changes: 59 additions & 1 deletion tests/loggers/test_remote_uploader_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from composer.core import Event, State
from composer.loggers import Logger, RemoteUploaderDownloader
from composer.utils.object_store.object_store import ObjectStore
from composer.utils.object_store.object_store import ObjectStore, ObjectStoreTransientError


class DummyObjectStore(ObjectStore):
Expand Down Expand Up @@ -190,6 +190,64 @@ def test_remote_uploader_downloader_no_overwrite(
)


def test_allow_overwrite_on_retry(tmp_path: pathlib.Path, dummy_state: State):
file_path = tmp_path / 'samples' / 'sample'
os.makedirs(tmp_path / 'samples')
with open(file_path, 'w') as f:
f.write('sample')

# Dummy object store that fails the first two uploads
# This tests that the remote uploader downloader allows overwriting a partially uploaded file on a retry.
class RetryDummyObjectStore(DummyObjectStore):

def __init__(
self,
dir: Optional[pathlib.Path] = None,
always_fail: bool = False,
**kwargs: Dict[str, Any],
) -> None:
self._retry = 0
super().__init__(dir, always_fail, **kwargs)

def upload_object(
self,
object_name: str,
filename: Union[str, pathlib.Path],
callback: Optional[Callable[[int, int], None]] = None,
) -> None:
if self._retry < 2:
self._retry += 1 # Takes two retries to upload the file
raise ObjectStoreTransientError('Retry this')
self._retry += 1
return super().upload_object(object_name, filename, callback)

def get_object_size(self, object_name: str) -> int:
if self._retry > 0:
return 1 # The 0th upload resulted in a partial upload
return super().get_object_size(object_name)

fork_context = multiprocessing.get_context('fork')
with patch('composer.loggers.remote_uploader_downloader.S3ObjectStore', RetryDummyObjectStore):
with patch('composer.loggers.remote_uploader_downloader.multiprocessing.get_context', lambda _: fork_context):
remote_uploader_downloader = RemoteUploaderDownloader(
bucket_uri=f"s3://{tmp_path}/'object_store_backend",
backend_kwargs={
'dir': tmp_path / 'object_store_backend',
},
num_concurrent_uploads=4,
upload_staging_folder=str(tmp_path / 'staging_folder'),
use_procs=True,
num_attempts=3,
)
logger = Logger(dummy_state, destinations=[remote_uploader_downloader])

remote_uploader_downloader.run_event(Event.INIT, dummy_state, logger)
remote_file_name = 'remote_file_name'
remote_uploader_downloader.upload_file(dummy_state, remote_file_name, file_path, overwrite=False)
remote_uploader_downloader.close(dummy_state, logger=logger)
remote_uploader_downloader.post_close()


@pytest.mark.parametrize('use_procs', [True, False])
def test_race_with_overwrite(tmp_path: pathlib.Path, use_procs: bool, dummy_state: State):
# Test a race condition with the object store logger where multiple files with the same name are logged in rapid succession
Expand Down
Loading