Skip to content

Commit

Permalink
feat(snippetgen): generate mock input for required fields (#941)
Browse files Browse the repository at this point in the history
Generate mock field inputs for all required fields.

* For oneofs, the first option is selected.
* For enums, the last option is selected (as 0 is often unspecified).
* Mock input for string fields that reference resources (`resource_pb2.resource_reference`) is the pattern provided in the protos.

**TODOs:**
- When a resource's pattern is `*` this results in unhelpful strings
    ```proto
    message Asset {
      option (google.api.resource) = {
        type: "cloudasset.googleapis.com/Asset"
        pattern: "*"
      };
    ```

    ```py
    request = asset_v1.BatchGetAssetsHistoryRequest(
       parent="*",
    )
    ```
- Map fields are not yet handled.

**Other changes:**

* Fields are set like `foo.bar` rather than through dict access `foo["bar"]`. This is because the former allows you to set fields nested more than one field deep `foo.bar.baz`.

    **Before**:
    ```py
    output_config = {}
    output_config["gcs_destination"]["uri"] = "uri_value" # fails as there is no nested dict
    ```

    **After**:
    ```py
    output_config = asset_v1.IamPolicyAnalysisOutputConfig()
    output_config.gcs_destination.uri = "uri_value"
    ```
  • Loading branch information
busunkim96 authored Aug 18, 2021
1 parent 49205d9 commit b2149da
Show file tree
Hide file tree
Showing 144 changed files with 1,155 additions and 107 deletions.
78 changes: 67 additions & 11 deletions gapic/samplegen/samplegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def build(
f"Resource {resource_typestr} has no pattern with params: {attr_name_str}"
)

return cls(base=base, body=attrs, single=None, pattern=pattern)
return cls(base=base, body=attrs, single=None, pattern=pattern,)


@dataclasses.dataclass
Expand Down Expand Up @@ -293,17 +293,22 @@ def preprocess_sample(sample, api_schema: api.API, rpc: wrappers.Method):
sample["module_name"] = api_schema.naming.versioned_module_name
sample["module_namespace"] = api_schema.naming.module_namespace

service = api_schema.services[sample["service"]]

# Assume the gRPC transport if the transport is not specified
sample.setdefault("transport", api.TRANSPORT_GRPC)
transport = sample.setdefault("transport", api.TRANSPORT_GRPC)

if sample["transport"] == api.TRANSPORT_GRPC_ASYNC:
sample["client_name"] = api_schema.services[sample["service"]
].async_client_name
else:
sample["client_name"] = api_schema.services[sample["service"]].client_name
is_async = transport == api.TRANSPORT_GRPC_ASYNC
sample["client_name"] = service.async_client_name if is_async else service.client_name

# the type of the request object passed to the rpc e.g, `ListRequest`
sample["request_type"] = rpc.input.ident.name
# the MessageType of the request object passed to the rpc e.g, `ListRequest`
sample["request_type"] = rpc.input

# If no request was specified in the config
# Add reasonable default values as placeholders
if "request" not in sample:
sample["request"] = generate_request_object(
api_schema, service, rpc.input)

# If no response was specified in the config
# Add reasonable defaults depending on the type of the sample
Expand Down Expand Up @@ -940,6 +945,58 @@ def parse_handwritten_specs(sample_configs: Sequence[str]) -> Generator[Dict[str
yield spec


def generate_request_object(api_schema: api.API, service: wrappers.Service, message: wrappers.MessageType, field_name_prefix: str = ""):
"""Generate dummy input for a given message.
Args:
api_schema (api.API): The schema that defines the API.
service (wrappers.Service): The service object the message belongs to.
message (wrappers.MessageType): The message to generate a request object for.
field_name_prefix (str): A prefix to attach to the field name in the request.
Returns:
List[Dict[str, Any]]: A list of dicts that can be turned into TransformedRequests.
"""
request: List[Dict[str, Any]] = []

request_fields: List[wrappers.Field] = []

# Choose the first option for each oneof
selected_oneofs: List[wrappers.Field] = [oneof_fields[0]
for oneof_fields in message.oneof_fields().values()]
request_fields = selected_oneofs + message.required_fields

for field in request_fields:
# TransformedRequest expects nested fields to be referenced like
# `destination.input_config.name`
field_name = ".".join([field_name_prefix, field.name]).lstrip('.')

# TODO(busunkim): Properly handle map fields
if field.is_primitive:
placeholder_value = field.mock_value_original_type
# If this field identifies a resource use the resource path
if service.resource_messages_dict.get(field.resource_reference):
placeholder_value = service.resource_messages_dict[
field.resource_reference].resource_path
request.append({"field": field_name, "value": placeholder_value})
elif field.enum:
# Choose the last enum value in the list since index 0 is often "unspecified"
request.append(
{"field": field_name, "value": field.enum.values[-1].name})
else:
# This is a message type, recurse
# TODO(busunkim): Some real world APIs have
# request objects are recursive.
# Reference `Field.mock_value` to ensure
# this always terminates.
request += generate_request_object(
api_schema, service, field.type,
field_name_prefix=field_name,
)

return request


def generate_sample_specs(api_schema: api.API, *, opts) -> Generator[Dict[str, Any], None, None]:
"""Given an API, generate basic sample specs for each method.
Expand All @@ -964,8 +1021,7 @@ def generate_sample_specs(api_schema: api.API, *, opts) -> Generator[Dict[str, A
"sample_type": "standalone",
"rpc": rpc_name,
"transport": transport,
"request": [],
# response is populated in `preprocess_sample`
# `request` and `response` is populated in `preprocess_sample`
"service": f"{api_schema.naming.proto_package}.{service_name}",
"region_tag": region_tag,
"description": f"Snippet for {utils.to_snake_case(rpc_name)}"
Expand Down
112 changes: 97 additions & 15 deletions gapic/schema/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import dataclasses
import re
from itertools import chain
from typing import (cast, Dict, FrozenSet, Iterable, List, Mapping,
from typing import (Any, cast, Dict, FrozenSet, Iterable, List, Mapping,
ClassVar, Optional, Sequence, Set, Tuple, Union)
from google.api import annotations_pb2 # type: ignore
from google.api import client_pb2
Expand Down Expand Up @@ -89,6 +89,17 @@ def map(self) -> bool:
"""Return True if this field is a map, False otherwise."""
return bool(self.repeated and self.message and self.message.map)

@utils.cached_property
def mock_value_original_type(self) -> Union[bool, str, bytes, int, float, List[Any], None]:
answer = self.primitive_mock() or None

# If this is a repeated field, then the mock answer should
# be a list.
if self.repeated:
answer = [answer]

return answer

@utils.cached_property
def mock_value(self) -> str:
visited_fields: Set["Field"] = set()
Expand All @@ -100,25 +111,13 @@ def mock_value(self) -> str:

return answer

def inner_mock(self, stack, visited_fields):
def inner_mock(self, stack, visited_fields) -> str:
"""Return a repr of a valid, usually truthy mock value."""
# For primitives, send a truthy value computed from the
# field name.
answer = 'None'
if isinstance(self.type, PrimitiveType):
if self.type.python_type == bool:
answer = 'True'
elif self.type.python_type == str:
answer = f"'{self.name}_value'"
elif self.type.python_type == bytes:
answer = f"b'{self.name}_blob'"
elif self.type.python_type == int:
answer = f'{sum([ord(i) for i in self.name])}'
elif self.type.python_type == float:
answer = f'0.{sum([ord(i) for i in self.name])}'
else: # Impossible; skip coverage checks.
raise TypeError('Unrecognized PrimitiveType. This should '
'never happen; please file an issue.')
answer = self.primitive_mock_as_str()

# If this is an enum, select the first truthy value (or the zero
# value if nothing else exists).
Expand Down Expand Up @@ -158,6 +157,45 @@ def inner_mock(self, stack, visited_fields):
# Done; return the mock value.
return answer

def primitive_mock(self) -> Union[bool, str, bytes, int, float, List[Any], None]:
"""Generate a valid mock for a primitive type. This function
returns the original (Python) type.
"""
answer: Union[bool, str, bytes, int, float, List[Any], None] = None

if not isinstance(self.type, PrimitiveType):
raise TypeError(f"'inner_mock_as_original_type' can only be used for"
f"PrimitiveType, but type is {self.type}")

else:
if self.type.python_type == bool:
answer = True
elif self.type.python_type == str:
answer = f"{self.name}_value"
elif self.type.python_type == bytes:
answer = bytes(f"{self.name}_blob", encoding="utf-8")
elif self.type.python_type == int:
answer = sum([ord(i) for i in self.name])
elif self.type.python_type == float:
name_sum = sum([ord(i) for i in self.name])
answer = name_sum * pow(10, -1 * len(str(name_sum)))
else: # Impossible; skip coverage checks.
raise TypeError('Unrecognized PrimitiveType. This should '
'never happen; please file an issue.')

return answer

def primitive_mock_as_str(self) -> str:
"""Like primitive mock, but return the mock as a string."""
answer = self.primitive_mock()

if isinstance(answer, str):
answer = f"'{answer}'"
else:
answer = str(answer)

return answer

@property
def proto_type(self) -> str:
"""Return the proto type constant to be used in templates."""
Expand Down Expand Up @@ -186,6 +224,17 @@ def required(self) -> bool:
return (field_behavior_pb2.FieldBehavior.Value('REQUIRED') in
self.options.Extensions[field_behavior_pb2.field_behavior])

@property
def resource_reference(self) -> Optional[str]:
"""Return a resource reference type if it exists.
This is only applicable for string fields.
Example: "translate.googleapis.com/Glossary"
"""
return (self.options.Extensions[resource_pb2.resource_reference].type
or self.options.Extensions[resource_pb2.resource_reference].child_type
or None)

@utils.cached_property
def type(self) -> Union['MessageType', 'EnumType', 'PrimitiveType']:
"""Return the type of this field."""
Expand Down Expand Up @@ -286,6 +335,13 @@ def oneof_fields(self, include_optional=False):

return oneof_fields

@utils.cached_property
def required_fields(self) -> Sequence['Field']:
required_fields = [
field for field in self.fields.values() if field.required]

return required_fields

@utils.cached_property
def field_types(self) -> Sequence[Union['MessageType', 'EnumType']]:
answer = tuple(
Expand Down Expand Up @@ -353,6 +409,11 @@ def resource_type(self) -> Optional[str]:
resource = self.options.Extensions[resource_pb2.resource]
return resource.type[resource.type.find('/') + 1:] if resource else None

@property
def resource_type_full_path(self) -> Optional[str]:
resource = self.options.Extensions[resource_pb2.resource]
return resource.type if resource else None

@property
def resource_path_args(self) -> Sequence[str]:
return self.PATH_ARG_RE.findall(self.resource_path or '')
Expand Down Expand Up @@ -1199,6 +1260,27 @@ def gen_indirect_resources_used(message):
)
)

@utils.cached_property
def resource_messages_dict(self) -> Dict[str, MessageType]:
"""Returns a dict from resource reference to
the message type. This *includes* the common resource messages.
Returns:
Dict[str, MessageType]: A mapping from resource path
string to the corresponding MessageType.
`{"locations.googleapis.com/Location": MessageType(...)}`
"""
service_resource_messages = {
r.resource_type_full_path: r for r in self.resource_messages}

# Add common resources
service_resource_messages.update(
(resource_path, resource.message_type)
for resource_path, resource in self.common_resources.items()
)

return service_resource_messages

@utils.cached_property
def any_client_streaming(self) -> bool:
return any(m.client_streaming for m in self.methods.values())
Expand Down
18 changes: 10 additions & 8 deletions gapic/templates/examples/feature_fragments.j2
Original file line number Diff line number Diff line change
Expand Up @@ -126,18 +126,19 @@ with open({{ print_string_formatting(statement["filename"])|trim }}, "wb") as f:
{% macro render_request_attr(base_name, attr) %}
{# Note: python code will have manipulated the value #}
{# to be the correct enum from the right module, if necessary. #}
{# Python is also responsible for verifying that each input parameter is unique,#}
{# Python is also responsible for verifying that each input parameter is unique, #}
{# no parameter is a reserved keyword #}
{% if attr.input_parameter %}

# {{ attr.input_parameter }} = {{ attr.value }}
{% if attr.value_is_file %}
with open({{ attr.input_parameter }}, "rb") as f:
{{ base_name }}["{{ attr.field }}"] = f.read()
{{ base_name }}.{{ attr.field }} = f.read()
{% else %}
{{ base_name }}["{{ attr.field }}"] = {{ attr.input_parameter }}
{{ base_name }}.{{ attr.field }} = {{ attr.input_parameter }}
{% endif %}
{% else %}
{{ base_name }}["{{ attr.field }}"] = {{ attr.value }}
{{ base_name }}.{{ attr.field }} = {{ attr.value }}
{% endif %}
{% endmacro %}

Expand All @@ -159,16 +160,17 @@ client = {{ module_name }}.{{ client_name }}()
{{ parameter_block.base }} = "{{parameter_block.pattern }}".format({{ formals|join(", ") }})
{% endwith %}
{% else %}{# End resource name construction #}
{{ parameter_block.base }} = {}
{{ parameter_block.base }} = {{ module_name }}.{{ request_type.get_field(parameter_block.base).type.name }}()
{% for attr in parameter_block.body %}
{{ render_request_attr(parameter_block.base, attr) }}
{{ render_request_attr(parameter_block.base, attr) -}}
{% endfor %}

{% endif %}
{% endfor %}
{% if not full_request.flattenable %}
request = {{ module_name }}.{{ request_type }}(
request = {{ module_name }}.{{ request_type.ident.name }}(
{% for parameter in full_request.request_list %}
{{ parameter.base }}={{ parameter.base if parameter.body else parameter.single }},
{{ parameter.base }}={{ parameter.base if parameter.body else parameter.single.value }},
{% endfor %}
)
{% endif %}
Expand Down
4 changes: 2 additions & 2 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def unit(session):
"--cov-report=term",
"--cov-fail-under=100",
path.join("tests", "unit"),
]
]
),
)

Expand Down Expand Up @@ -308,7 +308,7 @@ def snippetgen(session):

session.run(
"py.test",
"--quiet",
"-vv",
"tests/snippetgen"
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ async def sample_analyze_iam_policy():
client = asset_v1.AssetServiceAsyncClient()

# Initialize request argument(s)
analysis_query = asset_v1.IamPolicyAnalysisQuery()
analysis_query.scope = "scope_value"

request = asset_v1.AnalyzeIamPolicyRequest(
analysis_query=analysis_query,
)

# Make the request
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,15 @@ async def sample_analyze_iam_policy_longrunning():
client = asset_v1.AssetServiceAsyncClient()

# Initialize request argument(s)
analysis_query = asset_v1.IamPolicyAnalysisQuery()
analysis_query.scope = "scope_value"

output_config = asset_v1.IamPolicyAnalysisOutputConfig()
output_config.gcs_destination.uri = "uri_value"

request = asset_v1.AnalyzeIamPolicyLongrunningRequest(
analysis_query=analysis_query,
output_config=output_config,
)

# Make the request
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,15 @@ def sample_analyze_iam_policy_longrunning():
client = asset_v1.AssetServiceClient()

# Initialize request argument(s)
analysis_query = asset_v1.IamPolicyAnalysisQuery()
analysis_query.scope = "scope_value"

output_config = asset_v1.IamPolicyAnalysisOutputConfig()
output_config.gcs_destination.uri = "uri_value"

request = asset_v1.AnalyzeIamPolicyLongrunningRequest(
analysis_query=analysis_query,
output_config=output_config,
)

# Make the request
Expand Down
Loading

0 comments on commit b2149da

Please sign in to comment.