Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AWS S3 i/o #2175

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn, rank_zero_only
import pytorch_lightning.utilities.cloud_io as cloud_io


class ModelCheckpoint(Callback):
Expand Down Expand Up @@ -95,7 +96,7 @@ class ModelCheckpoint(Callback):

def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', verbose: bool = False,
save_last: bool = False, save_top_k: int = 1, save_weights_only: bool = False,
mode: str = 'auto', period: int = 1, prefix: str = ''):
mode: str = 'auto', period: int = 1, prefix: str = '', remove_non_top_k_s3_files: bool = True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here I would keep it synced so I would drop remove_non_top_k_s3_files

super().__init__()
if save_top_k > 0 and filepath is not None and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0:
rank_zero_warn(
Expand All @@ -109,6 +110,10 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
if filepath is None: # will be determined by trainer at runtime
self.dirpath, self.filename = None, None
else:
if cloud_io.is_s3_path(filepath):
self.save_to_s3 = True
self.bucket, filepath = cloud_io.parse_s3_path(filepath)

if os.path.isdir(filepath):
self.dirpath, self.filename = filepath, '{epoch}'
else:
Expand All @@ -127,6 +132,7 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
self.best_model_score = 0
self.best_model_path = ''
self.save_function = None
self.remove_non_top_k_s3_files = remove_non_top_k_s3_files

torch_inf = torch.tensor(np.Inf)
mode_dict = {
Expand Down Expand Up @@ -158,6 +164,9 @@ def kth_best_model(self):
def _del_model(self, filepath):
if os.path.isfile(filepath):
os.remove(filepath)
if self.save_to_s3:
if self.remove_non_top_k_s3_files:
cloud_io.remove_checkpoint_from_s3(self.bucket, filepath)

def _save_model(self, filepath):
# make paths
Expand All @@ -168,6 +177,8 @@ def _save_model(self, filepath):
self.save_function(filepath, self.save_weights_only)
else:
raise ValueError(".save_function() not set")
if self.save_to_s3:
cloud_io.save_checkpoint_to_s3(self.bucket, filepath)

def check_monitor_top_k(self, current):
less_than_k_models = len(self.best_k_models) < self.save_top_k
Expand Down
29 changes: 16 additions & 13 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,20 @@ def load_from_metrics(cls, weights_path, tags_csv, map_location=None):
"""
rank_zero_warn(
"`load_from_metrics` method has been unified with `load_from_checkpoint` in v0.7.0."
" The deprecated method will be removed in v0.9.0.", DeprecationWarning
" The deprecated method will be removed in v0.9.0.",
DeprecationWarning,
)
return cls.load_from_checkpoint(weights_path, tags_csv=tags_csv, map_location=map_location)

@classmethod
def load_from_checkpoint(
cls,
checkpoint_path: str,
*args,
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
hparams_file: Optional[str] = None,
tags_csv: Optional[str] = None, # backward compatible, todo: remove in v0.9.0
**kwargs
cls,
checkpoint_path: str,
*args,
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
hparams_file: Optional[str] = None,
tags_csv: Optional[str] = None, # backward compatible, todo: remove in v0.9.0
**kwargs,
):
r"""
Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint
Expand Down Expand Up @@ -136,10 +137,12 @@ def load_from_checkpoint(
pretrained_model.freeze()
y_hat = pretrained_model(x)
"""
if map_location is not None:
checkpoint = pl_load(checkpoint_path, map_location=map_location)
else:
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
if not map_location:

def map_location(storage, loc):
return storage

checkpoint = pl_load(checkpoint_path, map_location=map_location)

# add the hparams from csv file to checkpoint
if tags_csv is not None:
Expand Down Expand Up @@ -193,7 +196,7 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs):
if args_name in init_args_name:
kwargs.update({args_name: model_args})
else:
args = (model_args, ) + args
args = (model_args,) + args

# load the state_dict on the model automatically
model = cls(*args, **kwargs)
Expand Down
98 changes: 96 additions & 2 deletions pytorch_lightning/utilities/cloud_io.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,106 @@
from typing import Tuple
from urllib.parse import urlparse
import os.path as osp
import os
import torch
from torch.hub import _get_torch_home

from pathlib import Path
from urllib.parse import urlparse
import logging

logger = logging.getLogger(__name__)

torch_cache_home = _get_torch_home()
default_cache_path = osp.join(torch_cache_home, "pl_checkpoints")


def try_import_boto3():
try:
import boto3
except ImportError:
raise ImportError(f'Could not import `boto3`. Please `pip install boto3` and try again.')


def load(path_or_url: str, map_location=None):
parsed = urlparse(path_or_url)
if parsed.scheme == '' or Path(path_or_url).is_file():
# no scheme or local file
return torch.load(path_or_url, map_location=map_location)
elif parsed.scheme == 's3':
# AWS S3 file
filepath = download_checkpoint_from_s3(path_or_url)
return torch.load(filepath, map_location=map_location)
# URL
return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location)


def is_s3_path(path: str):
"""Checks if path is a valid S3 path"""
return path.startswith("s3://")


def parse_s3_path(s3_path: str) -> Tuple[str, str]:
"""
Returns bucket and key from an S3 path.
Example: "s3://my-bucket/folder/checkpoint.ckpt" -> ("my-bucket", "folder/checkpoint.ckpt")
"""
s3_path = urlparse(s3_path, allow_fragments=True)
assert s3_path.scheme == 's3', f'{s3_path} is not a valid AWS S3 path. Needs to start with `s3://`'
bucket, key = s3_path.netloc, s3_path.path
if key.startswith('/'):
key = key[1:]
return bucket, key


def save_checkpoint_to_s3(bucket_name, key):
"""
Saves a single checkpoint to an S3 path.
Args:
bucket_name: The name of the bucket we want to save to
key: The rest of the s3 path.
Returns:
None
"""
try_import_boto3()
bucket = boto3.resource("s3").Bucket(bucket_name)
bucket.upload_file(Filename=key, Key=key)


def download_checkpoint_from_s3(path_or_url: str, overwrite=False) -> str:
"""
Downloads file from S3 and saves it in default cache path under original S3 key.
Returns filepath where object has been downloaded.
"""
try_import_boto3()

# Eg "s3://bucket-name/folder/checkpoint.ckpt" --> ("bucket-name", "folder/checkpoint.ckpt")
bucket_name, key = parse_s3_path(path_or_url)

# ("folder", "checkpoint.ckpt")
directory, filename = osp.split(key)

# Make directory: '/Users/johnDoe/.cache/torch/pl_checkpoints/folder'
directory_to_make = osp.join(default_cache_path, directory)
os.makedirs(directory_to_make, exist_ok=True)

# File we will download to: '/Users/johnDoe/.cache/torch/pl_checkpoints/folder/checkpoint.ckpt'
filepath = osp.join(directory_to_make, filename)

def _download():
s3 = boto3.resource("s3")
bucket = s3.Bucket(bucket_name)
bucket.download_file(Key=key, Filename=filepath)

if not osp.exists(filepath):
_download()
else:
if overwrite:
_download()
return filepath


def remove_checkpoint_from_s3(bucket, key):
"""Simple remove object from S3"""
try_import_boto3()
s3 = boto3.resource("s3")
obj = s3.Object(bucket, key)
obj.delete()
3 changes: 2 additions & 1 deletion requirements/devel.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@
# extended list of dependencies dor development and run lint and tests
-r ./test.txt

cloudpickle>=1.2
cloudpickle>=1.2
boto3
1 change: 1 addition & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ check-manifest
twine==1.13.0
black==19.10b0
pre-commit>=1.0
moto
20 changes: 20 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

import pytest
import torch.multiprocessing as mp
import boto3
from moto import mock_s3
import os


def pytest_configure(config):
Expand Down Expand Up @@ -55,3 +58,20 @@ class ThreadingHTTPServer(ThreadingMixIn, HTTPServer):
server_thread.start()
yield server.server_address
server.shutdown()


@pytest.fixture(scope='function')
def aws_credentials():
"""Mocked AWS Credentials for moto."""
os.environ['AWS_ACCESS_KEY_ID'] = 'testing'
os.environ['AWS_SECRET_ACCESS_KEY'] = 'testing'
os.environ['AWS_SECURITY_TOKEN'] = 'testing'
os.environ['AWS_SESSION_TOKEN'] = 'testing'


@pytest.fixture(scope='function')
def s3(aws_credentials):
with mock_s3():
s3 = boto3.client('s3')
s3.create_bucket(Bucket='testing')
yield s3