Skip to content

Commit

Permalink
chore: canonicalize master urls everywhere [MLG-878] (#8670)
Browse files Browse the repository at this point in the history
We had a lot of ancient master url helper functions, that all needed
revisiting.

The main philosophical update is a transition from "validate and
canonicalize every master url every time" to "validate and canonicalize
user input and assume the url is canonicalized everywhere else".

This is similar to the networking concept of a "dumb network with smart
edges", which allows network internals to be simple and scalable.  In
the same way, our code internals can be simple and easy to write, and
only when receiving user input do we need to worry about validation or
canonicalization.

Issues with the old code:

  - parse_master_url() wasn't just parsing, it was also doing some amount
    of canonicalization (but not enough).

  - parse_master_url() actually returned a ParsedURL object but (almost)
    nobody used it directly; it was mostly a step of make_url().

  - make_url() was essentially doing a urllib.parse.urljoin() on its
    arguments, where the first argument was the master_url and,
    confusingly, the second argument was actually allowed to an absolute
    url, causing the first argument to be ignored.  Even more
    confusingly, there was a test that ensured that codepath worked,
    even though the test said it was "unexpected".  And lastly that
    confusing functionality was never used, since we never let users
    pass values directly to that function.

  - make_url_new() claimed to deprecate make_url() (even though they are
    both internal functions) but make_url_new() was only used in the
    `det dev` subcommands, causing those commands to exhibit different
    behavior than the rest of the system.  Also, make_url_new()
    explicitly preserved the most confusing parts of make_url().

  - master_url was optional all over the place (even on internal
    functions), and we were willing to read from environment variables
    in way too many places.  Confusingly, there were zero instances of
    calling these internal functions without a master_url, so the
    complexity was pointless.

Solutions in the new code:

   - There is a canonicalize_master_url() that must be applied to all
     user-provided values before proceeding to the internal parts of the
     system.  Happily, this only occurs in four places:
       - In the CLI.
       - In the det.cli.tunnel module (a different CLI, basically).
       - In the python SDK.
       - When building a ClusterInfo object.

   - Delete make_url() entirely.  It was always confusing what it did,
     creating a cargo-cult situation where you had to use it but you
     didn't know why.  Replace every use of make_url() with simple
     f-string concatenation of the master_url and the api path.

   - Design the canonical form of the master url specifically to make
     f-string concatenation of the master url and the path effective.

   - Bump the TokenStore's auth.json to v2, which is just like v1 only
     the master urls must be canonicalized, which is the actual fix for
     MLG-878.
  • Loading branch information
rb-determined-ai authored Feb 22, 2024
1 parent e3709bd commit 72d54be
Show file tree
Hide file tree
Showing 31 changed files with 507 additions and 348 deletions.
2 changes: 1 addition & 1 deletion .circleci/scripts/wait_for_master.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _wait_for_master(address: str) -> None:

def main() -> None:
parser = argparse.ArgumentParser(description="Wait for master helper.")
parser.add_argument("address", help="Master address.")
parser.add_argument("address", type=api.canonicalize_master_url, help="Master address.")
args = parser.parse_args()
_wait_for_master(args.address)

Expand Down
4 changes: 2 additions & 2 deletions e2e_tests/tests/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def load_config(config_path: str) -> Any:
return config


def make_master_url(suffix: str = "") -> str:
return f"{MASTER_SCHEME}://{MASTER_IP}:{MASTER_PORT}/{suffix}"
def make_master_url() -> str:
return api.canonicalize_master_url(f"{MASTER_SCHEME}://{MASTER_IP}:{MASTER_PORT}")


def set_global_batch_size(config: Dict[Any, Any], batch_size: int) -> Dict[Any, Any]:
Expand Down
2 changes: 1 addition & 1 deletion e2e_tests/tests/deploy/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def mksess(host: str, port: int, username: str = "determined", password: str = "
Since this file frequently creates new masters, always create a fresh Session.
"""

master_url = f"http://{host}:{port}"
master_url = api.canonicalize_master_url(f"http://{host}:{port}")
utp = authentication.login(master_url, username=username, password=password)
return api.Session(master_url, utp, cert=None)

Expand Down
3 changes: 2 additions & 1 deletion harness/determined/_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Dict, Iterable, List, Optional, Union

from determined import gpu
from determined.common import api

DEFAULT_RENDEZVOUS_INFO_PATH = "/run/determined/info/rendezvous.json"
DEFAULT_TRIAL_INFO_PATH = "/run/determined/info/trial.json"
Expand Down Expand Up @@ -209,7 +210,7 @@ def __init__(
resources_info: Optional[ResourcesInfo] = None,
):
#: The url for reaching the master.
self.master_url = master_url
self.master_url = api.canonicalize_master_url(master_url)

#: The unique identifier for this cluster.
self.cluster_id = cluster_id
Expand Down
6 changes: 3 additions & 3 deletions harness/determined/cli/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from determined import cli
from determined.cli import errors, render
from determined.common import api, declarative_argparse, util
from determined.common import api, declarative_argparse
from determined.common.api import authentication, bindings

output_format_args: Dict[str, declarative_argparse.Arg] = {
Expand Down Expand Up @@ -85,12 +85,12 @@ def make_pagination_args(


def unauth_session(args: argparse.Namespace) -> api.UnauthSession:
master_url = args.master or util.get_default_master_address()
master_url = args.master
return api.UnauthSession(master=master_url, cert=cli.cert)


def setup_session(args: argparse.Namespace) -> api.Session:
master_url = args.master or util.get_default_master_address()
master_url = args.master
utp = authentication.login_with_cache(
master_address=master_url,
requested_user=args.user,
Expand Down
35 changes: 13 additions & 22 deletions harness/determined/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Namespace,
)
from typing import List, Sequence, Union, cast
from urllib import parse

import argcomplete
import argcomplete.completers
Expand All @@ -20,8 +21,9 @@
from termcolor import colored

import determined as det
import determined.errors
from determined import cli
from determined.cli import render
from determined.cli import errors, render
from determined.cli.agent import args_description as agent_args_description
from determined.cli.checkpoint import args_description as checkpoint_args_description
from determined.cli.command import args_description as command_args_description
Expand All @@ -48,30 +50,20 @@
from determined.cli.version import args_description as version_args_description
from determined.cli.version import check_version
from determined.cli.workspace import args_description as workspace_args_description
from determined.common import api, yaml
from determined.common import api, util, yaml
from determined.common.api import bindings, certs
from determined.common.check import check_not_none
from determined.common.declarative_argparse import (
Arg,
ArgsDescription,
Cmd,
add_args,
generate_aliases,
)
from determined.common.util import (
chunks,
debug_mode,
get_default_master_address,
safe_load_yaml_with_exceptions,
)
from determined.errors import EnterpriseOnlyError

from .errors import CliError, FeatureFlagDisabled


def preview_search(args: Namespace) -> None:
sess = cli.setup_session(args)
experiment_config = safe_load_yaml_with_exceptions(args.config_file)
experiment_config = util.safe_load_yaml_with_exceptions(args.config_file)
args.config_file.close()

if "searcher" not in experiment_config:
Expand Down Expand Up @@ -134,7 +126,8 @@ def render_sequence(sequence: List[str]) -> str:
"--master",
help="master address",
metavar="address",
default=get_default_master_address(),
type=api.canonicalize_master_url,
default=api.get_default_master_url(),
),
Arg(
"-v",
Expand Down Expand Up @@ -188,7 +181,7 @@ def make_parser() -> ArgumentParser:


def die(message: str, always_print_traceback: bool = False, exit_code: int = 1) -> None:
if always_print_traceback or debug_mode():
if always_print_traceback or util.debug_mode():
import traceback

traceback.print_exc(file=sys.stderr)
Expand Down Expand Up @@ -242,9 +235,7 @@ def main(
# cert, so allow the user to store and trust the current cert. (It could also mean
# that we tried to talk HTTPS on the HTTP port, but distinguishing that based on the
# exception is annoying, and we'll figure that out in the next step anyway.)
addr = api.parse_master_address(parsed_args.master)
check_not_none(addr.hostname)
check_not_none(addr.port)
addr = parse.urlparse(parsed_args.master)
try:
ctx = SSL.Context(SSL.TLSv1_2_METHOD)
conn = SSL.Connection(ctx, socket.socket())
Expand All @@ -269,7 +260,7 @@ def main(
# Compute the fingerprint of the certificate; this is the same as the output of
# `openssl x509 -fingerprint -sha256 -inform pem -noout -in <cert>`.
cert_hash = hashlib.sha256(ssl.PEM_cert_to_DER_cert(cert_pem_data[0])).hexdigest()
cert_fingerprint = ":".join(chunks(cert_hash, 2))
cert_fingerprint = ":".join(util.chunks(cert_hash, 2))

if not render.yes_or_no(
"The master sent an untrusted certificate chain with this SHA256 fingerprint:\n"
Expand Down Expand Up @@ -297,11 +288,11 @@ def main(
"Failed to login: Attempted to read a corrupted token cache. "
"The store has been deleted; please try again."
)
except EnterpriseOnlyError as e:
except det.errors.EnterpriseOnlyError as e:
die(f"Determined Enterprise Edition is required for this functionality: {e}")
except FeatureFlagDisabled as e:
except errors.FeatureFlagDisabled as e:
die(f"Master does not support this operation: {e}")
except CliError as e:
except errors.CliError as e:
die(e.message, exit_code=e.exit_code)
except ArgumentError as e:
die(e.message, exit_code=2)
Expand Down
7 changes: 4 additions & 3 deletions harness/determined/cli/dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from determined.cli import errors, render
from determined.common.api import bindings
from determined.common.api import errors as api_errors
from determined.common.api import request
from determined.common.declarative_argparse import Arg, Cmd


Expand All @@ -34,15 +33,17 @@ def curl(args: Namespace) -> None:
sys.exit(1)

parsed = parse.urlparse(args.path)
if parsed.scheme:
if parsed.scheme or parsed.netloc:
raise errors.CliError(
"path argument does not support absolute URLs."
+ " Set the host path through `det` command"
)

relpath = args.path.lstrip("/")

cmd: List[str] = [
"curl",
request.make_url_new(args.master, args.path),
f"{args.master}/{relpath}",
"-H",
f"Authorization: Bearer {sess.token}",
"-s",
Expand Down
30 changes: 15 additions & 15 deletions harness/determined/cli/notebook.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import webbrowser
from argparse import ONE_OR_MORE, FileType, Namespace
from functools import partial
from pathlib import Path
Expand All @@ -8,7 +9,7 @@
from determined import cli
from determined.cli import ntsc, render, task
from determined.common import api, context
from determined.common.api import bindings, request
from determined.common.api import bindings
from determined.common.check import check_none
from determined.common.declarative_argparse import Arg, ArgsDescription, Cmd, Group

Expand Down Expand Up @@ -51,18 +52,18 @@ def start_notebook(args: Namespace) -> None:
cli.wait_ntsc_ready(sess, api.NTSC_Kind.notebook, nb.id)

assert nb.serviceAddress is not None, "missing tensorboard serviceAddress"
nb_path = request.make_interactive_task_url(
nb_path = ntsc.make_interactive_task_url(
task_id=nb.id,
service_address=nb.serviceAddress,
description=nb.description,
resource_pool=nb.resourcePool,
task_type="jupyter-lab",
currentSlotsExceeded=currentSlotsExceeded,
)
url = api.make_url(args.master, nb_path)
url = f"{args.master}/{nb_path}"
if not args.no_browser:
api.browser_open(args.master, nb_path)
print(colored("Jupyter Notebook is running at: {}".format(url), "green"))
webbrowser.open(url)
print(colored(f"Jupyter Notebook is running at: {url}", "green"))


def open_notebook(args: Namespace) -> None:
Expand All @@ -75,18 +76,17 @@ def open_notebook(args: Namespace) -> None:
nb = bindings.get_GetNotebook(sess, notebookId=notebook_id).notebook
assert nb.serviceAddress is not None, "missing tensorboard serviceAddress"

api.browser_open(
args.master,
request.make_interactive_task_url(
task_id=nb.id,
service_address=nb.serviceAddress,
description=nb.description,
resource_pool=nb.resourcePool,
task_type="jupyter-lab",
currentSlotsExceeded=False,
),
nb_path = ntsc.make_interactive_task_url(
task_id=nb.id,
service_address=nb.serviceAddress,
description=nb.description,
resource_pool=nb.resourcePool,
task_type="jupyter-lab",
currentSlotsExceeded=False,
)

webbrowser.open(f"{args.master}/{nb_path}")


args_description: ArgsDescription = [
Cmd(
Expand Down
33 changes: 33 additions & 0 deletions harness/determined/cli/ntsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import base64
import json
import operator
import os
import re
from argparse import Namespace
from collections import OrderedDict
from functools import reduce
from pathlib import Path
from typing import IO, Any, Dict, Iterable, List, Optional, Tuple, Union
from urllib import parse

from termcolor import colored

Expand Down Expand Up @@ -435,3 +437,34 @@ def launch_command(
body["workspaceId"] = workspace_id

return sess.post(endpoint, json=body).json()


def make_interactive_task_url(
task_id: str,
service_address: str,
description: str,
resource_pool: str,
task_type: str,
currentSlotsExceeded: bool,
) -> str:
wait_path = (
"/jupyter-lab/{}/events".format(task_id)
if task_type == "jupyter-lab"
else "/tensorboard/{}/events?tail=1".format(task_id)
)
wait_path_url = service_address + wait_path
public_url = os.environ.get("PUBLIC_URL", "/det")
wait_page_url = "{}/wait/{}/{}?eventUrl={}&serviceAddr={}".format(
public_url, task_type, task_id, wait_path_url, service_address
)
task_web_url = "{}/interactive/{}/{}/{}/{}/{}?{}".format(
public_url,
task_id,
task_type,
parse.quote(description),
resource_pool,
parse.quote_plus(wait_page_url),
f"currentSlotsExceeded={str(currentSlotsExceeded).lower()}",
)
# Return a relative path that can be joined to the master_url with a simple "/" separator.
return task_web_url.lstrip("/")
25 changes: 18 additions & 7 deletions harness/determined/cli/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
import urllib.request
from dataclasses import dataclass
from typing import Iterator, List, Optional
from urllib import parse

import lomond

from determined.common import api
from determined.common.api import bindings, certs, request
from determined.common.api import bindings, certs


@dataclass
Expand Down Expand Up @@ -138,8 +139,18 @@ def copy_from_websocket2(
f.close()


def maybe_upgrade_ws_scheme(master_address: str) -> str:
parsed = parse.urlparse(master_address)
if parsed.scheme == "https":
return parsed._replace(scheme="wss").geturl()
elif parsed.scheme == "http":
return parsed._replace(scheme="ws").geturl()
else:
return master_address


def http_connect_tunnel(sess: api.BaseSession, service: str) -> None:
parsed_master = request.parse_master_address(sess.master)
parsed_master = parse.urlparse(sess.master)
assert parsed_master.hostname is not None, f"Failed to parse master address: {sess.master}"

# The "lomond.WebSocket()" function does not honor the "no_proxy" or
Expand All @@ -154,8 +165,8 @@ def http_connect_tunnel(sess: api.BaseSession, service: str) -> None:
# specified, the default value is "None".
proxies = {} if urllib.request.proxy_bypass(parsed_master.hostname) else None # type: ignore

url = request.make_url(sess.master, f"proxy/{service}/")
ws = lomond.WebSocket(request.maybe_upgrade_ws_scheme(url), proxies=proxies)
url = f"{sess.master}/proxy/{service}/"
ws = lomond.WebSocket(maybe_upgrade_ws_scheme(url), proxies=proxies)
if isinstance(sess, api.Session):
ws.add_header(b"Authorization", f"Bearer {sess.token}".encode())

Expand Down Expand Up @@ -186,18 +197,18 @@ def _http_tunnel_listener(
sess: api.BaseSession,
tunnel: ListenerConfig,
) -> socketserver.ThreadingTCPServer:
parsed_master = request.parse_master_address(sess.master)
parsed_master = parse.urlparse(sess.master)
assert parsed_master.hostname is not None, f"Failed to parse master address: {sess.master}"

url = request.make_url(sess.master, f"proxy/{tunnel.service_id}/")
url = f"{sess.master}/proxy/{tunnel.service_id}/"

class TunnelHandler(socketserver.BaseRequestHandler):
def handle(self) -> None:
proxies = (
{} if urllib.request.proxy_bypass(parsed_master.hostname) else None # type: ignore
)

ws = lomond.WebSocket(request.maybe_upgrade_ws_scheme(url), proxies=proxies)
ws = lomond.WebSocket(maybe_upgrade_ws_scheme(url), proxies=proxies)
if isinstance(sess, api.Session):
ws.add_header(b"Authorization", f"Bearer {sess.token}".encode())
# We can't send data to the WebSocket before the connection becomes ready,
Expand Down
Loading

0 comments on commit 72d54be

Please sign in to comment.