diff --git a/Source/Test/Rewriting/Passes/MSTestRewriter.cs b/Source/Test/Rewriting/Passes/MSTestRewriter.cs index b17b93293..1814dace1 100644 --- a/Source/Test/Rewriting/Passes/MSTestRewriter.cs +++ b/Source/Test/Rewriting/Passes/MSTestRewriter.cs @@ -19,6 +19,23 @@ internal class MSTestRewriter : AssemblyRewriter /// private readonly Configuration Configuration; + /// + /// MethodReference of TestInitialize method within a test class. + /// + private MethodDefinition TestInitMethod; + private const string TestInitAttribute = "Microsoft.VisualStudio.TestTools.UnitTesting.TestInitializeAttribute"; + + /// + /// MethodReference of TestCleanupMethod method within a test class. + /// + private MethodDefinition TestCleanupMethod; + private const string TestCleanupAttribute = "Microsoft.VisualStudio.TestTools.UnitTesting.TestCleanupAttribute"; + + /// + /// TestClass Attribute of MSTests. + /// + private const string TestClassAttribute = "Microsoft.VisualStudio.TestTools.UnitTesting.TestClassAttribute"; + /// /// Initializes a new instance of the class. /// @@ -26,6 +43,40 @@ internal MSTestRewriter(Configuration configuration, ILogger logger) : base(logger) { this.Configuration = configuration; + this.TestInitMethod = null; + this.TestCleanupMethod = null; + } + + /// + internal override void VisitType(TypeDefinition type) + { + if (type.IsAbstract) + { + return; + } + + this.TestInitMethod = null; + this.TestCleanupMethod = null; + + if (type.CustomAttributes.Any(p => p.AttributeType.FullName == TestClassAttribute)) + { + foreach (var method in type.Methods) + { + if (method.CustomAttributes.Count > 0) + { + // Assuming that we can have at max only one TestInit and TestCleanup Methods. + if (method.CustomAttributes.Any(p => p.AttributeType.FullName == TestCleanupAttribute)) + { + this.TestCleanupMethod = method; + } + + if (method.CustomAttributes.Any(p => p.AttributeType.FullName == TestInitAttribute)) + { + this.TestInitMethod = method; + } + } + } + } } /// @@ -170,6 +221,8 @@ internal void RewriteTestMethod(MethodDefinition method, MethodDefinition testMe // The emitted IL corresponds to a method body such as: // Configuration configuration = Configuration.Create(); // TestingEngine engine = TestingEngine.Create(configuration, new Action(Test)); + // [engine.RegisterPerIterationInitMethod(new Action(method));] + // [engine.RegisterPerIterationCallBack(new Action(method));] // engine.Run(); // engine.ThrowIfBugFound(); // @@ -253,6 +306,35 @@ internal void RewriteTestMethod(MethodDefinition method, MethodDefinition testMe processor.Emit(OpCodes.Newobj, actionConstructor); processor.Emit(OpCodes.Call, createEngineMethod); processor.Emit(OpCodes.Dup); + + // Add call to engine.RegisterPerIterationInitMethod(new Action(method)); + if (this.TestInitMethod != null) + { + processor.Emit(OpCodes.Dup); + processor.Emit(OpCodes.Ldarg_0); + processor.Emit(OpCodes.Ldftn, this.TestInitMethod); + processor.Emit(OpCodes.Newobj, actionConstructor); + + MethodReference registerPerIterationInitMethod = this.Module.ImportReference( + FindMatchingMethodInDeclaringType(resolvedEngineType, "RegisterPerIterationInitMethod", actionType)); + + processor.Emit(OpCodes.Call, registerPerIterationInitMethod); + } + + // Add call to engine.RegisterPerIterationCallBack(new Action(method)); + if (this.TestCleanupMethod != null) + { + processor.Emit(OpCodes.Dup); + processor.Emit(OpCodes.Ldarg_0); + processor.Emit(OpCodes.Ldftn, this.TestCleanupMethod); + processor.Emit(OpCodes.Newobj, actionConstructor); + + MethodReference registerPerIterationCallBack = this.Module.ImportReference( + FindMatchingMethodInDeclaringType(resolvedEngineType, "RegisterPerIterationCallBack", actionType)); + + processor.Emit(OpCodes.Call, registerPerIterationCallBack); + } + this.EmitMethodCall(processor, resolvedEngineType, "Run"); this.EmitMethodCall(processor, resolvedEngineType, "ThrowIfBugFound"); processor.Emit(OpCodes.Ret); diff --git a/Source/Test/SystematicTesting/TestingEngine.cs b/Source/Test/SystematicTesting/TestingEngine.cs index 19823a9e1..df9dddab5 100644 --- a/Source/Test/SystematicTesting/TestingEngine.cs +++ b/Source/Test/SystematicTesting/TestingEngine.cs @@ -51,7 +51,13 @@ public sealed class TestingEngine /// Set of callbacks to invoke at the end /// of each iteration. /// - private readonly ISet> PerIterationCallbacks; + private readonly ISet PerIterationCallbacks; + + /// + /// Set of callbacks to invoke at the beginning + /// of each iteration. + /// + private readonly ISet PerIterationInitializationCallbacks; /// /// The scheduler used by the runtime during testing. @@ -247,7 +253,8 @@ private TestingEngine(Configuration configuration, TestMethodInfo testMethodInfo this.Logger = this.DefaultLogger; this.Profiler = new Profiler(); - this.PerIterationCallbacks = new HashSet>(); + this.PerIterationCallbacks = new HashSet(); + this.PerIterationInitializationCallbacks = new HashSet(); this.TestReport = new TestReport(configuration); this.ReadableTrace = string.Empty; @@ -529,6 +536,15 @@ private bool RunNextIteration(uint iteration) this.InitializeCustomActorLogging(runtime.DefaultActorExecutionContext); + // Invoke TestInit methods (if any) for every iteration, except the first one. + if (iteration != 0) + { + foreach (var callback in this.PerIterationInitializationCallbacks) + { + callback(); + } + } + // Runs the test and waits for it to terminate. runtime.RunTest(this.TestMethodInfo.Method, this.TestMethodInfo.Name); runtime.WaitAsync().Wait(); @@ -539,7 +555,7 @@ private bool RunNextIteration(uint iteration) // Invoke the per iteration callbacks, if any. foreach (var callback in this.PerIterationCallbacks) { - callback(iteration); + callback(); } if (!runtime.IsBugFound) @@ -712,14 +728,21 @@ public IEnumerable TryEmitTraces(string directory, string file) } /// - /// Registers a callback to invoke at the end of each iteration. The callback takes as - /// a parameter an integer representing the current iteration. + /// Registers a callback to invoke at the end of each iteration. /// - public void RegisterPerIterationCallBack(Action callback) + public void RegisterPerIterationCallBack(Action callback) { this.PerIterationCallbacks.Add(callback); } + /// + /// Registers a callback to invoke at the end of each iteration. + /// + public void RegisterPerIterationInitMethod(Action callback) + { + this.PerIterationInitializationCallbacks.Add(callback); + } + /// /// Take care of handling the settings for , /// , and by setting up the