Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide a code-fix to add missing stateful marshaller shape methods #73186

Merged
merged 8 commits into from
Aug 5, 2022
2 changes: 1 addition & 1 deletion eng/Versions.props
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@
<FsCheckVersion>2.14.3</FsCheckVersion>
<!-- Uncomment to set a fixed version, else the latest is used -->
<SdkVersionForWorkloadTesting>7.0.100-rc.1.22402.35</SdkVersionForWorkloadTesting>
<CompilerPlatformTestingVersion>1.1.2-beta1.22205.2</CompilerPlatformTestingVersion>
<CompilerPlatformTestingVersion>1.1.2-beta1.22403.2</CompilerPlatformTestingVersion>
<!-- Docs -->
<MicrosoftPrivateIntellisenseVersion>7.0.0-preview-20220721.1</MicrosoftPrivateIntellisenseVersion>
<!-- ILLink -->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ public static class DefaultMarshalModeDiagnostics
description: GetResourceString(nameof(SR.StatelessLinearCollectionRequiresTwoParameterAllocateContainerForManagedElementsDescription)));

/// <inheritdoc cref="CustomMarshallerAttributeAnalyzer.StatefulMarshallerRequiresFromManagedRule" />
public static readonly DiagnosticDescriptor StatefulMarshallerRequiresFromManagedRule =
private static readonly DiagnosticDescriptor StatefulMarshallerRequiresFromManagedRule =
new DiagnosticDescriptor(
Ids.CustomMarshallerTypeMustHaveRequiredShape,
GetResourceString(nameof(SR.CustomMarshallerTypeMustHaveRequiredShapeTitle)),
Expand All @@ -479,7 +479,7 @@ public static class DefaultMarshalModeDiagnostics
description: GetResourceString(nameof(SR.StatefulMarshallerRequiresFromManagedDescription)));

/// <inheritdoc cref="CustomMarshallerAttributeAnalyzer.StatefulMarshallerRequiresToUnmanagedRule" />
public static readonly DiagnosticDescriptor StatefulMarshallerRequiresToUnmanagedRule =
private static readonly DiagnosticDescriptor StatefulMarshallerRequiresToUnmanagedRule =
new DiagnosticDescriptor(
Ids.CustomMarshallerTypeMustHaveRequiredShape,
GetResourceString(nameof(SR.CustomMarshallerTypeMustHaveRequiredShapeTitle)),
Expand All @@ -490,7 +490,7 @@ public static class DefaultMarshalModeDiagnostics
description: GetResourceString(nameof(SR.StatefulMarshallerRequiresToUnmanagedDescription)));

/// <inheritdoc cref="CustomMarshallerAttributeAnalyzer.StatefulMarshallerRequiresToManagedRule" />
public static readonly DiagnosticDescriptor StatefulMarshallerRequiresToManagedRule =
private static readonly DiagnosticDescriptor StatefulMarshallerRequiresToManagedRule =
new DiagnosticDescriptor(
Ids.CustomMarshallerTypeMustHaveRequiredShape,
GetResourceString(nameof(SR.CustomMarshallerTypeMustHaveRequiredShapeTitle)),
Expand All @@ -501,7 +501,7 @@ public static class DefaultMarshalModeDiagnostics
description: GetResourceString(nameof(SR.StatefulMarshallerRequiresToManagedDescription)));

/// <inheritdoc cref="CustomMarshallerAttributeAnalyzer.StatefulMarshallerRequiresFromUnmanagedRule" />
public static readonly DiagnosticDescriptor StatefulMarshallerRequiresFromUnmanagedRule =
private static readonly DiagnosticDescriptor StatefulMarshallerRequiresFromUnmanagedRule =
new DiagnosticDescriptor(
Ids.CustomMarshallerTypeMustHaveRequiredShape,
GetResourceString(nameof(SR.CustomMarshallerTypeMustHaveRequiredShapeTitle)),
Expand All @@ -511,7 +511,7 @@ public static class DefaultMarshalModeDiagnostics
isEnabledByDefault: true,
description: GetResourceString(nameof(SR.StatefulMarshallerRequiresFromUnmanagedDescription)));

