From ca1b751fb4f17e98f8c4baee85e9861c8f6f8bf1 Mon Sep 17 00:00:00 2001 From: Stefan Nica Date: Tue, 14 May 2024 17:21:00 +0200 Subject: [PATCH] Remove expiry from STS token config and make it configurable via CLI --- src/zenml/cli/service_connectors.py | 94 ++++++++++++++++++- src/zenml/cli/utils.py | 2 +- src/zenml/client.py | 4 + .../aws_service_connector.py | 7 +- 4 files changed, 98 insertions(+), 9 deletions(-) diff --git a/src/zenml/cli/service_connectors.py b/src/zenml/cli/service_connectors.py index 2a6c1eabc5a..1197624e522 100644 --- a/src/zenml/cli/service_connectors.py +++ b/src/zenml/cli/service_connectors.py @@ -268,6 +268,65 @@ def prompt_expiration_time( return expiration_seconds +def prompt_expires_at( + default: Optional[datetime] = None, +) -> Optional[datetime]: + """Prompt the user for an expiration timestamp. + + Args: + default: The default expiration time. + + Returns: + The expiration time provided by the user. + """ + if default is None: + confirm = click.confirm( + "Are the credentials you configured temporary? If so, you'll be asked " + "to provide an expiration time in the next step.", + default=False, + ) + if not confirm: + return None + + while True: + default_str = "" + if default is not None: + seconds = int((default - datetime.utcnow()).total_seconds()) + default_str = ( + f" [{str(default)} i.e. in " + f"{seconds_to_human_readable(seconds)}]" + ) + + expires_at = click.prompt( + "Please enter the exact UTC date and time when the credentials " + f"will expire e.g. '2023-12-31 23:59:59'{default_str}", + type=click.DateTime(), + default=default, + show_default=False, + ) + + assert expires_at is not None + assert isinstance(expires_at, datetime) + if expires_at < datetime.utcnow(): + cli_utils.warning( + "The expiration time must be in the future. Please enter a " + "later date and time." + ) + continue + + seconds = int((expires_at - datetime.utcnow()).total_seconds()) + + confirm = click.confirm( + f"Credentials will be valid until {str(expires_at)} UTC (i.e. " + f"in {seconds_to_human_readable(seconds)}. Keep this value?", + default=True, + ) + if confirm: + break + + return expires_at + + @service_connector.command( "register", context_settings={"ignore_unknown_options": True}, @@ -367,6 +426,16 @@ def prompt_expiration_time( required=False, type=str, ) +@click.option( + "--expires-at", + "expires_at", + help="The exact UTC date and time when the credentials configured for this " + "connector will expire. Takes the form 'YYYY-MM-DD HH:MM:SS'. This is only " + "required if you are configuring a service connector with expiring " + "credentials.", + required=False, + type=click.DateTime(), +) @click.option( "--expires-skew-tolerance", "expires_skew_tolerance", @@ -444,6 +513,7 @@ def register_service_connector( resource_type: Optional[str] = None, resource_id: Optional[str] = None, auth_method: Optional[str] = None, + expires_at: Optional[datetime] = None, expires_skew_tolerance: Optional[int] = None, expiration_seconds: Optional[int] = None, no_verify: bool = False, @@ -463,6 +533,8 @@ def register_service_connector( resource_type: The type of resource to connect to. resource_id: The ID of the resource to connect to. auth_method: The authentication method to use. + expires_at: The exact UTC date and time when the credentials configured + for this connector will expire. expires_skew_tolerance: The tolerance, in seconds, allowed when determining when the credentials configured for or generated by this connector will expire. @@ -495,8 +567,6 @@ def register_service_connector( # Parse the given labels parsed_labels = cast(Dict[str, str], cli_utils.get_parsed_labels(labels)) - expires_at: Optional[datetime] = None - if interactive: # Get the list of available service connector types connector_types = client.list_service_connector_types( @@ -764,6 +834,9 @@ def register_service_connector( default=auth_method_spec.default_expiration_seconds, ) + # Prompt for the time when the credentials will expire + expires_at = prompt_expires_at(expires_at) + try: # Validate the connector configuration and fetch all available # resources that are accessible with the provided configuration @@ -781,6 +854,7 @@ def register_service_connector( auth_method=auth_method, resource_type=resource_type, configuration=config_dict, + expires_at=expires_at, expires_skew_tolerance=expires_skew_tolerance, expiration_seconds=expiration_seconds, auto_configure=False, @@ -1176,6 +1250,14 @@ def describe_service_connector( required=False, type=str, ) +@click.option( + "--expires-at", + "expires_at", + help="The time at which the credentials configured for this connector " + "will expire.", + required=False, + type=click.DateTime(), +) @click.option( "--expires-skew-tolerance", "expires_skew_tolerance", @@ -1244,6 +1326,7 @@ def update_service_connector( resource_type: Optional[str] = None, resource_id: Optional[str] = None, auth_method: Optional[str] = None, + expires_at: Optional[datetime] = None, expires_skew_tolerance: Optional[int] = None, expiration_seconds: Optional[int] = None, no_verify: bool = False, @@ -1264,6 +1347,8 @@ def update_service_connector( resource_type: The type of resource to connect to. resource_id: The ID of the resource to connect to. auth_method: The authentication method to use. + expires_at: The time at which the credentials configured for this + connector will expire. expires_skew_tolerance: The tolerance, in seconds, allowed when determining when the credentials configured for or generated by this connector will expire. @@ -1425,6 +1510,9 @@ def update_service_connector( or auth_method_spec.default_expiration_seconds, ) + # Prompt for the time when the credentials will expire + expires_at = prompt_expires_at(expires_at or connector.expires_at) + try: # Validate the connector configuration and fetch all available # resources that are accessible with the provided configuration @@ -1442,6 +1530,7 @@ def update_service_connector( # should be removed in the update if not set here resource_type=resource_type or "", configuration=config_dict, + expires_at=expires_at, # Use zero value to indicate that the expiration time # should be removed in the update if not set here expiration_seconds=expiration_seconds or 0, @@ -1534,6 +1623,7 @@ def update_service_connector( # should be removed in the update if not set here resource_id=resource_id or "", description=description, + expires_at=expires_at, # Use empty string to indicate that the expiration time # should be removed in the update if not set here expiration_seconds=expiration_seconds or 0, diff --git a/src/zenml/cli/utils.py b/src/zenml/cli/utils.py index 8d05f7424b1..a2aaf033bfb 100644 --- a/src/zenml/cli/utils.py +++ b/src/zenml/cli/utils.py @@ -1672,7 +1672,7 @@ def expires_in( expires_at -= datetime.timedelta(seconds=skew_tolerance) if expires_at < now: return expired_str - return seconds_to_human_readable((expires_at - now).seconds) + return seconds_to_human_readable(int((expires_at - now).total_seconds())) def print_service_connectors_table( diff --git a/src/zenml/client.py b/src/zenml/client.py index 4708bb69137..0169844fdf4 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -4965,6 +4965,7 @@ def update_service_connector( configuration: Optional[Dict[str, str]] = None, resource_id: Optional[str] = None, description: Optional[str] = None, + expires_at: Optional[datetime] = None, expires_skew_tolerance: Optional[int] = None, expiration_seconds: Optional[int] = None, labels: Optional[Dict[str, Optional[str]]] = None, @@ -5008,6 +5009,7 @@ def update_service_connector( If set to the empty string, the existing resource ID will be removed. description: The description of the service connector. + expires_at: The new UTC expiration time of the service connector. expires_skew_tolerance: The allowed expiration skew for the service connector credentials. expiration_seconds: The expiration time of the service connector. @@ -5074,11 +5076,13 @@ def update_service_connector( connector_type=connector.connector_type, description=description or connector_model.description, auth_method=auth_method or connector_model.auth_method, + expires_at=expires_at, expires_skew_tolerance=expires_skew_tolerance, expiration_seconds=expiration_seconds, user=self.active_user.id, workspace=self.active_workspace.id, ) + # Validate and configure the resources if configuration is not None: # The supplied configuration is a drop-in replacement for the diff --git a/src/zenml/integrations/aws/service_connectors/aws_service_connector.py b/src/zenml/integrations/aws/service_connectors/aws_service_connector.py index 851bf3f07a5..51fde3c17c8 100644 --- a/src/zenml/integrations/aws/service_connectors/aws_service_connector.py +++ b/src/zenml/integrations/aws/service_connectors/aws_service_connector.py @@ -142,11 +142,6 @@ class AWSSecretKeyConfig(AWSBaseConfig, AWSSecretKey): class STSTokenConfig(AWSBaseConfig, STSToken): """AWS STS token authentication configuration.""" - expires_at: Optional[datetime.datetime] = Field( - default=None, - title="AWS STS Token Expiration", - ) - class IAMRoleAuthenticationConfig(AWSSecretKeyConfig, AWSSessionPolicy): """AWS IAM authentication config.""" @@ -983,7 +978,7 @@ def _authenticate( aws_session_token=cfg.aws_session_token.get_secret_value(), region_name=cfg.region, ) - return session, cfg.expires_at + return session, self.expires_at elif auth_method in [ AWSAuthenticationMethods.IAM_ROLE, AWSAuthenticationMethods.SESSION_TOKEN,