Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add headless auth #508

Merged
merged 1 commit into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions safety/auth/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def render_successful_login(auth: Auth,


@auth_app.command(name=CMD_LOGIN_NAME, help=CLI_AUTH_LOGIN_HELP)
def login(ctx: typer.Context):
def login(ctx: typer.Context, headless: bool = False):
"""
Authenticate Safety CLI with your safetycli.com account using your default browser.
"""
Expand All @@ -105,29 +105,38 @@ def login(ctx: typer.Context):
fail_if_authenticated(ctx, with_msg=MSG_FAIL_LOGIN_AUTHED)

console.print()
brief_msg: str = "Redirecting your browser to log in; once authenticated, " \
"return here to start using Safety"

uri, initial_state = get_authorization_data(client=ctx.obj.auth.client,
code_verifier=ctx.obj.auth.code_verifier,
organization=ctx.obj.auth.org)
info = None

if ctx.obj.auth.org:
brief_msg: str = "Redirecting your browser to log in; once authenticated, " \
"return here to start using Safety"

if ctx.obj.auth.org:
console.print(f"Logging into [bold]{ctx.obj.auth.org.name}[/bold] " \
"organization.")


if headless:
brief_msg = "Running in headless mode. Please copy and open the following URL in a browser"


uri, initial_state = get_authorization_data(client=ctx.obj.auth.client,
code_verifier=ctx.obj.auth.code_verifier,
organization=ctx.obj.auth.org, headless=headless)
click.secho(brief_msg)
click.echo()

info = process_browser_callback(uri,
initial_state=initial_state, ctx=ctx)
info = process_browser_callback(uri, initial_state=initial_state, ctx=ctx, headless=headless)


if info:
if info.get("email", None):
organization = None
if ctx.obj.auth.org and ctx.obj.auth.org.name:
organization = ctx.obj.auth.org.name
ctx.obj.auth.refresh_from(info)
if headless:
console.print()

render_successful_login(ctx.obj.auth, organization=organization)

console.print()
Expand All @@ -149,7 +158,7 @@ def login(ctx: typer.Context):
else:
msg += "Error logging into Safety."

msg += " Please try again, or use [bold]`safety auth –help`[/bold] " \
msg += " Please try again, or use [bold]`safety auth -–help`[/bold] " \
"for more information[/red]"

console.print(msg, emoji=True)
Expand Down
5 changes: 3 additions & 2 deletions safety/auth/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json

from typing import Any, Dict, Optional, Tuple, Union
from urllib.parse import urlencode

from authlib.oidc.core import CodeIDToken
from authlib.jose import jwt
Expand All @@ -17,9 +18,9 @@

def get_authorization_data(client, code_verifier: str,
organization: Optional[Organization] = None,
sign_up: bool = False, ensure_auth: bool = False) -> Tuple[str, str]:
sign_up: bool = False, ensure_auth: bool = False, headless: bool = False) -> Tuple[str, str]:

kwargs = {'sign_up': sign_up, 'locale': 'en', 'ensure_auth': ensure_auth}
kwargs = {'sign_up': sign_up, 'locale': 'en', 'ensure_auth': ensure_auth, 'headless': headless}
if organization:
kwargs['organization'] = organization.id

Expand Down
133 changes: 85 additions & 48 deletions safety/auth/server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import http.server
import json
import logging
import socket
import sys
Expand All @@ -13,6 +14,8 @@

from safety.auth.constants import AUTH_SERVER_URL, CLI_AUTH_SUCCESS, CLI_LOGOUT_SUCCESS, HOST
from safety.auth.main import save_auth_config
from authlib.integrations.base_client.errors import OAuthError
from rich.prompt import Prompt

LOG = logging.getLogger(__name__)

Expand All @@ -33,40 +36,49 @@ def find_available_port():

return None

def auth_process(code: str, state: str, initial_state: str, code_verifier, client):
err = None

if initial_state is None or initial_state != state:
err = "The state parameter value provided does not match the expected " \
"value. The state parameter is used to protect against Cross-Site " \
"Request Forgery (CSRF) attacks. For security reasons, the " \
"authorization process cannot proceed with an invalid state " \
"parameter value. Please try again, ensuring that the state " \
"parameter value provided in the authorization request matches " \
"the value returned in the callback."

