Skip to content

Commit

Permalink
client: remove shared state between instances (#1453)
Browse files Browse the repository at this point in the history
* rm shared state, cache version and session

* add url to get_version

* global session and executor
  • Loading branch information
tschaume committed Dec 1, 2022
1 parent e2ecf82 commit 000a5fb
Showing 1 changed file with 41 additions and 59 deletions.
100 changes: 41 additions & 59 deletions mpcontribs-client/mpcontribs/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from IPython.display import display, HTML, Image, FileLink
from boltons.iterutils import remap
from pymatgen.core import Structure as PmgStructure
from concurrent.futures import as_completed
from concurrent.futures import as_completed, ProcessPoolExecutor
from requests_futures.sessions import FuturesSession
from urllib3.util.retry import Retry
from filetype.types.archive import Gz
Expand Down Expand Up @@ -106,6 +106,8 @@

LOG_LEVEL = os.environ.get("MPCONTRIBS_CLIENT_LOG_LEVEL", "INFO")
log_level = getattr(logging, LOG_LEVEL.upper())
_session = requests.Session()
_executor = ProcessPoolExecutor(max_workers=MAX_WORKERS)


class LogFilter(logging.Filter):
Expand Down Expand Up @@ -206,7 +208,7 @@ def grouper(n, iterable):
yield chunk


def get_session(max_workers=MAX_WORKERS):
def get_session():
# TODO add Bad Gateway 502?
adapter_kwargs = dict(max_retries=Retry(
total=RETRIES,
Expand All @@ -215,7 +217,9 @@ def get_session(max_workers=MAX_WORKERS):
respect_retry_after_header=True,
status_forcelist=[429], # rate limit
))
s = FuturesSession(max_workers=max_workers, adapter_kwargs=adapter_kwargs)
s = FuturesSession(
session=_session, executor=_executor, adapter_kwargs=adapter_kwargs
)
s.hooks['response'].append(_response_hook)
return s

Expand Down Expand Up @@ -461,15 +465,10 @@ def _load(protocol, host, headers_json, project, version):
return swagger_spec

# expand regex-based query parameters for `data` columns
session = get_session()
query = {"name": project} if project else {}
query["_fields"] = ["columns"]
kwargs = dict(headers=headers, params=query)
future = session.get(f"{url}/projects/", **kwargs)
track_id = "get_columns"
setattr(future, "track_id", track_id)
resp = _run_futures([future], timeout=3, disable=True).get(track_id, {}).get("result")
session.close()
resp = _session.get(f"{url}/projects/", **kwargs).json()

if not resp or not resp["data"]:
raise ValueError(f"Failed to load projects for query {query}!")
Expand Down Expand Up @@ -511,6 +510,36 @@ def _load(protocol, host, headers_json, project, version):
return swagger_spec


@functools.lru_cache(maxsize=1)
def _version(url):
retries, max_retries = 0, 3
protocol = urlparse(url).scheme
is_mock_test = 'pytest' in sys.modules and protocol == "http"

if is_mock_test:
now = datetime.datetime.now()
return Version(
major=now.year, minor=now.month, patch=now.day,
prerelease=(str(now.hour), str(now.minute))
)
else:
while retries < max_retries:
try:
r = requests.get(f"{url}/healthcheck", timeout=2)
if r.status_code == 200:
return r.json().get("version")
else:
retries += 1
logger.warning(
f"Healthcheck for {url} failed ({r.status_code})! Wait 30s."
)
time.sleep(30)
except RequestException as ex:
retries += 1
logger.warning(f"Could not connect to {url} ({ex})! Wait 30s.")
time.sleep(30)


class Client(SwaggerClient):
"""client to connect to MPContribs API
Expand All @@ -520,10 +549,6 @@ class Client(SwaggerClient):
>>> from mpcontribs.client import Client
>>> client = Client()
"""
# Borg: https://www.oreilly.com/library/view/python-cookbook/0596001673/ch05s23.html
# NOTE bravado future doesn't work with concurrent.futures
_shared_state = {}

def __init__(
self,
apikey: str = None,
Expand All @@ -539,10 +564,9 @@ def __init__(
host (str): host address to connect to (or use MPCONTRIBS_API_HOST env var)
project (str): use this project for all operations (query, update, create, delete)
"""
# NOTE bravado future doesn't work with concurrent.futures
# - Kong forwards consumer headers when api-key used for auth
# - forward consumer headers when connecting through localhost
self.__dict__ = self._shared_state

if not host:
host = os.environ.get("MPCONTRIBS_API_HOST", DEFAULT_HOST)

Expand All @@ -567,45 +591,13 @@ def __init__(
if self.url not in VALID_URLS:
raise ValueError(f"{self.url} not a valid URL (one of {VALID_URLS})")

if "version" not in self.__dict__:
retries, max_retries = 0, 3
is_mock_test = 'pytest' in sys.modules and self.protocol == "http"

if is_mock_test:
now = datetime.datetime.now()
self.version = Version(
major=now.year, minor=now.month, patch=now.day,
prerelease=(str(now.hour), str(now.minute))
)
else:
while retries < max_retries:
try:
r = requests.get(f"{self.url}/healthcheck", timeout=2)
if r.status_code == 200:
self.version = r.json().get("version")
break
else:
retries += 1
logger.warning(
f"Healthcheck for {self.url} failed ({r.status_code})! Wait 30s."
)
time.sleep(30)
except RequestException as ex:
retries += 1
logger.warning(f"Could not connect to {self.url} ({ex})! Wait 30s.")
time.sleep(30)

if "session" not in self.__dict__:
self.session = get_session()

self.version = _version(self.url) # includes healthcheck
self.session = get_session()
super().__init__(self.cached_swagger_spec)

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
return self.session.close()

@property
def cached_swagger_spec(self):
return _load(self.protocol, self.host, self.headers_json, self.project, self.version)
Expand Down Expand Up @@ -707,9 +699,6 @@ def _get_future(
op: str = "get",
data: dict = None
):
if self.session and self.session.executor._shutdown:
raise ValueError("Session closed. Use `with` statement.")

rname = rel_url.split("/", 1)[0]
resource = self.swagger_spec.resources[rname]
method = getattr(resource, f"{op}_entries").http_method
Expand Down Expand Up @@ -1706,12 +1695,7 @@ def submit_contributions(

# submit contributions
if contribs:
if self.session and self.session.executor._shutdown:
raise ValueError("Session closed. Use `with` statement.")

total, total_processed = 0, 0
self.session.close()
self.session = get_session(max_workers=4)

def post_future(track_id, payload):
future = self.session.post(
Expand Down Expand Up @@ -1804,8 +1788,6 @@ def put_future(pk, payload):
dt = (toc - tic) / 60
self.init_columns()
self._reinit()
self.session.close()
self.session = get_session()
logger.info(f"It took {dt:.1f}min to submit {total_processed}/{total} contributions.")
else:
logger.info("Nothing to submit.")
Expand Down

0 comments on commit 000a5fb

Please sign in to comment.