diff --git a/TUnit.Assertions.Analyzers.CodeFixers.Tests/Verifiers/CSharpCodeFixVerifier`2.cs b/TUnit.Assertions.Analyzers.CodeFixers.Tests/Verifiers/CSharpCodeFixVerifier`2.cs index 6ad9ff152f..2c8ee11dc2 100644 --- a/TUnit.Assertions.Analyzers.CodeFixers.Tests/Verifiers/CSharpCodeFixVerifier`2.cs +++ b/TUnit.Assertions.Analyzers.CodeFixers.Tests/Verifiers/CSharpCodeFixVerifier`2.cs @@ -4,7 +4,9 @@ using Microsoft.CodeAnalysis.CSharp.Testing; using Microsoft.CodeAnalysis.Diagnostics; using Microsoft.CodeAnalysis.Testing; +using TUnit.Assertions; using TUnit.Assertions.Analyzers.CodeFixers.Tests.Extensions; +using TUnit.Core; namespace TUnit.Assertions.Analyzers.CodeFixers.Tests.Verifiers; @@ -41,12 +43,19 @@ public static async Task VerifyAnalyzerAsync( params DiagnosticResult[] expected ) { + var referenceAssemblies = GetReferenceAssemblies(); + + // Only add xunit package for XUnitAssertionAnalyzer + if (typeof(TAnalyzer).Name == "XUnitAssertionAnalyzer") + { + referenceAssemblies = referenceAssemblies.AddPackages([new PackageIdentity("xunit.v3.assert", "3.2.0")]); + } + var test = new Test { TestCode = source.NormalizeLineEndings(), CodeActionValidationMode = CodeActionValidationMode.SemanticStructure, - ReferenceAssemblies = GetReferenceAssemblies() - .AddPackages([new PackageIdentity("xunit.v3.assert", "2.0.0")]), + ReferenceAssemblies = referenceAssemblies, TestState = { AdditionalReferences = @@ -76,12 +85,19 @@ public static async Task VerifyCodeFixAsync( [StringSyntax("c#-test")] string fixedSource ) { + var referenceAssemblies = GetReferenceAssemblies(); + + // Only add xunit package for XUnitAssertionAnalyzer + if (typeof(TAnalyzer).Name == "XUnitAssertionAnalyzer") + { + referenceAssemblies = referenceAssemblies.AddPackages([new PackageIdentity("xunit.v3.assert", "3.2.0")]); + } + var test = new Test { TestCode = source.NormalizeLineEndings(), FixedCode = fixedSource.NormalizeLineEndings(), - ReferenceAssemblies = GetReferenceAssemblies() - .AddPackages([new PackageIdentity("xunit.v3.assert", "2.0.0")]), + ReferenceAssemblies = referenceAssemblies, TestState = { AdditionalReferences = diff --git a/TUnit.Assertions.Analyzers.CodeFixers.Tests/XUnitAssertionCodeFixProviderTests.cs b/TUnit.Assertions.Analyzers.CodeFixers.Tests/XUnitAssertionCodeFixProviderTests.cs index 4359388627..4914f8aae6 100644 --- a/TUnit.Assertions.Analyzers.CodeFixers.Tests/XUnitAssertionCodeFixProviderTests.cs +++ b/TUnit.Assertions.Analyzers.CodeFixers.Tests/XUnitAssertionCodeFixProviderTests.cs @@ -150,4 +150,73 @@ public void MyTest() """ ); } + + [Test] + public async Task Xunit_All_Converts_To_AssertMultiple_WithForeach() + { + await Verifier + .VerifyCodeFixAsync( + """ + using System.Threading.Tasks; + + public class MyClass + { + public void MyTest() + { + var users = new[] + { + new User { Name = "Alice", Age = 25 }, + new User { Name = "Bob", Age = 30 } + }; + + {|#0:Xunit.Assert.All(users, user => + { + {|#1:Xunit.Assert.NotNull(user.Name)|}; + {|#2:Xunit.Assert.True(user.Age > 18)|}; + })|}; + } + } + + public class User + { + public string Name { get; init; } + public int Age { get; init; } + } + """, + [ + Verifier.Diagnostic(Rules.XUnitAssertion).WithLocation(0), + Verifier.Diagnostic(Rules.XUnitAssertion).WithLocation(1), + Verifier.Diagnostic(Rules.XUnitAssertion).WithLocation(2) + ], + """ + using System.Threading.Tasks; + + public class MyClass + { + public async Task MyTest() + { + var users = new[] + { + new User { Name = "Alice", Age = 25 }, + new User { Name = "Bob", Age = 30 } + }; + using (Assert.Multiple()) + { + foreach (var user in users) + { + await Assert.That(user.Name).IsNotNull(); + await Assert.That(user.Age > 18).IsTrue(); + } + } + } + } + + public class User + { + public string Name { get; init; } + public int Age { get; init; } + } + """ + ); + } } diff --git a/TUnit.Assertions.Analyzers.CodeFixers/XUnitAssertionCodeFixProvider.cs b/TUnit.Assertions.Analyzers.CodeFixers/XUnitAssertionCodeFixProvider.cs index c4f6adaea6..8afa092b20 100644 --- a/TUnit.Assertions.Analyzers.CodeFixers/XUnitAssertionCodeFixProvider.cs +++ b/TUnit.Assertions.Analyzers.CodeFixers/XUnitAssertionCodeFixProvider.cs @@ -5,6 +5,7 @@ using Microsoft.CodeAnalysis.CodeFixes; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Formatting; using TUnit.Assertions.Analyzers.Extensions; namespace TUnit.Assertions.Analyzers.CodeFixers; @@ -64,7 +65,57 @@ private static async Task ConvertAssertionAsync(CodeFixContext context var genericArgs = GetGenericArguments(memberAccessExpressionSyntax.Name); - var newExpression = await GetNewExpression(context, memberAccessExpressionSyntax, methodName, actual, expected, genericArgs, expressionSyntax.ArgumentList.Arguments); + // Special handling for Assert.All - returns a statement instead of an expression + if (methodName == "All") + { + var newStatement = ConvertAllToStatement(expected, actual); + if (newStatement != null) + { + // Find the parent expression statement and containing method + var parentStatement = expressionSyntax.FirstAncestorOrSelf(); + var methodDeclaration = expressionSyntax.FirstAncestorOrSelf(); + + if (parentStatement != null) + { + // Format the statement using Roslyn's formatter with annotations + newStatement = newStatement.WithAdditionalAnnotations(Formatter.Annotation); + + // Replace the statement + var newRoot = compilationUnit.ReplaceNode(parentStatement, newStatement); + + // Find the method declaration in the new tree if it needs to be modified + if (methodDeclaration != null && !methodDeclaration.Modifiers.Any(SyntaxKind.AsyncKeyword)) + { + // Find the method in the new tree + var newMethodDeclaration = newRoot.DescendantNodes() + .OfType() + .FirstOrDefault(m => m.Identifier.ValueText == methodDeclaration.Identifier.ValueText); + + if (newMethodDeclaration != null) + { + var asyncModifier = SyntaxFactory.Token(SyntaxKind.AsyncKeyword).WithTrailingTrivia(SyntaxFactory.Space); + var newModifiers = newMethodDeclaration.Modifiers.Add(asyncModifier); + var updatedMethodDeclaration = newMethodDeclaration.WithModifiers(newModifiers); + + // Update return type to Task if it's void + if (newMethodDeclaration.ReturnType.ToString() == "void") + { + updatedMethodDeclaration = updatedMethodDeclaration.WithReturnType( + SyntaxFactory.IdentifierName("Task").WithLeadingTrivia(newMethodDeclaration.ReturnType.GetLeadingTrivia()).WithTrailingTrivia(SyntaxFactory.Space)); + } + + newRoot = newRoot.ReplaceNode(newMethodDeclaration, updatedMethodDeclaration); + } + } + + // Format the entire document + var formattedRoot = Formatter.Format(newRoot, Formatter.Annotation, document.Project.Solution.Workspace); + return document.WithSyntaxRoot(formattedRoot); + } + } + } + + var newExpression = await GetNewExpression(context, expressionSyntax, memberAccessExpressionSyntax, methodName, actual, expected, genericArgs, expressionSyntax.ArgumentList.Arguments); if (newExpression != null) { @@ -75,12 +126,16 @@ private static async Task ConvertAssertionAsync(CodeFixContext context } private static async Task GetNewExpression(CodeFixContext context, + InvocationExpressionSyntax expressionSyntax, MemberAccessExpressionSyntax memberAccessExpressionSyntax, string method, ArgumentSyntax? actual, ArgumentSyntax? expected, string genericArgs, SeparatedSyntaxList argumentListArguments) { var isGeneric = !string.IsNullOrEmpty(genericArgs); + // Check if we're inside a .Satisfy() or .Satisfies() lambda + var (isInSatisfy, parameterName) = IsInsideSatisfyLambda(expressionSyntax); + return method switch { "Equal" => await IsEqualTo(context, argumentListArguments, actual, expected), @@ -95,13 +150,21 @@ private static async Task ConvertAssertionAsync(CodeFixContext context "EndsWith" => SyntaxFactory.ParseExpression($"Assert.That({actual}).EndsWith({expected})"), - "NotNull" => SyntaxFactory.ParseExpression($"Assert.That({actual}).IsNotNull()"), + "NotNull" => isInSatisfy && parameterName != null + ? SyntaxFactory.ParseExpression($"{actual}.IsNotNull()") + : SyntaxFactory.ParseExpression($"Assert.That({actual}).IsNotNull()"), - "Null" => SyntaxFactory.ParseExpression($"Assert.That({actual}).IsNull()"), + "Null" => isInSatisfy && parameterName != null + ? SyntaxFactory.ParseExpression($"{actual}.IsNull()") + : SyntaxFactory.ParseExpression($"Assert.That({actual}).IsNull()"), - "True" => SyntaxFactory.ParseExpression($"Assert.That({actual}).IsTrue()"), + "True" => isInSatisfy && parameterName != null + ? SyntaxFactory.ParseExpression($"{actual}.IsTrue()") + : SyntaxFactory.ParseExpression($"Assert.That({actual}).IsTrue()"), - "False" => SyntaxFactory.ParseExpression($"Assert.That({actual}).IsFalse()"), + "False" => isInSatisfy && parameterName != null + ? SyntaxFactory.ParseExpression($"{actual}.IsFalse()") + : SyntaxFactory.ParseExpression($"Assert.That({actual}).IsFalse()"), "Same" => SyntaxFactory.ParseExpression($"Assert.That({actual}).IsSameReferenceAs({expected})"), @@ -123,7 +186,7 @@ private static async Task ConvertAssertionAsync(CodeFixContext context ? SyntaxFactory.ParseExpression($"Assert.That({actual}).IsNotAssignableFrom<{genericArgs}>()") : SyntaxFactory.ParseExpression($"Assert.That({actual}).IsNotAssignableFrom({expected})"), - "All" => SyntaxFactory.ParseExpression($"Assert.That({actual}).All().Satisfy({expected})"), + // "All" is handled separately in ConvertAssertionAsync "Single" => SyntaxFactory.ParseExpression($"Assert.That({actual}).HasSingleItem()"), @@ -278,4 +341,193 @@ public static string GetGenericArguments(ExpressionSyntax expressionSyntax) return string.Empty; } + + private static (bool isInSatisfy, string? parameterName) IsInsideSatisfyLambda(SyntaxNode node) + { + var current = node.Parent; + + while (current != null) + { + // Check if we're in a lambda expression + if (current is SimpleLambdaExpressionSyntax simpleLambda) + { + // Check if the lambda is an argument to a .Satisfy() or .Satisfies() call + if (current.Parent is ArgumentSyntax argument && + argument.Parent is ArgumentListSyntax argumentList && + argumentList.Parent is InvocationExpressionSyntax invocation && + invocation.Expression is MemberAccessExpressionSyntax memberAccess) + { + var methodName = memberAccess.Name.Identifier.ValueText; + if (methodName is "Satisfy" or "Satisfies") + { + return (true, simpleLambda.Parameter.Identifier.ValueText); + } + } + } + else if (current is ParenthesizedLambdaExpressionSyntax parenLambda) + { + // Check if the lambda is an argument to a .Satisfy() or .Satisfies() call + if (current.Parent is ArgumentSyntax argument && + argument.Parent is ArgumentListSyntax argumentList && + argumentList.Parent is InvocationExpressionSyntax invocation && + invocation.Expression is MemberAccessExpressionSyntax memberAccess) + { + var methodName = memberAccess.Name.Identifier.ValueText; + if (methodName is "Satisfy" or "Satisfies") + { + // For parenthesized lambda, get the first parameter + var firstParam = parenLambda.ParameterList.Parameters.FirstOrDefault(); + return (true, firstParam?.Identifier.ValueText); + } + } + } + + current = current.Parent; + } + + return (false, null); + } + + private static StatementSyntax? ConvertAllToStatement(ArgumentSyntax? collection, ArgumentSyntax? lambda) + { + if (lambda?.Expression is not LambdaExpressionSyntax lambdaExpression) + { + return null; + } + + // Extract lambda parameter name + string? paramName = lambdaExpression switch + { + SimpleLambdaExpressionSyntax simple => simple.Parameter.Identifier.ValueText, + ParenthesizedLambdaExpressionSyntax paren => paren.ParameterList.Parameters.FirstOrDefault()?.Identifier.ValueText, + _ => null + }; + + if (paramName == null) + { + return null; + } + + // Get the lambda body + var lambdaBody = lambdaExpression switch + { + SimpleLambdaExpressionSyntax simple => simple.Body, + ParenthesizedLambdaExpressionSyntax paren => paren.Body, + _ => null + }; + + if (lambdaBody == null) + { + return null; + } + + // Convert xUnit assertions in the lambda body to TUnit assertions + var convertedStatements = ConvertLambdaBodyToTUnitAssertions(lambdaBody); + + if (convertedStatements == null || convertedStatements.Count == 0) + { + return null; + } + + // Build the foreach statement with converted assertions + var foreachBlock = SyntaxFactory.Block(convertedStatements); + var foreachStatement = SyntaxFactory.ForEachStatement( + SyntaxFactory.IdentifierName("var"), + SyntaxFactory.Identifier(paramName), + collection!.Expression, + foreachBlock + ); + + // Build the using statement with the foreach inside + var usingBlock = SyntaxFactory.Block(foreachStatement); + var usingStatement = SyntaxFactory.UsingStatement( + declaration: null, + expression: SyntaxFactory.InvocationExpression( + SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + SyntaxFactory.IdentifierName("Assert"), + SyntaxFactory.IdentifierName("Multiple") + ), + SyntaxFactory.ArgumentList() + ), + statement: usingBlock + ).NormalizeWhitespace(); + + return usingStatement; + } + + private static List? ConvertLambdaBodyToTUnitAssertions(SyntaxNode lambdaBody) + { + var statements = new List(); + + // Extract statements from the lambda body + var bodyStatements = lambdaBody switch + { + BlockSyntax block => block.Statements, + ExpressionSyntax expr => SyntaxFactory.SingletonList( + SyntaxFactory.ExpressionStatement(expr) + ), + _ => default + }; + + if (bodyStatements == default) + { + return null; + } + + foreach (var statement in bodyStatements) + { + // Find xUnit assertion invocations in this statement + var invocations = statement.DescendantNodes().OfType() + .Where(inv => inv.Expression is MemberAccessExpressionSyntax memberAccess && + memberAccess.Expression.ToString().Contains("Xunit.Assert")) + .ToList(); + + if (invocations.Count == 0) + { + // Not an assertion statement - keep it as is + statements.Add(statement); + continue; + } + + // Convert each xUnit assertion + foreach (var invocation in invocations) + { + if (invocation.Expression is not MemberAccessExpressionSyntax memberAccess) + { + continue; + } + + var methodName = memberAccess.Name.Identifier.ValueText; + var args = invocation.ArgumentList.Arguments; + + // Convert to TUnit assertion + ExpressionSyntax? tunitAssertion = methodName switch + { + "NotNull" when args.Count >= 1 => + SyntaxFactory.ParseExpression($"Assert.That({args[0]}).IsNotNull()"), + "Null" when args.Count >= 1 => + SyntaxFactory.ParseExpression($"Assert.That({args[0]}).IsNull()"), + "True" when args.Count >= 1 => + SyntaxFactory.ParseExpression($"Assert.That({args[0]}).IsTrue()"), + "False" when args.Count >= 1 => + SyntaxFactory.ParseExpression($"Assert.That({args[0]}).IsFalse()"), + "Equal" when args.Count >= 2 => + SyntaxFactory.ParseExpression($"Assert.That({args[1]}).IsEqualTo({args[0]})"), + "NotEqual" when args.Count >= 2 => + SyntaxFactory.ParseExpression($"Assert.That({args[1]}).IsNotEqualTo({args[0]})"), + _ => null + }; + + if (tunitAssertion != null) + { + // Make it an await expression statement + var awaitExpr = SyntaxFactory.AwaitExpression(tunitAssertion); + statements.Add(SyntaxFactory.ExpressionStatement(awaitExpr)); + } + } + } + + return statements.Count > 0 ? statements : null; + } }