Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix deserialization error for LRO which has discriminator #2589

Closed
wants to merge 52 commits into from
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
4c12c62
code
msyyc May 22, 2024
49ab757
fix for legacy test
msyyc May 22, 2024
3c22f8e
inv and black
msyyc May 22, 2024
7adc1c8
fix mypy
msyyc May 22, 2024
ca98be8
fix pyright error
msyyc May 22, 2024
e142d48
fix pylint
msyyc May 22, 2024
1857677
inv
msyyc May 22, 2024
dccdd16
update
msyyc May 23, 2024
c97c28c
review
msyyc May 23, 2024
2eea3a4
fix
msyyc May 23, 2024
a718d6a
fix
msyyc May 23, 2024
ebf0e16
Fix test
msyyc May 23, 2024
738caab
fix multiapi test
msyyc May 23, 2024
76c8d03
disable deserialize for all initial operation
msyyc May 24, 2024
d163f3d
review
msyyc May 24, 2024
133bf7b
inv
msyyc May 24, 2024
f830afe
Merge branch 'main' of https://github.com/Azure/autorest.python into …
msyyc May 24, 2024
f464d47
update changelog
msyyc May 24, 2024
0647135
inv
msyyc May 24, 2024
c0f0610
Merge branch 'deserialization-fix' of https://github.com/Azure/autore…
msyyc May 24, 2024
148e934
inv
msyyc May 24, 2024
52c13db
Merge branch 'main' into deserialization-fix
msyyc May 28, 2024
8b7f073
Merge branch 'main' into deserialization-fix
msyyc May 29, 2024
24d19a1
force initial operation to return stream
May 29, 2024
fa2c5ce
revert extra changes in builder_serializer
May 29, 2024
41f1d09
regen
May 29, 2024
3a04d08
regen lropaging
May 29, 2024
c92b238
regen with load_body for aiohttp
May 29, 2024
f71903c
fix
msyyc May 30, 2024
49d85c3
inv
msyyc May 30, 2024
1602dc1
use pipeline_response.http_response for legacy
msyyc May 30, 2024
015d844
fix test
msyyc May 30, 2024
bcbc2d3
inv
msyyc May 30, 2024
44017d5
Merge branch 'main' of https://github.com/Azure/autorest.python into …
Jun 4, 2024
f7e5be6
Merge branch 'deserialization-fix' of https://github.com/Azure/autore…
Jun 4, 2024
f0eae62
read in response
Jun 4, 2024
d8ef34f
inv
msyyc Jun 5, 2024
b086c62
fix multiapi test
msyyc Jun 5, 2024
066eb98
inv
msyyc Jun 5, 2024
cfb50f5
fix pyright
msyyc Jun 5, 2024
df5cf4c
simplify code
Jun 5, 2024
cfe3dbd
generate
Jun 5, 2024
e805530
regen
Jun 5, 2024
eb9250a
regenerate
Jun 5, 2024
f6577fb
black
Jun 5, 2024
d70e0e9
Merge branch 'main' of https://github.com/Azure/autorest.python into …
Jun 6, 2024
d71c77b
regen to revert changes
Jun 6, 2024
a3f8f22
revert changes
Jun 6, 2024
91d8643
regen
Jun 6, 2024
6681433
Merge branch 'main' of https://github.com/Azure/autorest.python into …
Jun 7, 2024
01b89e8
regen
Jun 7, 2024
c0e2522
revert tasks change
Jun 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .chronus/changes/deserialization-fix-2024-4-22-17-3-4.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
changeKind: fix
packages:
- "@autorest/python"
- "@azure-tools/typespec-python"
---

Fix deserialization error for lro when return type has discriminator and succeed in initial response
4 changes: 4 additions & 0 deletions packages/autorest.python/autorest/codegen/models/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,10 @@ def has_response_body(self) -> bool:
"""Tell if at least one response has a body."""
return any(response.type for response in self.responses)

@property
def has_stream_kwargs(self) -> bool:
return self.expose_stream_keyword and self.has_response_body

