Skip to content

Commit

Permalink
chore: logins return Sessions (#8883)
Browse files Browse the repository at this point in the history
Logins used to return an intermediate UsernameTokenPair that itself was used to create Sessions. We did it this way because the job of login is to find/create a valid token. But because login itself is a user-facing command, and tokens are an implementation detail most people don't care about, this patch updates the login API to just return a Session (the thing most users cared about in the first place).

If a user really wants to get a token, they can always access the token in the Session.

Additionally as a part of this patch, the UsernameTokenPair class has been removed -- it was kind of an in-between abstraction to make it easier for callers of login to create sessions.

This change also contains new BaseSession.with_retry to modify retries (by creating new sessions). We needed some mechanism for users to set retry to something other than the default, and passing retry into login would have been a weird interface (what does the retry of the eventual session have to do with logging in?).
  • Loading branch information
wes-turner authored Mar 19, 2024
1 parent 93b6aa2 commit a603f4c
Show file tree
Hide file tree
Showing 33 changed files with 125 additions and 123 deletions.
3 changes: 1 addition & 2 deletions .circleci/scripts/wait_for_perf_migration_upload_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ def _wait_for_master() -> None:

def _upload_migration_length(conn: extensions.connection) -> None:
cert = certs.Cert(noverify=True)
utp = authentication.login("http://127.0.0.1:8080", "admin", "", cert)
sess = api.Session("http://127.0.0.1:8080", utp, cert)
sess = authentication.login("http://127.0.0.1:8080", "admin", "", cert)

migration_start_log = None
migration_end_log = None
Expand Down
3 changes: 1 addition & 2 deletions e2e_tests/tests/api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ def cert() -> certs.Cert:
def make_session(username: str, password: str) -> api.Session:
master_url = conf.make_master_url()
# Use login instead of login_with_cache() to not touch auth.json on the filesystem.
utp = authentication.login(master_url, username, password, cert())
return api.Session(master_url, utp, cert(), max_retries=0)
return authentication.login(master_url, username, password, cert())


@functools.lru_cache(maxsize=1)
Expand Down
3 changes: 1 addition & 2 deletions e2e_tests/tests/deploy/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ def mksess(host: str, port: int, username: str = "determined", password: str = "
"""Since this file frequently creates new masters, always create a fresh Session."""

master_url = api.canonicalize_master_url(f"http://{host}:{port}")
utp = authentication.login(master_url, username=username, password=password)
return api.Session(master_url, utp, cert=None, max_retries=0)
return authentication.login(master_url, username=username, password=password)


def det_deploy(subcommand: List) -> None:
Expand Down
3 changes: 1 addition & 2 deletions harness/determined/cli/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,12 @@ def unauth_session(args: argparse.Namespace) -> api.UnauthSession:

def setup_session(args: argparse.Namespace) -> api.Session:
master_url = args.master
utp = authentication.login_with_cache(
return authentication.login_with_cache(
master_address=master_url,
requested_user=args.user,
password=None,
cert=cli.cert,
)
return api.Session(master_url, utp, cli.cert)


def require_feature_flag(feature_flag: str, error_message: str) -> Callable[..., Any]:
Expand Down
5 changes: 3 additions & 2 deletions harness/determined/cli/tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@
cert = certs.default_load(args.master_url, cert_file, args.cert_name, noverify)

if args.auth:
utp = authentication.login_with_cache(args.master_url, args.user, cert=cert)
sess: api.BaseSession = api.Session(args.master_url, utp, cert)
sess: api.BaseSession = authentication.login_with_cache(
args.master_url, args.user, cert=cert
)
else:
sess = api.UnauthSession(args.master_url, cert)

Expand Down
12 changes: 6 additions & 6 deletions harness/determined/cli/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def log_in_user(args: Namespace) -> None:
password = getpass.getpass(message)

token_store = authentication.TokenStore(args.master)
utp = authentication.login(args.master, username, password, cli.cert)
token_store.set_token(utp.username, utp.token)
token_store.set_active(utp.username)
sess = authentication.login(args.master, username, password, cli.cert)
token_store.set_token(sess.username, sess.token)
token_store.set_active(sess.username)


def log_out_user(args: Namespace) -> None:
Expand Down Expand Up @@ -127,9 +127,9 @@ def change_password(args: Namespace) -> None:
# password change so that the user doesn't have to do so manually.
if args.target_user is None:
token_store = authentication.TokenStore(args.master)
utp = authentication.login(args.master, username, password, cli.cert)
token_store.set_token(utp.username, utp.token)
token_store.set_active(utp.username)
sess = authentication.login(args.master, username, password, cli.cert)
token_store.set_token(sess.username, sess.token)
token_store.set_active(sess.username)


def link_with_agent_user(args: Namespace) -> None:
Expand Down
2 changes: 1 addition & 1 deletion harness/determined/common/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
workspace_by_name,
not_found_errs,
)
from determined.common.api.authentication import UsernameTokenPair, salt_and_hash
from determined.common.api.authentication import salt_and_hash
from determined.common.api.logs import (
pprint_logs,
trial_logs,
Expand Down
19 changes: 14 additions & 5 deletions harness/determined/common/api/_session.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import abc
import copy
import json as _json
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, TypeVar, Union

import requests
import urllib3

import determined as det
from determined.common import api
from determined.common import requests as det_requests
from determined.common.api import authentication, certs, errors
from determined.common.api import certs, errors

GeneralizedRetry = Union[urllib3.util.retry.Retry, int]
T = TypeVar("T", bound="BaseSession")

# Default retry logic
DEFAULT_MAX_RETRIES = urllib3.util.retry.Retry(
Expand Down Expand Up @@ -170,6 +172,12 @@ def put(
) -> requests.Response:
return self._do_request("PUT", path, params, json, data, headers, timeout, False)

def with_retry(self: T, max_retries: GeneralizedRetry) -> T:
"""Generate a new session with a different retry policy."""
new_session = copy.copy(self)
new_session._max_retries = max_retries
return new_session


class UnauthSession(BaseSession):
"""
Expand Down Expand Up @@ -229,7 +237,8 @@ class Session(BaseSession):
def __init__(
self,
master: str,
utp: authentication.UsernameTokenPair,
username: str,
token: str,
cert: Optional[certs.Cert],
max_retries: Optional[GeneralizedRetry] = DEFAULT_MAX_RETRIES,
) -> None:
Expand All @@ -241,8 +250,8 @@ def __init__(
)

self.master = master
self.username = utp.username
self.token = utp.token
self.username = username
self.token = token
self.cert = cert
self._max_retries = max_retries

Expand Down
48 changes: 26 additions & 22 deletions harness/determined/common/api/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,29 +34,30 @@ def get_det_password_from_env() -> Optional[str]:
return os.environ.get("DET_PASS")


class UsernameTokenPair:
def __init__(self, username: str, token: str):
self.username = username
self.token = token


def login(
master_address: str,
username: str,
password: str,
cert: Optional[certs.Cert] = None,
) -> UsernameTokenPair:
) -> "api.Session":
"""
Log in without considering or affecting the TokenStore on the file system.
This sends a login request to the master in order to obtain a new token that can sign future
requests to the master. This token is then baked into a new api.Session object for those future
communications.
Used as part of login_with_cache, and also useful in tests where you wish to not affect the
TokenStore.
Returns:
A new, logged-in api.Session (one that has a valid token).
"""
password = api.salt_and_hash(password)
unauth_session = api.UnauthSession(master=master_address, cert=cert, max_retries=0)
login = bindings.v1LoginRequest(username=username, password=password, isHashed=True)
r = bindings.post_Login(session=unauth_session, body=login)
return UsernameTokenPair(username, r.token)
return api.Session(master=master_address, username=username, token=r.token, cert=cert)


def default_load_user_password(
Expand Down Expand Up @@ -110,22 +111,27 @@ def login_with_cache(
requested_user: Optional[str] = None,
password: Optional[str] = None,
cert: Optional[certs.Cert] = None,
) -> UsernameTokenPair:
) -> "api.Session":
"""
Log in, preferring cached credentials in the TokenStore, if possible.
This is the login path for nearly all user-facing cases.
Unlike ``login``, this function may not send a login request to the master. It will instead
first attempt to find a valid token in the TokenStore, and only if that fails will it post a
login request to the master to generate a new one. As with ``login``, the token is then baked
into a new ``api.Session`` object to sign future communication with master.
There is also a special case for checking if the DET_USER_TOKEN is set in the environment (by
the determined-master). That must happen in this function because it is only used when no other
login tokens are active, but it must be considered before asking the user for a password.
As a somewhat surprising side-effect re-using an existing token from the cache, it is actually
possible in cache hit scenarios for an invalid password here to result in a valid login since
the password is only used in a cache miss.
As a somewhat surprising side-effect of re-using an existing token from the cache, it is
actually possible in cache hit scenarios for an invalid password here to result in a valid login
since the password is only used in a cache miss.
Returns:
The username and token of the logged in user.
A new, logged-in Session (one that has a valid token).
"""

token_store = TokenStore(master_address)
Expand All @@ -140,7 +146,7 @@ def login_with_cache(
token = None

if token is not None:
return UsernameTokenPair(user, token)
return api.Session(master=master_address, username=user, token=token, cert=cert)

# Special case: use token provided from the container environment if:
# - No token was obtained from the token store already,
Expand All @@ -156,14 +162,14 @@ def login_with_cache(
assert env_user
env_token = get_det_user_token_from_env()
assert env_token
return UsernameTokenPair(env_user, env_token)
return api.Session(master=master_address, username=env_user, token=env_token, cert=cert)

if password is None:
password = getpass.getpass(f"Password for user '{user}': ")

try:
utp = login(master_address, user, password, cert)
user, token = utp.username, utp.token
sess = login(master_address, user, password, cert)
user, token = sess.username, sess.token
except api.errors.ForbiddenException:
# Master will return a 403 if the user is not found, or if the password is incorrect.
# This is the right response to a failed explicit login attempt. But in the "fallback" case,
Expand All @@ -175,7 +181,7 @@ def login_with_cache(

token_store.set_token(user, token)

return UsernameTokenPair(user, token)
return sess


def logout(
Expand Down Expand Up @@ -210,8 +216,7 @@ def logout(

token_store.drop_user(user)

utp = UsernameTokenPair(user, token)
sess = api.Session(master=master_address, utp=utp, cert=cert)
sess = api.Session(master=master_address, username=user, token=token, cert=cert)
try:
bindings.post_Logout(sess)
except (api.errors.UnauthenticatedException, api.errors.APIException):
Expand All @@ -235,8 +240,7 @@ def _is_token_valid(master_address: str, token: str, cert: Optional[certs.Cert])
Find out whether the given token is valid by attempting to use it
on the "api/v1/me" endpoint.
"""
utp = UsernameTokenPair("username-doesnt-matter", token)
sess = api.Session(master_address, utp, cert)
sess = api.Session(master_address, username="ignored", token=token, cert=cert)
try:
r = sess.get("api/v1/me")
except (api.errors.UnauthenticatedException, api.errors.APIException):
Expand Down
8 changes: 6 additions & 2 deletions harness/determined/common/experimental/determined.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ def __init__(
explicit_noverify=noverify,
)

utp = authentication.login_with_cache(self._master, user, password, cert=cert)
self._session = api.Session(self._master, utp, cert)
self._session = authentication.login_with_cache(self._master, user, password, cert=cert)

@classmethod
def _from_session(cls, session: api.Session) -> "Determined":
Expand Down Expand Up @@ -99,6 +98,11 @@ def get_session_username(self) -> str:
return self._session.username

def logout(self) -> None:
"""Log out of the current session.
This results in dropping any cached credentials and sending a request to master to
invalidate the session's token.
"""
authentication.logout(self._session.master, self._session.username, self._session.cert)

def list_users(self, active: Optional[bool] = None) -> List[user.User]:
Expand Down
5 changes: 3 additions & 2 deletions harness/determined/core/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,9 @@ def init(

# We are on the cluster.
cert = certs.default_load(info.master_url)
utp = authentication.login_with_cache(info.master_url, cert=cert)
session = api.Session(info.master_url, utp, cert, max_retries=util.get_max_retries_config())
session = authentication.login_with_cache(info.master_url, cert=cert).with_retry(
util.get_max_retries_config()
)

if distributed is None:
if len(info.container_addrs) > 1 or len(info.slot_ids) > 1:
Expand Down
3 changes: 1 addition & 2 deletions harness/determined/deploy/healthcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ def wait_for_genai_url(
start_time = time.time()

# Hopefully we have an active session to this master, or we can make a default one.
utp = authentication.login_with_cache(master_url, cert=cert)
sess = api.Session(master_url, utp, cert)
sess = authentication.login_with_cache(master_url, cert=cert)

try:
while time.time() - start_time < timeout:
Expand Down
9 changes: 5 additions & 4 deletions harness/determined/exec/gc_checkpoints.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
The entrypoint for the GC checkpoints job container.
"""

import argparse
import json
import logging
Expand All @@ -12,7 +13,7 @@

import determined as det
from determined import errors, tensorboard
from determined.common import api, constants, storage
from determined.common import constants, storage
from determined.common.api import authentication, bindings, certs

logger = logging.getLogger("determined")
Expand All @@ -25,10 +26,10 @@ def patch_checkpoints(storage_ids_to_resources: Dict[str, Dict[str, int]]) -> No
info._to_file()

cert = certs.default_load(info.master_url)
utp = authentication.login_with_cache(info.master_url, cert=cert)
# With backoff retries for 64 seconds
max_retries = urllib3.util.retry.Retry(total=6, backoff_factor=0.5)
sess = api.Session(info.master_url, utp, cert, max_retries)
sess = authentication.login_with_cache(info.master_url, cert=cert).with_retry(
urllib3.util.retry.Retry(total=6, backoff_factor=0.5)
)

checkpoints = []
for storage_id, resources in storage_ids_to_resources.items():
Expand Down
5 changes: 2 additions & 3 deletions harness/determined/exec/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import types

import determined as det
from determined.common import api, constants, storage
from determined.common import constants, storage
from determined.common.api import authentication, certs
from determined.exec import prep_container

Expand All @@ -22,8 +22,7 @@ def trigger_preemption(signum: int, frame: types.FrameType) -> None:
logger.info("SIGTERM: Preemption imminent.")
# Notify the master that we need to be preempted
cert = certs.default_load(info.master_url)
utp = authentication.login_with_cache(info.master_url, cert=cert)
sess = api.Session(info.master_url, utp, cert)
sess = authentication.login_with_cache(info.master_url, cert=cert)
sess.post(f"/api/v1/allocations/{info.allocation_id}/signals/pending_preemption")


Expand Down
6 changes: 3 additions & 3 deletions harness/determined/exec/prep_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,10 +310,10 @@ def do_proxy(sess: api.Session, allocation_id: str) -> None:
)

cert = certs.default_load(info.master_url)
utp = authentication.login_with_cache(info.master_url, cert=cert)
# With backoff retries for 64 seconds
max_retries = urllib3.util.retry.Retry(total=6, backoff_factor=0.5)
sess = api.Session(info.master_url, utp, cert, max_retries)
sess = authentication.login_with_cache(info.master_url, cert=cert).with_retry(
urllib3.util.retry.Retry(total=6, backoff_factor=0.5)
)

# Notify the Determined Master that the container is running.
# This should only be used on HPC clusters.
Expand Down
Loading

0 comments on commit a603f4c

Please sign in to comment.