diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 3ab59c5..38e07da 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -10,6 +10,8 @@ This library adheres to `Semantic Versioning 2.0 `_; PR by Alex Waygood) - Fixed ``NameError`` when generated type checking code references an imported name from a method (`#362 `_) +- Fixed docstrings disappearing from instrumented functions + (`#359 `_) **4.0.0** (2023-05-12) diff --git a/src/typeguard/_transformer.py b/src/typeguard/_transformer.py index 8ade7a8..4963a36 100644 --- a/src/typeguard/_transformer.py +++ b/src/typeguard/_transformer.py @@ -137,6 +137,7 @@ class TransformMemo: should_instrument: bool = field(init=False, default=True) variable_annotations: dict[str, expr] = field(init=False, default_factory=dict) configuration_overrides: dict[str, Any] = field(init=False, default_factory=dict) + code_inject_index: int = field(init=False, default=0) def __post_init__(self) -> None: elements: list[str] = [] @@ -152,6 +153,18 @@ def __post_init__(self) -> None: self.joined_path = Constant(".".join(elements)) + # Figure out where to insert instrumentation code + if self.node: + for index, child in enumerate(self.node.body): + if isinstance(child, ImportFrom) and child.module == "__future__": + # (module only) __future__ imports must come first + continue + elif isinstance(child, Expr) and isinstance(child.value, Str): + continue # docstring + + self.code_inject_index = index + break + def get_unused_name(self, name: str) -> str: memo: TransformMemo | None = self while memo is not None: @@ -212,20 +225,12 @@ def insert_imports(self, node: Module | FunctionDef | AsyncFunctionDef) -> None: return # Insert imports after any "from __future__ ..." imports and any docstring - for i, child in enumerate(node.body): - if isinstance(child, ImportFrom) and child.module == "__future__": - continue - elif isinstance(child, Expr) and isinstance(child.value, Str): - continue # module docstring - - for modulename, names in self.load_names.items(): - aliases = [ - alias(orig_name, new_name.id if orig_name != new_name.id else None) - for orig_name, new_name in sorted(names.items()) - ] - node.body.insert(i, ImportFrom(modulename, aliases, 0)) - - break + for modulename, names in self.load_names.items(): + aliases = [ + alias(orig_name, new_name.id if orig_name != new_name.id else None) + for orig_name, new_name in sorted(names.items()) + ] + node.body.insert(self.code_inject_index, ImportFrom(modulename, aliases, 0)) def name_matches(self, expression: expr | Expr | None, *names: str) -> bool: if expression is None: @@ -757,7 +762,9 @@ def visit_FunctionDef( annotations_dict, self._memo.get_memo_name(), ] - node.body.insert(0, Expr(Call(func_name, args, []))) + node.body.insert( + self._memo.code_inject_index, Expr(Call(func_name, args, [])) + ) # Add a checked "return None" to the end if there's no explicit return # Skip if the return annotation is None or Any @@ -859,7 +866,7 @@ def visit_FunctionDef( [keyword(key, value) for key, value in memo_kwargs.items()], ) node.body.insert( - 0, + self._memo.code_inject_index, Assign([memo_store_name], memo_expr), ) diff --git a/tests/test_transformer.py b/tests/test_transformer.py index c5615d6..5a5208b 100644 --- a/tests/test_transformer.py +++ b/tests/test_transformer.py @@ -1521,3 +1521,30 @@ def foo(x: Annotated[str, 'foo bar']) -> None: """ ).strip() ) + + +def test_respect_docstring() -> None: + # Regression test for #359 + node = parse( + dedent( + ''' + def foo() -> int: + """This is a docstring.""" + return 1 + ''' + ) + ) + TypeguardTransformer(["foo"]).visit(node) + assert ( + unparse(node) + == dedent( + ''' + def foo() -> int: + """This is a docstring.""" + from typeguard import TypeCheckMemo + from typeguard._functions import check_return_type + memo = TypeCheckMemo(globals(), locals()) + return check_return_type('foo', 1, int, memo) + ''' + ).strip() + )