Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support removal of multiply operator #182

Merged
merged 1 commit into from
Oct 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/integration_tests/algorithmic_style_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def collatz(n):
\If{$n \mathbin{\%} 2 = 0$}
\State $n \gets \left\lfloor\frac{n}{2}\right\rfloor$
\Else
\State $n \gets 3 \cdot n + 1$
\State $n \gets 3 n + 1$
\EndIf
\State $\mathrm{iterations} \gets \mathrm{iterations} + 1$
\EndWhile
Expand All @@ -80,7 +80,7 @@ def collatz(n):
r" \hspace{2em} \mathbf{if} \ n \mathbin{\%} 2 = 0 \\"
r" \hspace{3em} n \gets \left\lfloor\frac{n}{2}\right\rfloor \\"
r" \hspace{2em} \mathbf{else} \\"
r" \hspace{3em} n \gets 3 \cdot n + 1 \\"
r" \hspace{3em} n \gets 3 n + 1 \\"
r" \hspace{2em} \mathbf{end \ if} \\"
r" \hspace{2em}"
r" \mathrm{iterations} \gets \mathrm{iterations} + 1 \\"
Expand Down
21 changes: 9 additions & 12 deletions src/integration_tests/regression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@ def test_quadratic_solution() -> None:
def solve(a, b, c):
return (-b + math.sqrt(b**2 - 4 * a * c)) / (2 * a)

latex = (
r"\mathrm{solve}(a, b, c) ="
r" \frac{-b + \sqrt{ b^{2} - 4 \cdot a \cdot c }}{2 \cdot a}"
)
latex = r"\mathrm{solve}(a, b, c) =" r" \frac{-b + \sqrt{ b^{2} - 4 a c }}{2 a}"
integration_utils.check_function(solve, latex)


Expand Down Expand Up @@ -47,7 +44,7 @@ def xtimesbeta(x, beta):
xtimesbeta, latex_without_symbols, use_math_symbols=False
)

latex_with_symbols = r"\mathrm{xtimesbeta}(x, \beta) = x \cdot \beta"
latex_with_symbols = r"\mathrm{xtimesbeta}(x, \beta) = x \beta"
integration_utils.check_function(
xtimesbeta, latex_with_symbols, use_math_symbols=True
)
Expand Down Expand Up @@ -145,7 +142,7 @@ def test_nested_function() -> None:
def nested(x):
return 3 * x

integration_utils.check_function(nested, r"\mathrm{nested}(x) = 3 \cdot x")
integration_utils.check_function(nested, r"\mathrm{nested}(x) = 3 x")


def test_double_nested_function() -> None:
Expand All @@ -155,7 +152,7 @@ def inner(y):

return inner

integration_utils.check_function(nested(3), r"\mathrm{inner}(y) = x \cdot y")
integration_utils.check_function(nested(3), r"\mathrm{inner}(y) = x y")


def test_reduce_assignments() -> None:
Expand All @@ -165,11 +162,11 @@ def f(x):

integration_utils.check_function(
f,
r"\begin{array}{l} a = x + x \\ f(x) = 3 \cdot a \end{array}",
r"\begin{array}{l} a = x + x \\ f(x) = 3 a \end{array}",
)
integration_utils.check_function(
f,
r"f(x) = 3 \cdot \mathopen{}\left( x + x \mathclose{}\right)",
r"f(x) = 3 \mathopen{}\left( x + x \mathclose{}\right)",
reduce_assignments=True,
)

Expand All @@ -184,15 +181,15 @@ def f(x):
r"\begin{array}{l}"
r" a = x^{2} \\"
r" b = a + a \\"
r" f(x) = 3 \cdot b"
r" f(x) = 3 b"
r" \end{array}"
)

integration_utils.check_function(f, latex_without_option)
integration_utils.check_function(f, latex_without_option, reduce_assignments=False)
integration_utils.check_function(
f,
r"f(x) = 3 \cdot \mathopen{}\left( x^{2} + x^{2} \mathclose{}\right)",
r"f(x) = 3 \mathopen{}\left( x^{2} + x^{2} \mathclose{}\right)",
reduce_assignments=True,
)

