Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix mTLS logic #374

Merged
merged 2 commits into from
Apr 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore
DEFAULT_ENDPOINT
)
DEFAULT_OPTIONS = ClientOptions.ClientOptions(api_endpoint=DEFAULT_ENDPOINT)

@classmethod
def from_service_account_file(cls, filename: str, *args, **kwargs):
Expand Down Expand Up @@ -126,7 +125,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
def __init__(self, *,
credentials: credentials.Credentials = None,
transport: Union[str, {{ service.name }}Transport] = None,
client_options: ClientOptions = DEFAULT_OPTIONS,
client_options: ClientOptions = None,
) -> None:
"""Instantiate the {{ (service.client_name|snake_case).replace('_', ' ') }}.

Expand All @@ -143,12 +142,10 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
(1) The ``api_endpoint`` property can be used to override the
default endpoint provided by the client.
(2) If ``transport`` argument is None, ``client_options`` can be
used to create a mutual TLS transport. If ``api_endpoint`` is
provided and different from the default endpoint, or the
``client_cert_source`` property is provided, mutual TLS
transport will be created if client SSL credentials are found.
Client SSL credentials are obtained from ``client_cert_source``
or application default SSL credentials.
used to create a mutual TLS transport. If ``client_cert_source``
is provided, mutual TLS transport will be created with the given
``api_endpoint`` or the default mTLS endpoint, and the client
SSL credentials obtained from ``client_cert_source``.

Raises:
google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
Expand All @@ -157,10 +154,6 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
if isinstance(client_options, dict):
client_options = ClientOptions.from_dict(client_options)

# Set default api endpoint if not set.
if client_options.api_endpoint is None:
client_options.api_endpoint = self.DEFAULT_ENDPOINT

# Save or instantiate the transport.
# Ordinarily, we provide the transport, but allowing a custom transport
# instance provides an extensibility point for unusual situations.
Expand All @@ -170,24 +163,37 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
raise ValueError('When providing a transport instance, '
'provide its credentials directly.')
self._transport = transport
elif transport is not None or (
client_options.api_endpoint == self.DEFAULT_ENDPOINT
elif client_options is None or (
client_options.api_endpoint == None
and client_options.client_cert_source is None
):
# Don't trigger mTLS.
# Don't trigger mTLS if we get an empty ClientOptions.
Transport = type(self).get_transport_class(transport)
self._transport = Transport(
credentials=credentials, host=client_options.api_endpoint
credentials=credentials, host=self.DEFAULT_ENDPOINT
)
else:
# Trigger mTLS. If the user overrides endpoint, use it as the mTLS
# endpoint, otherwise use the default mTLS endpoint.
option_endpoint = client_options.api_endpoint
api_mtls_endpoint = self.DEFAULT_MTLS_ENDPOINT if option_endpoint == self.DEFAULT_ENDPOINT else option_endpoint
# We have a non-empty ClientOptions. If client_cert_source is
# provided, trigger mTLS with user provided endpoint or the default
# mTLS endpoint.
if client_options.client_cert_source:
api_mtls_endpoint = (
client_options.api_endpoint
if client_options.api_endpoint
else self.DEFAULT_MTLS_ENDPOINT
)
else:
api_mtls_endpoint = None

api_endpoint = (
client_options.api_endpoint
if client_options.api_endpoint
else self.DEFAULT_ENDPOINT
)

self._transport = {{ service.name }}GrpcTransport(
arithmetic1728 marked this conversation as resolved.
Show resolved Hide resolved
credentials=credentials,
host=client_options.api_endpoint,
host=api_endpoint,
api_mtls_endpoint=api_mtls_endpoint,
client_cert_source=client_options.client_cert_source,
)
Expand Down
26 changes: 19 additions & 7 deletions gapic/templates/tests/unit/%name_%version/%sub/test_%service.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,6 @@ def test_{{ service.client_name|snake_case }}_from_service_account_file():


def test_{{ service.client_name|snake_case }}_client_options():
# Check the default options have their expected values.
assert {{ service.client_name }}.DEFAULT_OPTIONS.api_endpoint == {% if service.host %}'{{ service.host }}'{% else %}None{% endif %}
assert {{ service.client_name }}.DEFAULT_OPTIONS.api_endpoint == {{ service.client_name }}.DEFAULT_ENDPOINT

# Check that if channel is provided we won't create a new one.
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.{{ service.client_name }}.get_transport_class') as gtc:
transport = transports.{{ service.name }}GrpcTransport(
Expand All @@ -86,13 +82,14 @@ def test_{{ service.client_name|snake_case }}_client_options():
host=client.DEFAULT_ENDPOINT,
)

# Check mTLS is triggered with api endpoint override.
# Check mTLS is not triggered if api_endpoint is provided but
# client_cert_source is None.
options = client_options.ClientOptions(api_endpoint="squid.clam.whelk")
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
grpc_transport.return_value = None
client = {{ service.client_name }}(client_options=options)
grpc_transport.assert_called_once_with(
api_mtls_endpoint="squid.clam.whelk",
api_mtls_endpoint=None,
client_cert_source=None,
credentials=None,
host="squid.clam.whelk",
Expand All @@ -112,14 +109,29 @@ def test_{{ service.client_name|snake_case }}_client_options():
host=client.DEFAULT_ENDPOINT,
)

# Check mTLS is triggered if api_endpoint and client_cert_source are provided.
options = client_options.ClientOptions(
api_endpoint="squid.clam.whelk",
client_cert_source=client_cert_source_callback
)
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
grpc_transport.return_value = None
client = {{ service.client_name }}(client_options=options)
grpc_transport.assert_called_once_with(
api_mtls_endpoint="squid.clam.whelk",
client_cert_source=client_cert_source_callback,
credentials=None,
host="squid.clam.whelk",
)

def test_{{ service.client_name|snake_case }}_client_options_from_dict():
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
grpc_transport.return_value = None
client = {{ service.client_name }}(
client_options={'api_endpoint': 'squid.clam.whelk'}
)
grpc_transport.assert_called_once_with(
api_mtls_endpoint="squid.clam.whelk",
api_mtls_endpoint=None,
client_cert_source=None,
credentials=None,
host="squid.clam.whelk",
Expand Down