internal static DiagnosticDescriptor GetDefaultMarshalModeDiagnostic(DiagnosticDescriptor errorDescriptor)
public static DiagnosticDescriptor GetDefaultMarshalModeDiagnostic(DiagnosticDescriptor errorDescriptor)
{
if (ReferenceEquals(errorDescriptor, CustomMarshallerAttributeAnalyzer.StatelessValueInRequiresConvertToUnmanagedRule))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ private static void AddMissingMembers(
{
AddMissingMembersToStatelessMarshaller(editor, declaringSyntax, marshallerType, managedType, missingMemberNames, isLinearCollectionMarshaller);
}
if (marshallerType.IsValueType)
{
AddMissingMembersToStatefulMarshaller(editor, declaringSyntax, marshallerType, managedType, missingMemberNames, isLinearCollectionMarshaller);
}
}

private static void AddMissingMembersToStatelessMarshaller(DocumentEditor editor, SyntaxNode declaringSyntax, INamedTypeSymbol marshallerType, ITypeSymbol managedType, HashSet<string> missingMemberNames, bool isLinearCollectionMarshaller)
Expand Down Expand Up @@ -398,6 +402,173 @@ ITypeSymbol CreateManagedElementTypeSymbol()
}
}

private static void AddMissingMembersToStatefulMarshaller(DocumentEditor editor, SyntaxNode declaringSyntax, INamedTypeSymbol marshallerType, ITypeSymbol managedType, HashSet<string> missingMemberNames, bool isLinearCollectionMarshaller)
{
SyntaxGenerator gen = editor.Generator;
// Get the methods of the shape so we can use them to determine what types to use in signatures that are not obvious.
var (_, methods) = StatefulMarshallerShapeHelper.GetShapeForType(marshallerType, managedType, isLinearCollectionMarshaller, editor.SemanticModel.Compilation);
INamedTypeSymbol spanOfT = editor.SemanticModel.Compilation.GetBestTypeByMetadataName(TypeNames.System_Span_Metadata)!;
INamedTypeSymbol readOnlySpanOfT = editor.SemanticModel.Compilation.GetBestTypeByMetadataName(TypeNames.System_ReadOnlySpan_Metadata)!;
var (typeParameters, _) = marshallerType.GetAllTypeArgumentsIncludingInContainingTypes();

// Use a lazy factory for the type syntaxes to avoid re-checking the various methods and reconstructing the syntax.
Lazy<SyntaxNode> unmanagedTypeSyntax = new(CreateUnmanagedTypeSyntax, isThreadSafe: false);
Lazy<ITypeSymbol> managedElementTypeSymbol = new(CreateManagedElementTypeSymbol, isThreadSafe: false);

List<SyntaxNode> newMembers = new();

if (missingMemberNames.Contains(ShapeMemberNames.Value.Stateful.FromManaged))
{
newMembers.Add(
gen.MethodDeclaration(
ShapeMemberNames.Value.Stateful.FromManaged,
parameters: new[] { gen.ParameterDeclaration("managed", gen.TypeExpression(managedType)) },
accessibility: Accessibility.Public,
statements: new[] { DefaultMethodStatement(gen, editor.SemanticModel.Compilation) }));
}

if (missingMemberNames.Contains(ShapeMemberNames.Value.Stateful.ToUnmanaged))
{
newMembers.Add(
gen.MethodDeclaration(
ShapeMemberNames.Value.Stateful.ToUnmanaged,
returnType: unmanagedTypeSyntax.Value,
accessibility: Accessibility.Public,
statements: new[] { DefaultMethodStatement(gen, editor.SemanticModel.Compilation) }));
}

if (missingMemberNames.Contains(ShapeMemberNames.Value.Stateful.FromUnmanaged))
{
newMembers.Add(
gen.MethodDeclaration(
ShapeMemberNames.Value.Stateful.FromUnmanaged,
parameters: new[] { gen.ParameterDeclaration("unmanaged", unmanagedTypeSyntax.Value) },
accessibility: Accessibility.Public,
statements: new[] { DefaultMethodStatement(gen, editor.SemanticModel.Compilation) }));
}

if (missingMemberNames.Contains(ShapeMemberNames.Value.Stateful.ToManaged))
{
newMembers.Add(
gen.MethodDeclaration(
ShapeMemberNames.Value.Stateful.ToManaged,
returnType: gen.TypeExpression(managedType),
accessibility: Accessibility.Public,
statements: new[] { DefaultMethodStatement(gen, editor.SemanticModel.Compilation) }));
}

if (missingMemberNames.Contains(ShapeMemberNames.BufferSize))
{
newMembers.Add(
gen.WithAccessorDeclarations(
gen.PropertyDeclaration(ShapeMemberNames.BufferSize,
gen.TypeExpression(editor.SemanticModel.Compilation.GetSpecialType(SpecialType.System_Int32)),
Accessibility.Public,
DeclarationModifiers.Static),
gen.GetAccessorDeclaration(statements: new[] { DefaultMethodStatement(gen, editor.SemanticModel.Compilation) })));
}

if (missingMemberNames.Contains(ShapeMemberNames.LinearCollection.Stateless.GetManagedValuesSource))
jkoritzinsky marked this conversation as resolved.
Show resolved Hide resolved
{
newMembers.Add(
gen.MethodDeclaration(
ShapeMemberNames.LinearCollection.Stateless.GetManagedValuesSource,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stateless -> Stateful in a bunch of places below. Or maybe we should put the shared names just under ShapeMemeberNames.LinearCollection?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the only ones we have shared are the ones with the same signatures. It might be worthwhile sharing the common names that represent slightly different method signatures, but I don't think that needs to be done in this PR.

returnType: gen.TypeExpression(readOnlySpanOfT.Construct(managedElementTypeSymbol.Value)),
accessibility: Accessibility.Public,
statements: new[] { DefaultMethodStatement(gen, editor.SemanticModel.Compilation) }));
}

if (missingMemberNames.Contains(ShapeMemberNames.LinearCollection.Stateless.GetUnmanagedValuesDestination))
{
newMembers.Add(
gen.MethodDeclaration(
ShapeMemberNames.LinearCollection.Stateless.GetUnmanagedValuesDestination,
returnType: gen.TypeExpression(spanOfT.Construct(typeParameters[typeParameters.Length - 1])),
accessibility: Accessibility.Public,
statements: new[] { DefaultMethodStatement(gen, editor.SemanticModel.Compilation) }));
}

if (missingMemberNames.Contains(ShapeMemberNames.LinearCollection.Stateless.GetUnmanagedValuesSource))
{
newMembers.Add(
gen.MethodDeclaration(
ShapeMemberNames.LinearCollection.Stateless.GetUnmanagedValuesSource,
parameters: new[]
{
gen.ParameterDeclaration("numElements", gen.TypeExpression(SpecialType.System_Int32))
},
returnType: gen.TypeExpression(readOnlySpanOfT.Construct(typeParameters[typeParameters.Length - 1])),
accessibility: Accessibility.Public,
statements: new[] { DefaultMethodStatement(gen, editor.SemanticModel.Compilation) }));
}

if (missingMemberNames.Contains(ShapeMemberNames.LinearCollection.Stateless.GetManagedValuesDestination))
{
newMembers.Add(
gen.MethodDeclaration(
ShapeMemberNames.LinearCollection.Stateless.GetManagedValuesDestination,
parameters: new[]
{
gen.ParameterDeclaration("numElements", gen.TypeExpression(SpecialType.System_Int32))
},
returnType: gen.TypeExpression(spanOfT.Construct(managedElementTypeSymbol.Value)),
accessibility: Accessibility.Public,
statements: new[] { DefaultMethodStatement(gen, editor.SemanticModel.Compilation) }));
}

if (missingMemberNames.Contains(ShapeMemberNames.Free))
{
newMembers.Add(
gen.MethodDeclaration(
ShapeMemberNames.Value.Stateful.Free,
accessibility: Accessibility.Public,
statements: new[] { DefaultMethodStatement(gen, editor.SemanticModel.Compilation) }));
}

editor.ReplaceNode(declaringSyntax, (declaringSyntax, gen) => gen.AddMembers(declaringSyntax, newMembers));

SyntaxNode CreateUnmanagedTypeSyntax()
{
ITypeSymbol? unmanagedType = null;
if (methods.ToUnmanaged is not null)
{
unmanagedType = methods.ToUnmanaged.ReturnType;
}
else if (methods.FromUnmanaged is not null)
{
unmanagedType = methods.FromUnmanaged.Parameters[0].Type;
}
else if (methods.UnmanagedValuesSource is not null)
{
unmanagedType = methods.UnmanagedValuesSource.Parameters[0].Type;
}
else if (methods.UnmanagedValuesDestination is not null)
{
unmanagedType = methods.UnmanagedValuesDestination.Parameters[0].Type;
}

if (unmanagedType is not null)
{
return gen.TypeExpression(unmanagedType);
}
return gen.TypeExpression(editor.SemanticModel.Compilation.GetSpecialType(SpecialType.System_IntPtr));
}

ITypeSymbol CreateManagedElementTypeSymbol()
{
if (methods.ManagedValuesSource is not null)
{
return ((INamedTypeSymbol)methods.ManagedValuesSource.ReturnType).TypeArguments[0];
}
if (methods.ManagedValuesDestination is not null)
{
return ((INamedTypeSymbol)methods.ManagedValuesDestination.ReturnType).TypeArguments[0];
}

return editor.SemanticModel.Compilation.GetSpecialType(SpecialType.System_IntPtr);
}
}

private static SyntaxNode DefaultMethodStatement(SyntaxGenerator generator, Compilation compilation)
{
return generator.ThrowStatement(generator.ObjectCreationExpression(
Expand Down
Loading