diff --git a/libcst/_nodes/tests/test_namedexpr.py b/libcst/_nodes/tests/test_namedexpr.py index 4ba1485b5..9e26bfda2 100644 --- a/libcst/_nodes/tests/test_namedexpr.py +++ b/libcst/_nodes/tests/test_namedexpr.py @@ -101,6 +101,69 @@ class NamedExprTest(CSTNodeTest): "parser": _parse_statement_force_38, "expected_position": None, }, + # Function args + { + "node": cst.Call( + func=cst.Name(value="f"), + args=[ + cst.Arg( + value=cst.NamedExpr( + target=cst.Name(value="y"), + value=cst.Integer(value="1"), + whitespace_before_walrus=cst.SimpleWhitespace(""), + whitespace_after_walrus=cst.SimpleWhitespace(""), + ) + ), + ], + ), + "code": "f(y:=1)", + "parser": _parse_expression_force_38, + "expected_position": None, + }, + # Whitespace handling on args is fragile + { + "node": cst.Call( + func=cst.Name(value="f"), + args=[ + cst.Arg( + value=cst.Name(value="x"), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.Arg( + value=cst.NamedExpr( + target=cst.Name(value="y"), + value=cst.Integer(value="1"), + whitespace_before_walrus=cst.SimpleWhitespace(" "), + whitespace_after_walrus=cst.SimpleWhitespace(" "), + ), + whitespace_after_arg=cst.SimpleWhitespace(" "), + ), + ], + ), + "code": "f(x, y := 1 )", + "parser": _parse_expression_force_38, + "expected_position": None, + }, + { + "node": cst.Call( + func=cst.Name(value="f"), + args=[ + cst.Arg( + value=cst.NamedExpr( + target=cst.Name(value="y"), + value=cst.Integer(value="1"), + whitespace_before_walrus=cst.SimpleWhitespace(" "), + whitespace_after_walrus=cst.SimpleWhitespace(" "), + ), + whitespace_after_arg=cst.SimpleWhitespace(" "), + ), + ], + whitespace_before_args=cst.SimpleWhitespace(" "), + ), + "code": "f( y := 1 )", + "parser": _parse_expression_force_38, + "expected_position": None, + }, ) ) def test_valid(self, **kwargs: Any) -> None: diff --git a/libcst/_parser/conversions/expression.py b/libcst/_parser/conversions/expression.py index 8edbf262f..a88eb3104 100644 --- a/libcst/_parser/conversions/expression.py +++ b/libcst/_parser/conversions/expression.py @@ -1441,8 +1441,13 @@ def convert_arg_assign_comp_for( elt, for_in = children return Arg(value=GeneratorExp(elt.value, for_in, lpar=(), rpar=())) else: - # "key = value" assignment argument lhs, equal, rhs = children + # "key := value" assignment; positional + if equal.string == ":=": + val = convert_namedexpr_test(config, children) + assert isinstance(val, WithLeadingWhitespace) + return Arg(value=val.value) + # "key = value" assignment; keyword argument return Arg( keyword=lhs.value, equal=AssignEqual(