if err:
click.secho(f'Error: {err}', fg='red')
sys.exit(1)

try:
tokens = client.fetch_token(url=f'{AUTH_SERVER_URL}/oauth/token',
code_verifier=code_verifier,
client_id=client.client_id,
grant_type='authorization_code', code=code)

save_auth_config(access_token=tokens['access_token'],
id_token=tokens['id_token'],
refresh_token=tokens['refresh_token'])
return client.fetch_user_info()

except Exception as e:
LOG.exception(e)
sys.exit(1)

class CallbackHandler(http.server.BaseHTTPRequestHandler):
def auth(self, code: str, state: str, err, error_description):
initial_state = self.server.initial_state
ctx = self.server.ctx

if initial_state is None or initial_state != state:
err = "The state parameter value provided does not match the expected" \
"value. The state parameter is used to protect against Cross-Site " \
"Request Forgery (CSRF) attacks. For security reasons, the " \
"authorization process cannot proceed with an invalid state " \
"parameter value. Please try again, ensuring that the state " \
"parameter value provided in the authorization request matches " \
"the value returned in the callback."

if err:
click.secho(f'Error: {err}', fg='red')
sys.exit(1)
result = auth_process(code=code,
state=state,
initial_state=initial_state,
code_verifier=ctx.obj.auth.code_verifier,
client=ctx.obj.auth.client)

try:
tokens = ctx.obj.auth.client.fetch_token(url=f'{AUTH_SERVER_URL}/oauth/token',
code_verifier=ctx.obj.auth.code_verifier,
client_id=ctx.obj.auth.client.client_id,
grant_type='authorization_code', code=code)

save_auth_config(access_token=tokens['access_token'],
id_token=tokens['id_token'],
refresh_token=tokens['refresh_token'])
self.server.callback = ctx.obj.auth.client.fetch_user_info()

except Exception as e:
LOG.exception(e)
sys.exit(1)

self.server.callback = result
self.do_redirect(location=CLI_AUTH_SUCCESS, params={})

def logout(self):
Expand Down Expand Up @@ -132,27 +144,52 @@ def handle_timeout(self) -> None:
sys.exit(1)

try:
server = ThreadedHTTPServer((HOST, PORT), CallbackHandler)
server.initial_state = kwargs.get("initial_state", None)
server.timeout = kwargs.get("timeout", 600)
# timeout = kwargs.get("timeout", None)
# timeout = float(timeout) if timeout else None
server.ctx = kwargs.get("ctx", None)
server_thread = threading.Thread(target=server.handle_request)
server_thread.start()

target = f"{uri}&port={PORT}"
console.print(f"If the browser does not automatically open in 5 seconds, " \
"copy and paste this url into your browser: " \
f"[link={target}]{target}[/link]")
click.echo()

wait_msg = "waiting for browser authentication"

with console.status(wait_msg, spinner="bouncingBar"):
time.sleep(2)
click.launch(target)
server_thread.join()
headless = kwargs.get("headless", False)
initial_state = kwargs.get("initial_state", None)
ctx = kwargs.get("ctx", None)

message = "Copy and paste this url into your browser:"


if not headless:
server = ThreadedHTTPServer((HOST, PORT), CallbackHandler)
server.initial_state = initial_state
server.timeout = kwargs.get("timeout", 600)
server.ctx = ctx
server_thread = threading.Thread(target=server.handle_request)
server_thread.start()
message = f"If the browser does not automatically open in 5 seconds, " \
"copy and paste this url into your browser:"

target = uri if headless else f"{uri}&port={PORT}"
console.print(f"{message} [link={target}]{target}[/link]")
console.print()

if headless:

exchange_data = None
while not exchange_data:
auth_code_text = Prompt.ask("Paste the response here", default=None, console=console)
try:
exchange_data = json.loads(auth_code_text)
state = exchange_data["state"]
code = exchange_data["code"]
except Exception as e:
code = state = None

return auth_process(code=code,
state=state,
initial_state=initial_state,
code_verifier=ctx.obj.auth.code_verifier,
client=ctx.obj.auth.client)
else:

wait_msg = "waiting for browser authentication"

