diff --git a/src/Workspaces/CSharp/Portable/CodeGeneration/CSharpSyntaxGenerator.cs b/src/Workspaces/CSharp/Portable/CodeGeneration/CSharpSyntaxGenerator.cs index 662815b7b6333..66876defa2e0d 100644 --- a/src/Workspaces/CSharp/Portable/CodeGeneration/CSharpSyntaxGenerator.cs +++ b/src/Workspaces/CSharp/Portable/CodeGeneration/CSharpSyntaxGenerator.cs @@ -998,11 +998,11 @@ private SyntaxNode InsertAttributesInternal(SyntaxNode declaration, int index, I var existingAttributes = this.GetAttributes(declaration); if (index >= 0 && index < existingAttributes.Count) { - return this.InsertNodesBefore(declaration, existingAttributes[index], newAttributes); + return this.InsertNodesBefore(declaration, existingAttributes[index], WithRequiredTargetSpecifier(newAttributes, declaration)); } else if (existingAttributes.Count > 0) { - return this.InsertNodesAfter(declaration, existingAttributes[existingAttributes.Count - 1], newAttributes); + return this.InsertNodesAfter(declaration, existingAttributes[existingAttributes.Count - 1], WithRequiredTargetSpecifier(newAttributes, declaration)); } else { @@ -1059,6 +1059,16 @@ private static SyntaxList AsAssemblyAttributes(IEnumerable< attributes.Select(list => list.WithTarget(SyntaxFactory.AttributeTargetSpecifier(SyntaxFactory.Token(SyntaxKind.AssemblyKeyword))))); } + private static SyntaxList WithRequiredTargetSpecifier(SyntaxList attributes, SyntaxNode declaration) + { + if (!declaration.IsKind(SyntaxKind.CompilationUnit)) + { + return attributes; + } + + return AsAssemblyAttributes(attributes); + } + public override IReadOnlyList GetAttributeArguments(SyntaxNode attributeDeclaration) { switch (attributeDeclaration.Kind()) diff --git a/src/Workspaces/CSharpTest/CodeGeneration/SyntaxGeneratorTests.cs b/src/Workspaces/CSharpTest/CodeGeneration/SyntaxGeneratorTests.cs index c3d3e7912249a..5b60a15647e83 100644 --- a/src/Workspaces/CSharpTest/CodeGeneration/SyntaxGeneratorTests.cs +++ b/src/Workspaces/CSharpTest/CodeGeneration/SyntaxGeneratorTests.cs @@ -1588,6 +1588,14 @@ public void TestAddAttributes() Generator.CompilationUnit(Generator.NamespaceDeclaration("n")), Generator.Attribute("a")), "[assembly: a]\r\nnamespace n\r\n{\r\n}"); + + VerifySyntax( + Generator.AddAttributes( + Generator.AddAttributes( + Generator.CompilationUnit(Generator.NamespaceDeclaration("n")), + Generator.Attribute("a")), + Generator.Attribute("b")), + "[assembly: a]\r\n[assembly: b]\r\nnamespace n\r\n{\r\n}"); } [Fact] diff --git a/src/Workspaces/VisualBasic/Portable/CodeGeneration/VisualBasicSyntaxGenerator.vb b/src/Workspaces/VisualBasic/Portable/CodeGeneration/VisualBasicSyntaxGenerator.vb index c168e150cb122..4420d40bbad44 100644 --- a/src/Workspaces/VisualBasic/Portable/CodeGeneration/VisualBasicSyntaxGenerator.vb +++ b/src/Workspaces/VisualBasic/Portable/CodeGeneration/VisualBasicSyntaxGenerator.vb @@ -1637,6 +1637,15 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.CodeGeneration End Function Public Overrides Function GetAttributes(declaration As SyntaxNode) As IReadOnlyList(Of SyntaxNode) + If declaration.IsKind(SyntaxKind.CompilationUnit) Then + ' CompilationUnit syntaxes represent attribute lists in a way that we can't get a single AttributeList for all of the attributes in all cases. + ' However, some consumers of this API assume that all returned values are children of "declaration", so if there's one attribute list, we'll use + ' that value directly if possible. + Dim compilationUnit = DirectCast(declaration, CompilationUnitSyntax) + If compilationUnit.Attributes.Count = 1 Then + Return compilationUnit.Attributes(0).AttributeLists + End If + End If Return Me.Flatten(declaration.GetAttributeLists()) End Function @@ -1649,9 +1658,9 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.CodeGeneration Dim existingAttributes = Me.GetAttributes(declaration) If index >= 0 AndAlso index < existingAttributes.Count Then - Return Me.InsertNodesBefore(declaration, existingAttributes(index), newAttributes) + Return Me.InsertNodesBefore(declaration, existingAttributes(index), WithRequiredTargetSpecifier(newAttributes, declaration)) ElseIf existingAttributes.Count > 0 Then - Return Me.InsertNodesAfter(declaration, existingAttributes(existingAttributes.Count - 1), newAttributes) + Return Me.InsertNodesAfter(declaration, existingAttributes(existingAttributes.Count - 1), WithRequiredTargetSpecifier(newAttributes, declaration)) Else Dim lists = GetAttributeLists(declaration) Return Me.WithAttributeLists(declaration, lists.AddRange(AsAttributeLists(attributes))) @@ -1678,6 +1687,13 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.CodeGeneration End If End Function + Private Function WithRequiredTargetSpecifier(attributes As SyntaxList(Of AttributeListSyntax), declaration As SyntaxNode) As SyntaxList(Of AttributeListSyntax) + If Not declaration.IsKind(SyntaxKind.CompilationUnit) Then + Return attributes + End If + Return SyntaxFactory.List(attributes.Select(AddressOf WithAssemblyTargets)) + End Function + Public Overrides Function GetReturnAttributes(declaration As SyntaxNode) As IReadOnlyList(Of SyntaxNode) Return Me.Flatten(GetReturnAttributeLists(declaration)) End Function diff --git a/src/Workspaces/VisualBasicTest/CodeGeneration/SyntaxGeneratorTests.vb b/src/Workspaces/VisualBasicTest/CodeGeneration/SyntaxGeneratorTests.vb index 4c2f1c33d4d64..36328d79373d0 100644 --- a/src/Workspaces/VisualBasicTest/CodeGeneration/SyntaxGeneratorTests.vb +++ b/src/Workspaces/VisualBasicTest/CodeGeneration/SyntaxGeneratorTests.vb @@ -2181,6 +2181,18 @@ End Class") " Namespace n End Namespace +") + + VerifySyntax(Of CompilationUnitSyntax)( + Generator.AddAttributes( + Generator.AddAttributes( + Generator.CompilationUnit(Generator.NamespaceDeclaration("n")), + Generator.Attribute("a")), + Generator.Attribute("b")), +" + +Namespace n +End Namespace ") VerifySyntax(Of DelegateStatementSyntax)(