Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow credentials expiry to be configured for service connectors #2704

Merged
merged 1 commit into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 92 additions & 2 deletions src/zenml/cli/service_connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,65 @@ def prompt_expiration_time(
return expiration_seconds


def prompt_expires_at(
default: Optional[datetime] = None,
) -> Optional[datetime]:
"""Prompt the user for an expiration timestamp.

Args:
default: The default expiration time.

Returns:
The expiration time provided by the user.
"""
if default is None:
confirm = click.confirm(
"Are the credentials you configured temporary? If so, you'll be asked "
"to provide an expiration time in the next step.",
default=False,
)
if not confirm:
return None

while True:
default_str = ""
if default is not None:
seconds = int((default - datetime.utcnow()).total_seconds())
default_str = (
f" [{str(default)} i.e. in "
f"{seconds_to_human_readable(seconds)}]"
)

expires_at = click.prompt(
"Please enter the exact UTC date and time when the credentials "
f"will expire e.g. '2023-12-31 23:59:59'{default_str}",
type=click.DateTime(),
default=default,
show_default=False,
)

assert expires_at is not None
assert isinstance(expires_at, datetime)
if expires_at < datetime.utcnow():
cli_utils.warning(
"The expiration time must be in the future. Please enter a "
"later date and time."
)
continue

seconds = int((expires_at - datetime.utcnow()).total_seconds())

confirm = click.confirm(
f"Credentials will be valid until {str(expires_at)} UTC (i.e. "
f"in {seconds_to_human_readable(seconds)}. Keep this value?",
default=True,
)
if confirm:
break

return expires_at


@service_connector.command(
"register",
context_settings={"ignore_unknown_options": True},
Expand Down Expand Up @@ -367,6 +426,16 @@ def prompt_expiration_time(
required=False,
type=str,
)
@click.option(
"--expires-at",
"expires_at",
help="The exact UTC date and time when the credentials configured for this "
"connector will expire. Takes the form 'YYYY-MM-DD HH:MM:SS'. This is only "
"required if you are configuring a service connector with expiring "
"credentials.",
required=False,
type=click.DateTime(),
)
@click.option(
"--expires-skew-tolerance",
"expires_skew_tolerance",
Expand Down Expand Up @@ -444,6 +513,7 @@ def register_service_connector(
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
auth_method: Optional[str] = None,
expires_at: Optional[datetime] = None,
expires_skew_tolerance: Optional[int] = None,
expiration_seconds: Optional[int] = None,
no_verify: bool = False,
Expand All @@ -463,6 +533,8 @@ def register_service_connector(
resource_type: The type of resource to connect to.
resource_id: The ID of the resource to connect to.
auth_method: The authentication method to use.
expires_at: The exact UTC date and time when the credentials configured
for this connector will expire.
expires_skew_tolerance: The tolerance, in seconds, allowed when
determining when the credentials configured for or generated by
this connector will expire.
Expand Down Expand Up @@ -495,8 +567,6 @@ def register_service_connector(
# Parse the given labels
parsed_labels = cast(Dict[str, str], cli_utils.get_parsed_labels(labels))

expires_at: Optional[datetime] = None

if interactive:
# Get the list of available service connector types
connector_types = client.list_service_connector_types(
Expand Down Expand Up @@ -764,6 +834,9 @@ def register_service_connector(
default=auth_method_spec.default_expiration_seconds,
)

# Prompt for the time when the credentials will expire
expires_at = prompt_expires_at(expires_at)

try:
# Validate the connector configuration and fetch all available
# resources that are accessible with the provided configuration
Expand All @@ -781,6 +854,7 @@ def register_service_connector(
auth_method=auth_method,
resource_type=resource_type,
configuration=config_dict,
expires_at=expires_at,
expires_skew_tolerance=expires_skew_tolerance,
expiration_seconds=expiration_seconds,
auto_configure=False,
Expand Down Expand Up @@ -1176,6 +1250,14 @@ def describe_service_connector(
required=False,
type=str,
)
@click.option(
"--expires-at",
"expires_at",
help="The time at which the credentials configured for this connector "
"will expire.",
required=False,
type=click.DateTime(),
)
@click.option(
"--expires-skew-tolerance",
"expires_skew_tolerance",
Expand Down Expand Up @@ -1244,6 +1326,7 @@ def update_service_connector(
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
auth_method: Optional[str] = None,
expires_at: Optional[datetime] = None,
expires_skew_tolerance: Optional[int] = None,
expiration_seconds: Optional[int] = None,
no_verify: bool = False,
Expand All @@ -1264,6 +1347,8 @@ def update_service_connector(
resource_type: The type of resource to connect to.
resource_id: The ID of the resource to connect to.
auth_method: The authentication method to use.
expires_at: The time at which the credentials configured for this
connector will expire.
expires_skew_tolerance: The tolerance, in seconds, allowed when
determining when the credentials configured for or generated by
this connector will expire.
Expand Down Expand Up @@ -1425,6 +1510,9 @@ def update_service_connector(
or auth_method_spec.default_expiration_seconds,
)

# Prompt for the time when the credentials will expire
expires_at = prompt_expires_at(expires_at or connector.expires_at)

try:
# Validate the connector configuration and fetch all available
# resources that are accessible with the provided configuration
Expand All @@ -1442,6 +1530,7 @@ def update_service_connector(
# should be removed in the update if not set here
resource_type=resource_type or "",
configuration=config_dict,
expires_at=expires_at,
# Use zero value to indicate that the expiration time
# should be removed in the update if not set here
expiration_seconds=expiration_seconds or 0,
Expand Down Expand Up @@ -1534,6 +1623,7 @@ def update_service_connector(
# should be removed in the update if not set here
resource_id=resource_id or "",
description=description,
expires_at=expires_at,
# Use empty string to indicate that the expiration time
# should be removed in the update if not set here
expiration_seconds=expiration_seconds or 0,
Expand Down
2 changes: 1 addition & 1 deletion src/zenml/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1672,7 +1672,7 @@ def expires_in(
expires_at -= datetime.timedelta(seconds=skew_tolerance)
if expires_at < now:
return expired_str
return seconds_to_human_readable((expires_at - now).seconds)
return seconds_to_human_readable(int((expires_at - now).total_seconds()))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE: this was previously bugged because it only returned part of the seconds remaining up to the next day instead of the total number of seconds



def print_service_connectors_table(
Expand Down
4 changes: 4 additions & 0 deletions src/zenml/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4965,6 +4965,7 @@ def update_service_connector(
configuration: Optional[Dict[str, str]] = None,
resource_id: Optional[str] = None,
description: Optional[str] = None,
expires_at: Optional[datetime] = None,
expires_skew_tolerance: Optional[int] = None,
expiration_seconds: Optional[int] = None,
labels: Optional[Dict[str, Optional[str]]] = None,
Expand Down Expand Up @@ -5008,6 +5009,7 @@ def update_service_connector(
If set to the empty string, the existing resource ID will be
removed.
description: The description of the service connector.
expires_at: The new UTC expiration time of the service connector.
expires_skew_tolerance: The allowed expiration skew for the service
connector credentials.
expiration_seconds: The expiration time of the service connector.
Expand Down Expand Up @@ -5074,11 +5076,13 @@ def update_service_connector(
connector_type=connector.connector_type,
description=description or connector_model.description,
auth_method=auth_method or connector_model.auth_method,
expires_at=expires_at,
expires_skew_tolerance=expires_skew_tolerance,
expiration_seconds=expiration_seconds,
user=self.active_user.id,
workspace=self.active_workspace.id,
)

# Validate and configure the resources
if configuration is not None:
# The supplied configuration is a drop-in replacement for the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,6 @@ class AWSSecretKeyConfig(AWSBaseConfig, AWSSecretKey):
class STSTokenConfig(AWSBaseConfig, STSToken):
"""AWS STS token authentication configuration."""

expires_at: Optional[datetime.datetime] = Field(
default=None,
title="AWS STS Token Expiration",
)


class IAMRoleAuthenticationConfig(AWSSecretKeyConfig, AWSSessionPolicy):
"""AWS IAM authentication config."""
Expand Down Expand Up @@ -983,7 +978,7 @@ def _authenticate(
aws_session_token=cfg.aws_session_token.get_secret_value(),
region_name=cfg.region,
)
return session, cfg.expires_at
return session, self.expires_at
elif auth_method in [
AWSAuthenticationMethods.IAM_ROLE,
AWSAuthenticationMethods.SESSION_TOKEN,
Expand Down
Loading