with console.status(wait_msg, spinner="bouncingBar"):
time.sleep(2)
click.launch(target)
server_thread.join()

except OSError as e:
if e.errno == socket.errno.EADDRINUSE:
Expand Down
4 changes: 2 additions & 2 deletions test_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pytest
pytest-cov
pytest==7.4.4
pytest-cov==4.1.0
setuptools>=65.5.1; python_version>="3.7"
setuptools; python_version=="3.6"
Click>=8.0.2
Expand Down
2 changes: 1 addition & 1 deletion tests/auth/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_auth_calls_login(self, process_browser_callback,
get_authorization_data.assert_called_once()
process_browser_callback.assert_called_once_with(auth_data[0],
initial_state=auth_data[1],
ctx=ANY)
ctx=ANY, headless=False)

expected = [
"",
Expand Down
6 changes: 4 additions & 2 deletions tests/auth/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def test_get_authorization_data(self):
"sign_up": False,
"locale": "en",
"ensure_auth": False,
"organization": org_id
"organization": org_id,
"headless": False
}

client.create_authorization_url.assert_called_once_with(
Expand All @@ -42,7 +43,8 @@ def test_get_authorization_data(self):
kwargs = {
"sign_up": False,
"locale": "en",
"ensure_auth":False
"ensure_auth":False,
"headless": False
}

client.create_authorization_url.assert_called_once_with(
Expand Down
48 changes: 28 additions & 20 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,7 @@ def test_validate_with_basic_policy_file(self):
result = self.runner.invoke(cli.cli, ['validate', 'policy_file', '3.0', '--path', path])
cleaned_stdout = click.unstyle(result.stdout)
msg = 'The Safety policy (3.0) file (Used for scan and system-scan commands) was successfully parsed with the following values:\n'
parsed = json.dumps(
{
parsed = {
"version": "3.0",
"scan": {
"max_depth": 6,
Expand All @@ -230,19 +229,19 @@ def test_validate_with_basic_policy_file(self):
},
"fail_scan": {
"dependency_vulnerabilities": {
"enabled": True,
"fail_on_any_of": {
"cvss_severity": [
"critical",
"high",
"medium"
],
"exploitability": [
"critical",
"high",
"medium"
]
}
"enabled": True,
"fail_on_any_of": {
"cvss_severity": [
"critical",
"high",
"medium",
],
"exploitability": [
"critical",
"high",
"medium",
]
}
}
},
"security_updates": {
Expand All @@ -252,12 +251,21 @@ def test_validate_with_basic_policy_file(self):
]
}
}
},
indent=2
) + '\n'
}

self.assertEqual(msg + parsed, cleaned_stdout)
self.assertEqual(result.exit_code, 0)
msg_stdout, parsed_policy = cleaned_stdout.split('\n', 1)
msg_stdout += '\n'
parsed_policy = json.loads(parsed_policy.replace('\n', ''))

fail_scan = parsed_policy.get("fail_scan", None)
self.assertIsNotNone(fail_scan)
fail_of_any = fail_scan["dependency_vulnerabilities"]["fail_on_any_of"]
fail_of_any["cvss_severity"] = sorted(fail_of_any["cvss_severity"])
fail_of_any["exploitability"] = sorted(fail_of_any["exploitability"])

self.assertEqual(msg, msg_stdout)
self.assertEqual(parsed, parsed_policy)
self.assertEqual(result.exit_code, 0)


def test_validate_with_policy_file_using_invalid_keyword(self):
Expand Down
2 changes: 2 additions & 0 deletions tests/test_safety.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,8 @@ def test_get_announcements_http_ok(self, get_used_options):
@patch("safety.util.get_used_options")
@patch.object(click, 'get_current_context', Mock(command=Mock(name=Mock(return_value='check'))))
def test_get_announcements_wrong_json_response_handling(self, get_used_options):
get_used_options.return_value = {}

# wrong JSON structure
announcements = {
"type": "notice",
Expand Down
4 changes: 2 additions & 2 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ isolated_build = true

[testenv]
deps =
pytest-cov
pytest
pytest-cov==4.1.0
pytest==7.4.4

commands =
pytest -rP tests/ --cov=safety/ --cov-report=html
Expand Down
Loading