Skip to content

Commit

Permalink
zenodo integration
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmacavaney committed Sep 25, 2024
1 parent 98a8246 commit 8d791d4
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 6 deletions.
131 changes: 126 additions & 5 deletions pyterrier/_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import sys
import tarfile
import tempfile
import requests
from datetime import datetime
from hashlib import sha256
from pathlib import Path
from typing import Any, Dict, Iterator, Optional, Tuple, Union
Expand Down Expand Up @@ -222,6 +224,67 @@ def from_dataset(cls, dataset: str, variant: str, *, expected_sha256: Optional[s
branch=f'{dataset}.{variant}',
expected_sha256=expected_sha256)

def to_zenodo(self, *, pretty_name: Optional[str] = None, sandbox: bool = False) -> None:
"""Upload this artifact to Zenodo.
Args:
pretty_name: The human-readable name of the artifact.
sandbox: Whether to perform a test upload to the Zenodo sandbox.
"""
if sandbox:
base_url = 'https://sandbox.zenodo.org/api'
else:
base_url = 'https://zenodo.org/api'

access_token = os.environ.get('ZENODO_TOKEN')
params = {'access_token': access_token}

with tempfile.TemporaryDirectory() as d:
r = requests.post(f'{base_url}/deposit/depositions', params=params, json={})
r.raise_for_status()
deposit_data = r.json()
sys.stderr.write("Created {}\n".format(deposit_data['links']['html']))
try:
metadata = {}
sys.stderr.write("Building package.\n")
self.build_package(os.path.join(d, 'artifact.tar.lz4'), metadata_out=metadata)
z_meta = {
'metadata': self._zenodo_metadata(
pretty_name=pretty_name,
zenodo_id=deposit_data['id'],
metadata=metadata,
),
}
r = requests.put(deposit_data['links']['latest_draft'], params=params, json=z_meta)
r.raise_for_status()
sys.stderr.write("Uploading...\n")
for file in sorted(os.listdir(d)):
file_path = os.path.join(d, file)
with open(file_path, 'rb') as fin, \
pt.io.TqdmReader(fin, total=os.path.getsize(file_path), desc=file) as fin:
r = requests.put(
'{}/{}'.format(deposit_data['links']['bucket'], file),
params={'access_token': access_token},
data=fin)
r.raise_for_status()
except:
sys.stderr.write("Discarding {}\n".format(deposit_data['links']['html']))
requests.post(deposit_data['links']['discard'], params=params, json={})
raise
sys.stderr.write("Upload complete. Please complete the form at {} to publish this artifact.\n".format(
deposit_data['links']['html']))

@classmethod
def from_zenodo(cls, zenodo_id: str, *, expected_sha256: Optional[str] = None) -> 'Artifact':
"""Load an artifact from Zenodo.
Args:
zenodo_id: The Zenodo record ID of the artifact.
expected_sha256: The expected SHA-256 hash of the artifact. If provided, the downloaded artifact will be
verified against this hash and an error will be raised if the hash does not match.
"""
return cls.from_url(f'zenodo:{zenodo_id}', expected_sha256=expected_sha256)

def _package_files(self) -> Iterator[Tuple[str, Union[str, io.BytesIO]]]:
has_pt_meta_file = False
for root, dirs, files in os.walk(self.path):
Expand Down Expand Up @@ -315,6 +378,57 @@ def _hf_readme(self,
```
'''

def _zenodo_metadata(self, *, zenodo_id: str, pretty_name: Optional[str] = None, metadata: Dict) -> Optional[str]:
description = f'''
<h2>Description</h2>
<p>
<i>TODO: What is the artifact?</i>
</p>
<h2>Usage</h2>
<pre>
# Load the artifact
import pyterrier as pt
artifact = pt.Artifact.from_zenodo({str(zenodo_id)!r})
# TODO: Show how you use the artifact
</pre>
<h2>Benchmarks</h2>
<p>
<i>TODO: Provide benchmarks for the artifact.</i>
</p>
<h2>Reproduction</h2>
<pre>
# TODO: Show how you constructed the artifact.
</pre>
<h2>Metadata</h2>
<pre>
{json.dumps(metadata, indent=2)}
</pre>
'''
tags = ['pyterrier', 'pyterrier-artifact']
if 'type' in metadata:
tags.append('pyterrier-artifact.{type}'.format(**metadata))
if 'type' in metadata and 'format' in metadata:
tags.append('pyterrier-artifact.{type}.{format}'.format(**metadata))
metadata = {
'description': description,
'upload_type': 'other',
'publisher': 'Zenodo',
'publication_date': datetime.today().strftime('%Y-%m-%d'),
'keywords': tags,
}
if pretty_name:
metadata['title'] = pretty_name
return metadata

def build_package(
self,
package_path: Optional[str] = None,
Expand Down Expand Up @@ -398,7 +512,7 @@ def manage_maxsize(_: None):

if isinstance(file, io.BytesIO):
file.seek(0)
if rel_path == 'pt_meta.json':
if rel_path == 'pt_meta.json' and metadata_out is not None:
metadata_out.update(json.load(file))
file.seek(0)
with pt.io.CallbackReader(file, manage_maxsize) as fin:
Expand All @@ -407,7 +521,7 @@ def manage_maxsize(_: None):
with open(file, 'rb') as fin, \
pt.io.CallbackReader(fin, manage_maxsize) as fin:
tarout.addfile(tar_record, fin)
if rel_path == 'pt_meta.json':
if rel_path == 'pt_meta.json' and metadata_out is not None:
with open(file, 'rb') as fin:
metadata_out.update(json.load(fin))

Expand All @@ -416,9 +530,9 @@ def manage_maxsize(_: None):

metadata['expected_sha256'] = sha256_fout.hexdigest()

metadata_out = stack.enter_context(pt.io.finalized_open(f'{package_path}.json', 't'))
json.dump(metadata, metadata_out)
metadata_out.write('\n')
metadata_outf = stack.enter_context(pt.io.finalized_open(f'{package_path}.json', 't'))
json.dump(metadata, metadata_outf)
metadata_outf.write('\n')

if chunk_num == 0:
# no chunking was actually done, can use provided name directly
Expand Down Expand Up @@ -513,9 +627,16 @@ def _path_repr(path: str) -> str:


def _hf_url_resolver(parsed_url: ParseResult) -> str:
# paths like: hf:macavaney/msmarco-passage.terrier
org_repo = parsed_url.path
# default to ref=main, but allow user to specify another branch, hash, etc, with abc/xyz@branch
ref = 'main'
if '@' in org_repo:
org_repo, ref = org_repo.split('@', 1)
return f'https://huggingface.co/datasets/{org_repo}/resolve/{ref}/artifact.tar.lz4'


def _zenodo_url_resolver(parsed_url: ParseResult) -> str:
# paths like: zenodo:111952
zenodo_id = parsed_url.path
return f'https://zenodo.org/records/{zenodo_id}/files/artifact.tar.lz4'
1 change: 1 addition & 0 deletions pyterrier/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@ def close(self) -> None:
"""Close the reader and the progress bar."""
super().close()
self.reader.close()
self.pbar.close()


class CallbackReader(_NosyReader):
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def get_version(rel_path):
'terrier = pyterrier.terrier._metadata_adapter:terrier_artifact_metadata_adapter',
],
'pyterrier.artifact.url_protocol_resolver': [
'hf = pyterrier._artifact:_hf_url_resolver'
'hf = pyterrier._artifact:_hf_url_resolver',
'zenodo = pyterrier._artifact:_zenodo_url_resolver',
],
},
install_requires=requirements,
Expand Down

0 comments on commit 8d791d4

Please sign in to comment.