diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 index f0b7d05381..5c5fd5157b 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 @@ -91,7 +91,6 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) - DEFAULT_OPTIONS = ClientOptions.ClientOptions(api_endpoint=DEFAULT_ENDPOINT) @classmethod def from_service_account_file(cls, filename: str, *args, **kwargs): @@ -126,7 +125,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): def __init__(self, *, credentials: credentials.Credentials = None, transport: Union[str, {{ service.name }}Transport] = None, - client_options: ClientOptions = DEFAULT_OPTIONS, + client_options: ClientOptions = None, ) -> None: """Instantiate the {{ (service.client_name|snake_case).replace('_', ' ') }}. @@ -143,12 +142,10 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): (1) The ``api_endpoint`` property can be used to override the default endpoint provided by the client. (2) If ``transport`` argument is None, ``client_options`` can be - used to create a mutual TLS transport. If ``api_endpoint`` is - provided and different from the default endpoint, or the - ``client_cert_source`` property is provided, mutual TLS - transport will be created if client SSL credentials are found. - Client SSL credentials are obtained from ``client_cert_source`` - or application default SSL credentials. + used to create a mutual TLS transport. If ``client_cert_source`` + is provided, mutual TLS transport will be created with the given + ``api_endpoint`` or the default mTLS endpoint, and the client + SSL credentials obtained from ``client_cert_source``. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -157,10 +154,6 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): if isinstance(client_options, dict): client_options = ClientOptions.from_dict(client_options) - # Set default api endpoint if not set. - if client_options.api_endpoint is None: - client_options.api_endpoint = self.DEFAULT_ENDPOINT - # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. @@ -170,24 +163,37 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): raise ValueError('When providing a transport instance, ' 'provide its credentials directly.') self._transport = transport - elif transport is not None or ( - client_options.api_endpoint == self.DEFAULT_ENDPOINT + elif client_options is None or ( + client_options.api_endpoint == None and client_options.client_cert_source is None ): - # Don't trigger mTLS. + # Don't trigger mTLS if we get an empty ClientOptions. Transport = type(self).get_transport_class(transport) self._transport = Transport( - credentials=credentials, host=client_options.api_endpoint + credentials=credentials, host=self.DEFAULT_ENDPOINT ) else: - # Trigger mTLS. If the user overrides endpoint, use it as the mTLS - # endpoint, otherwise use the default mTLS endpoint. - option_endpoint = client_options.api_endpoint - api_mtls_endpoint = self.DEFAULT_MTLS_ENDPOINT if option_endpoint == self.DEFAULT_ENDPOINT else option_endpoint + # We have a non-empty ClientOptions. If client_cert_source is + # provided, trigger mTLS with user provided endpoint or the default + # mTLS endpoint. + if client_options.client_cert_source: + api_mtls_endpoint = ( + client_options.api_endpoint + if client_options.api_endpoint + else self.DEFAULT_MTLS_ENDPOINT + ) + else: + api_mtls_endpoint = None + + api_endpoint = ( + client_options.api_endpoint + if client_options.api_endpoint + else self.DEFAULT_ENDPOINT + ) self._transport = {{ service.name }}GrpcTransport( credentials=credentials, - host=client_options.api_endpoint, + host=api_endpoint, api_mtls_endpoint=api_mtls_endpoint, client_cert_source=client_options.client_cert_source, ) diff --git a/gapic/templates/tests/unit/%name_%version/%sub/test_%service.py.j2 b/gapic/templates/tests/unit/%name_%version/%sub/test_%service.py.j2 index e55cc99909..4e3e89a329 100644 --- a/gapic/templates/tests/unit/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/templates/tests/unit/%name_%version/%sub/test_%service.py.j2 @@ -64,10 +64,6 @@ def test_{{ service.client_name|snake_case }}_from_service_account_file(): def test_{{ service.client_name|snake_case }}_client_options(): - # Check the default options have their expected values. - assert {{ service.client_name }}.DEFAULT_OPTIONS.api_endpoint == {% if service.host %}'{{ service.host }}'{% else %}None{% endif %} - assert {{ service.client_name }}.DEFAULT_OPTIONS.api_endpoint == {{ service.client_name }}.DEFAULT_ENDPOINT - # Check that if channel is provided we won't create a new one. with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.{{ service.client_name }}.get_transport_class') as gtc: transport = transports.{{ service.name }}GrpcTransport( @@ -86,13 +82,14 @@ def test_{{ service.client_name|snake_case }}_client_options(): host=client.DEFAULT_ENDPOINT, ) - # Check mTLS is triggered with api endpoint override. + # Check mTLS is not triggered if api_endpoint is provided but + # client_cert_source is None. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport: grpc_transport.return_value = None client = {{ service.client_name }}(client_options=options) grpc_transport.assert_called_once_with( - api_mtls_endpoint="squid.clam.whelk", + api_mtls_endpoint=None, client_cert_source=None, credentials=None, host="squid.clam.whelk", @@ -112,6 +109,21 @@ def test_{{ service.client_name|snake_case }}_client_options(): host=client.DEFAULT_ENDPOINT, ) + # Check mTLS is triggered if api_endpoint and client_cert_source are provided. + options = client_options.ClientOptions( + api_endpoint="squid.clam.whelk", + client_cert_source=client_cert_source_callback + ) + with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport: + grpc_transport.return_value = None + client = {{ service.client_name }}(client_options=options) + grpc_transport.assert_called_once_with( + api_mtls_endpoint="squid.clam.whelk", + client_cert_source=client_cert_source_callback, + credentials=None, + host="squid.clam.whelk", + ) + def test_{{ service.client_name|snake_case }}_client_options_from_dict(): with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport: grpc_transport.return_value = None @@ -119,7 +131,7 @@ def test_{{ service.client_name|snake_case }}_client_options_from_dict(): client_options={'api_endpoint': 'squid.clam.whelk'} ) grpc_transport.assert_called_once_with( - api_mtls_endpoint="squid.clam.whelk", + api_mtls_endpoint=None, client_cert_source=None, credentials=None, host="squid.clam.whelk",