Skip to content

Commit

Permalink
client: attachment unpack and more efficient submission (#788)
Browse files Browse the repository at this point in the history
* client: add Attachment.unpack()
* client: catch payload size, skip_dupe_check flag, count processed
  • Loading branch information
tschaume committed Oct 15, 2021
1 parent 5fc80f0 commit 7ab62d3
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 44 deletions.
131 changes: 87 additions & 44 deletions mpcontribs-client/mpcontribs/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,13 @@
from pint.unit import UnitDefinition
from pint.converters import ScaleConverter
from pint.errors import DimensionalityError
from datetime import datetime

RETRIES = 3
MAX_WORKERS = 8
MAX_ELEMS = 10
MAX_BYTES = 1200 * 1024
MEGABYTES = 1024 * 1024
MAX_BYTES = 1.2 * MEGABYTES
MAX_PAYLOAD = 15 * MEGABYTES
DEFAULT_HOST = "contribs-api.materialsproject.org"
BULMA = "is-narrow is-fullwidth has-background-light"
PROVIDERS = {"github", "google", "facebook", "microsoft", "amazon"}
Expand Down Expand Up @@ -177,8 +178,10 @@ def _response_hook(resp, *args, **kwargs):

elif content_type == "application/gzip":
resp.result = resp.content
resp.count = 1
else:
print("ERROR", resp.status_code)
resp.count = 0


class FidoClientGlobalHeaders(FidoClient):
Expand Down Expand Up @@ -261,6 +264,14 @@ def decode(self) -> str:
"""Decode base64-encoded content of attachment"""
return b64decode(self["content"], validate=True)

def unpack(self) -> str:
unpacked = self.decode()

if self["mime"] == "application/gzip":
unpacked = gzip.decompress(unpacked).decode("utf-8")

return unpacked

def write(self, outdir: Union[str, Path] = None) -> Path:
"""Write attachment to file using its name
Expand Down Expand Up @@ -346,10 +357,11 @@ def _run_futures(futures, total: int = 0, timeout: int = -1, desc=None):

if hasattr(future, "track_id"):
tid = future.track_id
responses[tid] = {}
if hasattr(response, "result"):
responses[tid] = response.result
elif hasattr(response, "count"):
responses[tid] = {"count": response.count}
responses[tid]["result"] = response.result
if hasattr(response, "count"):
responses[tid]["count"] = response.count

elapsed = time.perf_counter() - start
timed_out = timeout > 0 and elapsed > timeout
Expand Down Expand Up @@ -905,15 +917,15 @@ def get_totals(
query = query or {}
skip_keys = {"per_page", "_fields", "format"}
query = {k: v for k, v in query.items() if k not in skip_keys}
query["_fields"] = ["id"]
query["_fields"] = [] # only need totals -> explicitly request no fields
queries = self._split_query(query, resource=resource, op=op) # don't paginate
result = {"total_count": 0, "total_pages": 0}
futures = [self._get_future(i, q, rel_url=resource) for i, q in enumerate(queries)]
responses = _run_futures(futures, timeout=timeout, desc="Totals")

for resp in responses.values():
for k in result:
result[k] += resp[k]
result[k] += resp["result"][k]

return result["total_count"], result["total_pages"]

Expand Down Expand Up @@ -1033,7 +1045,7 @@ def get_all_ids(
responses = _run_futures(futures, timeout=timeout, desc="Identifiers")

for resp in responses.values():
for contrib in resp["data"]:
for contrib in resp["result"]["data"]:
project = contrib["project"]
data_id_field = data_id_fields.get(project)

Expand Down Expand Up @@ -1237,7 +1249,8 @@ def submit_contributions(
ignore_dupes: bool = False,
retry: bool = False,
per_request: int = 100,
timeout: int = -1
timeout: int = -1,
skip_dupe_check: bool = False
):
"""Submit a list of contributions
Expand Down Expand Up @@ -1267,6 +1280,7 @@ def submit_contributions(
retry (bool): keep trying until all contributions successfully submitted
per_request (int): number of contributions to submit per request
timeout (int): cancel remaining requests if timeout exceeded (in seconds)
skip_dupe_check (bool): skip duplicate check for contribution identifiers
"""
if not contributions or not isinstance(contributions, list):
print("Please provide list of contributions to submit.")
Expand Down Expand Up @@ -1306,9 +1320,10 @@ def submit_contributions(
id2project[cid] = project_name

existing = defaultdict(dict)
unique_identifiers = defaultdict(dict)
project_names = list(project_names)

if len(collect_ids) != len(contributions):
if not skip_dupe_check and len(collect_ids) != len(contributions):
unique_identifiers = self.get_unique_identifiers_flags(projects=project_names)
existing = defaultdict(dict, self.get_all_ids(
dict(project__in=project_names), include=COMPONENTS, timeout=timeout
Expand All @@ -1328,8 +1343,8 @@ def submit_contributions(
update = "id" in contrib
project_name = id2project[contrib["id"]] if update else contrib["project"]
if (
not update and unique_identifiers[project_name]
and contrib["identifier"] in existing[project_name].get("identifiers", {})
not update and unique_identifiers.get(project_name)
and contrib["identifier"] in existing.get(project_name, {}).get("identifiers", {})
):
continue

Expand Down Expand Up @@ -1434,7 +1449,7 @@ def submit_contributions(

dupe = bool(
digest in digests[project_name][component] or
digest in existing[project_name].get(component, {}).get("md5s", set())
digest in existing.get(project_name, {}).get(component, {}).get("md5s", [])
)

if not ignore_dupes and dupe:
Expand All @@ -1458,59 +1473,87 @@ def submit_contributions(
self.session.close()
self.session = get_session(max_workers=2)

def post_future(chunk):
return self.session.post(
def post_future(track_id, payload):
future = self.session.post(
f"{self.url}/contributions/",
headers=self.headers,
data=ujson.dumps(chunk).encode("utf-8"),
data=payload,
)
setattr(future, "track_id", track_id)
return future

def put_future(cdct):
pk = cdct.pop("id")
return self.session.put(
def put_future(pk, payload):
future = self.session.put(
f"{self.url}/contributions/{pk}/",
headers=self.headers,
data=ujson.dumps(cdct).encode("utf-8"),
data=payload,
)
setattr(future, "track_id", pk)
return future

for project_name in project_names:
ncontribs = len(contribs[project_name])
total += ncontribs
retries = 0

while contribs[project_name]:
futures = []
for chunk in grouper(per_page, contribs[project_name]):
for idx, chunk in enumerate(grouper(per_page, contribs[project_name])):
post_chunk = []
for c in chunk:
if "id" in c:
futures.append(put_future(c))
pk = c.pop("id")
payload = ujson.dumps(c).encode("utf-8")
if len(payload) < MAX_PAYLOAD:
futures.append(put_future(pk, payload))
else:
print(
f"SKIPPED update of {project_name}/{pk}: too large."
)
else:
post_chunk.append(c)

if post_chunk:
futures.append(post_future(post_chunk))
payload = ujson.dumps(post_chunk).encode("utf-8")
if len(payload) < MAX_PAYLOAD:
futures.append(post_future(idx, payload))
else:
print(
f"SKIPPED {project_name}/{idx}: too large, reduce per_request"
)

responses = _run_futures(futures, total=ncontribs, timeout=timeout, desc="Submit")
processed = sum(r.count for r in responses.values())
if not futures:
break # nothing to do

responses = _run_futures(
futures, total=ncontribs, timeout=timeout, desc="Submit"
)
processed = sum(r["count"] for r in responses.values())
total_processed += processed
print("PROCESSED", processed)

if processed != ncontribs:
if retry and unique_identifiers[project_name]:
existing[project_name] = self.get_all_ids(
dict(project=project_name), include=COMPONENTS, timeout=timeout
).get(project_name, {"identifiers": set()})
unique_identifiers[project_name] = self.projects.get_entry(
pk=project_name, _fields=["unique_identifiers"]
).result()["unique_identifiers"]
contribs[project_name] = [
c for c in contribs[project_name]
if c["identifier"] not in existing[project_name]["identifiers"]
]
else:
contribs[project_name] = [] # abort retrying
if retry and not unique_identifiers[project_name]:
print("Please resubmit failed contributions manually.")

if processed != ncontribs and retry and retries < RETRIES and \
unique_identifiers.get(project_name):
existing[project_name] = self.get_all_ids(
dict(project=project_name), include=COMPONENTS, timeout=timeout
).get(project_name, {"identifiers": set()})
unique_identifiers[project_name] = self.projects.get_entry(
pk=project_name, _fields=["unique_identifiers"]
).result()["unique_identifiers"]
existing_ids = existing.get(project_name, {}).get("identifiers", [])
contribs[project_name] = [
c for c in contribs[project_name]
if c["identifier"] not in existing_ids
]
retries += 1
else:
contribs[project_name] = [] # abort retrying
if processed != ncontribs and retry:
if retries >= RETRIES:
print(f"{project_name}: Tried {RETRIES} times - abort.")
elif not unique_identifiers.get(project_name):
print(
f"{project_name}: resubmit failed contributions manually"
)

toc = time.perf_counter()
dt = (toc - tic) / 60
Expand Down Expand Up @@ -1730,7 +1773,7 @@ def _download_resource(
responses = _run_futures(futures, timeout=timeout)

for path, resp in responses.items():
path.write_bytes(resp)
path.write_bytes(resp["result"])
paths.append(path)

return paths
1 change: 1 addition & 0 deletions mpcontribs-client/tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ omit = *test_*.py

[testenv]
deps =
flake8<4
pytest
pytest-flake8
pytest-pycodestyle
Expand Down

0 comments on commit 7ab62d3

Please sign in to comment.