Expand Down Expand Up @@ -228,7 +225,7 @@ def solve(a, b):
r"\mathrm{solve}(a, b) ="
r" \frac{a + b - b}{a - b} - \mathopen{}\left("
r" a + b \mathclose{}\right) - \mathopen{}\left("
r" a - b \mathclose{}\right) - a \cdot b"
r" a - b \mathclose{}\right) - a b"
)
integration_utils.check_function(solve, latex)

Expand Down
83 changes: 83 additions & 0 deletions src/latexify/codegen/expression_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import ast
import re

from latexify import analyzers, ast_utils, exceptions
from latexify.codegen import codegen_utils, expression_rules, identifier_converter
Expand Down Expand Up @@ -406,12 +407,94 @@ def _wrap_binop_operand(

return rf"\mathopen{{}}\left( {latex} \mathclose{{}}\right)"

_l_bracket_pattern = re.compile(r"^\\mathopen.*")
_r_bracket_pattern = re.compile(r".*\\mathclose[^ ]+$")
_r_word_pattern = re.compile(r"\\mathrm\{[^ ]+\}$")

def _should_remove_multiply_op(
self, l_latex: str, r_latex: str, l_expr: ast.expr, r_expr: ast.expr
):
"""Determine whether the multiply operator should be removed or not.

See also:
https://github.com/google/latexify_py/issues/89#issuecomment-1344967636

This is an ad-hoc implementation.
This function doesn't fully implements the above requirements, but only
essential ones necessary to release v0.3.
"""

# NOTE(odashi): For compatibility with Python 3.7, we compare the generated
# caracter type directly to determine the "numeric" type.

if isinstance(l_expr, ast.Call):
l_type = "f"
elif self._r_bracket_pattern.match(l_latex):
l_type = "b"
elif self._r_word_pattern.match(l_latex):
l_type = "w"
elif l_latex[-1].isnumeric():
l_type = "n"
else:
le = l_expr
while True:
if isinstance(le, ast.UnaryOp):
le = le.operand
elif isinstance(le, ast.BinOp):
le = le.right
elif isinstance(le, ast.Compare):
le = le.comparators[-1]
elif isinstance(le, ast.BoolOp):
le = le.values[-1]
else:
break
l_type = "a" if isinstance(le, ast.Name) and len(le.id) == 1 else "m"

if isinstance(r_expr, ast.Call):
r_type = "f"
elif self._l_bracket_pattern.match(r_latex):
r_type = "b"
elif r_latex.startswith("\\mathrm"):
r_type = "w"
elif r_latex[0].isnumeric():
r_type = "n"
else:
re = r_expr
while True:
if isinstance(re, ast.UnaryOp):
if isinstance(re.op, ast.USub):
# NOTE(odashi): Unary "-" always require \cdot.
return False
re = re.operand
elif isinstance(re, ast.BinOp):
re = re.left
elif isinstance(re, ast.Compare):
re = re.left
elif isinstance(re, ast.BoolOp):
re = re.values[0]
else:
break
r_type = "a" if isinstance(re, ast.Name) and len(re.id) == 1 else "m"

if r_type == "n":
return False
if l_type in "bn":
return True
if l_type in "am" and r_type in "am":
return True
return False

def visit_BinOp(self, node: ast.BinOp) -> str:
"""Visit a BinOp node."""
prec = expression_rules.get_precedence(node)
rule = self._bin_op_rules[type(node.op)]
lhs = self._wrap_binop_operand(node.left, prec, rule.operand_left)
rhs = self._wrap_binop_operand(node.right, prec, rule.operand_right)

if type(node.op) in [ast.Mult, ast.MatMult]:
if self._should_remove_multiply_op(lhs, rhs, node.left, node.right):
return f"{rule.latex_left}{lhs} {rhs}{rule.latex_right}"

return f"{rule.latex_left}{lhs}{rule.latex_middle}{rhs}{rule.latex_right}"

def visit_UnaryOp(self, node: ast.UnaryOp) -> str:
Expand Down
Loading