Skip to content

Commit

Permalink
feat: precache wrapped rpcs (#553)
Browse files Browse the repository at this point in the history
During transport construction, cache the wrapped methods that the
client will eventually use when invoking rpcs.

This has a ~7.4% time impact in synthetic benchmarks.
  • Loading branch information
software-dov authored Jul 27, 2020
1 parent 5239ca8 commit 2f2fb5d
Show file tree
Hide file tree
Showing 10 changed files with 109 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
client_cert_source=client_options.client_cert_source,
)


{% for method in service.methods.values() -%}
def {{ method.name|snake_case }}(self,
{%- if not method.client_streaming %}
Expand Down Expand Up @@ -307,25 +308,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):

# Wrap the RPC method; this adds retry and timeout information,
# and friendly error handling.
rpc = gapic_v1.method.wrap_method(
self._transport.{{ method.name|snake_case }},
{%- if method.retry %}
default_retry=retries.Retry(
{% if method.retry.initial_backoff %}initial={{ method.retry.initial_backoff }},{% endif %}
{% if method.retry.max_backoff %}maximum={{ method.retry.max_backoff }},{% endif %}
{% if method.retry.backoff_multiplier %}multiplier={{ method.retry.backoff_multiplier }},{% endif %}
predicate=retries.if_exception_type(
{%- filter sort_lines %}
{%- for ex in method.retry.retryable_exceptions %}
exceptions.{{ ex.__name__ }},
{%- endfor %}
{%- endfilter %}
),
),
{%- endif %}
default_timeout={{ method.timeout }},
client_info=_client_info,
)
rpc = self._transport._wrapped_methods[self._transport.{{ method.name|snake_case}}]
{%- if method.field_headers %}

# Certain fields should be provided within the metadata header;
Expand Down Expand Up @@ -381,16 +364,6 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
{% endfor %}


try:
_client_info = gapic_v1.client_info.ClientInfo(
gapic_version=pkg_resources.get_distribution(
'{{ api.naming.warehouse_package_name }}',
).version,
)
except pkg_resources.DistributionNotFound:
_client_info = gapic_v1.client_info.ClientInfo()


