From 883638bcaa85d5f46f77f6205d5f05175670ad53 Mon Sep 17 00:00:00 2001 From: Nikolay Pianikov Date: Wed, 7 Aug 2024 12:09:26 +0300 Subject: [PATCH] Fix simple factory for generic type markers --- .../Core/ApiInvocationProcessor.cs | 52 ++----- .../Core/Code/FactoryCodeBuilder.cs | 128 +++++++++++++++++- .../Core/DependencyGraphBuilder.cs | 6 +- src/Pure.DI.Core/Core/FactoryTypeRewriter.cs | 95 +++++-------- src/Pure.DI.Core/Core/Models/MdBinding.cs | 3 +- src/Pure.DI.Core/Core/Models/MdFactory.cs | 3 +- src/Pure.DI.Core/Core/Models/MdResolver.cs | 6 +- src/Pure.DI.Core/Core/SetupsBuilder.cs | 106 ++++----------- src/Pure.DI.Core/Pure.DI.Core.csproj | 1 - .../BindAttributeTests.cs | 76 ++++++++++- tests/Pure.DI.IntegrationTests/GraphTests.cs | 12 +- .../SimpleFactoryTests.cs | 82 +++++++++++ 12 files changed, 374 insertions(+), 196 deletions(-) diff --git a/src/Pure.DI.Core/Core/ApiInvocationProcessor.cs b/src/Pure.DI.Core/Core/ApiInvocationProcessor.cs index ad0b0fe4f..463732afe 100644 --- a/src/Pure.DI.Core/Core/ApiInvocationProcessor.cs +++ b/src/Pure.DI.Core/Core/ApiInvocationProcessor.cs @@ -387,18 +387,20 @@ private void VisitSimpleFactory( ParenthesizedLambdaExpressionSyntax lambdaExpression) { CheckNotAsync(lambdaExpression); - var identifiers = lambdaExpression.ParameterList.Parameters.Select(i => i.Identifier).ToList(); - var paramAttributes = lambdaExpression.ParameterList.Parameters.Select(i => i.AttributeLists.SelectMany(j => j.Attributes).ToList()).ToList(); - const string ctxName = "ctx_1182D127"; - var contextParameter = SyntaxFactory.Parameter(SyntaxFactory.Identifier(ctxName)); + var parameters = lambdaExpression.ParameterList.Parameters; + var paramAttributes = parameters.Select(i => i.AttributeLists.SelectMany(j => j.Attributes).ToList()).ToList(); var resolvers = new List(); - var block = new List(); var namespaces = new HashSet(); for (var i = 0; i < argsTypes.Count; i++) { var argTypeSyntax = argsTypes[i]; var argType = semantic.GetTypeSymbol(semanticModel, argTypeSyntax); - namespaces.Add(argType.ContainingNamespace.ToString()); + var argNamespace = argType.ContainingNamespace; + if (argNamespace is not null) + { + namespaces.Add(argNamespace.ToString()); + } + var attributes = paramAttributes[i]; resolvers.Add(new MdResolver { @@ -406,50 +408,20 @@ private void VisitSimpleFactory( Source = argTypeSyntax, ContractType = argType, Tag = new MdTag(0, null), + ArgumentType = argTypeSyntax, + Parameter = parameters[i], Position = i, Attributes = attributes.ToImmutableArray() }); - - var valueDeclaration = SyntaxFactory.DeclarationExpression( - argTypeSyntax, - SyntaxFactory.SingleVariableDesignation(identifiers[i])); - - var valueArg = - SyntaxFactory.Argument(valueDeclaration) - .WithRefOrOutKeyword(SyntaxFactory.Token(SyntaxKind.OutKeyword)); - - var injection = SyntaxFactory.InvocationExpression( - SyntaxFactory.MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - SyntaxFactory.IdentifierName(ctxName), - SyntaxFactory.IdentifierName(nameof(IContext.Inject)))) - .AddArgumentListArguments(valueArg); - - block.Add(SyntaxFactory.ExpressionStatement(injection)); - } - - if (lambdaExpression.Block is {} lambdaBlock) - { - block.AddRange(lambdaBlock.Statements); - } - else - { - if (lambdaExpression.ExpressionBody is { } body) - { - block.Add(SyntaxFactory.ReturnStatement(body)); - } } - var newLambdaExpression = SyntaxFactory.SimpleLambdaExpression(contextParameter) - .WithBlock(SyntaxFactory.Block(block)); - metadataVisitor.VisitFactory( new MdFactory( semanticModel, source, returnType, - newLambdaExpression, - contextParameter, + lambdaExpression, + SyntaxFactory.Parameter(SyntaxFactory.Identifier("ctx_1182D127")), resolvers.ToImmutableArray(), false)); diff --git a/src/Pure.DI.Core/Core/Code/FactoryCodeBuilder.cs b/src/Pure.DI.Core/Core/Code/FactoryCodeBuilder.cs index 8818ac740..552b19204 100644 --- a/src/Pure.DI.Core/Core/Code/FactoryCodeBuilder.cs +++ b/src/Pure.DI.Core/Core/Code/FactoryCodeBuilder.cs @@ -11,6 +11,9 @@ internal class FactoryCodeBuilder( ICompilations compilations) : ICodeBuilder { + public static readonly ParenthesizedLambdaExpressionSyntax DefaultBindAttrParenthesizedLambda = SyntaxFactory.ParenthesizedLambdaExpression(); + public static readonly ParameterSyntax DefaultCtxParameter = SyntaxFactory.Parameter(SyntaxFactory.Identifier("ctx_1182D127")); + public const string DefaultInstanceValueName = "instance_1182D127"; private static readonly string InjectionStatement = $"{Names.InjectionMarker};"; public void Build(BuildContext ctx, in DpFactory factory) @@ -25,11 +28,134 @@ public void Build(BuildContext ctx, in DpFactory factory) lockIsRequired = default; } + var originalLambda = factory.Source.Factory; + // Simple factory + if (originalLambda is ParenthesizedLambdaExpressionSyntax parenthesizedLambda) + { + var block = new List(); + foreach (var resolver in factory.Source.Resolvers) + { + if (resolver.ArgumentType is not { } argumentType || resolver.Parameter is not {} parameter) + { + continue; + } + + var valueDeclaration = SyntaxFactory.DeclarationExpression( + argumentType, + SyntaxFactory.SingleVariableDesignation(parameter.Identifier)); + + var valueArg = + SyntaxFactory.Argument(valueDeclaration) + .WithRefOrOutKeyword(SyntaxFactory.Token(SyntaxKind.OutKeyword)); + + var injection = SyntaxFactory.InvocationExpression( + SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + SyntaxFactory.IdentifierName(DefaultCtxParameter.Identifier), + SyntaxFactory.IdentifierName(nameof(IContext.Inject)))) + .AddArgumentListArguments(valueArg); + + block.Add(SyntaxFactory.ExpressionStatement(injection)); + } + + if (factory.Source.MemberResolver is {} memberResolver + && memberResolver.Member is {} member + && memberResolver.TypeConstructor is {} typeConstructor) + { + ExpressionSyntax? value = default; + var type = memberResolver.ContractType; + ExpressionSyntax instance = member.IsStatic + ? SyntaxFactory.ParseTypeName(type.ToDisplayString(NullableFlowState.None, SymbolDisplayFormat.FullyQualifiedFormat)) + : SyntaxFactory.IdentifierName(DefaultInstanceValueName); + + switch (member) + { + case IFieldSymbol fieldSymbol: + value = SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + instance, + SyntaxFactory.IdentifierName(member.Name)); + break; + + case IPropertySymbol propertySymbol: + value = SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + instance, + SyntaxFactory.IdentifierName(member.Name)); + break; + + case IMethodSymbol methodSymbol: + var args = methodSymbol.Parameters + .Select(i => SyntaxFactory.Argument(SyntaxFactory.IdentifierName(i.Name))) + .ToArray(); + + if (methodSymbol.IsGenericMethod) + { + var setup = variable.Setup; + var binding = variable.Node.Binding; + var typeArgs = new List(); + // ReSharper disable once ForeachCanBeConvertedToQueryUsingAnotherGetEnumerator + foreach (var typeArg in methodSymbol.TypeArguments) + { + var argType = typeConstructor.ConstructReversed(setup, binding.SemanticModel.Compilation, typeArg); + if (binding.TypeConstructor is { } bindingTypeConstructor) + { + argType = bindingTypeConstructor.Construct(setup, binding.SemanticModel.Compilation, argType); + } + + var typeName = argType.ToDisplayString(NullableFlowState.None, SymbolDisplayFormat.FullyQualifiedFormat); + typeArgs.Add(SyntaxFactory.ParseTypeName(typeName)); + } + + value = SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + instance, + SyntaxFactory.GenericName(member.Name).AddTypeArgumentListArguments(typeArgs.ToArray())); + } + else + { + value = SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + instance, + SyntaxFactory.IdentifierName(member.Name)); + } + + value = SyntaxFactory + .InvocationExpression(value) + .AddArgumentListArguments(args); + + break; + } + + if (value is not null) + { + block.Add(SyntaxFactory.ReturnStatement(value)); + } + } + else + { + if (parenthesizedLambda.Block is {} lambdaBlock) + { + block.AddRange(lambdaBlock.Statements); + } + else + { + if (parenthesizedLambda.ExpressionBody is { } body) + { + block.Add(SyntaxFactory.ReturnStatement(body)); + } + } + } + + originalLambda = SyntaxFactory.SimpleLambdaExpression(DefaultCtxParameter) + .WithBlock(SyntaxFactory.Block(block)); + } + // Rewrites syntax tree var finishLabel = $"{variable.VariableDeclarationName}Finish"; var injections = new List(); var localVariableRenamingRewriter = new LocalVariableRenamingRewriter(idGenerator, factory.Source.SemanticModel); - var factoryExpression = localVariableRenamingRewriter.Rewrite(factory.Source.Factory); + var factoryExpression = localVariableRenamingRewriter.Rewrite(originalLambda); var factoryRewriter = new FactoryRewriter(arguments, compilations, factory, variable, finishLabel, injections); var lambda = factoryRewriter.Rewrite(factoryExpression); new FactoryValidator(factory).Validate(lambda); diff --git a/src/Pure.DI.Core/Core/DependencyGraphBuilder.cs b/src/Pure.DI.Core/Core/DependencyGraphBuilder.cs index 7bc6ea407..d3812a24e 100644 --- a/src/Pure.DI.Core/Core/DependencyGraphBuilder.cs +++ b/src/Pure.DI.Core/Core/DependencyGraphBuilder.cs @@ -458,6 +458,7 @@ private MdBinding CreateGenericBinding( return sourceNode.Binding with { Id = newId, + TypeConstructor = typeConstructor, Contracts = newContracts, Implementation = sourceNode.Binding.Implementation.HasValue ? sourceNode.Binding.Implementation.Value with @@ -518,9 +519,9 @@ private MdBinding CreateAutoBinding( var semanticModel = targetNode.Binding.SemanticModel; var compilation = semanticModel.Compilation; var sourceType = injection.Type; + var typeConstructor = typeConstructorFactory(); if (marker.IsMarkerBased(setup, injection.Type)) { - var typeConstructor = typeConstructorFactory(); typeConstructor.TryBind(setup, injection.Type, injection.Type); sourceType = typeConstructor.Construct(setup, compilation, injection.Type); } @@ -538,7 +539,8 @@ private MdBinding CreateAutoBinding( newContracts, newTags, new MdLifetime(semanticModel, setup.Source, Lifetime.Transient), - new MdImplementation(semanticModel, setup.Source, sourceType)); + new MdImplementation(semanticModel, setup.Source, sourceType), + TypeConstructor: typeConstructor); return newBinding; } diff --git a/src/Pure.DI.Core/Core/FactoryTypeRewriter.cs b/src/Pure.DI.Core/Core/FactoryTypeRewriter.cs index fced29ef0..188d18737 100644 --- a/src/Pure.DI.Core/Core/FactoryTypeRewriter.cs +++ b/src/Pure.DI.Core/Core/FactoryTypeRewriter.cs @@ -12,7 +12,7 @@ public MdFactory Build(RewriterContext context) { _context = context; var factory = context.State; - var newFactory = (LambdaExpressionSyntax)Visit(factory.Factory); + var newFactory = (LambdaExpressionSyntax)Visit(factory.Factory)!; return factory with { Type = context.TypeConstructor.Construct(context.Setup, factory.SemanticModel.Compilation, factory.Type), @@ -42,77 +42,54 @@ public MdFactory Build(RewriterContext context) return default; } - public override SyntaxNode? VisitIdentifierName(IdentifierNameSyntax node) + public override SyntaxNode? VisitIdentifierName(IdentifierNameSyntax node) => + TryCreateTypeSyntax(node) ?? base.VisitIdentifierName(node); + + public override SyntaxNode? VisitGenericName(GenericNameSyntax node) => + TryCreateTypeSyntax(node) ?? base.VisitGenericName(node); + + public override SyntaxNode? VisitQualifiedName(QualifiedNameSyntax node) => + TryCreateTypeSyntax(node) ?? base.VisitQualifiedName(node); + + private SyntaxNode? TryCreateTypeSyntax(SyntaxNode node) => + TryGetNewTypeName(node, out var newTypeName) + ? SyntaxFactory.ParseTypeName(newTypeName) + .WithLeadingTrivia(node.GetLeadingTrivia()) + .WithTrailingTrivia(node.GetTrailingTrivia()) + : default(SyntaxNode?); + + private bool TryGetNewTypeName(SyntaxNode? node, [NotNullWhen(true)] out string? newTypeName) { - var identifier = base.VisitIdentifierName(node) as IdentifierNameSyntax; - if (identifier is null) + newTypeName = default; + if (node is null) { - return identifier; + return false; } var semanticModel = _context.State.SemanticModel; - if (identifier.SyntaxTree != semanticModel.SyntaxTree) + if (semanticModel.GetSymbolInfo(node).Symbol is ITypeSymbol type) { - return identifier; + return TryGetNewTypeName(type, true, out newTypeName); } - var symbol = semanticModel.GetSymbolInfo(identifier).Symbol; - if (symbol is not ITypeSymbol type) - { - return identifier; - } - + return false; + } + + private bool TryGetNewTypeName(ITypeSymbol type, bool inTree, [NotNullWhen(true)] out string? newTypeName) + { + newTypeName = default; if (!marker.IsMarkerBased(_context.Setup, type)) { - return identifier; + return false; } - - var newType = _context.TypeConstructor.Construct(_context.Setup, semanticModel.Compilation, type); - var newTypeName = typeResolver.Resolve(_context.Setup, newType).Name; - return node.WithIdentifier( - SyntaxFactory.Identifier(newTypeName)) - .WithLeadingTrivia(node.Identifier.LeadingTrivia) - .WithTrailingTrivia(node.Identifier.TrailingTrivia); - } - public override SyntaxNode? VisitTypeArgumentList(TypeArgumentListSyntax node) - { - var newArgs = new List(); - var hasMarkerBased = false; - var semanticModel = _context.Setup.SemanticModel; - foreach (var arg in node.Arguments) + var newType = _context.TypeConstructor.Construct(_context.Setup, _context.State.SemanticModel.Compilation, type); + if (!inTree && SymbolEqualityComparer.Default.Equals(newType, type)) { - var typeName = arg.ToString(); - var isFound = false; - foreach (var type in semanticModel.Compilation.GetTypesByMetadataName(typeName)) - { - if (!marker.IsMarkerBased(_context.Setup, type)) - { - newArgs.Add(arg); - isFound = true; - break; - } - - hasMarkerBased = true; - var constructedType = _context.TypeConstructor.Construct(_context.Setup, semanticModel.Compilation, type); - if (SymbolEqualityComparer.Default.Equals(type, constructedType)) - { - continue; - } - - newArgs.Add(SyntaxFactory.ParseTypeName(constructedType.ToString())); - isFound = true; - break; - } - - if (!isFound) - { - return base.VisitTypeArgumentList(node); - } + return false; } - - return hasMarkerBased - ? SyntaxFactory.TypeArgumentList().AddArguments(newArgs.ToArray()) - : base.VisitTypeArgumentList(node); + + newTypeName = typeResolver.Resolve(_context.Setup, newType).Name; + return true; } } \ No newline at end of file diff --git a/src/Pure.DI.Core/Core/Models/MdBinding.cs b/src/Pure.DI.Core/Core/Models/MdBinding.cs index 85e8e95db..7d5daf4fd 100644 --- a/src/Pure.DI.Core/Core/Models/MdBinding.cs +++ b/src/Pure.DI.Core/Core/Models/MdBinding.cs @@ -11,7 +11,8 @@ internal record MdBinding( in MdImplementation? Implementation = default, in MdFactory? Factory = default, in MdArg? Arg = default, - in MdConstruct? Construct = default) + in MdConstruct? Construct = default, + ITypeConstructor? TypeConstructor = default) { public override string ToString() { diff --git a/src/Pure.DI.Core/Core/Models/MdFactory.cs b/src/Pure.DI.Core/Core/Models/MdFactory.cs index 7de8f2c7a..fefabce9f 100644 --- a/src/Pure.DI.Core/Core/Models/MdFactory.cs +++ b/src/Pure.DI.Core/Core/Models/MdFactory.cs @@ -8,7 +8,8 @@ internal readonly record struct MdFactory( LambdaExpressionSyntax Factory, ParameterSyntax Context, in ImmutableArray Resolvers, - bool HasContextTag) + bool HasContextTag, + in MdResolver? MemberResolver = default) { public override string ToString() => $"To<{Type}>({Factory})"; } \ No newline at end of file diff --git a/src/Pure.DI.Core/Core/Models/MdResolver.cs b/src/Pure.DI.Core/Core/Models/MdResolver.cs index 90f8aa8e0..b951f6949 100644 --- a/src/Pure.DI.Core/Core/Models/MdResolver.cs +++ b/src/Pure.DI.Core/Core/Models/MdResolver.cs @@ -9,7 +9,11 @@ internal readonly record struct MdResolver( ITypeSymbol ContractType, MdTag? Tag, ExpressionSyntax TargetValue, - ImmutableArray Attributes = default) + TypeSyntax? ArgumentType= default, + ParameterSyntax? Parameter = default, + ImmutableArray Attributes = default, + ISymbol? Member = default, + ITypeConstructor? TypeConstructor = default) { public override string ToString() => $"<=={ContractType}({Tag?.ToString()})"; } \ No newline at end of file diff --git a/src/Pure.DI.Core/Core/SetupsBuilder.cs b/src/Pure.DI.Core/Core/SetupsBuilder.cs index 9c1355807..6d87db82b 100644 --- a/src/Pure.DI.Core/Core/SetupsBuilder.cs +++ b/src/Pure.DI.Core/Core/SetupsBuilder.cs @@ -6,7 +6,7 @@ internal sealed class SetupsBuilder( ICache, bool> setupCache, Func bindingBuilderFactory, IArguments arguments, - ITypeConstructor typeConstructor) + Func typeConstructorFactory) : IBuilder>, IMetadataVisitor, ISetupFinalizer { private readonly List _setups = []; @@ -193,6 +193,7 @@ from attribute in member.GetAttributes() where attribute.AttributeClass?.ToDisplayString(NullableFlowState.None, SymbolDisplayFormat.FullyQualifiedFormat) == Names.BindAttributeName select (attribute, member); + var typeConstructor = typeConstructorFactory(); foreach (var (attribute, member) in membersToBind) { var values = arguments.GetArgs(attribute.ConstructorArguments, attribute.NamedArguments, "type", "lifetime", "tags"); @@ -202,86 +203,42 @@ from attribute in member.GetAttributes() contractType = newContractType; } - const string ctxName = "ctx_1182D127"; - const string valueName = "value"; - ExpressionSyntax instance = member.IsStatic - ? SyntaxFactory.ParseTypeName(type.ToDisplayString(NullableFlowState.None, SymbolDisplayFormat.FullyQualifiedFormat)) - : SyntaxFactory.IdentifierName(valueName); - - ExpressionSyntax value; - var position = 0; var namespaces = new HashSet(); var resolvers = new List(); - var block = new List(); switch (member) { case IFieldSymbol fieldSymbol: contractType ??= fieldSymbol.Type; - value = SyntaxFactory.MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - instance, - SyntaxFactory.IdentifierName(member.Name)); break; case IPropertySymbol propertySymbol: contractType ??= propertySymbol.Type; - value = SyntaxFactory.MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - instance, - SyntaxFactory.IdentifierName(member.Name)); break; case IMethodSymbol methodSymbol: contractType ??= methodSymbol.ReturnType; - - var args = methodSymbol.Parameters - .Select(i => SyntaxFactory.Argument(SyntaxFactory.IdentifierName(i.Name))) - .ToArray(); - if (methodSymbol.IsGenericMethod) { typeConstructor.TryBind(setup, contractType, methodSymbol.ReturnType); + contractType = typeConstructor.ConstructReversed(setup, semanticModel.Compilation, contractType); + // ReSharper disable once ForeachCanBeConvertedToQueryUsingAnotherGetEnumerator foreach (var parameter in methodSymbol.Parameters) { var paramType = typeConstructor.ConstructReversed(setup, binding.SemanticModel.Compilation, parameter.Type); - block.Add(SyntaxFactory.ExpressionStatement(Inject(paramType, parameter.Name, resolvers, MdTag.ContextTag, ref position))); + resolvers.Add(CreateResolver(typeConstructor, parameter.Name, paramType, MdTag.ContextTag, ref position)); } - - var typeArgs = new List(); - // ReSharper disable once ForeachCanBeConvertedToQueryUsingAnotherGetEnumerator - foreach (var typeArg in methodSymbol.TypeArguments) - { - var argType = typeConstructor.ConstructReversed(setup, binding.SemanticModel.Compilation, typeArg); - var typeName = argType.ToString(); - typeArgs.Add(SyntaxFactory.ParseTypeName(typeName)); - } - - value = SyntaxFactory.MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - instance, - SyntaxFactory.GenericName(member.Name) - .AddTypeArgumentListArguments(typeArgs.ToArray())); } else { // ReSharper disable once ForeachCanBeConvertedToQueryUsingAnotherGetEnumerator foreach (var parameter in methodSymbol.Parameters) { - block.Add(SyntaxFactory.ExpressionStatement(Inject(parameter.Type, parameter.Name, resolvers, MdTag.ContextTag, ref position))); + resolvers.Add(CreateResolver(typeConstructor, parameter.Name, parameter.Type, MdTag.ContextTag, ref position)); } - - value = SyntaxFactory.MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - instance, - SyntaxFactory.IdentifierName(member.Name)); } - - value = SyntaxFactory - .InvocationExpression(value) - .AddArgumentListArguments(args); - + break; default: @@ -320,22 +277,17 @@ from attribute in member.GetAttributes() tags = []; } - var contextParameter = SyntaxFactory.Parameter(SyntaxFactory.Identifier(ctxName)); object? valueTag = default; if (!contract.Tags.IsDefaultOrEmpty) { valueTag = contract.Tags.First().Value; } - + if (!member.IsStatic) { - block.Add(SyntaxFactory.ExpressionStatement(Inject(contract.ContractType!, valueName, resolvers, valueTag, ref position))); + resolvers.Add(CreateResolver(typeConstructor, FactoryCodeBuilder.DefaultInstanceValueName, contract.ContractType!, valueTag, ref position)); } - block.Add(SyntaxFactory.ReturnStatement(value)); - var lambdaExpression = SyntaxFactory.SimpleLambdaExpression(contextParameter) - .WithBlock(SyntaxFactory.Block(block)); - VisitContract( new MdContract( semanticModel, @@ -360,47 +312,37 @@ from attribute in member.GetAttributes() VisitTag(new MdTag(tagPosition, default)); } + var memberResolver = CreateResolver(typeConstructor, FactoryCodeBuilder.DefaultInstanceValueName, contract.ContractType!, valueTag, ref position); + memberResolver = memberResolver with { Member = member }; VisitFactory( new MdFactory( semanticModel, source, contractType, - lambdaExpression, - contextParameter, + FactoryCodeBuilder.DefaultBindAttrParenthesizedLambda, + FactoryCodeBuilder.DefaultCtxParameter, resolvers.ToImmutableArray(), - false)); - + false, + memberResolver)); + VisitUsingDirectives(new MdUsingDirectives(namespaces.ToImmutableArray(), ImmutableArray.Empty)); continue; - InvocationExpressionSyntax Inject(ITypeSymbol injectedType, string injectedName, ICollection resolversSet, object? tag, ref int curPosition) + MdResolver CreateResolver(ITypeConstructor constructor, string name, ITypeSymbol injectedType, object? tag, ref int curPosition) { + var typeSyntax = SyntaxFactory.ParseTypeName(injectedType.ToDisplayString(NullableFlowState.None, SymbolDisplayFormat.FullyQualifiedFormat)); namespaces.Add(injectedType.ContainingNamespace.ToString()); - resolversSet.Add(new MdResolver + return new MdResolver { SemanticModel = semanticModel, Source = source, ContractType = injectedType, Tag = new MdTag(curPosition, tag), - Position = curPosition - }); - - curPosition++; - - var valueDeclaration = SyntaxFactory.DeclarationExpression( - SyntaxFactory.ParseTypeName(injectedType.ToString()).WithTrailingTrivia(SyntaxFactory.Space), - SyntaxFactory.SingleVariableDesignation(SyntaxFactory.Identifier(injectedName))); - - var valueArg = - SyntaxFactory.Argument(valueDeclaration) - .WithRefOrOutKeyword(SyntaxFactory.Token(SyntaxKind.OutKeyword)); - - return SyntaxFactory.InvocationExpression( - SyntaxFactory.MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - SyntaxFactory.IdentifierName(ctxName), - SyntaxFactory.IdentifierName(nameof(IContext.Inject)))) - .AddArgumentListArguments(valueArg); + ArgumentType = typeSyntax, + Parameter = SyntaxFactory.Parameter(SyntaxFactory.Identifier(name)).WithType(typeSyntax), + Position = curPosition++, + TypeConstructor = constructor + }; } } } diff --git a/src/Pure.DI.Core/Pure.DI.Core.csproj b/src/Pure.DI.Core/Pure.DI.Core.csproj index 55411c568..b276b1c9b 100644 --- a/src/Pure.DI.Core/Pure.DI.Core.csproj +++ b/src/Pure.DI.Core/Pure.DI.Core.csproj @@ -23,7 +23,6 @@ True GenericTypeArguments.g.tt - diff --git a/tests/Pure.DI.IntegrationTests/BindAttributeTests.cs b/tests/Pure.DI.IntegrationTests/BindAttributeTests.cs index 74f601514..76bec5258 100644 --- a/tests/Pure.DI.IntegrationTests/BindAttributeTests.cs +++ b/tests/Pure.DI.IntegrationTests/BindAttributeTests.cs @@ -643,11 +643,83 @@ public static void Main() result.StdOut.ShouldBe(["123", "Abc"], result); } + [Theory] + [InlineData("Pure.DI.", "Bind")] + [InlineData("", "Bind")] + [InlineData("global::Pure.DI.", "Bind")] + [InlineData("Pure.DI.", "BindAttribute")] + [InlineData("", "BindAttribute")] + [InlineData("global::Pure.DI.", "BindAttribute")] + public async Task ShouldSupportBindAttributeWhenGenericMethodWithArgs(string typeName, string attrName) + { + // Given + + // When + var result = await """ +using System; +using Pure.DI; + +namespace Sample +{ + internal interface IDependency { } + + internal class Dependency : IDependency + { + public Dependency() { } + } + + internal interface IService { } + + internal class Service : IService + { + public Service(IDependency dependency) + { + } + } + + internal class BaseComposition + { + [#TypeName#AttrName(typeof(Sample.IDependency<#TypeNameTT>), #TypeNameLifetime.Transient, null)] + public Sample.IDependency GetDep(int id) + { + Console.WriteLine(id); + return new Dependency(); + } + } + + static class Setup + { + private static void SetupComposition() + { + DI.Setup("Composition") + .Bind().To(_ => 77) + .Bind().To() + .Bind().To() + .Root("Service"); + } + } + + public class Program + { + public static void Main() + { + var composition = new Composition(); + var service = composition.Service; + } + } +} +""".Replace("#TypeName", typeName).Replace("#AttrName", attrName).RunAsync(); + + // Then + result.Success.ShouldBeTrue(result); + result.StdOut.ShouldBe(["77"], result); + } + [Theory] [InlineData("Pure.DI.")] [InlineData("")] [InlineData("global::Pure.DI.")] - public async Task ShouldSupportBindAttributeWhenGenericMethodWithArgs(string typeName) + public async Task ShouldSupportBindAttributeWhenGenericMethodWithGenericArgs(string typeName) { // Given @@ -676,7 +748,7 @@ public Service(IDependency dependency) internal class BaseComposition { - [Bind(typeof(Sample.IDependency<#TypeNameTT>), Lifetime.Singleton, null, 1, "abc")] + [#TypeNameBind(typeof(Sample.IDependency<#TypeNameTT>), #TypeNameLifetime.Singleton, null, 1, "abc")] public Sample.IDependency GetDep(T val, string str) { Console.WriteLine(val); diff --git a/tests/Pure.DI.IntegrationTests/GraphTests.cs b/tests/Pure.DI.IntegrationTests/GraphTests.cs index 2397cfc7b..bd1a52474 100644 --- a/tests/Pure.DI.IntegrationTests/GraphTests.cs +++ b/tests/Pure.DI.IntegrationTests/GraphTests.cs @@ -917,8 +917,8 @@ public class Program { public static void Main() { } } Sample.IService() +[Sample.IService() ]<--[Sample.IService]--[Service(Sample.IDependency dependency<--Sample.IDependency))] Service(Sample.IDependency dependency<--Sample.IDependency)) - +[Service(Sample.IDependency dependency<--Sample.IDependency))]<--[Sample.IDependency]--[new Dependency(new int[1])] -new Dependency(new int[1]) + +[Service(Sample.IDependency dependency<--Sample.IDependency))]<--[Sample.IDependency]--[new Sample.Dependency(new int[1])] +new Sample.Dependency(new int[1]) """.Replace("\r", "")); } @@ -984,23 +984,23 @@ public class Program { public static void Main() { } } +[Service(Sample.IDependency dependency<--Sample.IDependency))]<--[Sample.IDependency]--[{ ctx.Inject(out int[] array); ctx.Inject("MyStr", out var str); - return new Dependency(array, str); + return new Sample.Dependency(array, str); }] new int[] {1, 2, 3} { ctx.Inject(out int[] array); ctx.Inject("MyStr", out var str); - return new Dependency(array, str); + return new Sample.Dependency(array, str); } +[{ ctx.Inject(out int[] array); ctx.Inject("MyStr", out var str); - return new Dependency(array, str); + return new Sample.Dependency(array, str); }]<--[int[]]--[new int[] {1, 2, 3}] +[{ ctx.Inject(out int[] array); ctx.Inject("MyStr", out var str); - return new Dependency(array, str); + return new Sample.Dependency(array, str); }]<--[string("MyStr")]--["Abc"] """.Replace("\r", "")); } diff --git a/tests/Pure.DI.IntegrationTests/SimpleFactoryTests.cs b/tests/Pure.DI.IntegrationTests/SimpleFactoryTests.cs index 6cab9de6d..af3c57176 100644 --- a/tests/Pure.DI.IntegrationTests/SimpleFactoryTests.cs +++ b/tests/Pure.DI.IntegrationTests/SimpleFactoryTests.cs @@ -70,6 +70,88 @@ public static void Main() result.StdOut.ShouldBe(["Sample.Dependency"], result); } + [Theory] + [InlineData("global::System.Collections.Generic.", "Pure.DI.")] + [InlineData("System.Collections.Generic.", "Pure.DI.")] + [InlineData("", "Pure.DI.")] + [InlineData("global::System.Collections.Generic.", "global::Pure.DI.")] + [InlineData("System.Collections.Generic.", "global::Pure.DI.")] + [InlineData("", "global::Pure.DI.")] + [InlineData("global::System.Collections.Generic.", "")] + [InlineData("System.Collections.Generic.", "")] + [InlineData("", "")] + public async Task ShouldSupportSimpleFactoryWhenArrOfT(string typePrefix, string ttPrefix) + { + // Given + + // When + var result = await """ +using System; +using System.Collections.Generic; +using Pure.DI; + +namespace Sample +{ + interface IDependency {} + + class Dependency: IDependency {} + + interface IService + { + IDependency? Dep { get; } + + IService Initialize(IDependency dep); + } + + class Service: IService + { + public IDependency? Dep { get; private set; } + + public IService Initialize(IDependency dep) + { + Dep = dep; + return this; + } + + public override string ToString() + { + return Dep?.ToString() ?? ""; + } + } + + static class Setup + { + private static void SetupComposition() + { + DI.Setup("Composition") + .Bind<#TypePrefixICollection<#ttPrefixTT>>() + .Bind<#TypePrefixIList<#ttPrefixTT>>() + .Bind<#TypePrefixList<#ttPrefixTT>>() + .To((#ttPrefixTT[] arr) => new #TypePrefixList<#ttPrefixTT>(arr)) + .Bind().To() + .Bind().To() + .Bind().To((IService service, #TypePrefixIList dependency) => service.Initialize(dependency[0]).ToString() ?? "") + .Root("DepName"); + } + } + + public class Program + { + public static void Main() + { + var composition = new Composition(); + var depName = composition.DepName; + Console.WriteLine(depName); + } + } +} +""".Replace("#TypePrefix", typePrefix).Replace("#ttPrefix", ttPrefix).RunAsync(); + + // Then + result.Success.ShouldBeTrue(result); + result.StdOut.ShouldBe(["Sample.Dependency"], result); + } + [Fact] public async Task ShouldSupportSimpleFactoryWhenSimpleLambdaWitgGenericParams() {