diff --git a/libcst/_nodes/statement.py b/libcst/_nodes/statement.py index 135194989..9493f57c2 100644 --- a/libcst/_nodes/statement.py +++ b/libcst/_nodes/statement.py @@ -57,6 +57,7 @@ from libcst._nodes.whitespace import ( BaseParenthesizableWhitespace, EmptyLine, + ParenthesizedWhitespace, SimpleWhitespace, TrailingWhitespace, ) @@ -2017,24 +2018,47 @@ class With(BaseCompoundStatement): #: Sequence of empty lines appearing before this with statement. leading_lines: Sequence[EmptyLine] = () + #: Optional open parenthesis for multi-line with bindings + lpar: Union[LeftParen, MaybeSentinel] = MaybeSentinel.DEFAULT + + #: Optional close parenthesis for multi-line with bindings + rpar: Union[RightParen, MaybeSentinel] = MaybeSentinel.DEFAULT + #: Whitespace after the ``with`` keyword and before the first item. whitespace_after_with: SimpleWhitespace = SimpleWhitespace.field(" ") #: Whitespace after the last item and before the colon. whitespace_before_colon: SimpleWhitespace = SimpleWhitespace.field("") + def _validate_parens(self) -> None: + if isinstance(self.lpar, MaybeSentinel) and isinstance(self.rpar, RightParen): + raise CSTValidationError( + "Do not mix concrete LeftParen/RightParen with MaybeSentinel." + ) + if isinstance(self.lpar, LeftParen) and isinstance(self.rpar, MaybeSentinel): + raise CSTValidationError( + "Do not mix concrete LeftParen/RightParen with MaybeSentinel." + ) + def _validate(self) -> None: + self._validate_parens() if len(self.items) == 0: raise CSTValidationError( "A With statement must have at least one WithItem." ) - if self.items[-1].comma != MaybeSentinel.DEFAULT: + if ( + isinstance(self.rpar, MaybeSentinel) + and self.items[-1].comma != MaybeSentinel.DEFAULT + ): raise CSTValidationError( - "The last WithItem in a With cannot have a trailing comma." + "The last WithItem in an unparenthesized With cannot have a trailing comma." ) - if self.whitespace_after_with.empty and not self.items[ - 0 - ].item._safe_to_use_with_word_operator(ExpressionPosition.RIGHT): + if self.whitespace_after_with.empty and not ( + isinstance(self.lpar, LeftParen) + or self.items[0].item._safe_to_use_with_word_operator( + ExpressionPosition.RIGHT + ) + ): raise CSTValidationError("Must have at least one space after with keyword.") def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "With": @@ -2048,7 +2072,9 @@ def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "With": whitespace_after_with=visit_required( self, "whitespace_after_with", self.whitespace_after_with, visitor ), + lpar=visit_sentinel(self, "lpar", self.lpar, visitor), items=visit_sequence(self, "items", self.items, visitor), + rpar=visit_sentinel(self, "rpar", self.rpar, visitor), whitespace_before_colon=visit_required( self, "whitespace_before_colon", self.whitespace_before_colon, visitor ), @@ -2060,15 +2086,36 @@ def _codegen_impl(self, state: CodegenState) -> None: ll._codegen(state) state.add_indent_tokens() + needs_paren = False + for item in self.items: + comma = item.comma + if isinstance(comma, Comma): + if isinstance( + comma.whitespace_after, + (EmptyLine, TrailingWhitespace, ParenthesizedWhitespace), + ): + needs_paren = True + break + with state.record_syntactic_position(self, end_node=self.body): asynchronous = self.asynchronous if asynchronous is not None: asynchronous._codegen(state) state.add_token("with") self.whitespace_after_with._codegen(state) + lpar = self.lpar + if isinstance(lpar, LeftParen): + lpar._codegen(state) + elif needs_paren: + state.add_token("(") last_item = len(self.items) - 1 for i, item in enumerate(self.items): item._codegen(state, default_comma=(i != last_item)) + rpar = self.rpar + if isinstance(rpar, RightParen): + rpar._codegen(state) + elif needs_paren: + state.add_token(")") self.whitespace_before_colon._codegen(state) state.add_token(":") self.body._codegen(state) diff --git a/libcst/_nodes/tests/test_with.py b/libcst/_nodes/tests/test_with.py index 2246bc2db..1310b3f88 100644 --- a/libcst/_nodes/tests/test_with.py +++ b/libcst/_nodes/tests/test_with.py @@ -7,6 +7,7 @@ import libcst as cst from libcst import parse_statement, PartialParserConfig +from libcst._maybe_sentinel import MaybeSentinel from libcst._nodes.tests.base import CSTNodeTest, DummyIndentedBlock, parse_statement_as from libcst._parser.entrypoints import is_native from libcst.metadata import CodeRange @@ -14,6 +15,8 @@ class WithTest(CSTNodeTest): + maxDiff: int = 2000 + @data_provider( ( # Simple with block @@ -138,45 +141,83 @@ class WithTest(CSTNodeTest): "parser": parse_statement, "expected_position": CodeRange((2, 0), (2, 24)), }, - # Weird spacing rules + # Whitespace + { + "node": cst.With( + ( + cst.WithItem( + cst.Call(cst.Name("context_mgr")), + cst.AsName( + cst.Name("ctx"), + whitespace_before_as=cst.SimpleWhitespace(" "), + whitespace_after_as=cst.SimpleWhitespace(" "), + ), + ), + ), + cst.SimpleStatementSuite((cst.Pass(),)), + whitespace_after_with=cst.SimpleWhitespace(" "), + whitespace_before_colon=cst.SimpleWhitespace(" "), + ), + "code": "with context_mgr() as ctx : pass\n", + "parser": parse_statement, + "expected_position": CodeRange((1, 0), (1, 36)), + }, + # Weird spacing rules, that parse differently depending on whether + # we are using a grammar that included parenthesized with statements. { "node": cst.With( ( cst.WithItem( cst.Call( cst.Name("context_mgr"), - lpar=(cst.LeftParen(),), - rpar=(cst.RightParen(),), + lpar=() if is_native() else (cst.LeftParen(),), + rpar=() if is_native() else (cst.RightParen(),), ) ), ), cst.SimpleStatementSuite((cst.Pass(),)), + lpar=(cst.LeftParen() if is_native() else MaybeSentinel.DEFAULT), + rpar=(cst.RightParen() if is_native() else MaybeSentinel.DEFAULT), whitespace_after_with=cst.SimpleWhitespace(""), ), "code": "with(context_mgr()): pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (1, 25)), }, - # Whitespace + # Multi-line parenthesized with. { "node": cst.With( ( cst.WithItem( - cst.Call(cst.Name("context_mgr")), - cst.AsName( - cst.Name("ctx"), - whitespace_before_as=cst.SimpleWhitespace(" "), - whitespace_after_as=cst.SimpleWhitespace(" "), + cst.Call(cst.Name("foo")), + comma=cst.Comma( + whitespace_after=cst.ParenthesizedWhitespace( + first_line=cst.TrailingWhitespace( + whitespace=cst.SimpleWhitespace( + value="", + ), + comment=None, + newline=cst.Newline( + value=None, + ), + ), + empty_lines=[], + indent=True, + last_line=cst.SimpleWhitespace( + value=" ", + ), + ) ), ), + cst.WithItem(cst.Call(cst.Name("bar")), comma=cst.Comma()), ), cst.SimpleStatementSuite((cst.Pass(),)), - whitespace_after_with=cst.SimpleWhitespace(" "), - whitespace_before_colon=cst.SimpleWhitespace(" "), + lpar=cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")), + rpar=cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")), ), - "code": "with context_mgr() as ctx : pass\n", - "parser": parse_statement, - "expected_position": CodeRange((1, 0), (1, 36)), + "code": ("with ( foo(),\n" " bar(), ): pass\n"), # noqa + "parser": parse_statement if is_native() else None, + "expected_position": CodeRange((1, 0), (2, 21)), }, ) ) @@ -201,7 +242,8 @@ def test_valid(self, **kwargs: Any) -> None: ), cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),)), ), - "expected_re": "The last WithItem in a With cannot have a trailing comma", + "expected_re": "The last WithItem in an unparenthesized With cannot " + + "have a trailing comma.", }, { "get_node": lambda: cst.With( @@ -211,6 +253,26 @@ def test_valid(self, **kwargs: Any) -> None: ), "expected_re": "Must have at least one space after with keyword", }, + { + "get_node": lambda: cst.With( + (cst.WithItem(cst.Call(cst.Name("context_mgr"))),), + cst.SimpleStatementSuite((cst.Pass(),)), + whitespace_after_with=cst.SimpleWhitespace(""), + lpar=cst.LeftParen(), + ), + "expected_re": "Do not mix concrete LeftParen/RightParen with " + + "MaybeSentinel", + }, + { + "get_node": lambda: cst.With( + (cst.WithItem(cst.Call(cst.Name("context_mgr"))),), + cst.SimpleStatementSuite((cst.Pass(),)), + whitespace_after_with=cst.SimpleWhitespace(""), + rpar=cst.RightParen(), + ), + "expected_re": "Do not mix concrete LeftParen/RightParen with " + + "MaybeSentinel", + }, ) ) def test_invalid(self, **kwargs: Any) -> None: @@ -234,3 +296,23 @@ def test_versions(self, **kwargs: Any) -> None: if is_native() and not kwargs.get("expect_success", True): self.skipTest("parse errors are disabled for native parser") self.assert_parses(**kwargs) + + def test_adding_parens(self) -> None: + node = cst.With( + ( + cst.WithItem( + cst.Call(cst.Name("foo")), + comma=cst.Comma( + whitespace_after=cst.ParenthesizedWhitespace(), + ), + ), + cst.WithItem(cst.Call(cst.Name("bar")), comma=cst.Comma()), + ), + cst.SimpleStatementSuite((cst.Pass(),)), + lpar=cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")), + rpar=cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")), + ) + module = cst.Module([]) + self.assertEqual( + module.code_for_node(node), ("with ( foo(),\n" "bar(), ): pass\n") # noqa + ) diff --git a/libcst/_typed_visitor.py b/libcst/_typed_visitor.py index cea085d0d..a880bee48 100644 --- a/libcst/_typed_visitor.py +++ b/libcst/_typed_visitor.py @@ -5379,6 +5379,22 @@ def visit_With_leading_lines(self, node: "With") -> None: def leave_With_leading_lines(self, node: "With") -> None: pass + @mark_no_op + def visit_With_lpar(self, node: "With") -> None: + pass + + @mark_no_op + def leave_With_lpar(self, node: "With") -> None: + pass + + @mark_no_op + def visit_With_rpar(self, node: "With") -> None: + pass + + @mark_no_op + def leave_With_rpar(self, node: "With") -> None: + pass + @mark_no_op def visit_With_whitespace_after_with(self, node: "With") -> None: pass diff --git a/libcst/matchers/__init__.py b/libcst/matchers/__init__.py index 655bc9478..9602de41e 100644 --- a/libcst/matchers/__init__.py +++ b/libcst/matchers/__init__.py @@ -15411,6 +15411,18 @@ class With(BaseCompoundStatement, BaseStatement, BaseMatcherNode): ] ], ] = DoNotCare() + lpar: Union[ + LeftParenMatchType, + DoNotCareSentinel, + OneOf[LeftParenMatchType], + AllOf[LeftParenMatchType], + ] = DoNotCare() + rpar: Union[ + RightParenMatchType, + DoNotCareSentinel, + OneOf[RightParenMatchType], + AllOf[RightParenMatchType], + ] = DoNotCare() whitespace_after_with: Union[ SimpleWhitespaceMatchType, DoNotCareSentinel, diff --git a/native/libcst/src/nodes/statement.rs b/native/libcst/src/nodes/statement.rs index 319e6f130..217253431 100644 --- a/native/libcst/src/nodes/statement.rs +++ b/native/libcst/src/nodes/statement.rs @@ -1927,6 +1927,19 @@ pub struct WithItem<'a> { pub comma: Option>, } +impl<'a> WithItem<'a> { + fn inflate_withitem(mut self, config: &Config<'a>, is_last: bool) -> Result { + self.item = self.item.inflate(config)?; + self.asname = self.asname.inflate(config)?; + self.comma = if is_last { + self.comma.map(|c| c.inflate_before(config)).transpose()? + } else { + self.comma.map(|c| c.inflate(config)).transpose()? + }; + Ok(self) + } +} + impl<'a> Codegen<'a> for WithItem<'a> { fn codegen(&self, state: &mut CodegenState<'a>) { self.item.codegen(state); @@ -1948,21 +1961,14 @@ impl<'a> WithComma<'a> for WithItem<'a> { } } -impl<'a> Inflate<'a> for WithItem<'a> { - fn inflate(mut self, config: &Config<'a>) -> Result { - self.item = self.item.inflate(config)?; - self.asname = self.asname.inflate(config)?; - self.comma = self.comma.inflate(config)?; - Ok(self) - } -} - #[derive(Debug, PartialEq, Eq, Clone, IntoPy)] pub struct With<'a> { pub items: Vec>, pub body: Suite<'a>, pub asynchronous: Option>, pub leading_lines: Vec>, + pub lpar: Option>, + pub rpar: Option>, pub whitespace_after_with: SimpleWhitespace<'a>, pub whitespace_before_colon: SimpleWhitespace<'a>, @@ -1983,6 +1989,18 @@ impl<'a> Codegen<'a> for With<'a> { } state.add_token("with"); self.whitespace_after_with.codegen(state); + + // TODO: Force parens whenever there are newlines in + // the commas of self.items. + // + // For now, only the python API does this. + let need_parens = false; + if let Some(lpar) = &self.lpar { + lpar.codegen(state); + } else if need_parens { + state.add_token("("); + } + let len = self.items.len(); for (i, item) in self.items.iter().enumerate() { item.codegen(state); @@ -1990,6 +2008,13 @@ impl<'a> Codegen<'a> for With<'a> { state.add_token(", "); } } + + if let Some(rpar) = &self.rpar { + rpar.codegen(state); + } else if need_parens { + state.add_token(")"); + } + self.whitespace_before_colon.codegen(state); state.add_token(":"); self.body.codegen(state); @@ -2027,7 +2052,18 @@ impl<'a> Inflate<'a> for With<'a> { self.whitespace_after_with = parse_simple_whitespace(config, &mut (*self.with_tok).whitespace_after.borrow_mut())?; - self.items = self.items.inflate(config)?; + self.lpar = self.lpar.map(|lpar| lpar.inflate(config)).transpose()?; + let len = self.items.len(); + self.items = self + .items + .into_iter() + .enumerate() + .map(|(idx, el)| el.inflate_withitem(config, idx + 1 == len)) + .collect::>>()?; + if !self.items.is_empty() { + // rpar only has whitespace if items is non empty + self.rpar = self.rpar.map(|rpar| rpar.inflate(config)).transpose()?; + } self.whitespace_before_colon = parse_simple_whitespace( config, &mut (*self.colon_tok).whitespace_before.borrow_mut(), diff --git a/native/libcst/src/parser/grammar.rs b/native/libcst/src/parser/grammar.rs index c881be572..70d2f968e 100644 --- a/native/libcst/src/parser/grammar.rs +++ b/native/libcst/src/parser/grammar.rs @@ -473,13 +473,21 @@ parser! { // With statement rule with_stmt() -> With<'a> - = kw:lit("with") items:separated(, ) + = kw:lit("with") l:lpar() items:separated_trailer(, ) r:rpar() col:lit(":") b:block() { - make_with(None, kw, comma_separate(items.0, items.1, None), col, b) + make_with(None, kw, Some(l), comma_separate(items.0, items.1, items.2), Some(r), col, b) + } + / kw:lit("with") items:separated(, ) + col:lit(":") b:block() { + make_with(None, kw, None, comma_separate(items.0, items.1, None), None, col, b) + } + / asy:tok(Async, "ASYNC") kw:lit("with") l:lpar() items:separated_trailer(, ) r:rpar() + col:lit(":") b:block() { + make_with(Some(asy), kw, Some(l), comma_separate(items.0, items.1, items.2), Some(r), col, b) } / asy:tok(Async, "ASYNC") kw:lit("with") items:separated(, ) col:lit(":") b:block() { - make_with(Some(asy), kw, comma_separate(items.0, items.1, None), col, b) + make_with(Some(asy), kw, None, comma_separate(items.0, items.1, None), None, col, b) } rule with_item() -> WithItem<'a> @@ -3218,7 +3226,9 @@ fn make_with_item<'a>( fn make_with<'a>( async_tok: Option>, with_tok: TokenRef<'a>, + lpar: Option>, items: Vec>, + rpar: Option>, colon_tok: TokenRef<'a>, body: Suite<'a>, ) -> With<'a> { @@ -3230,6 +3240,8 @@ fn make_with<'a>( body, asynchronous, leading_lines: Default::default(), + lpar, + rpar, whitespace_after_with: Default::default(), whitespace_before_colon: Default::default(), async_tok,