__all__ = (
'{{ service.client_name }}',
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
{% block content %}
import abc
import typing
import pkg_resources

from google import auth
from google.api_core import gapic_v1 # type: ignore
{%- if service.has_lro %}
from google.api_core import operations_v1 # type: ignore
{%- endif %}
Expand All @@ -17,6 +19,16 @@ from google.auth import credentials # type: ignore
{% endfor -%}
{% endfilter %}

try:
_client_info = gapic_v1.client_info.ClientInfo(
gapic_version=pkg_resources.get_distribution(
'{{ api.naming.warehouse_package_name }}',
).version,
)
except pkg_resources.DistributionNotFound:
_client_info = gapic_v1.client_info.ClientInfo()


class {{ service.name }}Transport(metaclass=abc.ABCMeta):
"""Abstract transport class for {{ service.name }}."""

Expand Down Expand Up @@ -54,6 +66,37 @@ class {{ service.name }}Transport(metaclass=abc.ABCMeta):

# Save the credentials.
self._credentials = credentials

# Lifted into its own function so it can be stubbed out during tests.
self._prep_wrapped_messages()

def _prep_wrapped_messages(self):
# Precomputed wrapped methods
self._wrapped_methods = {
{% for method in service.methods.values() -%}
self.{{ method.name|snake_case }}: gapic_v1.method.wrap_method(
self.{{ method.name|snake_case }},
{%- if method.retry %}
default_retry=retries.Retry(
{% if method.retry.initial_backoff %}initial={{ method.retry.initial_backoff }},{% endif %}
{% if method.retry.max_backoff %}maximum={{ method.retry.max_backoff }},{% endif %}
{% if method.retry.backoff_multiplier %}multiplier={{ method.retry.backoff_multiplier }},{% endif %}
predicate=retries.if_exception_type(
{%- filter sort_lines %}
{%- for ex in method.retry.retryable_exceptions %}
exceptions.{{ ex.__name__ }},
{%- endfor %}
{%- endfilter %}
),
),
{%- endif %}
default_timeout={{ method.timeout }},
client_info=_client_info,
),
{% endfor %} {# precomputed wrappers loop #}
}


{%- if service.has_lro %}

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,10 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
scopes=self.AUTH_SCOPES,
)

self._stubs = {} # type: Dict[str, Callable]

# Run the base constructor.
super().__init__(host=host, credentials=credentials)
self._stubs = {} # type: Dict[str, Callable]


@classmethod
Expand Down
2 changes: 1 addition & 1 deletion gapic/ads-templates/setup.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ setuptools.setup(
'google-api-core >= 1.17.0, < 2.0.0dev',
'googleapis-common-protos >= 1.5.8',
'grpcio >= 1.10.0',
'proto-plus >= 1.1.0',
'proto-plus >= 1.4.0',
{%- if api.requires_package(('google', 'iam', 'v1')) %}
'grpc-google-iam-v1',
{%- endif %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def test_{{ service.client_name|snake_case }}_client_options():


def test_{{ service.client_name|snake_case }}_client_options_from_dict():
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}Transport.__init__') as grpc_transport:
grpc_transport.return_value = None
client = {{ service.client_name }}(
client_options={'api_endpoint': 'squid.clam.whelk'}
Expand Down Expand Up @@ -556,9 +556,11 @@ def test_transport_grpc_default():

def test_{{ service.name|snake_case }}_base_transport():
# Instantiate the base transport.
transport = transports.{{ service.name }}Transport(
credentials=credentials.AnonymousCredentials(),
)
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as Transport:
Transport.return_value = None
transport = transports.{{ service.name }}Transport(
credentials=credentials.AnonymousCredentials(),
)

# Every method on the transport should just blindly
# raise NotImplementedError.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
quota_project_id=client_options.quota_project_id,
)


{% for method in service.methods.values() -%}
def {{ method.name|snake_case }}(self,
{%- if not method.client_streaming %}
Expand Down Expand Up @@ -317,25 +318,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):

# Wrap the RPC method; this adds retry and timeout information,
# and friendly error handling.
rpc = gapic_v1.method.wrap_method(
self._transport.{{ method.name|snake_case }},
{%- if method.retry %}
default_retry=retries.Retry(
{% if method.retry.initial_backoff %}initial={{ method.retry.initial_backoff }},{% endif %}
{% if method.retry.max_backoff %}maximum={{ method.retry.max_backoff }},{% endif %}
{% if method.retry.backoff_multiplier %}multiplier={{ method.retry.backoff_multiplier }},{% endif %}
predicate=retries.if_exception_type(
{%- filter sort_lines %}
{%- for ex in method.retry.retryable_exceptions %}
exceptions.{{ ex.__name__ }},
{%- endfor %}
{%- endfilter %}
),
),
{%- endif %}
default_timeout={{ method.timeout }},
client_info=_client_info,
)
rpc = self._transport._wrapped_methods[self._transport.{{ method.name|snake_case}}]
{%- if method.field_headers %}

# Certain fields should be provided within the metadata header;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
{% block content %}
import abc
import typing
import pkg_resources

from google import auth
from google.api_core import exceptions # type: ignore
from google.api_core import gapic_v1 # type: ignore
{%- if service.has_lro %}
from google.api_core import operations_v1 # type: ignore
{%- endif %}
Expand All @@ -22,6 +24,15 @@ from google.iam.v1 import policy_pb2 as policy # type: ignore
{% endif %}
{% endfilter %}

try:
_client_info = gapic_v1.client_info.ClientInfo(
gapic_version=pkg_resources.get_distribution(
'{{ api.naming.warehouse_package_name }}',
).version,
)
except pkg_resources.DistributionNotFound:
_client_info = gapic_v1.client_info.ClientInfo()

class {{ service.name }}Transport(abc.ABC):
"""Abstract transport class for {{ service.name }}."""

Expand Down Expand Up @@ -79,6 +90,38 @@ class {{ service.name }}Transport(abc.ABC):

# Save the credentials.
self._credentials = credentials

# Lifted into its own function so it can be stubbed out during tests.
self._prep_wrapped_messages()


def _prep_wrapped_messages(self):
# Precompute the wrapped methods.
self._wrapped_methods = {
{% for method in service.methods.values() -%}
self.{{ method.name|snake_case }}: gapic_v1.method.wrap_method(
self.{{ method.name|snake_case }},
{%- if method.retry %}
default_retry=retries.Retry(
{% if method.retry.initial_backoff %}initial={{ method.retry.initial_backoff }},{% endif %}
{% if method.retry.max_backoff %}maximum={{ method.retry.max_backoff }},{% endif %}
{% if method.retry.backoff_multiplier %}multiplier={{ method.retry.backoff_multiplier }},{% endif %}
predicate=retries.if_exception_type(
{%- filter sort_lines %}
{%- for ex in method.retry.retryable_exceptions %}
exceptions.{{ ex.__name__ }},
{%- endfor %}
{%- endfilter %}
),
),
{%- endif %}
default_timeout={{ method.timeout }},
client_info=_client_info,
),
{% endfor %} {# precomputed wrappers loop #}
}


{%- if service.has_lro %}

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
quota_project_id=quota_project_id,
)

self._stubs = {} # type: Dict[str, Callable]

# Run the base constructor.
super().__init__(
host=host,
Expand All @@ -127,8 +129,6 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
quota_project_id=quota_project_id,
)

self._stubs = {} # type: Dict[str, Callable]

@classmethod
def create_channel(cls,
host: str{% if service.host %} = '{{ service.host }}'{% endif %},
Expand Down
2 changes: 1 addition & 1 deletion gapic/templates/setup.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ setuptools.setup(
install_requires=(
'google-api-core[grpc] >= 1.22.0, < 2.0.0dev',
'libcst >= 0.2.5',
'proto-plus >= 1.1.0',
'proto-plus >= 1.4.0',
{%- if api.requires_package(('google', 'iam', 'v1')) or opts.add_iam_methods %}
'grpc-google-iam-v1',
{%- endif %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1008,9 +1008,11 @@ def test_{{ service.name|snake_case }}_base_transport_error():

def test_{{ service.name|snake_case }}_base_transport():
# Instantiate the base transport.
transport = transports.{{ service.name }}Transport(
credentials=credentials.AnonymousCredentials(),
)
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}Transport.__init__') as Transport:
Transport.return_value = None
transport = transports.{{ service.name }}Transport(
credentials=credentials.AnonymousCredentials(),
)

# Every method on the transport should just blindly
# raise NotImplementedError.
Expand Down Expand Up @@ -1038,7 +1040,8 @@ def test_{{ service.name|snake_case }}_base_transport():

def test_{{ service.name|snake_case }}_base_transport_with_credentials_file():
# Instantiate the base transport with a credentials file
with mock.patch.object(auth, 'load_credentials_from_file') as load_creds:
with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}Transport._prep_wrapped_messages') as Transport:
Transport.return_value = None
load_creds.return_value = (credentials.AnonymousCredentials(), None)
transport = transports.{{ service.name }}Transport(
credentials_file="credentials.json",
Expand Down

0 comments on commit 2f2fb5d

Please sign in to comment.