Skip to content

Commit

Permalink
[WB-4604] Artifact checkout API (#1901)
Browse files Browse the repository at this point in the history
  • Loading branch information
annirudh authored Feb 26, 2021
1 parent 181fc47 commit 9238851
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 22 deletions.
6 changes: 4 additions & 2 deletions standalone_tests/artifact_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def proc_version_reader(stop_queue, stats_queue, project_name, artifact_name, re
continue
print('Reader downloading: ', version)
try:
version.download('read-%s' % reader_id)
version.checkout('read-%s' % reader_id)
except:
stats_queue.put({'read_download_error': 1})
print('Reader caught error on version.download')
Expand Down Expand Up @@ -340,7 +340,9 @@ def main(argv):
stats_queue,
project_name,
artifact_name,
i))
i
)
)
p.start()
procs.append(p)

Expand Down
14 changes: 14 additions & 0 deletions tests/test_public_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,20 @@ def test_artifact_download(runner, mock_server, api):
else:
part = "mnist:v0"
assert path == os.path.join(".", "artifacts", part)
assert os.listdir(path) == ["digits.h5"]


def test_artifact_checkout(runner, mock_server, api):
with runner.isolated_filesystem():
# Create a file that should be removed as part of checkout
os.makedirs(os.path.join(".", "artifacts", "mnist"))
with open(os.path.join(".", "artifacts", "mnist", "bogus"), "w") as f:
f.write("delete me, i'm a bogus file")

art = api.artifact("entity/project/mnist:v0", type="dataset")
path = art.checkout()
assert path == os.path.join(".", "artifacts", "mnist")
assert os.listdir(path) == ["digits.h5"]


def test_artifact_run_used(runner, mock_server, api):
Expand Down
75 changes: 55 additions & 20 deletions wandb/apis/public.py
Original file line number Diff line number Diff line change
Expand Up @@ -2916,6 +2916,55 @@ def download(self, root=None, recursive=False):

return dirpath

def checkout(self, root=None):
dirpath = root or self._default_root(include_version=False)

for root, _, files in os.walk(dirpath):
for file in files:
full_path = os.path.join(root, file)
artifact_path = util.to_forward_slash_path(
os.path.relpath(full_path, start=dirpath)
)
try:
self.get_path(artifact_path)
except KeyError:
# File is not part of the artifact, remove it.
os.remove(full_path)

return self.download(root=dirpath)

def verify(self, root=None):
dirpath = root or self._default_root()
manifest = self._load_manifest()
ref_count = 0

for root, _, files in os.walk(dirpath):
for file in files:
full_path = os.path.join(root, file)
artifact_path = util.to_forward_slash_path(
os.path.relpath(full_path, start=dirpath)
)
try:
self.get_path(artifact_path)
except KeyError:
raise ValueError(
"Found file {} which is not a member of artifact {}".format(
full_path, self.name
)
)

for entry in manifest.entries.values():
if entry.ref is None:
if (
artifacts.md5_file_b64(os.path.join(dirpath, entry.path))
!= entry.digest
):
raise ValueError("Digest mismatch for file: %s" % entry.path)
else:
ref_count += 1
if ref_count > 0:
print("Warning: skipped verification of %s refs" % ref_count)

def file(self, root=None):
"""Download a single file artifact to dir specified by the <root>
Expand Down Expand Up @@ -2943,8 +2992,12 @@ def _download_file(self, name, root):
# download file into cache and copy to target dir
return self.get_path(name).download(root)

def _default_root(self):
root = os.path.join(".", "artifacts", self.name)
def _default_root(self, include_version=True):
root = (
os.path.join(".", "artifacts", self.name)
if include_version
else os.path.join(".", "artifacts", self._sequence_name)
)
if platform.system() == "Windows":
head, tail = os.path.splitdrive(root)
root = head + tail.replace(":", "-")
Expand Down Expand Up @@ -2990,24 +3043,6 @@ def save(self):
)
return True

def verify(self, root=None):
dirpath = root
if dirpath is None:
dirpath = os.path.join(".", "artifacts", self.name)
manifest = self._load_manifest()
ref_count = 0
for entry in manifest.entries.values():
if entry.ref is None:
if (
artifacts.md5_file_b64(os.path.join(dirpath, entry.path))
!= entry.digest
):
raise ValueError("Digest mismatch for file: %s" % entry.path)
else:
ref_count += 1
if ref_count > 0:
print("Warning: skipped verification of %s refs" % ref_count)

def wait(self):
return self

Expand Down
15 changes: 15 additions & 0 deletions wandb/sdk/interface/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,21 @@ def download(self, root: Optional[str] = None, recursive: bool = False) -> str:
"""
raise NotImplementedError

def checkout(self, root: Optional[str] = None) -> str:
"""
Replaces the specified root directory with the contents of the artifact.
WARNING: This will DELETE all files in `root` that are not included in the
artifact.
Arguments:
root: (str, optional) The directory to replace with this artifact's files.
Returns:
(str): The path to the checked out contents.
"""
raise NotImplementedError

def verify(self, root: Optional[str] = None):
"""
Verify that the actual contents of an artifact at a specified directory
Expand Down
8 changes: 8 additions & 0 deletions wandb/sdk/wandb_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,14 @@ def download(self, root: str = None, recursive: bool = False):
"Cannot call download on an artifact before it has been logged or in offline mode"
)

def checkout(self, root: Optional[str] = None) -> str:
if self._logged_artifact:
return self._logged_artifact.checkout(root=root)

raise ValueError(
"Cannot call checkout on an artifact before it has been logged or in offline mode"
)

def verify(self, root: Optional[str] = None):
if self._logged_artifact:
return self._logged_artifact.verify(root=root)
Expand Down
15 changes: 15 additions & 0 deletions wandb/sdk_py27/interface/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,21 @@ def download(self, root = None, recursive = False):
"""
raise NotImplementedError

def checkout(self, root = None):
"""
Replaces the specified root directory with the contents of the artifact.
WARNING: This will DELETE all files in `root` that are not included in the
artifact.
Arguments:
root: (str, optional) The directory to replace with this artifact's files.
Returns:
(str): The path to the checked out contents.
"""
raise NotImplementedError

def verify(self, root = None):
"""
Verify that the actual contents of an artifact at a specified directory
Expand Down
8 changes: 8 additions & 0 deletions wandb/sdk_py27/wandb_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,14 @@ def download(self, root = None, recursive = False):
"Cannot call download on an artifact before it has been logged or in offline mode"
)

def checkout(self, root = None):
if self._logged_artifact:
return self._logged_artifact.checkout(root=root)

raise ValueError(
"Cannot call checkout on an artifact before it has been logged or in offline mode"
)

def verify(self, root = None):
if self._logged_artifact:
return self._logged_artifact.verify(root=root)
Expand Down

0 comments on commit 9238851

Please sign in to comment.