Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[release/8.0] Options Source Gen Fixes #91432

Merged
merged 4 commits into from
Sep 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 66 additions & 19 deletions src/libraries/Microsoft.Extensions.Options/gen/Emitter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.DotnetRuntime.Extensions;

namespace Microsoft.Extensions.Options.Generators
{
Expand All @@ -25,6 +27,7 @@ internal sealed class Emitter : EmitterBase
private string _staticValidationAttributeHolderClassFQN;
private string _staticValidatorHolderClassFQN;
private string _modifier;
private string _TryGetValueNullableAnnotation;

private sealed record StaticFieldInfo(string FieldTypeFQN, int FieldOrder, string FieldName, IList<string> InstantiationLines);

Expand All @@ -37,13 +40,14 @@ public Emitter(Compilation compilation, bool emitPreamble = true) : base(emitPre
else
{
_modifier = "internal";
string suffix = $"_{new Random().Next():X8}";
string suffix = $"_{GetNonRandomizedHashCode(compilation.SourceModule.Name):X8}";
_staticValidationAttributeHolderClassName += suffix;
_staticValidatorHolderClassName += suffix;
}

_staticValidationAttributeHolderClassFQN = $"global::{StaticFieldHolderClassesNamespace}.{_staticValidationAttributeHolderClassName}";
_staticValidatorHolderClassFQN = $"global::{StaticFieldHolderClassesNamespace}.{_staticValidatorHolderClassName}";
_TryGetValueNullableAnnotation = GetNullableAnnotationStringForTryValidateValueToUseInGeneratedCode(compilation);
}

public string Emit(
Expand All @@ -65,6 +69,31 @@ public string Emit(
return Capture();
}

/// <summary>
/// Returns the nullable annotation string to use in the code generation according to the first parameter of
/// <see cref="System.ComponentModel.DataAnnotations.Validator.TryValidateValue(object, ValidationContext, ICollection{ValidationResult}, IEnumerable{ValidationAttribute})"/> is nullable annotated.
/// </summary>
/// <param name="compilation">The <see cref="Compilation"/> to consider for analysis.</param>
/// <returns>"!" if the first parameter is not nullable annotated, otherwise an empty string.</returns>
/// <remarks>
/// In .NET 8.0 we have changed the nullable annotation on first parameter of the method cref="System.ComponentModel.DataAnnotations.Validator.TryValidateValue(object, ValidationContext, ICollection{ValidationResult}, IEnumerable{ValidationAttribute})"/>
/// The source generator need to detect if we need to append "!" to the first parameter of the method call when running on down-level versions.
/// </remarks>
private static string GetNullableAnnotationStringForTryValidateValueToUseInGeneratedCode(Compilation compilation)
{
INamedTypeSymbol? validatorTypeSymbol = compilation.GetBestTypeByMetadataName("System.ComponentModel.DataAnnotations.Validator");
if (validatorTypeSymbol is not null)
{
ImmutableArray<ISymbol> members = validatorTypeSymbol.GetMembers("TryValidateValue");
if (members.Length == 1 && members[0] is IMethodSymbol tryValidateValueMethod)
{
return tryValidateValueMethod.Parameters[0].NullableAnnotation == NullableAnnotation.NotAnnotated ? "!" : string.Empty;
}
}

return "!";
}

private void GenValidatorType(ValidatorType vt, ref Dictionary<string, StaticFieldInfo> staticValidationAttributesDict, ref Dictionary<string, StaticFieldInfo> staticValidatorsDict)
{
if (vt.Namespace.Length > 0)
Expand Down Expand Up @@ -161,7 +190,7 @@ private void GenModelSelfValidationIfNecessary(ValidatedModel modelToValidate)
{
if (modelToValidate.SelfValidates)
{
OutLn($"builder.AddResults(((global::System.ComponentModel.DataAnnotations.IValidatableObject)options).Validate(context));");
OutLn($"(builder ??= new()).AddResults(((global::System.ComponentModel.DataAnnotations.IValidatableObject)options).Validate(context));");
OutLn();
}
}
Expand All @@ -182,8 +211,7 @@ private void GenModelValidationMethod(

OutLn($"public {(makeStatic ? "static " : string.Empty)}global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, {modelToValidate.Name} options)");
OutOpenBrace();
OutLn($"var baseName = (string.IsNullOrEmpty(name) ? \"{modelToValidate.SimpleName}\" : name) + \".\";");
OutLn($"var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder();");
OutLn($"global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null;");
OutLn($"var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options);");

int capacity = modelToValidate.MembersToValidate.Max(static vm => vm.ValidationAttributes.Count);
Expand All @@ -199,33 +227,33 @@ private void GenModelValidationMethod(
{
if (vm.ValidationAttributes.Count > 0)
{
GenMemberValidation(vm, ref staticValidationAttributesDict, cleanListsBeforeUse);
GenMemberValidation(vm, modelToValidate.SimpleName, ref staticValidationAttributesDict, cleanListsBeforeUse);
cleanListsBeforeUse = true;
OutLn();
}

if (vm.TransValidatorType is not null)
{
GenTransitiveValidation(vm, ref staticValidatorsDict);
GenTransitiveValidation(vm, modelToValidate.SimpleName, ref staticValidatorsDict);
OutLn();
}

if (vm.EnumerationValidatorType is not null)
{
GenEnumerationValidation(vm, ref staticValidatorsDict);
GenEnumerationValidation(vm, modelToValidate.SimpleName, ref staticValidatorsDict);
OutLn();
}
}

GenModelSelfValidationIfNecessary(modelToValidate);
OutLn($"return builder.Build();");
OutLn($"return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build();");
OutCloseBrace();
}

private void GenMemberValidation(ValidatedMember vm, ref Dictionary<string, StaticFieldInfo> staticValidationAttributesDict, bool cleanListsBeforeUse)
private void GenMemberValidation(ValidatedMember vm, string modelName, ref Dictionary<string, StaticFieldInfo> staticValidationAttributesDict, bool cleanListsBeforeUse)
{
OutLn($"context.MemberName = \"{vm.Name}\";");
OutLn($"context.DisplayName = baseName + \"{vm.Name}\";");
OutLn($"context.DisplayName = string.IsNullOrEmpty(name) ? \"{modelName}.{vm.Name}\" : $\"{{name}}.{vm.Name}\";");

if (cleanListsBeforeUse)
{
Expand All @@ -239,9 +267,9 @@ private void GenMemberValidation(ValidatedMember vm, ref Dictionary<string, Stat
OutLn($"validationAttributes.Add({_staticValidationAttributeHolderClassFQN}.{staticValidationAttributeInstance.FieldName});");
}

OutLn($"if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.{vm.Name}!, context, validationResults, validationAttributes))");
OutLn($"if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.{vm.Name}{_TryGetValueNullableAnnotation}, context, validationResults, validationAttributes))");
OutOpenBrace();
OutLn($"builder.AddResults(validationResults);");
OutLn($"(builder ??= new()).AddResults(validationResults);");
OutCloseBrace();
}

Expand Down Expand Up @@ -305,7 +333,7 @@ private StaticFieldInfo GetOrAddStaticValidationAttribute(ref Dictionary<string,
return staticValidationAttributeInstance;
}

private void GenTransitiveValidation(ValidatedMember vm, ref Dictionary<string, StaticFieldInfo> staticValidatorsDict)
private void GenTransitiveValidation(ValidatedMember vm, string modelName, ref Dictionary<string, StaticFieldInfo> staticValidatorsDict)
{
string callSequence;
if (vm.TransValidateTypeIsSynthetic)
Expand All @@ -321,20 +349,22 @@ private void GenTransitiveValidation(ValidatedMember vm, ref Dictionary<string,

var valueAccess = (vm.IsNullable && vm.IsValueType) ? ".Value" : string.Empty;

var baseName = $"string.IsNullOrEmpty(name) ? \"{modelName}.{vm.Name}\" : $\"{{name}}.{vm.Name}\"";

if (vm.IsNullable)
{
OutLn($"if (options.{vm.Name} is not null)");
OutOpenBrace();
OutLn($"builder.AddResult({callSequence}.Validate(baseName + \"{vm.Name}\", options.{vm.Name}{valueAccess}));");
OutLn($"(builder ??= new()).AddResult({callSequence}.Validate({baseName}, options.{vm.Name}{valueAccess}));");
OutCloseBrace();
}
else
{
OutLn($"builder.AddResult({callSequence}.Validate(baseName + \"{vm.Name}\", options.{vm.Name}{valueAccess}));");
OutLn($"(builder ??= new()).AddResult({callSequence}.Validate({baseName}, options.{vm.Name}{valueAccess}));");
}
}

private void GenEnumerationValidation(ValidatedMember vm, ref Dictionary<string, StaticFieldInfo> staticValidatorsDict)
private void GenEnumerationValidation(ValidatedMember vm, string modelName, ref Dictionary<string, StaticFieldInfo> staticValidatorsDict)
{
var valueAccess = (vm.IsValueType && vm.IsNullable) ? ".Value" : string.Empty;
var enumeratedValueAccess = (vm.EnumeratedIsNullable && vm.EnumeratedIsValueType) ? ".Value" : string.Empty;
Expand Down Expand Up @@ -365,22 +395,25 @@ private void GenEnumerationValidation(ValidatedMember vm, ref Dictionary<string,
{
OutLn($"if (o is not null)");
OutOpenBrace();
OutLn($"builder.AddResult({callSequence}.Validate(baseName + $\"{vm.Name}[{{count}}]\", o{enumeratedValueAccess}));");
var propertyName = $"string.IsNullOrEmpty(name) ? $\"{modelName}.{vm.Name}[{{count}}]\" : $\"{{name}}.{vm.Name}[{{count}}]\"";
OutLn($"(builder ??= new()).AddResult({callSequence}.Validate({propertyName}, o{enumeratedValueAccess}));");
OutCloseBrace();

if (!vm.EnumeratedMayBeNull)
{
OutLn($"else");
OutOpenBrace();
OutLn($"builder.AddError(baseName + $\"{vm.Name}[{{count}}] is null\");");
var error = $"string.IsNullOrEmpty(name) ? $\"{modelName}.{vm.Name}[{{count}}] is null\" : $\"{{name}}.{vm.Name}[{{count}}] is null\"";
OutLn($"(builder ??= new()).AddError({error});");
OutCloseBrace();
}

OutLn($"count++;");
}
else
{
OutLn($"builder.AddResult({callSequence}.Validate(baseName + $\"{vm.Name}[{{count++}}]\", o{enumeratedValueAccess}));");
var propertyName = $"string.IsNullOrEmpty(name) ? $\"{modelName}.{vm.Name}[{{count++}}] is null\" : $\"{{name}}.{vm.Name}[{{count++}}] is null\"";
OutLn($"(builder ??= new()).AddResult({callSequence}.Validate({propertyName}, o{enumeratedValueAccess}));");
}

OutCloseBrace();
Expand All @@ -405,5 +438,19 @@ private StaticFieldInfo GetOrAddStaticValidator(ref Dictionary<string, StaticFie

return staticValidatorInstance;
}

/// <summary>
/// Returns a non-randomized hash code for the given string.
/// We always return a positive value.
/// </summary>
internal static int GetNonRandomizedHashCode(string s)
{
uint result = 2166136261u;
foreach (char c in s)
{
result = (c ^ result) * 16777619;
}
return Math.Abs((int)result);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

<ItemGroup>
<Compile Include="$(CoreLibSharedDir)System\Runtime\CompilerServices\IsExternalInit.cs" Link="Common\System\Runtime\CompilerServices\IsExternalInit.cs" />
<Compile Include="$(CommonPath)\Roslyn\GetBestTypeByMetadataName.cs" Link="Common\Roslyn\GetBestTypeByMetadataName.cs" />
<Compile Include="DiagDescriptors.cs" />
<Compile Include="DiagDescriptorsBase.cs" />
<Compile Include="Emitter.cs" />
Expand Down
7 changes: 7 additions & 0 deletions src/libraries/Microsoft.Extensions.Options/gen/Parser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,13 @@ private static bool HasOpenGenerics(ITypeSymbol type, out string genericType)
type = ((INamedTypeSymbol)type).TypeArguments[0];
}

// Check first if the type is IEnumerable<T> interface
if (SymbolEqualityComparer.Default.Equals(type.OriginalDefinition, _symbolHolder.GenericIEnumerableSymbol))
{
return ((INamedTypeSymbol)type).TypeArguments[0];
}

// Check first if the type implement IEnumerable<T> interface
foreach (var implementingInterface in type.AllInterfaces)
{
if (SymbolEqualityComparer.Default.Equals(implementingInterface.OriginalDefinition, _compilation.GetSpecialType(SpecialType.System_Collections_Generic_IEnumerable_T)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ internal sealed record class SymbolHolder(
INamedTypeSymbol DataTypeAttributeSymbol,
INamedTypeSymbol ValidateOptionsSymbol,
INamedTypeSymbol IValidatableObjectSymbol,
INamedTypeSymbol GenericIEnumerableSymbol,
INamedTypeSymbol TypeSymbol,
INamedTypeSymbol? ValidateObjectMembersAttributeSymbol,
INamedTypeSymbol? ValidateEnumeratedItemsAttributeSymbol);
INamedTypeSymbol ValidateObjectMembersAttributeSymbol,
INamedTypeSymbol ValidateEnumeratedItemsAttributeSymbol);
}
27 changes: 12 additions & 15 deletions src/libraries/Microsoft.Extensions.Options/gen/SymbolLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,35 +15,33 @@ internal static class SymbolLoader
internal const string TypeOfType = "System.Type";
internal const string ValidateObjectMembersAttribute = "Microsoft.Extensions.Options.ValidateObjectMembersAttribute";
internal const string ValidateEnumeratedItemsAttribute = "Microsoft.Extensions.Options.ValidateEnumeratedItemsAttribute";
internal const string GenericIEnumerableType = "System.Collections.Generic.IEnumerable`1";

public static bool TryLoad(Compilation compilation, out SymbolHolder? symbolHolder)
{
INamedTypeSymbol? GetSymbol(string metadataName, bool optional = false)
{
var symbol = compilation.GetTypeByMetadataName(metadataName);
if (symbol == null && !optional)
{
return null;
}

return symbol;
}
INamedTypeSymbol? GetSymbol(string metadataName) => compilation.GetTypeByMetadataName(metadataName);

// required
var optionsValidatorSymbol = GetSymbol(OptionsValidatorAttribute);
var validationAttributeSymbol = GetSymbol(ValidationAttribute);
var dataTypeAttributeSymbol = GetSymbol(DataTypeAttribute);
var ivalidatableObjectSymbol = GetSymbol(IValidatableObjectType);
var validateOptionsSymbol = GetSymbol(IValidateOptionsType);
var genericIEnumerableSymbol = GetSymbol(GenericIEnumerableType);
var typeSymbol = GetSymbol(TypeOfType);
var validateObjectMembersAttribute = GetSymbol(ValidateObjectMembersAttribute);
var validateEnumeratedItemsAttribute = GetSymbol(ValidateEnumeratedItemsAttribute);

#pragma warning disable S1067 // Expressions should not be too complex
if (optionsValidatorSymbol == null ||
validationAttributeSymbol == null ||
dataTypeAttributeSymbol == null ||
ivalidatableObjectSymbol == null ||
validateOptionsSymbol == null ||
typeSymbol == null)
genericIEnumerableSymbol == null ||
typeSymbol == null ||
validateObjectMembersAttribute == null ||
validateEnumeratedItemsAttribute == null)
{
symbolHolder = default;
return false;
Expand All @@ -56,11 +54,10 @@ public static bool TryLoad(Compilation compilation, out SymbolHolder? symbolHold
dataTypeAttributeSymbol,
validateOptionsSymbol,
ivalidatableObjectSymbol,
genericIEnumerableSymbol,
typeSymbol,

// optional
GetSymbol(ValidateObjectMembersAttribute, optional: true),
GetSymbol(ValidateEnumeratedItemsAttribute, optional: true));
validateObjectMembersAttribute,
validateEnumeratedItemsAttribute);

return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,31 +70,30 @@ partial struct MyOptionsValidator
[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")]
public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::HelloWorld.MyOptions options)
{
var baseName = (string.IsNullOrEmpty(name) ? "MyOptions" : name) + ".";
var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder();
global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null;
var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options);
var validationResults = new global::System.Collections.Generic.List<global::System.ComponentModel.DataAnnotations.ValidationResult>();
var validationAttributes = new global::System.Collections.Generic.List<global::System.ComponentModel.DataAnnotations.ValidationAttribute>(1);

context.MemberName = "Val1";
context.DisplayName = baseName + "Val1";
context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.Val1" : $"{name}.Val1";
validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1);
if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val1!, context, validationResults, validationAttributes))
if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val1, context, validationResults, validationAttributes))
{
builder.AddResults(validationResults);
(builder ??= new()).AddResults(validationResults);
}

context.MemberName = "Val2";
context.DisplayName = baseName + "Val2";
context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.Val2" : $"{name}.Val2";
validationResults.Clear();
validationAttributes.Clear();
validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2);
if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val2!, context, validationResults, validationAttributes))
if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val2, context, validationResults, validationAttributes))
{
builder.AddResults(validationResults);
(builder ??= new()).AddResults(validationResults);
}

return builder.Build();
return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build();
}
}
}
Expand Down
Loading
Loading