Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement grpc transcode for rest transport and complete generated tests #999

Merged
merged 10 commits into from
Sep 30, 2021
14 changes: 14 additions & 0 deletions gapic/schema/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import collections
import dataclasses
import json
import re
from itertools import chain
from typing import (Any, cast, Dict, FrozenSet, Iterable, List, Mapping,
Expand All @@ -39,6 +40,7 @@
from google.api import http_pb2
from google.api import resource_pb2
from google.api_core import exceptions # type: ignore
from google.api_core import path_template # type: ignore
from google.protobuf import descriptor_pb2 # type: ignore
from google.protobuf.json_format import MessageToDict # type: ignore

Expand Down Expand Up @@ -714,6 +716,18 @@ class HttpRule:
uri: str
body: Optional[str]

@property
def path_fields(self) -> List[Tuple[str, str]]:
"""return list of (name, template) tuples extracted from uri."""
return [(match.group("name"), match.group("template"))
for match in path_template._VARIABLE_RE.finditer(self.uri)]

@property
def sample_request(self) -> str:
"""return json dict for sample request matching the uri template."""
sample = utils.sample_from_path_fields(self.path_fields)
return json.dumps(sample)

@classmethod
def try_parse_http_rule(cls, http_rule) -> Optional['HttpRule']:
method = http_rule.WhichOneof("pattern")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,37 +1,39 @@
from google.auth.transport.requests import AuthorizedSession
import json # type: ignore
import grpc # type: ignore
from google.auth.transport.grpc import SslCredentials # type: ignore
from google.auth import credentials as ga_credentials # type: ignore
from google.api_core import exceptions as core_exceptions # type: ignore
from google.api_core import retry as retries # type: ignore
from google.api_core import rest_helpers # type: ignore
from google.api_core import path_template # type: ignore
from google.api_core import gapic_v1 # type: ignore
from google.api_core import operations_v1
from requests import __version__ as requests_version
from typing import Callable, Dict, Optional, Sequence, Tuple
import warnings
{% extends '_base.py.j2' %}

{% block content %}

import warnings
from typing import Callable, Dict, Optional, Sequence, Tuple
from requests import __version__ as requests_version

{% if service.has_lro %}
from google.api_core import operations_v1
{% endif %}
from google.api_core import gapic_v1 # type: ignore
from google.api_core import retry as retries # type: ignore
from google.api_core import exceptions as core_exceptions # type: ignore
from google.auth import credentials as ga_credentials # type: ignore
from google.auth.transport.grpc import SslCredentials # type: ignore

import grpc # type: ignore

from google.auth.transport.requests import AuthorizedSession

{# TODO(yon-mg): re-add python_import/ python_modules from removed diff/current grpc template code #}
{% filter sort_lines %}
{% for method in service.methods.values() %}
{{ method.input.ident.python_import }}
{{ method.output.ident.python_import }}
{{method.input.ident.python_import}}
{{method.output.ident.python_import}}
{% endfor %}
{% if opts.add_iam_methods %}
from google.iam.v1 import iam_policy_pb2 # type: ignore
from google.iam.v1 import policy_pb2 # type: ignore
{% endif %}
{% endfilter %}

from .base import {{ service.name }}Transport, DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO
from .base import {{service.name}}Transport, DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO


DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
Expand All @@ -40,7 +42,7 @@ DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
rest_version=requests_version,
)

class {{ service.name }}RestTransport({{ service.name }}Transport):
class {{service.name}}RestTransport({{service.name}}Transport):
"""REST backend transport for {{ service.name }}.

{{ service.meta.doc|rst(width=72, indent=4) }}
Expand All @@ -54,13 +56,15 @@ class {{ service.name }}RestTransport({{ service.name }}Transport):
{# TODO(yon-mg): handle mtls stuff if that's relevant for rest transport #}
def __init__(self, *,
host: str{% if service.host %} = '{{ service.host }}'{% endif %},
credentials: ga_credentials.Credentials = None,
credentials_file: str = None,
scopes: Sequence[str] = None,
client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None,
quota_project_id: Optional[str] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
always_use_jwt_access: Optional[bool] = False,
credentials: ga_credentials.Credentials=None,
credentials_file: str=None,
scopes: Sequence[str]=None,
client_cert_source_for_mtls: Callable[[
], Tuple[bytes, bytes]]=None,
quota_project_id: Optional[str]=None,
client_info: gapic_v1.client_info.ClientInfo=DEFAULT_CLIENT_INFO,
always_use_jwt_access: Optional[bool]=False,
url_scheme: str='https',
) -> None:
"""Instantiate the transport.

Expand Down Expand Up @@ -88,6 +92,11 @@ class {{ service.name }}RestTransport({{ service.name }}Transport):
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
your own client library.
always_use_jwt_access (Optional[bool]): Whether self signed JWT should
be used for service account credentials.
url_scheme: the protocol scheme for the API endpoint. Normally
"https", but for testing or local servers,
"http" can be specified.
"""
# Run the base constructor
# TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc.
Expand All @@ -99,7 +108,8 @@ class {{ service.name }}RestTransport({{ service.name }}Transport):
client_info=client_info,
always_use_jwt_access=always_use_jwt_access,
)
self._session = AuthorizedSession(self._credentials, default_host=self.DEFAULT_HOST)
self._session = AuthorizedSession(
self._credentials, default_host=self.DEFAULT_HOST)
{% if service.has_lro %}
self._operations_client = None
{% endif %}
Expand Down Expand Up @@ -136,16 +146,17 @@ class {{ service.name }}RestTransport({{ service.name }}Transport):

# Return the client from cache.
return self._operations_client


{% endif %}
{% for method in service.methods.values() %}
{% if method.http_opt %}

def {{ method.name|snake_case }}(self,
request: {{ method.input.ident }}, *,
retry: retries.Retry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> {{ method.output.ident }}:
{%- if method.http_options and not method.lro and not (method.server_streaming or method.client_streaming) %}
def _{{method.name | snake_case}}(self,
request: {{method.input.ident}}, *,
retry: retries.Retry=gapic_v1.method.DEFAULT,
timeout: float=None,
metadata: Sequence[Tuple[str, str]]=(),
) -> {{method.output.ident}}:
r"""Call the {{- ' ' -}}
{{ (method.name|snake_case).replace('_',' ')|wrap(
width=70, offset=45, indent=8) }}
Expand All @@ -168,62 +179,57 @@ class {{ service.name }}RestTransport({{ service.name }}Transport):
{% endif %}
"""

{# TODO(yon-mg): refactor when implementing grpc transcoding
- parse request pb & assign body, path params
- shove leftovers into query params
- make sure dotted nested fields preserved
- format url and send the request
#}
{% if 'body' in method.http_opt %}
http_options = [
{%- for rule in method.http_options %}{
'method': '{{ rule.method }}',
'uri': '{{ rule.uri }}',
{%- if rule.body %}
'body': '{{ rule.body }}',
{%- endif %}
},
{%- endfor %}]

request_kwargs = {{method.input.ident}}.to_dict(request)
transcoded_request = path_template.transcode(
http_options, **request_kwargs)

{% set body_spec = method.http_options[0].body %}
{%- if body_spec %}

# Jsonify the request body
{% if method.http_opt['body'] != '*' %}
body = {{ method.input.fields[method.http_opt['body']].type.ident }}.to_json(
request.{{ method.http_opt['body'] }},
body = {% if body_spec == '*' -%}
{{method.input.ident}}.to_json(
{{method.input.ident}}(transcoded_request['body']),
{%- else -%}
{{method.input.fields[body_spec].type.ident}}.to_json(
{{method.input.fields[body_spec].type.ident}}(
transcoded_request['body']),
{%- endif %}

including_default_value_fields=False,
use_integers_for_enums=False
)
{% else %}
body = {{ method.input.ident }}.to_json(
request,
use_integers_for_enums=False
)
{% endif %}
{% endif %}
{%- endif %}{# body_spec #}
kbandes marked this conversation as resolved.
Show resolved Hide resolved

{# TODO(yon-mg): Write helper method for handling grpc transcoding url #}
# TODO(yon-mg): need to handle grpc transcoding and parse url correctly
# current impl assumes basic case of grpc transcoding
url = 'https://{host}{{ method.http_opt['url'] }}'.format(
host=self._host,
{% for field in method.path_params %}
{{ field }}=request.{{ method.input.get_field(field).name }},
{% endfor %}
)
uri = transcoded_request['uri']
method = transcoded_request['method']

{# TODO(yon-mg): move all query param logic out of wrappers into here to handle
nested fields correctly (can't just use set of top level fields
#}
# TODO(yon-mg): handle nested fields correctly rather than using only top level fields
# not required for GCE
query_params = {}
{% for field in method.query_params | sort%}
{% if method.input.fields[field].proto3_optional %}
if {{ method.input.ident }}.{{ field }} in request:
query_params['{{ field|camel_case }}'] = request.{{ field }}
{% else %}
query_params['{{ field|camel_case }}'] = request.{{ field }}
{% endif %}
{% endfor %}
# Jsonify the query params
query_params = json.loads({{method.input.ident}}.to_json(
{{method.input.ident}}(transcoded_request['query_params']),
including_default_value_fields=False,
use_integers_for_enums=False
))

# Send the request
headers = dict(metadata)
headers['Content-Type'] = 'application/json'
response = self._session.{{ method.http_opt['verb'] }}(
url,
response=getattr(self._session, method)(
uri,
timeout=timeout,
headers=headers,
params=query_params,
{% if 'body' in method.http_opt %}
params=rest_helpers.flatten_query_params(query_params),
{% if body_spec %}
data=body,
{% endif %}
)
Expand All @@ -235,16 +241,50 @@ class {{ service.name }}RestTransport({{ service.name }}Transport):
{% if not method.void %}

# Return the response
return {{ method.output.ident }}.from_json(
return {{method.output.ident}}.from_json(
response.content,
ignore_unknown_fields=True
)
{% endif %}
{% endif %}
{% else %}

def _{{method.name | snake_case}}(self,
request: {{method.input.ident}}, *,
metadata: Sequence[Tuple[str, str]]=(),
) -> {{method.output.ident}}:
r"""Placeholder: Unable to implement over REST
"""
{%- if not method.http_options %}

raise RuntimeError(
"Cannot define a method without a valid 'google.api.http' annotation.")
{%- elif method.lro %}

raise NotImplementedError(
"LRO over REST is not yet defined for python client.")
{%- elif method.server_streaming or method.client_streaming %}

raise NotImplementedError(
"Streaming over REST is not yet defined for python client")
{%- else %}

raise NotImplementedError()
{%- endif %}
{%- endif %}


{% endfor %}
{%- for method in service.methods.values() %}

@ property
def {{method.name | snake_case}}(self) -> Callable[
[{{method.input.ident}}],
{{method.output.ident}}]:
return self._{{method.name | snake_case}}
{%- endfor %}


__all__ = (
__all__=(
'{{ service.name }}RestTransport',
)
{% endblock %}
Loading