@property
def any_response_has_headers(self) -> bool:
return any(response.headers for response in self.responses)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def make_pipeline_call(self, builder: OperationType) -> List[str]:
type_ignore = self.async_mode and builder.group_name == "" # is in a mixin
stream_value = (
f'kwargs.pop("stream", {builder.has_stream_response})'
if builder.expose_stream_keyword and builder.has_response_body
if builder.has_stream_kwargs
else builder.has_stream_response
)
return [
Expand Down Expand Up @@ -897,11 +897,22 @@ def _call_request_builder_helper( # pylint: disable=too-many-statements
def call_request_builder(self, builder: OperationType, is_paging: bool = False) -> List[str]:
return self._call_request_builder_helper(builder, builder.request_builder, is_paging=is_paging)

@property
def deserialize_for_stream_res(self) -> str:
if self.code_model.options["version_tolerant"]:
return "response.iter_bytes()"
return (
"(await response.load_body()) or response._content # pylint: disable=protected-access"
if self.async_mode
else f"response.stream_download(self._client.{self.pipeline_name})"
)

def response_headers_and_deserialization(
self,
builder: OperationType,
response: Response,
) -> List[str]:
# pylint: disable=too-many-statements
retval: List[str] = [
(
f"response_headers['{response_header.wire_name}']=self._deserialize("
Expand All @@ -918,19 +929,23 @@ def response_headers_and_deserialization(
deserialized = f"{'await ' if self.async_mode else ''}response.read()"
else:
stream_logic = False
if self.code_model.options["version_tolerant"]:
deserialized = "response.iter_bytes()"
else:
deserialized = f"response.stream_download(self._client.{self.pipeline_name})"
deserialized = self.deserialize_for_stream_res
deserialize_code.append(f"deserialized = {deserialized}")
elif response.type:
pylint_disable = ""
if isinstance(response.type, ModelType) and response.type.internal:
pylint_disable = " # pylint: disable=protected-access"
if self.code_model.options["models_mode"] == "msrest":
if hasattr(builder, "initial_operation") and builder.initial_operation.has_stream_kwargs: # type: ignore # pylint: disable=line-too-long
response_name = "_response"
deserialize_code.append(
"_response = pipeline_response if getattr(pipeline_response, 'context', {}) else pipeline_response.http_response" # pylint: disable=line-too-long
)
else:
response_name = "pipeline_response"
deserialize_code.append("deserialized = self._deserialize(")
deserialize_code.append(f" '{response.serialization_type}',{pylint_disable}")
deserialize_code.append(" pipeline_response")
deserialize_code.append(f" {response_name}")
deserialize_code.append(")")
elif self.code_model.options["models_mode"] == "dpg":
if builder.has_stream_response:
Expand Down Expand Up @@ -959,7 +974,7 @@ def response_headers_and_deserialization(
if len(deserialize_code) > 0:
if builder.expose_stream_keyword and stream_logic:
retval.append("if _stream:")
retval.append(" deserialized = response.iter_bytes()")
retval.append(f" deserialized = {self.deserialize_for_stream_res}")
retval.append("else:")
retval.extend([f" {dc}" for dc in deserialize_code])
else:
Expand All @@ -969,11 +984,15 @@ def response_headers_and_deserialization(
def handle_error_response(self, builder: OperationType) -> List[str]:
async_await = "await " if self.async_mode else ""
retval = [f"if response.status_code not in {str(builder.success_status_codes)}:"]
if not self.code_model.need_request_converter:
need_download = (
builder.has_stream_kwargs and self.async_mode and not self.code_model.options["version_tolerant"]
)
if not self.code_model.need_request_converter or need_download:
load_func = "load_body" if need_download else "read"
retval.extend(
[
" if _stream:",
f" {async_await} response.read() # Load the body in memory and close the socket",
f" {async_await} response.{load_func}() # Load the body in memory and close the socket",
]
)
type_ignore = " # type: ignore" if _need_type_ignore(builder) else ""
Expand Down Expand Up @@ -1320,6 +1339,8 @@ def initial_call(self, builder: LROOperationType) -> List[str]:
[f" {parameter.client_name}={parameter.client_name}," for parameter in builder.parameters.method]
)
retval.append(" cls=lambda x,y,z: x,")
if builder.initial_operation.has_stream_kwargs:
msyyc marked this conversation as resolved.
Show resolved Hide resolved
retval.append(" stream=True,")
retval.append(" headers=_headers,")
retval.append(" params=_params,")
retval.append(" **kwargs")
Expand Down
4 changes: 4 additions & 0 deletions packages/autorest.python/autorest/preprocess/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,10 @@ def update_lro_operation(
self._update_lro_operation_helper(overload)
self.update_operation(code_model, overload["initialOperation"], is_overload=True)

# for lro initial reponse, there is no need to make deserialization so we mark it
# as stream operation by default which will not make deserialization by default
yaml_data["initialOperation"]["exposeStreamKeyword"] = True

def update_paging_operation(
self,
code_model: Dict[str, Any],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _basic_polling_initial(self, product: Optional[Union[JSON, IO[bytes]]] = Non
)
_request.url = self._client.format_url(_request.url)

_stream = False
_stream = kwargs.pop("stream", False)
pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access
_request, stream=_stream, **kwargs
)
Expand All @@ -124,10 +124,13 @@ def _basic_polling_initial(self, product: Optional[Union[JSON, IO[bytes]]] = Non

deserialized = None
if response.status_code == 200:
if response.content:
deserialized = response.json()
if _stream:
deserialized = response.iter_bytes()
else:
deserialized = None
if response.content:
deserialized = response.json()
else:
deserialized = None

if cls:
return cls(pipeline_response, deserialized, {}) # type: ignore
Expand Down Expand Up @@ -241,6 +244,7 @@ def begin_basic_polling(
product=product,
content_type=content_type,
cls=lambda x, y, z: x,
stream=True,
headers=_headers,
params=_params,
**kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ async def _basic_polling_initial(
)
_request.url = self._client.format_url(_request.url)

_stream = False
_stream = kwargs.pop("stream", False)
pipeline_response: PipelineResponse = await self._client._pipeline.run( # type: ignore # pylint: disable=protected-access
_request, stream=_stream, **kwargs
)
Expand All @@ -97,10 +97,13 @@ async def _basic_polling_initial(

deserialized = None
if response.status_code == 200:
if response.content:
deserialized = response.json()
if _stream:
deserialized = response.iter_bytes()
else:
deserialized = None
if response.content:
deserialized = response.json()
else:
deserialized = None

if cls:
return cls(pipeline_response, deserialized, {}) # type: ignore
Expand Down Expand Up @@ -214,6 +217,7 @@ async def begin_basic_polling(
product=product,
content_type=content_type,
cls=lambda x, y, z: x,
stream=True,
headers=_headers,
params=_params,
**kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,21 +148,26 @@ async def _test_lro_initial(
_request = _convert_request(_request)
_request.url = self._client.format_url(_request.url)

_stream = False
_stream = kwargs.pop("stream", False)
pipeline_response: PipelineResponse = await self._client._pipeline.run( # type: ignore # pylint: disable=protected-access
_request, stream=_stream, **kwargs
)

response = pipeline_response.http_response

if response.status_code not in [200, 204]:
if _stream:
await response.load_body() # Load the body in memory and close the socket
msyyc marked this conversation as resolved.
Show resolved Hide resolved
map_error(status_code=response.status_code, response=response, error_map=error_map)
error = self._deserialize.failsafe_deserialize(_models.Error, pipeline_response)
raise HttpResponseError(response=response, model=error, error_format=ARMErrorFormat)

deserialized = None
if response.status_code == 200:
deserialized = self._deserialize("Product", pipeline_response)
if _stream:
deserialized = (await response.load_body()) or response._content # pylint: disable=protected-access
msyyc marked this conversation as resolved.
Show resolved Hide resolved
else:
deserialized = self._deserialize("Product", pipeline_response)

if cls:
return cls(pipeline_response, deserialized, {}) # type: ignore
Expand Down Expand Up @@ -230,14 +235,18 @@ async def begin_test_lro(
product=product,
content_type=content_type,
cls=lambda x, y, z: x,
stream=True,
headers=_headers,
params=_params,
**kwargs
)
kwargs.pop("error_map", None)

def get_long_running_output(pipeline_response):
deserialized = self._deserialize("Product", pipeline_response)
_response = (
pipeline_response if getattr(pipeline_response, "context", {}) else pipeline_response.http_response
)
msyyc marked this conversation as resolved.
Show resolved Hide resolved
deserialized = self._deserialize("Product", _response)
if cls:
return cls(pipeline_response, deserialized, {}) # type: ignore
return deserialized
Expand Down Expand Up @@ -294,18 +303,23 @@ async def _test_lro_and_paging_initial(
_request = _convert_request(_request)
_request.url = self._client.format_url(_request.url)

_stream = False
_stream = kwargs.pop("stream", False)
pipeline_response: PipelineResponse = await self._client._pipeline.run( # type: ignore # pylint: disable=protected-access
_request, stream=_stream, **kwargs
)

response = pipeline_response.http_response

if response.status_code not in [200]:
if _stream:
await response.load_body() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response, error_format=ARMErrorFormat)

deserialized = self._deserialize("PagingResult", pipeline_response)
if _stream:
deserialized = (await response.load_body()) or response._content # pylint: disable=protected-access
else:
deserialized = self._deserialize("PagingResult", pipeline_response)

if cls:
return cls(pipeline_response, deserialized, {}) # type: ignore
Expand Down Expand Up @@ -411,6 +425,7 @@ async def get_next(next_link=None):
client_request_id=client_request_id,
test_lro_and_paging_options=test_lro_and_paging_options,
cls=lambda x, y, z: x,
stream=True,
headers=_headers,
params=_params,
**kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _test_lro_initial(
_request = _convert_request(_request)
_request.url = self._client.format_url(_request.url)

_stream = False
_stream = kwargs.pop("stream", False)
pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access
_request, stream=_stream, **kwargs
)
Expand All @@ -239,7 +239,10 @@ def _test_lro_initial(

deserialized = None
if response.status_code == 200:
deserialized = self._deserialize("Product", pipeline_response)
if _stream:
deserialized = response.stream_download(self._client._pipeline)
else:
deserialized = self._deserialize("Product", pipeline_response)

if cls:
return cls(pipeline_response, deserialized, {}) # type: ignore
Expand Down Expand Up @@ -304,14 +307,18 @@ def begin_test_lro(
product=product,
content_type=content_type,
cls=lambda x, y, z: x,
stream=True,
headers=_headers,
params=_params,
**kwargs
)
kwargs.pop("error_map", None)

def get_long_running_output(pipeline_response):
deserialized = self._deserialize("Product", pipeline_response)
_response = (
pipeline_response if getattr(pipeline_response, "context", {}) else pipeline_response.http_response
)
deserialized = self._deserialize("Product", _response)
if cls:
return cls(pipeline_response, deserialized, {}) # type: ignore
return deserialized
Expand Down Expand Up @@ -368,7 +375,7 @@ def _test_lro_and_paging_initial(
_request = _convert_request(_request)
_request.url = self._client.format_url(_request.url)

_stream = False
_stream = kwargs.pop("stream", False)
pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access
_request, stream=_stream, **kwargs
)
Expand All @@ -379,7 +386,10 @@ def _test_lro_and_paging_initial(
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response, error_format=ARMErrorFormat)

deserialized = self._deserialize("PagingResult", pipeline_response)
if _stream:
deserialized = response.stream_download(self._client._pipeline)
else:
deserialized = self._deserialize("PagingResult", pipeline_response)

if cls:
return cls(pipeline_response, deserialized, {}) # type: ignore
Expand Down Expand Up @@ -485,6 +495,7 @@ def get_next(next_link=None):
client_request_id=client_request_id,
test_lro_and_paging_options=test_lro_and_paging_options,
cls=lambda x, y, z: x,
stream=True,
headers=_headers,
params=_params,
**kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1608,18 +1608,23 @@ async def _get_multiple_pages_lro_initial(
_request = _convert_request(_request)
_request.url = self._client.format_url(_request.url)

_stream = False
_stream = kwargs.pop("stream", False)
pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access
_request, stream=_stream, **kwargs
)

response = pipeline_response.http_response

if response.status_code not in [202]:
if _stream:
await response.load_body() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response, error_format=ARMErrorFormat)

deserialized = self._deserialize("ProductResult", pipeline_response)
if _stream:
deserialized = (await response.load_body()) or response._content # pylint: disable=protected-access
else:
deserialized = self._deserialize("ProductResult", pipeline_response)

if cls:
return cls(pipeline_response, deserialized, {}) # type: ignore
Expand Down Expand Up @@ -1726,6 +1731,7 @@ async def get_next(next_link=None):
client_request_id=client_request_id,
paging_get_multiple_pages_lro_options=paging_get_multiple_pages_lro_options,
cls=lambda x, y, z: x,
stream=True,
headers=_headers,
params=_params,
**kwargs
Expand Down
Loading
Loading