From 6f01c6ca21fdd491b6218c5b2f665fb50f3e9ac3 Mon Sep 17 00:00:00 2001 From: Remco de Boer Date: Tue, 3 Aug 2021 15:28:04 +0200 Subject: [PATCH] fix: allow creating UnevaluatedExpr with doit arg Previously, there something like `BlattWeisskopfSquared(...).doit(deep=False)` would crash --- src/ampform/dynamics/__init__.py | 11 ++++++----- src/ampform/dynamics/decorator.py | 19 ++++++++++++++++++- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/src/ampform/dynamics/__init__.py b/src/ampform/dynamics/__init__.py index d4e6fd968..ff73f0cbf 100644 --- a/src/ampform/dynamics/__init__.py +++ b/src/ampform/dynamics/__init__.py @@ -11,7 +11,11 @@ import sympy as sp from sympy.printing.latex import LatexPrinter -from .decorator import UnevaluatedExpression, implement_doit_method +from .decorator import ( + UnevaluatedExpression, + create_expression, + implement_doit_method, +) from .math import ComplexSqrt try: @@ -56,10 +60,7 @@ def __new__( # pylint: disable=arguments-differ **hints: Any, ) -> "BlattWeisskopfSquared": args = sp.sympify((angular_momentum, z)) - if evaluate: - # pylint: disable=no-member - return sp.Expr.__new__(cls, *args, **hints).evaluate() - return sp.Expr.__new__(cls, *args, **hints) + return create_expression(cls, evaluate, *args, **hints) def evaluate(self) -> sp.Expr: angular_momentum, z = self.args diff --git a/src/ampform/dynamics/decorator.py b/src/ampform/dynamics/decorator.py index f24b62df7..47933b045 100644 --- a/src/ampform/dynamics/decorator.py +++ b/src/ampform/dynamics/decorator.py @@ -1,7 +1,7 @@ """Tools for defining lineshapes with `sympy`.""" from abc import abstractmethod -from typing import Any, Callable, Type +from typing import Any, Callable, Optional, Type import sympy as sp from sympy.printing.latex import LatexPrinter @@ -99,3 +99,20 @@ def doit_method(self: Any, **hints: Any) -> sp.Expr: return decorated_class return decorator + + +def create_expression( + cls: Type[UnevaluatedExpression], evaluate: bool, *args: Any, **kwargs: Any +) -> sp.Expr: + """Helper function for implementing :code:`Expr.__new__`. + + See e.g. source code of `.BlattWeisskopfSquared`. + """ + # pylint: disable=no-member + deep: Optional[bool] = kwargs.pop("deep", None) + expr = sp.Expr.__new__(cls, *args, **kwargs) + if evaluate: + expr = expr.evaluate() + if deep: + expr = expr.doit(deep=deep) + return expr