diff --git a/pyterrier/_artifact.py b/pyterrier/_artifact.py index 705b349a..b4c7e6eb 100644 --- a/pyterrier/_artifact.py +++ b/pyterrier/_artifact.py @@ -6,6 +6,8 @@ import sys import tarfile import tempfile +import requests +from datetime import datetime from hashlib import sha256 from pathlib import Path from typing import Any, Dict, Iterator, Optional, Tuple, Union @@ -222,6 +224,67 @@ def from_dataset(cls, dataset: str, variant: str, *, expected_sha256: Optional[s branch=f'{dataset}.{variant}', expected_sha256=expected_sha256) + def to_zenodo(self, *, pretty_name: Optional[str] = None, sandbox: bool = False) -> None: + """Upload this artifact to Zenodo. + + Args: + pretty_name: The human-readable name of the artifact. + sandbox: Whether to perform a test upload to the Zenodo sandbox. + """ + if sandbox: + base_url = 'https://sandbox.zenodo.org/api' + else: + base_url = 'https://zenodo.org/api' + + access_token = os.environ.get('ZENODO_TOKEN') + params = {'access_token': access_token} + + with tempfile.TemporaryDirectory() as d: + r = requests.post(f'{base_url}/deposit/depositions', params=params, json={}) + r.raise_for_status() + deposit_data = r.json() + sys.stderr.write("Created {}\n".format(deposit_data['links']['html'])) + try: + metadata = {} + sys.stderr.write("Building package.\n") + self.build_package(os.path.join(d, 'artifact.tar.lz4'), metadata_out=metadata) + z_meta = { + 'metadata': self._zenodo_metadata( + pretty_name=pretty_name, + zenodo_id=deposit_data['id'], + metadata=metadata, + ), + } + r = requests.put(deposit_data['links']['latest_draft'], params=params, json=z_meta) + r.raise_for_status() + sys.stderr.write("Uploading...\n") + for file in sorted(os.listdir(d)): + file_path = os.path.join(d, file) + with open(file_path, 'rb') as fin, \ + pt.io.TqdmReader(fin, total=os.path.getsize(file_path), desc=file) as fin: + r = requests.put( + '{}/{}'.format(deposit_data['links']['bucket'], file), + params={'access_token': access_token}, + data=fin) + r.raise_for_status() + except: + sys.stderr.write("Discarding {}\n".format(deposit_data['links']['html'])) + requests.post(deposit_data['links']['discard'], params=params, json={}) + raise + sys.stderr.write("Upload complete. Please complete the form at {} to publish this artifact.\n".format( + deposit_data['links']['html'])) + + @classmethod + def from_zenodo(cls, zenodo_id: str, *, expected_sha256: Optional[str] = None) -> 'Artifact': + """Load an artifact from Zenodo. + + Args: + zenodo_id: The Zenodo record ID of the artifact. + expected_sha256: The expected SHA-256 hash of the artifact. If provided, the downloaded artifact will be + verified against this hash and an error will be raised if the hash does not match. + """ + return cls.from_url(f'zenodo:{zenodo_id}', expected_sha256=expected_sha256) + def _package_files(self) -> Iterator[Tuple[str, Union[str, io.BytesIO]]]: has_pt_meta_file = False for root, dirs, files in os.walk(self.path): @@ -315,6 +378,57 @@ def _hf_readme(self, ``` ''' + def _zenodo_metadata(self, *, zenodo_id: str, pretty_name: Optional[str] = None, metadata: Dict) -> Optional[str]: + description = f''' +

Description

+ +

+TODO: What is the artifact? +

+ +

Usage

+ +
+# Load the artifact
+import pyterrier as pt
+artifact = pt.Artifact.from_zenodo({str(zenodo_id)!r})
+# TODO: Show how you use the artifact
+
+ +

Benchmarks

+ +

+TODO: Provide benchmarks for the artifact. +

+ +

Reproduction

+ +
+# TODO: Show how you constructed the artifact.
+
+ +

Metadata

+ +
+{json.dumps(metadata, indent=2)}
+
+''' + tags = ['pyterrier', 'pyterrier-artifact'] + if 'type' in metadata: + tags.append('pyterrier-artifact.{type}'.format(**metadata)) + if 'type' in metadata and 'format' in metadata: + tags.append('pyterrier-artifact.{type}.{format}'.format(**metadata)) + metadata = { + 'description': description, + 'upload_type': 'other', + 'publisher': 'Zenodo', + 'publication_date': datetime.today().strftime('%Y-%m-%d'), + 'keywords': tags, + } + if pretty_name: + metadata['title'] = pretty_name + return metadata + def build_package( self, package_path: Optional[str] = None, @@ -398,7 +512,7 @@ def manage_maxsize(_: None): if isinstance(file, io.BytesIO): file.seek(0) - if rel_path == 'pt_meta.json': + if rel_path == 'pt_meta.json' and metadata_out is not None: metadata_out.update(json.load(file)) file.seek(0) with pt.io.CallbackReader(file, manage_maxsize) as fin: @@ -407,7 +521,7 @@ def manage_maxsize(_: None): with open(file, 'rb') as fin, \ pt.io.CallbackReader(fin, manage_maxsize) as fin: tarout.addfile(tar_record, fin) - if rel_path == 'pt_meta.json': + if rel_path == 'pt_meta.json' and metadata_out is not None: with open(file, 'rb') as fin: metadata_out.update(json.load(fin)) @@ -416,9 +530,9 @@ def manage_maxsize(_: None): metadata['expected_sha256'] = sha256_fout.hexdigest() - metadata_out = stack.enter_context(pt.io.finalized_open(f'{package_path}.json', 't')) - json.dump(metadata, metadata_out) - metadata_out.write('\n') + metadata_outf = stack.enter_context(pt.io.finalized_open(f'{package_path}.json', 't')) + json.dump(metadata, metadata_outf) + metadata_outf.write('\n') if chunk_num == 0: # no chunking was actually done, can use provided name directly @@ -513,9 +627,16 @@ def _path_repr(path: str) -> str: def _hf_url_resolver(parsed_url: ParseResult) -> str: + # paths like: hf:macavaney/msmarco-passage.terrier org_repo = parsed_url.path # default to ref=main, but allow user to specify another branch, hash, etc, with abc/xyz@branch ref = 'main' if '@' in org_repo: org_repo, ref = org_repo.split('@', 1) return f'https://huggingface.co/datasets/{org_repo}/resolve/{ref}/artifact.tar.lz4' + + +def _zenodo_url_resolver(parsed_url: ParseResult) -> str: + # paths like: zenodo:111952 + zenodo_id = parsed_url.path + return f'https://zenodo.org/records/{zenodo_id}/files/artifact.tar.lz4' diff --git a/pyterrier/io.py b/pyterrier/io.py index e8fb2c7d..f4c7e5e9 100644 --- a/pyterrier/io.py +++ b/pyterrier/io.py @@ -619,6 +619,7 @@ def close(self) -> None: """Close the reader and the progress bar.""" super().close() self.reader.close() + self.pbar.close() class CallbackReader(_NosyReader): diff --git a/setup.py b/setup.py index a9d8b066..ffdffa92 100644 --- a/setup.py +++ b/setup.py @@ -90,7 +90,8 @@ def get_version(rel_path): 'terrier = pyterrier.terrier._metadata_adapter:terrier_artifact_metadata_adapter', ], 'pyterrier.artifact.url_protocol_resolver': [ - 'hf = pyterrier._artifact:_hf_url_resolver' + 'hf = pyterrier._artifact:_hf_url_resolver', + 'zenodo = pyterrier._artifact:_zenodo_url_resolver', ], }, install_requires=requirements,